Skip to content

Commit

Permalink
Permute list in place (#23)
Browse files Browse the repository at this point in the history
* Permute list in place

* Actual tests

* Move test to correct position and updated tests

* Fixed bug, separated function

* Adds epoch control

* misc fixes

* type hint

* % m in shuffle_epoch

---------

Co-authored-by: Jett <[email protected]>
  • Loading branch information
SrGonao and jettjaniak authored Feb 5, 2024
1 parent cb3976e commit 5448b4b
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 0 deletions.
39 changes: 39 additions & 0 deletions src/delphi/train/shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
class RNG:
"""Random Number Generator
Linear Congruential Generator equivalent to minstd_rand in C++11
https://en.cppreference.com/w/cpp/numeric/random
"""

a = 48271
m = 2147483647 # 2^31 - 1

def __init__(self, seed: int):
assert 0 <= seed < self.m
self.state = seed

def __call__(self) -> int:
self.state = (self.state * self.a) % self.m
return self.state


def shuffle_list(in_out: list, seed: int):
"""Deterministically shuffle a list in-place
Implements Fisher-Yates shuffle with LCG as randomness source
https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm
"""
rng = RNG(seed)
n = len(in_out)
for i in range(n - 1, 0, -1):
j = rng() % (i + 1)
in_out[i], in_out[j] = in_out[j], in_out[i]


def shuffle_epoch(samples: list, seed: int, epoch_nr: int):
"""Shuffle the samples in-place for a given training epoch"""
rng = RNG((10_000 + seed) % RNG.m)
for _ in range(epoch_nr):
rng()
shuffle_seed = rng()
shuffle_list(samples, shuffle_seed)
51 changes: 51 additions & 0 deletions tests/train/test_shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import random

import pytest

from delphi.train.shuffle import RNG, shuffle_epoch, shuffle_list


def test_rng():
"""
Compare to the following C++ code:
#include <iostream>
#include <random>
int main() {
unsigned int seed = 12345;
std::minstd_rand generator(seed);
for (int i = 0; i < 5; i++)
std::cout << generator() << ", ";
}
"""
rng = RNG(12345)
expected = [595905495, 1558181227, 1498755989, 2021244883, 887213142]
for val in expected:
assert rng() == val


@pytest.mark.parametrize(
"input_list, seed",
[(random.sample(range(100), 10), random.randint(1, 1000)) for _ in range(5)],
)
def test_shuffle_list(input_list, seed):
original_list = input_list.copy()
shuffle_list(input_list, seed)
assert sorted(input_list) == sorted(original_list)


@pytest.mark.parametrize(
"seed, epoch_nr, expected",
[
(1, 1, [2, 5, 1, 3, 4]),
(2, 5, [2, 1, 4, 5, 3]),
(3, 10, [1, 4, 3, 5, 2]),
(4, 100, [3, 4, 5, 1, 2]),
],
)
def test_shuffle_epoch(seed, epoch_nr, expected):
samples = [1, 2, 3, 4, 5]
shuffle_epoch(samples, seed, epoch_nr)
assert samples == expected

0 comments on commit 5448b4b

Please sign in to comment.