-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add interface functions to allow replacing the log density function and replacing AD wrapper type #33
Conversation
cc @devmotion, @tpapp, @torfjelde, @yebai, @miguelbiron |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also see comment in discussion.
It is unfortunate that the ADgradient(get_AD(ℓ), new_ℓ) and just implement |
It is cleaner. We can opt for it after the PR is merged. |
But then we would have to change the interface again... I am inclined to go with But will wait to hear from @devmotion. |
My impression from TuringLang/Turing.jl#2231 (comment) and related comments in Turing.jl was that there's no clear need for such an API currently? One reason for such an API would be a case where calling Regarding the implementation: Couldn't we achieve this functionality by overloading |
Does the ADTypes.jl extension not effectively solve this? Or are there some kwargs that are still missing from the ADTypes.jl structs? |
I have a new proposal: add an interface function EDIT: just realized this is exactly what @tpapp was suggesting 👍 The motivation is that I don't think |
Sorry for the late responses, I am on holiday with limited net access. @torfjelde: the problem is that not all the API is using ADtypes. @sunxd3: yes, the cleanest solution would be that, see my comment above. But we need to clean up the API first. I am not sure how pressing is the need for this solution. We could introduce something interim that solves the problem for Turing, with the understanding that it is internal and would be removed once we solve this. |
Gotcha 👍
We have a work-around on our side, so I think it's less pressing atm |
Ref #32 (comment)
Brief summary:
replace_ℓ
interface functionADgradient
take in aADGradientWrapper
, then recreate a new gradient wrapper with its log density functionI only added some implementations for
ReverseDiff
.This is very much a draft right now, everything is up to modify.