Module tau2.raman_calculator

Functions

def build_field_calculation_graph(*,
setup_data_future,
coupling_slices_future,
num_phonon_chunks,
args,
params: RamanParams)
Expand source code
def build_field_calculation_graph(*, setup_data_future, coupling_slices_future,
                                  num_phonon_chunks, args, params: RamanParams):
    """Builds the Dask graph for the second stage of a single magnetic field calculation.

    This function constructs a Dask task graph to calculate the Raman spectral density
    and rates (if temperatures are given) for one specific magnetic field magnitude and 
    orientation.

    Args:
        setup_data_future (dask.Future): Future pointing to the Zeeman-split
            electronic structure (energies, states, etc.).
        coupling_slices_future (dask.Future): Future pointing to the tuple of
            (all_V_rows, all_V_cols).
        num_phonon_chunks (int): The total number of chunks to divide the
            phonon pair loop into across all cores.
        args (argparse.Namespace): Parsed command-line arguments, used
            to access settings like `n_cores`.
        params (RamanParams): Dataclass containing Raman calculation
            parameters.

    Returns:
        dask.Delayed: The final delayed task in the graph for this field.
            Computing this task will yield a dictionary of results.

    """    
    all_q_ranges = _partition_q_loop(params.num_modes, num_phonon_chunks)
    worker_q_ranges = [np.array(all_q_ranges[i::args.n_cores]) for i in range(args.n_cores)]

    # Unpack futures
    all_V_rows_future, all_V_cols_future = coupling_slices_future[0], coupling_slices_future[1]
    energies_future = setup_data_future['energies']
    modes_for_calc_future = setup_data_future['modes_for_calc']
    omega_grid_future = setup_data_future['omega_grid']
    B_vec_future = setup_data_future['B_vec']
    field_mag_future = setup_data_future['field_mag']

    all_worker_tasks = []
    for q_ranges_for_one_worker in worker_q_ranges:
        if len(q_ranges_for_one_worker) > 0:
            # Slice V arrays to start at min_q
            # The worker only needs V matrices for q >= min_q
            min_q = int(np.min(q_ranges_for_one_worker[:, 0]))
            
            # Slice the delayed/future objects. 
            # Dask will pass only the slice to the worker.
            sliced_V_rows = all_V_rows_future[min_q:]
            sliced_V_cols = all_V_cols_future[min_q:]

            task = worker_task(
                q_ranges_for_one_worker.tolist(),
                all_V_rows=sliced_V_rows,
                all_V_cols=sliced_V_cols,
                energies=energies_future,
                modes_for_calc=modes_for_calc_future,
                omega_grid=omega_grid_future,
                B_vec=B_vec_future,
                params=params,
                v_offset=min_q
                )
            all_worker_tasks.append(task)
    
    merged_results = merge_worker_results(all_worker_tasks, omega_grid_future, params)

    final_task = _process_and_finalise_field_results(
        merged_results, setup_data_future, field_mag_future, args, params
    )

    return final_task

Builds the Dask graph for the second stage of a single magnetic field calculation.

This function constructs a Dask task graph to calculate the Raman spectral density and rates (if temperatures are given) for one specific magnetic field magnitude and orientation.

Args

setup_data_future : dask.Future
Future pointing to the Zeeman-split electronic structure (energies, states, etc.).
coupling_slices_future : dask.Future
Future pointing to the tuple of (all_V_rows, all_V_cols).
num_phonon_chunks : int
The total number of chunks to divide the phonon pair loop into across all cores.
args : argparse.Namespace
Parsed command-line arguments, used to access settings like n_cores.
params : RamanParams
Dataclass containing Raman calculation parameters.

Returns

