Skip to content

Commit

Permalink
add transform fitting
Browse files Browse the repository at this point in the history
  • Loading branch information
timkucera committed Dec 1, 2023
1 parent 70e70d3 commit 5e47ae1
Show file tree
Hide file tree
Showing 10 changed files with 44 additions and 29 deletions.
2 changes: 1 addition & 1 deletion proteinshake/frontend/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
"sequence": "".join(
random.choice(amino_acid_alphabet) for _ in range(300)
),
"label": np.random.random(),
"label": np.random.random() * 100,
"split": random.choice(["train", "test", "val"]),
}
for i in range(100)
Expand Down
3 changes: 3 additions & 0 deletions proteinshake/frontend/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def __init__(
self.index = splitter(dataset)
# partition the dataset. the dataset will optimize data loading.
dataset.partition(self.index)
# fit the transforms
X_transform.fit(dataset)
y_transform.fit(dataset)
# create X,y,dataloader for each item in the split.
for name, index in self.index.items():
# get the partition of the split, apply transforms, and save to disk.
Expand Down
1 change: 0 additions & 1 deletion proteinshake/frontend/transforms/framework/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from .framework import *
from .torch import *
5 changes: 0 additions & 5 deletions proteinshake/frontend/transforms/framework/framework.py

This file was deleted.

4 changes: 2 additions & 2 deletions proteinshake/frontend/transforms/framework/torch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from .framework import FrameworkTransform
from ..transform import FrameworkTransform


class TorchFrameworkTransform(FrameworkTransform):
def __call__(self, representation):
def transform(self, representation):
return torch.tensor(representation)
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from .representation import *
from .point import *
4 changes: 2 additions & 2 deletions proteinshake/frontend/transforms/representation/point.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .representation import RepresentationTransform
from ..transform import RepresentationTransform


class PointRepresentationTransform(RepresentationTransform):
def __call__(self, protein):
def transform(self, protein):
return protein["coords"]

This file was deleted.

29 changes: 24 additions & 5 deletions proteinshake/frontend/transforms/transform.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,38 @@
class Transform:
def fit(self, x):
pass

def transform(self, x):
return x

def __call__(self, *args, **kwargs):
return self.transform(*args, **kwargs)


class RepresentationTransform(Transform):
pass


class FrameworkTransform(Transform):
pass


class DataTransform:
def __init__(self, representation_transform, framework_transform):
class DataTransform(Transform):
def __init__(
self,
representation_transform=RepresentationTransform(),
framework_transform=FrameworkTransform(),
):
self.representation_transform = representation_transform
self.framework_transform = framework_transform

def __call__(self, x):
def transform(self, x):
return self.framework_transform(self.representation_transform(x))


class TargetTransform:
class TargetTransform(Transform):
pass


class LabelTransform:
class LabelTransform(Transform):
pass
19 changes: 12 additions & 7 deletions tests/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

class TestTask(unittest.TestCase):
def test_task(self):
# CREATOR
# CONTRIBUTOR
class MyTargetTransform(TargetTransform):
def __call__(self, protein):
def transform(self, protein):
return protein["label"]

class MyEvaluator:
Expand All @@ -34,16 +34,21 @@ def __init__(self, **kwargs):
)

# END USER
y_transform = lambda label: label * 100
class MyLabelTransform(LabelTransform):
def fit(self, dataset):
labels = [p["label"] for p in dataset.split("train").proteins]
self.min, self.max = min(labels), max(labels)

def transform(self, x):
return (x - self.min) / (self.max - self.min)

y_transform = MyLabelTransform()
X_transform = DataTransform(
representation_transform=PointRepresentationTransform(),
framework_transform=TorchFrameworkTransform(),
)

task = MyTask(
X_transform=X_transform,
y_transform=y_transform,
)
task = MyTask(X_transform=X_transform, y_transform=y_transform)

print(task.train_index)
print(next(task.X_train).shape)
Expand Down

0 comments on commit 5e47ae1

Please sign in to comment.