diff --git a/tests/test_iter_utils.py b/tests/test_iter_utils.py index a4edb5761a..612ab6fe7f 100644 --- a/tests/test_iter_utils.py +++ b/tests/test_iter_utils.py @@ -6,11 +6,11 @@ @pytest.mark.parametrize( "values, expected", [ - ([], set()), - ([0, 0], {0}), - ([0, 0, 0], {0}), - ([1, 2, 3], {1}), - ([1, 5, 8, 8, 10, 15], {4, 3, 0, 2, 5}), + ([], []), + ([0, 0], [0]), + ([0, 0, 0], [0, 0]), + ([1, 2, 3], [1, 1]), + ([1, 5, 8, 8, 10, 15], [4, 3, 0, 2, 5]), ], ) def test_get_internals(values, expected): diff --git a/unblob/iter_utils.py b/unblob/iter_utils.py index 66e7ba794d..4dbb973f29 100644 --- a/unblob/iter_utils.py +++ b/unblob/iter_utils.py @@ -1,5 +1,5 @@ import itertools -from typing import List, Set +from typing import List def pairwise(iterable): @@ -10,7 +10,7 @@ def pairwise(iterable): return zip(a, b) -def get_intervals(values: List[int]) -> Set[int]: +def get_intervals(values: List[int]) -> List[int]: """Get all the intervals between numbers. It's similar to numpy.diff function. @@ -20,7 +20,7 @@ def get_intervals(values: List[int]) -> Set[int]: >>> get_intervals([1, 4, 5, 6, 10]) [3, 1, 1, 4] """ - all_diffs = set() + all_diffs = [] for value, next_value in pairwise(values): - all_diffs.add(next_value - value) + all_diffs.append(next_value - value) return all_diffs