Skip to content

Commit

Permalink
Skip tests safely without sparse installed
Browse files Browse the repository at this point in the history
  • Loading branch information
khaeru committed Nov 8, 2024
1 parent a7622d3 commit 54a23ff
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
7 changes: 3 additions & 4 deletions genno/tests/core/test_sparsedataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,16 @@
import numpy as np
import pandas as pd
import pytest
import sparse
import xarray as xr
from xarray.testing import assert_equal as assert_xr_equal

import genno
from genno import Computer
from genno.core.sparsedataarray import HAS_SPARSE, SparseDataArray
from genno.core.sparsedataarray import SparseDataArray
from genno.testing import add_test_data, random_qty

pytestmark = pytest.mark.skipif(
condition=not HAS_SPARSE,
sparse = pytest.importorskip(
"sparse",
reason="`sparse` not available → can't test SparseDataArray",
)

Expand Down
24 changes: 16 additions & 8 deletions genno/tests/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import genno.caching
from genno.caching import Encoder, decorate, hash_args, hash_code, hash_contents
from genno.core.attrseries import AttrSeries
from genno.core.sparsedataarray import SparseDataArray
from genno.core.sparsedataarray import HAS_SPARSE, SparseDataArray


class TestEncoder:
Expand Down Expand Up @@ -58,11 +58,19 @@ def _serialize_bar(o: Bar):
@pytest.mark.parametrize(
"value, suffix",
(
(np.array([3]), "pickle"),
(pd.DataFrame(), "parquet"),
(AttrSeries(), "parquet" if sys.version_info >= (3, 9) else "pickle"),
(lambda: np.array([3]), "pickle"),
(pd.DataFrame, "parquet"),
(AttrSeries, "parquet" if sys.version_info >= (3, 9) else "pickle"),
pytest.param(
SparseDataArray(), "parquet", marks=pytest.mark.xfail(raises=TypeError)
SparseDataArray,
"parquet",
marks=[
pytest.mark.skipif(
not HAS_SPARSE,
reason="`sparse` not available → can't test SparseDataArray",
),
pytest.mark.xfail(raises=TypeError),
],
),
),
)
Expand All @@ -72,12 +80,12 @@ def test_decorate(caplog, tmp_path, value, suffix):
caplog.set_level(logging.DEBUG)

def myfunc():
return value
return value()

decorated = decorate(myfunc, cache_path=tmp_path)

# Decorated function runs
assert all(value == decorated())
assert all(value() == decorated())

# Value was cached
# NB use [1] not [-1] to accommodate a possible message about Parquet support
Expand All @@ -86,7 +94,7 @@ def myfunc():
assert 1 == len(files)

# Cache hit on the second call
assert all(value == decorated())
assert all(value() == decorated())
assert caplog.messages[-1].startswith("Cache hit for myfunc(<")

for f in files:
Expand Down

0 comments on commit 54a23ff

Please sign in to comment.