Skip to content

Commit

Permalink
update L7 and X40 to use python base yaml package
Browse files Browse the repository at this point in the history
  • Loading branch information
mcneela committed Mar 12, 2024
1 parent 802b70b commit e969b54
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 19 deletions.
59 changes: 49 additions & 10 deletions openqdc/datasets/interaction/L7.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,53 @@
from typing import Dict, List

import numpy as np
import yaml
from loguru import logger
from ruamel.yaml import YAML

from openqdc.datasets.interaction import BaseInteractionDataset
from openqdc.utils.molecule import atom_table


class DataItemYAMLObj:
def __init__(self, name, shortname, geometry, reference_value, setup, group, tags):
self.name = name
self.shortname = shortname
self.geometry = geometry
self.reference_value = reference_value
self.setup = setup
self.group = group
self.tags = tags


class DataSetYAMLObj:
def __init__(self, name, references, text, method_energy, groups_by, groups, global_setup):
self.name = name
self.references = references
self.text = text
self.method_energy = method_energy
self.groups_by = groups_by
self.groups = groups
self.global_setup = global_setup


def data_item_constructor(loader: yaml.SafeLoader, node: yaml.nodes.MappingNode):
"""Construct an employee."""
return DataItemYAMLObj(**loader.construct_mapping(node))


def dataset_constructor(loader: yaml.SafeLoader, node: yaml.nodes.MappingNode):
"""Construct an employee."""
return DataSetYAMLObj(**loader.construct_mapping(node))


def get_loader():
"""Add constructors to PyYAML loader."""
loader = yaml.SafeLoader
loader.add_constructor("!ruby/object:ProtocolDataset::DataSetItem", data_item_constructor)
loader.add_constructor("!ruby/object:ProtocolDataset::DataSetDescription", dataset_constructor)
return loader


class L7(BaseInteractionDataset):
"""
The L7 interaction energy dataset as described in:
Expand Down Expand Up @@ -43,23 +83,22 @@ def read_raw_entries(self) -> List[Dict]:
yaml_fpath = os.path.join(self.root, "l7.yaml")
logger.info(f"Reading L7 interaction data from {self.root}")
yaml_file = open(yaml_fpath, "r")
yaml = YAML()
data = []
data_dict = yaml.load(yaml_file)
charge0 = int(data_dict["description"]["global_setup"]["molecule_a"]["charge"])
charge1 = int(data_dict["description"]["global_setup"]["molecule_b"]["charge"])
data_dict = yaml.load(yaml_file, Loader=get_loader())
charge0 = int(data_dict["description"].global_setup["molecule_a"]["charge"])
charge1 = int(data_dict["description"].global_setup["molecule_b"]["charge"])

for idx, item in enumerate(data_dict["items"]):
energies = []
name = np.array([item["shortname"]])
fname = item["geometry"].split(":")[1]
energies.append(item["reference_value"])
name = np.array([item.shortname])
fname = item.geometry.split(":")[1]
energies.append(item.reference_value)
xyz_file = open(os.path.join(self.root, f"{fname}.xyz"), "r")
lines = list(map(lambda x: x.strip().split(), xyz_file.readlines()))
lines.pop(1)
n_atoms = np.array([int(lines[0][0])], dtype=np.int32)
n_atoms_first = np.array([int(item["setup"]["molecule_a"]["selection"].split("-")[1])], dtype=np.int32)
subset = np.array([item["group"]])
n_atoms_first = np.array([int(item.setup["molecule_a"]["selection"].split("-")[1])], dtype=np.int32)
subset = np.array([item.group])
energies += [float(val[idx]) for val in list(data_dict["alternative_reference"].values())]
energies = np.array([energies], dtype=np.float32)
pos = np.array(lines[1:])[:, 1:].astype(np.float32)
Expand Down
18 changes: 9 additions & 9 deletions openqdc/datasets/interaction/X40.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from typing import Dict, List

import numpy as np
import yaml
from loguru import logger
from ruamel.yaml import YAML

from openqdc.datasets.interaction import BaseInteractionDataset
from openqdc.datasets.interaction.L7 import get_loader
from openqdc.utils.molecule import atom_table


Expand Down Expand Up @@ -41,23 +42,22 @@ def read_raw_entries(self) -> List[Dict]:
yaml_fpath = os.path.join(self.root, "x40.yaml")
logger.info(f"Reading X40 interaction data from {self.root}")
yaml_file = open(yaml_fpath, "r")
yaml = YAML()
data = []
data_dict = yaml.load(yaml_file)
charge0 = int(data_dict["description"]["global_setup"]["molecule_a"]["charge"])
charge1 = int(data_dict["description"]["global_setup"]["molecule_b"]["charge"])
data_dict = yaml.load(yaml_file, Loader=get_loader())
charge0 = int(data_dict["description"].global_setup["molecule_a"]["charge"])
charge1 = int(data_dict["description"].global_setup["molecule_b"]["charge"])

for idx, item in enumerate(data_dict["items"]):
energies = []
name = np.array([item["shortname"]])
energies.append(float(item["reference_value"]))
xyz_file = open(os.path.join(self.root, f"{item['shortname']}.xyz"), "r")
name = np.array([item.shortname])
energies.append(float(item.reference_value))
xyz_file = open(os.path.join(self.root, f"{item.shortname}.xyz"), "r")
lines = list(map(lambda x: x.strip().split(), xyz_file.readlines()))
setup = lines.pop(1)
n_atoms = np.array([int(lines[0][0])], dtype=np.int32)
n_atoms_first = setup[0].split("-")[1]
n_atoms_first = np.array([int(n_atoms_first)], dtype=np.int32)
subset = np.array([item["group"]])
subset = np.array([item.group])
energies += [float(val[idx]) for val in list(data_dict["alternative_reference"].values())]
energies = np.array([energies], dtype=np.float32)
pos = np.array(lines[1:])[:, 1:].astype(np.float32)
Expand Down

0 comments on commit e969b54

Please sign in to comment.