Skip to content

Commit

Permalink
initial resample function
Browse files Browse the repository at this point in the history
  • Loading branch information
kseniyausovich committed Aug 5, 2022
1 parent 1fd8bd6 commit 6d27ff9
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,20 +334,20 @@ def test_check_logger_exists():


def test_class_stratify_check():
selection_frac = 0.9
train_frac = 0.9
idx = np.arange(100)
y = np.tile(np.arange(5), 20)
train, test = resample(idx, selection_frac=selection_frac, random_state=0, stratify=y)
train, test = resample(idx, train_frac=train_frac, random_state=0, stratify=y)

if int(np.ceil(len(idx) * selection_frac)) != len(train):
if int(np.ceil(len(idx) * train_frac)) != len(train):
raise ValueError("Incorrect train size")
if (len(idx) - int(np.ceil(len(idx) * selection_frac))) != len(test):
if (len(idx) - int(np.ceil(len(idx) * train_frac))) != len(test):
raise ValueError("Incorrect test size")

classes, dist = np.unique(y, return_counts=True)

for cl, di in zip(classes, dist):
if int(np.ceil(di * selection_frac)) != sum(y[train] == cl):
if int(np.ceil(di * train_frac)) != sum(y[train] == cl):
raise ValueError(f"Incorrect train class size {cl}")
if di - int(np.ceil(di * selection_frac)) != sum(y[test] == cl):
if di - int(np.ceil(di * train_frac)) != sum(y[test] == cl):
raise ValueError(f"Incorrect test class size {cl}")

0 comments on commit 6d27ff9

Please sign in to comment.