Skip to content

Commit

Permalink
Merge pull request #1770 from OceanParcels/v/1764
Browse files Browse the repository at this point in the history
Fixture in `test_examples.py` to clean up generated data files
  • Loading branch information
VeckoTheGecko authored Nov 25, 2024
2 parents d233947 + d2425a2 commit 3c94b3e
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 7 deletions.
19 changes: 12 additions & 7 deletions docs/examples/example_mitgcm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from datetime import timedelta
from pathlib import Path
from typing import Literal

import numpy as np
import parcels
Expand All @@ -7,7 +9,7 @@
ptype = {"scipy": parcels.ScipyParticle, "jit": parcels.JITParticle}


def run_mitgcm_zonally_reentrant(mode):
def run_mitgcm_zonally_reentrant(mode: Literal["scipy", "jit"], path: Path):
"""Function that shows how to load MITgcm data in a zonally periodic domain."""
data_folder = parcels.download_example_dataset("MITgcm_example_data")
filenames = {
Expand Down Expand Up @@ -41,7 +43,7 @@ def periodicBC(particle, fieldset, time):
size=10,
)
pfile = parcels.ParticleFile(
"MIT_particles_" + str(mode) + ".zarr",
str(path),
pset,
outputdt=timedelta(days=1),
chunks=(len(pset), 1),
Expand All @@ -52,12 +54,15 @@ def periodicBC(particle, fieldset, time):
)


def test_mitgcm_output_compare():
run_mitgcm_zonally_reentrant("scipy")
run_mitgcm_zonally_reentrant("jit")
def test_mitgcm_output_compare(tmpdir):
def get_path(mode: Literal["scipy", "jit"]) -> Path:
return tmpdir / f"MIT_particles_{mode}.zarr"

ds_jit = xr.open_zarr("MIT_particles_jit.zarr")
ds_scipy = xr.open_zarr("MIT_particles_scipy.zarr")
for mode in ["scipy", "jit"]:
run_mitgcm_zonally_reentrant(mode, get_path(mode))

ds_jit = xr.open_zarr(get_path("jit"))
ds_scipy = xr.open_zarr(get_path("scipy"))

np.testing.assert_allclose(ds_jit.lat.data, ds_scipy.lat.data)
np.testing.assert_allclose(ds_jit.lon.data, ds_scipy.lon.data)
30 changes: 30 additions & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
import runpy
import shutil
import sys
import time
from pathlib import Path

import pytest
Expand All @@ -8,6 +11,33 @@
example_fnames = [path.name for path in example_folder.glob("*.py")]


@pytest.fixture(autouse=True)
def cleanup_generated_data_files():
"""Clean up generated data files from test run.
Records current folder contents before test, and cleans up any generated `.nc` files
and `.zarr` folders afterwards. For safety this is non-recursive. This function is
only necessary as the scripts being run aren't native pytest tests, so they don't
have access to the `tmpdir` fixture.
"""
folder_contents = os.listdir()
yield
time.sleep(0.1) # Buffer so that files are closed before we try to delete them.
for fname in os.listdir():
if fname in folder_contents:
continue
if not (fname.endswith(".nc") or fname.endswith(".zarr")):
continue

path = Path(fname)
if path.is_dir():
shutil.rmtree(path)
else:
path.unlink()
print(f"Removed {path}")


@pytest.mark.parametrize("example_fname", example_fnames)
def test_example_script(example_fname):
script = str(example_folder / example_fname)
Expand Down

0 comments on commit 3c94b3e

Please sign in to comment.