dask.Delayed
The final delayed task in the graph for this field. Computing this task will yield a dictionary of results.
def calculate_W_sq(process,
V_q_rows,
V_q_cols,
V_r_rows,
V_r_cols,
energies,
E_i,
E_f,
row_idx_f,
col_idx_i,
hw_q_grid,
hw_r_grid,
max_state=None)
Expand source code
@jit(nopython=True, cache=True, nogil=True)
def calculate_W_sq(process, V_q_rows, V_q_cols, V_r_rows, V_r_cols, energies, E_i, E_f,
                   row_idx_f, col_idx_i, hw_q_grid, hw_r_grid, max_state=None):
    """
    Calculates the squared electronic matrix element for the Raman process.

    This function computes the transition amplitude by summing over all
    intermediate electronic states |c>, according to second-order
    perturbation theory.

    W^2 = | sum_c <f|V_r|c><c|V_q|i> / (E_c - E_i +- hw_q) +
                <f|V_q|c><c|V_r|i> / (E_c - E_i +- hw_r) |^2

    Args:
        process (str): The Raman process type ('pp', 'mm', 'pm', 'mp'), which
            determines the signs in the energy denominators.
        V_q_rows (np.ndarray): The relevant rows of the coupling matrix for q.
        V_q_cols (np.ndarray): The relevant columns of the coupling matrix for q.
        V_r_rows (np.ndarray): The relevant rows of the coupling matrix for r.
        V_r_cols (np.ndarray): The relevant columns of the coupling matrix for r.
        energies (np.ndarray): The Zeeman-split electronic energies.
        E_i (int): The index of the initial electronic state.
        E_f (int): The index of the final electronic state.
        row_idx_f (int): The row index corresponding to E_f in the sliced V matrices.
        col_idx_i (int): The column index corresponding to E_i in the sliced V matrices.
        hw_q_grid (np.ndarray): The grid of energies for the first phonon (omega_1).
        hw_r_grid (np.ndarray): The grid of energies for the second phonon (omega_2).
        max_state (int, optional): The maximum index of intermediate states to include.
                                   If None, all available states are used.

    Returns:
        np.ndarray: The squared matrix element, evaluated over the energy grids.
    """
    if max_state is None:
        max_state_val = len(energies)
    else:
        max_state_val = max_state

    # Initialize accumulator for the sum over intermediate states
    accumulator = np.zeros(len(hw_q_grid), dtype=np.complex128)

    E_i_val = energies[E_i]
    electronic_hwhm = 5e-8

    # Explicit loop over intermediate states 'c'
    for c in range(max_state_val):
        if c == E_i or c == E_f:
            continue

        Ec = energies[c]

        # Numerators
        # Use row_idx_f for <f|V|c> and col_idx_i for <c|V|i>
        num1 = V_r_rows[row_idx_f, c] * V_q_cols[c, col_idx_i] # <f|V_r|c><c|V_q|i>
        num2 = V_q_rows[row_idx_f, c] * V_r_cols[c, col_idx_i] # <f|V_q|c><c|V_r|i>

        # This check is not strictly necessary but can prune inner loop work
        if np.abs(num1) < 1e-30 and np.abs(num2) < 1e-30:
            continue

        # Denominators are calculated inside the loop
        if process == 'pp':
            d1 = Ec - E_i_val + hw_q_grid
            d2 = Ec - E_i_val + hw_r_grid
        elif process == 'mm':
            d1 = Ec - E_i_val - hw_q_grid
            d2 = Ec - E_i_val - hw_r_grid
        elif process == 'pm':
            d1 = Ec - E_i_val + hw_q_grid
            d2 = Ec - E_i_val - hw_r_grid
        else: # mp
            d1 = Ec - E_i_val - hw_q_grid
            d2 = Ec - E_i_val + hw_r_grid

        # Accumulate the contribution from state 'c'
        accumulator += (num1 / (d1 + 1j*electronic_hwhm)) + (num2 / (d2 + 1j*electronic_hwhm))

    return np.abs(accumulator)**2

Calculates the squared electronic matrix element for the Raman process.

This function computes the transition amplitude by summing over all intermediate electronic states |c>, according to second-order perturbation theory.

W^2 = | sum_c / (E_c - E_i +- hw_q) + / (E_c - E_i +- hw_r) |^2

Args

process : str
The Raman process type ('pp', 'mm', 'pm', 'mp'), which determines the signs in the energy denominators.
V_q_rows : np.ndarray
The relevant rows of the coupling matrix for q.
V_q_cols : np.ndarray
The relevant columns of the coupling matrix for q.
V_r_rows : np.ndarray
The relevant rows of the coupling matrix for r.
V_r_cols : np.ndarray
The relevant columns of the coupling matrix for r.
energies : np.ndarray
The Zeeman-split electronic energies.
E_i : int
The index of the initial electronic state.
E_f : int
The index of the final electronic state.
row_idx_f : int
The row index corresponding to E_f in the sliced V matrices.
col_idx_i : int
The column index corresponding to E_i in the sliced V matrices.
hw_q_grid : np.ndarray
The grid of energies for the first phonon (omega_1).
hw_r_grid : np.ndarray
The grid of energies for the second phonon (omega_2).
max_state : int, optional
The maximum index of intermediate states to include. If None, all available states are used.

Returns

np.ndarray
The squared matrix element, evaluated over the energy grids.
def driver_compute_zeeman_and_save(energies, angmom, spin, B_vec, temp_file_path)
Expand source code
def driver_compute_zeeman_and_save(energies, angmom, spin, B_vec, temp_file_path):
    """
    Performs Zeeman splitting on the Driver process and saves U to a file.
    
    This moves the memory-intensive diagonalisation out of the Dask graph.
    """
    # 1. Perform Zeeman splitting
    # We pass couplings=None because we only need U and energies here.
    results = apply_zeeman_splitting(energies, angmom, spin, B_vec, couplings=None)
    # results = (new_energies, new_couplings, state_mu, state_mJ, U)
    
    U_matrix = results[4]
    
    # 2. Save U to the temp file
    # Use 'w' to overwrite if it exists
    with h5py.File(temp_file_path, 'w') as f:
        f.create_dataset('U', data=U_matrix)
        
    # Return (energies, state_mu, state_mJ)
    # U is in the file. new_couplings is None.
    return results[0], results[2], results[3]

Performs Zeeman splitting on the Driver process and saves U to a file.

This moves the memory-intensive diagonalisation out of the Dask graph.

