Skip to content

Commit

Permalink
Release v0.5 (#476)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 authored Feb 17, 2023
1 parent 12ea371 commit 884107c
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 81 deletions.
19 changes: 19 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,24 @@
# Changelog

## v0.5.0 (2022-02-16)

### Enhancements

* Bump Nx dependency
* Update documentation to account for channels last default
* Improve error message in compilation/build errors for models
* Remove deprecated `transform`

### Deprecations

* Deprecate `Axon.Loop.handle/4`

## v0.4.1 (2022-01-21)

### Bug Fixes

* Fixed a shape mismatch when training with certain optimizers

## v0.4.0 (2022-01-19)

### Enhancements
Expand Down
3 changes: 3 additions & 0 deletions lib/axon/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@ defmodule Axon.Defn do

@impl true
def __compile__(_, _, _, _), do: raise("not implemented")

@impl true
def __partitions_options__(_), do: raise("not implemented")
end
12 changes: 2 additions & 10 deletions lib/axon/loop.ex
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,6 @@ defmodule Axon.Loop do
:iteration_completed, # On iteration complete
:epoch_completed, # On epoch complete
:epoch_halted, # On epoch halt, if early halted
:halted, # On loop halt, if early halted
:completed # On loop completion
]
You can attach event handlers to events using `Axon.Loop.handle_event/4`:
Expand Down Expand Up @@ -229,9 +227,7 @@ defmodule Axon.Loop do
:iteration_started,
:iteration_completed,
:epoch_completed,
:epoch_halted,
:halted,
:completed
:epoch_halted
]

@default_handlers %{
Expand Down Expand Up @@ -896,8 +892,6 @@ defmodule Axon.Loop do
:iteration_completed, # On iteration complete
:epoch_completed, # On epoch complete
:epoch_halted, # On epoch halt, if early halted
:halted, # On loop halt, if early halted
:completed # On loop completion
]
Generally, event handlers are side-effecting operations which provide some
Expand Down Expand Up @@ -1066,7 +1060,6 @@ defmodule Axon.Loop do

