Module tau2.orbach_calculator

Functions

def build_orbach_gamma_task(*,
field_mag,
B_unit,
initial_data_future,
params: OrbachParams)
Expand source code
def build_orbach_gamma_task(*, field_mag, B_unit, initial_data_future, params: OrbachParams):
    """
    Builds the Dask graph for calculating the full Orbach gamma matrix for a single field magnitude.

    Args:
        field_mag (float): The magnitude of the applied magnetic field (T).
        B_unit (np.ndarray): The normalised direction vector of the field.
        initial_data_future (Future/Delayed): The initial data object (scattered to workers).
        params (OrbachParams): The immutable object with core calculation parameters.

    Returns:
        dask.delayed: A single Dask delayed object representing the complete
                      Orbach gamma matrix calculation for the given field.
    """
    
    # 1. Define a delayed setup function. This runs on the worker.
    @dask.delayed
    def setup_zeeman_on_worker(initial_data):
        B_vec = B_unit * field_mag
        # Apply Zeeman splitting to get energies and the eigenvectors (U_matrix)
        # Note: We do NOT pass couplings here anymore. We handle the transform manually below.
        energies, _, state_mu, state_mJ, U_matrix = apply_zeeman_splitting(
            initial_data['initial_energies'], initial_data['angmom'], initial_data['spin'],
            B_vec
        )
        
        # Transform the dense coupling array: V_zeeman = U^dag @ V_initial @ U
        # This aligns the couplings with the Zeeman eigenstates.
        # initial_data['symmetrised_couplings_array'] shape is (N_modes, N_states, N_states)
        
        U_dagger = U_matrix.conj().T
        couplings_array = initial_data['symmetrised_couplings_array']
        
        # Transform the coupling array into the Zeeman basis
        # einsum: 'ij,mjk,kl->mil'
        # i,l = indices of U_dagger/U (electronic states)
        # m = mode index
        # j,k = indices of original coupling array
        couplings_transformed = np.einsum('ij,mjk,kl->mil', 
                                          U_dagger, 
                                          couplings_array, 
                                          U_matrix, 
                                          optimize=True)

        # Slice to the requested number of states
        sliced_energies = energies[:params.max_state]
        sliced_couplings = couplings_transformed[:, :params.max_state, :params.max_state]
        
        return sliced_energies, sliced_couplings, state_mu, state_mJ, B_vec

    # 2. Call the delayed setup
    setup_result = setup_zeeman_on_worker(initial_data_future)
    
    # Unpack results (these are still delayed objects)
    energies_d = setup_result[0]
    couplings_d = setup_result[1]
    state_mu_d = setup_result[2]
    state_mJ_d = setup_result[3]
    B_vec_d = setup_result[4]

    # 3. Create the calculation task, using the delayed outputs of the setup
    task = dask.delayed(calculate_orbach_gamma_for_field)(
        energies=energies_d,
        couplings=couplings_d, 
        mode_energies_all=initial_data_future, 
        params=params,
        field_mag=field_mag,
        state_mu=state_mu_d,
        state_mJ=state_mJ_d,
        B_vec=B_vec_d
    )

    return task

Builds the Dask graph for calculating the full Orbach gamma matrix for a single field magnitude.

Args

field_mag : float
The magnitude of the applied magnetic field (T).
B_unit : np.ndarray
The normalised direction vector of the field.
initial_data_future (Future/Delayed): The initial data object (scattered to workers).
params : OrbachParams
The immutable object with core calculation parameters.

Returns

