Module tau2.main

Functions

def determine_parallel_strategy(args, magnitudes_to_run, orientations_to_run, params)
Expand source code
def determine_parallel_strategy(args, magnitudes_to_run, orientations_to_run, params):
    """
    Works out how many groups ('chunks') to split the sum over the first phonon, q. For calculations
    at lots of fields (e.g. hysteresis calculations) this will typically be 1. For calculations with many
    phonon pairs (e.g. fine q-meshes), the number of phonon pairs will be higher to better balance the 
    computational load over multiple cores.
    """
    if args.n_chunks:
        num_phonon_chunks = args.n_chunks
    else:
        target_mode_pairs = 250000
        total_mode_pairs = params.num_modes * (params.num_modes + 1) // 2
        mode_pair_chunks = max(total_mode_pairs // target_mode_pairs, 1)

        if len(magnitudes_to_run) * len(orientations_to_run) >= args.n_cores * mode_pair_chunks * 2:
            num_phonon_chunks = mode_pair_chunks
        else:
            target_min_chunks = 2 * args.n_cores
            mode_pair_chunks = ((mode_pair_chunks + args.n_cores - 1) // args.n_cores) * args.n_cores
            num_phonon_chunks = max(target_min_chunks, mode_pair_chunks)
    
    print(f"\nUsing {num_phonon_chunks} phonon chunk{'s'[:num_phonon_chunks^1]} per field")
    return num_phonon_chunks

Works out how many groups ('chunks') to split the sum over the first phonon, q. For calculations at lots of fields (e.g. hysteresis calculations) this will typically be 1. For calculations with many phonon pairs (e.g. fine q-meshes), the number of phonon pairs will be higher to better balance the computational load over multiple cores.

def main()
Expand source code
def main():
    """The main entry point for the tau2 program."""

    banner=f'''
                                                 x
          *~~~~~~~~~~~~~~~~~~xxxxxxxxxxxxxxxxxxxx
      *~~~~~~~~~~~~~~~~~~~~~xxxxxxxxxxxxxxxxxxxx
                 *~~~~~~~~~x        xxxx
                                   xxxx           tau2
                                  xxxx              Ben Atkinson
                                 xxxx           
                                xxxx              Based on Tau
                    *~~~~~~~~~~xxxx       x         Jon Kragskow
             *~~~~~~~~~~~~~~~~xxxx       xx         Daniel Reta
              *~~~~~~~~~~~~~~~xxxx      xxx         Nicholas Chilton
                               xxxx    xxxx
                 *~~~~~~~~~~~~~~xxxxxxxxxx
                    *~~~~~~~~~~~~xxxxxxx          Version {version}
    '''
    print(banner)
    args = parse_cli_args()

    mode_map = {
        'raman': main_raman,
        'raman_rates': main_raman_rates,
        'orbach': main_orbach,
        'hyst': main_hyst,
        'plot_rate': main_plotrate,
        'plot_raman_j': main_plotramanj1d,
        'plot_raman_j_2d': main_plotramanj2d,
        'plot_orbach_j': main_plotorbachj,
    }
    args.func = mode_map[args.mode]

    # For tau2 raman/orbach, process args (e.g., infer energy range from modes)
    if args.mode in ['raman', 'orbach']:
        # Temporarily load just the mode energies for arg validation
        with h5py.File(args.input_file, 'r') as h5:
            all_mode_energies = h5['mode_energies'][...]
            all_state_energies = h5['energies'][...]
        args = prepare_and_validate_args(args, all_mode_energies, all_state_energies)

    args.func(args)

The main entry point for the tau2 program.

def main_hyst(args)
Expand source code
def main_hyst(args):
    """Runs the hysteresis simulation."""
    run_hysteresis_simulation(args)

Runs the hysteresis simulation.

def main_orbach(args)
Expand source code
def main_orbach(args):
    """
    High-level orchestrator for the Orbach calculation workflow.
    """
    # Step 1: Prepare data and parameters
    initial_data, params = prepare_orbach_inputs(args)
    print_calculation_parameters(args)

    # Step 2: Set up the magnetic field sweep and orientations
    # Get the rotation matrix and magnitudes, which are independent of orientation
    B_unit, magnitudes_to_run, rotation_matrix = setup_field_orientation(
        args, initial_data['angmom'], initial_data['spin'], orientations=args.orientations
    )

    if args.orientations > 1:
        orientations_principal_frame, orientation_weights = generate_orientations(args.quadrature, args.orientations)
        # Rotate the generated vectors from the principal frame back to the input frame
        orientation_vectors = (rotation_matrix @ orientations_principal_frame.T).T
        B_unit_for_print = orientation_vectors[0]
    else:
        # For a single orientation, B_unit is the vector we need
        orientation_vectors = [B_unit]
        orientation_weights = [1.0]
        B_unit_for_print = B_unit

    if magnitudes_to_run:
        if args.orientations > 1:
            header = f"\n--- Equilibrium Electronic Structure ({magnitudes_to_run[0]:.4f} T, first orientation) ---"
        else:
            header = f"\n--- Equilibrium Electronic Structure ({magnitudes_to_run[0]:.4f} T) ---"
        print(header)
        print_electronic_structure(initial_data, B_unit_for_print, magnitudes_to_run[0])

    # Step 3: Run the core calculation
    all_field_results = run_orbach_calculation(
        initial_data=initial_data,
        params=params,
        orientation_vectors=orientation_vectors,
        magnitudes_to_run=magnitudes_to_run,
        args=args
    )
    print("\nCalculation complete.")

    # Step 4: Save results
    save_orbach_outputs(all_field_results, params, orientation_weights, args)

High-level orchestrator for the Orbach calculation workflow.

def main_plotorbachj(args)
Expand source code
def main_plotorbachj(args):
    """
    Plotting of the Raman spectral density as a function of omega_r and field
    """

    output_prefix = args.output or args.input_file.rsplit('.', 1)[0]
    plot_orbach_spectral_density(output_prefix, args)

Plotting of the Raman spectral density as a function of omega_r and field

def main_plotramanj1d(args)
Expand source code
def main_plotramanj1d(args):
    """
    Plotting of the Raman spectral density as a function of omega_r
    """
    if args.output:
        output_prefix = args.output
    elif len(args.input_files) == 1:
        output_prefix = args.input_files[0].rsplit('.', 1)[0]
    else:
        output_prefix = "comparison"
    
    plot_1d_from_hdf5(output_prefix, args)

Plotting of the Raman spectral density as a function of omega_r

def main_plotramanj2d(args)
Expand source code
def main_plotramanj2d(args):
    """
    Plotting of the Raman spectral density as a function of omega_r and field
    """

    output_prefix = args.output or args.input_file.rsplit('.', 1)[0]
    plot_2d_from_hdf5(output_prefix, args)

Plotting of the Raman spectral density as a function of omega_r and field

def main_plotrate(args)
Expand source code
def main_plotrate(args):
    """
    Plotting of relaxation rates.
    """
    if args.output:
        output_prefix = args.output
    elif args.raman:
        output_prefix = args.raman[0].rsplit('.', 1)[0]
    elif args.orbach:
        output_prefix = args.orbach[0].rsplit('.', 1)[0]
    else:
        output_prefix = "rates"
    
    plot_rate(output_prefix, args)

Plotting of relaxation rates.

def main_raman(args)
Expand source code
def main_raman(args):
    """
    High-level orchestrator for the Raman calculation workflow.
    """
    # Create the Dask client and cluster.
    cluster, client = _setup_dask_client(args)
    with cluster, client:
        # Step 1: Prepare data and parameters, passing the client to inform strategy
        initial_data, params = prepare_raman_inputs(args, client)
        print_calculation_parameters(args)

        # Step 2: Set up the magnetic field sweep and orientations
        # Get the rotation matrix and magnitudes, which are independent of orientation
        B_unit, magnitudes_to_run, rotation_matrix = setup_field_orientation(
            args, initial_data['angmom'], initial_data['spin'], orientations=args.orientations
        )

        if args.orientations > 1:
            orientations_principal_frame, orientation_weights = generate_orientations(args.quadrature, args.orientations)
            # Rotate the generated vectors from the principal frame back to the input frame
            orientation_vectors = (rotation_matrix @ orientations_principal_frame.T).T
            B_unit_for_print = orientation_vectors[0]
        else:
            # For a single orientation, B_unit is the vector we need
            orientation_vectors = [B_unit]
            orientation_weights = [1.0]
            B_unit_for_print = B_unit

        precomputed_first_field_data = None
        if magnitudes_to_run:
            if args.orientations > 1:
                header = f"\n--- Equilibrium Electronic Structure ({magnitudes_to_run[0]:.4f} T, first orientation) ---"
            else:
                header = f"\n--- Equilibrium Electronic Structure ({magnitudes_to_run[0]:.4f} T) ---"
            print(header)
            precomputed_first_field_data = print_electronic_structure(initial_data, B_unit_for_print, magnitudes_to_run[0])

        # Step 3: Determine parallelisation strategy
        num_phonon_chunks = determine_parallel_strategy(args, magnitudes_to_run, orientation_vectors, params)

        # Optimisation for Large Systems:
        # If we are using the deferred loading strategy (coupling_info exists), we should also 
        # avoid scattering the huge angmom/spin matrices. We delete them from initial_data
        # here; the workers will load them from disk on-demand during Zeeman splitting.
        if 'coupling_info' in initial_data:
            print("Large system: Removing static angmom/spin arrays from Client to prevent scattering.")
            if 'angmom' in initial_data: del initial_data['angmom']
            if 'spin' in initial_data: del initial_data['spin']

        # Step 4: Run the core calculation, passing the client
        all_field_results = run_raman_calculation(
            initial_data=initial_data,
            params=params,
            orientation_vectors=orientation_vectors,
            magnitudes_to_run=magnitudes_to_run,
            num_phonon_chunks=num_phonon_chunks,
            args=args,
            client=client,
            precomputed_first_field_data=precomputed_first_field_data
        )
        print("\nCalculation complete.")

        # Step 5: Save results
        output_hdf5 = save_raman_outputs(all_field_results, params, orientation_weights, args)

        # Step 6: Plotting
        if args.plot_transition and len(magnitudes_to_run) == 1 and args.orientations == 1:
            print("\n--- Proceeding to Plot Generation ---")
            plot_prefix = output_hdf5.rsplit('.', 1)[0]
            plot_1d_from_hdf5([output_hdf5], plot_prefix, tuple(args.plot_transition))

High-level orchestrator for the Raman calculation workflow.

def main_raman_rates(args)
Expand source code
def main_raman_rates(args):
    """
    Handles the 'raman_rates' mode. Loads a pre-computed spectral density
    from an HDF5 file and recalculates relaxation rates at newly specified
    temperatures.
    """
    # Load the entire HDF5 file created by a 'tau2 raman' run.
    print(f"--- Loading spectral density from {args.input_file} ---")
    data = load_from_hdf5(args.input_file)
    orient_data = data['orientation_0']
    loaded_params = orient_data['parameters']
    results_data = orient_data['results']

    # This dictionary will store the newly calculated rates for all fields.
    all_new_rates_data = {}

    # Get the max_state used for the original Gamma matrix calculation.
    # It may be stored as a string, so it should be converted to int.
    try:
        raman_max_state = int(loaded_params['raman_max_state'])
    except KeyError:
        print("Warning: 'raman_max_state' not found in HDF5 parameters. Defaulting to 2.")
        raman_max_state = 2

    print(f"--- Recalculating rates for {len(results_data)} magnetic fields at specified temperatures ---")

    # Loop over each field in the loaded results.
    for field_idx, (field_key, field_data) in enumerate(results_data.items()):
        field_mag = loaded_params['field_magnitudes'][field_idx]

        # Reconstruct the input data required by the rate calculator function.
        raman_data = field_data.get('raman', {})
        if not raman_data or 'omega_grid' not in raman_data:
            print(f"Warning: No Raman spectral density found for field {field_mag:.4f} T. Skipping.")
            continue

        calc_input = {
            'omega_grid': raman_data['omega_grid'],
            'energies': field_data['electronic_energies'],
            'integrands': {}
        }

        # Find all 'transition_i_j' groups and rebuild the integrands dictionary.
        for key, value in raman_data.items():
            if key.startswith('transition_'):
                # Key is 'transition_0_1', for example.
                _, i_str, f_str = key.split('_')
                i, f = int(i_str), int(f_str)
                
                # Value is a dict like {'J_pp': array, 'J_mm': array}.
                integrands_for_trans = {}
                for proc_key, J_array in value.items():
                    # proc_key is 'J_pp', for example.
                    proc_name = proc_key.split('_')[1]
                    integrands_for_trans[proc_name] = J_array
                
                calc_input['integrands'][(i, f)] = integrands_for_trans

        # Call the rate calculator with the loaded data and new temperatures.
        new_rates, _ = calculate_rates_from_data(
            calc_input,
            args.temperatures,
            raman_max_state,
            args.rate_integrator
        )

        # Store the newly calculated rates.
        all_new_rates_data[field_mag] = new_rates

    # Write the final results to the specified output CSV file.
    if not all_new_rates_data:
        print("Calculation failed: No rates were generated.")
        return

    output_csv = args.output or f"{args.input_file.rsplit('.', 1)[0]}_rates.csv"
    write_rates_to_csv(output_csv, args.temperatures, all_new_rates_data)

Handles the 'raman_rates' mode. Loads a pre-computed spectral density from an HDF5 file and recalculates relaxation rates at newly specified temperatures.

def prepare_orbach_inputs(args)
Expand source code
def prepare_orbach_inputs(args):
    """
    Loads all initial data from files and prepares the Orbach params object.
    """
    initial_data = load_initial_data(args.input_file, args.kramers, args.max_state, args.states, args.max_modes, args.modes)

    n_electronic_states = initial_data['initial_energies'].shape[0]
    if args.max_state is None:
        args.max_state = n_electronic_states

    # --- Mode Selection and Array Construction ---
    # Determine which modes (0-indexed) from the input file we are actually using
    if args.modes:
        target_mode_indices = [m - 1 for m in args.modes]
    else:
        target_mode_indices = list(range(args.max_modes))
    
    # Filter the mode energies to match the selected modes
    initial_data['all_mode_energies'] = initial_data['all_mode_energies'][target_mode_indices]

    # Convert the sparse dictionary of couplings into a dense array
    # ordered according to target_mode_indices.
    s_couplings_dict = initial_data.pop('symmetrised_couplings')
    
    num_modes = len(target_mode_indices)
    
    s_couplings_array = np.zeros((num_modes, n_electronic_states, n_electronic_states), dtype=np.complex128)
    
    for i, original_idx in enumerate(target_mode_indices):
        if original_idx in s_couplings_dict:
            s_couplings_array[i] = s_couplings_dict[original_idx]
            
    initial_data['symmetrised_couplings_array'] = s_couplings_array
    
    width = get_integration_width(args.lineshape, args.fwhm, args.integration_width)
    
    params = OrbachParams(
        max_state=args.max_state,
        max_modes=args.max_modes,
        fwhm=args.fwhm,
        lineshape=args.lineshape,
        gamma_diagonalisation=args.gamma_diagonalisation,
        gamma_precision=args.gamma_precision,
        width=width,
        hwhm=args.fwhm / 2.0,
        temperatures=args.temperatures,
    )
    return initial_data, params

Loads all initial data from files and prepares the Orbach params object.

def prepare_raman_inputs(args, client)
Expand source code
def prepare_raman_inputs(args, client):
    """
    Loads initial data from files and prepares the params object, choosing a
    data loading strategy based on the available worker memory.
    """
    # Estimate the peak memory required to use the "load-first" strategy.
    # We rely on prepare_and_validate_args having set max_state and max_modes.
    num_states = args.max_state
    # If for some reason max_state is still None (shouldn't happen if validation ran), peek at file
    if num_states is None:
        with h5py.File(args.input_file, 'r') as hf:
            num_states = hf['energies'].shape[0]
    
    num_modes = args.max_modes
    # Peak memory is ~ (2*M + 6) * N*N * 16 bytes.
    peak_mem_estimate = (2 * num_modes + 6) * (num_states**2) * 16

    # Get the memory limit from the first Dask worker.
    workers = client.scheduler_info()['workers']
    if not workers:
        worker_memory = 4e9 # Fallback to 4GB
    else:
        worker_memory = next(iter(workers.values()))['memory_limit']

    # Use 75% of worker memory as a safe threshold.
    should_defer = peak_mem_estimate > (worker_memory * 0.75)
    
    if should_defer:
        print(f"\nLarge system detected (Est. peak memory {peak_mem_estimate/1e9:.2f} GB > 75% of worker memory {worker_memory/1e9:.2f} GB).")
        print("Deferring coupling matrix loading to conserve memory.")
        initial_data = load_initial_data(
            args.input_file, args.kramers, args.max_state, args.states, 
            args.max_modes, args.modes, defer_coupling_loading=True
        )
    else:
        print(f"\nSmall system detected (Est. peak memory {peak_mem_estimate/1e9:.2f} GB < 75% of worker memory {worker_memory/1e9:.2f} GB).")
        print("Loading all coupling matrices into memory at once.")
        initial_data = load_initial_data(
            args.input_file, args.kramers, args.max_state, args.states, 
            args.max_modes, args.modes, defer_coupling_loading=False
        )
        
        if args.modes:
            target_mode_indices = [m - 1 for m in args.modes]
        else:
            target_mode_indices = list(range(args.max_modes))

        s_couplings_dict = initial_data.pop('symmetrised_couplings')
        n_electronic_states = initial_data['initial_energies'].shape[0]
        
        s_couplings_array = np.zeros((num_modes, n_electronic_states, n_electronic_states), dtype=np.complex128)
        
        for i, original_idx in enumerate(target_mode_indices):
            if original_idx in s_couplings_dict:
                s_couplings_array[i] = s_couplings_dict[original_idx]
                
        initial_data['symmetrised_couplings_array'] = s_couplings_array
 
    width = get_integration_width(args.lineshape, args.fwhm, args.integration_width, args.erange)
    params = RamanParams(
        max_state=args.max_state, raman_states=args.raman_states,
        num_modes=args.max_modes, fwhm=args.fwhm, lineshape=args.lineshape,
        grid_variable=args.grid_variable, width=width, hwhm=args.fwhm / 2.0,
        temperatures=args.temperatures, window_type=args.window_type,

        debug_summary=(args.debug_summary or args.debug_verbose),
        debug_verbose=args.debug_verbose, debug_top_n=args.debug_top_n,
    )
    return initial_data, params

Loads initial data from files and prepares the params object, choosing a data loading strategy based on the available worker memory.

def save_orbach_outputs(all_field_results, params, orientation_weights, args)
Expand source code
def save_orbach_outputs(all_field_results, params, orientation_weights, args):
    """
    Handles saving all Orbach results to HDF5 and CSV files.
    """
    if args.output is None:
        base_name = args.input_file.rsplit('.', 1)[0]
        args.output = f"{base_name}_orbach.hdf5"

    if args.output_csv is None and args.temperatures:
        base_name = args.output.rsplit('.', 1)[0]
        args.output_csv = f"{base_name}.csv"

    save_calculation_to_hdf5(all_field_results, params, orientation_weights, args.output)
    
    if args.temperatures and args.orientations == 1:
        all_rates_data = {res['field_mag']: res['relaxation_rates'] 
                            for res in all_field_results[0] if 'relaxation_rates' in res}
        write_rates_to_csv(args.output_csv, args.temperatures, all_rates_data)

Handles saving all Orbach results to HDF5 and CSV files.

def save_raman_outputs(all_field_results, params, orientation_weights, args)
Expand source code
def save_raman_outputs(all_field_results, params, orientation_weights, args):
    """
    Handles saving Raman results to HDF5 and CSV files.
    """
    if args.output is None:
        base_name = args.input_file.rsplit('.', 1)[0]
        args.output = f"{base_name}_raman.hdf5"
    if args.output_csv is None and args.temperatures:
        base_name = args.output.rsplit('.', 1)[0]
        args.output_csv = f"{base_name}.csv"
    
    save_calculation_to_hdf5(all_field_results, params, orientation_weights, args.output)
    
    if args.temperatures and args.orientations == 1:
        all_rates_data = {res['field_mag']: res['relaxation_rates'] for res in all_field_results[0] if 'relaxation_rates' in res}
        write_rates_to_csv(args.output_csv, args.temperatures, all_rates_data)
        
    return args.output

Handles saving Raman results to HDF5 and CSV files.