Revert "Use source hashing to generate consistent symbolic ids (#149665)"

This reverts commit 1f92348dc6.

Reverted https://github.com/pytorch/pytorch/pull/149665 on behalf of https://github.com/malfet due to Broke trunk, see 6eb3c2e282/1 ([comment](https://github.com/pytorch/pytorch/pull/149665#issuecomment-2758578187))
This commit is contained in:
PyTorch MergeBot 2025-03-27 16:02:27 +00:00
parent 6eb3c2e282
commit af7719a2fa
21 changed files with 437 additions and 513 deletions

View File

@ -196,46 +196,6 @@ class AOTAutogradCacheTests(InductorTestCase):
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
@inductor_config.patch("fx_graph_remote_cache", False)
@inductor_config.patch("fx_graph_cache", True)
@functorch_config.patch({"enable_autograd_cache": True})
def test_symbol_specialization(self):
"""
Verify the symbol specializations don't cause cache miss.
"""
def fn(x, y, z):
return (torch.randn(5) + x + y, z * torch.randn(1))
a = torch.rand(5)
torch._dynamo.maybe_mark_dynamic(a, 0)
b = torch.rand(5)
c = torch.randn(6)
torch._dynamo.maybe_mark_dynamic(c, 0)
compiled_fn = torch.compile(fn, backend="inductor")
# A first call should miss in the cache.
compiled_fn(a, b, c)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
# A second call should hit even if a new dimension is marked as dynamic
# that is later specialized as part of tracing.
a = torch.rand(5)
torch._dynamo.maybe_mark_dynamic(a, 0)
b = torch.rand(5)
torch._dynamo.maybe_mark_dynamic(b, 0)
c = torch.randn(6)
torch._dynamo.maybe_mark_dynamic(c, 0)
self._clear_dynamo_and_codecache()
compiled_fn(a, b, c)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
@functorch_config.patch({"enable_autograd_cache": True})
def test_aot_runtime_trace_joint(self):
@torch.compile(backend="inductor")

View File

@ -245,7 +245,7 @@ class GraphModule(torch.nn.Module):
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(2)", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s45)"):
def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(2)", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s7)"):
l_inputs_ = L_inputs_
l_sizes_0_ = L_sizes_0_
l_hooks_1_keywords_fn_keywords_obj_counter = L_hooks_1_keywords_fn_keywords_obj_counter
@ -264,7 +264,7 @@ class GraphModule(torch.nn.Module):
copy_: "f32[2]" = new_grad_strided.copy_(aot0_tangents_1); copy_ = None
add: "Sym(s45 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None
add: "Sym(s7 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None
result: "f32[2]" = aot0_tangents_1 * aot0_tangents_1; aot0_tangents_1 = None

View File

@ -57,18 +57,18 @@ class ComptimeTests(torch._dynamo.test_case.TestCase):
self.assertExpectedInline(
FILE.getvalue().strip(),
"""\
FakeTensor(..., size=(s77,))
FakeTensor(..., size=(s0,))
2
[FakeTensor(..., size=(s77,)), 2]
(FakeTensor(..., size=(s77,)), 2)
{'foo': FakeTensor(..., size=(s77,))}
[FakeTensor(..., size=(s0,)), 2]
(FakeTensor(..., size=(s0,)), 2)
{'foo': FakeTensor(..., size=(s0,))}
range(1, 3, 1)
Employee(name='foo', id=2)
UserDefinedListVariable(mylist)
defaultdict(NestedUserFunctionVariable(), {})
set()
{'a','b'}
s77""",
s0""",
)
def test_print_graph(self):

View File

@ -256,34 +256,34 @@ Model:
==> L['x'].size()[0]: 3
==> L['x'].storage_offset(): 0
==> L['x'].stride()[0]: 1
==> s0: 3
==> s1: 0
==> s2: 1
==> s3: 1
==> s52: 1
==> s77: 3
==> s86: 0
Assertions:
==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0])
==> (== L['shape'][0] s86)
==> (== L['shape'][1] s52)
==> (== L['shape'][0] s1)
==> (== L['shape'][1] s2)
==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s77)
==> (> s77 1)
==> (== L['x'].size()[0] s0)
==> (> s0 1)
Target Expressions:
==> (!= (+ s3 s52 s86) s77)
==> (!= (+ s1 s2 s3) s0)
==> (<= 0 s1)
==> (<= 0 s2)
==> (<= 0 s3)
==> (<= 0 s52)
==> (<= 0 s86)
==> (<= 2 s77)
==> (<= 2 s0)
==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0])
==> (== L['shape'][0] s86)
==> (== L['shape'][1] s52)
==> (== L['shape'][0] s1)
==> (== L['shape'][1] s2)
==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s77)
==> (> s77 0)
==> (>= 0 s86)
==> (== L['x'].size()[0] s0)
==> (> s0 0)
==> (>= 0 s1)
Failed Source Expressions:
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
@ -309,7 +309,7 @@ Failed Source Expressions:
BisectValidationException,
lambda: fn(torch.randn(20), (5, 10, 5)),
"""\
translation validation failed when evaluating: Eq(s3 + s52 + s86, s77)
translation validation failed when evaluating: Eq(s1 + s2 + s3, s0)
Failure occurred while running node:
%split : [num_users=3] = call_method[target=split](args = (%l_x_, (%l_shape_0_, %l_shape_1_, %l_shape_2_)), kwargs = {})
@ -321,33 +321,33 @@ Model:
==> L['x'].size()[0]: 3
==> L['x'].storage_offset(): 0
==> L['x'].stride()[0]: 1
==> s0: 3
==> s1: 1
==> s2: 1
==> s3: 0
==> s52: 1
==> s77: 3
==> s86: 1
Assertions:
==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0])
==> (== L['shape'][0] s86)
==> (== L['shape'][1] s52)
==> (== L['shape'][0] s1)
==> (== L['shape'][1] s2)
==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s77)
==> (> s77 1)
==> (== L['x'].size()[0] s0)
==> (> s0 1)
Target Expressions:
==> (!= (+ s3 s52 s86) s77)
==> (!= (+ s1 s2 s3) s0)
==> (<= 0 s1)
==> (<= 0 s2)
==> (<= 0 s3)
==> (<= 0 s52)
==> (<= 0 s86)
==> (<= 2 s77)
==> (<= 2 s0)
==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0])
==> (== L['shape'][0] s86)
==> (== L['shape'][1] s52)
==> (== L['shape'][0] s1)
==> (== L['shape'][1] s2)
==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s77)
==> (> s77 0)
==> (== L['x'].size()[0] s0)
==> (> s0 0)
Failed Source Expressions:
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",

View File

@ -2703,7 +2703,7 @@ def forward(self, x):
for node in ebar.graph_module.graph.nodes
if node.op == "placeholder"
],
["torch.Size([s17, s27, s27])", "torch.Size([s17, s27, s27])"],
["torch.Size([s0, s1, s1])", "torch.Size([s0, s1, s1])"],
)
@torch._dynamo.config.patch(
@ -3480,23 +3480,23 @@ def forward(self, x):
true_graph = """\
class GraphModule(torch.nn.Module):
def forward(self, pred, x):
arg1: "f32[s77, s27]";
arg1: "f32[s1, s2]";
arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
l_x_ = arg1
sin: "f32[s77, s27]" = l_x_.sin(); l_x_ = None
sin: "f32[s1, s2]" = l_x_.sin(); l_x_ = None
return pytree.tree_unflatten([sin], self._out_spec)
"""
false_graph = """\
class GraphModule(torch.nn.Module):
def forward(self, pred, x):
arg1: "f32[s77, s27]";
arg1: "f32[s1, s2]";
arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
l_x_ = arg1
cos: "f32[s77, s27]" = l_x_.cos(); l_x_ = None
cos: "f32[s1, s2]" = l_x_.cos(); l_x_ = None
return pytree.tree_unflatten([cos], self._out_spec)
"""
true_guard_code = [

View File

@ -2655,7 +2655,7 @@ class GraphModule(torch.nn.Module):
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77]"):
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0]"):
l_x_ = L_x_
sum_1: "f32[]" = l_x_.sum(); l_x_ = None
@ -2885,13 +2885,13 @@ class GraphModule(torch.nn.Module):
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, s9: "Sym(s9)", L_lambda0_keywords_y_: "f32[s9, s9]"):
def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
l_lambda0_keywords_y_ = L_lambda0_keywords_y_
mul: "f32[s9, s9]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
mul_1: "f32[s9, s9]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
mul_1: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
mul_2: "f32[s9, s9]" = torch.mul(mul, mul_1); mul = mul_1 = None
mul_2: "f32[s0, s0]" = torch.mul(mul, mul_1); mul = mul_1 = None
return (mul_2,)
""",
)
@ -2932,14 +2932,14 @@ class GraphModule(torch.nn.Module):
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, s9: "Sym(s9)", L_lambda0_keywords_y_: "f32[s9, s9]"):
def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
l_lambda0_keywords_y_ = L_lambda0_keywords_y_
mul: "f32[s9, s9]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
add: "f32[s9, s9]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
add: "f32[s0, s0]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
mul_1: "f32[s9, s9]" = torch.mul(mul, add); mul = add = None
mul_1: "f32[s0, s0]" = torch.mul(mul, add); mul = add = None
return (mul_1,)
""",
)
@ -2982,14 +2982,14 @@ class GraphModule(torch.nn.Module):
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, s9: "Sym(s9)", L_lambda0_keywords_y_: "f32[s9, s9]"):
def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
l_lambda0_keywords_y_ = L_lambda0_keywords_y_
mul: "f32[s9, s9]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
add: "f32[s9, s9]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
add: "f32[s0, s0]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
mul_1: "f32[s9, s9]" = torch.mul(mul, add); mul = add = None
mul_1: "f32[s0, s0]" = torch.mul(mul, add); mul = add = None
return (mul_1,)
""",
)
@ -3029,14 +3029,14 @@ class GraphModule(torch.nn.Module):
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77, s77]"):
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, s0]"):
l_x_ = L_x_
mul: "f32[s77, s77]" = l_x_ * 4
mul_1: "f32[s77, s77]" = mul * l_x_; mul = None
mul_2: "f32[s77, s77]" = 20 * l_x_; l_x_ = None
mul: "f32[s0, s0]" = l_x_ * 4
mul_1: "f32[s0, s0]" = mul * l_x_; mul = None
mul_2: "f32[s0, s0]" = 20 * l_x_; l_x_ = None
mul_3: "f32[s77, s77]" = torch.mul(mul_1, mul_2); mul_1 = mul_2 = None
mul_3: "f32[s0, s0]" = torch.mul(mul_1, mul_2); mul_1 = mul_2 = None
return (mul_3,)
""",
)

View File

