Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: updates for breaking changes #164

Merged
merged 9 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
os:
- ubuntu-latest
version:
- 'nightly' # coverage fast on nightly
- '1'
threads:
- '3'
- '4'
Expand Down Expand Up @@ -82,7 +82,6 @@ jobs:
- ubuntu-latest
version:
- '1.6'
- '1'
steps:
- uses: actions/checkout@c85c95e3d7251135ab7dc9ce3241c5835cc595a9
- uses: julia-actions/setup-julia@f40c4b69330df1d22e7590c12e76dc2f9c66e0bc
Expand Down
16 changes: 4 additions & 12 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SimpleChains"
uuid = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
authors = ["Chris Elrod <[email protected]> and contributors"]
version = "0.4.6"
version = "0.4.7"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down Expand Up @@ -38,22 +38,14 @@ LayoutPointers = "0.1.3"
LoopVectorization = "0.12.104"
ManualMemory = "0.1.8"
Polyester = "0.4, 0.5, 0.6, 0.7"
Random = "<0.0.1, 1"
SIMDTypes = "0.1"
SLEEFPirates = "0.6"
Static = "0.8.4"
Static = "0.8.4, 1"
StaticArrayInterface = "1"
StaticArrays = "1"
StrideArraysCore = "0.4.7"
StrideArraysCore = "0.4.7, 0.5"
UnPack = "1"
VectorizationBase = "0.21.40"
VectorizedRNG = "0.2.13"
julia = "1.6"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ForwardDiff", "Zygote", "Test"]
2 changes: 1 addition & 1 deletion docs/src/examples/mnist.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ We define the inputs as being statically sized `(28,28,1)` images.
Specifying the input sizes allows these to be checked.
Making them static, which we can do either in our simple chain, or by adding
static sizing to the images themselves using a package like [StrideArrays.jl](https://github.com/JuliaSIMD/StrideArrays.jl)
or [HybridArrays.jl](git@github.com:JuliaArrays/HybridArrays.jl.git). These packages are recommended
or [HybridArrays.jl](https://github.com/JuliaArrays/HybridArrays.jl). These packages are recommended
for allowing you to mix dynamic and static sizes; the batch size should probably
be left dynamic, as you're unlikely to want to specialize code generation on this,
given that it is likely to vary, increasing compile times while being unlikely to
Expand Down
1 change: 1 addition & 0 deletions src/SimpleChains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import ChainRulesCore
import ForwardDiff
import LoopVectorization
import StaticArrays
using StaticArrays: SVector, SMatrix
using Random: AbstractRNG

using LoopVectorization: matmul_params, @turbo
Expand Down
4 changes: 2 additions & 2 deletions src/chain_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ function valgrad_noloss(
@inbounds @simd ivdep for i in eachindex(parg)
parg2[i] = parg[i]
end
pm += aoff
pm = __add(pm, aoff)
g = PtrArray(Ptr{T}(pm), (glen,))
pm += goff
pm = __add(pm, goff)
l, pbl =
chain_valgrad_pullback!(pointer(g), parg2, layers, pointer(params), pm)
end
Expand Down
8 changes: 4 additions & 4 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,7 @@ end
function dense_param_update!(::TurboDense{true}, Ā, C̄, B)
Kp1 = static_size(Ā, StaticInt(2))
K = Kp1 - StaticInt(1)
dense!(identity, nothing, view(Ā, :, static(1):K), C̄, B', False())
dense!(identity, nothing, view(Ā, :, static(1):K), C̄, __adjoint(B), False())
@turbo for m ∈ axes(Ā, 1)
s = zero(eltype(Ā))
for n ∈ axes(C̄, 2)
Expand All @@ -1014,12 +1014,12 @@ function dense_param_update!(::TurboDense{true}, Ā, C̄, B)
end
end
function dense_param_update!(::TurboDense{false}, Ā, C̄, B)
dense!(identity, nothing, Ā, C̄, B', False())
dense!(identity, nothing, Ā, C̄, __adjoint(B), False())
end
function dense_param_update!(::TurboDense{true}, Ā, C̄, B, inds)
Kp1 = static_size(Ā, StaticInt(2))
K = Kp1 - StaticInt(1)
denserev!(identity, nothing, view(Ā, :, static(1):K), C̄, B', inds, False())
denserev!(identity, nothing, view(Ā, :, static(1):K), C̄, __adjoint(B), inds, False())
@turbo for m ∈ axes(Ā, 1)
s = zero(eltype(Ā))
for n ∈ axes(C̄, 2)
Expand All @@ -1029,7 +1029,7 @@ function dense_param_update!(::TurboDense{true}, Ā, C̄, B, inds)
end
end
function dense_param_update!(::TurboDense{false}, Ā, C̄, B, inds)
dense!(identity, nothing, Ā, C̄, B', inds, False())
dense!(identity, nothing, Ā, C̄, __adjoint(B), inds, False())
end

@inline function dense!(
Expand Down
2 changes: 1 addition & 1 deletion src/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ function valgrad_layer!(
VectorizedRNG.storestate!(rng, state)
end # GC preserve

pg, x, p, align(pu + ((static(7) + N) >>> static(3)))
pg, x, p, align(__add(pu, ((static(7) + N) >>> static(3))))
end

function pullback_arg!(
Expand Down
4 changes: 2 additions & 2 deletions src/simple_chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,7 @@ function valgrad_core(
) where {T}
@unpack layers = c
g = PtrArray(Ptr{T}(pu), (glen,))
l = unsafe_valgrad!(c, pu + align(glen * static_sizeof(T)), g, params, arg)
l = unsafe_valgrad!(c, __add(pu, align(glen * static_sizeof(T))), g, params, arg)
Base.FastMath.add_fast(
l,
apply_penalty!(g, getpenalty(c), params, static_size(arg))
Expand All @@ -838,7 +838,7 @@ function valgrad_core_sarray(
l = Base.FastMath.add_fast(
unsafe_valgrad!(
c,
pu + align(static(L) * static_sizeof(T)),
__add(pu, align(static(L) * static_sizeof(T))),
g,
params,
arg
Expand Down
7 changes: 7 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,10 @@ function _add_memory(t::Tuple, p)
(A, B...)
end
_add_memory(::Nothing, p) = nothing

__add(x, y) = x + y
__add(x::Ptr, ::StaticInt{N}) where {N} = x + N
__add(::StaticInt{N}, y::Ptr) where {N} = y + N

__adjoint(x) = x'
__adjoint(x::SVector{N, <:Real}) where {N} = SMatrix{1, N, eltype(x)}(x.data)
11 changes: 3 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using SimpleChains
using Test, Aqua, ForwardDiff, Zygote, ChainRules, Random
@static if VERSION >= v"1.9"
@static if VERSION v"1.9"
using JET: @test_opt
else
macro test_opt(ex)
Expand Down Expand Up @@ -84,12 +84,8 @@ InteractiveUtils.versioninfo(; verbose = true)
SquaredLoss"""

@test sprint((io, t) -> show(io, t), sc) == print_str0
if VERSION >= v"1.6"
@test sprint((io, t) -> show(io, t), scflp) == print_str1
else
# typename doesn't work on 1.5
@test_broken sprint((io, t) -> show(io, t), scflp) == print_str1
end
@test sprint((io, t) -> show(io, t), scflp) == print_str1

p = SimpleChains.init_params(scflp, T; rng = Random.default_rng())
g = similar(p)
let sc = SimpleChains.remove_loss(sc)
Expand Down Expand Up @@ -574,5 +570,4 @@ end
Aqua.test_all(
SimpleChains;
ambiguities = false,
project_toml_formatting = false
)
Loading