From 652b93e8d14abc8a252895b90f013c0aa9ffb71b Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Wed, 2 Oct 2024 13:46:36 -0400 Subject: [PATCH] Format files using ruff --- pyrenew/datasets/wastewater.py | 1 - pyrenew/distutil.py | 1 + test/test_forecast.py | 3 +-- test/test_integrate_discrete.py | 1 + test/test_latent_admissions.py | 3 +-- test/test_latent_infections.py | 3 +-- test/test_model_basic_renewal.py | 3 +-- test/test_model_hosp_admissions.py | 3 +-- test/test_predictive.py | 3 +-- test/test_random_key.py | 3 +-- test/test_scan_rv_plate_compatibility.py | 1 + 11 files changed, 10 insertions(+), 15 deletions(-) diff --git a/pyrenew/datasets/wastewater.py b/pyrenew/datasets/wastewater.py index 2ff68556..5dfb2f86 100644 --- a/pyrenew/datasets/wastewater.py +++ b/pyrenew/datasets/wastewater.py @@ -4,7 +4,6 @@ This module loads the package dataset named 'wastewater' and provides functions to manipulate the data. It uses the 'polars' library. """ - from importlib.resources import files import polars as pl diff --git a/pyrenew/distutil.py b/pyrenew/distutil.py index 315fc6b3..b8bcadf6 100755 --- a/pyrenew/distutil.py +++ b/pyrenew/distutil.py @@ -6,6 +6,7 @@ found in renewal equation modeling, such as discrete time-to-event distributions """ + from __future__ import annotations import jax.numpy as jnp diff --git a/test/test_forecast.py b/test/test_forecast.py index df70436e..abe7cbd5 100644 --- a/test/test_forecast.py +++ b/test/test_forecast.py @@ -1,7 +1,5 @@ # numpydoc ignore=GL08 -from test.utils import SimpleRt - import jax.numpy as jnp import jax.random as jr import numpyro @@ -17,6 +15,7 @@ from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation from pyrenew.randomvariable import DistributionalVariable +from test.utils import SimpleRt def test_forecast(): diff --git a/test/test_integrate_discrete.py b/test/test_integrate_discrete.py index 0e168b18..da7b6288 100644 --- a/test/test_integrate_discrete.py +++ b/test/test_integrate_discrete.py @@ -2,6 +2,7 @@ Test the integrate_discrete function used in DifferencedProcess and elsewhere """ + import jax import jax.numpy as jnp import pytest diff --git a/test/test_latent_admissions.py b/test/test_latent_admissions.py index e652e382..053e1545 100644 --- a/test/test_latent_admissions.py +++ b/test/test_latent_admissions.py @@ -1,7 +1,5 @@ # numpydoc ignore=GL08 -from test.utils import SimpleRt - import jax.numpy as jnp import numpy.testing as testing import numpyro @@ -10,6 +8,7 @@ from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.latent import HospitalAdmissions, Infections from pyrenew.randomvariable import DistributionalVariable +from test.utils import SimpleRt def test_admissions_sample(): diff --git a/test/test_latent_infections.py b/test/test_latent_infections.py index 0edd3231..0c33f597 100755 --- a/test/test_latent_infections.py +++ b/test/test_latent_infections.py @@ -1,13 +1,12 @@ # numpydoc ignore=GL08 -from test.utils import SimpleRt - import jax.numpy as jnp import numpy.testing as testing import numpyro import pytest from pyrenew.latent import Infections +from test.utils import SimpleRt def test_infections_as_deterministic(): diff --git a/test/test_model_basic_renewal.py b/test/test_model_basic_renewal.py index dfc33c3b..8b1f99f7 100644 --- a/test/test_model_basic_renewal.py +++ b/test/test_model_basic_renewal.py @@ -1,7 +1,5 @@ # numpydoc ignore=GL08 -from test.utils import SimpleRt - import jax.numpy as jnp import jax.random as jr import numpy as np @@ -19,6 +17,7 @@ from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation from pyrenew.randomvariable import DistributionalVariable +from test.utils import SimpleRt def test_model_basicrenewal_no_timepoints_or_observations(): diff --git a/test/test_model_hosp_admissions.py b/test/test_model_hosp_admissions.py index 4aed146c..68b9a86b 100644 --- a/test/test_model_hosp_admissions.py +++ b/test/test_model_hosp_admissions.py @@ -1,7 +1,5 @@ # numpydoc ignore=GL08 -from test.utils import SimpleRt - import jax.numpy as jnp import jax.random as jr import numpy as np @@ -24,6 +22,7 @@ from pyrenew.model import HospitalAdmissionsModel from pyrenew.observation import PoissonObservation from pyrenew.randomvariable import DistributionalVariable +from test.utils import SimpleRt def test_model_hosp_no_timepoints_or_observations(): diff --git a/test/test_predictive.py b/test/test_predictive.py index 414da12e..4619a3aa 100644 --- a/test/test_predictive.py +++ b/test/test_predictive.py @@ -3,8 +3,6 @@ when no posterior samples are available. """ -from test.utils import SimpleRt - import jax.numpy as jnp import numpyro.distributions as dist import pytest @@ -18,6 +16,7 @@ from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation from pyrenew.randomvariable import DistributionalVariable +from test.utils import SimpleRt pmf_array = jnp.array([0.25, 0.1, 0.2, 0.45]) gen_int = DeterministicPMF(name="gen_int", value=pmf_array) diff --git a/test/test_random_key.py b/test/test_random_key.py index 99314e8f..11d0899c 100644 --- a/test/test_random_key.py +++ b/test/test_random_key.py @@ -3,8 +3,6 @@ with different random keys behave appropriately. """ -from test.utils import SimpleRt - import jax.numpy as jnp import jax.random as jr import numpyro @@ -20,6 +18,7 @@ from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation from pyrenew.randomvariable import DistributionalVariable +from test.utils import SimpleRt def create_test_model(): # numpydoc ignore=GL08 diff --git a/test/test_scan_rv_plate_compatibility.py b/test/test_scan_rv_plate_compatibility.py index 4b60b4bf..f75540d1 100644 --- a/test/test_scan_rv_plate_compatibility.py +++ b/test/test_scan_rv_plate_compatibility.py @@ -3,6 +3,7 @@ classes behave as expected in a :func:`numpyro.plate` context. """ + import jax.numpy as jnp import numpyro import numpyro.distributions as dist