mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Fix a bug by desugaring in-place ops on constants (#113117)
Summary: Python allows users to write code like ``` x: 1 x += y x += z ``` This code has well-defined semantics: because x is an immutable primitive, the first `+=` will actually re-bind x, it is equivalent to `x = x + y`. The second in-place operation will either similarly desugar (if the result of `x + y` is itself immutable), or possibly result in "true" in-place operation. Now, this is a problem for us because today, dynamo tries to both resolve constant variables to their literal values at compile time and also compile in a way that treats `operator.*` builtin functions consistently. This leads to a bug where code like ``` x: 1 x += y ``` actually gets compiled to ``` 1 += y ``` which is both semantically meaningless and a syntax error. A very simple fix that we've already used to fix the special case of `+=` is to detect this, treat it as an edge case, and desugar eagerly into `x = x + y`. The problem with that fix is that it only patched `iadd`, but actually *all* of the in-place operators exhibit this behavior. This commit proposes that we tackle all of the inplace opeartors supported by fx in the same way: eagerly remap the operation to an assignment when the left-side is actually an immutable constant. **Alternatives?** There might be some other fix possible that wouldn't produce a hardcoded remapping; I know that we generally don't like the growth of mappings and blocklists in dynamo. I'm a little skeptical about a general solution though, because the bug is due precisely to Python's highly dynamic dispatching of inplace operations by type; since the fx graph has to be purely static, I suspect that we actually have to desugar this somewhere, because the dataflow is fundamentally different for true inplace operations on types that define `__iadd__`, etc vs the desugaring on primitives. I'm open to other suggestions Test Plan: I verified that the code in https://github.com/pytorch/pytorch/issues/112656 compiles with this fix, and the compiled functions produce the same outputs as the originals. This needs unit tests, but I'd like to get feedback on the approach in the meantime. Fixes #112656 Pull Request resolved: https://github.com/pytorch/pytorch/pull/113117 Approved by: https://github.com/yanboliang
This commit is contained in:
parent
bf452dcde6
commit
cada6c7fee
|
|
@ -217,6 +217,18 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
torch._dynamo.testing.standard_test(self, inplace1, 2, expected_ops=3)
|
||||
|
||||
def test_inplace_desugaring(self):
|
||||
def inplace_on_literals(y):
|
||||
x0 = 1
|
||||
x0 += y
|
||||
x1 = 1
|
||||
x1 -= y
|
||||
return x0, x1
|
||||
|
||||
torch._dynamo.testing.standard_test(
|
||||
self, inplace_on_literals, 1, expected_ops=2
|
||||
)
|
||||
|
||||
def test_unpack4(self):
|
||||
def unpack4(a, b):
|
||||
a = a[:5, :]
|
||||
|
|
|
|||
|
|
@ -54,6 +54,23 @@ from .user_defined import UserDefinedVariable
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
IN_PLACE_DESUGARING_MAP = {
|
||||
operator.iadd: operator.add,
|
||||
operator.isub: operator.sub,
|
||||
operator.imul: operator.mul,
|
||||
operator.ifloordiv: operator.floordiv,
|
||||
operator.itruediv: operator.truediv,
|
||||
operator.imod: operator.mod,
|
||||
operator.imatmul: operator.imatmul,
|
||||
operator.ilshift: operator.lshift,
|
||||
operator.irshift: operator.rshift,
|
||||
operator.ipow: operator.pow,
|
||||
operator.iand: operator.and_,
|
||||
operator.ior: operator.or_,
|
||||
operator.ixor: operator.xor,
|
||||
}
|
||||
|
||||
|
||||
class BuiltinVariable(VariableTracker):
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
|
|
@ -488,11 +505,17 @@ class BuiltinVariable(VariableTracker):
|
|||
):
|
||||
try:
|
||||
fn = self.fn
|
||||
if self.fn is operator.iadd and isinstance(
|
||||
|
||||
if self.fn in IN_PLACE_DESUGARING_MAP and isinstance(
|
||||
args[0], variables.ConstantVariable
|
||||
):
|
||||
# Work around weird bug in hf_T5
|
||||
fn, args = operator.add, [args[1], args[0]]
|
||||
# In-place operators like += usually mustate tensor
|
||||
# values, but in the edge case of immutable values they
|
||||
# re-bind the variable.
|
||||
#
|
||||
# The easiest way to keep the graph consistent in this
|
||||
# scenario is to de-sugar eagerly.
|
||||
fn, args = IN_PLACE_DESUGARING_MAP[self.fn], [args[0], args[1]]
|
||||
|
||||
if self.fn is operator.getitem and isinstance(args[1], SymNodeVariable):
|
||||
# Standard indexing will force specialization due to
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user