From 15c8d12b4c757c50d62a09099d34aa1e1020aa8d Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Tue, 19 Nov 2024 15:31:32 +0900 Subject: [PATCH] avoid flattening non-tensor args of subclass ctor Signed-off-by: Masaki Kozuki --- thunder/core/proxies.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index 2f7f054cd5..9799ea0561 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -1896,12 +1896,19 @@ def __init__(self, *args, **kwargs): kwarg_non_tensors = kwargs.pop("non_tensors", []) subclass_type = kwargs.pop("subclass_type", None) + has_name_before_init = hasattr(self, "_name") # If tensors (and non_tensors) are not empty, then it should be the path of `_make_wrapper_subclass` # where `self` should already have gotten its name. flat_args, spec = tree_flatten((args, kwargs)) - tensors = list(filter(lambda t: isinstance(t, TensorProxy), flat_args)) - non_tensors = list(filter(lambda t: not isinstance(t, TensorProxy), flat_args)) - has_name_before_init = hasattr(self, "_name") + tensors: list[TensorProxy] = [] + non_tensors: list[Any] = [] + for t in args + tuple(kwargs.values()): + if type(t) is SubclassTensorProxy: + continue + if type(t) is TensorProxy: + tensors.append(t) + else: + non_tensors.append(t) is_dunder_init_following_make_wrapper_subclass: bool = False if tensors: