forked from cooijmanstim/tsa-rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
svhn.py
27 lines (22 loc) · 894 Bytes
/
svhn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from fuel.transformers import Mapping
from fuel.datasets.svhn import SVHN
import tasks
def fix_target_representation(data):
x, y = data
# use zero to represent zero
y[y == 10] = 0
return x, y
class DigitTask(tasks.Classification):
name = "svhn_digit"
def __init__(self, *args, **kwargs):
super(DigitTask, self).__init__(*args, **kwargs)
self.n_classes = 10
self.n_channels = 1
def load_datasets(self):
return dict(
train=SVHN(which_sets=["train"], which_format=2, subset=slice(None, 50000)),
valid=SVHN(which_sets=["train"], which_format=2, subset=slice(50000, None)),
test=SVHN(which_sets=["test"], which_format=2))
def get_stream(self, *args, **kwargs):
return Mapping(super(DigitTask, self).get_stream(*args, **kwargs),
mapping=fix_target_representation)