Skip to content

Optimisers.jl defines many standard optimisers and utilities for learning loops.

License

Notifications You must be signed in to change notification settings

MurrellGroup/Optimisers.jl

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Optimisers.jl

Optimisers.jl defines many standard gradient-based optimisation rules, and tools for applying them to deeply nested models.

This was written as the new training system for Flux.jl neural networks, and also used by Lux.jl. But it can be used separately on any array, or anything else understood by Functors.jl.

Warning

With version 0.4 the default update rule for AdamW has changed to match the pytorch implementation. The previous rule, which is closer to the original paper, can be obtained by setting AdamW(..., couple=false). See this issue for more details.

Installation

] add Optimisers

Usage

The core idea is that optimiser state (such as momentum) is explicitly handled. It is initialised by setup, and then at each step, update returns both the new state, and the model with its trainable parameters adjusted:

state = Optimisers.setup(Optimisers.Adam(), model)  # just once

grad = Zygote.gradient(m -> loss(m(x), y), model)[1]

state, model = Optimisers.update(state, model, grad)  # at every step

For models with deeply nested layers containing the parameters (like Flux.jl models), this state is a similarly nested tree. As is the gradient: if using Zygote, you must use the "explicit" style as shown, not the "implicit" one with Params.

The function destructure collects all the trainable parameters into one vector, and returns this along with a function to re-build a similar model:

vector, re = Optimisers.destructure(model)

model2 = re(2 .* vector)

The documentation explains usage in more detail, describes all the optimization rules, and shows how to define new ones.

About

Optimisers.jl defines many standard optimisers and utilities for learning loops.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Julia 100.0%