Skip to content

Minimal implementation of flow matching with JAX

Notifications You must be signed in to change notification settings

gle-bellier/jax-fm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 

Repository files navigation

Flow Matching with Jax

notebook by Georges Le Bellier - Twitter, Website

flow(1)

This notebook proposes a minimal implementation of Flow Matching using Jax, Flax and JaxTyping. I mostly wanted to define the conditional velocity as the time derivative of the interpolant (i.e. conditional path):

def interpolant(x0: Float[Array, "N"] , x1: Float[Array, "N"], t: float) -> Float[Array, "N"]:
    return x0 + (x1 - x0) * t

velocity = jax.jacrev(interpolant, argnums=2)

References:

📄 [1] Flow Matching for Generative Modeling by Yaron Lipman, Ricky T. Q. Chen, Heli Ben-Hamu, Maximilian Nickel, Matt Le - Article

🐍 [2] Jax - Code

🐍 [3] Flax - Code

🐍 [4] JaxTyping by Patrick Kidger - Doc

🐍 [5] Introduction to Flow Matching by Georges Le Bellier - Notebook

About

Minimal implementation of flow matching with JAX

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published