diff --git a/src/delphi/train/shuffle.py b/src/delphi/train/shuffle.py new file mode 100644 index 00000000..1f6431ce --- /dev/null +++ b/src/delphi/train/shuffle.py @@ -0,0 +1,39 @@ +class RNG: + """Random Number Generator + + Linear Congruential Generator equivalent to minstd_rand in C++11 + https://en.cppreference.com/w/cpp/numeric/random + """ + + a = 48271 + m = 2147483647 # 2^31 - 1 + + def __init__(self, seed: int): + assert 0 <= seed < self.m + self.state = seed + + def __call__(self) -> int: + self.state = (self.state * self.a) % self.m + return self.state + + +def shuffle_list(in_out: list, seed: int): + """Deterministically shuffle a list in-place + + Implements Fisher-Yates shuffle with LCG as randomness source + https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm + """ + rng = RNG(seed) + n = len(in_out) + for i in range(n - 1, 0, -1): + j = rng() % (i + 1) + in_out[i], in_out[j] = in_out[j], in_out[i] + + +def shuffle_epoch(samples: list, seed: int, epoch_nr: int): + """Shuffle the samples in-place for a given training epoch""" + rng = RNG((10_000 + seed) % RNG.m) + for _ in range(epoch_nr): + rng() + shuffle_seed = rng() + shuffle_list(samples, shuffle_seed) diff --git a/tests/train/test_shuffle.py b/tests/train/test_shuffle.py new file mode 100644 index 00000000..cf4a6f96 --- /dev/null +++ b/tests/train/test_shuffle.py @@ -0,0 +1,51 @@ +import random + +import pytest + +from delphi.train.shuffle import RNG, shuffle_epoch, shuffle_list + + +def test_rng(): + """ + Compare to the following C++ code: + + #include + #include + + int main() { + unsigned int seed = 12345; + std::minstd_rand generator(seed); + + for (int i = 0; i < 5; i++) + std::cout << generator() << ", "; + } + """ + rng = RNG(12345) + expected = [595905495, 1558181227, 1498755989, 2021244883, 887213142] + for val in expected: + assert rng() == val + + +@pytest.mark.parametrize( + "input_list, seed", + [(random.sample(range(100), 10), random.randint(1, 1000)) for _ in range(5)], +) +def test_shuffle_list(input_list, seed): + original_list = input_list.copy() + shuffle_list(input_list, seed) + assert sorted(input_list) == sorted(original_list) + + +@pytest.mark.parametrize( + "seed, epoch_nr, expected", + [ + (1, 1, [2, 5, 1, 3, 4]), + (2, 5, [2, 1, 4, 5, 3]), + (3, 10, [1, 4, 3, 5, 2]), + (4, 100, [3, 4, 5, 1, 2]), + ], +) +def test_shuffle_epoch(seed, epoch_nr, expected): + samples = [1, 2, 3, 4, 5] + shuffle_epoch(samples, seed, epoch_nr) + assert samples == expected