Skip to content

Commit

Permalink
Merge pull request spacetelescope#1864 from kecnry/model-residuals
Browse files Browse the repository at this point in the history
Model fitting residuals
  • Loading branch information
kecnry authored Dec 5, 2022
2 parents 4366f75 + 18ec056 commit 8122602
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 36 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ New Features

- Spinner in plot options while processing changes to contour settings. [#1794]

- Model fitting plugin can optionally expose the residuals as an additional data collection entry.
[#1864]

Cubeviz
^^^^^^^

Expand Down
2 changes: 2 additions & 0 deletions jdaviz/components/plugin_add_results.vue
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
</v-switch>
</v-row>

<slot></slot>

<v-row justify="end">
<j-tooltip :tooltipcontent="label_overwrite ? action_tooltip+' and replace existing entry' : action_tooltip">
<v-btn :disabled="label_invalid_msg.length > 0 || action_disabled"
Expand Down
60 changes: 34 additions & 26 deletions jdaviz/configs/default/plugins/model_fitting/model_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,15 @@ class ModelFitting(PluginTemplateMixin, DatasetSelectMixin,
* :meth:`get_model_component`
* :meth:`set_model_component`
* ``equation`` (:class:`~jdaviz.core.template_mixin.AutoTextField`)
* ``add_results`` (:class:`~jdaviz.core.template_mixin.AddResults`)
* ``cube_fit``
Only exposed for Cubeviz. Whether to fit the model to the cube instead of to the
collapsed spectrum.
* ``add_results`` (:class:`~jdaviz.core.template_mixin.AddResults`)
* ``residuals_expose`` (bool)
Whether to calculate and expose the residuals (model minus data).
* ``residuals`` (:class:`~jdaviz.core.template_mixin.AutoTextField`)
Label of the residuals to apply when calling :meth:`calculate_fit` if ``residuals_expose``
is ``True``.
* :meth:`calculate_fit`
"""
dialog = Bool(False).tag(sync=True)
Expand Down Expand Up @@ -93,6 +98,13 @@ class ModelFitting(PluginTemplateMixin, DatasetSelectMixin,

cube_fit = Bool(False).tag(sync=True)

# residuals (non-cube fit only)
residuals_expose = Bool(False).tag(sync=True)
residuals_label = Unicode().tag(sync=True)
residuals_label_default = Unicode().tag(sync=True)
residuals_label_auto = Bool(True).tag(sync=True)
residuals_label_invalid_msg = Unicode('').tag(sync=True)

def __init__(self, *args, **kwargs):
self._spectrum1d = None
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -138,6 +150,9 @@ def __init__(self, *args, **kwargs):
self.equation = AutoTextField(self, 'model_equation', 'model_equation_default',
'model_equation_auto', 'model_equation_invalid_msg')

self.residuals = AutoTextField(self, 'residuals_label', 'residuals_label_default',
'residuals_label_auto', 'residuals_label_invalid_msg')

# set the filter on the viewer options
self._update_viewer_filters()

Expand All @@ -149,7 +164,7 @@ def user_api(self):
expose += ['spectral_subset', 'model_component', 'poly_order', 'model_component_label',
'model_components', 'create_model_component', 'remove_model_component',
'get_model_component', 'set_model_component',
'equation', 'add_results']
'equation', 'add_results', 'residuals_expose', 'residuals']
if self.config == "cubeviz":
expose += ['cube_fit']
expose += ['calculate_fit']
Expand Down Expand Up @@ -620,6 +635,10 @@ def _set_default_results_label(self, event={}):
label_comps += ["model"]
self.results_label_default = " ".join(label_comps)

@observe("results_label")
def _set_residuals_label_default(self, event={}):
self.residuals_label_default = self.results_label+" residuals"

@observe("cube_fit")
def _update_viewer_filters(self, event={}):
if event.get('new', self.cube_fit):
Expand All @@ -643,6 +662,7 @@ def calculate_fit(self, add_data=True):
-------
fitted model
fitted spectrum/cube
residuals (if ``residuals_expose`` is set to ``True``)
"""
if self.cube_fit:
return self._fit_model_to_cube(add_data=add_data)
Expand Down Expand Up @@ -684,7 +704,16 @@ def _fit_model_to_spectrum(self, add_data):

if add_data:
self.app.fitted_models[self.results_label] = fitted_model
self._register_spectrum({"spectrum": fitted_spectrum})
self.add_results.add_results_from_plugin(fitted_spectrum)

if self.residuals_expose:
# NOTE: this will NOT load into the viewer since we have already called
# add_results_from_plugin above.
self.add_results.add_results_from_plugin(self._spectrum1d-fitted_spectrum,
label=self.residuals.value,
replace=False)

self._set_default_results_label()

# Update component model parameters with fitted values
if isinstance(self._fitted_model, QuantityModel):
Expand All @@ -699,6 +728,8 @@ def _fit_model_to_spectrum(self, add_data):
# Reset the data mask in case we use a different subset next time
self._spectrum1d.mask = self._original_mask

if self.residuals_expose:
return fitted_model, fitted_spectrum, self._spectrum1d-fitted_spectrum
return fitted_model, fitted_spectrum

def _fit_model_to_cube(self, add_data):
Expand Down Expand Up @@ -780,29 +811,6 @@ def _fit_model_to_cube(self, add_data):

return fitted_model, output_cube

def _register_spectrum(self, event):
"""
Add a spectrum to the data collection based on the currently displayed
parameters (these could be user input or fit values).
"""
if self._warn_if_no_equation():
return
# Make sure the initialized models are updated with any user-specified
# parameters
self._update_initialized_parameters()

# Need to run the model fitter with run_fitter=False to get spectrum
if "spectrum" in event:
spectrum = event["spectrum"]
else:
model, spectrum = fit_model_to_spectrum(self._spectrum1d,
self._initialized_models.values(),
self.model_equation,
window=self._window)

self.add_results.add_results_from_plugin(spectrum)
self._set_default_results_label()

def _apply_subset_masks(self, spectrum, subset_component):
"""
For a spectrum/spectral cube ``spectrum``, add a mask attribute
Expand Down
24 changes: 23 additions & 1 deletion jdaviz/configs/default/plugins/model_fitting/model_fitting.vue
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,29 @@
action_tooltip="Fit the model to the data"
:action_disabled="model_equation_invalid_msg.length > 0"
@click:action="apply"
></plugin-add-results>
>
<div v-if="config!=='cubeviz' || !cube_fit">
<v-row>
<v-switch
v-model="residuals_expose"
label="Expose residuals"
hint="Whether to compute and export residuals (data minus model)."
persistent-hint
></v-switch>
</v-row>

<plugin-auto-label
v-if="residuals_expose"
:value.sync="residuals_label"
:default="residuals_label_default"
:auto.sync="residuals_label_auto"
:invalid_msg="residuals_label_invalid_msg"
label="Residuals Data Label"
hint="Label for the residuals. Data entry will not be loaded into the viewer automatically."
></plugin-auto-label>

</div>
</plugin-add-results>

<v-row>
<span class="v-messages v-messages__message text--secondary">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ def test_fit_gaussian_with_fixed_mean(specviz_helper, spectrum1d):
old_mean = params[1]['value']
old_std = params[2]['value']

result = modelfit_plugin.calculate_fit()[0]
modelfit_plugin.residuals_expose = True
result, spectrum, resids = modelfit_plugin.calculate_fit()

# Make sure mean is really fixed.
assert_allclose(result.mean.value, old_mean)
Expand Down
19 changes: 11 additions & 8 deletions jdaviz/core/template_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,14 +1696,17 @@ def _on_label_changed(self, msg={}):
self.label_invalid_msg = ''
self.label_overwrite = False

def add_results_from_plugin(self, data_item, replace=None):
def add_results_from_plugin(self, data_item, replace=None, label=None):
"""
Add ``data_item`` to the app's data_collection according to the default or user-provided
label and adds to any requested viewers.
"""
if self.label_invalid_msg:
raise ValueError(self.label_invalid_msg)

if label is None:
label = self.label

if replace is None:
replace = self.viewer.selected_reference != 'spectrum-viewer'

Expand All @@ -1714,29 +1717,29 @@ def add_results_from_plugin(self, data_item, replace=None):
viewer_item = self.app._viewer_item_by_reference(viewer_reference)
viewer = self.app.get_viewer(viewer_reference)
viewer_loaded_labels = [layer.layer.label for layer in viewer.layers]
add_to_viewer_selected = viewer_reference if self.label in viewer_loaded_labels else 'None' # noqa
visible = self.label in viewer_item['visible_layers']
add_to_viewer_selected = viewer_reference if label in viewer_loaded_labels else 'None' # noqa
visible = label in viewer_item['visible_layers']
else:
add_to_viewer_selected = self.add_to_viewer_selected
visible = True

if self.label in self.app.data_collection:
if label in self.app.data_collection:
if add_to_viewer_selected != 'None':
self.app.remove_data_from_viewer(self.viewer.selected_reference, self.label)
self.app.data_collection.remove(self.app.data_collection[self.label])
self.app.remove_data_from_viewer(self.viewer.selected_reference, label)
self.app.data_collection.remove(self.app.data_collection[label])

if not hasattr(data_item, 'meta'):
data_item.meta = {}
data_item.meta['Plugin'] = self._plugin.__class__.__name__
if self.app.config == 'mosviz':
data_item.meta['mosviz_row'] = self.app.state.settings['mosviz_row']
self.app.add_data(data_item, self.label)
self.app.add_data(data_item, label)

if add_to_viewer_selected != 'None':
# replace the contents in the selected viewer with the results from this plugin
# TODO: switch to an instance/classname check?
self.app.add_data_to_viewer(self.viewer.selected_id,
self.label,
label,
visible=visible, clear_other_data=replace)

# update overwrite warnings, etc
Expand Down

0 comments on commit 8122602

Please sign in to comment.