From 8a5708ba3dd1287976a8aaf8f594958dbc23c9c3 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 15 Aug 2024 14:32:25 -0700 Subject: [PATCH] [dynamo] Support object creation of classes with custom __new__ (#132977) Pull Request resolved: https://github.com/pytorch/pytorch/pull/132977 Approved by: https://github.com/jansel --- test/dynamo/test_misc.py | 101 +++++++++++++++++++++++- torch/_dynamo/polyfill.py | 8 ++ torch/_dynamo/variables/lists.py | 3 + torch/_dynamo/variables/misc.py | 16 ++++ torch/_dynamo/variables/user_defined.py | 12 +++ 5 files changed, 137 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 1f9cfa22259df..eb62b36830cf3 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -3215,12 +3215,107 @@ def forward(self, x): x = torch.rand(2, 2) m = Model() - opt_m = torch.compile(backend="eager")(m) + opt_m = torch.compile(backend="eager", fullgraph=True)(m) ref = m(x) res = opt_m(x) self.assertTrue(same(ref, res)) - self.assertEqual(len(counters["graph_break"]), 1) - self.assertFalse("super() nn.Module.__init__" in counters["graph_break"]) + + def test_dunder_new_function_inlining1(self): + class Mock: + def __new__(cls): + return super().__new__(cls) + + def __init__(self): + self.c = 5 + + def run(self, x): + return x * self.c + + def fn(x): + mock = Mock() + return mock.run(x) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + x = torch.randn(4) + + self.assertEqual(fn(x), opt_fn(x)) + + def test_dunder_new_function_inlining2(self): + class Vehicle: + def __new__(cls, *args, **kwargs): + return super(Vehicle, cls).__new__(cls) + + def __init__(self, make, model, year): + self.make = make + self.model = model + self.year = year + + class Car(Vehicle): + def __new__(cls, *args, **kwargs): + return super(Car, cls).__new__(cls) + + def __init__(self, make, model, year, num_doors): + super(Car, self).__init__(make, model, year) + self.num_doors = num_doors + + class ElectricCar(Car): + def __new__(cls, *args, **kwargs): + return super(ElectricCar, cls).__new__(cls) + + def __init__(self, make, model, year, num_doors, battery_capacity): + super(ElectricCar, self).__init__(make, model, year, num_doors) + self.battery_capacity = battery_capacity + + def run(self, x): + return torch.sin(x) + + def fn(x): + ev = ElectricCar("Tesla", "Model S", 2022, 4, "100 kWh") + return ev.run(x) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + + x = torch.randn(4) + + self.assertEqual(fn(x), opt_fn(x)) + + def test_multiple_inheritance(self): + class Base1: + def __new__(cls): + return super().__new__(cls) + + def __init__(self): + super().__init__() + if not hasattr(self, "base2"): + raise ValueError("Wrong MRO tracing") + self.base1 = 3 + + class Base2: + def __new__(cls): + return super().__new__(cls) + + def __init__(self): + super().__init__() + self.base2 = 5 + + class Derived(Base1, Base2): + def __new__(cls): + return super().__new__(cls) + + def __init__(self): + super().__init__() + self.derived = 7 + + def run(self, x): + return self.base1 * self.base2 * self.derived * x + + def fn(x): + o = Derived() + return o.run(x) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + x = torch.randn(4) + self.assertEqual(fn(x), opt_fn(x)) def test_class_duner_mro(self): class ModuleA(torch.nn.Module): diff --git a/torch/_dynamo/polyfill.py b/torch/_dynamo/polyfill.py index 7a8091c6282d1..531b5cb858904 100644 --- a/torch/_dynamo/polyfill.py +++ b/torch/_dynamo/polyfill.py @@ -131,3 +131,11 @@ def mapping_get(obj, key, value=None): return obj.__getitem__(key) except KeyError: return value + + +def instantiate_user_defined_class_object(*args, **kwargs): + cls = args[0] + other_args = args[1:] + obj = cls.__new__(cls, *other_args, **kwargs) + obj.__init__(*other_args, **kwargs) + return obj diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index f38421471df79..0aaeada7113d7 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -506,6 +506,9 @@ class TupleVariable(BaseListVariable): def python_type(self): return tuple + def __repr__(self) -> str: + return f"{self.__class__.__name__}(length={len(self.items)})" + def debug_repr(self): return self.debug_repr_helper("(", ")") diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 1bd2b3a379b5a..aedbe0f0b3005 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -151,6 +151,22 @@ def call_method( ).call_function(tx, [self.objvar] + args, kwargs) else: unimplemented("super() nn.Module.__init__") + elif self.objvar.source and inner_fn is object.__new__: + return tx.output.side_effects.track_object_new( + self.objvar.source, + self.objvar.value, + variables.UnspecializedNNModuleVariable + if issubclass(self.objvar.value, torch.nn.Module) + else UserDefinedObjectVariable, + {}, + ) + elif name == "__new__" and isinstance(inner_fn, types.FunctionType): + # __new__ is a staticmethod object, but accessing __new__ from the super object, as done in + # _resolved_getattr_and_source, results in a function object. If not specialized here, it will try to add + # the `self` arg and fail bind arg matching later. + return variables.UserFunctionVariable( + inner_fn, source=source + ).call_function(tx, args, kwargs) elif isinstance(inner_fn, types.FunctionType): return variables.UserFunctionVariable( inner_fn, source=source diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 7ea1b212d5d2c..fdf10bbc0b974 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -495,6 +495,18 @@ def call_function( seed = None random_object = random.Random(seed) return RandomVariable(random_object) + elif ( + not self.is_standard_new() + and SideEffects.cls_supports_mutation_side_effects(self.value) + and self.source + ): + return tx.inline_user_function_return( + SourcelessBuilder.create( + tx, polyfill.instantiate_user_defined_class_object + ), + [self, *args], + kwargs, + ) return super().call_function(tx, args, kwargs)