Skip to content

Commit

Permalink
add support for integration of gnn module
Browse files Browse the repository at this point in the history
  • Loading branch information
sfluegel committed Dec 11, 2023
1 parent cb5d6b4 commit 643ffe8
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 7 deletions.
8 changes: 5 additions & 3 deletions chebai/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from lightning.pytorch.core.module import LightningModule
import torch
from typing import Optional, Dict, Any
from chebai.preprocessing.structures import XYData

logging.getLogger("pysmiles").setLevel(logging.CRITICAL)

Expand Down Expand Up @@ -83,6 +84,7 @@ def predict_step(self, batch, batch_idx, **kwargs):
return self._execute(batch, batch_idx, self.test_metrics, prefix="", log=False)

def _execute(self, batch, batch_idx, metrics, prefix="", log=True, sync_dist=False):
assert isinstance(batch, XYData)
data = self._process_batch(batch, batch_idx)
labels = data["labels"]
model_output = self(data, **data.get("model_kwargs", dict()))
Expand All @@ -101,7 +103,7 @@ def _execute(self, batch, batch_idx, metrics, prefix="", log=True, sync_dist=Fal
self.log(
f"{prefix}loss",
loss.item(),
batch_size=batch.x.shape[0],
batch_size=len(batch),
on_step=True,
on_epoch=True,
prog_bar=True,
Expand All @@ -116,7 +118,7 @@ def _execute(self, batch, batch_idx, metrics, prefix="", log=True, sync_dist=Fal
self.log(
f"{prefix}{metric_name}{k}",
m2,
batch_size=batch.x.shape[0],
batch_size=len(batch),
on_step=False,
on_epoch=True,
prog_bar=True,
Expand All @@ -127,7 +129,7 @@ def _execute(self, batch, batch_idx, metrics, prefix="", log=True, sync_dist=Fal
self.log(
f"{prefix}{metric_name}",
m,
batch_size=batch.x.shape[0],
batch_size=len(batch),
on_step=False,
on_epoch=True,
prog_bar=True,
Expand Down
2 changes: 1 addition & 1 deletion chebai/preprocessing/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class Collater:
def __init__(self, **kwargs):
pass

def __call__(self, data):
def __call__(self, data) -> XYData:
raise NotImplementedError


Expand Down
6 changes: 3 additions & 3 deletions chebai/preprocessing/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(self, collator_kwargs=None, token_path=None, **kwargs):
if collator_kwargs is None:
collator_kwargs = dict()
self.collater = self.COLLATER(**collator_kwargs)
self.dirname = os.path.dirname(__file__)
self._token_path = token_path

def _get_raw_data(self, row):
Expand All @@ -48,9 +49,8 @@ def token_path(self):
"""Get token path, create file if it does not exist yet"""
if self._token_path is not None:
return self._token_path
dirname = os.path.dirname(__file__)
token_path = os.path.join(dirname, "bin", self.name(), "tokens.txt")
os.makedirs(os.path.join(dirname, "bin", self.name()), exist_ok=True)
token_path = os.path.join(self.dirname, "bin", self.name(), "tokens.txt")
os.makedirs(os.path.join(self.dirname, "bin", self.name()), exist_ok=True)
if not os.path.exists(token_path):
with open(token_path, "x"):
pass
Expand Down
1 change: 1 addition & 0 deletions chebai/preprocessing/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def __getitem__(self, index) -> T_co:
return self.x[index], self.y[index]

def __len__(self):
# return batch size
return len(self.x)

def __init__(self, x, y, **kwargs):
Expand Down

0 comments on commit 643ffe8

Please sign in to comment.