AOTAutograd: Make general SymInt hashable when merging view inputs. (#139553)

Fix: #139111

This PR wraps `SymInt` input arguments with `SymIntEqByExpr`, making them hashable when
merging view inputs (`merge_view_inputs` function).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139553
Approved by: https://github.com/ezyang
This commit is contained in:
Yukio Siraichi 2024-11-02 12:07:59 -03:00 committed by PyTorch MergeBot
parent b46e1fc141
commit a3cb8ee38b
2 changed files with 31 additions and 2 deletions

View File

@ -6632,6 +6632,24 @@ class TestAOTAutogradWithDynamo(TestAOTAutograd):
return torch_compile_wrapper
def test_inputs_overlapping_unsqueeze_with_mutation(self):
def f(x, y):
x.add_(1)
y.add_(1)
return x
def run(f):
base = torch.ones(10)
inputs = [base.unsqueeze(0), base.unsqueeze(0)]
return f(*inputs)
optf = torch.compile(backend="aot_eager", dynamic=True)(f)
out = run(f)
optout = run(optf)
self.assertEqual(out, optout)
class MockFXGraphCache:
"""

View File

@ -1408,15 +1408,26 @@ def merge_view_inputs(
# If no synthetic bases are necessary, just return the original inputs.
return fwd_inputs, None
else:
from torch.fx.experimental.symbolic_shapes import SymIntEqByExpr
def make_hashable(arg):
if isinstance(arg, torch.SymInt):
# Since only nested SymInt objects can be hashed, we wrap them with
# SymIntEqByExpr, which is a hashable wrapper of SymInts.
return SymIntEqByExpr(arg)
return arg
# Otherwise, return:
# (1) The new args according to the updated calling convention: (synthetic_bases, other_args)
# (2) Metadata telling functionalization how to generate the inner argument list given the outer calling convention.
# We post-process it into a list, where meta[i] tells you info about the i'th argument in the inner calling convention.
args_to_functionalization = base_args + other_args
arg_to_old_idx_map = {arg: i for (i, arg) in enumerate(fwd_inputs)}
arg_to_old_idx_map = {
make_hashable(arg): i for (i, arg) in enumerate(fwd_inputs)
}
for i, other_arg in enumerate(other_args):
new_idx = len(base_args) + i
old_idx = arg_to_old_idx_map[other_arg]
old_idx = arg_to_old_idx_map[make_hashable(other_arg)]
inner_calling_convention_meta[old_idx] = new_idx
# post process into a list
post_processed_calling_convention_meta: List[