From 8bcec967341a0c25f33e33f9fd3dc354cc55d001 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 9 Sep 2022 10:31:28 +0200 Subject: [PATCH 1/6] Add demo test extension. --- thinc/tests/backends/test_ops.py | 13 ++++++------- thinc/tests/strategies.py | 23 ++++++++++++++++------- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/thinc/tests/backends/test_ops.py b/thinc/tests/backends/test_ops.py index 04ab7d231..3a603ab9b 100644 --- a/thinc/tests/backends/test_ops.py +++ b/thinc/tests/backends/test_ops.py @@ -15,9 +15,8 @@ from thinc.types import Floats2d import inspect -from .. import strategies -from ..strategies import arrays_BI, ndarrays_of_shape - +from thinc.tests import strategies +from thinc.tests.strategies import ndarrays_of_shape MAX_EXAMPLES = 10 @@ -774,11 +773,11 @@ def test_gemm_out_used(cpu_ops): @pytest.mark.parametrize("cpu_ops", CPU_OPS) -@settings(max_examples=MAX_EXAMPLES, deadline=None) -@given(X=strategies.arrays_BI()) -def test_flatten_unflatten_roundtrip(cpu_ops, X): +@settings(max_examples=MAX_EXAMPLES * 2, deadline=None) +@given(X=strategies.arrays_BI(dtype="i") | strategies.arrays_BI(dtype="f")) +def test_flatten_unflatten_roundtrip(cpu_ops: NumpyOps, X: numpy.ndarray): flat = cpu_ops.flatten([x for x in X]) - assert flat.ndim == 1 + assert flat.ndim == X.ndim - 1 unflat = cpu_ops.unflatten(flat, [len(x) for x in X]) assert_allclose(X, unflat) flat2 = cpu_ops.flatten([x for x in X], pad=1, dtype="f") diff --git a/thinc/tests/strategies.py b/thinc/tests/strategies.py index 322728cd9..d6f289b39 100644 --- a/thinc/tests/strategies.py +++ b/thinc/tests/strategies.py @@ -1,7 +1,10 @@ +from typing import Protocol, Union, get_args + import numpy -from hypothesis.strategies import just, tuples, integers, floats +from hypothesis.strategies import just, tuples, integers, floats, SearchStrategy from hypothesis.extra.numpy import arrays from thinc.api import NumpyOps, Linear +from thinc.types import DTypes, DTypesFloat def get_ops(): @@ -25,6 +28,12 @@ def get_input(nr_batch, nr_in): return ops.alloc2f(nr_batch, nr_in) +class StrategyCallable(Protocol): + """For proper typing in .lengths().""" + def __call__(self, min_value: Union[int, float], max_value: Union[int, float]) -> SearchStrategy: + ... + + def lengths(lo=1, hi=10): return integers(min_value=lo, max_value=hi) @@ -33,8 +42,8 @@ def shapes(min_rows=1, max_rows=100, min_cols=1, max_cols=100): return tuples(lengths(lo=min_rows, hi=max_rows), lengths(lo=min_cols, hi=max_cols)) -def ndarrays_of_shape(shape, lo=-10.0, hi=10.0, dtype="float32", width=32): - if dtype.startswith("float"): +def ndarrays_of_shape(shape, lo=-10.0, hi=10.0, dtype: DTypes = "float32", width=32): + if dtype in get_args(DTypesFloat): return arrays( dtype, shape=shape, elements=floats(min_value=lo, max_value=hi, width=width) ) @@ -48,18 +57,18 @@ def ndarrays(min_len=0, max_len=10, min_val=-10.0, max_val=10.0): ) -def arrays_BI(min_B=1, max_B=10, min_I=1, max_I=100): +def arrays_BI(min_B=1, max_B=10, min_I=1, max_I=100, dtype: DTypes = "float32"): shapes = tuples(lengths(lo=min_B, hi=max_B), lengths(lo=min_I, hi=max_I)) - return shapes.flatmap(ndarrays_of_shape) + return shapes.flatmap(lambda shape: ndarrays_of_shape(shape, dtype=dtype)) -def arrays_BOP(min_B=1, max_B=10, min_O=1, max_O=100, min_P=1, max_P=5): +def arrays_BOP(min_B=1, max_B=10, min_O=1, max_O=100, min_P=1, max_P=5, dtype: DTypes = "float32"): shapes = tuples( lengths(lo=min_B, hi=max_B), lengths(lo=min_O, hi=max_O), lengths(lo=min_P, hi=max_P), ) - return shapes.flatmap(ndarrays_of_shape) + return shapes.flatmap(lambda shape: ndarrays_of_shape(shape, dtype=dtype)) def arrays_BOP_BO(min_B=1, max_B=10, min_O=1, max_O=100, min_P=1, max_P=5): From de52fa0c7d9f0a8eae62e18dc8b962683608a576 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 9 Sep 2022 10:37:38 +0200 Subject: [PATCH 2/6] Remove unused StrategyCallable. --- thinc/tests/strategies.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/thinc/tests/strategies.py b/thinc/tests/strategies.py index d6f289b39..064a73bf2 100644 --- a/thinc/tests/strategies.py +++ b/thinc/tests/strategies.py @@ -28,12 +28,6 @@ def get_input(nr_batch, nr_in): return ops.alloc2f(nr_batch, nr_in) -class StrategyCallable(Protocol): - """For proper typing in .lengths().""" - def __call__(self, min_value: Union[int, float], max_value: Union[int, float]) -> SearchStrategy: - ... - - def lengths(lo=1, hi=10): return integers(min_value=lo, max_value=hi) From 825a0909a0a848a4d56e0a5d127c6f6ea45c186e Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 9 Sep 2022 10:39:15 +0200 Subject: [PATCH 3/6] Remove unused imports. --- thinc/tests/strategies.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thinc/tests/strategies.py b/thinc/tests/strategies.py index 064a73bf2..91688e989 100644 --- a/thinc/tests/strategies.py +++ b/thinc/tests/strategies.py @@ -1,7 +1,7 @@ -from typing import Protocol, Union, get_args +from typing import, get_args import numpy -from hypothesis.strategies import just, tuples, integers, floats, SearchStrategy +from hypothesis.strategies import just, tuples, integers, floats from hypothesis.extra.numpy import arrays from thinc.api import NumpyOps, Linear from thinc.types import DTypes, DTypesFloat From b5b37246a31ddd27f94c52888cf3561467267eb8 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 9 Sep 2022 10:44:46 +0200 Subject: [PATCH 4/6] Fix syntax error. --- thinc/tests/strategies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thinc/tests/strategies.py b/thinc/tests/strategies.py index 91688e989..448d7e25e 100644 --- a/thinc/tests/strategies.py +++ b/thinc/tests/strategies.py @@ -1,4 +1,4 @@ -from typing import, get_args +from typing import get_args import numpy from hypothesis.strategies import just, tuples, integers, floats From 1d6f192f6a83db3325f169183ab6b8d50bae22ca Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 9 Sep 2022 10:47:16 +0200 Subject: [PATCH 5/6] Revert imports. --- thinc/tests/backends/test_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thinc/tests/backends/test_ops.py b/thinc/tests/backends/test_ops.py index 3a603ab9b..1236329f7 100644 --- a/thinc/tests/backends/test_ops.py +++ b/thinc/tests/backends/test_ops.py @@ -15,8 +15,8 @@ from thinc.types import Floats2d import inspect -from thinc.tests import strategies -from thinc.tests.strategies import ndarrays_of_shape +from .. import strategies +from ..strategies import ndarrays_of_shape MAX_EXAMPLES = 10 From 3d81072bc45bd7268c291b82eb2f9db3ad1464ec Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 9 Sep 2022 11:01:45 +0200 Subject: [PATCH 6/6] Remove typing.get_args() for Python < 3.8. --- thinc/tests/strategies.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/thinc/tests/strategies.py b/thinc/tests/strategies.py index 448d7e25e..214500c0e 100644 --- a/thinc/tests/strategies.py +++ b/thinc/tests/strategies.py @@ -1,10 +1,8 @@ -from typing import get_args - import numpy from hypothesis.strategies import just, tuples, integers, floats from hypothesis.extra.numpy import arrays from thinc.api import NumpyOps, Linear -from thinc.types import DTypes, DTypesFloat +from thinc.types import DTypes def get_ops(): @@ -37,7 +35,7 @@ def shapes(min_rows=1, max_rows=100, min_cols=1, max_cols=100): def ndarrays_of_shape(shape, lo=-10.0, hi=10.0, dtype: DTypes = "float32", width=32): - if dtype in get_args(DTypesFloat): + if dtype.startswith("f"): return arrays( dtype, shape=shape, elements=floats(min_value=lo, max_value=hi, width=width) )