Skip to content

Latest commit

 

History

History
24 lines (15 loc) · 1.13 KB

README.md

File metadata and controls

24 lines (15 loc) · 1.13 KB

Flow Matching with Jax

notebook by Georges Le Bellier - Twitter, Website

flow(1)

This notebook proposes a minimal implementation of Flow Matching using Jax, Flax and JaxTyping. I mostly wanted to define the conditional velocity as the time derivative of the interpolant (i.e. conditional path):

def interpolant(x0: Float[Array, "N"] , x1: Float[Array, "N"], t: float) -> Float[Array, "N"]:
    return x0 + (x1 - x0) * t

velocity = jax.jacrev(interpolant, argnums=2)

References:

📄 [1] Flow Matching for Generative Modeling by Yaron Lipman, Ricky T. Q. Chen, Heli Ben-Hamu, Maximilian Nickel, Matt Le - Article

🐍 [2] Jax - Code

🐍 [3] Flax - Code

🐍 [4] JaxTyping by Patrick Kidger - Doc

🐍 [5] Introduction to Flow Matching by Georges Le Bellier - Notebook