diff --git a/tests/test_utils.py b/tests/test_utils.py index d1ca8b8..231bd51 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -334,20 +334,20 @@ def test_check_logger_exists(): def test_class_stratify_check(): - selection_frac = 0.9 + train_frac = 0.9 idx = np.arange(100) y = np.tile(np.arange(5), 20) - train, test = resample(idx, selection_frac=selection_frac, random_state=0, stratify=y) + train, test = resample(idx, train_frac=train_frac, random_state=0, stratify=y) - if int(np.ceil(len(idx) * selection_frac)) != len(train): + if int(np.ceil(len(idx) * train_frac)) != len(train): raise ValueError("Incorrect train size") - if (len(idx) - int(np.ceil(len(idx) * selection_frac))) != len(test): + if (len(idx) - int(np.ceil(len(idx) * train_frac))) != len(test): raise ValueError("Incorrect test size") classes, dist = np.unique(y, return_counts=True) for cl, di in zip(classes, dist): - if int(np.ceil(di * selection_frac)) != sum(y[train] == cl): + if int(np.ceil(di * train_frac)) != sum(y[train] == cl): raise ValueError(f"Incorrect train class size {cl}") - if di - int(np.ceil(di * selection_frac)) != sum(y[test] == cl): + if di - int(np.ceil(di * train_frac)) != sum(y[test] == cl): raise ValueError(f"Incorrect test class size {cl}")