From 348965bab55d317b6aa82608a5ab09fc907646d9 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 18 Apr 2022 15:48:47 -0400 Subject: [PATCH] Custom vmap implementation TODO: needs description of what is going on --- simple_functorch.py | 63 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/simple_functorch.py b/simple_functorch.py index 22f8266..eddabc2 100644 --- a/simple_functorch.py +++ b/simple_functorch.py @@ -46,6 +46,7 @@ import torch from torch import Tensor +from torch.utils._pytree import tree_map class TapeEntry(NamedTuple): @@ -209,6 +210,10 @@ def custom_vjp(self, fwd_fn, bwd_fn, *args): result = label(a), label(b) return result + def custom_vmap(self, fn, batch_rule, *args): + results = fn(self, *args) + return results + def lift(self, input, d): assert d is self return input @@ -510,6 +515,25 @@ def propagate(dL_doutputs: List[Tensor]): ) return r, saved + def custom_vmap(self, fn, batch_rule, *args): + def call_with_current_dispatcher(fn): + def wrapped(d, *args): + saved = self.inner + try: + self.inner = d + result = fn(self, *args) + return result + finally: + self.inner = saved + return wrapped + + # either fn or batch_rule gets invoked later down the line. Whichever one + # it is, we want to record the history onto this dispatcher's gradient tape. + result = self.inner.custom_vmap( + call_with_current_dispatcher(fn), + call_with_current_dispatcher(batch_rule), *args) + return result + def lift(self, input, d): if d is self: return input @@ -740,6 +764,10 @@ def new_fn(d, *args): r, saved = self.inner.custom_vjp(batchify(fwd_fn), batchify(bwd_fn), *args) return r, saved + def custom_vmap(self, fn, batch_rule, *args): + result = batch_rule(self.inner, *args) + return result + # The lift operation takes a tensor associated with some inner # dispatcher, and "lifts" it so that it is interpreted neutrally # for the outer dispatcher. For most dispatchers this is trivial, @@ -1277,3 +1305,38 @@ def run_gradvmap(d2: "Batched", d1: "Autograd"): run_gradvmap(d2, d1) + +d = Batched(Torch(), length=3) + +# Custom vmap +def f(d, x): + return d.mul(x, x) + +def f_batch_rule(d, x): + # to prove a point + return d.add(x, x) + +x = label(torch.tensor([1., 2., 3.])) + +result = d.custom_vmap(f, f_batch_rule, x) +assert torch.allclose(result, 2 * x) + +# autograd should let custom_vmap pass through +x = label(torch.tensor([1., 2., 3.])) +d = Autograd(Torch()) +result = d.custom_vmap(f, f_batch_rule, x) +assert torch.allclose(result, x * x) +loss999 = d.sum(result, name='loss999') + +grad_x, = d.grad(loss999, [x]) +assert torch.allclose(grad_x, 2 * x) + +# autograd should let custom_vmap pass through +x = label(torch.tensor([[1., 2., 3.], [4., 5., 6.]])) +d = Autograd(Batched(Torch(), length=2)) +result = d.custom_vmap(f, f_batch_rule, x) +assert torch.allclose(result, 2 * x) +loss88 = d.sum(result, name='loss88') + +grad_x, = d.grad(loss88, [x]) +assert torch.allclose(grad_x, torch.full_like(x, 2))