Skip to content

Commit

Permalink
Merge pull request #2603 from ales-erjavec/fixes/randomization-seed-r…
Browse files Browse the repository at this point in the history
…euse

[FIX] preprocess.randomization: Do not use the same seed for X, Y, and meta
  • Loading branch information
janezd authored Sep 21, 2017
2 parents cdddfc3 + 8bdcb3e commit 351b364
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
19 changes: 11 additions & 8 deletions Orange/preprocess/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,30 +354,33 @@ def __call__(self, data):
Randomized data table.
"""
new_data = data.copy()
rstate = np.random.RandomState(self.rand_seed)
# ensure the same seed is not used to shuffle X and Y at the same time
r1, r2, r3 = rstate.randint(0, 2 ** 32 - 1, size=3, dtype=np.int64)
if self.rand_type & Randomize.RandomizeClasses:
new_data.Y = self.randomize(new_data.Y)
new_data.Y = self.randomize(new_data.Y, r1)
if self.rand_type & Randomize.RandomizeAttributes:
new_data.X = self.randomize(new_data.X)
new_data.X = self.randomize(new_data.X, r2)
if self.rand_type & Randomize.RandomizeMetas:
new_data.metas = self.randomize(new_data.metas)
new_data.metas = self.randomize(new_data.metas, r3)
return new_data

def randomize(self, table):
np.random.seed(self.rand_seed)
def randomize(self, table, rand_state=None):
rstate = np.random.RandomState(rand_state)
if sp.issparse(table):
table = table.tocsc()
rnd_indices = np.arange(table.shape[0], dtype=table.indices.dtype)
for i in range(table.shape[1]):
col_indices = \
table.indices[table.indptr[i]: table.indptr[i + 1]]
new_indices = rnd_indices[:len(col_indices)]
np.random.shuffle(new_indices)
rstate.shuffle(new_indices)
col_indices[:] = new_indices
elif len(table.shape) > 1:
for i in range(table.shape[1]):
np.random.shuffle(table[:, i])
rstate.shuffle(table[:, i])
else:
np.random.shuffle(table)
rstate.shuffle(table)
return table


Expand Down
11 changes: 8 additions & 3 deletions Orange/widgets/data/tests/test_owpreprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,16 @@ def test_randomize(self):
self.widget.set_model(model)
self.send_signal(self.widget.Inputs.data, self.zoo)
output = self.get_output(self.widget.Outputs.preprocessed_data)
np.random.seed(1)
np.random.shuffle(self.zoo.Y)
r = Randomize(Randomize.RandomizeClasses, rand_seed=1)
expected = r(self.zoo)

np.testing.assert_array_equal(expected.X, output.X)
np.testing.assert_array_equal(expected.Y, output.Y)
np.testing.assert_array_equal(expected.metas, output.metas)

np.testing.assert_array_equal(self.zoo.X, output.X)
np.testing.assert_array_equal(self.zoo.Y, output.Y)
np.testing.assert_array_equal(self.zoo.metas, output.metas)
self.assertFalse(np.array_equal(self.zoo.Y, output.Y))

def test_normalize(self):
data = Table("iris")
Expand Down

0 comments on commit 351b364

Please sign in to comment.