def merge_worker_results(results_from_workers,
omega_grid,
params: RamanParams)
Expand source code
@dask.delayed
def merge_worker_results(results_from_workers, omega_grid, params: RamanParams):
    """Final aggregation step on the Dask graph.

    This function receives the pre-aggregated results from all parallel worker
    tasks and merges them into a single, final result. It sums the partial
    spectral densities and collects debug info.

    Args:
        results_from_workers (list): List of results from worker tasks.
        omega_grid (np.ndarray): The frequency integration grid.
        params (RamanParams): The calculation parameters.

    Returns:
        tuple: (final_integrands, all_summary_data)
    """
    final_integrands = {}
    all_summary_data = []

    # Iterate over the list of results gathered from all workers.
    for worker_integrands, worker_summary in results_from_workers:
        # Collect debug summary data if it was generated.
        if params.debug_summary:
            all_summary_data.extend(worker_summary)
        
        # Sum the partial spectral densities from each worker.
        for transition, procs in worker_integrands.items():
            if transition not in final_integrands:
                final_integrands[transition] = {}
            for process, integrand in procs.items():
                if process not in final_integrands[transition]:
                    final_integrands[transition][process] = np.zeros_like(omega_grid)
                final_integrands[transition][process] += integrand
    
    return final_integrands, all_summary_data

Final aggregation step on the Dask graph.

This function receives the pre-aggregated results from all parallel worker tasks and merges them into a single, final result. It sums the partial spectral densities and collects debug info.

Args

results_from_workers : list
List of results from worker tasks.
omega_grid : np.ndarray
The frequency integration grid.
params : RamanParams
The calculation parameters.

Returns

tuple
(final_integrands, all_summary_data)
def process_q_range_jit(q_start,
q_end,
num_modes,
mode_energies,
energies,
all_V_rows,
all_V_cols,
main_grid,
width,
hwhm,
lineshape_code,
grid_variable_code,
num_c_states,
debug_summary,
summary_list,
active_indices,
window_type,
v_offset=0)
Expand source code
@jit(nopython=True, cache=True, nogil=True)
def process_q_range_jit(q_start, q_end, num_modes, mode_energies, energies,
                        all_V_rows, all_V_cols, main_grid, width, hwhm, lineshape_code, 
                        grid_variable_code, num_c_states, debug_summary, summary_list, 
                        active_indices, window_type, v_offset=0):
    """JIT-compiled mega-task to process a range of primary phonon modes (q).

    This function iterates through a given range of q modes, and for each q,
    iterates through all possible secondary r modes (r >= q). It performs
    the screening and calculation for all transitions and processes, and
    accumulates the results into a single integrand array.

    Args:
        q_start (int): Start index of the q loop.
        q_end (int): End index of the q loop.
        num_modes (int): Total number of modes.
        mode_energies (np.ndarray): Energies of the phonon modes.
        energies (np.ndarray): Zeeman-split electronic energies.
        all_V_rows (np.ndarray): Sliced array of coupling matrix rows.
        all_V_cols (np.ndarray): Sliced array of coupling matrix columns.
        main_grid (np.ndarray): The main energy integration grid.
        width (float): Integration width.
        hwhm (float): HWHM for lineshapes.
        lineshape_code (int): Code for lineshape function.
        grid_variable_code (int): Code for grid variable.
        num_c_states (int): Number of intermediate states.
        debug_summary (bool): Whether to collect debug summary.
        summary_list (List): List to append summary data to.
        active_indices (np.ndarray): Array of active electronic state indices.
        window_type (str): Windowing strategy.
        v_offset (int, optional): Offset for indexing V arrays. Defaults to 0.

    Returns:
        tuple: A tuple containing:
            - A list of unique (process_code, E_i, E_f) keys.
            - A list of the corresponding final integrand arrays for this q-range.
            - A list of summary data tuples for debugging.
    """
    unique_keys = List()
    transitions = List()
    # Store tuples of (E_i, E_f, local_i, local_f)
    # Using a list of tuples for iteration
    transition_data = List()
    
    # Iterate using local indices
    for i in range(len(active_indices)):
        E_i = active_indices[i]
        for f in range(len(active_indices)):
            E_f = active_indices[f]
            if E_i == E_f: continue
            transitions.append((E_i, E_f))
            transition_data.append((E_i, E_f, i, f))

    # Find all unique (process, transition) keys that can exist
    for E_i, E_f in transitions:
        delta_E = energies[E_f] - energies[E_i]
        # pm, mp are always possible
        unique_keys.append((2, E_i, E_f))
        unique_keys.append((3, E_i, E_f))
        if delta_E < 0:
            unique_keys.append((0, E_i, E_f)) # pp
        if delta_E > 0:
            unique_keys.append((1, E_i, E_f)) # mm

    integrands = [np.zeros_like(main_grid) for _ in unique_keys]

    # Main loop over the assigned range of primary q modes
    for idx_q in range(q_start, q_end):
        hw_q = mode_energies[idx_q]
        # Use offset to access the sliced V arrays
        # all_V_rows[0] corresponds to mode `v_offset`
        V_q_rows = all_V_rows[idx_q - v_offset]
        V_q_cols = all_V_cols[idx_q - v_offset]

        # Inner loop over secondary r modes (r >= q)
        for idx_r in range(idx_q, num_modes):
            hw_r = mode_energies[idx_r]
            # Use offset to access the sliced V arrays
            V_r_rows = all_V_rows[idx_r - v_offset]
            V_r_cols = all_V_cols[idx_r - v_offset]

            # Loop over all possible electronic transitions
            # We iterate over the pre-calculated list to get local indices
            for t_idx in range(len(transition_data)):
                E_i, E_f, local_i, local_f = transition_data[t_idx]
                
                delta_E = energies[E_f] - energies[E_i]
                two_width = 2 * width

                # --- In-line screening and calculation ---
                # pp process
                if delta_E < 0 and abs((hw_q + hw_r) + delta_E) <= two_width:
                    key = (0, E_i, E_f)
                    res_idx = unique_keys.index(key)
                    integrands[res_idx] += _calculate_pair_contribution_jit(
                        "pp", idx_q, idx_r, hw_q, hw_r, V_q_rows, V_q_cols, V_r_rows, V_r_cols, energies, E_i, E_f, local_i, local_f,
                        main_grid, width, hwhm, lineshape_code, grid_variable_code, 
                        summary_list, debug_summary, num_c_states, window_type)

                # mm process
                if delta_E > 0 and abs((hw_q + hw_r) - delta_E) <= two_width:
                    key = (1, E_i, E_f)
                    res_idx = unique_keys.index(key)
                    integrands[res_idx] += _calculate_pair_contribution_jit(
                    "mm", idx_q, idx_r, hw_q, hw_r, V_q_rows, V_q_cols, V_r_rows, V_r_cols, energies, E_i, E_f, local_i, local_f,
                    main_grid, width, hwhm, lineshape_code, grid_variable_code, 
                    summary_list, debug_summary, num_c_states, window_type)

                # pm process 
                if abs(delta_E + hw_q - hw_r) <= two_width:
                    key = (2, E_i, E_f)
                    res_idx = unique_keys.index(key)
                    integrands[res_idx] += _calculate_pair_contribution_jit(
                        "pm", idx_q, idx_r, hw_q, hw_r, V_q_rows, V_q_cols, V_r_rows, V_r_cols, energies, E_i, E_f, local_i, local_f,
                        main_grid, width, hwhm, lineshape_code, grid_variable_code, 
                        summary_list, debug_summary, num_c_states, window_type)

                # mp process
                if abs(delta_E - hw_q + hw_r) <= two_width:
                    key = (3, E_i, E_f)
                    res_idx = unique_keys.index(key)
                    integrands[res_idx] += _calculate_pair_contribution_jit(
                        "mp", idx_q, idx_r, hw_q, hw_r, V_q_rows, V_q_cols, V_r_rows, V_r_cols, energies, E_i, E_f, local_i, local_f,
                        main_grid, width, hwhm, lineshape_code, grid_variable_code, 
                        summary_list, debug_summary, num_c_states, window_type)

    return unique_keys, integrands, summary_list