dask.delayed
A single Dask delayed object representing the complete Orbach gamma matrix calculation for the given field.
def calculate_orbach_gamma_for_field(*, energies, couplings, mode_energies_all, params, field_mag, state_mu, state_mJ, B_vec)
Expand source code
def calculate_orbach_gamma_for_field(*, energies, couplings, mode_energies_all, params, field_mag, state_mu, state_mJ, B_vec):
    """
    Calculates the Orbach relaxation rate matrix and returns a structured dictionary.
    
    This function computes the gamma matrix for each specified temperature,
    diagonalises it to find the relaxation rates, and packages all results
    for the current magnetic field into a single dictionary.

    Args:
        energies (np.ndarray): Zeeman-split electronic energies.
        couplings (np.ndarray): Dense, symmetrised coupling array (transformed to Zeeman basis).
        mode_energies_all (np.ndarray): Phonon mode energies.
        params (OrbachParams): Dataclass with calculation parameters.
        field_mag (float): Magnitude of the magnetic field (T).
        state_mu (np.ndarray): Magnetic moments of the states.
        state_mJ (np.ndarray): mJ values of the states.

    Returns:
        dict: A dictionary containing all results for the field, including
              energies, magnetic moments, gamma matrices, and the final rates.
    """
    # Slice mode energies to match the number of modes in the coupling array
    mode_energies = mode_energies_all['all_mode_energies'][:len(couplings)]
    
    # empty dictionaries for gamma, detailed balance + detailed values at each temperature
    gamma_matrices = {}
    det_balance_matrices = {}
    det_values = {}

    relaxation_rates_list = []
    lineshape = get_lineshape_function(params.lineshape, params.fwhm)

    # Pre-calculate the matrix of energy differences, Delta_E_ij = E_i - E_j
    E_col = energies[:, np.newaxis]
    E_row = energies[np.newaxis, :]
    delta_E = E_col - E_row

    # Pre-calculate masks for upward (absorption) and downward (emission) transitions
    pos_delta_mask = delta_E > 0
    neg_delta_mask = delta_E < 0

    # Pre-calculate temperature independent part of Orbach rate, W±
    # Initialise W matrices for both processes
    w_p = np.zeros((params.max_state, params.max_state), dtype=np.float64)
    w_m = np.zeros((params.max_state, params.max_state), dtype=np.float64)

    for mode_idx, mode_energy in enumerate(mode_energies):
        # Mask for delta_E within mode energy integration width
        window_mask = (np.abs(delta_E) > max(0, mode_energy - params.width)) & \
                        (np.abs(delta_E) < mode_energy + params.width)

        # Skip this mode if no state pairs have an energy gap in the window
        if not np.any(window_mask):
            continue
        
        phonon_DOS = lineshape(np.abs(delta_E), mode_energy)
        
        square_couplings = np.abs(couplings[mode_idx])**2

        rate_term = square_couplings * phonon_DOS * window_mask
        # Calculate emission W (p for 'plus', delta_E < 0)
        w_p += rate_term * neg_delta_mask
        # Calculate absorption W (m for 'minus', delta_E > 0)
        w_m += rate_term * pos_delta_mask

    # Calculate temperature dependent rate using W
    for T in params.temperatures:
        # Calculate temperature-dependent terms once per temperature
        thermal_pop_matrix = get_thermal_pop(np.abs(delta_E), T)

        # Calculate emission rates (p for 'plus', delta_E < 0)
        gamma_p_t = w_p * (thermal_pop_matrix + 1)
        # Calculate absorption rates (m for 'minus', delta_E > 0)
        gamma_m_t = w_m * thermal_pop_matrix
    
        gamma_matrix_t = (gamma_p_t + gamma_m_t) * TWO_PI_OVER_HBAR
        
        # Fill the diagonal
        np.fill_diagonal(gamma_matrix_t, -np.sum(gamma_matrix_t, axis=0))
        gamma_matrices[T] = gamma_matrix_t        

        all_eigs = None
        if params.gamma_diagonalisation == 'deflation':
            all_eigs = diagonalise_gamma_deflation(gamma_matrix_t, energies, T, KB_INV_CM)
        elif params.gamma_diagonalisation == 'mpmath':
            all_eigs = diagonalise_gamma_mpmath(gamma_matrix_t, prec=params.gamma_precision)
        else:
            all_eigs = diagonalise_gamma_arb(gamma_matrix_t, prec=params.gamma_precision)

        rate_value = all_eigs[1] if len(all_eigs) > 1 else 0.0
        relaxation_rates_list.append(rate_value)
        det_balance_matrices[T], det_values[T] = calculate_detailed_balance(gamma_matrix_t, energies, T)

    final_rates_array = np.array(relaxation_rates_list)

    # Package all results for this field into a single dictionary
    results_for_field = {
        'field_mag': field_mag,
        'energies': energies,
        'state_mu': state_mu,
        'state_mJ': state_mJ,
        'B_vec': B_vec,
        'gamma_matrices': gamma_matrices,
        'relaxation_rates': final_rates_array,
        'det_balance': det_balance_matrices,
        'det_values': det_values,
    }
    
    return results_for_field

