From a4bf7e991848b08a47b128a28a7a9592c9812a6e Mon Sep 17 00:00:00 2001 From: Bart van Merrienboer Date: Wed, 25 Feb 2015 20:27:34 -0500 Subject: [PATCH] Remove default schemes --- fuel/datasets/base.py | 10 -------- fuel/datasets/binarized_mnist.py | 1 - fuel/datasets/cifar10.py | 1 - fuel/datasets/container.py | 2 -- fuel/datasets/mnist.py | 1 - fuel/datasets/text.py | 4 +-- tests/test_binarized_mnist.py | 3 ++- tests/test_datasets.py | 42 +++++++++++++++----------------- tests/test_text.py | 5 ++-- 9 files changed, 27 insertions(+), 42 deletions(-) diff --git a/fuel/datasets/base.py b/fuel/datasets/base.py index 7d14cc378..566caf36a 100644 --- a/fuel/datasets/base.py +++ b/fuel/datasets/base.py @@ -28,10 +28,6 @@ class Dataset(object): 'targets')`` for MNIST (regardless of which data the data stream actually requests). Any implementation of a dataset should set this attribute on the class (or at least before calling ``super``). - default_iteration_scheme : :class:`.IterationScheme`, optional - The default iteration scheme that will be used by - :meth:`get_default_stream` to create a data stream without needing - to specify what iteration scheme to use. Notes ----- @@ -139,12 +135,6 @@ def get_data(self, state=None, request=None): """ raise NotImplementedError - def get_default_stream(self): - """Use the default iteration scheme to construct a data stream.""" - if not hasattr(self, 'default_scheme'): - raise ValueError("Dataset does not provide a default iterator") - return DataStream(self, iteration_scheme=self.default_scheme) - def filter_sources(self, data): """Filter the requested sources from those provided by the dataset. diff --git a/fuel/datasets/binarized_mnist.py b/fuel/datasets/binarized_mnist.py index 8c1f25c85..5c514b78a 100644 --- a/fuel/datasets/binarized_mnist.py +++ b/fuel/datasets/binarized_mnist.py @@ -61,7 +61,6 @@ def __init__(self, which_set, **kwargs): raise ValueError("available splits are 'train', 'valid' and " "'test'") self.num_examples = 50000 if which_set == 'train' else 10000 - self.default_scheme = SequentialScheme(self.num_examples, 1) super(BinarizedMNIST, self).__init__(**kwargs) self.which_set = which_set diff --git a/fuel/datasets/cifar10.py b/fuel/datasets/cifar10.py index b9a9ebc23..a0d414ffb 100644 --- a/fuel/datasets/cifar10.py +++ b/fuel/datasets/cifar10.py @@ -57,7 +57,6 @@ def __init__(self, which_set, start=None, stop=None, **kwargs): if start is None: start = 0 self.num_examples = stop - start - self.default_scheme = SequentialScheme(self.num_examples, 1) super(CIFAR10, self).__init__(**kwargs) self.which_set = which_set diff --git a/fuel/datasets/container.py b/fuel/datasets/container.py index d39fd0080..a7836f438 100644 --- a/fuel/datasets/container.py +++ b/fuel/datasets/container.py @@ -23,8 +23,6 @@ class ContainerDataset(Dataset): :class:`BatchDataStream` data stream. """ - default_scheme = None - def __init__(self, container, sources=None): if isinstance(container, dict): self.provides_sources = (sources if sources is not None diff --git a/fuel/datasets/mnist.py b/fuel/datasets/mnist.py index 19a1a7bc9..9bd5db5a9 100644 --- a/fuel/datasets/mnist.py +++ b/fuel/datasets/mnist.py @@ -58,7 +58,6 @@ def __init__(self, which_set, start=None, stop=None, binary=False, if start is None: start = 0 self.num_examples = stop - start - self.default_scheme = SequentialScheme(self.num_examples, 1) super(MNIST, self).__init__(**kwargs) self.which_set = which_set diff --git a/fuel/datasets/text.py b/fuel/datasets/text.py index bff836e59..b7729099c 100644 --- a/fuel/datasets/text.py +++ b/fuel/datasets/text.py @@ -49,7 +49,8 @@ class TextFile(Dataset): >>> text_data = TextFile(files=['sentences.txt'], ... dictionary=dictionary, bos_token=None, ... preprocess=lower) - >>> for data in text_data.get_default_stream().get_epoch_iterator(): + >>> from fuel.streams import DataStream + >>> for data in DataStream(text_data).get_epoch_iterator(): ... print(data) ([2, 0, 3, 0, 1],) ([2, 0, 4, 1],) @@ -62,7 +63,6 @@ class TextFile(Dataset): """ provides_sources = ('features',) - default_scheme = None def __init__(self, files, dictionary, bos_token='', eos_token='', unk_token='', level='word', preprocess=None): diff --git a/tests/test_binarized_mnist.py b/tests/test_binarized_mnist.py index 232eaa2b8..c365a9263 100644 --- a/tests/test_binarized_mnist.py +++ b/tests/test_binarized_mnist.py @@ -1,10 +1,11 @@ -import numpy from numpy.testing import assert_raises from fuel.datasets import BinarizedMNIST +from tests import skip_if_not_available def test_mnist(): + skip_if_not_available(datasets=['binarized_mnist']) mnist_train = BinarizedMNIST('train') assert len(mnist_train.features) == 50000 assert mnist_train.num_examples == 50000 diff --git a/tests/test_datasets.py b/tests/test_datasets.py index ff6599787..314d7b5ef 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -18,7 +18,7 @@ def test_dataset(): data = [1, 2, 3] # The default stream requests an example at a time - stream = ContainerDataset(data).get_default_stream() + stream = DataStream(ContainerDataset(data)) epoch = stream.get_epoch_iterator() assert list(epoch) == list(zip(data)) @@ -33,7 +33,7 @@ def test_dataset(): def test_data_stream_mapping(): data = [1, 2, 3] data_doubled = [2, 4, 6] - stream = ContainerDataset(data).get_default_stream() + stream = DataStream(ContainerDataset(data)) wrapper1 = DataStreamMapping( stream, lambda d: (2 * d[0],)) assert list(wrapper1.get_epoch_iterator()) == list(zip(data_doubled)) @@ -49,7 +49,7 @@ def test_data_stream_mapping_sort(): [3, 2, 1]] data_sorted = [[1, 2, 3]] * 3 data_sorted_rev = [[3, 2, 1]] * 3 - stream = ContainerDataset(data).get_default_stream() + stream = DataStream(ContainerDataset(data)) wrapper1 = DataStreamMapping(stream, mapping=SortMapping(operator.itemgetter(0))) assert list(wrapper1.get_epoch_iterator()) == list(zip(data_sorted)) @@ -71,7 +71,7 @@ def test_data_stream_mapping_sort_multisource_ndarrays(): data_sorted = [(numpy.array([1, 2, 3]), numpy.array([6, 5, 4])), (numpy.array([1, 2, 3]), numpy.array([4, 6, 5])), (numpy.array([1, 2, 3]), numpy.array([4, 5, 6]))] - stream = ContainerDataset(data).get_default_stream() + stream = DataStream(ContainerDataset(data)) wrapper = DataStreamMapping(stream, mapping=SortMapping(operator.itemgetter(0))) for output, ground_truth in zip(wrapper.get_epoch_iterator(), data_sorted): @@ -87,7 +87,7 @@ def test_data_stream_mapping_sort_multisource(): data_sorted = [([1, 2, 3], [6, 5, 4]), ([1, 2, 3], [4, 6, 5]), ([1, 2, 3], [4, 5, 6])] - stream = ContainerDataset(data).get_default_stream() + stream = DataStream(ContainerDataset(data)) wrapper = DataStreamMapping(stream, mapping=SortMapping(operator.itemgetter(0))) assert list(wrapper.get_epoch_iterator()) == data_sorted @@ -96,7 +96,7 @@ def test_data_stream_mapping_sort_multisource(): def test_data_stream_filter(): data = [1, 2, 3] data_filtered = [1, 3] - stream = ContainerDataset(data).get_default_stream() + stream = DataStream(ContainerDataset(data)) wrapper = DataStreamFilter(stream, lambda d: d[0] % 2 == 1) assert list(wrapper.get_epoch_iterator()) == list(zip(data_filtered)) @@ -105,7 +105,7 @@ def test_floatx(): x = [numpy.array(d, dtype="float64") for d in [[1, 2], [3, 4]]] y = [numpy.array(d, dtype="int64") for d in [1, 2, 3]] dataset = ContainerDataset(OrderedDict([("x", x), ("y", y)])) - data = next(ForceFloatX(dataset.get_default_stream()).get_epoch_iterator()) + data = next(ForceFloatX(DataStream(dataset)).get_epoch_iterator()) assert str(data[0].dtype) == floatX assert str(data[1].dtype) == "int64" @@ -113,19 +113,19 @@ def test_floatx(): def test_sources_selection(): features = [5, 6, 7, 1] targets = [1, 0, 1, 1] - stream = ContainerDataset(OrderedDict( - [('features', features), ('targets', targets)])).get_default_stream() + stream = DataStream(ContainerDataset(OrderedDict( + [('features', features), ('targets', targets)]))) assert list(stream.get_epoch_iterator()) == list(zip(features, targets)) - stream = ContainerDataset({'features': features, 'targets': targets}, - sources=('targets',)).get_default_stream() + stream = DataStream(ContainerDataset( + {'features': features, 'targets': targets}, + sources=('targets',))) assert list(stream.get_epoch_iterator()) == list(zip(targets)) def test_data_driven_epochs(): class TestDataset(ContainerDataset): sources = ('data',) - default_scheme = ConstantScheme(1) def __init__(self): self.data = [[1, 2, 3, 4], @@ -152,7 +152,7 @@ def get_data(self, state, request): epochs = [] epochs.append([([1],), ([2],), ([3],), ([4],)]) epochs.append([([5],), ([6],), ([7],), ([8],)]) - stream = TestDataset().get_default_stream() + stream = DataStream(TestDataset(), iteration_scheme=ConstantScheme(1)) assert list(stream.get_epoch_iterator()) == epochs[0] assert list(stream.get_epoch_iterator()) == epochs[1] assert list(stream.get_epoch_iterator()) == epochs[0] @@ -206,7 +206,7 @@ def test_cache(): def test_batch_data_stream(): - stream = ContainerDataset([1, 2, 3, 4, 5]).get_default_stream() + stream = DataStream(ContainerDataset([1, 2, 3, 4, 5])) batches = list(BatchDataStream(stream, ConstantScheme(2)) .get_epoch_iterator()) expected = [(numpy.array([1, 2]),), @@ -223,7 +223,7 @@ def try_strict(strictness): .get_epoch_iterator()) assert_raises(ValueError, try_strict, 2) assert len(try_strict(1)) == 2 - stream2 = ContainerDataset([1, 2, 3, 4, 5, 6]).get_default_stream() + stream2 = DataStream(ContainerDataset([1, 2, 3, 4, 5, 6])) assert len(list(BatchDataStream(stream2, ConstantScheme(2), strictness=2) .get_epoch_iterator())) == 3 @@ -231,8 +231,7 @@ def try_strict(strictness): def test_padding_data_stream(): # 1-D sequences stream = BatchDataStream( - ContainerDataset([[1], [2, 3], [], [4, 5, 6], [7]]) - .get_default_stream(), + DataStream(ContainerDataset([[1], [2, 3], [], [4, 5, 6], [7]])), ConstantScheme(2)) mask_stream = PaddingDataStream(stream) assert mask_stream.sources == ("data", "data_mask") @@ -249,8 +248,8 @@ def test_padding_data_stream(): # 2D sequences stream2 = BatchDataStream( - ContainerDataset([numpy.ones((3, 4)), 2 * numpy.ones((2, 4))]) - .get_default_stream(), + DataStream(ContainerDataset([numpy.ones((3, 4)), + 2 * numpy.ones((2, 4))])), ConstantScheme(2)) it = PaddingDataStream(stream2).get_epoch_iterator() data, mask = next(it) @@ -261,8 +260,7 @@ def test_padding_data_stream(): # 2 sources stream3 = PaddingDataStream(BatchDataStream( - ContainerDataset(dict(features=[[1], [2, 3], []], - targets=[[4, 5, 6], [7]])) - .get_default_stream(), + DataStream(ContainerDataset(dict(features=[[1], [2, 3], []], + targets=[[4, 5, 6], [7]]))), ConstantScheme(2))) assert len(next(stream3.get_epoch_iterator())) == 4 diff --git a/tests/test_text.py b/tests/test_text.py index 907592755..10b3f37fd 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -5,6 +5,7 @@ from six.moves import cPickle from fuel.datasets import TextFile +from fuel.streams import DataStream def lower(s): @@ -25,7 +26,7 @@ def test_text(): text_data = TextFile(files=[sentences1, sentences2], dictionary=dictionary, bos_token=None, preprocess=lower) - stream = text_data.get_default_stream() + stream = DataStream(text_data) epoch = stream.get_epoch_iterator() assert len(list(epoch)) == 4 epoch = stream.get_epoch_iterator() @@ -46,6 +47,6 @@ def test_text(): text_data = TextFile(files=[sentences1, sentences2], dictionary=dictionary, preprocess=lower, level="character") - sentence = next(text_data.get_default_stream().get_epoch_iterator())[0] + sentence = next(DataStream(text_data).get_epoch_iterator())[0] assert sentence[:3] == [27, 19, 7] assert sentence[-3:] == [2, 4, 28]