-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom vmap implementation #25
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,6 +46,7 @@ | |
|
||
import torch | ||
from torch import Tensor | ||
from torch.utils._pytree import tree_map | ||
|
||
|
||
class TapeEntry(NamedTuple): | ||
|
@@ -209,6 +210,10 @@ def custom_vjp(self, fwd_fn, bwd_fn, *args): | |
result = label(a), label(b) | ||
return result | ||
|
||
def custom_vmap(self, fn, batch_rule, *args): | ||
results = fn(self, *args) | ||
return results | ||
|
||
def lift(self, input, d): | ||
assert d is self | ||
return input | ||
|
@@ -510,6 +515,25 @@ def propagate(dL_doutputs: List[Tensor]): | |
) | ||
return r, saved | ||
|
||
def custom_vmap(self, fn, batch_rule, *args): | ||
def call_with_current_dispatcher(fn): | ||
def wrapped(d, *args): | ||
saved = self.inner | ||
try: | ||
self.inner = d | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. woooow so spicy There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was going to say that you can do this without mutating the dispatcher stack by simply creating a new Autograd dispatcher on the fly, whose inner is d, but then the tapes would not be shared. This seems... dubious. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess if we represent the tape with an extra indirection this isn't too hard to do. Probably better and then makes this rule nicely symmetric for how |
||
result = fn(self, *args) | ||
return result | ||
finally: | ||
self.inner = saved | ||
return wrapped | ||
|
||
# either fn or batch_rule gets invoked later down the line. Whichever one | ||
# it is, we want to record the history onto this dispatcher's gradient tape. | ||
result = self.inner.custom_vmap( | ||
call_with_current_dispatcher(fn), | ||
call_with_current_dispatcher(batch_rule), *args) | ||
return result | ||
|
||
def lift(self, input, d): | ||
if d is self: | ||
return input | ||
|
@@ -740,6 +764,10 @@ def new_fn(d, *args): | |
r, saved = self.inner.custom_vjp(batchify(fwd_fn), batchify(bwd_fn), *args) | ||
return r, saved | ||
|
||
def custom_vmap(self, fn, batch_rule, *args): | ||
result = batch_rule(self.inner, *args) | ||
return result | ||
Comment on lines
+768
to
+769
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unlike custom_vjp, custom_vmap does not call custom_vmap on the inner dispatcher! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how come. It doesn't seem to me like Batched(Batched()) wouldn't apply. Perhaps you are saying, it is impossible for a batching rule to recursively refer to itself? |
||
|
||
# The lift operation takes a tensor associated with some inner | ||
# dispatcher, and "lifts" it so that it is interpreted neutrally | ||
# for the outer dispatcher. For most dispatchers this is trivial, | ||
|
@@ -1277,3 +1305,38 @@ def run_gradvmap(d2: "Batched", d1: "Autograd"): | |
|
||
|
||
run_gradvmap(d2, d1) | ||
|
||
d = Batched(Torch(), length=3) | ||
|
||
# Custom vmap | ||
def f(d, x): | ||
return d.mul(x, x) | ||
|
||
def f_batch_rule(d, x): | ||
# to prove a point | ||
return d.add(x, x) | ||
|
||
x = label(torch.tensor([1., 2., 3.])) | ||
|
||
result = d.custom_vmap(f, f_batch_rule, x) | ||
assert torch.allclose(result, 2 * x) | ||
|
||
# autograd should let custom_vmap pass through | ||
x = label(torch.tensor([1., 2., 3.])) | ||
d = Autograd(Torch()) | ||
result = d.custom_vmap(f, f_batch_rule, x) | ||
assert torch.allclose(result, x * x) | ||
loss999 = d.sum(result, name='loss999') | ||
|
||
grad_x, = d.grad(loss999, [x]) | ||
assert torch.allclose(grad_x, 2 * x) | ||
|
||
# autograd should let custom_vmap pass through | ||
x = label(torch.tensor([[1., 2., 3.], [4., 5., 6.]])) | ||
d = Autograd(Batched(Torch(), length=2)) | ||
result = d.custom_vmap(f, f_batch_rule, x) | ||
assert torch.allclose(result, 2 * x) | ||
loss88 = d.sum(result, name='loss88') | ||
|
||
grad_x, = d.grad(loss88, [x]) | ||
assert torch.allclose(grad_x, torch.full_like(x, 2)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think I agree with this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Assuming someone doesn't use this to completely change the values of the output (like we're doing here), then the output values should be the same. The difference is how the backward pass is being executed. As written in this PR right now Autograd(Batched(Torch(), length=2)) executes the backward pass of the batching rule, not the backward pass of the original function. To check, your claim is that Autograd(Batched(Torch(), length=2)) should execute the backward pass of the original function, not the backward pass of the batching rule, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @albanD for Batched(Autograd(Torch()) and a custom_vjp(f_fwd, f_bwd, *args) call, what would you expect to happen? Option 1: The backward pass differentiates through the batching rule for f_fwd Option 2: The backward pass runs vmap(f_bwd) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't that question beyond what we're discussing here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I'm just trying to understand the difference between this and custom_vjp. I think the current semantics are analogous to what custom_vjp is doing but reasoning through it is confusing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Disregarding custom_vjp for a moment, if we're going off of the argument from the meeting today that this should match the behavior of a normal batch rule, isn't this code right? Let's say we're calling As a related a note, I always get confused that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that the current Dispatcher for the regular code does that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes? More explanation:
Given that the original function is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This appears to diverge substantially from JAX's custom_vmap, at https://github.com/google/jax/blob/main/jax/_src/custom_batching.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is your comment that the API is different, the implementation is different, or both?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
implementation. But later I worked out that this is exactly analogous to how we did batching, so... idk, maybe it's still fine