forked from google-deepmind/deepmind-research
-
Notifications
You must be signed in to change notification settings - Fork 0
/
decoders.py
88 lines (71 loc) · 2.49 KB
/
decoders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Decoder architectures to be used with VAE."""
import abc
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
class DecoderBase(hk.Module):
"""Base class for decoder network classes."""
def __init__(self, obs_var: float):
"""Class initializer.
Args:
obs_var: oversation variance of the dataset.
"""
super().__init__()
self._obs_var = obs_var
@abc.abstractmethod
def __call__(self, z: jnp.ndarray) -> jnp.ndarray:
"""Reconstruct from a given latent sample.
Args:
z: latent samples of shape (batch_size, latent_dim)
Returns:
Reconstruction with shape (batch_size, ...).
"""
def data_fidelity(
self,
input_data: jnp.ndarray,
recons: jnp.ndarray,
) -> jnp.ndarray:
"""Compute Data fidelity (recons loss) for given input and recons.
Args:
input_data: Input batch of shape (batch_size, ...).
recons: Reconstruction of the input data. An array with the same shape as
`input_data.data`.
Returns:
Computed data fidelity term across batch of data. An array of shape
`(batch_size,)`.
"""
error = (input_data - recons).reshape(input_data.shape[0], -1)
return -0.5 * jnp.sum(jnp.square(error), axis=1) / self._obs_var
class ColorMnistMLPDecoder(DecoderBase):
"""MLP decoder for Color Mnist."""
_hidden_units = (200, 200, 200, 200)
_image_dims = (28, 28, 3) # Dimensions of a single MNIST image.
def __call__(self, z: jnp.ndarray) -> jnp.ndarray:
"""Reconstruct with given latent sample.
Args:
z: latent samples of shape (batch_size, latent_dim)
Returns:
Reconstructions data of shape (batch_size, 28, 28, 3).
"""
out = z
for units in self._hidden_units:
out = hk.Linear(units)(out)
out = jax.nn.relu(out)
out = hk.Linear(np.product(self._image_dims))(out)
out = jax.nn.sigmoid(out)
return jnp.reshape(out, (-1,) + self._image_dims)