diff --git a/proteinshake/frontend/dataset.py b/proteinshake/frontend/dataset.py index 7f6833ed..6c239776 100644 --- a/proteinshake/frontend/dataset.py +++ b/proteinshake/frontend/dataset.py @@ -24,7 +24,7 @@ def __init__( "sequence": "".join( random.choice(amino_acid_alphabet) for _ in range(300) ), - "label": np.random.random(), + "label": np.random.random() * 100, "split": random.choice(["train", "test", "val"]), } for i in range(100) diff --git a/proteinshake/frontend/task.py b/proteinshake/frontend/task.py index 3ca07a7c..6df09396 100644 --- a/proteinshake/frontend/task.py +++ b/proteinshake/frontend/task.py @@ -19,6 +19,9 @@ def __init__( self.index = splitter(dataset) # partition the dataset. the dataset will optimize data loading. dataset.partition(self.index) + # fit the transforms + X_transform.fit(dataset) + y_transform.fit(dataset) # create X,y,dataloader for each item in the split. for name, index in self.index.items(): # get the partition of the split, apply transforms, and save to disk. diff --git a/proteinshake/frontend/transforms/framework/__init__.py b/proteinshake/frontend/transforms/framework/__init__.py index 245bcba0..72f4e5e8 100644 --- a/proteinshake/frontend/transforms/framework/__init__.py +++ b/proteinshake/frontend/transforms/framework/__init__.py @@ -1,2 +1 @@ -from .framework import * from .torch import * diff --git a/proteinshake/frontend/transforms/framework/framework.py b/proteinshake/frontend/transforms/framework/framework.py deleted file mode 100644 index 7fe5322c..00000000 --- a/proteinshake/frontend/transforms/framework/framework.py +++ /dev/null @@ -1,5 +0,0 @@ -from ..transform import Transform - - -class FrameworkTransform(Transform): - pass diff --git a/proteinshake/frontend/transforms/framework/torch.py b/proteinshake/frontend/transforms/framework/torch.py index 477e9836..497ab93a 100644 --- a/proteinshake/frontend/transforms/framework/torch.py +++ b/proteinshake/frontend/transforms/framework/torch.py @@ -1,7 +1,7 @@ import torch -from .framework import FrameworkTransform +from ..transform import FrameworkTransform class TorchFrameworkTransform(FrameworkTransform): - def __call__(self, representation): + def transform(self, representation): return torch.tensor(representation) diff --git a/proteinshake/frontend/transforms/representation/__init__.py b/proteinshake/frontend/transforms/representation/__init__.py index 33f657a5..86ba70b7 100644 --- a/proteinshake/frontend/transforms/representation/__init__.py +++ b/proteinshake/frontend/transforms/representation/__init__.py @@ -1,2 +1 @@ -from .representation import * from .point import * diff --git a/proteinshake/frontend/transforms/representation/point.py b/proteinshake/frontend/transforms/representation/point.py index 13f67a0e..98c9559d 100644 --- a/proteinshake/frontend/transforms/representation/point.py +++ b/proteinshake/frontend/transforms/representation/point.py @@ -1,6 +1,6 @@ -from .representation import RepresentationTransform +from ..transform import RepresentationTransform class PointRepresentationTransform(RepresentationTransform): - def __call__(self, protein): + def transform(self, protein): return protein["coords"] diff --git a/proteinshake/frontend/transforms/representation/representation.py b/proteinshake/frontend/transforms/representation/representation.py deleted file mode 100644 index d0cfac70..00000000 --- a/proteinshake/frontend/transforms/representation/representation.py +++ /dev/null @@ -1,5 +0,0 @@ -from ..transform import Transform - - -class RepresentationTransform(Transform): - pass diff --git a/proteinshake/frontend/transforms/transform.py b/proteinshake/frontend/transforms/transform.py index 09ce1e72..5baebad3 100644 --- a/proteinshake/frontend/transforms/transform.py +++ b/proteinshake/frontend/transforms/transform.py @@ -1,19 +1,38 @@ class Transform: + def fit(self, x): + pass + + def transform(self, x): + return x + + def __call__(self, *args, **kwargs): + return self.transform(*args, **kwargs) + + +class RepresentationTransform(Transform): + pass + + +class FrameworkTransform(Transform): pass -class DataTransform: - def __init__(self, representation_transform, framework_transform): +class DataTransform(Transform): + def __init__( + self, + representation_transform=RepresentationTransform(), + framework_transform=FrameworkTransform(), + ): self.representation_transform = representation_transform self.framework_transform = framework_transform - def __call__(self, x): + def transform(self, x): return self.framework_transform(self.representation_transform(x)) -class TargetTransform: +class TargetTransform(Transform): pass -class LabelTransform: +class LabelTransform(Transform): pass diff --git a/tests/task.py b/tests/task.py index 1132571d..0a6210fa 100644 --- a/tests/task.py +++ b/tests/task.py @@ -7,9 +7,9 @@ class TestTask(unittest.TestCase): def test_task(self): - # CREATOR + # CONTRIBUTOR class MyTargetTransform(TargetTransform): - def __call__(self, protein): + def transform(self, protein): return protein["label"] class MyEvaluator: @@ -34,16 +34,21 @@ def __init__(self, **kwargs): ) # END USER - y_transform = lambda label: label * 100 + class MyLabelTransform(LabelTransform): + def fit(self, dataset): + labels = [p["label"] for p in dataset.split("train").proteins] + self.min, self.max = min(labels), max(labels) + + def transform(self, x): + return (x - self.min) / (self.max - self.min) + + y_transform = MyLabelTransform() X_transform = DataTransform( representation_transform=PointRepresentationTransform(), framework_transform=TorchFrameworkTransform(), ) - task = MyTask( - X_transform=X_transform, - y_transform=y_transform, - ) + task = MyTask(X_transform=X_transform, y_transform=y_transform) print(task.train_index) print(next(task.X_train).shape)