Skip to content

Commit

Permalink
update test_bundle to include a new test case for vararg (with node i…
Browse files Browse the repository at this point in the history
…nstance arg mixture)
  • Loading branch information
allenanie committed Nov 11, 2024
1 parent 917ad5a commit e811dba
Showing 1 changed file with 34 additions and 1 deletion.
35 changes: 34 additions & 1 deletion tests/unit_tests/test_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def add(self, x, y):


# Test functions with *args and *kwargs
print("*args, **kwargs test 1")
@bundle() # This is the default behavior
def fun(a, args, kwargs, *_args, **_kwargs):
print(a)
Expand All @@ -132,12 +133,44 @@ def fun(a, args, kwargs, *_args, **_kwargs):
print(v)
return a


x = fun(
node(1), node("args"), node("kwargs"), node("_args_1"), node("_args_2"), b=node("_kwargs_b"), c=node("_kwargs_c")
)
print(x, x.inputs)

print("*args, **kwargs test 2")
@bundle() # This is the default behavior
def fun(a, args, kwargs, *_args, **_kwargs):
print(a)
print(args)
print(kwargs)
for v in _args:
print(v)
for k, v in _kwargs.items():
print(v)
return a

x = fun(
node(1), 'arg1', 'kwargs', node("var_args_1"), node("var_args_2"), b=node("_kwargs_b"), c=node("_kwargs_c")
)
print(x, x.inputs)

@bundle() # This is the default behavior
def fun(a, args, kwargs, *_args, **_kwargs):
print(a)
print(args)
print(kwargs)
for v in _args:
print(v)
for k, v in _kwargs.items():
print(v)
return a

x = fun(
node(1), 'arg1', 'kwargs', "var_args_1", node("var_args_2"), b=node("_kwargs_b"), c=node("_kwargs_c")
)
print(x, x.inputs)


# Test stop_tracing
x = node(1)
Expand Down

0 comments on commit e811dba

Please sign in to comment.