Skip to content

Commit

Permalink
update flux v0.16
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Dec 31, 2024
1 parent 1e97549 commit ff35528
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 16 deletions.
2 changes: 1 addition & 1 deletion DifferentiationInterfaceTest/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ DifferentiationInterface = "0.6.0"
DocStringExtensions = "0.8,0.9"
ExplicitImports = "1.10.1"
FiniteDifferences = "0.12"
Flux = "0.15"
Flux = "0.16"
ForwardDiff = "0.10.36"
Functors = "0.4, 0.5"
JET = "0.4 - 0.8, 0.9"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,31 +44,22 @@ end

function DIT.flux_isapprox(a, b; atol, rtol)
isapprox_results = fmapstructure_with_path(a, b) do kp, x, y
if :state in kp # ignore RNN and LSTM state
if x isa AbstractArray{<:Number}
return isapprox(x, y; atol, rtol)
else # ignore non-arrays
return true
else
if x isa AbstractArray{<:Number}
return isapprox(x, y; atol, rtol)
else # ignore non-arrays
return true
end
end
end
return all(fleaves(isapprox_results))
end

function square_loss(model, x)
y = model(x)
y = y isa Tuple ? y[1] : y # handle LSTM
return mean(abs2, y)
end
square_loss(model, x) = mean(abs2, model(x))

function square_loss_iterated(cell, x)
st = cell(x) # uses default initial state
y, st = cell(x) # uses default initial state
for _ in 1:2
st = cell(x, st)
y, st = cell(x, st)
end
y = st isa Tuple ? st[1] : st # handle LSTM
return mean(abs2, y)
end

Expand Down Expand Up @@ -158,6 +149,10 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng())
GRU(3 => 4; init_kernel=init, init_recurrent_kernel=init),
randn(rng, Float32, 3, 2, 1)
),
(
Chain(LSTM(3 => 4), RNN(4 => 5), Dense(5 => 2)),
randn(rng, Float32, 3, 2, 1)
),
#! format: on
]

Expand Down

0 comments on commit ff35528

Please sign in to comment.