Calculates the Orbach relaxation rate matrix and returns a structured dictionary.

This function computes the gamma matrix for each specified temperature, diagonalises it to find the relaxation rates, and packages all results for the current magnetic field into a single dictionary.

Args

energies : np.ndarray
Zeeman-split electronic energies.
couplings : np.ndarray
Dense, symmetrised coupling array (transformed to Zeeman basis).
mode_energies_all : np.ndarray
Phonon mode energies.
params : OrbachParams
Dataclass with calculation parameters.
field_mag : float
Magnitude of the magnetic field (T).
state_mu : np.ndarray
Magnetic moments of the states.
state_mJ : np.ndarray
mJ values of the states.

Returns

dict
A dictionary containing all results for the field, including energies, magnetic moments, gamma matrices, and the final rates.
def diagonalise_gamma_arb(gamma_matrix, prec=113)
Expand source code
def diagonalise_gamma_arb(gamma_matrix, prec=113):
    """
    Diagonalises a matrix using python-flint arb for arbitrary precision.
    Arb is compiled C and is much faster than mpmath, hence is the default. 

    Args:
        gamma_matrix (np.ndarray): The matrix to diagonalise.
        prec_bits (int): The desired binary precision for the calculation. 
                         113 bits corresponds to quad precision.

    Returns:
        np.ndarray: A sorted array of the absolute values of the eigenvalues.
    """
    import flint

    # Set flint's global context precision in bits
    flint.ctx.prec = prec
    
    # Convert the numpy matrix to a flint complex matrix (acb_mat)
    # The list comprehension is a robust way to handle the conversion.
    gamma_flint = flint.acb_mat([[flint.acb(x) for x in row] for row in gamma_matrix])
    
    # Calculate eigenvalues. eig() returns a list of acb objects.
    eigs_flint = gamma_flint.eig(algorithm="approx")

    # abs() on an acb object returns an arb (arbitrary-precision real) object
    abs_eigs = [abs(eig) for eig in eigs_flint]
    
    # Sort the list of absolute eigenvalues
    abs_eigs.sort()
    
    # Return as a numpy array of standard floats
    return np.array([float(rate) for rate in abs_eigs])

Diagonalises a matrix using python-flint arb for arbitrary precision. Arb is compiled C and is much faster than mpmath, hence is the default.

Args

gamma_matrix : np.ndarray
The matrix to diagonalise.
prec_bits : int
The desired binary precision for the calculation. 113 bits corresponds to quad precision.

Returns

np.ndarray
A sorted array of the absolute values of the eigenvalues.
def diagonalise_gamma_deflation(gamma_matrix, energies, temp, kB)
Expand source code
def diagonalise_gamma_deflation(gamma_matrix, energies, temp, kB):
    """
    Diagonalises a rate matrix using the symmetrisation and deflation method.

    Args:
        gamma_matrix (np.ndarray): The N x N rate matrix (NumPy float/complex array).
        energies (np.ndarray): The N energies of the states.
        temp (float): The temperature in Kelvin.
        kB (float): The Boltzmann constant.

    Returns:
        np.ndarray: A sorted array of the N eigenvalues (relaxation rates).
    """
    import scipy.linalg
    n_states = gamma_matrix.shape[0]

    # --- 1. Symmetrisation ---
    # Calculate the equilibrium Boltzmann population vector
    eq_vector = np.exp(-energies / (kB * temp))
    eq_vector /= np.sum(eq_vector) # Normalise

    # Create the diagonal transformation matrices D and D^-1
    # Add a small epsilon to avoid division by zero for zero populations
    boltz_diag = np.sqrt(eq_vector + 1e-30)
    inv_boltz_diag = 1.0 / boltz_diag
    
    D = np.diag(boltz_diag)
    D_inv = np.diag(inv_boltz_diag)

    # Perform the similarity transformation: G' = D^-1 * G * D, and enforce perfect symmetry to remove numerical noise
    gamma_symmetric = D_inv @ gamma_matrix @ D
    gamma_symmetric = 0.5 * (gamma_symmetric + gamma_symmetric.T)

    # --- 2. Deflation Basis Construction ---
    # The first vector of our new basis is the transformed equilibrium vector
    q1 = inv_boltz_diag * eq_vector
    q1 /= np.linalg.norm(q1) # Normalise

    # Find an orthonormal basis for the subspace orthogonal to q1
    # This forms the rest of our rotation matrix Q
    ortho_complement = scipy.linalg.null_space(q1[np.newaxis, :])
    
    # Combine q1 and its complement to form the full rotation matrix
    Q = np.hstack((q1[:, np.newaxis], ortho_complement))

    # --- 3. Deflation ---
    # Rotate the symmetric matrix into the new basis where the 0-eigenvalue is isolated
    gamma_deflated = Q.T @ gamma_symmetric @ Q
    
    # Extract the (N-1)x(N-1) submatrix that contains the non-zero eigenvalues
    gamma_submatrix = gamma_deflated[1:, 1:]

    # --- 4. Solve the Submatrix ---
    # Find the eigenvalues of the smaller, symmetric, and now well-behaved submatrix.
    nonzero_eigenvalues = scipy.linalg.eigh(gamma_submatrix, eigvals_only=True)

    # --- 5. Combine and Return ---
    # The final rates are the non-zero eigenvalues plus the zero one we removed
    all_eigenvalues = np.concatenate(([0.0], nonzero_eigenvalues))
    
    # Return the absolute rates, sorted
    return np.sort(np.abs(all_eigenvalues))

Diagonalises a rate matrix using the symmetrisation and deflation method.

Args

gamma_matrix : np.ndarray
The N x N rate matrix (NumPy float/complex array).
energies : np.ndarray
The N energies of the states.
temp : float
The temperature in Kelvin.
kB : float
The Boltzmann constant.

Returns

np.ndarray
A sorted array of the N eigenvalues (relaxation rates).
def diagonalise_gamma_mpmath(gamma_matrix, prec=113)
Expand source code
def diagonalise_gamma_mpmath(gamma_matrix, prec=113):
    """
    Diagonalises a matrix using mpmath for arbitrary precision.

    Args:
        gamma_matrix (np.ndarray): The matrix to diagonalise.
        prec_bits (int): The desired binary precision for the calculation. 
                         113 bits corresponds to quad precision.

    Returns:
        np.ndarray: A sorted array of the absolute values of the eigenvalues.
    """
    from mpmath import mp

    # Set mpmath binary precision.
    mp.prec = prec
    
    # Convert the final numpy matrix to an mpmath matrix
    gamma_matrix_mpmath = mp.matrix(gamma_matrix)
    
    # Diagonalise using mpmath's stable eigensolver
    eigs_mpmath, _ = mp.eig(gamma_matrix_mpmath)

    sorted_rates_mpmath = sorted(np.abs(eigs_mpmath))
    
    # Return the absolute rates, sorted
    return np.array([float(rate) for rate in sorted_rates_mpmath])

Diagonalises a matrix using mpmath for arbitrary precision.

Args

gamma_matrix : np.ndarray
The matrix to diagonalise.
prec_bits : int
The desired binary precision for the calculation. 113 bits corresponds to quad precision.

Returns

