Skip to content

Commit

Permalink
intermediate commit
Browse files Browse the repository at this point in the history
  • Loading branch information
timkucera committed Nov 28, 2023
1 parent bb9b412 commit a05df4d
Show file tree
Hide file tree
Showing 37 changed files with 177 additions and 87 deletions.
File renamed without changes.
25 changes: 25 additions & 0 deletions proteinshake/backend/collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
class Collection:
"""
Holds a set of proteins as the result of a database query and prepares it for dataset creation.
"""

def __init__(self, proteins: list[dict]) -> None:
pass

def add(self, metadata: Any) -> None:
"""
Adds any kind of metadata to the collection, such as split indices.
"""
pass

def save(self, name: str) -> None:
"""
Saves the proteins and meta data in compressed format.
"""
pass

def upload(self, version: str = None) -> None:
"""
Uploads the collection and meta data to Zenodo. `version` defaults to the current date.
"""
pass
13 changes: 13 additions & 0 deletions proteinshake/backend/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from pathlib import Path
from .collection import Collection


class Database:
def __init__(self, storage: Path) -> None:
pass

def update(self) -> None:
pass

def query(self, query: str) -> Collection:
pass
4 changes: 4 additions & 0 deletions proteinshake/backend/protein.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class Protein:
"""
Takes a result row from a database query and converts it to a dictionary.
"""
4 changes: 4 additions & 0 deletions proteinshake/backend/structure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class Structure:
"""
Takes a pdb/mmcif file and converts it to a compressed data format.
"""
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## Dataset API

tba

## Task API

A task brings three objects: `Splitter`, `Target`, `Evaluator` to a given `proteinshake.Dataset` instance.
Expand Down
54 changes: 54 additions & 0 deletions proteinshake/frontend/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
class Dataset:
def __init__(
self,
path: Path,
version: str = "latest",
shard_size: int = None,
batch_size: int = None,
shuffle: bool = False,
random_seed: int = 42,
) -> None:
"""
Takes a compressed collection and applies transforms.
`path` is either pointing to a Zenodo repository or a directory in the local filesystem.
"""
pass

def to_graph(
self,
pre_transform: PreRepresentationTransform = None,
post_transform: PostRepresentationTransform = None,
**kwargs
) -> Dataset:
"""
Applies pre/representation/post transforms to all proteins in the dataset.
"""
self.proteins.apply(pre_transform)
self.proteins.apply(GraphTransform(**kwargs))
self.proteins.apply(post_transform)
return self

def pyg(
self,
pre_transform: PreFrameworkTransform = None,
post_transform: PostFrameworkTransform = None,
**kwargs
) -> Generic:
"""
Creates an iterable that wraps around __next__ or __getitem__ and applies pre/framework/post transforms.
Returns a framework-specific dataset instance (iterable-style if sharded, map-style if in-memory or on-disk).
"""
pass

def __next__(self) -> None:
"""
Yields the next protein from a shard. When the shard is finished, loads the next one.
If `shuffle` is True, loads a random shard and applies shuffling within the shard.
"""
pass

def __getitem__(self, index: Union[int, list, tuple, ndarray]) -> None:
"""
Returns the indexed proteins. Not available with sharding for performance reasons.
"""
pass
Empty file.
4 changes: 4 additions & 0 deletions proteinshake/frontend/evaluators/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class ClassificationEvaluator(Evaluator):
def __call__(self, pred : list, truth: list):
return {'accuracy': sklearn.accuracy(pred, truth)}
pass
File renamed without changes.
4 changes: 4 additions & 0 deletions proteinshake/frontend/protein.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class Protein:
"""
Takes a (compressed) collection protein and converts it to an uncompressed protein dictionary.
"""
14 changes: 14 additions & 0 deletions proteinshake/frontend/splitters/attribute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
class AttributeSplitter(Splitter):
"""
Compute splits based on an attribute that already exists in the dataset
"""

def __init__(
self, train_attribute: str, val_attribute: str, test_attribute: str
) -> None:
self.train_attribute = train_attribute
self.val_attribute = val_attribute
self.test_attribute = test_attribute

def __call__(self, dataset) -> tuple[list, list, list]:
pass
Empty file.
17 changes: 17 additions & 0 deletions proteinshake/frontend/splitters/pairwise_attribute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
class PairwiseAttributeSplitter(Splitter):
"""Compute pairwise splits based on an attribute that already exists in the dataset.
Takes all pairs of train/val/test in the single attribute splitting setting."""

def __init__(
self, train_attribute: str, val_attribute: str, test_attribute: str
) -> None:
self.train_attribute = train_attribute
self.val_attribute = val_attribute
self.test_attribute = test_attribute

def __call__(self, dataset) -> tuple[list, list, list]:
tmp_splitter = AttributeSplitter(
self.train_attribute, self.val_attribute, self.test_attribute
)
# compute pairs of indices on the non-paired splits
pass
Empty file.
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
class Splitter:
""" Abstract class for selecting train/val/test indices given a dataset.
"""
"""
Abstract class for selecting train/val/test indices given a dataset.
"""

def __call__(self, dataset) -> tuple[list, list, list]:
raise NotImplementedError
Empty file.
File renamed without changes.
File renamed without changes.
30 changes: 30 additions & 0 deletions proteinshake/frontend/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
class Task:
"""
Abstract class for Tasks. A task contains the logic for splitting, target generation, and evaluation.
Optionally, we can consider the Task as a way of syncing with a paperwithcode instance https://github.com/paperswithcode/paperswithcode-client.
"""

def __init__(
self,
dataset: proteinshake.Dataset,
splitter: proteinshake.Splitter,
target: proteinshake.Target,
evaluator: proteinshake.Evaluator,
task_id: int,
) -> None:
self.dataset = dataset
self.train_idx = splitter.train_idx()
self.val_idx = splitter.val_idx()
self.test_idx = splitter.test_idx()

self.task_id = task_id

self.target = target
self.evaluator = evaluator
pass

def leaderboard_fetch(self):
"""Load current leaderboard results for this task"""

if not self.task_id is None:
return get_leaderboard(f"https:/paperswithcode.com/sota/{self.task_id}")
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
File renamed without changes.
16 changes: 0 additions & 16 deletions proteinshake/tasks/attribute_splitter.py

This file was deleted.

4 changes: 0 additions & 4 deletions proteinshake/tasks/classification_evaluator.py

This file was deleted.

21 changes: 0 additions & 21 deletions proteinshake/tasks/pairwise_attribute_splitter.py

This file was deleted.

29 changes: 0 additions & 29 deletions proteinshake/tasks/task.py

This file was deleted.

15 changes: 0 additions & 15 deletions proteinshake/tasks/time_splitter.py

This file was deleted.

0 comments on commit a05df4d

Please sign in to comment.