Skip to content

Commit

Permalink
Merge pull request #646 from AayushSabharwal/as/reversediff-ext
Browse files Browse the repository at this point in the history
feat: add `Code.create_array` method for `TrackedArray` in ReverseDiffExt
  • Loading branch information
ChrisRackauckas authored Sep 4, 2024
2 parents 3303acf + b409ba6 commit b16b285
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/benchmark_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: "1.8"
version: "1"
- uses: julia-actions/cache@v1
- name: Extract Package Name from Project.toml
id: extract-package-name
Expand Down
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "3.5.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Bijections = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Expand All @@ -28,12 +29,15 @@ Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415"

[weakdeps]
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"

[extensions]
SymbolicUtilsLabelledArraysExt = "LabelledArrays"
SymbolicUtilsReverseDiffExt = "ReverseDiff"

[compat]
AbstractTrees = "0.4"
ArrayInterface = "7.8"
Bijections = "0.1.2"
ChainRulesCore = "1"
Combinatorics = "1.0"
Expand All @@ -45,6 +49,7 @@ IfElse = "0.1"
LabelledArrays = "1.5"
MultivariatePolynomials = "0.5"
NaNMath = "0.3, 1"
ReverseDiff = "1"
Setfield = "0.7, 0.8, 1"
SpecialFunctions = "0.10, 1.0, 2"
StaticArrays = "0.12, 1.0"
Expand All @@ -62,8 +67,9 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["BenchmarkTools", "Documenter", "LabelledArrays", "Pkg", "PkgBenchmark", "Random", "ReferenceTests", "Test", "Zygote"]
test = ["BenchmarkTools", "Documenter", "LabelledArrays", "Pkg", "PkgBenchmark", "Random", "ReferenceTests", "ReverseDiff", "Test", "Zygote"]
10 changes: 10 additions & 0 deletions ext/SymbolicUtilsReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module SymbolicUtilsReverseDiffExt

using ReverseDiff
using SymbolicUtils

@inline function SymbolicUtils.Code.create_array(::Type{<:ReverseDiff.TrackedArray}, T, v1::Val, v2::Val{dims}, elems...) where dims
SymbolicUtils.ArrayInterface.aos_to_soa(SymbolicUtils.Code.create_array(Array, T, v1, v2, elems...))
end

end
2 changes: 2 additions & 0 deletions src/SymbolicUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ using ConstructionBase
using TermInterface
import TermInterface: iscall, isexpr, head, children,
operation, arguments, metadata, maketerm, sorted_arguments
# For ReverseDiffExt
import ArrayInterface

Base.@deprecate istree iscall
export istree, operation, arguments, sorted_arguments, similarterm, iscall
Expand Down
12 changes: 12 additions & 0 deletions test/code.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using SymbolicUtils.Code: LazyState
using StaticArrays
using LabelledArrays
using SparseArrays
using ReverseDiff
using LinearAlgebra

test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linenums!(b))
Expand Down Expand Up @@ -158,6 +159,17 @@ nanmath_st.rewrites[:nanmath] = true
@test eval(toexpr(Let([a 1, b 2, arr @SLVector((:a, :b))(@SVector[1,2])],
MakeArray([a+b,a/b], arr)))) === @SLVector((:a, :b))(@SVector [3, 1/2])

trackedarr = eval(toexpr(Let([a ReverseDiff.track(1.0), b 2, arr ReverseDiff.track(ones(2))],
MakeArray([a+b,a/b], arr))))
@test trackedarr isa ReverseDiff.TrackedArray
@test trackedarr == [3, 1/2]

trackedarr = eval(toexpr(Let([a ReverseDiff.track(1.0), b 2, arr ReverseDiff.track(ones(2))],
MakeArray([a b; a+b a/b], arr))))
@test trackedarr isa ReverseDiff.TrackedArray
@test trackedarr == [1 2; 3 1/2]


R1 = eval(toexpr(Let([a 1, b 2, arr @MVector([1,2])],MakeArray([a,b,a+b,a/b], arr))))
@test R1 == (@MVector [1, 2, 3, 1/2]) && R1 isa MVector

Expand Down

0 comments on commit b16b285

Please sign in to comment.