Skip to content

Commit

Permalink
Remove default schemes
Browse files Browse the repository at this point in the history
  • Loading branch information
bartvm committed Feb 26, 2015
1 parent 8b975e1 commit a4bf7e9
Show file tree
Hide file tree
Showing 9 changed files with 27 additions and 42 deletions.
10 changes: 0 additions & 10 deletions fuel/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@ class Dataset(object):
'targets')`` for MNIST (regardless of which data the data stream
actually requests). Any implementation of a dataset should set this
attribute on the class (or at least before calling ``super``).
default_iteration_scheme : :class:`.IterationScheme`, optional
The default iteration scheme that will be used by
:meth:`get_default_stream` to create a data stream without needing
to specify what iteration scheme to use.
Notes
-----
Expand Down Expand Up @@ -139,12 +135,6 @@ def get_data(self, state=None, request=None):
"""
raise NotImplementedError

def get_default_stream(self):
"""Use the default iteration scheme to construct a data stream."""
if not hasattr(self, 'default_scheme'):
raise ValueError("Dataset does not provide a default iterator")
return DataStream(self, iteration_scheme=self.default_scheme)

def filter_sources(self, data):
"""Filter the requested sources from those provided by the dataset.
Expand Down
1 change: 0 additions & 1 deletion fuel/datasets/binarized_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def __init__(self, which_set, **kwargs):
raise ValueError("available splits are 'train', 'valid' and "
"'test'")
self.num_examples = 50000 if which_set == 'train' else 10000
self.default_scheme = SequentialScheme(self.num_examples, 1)
super(BinarizedMNIST, self).__init__(**kwargs)

self.which_set = which_set
Expand Down
1 change: 0 additions & 1 deletion fuel/datasets/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def __init__(self, which_set, start=None, stop=None, **kwargs):
if start is None:
start = 0
self.num_examples = stop - start
self.default_scheme = SequentialScheme(self.num_examples, 1)
super(CIFAR10, self).__init__(**kwargs)

self.which_set = which_set
Expand Down
2 changes: 0 additions & 2 deletions fuel/datasets/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ class ContainerDataset(Dataset):
:class:`BatchDataStream` data stream.
"""
default_scheme = None

def __init__(self, container, sources=None):
if isinstance(container, dict):
self.provides_sources = (sources if sources is not None
Expand Down
1 change: 0 additions & 1 deletion fuel/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def __init__(self, which_set, start=None, stop=None, binary=False,
if start is None:
start = 0
self.num_examples = stop - start
self.default_scheme = SequentialScheme(self.num_examples, 1)
super(MNIST, self).__init__(**kwargs)

self.which_set = which_set
Expand Down
4 changes: 2 additions & 2 deletions fuel/datasets/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class TextFile(Dataset):
>>> text_data = TextFile(files=['sentences.txt'],
... dictionary=dictionary, bos_token=None,
... preprocess=lower)
>>> for data in text_data.get_default_stream().get_epoch_iterator():
>>> from fuel.streams import DataStream
>>> for data in DataStream(text_data).get_epoch_iterator():
... print(data)
([2, 0, 3, 0, 1],)
([2, 0, 4, 1],)
Expand All @@ -62,7 +63,6 @@ class TextFile(Dataset):
"""
provides_sources = ('features',)
default_scheme = None

def __init__(self, files, dictionary, bos_token='<S>', eos_token='</S>',
unk_token='<UNK>', level='word', preprocess=None):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_binarized_mnist.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import numpy
from numpy.testing import assert_raises

from fuel.datasets import BinarizedMNIST
from tests import skip_if_not_available


def test_mnist():
skip_if_not_available(datasets=['binarized_mnist'])
mnist_train = BinarizedMNIST('train')
assert len(mnist_train.features) == 50000
assert mnist_train.num_examples == 50000
Expand Down
42 changes: 20 additions & 22 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def test_dataset():
data = [1, 2, 3]
# The default stream requests an example at a time
stream = ContainerDataset(data).get_default_stream()
stream = DataStream(ContainerDataset(data))
epoch = stream.get_epoch_iterator()
assert list(epoch) == list(zip(data))

