diff --git a/aiida_aimall/workchains.py b/aiida_aimall/workchains.py index 025cc92..f639144 100644 --- a/aiida_aimall/workchains.py +++ b/aiida_aimall/workchains.py @@ -15,7 +15,7 @@ import multiprocess as mp import pandas as pd from aiida.engine import ToContext, WorkChain, calcfunction -from aiida.orm import Code, Dict, Int, List, SinglefileData, Str, load_group +from aiida.orm import Code, Dict, Group, Int, List, SinglefileData, Str, load_group from aiida.orm.extras import EntityExtras from aiida.plugins.factories import CalculationFactory, DataFactory from group_decomposition.fragfunctions import ( @@ -128,7 +128,7 @@ def parse_cml_files(singlefiledata): @calcfunction -def generate_cml_fragments(params, cml_Dict, n_procs): +def generate_cml_fragments(params, cml_Dict, n_procs, prev_smi): """Fragment the molecule defined by a CML Args: @@ -139,6 +139,7 @@ def generate_cml_fragments(params, cml_Dict, n_procs): """ # pylint:disable=too-many-locals # pylint:disable=too-many-statements + done_smi = prev_smi.get_list() cml_list = ( cml_Dict.get_dict().values() ) # maybe just don't store cml files in database, just pass list to cgis here @@ -146,10 +147,10 @@ def generate_cml_fragments(params, cml_Dict, n_procs): input_type = param_dict["input_type"] # should set to cmldict bb_patt = param_dict["bb_patt"] - done_smi = [] + # done_smi = [] # dict_list = [] fd = {} - out_frame = pd.DataFrame() + # out_frame = pd.DataFrame() with mp.Pool(n_procs.value) as pool: # pylint:disable=not-callable no-member result_list = list( pool.map( @@ -166,6 +167,7 @@ def generate_cml_fragments(params, cml_Dict, n_procs): done_smi.append(key) # dict_list.append(frag_dict[0][key]) fd[key] = frag_dict[0][key] + while len(frame_list) > 1: frame_list = frame_list[2:] + [merge_uniques(frame_list[0], frame_list[1])] out_frame = frame_list[0] @@ -221,7 +223,13 @@ def generate_cml_fragments(params, cml_Dict, n_procs): out_frame = out_frame.drop("atom_types", axis=1) out_frame = out_frame.drop("count", axis=1) out_frame = out_frame.drop("numAttachments", axis=1) - out_dict["cgis_frame"] = PDData(out_frame) + g = Group(label="fragment_frames") + g.store() + node_frame = PDData(out_frame) + node_frame.store() + g.add_nodes() + out_dict["cgis_frame"] = node_frame + out_dict["done_smi"] = List(done_smi) return out_dict @@ -285,6 +293,7 @@ def define(cls, spec): super().define(spec) spec.input("cml_file_dict", valid_type=Dict) spec.input("frag_params", valid_type=Dict) + spec.input("prev_smi", valid_type=List, default=List([]), required=False) # spec.input("g16_code", valid_type=Code) spec.input("procs", valid_type=Int, default=Int(8)) # spec.input('aim_code',valid_type=Code) @@ -299,6 +308,7 @@ def generate_fragments(self): self.inputs.frag_params, self.inputs.cml_file_dict, self.inputs.procs, + self.inputs.prev_smi, ) g16_opt_group = load_group("inp_frag") for ( diff --git a/example/test.ipynb b/example/test.ipynb index 3a08c4b..e6db41b 100644 --- a/example/test.ipynb +++ b/example/test.ipynb @@ -106,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -162,7 +162,18 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from aiida_aimall.data import AimqbParameters\n", + "spec.input(\"aim_params\", valid_type=AimqbParameters)\n", + " spec.input(\"file\", valid_type=SinglefileData)\n", + " # spec.output('aim_dict',valid_type=Dict)\n", + " spec.input(\"aim_code\", valid_type=Code)\n", + " spec.input(\"frag_label\", valid_type=Str, required=False)\n", + " spec.output(\"rotated_structure\", valid_type=Str)\n", + "wf(AimqbParameters(\n", + " parameter_dict={\"naat\": 2, \"nproc\": 2, \"atlaprhocps\": True}\n", + " ))" + ] } ], "metadata": { diff --git a/tests/workchains/test_aimreorworkchain.py b/tests/workchains/test_aimreorworkchain.py index 4685adb..f51b937 100644 --- a/tests/workchains/test_aimreorworkchain.py +++ b/tests/workchains/test_aimreorworkchain.py @@ -2,6 +2,8 @@ # import os from aiida.orm import Dict, SinglefileData, Str + +# from aiida.plugins import WorkflowFactory from plumpy.utils import AttributesFrozendict from subproptools import qtaim_extract as qt @@ -56,3 +58,9 @@ def test_dict_to_structure(): ) str_str = dict_to_structure(str_dict) assert isinstance(str_str, Str) + + +# def test_aimall(): +# """Test aimall step of AIMReor""" +# AIMAllReor = WorkflowFactory("aimall.aimreor") +# wf = AIMAllReor()