Skip to content

Commit

Permalink
fix model evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
sfluegel committed Dec 13, 2023
1 parent d01a67a commit ecabf79
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions chebai/result/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from chebai.models import ChebaiBaseNet
from chebai.preprocessing.datasets import XYBaseDataModule
from Typing import Optional
from typing import Optional
from lightning.fabric.utilities.types import _PATH


Expand All @@ -31,7 +31,6 @@ def visualise_f1(logs_path):
plt.show()


# get predictions from model
def evaluate_model(
model: ChebaiBaseNet,
data_module: XYBaseDataModule,
Expand All @@ -42,9 +41,10 @@ def evaluate_model(
If buffer_dir is set, results will be saved in buffer_dir. Returns tensors with predictions and labels."""
model.eval()
collate = data_module.reader.COLLATER()
if test_file is None:
test_file = data_module.processed_file_names_dict["test"]
data_path = os.path.join(data_module.processed_dir, test_file)
if data_path is None:
data_path = os.path.join(
data_module.processed_dir, data_module.processed_file_names_dict["test"]
)
data_list = torch.load(data_path)
preds_list = []
labels_list = []
Expand Down

0 comments on commit ecabf79

Please sign in to comment.