Skip to content

Commit

Permalink
docstrings, restructure
Browse files Browse the repository at this point in the history
  • Loading branch information
timkucera committed Jan 7, 2024
1 parent 7f5b035 commit 9629ebe
Show file tree
Hide file tree
Showing 27 changed files with 113 additions and 129 deletions.
6 changes: 6 additions & 0 deletions proteinshake/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class Adapter:
"""
Downloads raw pdb files and/or meta data from a source and formats it to the shake database schema.
"""

pass
7 changes: 5 additions & 2 deletions proteinshake/database.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from pathlib import Path
from .collection import Collection


class Database:
"""
Spins up a redis database
"""

def __init__(self, storage: Path) -> None:
pass

def update(self) -> None:
pass

def query(self, query: str) -> Collection:
def query(self, query: str):
pass
9 changes: 8 additions & 1 deletion proteinshake/framework.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
class Framework:
"""
Abstract class for a framework. Used as Mixin with a Transform.
"""

def create_loader(self, iterator):
pass
"""
Creates a framework-specific dataloader from an iterator.
"""
raise NotImplementedError
4 changes: 3 additions & 1 deletion proteinshake/metric.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
class Metric:
"""For a collection of predictions and target values, return set of performance metrics.,"""
"""
Computes a set of relevant metrics for a task.
"""

pass
2 changes: 1 addition & 1 deletion proteinshake/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .evaluator import *
from .dummy_metric import *
4 changes: 0 additions & 4 deletions proteinshake/metrics/classification.py

This file was deleted.

7 changes: 7 additions & 0 deletions proteinshake/metrics/dummy_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from proteinshake.metric import Metric
import numpy as np


class DummyMetric(Metric):
def __call__(self, y_true, y_pred):
return {"Accuracy": np.random.random()}
Empty file removed proteinshake/metrics/regression.py
Empty file.
Empty file removed proteinshake/metrics/retrieval.py
Empty file.
4 changes: 4 additions & 0 deletions proteinshake/representation.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
class Representation:
"""
Abstract class for a representation. Used as Mixin with a Transform.
"""

pass
11 changes: 10 additions & 1 deletion proteinshake/split.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from typing import Dict, Iterator


class Split:
"""
Abstract class for selecting train/val/test indices given a dataset.
Abstract class to create data splits from a dataset.
"""

def __call__(self, dataset: Iterator) -> Dict[str, Iterator]:
"""
Takes an Xy iterator and returns a dictionary of Xy iterators, where each key denotes the split name (usually 'train', 'test', and 'val').
"""
raise NotImplementedError

@property
def hash(self):
return self.__class__.__name__
2 changes: 1 addition & 1 deletion proteinshake/splits/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .splitter import *
from .dummy_split import *
14 changes: 0 additions & 14 deletions proteinshake/splits/attribute.py

This file was deleted.

13 changes: 13 additions & 0 deletions proteinshake/splits/dummy_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from proteinshake.split import Split
import itertools


class DummySplit(Split):
def __call__(self, Xy):
train, testval = itertools.tee(Xy)
test, val = itertools.tee(testval)
return {
"train": filter(lambda Xy: Xy[0][0]["split"] == "train", train),
"test": filter(lambda Xy: Xy[0][0]["split"] == "test", test),
"val": filter(lambda Xy: Xy[0][0]["split"] == "val", val),
}
17 changes: 0 additions & 17 deletions proteinshake/splits/pairwise_attribute.py

This file was deleted.

20 changes: 0 additions & 20 deletions proteinshake/splits/random.py

This file was deleted.

13 changes: 10 additions & 3 deletions proteinshake/target.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from typing import Dict, Iterator


