Skip to content

Commit

Permalink
Use python for loop even for tensorflow in FastMCInference because of…
Browse files Browse the repository at this point in the history
… random seed is fixed issue
  • Loading branch information
henrysky committed Sep 2, 2024
1 parent c2bc0b5 commit e4fe706
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
4 changes: 2 additions & 2 deletions demo_tutorial/NN_uncertainty_analysis/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Uncertainty with Dropout Variational Inference Demonstration for Regression and
astroNN was used to do a regression task with Dropout VI in a paper
**Deep learning of multi-element abundances from high-resolution spectroscopic data**
which the code available at [https://github.com/henrysky/astroNN_spectra_paper_figures](https://github.com/henrysky/astroNN_spectra_paper_figures)
and the paper available at [[arxiv:1808.04428](https://arxiv.org/abs/1808.04428)][[ADS](https://ui.adsabs.harvard.edu/#abs/2018arXiv180804428L/)].
and the paper available at [[arxiv:1808.04428](https://ui.adsabs.harvard.edu/abs/2019MNRAS.483.3255L/abstract)][[ADS](https://ui.adsabs.harvard.edu/#abs/2018arXiv180804428L/)].
We demonstrated Dropout VI can report reasonable uncertainty with high prediction
accuracy trained on incomplete stellar parameters and abundances data from from high-resolution stellar spectroscopic data.

Expand All @@ -39,7 +39,7 @@ For Dropout variational inference, related material:
* Paper: [Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning](https://arxiv.org/abs/1506.02142)
* Paper: [What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision?](https://arxiv.org/abs/1703.04977)
* Paper: [Bayesian Convolutional Neural Networks with Bernoulli Approximate Variational Inference](https://arxiv.org/abs/1506.02158)
* Yarin Gal's Blog: [What My Deep Model Doesn't Know...](https://mlg.eng.cam.ac.uk/yarin/blog_3d801aa532c1ce.html)
* Yarin Gal's Blog: [What My Deep Model Doesn't Know...](https://www.cs.ox.ac.uk/people/yarin.gal/website/blog_3d801aa532c1ce.html)
* [Demo from Yarin Gal written in javascript](https://github.com/yaringal/HeteroscedasticDropoutUncertainty)

<br>
Expand Down
18 changes: 12 additions & 6 deletions src/astroNN/nn/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,16 +343,22 @@ def loop_fn(i):
return self.layer(inputs)

# vectorizing operation depends on backend
# tensorflow vectorized_map traced operation so there is no randomness which affects e.g., dropout
if keras.backend.backend() == "tensorflow":
# outputs = backend_framework.vectorized_map(loop_fn, self.arange_n)
outputs = backend_framework.map_fn(loop_fn, self.arange_n)
elif keras.backend.backend() == "torch":
# TODO: tensorflow vectorized_map traced operation so there is no randomness which affects e.g., dropout
# if keras.backend.backend() == "tensorflow":
# outputs = backend_framework.vectorized_map(loop_fn, self.arange_n)
if keras.backend.backend() == "torch":
outputs = backend_framework.vmap(
loop_fn, randomness="different", in_dims=0
)(self.arange_n)
else: # fallback to simple for loop
outputs = keras.ops.concatenate([self.layer(inputs) for _ in self.arange_n])
outputs = [self.layer(inputs) for _ in range(self.n)]
if isinstance(outputs[0], dict):
outputs = {
key: keras.ops.stack([output[key] for output in outputs])
for key in outputs[0].keys()
}
else:
outputs = keras.ops.stack(outputs)
return outputs # outputs can be tensor or dict of tensors


Expand Down
2 changes: 1 addition & 1 deletion tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def test_FastMCInference(random_data):
# ======== Simple Keras sequential Model, one input one output ======== #
smodel = keras.models.Sequential()
smodel.add(keras.layers.Input(shape=(7514,)))
smodel.add(keras.layers.Dense(32, input_shape=(7514,)))
smodel.add(keras.layers.Dense(32))
smodel.add(keras.layers.Dense(10, activation="softmax"))
smodel.compile(
optimizer="rmsprop", loss="categorical_crossentropy", metrics=["accuracy"]
Expand Down

0 comments on commit e4fe706

Please sign in to comment.