Module tau2.orbach_calculator
Functions
def build_orbach_gamma_task(*,
field_mag,
B_unit,
initial_data_future,
params: OrbachParams)-
Expand source code
def build_orbach_gamma_task(*, field_mag, B_unit, initial_data_future, params: OrbachParams): """ Builds the Dask graph for calculating the full Orbach gamma matrix for a single field magnitude. Args: field_mag (float): The magnitude of the applied magnetic field (T). B_unit (np.ndarray): The normalised direction vector of the field. initial_data_future (Future/Delayed): The initial data object (scattered to workers). params (OrbachParams): The immutable object with core calculation parameters. Returns: dask.delayed: A single Dask delayed object representing the complete Orbach gamma matrix calculation for the given field. """ # 1. Define a delayed setup function. This runs on the worker. @dask.delayed def setup_zeeman_on_worker(initial_data): B_vec = B_unit * field_mag # Apply Zeeman splitting to get energies and the eigenvectors (U_matrix) # Note: We do NOT pass couplings here anymore. We handle the transform manually below. energies, _, state_mu, state_mJ, U_matrix = apply_zeeman_splitting( initial_data['initial_energies'], initial_data['angmom'], initial_data['spin'], B_vec ) # Transform the dense coupling array: V_zeeman = U^dag @ V_initial @ U # This aligns the couplings with the Zeeman eigenstates. # initial_data['symmetrised_couplings_array'] shape is (N_modes, N_states, N_states) U_dagger = U_matrix.conj().T couplings_array = initial_data['symmetrised_couplings_array'] # Transform the coupling array into the Zeeman basis # einsum: 'ij,mjk,kl->mil' # i,l = indices of U_dagger/U (electronic states) # m = mode index # j,k = indices of original coupling array couplings_transformed = np.einsum('ij,mjk,kl->mil', U_dagger, couplings_array, U_matrix, optimize=True) # Slice to the requested number of states sliced_energies = energies[:params.max_state] sliced_couplings = couplings_transformed[:, :params.max_state, :params.max_state] return sliced_energies, sliced_couplings, state_mu, state_mJ, B_vec # 2. Call the delayed setup setup_result = setup_zeeman_on_worker(initial_data_future) # Unpack results (these are still delayed objects) energies_d = setup_result[0] couplings_d = setup_result[1] state_mu_d = setup_result[2] state_mJ_d = setup_result[3] B_vec_d = setup_result[4] # 3. Create the calculation task, using the delayed outputs of the setup task = dask.delayed(calculate_orbach_gamma_for_field)( energies=energies_d, couplings=couplings_d, mode_energies_all=initial_data_future, params=params, field_mag=field_mag, state_mu=state_mu_d, state_mJ=state_mJ_d, B_vec=B_vec_d ) return taskBuilds the Dask graph for calculating the full Orbach gamma matrix for a single field magnitude.
Args
field_mag:float- The magnitude of the applied magnetic field (T).
B_unit:np.ndarray- The normalised direction vector of the field.
- initial_data_future (Future/Delayed): The initial data object (scattered to workers).
params:OrbachParams- The immutable object with core calculation parameters.
Returns
dask.delayed- A single Dask delayed object representing the complete Orbach gamma matrix calculation for the given field.
def calculate_orbach_gamma_for_field(*, energies, couplings, mode_energies_all, params, field_mag, state_mu, state_mJ, B_vec)-
Expand source code
def calculate_orbach_gamma_for_field(*, energies, couplings, mode_energies_all, params, field_mag, state_mu, state_mJ, B_vec): """ Calculates the Orbach relaxation rate matrix and returns a structured dictionary. This function computes the gamma matrix for each specified temperature, diagonalises it to find the relaxation rates, and packages all results for the current magnetic field into a single dictionary. Args: energies (np.ndarray): Zeeman-split electronic energies. couplings (np.ndarray): Dense, symmetrised coupling array (transformed to Zeeman basis). mode_energies_all (np.ndarray): Phonon mode energies. params (OrbachParams): Dataclass with calculation parameters. field_mag (float): Magnitude of the magnetic field (T). state_mu (np.ndarray): Magnetic moments of the states. state_mJ (np.ndarray): mJ values of the states. Returns: dict: A dictionary containing all results for the field, including energies, magnetic moments, gamma matrices, and the final rates. """ # Slice mode energies to match the number of modes in the coupling array mode_energies = mode_energies_all['all_mode_energies'][:len(couplings)] # empty dictionaries for gamma, detailed balance + detailed values at each temperature gamma_matrices = {} det_balance_matrices = {} det_values = {} relaxation_rates_list = [] lineshape = get_lineshape_function(params.lineshape, params.fwhm) # Pre-calculate the matrix of energy differences, Delta_E_ij = E_i - E_j E_col = energies[:, np.newaxis] E_row = energies[np.newaxis, :] delta_E = E_col - E_row # Pre-calculate masks for upward (absorption) and downward (emission) transitions pos_delta_mask = delta_E > 0 neg_delta_mask = delta_E < 0 # Pre-calculate temperature independent part of Orbach rate, W± # Initialise W matrices for both processes w_p = np.zeros((params.max_state, params.max_state), dtype=np.float64) w_m = np.zeros((params.max_state, params.max_state), dtype=np.float64) for mode_idx, mode_energy in enumerate(mode_energies): # Mask for delta_E within mode energy integration width window_mask = (np.abs(delta_E) > max(0, mode_energy - params.width)) & \ (np.abs(delta_E) < mode_energy + params.width) # Skip this mode if no state pairs have an energy gap in the window if not np.any(window_mask): continue phonon_DOS = lineshape(np.abs(delta_E), mode_energy) square_couplings = np.abs(couplings[mode_idx])**2 rate_term = square_couplings * phonon_DOS * window_mask # Calculate emission W (p for 'plus', delta_E < 0) w_p += rate_term * neg_delta_mask # Calculate absorption W (m for 'minus', delta_E > 0) w_m += rate_term * pos_delta_mask # Calculate temperature dependent rate using W for T in params.temperatures: # Calculate temperature-dependent terms once per temperature thermal_pop_matrix = get_thermal_pop(np.abs(delta_E), T) # Calculate emission rates (p for 'plus', delta_E < 0) gamma_p_t = w_p * (thermal_pop_matrix + 1) # Calculate absorption rates (m for 'minus', delta_E > 0) gamma_m_t = w_m * thermal_pop_matrix gamma_matrix_t = (gamma_p_t + gamma_m_t) * TWO_PI_OVER_HBAR # Fill the diagonal np.fill_diagonal(gamma_matrix_t, -np.sum(gamma_matrix_t, axis=0)) gamma_matrices[T] = gamma_matrix_t all_eigs = None if params.gamma_diagonalisation == 'deflation': all_eigs = diagonalise_gamma_deflation(gamma_matrix_t, energies, T, KB_INV_CM) elif params.gamma_diagonalisation == 'mpmath': all_eigs = diagonalise_gamma_mpmath(gamma_matrix_t, prec=params.gamma_precision) else: all_eigs = diagonalise_gamma_arb(gamma_matrix_t, prec=params.gamma_precision) rate_value = all_eigs[1] if len(all_eigs) > 1 else 0.0 relaxation_rates_list.append(rate_value) det_balance_matrices[T], det_values[T] = calculate_detailed_balance(gamma_matrix_t, energies, T) final_rates_array = np.array(relaxation_rates_list) # Package all results for this field into a single dictionary results_for_field = { 'field_mag': field_mag, 'energies': energies, 'state_mu': state_mu, 'state_mJ': state_mJ, 'B_vec': B_vec, 'gamma_matrices': gamma_matrices, 'relaxation_rates': final_rates_array, 'det_balance': det_balance_matrices, 'det_values': det_values, } return results_for_fieldCalculates the Orbach relaxation rate matrix and returns a structured dictionary.
This function computes the gamma matrix for each specified temperature, diagonalises it to find the relaxation rates, and packages all results for the current magnetic field into a single dictionary.
Args
energies:np.ndarray- Zeeman-split electronic energies.
couplings:np.ndarray- Dense, symmetrised coupling array (transformed to Zeeman basis).
mode_energies_all:np.ndarray- Phonon mode energies.
params:OrbachParams- Dataclass with calculation parameters.
field_mag:float- Magnitude of the magnetic field (T).
state_mu:np.ndarray- Magnetic moments of the states.
state_mJ:np.ndarray- mJ values of the states.
Returns
dict- A dictionary containing all results for the field, including energies, magnetic moments, gamma matrices, and the final rates.
def diagonalise_gamma_arb(gamma_matrix, prec=113)-
Expand source code
def diagonalise_gamma_arb(gamma_matrix, prec=113): """ Diagonalises a matrix using python-flint arb for arbitrary precision. Arb is compiled C and is much faster than mpmath, hence is the default. Args: gamma_matrix (np.ndarray): The matrix to diagonalise. prec_bits (int): The desired binary precision for the calculation. 113 bits corresponds to quad precision. Returns: np.ndarray: A sorted array of the absolute values of the eigenvalues. """ import flint # Set flint's global context precision in bits flint.ctx.prec = prec # Convert the numpy matrix to a flint complex matrix (acb_mat) # The list comprehension is a robust way to handle the conversion. gamma_flint = flint.acb_mat([[flint.acb(x) for x in row] for row in gamma_matrix]) # Calculate eigenvalues. eig() returns a list of acb objects. eigs_flint = gamma_flint.eig(algorithm="approx") # abs() on an acb object returns an arb (arbitrary-precision real) object abs_eigs = [abs(eig) for eig in eigs_flint] # Sort the list of absolute eigenvalues abs_eigs.sort() # Return as a numpy array of standard floats return np.array([float(rate) for rate in abs_eigs])Diagonalises a matrix using python-flint arb for arbitrary precision. Arb is compiled C and is much faster than mpmath, hence is the default.
Args
gamma_matrix:np.ndarray- The matrix to diagonalise.
prec_bits:int- The desired binary precision for the calculation. 113 bits corresponds to quad precision.
Returns
np.ndarray- A sorted array of the absolute values of the eigenvalues.
def diagonalise_gamma_deflation(gamma_matrix, energies, temp, kB)-
Expand source code
def diagonalise_gamma_deflation(gamma_matrix, energies, temp, kB): """ Diagonalises a rate matrix using the symmetrisation and deflation method. Args: gamma_matrix (np.ndarray): The N x N rate matrix (NumPy float/complex array). energies (np.ndarray): The N energies of the states. temp (float): The temperature in Kelvin. kB (float): The Boltzmann constant. Returns: np.ndarray: A sorted array of the N eigenvalues (relaxation rates). """ import scipy.linalg n_states = gamma_matrix.shape[0] # --- 1. Symmetrisation --- # Calculate the equilibrium Boltzmann population vector eq_vector = np.exp(-energies / (kB * temp)) eq_vector /= np.sum(eq_vector) # Normalise # Create the diagonal transformation matrices D and D^-1 # Add a small epsilon to avoid division by zero for zero populations boltz_diag = np.sqrt(eq_vector + 1e-30) inv_boltz_diag = 1.0 / boltz_diag D = np.diag(boltz_diag) D_inv = np.diag(inv_boltz_diag) # Perform the similarity transformation: G' = D^-1 * G * D, and enforce perfect symmetry to remove numerical noise gamma_symmetric = D_inv @ gamma_matrix @ D gamma_symmetric = 0.5 * (gamma_symmetric + gamma_symmetric.T) # --- 2. Deflation Basis Construction --- # The first vector of our new basis is the transformed equilibrium vector q1 = inv_boltz_diag * eq_vector q1 /= np.linalg.norm(q1) # Normalise # Find an orthonormal basis for the subspace orthogonal to q1 # This forms the rest of our rotation matrix Q ortho_complement = scipy.linalg.null_space(q1[np.newaxis, :]) # Combine q1 and its complement to form the full rotation matrix Q = np.hstack((q1[:, np.newaxis], ortho_complement)) # --- 3. Deflation --- # Rotate the symmetric matrix into the new basis where the 0-eigenvalue is isolated gamma_deflated = Q.T @ gamma_symmetric @ Q # Extract the (N-1)x(N-1) submatrix that contains the non-zero eigenvalues gamma_submatrix = gamma_deflated[1:, 1:] # --- 4. Solve the Submatrix --- # Find the eigenvalues of the smaller, symmetric, and now well-behaved submatrix. nonzero_eigenvalues = scipy.linalg.eigh(gamma_submatrix, eigvals_only=True) # --- 5. Combine and Return --- # The final rates are the non-zero eigenvalues plus the zero one we removed all_eigenvalues = np.concatenate(([0.0], nonzero_eigenvalues)) # Return the absolute rates, sorted return np.sort(np.abs(all_eigenvalues))Diagonalises a rate matrix using the symmetrisation and deflation method.
Args
gamma_matrix:np.ndarray- The N x N rate matrix (NumPy float/complex array).
energies:np.ndarray- The N energies of the states.
temp:float- The temperature in Kelvin.
kB:float- The Boltzmann constant.
Returns
np.ndarray- A sorted array of the N eigenvalues (relaxation rates).
def diagonalise_gamma_mpmath(gamma_matrix, prec=113)-
Expand source code
def diagonalise_gamma_mpmath(gamma_matrix, prec=113): """ Diagonalises a matrix using mpmath for arbitrary precision. Args: gamma_matrix (np.ndarray): The matrix to diagonalise. prec_bits (int): The desired binary precision for the calculation. 113 bits corresponds to quad precision. Returns: np.ndarray: A sorted array of the absolute values of the eigenvalues. """ from mpmath import mp # Set mpmath binary precision. mp.prec = prec # Convert the final numpy matrix to an mpmath matrix gamma_matrix_mpmath = mp.matrix(gamma_matrix) # Diagonalise using mpmath's stable eigensolver eigs_mpmath, _ = mp.eig(gamma_matrix_mpmath) sorted_rates_mpmath = sorted(np.abs(eigs_mpmath)) # Return the absolute rates, sorted return np.array([float(rate) for rate in sorted_rates_mpmath])Diagonalises a matrix using mpmath for arbitrary precision.
Args
gamma_matrix:np.ndarray- The matrix to diagonalise.
prec_bits:int- The desired binary precision for the calculation. 113 bits corresponds to quad precision.
Returns
np.ndarray- A sorted array of the absolute values of the eigenvalues.
def run_orbach_calculation(*, initial_data, params, orientation_vectors, magnitudes_to_run, args)-
Expand source code
def run_orbach_calculation(*, initial_data, params, orientation_vectors, magnitudes_to_run, args): """ Top-level runner for the full Orbach calculation. Handles data scattering, graph building, and batched execution to manage memory. """ cluster, client = _setup_dask_client(args) with cluster, client: # Scatter the heavy initial data ONCE to all workers. # This prevents the task graph from becoming huge by not pickling the data into every task. initial_data_future = client.scatter(initial_data, broadcast=True) # Determine batch size for fields to prevent too large a task graph. # A batch size of ~5000 fields is usually safe for memory while maintaining parallelism. field_batch_size = 5000 all_field_results = {} # Outer loop: Orientations # If serial_orientations is True, we gather after every orientation. if args.serial_orientations is not None: serial_mode = args.serial_orientations else: serial_mode = (len(magnitudes_to_run) * len(orientation_vectors) > 1000) if serial_mode: print(f"Running orientations serially to conserve memory (Active mode: {'Explicit --serial_orientations' if args.serial_orientations else 'Auto-detected large job'}).") num_orientations = len(orientation_vectors) num_magnitudes = len(magnitudes_to_run) print(f"\n--- Calculating Orbach rates for {num_orientations} orientation{'' if num_orientations == 1 else 's'}, " f"{num_magnitudes} field magnitude{'' if num_magnitudes == 1 else 's'} ---\n") for i, B_unit in enumerate(orientation_vectors): if serial_mode: print(f"Processing Orientation {i+1}/{num_orientations}...") # Helper to process a batch of fields for the current orientation orientation_results = [] # Split magnitudes into batches for b_start in range(0, num_magnitudes, field_batch_size): b_end = min(b_start + field_batch_size, num_magnitudes) batch_mags = magnitudes_to_run[b_start:b_end] batch_tasks = [] for mag in batch_mags: task = build_orbach_gamma_task( field_mag=mag, B_unit=B_unit, initial_data_future=initial_data_future, params=params ) batch_tasks.append(task) # Compute this batch if serial_mode: futures = client.compute(batch_tasks) # Simple progress logging if b_start == 0: log_progress(futures, interval=max(1, len(futures) // 5)) batch_res = client.gather(futures) orientation_results.extend(batch_res) # Explicitly release memory del futures, batch_res else: pass # Handled below if serial_mode: all_field_results[i] = orientation_results # If we are NOT in serial mode, we need to gather all tasks first if not serial_mode: # Flatten everything all_flat_tasks = [] task_map = [] # (orient_idx, start_idx, end_idx) in the flat list current_idx = 0 for i, B_unit in enumerate(orientation_vectors): orient_tasks = [] for mag in magnitudes_to_run: task = build_orbach_gamma_task( field_mag=mag, B_unit=B_unit, initial_data_future=initial_data_future, params=params ) orient_tasks.append(task) all_flat_tasks.extend(orient_tasks) task_map.append((i, current_idx, current_idx + len(orient_tasks))) current_idx += len(orient_tasks) # Compute all at once futures = client.compute(all_flat_tasks) log_progress(futures, interval=max(1, len(all_flat_tasks) // 10)) all_results_flat = client.gather(futures) for o_idx, start, end in task_map: all_field_results[o_idx] = all_results_flat[start:end] return all_field_resultsTop-level runner for the full Orbach calculation.
Handles data scattering, graph building, and batched execution to manage memory.