-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
148 add training workchain #153
Open
federicazanca
wants to merge
9
commits into
stfc:main
Choose a base branch
from
federicazanca:148-add-training-workchain
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+330
−1
Open
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
19b831a
initial draft for workgraph, pre-commit fail
federicazanca 6df534d
minor change
federicazanca 2c86173
pyproject changes?
federicazanca e917a4f
workgraph mostly done but pre-commits fail
federicazanca d128b01
change paths
95fbb81
ok but entry point not working
34c535b
fixed workgraph and submission
federicazanca 1f2389e
fix pre-commit?
federicazanca 3deb4ff
Apply suggestions from code review
alinelena File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
""" Workgraph to run DFT calculations and use the outputs fpr training a MLIP model.""" | ||
|
||
from pathlib import Path | ||
|
||
from aiida_quantumespresso.workflows.pw.relax import PwRelaxWorkChain | ||
from aiida_workgraph import WorkGraph, task | ||
from ase.io import read | ||
from sklearn.model_selection import train_test_split | ||
|
||
from aiida.orm import SinglefileData | ||
from aiida.plugins import CalculationFactory, WorkflowFactory | ||
|
||
from aiida_mlip.data.config import JanusConfigfile | ||
from aiida_mlip.helpers.help_load import load_structure | ||
|
||
PwRelaxWorkChain = WorkflowFactory("quantumespresso.pw.relax") | ||
|
||
|
||
@task.graph_builder(outputs=[{"name": "result", "from": "context.pw"}]) | ||
def run_pw_calc(folder: Path, dft_inputs: dict) -> WorkGraph: | ||
""" | ||
Run a quantumespresso calculation using PwRelaxWorkChain. | ||
|
||
Parameters | ||
---------- | ||
folder : Path | ||
Path to the folder containing input structure files. | ||
dft_inputs : dict | ||
Dictionary of inputs for the DFT calculations. | ||
|
||
Returns | ||
------- | ||
WorkGraph | ||
The work graph containing the PW relaxation tasks. | ||
""" | ||
wg = WorkGraph() | ||
|
||
for child in folder.glob("**/*"): | ||
try: | ||
read(child.as_posix()) | ||
except Exception: # pylint: disable=broad-except | ||
continue | ||
structure = load_structure(child) | ||
dft_inputs["base"]["structure"] = structure | ||
dft_inputs["base"]["pw"]["metadata"]["label"] = child.stem | ||
pw_task = wg.add_task( | ||
PwRelaxWorkChain, name=f"pw_relax_{child.stem}", **dft_inputs | ||
) | ||
pw_task.set_context({"output_structure": f"pw.{child.stem}"}) | ||
return wg | ||
|
||
|
||
@task.calcfunction() | ||
def create_input(**inputs: dict) -> SinglefileData: | ||
""" | ||
Create input files from given structures. | ||
|
||
Parameters | ||
---------- | ||
**inputs : dict | ||
Dictionary where keys are names and values are structure data. | ||
|
||
Returns | ||
------- | ||
SinglefileData | ||
A SinglefileData node containing the generated input data. | ||
""" | ||
|
||
input_data = [] | ||
for _, structure in inputs.items(): | ||
ase_structure = structure.to_ase() | ||
extxyz_str = ase_structure.write(format="extxyz") | ||
input_data.append(extxyz_str) | ||
temp_file_path = "tmp.extxyz" | ||
with open(temp_file_path, "w", encoding="utf8") as temp_file: | ||
temp_file.write("\n".join(input_data)) | ||
|
||
file_data = SinglefileData(file=temp_file_path) | ||
|
||
return file_data | ||
|
||
|
||
@task.calcfunction(outputs = [{"name": train"}, | ||
{"name": test"}, | ||
{"name": "validation"} | ||
]) | ||
def split_xyz_file(xyz_file: SinglefileData) -> dict: | ||
""" | ||
Split an XYZ file into training, testing, and validation datasets. | ||
|
||
Parameters | ||
---------- | ||
xyz_file : SinglefileData | ||
A SinglefileData node containing the XYZ file. | ||
|
||
Returns | ||
------- | ||
dict | ||
A dictionary with keys 'train', 'test', and 'validation', each containing | ||
SinglefileData nodes for the respective datasets. | ||
""" | ||
|
||
with xyz_file.open() as file: | ||
lines = file.readlines() | ||
|
||
data = [line.strip() for line in lines if line.strip()] | ||
|
||
train_data, test_validation_data = train_test_split( | ||
data, test_size=0.4, random_state=42 | ||
) | ||
test_data, validation_data = train_test_split( | ||
test_validation_data, test_size=0.5, random_state=42 | ||
) | ||
|
||
train_path = "train.extxyz" | ||
test_path = "test.extxyz" | ||
validation_path = "validation.extxyz" | ||
|
||
with open(train_path, "w", encoding="utf8") as f: | ||
f.write("\n".join(train_data)) | ||
with open(test_path, "w", encoding="utf8") as f: | ||
f.write("\n".join(test_data)) | ||
with open(validation_path, "w", encoding="utf8") as f: | ||
f.write("\n".join(validation_data)) | ||
|
||
return { | ||
"train": SinglefileData(file=train_path), | ||
"test": SinglefileData(file=test_path), | ||
"validation": SinglefileData(file=validation_path), | ||
} | ||
|
||
|
||
@task.calcfunction() | ||
def update_janusconfigfile(janusconfigfile: JanusConfigfile) -> JanusConfigfile: | ||
""" | ||
Update the JanusConfigfile with new paths for train, test, and validation datasets. | ||
|
||
Parameters | ||
---------- | ||
janusconfigfile : JanusConfigfile | ||
The original JanusConfigfile. | ||
|
||
Returns | ||
------- | ||
JanusConfigfile | ||
A new JanusConfigfile with updated paths. | ||
""" | ||
print("CHECKPOINT 10") | ||
janus_dict = janusconfigfile.as_dictionary | ||
config_parse = janusconfigfile.get_content() | ||
|
||
content = config_parse.replace(janus_dict["train_file"], "train.extxyz") | ||
content = content.replace(janus_dict["test_file"], "test.extxyz") | ||
content = content.replace(janus_dict["train_file"], "validation.extxyz") | ||
|
||
new_config_path = "./config.yml" | ||
|
||
with open(new_config_path, "w", encoding="utf8") as file: | ||
file.write(content) | ||
|
||
return JanusConfigfile(file=new_config_path) | ||
|
||
|
||
# pylint: disable=unused-variable | ||
def TrainWorkGraph( | ||
folder_path: Path, inputs: dict, janusconfigfile: JanusConfigfile | ||
) -> WorkGraph: | ||
""" | ||
Create a workflow for optimising using QE and using the results for training mlips. | ||
|
||
Parameters | ||
---------- | ||
folder_path : Path | ||
Path to the folder containing input structure files. | ||
inputs : dict | ||
Dictionary of inputs for the calculations. | ||
janusconfigfile : JanusConfigfile | ||
File with inputs for janus calculations. | ||
|
||
Returns | ||
------- | ||
WorkGraph | ||
The workgraph containing the training workflow. | ||
""" | ||
wg = WorkGraph("trainingworkflow") | ||
|
||
pw_task = wg.add_task( | ||
run_pw_calc, name="pw_relax", folder=folder_path, dft_inputs=inputs | ||
) | ||
|
||
create_file_task = wg.add_task(create_input, name="create_input") | ||
wg.add_link(pw_task.outputs["result"], create_file_task.inputs["inputs"]) | ||
|
||
split_files_task = wg.add_task( | ||
split_xyz_file, name="split_xyz", xyz_file=create_file_task.outputs.result | ||
) | ||
|
||
update_config_task = wg.add_task( | ||
update_janusconfigfile, | ||
name="update_janusconfigfile", | ||
janusconfigfile=janusconfigfile, | ||
) | ||
|
||
wg.add_link(split_files_task.outputs["_wait"], update_config_task.inputs["_wait"]) | ||
|
||
training_calc = CalculationFactory("mlip.train") | ||
train_inputs = {} | ||
train_inputs["config_file"] = update_config_task.outputs.result | ||
|
||
train_task = wg.add_task( | ||
training_calc, name="training", mlip_config=update_config_task.outputs.result | ||
) | ||
wg.group_outputs = [{"name": "opt_structures", "from": "pw_task.output_structures"}] | ||
wg.group_outputs = [{"name": "final_model", "from": "train_task.outputs.model"}] | ||
|
||
wg.to_html() | ||
|
||
wg.max_number_jobs = 10 | ||
wg.submit(wait=True) | ||
return wg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
"""Example submission for hts workgraph.""" | ||
|
||
from pathlib import Path | ||
|
||
from aiida.orm import Dict, load_code | ||
|
||
from aiida_mlip.data.config import JanusConfigfile | ||
from aiida_mlip.workflows.training_workgraph import TrainWorkGraph | ||
|
||
folder_path = Path("/work4/scd/scarf1228/prova_train_workgraph/") | ||
code = load_code("qe-7.1@scarf") | ||
inputs = { | ||
"base": { | ||
"settings": Dict({"GAMMA_ONLY": True}), | ||
"pw": { | ||
"parameters": Dict( | ||
{ | ||
"CONTROL": { | ||
"calculation": "vc-relax", | ||
"nstep": 1200, | ||
"etot_conv_thr": 1e-05, | ||
"forc_conv_thr": 1e-04, | ||
}, | ||
"SYSTEM": { | ||
"ecutwfc": 500, | ||
"input_dft": "PBE", | ||
"nspin": 1, | ||
"occupations": "smearing", | ||
"degauss": 0.001, | ||
"smearing": "m-p", | ||
}, | ||
"ELECTRONS": { | ||
"electron_maxstep": 1000, | ||
"scf_must_converge": False, | ||
"conv_thr": 1e-08, | ||
"mixing_beta": 0.25, | ||
"diago_david_ndim": 4, | ||
"startingpot": "atomic", | ||
"startingwfc": "atomic+random", | ||
}, | ||
"IONS": { | ||
"ion_dynamics": "bfgs", | ||
}, | ||
"CELL": { | ||
"cell_dynamics": "bfgs", | ||
"cell_dofree": "ibrav", | ||
}, | ||
} | ||
), | ||
"code": code, | ||
"metadata": { | ||
"options": { | ||
"resources": { | ||
"num_machines": 4, | ||
"num_mpiprocs_per_machine": 32, | ||
}, | ||
"max_wallclock_seconds": 48 * 60 * 60, | ||
}, | ||
}, | ||
}, | ||
}, | ||
"base_final_scf": { | ||
"pw": { | ||
"parameters": Dict( | ||
{ | ||
"CONTROL": { | ||
"calculation": "scf", | ||
"tprnfor": True, | ||
}, | ||
"SYSTEM": { | ||
"ecutwfc": 70, | ||
"ecutrho": 650, | ||
"input_dft": "PBE", | ||
"occupations": "smearing", | ||
"degauss": 0.001, | ||
"smearing": "m-p", | ||
}, | ||
"ELECTRONS": { | ||
"conv_thr": 1e-10, | ||
"mixing_beta": 0.25, | ||
"diago_david_ndim": 4, | ||
"startingpot": "atomic", | ||
"startingwfc": "atomic+random", | ||
}, | ||
} | ||
), | ||
"code": code, | ||
"metadata": { | ||
"options": { | ||
"resources": { | ||
"num_machines": 1, | ||
"num_mpiprocs_per_machine": 32, | ||
}, | ||
"max_wallclock_seconds": 3 * 60 * 60, | ||
}, | ||
}, | ||
}, | ||
}, | ||
} | ||
janusconfigfile_path = "/work4/scd/scarf1228/prova_train_workgraph/mlip_train.yml" | ||
janusconfigfile = JanusConfigfile(file=janusconfigfile_path) | ||
|
||
TrainWorkGraph(folder_path, inputs, janusconfigfile) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as in #147