diff --git a/.watchman b/.watchman new file mode 100644 index 00000000000..0967ef424bc --- /dev/null +++ b/.watchman @@ -0,0 +1 @@ +{} diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 9c85949a9d0..0a110f7e869 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -217,6 +217,18 @@ class MiscTests(torch._dynamo.test_case.TestCase): torch._dynamo.testing.standard_test(self, inplace1, 2, expected_ops=3) + def test_inplace_desugaring(self): + def inplace_on_literals(y): + x0 = 1 + x0 += y + x1 = 1 + x1 -= y + return x0, x1 + + torch._dynamo.testing.standard_test( + self, inplace_on_literals, 1, expected_ops=2 + ) + def test_unpack4(self): def unpack4(a, b): a = a[:5, :] diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 5cf6a09c9d1..1c4b932d165 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -54,6 +54,23 @@ from .user_defined import UserDefinedVariable log = logging.getLogger(__name__) +IN_PLACE_DESUGARING_MAP = { + operator.iadd: operator.add, + operator.isub: operator.sub, + operator.imul: operator.mul, + operator.ifloordiv: operator.floordiv, + operator.itruediv: operator.truediv, + operator.imod: operator.mod, + operator.imatmul: operator.imatmul, + operator.ilshift: operator.lshift, + operator.irshift: operator.rshift, + operator.ipow: operator.pow, + operator.iand: operator.and_, + operator.ior: operator.or_, + operator.ixor: operator.xor, +} + + class BuiltinVariable(VariableTracker): @staticmethod @functools.lru_cache(None) @@ -488,11 +505,17 @@ class BuiltinVariable(VariableTracker): ): try: fn = self.fn - if self.fn is operator.iadd and isinstance( + + if self.fn in IN_PLACE_DESUGARING_MAP and isinstance( args[0], variables.ConstantVariable ): - # Work around weird bug in hf_T5 - fn, args = operator.add, [args[1], args[0]] + # In-place operators like += usually mustate tensor + # values, but in the edge case of immutable values they + # re-bind the variable. + # + # The easiest way to keep the graph consistent in this + # scenario is to de-sugar eagerly. + fn, args = IN_PLACE_DESUGARING_MAP[self.fn], [args[0], args[1]] if self.fn is operator.getitem and isinstance(args[1], SymNodeVariable): # Standard indexing will force specialization due to