class Target:
"""Returns the attribute to predict for a single instance, given arbitrary inputs.
Different tasks will have target computations on different types and numbers of entitites.
"""
Abstract class for reshaping a dataset into the correct data-target structure for a task.
"""

pass
def __call__(self, dataset: Iterator[dict]) -> Dict[str, Iterator]:
"""
Takes a dataset iterator and returns an Xy iterator, whose elements are ((X1,X2,...), y) pairs of data tuples and targets.
"""
raise NotImplementedError
2 changes: 1 addition & 1 deletion proteinshake/targets/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .target import Target
from .attribute_target import AttributeTarget
10 changes: 10 additions & 0 deletions proteinshake/targets/attribute_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from proteinshake.target import Target


class AttributeTarget(Target):
def __init__(self, attribute) -> None:
super().__init__()
self.attribute = attribute

def __call__(self, dataset):
return (((p,), p[self.attribute]) for p in dataset)
9 changes: 0 additions & 9 deletions proteinshake/targets/pairwise_property_target.py

This file was deleted.

8 changes: 0 additions & 8 deletions proteinshake/targets/property_target.py

This file was deleted.

11 changes: 6 additions & 5 deletions proteinshake/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from proteinshake.split import Split
from proteinshake.target import Target
from proteinshake.metric import Metric
from proteinshake.transform import Transform, Compose
from proteinshake.transform import Transform, Compose, IdentityTransform
from proteinshake.util import amino_acid_alphabet, sharded, save_shards, load, warn


Expand All @@ -15,7 +15,7 @@ class Task:
split: Split = None
target: Target = None
metrics: Metric = None
augmentation: Transform = None
augmentation: Transform = IdentityTransform

def __init__(
self,
Expand Down Expand Up @@ -46,6 +46,7 @@ def __init__(
@property
def proteins(self):
# return dataset iterator
# this is a dummy for now. It will load a dataset from file in the future.
rng = np.random.default_rng(42)
return (
{
Expand All @@ -62,7 +63,7 @@ def proteins(self):

def transform(self, *transforms) -> None:
Xy = self.target(self.proteins)
partitions = self.split(Xy) # returns dict of generators[(X,...),y]
partitions = self.split(Xy)
self.transform = Compose(*[self.augmentation, *transforms])
# cache from here
self.transform.fit(partitions["train"])
Expand All @@ -73,7 +74,7 @@ def transform(self, *transforms) -> None:
)
save_shards(
data_transformed,
self.root / self.split.hash / self.transform.hash / "shards",
self.root / self.split.hash / name / self.transform.hash / "shards",
)
setattr(self, f"{name}_loader", partial(self.loader, split=name))
return self
Expand All @@ -87,7 +88,7 @@ def loader(
**kwargs,
):
rng = np.random.default_rng(random_seed)
path = self.root / self.split.hash / self.transform.hash / "shards"
path = self.root / self.split.hash / split / self.transform.hash / "shards"
shard_index = load(path / "index.npy")
if self.shard_size % batch_size != 0 and batch_size % self.shard_size != 0:
warn(
Expand Down
1 change: 1 addition & 0 deletions proteinshake/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .dummy_task import DummyTask
11 changes: 11 additions & 0 deletions proteinshake/tasks/dummy_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from proteinshake.task import Task
from proteinshake.metrics import DummyMetric
from proteinshake.targets import AttributeTarget
from proteinshake.splits import DummySplit


class DummyTask(Task):
dataset = "test"
split = DummySplit
target = AttributeTarget
metrics = DummyMetric
1 change: 0 additions & 1 deletion proteinshake/tasks/gene_ontology_classification.py

This file was deleted.

13 changes: 13 additions & 0 deletions proteinshake/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@


class BaseTransform:
"""
Abstract class for transforms. A transform can be stochastic or deterministic, which decides whether the transformed result can be precomputed and saved to disk (deterministic), or if it needs to be computed when retrieving a data item (stochastic). Transforms generally take a batch of Xy tuples, some subclasses exist that facilitate reshaping (see below). Transforms can be fit beforehand (on the 'train' partition).
"""

stochastic = False

def __call__(self, Xy):
Expand Down Expand Up @@ -55,7 +59,16 @@ def inverse_transform(self, y):
return y


class IdentityTransform(BaseTransform):
def __call__(self, Xy):
return Xy


class Compose:
"""
Composes multiple transforms into one object. Takes care of splitting the deterministic and stochastic part, as well as storing the framework create_dataloader method.
"""

def __init__(self, *transforms):
self.transforms = transforms
self.deterministic_transforms = []
Expand Down
43 changes: 3 additions & 40 deletions tests/task.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,26 @@
import unittest
import numpy as np
import itertools
from proteinshake.metric import Metric
from proteinshake.target import Target
from proteinshake.split import Split
from proteinshake.task import Task
from proteinshake.tasks import DummyTask
from proteinshake.transform import *
from proteinshake.transforms import *


class TestTask(unittest.TestCase):
def test_task(self):
# CONTRIBUTOR
class MyTarget(Target):
def __call__(self, dataset):
return (((p,), p["label"]) for p in dataset)

class MyMetric(Metric):
def __call__(self, y_true, y_pred):
return {"Accuracy": np.random.random()}

class MySplit(Split):
def __call__(self, Xy):
# this implementation looks a bit inefficient
train, testval = itertools.tee(Xy)
test, val = itertools.tee(testval)
return {
"train": filter(lambda Xy: Xy[0][0]["split"] == "train", train),
"test": filter(lambda Xy: Xy[0][0]["split"] == "test", test),
"val": filter(lambda Xy: Xy[0][0]["split"] == "val", val),
}

class MyAugmentation(Transform):
def transform(self, X):
return X

class MyTask(Task):
dataset = "test"
split = MySplit
target = MyTarget
metrics = MyMetric
augmentation = MyAugmentation

# END USER
class MyLabelTransform(LabelTransform):
def transform(self, y):
return -y

def inverse_transform(self, y):
return -y

task = MyTask(shard_size=8).transform(
task = DummyTask(target_kwargs={"attribute": "label"}, shard_size=8).transform(
MyLabelTransform(),
PointRepresentationTransform(),
TorchFrameworkTransform(),
)

for epoch in range(5):
for X, y in task.train_loader(batch_size=64, random_seed=0):
for X, y in task.train_loader(batch_size=16, shuffle=True, random_seed=0):
print("X", X.shape)
print("y", y.shape)
break
Expand Down

0 comments on commit 9629ebe

Please sign in to comment.