diff --git a/tests/test_model.py b/tests/test_model.py index 096a2ba..53dddeb 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from astropy import time, units +from astropy import coordinates, time, units from zodipy import Model, grid_number_density, model_registry @@ -60,12 +60,16 @@ def test_get_parameters() -> None: def test_update_model() -> None: """Tests that the model can be updated.""" model = Model(20 * units.micron, name="dirbe") - + obstime = time.Time("2021-01-01T00:00:00") + skycoord = coordinates.SkyCoord(20, 30, unit=units.deg, obstime=obstime) + emission_before = model.evaluate(skycoord) parameters = model.get_parameters() comp = random.choice(list(parameters["comps"].keys())) parameter = random.choice(list(parameters["comps"][comp])) parameters["comps"][comp][parameter] = random.random() model.update_parameters(parameters) + emission_after = model.evaluate(skycoord) + assert not np.allclose(emission_before, emission_after) def test_get_model_raises_error() -> None: diff --git a/zodipy/model.py b/zodipy/model.py index 94ff412..15bed65 100644 --- a/zodipy/model.py +++ b/zodipy/model.py @@ -88,16 +88,10 @@ def __init__( normalized_weights = weights / integrate.trapezoid(weights, x) else: normalized_weights = None - self._b_nu_table = tabulate_blackbody_emission(x, normalized_weights) - # We interpolate the spectrally dependant zodiacal light parameters over the provided - # bandpass or delta frequency/wavelength. - interp_and_unpack_func = get_model_interp_func(self._ipd_model) - self._interped_comp_params, self._interped_shared_params = interp_and_unpack_func( - x, normalized_weights, self._ipd_model - ) - - self._ephemeris = ephemeris + self._x = x + self._normalized_weights = normalized_weights + self._b_nu_table = tabulate_blackbody_emission(self._x, self._normalized_weights) quad_points, quad_weights = np.polynomial.legendre.leggauss(gauss_quad_degree) self._integrate_leggauss = functools.partial( @@ -106,14 +100,14 @@ def __init__( weights=quad_weights, ) - # Build partial functions to be evaluated when simulating the zodiacal light. These partials - # are pre-populated functions that contains all non line-of-sight related parameters. - self._number_density_partials = get_partial_number_density_func(comps=self._ipd_model.comps) - self._shared_brightness_partial = functools.partial( - self._ipd_model.brightness_at_step_callable, - bp_interpolation_table=self._b_nu_table, - **self._interped_shared_params, - ) + self._ephemeris = ephemeris + + # Make mypy happy by declaring types of to-be initialized attributes. + self._number_density_partials: dict[ComponentLabel, functools.partial] + self._interped_comp_params: dict[ComponentLabel, dict] + self._interped_shared_params: dict + + self._init_ipd_model_partials() def evaluate( self, @@ -178,6 +172,7 @@ def evaluate( number_density_partials = self._number_density_partials shared_brightness_partial = self._shared_brightness_partial + dist_coords_to_cores = skycoord.size > nprocesses and nprocesses > 1 if instantaneous or not dist_coords_to_cores: # Populate the instantaneous Earth and observer position in the partial functions. @@ -276,6 +271,26 @@ def evaluate( emission <<= units.MJy / units.sr return emission if return_comps else emission.sum(axis=0) + def _init_ipd_model_partials(self) -> None: + """Initialize the partial functions for the interplanetary dust model. + + The spectrally dependant model parameters are interpolated over the provided bandpass or + delta frequency/wavelength. The partial functions are pre-populated functions that contains + all non line-of-sight related parameters. + """ + interp_and_unpack_func = get_model_interp_func(self._ipd_model) + dicts = interp_and_unpack_func(self._x, self._normalized_weights, self._ipd_model) + self._interped_comp_params = dicts[0] + self._interped_shared_params = dicts[1] + + self._shared_brightness_partial = functools.partial( + self._ipd_model.brightness_at_step_callable, + bp_interpolation_table=self._b_nu_table, + **self._interped_shared_params, + ) + + self._number_density_partials = get_partial_number_density_func(comps=self._ipd_model.comps) + def get_parameters(self) -> dict: """Return a dictionary containing the interplanetary dust model parameters. @@ -294,7 +309,8 @@ def update_parameters(self, parameters: dict) -> None: Args: parameters: Dictionary of parameters to update. The keys must be the names of the parameters as defined in the model. To get the parameters dict - of an existing model, use `Zodipy("dirbe").get_parameters()`. + of an existing model, use the`get_parameters` method of an initialized + `zodipy.Model`. """ _dict = parameters.copy() _dict["comps"] = {} @@ -308,6 +324,7 @@ def update_parameters(self, parameters: dict) -> None: _dict[key] = {ComponentLabel(k): v for k, v in value.items()} self._ipd_model = self._ipd_model.__class__(**_dict) + self._init_ipd_model_partials() def validate_user_input(skycoord: coords.SkyCoord, obspos: units.Quantity | str) -> time.Time: