diff --git a/examples/generative/vq_vae.py b/examples/generative/vq_vae.py index cccd4037c9..f0a2bb6584 100644 --- a/examples/generative/vq_vae.py +++ b/examples/generative/vq_vae.py @@ -139,7 +139,14 @@ def get_code_indices(self, flattened_inputs): **A note on straight-through estimation**: This line of code does the straight-through estimation part: `quantized = x + -tf.stop_gradient(quantized - x)`. During backpropagation, `(quantized - x)` won't be +tf.stop_gradient(quantized - x)`. The straight-through estimator affects the +forward computation and backpropagation differently. + +During forward computation the x values cancel out and the nearest embedding `quantized` +is passed to the decoder. + +During backpropagation, the `tf.stop_gradient` operator prevents the contribution +of its inputs from being taken into account. As a result, `(quantized - x)` won't be included in the computation graph and the gradients obtained for `quantized` will be copied for `inputs`. Thanks to [this video](https://youtu.be/VZFVUrYcig0?t=1393) for helping me understand this technique.