@ -413,18 +413,18 @@ class GraphModule(torch.nn.Module):
actual_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77, 1]"):
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, 1]"):
l_x_ = L_x_
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_); wrap_body_0 = s77 = l_x_ = None
getitem: "f32[s77]" = wrap[0]; wrap = None
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_); wrap_body_0 = s0 = l_x_ = None
getitem: "f32[s0]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77, 1]"):
view: "f32[s77]" = l_x_.view(s77); l_x_ = s77 = None
add: "f32[s77]" = view + 0.5; view = None
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0, 1]"):
view: "f32[s0]" = l_x_.view(s0); l_x_ = s0 = None
add: "f32[s0]" = view + 0.5; view = None
return (add,)
""",
)
@ -606,27 +606,27 @@ class GraphModule(torch.nn.Module):
out_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77]"):
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0]"):
l_x_ = L_x_
sum_1: "f32[]" = l_x_.sum()
item: "Sym(zuf0)" = sum_1.item(); sum_1 = None
wrap_body_1 = self.wrap_body_1
wrap = torch.ops.higher_order.wrap(wrap_body_1, s77, l_x_, item); wrap_body_1 = s77 = l_x_ = item = None
getitem: "f32[s77]" = wrap[0]; wrap = None
wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, l_x_, item); wrap_body_1 = s0 = l_x_ = item = None
getitem: "f32[s0]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_1(torch.nn.Module):
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", item: "Sym(zuf0)"):
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", item: "Sym(zuf0)"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_, item); wrap_body_0 = s77 = l_x_ = item = None
getitem: "f32[s77]" = wrap[0]; wrap = None
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_, item); wrap_body_0 = s0 = l_x_ = item = None
getitem: "f32[s0]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", item: "Sym(zuf0)"):
add: "f32[s77]" = l_x_ + item; l_x_ = item = None
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", item: "Sym(zuf0)"):
add: "f32[s0]" = l_x_ + item; l_x_ = item = None
return (add,)
""",
)
@ -692,7 +692,7 @@ class GraphModule(torch.nn.Module):
out_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77]"):
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0]"):
l_x_ = L_x_
c: "i64[u0, 1]" = l_x_.nonzero()
@ -704,22 +704,22 @@ class GraphModule(torch.nn.Module):
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
wrap_body_1 = self.wrap_body_1
wrap = torch.ops.higher_order.wrap(wrap_body_1, s77, l_x_, sym_size_int_1, c); wrap_body_1 = s77 = l_x_ = sym_size_int_1 = c = None
getitem: "f32[s77]" = wrap[0]
wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, l_x_, sym_size_int_1, c); wrap_body_1 = s0 = l_x_ = sym_size_int_1 = c = None
getitem: "f32[s0]" = wrap[0]
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (getitem, getitem_1)
class wrap_body_1(torch.nn.Module):
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"):
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", u0: "Sym(u0)", c: "i64[u0, 1]"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_, u0, c); wrap_body_0 = s77 = l_x_ = u0 = c = None
child: "f32[s77]" = wrap[0]
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_, u0, c); wrap_body_0 = s0 = l_x_ = u0 = c = None
child: "f32[s0]" = wrap[0]
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (child, child_1)
class wrap_body_0(torch.nn.Module):
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"):
child: "f32[s77]" = l_x_.sin(); l_x_ = None
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", u0: "Sym(u0)", c: "i64[u0, 1]"):
child: "f32[s0]" = l_x_.sin(); l_x_ = None
child_1: "f32[u0, 1]" = c.sin(); c = None
return (child, child_1)
""",
@ -994,25 +994,25 @@ class GraphModule(torch.nn.Module):
out_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]", s94: "Sym(s94)", L_y_: "f32[s27, s94]"):
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_x_: "f32[s0, s1]", s2: "Sym(s2)", L_y_: "f32[s1, s2]"):
l_x_ = L_x_
l_y_ = L_y_
wrap_body_1 = self.wrap_body_1
wrap = torch.ops.higher_order.wrap(wrap_body_1, s77, s27, l_x_, s94, l_y_); wrap_body_1 = s77 = s27 = l_x_ = s94 = l_y_ = None
getitem: "f32[s77, s94]" = wrap[0]; wrap = None
wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, s1, l_x_, s2, l_y_); wrap_body_1 = s0 = s1 = l_x_ = s2 = l_y_ = None
getitem: "f32[s0, s2]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_1(torch.nn.Module):
def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", l_x_: "f32[s77, s27]", s94: "Sym(s94)", l_y_: "f32[s27, s94]"):
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", l_x_: "f32[s0, s1]", s2: "Sym(s2)", l_y_: "f32[s1, s2]"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, s27, l_x_, s94, l_y_); wrap_body_0 = s77 = s27 = l_x_ = s94 = l_y_ = None
getitem: "f32[s77, s94]" = wrap[0]; wrap = None
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, s1, l_x_, s2, l_y_); wrap_body_0 = s0 = s1 = l_x_ = s2 = l_y_ = None
getitem: "f32[s0, s2]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", l_x_: "f32[s77, s27]", s94: "Sym(s94)", l_y_: "f32[s27, s94]"):
matmul: "f32[s77, s94]" = l_x_ @ l_y_; l_x_ = l_y_ = None
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", l_x_: "f32[s0, s1]", s2: "Sym(s2)", l_y_: "f32[s1, s2]"):
matmul: "f32[s0, s2]" = l_x_ @ l_y_; l_x_ = l_y_ = None
return (matmul,)
""",
)

View File

@ -10382,22 +10382,19 @@ ShapeEnv not equal: field values don't match:
> Left: {x.size()[0]: x.size()[0], x.size()[1]: x.size()[1], x.storage_offset(): x.storage_offset(), x.stride()[0]: x.stride()[0], x.stride()[1]: x.stride()[1]}
> Right: {}
==> source_to_var: values don't match.
> Left: {x.size()[0]: s93, x.size()[1]: s44}
> Right: {}
==> unique_ids: values don't match.
> Left: {44, 93}
> Left: {x.size()[0]: s0, x.size()[1]: s1}
> Right: {}
==> val_to_var: values don't match.
> Left: {0: 0, 1: 1, 2: s44, 3: s93}
> Left: {0: 0, 1: 1, 2: s1, 3: s0}
> Right: {0: 0, 1: 1}
==> var_to_range: values don't match.
> Left: {s44: VR[2, int_oo], s93: VR[2, int_oo]}
> Left: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
> Right: {}
==> var_to_sources: values don't match.
> Left: {s44: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)], s93: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)]}
> Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)]}
> Right: {}
==> var_to_val: values don't match.
> Left: {s44: 2, s93: 3}
> Left: {s0: 3, s1: 2}
> Right: {}
""",
)
@ -10456,13 +10453,13 @@ ShapeEnv not equal: field values don't match:
ShapeEnv not equal: field values don't match:
==> axioms: values don't match.
> Left: {(Mod(s93, 3)) < 0: False, (Mod(s93, 3)) <= 0: True, 0 < (Mod(s93, 3)): False, 0 <= (Mod(s93, 3)): True, Eq(0, Mod(s93, 3)): True, Eq(Mod(s93, 3), 0): True, Ne(0, Mod(s93, 3)): False, Ne(Mod(s93, 3), 0): False}
> Left: {(Mod(s0, 3)) < 0: False, (Mod(s0, 3)) <= 0: True, 0 < (Mod(s0, 3)): False, 0 <= (Mod(s0, 3)): True, Eq(0, Mod(s0, 3)): True, Eq(Mod(s0, 3), 0): True, Ne(0, Mod(s0, 3)): False, Ne(Mod(s0, 3), 0): False}
> Right: {}
==> divisible: values don't match.
> Left: {Mod(s93, 3)}
> Left: {Mod(s0, 3)}
> Right: {}
==> guards: values don't match.
> Left: [Eq(Mod(s93, 3), 0)]
> Left: [Eq(Mod(s0, 3), 0)]
> Right: []
==> name_to_node: values don't match.
> Left: {_assert, eq, mod, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
@ -10499,17 +10496,17 @@ ShapeEnv not equal: field values don't match:
> Left: {False: False, True: True}
> Right: {}
==> guards: values don't match.
> Left: [Eq(s93, 3)]
> Left: [Eq(s0, 3)]
> Right: []
==> name_to_node: values don't match.
> Left: {_assert, eq, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
> Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
==> replacements: values don't match.
> Left: {s93: 3}
> Left: {s0: 3}
> Right: {}
==> var_to_range: values don't match.
> Left: {s44: VR[2, int_oo], s93: VR[3, 3]}
> Right: {s44: VR[2, int_oo], s93: VR[2, int_oo]}
> Left: {s0: VR[3, 3], s1: VR[2, int_oo]}
> Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
""",
)
self._replay_and_check(main)
@ -10540,17 +10537,17 @@ ShapeEnv not equal: field values don't match:
ShapeEnv not equal: field values don't match:
==> axioms: values don't match.
> Left: {3 <= s93: True, s93 < 3: False}
> Left: {3 <= s0: True, s0 < 3: False}
> Right: {}
==> guards: values don't match.
> Left: [s93 >= 3]
> Left: [s0 >= 3]
> Right: []
==> name_to_node: values don't match.
> Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
> Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
==> var_to_range: values don't match.
> Left: {s44: VR[2, int_oo], s93: VR[3, int_oo]}
> Right: {s44: VR[2, int_oo], s93: VR[2, int_oo]}
> Left: {s0: VR[3, int_oo], s1: VR[2, int_oo]}
> Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
""",
)
self._replay_and_check(main)

View File

@ -4768,7 +4768,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertExpectedInline(
str(graph.code).strip(),
"""\
def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
l_x_ = L_x_
getitem_2 = l_x_[0]
sum_1 = getitem_2.sum(); getitem_2 = None

View File

