diff --git a/tests/unit_tests/test_bundle.py b/tests/unit_tests/test_bundle.py index f085a347..aa1055eb 100644 --- a/tests/unit_tests/test_bundle.py +++ b/tests/unit_tests/test_bundle.py @@ -121,6 +121,7 @@ def add(self, x, y): # Test functions with *args and *kwargs + print("*args, **kwargs test 1") @bundle() # This is the default behavior def fun(a, args, kwargs, *_args, **_kwargs): print(a) @@ -132,12 +133,44 @@ def fun(a, args, kwargs, *_args, **_kwargs): print(v) return a - x = fun( node(1), node("args"), node("kwargs"), node("_args_1"), node("_args_2"), b=node("_kwargs_b"), c=node("_kwargs_c") ) print(x, x.inputs) + print("*args, **kwargs test 2") + @bundle() # This is the default behavior + def fun(a, args, kwargs, *_args, **_kwargs): + print(a) + print(args) + print(kwargs) + for v in _args: + print(v) + for k, v in _kwargs.items(): + print(v) + return a + + x = fun( + node(1), 'arg1', 'kwargs', node("var_args_1"), node("var_args_2"), b=node("_kwargs_b"), c=node("_kwargs_c") + ) + print(x, x.inputs) + + @bundle() # This is the default behavior + def fun(a, args, kwargs, *_args, **_kwargs): + print(a) + print(args) + print(kwargs) + for v in _args: + print(v) + for k, v in _kwargs.items(): + print(v) + return a + + x = fun( + node(1), 'arg1', 'kwargs', "var_args_1", node("var_args_2"), b=node("_kwargs_b"), c=node("_kwargs_c") + ) + print(x, x.inputs) + # Test stop_tracing x = node(1)