metrics =
Enum.reduce(metric_fns, evaluator, fn {k, {_, v}}, loop -> metric(loop, v, k) end)
|> log(fn _ -> "\n" end, event: :completed)
|> run(validation_data, model_state)
|> Access.get(0)
|> Map.new(fn {k, v} ->
Expand Down Expand Up @@ -1733,8 +1726,7 @@ defmodule Axon.Loop do
end
end

{_, state} = fire_event(status, handler_fns, state, debug?)
state = %State{state | metrics: final_metrics}
state = %State{state | metrics: final_metrics, status: status}

output_transform.(state)
end
Expand Down
4 changes: 4 additions & 0 deletions lib/axon/loop/state.ex
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,14 @@ defmodule Axon.Loop.State do
`event_counts` is a metadata field which stores information about the number
of times each event has been fired. This is useful when creating custom filters.
`status` refers to the loop state status after the loop has executed. You can
use this to determine if the loop ran to completion or if it was halted early.
"""
@enforce_keys [:step_state]
defstruct [
:step_state,
:status,
handler_metadata: %{},
epoch: 0,
max_epoch: 1,
Expand Down
14 changes: 7 additions & 7 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ defmodule Axon.MixProject do
use Mix.Project

@source_url "https://github.com/elixir-nx/axon"
@version "0.4.1"
@version "0.5.0"

def project do
[
Expand Down Expand Up @@ -35,9 +35,9 @@ defmodule Axon.MixProject do
# Run "mix help deps" to learn about dependencies.
defp deps do
[
{:exla, "~> 0.4.0", [only: :test] ++ exla_opts()},
{:torchx, "~> 0.4.0", [only: :test] ++ torchx_opts()},
{:nx, "~> 0.4.0", nx_opts()},
{:exla, "~> 0.5.0", [only: :test] ++ exla_opts()},
{:torchx, "~> 0.5.0", [only: :test] ++ torchx_opts()},
{:nx, "~> 0.5.0", nx_opts()},
{:ex_doc, "~> 0.23", only: :docs},
{:table_rex, "~> 3.1.1", optional: true},
{:kino, "~> 0.7", optional: true},
Expand All @@ -57,23 +57,23 @@ defmodule Axon.MixProject do
if path = System.get_env("AXON_NX_PATH") do
[path: path, override: true]
else
[github: "elixir-nx/nx", sparse: "nx", override: true]
[]
end
end

defp exla_opts do
if path = System.get_env("AXON_EXLA_PATH") do
[path: path]
else
[github: "elixir-nx/nx", sparse: "exla", override: true]
[]
end
end

defp torchx_opts do
if path = System.get_env("AXON_TORCHX_PATH") do
[path: path]
else
[github: "elixir-nx/nx", sparse: "torchx", override: true]
[]
end
end

Expand Down
8 changes: 4 additions & 4 deletions mix.lock
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
%{
"castore": {:hex, :castore, "0.1.22", "4127549e411bedd012ca3a308dede574f43819fe9394254ca55ab4895abfa1a2", [:mix], [], "hexpm", "c17576df47eb5aa1ee40cc4134316a99f5cad3e215d5c77b8dd3cfef12a22cac"},
"cc_precompiler": {:hex, :cc_precompiler, "0.1.5", "ac3ef86f31ab579b856192a948e956cc3e4bb5006e303c4ab4b24958108e218a", [:mix], [{:elixir_make, "~> 0.7.3", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "ee5b2e56eb03798231a3d322579fff509139a534ef54205d04c188e18cab1f57"},
"complex": {:hex, :complex, "0.4.3", "84db4aad241099a8785446ac6eacf498bf3a60634a0e45c7745d875714ddbf98", [:mix], [], "hexpm", "2ceda96ebddcc22697974f1a2666d4cc5dfdd34f8cd8c4f9dced037bcb41eeb5"},
"complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"},
"dll_loader_helper": {:hex, :dll_loader_helper, "0.1.10", "ba85d66f82c1748513dbaee71aa9d0593bb9a65dba246b980753c4d683b0a07b", [:make, :mix], [{:castore, ">= 0.0.0", [hex: :castore, repo: "hexpm", optional: false]}, {:cc_precompiler, "~> 0.1", [hex: :cc_precompiler, repo: "hexpm", optional: false]}], "hexpm", "c0d02a2d8cd0085252f7551a343f89060bb7beb3f303d991e46a7370ed257485"},
"earmark_parser": {:hex, :earmark_parser, "1.4.30", "0b938aa5b9bafd455056440cdaa2a79197ca5e693830b4a982beada840513c5f", [:mix], [], "hexpm", "3b5385c2d36b0473d0b206927b841343d25adb14f95f0110062506b300cd5a1b"},
"elixir_make": {:hex, :elixir_make, "0.7.3", "c37fdae1b52d2cc51069713a58c2314877c1ad40800a57efb213f77b078a460d", [:mix], [{:castore, "~> 0.1", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "24ada3e3996adbed1fa024ca14995ef2ba3d0d17b678b0f3f2b1f66e6ce2b274"},
"ex_doc": {:hex, :ex_doc, "0.29.1", "b1c652fa5f92ee9cf15c75271168027f92039b3877094290a75abcaac82a9f77", [:mix], [{:earmark_parser, "~> 1.4.19", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "b7745fa6374a36daf484e2a2012274950e084815b936b1319aeebcf7809574f6"},
"exla": {:git, "https://github.com/elixir-nx/nx.git", "952dd193f8be7041a1ee835bbe58753baa5460cc", [sparse: "exla"]},
"exla": {:hex, :exla, "0.5.0", "a002cb70e59c26d4ec78a256489e4026c428ff4917f25d266e6a86c58636dc7f", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.5.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.4.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "9219366cb0ea18c421349b8e0f130d85e83d7404df8054e5af6e18a47540c886"},
"kino": {:hex, :kino, "0.8.1", "da3b2cba121b7542146cffdb8af055fa0129395fa67aead9e7e3df93aed1f107", [:mix], [{:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "da45dd141db30db18973de0e3398bda3ab8cb0b5da58d6a0debbe5b864aba295"},
"kino_vega_lite": {:hex, :kino_vega_lite, "0.1.7", "c93fdfe6e35c4c5a4f8afd51a89786b2187e5a7da4595b13ea02a4329d9f0976", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}, {:vega_lite, "~> 0.1.4", [hex: :vega_lite, repo: "hexpm", optional: false]}], "hexpm", "59ee442f0532266749d15dc9af4e2875bec61ccfa1b07636bc396ee63dfde8e7"},
"makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"},
"makeup_elixir": {:hex, :makeup_elixir, "0.16.0", "f8c570a0d33f8039513fbccaf7108c5d750f47d8defd44088371191b76492b0b", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "28b2cbdc13960a46ae9a8858c4bebdec3c9a6d7b4b9e7f4ed1502f8159f338e7"},
"makeup_erlang": {:hex, :makeup_erlang, "0.1.1", "3fcb7f09eb9d98dc4d208f49cc955a34218fc41ff6b84df7c75b3e6e533cc65f", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "174d0809e98a4ef0b3309256cbf97101c6ec01c4ab0b23e926a9e17df2077cbb"},
"nimble_parsec": {:hex, :nimble_parsec, "1.2.3", "244836e6e3f1200c7f30cb56733fd808744eca61fd182f731eac4af635cc6d0b", [:mix], [], "hexpm", "c8d789e39b9131acf7b99291e93dae60ab48ef14a7ee9d58c6964f59efb570b0"},
"nx": {:git, "https://github.com/elixir-nx/nx.git", "952dd193f8be7041a1ee835bbe58753baa5460cc", [sparse: "nx"]},
"nx": {:hex, :nx, "0.5.0", "c5e62e82606ff372d986e72cce505c98421bb4305ce9cc8e439fe6cc1966c6ad", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "b29c246318181c3ebfcf0f230a0d33783ac4c92dfa34ca3aa5b9b38ae58c187e"},
"table": {:hex, :table, "0.1.2", "87ad1125f5b70c5dea0307aa633194083eb5182ec537efc94e96af08937e14a8", [:mix], [], "hexpm", "7e99bc7efef806315c7e65640724bf165c3061cdc5d854060f74468367065029"},
"table_rex": {:hex, :table_rex, "3.1.1", "0c67164d1714b5e806d5067c1e96ff098ba7ae79413cc075973e17c38a587caa", [:mix], [], "hexpm", "678a23aba4d670419c23c17790f9dcd635a4a89022040df7d5d772cb21012490"},
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
"torchx": {:git, "https://github.com/elixir-nx/nx.git", "952dd193f8be7041a1ee835bbe58753baa5460cc", [sparse: "torchx"]},
"torchx": {:hex, :torchx, "0.5.0", "d787ea5a62f299a93c03a7a9f1d0d903dd854797e8fc27bbbee984d8e3e6acf1", [:make, :mix], [{:dll_loader_helper, "~> 0.1.0", [hex: :dll_loader_helper, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.5.0", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "832205d22259011930231e5203cc1b929136a3ad1b160e1f4690d35dfb11ddbd"},
"vega_lite": {:hex, :vega_lite, "0.1.6", "145ab4908bc890b02cef3526e890e9b899528eaa7aa9d6fa642b52a8a2c682c6", [:mix], [{:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "078c0d8cd9a8eca4ae8f9527c45c01d69cefb6b2235fd5179a227ac2f031d7ac"},
"xla": {:hex, :xla, "0.4.3", "cf6201aaa44d990298996156a83a16b9a87c5fbb257758dbf4c3e83c5e1c4b96", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "caae164b56dcaec6fbcabcd7dea14303afde07623b0cfa4a3cd2576b923105f5"},
}
69 changes: 9 additions & 60 deletions test/axon/loop_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ defmodule Axon.LoopTest do
Axon.input("input", shape: {nil, 1})
|> Axon.dense(1)
|> Loop.trainer(:binary_cross_entropy, :sgd, log: 0)
|> Loop.handle(
|> Loop.handle_event(
:epoch_completed,
fn %State{step_state: pstate} = state ->
{
Expand All @@ -376,14 +376,6 @@ defmodule Axon.LoopTest do
}
end
)
|> Loop.handle(
:completed,
fn %State{step_state: %{counter: counter}} = state ->
assert 4 = counter

{:continue, state}
end
)
|> Loop.run(
[{Nx.tensor([[1.0]]), Nx.tensor([[1.0]])}],
%{},
Expand All @@ -396,7 +388,7 @@ defmodule Axon.LoopTest do
Axon.input("input", shape: {nil, 1})
|> Axon.dense(1)
|> Loop.trainer(:binary_cross_entropy, :sgd, log: 0)
|> Loop.handle(
|> Loop.handle_event(
:epoch_completed,
fn %State{step_state: pstate} = state ->
{
Expand All @@ -416,14 +408,6 @@ defmodule Axon.LoopTest do
}
end
)
|> Loop.handle(
:completed,
fn %State{step_state: %{counter: counter}} = state ->
assert {{4}, 4} = counter

{:continue, state}
end
)
|> Loop.run(
[{Nx.tensor([[1.0]]), Nx.tensor([[1.0]])}],
%{},
Expand Down Expand Up @@ -477,7 +461,7 @@ defmodule Axon.LoopTest do
end

def send_handler(loop, event) do
Axon.Loop.handle(loop, event, fn state ->
Axon.Loop.handle_event(loop, event, fn state ->
send(self(), event)
{:continue, state}
end)
Expand Down Expand Up @@ -540,15 +524,6 @@ defmodule Axon.LoopTest do
refute_received :iteration_completed
end

test "fires correctly on :completed" do
ExUnit.CaptureIO.capture_io(fn ->
run_dummy_loop!(:completed, 5, 10)
end)

assert_received :completed
refute_received :completed
end

test "fires correctly on :epoch_halted" do
model = Axon.input("foo")

Expand All @@ -562,7 +537,7 @@ defmodule Axon.LoopTest do
ExUnit.CaptureIO.capture_io(fn ->
model
|> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
|> Axon.Loop.handle(:iteration_completed, fn state ->
|> Axon.Loop.handle_event(:iteration_completed, fn state ->
{:halt_epoch, state}
end)
|> send_handler(:epoch_halted)
Expand All @@ -576,30 +551,6 @@ defmodule Axon.LoopTest do
refute_received :epoch_halted
end

test "fires correctly on :halted" do
model = Axon.input("foo")

data =
Stream.repeatedly(fn ->
xs = Nx.tensor([[Enum.random(0..10)]])
ys = Nx.greater(xs, 5)
{xs, ys}
end)

ExUnit.CaptureIO.capture_io(fn ->
model
|> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
|> Axon.Loop.handle(:iteration_completed, fn state ->
{:halt_loop, state}
end)
|> send_handler(:halted)
|> Axon.Loop.run(data, %{}, epochs: 5, iterations: 10)
end)

assert_received :halted
refute_received :halted
end

test "events fire in order" do
model = Axon.input("foo")

Expand All @@ -618,7 +569,6 @@ defmodule Axon.LoopTest do
|> send_handler(:iteration_started)
|> send_handler(:iteration_completed)
|> send_handler(:epoch_completed)
|> send_handler(:completed)
|> Axon.Loop.run(data, %{}, epochs: 1, iterations: 1)
end)

Expand All @@ -627,7 +577,6 @@ defmodule Axon.LoopTest do
assert_received :iteration_started
assert_received :iteration_completed
assert_received :epoch_completed
assert_received :completed

refute_received _
end
Expand All @@ -651,7 +600,7 @@ defmodule Axon.LoopTest do
end

def send_handler(loop, event, filter) do
Axon.Loop.handle(
Axon.Loop.handle_event(
loop,
event,
fn state ->
Expand Down Expand Up @@ -863,7 +812,7 @@ defmodule Axon.LoopTest do
model
|> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
|> Axon.Loop.from_state(state1)
|> Axon.Loop.handle(:epoch_completed, fn %{epoch: epoch} = state ->
|> Axon.Loop.handle_event(:epoch_completed, fn %{epoch: epoch} = state ->
assert epoch >= 3
{:continue, state}
end)
Expand All @@ -888,7 +837,7 @@ defmodule Axon.LoopTest do
|> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, Enum.take(data, 5))
|> Axon.Loop.handle(
|> Axon.Loop.handle_event(
:epoch_completed,
fn %{metrics: metrics} = state ->
assert Map.has_key?(metrics, "validation_accuracy")
Expand Down Expand Up @@ -918,7 +867,7 @@ defmodule Axon.LoopTest do
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, Enum.take(data, 5))
|> Axon.Loop.early_stop("validation_accuracy", mode: :max)
|> Axon.Loop.handle(
|> Axon.Loop.handle_event(
:epoch_completed,
fn %{handler_metadata: meta} = state ->
assert %{early_stop: %{"validation_accuracy" => _, :since_last_improvement => _}} =
Expand Down Expand Up @@ -1006,7 +955,7 @@ defmodule Axon.LoopTest do
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, Enum.take(data, 5))
|> Axon.Loop.reduce_lr_on_plateau("validation_accuracy", mode: :max)
|> Axon.Loop.handle(
|> Axon.Loop.handle_event(
:epoch_completed,
fn %{handler_metadata: meta} = state ->
assert %{reduce_lr: %{"validation_accuracy" => _, :since_last_improvement => _}} =
Expand Down

0 comments on commit 884107c

Please sign in to comment.