np.ndarray
A sorted array of the absolute values of the eigenvalues.
def run_orbach_calculation(*, initial_data, params, orientation_vectors, magnitudes_to_run, args)
Expand source code
def run_orbach_calculation(*, initial_data, params, orientation_vectors, magnitudes_to_run, args):
    """
    Top-level runner for the full Orbach calculation.

    Handles data scattering, graph building, and batched execution to manage memory.
    """
    cluster, client = _setup_dask_client(args)
    with cluster, client:        
        # Scatter the heavy initial data ONCE to all workers.
        # This prevents the task graph from becoming huge by not pickling the data into every task.
        initial_data_future = client.scatter(initial_data, broadcast=True)

        
        # Determine batch size for fields to prevent too large a task graph.
        # A batch size of ~5000 fields is usually safe for memory while maintaining parallelism.
        field_batch_size = 5000
        
        all_field_results = {}

        # Outer loop: Orientations
        # If serial_orientations is True, we gather after every orientation.
        
        if args.serial_orientations is not None:
            serial_mode = args.serial_orientations 
        else:
            serial_mode = (len(magnitudes_to_run) * len(orientation_vectors) > 1000)
        
        if serial_mode:
            print(f"Running orientations serially to conserve memory (Active mode: {'Explicit --serial_orientations' if args.serial_orientations else 'Auto-detected large job'}).")

        num_orientations = len(orientation_vectors)
        num_magnitudes = len(magnitudes_to_run)
        
        print(f"\n--- Calculating Orbach rates for {num_orientations} orientation{'' if num_orientations == 1 else 's'}, "
            f"{num_magnitudes} field magnitude{'' if num_magnitudes == 1 else 's'} ---\n")        

        for i, B_unit in enumerate(orientation_vectors):
            if serial_mode:
                print(f"Processing Orientation {i+1}/{num_orientations}...")
            
            # Helper to process a batch of fields for the current orientation
            orientation_results = []
            
            # Split magnitudes into batches
            for b_start in range(0, num_magnitudes, field_batch_size):
                b_end = min(b_start + field_batch_size, num_magnitudes)
                batch_mags = magnitudes_to_run[b_start:b_end]
                
                batch_tasks = []
                for mag in batch_mags:
                    task = build_orbach_gamma_task(
                        field_mag=mag,
                        B_unit=B_unit,
                        initial_data_future=initial_data_future,
                        params=params
                    )
                    batch_tasks.append(task)
                
                # Compute this batch
                if serial_mode:
                    futures = client.compute(batch_tasks)
                    # Simple progress logging
                    if b_start == 0: log_progress(futures, interval=max(1, len(futures) // 5))
                    batch_res = client.gather(futures)
                    orientation_results.extend(batch_res)
                    
                    # Explicitly release memory
                    del futures, batch_res
                else:
                    pass # Handled below
            
            if serial_mode:
                all_field_results[i] = orientation_results

        # If we are NOT in serial mode, we need to gather all tasks first
        if not serial_mode:
            # Flatten everything
            all_flat_tasks = []
            task_map = [] # (orient_idx, start_idx, end_idx) in the flat list
            
            current_idx = 0
            for i, B_unit in enumerate(orientation_vectors):
                orient_tasks = []
                for mag in magnitudes_to_run:
                    task = build_orbach_gamma_task(
                        field_mag=mag, B_unit=B_unit, 
                        initial_data_future=initial_data_future, params=params
                    )
                    orient_tasks.append(task)
                all_flat_tasks.extend(orient_tasks)
                task_map.append((i, current_idx, current_idx + len(orient_tasks)))
                current_idx += len(orient_tasks)

            # Compute all at once
            futures = client.compute(all_flat_tasks)
            log_progress(futures, interval=max(1, len(all_flat_tasks) // 10))
            all_results_flat = client.gather(futures)
            
            for o_idx, start, end in task_map:
                all_field_results[o_idx] = all_results_flat[start:end]

    return all_field_results

Top-level runner for the full Orbach calculation.

Handles data scattering, graph building, and batched execution to manage memory.