JIT-compiled mega-task to process a range of primary phonon modes (q).

This function iterates through a given range of q modes, and for each q, iterates through all possible secondary r modes (r >= q). It performs the screening and calculation for all transitions and processes, and accumulates the results into a single integrand array.

Args

q_start : int
Start index of the q loop.
q_end : int
End index of the q loop.
num_modes : int
Total number of modes.
mode_energies : np.ndarray
Energies of the phonon modes.
energies : np.ndarray
Zeeman-split electronic energies.
all_V_rows : np.ndarray
Sliced array of coupling matrix rows.
all_V_cols : np.ndarray
Sliced array of coupling matrix columns.
main_grid : np.ndarray
The main energy integration grid.
width : float
Integration width.
hwhm : float
HWHM for lineshapes.
lineshape_code : int
Code for lineshape function.
grid_variable_code : int
Code for grid variable.
num_c_states : int
Number of intermediate states.
debug_summary : bool
Whether to collect debug summary.
summary_list : List
List to append summary data to.
active_indices : np.ndarray
Array of active electronic state indices.
window_type : str
Windowing strategy.
v_offset : int, optional
Offset for indexing V arrays. Defaults to 0.

Returns

tuple
A tuple containing: - A list of unique (process_code, E_i, E_f) keys. - A list of the corresponding final integrand arrays for this q-range. - A list of summary data tuples for debugging.
def run_raman_calculation(*,
initial_data,
params,
orientation_vectors,
magnitudes_to_run,
num_phonon_chunks,
args,
client,
precomputed_first_field_data=None)
Expand source code
def run_raman_calculation(*, initial_data, params, orientation_vectors, magnitudes_to_run, num_phonon_chunks, args, client, precomputed_first_field_data=None):
    """Orchestrates the adaptive parallel Dask execution of the Raman calculation.

    This function dispatches to one of two strategies based on the prepared data:
    - If `symmetrised_couplings_array` is present, it uses the fast in-memory
      path suitable for small systems.
    - If `coupling_info` is present, it uses the memory-conserving deferred
      I/O path suitable for large systems.

    Args:
        initial_data (dict): The initial electronic structure and coupling data.
        params (RamanParams): Dataclass object with calculation parameters.
        orientation_vectors (list): List of B-field orientation vectors.
        magnitudes_to_run (list): List of B-field magnitudes in Tesla.
        num_phonon_chunks (int): Number of chunks for partitioning phonon loops.
        args (argparse.Namespace): Parsed command-line arguments.
        client (dask.distributed.Client): An active Dask client.
        precomputed_first_field_data (dict, optional): Precomputed Zeeman data for first field.

    Returns:
        dict: A dictionary of results, with orientation indices as keys.
    """
    if 'symmetrised_couplings_array' in initial_data:
        return run_raman_small_system(
            initial_data=initial_data, params=params, orientation_vectors=orientation_vectors,
            magnitudes_to_run=magnitudes_to_run, num_phonon_chunks=num_phonon_chunks,
            args=args, client=client
        )
    elif 'coupling_info' in initial_data:
        return run_raman_large_system(
            initial_data=initial_data, params=params, orientation_vectors=orientation_vectors,
            magnitudes_to_run=magnitudes_to_run, num_phonon_chunks=num_phonon_chunks,
            args=args, client=client,
            precomputed_first_field_data=precomputed_first_field_data
        )
    else:
        raise ValueError("Could not determine data loading strategy from initial_data keys.")

