marp | paginate |
---|---|
true |
true |
Akash Srivastava$^{\ast,1,2}$, Kai Xu$^{\ast,3}$, Michael U. Gutmann$^{3}$, Charles Sutton$^{3,4,5}$
To appear in ICLR 2020; OpenReview: https://openreview.net/forum?id=SJg7spEYDS
Implicit deep generative models:
Maximum mean discrepancy networks (MMD-nets)
- ❌ can only work well with low-dimensional data
- ✅ are very stable to train by avoiding the saddle-point optimization problem
Adversarial generative models (e.g. GANs, MMD-GANs)
- ✅ can generate high-dimensional data such as natural images
- ❌ are very difficult to train due to the saddle-point optimization problem
Q: Can we have two ✅✅? A: Yes. Generative ratio matching (GRAM) is a stable learning algorithm for implicit deep generative models that does not involve a saddle-point optimization problem and therefore is easy to train 🎉.
The maximum mean discrepancy (MMD) between two distributions
- \frac{2}{NM}\sum_{i=1}^N\sum_{j=1}^M k(x_i, y_j)
- \frac{1}{M^2}\sum_{j=1}^M\sum_{j'=1}^M k(y_j,y_{j'}) $$ MMD-nets trains neural generators by minimizing this empirical estimate.
Density ratio estimation: find
Finite moments under the fixed design setup gives
Huang et al. (2007) shows that by changing
Analytical solution: $\hat{\mathbf{r}} = \mathbf{K}{q,q}^{-1}\mathbf{K}{q,p} \mathbf{1}$, where $[\mathbf{K}{p,q}]{i,j} = k(x_i^p, x_j^q)$ given samples
Two targets in the training loop
- Learning a projection function
$f_\theta$ that maps the data space into a low-dimensional manifold which preserves the density ratio between data and model.- "Preserves":
$\frac{p_x(x)}{q_x(x)} = \frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))}$ , measured by$D(\theta) = \int q_x(x) \left( \frac{p_x(x)}{q_x(x)} - \frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))} \right)^2 dx$ - ❓
$\frac{p_x(x)}{q_x(x)}$ is hard to estimate in the high-dimensional space ...
- "Preserves":
- Matching the model
$G_\gamma$ to data in the low-dimensional manifold by minimizing MMD- 👍 MMD works well in low dimensional space
-
$\mathrm{MMD} = 0$ ➡️$\frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))} = 1$ ➡️$\frac{p_x(x)}{q_x(x)} = 1$
Both with empirical estimates based on samples from the data
1️⃣ Learning the projection function
A reminder on LOTUS:
Monte Carlo approximation of PD
$$
\mathrm{PD}(\bar{q}, \bar{p}) = \int \bar{q}(f_\theta(x)) \left( \frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))} \right)^2 df_\theta(x) - 1 \approx \frac{1}{N} \sum_{i=1}^N \left( \frac{\bar{p}(f_\theta(x_i^q))}{\bar{q}(f_\theta(x_i^q))} \right)^2 - 1
$$
where
Given samples
2️⃣ Minimizing the empirical estimator of MMD in the low-dimensional manifold
$$ \begin{aligned} \min_\gamma \Bigg[&\frac{1}{N^2}\sum_{i=1}^N\sum_{i'=1}^N k(f_\theta(x_i),f_\theta(x_{i'}))
- \frac{2}{NM}\sum_{i=1}^N\sum_{j=1}^M k(f_\theta(x_i), f_\theta(G_\gamma(z_j)))\ &\quad + \frac{1}{M^2}\sum_{j=1}^M\sum_{j'=1}^M k(f_\theta(G_\gamma(z_j)),f_\theta(G_\gamma(z_{j'}))) \Bigg ] \end{aligned} $$
with respect to its parameters
Loop until convergence
- Sample a minibatch of data and generate samples from
$G_\gamma$ - Project data and generated samples using
$f_\theta$ - Compute the kernel Gram matrices using Gaussian kernels in the projected space
- Compute the objectives for
$f_\theta$ and$G_\gamma$ using the same kernel Gram matrices - Backprop two objectives to get the gradients for
$\theta$ and$\gamma$ - Perform gradient update for
$\theta$ and$\gamma$
😎 Fun fact: the objectives in our GRAM algorithm both heavily relies on the use of kernel Gram matrices.
GAN | MMD-net | MMD-GAN | GRAM-net |
---|---|---|---|
Blue: data, Orange: samples Top: original, Bottom: projected
x-axis = noise dimension and y-axis = generator layer size
x-axis = noise dimension and y-axis = generator layer size