diff --git a/src/datachain/toolkit/split.py b/src/datachain/toolkit/split.py index 1fb62fe0..426c2495 100644 --- a/src/datachain/toolkit/split.py +++ b/src/datachain/toolkit/split.py @@ -58,10 +58,14 @@ def train_test_split(dc: DataChain, weights: list[float]) -> list[DataChain]: weights_normalized = [weight / sum(weights) for weight in weights] + resolution = 2**31 - 1 # Maximum positive value for a 32-bit signed integer. + return [ dc.filter( - C("sys__rand") % 1000 >= round(sum(weights_normalized[:index]) * 1000), - C("sys__rand") % 1000 < round(sum(weights_normalized[: index + 1]) * 1000), + C("sys__rand") % resolution + >= round(sum(weights_normalized[:index]) * resolution), + C("sys__rand") % resolution + < round(sum(weights_normalized[: index + 1]) * resolution), ) for index, _ in enumerate(weights_normalized) ] diff --git a/tests/conftest.py b/tests/conftest.py index 3a15e216..1291643e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -707,16 +707,16 @@ def studio_datasets(requests_mock): def not_random_ds(test_session): return DataChain.from_records( [ - {"sys__id": 1, "sys__rand": 50, "fib": 0}, - {"sys__id": 2, "sys__rand": 150, "fib": 1}, - {"sys__id": 3, "sys__rand": 250, "fib": 1}, - {"sys__id": 4, "sys__rand": 350, "fib": 2}, - {"sys__id": 5, "sys__rand": 450, "fib": 3}, - {"sys__id": 6, "sys__rand": 550, "fib": 5}, - {"sys__id": 7, "sys__rand": 650, "fib": 8}, - {"sys__id": 8, "sys__rand": 750, "fib": 13}, - {"sys__id": 9, "sys__rand": 850, "fib": 21}, - {"sys__id": 10, "sys__rand": 950, "fib": 34}, + {"sys__id": 1, "sys__rand": 200000000, "fib": 0}, + {"sys__id": 2, "sys__rand": 400000000, "fib": 1}, + {"sys__id": 3, "sys__rand": 600000000, "fib": 1}, + {"sys__id": 4, "sys__rand": 800000000, "fib": 2}, + {"sys__id": 5, "sys__rand": 1000000000, "fib": 3}, + {"sys__id": 6, "sys__rand": 1200000000, "fib": 5}, + {"sys__id": 7, "sys__rand": 1400000000, "fib": 8}, + {"sys__id": 8, "sys__rand": 1600000000, "fib": 13}, + {"sys__id": 9, "sys__rand": 1800000000, "fib": 21}, + {"sys__id": 10, "sys__rand": 2000000000, "fib": 34}, ], session=test_session, schema={"sys": Sys, "fib": int}, @@ -727,16 +727,16 @@ def not_random_ds(test_session): def pseudo_random_ds(test_session): return DataChain.from_records( [ - {"sys__id": 1, "sys__rand": 1344339883, "fib": 0}, - {"sys__id": 2, "sys__rand": 3901153096, "fib": 1}, - {"sys__id": 3, "sys__rand": 4255991360, "fib": 1}, - {"sys__id": 4, "sys__rand": 2526403609, "fib": 2}, - {"sys__id": 5, "sys__rand": 1871733386, "fib": 3}, - {"sys__id": 6, "sys__rand": 9380910850, "fib": 5}, - {"sys__id": 7, "sys__rand": 2770679740, "fib": 8}, - {"sys__id": 8, "sys__rand": 2538886575, "fib": 13}, - {"sys__id": 9, "sys__rand": 3969542617, "fib": 21}, - {"sys__id": 10, "sys__rand": 7541790992, "fib": 34}, + {"sys__id": 1, "sys__rand": 2406827533654413759, "fib": 0}, + {"sys__id": 2, "sys__rand": 743035223448130834, "fib": 1}, + {"sys__id": 3, "sys__rand": 8572034894545971037, "fib": 1}, + {"sys__id": 4, "sys__rand": 3413911135601125438, "fib": 2}, + {"sys__id": 5, "sys__rand": 8036488725627198326, "fib": 3}, + {"sys__id": 6, "sys__rand": 2020789040280779494, "fib": 5}, + {"sys__id": 7, "sys__rand": 8478782014085172114, "fib": 8}, + {"sys__id": 8, "sys__rand": 1374262678671783922, "fib": 13}, + {"sys__id": 9, "sys__rand": 7728884931956308771, "fib": 21}, + {"sys__id": 10, "sys__rand": 5591681088079559562, "fib": 34}, ], session=test_session, schema={"sys": Sys, "fib": int}, diff --git a/tests/func/test_pull.py b/tests/func/test_pull.py index 9f7415d0..24ce20bf 100644 --- a/tests/func/test_pull.py +++ b/tests/func/test_pull.py @@ -78,7 +78,7 @@ def _adapt_row(row): def schema(): return { "id": {"type": "UInt64"}, - "sys__rand": {"type": "Int64"}, + "sys__rand": {"type": "UInt64"}, "file__path": {"type": "String"}, "file__etag": {"type": "String"}, "file__version": {"type": "String"}, diff --git a/tests/func/test_toolkit.py b/tests/func/test_toolkit.py index f2388254..673475eb 100644 --- a/tests/func/test_toolkit.py +++ b/tests/func/test_toolkit.py @@ -22,9 +22,9 @@ def test_train_test_split_not_random(not_random_ds, weights, expected): @pytest.mark.parametrize( "weights,expected", [ - [[1, 1], [[2, 3, 5], [1, 4, 6, 7, 8, 9, 10]]], - [[4, 1], [[2, 3, 4, 5, 7, 8, 9], [1, 6, 10]]], - [[0.7, 0.2, 0.1], [[2, 3, 4, 5, 8, 9], [1, 6, 7], [10]]], + [[1, 1], [[1, 5, 6, 7, 8], [2, 3, 4, 9, 10]]], + [[4, 1], [[1, 3, 5, 6, 7, 8, 9], [2, 4, 10]]], + [[0.7, 0.2, 0.1], [[1, 3, 5, 6, 7, 8, 9], [2, 4], [10]]], ], ) def test_train_test_split_random(pseudo_random_ds, weights, expected):