Orchestrates the adaptive parallel Dask execution of the Raman calculation.

This function dispatches to one of two strategies based on the prepared data: - If symmetrised_couplings_array is present, it uses the fast in-memory path suitable for small systems. - If coupling_info is present, it uses the memory-conserving deferred I/O path suitable for large systems.

Args

initial_data : dict
The initial electronic structure and coupling data.
params : RamanParams
Dataclass object with calculation parameters.
orientation_vectors : list
List of B-field orientation vectors.
magnitudes_to_run : list
List of B-field magnitudes in Tesla.
num_phonon_chunks : int
Number of chunks for partitioning phonon loops.
args : argparse.Namespace
Parsed command-line arguments.
client : dask.distributed.Client
An active Dask client.
precomputed_first_field_data : dict, optional
Precomputed Zeeman data for first field.

Returns

dict
A dictionary of results, with orientation indices as keys.
def run_raman_large_system(*,
initial_data,
params,
orientation_vectors,
magnitudes_to_run,
num_phonon_chunks,
args,
client,
precomputed_first_field_data=None)
Expand source code
def run_raman_large_system(*, initial_data, params, orientation_vectors, magnitudes_to_run, num_phonon_chunks, args, client, precomputed_first_field_data=None):
    """
    Executes the Raman calculation using the file-based 'Large System' strategy.

    Args:
        initial_data (dict): Initial data loaded from file.
        params (RamanParams): Calculation parameters.
        orientation_vectors (list): List of orientation vectors.
        magnitudes_to_run (list): List of field magnitudes.
        num_phonon_chunks (int): Number of chunks for phonon loop.
        args (argparse.Namespace): Command-line arguments.
        client (dask.distributed.Client): Dask client.
        precomputed_first_field_data (dict, optional): Precomputed Zeeman data for first field.

    Returns:
        dict: The results of the calculation.
    """
    omega_grid = _create_grid(args)
    omega_grid_future = client.scatter(omega_grid, broadcast=True)
    worker_initial_data_future = client.scatter(initial_data, broadcast=True)

    # Determine chunk size for I/O
    workers = client.scheduler_info()['workers']
    worker_memory = next(iter(workers.values()))['memory_limit'] if workers else 4e9
    mem_per_mode = (params.max_state**2) * 16 # Complex128
    io_budget_per_task = worker_memory * 0.10
    chunk_size = max(1, int(io_budget_per_task / mem_per_mode))
    
    # Load operators on Driver for Zeeman splitting
    # We do this once to avoid reloading for every field.
    coupling_info = initial_data['coupling_info']
    with h5py.File(coupling_info['file'], 'r') as hf:
        s_load = coupling_info['s_load']
        kramers = coupling_info['kramers']
        
        def load_op(dset_group):
            ops = []
            for d in 'xyz':
                mat = dset_group[d][s_load][:, s_load]
                ops.append(symmetrise_op(op=mat, is_kramers=kramers, is_tr_even=False))
            return np.array(ops)

        angmom_driver = load_op(hf['angmom'])
        spin_driver = load_op(hf['spin'])
        energies_driver = hf['energies'][s_load]

    mode_indices = list(range(params.num_modes))
    mode_chunks = [mode_indices[j:j + chunk_size] for j in range(0, params.num_modes, chunk_size)]

    tasks = {}
    temp_files = []

    try:
        for i, B_unit in enumerate(orientation_vectors):
            tasks[i] = []
            for mag in magnitudes_to_run:
                # 1. Zeeman Splitting on Driver
                # Create a unique temp file for U matrix
                tf = tempfile.NamedTemporaryFile(delete=False, prefix=f"tau_U_{i}_{mag}_", suffix=".h5")
                tf.close()
                temp_u_path = tf.name
                temp_files.append(temp_u_path)

                B_vec_val = B_unit * mag
                
                # Check if we have precomputed data for this first field/orientation
                if (precomputed_first_field_data is not None and 
                    i == 0 and mag == magnitudes_to_run[0]):
                    
                    # Use precomputed U and energies
                    zeeman_energies = precomputed_first_field_data['energies']
                    state_mu = precomputed_first_field_data['state_mu']
                    state_mJ = precomputed_first_field_data['state_mJ']
                    U_matrix = precomputed_first_field_data['U']
                    
                    # Save the precomputed U to the temp file
                    with h5py.File(temp_u_path, 'w') as f:
                        f.create_dataset('U', data=U_matrix)
                else:
                    # Execute synchronously on Driver
                    zeeman_energies, state_mu, state_mJ = driver_compute_zeeman_and_save(
                        energies_driver, angmom_driver, spin_driver, B_vec_val, temp_u_path
                    )
                
                # Scatter these small results to workers for the graph
                # (energies is ~10-100KB, cheap to scatter)
                zeeman_energies_future = client.scatter(zeeman_energies, broadcast=True)
                B_vec_future = client.scatter(B_vec_val, broadcast=True)
                # state_mu/mJ are only needed for finalisation, can pass as values or scatter
                state_mu_future = client.scatter(state_mu, broadcast=True)
                state_mJ_future = client.scatter(state_mJ, broadcast=True)
                field_mag_future = client.scatter(mag, broadcast=True)

                # 2. Transform and Slice (Chunked I/O)
                # Pass path to U instead of matrix
                slice_chunk_tasks = [transform_and_slice_chunk(chunk, initial_data['coupling_info'], temp_u_path, params.raman_states) for chunk in mode_chunks]
                
                @dask.delayed
                def aggregate_chunks(chunks):
                    all_slices = [item for sublist in chunks for item in sublist]
                    all_rows = np.array([s[0] for s in all_slices])
                    all_cols = np.array([s[1] for s in all_slices])
                    return all_rows, all_cols
                
                coupling_slices_future = aggregate_chunks(slice_chunk_tasks)

                # 3. Construct Setup Data Dictionary
                # We already scattered the components, so we just package them into a dict
                # compatible with build_field_calculation_graph expectation.
                setup_data_future = {
                    'energies': zeeman_energies_future,
                    'state_mu': state_mu_future,
                    'state_mJ': state_mJ_future,
                    'field_mag': field_mag_future, 
                    'B_vec': B_vec_future,
                    'omega_grid': omega_grid_future,
                    # We need modes_for_calc. Since initial_data is scattered, we can pull it from there
                    # inside the worker, OR just scatter the array. 
                    # Let's extract it from the dict we already scattered if possible, 
                    # or just create a future for it.
                    'modes_for_calc': client.scatter(initial_data.get('all_mode_energies', [])[:params.num_modes], broadcast=True)
                }

                # 4. Build Calculation Graph
                task = build_field_calculation_graph(
                    setup_data_future=setup_data_future,
                    coupling_slices_future=coupling_slices_future,
                    num_phonon_chunks=num_phonon_chunks, args=args, params=params
                )
                tasks[i].append(task)
        
        results = _execute_tasks(tasks, orientation_vectors, magnitudes_to_run, args, client)
    
    finally:
        # Cleanup temp files
        for f in temp_files:
            if os.path.exists(f):
                try:
                    os.remove(f)
                except OSError:
                    pass
    
    return results

