[dynamo] Fix a bug by desugaring in-place ops on constants (#113117)

Summary:

Python allows users to write code like
```
x: 1
x += y
x += z
```

This code has well-defined semantics: because x is an immutable primitive, the first `+=` will actually re-bind x, it is equivalent to `x = x + y`.

The second in-place operation will either similarly desugar (if the result of `x + y` is itself immutable), or possibly result in "true" in-place operation.

Now, this is a problem for us because today, dynamo tries to both resolve constant variables to their literal values at compile time and also compile in a way that treats `operator.*` builtin functions consistently. This leads to a bug where code like
```
x: 1
x += y
```
actually gets compiled to
```
1 += y
```
which is both semantically meaningless and a syntax error.

A very simple fix that we've already used to fix the special case of `+=` is to detect this, treat it as an edge case, and desugar eagerly into `x = x + y`.

The problem with that fix is that it only patched `iadd`, but actually *all* of the in-place operators exhibit this behavior.

This commit proposes that we tackle all of the inplace opeartors supported by fx in the same way: eagerly remap the operation to an assignment when the left-side is actually an immutable constant.

**Alternatives?**

There might be some other fix possible that wouldn't produce a hardcoded remapping; I know that we generally don't like the growth of mappings and blocklists in dynamo.

I'm a little skeptical about a general solution though, because the bug is due precisely to Python's highly dynamic dispatching of inplace operations by type; since the fx graph has to be purely static, I suspect that we actually have to desugar this somewhere, because the dataflow is fundamentally different for true inplace operations on types that define `__iadd__`, etc vs the desugaring on primitives.

I'm open to other suggestions

Test Plan:

I verified that the code in
https://github.com/pytorch/pytorch/issues/112656
compiles with this fix, and the compiled functions produce the same outputs as the originals.

This needs unit tests, but I'd like to get feedback on the approach in the meantime.

Fixes #112656

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113117
Approved by: https://github.com/yanboliang
This commit is contained in:
Steven Troxler 2023-11-10 00:22:50 +00:00 committed by PyTorch MergeBot
parent bf452dcde6
commit cada6c7fee
3 changed files with 39 additions and 3 deletions

1
.watchman Normal file
View File

@ -0,0 +1 @@
{}

View File

@ -217,6 +217,18 @@ class MiscTests(torch._dynamo.test_case.TestCase):
torch._dynamo.testing.standard_test(self, inplace1, 2, expected_ops=3)
def test_inplace_desugaring(self):
def inplace_on_literals(y):
x0 = 1
x0 += y
x1 = 1
x1 -= y
return x0, x1
torch._dynamo.testing.standard_test(
self, inplace_on_literals, 1, expected_ops=2
)
def test_unpack4(self):
def unpack4(a, b):
a = a[:5, :]

View File

@ -54,6 +54,23 @@ from .user_defined import UserDefinedVariable
log = logging.getLogger(__name__)
IN_PLACE_DESUGARING_MAP = {
operator.iadd: operator.add,
operator.isub: operator.sub,
operator.imul: operator.mul,
operator.ifloordiv: operator.floordiv,
operator.itruediv: operator.truediv,
operator.imod: operator.mod,
operator.imatmul: operator.imatmul,
operator.ilshift: operator.lshift,
operator.irshift: operator.rshift,
operator.ipow: operator.pow,
operator.iand: operator.and_,
operator.ior: operator.or_,
operator.ixor: operator.xor,
}
class BuiltinVariable(VariableTracker):
@staticmethod
@functools.lru_cache(None)
@ -488,11 +505,17 @@ class BuiltinVariable(VariableTracker):
):
try:
fn = self.fn
if self.fn is operator.iadd and isinstance(
if self.fn in IN_PLACE_DESUGARING_MAP and isinstance(
args[0], variables.ConstantVariable
):
# Work around weird bug in hf_T5
fn, args = operator.add, [args[1], args[0]]
# In-place operators like += usually mustate tensor
# values, but in the edge case of immutable values they
# re-bind the variable.
#
# The easiest way to keep the graph consistent in this
# scenario is to de-sugar eagerly.
fn, args = IN_PLACE_DESUGARING_MAP[self.fn], [args[0], args[1]]
if self.fn is operator.getitem and isinstance(args[1], SymNodeVariable):
# Standard indexing will force specialization due to