Expand All @@ -33,7 +33,7 @@ def test_dataset():
def test_data_stream_mapping():
data = [1, 2, 3]
data_doubled = [2, 4, 6]
stream = ContainerDataset(data).get_default_stream()
stream = DataStream(ContainerDataset(data))
wrapper1 = DataStreamMapping(
stream, lambda d: (2 * d[0],))
assert list(wrapper1.get_epoch_iterator()) == list(zip(data_doubled))
Expand All @@ -49,7 +49,7 @@ def test_data_stream_mapping_sort():
[3, 2, 1]]
data_sorted = [[1, 2, 3]] * 3
data_sorted_rev = [[3, 2, 1]] * 3
stream = ContainerDataset(data).get_default_stream()
stream = DataStream(ContainerDataset(data))
wrapper1 = DataStreamMapping(stream,
mapping=SortMapping(operator.itemgetter(0)))
assert list(wrapper1.get_epoch_iterator()) == list(zip(data_sorted))
Expand All @@ -71,7 +71,7 @@ def test_data_stream_mapping_sort_multisource_ndarrays():
data_sorted = [(numpy.array([1, 2, 3]), numpy.array([6, 5, 4])),
(numpy.array([1, 2, 3]), numpy.array([4, 6, 5])),
(numpy.array([1, 2, 3]), numpy.array([4, 5, 6]))]
stream = ContainerDataset(data).get_default_stream()
stream = DataStream(ContainerDataset(data))
wrapper = DataStreamMapping(stream,
mapping=SortMapping(operator.itemgetter(0)))
for output, ground_truth in zip(wrapper.get_epoch_iterator(), data_sorted):
Expand All @@ -87,7 +87,7 @@ def test_data_stream_mapping_sort_multisource():
data_sorted = [([1, 2, 3], [6, 5, 4]),
([1, 2, 3], [4, 6, 5]),
([1, 2, 3], [4, 5, 6])]
stream = ContainerDataset(data).get_default_stream()
stream = DataStream(ContainerDataset(data))
wrapper = DataStreamMapping(stream,
mapping=SortMapping(operator.itemgetter(0)))
assert list(wrapper.get_epoch_iterator()) == data_sorted
Expand All @@ -96,7 +96,7 @@ def test_data_stream_mapping_sort_multisource():
def test_data_stream_filter():
data = [1, 2, 3]
data_filtered = [1, 3]
stream = ContainerDataset(data).get_default_stream()
stream = DataStream(ContainerDataset(data))
wrapper = DataStreamFilter(stream, lambda d: d[0] % 2 == 1)
assert list(wrapper.get_epoch_iterator()) == list(zip(data_filtered))

Expand All @@ -105,27 +105,27 @@ def test_floatx():
x = [numpy.array(d, dtype="float64") for d in [[1, 2], [3, 4]]]
y = [numpy.array(d, dtype="int64") for d in [1, 2, 3]]
dataset = ContainerDataset(OrderedDict([("x", x), ("y", y)]))
data = next(ForceFloatX(dataset.get_default_stream()).get_epoch_iterator())
data = next(ForceFloatX(DataStream(dataset)).get_epoch_iterator())
assert str(data[0].dtype) == floatX
assert str(data[1].dtype) == "int64"


def test_sources_selection():
features = [5, 6, 7, 1]
targets = [1, 0, 1, 1]
stream = ContainerDataset(OrderedDict(
[('features', features), ('targets', targets)])).get_default_stream()
stream = DataStream(ContainerDataset(OrderedDict(
[('features', features), ('targets', targets)])))
assert list(stream.get_epoch_iterator()) == list(zip(features, targets))

stream = ContainerDataset({'features': features, 'targets': targets},
sources=('targets',)).get_default_stream()
stream = DataStream(ContainerDataset(
{'features': features, 'targets': targets},
sources=('targets',)))
assert list(stream.get_epoch_iterator()) == list(zip(targets))


def test_data_driven_epochs():
class TestDataset(ContainerDataset):
sources = ('data',)
default_scheme = ConstantScheme(1)

