Module tau2.main
Functions
def determine_parallel_strategy(args, magnitudes_to_run, orientations_to_run, params)-
Expand source code
def determine_parallel_strategy(args, magnitudes_to_run, orientations_to_run, params): """ Works out how many groups ('chunks') to split the sum over the first phonon, q. For calculations at lots of fields (e.g. hysteresis calculations) this will typically be 1. For calculations with many phonon pairs (e.g. fine q-meshes), the number of phonon pairs will be higher to better balance the computational load over multiple cores. """ if args.n_chunks: num_phonon_chunks = args.n_chunks else: target_mode_pairs = 250000 total_mode_pairs = params.num_modes * (params.num_modes + 1) // 2 mode_pair_chunks = max(total_mode_pairs // target_mode_pairs, 1) if len(magnitudes_to_run) * len(orientations_to_run) >= args.n_cores * mode_pair_chunks * 2: num_phonon_chunks = mode_pair_chunks else: target_min_chunks = 2 * args.n_cores mode_pair_chunks = ((mode_pair_chunks + args.n_cores - 1) // args.n_cores) * args.n_cores num_phonon_chunks = max(target_min_chunks, mode_pair_chunks) print(f"\nUsing {num_phonon_chunks} phonon chunk{'s'[:num_phonon_chunks^1]} per field") return num_phonon_chunksWorks out how many groups ('chunks') to split the sum over the first phonon, q. For calculations at lots of fields (e.g. hysteresis calculations) this will typically be 1. For calculations with many phonon pairs (e.g. fine q-meshes), the number of phonon pairs will be higher to better balance the computational load over multiple cores.
def main()-
Expand source code
def main(): """The main entry point for the tau2 program.""" banner=f''' x *~~~~~~~~~~~~~~~~~~xxxxxxxxxxxxxxxxxxxx *~~~~~~~~~~~~~~~~~~~~~xxxxxxxxxxxxxxxxxxxx *~~~~~~~~~x xxxx xxxx tau2 xxxx Ben Atkinson xxxx xxxx Based on Tau *~~~~~~~~~~xxxx x Jon Kragskow *~~~~~~~~~~~~~~~~xxxx xx Daniel Reta *~~~~~~~~~~~~~~~xxxx xxx Nicholas Chilton xxxx xxxx *~~~~~~~~~~~~~~xxxxxxxxxx *~~~~~~~~~~~~xxxxxxx Version {version} ''' print(banner) args = parse_cli_args() mode_map = { 'raman': main_raman, 'raman_rates': main_raman_rates, 'orbach': main_orbach, 'hyst': main_hyst, 'plot_rate': main_plotrate, 'plot_raman_j': main_plotramanj1d, 'plot_raman_j_2d': main_plotramanj2d, 'plot_orbach_j': main_plotorbachj, } args.func = mode_map[args.mode] # For tau2 raman/orbach, process args (e.g., infer energy range from modes) if args.mode in ['raman', 'orbach']: # Temporarily load just the mode energies for arg validation with h5py.File(args.input_file, 'r') as h5: all_mode_energies = h5['mode_energies'][...] all_state_energies = h5['energies'][...] args = prepare_and_validate_args(args, all_mode_energies, all_state_energies) args.func(args)The main entry point for the tau2 program.
def main_hyst(args)-
Expand source code
def main_hyst(args): """Runs the hysteresis simulation.""" run_hysteresis_simulation(args)Runs the hysteresis simulation.
def main_orbach(args)-
Expand source code
def main_orbach(args): """ High-level orchestrator for the Orbach calculation workflow. """ # Step 1: Prepare data and parameters initial_data, params = prepare_orbach_inputs(args) print_calculation_parameters(args) # Step 2: Set up the magnetic field sweep and orientations # Get the rotation matrix and magnitudes, which are independent of orientation B_unit, magnitudes_to_run, rotation_matrix = setup_field_orientation( args, initial_data['angmom'], initial_data['spin'], orientations=args.orientations ) if args.orientations > 1: orientations_principal_frame, orientation_weights = generate_orientations(args.quadrature, args.orientations) # Rotate the generated vectors from the principal frame back to the input frame orientation_vectors = (rotation_matrix @ orientations_principal_frame.T).T B_unit_for_print = orientation_vectors[0] else: # For a single orientation, B_unit is the vector we need orientation_vectors = [B_unit] orientation_weights = [1.0] B_unit_for_print = B_unit if magnitudes_to_run: if args.orientations > 1: header = f"\n--- Equilibrium Electronic Structure ({magnitudes_to_run[0]:.4f} T, first orientation) ---" else: header = f"\n--- Equilibrium Electronic Structure ({magnitudes_to_run[0]:.4f} T) ---" print(header) print_electronic_structure(initial_data, B_unit_for_print, magnitudes_to_run[0]) # Step 3: Run the core calculation all_field_results = run_orbach_calculation( initial_data=initial_data, params=params, orientation_vectors=orientation_vectors, magnitudes_to_run=magnitudes_to_run, args=args ) print("\nCalculation complete.") # Step 4: Save results save_orbach_outputs(all_field_results, params, orientation_weights, args)High-level orchestrator for the Orbach calculation workflow.
def main_plotorbachj(args)-
Expand source code
def main_plotorbachj(args): """ Plotting of the Raman spectral density as a function of omega_r and field """ output_prefix = args.output or args.input_file.rsplit('.', 1)[0] plot_orbach_spectral_density(output_prefix, args)Plotting of the Raman spectral density as a function of omega_r and field
def main_plotramanj1d(args)-
Expand source code
def main_plotramanj1d(args): """ Plotting of the Raman spectral density as a function of omega_r """ if args.output: output_prefix = args.output elif len(args.input_files) == 1: output_prefix = args.input_files[0].rsplit('.', 1)[0] else: output_prefix = "comparison" plot_1d_from_hdf5(output_prefix, args)Plotting of the Raman spectral density as a function of omega_r
def main_plotramanj2d(args)-
Expand source code
def main_plotramanj2d(args): """ Plotting of the Raman spectral density as a function of omega_r and field """ output_prefix = args.output or args.input_file.rsplit('.', 1)[0] plot_2d_from_hdf5(output_prefix, args)Plotting of the Raman spectral density as a function of omega_r and field
def main_plotrate(args)-
Expand source code
def main_plotrate(args): """ Plotting of relaxation rates. """ if args.output: output_prefix = args.output elif args.raman: output_prefix = args.raman[0].rsplit('.', 1)[0] elif args.orbach: output_prefix = args.orbach[0].rsplit('.', 1)[0] else: output_prefix = "rates" plot_rate(output_prefix, args)Plotting of relaxation rates.
def main_raman(args)-
Expand source code
def main_raman(args): """ High-level orchestrator for the Raman calculation workflow. """ # Create the Dask client and cluster. cluster, client = _setup_dask_client(args) with cluster, client: # Step 1: Prepare data and parameters, passing the client to inform strategy initial_data, params = prepare_raman_inputs(args, client) print_calculation_parameters(args) # Step 2: Set up the magnetic field sweep and orientations # Get the rotation matrix and magnitudes, which are independent of orientation B_unit, magnitudes_to_run, rotation_matrix = setup_field_orientation( args, initial_data['angmom'], initial_data['spin'], orientations=args.orientations ) if args.orientations > 1: orientations_principal_frame, orientation_weights = generate_orientations(args.quadrature, args.orientations) # Rotate the generated vectors from the principal frame back to the input frame orientation_vectors = (rotation_matrix @ orientations_principal_frame.T).T B_unit_for_print = orientation_vectors[0] else: # For a single orientation, B_unit is the vector we need orientation_vectors = [B_unit] orientation_weights = [1.0] B_unit_for_print = B_unit precomputed_first_field_data = None if magnitudes_to_run: if args.orientations > 1: header = f"\n--- Equilibrium Electronic Structure ({magnitudes_to_run[0]:.4f} T, first orientation) ---" else: header = f"\n--- Equilibrium Electronic Structure ({magnitudes_to_run[0]:.4f} T) ---" print(header) precomputed_first_field_data = print_electronic_structure(initial_data, B_unit_for_print, magnitudes_to_run[0]) # Step 3: Determine parallelisation strategy num_phonon_chunks = determine_parallel_strategy(args, magnitudes_to_run, orientation_vectors, params) # Optimisation for Large Systems: # If we are using the deferred loading strategy (coupling_info exists), we should also # avoid scattering the huge angmom/spin matrices. We delete them from initial_data # here; the workers will load them from disk on-demand during Zeeman splitting. if 'coupling_info' in initial_data: print("Large system: Removing static angmom/spin arrays from Client to prevent scattering.") if 'angmom' in initial_data: del initial_data['angmom'] if 'spin' in initial_data: del initial_data['spin'] # Step 4: Run the core calculation, passing the client all_field_results = run_raman_calculation( initial_data=initial_data, params=params, orientation_vectors=orientation_vectors, magnitudes_to_run=magnitudes_to_run, num_phonon_chunks=num_phonon_chunks, args=args, client=client, precomputed_first_field_data=precomputed_first_field_data ) print("\nCalculation complete.") # Step 5: Save results output_hdf5 = save_raman_outputs(all_field_results, params, orientation_weights, args) # Step 6: Plotting if args.plot_transition and len(magnitudes_to_run) == 1 and args.orientations == 1: print("\n--- Proceeding to Plot Generation ---") plot_prefix = output_hdf5.rsplit('.', 1)[0] plot_1d_from_hdf5([output_hdf5], plot_prefix, tuple(args.plot_transition))High-level orchestrator for the Raman calculation workflow.
def main_raman_rates(args)-
Expand source code
def main_raman_rates(args): """ Handles the 'raman_rates' mode. Loads a pre-computed spectral density from an HDF5 file and recalculates relaxation rates at newly specified temperatures. """ # Load the entire HDF5 file created by a 'tau2 raman' run. print(f"--- Loading spectral density from {args.input_file} ---") data = load_from_hdf5(args.input_file) orient_data = data['orientation_0'] loaded_params = orient_data['parameters'] results_data = orient_data['results'] # This dictionary will store the newly calculated rates for all fields. all_new_rates_data = {} # Get the max_state used for the original Gamma matrix calculation. # It may be stored as a string, so it should be converted to int. try: raman_max_state = int(loaded_params['raman_max_state']) except KeyError: print("Warning: 'raman_max_state' not found in HDF5 parameters. Defaulting to 2.") raman_max_state = 2 print(f"--- Recalculating rates for {len(results_data)} magnetic fields at specified temperatures ---") # Loop over each field in the loaded results. for field_idx, (field_key, field_data) in enumerate(results_data.items()): field_mag = loaded_params['field_magnitudes'][field_idx] # Reconstruct the input data required by the rate calculator function. raman_data = field_data.get('raman', {}) if not raman_data or 'omega_grid' not in raman_data: print(f"Warning: No Raman spectral density found for field {field_mag:.4f} T. Skipping.") continue calc_input = { 'omega_grid': raman_data['omega_grid'], 'energies': field_data['electronic_energies'], 'integrands': {} } # Find all 'transition_i_j' groups and rebuild the integrands dictionary. for key, value in raman_data.items(): if key.startswith('transition_'): # Key is 'transition_0_1', for example. _, i_str, f_str = key.split('_') i, f = int(i_str), int(f_str) # Value is a dict like {'J_pp': array, 'J_mm': array}. integrands_for_trans = {} for proc_key, J_array in value.items(): # proc_key is 'J_pp', for example. proc_name = proc_key.split('_')[1] integrands_for_trans[proc_name] = J_array calc_input['integrands'][(i, f)] = integrands_for_trans # Call the rate calculator with the loaded data and new temperatures. new_rates, _ = calculate_rates_from_data( calc_input, args.temperatures, raman_max_state, args.rate_integrator ) # Store the newly calculated rates. all_new_rates_data[field_mag] = new_rates # Write the final results to the specified output CSV file. if not all_new_rates_data: print("Calculation failed: No rates were generated.") return output_csv = args.output or f"{args.input_file.rsplit('.', 1)[0]}_rates.csv" write_rates_to_csv(output_csv, args.temperatures, all_new_rates_data)Handles the 'raman_rates' mode. Loads a pre-computed spectral density from an HDF5 file and recalculates relaxation rates at newly specified temperatures.
def prepare_orbach_inputs(args)-
Expand source code
def prepare_orbach_inputs(args): """ Loads all initial data from files and prepares the Orbach params object. """ initial_data = load_initial_data(args.input_file, args.kramers, args.max_state, args.states, args.max_modes, args.modes) n_electronic_states = initial_data['initial_energies'].shape[0] if args.max_state is None: args.max_state = n_electronic_states # --- Mode Selection and Array Construction --- # Determine which modes (0-indexed) from the input file we are actually using if args.modes: target_mode_indices = [m - 1 for m in args.modes] else: target_mode_indices = list(range(args.max_modes)) # Filter the mode energies to match the selected modes initial_data['all_mode_energies'] = initial_data['all_mode_energies'][target_mode_indices] # Convert the sparse dictionary of couplings into a dense array # ordered according to target_mode_indices. s_couplings_dict = initial_data.pop('symmetrised_couplings') num_modes = len(target_mode_indices) s_couplings_array = np.zeros((num_modes, n_electronic_states, n_electronic_states), dtype=np.complex128) for i, original_idx in enumerate(target_mode_indices): if original_idx in s_couplings_dict: s_couplings_array[i] = s_couplings_dict[original_idx] initial_data['symmetrised_couplings_array'] = s_couplings_array width = get_integration_width(args.lineshape, args.fwhm, args.integration_width) params = OrbachParams( max_state=args.max_state, max_modes=args.max_modes, fwhm=args.fwhm, lineshape=args.lineshape, gamma_diagonalisation=args.gamma_diagonalisation, gamma_precision=args.gamma_precision, width=width, hwhm=args.fwhm / 2.0, temperatures=args.temperatures, ) return initial_data, paramsLoads all initial data from files and prepares the Orbach params object.
def prepare_raman_inputs(args, client)-
Expand source code
def prepare_raman_inputs(args, client): """ Loads initial data from files and prepares the params object, choosing a data loading strategy based on the available worker memory. """ # Estimate the peak memory required to use the "load-first" strategy. # We rely on prepare_and_validate_args having set max_state and max_modes. num_states = args.max_state # If for some reason max_state is still None (shouldn't happen if validation ran), peek at file if num_states is None: with h5py.File(args.input_file, 'r') as hf: num_states = hf['energies'].shape[0] num_modes = args.max_modes # Peak memory is ~ (2*M + 6) * N*N * 16 bytes. peak_mem_estimate = (2 * num_modes + 6) * (num_states**2) * 16 # Get the memory limit from the first Dask worker. workers = client.scheduler_info()['workers'] if not workers: worker_memory = 4e9 # Fallback to 4GB else: worker_memory = next(iter(workers.values()))['memory_limit'] # Use 75% of worker memory as a safe threshold. should_defer = peak_mem_estimate > (worker_memory * 0.75) if should_defer: print(f"\nLarge system detected (Est. peak memory {peak_mem_estimate/1e9:.2f} GB > 75% of worker memory {worker_memory/1e9:.2f} GB).") print("Deferring coupling matrix loading to conserve memory.") initial_data = load_initial_data( args.input_file, args.kramers, args.max_state, args.states, args.max_modes, args.modes, defer_coupling_loading=True ) else: print(f"\nSmall system detected (Est. peak memory {peak_mem_estimate/1e9:.2f} GB < 75% of worker memory {worker_memory/1e9:.2f} GB).") print("Loading all coupling matrices into memory at once.") initial_data = load_initial_data( args.input_file, args.kramers, args.max_state, args.states, args.max_modes, args.modes, defer_coupling_loading=False ) if args.modes: target_mode_indices = [m - 1 for m in args.modes] else: target_mode_indices = list(range(args.max_modes)) s_couplings_dict = initial_data.pop('symmetrised_couplings') n_electronic_states = initial_data['initial_energies'].shape[0] s_couplings_array = np.zeros((num_modes, n_electronic_states, n_electronic_states), dtype=np.complex128) for i, original_idx in enumerate(target_mode_indices): if original_idx in s_couplings_dict: s_couplings_array[i] = s_couplings_dict[original_idx] initial_data['symmetrised_couplings_array'] = s_couplings_array width = get_integration_width(args.lineshape, args.fwhm, args.integration_width, args.erange) params = RamanParams( max_state=args.max_state, raman_states=args.raman_states, num_modes=args.max_modes, fwhm=args.fwhm, lineshape=args.lineshape, grid_variable=args.grid_variable, width=width, hwhm=args.fwhm / 2.0, temperatures=args.temperatures, window_type=args.window_type, debug_summary=(args.debug_summary or args.debug_verbose), debug_verbose=args.debug_verbose, debug_top_n=args.debug_top_n, ) return initial_data, paramsLoads initial data from files and prepares the params object, choosing a data loading strategy based on the available worker memory.
def save_orbach_outputs(all_field_results, params, orientation_weights, args)-
Expand source code
def save_orbach_outputs(all_field_results, params, orientation_weights, args): """ Handles saving all Orbach results to HDF5 and CSV files. """ if args.output is None: base_name = args.input_file.rsplit('.', 1)[0] args.output = f"{base_name}_orbach.hdf5" if args.output_csv is None and args.temperatures: base_name = args.output.rsplit('.', 1)[0] args.output_csv = f"{base_name}.csv" save_calculation_to_hdf5(all_field_results, params, orientation_weights, args.output) if args.temperatures and args.orientations == 1: all_rates_data = {res['field_mag']: res['relaxation_rates'] for res in all_field_results[0] if 'relaxation_rates' in res} write_rates_to_csv(args.output_csv, args.temperatures, all_rates_data)Handles saving all Orbach results to HDF5 and CSV files.
def save_raman_outputs(all_field_results, params, orientation_weights, args)-
Expand source code
def save_raman_outputs(all_field_results, params, orientation_weights, args): """ Handles saving Raman results to HDF5 and CSV files. """ if args.output is None: base_name = args.input_file.rsplit('.', 1)[0] args.output = f"{base_name}_raman.hdf5" if args.output_csv is None and args.temperatures: base_name = args.output.rsplit('.', 1)[0] args.output_csv = f"{base_name}.csv" save_calculation_to_hdf5(all_field_results, params, orientation_weights, args.output) if args.temperatures and args.orientations == 1: all_rates_data = {res['field_mag']: res['relaxation_rates'] for res in all_field_results[0] if 'relaxation_rates' in res} write_rates_to_csv(args.output_csv, args.temperatures, all_rates_data) return args.outputHandles saving Raman results to HDF5 and CSV files.