notebook by Georges Le Bellier - Twitter, Website
This notebook proposes a minimal implementation of Flow Matching using Jax, Flax and JaxTyping. I mostly wanted to define the conditional velocity as the time derivative of the interpolant (i.e. conditional path):
def interpolant(x0: Float[Array, "N"] , x1: Float[Array, "N"], t: float) -> Float[Array, "N"]:
return x0 + (x1 - x0) * t
velocity = jax.jacrev(interpolant, argnums=2)
📄 [1] Flow Matching for Generative Modeling by Yaron Lipman, Ricky T. Q. Chen, Heli Ben-Hamu, Maximilian Nickel, Matt Le - Article
🐍 [2] Jax - Code
🐍 [3] Flax - Code
🐍 [4] JaxTyping by Patrick Kidger - Doc
🐍 [5] Introduction to Flow Matching by Georges Le Bellier - Notebook