def __init__(self):
self.data = [[1, 2, 3, 4],
Expand All @@ -152,7 +152,7 @@ def get_data(self, state, request):
epochs = []
epochs.append([([1],), ([2],), ([3],), ([4],)])
epochs.append([([5],), ([6],), ([7],), ([8],)])
stream = TestDataset().get_default_stream()
stream = DataStream(TestDataset(), iteration_scheme=ConstantScheme(1))
assert list(stream.get_epoch_iterator()) == epochs[0]
assert list(stream.get_epoch_iterator()) == epochs[1]
assert list(stream.get_epoch_iterator()) == epochs[0]
Expand Down Expand Up @@ -206,7 +206,7 @@ def test_cache():


def test_batch_data_stream():
stream = ContainerDataset([1, 2, 3, 4, 5]).get_default_stream()
stream = DataStream(ContainerDataset([1, 2, 3, 4, 5]))
batches = list(BatchDataStream(stream, ConstantScheme(2))
.get_epoch_iterator())
expected = [(numpy.array([1, 2]),),
Expand All @@ -223,16 +223,15 @@ def try_strict(strictness):
.get_epoch_iterator())
assert_raises(ValueError, try_strict, 2)
assert len(try_strict(1)) == 2
stream2 = ContainerDataset([1, 2, 3, 4, 5, 6]).get_default_stream()
stream2 = DataStream(ContainerDataset([1, 2, 3, 4, 5, 6]))
assert len(list(BatchDataStream(stream2, ConstantScheme(2), strictness=2)
.get_epoch_iterator())) == 3


def test_padding_data_stream():
# 1-D sequences
stream = BatchDataStream(
ContainerDataset([[1], [2, 3], [], [4, 5, 6], [7]])
.get_default_stream(),
DataStream(ContainerDataset([[1], [2, 3], [], [4, 5, 6], [7]])),
ConstantScheme(2))
mask_stream = PaddingDataStream(stream)
assert mask_stream.sources == ("data", "data_mask")
Expand All @@ -249,8 +248,8 @@ def test_padding_data_stream():

# 2D sequences
stream2 = BatchDataStream(
ContainerDataset([numpy.ones((3, 4)), 2 * numpy.ones((2, 4))])
.get_default_stream(),
DataStream(ContainerDataset([numpy.ones((3, 4)),
2 * numpy.ones((2, 4))])),
ConstantScheme(2))
it = PaddingDataStream(stream2).get_epoch_iterator()
data, mask = next(it)
Expand All @@ -261,8 +260,7 @@ def test_padding_data_stream():

# 2 sources
stream3 = PaddingDataStream(BatchDataStream(
ContainerDataset(dict(features=[[1], [2, 3], []],
targets=[[4, 5, 6], [7]]))
.get_default_stream(),
DataStream(ContainerDataset(dict(features=[[1], [2, 3], []],
targets=[[4, 5, 6], [7]]))),
ConstantScheme(2)))
assert len(next(stream3.get_epoch_iterator())) == 4
5 changes: 3 additions & 2 deletions tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from six.moves import cPickle

from fuel.datasets import TextFile
from fuel.streams import DataStream


def lower(s):
Expand All @@ -25,7 +26,7 @@ def test_text():
text_data = TextFile(files=[sentences1, sentences2],
dictionary=dictionary, bos_token=None,
preprocess=lower)
stream = text_data.get_default_stream()
stream = DataStream(text_data)
epoch = stream.get_epoch_iterator()
assert len(list(epoch)) == 4
epoch = stream.get_epoch_iterator()
Expand All @@ -46,6 +47,6 @@ def test_text():
text_data = TextFile(files=[sentences1, sentences2],
dictionary=dictionary, preprocess=lower,
level="character")
sentence = next(text_data.get_default_stream().get_epoch_iterator())[0]
sentence = next(DataStream(text_data).get_epoch_iterator())[0]
assert sentence[:3] == [27, 19, 7]
assert sentence[-3:] == [2, 4, 28]

0 comments on commit a4bf7e9

Please sign in to comment.