@ -339,7 +339,7 @@ class StructuredTraceTest(TestCase):
{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"create_symbol": {"symbol": "s48", "val": "1", "vr": "[-int_oo, int_oo]", "source": "L['y']", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"create_symbol": {"symbol": "s0", "val": "1", "vr": "[-int_oo, int_oo]", "source": "L['y']", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_x_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"inductor_pre_grad_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
@ -767,15 +767,15 @@ class StructuredTraceTest(TestCase):
{"describe_storage": {"id": 0, "describer_id": "ID", "size": 200}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [5, 10], "is_leaf": true, "stride": [10, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"create_symbol": {"symbol": "s97", "val": "5", "vr": "[2, int_oo]", "source": "L['a'].size()[0]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"create_symbol": {"symbol": "s98", "val": "10", "vr": "[2, int_oo]", "source": "L['a'].size()[1]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"create_symbol": {"symbol": "s0", "val": "5", "vr": "[2, int_oo]", "source": "L['a'].size()[0]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"create_symbol": {"symbol": "s1", "val": "10", "vr": "[2, int_oo]", "source": "L['a'].size()[1]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_storage": {"id": 1, "describer_id": "ID", "size": 600}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [10, 15], "is_leaf": true, "stride": [15, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['b']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"create_symbol": {"symbol": "s52", "val": "10", "vr": "[2, int_oo]", "source": "L['b'].size()[0]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"create_symbol": {"symbol": "s20", "val": "15", "vr": "[2, int_oo]", "source": "L['b'].size()[1]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"guard_added_fast": {"expr": "Eq(s98, s52)", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_a_": ["s97", "s52"], "l_b_": ["s52", "s20"], "matmul": ["s97", "s20"]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"create_symbol": {"symbol": "s2", "val": "10", "vr": "[2, int_oo]", "source": "L['b'].size()[0]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"create_symbol": {"symbol": "s3", "val": "15", "vr": "[2, int_oo]", "source": "L['b'].size()[1]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"guard_added_fast": {"expr": "Eq(s1, s2)", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_a_": ["s0", "s1"], "l_b_": ["s1", "s3"], "matmul": ["s0", "s3"]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
""", # noqa: B950

View File

@ -1373,21 +1373,21 @@ class GraphModule(torch.nn.Module):
# During fakeifying, we end up allocating a separate symint
# for the outer and inner tensor (in this test, s0 is unused).
expected_var_to_val = {
"s50": 4,
"s77": 8,
"s0": 8,
"s1": 4,
}
expected_var_to_sources = {
"s50": "L['x'].inner_elem.size()[0]",
"s77": "L['x'].size()[0]",
"s0": "L['x'].size()[0]",
"s1": "L['x'].inner_elem.size()[0]",
}
self.assertEqual(curr_var_to_val, expected_var_to_val)
self.assertEqual(curr_var_to_sources, expected_var_to_sources)
self.assertExpectedInline(
"\n".join(guards),
"""\
Eq(2*s50, s77)
2*s50 < 13
s50 > 3""",
Eq(2*s1, s0)
2*s1 < 13
s1 > 3""",
)
def test_wrapper_subclass_with_same_sized_inner_tensor(self):
@ -1976,9 +1976,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
mul: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None
mul_3: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
mul: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None
mul_3: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None
return (mul, mul_3, primals_5, primals_7, primals_7, primals_1, primals_5, primals_7)
""", # noqa: B950
)
@ -1987,9 +1987,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s47)", primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s47, s16]", tangents_2: "f32[s47, s16]"):
mul_8: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = None
mul_9: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(tangents_2, primals_1); tangents_2 = primals_1 = None
def forward(self, primals_1: "Sym(s0)", primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0, s1]", tangents_2: "f32[s0, s1]"):
mul_8: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = None
mul_9: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_2, primals_1); tangents_2 = primals_1 = None
return (None, None, mul_8, mul_9, primals_5, primals_7, primals_7)
""", # noqa: B950
)
@ -2009,12 +2009,12 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
view: "f32[s16, s47]" = torch.ops.aten.view.default(clone, [primals_2, primals_1]); clone = None
view_1: "f32[s16, s47]" = torch.ops.aten.view.default(clone_1, [primals_2, primals_1]); clone_1 = primals_1 = None
view: "f32[s1, s0]" = torch.ops.aten.view.default(clone, [primals_2, primals_1]); clone = None
view_1: "f32[s1, s0]" = torch.ops.aten.view.default(clone_1, [primals_2, primals_1]); clone_1 = primals_1 = None
return (view, view_1, primals_2, primals_5, primals_5, primals_5, primals_7)
""", # noqa: B950
)
@ -2023,9 +2023,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s16, s47]", tangents_2: "f32[s16, s47]"):
view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s1, s0]", tangents_2: "f32[s1, s0]"):
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
""", # noqa: B950
)
@ -2047,15 +2047,15 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s97)", primals_2: "Sym(s98)", primals_3: "f32[s97, s98]", primals_4: "f32[s97, s98]", primals_5: "Sym(s97)", primals_6: "Sym(s98)", primals_7: "Sym(s98)"):
mul: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None
mul_3: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None
mul_8: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul, primals_2); mul = None
mul_11: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_3, primals_2); mul_3 = None
mul_16: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_8, primals_1); mul_8 = None
mul_19: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_11, primals_1); mul_11 = None
mul_24: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_16, primals_2); mul_16 = None
mul_27: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_19, primals_2); mul_19 = None
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
mul: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None
mul_3: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None
mul_8: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul, primals_2); mul = None
mul_11: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_3, primals_2); mul_3 = None
mul_16: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_8, primals_1); mul_8 = None
mul_19: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_11, primals_1); mul_11 = None
mul_24: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_16, primals_2); mul_16 = None
mul_27: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_19, primals_2); mul_19 = None
return (mul_24, mul_27, primals_5, primals_7, primals_7, primals_1, primals_2, primals_5, primals_7)
""", # noqa: B950
)
@ -2064,15 +2064,15 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s97)", primals_2: "Sym(s98)", primals_5: "Sym(s97)", primals_7: "Sym(s98)", tangents_1: "f32[s97, s98]", tangents_2: "f32[s97, s98]"):
mul_32: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(tangents_1, primals_2); tangents_1 = None
mul_33: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(tangents_2, primals_2); tangents_2 = None
mul_34: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_32, primals_1); mul_32 = None
mul_35: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_33, primals_1); mul_33 = None
mul_36: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_34, primals_2); mul_34 = None
mul_37: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_35, primals_2); mul_35 = primals_2 = None
mul_38: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_36, primals_1); mul_36 = None
mul_39: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_37, primals_1); mul_37 = primals_1 = None
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0, s1]", tangents_2: "f32[s0, s1]"):
mul_32: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_1, primals_2); tangents_1 = None
mul_33: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_2, primals_2); tangents_2 = None
mul_34: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_32, primals_1); mul_32 = None
mul_35: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_33, primals_1); mul_33 = None
mul_36: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_34, primals_2); mul_34 = None
mul_37: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_35, primals_2); mul_35 = primals_2 = None
mul_38: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_36, primals_1); mul_36 = None
mul_39: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_37, primals_1); mul_37 = primals_1 = None
return (None, None, mul_38, mul_39, primals_5, primals_7, primals_7)
""", # noqa: B950
)
@ -2092,12 +2092,12 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
view: "f32[s47, s16]" = torch.ops.aten.view.default(clone, [primals_1, primals_2]); clone = None
view_1: "f32[s47, s16]" = torch.ops.aten.view.default(clone_1, [primals_1, primals_2]); clone_1 = primals_1 = primals_2 = None
view: "f32[s0, s1]" = torch.ops.aten.view.default(clone, [primals_1, primals_2]); clone = None
view_1: "f32[s0, s1]" = torch.ops.aten.view.default(clone_1, [primals_1, primals_2]); clone_1 = primals_1 = primals_2 = None
return (view, view_1, primals_5, primals_7, primals_7, primals_5, primals_7)
""", # noqa: B950
)
@ -2106,9 +2106,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s47, s16]", tangents_2: "f32[s47, s16]"):
view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0, s1]", tangents_2: "f32[s0, s1]"):
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
""", # noqa: B950
)
@ -2128,13 +2128,13 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
mul_6: "Sym(s16*s47)" = primals_1 * primals_2; primals_1 = primals_2 = None
view: "f32[s16*s47]" = torch.ops.aten.view.default(clone, [mul_6]); clone = None
view_1: "f32[s16*s47]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None
mul_6: "Sym(s0*s1)" = primals_1 * primals_2; primals_1 = primals_2 = None
view: "f32[s0*s1]" = torch.ops.aten.view.default(clone, [mul_6]); clone = None
view_1: "f32[s0*s1]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None
return (view, view_1, mul_6, primals_5, primals_7)
""", # noqa: B950
)
@ -2143,9 +2143,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s16*s47]", tangents_2: "f32[s16*s47]"):
view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0*s1]", tangents_2: "f32[s0*s1]"):
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
""", # noqa: B950
)
@ -2165,13 +2165,13 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
mul_6: "Sym(s16*s47)" = primals_1 * primals_2; primals_1 = primals_2 = None
view: "f32[s16*s47]" = torch.ops.aten.view.default(clone, [mul_6])
view_1: "f32[s16*s47]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None
mul_6: "Sym(s0*s1)" = primals_1 * primals_2; primals_1 = primals_2 = None
view: "f32[s0*s1]" = torch.ops.aten.view.default(clone, [mul_6])
view_1: "f32[s0*s1]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None
return (clone, view, view_1, mul_6, primals_5, primals_7)
""", # noqa: B950
)
@ -2180,9 +2180,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s16*s47]", tangents_2: "f32[s16*s47]"):
view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0*s1]", tangents_2: "f32[s0*s1]"):
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
""", # noqa: B950
)
@ -2261,13 +2261,13 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[1].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s16)", primals_2: "f32[3, s16]", primals_3: "f32[3, s16]", primals_4: "Sym(s16)", primals_5: "Sym(s16)"):
clone: "f32[3, s16]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
def forward(self, primals_1: "Sym(s0)", primals_2: "f32[3, s0]", primals_3: "f32[3, s0]", primals_4: "Sym(s0)", primals_5: "Sym(s0)"):
clone: "f32[3, s0]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
clone_1: "f32[3, s0]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1])
sym_numel_default: "Sym(3*s16)" = torch.ops.aten.sym_numel.default(clone)
view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1])
view: "f32[3*s0]" = torch.ops.aten.view.default(clone, [-1])
sym_numel_default: "Sym(3*s0)" = torch.ops.aten.sym_numel.default(clone)
view_1: "f32[3*s0]" = torch.ops.aten.view.default(clone_1, [-1])
return (clone, view, view_1, sym_numel_default, clone_1, primals_5)
""", # noqa: B950
)
@ -2287,9 +2287,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[1].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s16)", tangents_1: "f32[3*s16]", tangents_2: "f32[3*s16]"):
view_2: "f32[3, s16]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None
view_3: "f32[3, s16]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None
def forward(self, primals_5: "Sym(s0)", tangents_1: "f32[3*s0]", tangents_2: "f32[3*s0]"):
view_2: "f32[3, s0]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None
view_3: "f32[3, s0]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None
return (None, view_2, view_3, primals_5, primals_5)
""", # noqa: B950
)
@ -2317,13 +2317,13 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s16)", primals_2: "f32[3, s16]", primals_3: "f32[3, s16]", primals_4: "Sym(s16)", primals_5: "Sym(s16)"):
clone: "f32[3, s16]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
def forward(self, primals_1: "Sym(s0)", primals_2: "f32[3, s0]", primals_3: "f32[3, s0]", primals_4: "Sym(s0)", primals_5: "Sym(s0)"):
clone: "f32[3, s0]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
clone_1: "f32[3, s0]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1])
sym_numel_default: "Sym(3*s16)" = torch.ops.aten.sym_numel.default(clone)
view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1])
view: "f32[3*s0]" = torch.ops.aten.view.default(clone, [-1])
sym_numel_default: "Sym(3*s0)" = torch.ops.aten.sym_numel.default(clone)
view_1: "f32[3*s0]" = torch.ops.aten.view.default(clone_1, [-1])
return (clone, view, view_1, sym_numel_default, clone_1, primals_5)
""", # noqa: B950
)
@ -2332,9 +2332,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s16)", tangents_1: "f32[3*s16]", tangents_2: "f32[3*s16]"):
view_2: "f32[3, s16]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None
view_3: "f32[3, s16]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None
def forward(self, primals_5: "Sym(s0)", tangents_1: "f32[3*s0]", tangents_2: "f32[3*s0]"):
view_2: "f32[3, s0]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None
view_3: "f32[3, s0]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None
return (None, view_2, view_3, primals_5, primals_5)
""", # noqa: B950
)
@ -2501,10 +2501,10 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s51)", primals_2: "Sym(s71)", primals_3: "Sym(s55)", primals_4: "f64[s64, s55]", primals_5: "i64[s51 + 1]", primals_6: "f32[s0, 0]", primals_7: "f32[s83, 0]", primals_8: "Sym(s51)", primals_9: "Sym(s55)", primals_10: "Sym(s55)"):
clone: "f64[s64, s55]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
def forward(self, primals_1: "Sym(s2)", primals_2: "Sym(s3)", primals_3: "Sym(s1)", primals_4: "f64[s0, s1]", primals_5: "i64[s2 + 1]", primals_6: "f32[s6, 0]", primals_7: "f32[s7, 0]", primals_8: "Sym(s2)", primals_9: "Sym(s1)", primals_10: "Sym(s1)"):
clone: "f64[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
mul: "f64[s64, s55]" = torch.ops.aten.mul.Tensor(clone, primals_1); clone = None
mul: "f64[s0, s1]" = torch.ops.aten.mul.Tensor(clone, primals_1); clone = None
return (mul, primals_5, primals_6, primals_7, primals_8, primals_10, primals_10, primals_1, primals_8, primals_10)
""", # noqa: B950
)
@ -2513,8 +2513,8 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s51)", primals_8: "Sym(s51)", primals_10: "Sym(s55)", tangents_1: "f64[s64, s55]", tangents_2: "i64[s51 + 1]", tangents_3: "f32[s0, 0]", tangents_4: "f32[s83, 0]"):
mul_1: "f64[s64, s55]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = primals_1 = None
def forward(self, primals_1: "Sym(s2)", primals_8: "Sym(s2)", primals_10: "Sym(s1)", tangents_1: "f64[s0, s1]", tangents_2: "i64[s2 + 1]", tangents_3: "f32[s6, 0]", tangents_4: "f32[s7, 0]"):
mul_1: "f64[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = primals_1 = None
return (None, None, None, mul_1, tangents_2, tangents_3, tangents_4, primals_8, primals_10, primals_10)
""", # noqa: B950
)
@ -2534,11 +2534,11 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s51)", primals_2: "Sym(s71)", primals_3: "Sym(s55)", primals_4: "f64[s64, s55]", primals_5: "i64[s51 + 1]", primals_6: "f32[s0, 0]", primals_7: "f32[s83, 0]", primals_8: "Sym(s51)", primals_9: "Sym(s55)", primals_10: "Sym(s55)"):
clone: "f64[s64, s55]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
def forward(self, primals_1: "Sym(s2)", primals_2: "Sym(s3)", primals_3: "Sym(s1)", primals_4: "f64[s0, s1]", primals_5: "i64[s2 + 1]", primals_6: "f32[s6, 0]", primals_7: "f32[s7, 0]", primals_8: "Sym(s2)", primals_9: "Sym(s1)", primals_10: "Sym(s1)"):
clone: "f64[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
cat: "f64[s64, 2*s55]" = torch.ops.aten.cat.default([clone, clone], 1); clone = None
add_2: "Sym(2*s55)" = primals_10 + primals_10
cat: "f64[s0, 2*s1]" = torch.ops.aten.cat.default([clone, clone], 1); clone = None
add_2: "Sym(2*s1)" = primals_10 + primals_10
return (cat, primals_5, primals_6, primals_7, primals_8, add_2, add_2, primals_8, primals_10, add_2)
""", # noqa: B950
)
@ -2547,11 +2547,11 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_8: "Sym(s51)", primals_10: "Sym(s55)", add_2: "Sym(2*s55)", tangents_1: "f64[s64, 2*s55]", tangents_2: "i64[s51 + 1]", tangents_3: "f32[s0, 0]", tangents_4: "f32[s83, 0]"):
slice_1: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, 0, primals_10)
slice_2: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, primals_10, add_2); tangents_1 = add_2 = None
def forward(self, primals_8: "Sym(s2)", primals_10: "Sym(s1)", add_2: "Sym(2*s1)", tangents_1: "f64[s0, 2*s1]", tangents_2: "i64[s2 + 1]", tangents_3: "f32[s6, 0]", tangents_4: "f32[s7, 0]"):
slice_1: "f64[s0, s1]" = torch.ops.aten.slice.Tensor(tangents_1, 1, 0, primals_10)
slice_2: "f64[s0, s1]" = torch.ops.aten.slice.Tensor(tangents_1, 1, primals_10, add_2); tangents_1 = add_2 = None
add_4: "f64[s64, s55]" = torch.ops.aten.add.Tensor(slice_1, slice_2); slice_1 = slice_2 = None
add_4: "f64[s0, s1]" = torch.ops.aten.add.Tensor(slice_1, slice_2); slice_1 = slice_2 = None
return (None, None, None, add_4, tangents_2, tangents_3, tangents_4, primals_8, primals_10, primals_10)
""", # noqa: B950
)
@ -2580,7 +2580,7 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "Sym(s51)", arg1_1: "Sym(s71)", arg2_1: "Sym(s55)", arg3_1: "f64[9, s55]", arg4_1: "i64[s51 + 1]", arg5_1: "f32[s0, 0]", arg6_1: "f32[s83, 0]", arg7_1: "Sym(s51)", arg8_1: "Sym(s55)", arg9_1: "Sym(s55)"):
def forward(self, arg0_1: "Sym(s3)", arg1_1: "Sym(s4)", arg2_1: "Sym(s2)", arg3_1: "f64[9, s2]", arg4_1: "i64[s3 + 1]", arg5_1: "f32[s7, 0]", arg6_1: "f32[s8, 0]", arg7_1: "Sym(s3)", arg8_1: "Sym(s2)", arg9_1: "Sym(s2)"):
randn: "f64[2, 5]" = torch.ops.aten.randn.default([2, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
randn_1: "f64[3, 5]" = torch.ops.aten.randn.default([3, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
randn_2: "f64[4, 5]" = torch.ops.aten.randn.default([4, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
@ -2594,13 +2594,13 @@ class <lambda>(torch.nn.Module):
zeros_1: "f32[2, 0]" = torch.ops.aten.zeros.default([2, 0], device = device(type='cpu'), pin_memory = False)
zeros_2: "f32[4, 0]" = torch.ops.aten.zeros.default([4, 0], device = device(type='cpu'), pin_memory = False)
cat_2: "f64[9, s55 + 5]" = torch.ops.aten.cat.default([cat, arg3_1], 1); cat = arg3_1 = None
cat_2: "f64[9, s2 + 5]" = torch.ops.aten.cat.default([cat, arg3_1], 1); cat = arg3_1 = None
sin: "f64[9, s55 + 5]" = torch.ops.aten.sin.default(cat_2)
mul: "f64[9, s55 + 5]" = torch.ops.aten.mul.Tensor(sin, 3); sin = None
sin: "f64[9, s2 + 5]" = torch.ops.aten.sin.default(cat_2)
mul: "f64[9, s2 + 5]" = torch.ops.aten.mul.Tensor(sin, 3); sin = None
sym_size_int: "Sym(s55 + 5)" = torch.ops.aten.sym_size.int(cat_2, 1); cat_2 = None
sym_stride_int: "Sym(s55 + 5)" = torch.ops.aten.sym_stride.int(mul, 0)
sym_size_int: "Sym(s2 + 5)" = torch.ops.aten.sym_size.int(cat_2, 1); cat_2 = None
sym_stride_int: "Sym(s2 + 5)" = torch.ops.aten.sym_stride.int(mul, 0)
return (mul, cat_1, zeros_1, zeros_2, sym_size_int, sym_stride_int)
""", # noqa: B950
)
@ -2757,10 +2757,10 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase):
norm_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, s71: "Sym(s71)", L_nt_: "f64[3, s71, 5]"):
def forward(self, s1: "Sym(s1)", L_nt_: "f64[3, s1, 5]"):
l_nt_ = L_nt_
add: "f64[3, s71, 5]" = l_nt_ + 2; l_nt_ = None
add: "f64[3, s1, 5]" = l_nt_ + 2; l_nt_ = None
return (add,)
""", # noqa: B950
)
@ -3254,27 +3254,27 @@ class GraphModule(torch.nn.Module):
# varies based on the type of view
guard_str = "\n".join(guards)
if nt_view_name == "subclass_dense":
self.assertExpectedInline(guard_str, """Eq(s85 - 1, s77)""")
self.assertExpectedInline(guard_str, """Eq(s3 - 1, s0)""")
elif nt_view_name == "dense_subclass_dense_subclass":
self.assertExpectedInline(
guard_str,
"""\
Eq(s85 - 1, s77)
Eq(s80 - 1, s78)
Eq(s72, s71)""",
Eq(s5 - 1, s2)
Eq(s12 - 1, s7)
Eq(s11, s9)""",
)
elif nt_view_name.startswith("base_is_nt_True"):
self.assertExpectedInline(
guard_str,
"""Eq(s17 - 1, s83)""",
"""Eq(s3 - 1, s0)""",
)
else:
self.assertExpectedInline(
guard_str,
"""\
Eq(s85 - 1, s64)
Eq(s80 - 1, s77)
Eq(s72, s71)""",
Eq(s4 - 1, s1)
Eq(s13 - 1, s8)
Eq(s12, s10)""",
)
return gm

View File

@ -1560,7 +1560,7 @@ graph():
)
with self.assertRaisesRegex(
error_type,
r"Real tensor propagation found an output size mismatch between fake shape s\d+ and real shape 4, "
r"Real tensor propagation found an output size mismatch between fake shape s1 and real shape 4, "
r"at output\.size\(0\), for func: mylib.foo.default",
):
export(
@ -2848,7 +2848,7 @@ def forward(self, p_linear_weight, p_linear_bias, x):
with self.assertRaisesRegex(
RuntimeError,
"Expected input.*shape.*= 9 to be "
"of the form 2\\*s92, where s92 is an integer",
"of the form 2\\*s1, where s1 is an integer",
):
ep.module()(torch.randn(9))
@ -3506,11 +3506,8 @@ def forward(self, x):
dynamic_shapes=({0: Dim("x")},),
)
# Since symbol names are based on hash of source names, and these differ across inference and
# training, we do range comparisons instead.
self.assertEqual(
str(ep_for_training.range_constraints.values()),
str(ep_for_real.range_constraints.values()),
str(ep_for_training.range_constraints), str(ep_for_real.range_constraints)
)
def test_export_for_training_with_container_type(self):
@ -4388,7 +4385,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
em.module()(torch.randn(4, 3))
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Eq\(Mod\(s27\*s77, s77 \- 1\), 0\)",
r"Runtime assertion failed for expression Eq\(Mod\(s0\*s1, s0 \- 1\), 0\)",
):
em.module()(torch.randn(4, 5))
@ -4399,7 +4396,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
x = torch.randn(3, 5)
with self.assertRaisesRegex(
RuntimeError,
"Expected.*shape\\[1\\] = 5 to be of the form 2\\*s33, where s33 is an integer",
"Expected.*shape\\[1\\] = 5 to be of the form 2\\*s1, where s1 is an integer",
):
em.module()(x)
@ -4958,14 +4955,11 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
)
self.assertEqual(
[
# First dimension varies across strict and non-strict
# since the source names are different, resulting in
# different symbol names.
str(node.meta["val"].shape[1:])
str(node.meta["val"].shape)
for node in efoo.graph_module.graph.nodes
if node.op == "placeholder"
],
["torch.Size([2, 3])", "torch.Size([3, 4])"],
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
)
@testing.expectedFailureCppSerDes
@ -5103,10 +5097,14 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
"y": (batch, size, size),
},
)
for node in efoo.graph_module.graph.nodes:
if node.op == "placeholder":
self.assertEqual(node.meta["val"].shape[1], node.meta["val"].shape[2])
self.assertEqual(
[
str(node.meta["val"].shape)
for node in efoo.graph_module.graph.nodes
if node.op == "placeholder"
],
["torch.Size([s0, s1, s1])", "torch.Size([s0, s1, s1])"],
)
self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
# pass dynamic shapes of inputs [multiple, mostly distinct]
@ -5117,14 +5115,13 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
inputs,
dynamic_shapes={"x": (batch, M, K), "y": (batch, K, N)},
)
placeholders = [
node.meta["val"].shape
for node in efoo.graph_module.graph.nodes
if node.op == "placeholder"
]
self.assertEqual(
placeholders[0][2],
placeholders[1][1],
[
str(node.meta["val"].shape)
for node in efoo.graph_module.graph.nodes
if node.op == "placeholder"
],
["torch.Size([s0, s1, s2])", "torch.Size([s0, s2, s5])"],
)
self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
@ -5141,14 +5138,11 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
)
self.assertEqual(
[
# First dimension varies across strict and non-strict
# since the source names are different, resulting in
# different symbol names.
str(node.meta["val"].shape[1:])
str(node.meta["val"].shape)
for node in efoo.graph_module.graph.nodes
if node.op == "placeholder"
],
["torch.Size([2, 3])", "torch.Size([3, 4])"],
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
)
self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
@ -5165,14 +5159,11 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
)
self.assertEqual(
[
# First dimension varies across strict and non-strict
# since the source names are different, resulting in
# different symbol names.
str(node.meta["val"].shape[1:])
str(node.meta["val"].shape)
for node in efoo.graph_module.graph.nodes
if node.op == "placeholder"
],
["torch.Size([2, 3])", "torch.Size([3, 4])"],
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
)
self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
@ -5482,7 +5473,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
if node.op == "placeholder"
]
self.assertEqual(len(input_shapes), 9)
self.assertTrue(all(shape == "torch.Size([s3])" for shape in input_shapes))
self.assertTrue(all(shape == "torch.Size([s0])" for shape in input_shapes))
def test_error_does_not_reference_eager_fallback(self):
class Module(torch.nn.Module):
@ -11146,7 +11137,7 @@ def forward(self, x, y):
self.assertEqual(out2.shape, torch.ones(11, 4, 3).shape)
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Eq\(Mod\(s27\*s77, 4\*s77 \- 4\), 0\) on node 'eq.*'",
r"Runtime assertion failed for expression Eq\(Mod\(s0\*s1, 4\*s0 \- 4\), 0\) on node 'eq.*'",
):
ep.module()(torch.randn(8, 8)) # fail
@ -11178,7 +11169,7 @@ def forward(self, x, y):
self.assertEqual(out2.shape, torch.ones(40).shape)
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Eq\((.*)\) on node '.*'",
r"Runtime assertion failed for expression Eq\(s0\*s1, s2\*s3\) on node 'eq.*'",
): # fail only at runtime
ep.module()(torch.randn(5, 8), torch.randn(4, 5), torch.randn(30)) # fail
@ -11205,7 +11196,7 @@ def forward(self, x, y):
self.assertEqual(out1.shape, torch.ones(126).shape)
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Eq\((.*)\) on node '.*'",
r"Runtime assertion failed for expression Eq\(s0\*s1\*s2, s3\) on node 'eq.*'",
): # fail only at runtime
ep.module()(torch.randn(4, 3, 2), torch.randn(10)) # fail
@ -11286,12 +11277,12 @@ def forward(self, x, y):
)
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Ne\(s77, 20\)",
r"Runtime assertion failed for expression Ne\(s0, 20\)",
):
ep.module()(torch.randn(20, 20, 16))
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Ne\(Mod\(s77, 20\), 0\)",
r"Runtime assertion failed for expression Ne\(Mod\(s0, 20\), 0\)",
):
ep.module()(torch.randn(400, 20, 16))
ep.module()(torch.randn(42, 20, 16))
@ -11329,17 +11320,17 @@ def forward(self, x, y):
self.assertEqual(out1.shape, torch.ones(27).shape)
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Ne\(s77, s17\)",
r"Runtime assertion failed for expression Ne\(s0, s1\)",
): # fail only at runtime
ep.module()(torch.randn(4), torch.randn(4)) # fail
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Ne\(s77, s17\**3\)",
r"Runtime assertion failed for expression Ne\(s0, s1\**3\)",
):
ep.module()(torch.randn(64), torch.randn(4)) # fail
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Eq\(s77\**2, 3\*s17\)",
r"Runtime assertion failed for expression Eq\(s0\**2, 3\*s1\)",
):
ep.module()(torch.randn(10), torch.randn(9)) # fail

