diff --git a/wfl/generate/md/__init__.py b/wfl/generate/md/__init__.py index 6f0399eb..56e4944c 100644 --- a/wfl/generate/md/__init__.py +++ b/wfl/generate/md/__init__.py @@ -113,29 +113,37 @@ def _sample_autopara_wrappable(atoms, calculator, steps, dt, integrator="NVTBere logger_constructor = logger_kwargs.pop("logger", MDLogger) logger_logfile = logger_kwargs.get("logfile", "-") - if temperature_tau is None and (temperature is not None and not isinstance(temperature, (float, int))): - raise RuntimeError(f'NVE (temperature_tau is None) can only accept temperature=float for initial T, got {type(temperature)}') - - if temperature is not None: - if isinstance(temperature, (float, int)): - # float into a list - temperature = [temperature] - if not isinstance(temperature[0], dict): - # create a stage dict from a constant or ramp - t_stage_data = temperature - # start with constant - t_stage = {'T_i': t_stage_data[0], 'T_f': t_stage_data[0], 'traj_frac': 1.0, 'n_stages': 10, 'steps': steps} - if len(t_stage_data) >= 2: - # set different final T for ramp - t_stage['T_f'] = t_stage_data[1] - if len(t_stage_data) >= 3: - # set number of stages - t_stage['n_stages'] = t_stage_data[2] - temperature = [t_stage] - else: - for t_stage in temperature: - if 'n_stages' not in t_stage: - t_stage['n_stages'] = 10 + def _get_pressure(atoms): + return atoms.info.get("WFL_MD_PRESSURE", pressure) + + def _get_temperature(atoms): + temperature_use = atoms.info.get("WFL_MD_TEMPERATURE", temperature) + + if temperature_tau is None and (temperature_use is not None and not isinstance(temperature_use, (float, int))): + raise RuntimeError(f'NVE (temperature_tau is None) can only accept temperature=float for initial T, got {type(temperature_use)}') + + if temperature_use is not None: + if isinstance(temperature_use, (float, int)): + # float into a list + temperature_use = [temperature_use] + if not isinstance(temperature_use[0], dict): + # create a stage dict from a constant or ramp + t_stage_data = temperature_use + # start with constant + t_stage = {'T_i': t_stage_data[0], 'T_f': t_stage_data[0], 'traj_frac': 1.0, 'n_stages': 10, 'steps': steps} + if len(t_stage_data) >= 2: + # set different final T for ramp + t_stage['T_f'] = t_stage_data[1] + if len(t_stage_data) >= 3: + # set number of stages + t_stage['n_stages'] = t_stage_data[2] + temperature_use = [t_stage] + else: + for t_stage in temperature_use: + if 'n_stages' not in t_stage: + t_stage['n_stages'] = 10 + + return temperature_use for at_i, at in enumerate(atoms_to_list(atoms)): # get rng from autopara_per_item info if available ("rng" arg that was passed in was @@ -143,12 +151,15 @@ def _sample_autopara_wrappable(atoms, calculator, steps, dt, integrator="NVTBere rng = _autopara_per_item_info[at_i].get("rng") item_i = _autopara_per_item_info[at_i].get("item_i") + temperature_use = _get_temperature(at) + pressure_use = _get_pressure(at) + at.calc = calculator - if pressure is not None and compressibility_au is None: - pressure = sample_pressure(pressure, at, rng=rng) + if pressure_use is not None and compressibility_au is None: + pressure_use = sample_pressure(pressure_use, at, rng=rng) at.info['MD_pressure_GPa'] = pressure # convert to ASE internal units - pressure *= GPa + pressure_use *= GPa E0 = at.get_potential_energy() c0 = at.get_cell() @@ -160,17 +171,17 @@ def _sample_autopara_wrappable(atoms, calculator, steps, dt, integrator="NVTBere d2E_dF2 = (Ep + Em - 2.0 * E0) / (compressibility_fd_displ ** 2) compressibility_au = at.get_volume() / d2E_dF2 - if temperature is not None: + if temperature_use is not None: # set initial temperature assert rng is not None - MaxwellBoltzmannDistribution(at, temperature_K=temperature[0]['T_i'], force_temp=True, communicator=None, rng=rng) + MaxwellBoltzmannDistribution(at, temperature_K=temperature_use[0]['T_i'], force_temp=True, communicator=None, rng=rng) Stationary(at, preserve_temperature=True) stage_kwargs = {'timestep': dt * fs, 'logfile': logfile} if temperature_tau is None: # NVE - if pressure is not None: + if pressure_use is not None: raise RuntimeError('Cannot do NPH dynamics') md_constructor = VelocityVerlet # one stage, simple @@ -182,7 +193,7 @@ def _sample_autopara_wrappable(atoms, calculator, steps, dt, integrator="NVTBere all_stage_kwargs = [] all_run_kwargs = [] - if pressure is not None: + if pressure_use is not None: md_constructor = NPTBerendsen stage_kwargs['pressure_au'] = pressure stage_kwargs['compressibility_au'] = compressibility_au @@ -199,7 +210,7 @@ def _sample_autopara_wrappable(atoms, calculator, steps, dt, integrator="NVTBere assert rng is not None stage_kwargs["rng"] = rng - for t_stage_i, t_stage in enumerate(temperature): + for t_stage_i, t_stage in enumerate(temperature_use): stage_steps = t_stage['traj_frac'] * steps if t_stage['T_f'] == t_stage['T_i']: