-
Notifications
You must be signed in to change notification settings - Fork 422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compatibility with Zygote AD #1516
Comments
It is possible, e.g. (just a quick hack): function ChainRulesCore.rrule(::typeof(rand), d::Normal{T}, n::Integer) where {T}
vals = rand(d, n)
function rand_pullback(rand_bar)
d_bar = Tangent{Normal{T}}(;μ=n, σ=sum((vals .- d.μ)) / d.σ)
return NoTangent(), d_bar, NoTangent()
end
return vals, rand_pullback
end But I wonder why a dedicated rule is necessary at all. Looking at the the definition for rand for Normal, this should be easily differentiable. I don't really see why Zygote returns |
It is due to https://github.com/JuliaDiff/ChainRules.jl/blob/158ca756ef99ccf3f1dde2e66b5855e8e68e0363/src/rulesets/Random/random.jl#L23-L25. It is a deliberate design decision to mark |
Ah, I missed that. Then everything makes senes. Sort of a meta question: Where would one put AD rules for such things? |
It seems that only ForwardDiff is supported for ADing sampling. Is it possible to implement rules for Zygote as well?
The text was updated successfully, but these errors were encountered: