Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support vectorize/devectorize inside gradients #1533

Open
jyc opened this issue Sep 11, 2024 · 10 comments · Fixed by #1535
Open

Support vectorize/devectorize inside gradients #1533

jyc opened this issue Sep 11, 2024 · 10 comments · Fixed by #1535
Assignees
Labels
area:defn Applies to defn kind:bug Something isn't working

Comments

@jyc
Copy link

jyc commented Sep 11, 2024

Thanks for making Nx!

I tried to use value_and_grad on a function that takes two inputs: a vectorized tensor and a non-vectorized tensor.

defmodule Foo do
  import Nx.Defn
  defn f(x, y) do
    x + y
  end

  defn f_and_grad(x, y) do
    value_and_grad(y, fn y -> Foo.f(x, y) end)
  end
end

x = ~VEC[0 1] |> vectorize(:bar)
Foo.f_and_grad(x, 1)

This evaluates to:

{#Nx.Tensor<
   vectorized[bar: 2]
   s64
   EXLA.Backend<host:0, 0.731981912.321781778.128426>
   [1, 2]
 >,
 #Nx.Tensor<
   f32
   EXLA.Backend<host:0, 0.731981912.321781778.128427>
   2.0
 >}

The value is correct and maintains the vectorized axis of the vectorized input to x, but the gradient surprises me. I would have expected a vectorized tensor rank-1 dimension-2 vector with the same :foo axis and which is everywhere 1; it looks like instead Nx is summing up the two gradients.

Is this behavior expected? If so, is there any way to make Nx return a vectorized gradient?

Thanks!

@jyc
Copy link
Author

jyc commented Sep 11, 2024

I know that I can use

y = ~VEC[1 1] |> vectorize(:bar)
Foo.f_and_grad(x, y)

to get the result I expect, but in practice y is actually quite large, so repeating it just so the gradient is computed properly seems wasteful. I will dig into that more though.

@josevalim
Copy link
Collaborator

I think this makes sense because the grad is computed over y, but I would like to see if @polvalente has a different opinion.

@jyc
Copy link
Author

jyc commented Sep 11, 2024

I tried checking if it would still be efficient to broadcast y to the size of x in order to get a gradient with the same dimensions as y; I wasn't sure whether Nx would create e.g. a vector with zero stride. However it looks like the byte_size increases, at least with Nx.BinaryBackend and Nx.EXLABackend:

x = ~VEC[0 0] |> vectorize(:foo)
y = ~VEC[1]
[x, y] = Nx.broadcast_vectors([x, y])
y |> Nx.byte_size()
# 16
# if another elements are added to `x`, evaluates to 24, etc.

So I still would be interested if there is a way to get the non-summed gradient, although I understand if it's not possible with this API.

@polvalente
Copy link
Contributor

I agree with @jyc in that the grad should have the same vector shape as the output. That is, the correct result for the example should be [1.0, 1.0] instead of 2.0.

The mental model I have is that fun(vectorized[foo: 2] [1, 2]) should yield the same output as vectorize(stack([fun(1), fun(2)]), :foo), which is not the case here.

@polvalente
Copy link
Contributor

Memory-wise, vectorization will end up doing the explicit broadcasting, if applicable, regardless of the backend (although some backends might end up fusing things).

@polvalente polvalente added kind:bug Something isn't working area:defn Applies to defn labels Sep 11, 2024
@jyc
Copy link
Author

jyc commented Sep 11, 2024

Note: This specific comment is wrong and can be ignored; what I said earlier & what polvalente said is correct AFAIK. Sorry for the confusion!

@polvalente Thanks for the reply! Sorry but just to be clear, I checked after you mentioned the mental model and it looks like grad returns the same result even without vectorization, so my mentioning the vectorization was a red herring:

defmodule Foo do
  import Nx.Defn
  defn f(x, y) do
    x + y
  end

  defn f_and_grad(x, y) do
    value_and_grad(y, fn y -> Foo.f(x, y) end)
  end
end

x = ~VEC[0 1 2]
y = ~VEC[1]
Foo.f_and_grad(x, y)
# {~VEC[1, 2, 3], ~VEC[3]}

This is still surprising to me but at least it is consistent with and without vectorization. I will keep looking for a workaround.

@jyc
Copy link
Author

jyc commented Sep 11, 2024

Actually, I have confused myself! I don't believe it's a red herring because it's the other axis that is vectorized. I misunderstood. Please ignore my last comment, sorry for the noise. In other words, I agree with your comment here:

The mental model I have is that fun(vectorized[foo: 2] [1, 2]) should yield the same output as vectorize(stack([fun(1), fun(2)]), :foo), which is not the case here.

@polvalente
Copy link
Contributor

polvalente commented Sep 11, 2024

The problem here is that for that Foo module, this isn't true:

x = Nx.tensor([0, 1, 2])
y = 1

{_, grad0} = Foo.f_and_grad(x[0], y)
{_, grad1} = Foo.f_and_grad(x[1], y)
{_, grad2} = Foo.f_and_grad(x[2], y)

expected_result = Nx.stack([grad0, grad1, grad2]) |> Nx.vectorize(:foo)

actual_result = Foo.f_and_grad(Nx.vectorize(x, :foo), y)
iex(19)> expected_result = Nx.stack([grad0, grad1, grad2]) |> Nx.vectorize(:foo)
#Nx.Tensor<
  vectorized[foo: 3]
  f32
  [1.0, 1.0, 1.0]
>
iex(20)>
nil
iex(21)> actual_result = Foo.f_and_grad(Nx.vectorize(x, :foo), y)
{#Nx.Tensor<
   vectorized[foo: 3]
   s32
   [1, 2, 3]
 >,
 #Nx.Tensor<
   f32
   3.0
 >}

@jyc
Copy link
Author

jyc commented Sep 11, 2024

You are right! Sorry for the noise.

@josevalim
Copy link
Collaborator

Reopening because we still need to support vectorize/devectorize inside the gradient. :)

@josevalim josevalim reopened this Sep 24, 2024
@josevalim josevalim changed the title Nx.Defn.Grad returns sum of gradient over vectorized axis? Support vectorize/devectorize inside gradients Sep 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area:defn Applies to defn kind:bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants