diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5dcf2271..5144226e 100755 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,20 +20,17 @@ repos: language: script always_run: true files: "docs/source/tutorials/.*(qmd|md)$" - - repo: https://github.com/psf/black - rev: 23.10.0 - hooks: - - id: black - args: ["--line-length", "79"] - - repo: https://github.com/PyCQA/isort - rev: 5.12.0 - hooks: - - id: isort - args: ["--profile", "black", "--line-length", "79"] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.0 + rev: v0.6.8 hooks: + # Sort imports + - id: ruff + args: ['check', '--select', 'I', '--fix'] + # Run the linter - id: ruff + # Run the formatter + - id: ruff-format + args: ['--line-length', '79'] - repo: https://github.com/numpy/numpydoc rev: v1.7.0 hooks: diff --git a/pyrenew/datasets/wastewater.py b/pyrenew/datasets/wastewater.py index 2ff68556..5dfb2f86 100644 --- a/pyrenew/datasets/wastewater.py +++ b/pyrenew/datasets/wastewater.py @@ -4,7 +4,6 @@ This module loads the package dataset named 'wastewater' and provides functions to manipulate the data. It uses the 'polars' library. """ - from importlib.resources import files import polars as pl diff --git a/pyrenew/distutil.py b/pyrenew/distutil.py index 315fc6b3..b8bcadf6 100755 --- a/pyrenew/distutil.py +++ b/pyrenew/distutil.py @@ -6,6 +6,7 @@ found in renewal equation modeling, such as discrete time-to-event distributions """ + from __future__ import annotations import jax.numpy as jnp diff --git a/test/test_forecast.py b/test/test_forecast.py index df70436e..abe7cbd5 100644 --- a/test/test_forecast.py +++ b/test/test_forecast.py @@ -1,7 +1,5 @@ # numpydoc ignore=GL08 -from test.utils import SimpleRt - import jax.numpy as jnp import jax.random as jr import numpyro @@ -17,6 +15,7 @@ from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation from pyrenew.randomvariable import DistributionalVariable +from test.utils import SimpleRt def test_forecast(): diff --git a/test/test_integrate_discrete.py b/test/test_integrate_discrete.py index 0e168b18..da7b6288 100644 --- a/test/test_integrate_discrete.py +++ b/test/test_integrate_discrete.py @@ -2,6 +2,7 @@ Test the integrate_discrete function used in DifferencedProcess and elsewhere """ + import jax import jax.numpy as jnp import pytest diff --git a/test/test_latent_admissions.py b/test/test_latent_admissions.py index e652e382..053e1545 100644 --- a/test/test_latent_admissions.py +++ b/test/test_latent_admissions.py @@ -1,7 +1,5 @@ # numpydoc ignore=GL08 -from test.utils import SimpleRt - import jax.numpy as jnp import numpy.testing as testing import numpyro @@ -10,6 +8,7 @@ from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.latent import HospitalAdmissions, Infections from pyrenew.randomvariable import DistributionalVariable +from test.utils import SimpleRt def test_admissions_sample(): diff --git a/test/test_latent_infections.py b/test/test_latent_infections.py index 0edd3231..0c33f597 100755 --- a/test/test_latent_infections.py +++ b/test/test_latent_infections.py @@ -1,13 +1,12 @@ # numpydoc ignore=GL08 -from test.utils import SimpleRt - import jax.numpy as jnp import numpy.testing as testing import numpyro import pytest from pyrenew.latent import Infections +from test.utils import SimpleRt def test_infections_as_deterministic(): diff --git a/test/test_model_basic_renewal.py b/test/test_model_basic_renewal.py index dfc33c3b..8b1f99f7 100644 --- a/test/test_model_basic_renewal.py +++ b/test/test_model_basic_renewal.py @@ -1,7 +1,5 @@ # numpydoc ignore=GL08 -from test.utils import SimpleRt - import jax.numpy as jnp import jax.random as jr import numpy as np @@ -19,6 +17,7 @@ from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation from pyrenew.randomvariable import DistributionalVariable +from test.utils import SimpleRt def test_model_basicrenewal_no_timepoints_or_observations(): diff --git a/test/test_model_hosp_admissions.py b/test/test_model_hosp_admissions.py index 4aed146c..68b9a86b 100644 --- a/test/test_model_hosp_admissions.py +++ b/test/test_model_hosp_admissions.py @@ -1,7 +1,5 @@ # numpydoc ignore=GL08 -from test.utils import SimpleRt - import jax.numpy as jnp import jax.random as jr import numpy as np @@ -24,6 +22,7 @@ from pyrenew.model import HospitalAdmissionsModel from pyrenew.observation import PoissonObservation from pyrenew.randomvariable import DistributionalVariable +from test.utils import SimpleRt def test_model_hosp_no_timepoints_or_observations(): diff --git a/test/test_predictive.py b/test/test_predictive.py index 414da12e..4619a3aa 100644 --- a/test/test_predictive.py +++ b/test/test_predictive.py @@ -3,8 +3,6 @@ when no posterior samples are available. """ -from test.utils import SimpleRt - import jax.numpy as jnp import numpyro.distributions as dist import pytest @@ -18,6 +16,7 @@ from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation from pyrenew.randomvariable import DistributionalVariable +from test.utils import SimpleRt pmf_array = jnp.array([0.25, 0.1, 0.2, 0.45]) gen_int = DeterministicPMF(name="gen_int", value=pmf_array) diff --git a/test/test_random_key.py b/test/test_random_key.py index 99314e8f..11d0899c 100644 --- a/test/test_random_key.py +++ b/test/test_random_key.py @@ -3,8 +3,6 @@ with different random keys behave appropriately. """ -from test.utils import SimpleRt - import jax.numpy as jnp import jax.random as jr import numpyro @@ -20,6 +18,7 @@ from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation from pyrenew.randomvariable import DistributionalVariable +from test.utils import SimpleRt def create_test_model(): # numpydoc ignore=GL08 diff --git a/test/test_scan_rv_plate_compatibility.py b/test/test_scan_rv_plate_compatibility.py index 4b60b4bf..f75540d1 100644 --- a/test/test_scan_rv_plate_compatibility.py +++ b/test/test_scan_rv_plate_compatibility.py @@ -3,6 +3,7 @@ classes behave as expected in a :func:`numpyro.plate` context. """ + import jax.numpy as jnp import numpyro import numpyro.distributions as dist