Skip to content

Commit

Permalink
test added
Browse files Browse the repository at this point in the history
  • Loading branch information
jjsjann123 committed Dec 1, 2024
1 parent 2848612 commit 41e0036
Showing 1 changed file with 27 additions and 4 deletions.
31 changes: 27 additions & 4 deletions thunder/tests/test_jit_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,6 @@ def forward(self, x):
ids=("remove_duplicate=False", "remove_duplicate=True"),
)
def test_named_params_and_named_buffers(prefix, recurse, remove_duplicate):

buffer_tensor = torch.tensor([1.0])

class SubMod(torch.nn.Module):
Expand Down Expand Up @@ -1141,7 +1140,6 @@ def test_custom_autograd_function():
from torch.testing._internal.common_utils import gradcheck

class MyFunction(torch.autograd.Function):

@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
return x * 2.0
Expand Down Expand Up @@ -1204,7 +1202,6 @@ def forward(self, x):


def test_autograd_function_apply():

def forward(ctx, x):
saved_for_backward = (x,)
return x.sin(), saved_for_backward
Expand Down Expand Up @@ -1273,7 +1270,6 @@ def my_sin_with_wrong_backward(x):


def test_autograd_function_empty_forward():

class Fn(torch.autograd.Function):
@staticmethod
def forward(self, x):
Expand Down Expand Up @@ -1462,3 +1458,30 @@ def foo(a):
expected = foo(a)

assert_close(actual, expected)


def test_cache_symbolic_values_dict():
def foo(a, v):
return a[v].relu()

jfoo = thunder.jit(foo, cache="symbolic values")

a = {
2: torch.randn(2, 3, 8, requires_grad=True, device="cpu"),
5: torch.randn(4, 8, requires_grad=True, device="cpu"),
}

actual = jfoo(a, 2)
expected = foo(a, 2)

assert_close(actual, expected)

b = {
"a": torch.randn(2, 8, requires_grad=True, device="cpu"),
"b": torch.randn(7, requires_grad=True, device="cpu"),
}

actual = jfoo(b, "b")
expected = foo(b, "b")

assert_close(actual, expected)

0 comments on commit 41e0036

Please sign in to comment.