Executes the Raman calculation using the file-based 'Large System' strategy.

Args

initial_data : dict
Initial data loaded from file.
params : RamanParams
Calculation parameters.
orientation_vectors : list
List of orientation vectors.
magnitudes_to_run : list
List of field magnitudes.
num_phonon_chunks : int
Number of chunks for phonon loop.
args : argparse.Namespace
Command-line arguments.
client : dask.distributed.Client
Dask client.
precomputed_first_field_data : dict, optional
Precomputed Zeeman data for first field.

Returns

dict
The results of the calculation.
def run_raman_small_system(*,
initial_data,
params,
orientation_vectors,
magnitudes_to_run,
num_phonon_chunks,
args,
client)
Expand source code
def run_raman_small_system(*, initial_data, params, orientation_vectors, magnitudes_to_run, num_phonon_chunks, args, client):
    """
    Executes the Raman calculation using the in-memory 'Small System' strategy.

    Args:
        initial_data (dict): Initial data loaded from file.
        params (RamanParams): Calculation parameters.
        orientation_vectors (list): List of orientation vectors.
        magnitudes_to_run (list): List of field magnitudes.
        num_phonon_chunks (int): Number of chunks for phonon loop.
        args (argparse.Namespace): Command-line arguments.
        client (dask.distributed.Client): Dask client.

    Returns:
        dict: The results of the calculation.
    """
    omega_grid = _create_grid(args)
    omega_grid_future = client.scatter(omega_grid, broadcast=True)
    
    # Scatter large static arrays
    angmom_future = client.scatter(initial_data.pop('angmom'), broadcast=True)
    spin_future = client.scatter(initial_data.pop('spin'), broadcast=True)
    s_couplings_array_future = client.scatter(initial_data['symmetrised_couplings_array'], broadcast=True)
    worker_initial_data_future = client.scatter({k:v for k,v in initial_data.items() if k != 'symmetrised_couplings_array'}, broadcast=True)

    tasks = {}

    for i, B_unit in enumerate(orientation_vectors):
        tasks[i] = []
        for mag in magnitudes_to_run:
            # 1. Zeeman Splitting
            setup_data = dask.delayed(apply_zeeman_splitting)(
                initial_data['initial_energies'], angmom_future, spin_future,
                B_unit * mag, couplings=None
            )
            U_matrix_future = setup_data[4]

            # 2. Transform and Slice (In-Memory)
            coupling_slices_future = transform_and_slice_in_worker(
                s_couplings_array_future, U_matrix_future, params.raman_states
            )
            
            # Helper to package for next stage (consistent with Large System)
            @dask.delayed
            def package_coupling_slices(rows, cols):
                return (rows, cols)
            coupling_slices_future = package_coupling_slices(coupling_slices_future[0], coupling_slices_future[1])

            # 3. Package Setup Data
            @dask.delayed
            def package_setup_data(setup, mag, B_unit, omega_grid, initial_data):
                modes = initial_data.get('all_mode_energies', [])
                return {
                    'energies': setup[0], 'state_mu': setup[2], 'state_mJ': setup[3],
                    'field_mag': mag, 'B_vec': B_unit * mag, 'omega_grid': omega_grid,
                    'modes_for_calc': modes[:params.num_modes]
                }
            
            setup_data_future = package_setup_data(setup_data, mag, B_unit, omega_grid_future, worker_initial_data_future)

            # 4. Build Calculation Graph
            task = build_field_calculation_graph(
                setup_data_future=setup_data_future,
                coupling_slices_future=coupling_slices_future,
                num_phonon_chunks=num_phonon_chunks, args=args, params=params
            )
            tasks[i].append(task)

    return _execute_tasks(tasks, orientation_vectors, magnitudes_to_run, args, client)

Executes the Raman calculation using the in-memory 'Small System' strategy.

Args

initial_data : dict
Initial data loaded from file.
params : RamanParams
Calculation parameters.
orientation_vectors : list
List of orientation vectors.
magnitudes_to_run : list
List of field magnitudes.
num_phonon_chunks : int
Number of chunks for phonon loop.
args : argparse.Namespace
Command-line arguments.
client : dask.distributed.Client
Dask client.

Returns

dict
The results of the calculation.
def transform_and_slice_chunk(mode_indices, coupling_info, u_input, active_indices)
Expand source code
@dask.delayed
def transform_and_slice_chunk(mode_indices, coupling_info, u_input, active_indices):
    """
    Loads, symmetrises, transforms, and slices a chunk of coupling matrices.

    Args:
        mode_indices (list): List of phonon mode indices to process in this chunk.
        coupling_info (dict): Information for deferred loading of coupling matrices.
        u_input (str or np.ndarray): Either the path to the HDF5 file containing 'U' 
                                     or the U matrix itself.
        active_indices (list): List of electronic state indices to keep.

    Returns:
        list: A list of tuples (V_rows, V_cols) for each mode in the chunk.
    """    
    # Load U if it is passed as a path (Large System Strategy)
    if isinstance(u_input, str):
        with h5py.File(u_input, 'r') as hf_u:
             U_matrix = hf_u['U'][...]
    else:
        U_matrix = u_input

    results = []
    
    # Create the projection operator for just the bra states <active|
    U_dag_slice = U_matrix[:, active_indices].conj().T

    with h5py.File(coupling_info['file'], 'r') as hf:
        for mode_idx in mode_indices:
            # 1. Load raw matrix
            s_load = coupling_info['s_load']
            key = coupling_info['keys'][mode_idx]
            matrix = hf['couplings'][key][s_load][:, s_load]
            V_raw = matrix

            # 2. Symmetrise
            V_symm = symmetrise_op(op=V_raw, is_kramers=coupling_info['kramers'], is_tr_even=True)
            
            # 3. Transform (Subspace only)
            # Calculate only the active rows: <active|V|n>
            V_rows = (U_dag_slice @ V_symm) @ U_matrix

            # 4. Get columns via Hermiticity
            # Since V is Hermitian, <n|V|active> = <active|V|n>*
            V_cols = V_rows.conj().T
            
            results.append((V_rows, V_cols))
    return results

Loads, symmetrises, transforms, and slices a chunk of coupling matrices.

Args

mode_indices : list
List of phonon mode indices to process in this chunk.
coupling_info : dict
Information for deferred loading of coupling matrices.
u_input : str or np.ndarray
Either the path to the HDF5 file containing 'U' or the U matrix itself.
active_indices : list
List of electronic state indices to keep.

Returns

list
A list of tuples (V_rows, V_cols) for each mode in the chunk.
def transform_and_slice_in_worker(s_couplings_array, U_matrix, active_indices)
Expand source code
@dask.delayed
def transform_and_slice_in_worker(s_couplings_array, U_matrix, active_indices):
    """
    Transforms and slices pre-loaded coupling matrices inside a Dask worker.
    
    This function is used for the "Small System" strategy where all coupling
    matrices are loaded into memory. It performs the basis transformation
    and slicing efficiently on the worker.

    Args:
        s_couplings_array (np.ndarray): Dense array of symmetrised couplings (N_modes, N_states, N_states).
        U_matrix (np.ndarray): The unitary transformation matrix from Zeeman splitting.
        active_indices (list[int]): List of electronic state indices to keep.

    Returns:
        tuple: (all_V_rows, all_V_cols) arrays containing the sliced matrices.
    """
    num_modes = s_couplings_array.shape[0]
    
    # Pre-allocate output arrays
    # Rows: <active|V|n>, so shape is (N_modes, len(active), N_states)
    all_V_rows = []
    
    # Create the projection operator for just the bra states <active|
    U_dag_slice = U_matrix[:, active_indices].conj().T
    
    for i in range(num_modes):
        V_symm = s_couplings_array[i]
        # Calculate only the active rows: <active|V|n>
        # (k, N) @ (N, N) @ (N, N) -> (k, N)
        V_rows = (U_dag_slice @ V_symm) @ U_matrix
        all_V_rows.append(V_rows)

    all_V_rows = np.array(all_V_rows)
    # Since V is Hermitian, <n|V|active> = <active|V|n>*
    # Columns are just the conjugate transpose of rows.
    # Swap last two dimensions: (N_modes, k, N_states) -> (N_modes, N_states, k)
    all_V_cols = all_V_rows.transpose(0, 2, 1).conj()
    
    return all_V_rows, all_V_cols

Transforms and slices pre-loaded coupling matrices inside a Dask worker.

This function is used for the "Small System" strategy where all coupling matrices are loaded into memory. It performs the basis transformation and slicing efficiently on the worker.

