Skip to content

Commit

Permalink
Added support for previous smiles input
Browse files Browse the repository at this point in the history
  • Loading branch information
kmlefran committed Jan 30, 2024
1 parent 1f36fde commit 5160bcc
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 7 deletions.
20 changes: 15 additions & 5 deletions aiida_aimall/workchains.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import multiprocess as mp
import pandas as pd
from aiida.engine import ToContext, WorkChain, calcfunction
from aiida.orm import Code, Dict, Int, List, SinglefileData, Str, load_group
from aiida.orm import Code, Dict, Group, Int, List, SinglefileData, Str, load_group
from aiida.orm.extras import EntityExtras
from aiida.plugins.factories import CalculationFactory, DataFactory
from group_decomposition.fragfunctions import (
Expand Down Expand Up @@ -128,7 +128,7 @@ def parse_cml_files(singlefiledata):


@calcfunction
def generate_cml_fragments(params, cml_Dict, n_procs):
def generate_cml_fragments(params, cml_Dict, n_procs, prev_smi):
"""Fragment the molecule defined by a CML
Args:
Expand All @@ -139,17 +139,18 @@ def generate_cml_fragments(params, cml_Dict, n_procs):
"""
# pylint:disable=too-many-locals
# pylint:disable=too-many-statements
done_smi = prev_smi.get_list()
cml_list = (
cml_Dict.get_dict().values()
) # maybe just don't store cml files in database, just pass list to cgis here
param_dict = params.get_dict() # get dict from aiida node
input_type = param_dict["input_type"] # should set to cmldict
bb_patt = param_dict["bb_patt"]

done_smi = []
# done_smi = []
# dict_list = []
fd = {}
out_frame = pd.DataFrame()
# out_frame = pd.DataFrame()
with mp.Pool(n_procs.value) as pool: # pylint:disable=not-callable no-member
result_list = list(
pool.map(
Expand All @@ -166,6 +167,7 @@ def generate_cml_fragments(params, cml_Dict, n_procs):
done_smi.append(key)
# dict_list.append(frag_dict[0][key])
fd[key] = frag_dict[0][key]

while len(frame_list) > 1:
frame_list = frame_list[2:] + [merge_uniques(frame_list[0], frame_list[1])]
out_frame = frame_list[0]
Expand Down Expand Up @@ -221,7 +223,13 @@ def generate_cml_fragments(params, cml_Dict, n_procs):
out_frame = out_frame.drop("atom_types", axis=1)
out_frame = out_frame.drop("count", axis=1)
out_frame = out_frame.drop("numAttachments", axis=1)
out_dict["cgis_frame"] = PDData(out_frame)
g = Group(label="fragment_frames")
g.store()
node_frame = PDData(out_frame)
node_frame.store()
g.add_nodes()
out_dict["cgis_frame"] = node_frame
out_dict["done_smi"] = List(done_smi)
return out_dict


Expand Down Expand Up @@ -285,6 +293,7 @@ def define(cls, spec):
super().define(spec)
spec.input("cml_file_dict", valid_type=Dict)
spec.input("frag_params", valid_type=Dict)
spec.input("prev_smi", valid_type=List, default=List([]), required=False)
# spec.input("g16_code", valid_type=Code)
spec.input("procs", valid_type=Int, default=Int(8))
# spec.input('aim_code',valid_type=Code)
Expand All @@ -299,6 +308,7 @@ def generate_fragments(self):
self.inputs.frag_params,
self.inputs.cml_file_dict,
self.inputs.procs,
self.inputs.prev_smi,
)
g16_opt_group = load_group("inp_frag")
for (
Expand Down
15 changes: 13 additions & 2 deletions example/test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -162,7 +162,18 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"from aiida_aimall.data import AimqbParameters\n",
"spec.input(\"aim_params\", valid_type=AimqbParameters)\n",
" spec.input(\"file\", valid_type=SinglefileData)\n",
" # spec.output('aim_dict',valid_type=Dict)\n",
" spec.input(\"aim_code\", valid_type=Code)\n",
" spec.input(\"frag_label\", valid_type=Str, required=False)\n",
" spec.output(\"rotated_structure\", valid_type=Str)\n",
"wf(AimqbParameters(\n",
" parameter_dict={\"naat\": 2, \"nproc\": 2, \"atlaprhocps\": True}\n",
" ))"
]
}
],
"metadata": {
Expand Down
8 changes: 8 additions & 0 deletions tests/workchains/test_aimreorworkchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# import os

from aiida.orm import Dict, SinglefileData, Str

# from aiida.plugins import WorkflowFactory
from plumpy.utils import AttributesFrozendict
from subproptools import qtaim_extract as qt

Expand Down Expand Up @@ -56,3 +58,9 @@ def test_dict_to_structure():
)
str_str = dict_to_structure(str_dict)
assert isinstance(str_str, Str)


# def test_aimall():
# """Test aimall step of AIMReor"""
# AIMAllReor = WorkflowFactory("aimall.aimreor")
# wf = AIMAllReor()

0 comments on commit 5160bcc

Please sign in to comment.