View File

@ -539,12 +539,8 @@ def forward(self, x):
ep.range_constraints[symint] = ValueRanges(lower=lower_range, upper=upper_range)
serialized = ExportedProgramSerializer().serialize(ep)
self.assertEqual(
serialized.exported_program.range_constraints["s77"].min_val, 2
)
self.assertEqual(
serialized.exported_program.range_constraints["s77"].max_val, 3
)
self.assertEqual(serialized.exported_program.range_constraints["s0"].min_val, 2)
self.assertEqual(serialized.exported_program.range_constraints["s0"].max_val, 3)
def test_kwargs_default(self) -> None:
"""

View File

@ -5941,8 +5941,8 @@ class TestAOTModuleSimplified(AOTTestCase):
self.assertExpectedInline(
shape_env.format_guards(),
"""\
- Eq(s49, 20)
- Eq(s70, 30)""",
- Eq(s1, 20)
- Eq(s2, 30)""",
)
assert torch.allclose(ref[0], res[0])

View File

@ -4553,10 +4553,10 @@ def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_bo
gm.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
sym_size_int = torch.ops.aten.sym_size.int(arg3_1, 1)
sym_size_int = torch.ops.aten.sym_size.int(arg2_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg2_1, 1)
sym_size_int_2 = torch.ops.aten.sym_size.int(arg2_1, 0)
sym_size_int_3 = torch.ops.aten.sym_size.int(arg3_1, 0)
sym_size_int_2 = torch.ops.aten.sym_size.int(arg3_1, 0)
sym_size_int_3 = torch.ops.aten.sym_size.int(arg3_1, 1)
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), (sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3)); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None
@ -4719,10 +4719,10 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1
def forward(self, a_1, b_1):
sum_1 = torch.ops.aten.sum.default(a_1)
gt = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
sym_size_int = torch.ops.aten.sym_size.int(a_1, 1)
sym_size_int_1 = torch.ops.aten.sym_size.int(b_1, 0)
sym_size_int_2 = torch.ops.aten.sym_size.int(b_1, 1)
sym_size_int_3 = torch.ops.aten.sym_size.int(a_1, 0)
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(a_1, 1)
sym_size_int_2 = torch.ops.aten.sym_size.int(b_1, 0)
sym_size_int_3 = torch.ops.aten.sym_size.int(b_1, 1)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (a_1, b_1, sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3)); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None
@ -5856,7 +5856,7 @@ def forward(self, x_1):
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int_1, sym_size_int)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = sym_size_int = None
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int, sym_size_int_1)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int = sym_size_int_1 = None
getitem = cond[0]; cond = None
return getitem""", # noqa: B950
)
@ -5969,7 +5969,7 @@ def forward(self, x_1):
false_graph_0 = self.false_graph_0
_tensor_constant0 = self._tensor_constant0
_tensor_constant1 = self._tensor_constant1
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, _tensor_constant0, sym_size_int_1, sym_size_int, _tensor_constant1)); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = sym_size_int_1 = sym_size_int = _tensor_constant1 = None
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, _tensor_constant0, sym_size_int, sym_size_int_1, _tensor_constant1)); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = sym_size_int = sym_size_int_1 = _tensor_constant1 = None
getitem = cond[0]; cond = None
return getitem""", # noqa: B950
)
@ -6209,7 +6209,7 @@ def forward(self, x_1):
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int_1, sym_size_int)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = sym_size_int = None
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int, sym_size_int_1)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int = sym_size_int_1 = None
getitem = cond[0]; cond = None
return getitem""", # noqa: B950
)
@ -6558,14 +6558,14 @@ def forward(self, l_inp_, l_tmp_):
self.assertExpectedInline(
backend.graphs[0].code.strip(),
"""\
def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor, L_self_num : torch.SymInt):
def forward(self, s0 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor, L_self_num : torch.SymInt):
l_a_ = L_a_
l_b_ = L_b_
l_self_num = L_self_num
tensor = torch.tensor([True])
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, l_self_num, s97)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = l_self_num = s97 = None
cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, l_self_num, s0)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = l_self_num = s0 = None
getitem = cond[0]; cond = None
return (getitem,)""", # noqa: B950
)
@ -6753,10 +6753,10 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_
"""\
class GraphModule(torch.nn.Module):
def forward(self, x):
x: "f32[s35, 3]";
x: "f32[s0, 3]";
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
sym_size_int_1: "Sym(s35)" = torch.ops.aten.sym_size.int(x, 0)
sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
@ -6770,27 +6770,27 @@ class GraphModule(torch.nn.Module):
gt_1: "Sym(u1 > 0)" = getitem_2 > 0
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 on node 'gt_1'"); gt_1 = _assert_scalar_default_1 = None
getitem_1: "f32[s35, 3]" = while_loop[1]; while_loop = None
getitem_1: "f32[s0, 3]" = while_loop[1]; while_loop = None
add: "Sym(u1 + 1)" = getitem_2 + 1
add_1: "f32[s35, 3]" = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = None
add_1: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = None
lt: "Sym(u1 < s35)" = getitem_2 < sym_size_int_1; sym_size_int_1 = None
lt: "Sym(u1 < s0)" = getitem_2 < sym_size_int_1; sym_size_int_1 = None
mul: "Sym(2*u1)" = getitem_2 * 2; getitem_2 = None
ones: "f32[2*u1]" = torch.ops.aten.ones.default([mul], device = device(type='cpu'), pin_memory = False); mul = None
return pytree.tree_unflatten((add, add_1, lt, ones), self._out_spec)
class while_loop_cond_graph_0(torch.nn.Module):
def forward(self, it_1: "Sym(u0)", x_1: "f32[s35, 3]"):
sym_size_int: "Sym(s35)" = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
lt: "Sym(u0 < s35)" = it_1 < sym_size_int; it_1 = sym_size_int = None
def forward(self, it_1: "Sym(u0)", x_1: "f32[s0, 3]"):
sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
lt: "Sym(u0 < s0)" = it_1 < sym_size_int; it_1 = sym_size_int = None
return lt
class while_loop_body_graph_0(torch.nn.Module):
def forward(self, it_1: "Sym(u0)", x_1: "f32[s35, 3]"):
clone: "f32[s35, 3]" = torch.ops.aten.clone.default(x_1); x_1 = None
def forward(self, it_1: "Sym(u0)", x_1: "f32[s0, 3]"):
clone: "f32[s0, 3]" = torch.ops.aten.clone.default(x_1); x_1 = None
select: "f32[3]" = torch.ops.aten.select.int(clone, 0, it_1)
select_1: "f32[3]" = torch.ops.aten.select.int(clone, 0, it_1)
add: "f32[3]" = torch.ops.aten.add.Tensor(select_1, it_1); select_1 = None
@ -6820,12 +6820,12 @@ class GraphModule(torch.nn.Module):
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"):
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_x_: "f32[s0, s1]"):
l_x_ = L_x_
cond_fn_0 = self.cond_fn_0
body_fn_0 = self.body_fn_0
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (0, l_x_), (s27, s77)); cond_fn_0 = body_fn_0 = l_x_ = s27 = None
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (0, l_x_), (s0, s1)); cond_fn_0 = body_fn_0 = l_x_ = s1 = None
getitem_4: "Sym(u1)" = while_loop[0]
@ -6835,49 +6835,49 @@ class GraphModule(torch.nn.Module):
gt_1: "Sym(u1 > 0)" = getitem_4 > 0
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 on node 'gt_1'"); gt_1 = _assert_scalar_default_1 = None
out_x: "f32[s77, s27]" = while_loop[1]; while_loop = None
out_x: "f32[s0, s1]" = while_loop[1]; while_loop = None
add: "Sym(u1 + 1)" = getitem_4 + 1
add_1: "f32[s77, s27]" = getitem_4 + out_x; out_x = None
add_1: "f32[s0, s1]" = getitem_4 + out_x; out_x = None
lt: "Sym(u1 < s77)" = getitem_4 < s77; s77 = None
lt: "Sym(u1 < s0)" = getitem_4 < s0; s0 = None
mul: "Sym(2*u1)" = getitem_4 * 2; getitem_4 = None
ones: "f32[2*u1]" = torch.ones(mul); mul = None
return (add, add_1, lt, ones)
class cond_fn_0(torch.nn.Module):
def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s77, s27]", s27, s77):
s27_1 = s27
s77_1 = s77
def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s0, s1]", s0, s1):
s0_1 = s0
s1_1 = s1
size = l_x_.size(); l_x_ = None
getitem: "Sym(s77)" = size[0]
getitem_1: "Sym(s27)" = size[1]; size = getitem_1 = None
lt: "Sym(u0 < s77)" = unbacked_symint < getitem; unbacked_symint = getitem = None
getitem: "Sym(s0)" = size[0]
getitem_1: "Sym(s1)" = size[1]; size = getitem_1 = None
lt: "Sym(u0 < s0)" = unbacked_symint < getitem; unbacked_symint = getitem = None
return lt
class body_fn_0(torch.nn.Module):
def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s77, s27]", s27, s77):
s27_1 = s27
s77_1 = s77
def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s0, s1]", s0, s1):
s0_1 = s0
s1_1 = s1
x_clone: "f32[s77, s27]" = l_x_.clone()
x_clone: "f32[s0, s1]" = l_x_.clone()
ge: "Sym(u0 >= 0)" = unbacked_symint >= 0
_check = torch._check(ge); ge = _check = None
size = l_x_.size(); l_x_ = None
getitem: "Sym(s77)" = size[0]
getitem_1: "Sym(s27)" = size[1]; size = getitem_1 = None
lt: "Sym(u0 < s77)" = unbacked_symint < getitem; getitem = None
getitem: "Sym(s0)" = size[0]
getitem_1: "Sym(s1)" = size[1]; size = getitem_1 = None
lt: "Sym(u0 < s0)" = unbacked_symint < getitem; getitem = None
_check_1 = torch._check(lt); lt = _check_1 = None
select: "f32[s27]" = x_clone.select(0, unbacked_symint)
select_1: "f32[s27]" = x_clone.select(0, unbacked_symint)
add: "f32[s27]" = select_1 + unbacked_symint; select_1 = None
copy_: "f32[s27]" = select.copy_(add); select = add = copy_ = None
select: "f32[s1]" = x_clone.select(0, unbacked_symint)
select_1: "f32[s1]" = x_clone.select(0, unbacked_symint)
add: "f32[s1]" = select_1 + unbacked_symint; select_1 = None
copy_: "f32[s1]" = select.copy_(add); select = add = copy_ = None
add_1: "Sym(u0 + 1)" = unbacked_symint + 1; unbacked_symint = None
return (add_1, x_clone)
@ -7048,12 +7048,12 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, x):
x: "f32[s77, 3]";
x: "f32[s0, 3]";
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0)
sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
sin: "f32[s77, 3]" = torch.ops.aten.sin.default(x); x = None
sin: "f32[s0, 3]" = torch.ops.aten.sin.default(x); x = None
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
@ -7065,19 +7065,19 @@ class GraphModule(torch.nn.Module):
getitem_9: "Sym(u8)" = while_loop[3]
getitem_10: "Sym(u9)" = while_loop[4]
getitem_5: "f32[s77, 3]" = while_loop[5]; while_loop = None
getitem_5: "f32[s0, 3]" = while_loop[5]; while_loop = None
add: "Sym(u7 + 1)" = getitem_8 + 1
add_1: "Sym(u8 + 1)" = getitem_9 + 1
add_2: "Sym(u9 + 1)" = getitem_10 + 1
add_3: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_8); getitem_8 = None
add_4: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_9); getitem_9 = None
add_5: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_10); getitem_10 = None
add_3: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_8); getitem_8 = None
add_4: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_9); getitem_9 = None
add_5: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_10); getitem_10 = None
return pytree.tree_unflatten((getitem_6, getitem_7, add, add_1, add_2, add_3, add_4, add_5, getitem_5), self._out_spec)
class while_loop_cond_graph_0(torch.nn.Module):
def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s77, 3]"):
def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s0, 3]"):
mul: "Sym(u17*u18)" = arg2_1 * arg3_1; arg2_1 = arg3_1 = None
mul_1: "Sym(u17*u18*u19)" = mul * arg4_1; mul = arg4_1 = None
mul_2: "Sym(u15*u16)" = arg0_1 * arg1_1; arg0_1 = arg1_1 = None
@ -7085,7 +7085,7 @@ class GraphModule(torch.nn.Module):
return lt
class while_loop_body_graph_0(torch.nn.Module):
def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s77, 3]"):
def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s0, 3]"):
add: "Sym(u15 + 1)" = arg0_1 + 1; arg0_1 = None
add_1: "Sym(u16 + 1)" = arg1_1 + 1; arg1_1 = None
@ -7093,7 +7093,7 @@ class GraphModule(torch.nn.Module):
add_3: "Sym(u18 + 1)" = arg3_1 + 1; arg3_1 = None
add_4: "Sym(u19 + 1)" = arg4_1 + 1; arg4_1 = None
add_5: "f32[s77, 3]" = torch.ops.aten.add.Tensor(arg5_1, 1); arg5_1 = None
add_5: "f32[s0, 3]" = torch.ops.aten.add.Tensor(arg5_1, 1); arg5_1 = None
return (add, add_1, add_2, add_3, add_4, add_5)
""", # noqa: B950
)
@ -7119,14 +7119,14 @@ class GraphModule(torch.nn.Module):
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"):
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_x_: "f32[s0, s1]"):
l_x_ = L_x_
child: "f32[s77, s27]" = l_x_.sin(); l_x_ = None
child: "f32[s0, s1]" = l_x_.sin(); l_x_ = None
cond_fn_0 = self.cond_fn_0
body_fn_0 = self.body_fn_0
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (s77, s27, 2, 2, 3, child), (s27, s77)); cond_fn_0 = body_fn_0 = s77 = s27 = child = None
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (s0, s1, 2, 2, 3, child), (s0, s1)); cond_fn_0 = body_fn_0 = s0 = s1 = child = None
getitem_10: "Sym(u5)" = while_loop[0]
getitem_11: "Sym(u6)" = while_loop[1]
@ -7134,21 +7134,21 @@ class GraphModule(torch.nn.Module):
getitem_13: "Sym(u8)" = while_loop[3]
getitem_14: "Sym(u9)" = while_loop[4]
out_x: "f32[s77, s27]" = while_loop[5]; while_loop = None
out_x: "f32[s0, s1]" = while_loop[5]; while_loop = None
add: "Sym(u7 + 1)" = getitem_12 + 1
add_1: "Sym(u8 + 1)" = getitem_13 + 1
add_2: "Sym(u9 + 1)" = getitem_14 + 1
add_3: "f32[s77, s27]" = getitem_12 + out_x; getitem_12 = None
add_4: "f32[s77, s27]" = getitem_13 + out_x; getitem_13 = None
add_5: "f32[s77, s27]" = getitem_14 + out_x; getitem_14 = None
add_3: "f32[s0, s1]" = getitem_12 + out_x; getitem_12 = None
add_4: "f32[s0, s1]" = getitem_13 + out_x; getitem_13 = None
add_5: "f32[s0, s1]" = getitem_14 + out_x; getitem_14 = None
return (getitem_10, getitem_11, add, add_1, add_2, add_3, add_4, add_5, out_x)
class cond_fn_0(torch.nn.Module):
def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s77, s27]", s27, s77):
s27_1 = s27
s77_1 = s77
def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s0, s1]", s0, s1):
s0_1 = s0
s1_1 = s1
mul: "Sym(u2*u3)" = unbacked_symint_1 * unbacked_symint_2; unbacked_symint_1 = unbacked_symint_2 = None
mul_1: "Sym(u2*u3*u4)" = mul * unbacked_symint_3; mul = unbacked_symint_3 = None
@ -7157,9 +7157,9 @@ class GraphModule(torch.nn.Module):
return lt
class body_fn_0(torch.nn.Module):
def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s77, s27]", s27, s77):
s27_1 = s27
s77_1 = s77
def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s0, s1]", s0, s1):
s0_1 = s0
s1_1 = s1
add: "Sym(u0 + 1)" = unbacked_symint + 1; unbacked_symint = None
add_1: "Sym(u1 + 1)" = unbacked_symint_0 + 1; unbacked_symint_0 = None
@ -7168,7 +7168,7 @@ class GraphModule(torch.nn.Module):
add_3: "Sym(u3 + 1)" = unbacked_symint_2 + 1; unbacked_symint_2 = None
add_4: "Sym(u4 + 1)" = unbacked_symint_3 + 1; unbacked_symint_3 = None
child_1: "f32[s77, s27]" = child + 1; child = None
child_1: "f32[s0, s1]" = child + 1; child = None
return (add, add_1, add_2, add_3, add_4, child_1)
""", # noqa: B950
)
@ -7336,30 +7336,30 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, x, y, z):
x: "f32[s35, 3]"; y: "f32[s58]"; z: "f32[s35, 3]";
x: "f32[s0, 3]"; y: "f32[s1]"; z: "f32[s0, 3]";
x, y, z, = fx_pytree.tree_flatten_spec(([x, y, z], {}), self._in_spec)
sym_size_int_3: "Sym(s35)" = torch.ops.aten.sym_size.int(x, 0)
sym_size_int_4: "Sym(s58)" = torch.ops.aten.sym_size.int(y, 0); y = None
sym_size_int_3: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
sym_size_int_4: "Sym(s1)" = torch.ops.aten.sym_size.int(y, 0); y = None
gt: "Sym(s35 > 5)" = sym_size_int_3 > 5
gt: "Sym(s0 > 5)" = sym_size_int_3 > 5
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x, sym_size_int_4, sym_size_int_3, z)); gt = true_graph_0 = false_graph_0 = x = sym_size_int_4 = sym_size_int_3 = z = None
getitem: "f32[s35, 3]" = cond[0]; cond = None
getitem: "f32[s0, 3]" = cond[0]; cond = None
return pytree.tree_unflatten((getitem,), self._out_spec)
class true_graph_0(torch.nn.Module):
def forward(self, x: "f32[s35, 3]", sym_size_int_4: "Sym(s58)", sym_size_int_3: "Sym(s35)", z: "f32[s35, 3]"):
add: "f32[s35, 3]" = torch.ops.aten.add.Tensor(x, sym_size_int_4); x = sym_size_int_4 = None
def forward(self, x: "f32[s0, 3]", sym_size_int_4: "Sym(s1)", sym_size_int_3: "Sym(s0)", z: "f32[s0, 3]"):
add: "f32[s0, 3]" = torch.ops.aten.add.Tensor(x, sym_size_int_4); x = sym_size_int_4 = None
return (add,)
class false_graph_0(torch.nn.Module):
def forward(self, x: "f32[s35, 3]", sym_size_int_4: "Sym(s58)", sym_size_int_3: "Sym(s35)", z: "f32[s35, 3]"):
mul: "f32[s35, 3]" = torch.ops.aten.mul.Tensor(z, sym_size_int_3); z = sym_size_int_3 = None
def forward(self, x: "f32[s0, 3]", sym_size_int_4: "Sym(s1)", sym_size_int_3: "Sym(s0)", z: "f32[s0, 3]"):
mul: "f32[s0, 3]" = torch.ops.aten.mul.Tensor(z, sym_size_int_3); z = sym_size_int_3 = None
add: "f32[s35, 3]" = torch.ops.aten.add.Tensor(x, mul); x = mul = None
add: "f32[s0, 3]" = torch.ops.aten.add.Tensor(x, mul); x = mul = None
return (add,)
""", # noqa: B950
)
@ -7522,7 +7522,7 @@ class GraphModule(torch.nn.Module):
normalize_gm(bk.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, s17: "Sym(s17)", s94: "Sym(s94)", L_y_: "f32[s17, s94]", L_z_: "f32[s17, s94]", L_x_: "f32[s17, s94]"):
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_y_: "f32[s0, s1]", L_z_: "f32[s0, s1]", L_x_: "f32[s0, s1]"):
l_y_ = L_y_
l_z_ = L_z_
l_x_ = L_x_
@ -7532,39 +7532,39 @@ class GraphModule(torch.nn.Module):
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, (l_x_, s94, s17, s17, l_z_)); gt = cond_true_0 = cond_false_0 = l_x_ = s94 = s17 = l_z_ = None
cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, (l_x_, s1, s0, s0, l_z_)); gt = cond_true_0 = cond_false_0 = l_x_ = s1 = s0 = l_z_ = None
getitem_5: "f32[u0, s94]" = cond[0]
getitem_5: "f32[u0, s1]" = cond[0]
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(getitem_5, 0); getitem_5 = None
_check_is_size = torch._check_is_size(sym_size_int); _check_is_size = None
ge: "Sym(u0 >= 0)" = sym_size_int >= 0; sym_size_int = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
ret: "f32[u0, s94]" = cond[0]; cond = None
ret: "f32[u0, s1]" = cond[0]; cond = None
sum_2: "f32[]" = l_y_.sum(); l_y_ = None
sub: "f32[u0, s94]" = sum_2 - ret; sum_2 = ret = None
sub: "f32[u0, s1]" = sum_2 - ret; sum_2 = ret = None
return (sub,)
class cond_true_0(torch.nn.Module):
def forward(self, l_x_, s94, s17_true_branch, getitem_2_false_branch, l_z__false_branch):
def forward(self, l_x_, s1, s0_true_branch, getitem_2_false_branch, l_z__false_branch):
l_x__1 = l_x_
s94_1 = s94
s1_1 = s1
add: "f32[s17, s94]" = l_x__1 + s17_true_branch; l_x__1 = s17_true_branch = None
getitem: "f32[s17 - 2, s94]" = add[slice(2, None, None)]; add = None
clone: "f32[s17 - 2, s94]" = getitem.clone(); getitem = None
add: "f32[s0, s1]" = l_x__1 + s0_true_branch; l_x__1 = s0_true_branch = None
getitem: "f32[s0 - 2, s1]" = add[slice(2, None, None)]; add = None
clone: "f32[s0 - 2, s1]" = getitem.clone(); getitem = None
return (clone,)
class cond_false_0(torch.nn.Module):
def forward(self, l_x_, s94, s17_true_branch, getitem_2_false_branch, l_z__false_branch):
def forward(self, l_x_, s1, s0_true_branch, getitem_2_false_branch, l_z__false_branch):
l_x__1 = l_x_
s94_1 = s94
s1_1 = s1
mul: "f32[s17, s94]" = getitem_2_false_branch * l_z__false_branch; getitem_2_false_branch = l_z__false_branch = None
add: "f32[s17, s94]" = l_x__1 + mul; l_x__1 = mul = None
getitem: "f32[2, s94]" = add[slice(None, 2, None)]; add = None
clone: "f32[2, s94]" = getitem.clone(); getitem = None
mul: "f32[s0, s1]" = getitem_2_false_branch * l_z__false_branch; getitem_2_false_branch = l_z__false_branch = None
add: "f32[s0, s1]" = l_x__1 + mul; l_x__1 = mul = None
getitem: "f32[2, s1]" = add[slice(None, 2, None)]; add = None
clone: "f32[2, s1]" = getitem.clone(); getitem = None
return (clone,)
""", # noqa: B950
)