Args

s_couplings_array : np.ndarray
Dense array of symmetrised couplings (N_modes, N_states, N_states).
U_matrix : np.ndarray
The unitary transformation matrix from Zeeman splitting.
active_indices : list[int]
List of electronic state indices to keep.

Returns

tuple
(all_V_rows, all_V_cols) arrays containing the sliced matrices.
def worker_task(q_ranges_for_worker,
*,
all_V_rows,
all_V_cols,
energies,
modes_for_calc,
omega_grid,
B_vec,
params: RamanParams,
v_offset=0)
Expand source code
@dask.delayed
def worker_task(q_ranges_for_worker, *, all_V_rows, all_V_cols, energies, modes_for_calc,
                omega_grid, B_vec, params: RamanParams, v_offset=0):
    """
    A self-contained dask.delayed task that runs on a single worker.

    This function is the core parallel unit of the Raman calculation. It is
    responsible for a single field and a subset of phonon pair chunks. It
    receives pre-transformed and sliced coupling matrices and invokes a
    JIT-compiled function to calculate its share of the spectral density.

    Args:
        q_ranges_for_worker (list): List of (start, end) tuples defining q-ranges.
        all_V_rows (np.ndarray): Sliced array of coupling matrix rows.
        all_V_cols (np.ndarray): Sliced array of coupling matrix columns.
        energies (np.ndarray): Zeeman-split electronic energies.
        modes_for_calc (np.ndarray): Energies of the phonon modes.
        omega_grid (np.ndarray): The frequency integration grid.
        B_vec (np.ndarray): The magnetic field vector.
        params (RamanParams): The calculation parameters.
        v_offset (int, optional): The offset to apply to q indices to index into 
                                  all_V_rows/cols. Defaults to 0.

    Returns:
        tuple: (final_integrands_worker, all_summary_data_worker)
    """
    # Use the explicitly provided active indices (from parameters, which cli.py set)
    # Convert list[int] to np.array for Numba
    active_indices_arr = np.array(params.raman_states, dtype=np.int64)

    final_integrands_worker = {}
    all_summary_data_worker = []

    summary_item_type = types.Tuple((
        types.int64, types.int64, types.int64, types.int64, types.int64,
        types.float64, types.float64, types.float64, types.float64,
        types.float64, types.float64
    ))
    
    # Codes necessary for Numba
    process_map_rev = {v: k for k, v in {'pp': 0, 'mm': 1, 'pm': 2, 'mp': 3}.items()}
    lineshape_map = {'antilorentzian': 0, 'lorentzian': 1, 'gaussian': 2}
    grid_var_map = {'omega1': 0, 'omega2': 1, 'variable': 2}

    # Loop over the q phonons assigned to this worker
    for q_start, q_end in q_ranges_for_worker:
        summary_list = List.empty_list(summary_item_type)

        unique_keys_jit, integrands_jit, summary_list_jit = process_q_range_jit(
            q_start, q_end, params.num_modes, modes_for_calc, energies,
            all_V_rows, all_V_cols, omega_grid, params.width, params.hwhm, lineshape_map[params.lineshape],
            grid_var_map[params.grid_variable], params.max_state, params.debug_summary,
            summary_list, active_indices_arr, params.window_type, v_offset
        )

        final_keys = [(process_map_rev[p], (ei, ef)) for p, ei, ef in unique_keys_jit]
        
        if params.debug_summary:
            for item in summary_list_jit:
                p_code, E_i, E_f, i_q, i_r, h_q, h_r, r_dist, ml_prod, a_w2, contrib = item
                all_summary_data_worker.append({
                    'Process': process_map_rev[p_code], 'Transition': f"{E_i}->{E_f}",
                    'Pair(q,r)': f"({i_q},{i_r})", 'hw_q': h_q, 'hw_r': h_r,
                    'ResonanceDist': r_dist, 'MaxL_prod': ml_prod,
                    'Avg_W2': a_w2, 'Contribution': contrib
                })

        for i in range(len(final_keys)):
            process, transition = final_keys[i]
            integrand = integrands_jit[i]

            if transition not in final_integrands_worker:
                final_integrands_worker[transition] = {}
            if process not in final_integrands_worker[transition]:
                final_integrands_worker[transition][process] = np.zeros_like(omega_grid)
            final_integrands_worker[transition][process] += integrand
    
    return final_integrands_worker, all_summary_data_worker

A self-contained dask.delayed task that runs on a single worker.

This function is the core parallel unit of the Raman calculation. It is responsible for a single field and a subset of phonon pair chunks. It receives pre-transformed and sliced coupling matrices and invokes a JIT-compiled function to calculate its share of the spectral density.

Args

q_ranges_for_worker : list
List of (start, end) tuples defining q-ranges.
all_V_rows : np.ndarray
Sliced array of coupling matrix rows.
all_V_cols : np.ndarray
Sliced array of coupling matrix columns.
energies : np.ndarray
Zeeman-split electronic energies.
modes_for_calc : np.ndarray
Energies of the phonon modes.
omega_grid : np.ndarray
The frequency integration grid.
B_vec : np.ndarray
The magnetic field vector.
params : RamanParams
The calculation parameters.
v_offset : int, optional
The offset to apply to q indices to index into all_V_rows/cols. Defaults to 0.

Returns

tuple
(final_integrands_worker, all_summary_data_worker)