From 6f4d5303ba3b066bbcf0b8e7106cf51982a2a7b6 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 23 Aug 2024 12:50:25 +0100 Subject: [PATCH] add DynamicPPL to allow rule-reusing --- Project.toml | 11 ++++++++--- ext/TapirDynamicPPLExt.jl | 30 ++++++++++++++++++++++++++++++ test/ext/TapirDynamicPPLExt.jl | 20 ++++++++++++++++++++ 3 files changed, 58 insertions(+), 3 deletions(-) create mode 100644 ext/TapirDynamicPPLExt.jl create mode 100644 test/ext/TapirDynamicPPLExt.jl diff --git a/Project.toml b/Project.toml index 0c41891d0..4cb245e46 100644 --- a/Project.toml +++ b/Project.toml @@ -1,13 +1,14 @@ name = "Tapir" uuid = "07d77754-e150-4737-8c94-cd238a1fb45b" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.2.40" +version = "0.2.41" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d" +DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -19,12 +20,14 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [extensions] TapirCUDAExt = "CUDA" +TapirDynamicPPLExt = "DynamicPPL" TapirJETExt = "JET" TapirLogDensityProblemsADExt = "LogDensityProblemsAD" TapirSpecialFunctionsExt = "SpecialFunctions" @@ -38,6 +41,7 @@ DiffRules = "1" DiffTests = "0.1" Distributions = "0.25" Documenter = "1" +DynamicPPL = "0.28" ExprTools = "0.1" FillArrays = "1" Graphs = "1" @@ -49,13 +53,14 @@ Setfield = "1" SpecialFunctions = "2" StableRNGs = "1" TemporalGPs = "0.6" -Turing = "0.32" +Turing = "0.32, 0.33" julia = "1.10" [extras] AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" @@ -72,4 +77,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [targets] -test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "FillArrays", "KernelFunctions", "JET", "LogDensityProblemsAD", "PDMats", "ReverseDiff", "SpecialFunctions", "StableRNGs", "Test", "Turing", "TemporalGPs"] +test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DynamicPPL", "DiffTests", "Distributions", "Documenter", "FillArrays", "KernelFunctions", "JET", "LogDensityProblemsAD", "PDMats", "ReverseDiff", "SpecialFunctions", "StableRNGs", "Test", "Turing", "TemporalGPs"] diff --git a/ext/TapirDynamicPPLExt.jl b/ext/TapirDynamicPPLExt.jl new file mode 100644 index 000000000..42754251b --- /dev/null +++ b/ext/TapirDynamicPPLExt.jl @@ -0,0 +1,30 @@ +module TapirDynamicPPLExt + +using DynamicPPL: DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD +using Tapir + +function DynamicPPL.setmodel( + f::LogDensityProblemsAD.ADGradientWrapper, + model::DynamicPPL.Model, + adtype::ADTypes.AutoTapir, +) + if !hasfield(typeof(f), :rule) + @warn "ADGradientWrapper does not have a `rule` field. Please check Tapir version. It is also possible that `adtype` mismatch `ADGradientWrapper` type." + @warn "Using default rule." + return LogDensityProblemsAD.ADgradient( + Val(:Tapir), + DynamicPPL.setmodel(LogDensityProblemsAD.parent(f), model); + safety_on=adtype.safe_mode, + rule=nothing, + ) + else + return LogDensityProblemsAD.ADgradient( + Val(:Tapir), + DynamicPPL.setmodel(LogDensityProblemsAD.parent(f), model); + safety_on=adtype.safe_mode, + rule=f.rule, + ) + end +end + +end diff --git a/test/ext/TapirDynamicPPLExt.jl b/test/ext/TapirDynamicPPLExt.jl new file mode 100644 index 000000000..3f4ec3202 --- /dev/null +++ b/test/ext/TapirDynamicPPLExt.jl @@ -0,0 +1,20 @@ +using DynamicPPL +using DynamicPPL: ADTypes, LogDensityProblemsAD + +@testset "TapirDynamicPPLExt" begin + demo_model = DynamicPPL.TestUtils.DEMO_MODELS[1] + new_model = demo_model | (s = [1.0, 2.0],) + f = DynamicPPL.LogDensityFunction(demo_model) + ad_f_safe = LogDensityProblemsAD.ADgradient(ADTypes.AutoTapir(true), f) + new_ad_f_safe = DynamicPPL.setmodel(ad_f_safe, new_model, ADTypes.AutoTapir(true)) + @test new_ad_f_safe.ℓ.x.model === new_model + @test new_ad_f_safe isa LogDensityProblemsAD.ADGradientWrapper + @test new_ad_f_safe.rule isa Tapir.SafeRRule + @test new_ad_f_safe.rule === ad_f_safe.rule + ad_f_unsafe = LogDensityProblemsAD.ADgradient(ADTypes.AutoTapir(false), f) + new_ad_f_unsafe = DynamicPPL.setmodel(ad_f_unsafe, new_model, ADTypes.AutoTapir(false)) + @test new_ad_f_unsafe.ℓ.x.model === new_model + @test new_ad_f_unsafe isa LogDensityProblemsAD.ADGradientWrapper + @test new_ad_f_unsafe.rule isa Tapir.DerivedRule + @test new_ad_f_unsafe.rule === ad_f_unsafe.rule +end