View File

@ -403,10 +403,10 @@ arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None
self.assertExpectedInline(
post_grad_graphs,
"""\
def forward(self, arg0_1: "Sym(s72)", arg1_1: "f32[s72][1]cpu", arg2_1: "f32[s72][1]cpu", arg3_1: "f32[s72][1]cpu", arg4_1: "f32[s72][1]cpu", arg5_1: "f32[s72][1]cpu"):
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu", arg3_1: "f32[s0][1]cpu", arg4_1: "f32[s0][1]cpu", arg5_1: "f32[s0][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg3_1, [arg4_1, arg5_1], arg2_1, 2, arg1_1); arg4_1 = arg5_1 = arg1_1 = foo_default = None
copy_: "f32[s72][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None
copy__1: "f32[s72][1]cpu" = torch.ops.aten.copy_.default(arg3_1, arg3_1); arg3_1 = copy__1 = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg3_1, arg3_1); arg3_1 = copy__1 = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -563,13 +563,13 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
self.assertExpectedInline(
graph_aot,
"""\
def forward(self, arg0_1: "Sym(s17)", arg1_1: "f32[s17][1]cpu", arg2_1: "f32[s17][1]cpu"):
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"):
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg2_1, arg1_1])
getitem_1: "f32[s17][1]cpu" = auto_functionalized_v2[1]
getitem_2: "f32[s17][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
add: "f32[s17][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2)
copy_: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy_ = None
copy__1: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_1); arg2_1 = getitem_1 = copy__1 = None
getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]
getitem_2: "f32[s0][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2)
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy_ = None
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_1); arg2_1 = getitem_1 = copy__1 = None
return (add,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -595,11 +595,11 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
self.assertExpectedInline(
graph_inductor,
"""\
def forward(self, arg0_1: "Sym(s17)", arg1_1: "f32[s17][1]cpu", arg2_1: "f32[s17][1]cpu"):
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg2_1, arg1_1); foo_default = None
add: "f32[s17][1]cpu" = torch.ops.aten.add.Tensor(arg2_1, arg1_1)
copy_: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
copy__1: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(arg2_1, arg1_1)
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
return (add,)""",
ignore_comments=True,
ignore_empty_lines=True,
@ -663,10 +663,10 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
self.assertExpectedInline(
graph_aot,
"""\
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg1_1])
getitem_1: "f32[s77][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy_ = None
getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy_ = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -691,11 +691,11 @@ def forward(self, arg0_1: "f32[2][1]cpu"):
self.assertExpectedInline(
graph_inductor,
"""\
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 0)
as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 1)
foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None
copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -1291,14 +1291,14 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"):
self.assertExpectedInline(
graph_aot,
"""\
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77, s77][s77, 1]cpu"):
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0, s0][s0, 1]cpu"):
floordiv: "Sym(0)" = 0 // arg0_1; arg0_1 = None
add_6: "Sym(2)" = floordiv + 2
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_slice_dim = 0, _x_slice_start = floordiv, _x_slice_end = add_6, _y_base_index = 0, _y_slice_dim = 1, _y_slice_start = 3, _y_slice_end = 4, _all_bases = [arg1_1]); floordiv = add_6 = None
getitem_1: "f32[s77, s77][s77, 1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
copy_: "f32[s77, s77][s77, 1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None
slice_3: "f32[2, s77][s77, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 0, 0, 2)
slice_4: "f32[s77, 1][s77, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 1, 3, 4); getitem_1 = None
getitem_1: "f32[s0, s0][s0, 1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
copy_: "f32[s0, s0][s0, 1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None
slice_3: "f32[2, s0][s0, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 0, 0, 2)
slice_4: "f32[s0, 1][s0, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 1, 3, 4); getitem_1 = None
return (slice_3, slice_4)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -1324,13 +1324,13 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"):
self.assertExpectedInline(
graph_inductor,
"""\
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77, s77][s77, 1]cpu"):
slice_tensor: "f32[2, s77][s77, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, 0, 2)
slice_tensor_1: "f32[s77, 1][s77, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 1, 3, 4)
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0, s0][s0, 1]cpu"):
slice_tensor: "f32[2, s0][s0, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, 0, 2)
slice_tensor_1: "f32[s0, 1][s0, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 1, 3, 4)
foo_default = torch.ops.mylib.foo.default(slice_tensor, slice_tensor_1); slice_tensor = slice_tensor_1 = foo_default = None
copy_: "f32[s77, s77][s77, 1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None
slice_3: "f32[2, s77][s77, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, 0, 2)
slice_4: "f32[s77, 1][s77, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 1, 3, 4); arg1_1 = None
copy_: "f32[s0, s0][s0, 1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None
slice_3: "f32[2, s0][s0, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, 0, 2)
slice_4: "f32[s0, 1][s0, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 1, 3, 4); arg1_1 = None
return (slice_3, slice_4)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -1470,18 +1470,18 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"):
self.assertExpectedInline(
graph_aot,
"""\
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
clone: "f32[s77][1]cpu" = torch.ops.aten.clone.default(arg1_1)
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
clone: "f32[s0][1]cpu" = torch.ops.aten.clone.default(arg1_1)
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
_to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg1_1, _to_copy]); _to_copy = None
getitem_1: "f32[s77][1]cpu" = auto_functionalized_v2[1]
getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]
getitem_2: "f32[u0, 1][1, u0]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None
alias_1: "f32[s77][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None
alias_1: "f32[s0][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None
return (alias_1, slice_2)""", # noqa: B950
ignore_comments=True,
@ -1517,16 +1517,16 @@ def forward(self, arg0_1: "f32[2][1]cpu"):
self.assertExpectedInline(
graph_inductor,
"""\
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg1_1)
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
alias_default: "f32[s77][1]cpu" = torch.ops.aten.alias.default(arg1_1)
alias_default: "f32[s0][1]cpu" = torch.ops.aten.alias.default(arg1_1)
alias_default_1: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.alias.default(convert_element_type)
foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); alias_default = alias_default_1 = foo_default = None
copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None
return (arg1_1, slice_2)""", # noqa: B950
ignore_comments=True,

View File

@ -11822,7 +11822,7 @@ class CommonTemplate:
# 'i1 + 3 * i0' is cached.
self.assertTrue(
"i0 + 2 * i1" in mul_buf.data.inner_fn_str()
or "i0 + i1 * s64" in mul_buf.data.inner_fn_str()
or "i0 + i1 * s1" in mul_buf.data.inner_fn_str()
)
with add_scheduler_init_hook(hook_fn):
@ -12551,7 +12551,7 @@ class CommonTemplate:
torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3)
if is_dynamic_shape_enabled():
size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, s12, s80, s80., .3\*s12\*s80\*s80, s12\*s80\*s80, 1, s12\*s80, s1.." # noqa: B950
size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, s1, s2, s2., .3\*s1\*s2\*s2, s1\*s2\*s2, 1, s1\*s2, s1.." # noqa: B950
else:
size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, 16, 32, 32., .49152, 16384, 1, 512, 16.."
FileCheck().check_regex(size_assert_pattern).run(code)
@ -12570,8 +12570,8 @@ class CommonTemplate:
code = run_and_get_triton_code(f, x)
if is_dynamic_shape_enabled():
FileCheck().check("assert_size_stride(buf1, (s77, s27), (s27, 1))").check(
"assert_size_stride(buf2, (s77, s27), (s27, 1))"
FileCheck().check("assert_size_stride(buf1, (s0, s1), (s1, 1))").check(
"assert_size_stride(buf2, (s0, s1), (s1, 1))"
).run(code)
else:
FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1))").check(

