-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom vmap implementation #25
base: main
Are you sure you want to change the base?
Conversation
TODO: needs description of what is going on
result = batch_rule(self.inner, *args) | ||
return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unlike custom_vjp, custom_vmap does not call custom_vmap on the inner dispatcher!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how come. It doesn't seem to me like Batched(Batched()) wouldn't apply. Perhaps you are saying, it is impossible for a batching rule to recursively refer to itself?
loss88 = d.sum(result, name='loss88') | ||
|
||
grad_x, = d.grad(loss88, [x]) | ||
assert torch.allclose(grad_x, torch.full_like(x, 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think I agree with this?
I would expect that Autograd(Batched(Torch(), length=2))
would give 2 * x
as the gradient while Batched(Autograd(Torch()), length=2)
would give 2
. No?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming someone doesn't use this to completely change the values of the output (like we're doing here), then the output values should be the same. The difference is how the backward pass is being executed.
As written in this PR right now Autograd(Batched(Torch(), length=2)) executes the backward pass of the batching rule, not the backward pass of the original function. To check, your claim is that Autograd(Batched(Torch(), length=2)) should execute the backward pass of the original function, not the backward pass of the batching rule, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.
I would expect here to be able to see the different between batch of gradients and element-wise gradients.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@albanD for Batched(Autograd(Torch()) and a custom_vjp(f_fwd, f_bwd, *args) call, what would you expect to happen?
Option 1: The backward pass differentiates through the batching rule for f_fwd
Option 2: The backward pass runs vmap(f_bwd)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't that question beyond what we're discussing here?
The question is still there without considering any custom_vjp right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I'm just trying to understand the difference between this and custom_vjp. I think the current semantics are analogous to what custom_vjp is doing but reasoning through it is confusing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Disregarding custom_vjp for a moment, if we're going off of the argument from the meeting today that this should match the behavior of a normal batch rule, isn't this code right?
Let's say we're calling d2.unsqueeze(x, dim)
with d2 = Batched(Autograd(Torch()))
. First it hits the Batched dispatcher, so we get d2.inner.unsqueeze(x, dim + 1)
so the Autograd dispatcher sees unsqueeze(x, dim + 1)
(which is the batch rule) and does autograd on that function. Using the same dispatcher stack, we similarly expect autograd runs on the "batch rule"
As a related a note, I always get confused that Batched(Autograd(Torch()))
is the same as grad(vmap())
so it might be worth to add the transform implementations if that's not too much work? I think this makes a lot more sense if we're able to say grad(vmap())
gets the derivative of the custom batch rule rather than remembering to invert the interpreter stack.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that the current Dispatcher for the regular code does that.
But now if we call the custom_vmap
with this unsqueeze function does it do that as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But now if we call the custom_vmap with this unsqueeze function does it do that as well?
Yes?
More explanation:
I would expect that Autograd(Batched(Torch(), length=2)) would give 2 * x as the gradient
Given that the original function is d.mul(x, x)
, this would mean that autograd is running on the unbatched function, not the batch rule. If we agree that given this set of Dispatchers should end up with Autograd running on the batched rule, then 2 is the expected gradient because d.add(x, x)
is the batched rule and the derivative of that is 2
def wrapped(d, *args): | ||
saved = self.inner | ||
try: | ||
self.inner = d |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
woooow so spicy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was going to say that you can do this without mutating the dispatcher stack by simply creating a new Autograd dispatcher on the fly, whose inner is d, but then the tapes would not be shared. This seems... dubious.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess if we represent the tape with an extra indirection this isn't too hard to do. Probably better and then makes this rule nicely symmetric for how custom_vjp
is implemented in Batched
.
@@ -510,6 +515,25 @@ def propagate(dL_doutputs: List[Tensor]): | |||
) | |||
return r, saved | |||
|
|||
def custom_vmap(self, fn, batch_rule, *args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This appears to diverge substantially from JAX's custom_vmap, at https://github.com/google/jax/blob/main/jax/_src/custom_batching.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is your comment that the API is different, the implementation is different, or both?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
implementation. But later I worked out that this is exactly analogous to how we did batching, so... idk, maybe it's still fine
TODO: needs description of what is going on