mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
AOTAutograd: Make general SymInt hashable when merging view inputs. (#139553)
Fix: #139111 This PR wraps `SymInt` input arguments with `SymIntEqByExpr`, making them hashable when merging view inputs (`merge_view_inputs` function). Pull Request resolved: https://github.com/pytorch/pytorch/pull/139553 Approved by: https://github.com/ezyang
This commit is contained in:
parent
b46e1fc141
commit
a3cb8ee38b
|
|
@ -6632,6 +6632,24 @@ class TestAOTAutogradWithDynamo(TestAOTAutograd):
|
|||
|
||||
return torch_compile_wrapper
|
||||
|
||||
def test_inputs_overlapping_unsqueeze_with_mutation(self):
|
||||
def f(x, y):
|
||||
x.add_(1)
|
||||
y.add_(1)
|
||||
return x
|
||||
|
||||
def run(f):
|
||||
base = torch.ones(10)
|
||||
inputs = [base.unsqueeze(0), base.unsqueeze(0)]
|
||||
return f(*inputs)
|
||||
|
||||
optf = torch.compile(backend="aot_eager", dynamic=True)(f)
|
||||
|
||||
out = run(f)
|
||||
optout = run(optf)
|
||||
|
||||
self.assertEqual(out, optout)
|
||||
|
||||
|
||||
class MockFXGraphCache:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1408,15 +1408,26 @@ def merge_view_inputs(
|
|||
# If no synthetic bases are necessary, just return the original inputs.
|
||||
return fwd_inputs, None
|
||||
else:
|
||||
from torch.fx.experimental.symbolic_shapes import SymIntEqByExpr
|
||||
|
||||
def make_hashable(arg):
|
||||
if isinstance(arg, torch.SymInt):
|
||||
# Since only nested SymInt objects can be hashed, we wrap them with
|
||||
# SymIntEqByExpr, which is a hashable wrapper of SymInts.
|
||||
return SymIntEqByExpr(arg)
|
||||
return arg
|
||||
|
||||
# Otherwise, return:
|
||||
# (1) The new args according to the updated calling convention: (synthetic_bases, other_args)
|
||||
# (2) Metadata telling functionalization how to generate the inner argument list given the outer calling convention.
|
||||
# We post-process it into a list, where meta[i] tells you info about the i'th argument in the inner calling convention.
|
||||
args_to_functionalization = base_args + other_args
|
||||
arg_to_old_idx_map = {arg: i for (i, arg) in enumerate(fwd_inputs)}
|
||||
arg_to_old_idx_map = {
|
||||
make_hashable(arg): i for (i, arg) in enumerate(fwd_inputs)
|
||||
}
|
||||
for i, other_arg in enumerate(other_args):
|
||||
new_idx = len(base_args) + i
|
||||
old_idx = arg_to_old_idx_map[other_arg]
|
||||
old_idx = arg_to_old_idx_map[make_hashable(other_arg)]
|
||||
inner_calling_convention_meta[old_idx] = new_idx
|
||||
# post process into a list
|
||||
post_processed_calling_convention_meta: List[
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user