View File

@ -393,7 +393,7 @@ class TestPySymInt(TestCase):
self.assertEqual(res_and, 0b1000)
self.assertIsInstance(res_and, torch.SymInt, msg=type(res_and))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_and(s97, s26), 8)"""
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_and(s0, s1), 8)"""
)
a1 = create_symint(shape_env, 3)
@ -415,7 +415,7 @@ class TestPySymInt(TestCase):
self.assertEqual(res_or, 0b1110)
self.assertIsInstance(res_or, torch.SymInt, msg=type(res_or))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_or(s97, s26), 14)"""
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_or(s0, s1), 14)"""
)
def test_stride(self):
@ -497,7 +497,7 @@ class TestPySymInt(TestCase):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
self.assertEqual(guard_int(a0), 2)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""")
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
def test_sym_sum(self):
shape_env = ShapeEnv()
@ -512,7 +512,7 @@ class TestPySymInt(TestCase):
shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
s0 = create_symint(shape_env, 2)
self.assertEqual(guard_int(s0), 2)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""")
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
s0 = create_symint(shape_env, 2)
@ -520,7 +520,7 @@ class TestPySymInt(TestCase):
self.assertEqual(len(shape_env.guards), 0)
self.assertExpectedInline(
str([ra.expr for ra in shape_env.deferred_runtime_asserts[None]]),
"""[Eq(s97, 2)]""",
"""[Eq(s0, 2)]""",
)
def test_sym_int(self):
@ -529,14 +529,14 @@ class TestPySymInt(TestCase):
r = sym_int(a0)
self.assertEqual(r, 5)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 5)""")
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 5)""")
a1 = create_symint(shape_env, 7)
r = sym_int(a1 / 2)
self.assertEqual(guard_int(r), 3)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s26, 2)), 3)"""
str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s1, 2)), 3)"""
)
a3 = create_symint(shape_env, 3)
@ -544,7 +544,7 @@ class TestPySymInt(TestCase):
self.assertEqual(guard_int(r), 6)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s57)), 6)"""
str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s2)), 6)"""
)
def test_sym_log2(self):
@ -554,7 +554,7 @@ class TestPySymInt(TestCase):
self.assertEqual(r, 2.0)
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_log2(ToFloat(s97)), 2.0)"""
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_log2(ToFloat(s0)), 2.0)"""
)
def test_sym_sqrt(self):
@ -564,7 +564,7 @@ class TestPySymInt(TestCase):
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(ToFloat(s97)), 2.0)"""
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(ToFloat(s0)), 2.0)"""
)
def test_sym_floor(self):
@ -575,14 +575,14 @@ class TestPySymInt(TestCase):
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]),
"""Eq(FloorToInt(IntTrueDiv(s97, 2)), 2)""",
"""Eq(FloorToInt(IntTrueDiv(s0, 2)), 2)""",
)
r = math.floor(3.0 * a0)
self.assertEqual(r, 15)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[1][0]),
"""Eq(FloorToInt(3.0*ToFloat(s97)), 15)""",
"""Eq(FloorToInt(3.0*ToFloat(s0)), 15)""",
)
def test_sym_trunc(self):
@ -592,14 +592,14 @@ class TestPySymInt(TestCase):
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s97, 2)), 2)"""
str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s0, 2)), 2)"""
)
r = torch.sym_int(torch.sym_sqrt(a0))
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[1][0]),
"""Eq(TruncToInt(OpaqueUnaryFn_sqrt(ToFloat(s97))), 2)""",
"""Eq(TruncToInt(OpaqueUnaryFn_sqrt(ToFloat(s0))), 2)""",
)
def test_sym_ceil(self):
@ -610,7 +610,7 @@ class TestPySymInt(TestCase):
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]),
"""Eq(CeilToInt(IntTrueDiv(s97, 2)), 3)""",
"""Eq(CeilToInt(IntTrueDiv(s0, 2)), 3)""",
)
r1 = 3.0 * a0
r = math.floor(r1)
@ -618,7 +618,7 @@ class TestPySymInt(TestCase):
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[1][0]),
"""Eq(FloorToInt(3.0*ToFloat(s97)), 15)""",
"""Eq(FloorToInt(3.0*ToFloat(s0)), 15)""",
)
def test_sym_ite(self):
@ -638,7 +638,7 @@ class TestPySymInt(TestCase):
self.assertEqual(type(t), type(r3))
self.assertExpectedInline(
str(shape_env.guards[0][0]),
"""Eq(Piecewise((s97, Eq(s97, 5)), (s26, True)), 5)""",
"""Eq(Piecewise((s0, Eq(s0, 5)), (s1, True)), 5)""",
)
b4 = f == 5
r4 = torch.sym_ite(b4, t, f)
@ -647,7 +647,7 @@ class TestPySymInt(TestCase):
self.assertEqual(type(f), type(r4))
self.assertExpectedInline(
str(shape_env.guards[1][0]),
"""Eq(Piecewise((s97, Eq(s26, 5)), (s26, True)), 4)""",
"""Eq(Piecewise((s0, Eq(s1, 5)), (s1, True)), 4)""",
)
def test_tracing_sym_ite(self):
@ -679,7 +679,7 @@ def forward(self, x_1):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
int(a0)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""")
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
def test_data_dependent_guard(self):
shape_env = ShapeEnv()
@ -710,7 +710,7 @@ def forward(self, x_1):
self.assertTrue(expect_true(i0 < s0))
self.assertExpectedInline(
str([ra.expr for ra in shape_env.deferred_runtime_asserts[i0.node.expr]]),
"""[u0 < s97]""",
"""[u0 < s0]""",
)
self.assertTrue(i0 < s0)
self.assertTrue(i0 != s0)
@ -1173,18 +1173,18 @@ def forward(self, x_1):
out.strip(),
"""\
class f(torch.nn.Module):
def forward(self, a_1: "f32[s75, s96]", b_1: "f32[s57, s96]"):
def forward(self, a_1: "f32[s0, s1]", b_1: "f32[s2, s1]"):
# No stacktrace found for following nodes
sym_size_int: "Sym(s75)" = torch.ops.aten.sym_size.int(a_1, 0)
sym_size_int_1: "Sym(s57)" = torch.ops.aten.sym_size.int(b_1, 0)
add: "Sym(s57 + s75)" = sym_size_int + sym_size_int_1; sym_size_int = sym_size_int_1 = None
sym_size_int_2: "Sym(s96)" = torch.ops.aten.sym_size.int(a_1, 1)
sym_size_int_3: "Sym(s96)" = torch.ops.aten.sym_size.int(b_1, 1); b_1 = None
add_1: "Sym(2*s96)" = sym_size_int_2 + sym_size_int_3; sym_size_int_2 = sym_size_int_3 = None
new_empty: "f32[s57 + s75, 2*s96]" = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None
sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(a_1, 0)
sym_size_int_1: "Sym(s2)" = torch.ops.aten.sym_size.int(b_1, 0)
add: "Sym(s0 + s2)" = sym_size_int + sym_size_int_1; sym_size_int = sym_size_int_1 = None
sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(a_1, 1)
sym_size_int_3: "Sym(s1)" = torch.ops.aten.sym_size.int(b_1, 1); b_1 = None
add_1: "Sym(2*s1)" = sym_size_int_2 + sym_size_int_3; sym_size_int_2 = sym_size_int_3 = None
new_empty: "f32[s0 + s2, 2*s1]" = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
getitem: "f32[s57 + s75, 2*s96]" = native_dropout[0]
getitem_1: "b8[s57 + s75, 2*s96]" = native_dropout[1]; native_dropout = None
getitem: "f32[s0 + s2, 2*s1]" = native_dropout[0]
getitem_1: "b8[s0 + s2, 2*s1]" = native_dropout[1]; native_dropout = None
return (getitem, getitem_1)""", # noqa: B950
)
@ -2846,8 +2846,8 @@ class TestGuardsExpressions(TestCase):
],
)
self.assertEqual(f"{x.stride()}", "(s49, 1)")
self.assertEqual(f"{x.shape}", "torch.Size([s26, s49])")
self.assertEqual(f"{x.stride()}", "(s1, 1)")
self.assertEqual(f"{x.shape}", "torch.Size([s0, s1])")
x_clean = _remove_symbols_without_guarding(x, 4096)

View File

@ -1084,7 +1084,7 @@ def forward(self, x_1, y_1):
test_inputs.append([(6, 8)])
gm = self._test_dynamic(f, [(3, 4)], test_inputs)
self.assertTrue(eval_guards(gm, torch.randn(4, 5)))
self.assertEqual(repr(bind_symbols(gm, torch.randn(4, 5))), "{s75: 4, s96: 5}")
self.assertEqual(repr(bind_symbols(gm, torch.randn(4, 5))), "{s0: 4, s1: 5}")
self.assertFalse(eval_guards(gm, torch.randn(25, 5)))
self.assertExpectedInline(show_guards(gm), """L['x'].size()[0] <= 19""")
@ -1717,7 +1717,7 @@ def forward(self, a_1):
gm = self._test_dynamic(f, [(1, 6), (8, 1)], test_inputs)
self.assertTrue(eval_guards(gm, torch.randn(1, 10), torch.randn(6, 1)))
self.assertFalse(eval_guards(gm, torch.randn(1, 2), torch.randn(4, 1)))
self.assertExpectedInline(show_guards(gm), """2*L['b'].size()[0]*L['a'].size()[1] > 20""")
self.assertExpectedInline(show_guards(gm), """2*L['a'].size()[1]*L['b'].size()[0] > 20""")
def test_new_empty(self):
def f(a, b):

View File

@ -869,12 +869,6 @@ def _optimized_add(
if sortkey(lhs._args[0]) > sortkey(rhs._args[-1]):
return make_optimized(rhs._args + lhs._args)
# (a1+a3) + (a0+a2) => (a0+a1+a2+a3)
new_args = list(lhs._args)
for a in rhs._args:
new_args = _binary_search_insert_arg(new_args, a)
return make_optimized(new_args)
# (a0+a2) + a1 => (a0+a1+a2)
if lhs_is_optimized_summation and rhs.is_symbol:
new_args = _binary_search_insert_arg(list(lhs._args), rhs)

View File

@ -14,7 +14,6 @@ import atexit
import collections
import dis
import functools
import hashlib
import inspect
import itertools
import logging
@ -3290,11 +3289,6 @@ class ShapeEnv:
self.guards: list[ShapeGuard] = []
self.axioms: dict[sympy.Expr, sympy.Expr] = {}
# A set of ids that have already been allocated. This is used
# for when we allocate symbol ids using the hash of the source
# names to ensure we don't have collisions via linear probing
self.unique_ids: set[int] = set()
# Maps symbolic ints to their original concrete values
# Currently populated from tensors
self.var_to_val: dict[sympy.Symbol, sympy.Integer] = {}
@ -4546,14 +4540,13 @@ class ShapeEnv:
# If we're not duck shaping, we always create a new symbol
# Even if we're duck shaping, if we haven't seen this particular
# value before, we also create a new symbol
symbol_id = self._generate_unique_id(source.name())
if type(val) is int or is_nested_int(val):
sympy_expr = make_symbol(
SymT.SIZE, symbol_id, positive=positive, integer=True
SymT.SIZE, len(self.var_to_val), positive=positive, integer=True
)
else:
sympy_expr = make_symbol(
SymT.FLOAT, symbol_id, positive=positive, real=True
SymT.FLOAT, len(self.var_to_val), positive=positive, real=True
)
self.source_to_var[source_name] = sympy_expr
# We always associate vars to vals
@ -6565,13 +6558,6 @@ class ShapeEnv:
sloc, _ = self._get_stack_summary(framework_loc=framework_loc)
return sloc
def _generate_unique_id(self, source_name: str) -> int:
attempt = int(hashlib.sha256(source_name.encode()).hexdigest(), 16) % 100
while attempt in self.unique_ids:
attempt += 1
self.unique_ids.add(attempt)
return attempt
def _find_frame_locals(self) -> _FrameLocalResult:
"""
Given the current user code frame, finds the relevant lines of code,