Skip to content

Commit

Permalink
Merge pull request #19 from kipoi/fix_output_schema
Browse files Browse the repository at this point in the history
fix get_output_schema
  • Loading branch information
Avsecz authored Oct 28, 2018
2 parents 6cc8c9a + 4af5af3 commit 4a28c91
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 30 deletions.
12 changes: 7 additions & 5 deletions kipoiseq/dataloaders/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,12 @@ def __getitem__(self, idx):

@classmethod
def get_output_schema(cls):
output_schema = deepcopy(cls.output_schema)
kwargs = default_kwargs(cls)
ignore_targets = kwargs['ignore_targets']
if ignore_targets:
cls.output_schema.targets = None
return cls.output_schema
output_schema.targets = None
return output_schema


# TODO - properly deal with samples outside of the genome
Expand Down Expand Up @@ -354,6 +355,7 @@ def __getitem__(self, idx):
def get_output_schema(cls):
"""Get the output schema. Overrides the default `cls.output_schema`
"""
output_schema = deepcopy(cls.output_schema)

# get the default kwargs
kwargs = default_kwargs(cls)
Expand All @@ -366,10 +368,10 @@ def get_output_schema(cls):
input_shape = mock_input_transform.get_output_shape(kwargs['auto_resize_len'])

# modify it
cls.output_schema.inputs.shape = input_shape
output_schema.inputs.shape = input_shape

# (optionally) get rid of the target shape
if kwargs['ignore_targets']:
cls.output_schema.targets = None
output_schema.targets = None

return cls.output_schema
return output_schema
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from setuptools import setup, find_packages

requirements = [
"kipoi>=0.4.2",
"kipoi>=0.5.5",
# "genomelake",
"pybedtools",
"pyfaidx",
Expand Down
52 changes: 28 additions & 24 deletions tests/dataloaders/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,30 +122,34 @@ def test_examples_exist(cls):
def test_output_schape():
Dl = deepcopy(SeqIntervalDl)
assert Dl.get_output_schema().inputs.shape == (None, 4)
override_default_kwargs(Dl, {"auto_resize_len": 100})
assert Dl.get_output_schema().inputs.shape == (100, 4)

override_default_kwargs(Dl, {"auto_resize_len": 100, "dummy_axis": 1, "alphabet_axis": 2})
assert Dl.get_output_schema().inputs.shape == (100, 1, 4)
override_default_kwargs(Dl, {"auto_resize_len": 100, "dummy_axis": None, "alphabet_axis": 1}) # reset
override_default_kwargs(Dl, {"auto_resize_len": 100, "dummy_axis": 2})
assert Dl.get_output_schema().inputs.shape == (100, 4, 1)
override_default_kwargs(Dl, {"auto_resize_len": 100, "dummy_axis": None, "alphabet_axis": 1}) # reset

override_default_kwargs(Dl, {"auto_resize_len": 100, "alphabet": "ACGTD"})
assert Dl.get_output_schema().inputs.shape == (100, 5)
override_default_kwargs(Dl, {"auto_resize_len": 100, "alphabet": "ACGT"}) # reset

override_default_kwargs(Dl, {"auto_resize_len": 160, "dummy_axis": 2, "alphabet_axis": 0})
assert Dl.get_output_schema().inputs.shape == (4, 160, 1)

override_default_kwargs(Dl, {"auto_resize_len": 160, "dummy_axis": 2, "alphabet_axis": 1})
assert Dl.get_output_schema().inputs.shape == (160, 4, 1)
targets = Dl.get_output_schema().targets
Dlc = override_default_kwargs(Dl, {"auto_resize_len": 100})
assert Dlc.get_output_schema().inputs.shape == (100, 4)

# original left intact
assert Dl.get_output_schema().inputs.shape == (None, 4)

Dlc = override_default_kwargs(Dl, {"auto_resize_len": 100, "dummy_axis": 1, "alphabet_axis": 2})
assert Dlc.get_output_schema().inputs.shape == (100, 1, 4)
Dlc = override_default_kwargs(Dl, {"auto_resize_len": 100, "dummy_axis": 2})
assert Dlc.get_output_schema().inputs.shape == (100, 4, 1)
# original left intact
assert Dl.get_output_schema().inputs.shape == (None, 4)

Dlc = override_default_kwargs(Dl, {"auto_resize_len": 100, "alphabet": "ACGTD"})
assert Dlc.get_output_schema().inputs.shape == (100, 5)

Dlc = override_default_kwargs(Dl, {"auto_resize_len": 160, "dummy_axis": 2, "alphabet_axis": 0})
assert Dlc.get_output_schema().inputs.shape == (4, 160, 1)

Dlc = override_default_kwargs(Dl, {"auto_resize_len": 160, "dummy_axis": 2, "alphabet_axis": 1})
assert Dlc.get_output_schema().inputs.shape == (160, 4, 1)
targets = Dlc.get_output_schema().targets
assert targets.shape == (None,)

override_default_kwargs(Dl, {"ignore_targets": True})
assert Dl.get_output_schema().targets is None
Dlc = override_default_kwargs(Dl, {"ignore_targets": True})
assert Dlc.get_output_schema().targets is None
# reset back
override_default_kwargs(Dl, {"ignore_targets": False})
Dl.output_schema.targets = targets

# original left intact
assert Dl.get_output_schema().inputs.shape == (None, 4)
assert Dl.get_output_schema().targets.shape == (None, )

0 comments on commit 4a28c91

Please sign in to comment.