Use source hashing to generate consistent symbolic ids (#149665)

This PR was inspired by internal models that were cache missing due to PGO. At a high level the problem looks as follows

Run 1, Invocation 1: We do static compile, save some example values in PGO/automatic dynamic

Run 1, Invocation 2: We detect varying inputs, do dynamic compile, get a dynamic graph and save to PGO. Crucially what we save to PGO is actually a superset of what is actually dynamic. If we notice an input was varying, we mark it as dynamic in PGO even if later on that value gets specialized. When a value gets specialized, we actually remove the symbol from the graph. This results in an interesting conundrum where although we are producing the same isomorphic graph, PGO makes the second run cache miss. Let's see how....

Run 2, Invocation 1: We fetch the PGO, over-mark things as dynamic, get a fx graph, look it up in the cache and... whoops! cache miss! This is because of the aforementioned behavior where the PGO profile will cause us to over-allocate symbols. In practice this means we end up saving a graph in cache with symbols x:s1, y:s3 and on second attempt we cache miss with x:s1, y:s6 where symbols s3,s4,s5 were all optimistically marked dynamic by PGO and subsequently specialized.

We solve this problem by hashing the source names. This ensures somewhat stable assignment. To prevent catastrophic symbol collisions, we use linear probing to ensure no collisions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149665
Approved by: https://github.com/Mingming-Ding, https://github.com/laithsakka
This commit is contained in:
bobrenjc93 2025-03-27 17:53:11 -07:00 committed by PyTorch MergeBot
parent c49315e645
commit f649ee73ce
23 changed files with 521 additions and 443 deletions

View File

@ -196,6 +196,46 @@ class AOTAutogradCacheTests(InductorTestCase):
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
@inductor_config.patch("fx_graph_remote_cache", False)
@inductor_config.patch("fx_graph_cache", True)
@functorch_config.patch({"enable_autograd_cache": True})
def test_symbol_specialization(self):
"""
Verify the symbol specializations don't cause cache miss.
"""
def fn(x, y, z):
return (torch.randn(5) + x + y, z * torch.randn(1))
a = torch.rand(5)
torch._dynamo.maybe_mark_dynamic(a, 0)
b = torch.rand(5)
c = torch.randn(6)
torch._dynamo.maybe_mark_dynamic(c, 0)
compiled_fn = torch.compile(fn, backend="inductor")
# A first call should miss in the cache.
compiled_fn(a, b, c)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
# A second call should hit even if a new dimension is marked as dynamic
# that is later specialized as part of tracing.
a = torch.rand(5)
torch._dynamo.maybe_mark_dynamic(a, 0)
b = torch.rand(5)
torch._dynamo.maybe_mark_dynamic(b, 0)
c = torch.randn(6)
torch._dynamo.maybe_mark_dynamic(c, 0)
self._clear_dynamo_and_codecache()
compiled_fn(a, b, c)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
@functorch_config.patch({"enable_autograd_cache": True})
def test_aot_runtime_trace_joint(self):
@torch.compile(backend="inductor")

View File

@ -245,7 +245,7 @@ class GraphModule(torch.nn.Module):
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(2)", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s7)"):
def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(2)", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s45)"):
l_inputs_ = L_inputs_
l_sizes_0_ = L_sizes_0_
l_hooks_1_keywords_fn_keywords_obj_counter = L_hooks_1_keywords_fn_keywords_obj_counter
@ -264,7 +264,7 @@ class GraphModule(torch.nn.Module):
copy_: "f32[2]" = new_grad_strided.copy_(aot0_tangents_1); copy_ = None
add: "Sym(s7 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None
add: "Sym(s45 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None
result: "f32[2]" = aot0_tangents_1 * aot0_tangents_1; aot0_tangents_1 = None

View File

@ -57,18 +57,18 @@ class ComptimeTests(torch._dynamo.test_case.TestCase):
self.assertExpectedInline(
FILE.getvalue().strip(),
"""\
FakeTensor(..., size=(s0,))
FakeTensor(..., size=(s77,))
2
[FakeTensor(..., size=(s0,)), 2]
(FakeTensor(..., size=(s0,)), 2)
{'foo': FakeTensor(..., size=(s0,))}
[FakeTensor(..., size=(s77,)), 2]
(FakeTensor(..., size=(s77,)), 2)
{'foo': FakeTensor(..., size=(s77,))}
range(1, 3, 1)
Employee(name='foo', id=2)
UserDefinedListVariable(mylist)
defaultdict(NestedUserFunctionVariable(), {})
set()
{'a','b'}
s0""",
s77""",
)
def test_print_graph(self):

View File

@ -256,34 +256,34 @@ Model:
==> L['x'].size()[0]: 3
==> L['x'].storage_offset(): 0
==> L['x'].stride()[0]: 1
==> s0: 3
==> s1: 0
==> s2: 1
==> s3: 1
==> s52: 1
==> s77: 3
==> s86: 0
Assertions:
==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0])
==> (== L['shape'][0] s1)
==> (== L['shape'][1] s2)
==> (== L['shape'][0] s86)
==> (== L['shape'][1] s52)
==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s0)
==> (> s0 1)
==> (== L['x'].size()[0] s77)
==> (> s77 1)
Target Expressions:
==> (!= (+ s1 s2 s3) s0)
==> (<= 0 s1)
==> (<= 0 s2)
==> (!= (+ s3 s52 s86) s77)
==> (<= 0 s3)
==> (<= 2 s0)
==> (<= 0 s52)
==> (<= 0 s86)
==> (<= 2 s77)
==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0])
==> (== L['shape'][0] s1)
==> (== L['shape'][1] s2)
==> (== L['shape'][0] s86)
==> (== L['shape'][1] s52)
==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s0)
==> (> s0 0)
==> (>= 0 s1)
==> (== L['x'].size()[0] s77)
==> (> s77 0)
==> (>= 0 s86)
Failed Source Expressions:
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
@ -309,7 +309,7 @@ Failed Source Expressions:
BisectValidationException,
lambda: fn(torch.randn(20), (5, 10, 5)),
"""\
translation validation failed when evaluating: Eq(s1 + s2 + s3, s0)
translation validation failed when evaluating: Eq(s3 + s52 + s86, s77)
Failure occurred while running node:
%split : [num_users=3] = call_method[target=split](args = (%l_x_, (%l_shape_0_, %l_shape_1_, %l_shape_2_)), kwargs = {})
@ -321,33 +321,33 @@ Model:
==> L['x'].size()[0]: 3
==> L['x'].storage_offset(): 0
==> L['x'].stride()[0]: 1
==> s0: 3
==> s1: 1
==> s2: 1
==> s3: 0
==> s52: 1
==> s77: 3
==> s86: 1
Assertions:
==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0])
==> (== L['shape'][0] s1)
==> (== L['shape'][1] s2)
==> (== L['shape'][0] s86)
==> (== L['shape'][1] s52)
==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s0)
==> (> s0 1)
==> (== L['x'].size()[0] s77)
==> (> s77 1)
Target Expressions:
==> (!= (+ s1 s2 s3) s0)
==> (<= 0 s1)
==> (<= 0 s2)
==> (!= (+ s3 s52 s86) s77)
==> (<= 0 s3)
==> (<= 2 s0)
==> (<= 0 s52)
==> (<= 0 s86)
==> (<= 2 s77)
==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0])
==> (== L['shape'][0] s1)
==> (== L['shape'][1] s2)
==> (== L['shape'][0] s86)
==> (== L['shape'][1] s52)
==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s0)
==> (> s0 0)
==> (== L['x'].size()[0] s77)
==> (> s77 0)
Failed Source Expressions:
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",

View File

@ -2703,7 +2703,7 @@ def forward(self, x):
for node in ebar.graph_module.graph.nodes
if node.op == "placeholder"
],
["torch.Size([s0, s1, s1])", "torch.Size([s0, s1, s1])"],
["torch.Size([s17, s27, s27])", "torch.Size([s17, s27, s27])"],
)
@torch._dynamo.config.patch(
@ -3480,23 +3480,23 @@ def forward(self, x):
true_graph = """\
class GraphModule(torch.nn.Module):
def forward(self, pred, x):
arg1: "f32[s1, s2]";
arg1: "f32[s77, s27]";
arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
l_x_ = arg1
sin: "f32[s1, s2]" = l_x_.sin(); l_x_ = None
sin: "f32[s77, s27]" = l_x_.sin(); l_x_ = None
return pytree.tree_unflatten([sin], self._out_spec)
"""
false_graph = """\
class GraphModule(torch.nn.Module):
def forward(self, pred, x):
arg1: "f32[s1, s2]";
arg1: "f32[s77, s27]";
arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
l_x_ = arg1
cos: "f32[s1, s2]" = l_x_.cos(); l_x_ = None
cos: "f32[s77, s27]" = l_x_.cos(); l_x_ = None
return pytree.tree_unflatten([cos], self._out_spec)
"""
true_guard_code = [

View File

@ -2655,7 +2655,7 @@ class GraphModule(torch.nn.Module):
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0]"):
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77]"):
l_x_ = L_x_
sum_1: "f32[]" = l_x_.sum(); l_x_ = None
@ -2885,13 +2885,13 @@ class GraphModule(torch.nn.Module):
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
def forward(self, s9: "Sym(s9)", L_lambda0_keywords_y_: "f32[s9, s9]"):
l_lambda0_keywords_y_ = L_lambda0_keywords_y_
mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
mul_1: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
mul: "f32[s9, s9]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
mul_1: "f32[s9, s9]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
mul_2: "f32[s0, s0]" = torch.mul(mul, mul_1); mul = mul_1 = None
mul_2: "f32[s9, s9]" = torch.mul(mul, mul_1); mul = mul_1 = None
return (mul_2,)
""",
)
@ -2932,14 +2932,14 @@ class GraphModule(torch.nn.Module):
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
def forward(self, s9: "Sym(s9)", L_lambda0_keywords_y_: "f32[s9, s9]"):
l_lambda0_keywords_y_ = L_lambda0_keywords_y_
mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
mul: "f32[s9, s9]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
add: "f32[s0, s0]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
add: "f32[s9, s9]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
mul_1: "f32[s0, s0]" = torch.mul(mul, add); mul = add = None
mul_1: "f32[s9, s9]" = torch.mul(mul, add); mul = add = None
return (mul_1,)
""",
)
@ -2982,14 +2982,14 @@ class GraphModule(torch.nn.Module):
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
def forward(self, s9: "Sym(s9)", L_lambda0_keywords_y_: "f32[s9, s9]"):
l_lambda0_keywords_y_ = L_lambda0_keywords_y_
mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
mul: "f32[s9, s9]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
add: "f32[s0, s0]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
add: "f32[s9, s9]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
mul_1: "f32[s0, s0]" = torch.mul(mul, add); mul = add = None
mul_1: "f32[s9, s9]" = torch.mul(mul, add); mul = add = None
return (mul_1,)
""",
)
@ -3029,14 +3029,14 @@ class GraphModule(torch.nn.Module):
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, s0]"):
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77, s77]"):
l_x_ = L_x_
mul: "f32[s0, s0]" = l_x_ * 4
mul_1: "f32[s0, s0]" = mul * l_x_; mul = None
mul_2: "f32[s0, s0]" = 20 * l_x_; l_x_ = None
mul: "f32[s77, s77]" = l_x_ * 4
mul_1: "f32[s77, s77]" = mul * l_x_; mul = None
mul_2: "f32[s77, s77]" = 20 * l_x_; l_x_ = None
mul_3: "f32[s0, s0]" = torch.mul(mul_1, mul_2); mul_1 = mul_2 = None
mul_3: "f32[s77, s77]" = torch.mul(mul_1, mul_2); mul_1 = mul_2 = None
return (mul_3,)
""",
)

View File

@ -413,18 +413,18 @@ class GraphModule(torch.nn.Module):
actual_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, 1]"):
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77, 1]"):
l_x_ = L_x_
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_); wrap_body_0 = s0 = l_x_ = None
getitem: "f32[s0]" = wrap[0]; wrap = None
wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_); wrap_body_0 = s77 = l_x_ = None
getitem: "f32[s77]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0, 1]"):
view: "f32[s0]" = l_x_.view(s0); l_x_ = s0 = None
add: "f32[s0]" = view + 0.5; view = None
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77, 1]"):
view: "f32[s77]" = l_x_.view(s77); l_x_ = s77 = None
add: "f32[s77]" = view + 0.5; view = None
return (add,)
""",
)
@ -606,27 +606,27 @@ class GraphModule(torch.nn.Module):
out_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0]"):
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77]"):
l_x_ = L_x_
sum_1: "f32[]" = l_x_.sum()
item: "Sym(zuf0)" = sum_1.item(); sum_1 = None
wrap_body_1 = self.wrap_body_1
wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, l_x_, item); wrap_body_1 = s0 = l_x_ = item = None
getitem: "f32[s0]" = wrap[0]; wrap = None
wrap = torch.ops.higher_order.wrap(wrap_body_1, s77, l_x_, item); wrap_body_1 = s77 = l_x_ = item = None
getitem: "f32[s77]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_1(torch.nn.Module):
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", item: "Sym(zuf0)"):
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", item: "Sym(zuf0)"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_, item); wrap_body_0 = s0 = l_x_ = item = None
getitem: "f32[s0]" = wrap[0]; wrap = None
wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_, item); wrap_body_0 = s77 = l_x_ = item = None
getitem: "f32[s77]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", item: "Sym(zuf0)"):
add: "f32[s0]" = l_x_ + item; l_x_ = item = None
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", item: "Sym(zuf0)"):
add: "f32[s77]" = l_x_ + item; l_x_ = item = None
return (add,)
""",
)
@ -692,7 +692,7 @@ class GraphModule(torch.nn.Module):
out_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0]"):
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77]"):
l_x_ = L_x_
c: "i64[u0, 1]" = l_x_.nonzero()
@ -704,22 +704,22 @@ class GraphModule(torch.nn.Module):
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
wrap_body_1 = self.wrap_body_1
wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, l_x_, sym_size_int_1, c); wrap_body_1 = s0 = l_x_ = sym_size_int_1 = c = None
getitem: "f32[s0]" = wrap[0]
wrap = torch.ops.higher_order.wrap(wrap_body_1, s77, l_x_, sym_size_int_1, c); wrap_body_1 = s77 = l_x_ = sym_size_int_1 = c = None
getitem: "f32[s77]" = wrap[0]
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (getitem, getitem_1)
class wrap_body_1(torch.nn.Module):
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", u0: "Sym(u0)", c: "i64[u0, 1]"):
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_, u0, c); wrap_body_0 = s0 = l_x_ = u0 = c = None
child: "f32[s0]" = wrap[0]
wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_, u0, c); wrap_body_0 = s77 = l_x_ = u0 = c = None
child: "f32[s77]" = wrap[0]
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (child, child_1)
class wrap_body_0(torch.nn.Module):
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", u0: "Sym(u0)", c: "i64[u0, 1]"):
child: "f32[s0]" = l_x_.sin(); l_x_ = None
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"):
child: "f32[s77]" = l_x_.sin(); l_x_ = None
child_1: "f32[u0, 1]" = c.sin(); c = None
return (child, child_1)
""",
@ -994,25 +994,25 @@ class GraphModule(torch.nn.Module):
out_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_x_: "f32[s0, s1]", s2: "Sym(s2)", L_y_: "f32[s1, s2]"):
def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]", s94: "Sym(s94)", L_y_: "f32[s27, s94]"):
l_x_ = L_x_
l_y_ = L_y_
wrap_body_1 = self.wrap_body_1
wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, s1, l_x_, s2, l_y_); wrap_body_1 = s0 = s1 = l_x_ = s2 = l_y_ = None
getitem: "f32[s0, s2]" = wrap[0]; wrap = None
wrap = torch.ops.higher_order.wrap(wrap_body_1, s77, s27, l_x_, s94, l_y_); wrap_body_1 = s77 = s27 = l_x_ = s94 = l_y_ = None
getitem: "f32[s77, s94]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_1(torch.nn.Module):
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", l_x_: "f32[s0, s1]", s2: "Sym(s2)", l_y_: "f32[s1, s2]"):
def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", l_x_: "f32[s77, s27]", s94: "Sym(s94)", l_y_: "f32[s27, s94]"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, s1, l_x_, s2, l_y_); wrap_body_0 = s0 = s1 = l_x_ = s2 = l_y_ = None
getitem: "f32[s0, s2]" = wrap[0]; wrap = None
wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, s27, l_x_, s94, l_y_); wrap_body_0 = s77 = s27 = l_x_ = s94 = l_y_ = None
getitem: "f32[s77, s94]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", l_x_: "f32[s0, s1]", s2: "Sym(s2)", l_y_: "f32[s1, s2]"):
matmul: "f32[s0, s2]" = l_x_ @ l_y_; l_x_ = l_y_ = None
def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", l_x_: "f32[s77, s27]", s94: "Sym(s94)", l_y_: "f32[s27, s94]"):
matmul: "f32[s77, s94]" = l_x_ @ l_y_; l_x_ = l_y_ = None
return (matmul,)
""",
)

View File

@ -10382,19 +10382,22 @@ ShapeEnv not equal: field values don't match:
> Left: {x.size()[0]: x.size()[0], x.size()[1]: x.size()[1], x.storage_offset(): x.storage_offset(), x.stride()[0]: x.stride()[0], x.stride()[1]: x.stride()[1]}
> Right: {}
==> source_to_var: values don't match.
> Left: {x.size()[0]: s0, x.size()[1]: s1}
> Left: {x.size()[0]: s93, x.size()[1]: s44}
> Right: {}
==> unique_ids: values don't match.
> Left: {44, 93}
> Right: {}
==> val_to_var: values don't match.
> Left: {0: 0, 1: 1, 2: s1, 3: s0}
> Left: {0: 0, 1: 1, 2: s44, 3: s93}
> Right: {0: 0, 1: 1}
==> var_to_range: values don't match.
> Left: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
> Left: {s44: VR[2, int_oo], s93: VR[2, int_oo]}
> Right: {}
==> var_to_sources: values don't match.
> Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)]}
> Left: {s44: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)], s93: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)]}
> Right: {}
==> var_to_val: values don't match.
> Left: {s0: 3, s1: 2}
> Left: {s44: 2, s93: 3}
> Right: {}
""",
)
@ -10453,13 +10456,13 @@ ShapeEnv not equal: field values don't match:
ShapeEnv not equal: field values don't match:
==> axioms: values don't match.
> Left: {(Mod(s0, 3)) < 0: False, (Mod(s0, 3)) <= 0: True, 0 < (Mod(s0, 3)): False, 0 <= (Mod(s0, 3)): True, Eq(0, Mod(s0, 3)): True, Eq(Mod(s0, 3), 0): True, Ne(0, Mod(s0, 3)): False, Ne(Mod(s0, 3), 0): False}
> Left: {(Mod(s93, 3)) < 0: False, (Mod(s93, 3)) <= 0: True, 0 < (Mod(s93, 3)): False, 0 <= (Mod(s93, 3)): True, Eq(0, Mod(s93, 3)): True, Eq(Mod(s93, 3), 0): True, Ne(0, Mod(s93, 3)): False, Ne(Mod(s93, 3), 0): False}
> Right: {}
==> divisible: values don't match.
> Left: {Mod(s0, 3)}
> Left: {Mod(s93, 3)}
> Right: {}
==> guards: values don't match.
> Left: [Eq(Mod(s0, 3), 0)]
> Left: [Eq(Mod(s93, 3), 0)]
> Right: []
==> name_to_node: values don't match.
> Left: {_assert, eq, mod, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
@ -10496,17 +10499,17 @@ ShapeEnv not equal: field values don't match:
> Left: {False: False, True: True}
> Right: {}
==> guards: values don't match.
> Left: [Eq(s0, 3)]
> Left: [Eq(s93, 3)]
> Right: []
==> name_to_node: values don't match.
> Left: {_assert, eq, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
> Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
==> replacements: values don't match.
> Left: {s0: 3}
> Left: {s93: 3}
> Right: {}
==> var_to_range: values don't match.
> Left: {s0: VR[3, 3], s1: VR[2, int_oo]}
> Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
> Left: {s44: VR[2, int_oo], s93: VR[3, 3]}
> Right: {s44: VR[2, int_oo], s93: VR[2, int_oo]}
""",
)
self._replay_and_check(main)
@ -10537,17 +10540,17 @@ ShapeEnv not equal: field values don't match:
ShapeEnv not equal: field values don't match:
==> axioms: values don't match.
> Left: {3 <= s0: True, s0 < 3: False}
> Left: {3 <= s93: True, s93 < 3: False}
> Right: {}
==> guards: values don't match.
> Left: [s0 >= 3]
> Left: [s93 >= 3]
> Right: []
==> name_to_node: values don't match.
> Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
> Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
==> var_to_range: values don't match.
> Left: {s0: VR[3, int_oo], s1: VR[2, int_oo]}
> Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
> Left: {s44: VR[2, int_oo], s93: VR[3, int_oo]}
> Right: {s44: VR[2, int_oo], s93: VR[2, int_oo]}
""",
)
self._replay_and_check(main)

View File

@ -4768,7 +4768,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertExpectedInline(
str(graph.code).strip(),
"""\
def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
l_x_ = L_x_
getitem_2 = l_x_[0]
sum_1 = getitem_2.sum(); getitem_2 = None

View File

@ -339,7 +339,7 @@ class StructuredTraceTest(TestCase):
{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"create_symbol": {"symbol": "s0", "val": "1", "vr": "[-int_oo, int_oo]", "source": "L['y']", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"create_symbol": {"symbol": "s48", "val": "1", "vr": "[-int_oo, int_oo]", "source": "L['y']", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_x_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"inductor_pre_grad_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
@ -767,15 +767,15 @@ class StructuredTraceTest(TestCase):
{"describe_storage": {"id": 0, "describer_id": "ID", "size": 200}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [5, 10], "is_leaf": true, "stride": [10, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"create_symbol": {"symbol": "s0", "val": "5", "vr": "[2, int_oo]", "source": "L['a'].size()[0]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"create_symbol": {"symbol": "s1", "val": "10", "vr": "[2, int_oo]", "source": "L['a'].size()[1]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"create_symbol": {"symbol": "s97", "val": "5", "vr": "[2, int_oo]", "source": "L['a'].size()[0]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"create_symbol": {"symbol": "s98", "val": "10", "vr": "[2, int_oo]", "source": "L['a'].size()[1]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_storage": {"id": 1, "describer_id": "ID", "size": 600}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [10, 15], "is_leaf": true, "stride": [15, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['b']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"create_symbol": {"symbol": "s2", "val": "10", "vr": "[2, int_oo]", "source": "L['b'].size()[0]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"create_symbol": {"symbol": "s3", "val": "15", "vr": "[2, int_oo]", "source": "L['b'].size()[1]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"guard_added_fast": {"expr": "Eq(s1, s2)", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_a_": ["s0", "s1"], "l_b_": ["s1", "s3"], "matmul": ["s0", "s3"]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"create_symbol": {"symbol": "s52", "val": "10", "vr": "[2, int_oo]", "source": "L['b'].size()[0]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"create_symbol": {"symbol": "s20", "val": "15", "vr": "[2, int_oo]", "source": "L['b'].size()[1]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"guard_added_fast": {"expr": "Eq(s98, s52)", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_a_": ["s97", "s52"], "l_b_": ["s52", "s20"], "matmul": ["s97", "s20"]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
""", # noqa: B950

View File

@ -1373,21 +1373,21 @@ class GraphModule(torch.nn.Module):
# During fakeifying, we end up allocating a separate symint
# for the outer and inner tensor (in this test, s0 is unused).
expected_var_to_val = {
"s0": 8,
"s1": 4,
"s50": 4,
"s77": 8,
}
expected_var_to_sources = {
"s0": "L['x'].size()[0]",
"s1": "L['x'].inner_elem.size()[0]",
"s50": "L['x'].inner_elem.size()[0]",
"s77": "L['x'].size()[0]",
}
self.assertEqual(curr_var_to_val, expected_var_to_val)
self.assertEqual(curr_var_to_sources, expected_var_to_sources)
self.assertExpectedInline(
"\n".join(guards),
"""\
Eq(2*s1, s0)
2*s1 < 13
s1 > 3""",
Eq(2*s50, s77)
2*s50 < 13
s50 > 3""",
)
def test_wrapper_subclass_with_same_sized_inner_tensor(self):
@ -1976,9 +1976,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
mul: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None
mul_3: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
mul: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None
mul_3: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None
return (mul, mul_3, primals_5, primals_7, primals_7, primals_1, primals_5, primals_7)
""", # noqa: B950
)
@ -1987,9 +1987,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0, s1]", tangents_2: "f32[s0, s1]"):
mul_8: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = None
mul_9: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_2, primals_1); tangents_2 = primals_1 = None
def forward(self, primals_1: "Sym(s47)", primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s47, s16]", tangents_2: "f32[s47, s16]"):
mul_8: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = None
mul_9: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(tangents_2, primals_1); tangents_2 = primals_1 = None
return (None, None, mul_8, mul_9, primals_5, primals_7, primals_7)
""", # noqa: B950
)
@ -2009,12 +2009,12 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
view: "f32[s1, s0]" = torch.ops.aten.view.default(clone, [primals_2, primals_1]); clone = None
view_1: "f32[s1, s0]" = torch.ops.aten.view.default(clone_1, [primals_2, primals_1]); clone_1 = primals_1 = None
view: "f32[s16, s47]" = torch.ops.aten.view.default(clone, [primals_2, primals_1]); clone = None
view_1: "f32[s16, s47]" = torch.ops.aten.view.default(clone_1, [primals_2, primals_1]); clone_1 = primals_1 = None
return (view, view_1, primals_2, primals_5, primals_5, primals_5, primals_7)
""", # noqa: B950
)
@ -2023,9 +2023,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s1, s0]", tangents_2: "f32[s1, s0]"):
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s16, s47]", tangents_2: "f32[s16, s47]"):
view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
""", # noqa: B950
)
@ -2047,15 +2047,15 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
mul: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None
mul_3: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None
mul_8: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul, primals_2); mul = None
mul_11: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_3, primals_2); mul_3 = None
mul_16: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_8, primals_1); mul_8 = None
mul_19: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_11, primals_1); mul_11 = None
mul_24: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_16, primals_2); mul_16 = None
mul_27: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_19, primals_2); mul_19 = None
def forward(self, primals_1: "Sym(s97)", primals_2: "Sym(s98)", primals_3: "f32[s97, s98]", primals_4: "f32[s97, s98]", primals_5: "Sym(s97)", primals_6: "Sym(s98)", primals_7: "Sym(s98)"):
mul: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None
mul_3: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None
mul_8: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul, primals_2); mul = None
mul_11: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_3, primals_2); mul_3 = None
mul_16: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_8, primals_1); mul_8 = None
mul_19: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_11, primals_1); mul_11 = None
mul_24: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_16, primals_2); mul_16 = None
mul_27: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_19, primals_2); mul_19 = None
return (mul_24, mul_27, primals_5, primals_7, primals_7, primals_1, primals_2, primals_5, primals_7)
""", # noqa: B950
)
@ -2064,15 +2064,15 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0, s1]", tangents_2: "f32[s0, s1]"):
mul_32: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_1, primals_2); tangents_1 = None
mul_33: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_2, primals_2); tangents_2 = None
mul_34: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_32, primals_1); mul_32 = None
mul_35: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_33, primals_1); mul_33 = None
mul_36: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_34, primals_2); mul_34 = None
mul_37: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_35, primals_2); mul_35 = primals_2 = None
mul_38: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_36, primals_1); mul_36 = None
mul_39: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_37, primals_1); mul_37 = primals_1 = None
def forward(self, primals_1: "Sym(s97)", primals_2: "Sym(s98)", primals_5: "Sym(s97)", primals_7: "Sym(s98)", tangents_1: "f32[s97, s98]", tangents_2: "f32[s97, s98]"):
mul_32: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(tangents_1, primals_2); tangents_1 = None
mul_33: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(tangents_2, primals_2); tangents_2 = None
mul_34: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_32, primals_1); mul_32 = None
mul_35: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_33, primals_1); mul_33 = None
mul_36: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_34, primals_2); mul_34 = None
mul_37: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_35, primals_2); mul_35 = primals_2 = None
mul_38: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_36, primals_1); mul_36 = None
mul_39: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_37, primals_1); mul_37 = primals_1 = None
return (None, None, mul_38, mul_39, primals_5, primals_7, primals_7)
""", # noqa: B950
)
@ -2092,12 +2092,12 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
view: "f32[s0, s1]" = torch.ops.aten.view.default(clone, [primals_1, primals_2]); clone = None
view_1: "f32[s0, s1]" = torch.ops.aten.view.default(clone_1, [primals_1, primals_2]); clone_1 = primals_1 = primals_2 = None
view: "f32[s47, s16]" = torch.ops.aten.view.default(clone, [primals_1, primals_2]); clone = None
view_1: "f32[s47, s16]" = torch.ops.aten.view.default(clone_1, [primals_1, primals_2]); clone_1 = primals_1 = primals_2 = None
return (view, view_1, primals_5, primals_7, primals_7, primals_5, primals_7)
""", # noqa: B950
)
@ -2106,9 +2106,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0, s1]", tangents_2: "f32[s0, s1]"):
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s47, s16]", tangents_2: "f32[s47, s16]"):
view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
""", # noqa: B950
)
@ -2128,13 +2128,13 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
mul_6: "Sym(s0*s1)" = primals_1 * primals_2; primals_1 = primals_2 = None
view: "f32[s0*s1]" = torch.ops.aten.view.default(clone, [mul_6]); clone = None
view_1: "f32[s0*s1]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None
mul_6: "Sym(s16*s47)" = primals_1 * primals_2; primals_1 = primals_2 = None
view: "f32[s16*s47]" = torch.ops.aten.view.default(clone, [mul_6]); clone = None
view_1: "f32[s16*s47]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None
return (view, view_1, mul_6, primals_5, primals_7)
""", # noqa: B950
)
@ -2143,9 +2143,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0*s1]", tangents_2: "f32[s0*s1]"):
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s16*s47]", tangents_2: "f32[s16*s47]"):
view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
""", # noqa: B950
)
@ -2165,13 +2165,13 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
mul_6: "Sym(s0*s1)" = primals_1 * primals_2; primals_1 = primals_2 = None
view: "f32[s0*s1]" = torch.ops.aten.view.default(clone, [mul_6])
view_1: "f32[s0*s1]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None
mul_6: "Sym(s16*s47)" = primals_1 * primals_2; primals_1 = primals_2 = None
view: "f32[s16*s47]" = torch.ops.aten.view.default(clone, [mul_6])
view_1: "f32[s16*s47]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None
return (clone, view, view_1, mul_6, primals_5, primals_7)
""", # noqa: B950
)
@ -2180,9 +2180,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0*s1]", tangents_2: "f32[s0*s1]"):
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s16*s47]", tangents_2: "f32[s16*s47]"):
view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
""", # noqa: B950
)
@ -2261,13 +2261,13 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[1].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "f32[3, s0]", primals_3: "f32[3, s0]", primals_4: "Sym(s0)", primals_5: "Sym(s0)"):
clone: "f32[3, s0]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
clone_1: "f32[3, s0]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
def forward(self, primals_1: "Sym(s16)", primals_2: "f32[3, s16]", primals_3: "f32[3, s16]", primals_4: "Sym(s16)", primals_5: "Sym(s16)"):
clone: "f32[3, s16]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
view: "f32[3*s0]" = torch.ops.aten.view.default(clone, [-1])
sym_numel_default: "Sym(3*s0)" = torch.ops.aten.sym_numel.default(clone)
view_1: "f32[3*s0]" = torch.ops.aten.view.default(clone_1, [-1])
view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1])
sym_numel_default: "Sym(3*s16)" = torch.ops.aten.sym_numel.default(clone)
view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1])
return (clone, view, view_1, sym_numel_default, clone_1, primals_5)
""", # noqa: B950
)
@ -2287,9 +2287,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[1].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s0)", tangents_1: "f32[3*s0]", tangents_2: "f32[3*s0]"):
view_2: "f32[3, s0]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None
view_3: "f32[3, s0]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None
def forward(self, primals_5: "Sym(s16)", tangents_1: "f32[3*s16]", tangents_2: "f32[3*s16]"):
view_2: "f32[3, s16]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None
view_3: "f32[3, s16]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None
return (None, view_2, view_3, primals_5, primals_5)
""", # noqa: B950
)
@ -2317,13 +2317,13 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "f32[3, s0]", primals_3: "f32[3, s0]", primals_4: "Sym(s0)", primals_5: "Sym(s0)"):
clone: "f32[3, s0]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
clone_1: "f32[3, s0]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
def forward(self, primals_1: "Sym(s16)", primals_2: "f32[3, s16]", primals_3: "f32[3, s16]", primals_4: "Sym(s16)", primals_5: "Sym(s16)"):
clone: "f32[3, s16]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
view: "f32[3*s0]" = torch.ops.aten.view.default(clone, [-1])
sym_numel_default: "Sym(3*s0)" = torch.ops.aten.sym_numel.default(clone)
view_1: "f32[3*s0]" = torch.ops.aten.view.default(clone_1, [-1])
view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1])
sym_numel_default: "Sym(3*s16)" = torch.ops.aten.sym_numel.default(clone)
view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1])
return (clone, view, view_1, sym_numel_default, clone_1, primals_5)
""", # noqa: B950
)
@ -2332,9 +2332,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s0)", tangents_1: "f32[3*s0]", tangents_2: "f32[3*s0]"):
view_2: "f32[3, s0]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None
view_3: "f32[3, s0]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None
def forward(self, primals_5: "Sym(s16)", tangents_1: "f32[3*s16]", tangents_2: "f32[3*s16]"):
view_2: "f32[3, s16]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None
view_3: "f32[3, s16]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None
return (None, view_2, view_3, primals_5, primals_5)
""", # noqa: B950
)
@ -2501,10 +2501,10 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s2)", primals_2: "Sym(s3)", primals_3: "Sym(s1)", primals_4: "f64[s0, s1]", primals_5: "i64[s2 + 1]", primals_6: "f32[s6, 0]", primals_7: "f32[s7, 0]", primals_8: "Sym(s2)", primals_9: "Sym(s1)", primals_10: "Sym(s1)"):
clone: "f64[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
def forward(self, primals_1: "Sym(s51)", primals_2: "Sym(s71)", primals_3: "Sym(s55)", primals_4: "f64[s64, s55]", primals_5: "i64[s51 + 1]", primals_6: "f32[s0, 0]", primals_7: "f32[s83, 0]", primals_8: "Sym(s51)", primals_9: "Sym(s55)", primals_10: "Sym(s55)"):
clone: "f64[s64, s55]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
mul: "f64[s0, s1]" = torch.ops.aten.mul.Tensor(clone, primals_1); clone = None
mul: "f64[s64, s55]" = torch.ops.aten.mul.Tensor(clone, primals_1); clone = None
return (mul, primals_5, primals_6, primals_7, primals_8, primals_10, primals_10, primals_1, primals_8, primals_10)
""", # noqa: B950
)
@ -2513,8 +2513,8 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s2)", primals_8: "Sym(s2)", primals_10: "Sym(s1)", tangents_1: "f64[s0, s1]", tangents_2: "i64[s2 + 1]", tangents_3: "f32[s6, 0]", tangents_4: "f32[s7, 0]"):
mul_1: "f64[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = primals_1 = None
def forward(self, primals_1: "Sym(s51)", primals_8: "Sym(s51)", primals_10: "Sym(s55)", tangents_1: "f64[s64, s55]", tangents_2: "i64[s51 + 1]", tangents_3: "f32[s0, 0]", tangents_4: "f32[s83, 0]"):
mul_1: "f64[s64, s55]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = primals_1 = None
return (None, None, None, mul_1, tangents_2, tangents_3, tangents_4, primals_8, primals_10, primals_10)
""", # noqa: B950
)
@ -2534,11 +2534,11 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s2)", primals_2: "Sym(s3)", primals_3: "Sym(s1)", primals_4: "f64[s0, s1]", primals_5: "i64[s2 + 1]", primals_6: "f32[s6, 0]", primals_7: "f32[s7, 0]", primals_8: "Sym(s2)", primals_9: "Sym(s1)", primals_10: "Sym(s1)"):
clone: "f64[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
def forward(self, primals_1: "Sym(s51)", primals_2: "Sym(s71)", primals_3: "Sym(s55)", primals_4: "f64[s64, s55]", primals_5: "i64[s51 + 1]", primals_6: "f32[s0, 0]", primals_7: "f32[s83, 0]", primals_8: "Sym(s51)", primals_9: "Sym(s55)", primals_10: "Sym(s55)"):
clone: "f64[s64, s55]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
cat: "f64[s0, 2*s1]" = torch.ops.aten.cat.default([clone, clone], 1); clone = None
add_2: "Sym(2*s1)" = primals_10 + primals_10
cat: "f64[s64, 2*s55]" = torch.ops.aten.cat.default([clone, clone], 1); clone = None
add_2: "Sym(2*s55)" = primals_10 + primals_10
return (cat, primals_5, primals_6, primals_7, primals_8, add_2, add_2, primals_8, primals_10, add_2)
""", # noqa: B950
)
@ -2547,11 +2547,11 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_8: "Sym(s2)", primals_10: "Sym(s1)", add_2: "Sym(2*s1)", tangents_1: "f64[s0, 2*s1]", tangents_2: "i64[s2 + 1]", tangents_3: "f32[s6, 0]", tangents_4: "f32[s7, 0]"):
slice_1: "f64[s0, s1]" = torch.ops.aten.slice.Tensor(tangents_1, 1, 0, primals_10)
slice_2: "f64[s0, s1]" = torch.ops.aten.slice.Tensor(tangents_1, 1, primals_10, add_2); tangents_1 = add_2 = None
def forward(self, primals_8: "Sym(s51)", primals_10: "Sym(s55)", add_2: "Sym(2*s55)", tangents_1: "f64[s64, 2*s55]", tangents_2: "i64[s51 + 1]", tangents_3: "f32[s0, 0]", tangents_4: "f32[s83, 0]"):
slice_1: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, 0, primals_10)
slice_2: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, primals_10, add_2); tangents_1 = add_2 = None
add_4: "f64[s0, s1]" = torch.ops.aten.add.Tensor(slice_1, slice_2); slice_1 = slice_2 = None
add_4: "f64[s64, s55]" = torch.ops.aten.add.Tensor(slice_1, slice_2); slice_1 = slice_2 = None
return (None, None, None, add_4, tangents_2, tangents_3, tangents_4, primals_8, primals_10, primals_10)
""", # noqa: B950
)
@ -2580,7 +2580,7 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "Sym(s3)", arg1_1: "Sym(s4)", arg2_1: "Sym(s2)", arg3_1: "f64[9, s2]", arg4_1: "i64[s3 + 1]", arg5_1: "f32[s7, 0]", arg6_1: "f32[s8, 0]", arg7_1: "Sym(s3)", arg8_1: "Sym(s2)", arg9_1: "Sym(s2)"):
def forward(self, arg0_1: "Sym(s51)", arg1_1: "Sym(s71)", arg2_1: "Sym(s55)", arg3_1: "f64[9, s55]", arg4_1: "i64[s51 + 1]", arg5_1: "f32[s0, 0]", arg6_1: "f32[s83, 0]", arg7_1: "Sym(s51)", arg8_1: "Sym(s55)", arg9_1: "Sym(s55)"):
randn: "f64[2, 5]" = torch.ops.aten.randn.default([2, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
randn_1: "f64[3, 5]" = torch.ops.aten.randn.default([3, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
randn_2: "f64[4, 5]" = torch.ops.aten.randn.default([4, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
@ -2594,13 +2594,13 @@ class <lambda>(torch.nn.Module):
zeros_1: "f32[2, 0]" = torch.ops.aten.zeros.default([2, 0], device = device(type='cpu'), pin_memory = False)
zeros_2: "f32[4, 0]" = torch.ops.aten.zeros.default([4, 0], device = device(type='cpu'), pin_memory = False)
cat_2: "f64[9, s2 + 5]" = torch.ops.aten.cat.default([cat, arg3_1], 1); cat = arg3_1 = None
cat_2: "f64[9, s55 + 5]" = torch.ops.aten.cat.default([cat, arg3_1], 1); cat = arg3_1 = None
sin: "f64[9, s2 + 5]" = torch.ops.aten.sin.default(cat_2)
mul: "f64[9, s2 + 5]" = torch.ops.aten.mul.Tensor(sin, 3); sin = None
sin: "f64[9, s55 + 5]" = torch.ops.aten.sin.default(cat_2)
mul: "f64[9, s55 + 5]" = torch.ops.aten.mul.Tensor(sin, 3); sin = None
sym_size_int: "Sym(s2 + 5)" = torch.ops.aten.sym_size.int(cat_2, 1); cat_2 = None
sym_stride_int: "Sym(s2 + 5)" = torch.ops.aten.sym_stride.int(mul, 0)
sym_size_int: "Sym(s55 + 5)" = torch.ops.aten.sym_size.int(cat_2, 1); cat_2 = None
sym_stride_int: "Sym(s55 + 5)" = torch.ops.aten.sym_stride.int(mul, 0)
return (mul, cat_1, zeros_1, zeros_2, sym_size_int, sym_stride_int)
""", # noqa: B950
)
@ -2757,10 +2757,10 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase):
norm_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, s1: "Sym(s1)", L_nt_: "f64[3, s1, 5]"):
def forward(self, s71: "Sym(s71)", L_nt_: "f64[3, s71, 5]"):
l_nt_ = L_nt_
add: "f64[3, s1, 5]" = l_nt_ + 2; l_nt_ = None
add: "f64[3, s71, 5]" = l_nt_ + 2; l_nt_ = None
return (add,)
""", # noqa: B950
)
@ -3254,27 +3254,27 @@ class GraphModule(torch.nn.Module):
# varies based on the type of view
guard_str = "\n".join(guards)
if nt_view_name == "subclass_dense":
self.assertExpectedInline(guard_str, """Eq(s3 - 1, s0)""")
self.assertExpectedInline(guard_str, """Eq(s85 - 1, s77)""")
elif nt_view_name == "dense_subclass_dense_subclass":
self.assertExpectedInline(
guard_str,
"""\
Eq(s5 - 1, s2)
Eq(s12 - 1, s7)
Eq(s11, s9)""",
Eq(s85 - 1, s77)
Eq(s80 - 1, s78)
Eq(s72, s71)""",
)
elif nt_view_name.startswith("base_is_nt_True"):
self.assertExpectedInline(
guard_str,
"""Eq(s3 - 1, s0)""",
"""Eq(s17 - 1, s83)""",
)
else:
self.assertExpectedInline(
guard_str,
"""\
Eq(s4 - 1, s1)
Eq(s13 - 1, s8)
Eq(s12, s10)""",
Eq(s85 - 1, s64)
Eq(s80 - 1, s77)
Eq(s72, s71)""",
)
return gm

View File

@ -1560,7 +1560,7 @@ graph():
)
with self.assertRaisesRegex(
error_type,
r"Real tensor propagation found an output size mismatch between fake shape s1 and real shape 4, "
r"Real tensor propagation found an output size mismatch between fake shape s\d+ and real shape 4, "
r"at output\.size\(0\), for func: mylib.foo.default",
):
export(
@ -2848,7 +2848,7 @@ def forward(self, p_linear_weight, p_linear_bias, x):
with self.assertRaisesRegex(
RuntimeError,
"Expected input.*shape.*= 9 to be "
"of the form 2\\*s1, where s1 is an integer",
"of the form 2\\*s92, where s92 is an integer",
):
ep.module()(torch.randn(9))
@ -3506,8 +3506,11 @@ def forward(self, x):
dynamic_shapes=({0: Dim("x")},),
)
# Since symbol names are based on hash of source names, and these differ across inference and
# training, we do range comparisons instead.
self.assertEqual(
str(ep_for_training.range_constraints), str(ep_for_real.range_constraints)
str(ep_for_training.range_constraints.values()),
str(ep_for_real.range_constraints.values()),
)
def test_export_for_training_with_container_type(self):
@ -4398,7 +4401,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
em.module()(torch.randn(4, 3))
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Eq\(Mod\(s0\*s1, s0 \- 1\), 0\)",
r"Runtime assertion failed for expression Eq\(Mod\(s27\*s77, s77 \- 1\), 0\)",
):
em.module()(torch.randn(4, 5))
@ -4409,7 +4412,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
x = torch.randn(3, 5)
with self.assertRaisesRegex(
RuntimeError,
"Expected.*shape\\[1\\] = 5 to be of the form 2\\*s1, where s1 is an integer",
"Expected.*shape\\[1\\] = 5 to be of the form 2\\*s33, where s33 is an integer",
):
em.module()(x)
@ -4968,11 +4971,14 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
)
self.assertEqual(
[
str(node.meta["val"].shape)
# First dimension varies across strict and non-strict
# since the source names are different, resulting in
# different symbol names.
str(node.meta["val"].shape[1:])
for node in efoo.graph_module.graph.nodes
if node.op == "placeholder"
],
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
["torch.Size([2, 3])", "torch.Size([3, 4])"],
)
@testing.expectedFailureCppSerDes
@ -5110,14 +5116,10 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
"y": (batch, size, size),
},
)
self.assertEqual(
[
str(node.meta["val"].shape)
for node in efoo.graph_module.graph.nodes
if node.op == "placeholder"
],
["torch.Size([s0, s1, s1])", "torch.Size([s0, s1, s1])"],
)
for node in efoo.graph_module.graph.nodes:
if node.op == "placeholder":
self.assertEqual(node.meta["val"].shape[1], node.meta["val"].shape[2])
self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
# pass dynamic shapes of inputs [multiple, mostly distinct]
@ -5128,13 +5130,14 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
inputs,
dynamic_shapes={"x": (batch, M, K), "y": (batch, K, N)},
)
placeholders = [
node.meta["val"].shape
for node in efoo.graph_module.graph.nodes
if node.op == "placeholder"
]
self.assertEqual(
[
str(node.meta["val"].shape)
for node in efoo.graph_module.graph.nodes
if node.op == "placeholder"
],
["torch.Size([s0, s1, s2])", "torch.Size([s0, s2, s5])"],
placeholders[0][2],
placeholders[1][1],
)
self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
@ -5151,11 +5154,14 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
)
self.assertEqual(
[
str(node.meta["val"].shape)
# First dimension varies across strict and non-strict
# since the source names are different, resulting in
# different symbol names.
str(node.meta["val"].shape[1:])
for node in efoo.graph_module.graph.nodes
if node.op == "placeholder"
],
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
["torch.Size([2, 3])", "torch.Size([3, 4])"],
)
self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
@ -5172,11 +5178,14 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
)
self.assertEqual(
[
str(node.meta["val"].shape)
# First dimension varies across strict and non-strict
# since the source names are different, resulting in
# different symbol names.
str(node.meta["val"].shape[1:])
for node in efoo.graph_module.graph.nodes
if node.op == "placeholder"
],
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
["torch.Size([2, 3])", "torch.Size([3, 4])"],
)
self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
@ -5486,7 +5495,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
if node.op == "placeholder"
]
self.assertEqual(len(input_shapes), 9)
self.assertTrue(all(shape == "torch.Size([s0])" for shape in input_shapes))
self.assertTrue(all(shape == "torch.Size([s3])" for shape in input_shapes))
def test_error_does_not_reference_eager_fallback(self):
class Module(torch.nn.Module):
@ -11165,7 +11174,7 @@ def forward(self, x, y):
self.assertEqual(out2.shape, torch.ones(11, 4, 3).shape)
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Eq\(Mod\(s0\*s1, 4\*s0 \- 4\), 0\) on node 'eq.*'",
r"Runtime assertion failed for expression Eq\(Mod\(s27\*s77, 4\*s77 \- 4\), 0\) on node 'eq.*'",
):
ep.module()(torch.randn(8, 8)) # fail
@ -11197,7 +11206,7 @@ def forward(self, x, y):
self.assertEqual(out2.shape, torch.ones(40).shape)
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Eq\(s0\*s1, s2\*s3\) on node 'eq.*'",
r"Runtime assertion failed for expression Eq\((.*)\) on node '.*'",
): # fail only at runtime
ep.module()(torch.randn(5, 8), torch.randn(4, 5), torch.randn(30)) # fail
@ -11224,7 +11233,7 @@ def forward(self, x, y):
self.assertEqual(out1.shape, torch.ones(126).shape)
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Eq\(s0\*s1\*s2, s3\) on node 'eq.*'",
r"Runtime assertion failed for expression Eq\((.*)\) on node '.*'",
): # fail only at runtime
ep.module()(torch.randn(4, 3, 2), torch.randn(10)) # fail
@ -11305,12 +11314,12 @@ def forward(self, x, y):
)
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Ne\(s0, 20\)",
r"Runtime assertion failed for expression Ne\(s77, 20\)",
):
ep.module()(torch.randn(20, 20, 16))
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Ne\(Mod\(s0, 20\), 0\)",
r"Runtime assertion failed for expression Ne\(Mod\(s77, 20\), 0\)",
):
ep.module()(torch.randn(400, 20, 16))
ep.module()(torch.randn(42, 20, 16))
@ -11348,17 +11357,17 @@ def forward(self, x, y):
self.assertEqual(out1.shape, torch.ones(27).shape)
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Ne\(s0, s1\)",
r"Runtime assertion failed for expression Ne\(s77, s17\)",
): # fail only at runtime
ep.module()(torch.randn(4), torch.randn(4)) # fail
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Ne\(s0, s1\**3\)",
r"Runtime assertion failed for expression Ne\(s77, s17\**3\)",
):
ep.module()(torch.randn(64), torch.randn(4)) # fail
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Eq\(s0\**2, 3\*s1\)",
r"Runtime assertion failed for expression Eq\(s77\**2, 3\*s17\)",
):
ep.module()(torch.randn(10), torch.randn(9)) # fail

View File

@ -539,8 +539,12 @@ def forward(self, x):
ep.range_constraints[symint] = ValueRanges(lower=lower_range, upper=upper_range)
serialized = ExportedProgramSerializer().serialize(ep)
self.assertEqual(serialized.exported_program.range_constraints["s0"].min_val, 2)
self.assertEqual(serialized.exported_program.range_constraints["s0"].max_val, 3)
self.assertEqual(
serialized.exported_program.range_constraints["s77"].min_val, 2
)
self.assertEqual(
serialized.exported_program.range_constraints["s77"].max_val, 3
)
def test_kwargs_default(self) -> None:
"""

View File

@ -5941,8 +5941,8 @@ class TestAOTModuleSimplified(AOTTestCase):
self.assertExpectedInline(
shape_env.format_guards(),
"""\
- Eq(s1, 20)
- Eq(s2, 30)""",
- Eq(s49, 20)
- Eq(s70, 30)""",
)
assert torch.allclose(ref[0], res[0])

View File

@ -4553,10 +4553,10 @@ def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_bo
gm.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
sym_size_int = torch.ops.aten.sym_size.int(arg2_1, 0)
sym_size_int = torch.ops.aten.sym_size.int(arg3_1, 1)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg2_1, 1)
sym_size_int_2 = torch.ops.aten.sym_size.int(arg3_1, 0)
sym_size_int_3 = torch.ops.aten.sym_size.int(arg3_1, 1)
sym_size_int_2 = torch.ops.aten.sym_size.int(arg2_1, 0)
sym_size_int_3 = torch.ops.aten.sym_size.int(arg3_1, 0)
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), (sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3)); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None
@ -4719,10 +4719,10 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1
def forward(self, a_1, b_1):
sum_1 = torch.ops.aten.sum.default(a_1)
gt = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(a_1, 1)
sym_size_int_2 = torch.ops.aten.sym_size.int(b_1, 0)
sym_size_int_3 = torch.ops.aten.sym_size.int(b_1, 1)
sym_size_int = torch.ops.aten.sym_size.int(a_1, 1)
sym_size_int_1 = torch.ops.aten.sym_size.int(b_1, 0)
sym_size_int_2 = torch.ops.aten.sym_size.int(b_1, 1)
sym_size_int_3 = torch.ops.aten.sym_size.int(a_1, 0)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (a_1, b_1, sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3)); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None
@ -5856,7 +5856,7 @@ def forward(self, x_1):
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int, sym_size_int_1)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int = sym_size_int_1 = None
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int_1, sym_size_int)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = sym_size_int = None
getitem = cond[0]; cond = None
return getitem""", # noqa: B950
)
@ -5969,7 +5969,7 @@ def forward(self, x_1):
false_graph_0 = self.false_graph_0
_tensor_constant0 = self._tensor_constant0
_tensor_constant1 = self._tensor_constant1
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, _tensor_constant0, sym_size_int, sym_size_int_1, _tensor_constant1)); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = sym_size_int = sym_size_int_1 = _tensor_constant1 = None
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, _tensor_constant0, sym_size_int_1, sym_size_int, _tensor_constant1)); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = sym_size_int_1 = sym_size_int = _tensor_constant1 = None
getitem = cond[0]; cond = None
return getitem""", # noqa: B950
)
@ -6209,7 +6209,7 @@ def forward(self, x_1):
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int, sym_size_int_1)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int = sym_size_int_1 = None
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int_1, sym_size_int)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = sym_size_int = None
getitem = cond[0]; cond = None
return getitem""", # noqa: B950
)
@ -6558,14 +6558,14 @@ def forward(self, l_inp_, l_tmp_):
self.assertExpectedInline(
backend.graphs[0].code.strip(),
"""\
def forward(self, s0 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor, L_self_num : torch.SymInt):
def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor, L_self_num : torch.SymInt):
l_a_ = L_a_
l_b_ = L_b_
l_self_num = L_self_num
tensor = torch.tensor([True])
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, l_self_num, s0)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = l_self_num = s0 = None
cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, l_self_num, s97)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = l_self_num = s97 = None
getitem = cond[0]; cond = None
return (getitem,)""", # noqa: B950
)
@ -6753,10 +6753,10 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_
"""\
class GraphModule(torch.nn.Module):
def forward(self, x):
x: "f32[s0, 3]";
x: "f32[s35, 3]";
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
sym_size_int_1: "Sym(s35)" = torch.ops.aten.sym_size.int(x, 0)
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
@ -6770,27 +6770,27 @@ class GraphModule(torch.nn.Module):
gt_1: "Sym(u1 > 0)" = getitem_2 > 0
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 on node 'gt_1'"); gt_1 = _assert_scalar_default_1 = None
getitem_1: "f32[s0, 3]" = while_loop[1]; while_loop = None
getitem_1: "f32[s35, 3]" = while_loop[1]; while_loop = None
add: "Sym(u1 + 1)" = getitem_2 + 1
add_1: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = None
add_1: "f32[s35, 3]" = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = None
lt: "Sym(u1 < s0)" = getitem_2 < sym_size_int_1; sym_size_int_1 = None
lt: "Sym(u1 < s35)" = getitem_2 < sym_size_int_1; sym_size_int_1 = None
mul: "Sym(2*u1)" = getitem_2 * 2; getitem_2 = None
ones: "f32[2*u1]" = torch.ops.aten.ones.default([mul], device = device(type='cpu'), pin_memory = False); mul = None
return pytree.tree_unflatten((add, add_1, lt, ones), self._out_spec)
class while_loop_cond_graph_0(torch.nn.Module):
def forward(self, it_1: "Sym(u0)", x_1: "f32[s0, 3]"):
sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
lt: "Sym(u0 < s0)" = it_1 < sym_size_int; it_1 = sym_size_int = None
def forward(self, it_1: "Sym(u0)", x_1: "f32[s35, 3]"):
sym_size_int: "Sym(s35)" = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
lt: "Sym(u0 < s35)" = it_1 < sym_size_int; it_1 = sym_size_int = None
return lt
class while_loop_body_graph_0(torch.nn.Module):
def forward(self, it_1: "Sym(u0)", x_1: "f32[s0, 3]"):
clone: "f32[s0, 3]" = torch.ops.aten.clone.default(x_1); x_1 = None
def forward(self, it_1: "Sym(u0)", x_1: "f32[s35, 3]"):
clone: "f32[s35, 3]" = torch.ops.aten.clone.default(x_1); x_1 = None
select: "f32[3]" = torch.ops.aten.select.int(clone, 0, it_1)
select_1: "f32[3]" = torch.ops.aten.select.int(clone, 0, it_1)
add: "f32[3]" = torch.ops.aten.add.Tensor(select_1, it_1); select_1 = None
@ -6820,12 +6820,12 @@ class GraphModule(torch.nn.Module):
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_x_: "f32[s0, s1]"):
def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"):
l_x_ = L_x_
cond_fn_0 = self.cond_fn_0
body_fn_0 = self.body_fn_0
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (0, l_x_), (s0, s1)); cond_fn_0 = body_fn_0 = l_x_ = s1 = None
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (0, l_x_), (s27, s77)); cond_fn_0 = body_fn_0 = l_x_ = s27 = None
getitem_4: "Sym(u1)" = while_loop[0]
@ -6835,49 +6835,49 @@ class GraphModule(torch.nn.Module):
gt_1: "Sym(u1 > 0)" = getitem_4 > 0
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 on node 'gt_1'"); gt_1 = _assert_scalar_default_1 = None
out_x: "f32[s0, s1]" = while_loop[1]; while_loop = None
out_x: "f32[s77, s27]" = while_loop[1]; while_loop = None
add: "Sym(u1 + 1)" = getitem_4 + 1
add_1: "f32[s0, s1]" = getitem_4 + out_x; out_x = None
add_1: "f32[s77, s27]" = getitem_4 + out_x; out_x = None
lt: "Sym(u1 < s0)" = getitem_4 < s0; s0 = None
lt: "Sym(u1 < s77)" = getitem_4 < s77; s77 = None
mul: "Sym(2*u1)" = getitem_4 * 2; getitem_4 = None
ones: "f32[2*u1]" = torch.ones(mul); mul = None
return (add, add_1, lt, ones)
class cond_fn_0(torch.nn.Module):
def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s0, s1]", s0, s1):
s0_1 = s0
s1_1 = s1
def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s77, s27]", s27, s77):
s27_1 = s27
s77_1 = s77
size = l_x_.size(); l_x_ = None
getitem: "Sym(s0)" = size[0]
getitem_1: "Sym(s1)" = size[1]; size = getitem_1 = None
lt: "Sym(u0 < s0)" = unbacked_symint < getitem; unbacked_symint = getitem = None
getitem: "Sym(s77)" = size[0]
getitem_1: "Sym(s27)" = size[1]; size = getitem_1 = None
lt: "Sym(u0 < s77)" = unbacked_symint < getitem; unbacked_symint = getitem = None
return lt
class body_fn_0(torch.nn.Module):
def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s0, s1]", s0, s1):
s0_1 = s0
s1_1 = s1
def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s77, s27]", s27, s77):
s27_1 = s27
s77_1 = s77
x_clone: "f32[s0, s1]" = l_x_.clone()
x_clone: "f32[s77, s27]" = l_x_.clone()
ge: "Sym(u0 >= 0)" = unbacked_symint >= 0
_check = torch._check(ge); ge = _check = None
size = l_x_.size(); l_x_ = None
getitem: "Sym(s0)" = size[0]
getitem_1: "Sym(s1)" = size[1]; size = getitem_1 = None
lt: "Sym(u0 < s0)" = unbacked_symint < getitem; getitem = None
getitem: "Sym(s77)" = size[0]
getitem_1: "Sym(s27)" = size[1]; size = getitem_1 = None
lt: "Sym(u0 < s77)" = unbacked_symint < getitem; getitem = None
_check_1 = torch._check(lt); lt = _check_1 = None
select: "f32[s1]" = x_clone.select(0, unbacked_symint)
select_1: "f32[s1]" = x_clone.select(0, unbacked_symint)
add: "f32[s1]" = select_1 + unbacked_symint; select_1 = None
copy_: "f32[s1]" = select.copy_(add); select = add = copy_ = None
select: "f32[s27]" = x_clone.select(0, unbacked_symint)
select_1: "f32[s27]" = x_clone.select(0, unbacked_symint)
add: "f32[s27]" = select_1 + unbacked_symint; select_1 = None
copy_: "f32[s27]" = select.copy_(add); select = add = copy_ = None
add_1: "Sym(u0 + 1)" = unbacked_symint + 1; unbacked_symint = None
return (add_1, x_clone)
@ -7048,12 +7048,12 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, x):
x: "f32[s0, 3]";
x: "f32[s77, 3]";
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0)
sin: "f32[s0, 3]" = torch.ops.aten.sin.default(x); x = None
sin: "f32[s77, 3]" = torch.ops.aten.sin.default(x); x = None
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
@ -7065,19 +7065,19 @@ class GraphModule(torch.nn.Module):
getitem_9: "Sym(u8)" = while_loop[3]
getitem_10: "Sym(u9)" = while_loop[4]
getitem_5: "f32[s0, 3]" = while_loop[5]; while_loop = None
getitem_5: "f32[s77, 3]" = while_loop[5]; while_loop = None
add: "Sym(u7 + 1)" = getitem_8 + 1
add_1: "Sym(u8 + 1)" = getitem_9 + 1
add_2: "Sym(u9 + 1)" = getitem_10 + 1
add_3: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_8); getitem_8 = None
add_4: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_9); getitem_9 = None
add_5: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_10); getitem_10 = None
add_3: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_8); getitem_8 = None
add_4: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_9); getitem_9 = None
add_5: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_10); getitem_10 = None
return pytree.tree_unflatten((getitem_6, getitem_7, add, add_1, add_2, add_3, add_4, add_5, getitem_5), self._out_spec)
class while_loop_cond_graph_0(torch.nn.Module):
def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s0, 3]"):
def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s77, 3]"):
mul: "Sym(u17*u18)" = arg2_1 * arg3_1; arg2_1 = arg3_1 = None
mul_1: "Sym(u17*u18*u19)" = mul * arg4_1; mul = arg4_1 = None
mul_2: "Sym(u15*u16)" = arg0_1 * arg1_1; arg0_1 = arg1_1 = None
@ -7085,7 +7085,7 @@ class GraphModule(torch.nn.Module):
return lt
class while_loop_body_graph_0(torch.nn.Module):
def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s0, 3]"):
def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s77, 3]"):
add: "Sym(u15 + 1)" = arg0_1 + 1; arg0_1 = None
add_1: "Sym(u16 + 1)" = arg1_1 + 1; arg1_1 = None
@ -7093,7 +7093,7 @@ class GraphModule(torch.nn.Module):
add_3: "Sym(u18 + 1)" = arg3_1 + 1; arg3_1 = None
add_4: "Sym(u19 + 1)" = arg4_1 + 1; arg4_1 = None
add_5: "f32[s0, 3]" = torch.ops.aten.add.Tensor(arg5_1, 1); arg5_1 = None
add_5: "f32[s77, 3]" = torch.ops.aten.add.Tensor(arg5_1, 1); arg5_1 = None
return (add, add_1, add_2, add_3, add_4, add_5)
""", # noqa: B950
)
@ -7119,14 +7119,14 @@ class GraphModule(torch.nn.Module):
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_x_: "f32[s0, s1]"):
def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"):
l_x_ = L_x_
child: "f32[s0, s1]" = l_x_.sin(); l_x_ = None
child: "f32[s77, s27]" = l_x_.sin(); l_x_ = None
cond_fn_0 = self.cond_fn_0
body_fn_0 = self.body_fn_0
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (s0, s1, 2, 2, 3, child), (s0, s1)); cond_fn_0 = body_fn_0 = s0 = s1 = child = None
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (s77, s27, 2, 2, 3, child), (s27, s77)); cond_fn_0 = body_fn_0 = s77 = s27 = child = None
getitem_10: "Sym(u5)" = while_loop[0]
getitem_11: "Sym(u6)" = while_loop[1]
@ -7134,21 +7134,21 @@ class GraphModule(torch.nn.Module):
getitem_13: "Sym(u8)" = while_loop[3]
getitem_14: "Sym(u9)" = while_loop[4]
out_x: "f32[s0, s1]" = while_loop[5]; while_loop = None
out_x: "f32[s77, s27]" = while_loop[5]; while_loop = None
add: "Sym(u7 + 1)" = getitem_12 + 1
add_1: "Sym(u8 + 1)" = getitem_13 + 1
add_2: "Sym(u9 + 1)" = getitem_14 + 1
add_3: "f32[s0, s1]" = getitem_12 + out_x; getitem_12 = None
add_4: "f32[s0, s1]" = getitem_13 + out_x; getitem_13 = None
add_5: "f32[s0, s1]" = getitem_14 + out_x; getitem_14 = None
add_3: "f32[s77, s27]" = getitem_12 + out_x; getitem_12 = None
add_4: "f32[s77, s27]" = getitem_13 + out_x; getitem_13 = None
add_5: "f32[s77, s27]" = getitem_14 + out_x; getitem_14 = None
return (getitem_10, getitem_11, add, add_1, add_2, add_3, add_4, add_5, out_x)
class cond_fn_0(torch.nn.Module):
def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s0, s1]", s0, s1):
s0_1 = s0
s1_1 = s1
def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s77, s27]", s27, s77):
s27_1 = s27
s77_1 = s77
mul: "Sym(u2*u3)" = unbacked_symint_1 * unbacked_symint_2; unbacked_symint_1 = unbacked_symint_2 = None
mul_1: "Sym(u2*u3*u4)" = mul * unbacked_symint_3; mul = unbacked_symint_3 = None
@ -7157,9 +7157,9 @@ class GraphModule(torch.nn.Module):
return lt
class body_fn_0(torch.nn.Module):
def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s0, s1]", s0, s1):
s0_1 = s0
s1_1 = s1
def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s77, s27]", s27, s77):
s27_1 = s27
s77_1 = s77
add: "Sym(u0 + 1)" = unbacked_symint + 1; unbacked_symint = None
add_1: "Sym(u1 + 1)" = unbacked_symint_0 + 1; unbacked_symint_0 = None
@ -7168,7 +7168,7 @@ class GraphModule(torch.nn.Module):
add_3: "Sym(u3 + 1)" = unbacked_symint_2 + 1; unbacked_symint_2 = None
add_4: "Sym(u4 + 1)" = unbacked_symint_3 + 1; unbacked_symint_3 = None
child_1: "f32[s0, s1]" = child + 1; child = None
child_1: "f32[s77, s27]" = child + 1; child = None
return (add, add_1, add_2, add_3, add_4, child_1)
""", # noqa: B950
)
@ -7336,30 +7336,30 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, x, y, z):
x: "f32[s0, 3]"; y: "f32[s1]"; z: "f32[s0, 3]";
x: "f32[s35, 3]"; y: "f32[s58]"; z: "f32[s35, 3]";
x, y, z, = fx_pytree.tree_flatten_spec(([x, y, z], {}), self._in_spec)
sym_size_int_3: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
sym_size_int_4: "Sym(s1)" = torch.ops.aten.sym_size.int(y, 0); y = None
sym_size_int_3: "Sym(s35)" = torch.ops.aten.sym_size.int(x, 0)
sym_size_int_4: "Sym(s58)" = torch.ops.aten.sym_size.int(y, 0); y = None
gt: "Sym(s0 > 5)" = sym_size_int_3 > 5
gt: "Sym(s35 > 5)" = sym_size_int_3 > 5
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x, sym_size_int_4, sym_size_int_3, z)); gt = true_graph_0 = false_graph_0 = x = sym_size_int_4 = sym_size_int_3 = z = None
getitem: "f32[s0, 3]" = cond[0]; cond = None
getitem: "f32[s35, 3]" = cond[0]; cond = None
return pytree.tree_unflatten((getitem,), self._out_spec)
class true_graph_0(torch.nn.Module):
def forward(self, x: "f32[s0, 3]", sym_size_int_4: "Sym(s1)", sym_size_int_3: "Sym(s0)", z: "f32[s0, 3]"):
add: "f32[s0, 3]" = torch.ops.aten.add.Tensor(x, sym_size_int_4); x = sym_size_int_4 = None
def forward(self, x: "f32[s35, 3]", sym_size_int_4: "Sym(s58)", sym_size_int_3: "Sym(s35)", z: "f32[s35, 3]"):
add: "f32[s35, 3]" = torch.ops.aten.add.Tensor(x, sym_size_int_4); x = sym_size_int_4 = None
return (add,)
class false_graph_0(torch.nn.Module):
def forward(self, x: "f32[s0, 3]", sym_size_int_4: "Sym(s1)", sym_size_int_3: "Sym(s0)", z: "f32[s0, 3]"):
mul: "f32[s0, 3]" = torch.ops.aten.mul.Tensor(z, sym_size_int_3); z = sym_size_int_3 = None
def forward(self, x: "f32[s35, 3]", sym_size_int_4: "Sym(s58)", sym_size_int_3: "Sym(s35)", z: "f32[s35, 3]"):
mul: "f32[s35, 3]" = torch.ops.aten.mul.Tensor(z, sym_size_int_3); z = sym_size_int_3 = None
add: "f32[s0, 3]" = torch.ops.aten.add.Tensor(x, mul); x = mul = None
add: "f32[s35, 3]" = torch.ops.aten.add.Tensor(x, mul); x = mul = None
return (add,)
""", # noqa: B950
)
@ -7522,7 +7522,7 @@ class GraphModule(torch.nn.Module):
normalize_gm(bk.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_y_: "f32[s0, s1]", L_z_: "f32[s0, s1]", L_x_: "f32[s0, s1]"):
def forward(self, s17: "Sym(s17)", s94: "Sym(s94)", L_y_: "f32[s17, s94]", L_z_: "f32[s17, s94]", L_x_: "f32[s17, s94]"):
l_y_ = L_y_
l_z_ = L_z_
l_x_ = L_x_
@ -7532,39 +7532,39 @@ class GraphModule(torch.nn.Module):
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, (l_x_, s1, s0, s0, l_z_)); gt = cond_true_0 = cond_false_0 = l_x_ = s1 = s0 = l_z_ = None
cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, (l_x_, s94, s17, s17, l_z_)); gt = cond_true_0 = cond_false_0 = l_x_ = s94 = s17 = l_z_ = None
getitem_5: "f32[u0, s1]" = cond[0]
getitem_5: "f32[u0, s94]" = cond[0]
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(getitem_5, 0); getitem_5 = None
_check_is_size = torch._check_is_size(sym_size_int); _check_is_size = None
ge: "Sym(u0 >= 0)" = sym_size_int >= 0; sym_size_int = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
ret: "f32[u0, s1]" = cond[0]; cond = None
ret: "f32[u0, s94]" = cond[0]; cond = None
sum_2: "f32[]" = l_y_.sum(); l_y_ = None
sub: "f32[u0, s1]" = sum_2 - ret; sum_2 = ret = None
sub: "f32[u0, s94]" = sum_2 - ret; sum_2 = ret = None
return (sub,)
class cond_true_0(torch.nn.Module):
def forward(self, l_x_, s1, s0_true_branch, getitem_2_false_branch, l_z__false_branch):
def forward(self, l_x_, s94, s17_true_branch, getitem_2_false_branch, l_z__false_branch):
l_x__1 = l_x_
s1_1 = s1
s94_1 = s94
add: "f32[s0, s1]" = l_x__1 + s0_true_branch; l_x__1 = s0_true_branch = None
getitem: "f32[s0 - 2, s1]" = add[slice(2, None, None)]; add = None
clone: "f32[s0 - 2, s1]" = getitem.clone(); getitem = None
add: "f32[s17, s94]" = l_x__1 + s17_true_branch; l_x__1 = s17_true_branch = None
getitem: "f32[s17 - 2, s94]" = add[slice(2, None, None)]; add = None
clone: "f32[s17 - 2, s94]" = getitem.clone(); getitem = None
return (clone,)
class cond_false_0(torch.nn.Module):
def forward(self, l_x_, s1, s0_true_branch, getitem_2_false_branch, l_z__false_branch):
def forward(self, l_x_, s94, s17_true_branch, getitem_2_false_branch, l_z__false_branch):
l_x__1 = l_x_
s1_1 = s1
s94_1 = s94
mul: "f32[s0, s1]" = getitem_2_false_branch * l_z__false_branch; getitem_2_false_branch = l_z__false_branch = None
add: "f32[s0, s1]" = l_x__1 + mul; l_x__1 = mul = None
getitem: "f32[2, s1]" = add[slice(None, 2, None)]; add = None
clone: "f32[2, s1]" = getitem.clone(); getitem = None
mul: "f32[s17, s94]" = getitem_2_false_branch * l_z__false_branch; getitem_2_false_branch = l_z__false_branch = None
add: "f32[s17, s94]" = l_x__1 + mul; l_x__1 = mul = None
getitem: "f32[2, s94]" = add[slice(None, 2, None)]; add = None
clone: "f32[2, s94]" = getitem.clone(); getitem = None
return (clone,)
""", # noqa: B950
)

View File

@ -403,10 +403,10 @@ arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None
self.assertExpectedInline(
post_grad_graphs,
"""\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu", arg3_1: "f32[s0][1]cpu", arg4_1: "f32[s0][1]cpu", arg5_1: "f32[s0][1]cpu"):
def forward(self, arg0_1: "Sym(s72)", arg1_1: "f32[s72][1]cpu", arg2_1: "f32[s72][1]cpu", arg3_1: "f32[s72][1]cpu", arg4_1: "f32[s72][1]cpu", arg5_1: "f32[s72][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg3_1, [arg4_1, arg5_1], arg2_1, 2, arg1_1); arg4_1 = arg5_1 = arg1_1 = foo_default = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg3_1, arg3_1); arg3_1 = copy__1 = None
copy_: "f32[s72][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None
copy__1: "f32[s72][1]cpu" = torch.ops.aten.copy_.default(arg3_1, arg3_1); arg3_1 = copy__1 = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -563,13 +563,13 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
self.assertExpectedInline(
graph_aot,
"""\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"):
def forward(self, arg0_1: "Sym(s17)", arg1_1: "f32[s17][1]cpu", arg2_1: "f32[s17][1]cpu"):
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg2_1, arg1_1])
getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]
getitem_2: "f32[s0][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2)
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy_ = None
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_1); arg2_1 = getitem_1 = copy__1 = None
getitem_1: "f32[s17][1]cpu" = auto_functionalized_v2[1]
getitem_2: "f32[s17][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
add: "f32[s17][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2)
copy_: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy_ = None
copy__1: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_1); arg2_1 = getitem_1 = copy__1 = None
return (add,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -595,11 +595,11 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
self.assertExpectedInline(
graph_inductor,
"""\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"):
def forward(self, arg0_1: "Sym(s17)", arg1_1: "f32[s17][1]cpu", arg2_1: "f32[s17][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg2_1, arg1_1); foo_default = None
add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(arg2_1, arg1_1)
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
add: "f32[s17][1]cpu" = torch.ops.aten.add.Tensor(arg2_1, arg1_1)
copy_: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
copy__1: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
return (add,)""",
ignore_comments=True,
ignore_empty_lines=True,
@ -663,10 +663,10 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
self.assertExpectedInline(
graph_aot,
"""\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg1_1])
getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy_ = None
getitem_1: "f32[s77][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy_ = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -691,11 +691,11 @@ def forward(self, arg0_1: "f32[2][1]cpu"):
self.assertExpectedInline(
graph_inductor,
"""\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 0)
as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 1)
foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -1291,14 +1291,14 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"):
self.assertExpectedInline(
graph_aot,
"""\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0, s0][s0, 1]cpu"):
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77, s77][s77, 1]cpu"):
floordiv: "Sym(0)" = 0 // arg0_1; arg0_1 = None
add_6: "Sym(2)" = floordiv + 2
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_slice_dim = 0, _x_slice_start = floordiv, _x_slice_end = add_6, _y_base_index = 0, _y_slice_dim = 1, _y_slice_start = 3, _y_slice_end = 4, _all_bases = [arg1_1]); floordiv = add_6 = None
getitem_1: "f32[s0, s0][s0, 1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
copy_: "f32[s0, s0][s0, 1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None
slice_3: "f32[2, s0][s0, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 0, 0, 2)
slice_4: "f32[s0, 1][s0, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 1, 3, 4); getitem_1 = None
getitem_1: "f32[s77, s77][s77, 1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
copy_: "f32[s77, s77][s77, 1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None
slice_3: "f32[2, s77][s77, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 0, 0, 2)
slice_4: "f32[s77, 1][s77, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 1, 3, 4); getitem_1 = None
return (slice_3, slice_4)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -1324,13 +1324,13 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"):
self.assertExpectedInline(
graph_inductor,
"""\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0, s0][s0, 1]cpu"):
slice_tensor: "f32[2, s0][s0, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, 0, 2)
slice_tensor_1: "f32[s0, 1][s0, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 1, 3, 4)
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77, s77][s77, 1]cpu"):
slice_tensor: "f32[2, s77][s77, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, 0, 2)
slice_tensor_1: "f32[s77, 1][s77, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 1, 3, 4)
foo_default = torch.ops.mylib.foo.default(slice_tensor, slice_tensor_1); slice_tensor = slice_tensor_1 = foo_default = None
copy_: "f32[s0, s0][s0, 1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None
slice_3: "f32[2, s0][s0, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, 0, 2)
slice_4: "f32[s0, 1][s0, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 1, 3, 4); arg1_1 = None
copy_: "f32[s77, s77][s77, 1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None
slice_3: "f32[2, s77][s77, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, 0, 2)
slice_4: "f32[s77, 1][s77, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 1, 3, 4); arg1_1 = None
return (slice_3, slice_4)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -1470,18 +1470,18 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"):
self.assertExpectedInline(
graph_aot,
"""\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
clone: "f32[s0][1]cpu" = torch.ops.aten.clone.default(arg1_1)
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
clone: "f32[s77][1]cpu" = torch.ops.aten.clone.default(arg1_1)
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
_to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg1_1, _to_copy]); _to_copy = None
getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]
getitem_1: "f32[s77][1]cpu" = auto_functionalized_v2[1]
getitem_2: "f32[u0, 1][1, u0]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None
alias_1: "f32[s0][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None
copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None
alias_1: "f32[s77][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None
return (alias_1, slice_2)""", # noqa: B950
ignore_comments=True,
@ -1517,16 +1517,16 @@ def forward(self, arg0_1: "f32[2][1]cpu"):
self.assertExpectedInline(
graph_inductor,
"""\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg1_1)
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
alias_default: "f32[s0][1]cpu" = torch.ops.aten.alias.default(arg1_1)
alias_default: "f32[s77][1]cpu" = torch.ops.aten.alias.default(arg1_1)
alias_default_1: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.alias.default(convert_element_type)
foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); alias_default = alias_default_1 = foo_default = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None
copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None
return (arg1_1, slice_2)""", # noqa: B950
ignore_comments=True,

View File

@ -53,11 +53,13 @@ class TestMemoryPlanning(TestCase):
result, code = run_and_get_cpp_code(compiled, *args)
FileCheck().check(
"pool1 = empty_strided_" + GPU_TYPE + "((4*s0*s1 + align(4*s0*s0), ), (1, )"
"pool1 = empty_strided_"
+ GPU_TYPE
+ "((4*s27*s77 + align(4*s77*s77), ), (1, )"
).check_next(
"buf0 = alloc_from_pool(pool1, 0, torch.float32, (s0, s0), (s0, 1))"
"buf0 = alloc_from_pool(pool1, 0, torch.float32, (s77, s77), (s77, 1))"
).check(
"buf1 = alloc_from_pool(pool1, align(4*s0*s0),"
"buf1 = alloc_from_pool(pool1, align(4*s77*s77),"
).run(
code
)
@ -95,7 +97,7 @@ class TestMemoryPlanning(TestCase):
)
FileCheck().check(
"int64_t int_array_2[] = {24L + align(12L*s0), };"
"int64_t int_array_2[] = {24L + align(12L*s77), };"
).check_next("int64_t int_array_3[] = {1L, };").check_next(
"AtenTensorHandle pool1_handle;"
).check_next(
@ -103,7 +105,7 @@ class TestMemoryPlanning(TestCase):
).check_next(
"RAIIAtenTensorHandle pool1(pool1_handle);"
).check_next(
"int64_t int_array_4[] = {s0, 3L};"
"int64_t int_array_4[] = {s77, 3L};"
).check_next(
"int64_t int_array_5[] = {3L, 1L};"
).check_next(

View File

@ -11819,7 +11819,7 @@ class CommonTemplate:
# 'i1 + 3 * i0' is cached.
self.assertTrue(
"i0 + 2 * i1" in mul_buf.data.inner_fn_str()
or "i0 + i1 * s1" in mul_buf.data.inner_fn_str()
or "i0 + i1 * s64" in mul_buf.data.inner_fn_str()
)
with add_scheduler_init_hook(hook_fn):
@ -12548,7 +12548,7 @@ class CommonTemplate:
torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3)
if is_dynamic_shape_enabled():
size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, s1, s2, s2., .3\*s1\*s2\*s2, s1\*s2\*s2, 1, s1\*s2, s1.." # noqa: B950
size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, s12, s80, s80., .3\*s12\*s80\*s80, s12\*s80\*s80, 1, s12\*s80, s1.." # noqa: B950
else:
size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, 16, 32, 32., .49152, 16384, 1, 512, 16.."
FileCheck().check_regex(size_assert_pattern).run(code)
@ -12567,8 +12567,8 @@ class CommonTemplate:
code = run_and_get_triton_code(f, x)
if is_dynamic_shape_enabled():
FileCheck().check("assert_size_stride(buf1, (s0, s1), (s1, 1))").check(
"assert_size_stride(buf2, (s0, s1), (s1, 1))"
FileCheck().check("assert_size_stride(buf1, (s77, s27), (s27, 1))").check(
"assert_size_stride(buf2, (s77, s27), (s27, 1))"
).run(code)
else:
FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1))").check(

View File

@ -456,7 +456,7 @@ def forward(self, x_1, output_1):
self.assertIn("output_handles[0] = ", code)
self.assertIn("output_handles[1] = ", code)
else:
self.assertIn("return (buf0, s0, )", code)
self.assertIn("return (buf0, s92, )", code)
else:
self.assertIn(
"output_handles[0] = "

View File

@ -393,7 +393,7 @@ class TestPySymInt(TestCase):
self.assertEqual(res_and, 0b1000)
self.assertIsInstance(res_and, torch.SymInt, msg=type(res_and))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_and(s0, s1), 8)"""
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_and(s97, s26), 8)"""
)
a1 = create_symint(shape_env, 3)
@ -415,7 +415,7 @@ class TestPySymInt(TestCase):
self.assertEqual(res_or, 0b1110)
self.assertIsInstance(res_or, torch.SymInt, msg=type(res_or))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_or(s0, s1), 14)"""
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_or(s97, s26), 14)"""
)
def test_stride(self):
@ -497,7 +497,7 @@ class TestPySymInt(TestCase):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
self.assertEqual(guard_int(a0), 2)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""")
def test_sym_sum(self):
shape_env = ShapeEnv()
@ -512,7 +512,7 @@ class TestPySymInt(TestCase):
shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
s0 = create_symint(shape_env, 2)
self.assertEqual(guard_int(s0), 2)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""")
shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
s0 = create_symint(shape_env, 2)
@ -520,7 +520,7 @@ class TestPySymInt(TestCase):
self.assertEqual(len(shape_env.guards), 0)
self.assertExpectedInline(
str([ra.expr for ra in shape_env.deferred_runtime_asserts[None]]),
"""[Eq(s0, 2)]""",
"""[Eq(s97, 2)]""",
)
def test_sym_int(self):
@ -529,14 +529,14 @@ class TestPySymInt(TestCase):
r = sym_int(a0)
self.assertEqual(r, 5)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 5)""")
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 5)""")
a1 = create_symint(shape_env, 7)
r = sym_int(a1 / 2)
self.assertEqual(guard_int(r), 3)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s1, 2)), 3)"""
str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s26, 2)), 3)"""
)
a3 = create_symint(shape_env, 3)
@ -544,7 +544,7 @@ class TestPySymInt(TestCase):
self.assertEqual(guard_int(r), 6)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s2)), 6)"""
str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s57)), 6)"""
)
def test_sym_log2(self):
@ -554,7 +554,7 @@ class TestPySymInt(TestCase):
self.assertEqual(r, 2.0)
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_log2(ToFloat(s0)), 2.0)"""
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_log2(ToFloat(s97)), 2.0)"""
)
def test_sym_sqrt(self):
@ -564,7 +564,7 @@ class TestPySymInt(TestCase):
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(ToFloat(s0)), 2.0)"""
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(ToFloat(s97)), 2.0)"""
)
def test_sym_floor(self):
@ -575,14 +575,14 @@ class TestPySymInt(TestCase):
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]),
"""Eq(FloorToInt(IntTrueDiv(s0, 2)), 2)""",
"""Eq(FloorToInt(IntTrueDiv(s97, 2)), 2)""",
)
r = math.floor(3.0 * a0)
self.assertEqual(r, 15)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[1][0]),
"""Eq(FloorToInt(3.0*ToFloat(s0)), 15)""",
"""Eq(FloorToInt(3.0*ToFloat(s97)), 15)""",
)
def test_sym_trunc(self):
@ -592,14 +592,14 @@ class TestPySymInt(TestCase):
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s0, 2)), 2)"""
str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s97, 2)), 2)"""
)
r = torch.sym_int(torch.sym_sqrt(a0))
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[1][0]),
"""Eq(TruncToInt(OpaqueUnaryFn_sqrt(ToFloat(s0))), 2)""",
"""Eq(TruncToInt(OpaqueUnaryFn_sqrt(ToFloat(s97))), 2)""",
)
def test_sym_ceil(self):
@ -610,7 +610,7 @@ class TestPySymInt(TestCase):
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]),
"""Eq(CeilToInt(IntTrueDiv(s0, 2)), 3)""",
"""Eq(CeilToInt(IntTrueDiv(s97, 2)), 3)""",
)
r1 = 3.0 * a0
r = math.floor(r1)
@ -618,7 +618,7 @@ class TestPySymInt(TestCase):
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[1][0]),
"""Eq(FloorToInt(3.0*ToFloat(s0)), 15)""",
"""Eq(FloorToInt(3.0*ToFloat(s97)), 15)""",
)
def test_sym_ite(self):
@ -638,7 +638,7 @@ class TestPySymInt(TestCase):
self.assertEqual(type(t), type(r3))
self.assertExpectedInline(
str(shape_env.guards[0][0]),
"""Eq(Piecewise((s0, Eq(s0, 5)), (s1, True)), 5)""",
"""Eq(Piecewise((s97, Eq(s97, 5)), (s26, True)), 5)""",
)
b4 = f == 5
r4 = torch.sym_ite(b4, t, f)
@ -647,7 +647,7 @@ class TestPySymInt(TestCase):
self.assertEqual(type(f), type(r4))
self.assertExpectedInline(
str(shape_env.guards[1][0]),
"""Eq(Piecewise((s0, Eq(s1, 5)), (s1, True)), 4)""",
"""Eq(Piecewise((s97, Eq(s26, 5)), (s26, True)), 4)""",
)
def test_tracing_sym_ite(self):
@ -679,7 +679,7 @@ def forward(self, x_1):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
int(a0)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""")
def test_data_dependent_guard(self):
shape_env = ShapeEnv()
@ -710,7 +710,7 @@ def forward(self, x_1):
self.assertTrue(expect_true(i0 < s0))
self.assertExpectedInline(
str([ra.expr for ra in shape_env.deferred_runtime_asserts[i0.node.expr]]),
"""[u0 < s0]""",
"""[u0 < s97]""",
)
self.assertTrue(i0 < s0)
self.assertTrue(i0 != s0)
@ -1173,18 +1173,18 @@ def forward(self, x_1):
out.strip(),
"""\
class f(torch.nn.Module):
def forward(self, a_1: "f32[s0, s1]", b_1: "f32[s2, s1]"):
def forward(self, a_1: "f32[s75, s96]", b_1: "f32[s57, s96]"):
# No stacktrace found for following nodes
sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(a_1, 0)
sym_size_int_1: "Sym(s2)" = torch.ops.aten.sym_size.int(b_1, 0)
add: "Sym(s0 + s2)" = sym_size_int + sym_size_int_1; sym_size_int = sym_size_int_1 = None
sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(a_1, 1)
sym_size_int_3: "Sym(s1)" = torch.ops.aten.sym_size.int(b_1, 1); b_1 = None
add_1: "Sym(2*s1)" = sym_size_int_2 + sym_size_int_3; sym_size_int_2 = sym_size_int_3 = None
new_empty: "f32[s0 + s2, 2*s1]" = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None
sym_size_int: "Sym(s75)" = torch.ops.aten.sym_size.int(a_1, 0)
sym_size_int_1: "Sym(s57)" = torch.ops.aten.sym_size.int(b_1, 0)
add: "Sym(s57 + s75)" = sym_size_int + sym_size_int_1; sym_size_int = sym_size_int_1 = None
sym_size_int_2: "Sym(s96)" = torch.ops.aten.sym_size.int(a_1, 1)
sym_size_int_3: "Sym(s96)" = torch.ops.aten.sym_size.int(b_1, 1); b_1 = None
add_1: "Sym(2*s96)" = sym_size_int_2 + sym_size_int_3; sym_size_int_2 = sym_size_int_3 = None
new_empty: "f32[s57 + s75, 2*s96]" = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
getitem: "f32[s0 + s2, 2*s1]" = native_dropout[0]
getitem_1: "b8[s0 + s2, 2*s1]" = native_dropout[1]; native_dropout = None
getitem: "f32[s57 + s75, 2*s96]" = native_dropout[0]
getitem_1: "b8[s57 + s75, 2*s96]" = native_dropout[1]; native_dropout = None
return (getitem, getitem_1)""", # noqa: B950
)
@ -2846,8 +2846,8 @@ class TestGuardsExpressions(TestCase):
],
)
self.assertEqual(f"{x.stride()}", "(s1, 1)")
self.assertEqual(f"{x.shape}", "torch.Size([s0, s1])")
self.assertEqual(f"{x.stride()}", "(s49, 1)")
self.assertEqual(f"{x.shape}", "torch.Size([s26, s49])")
x_clean = _remove_symbols_without_guarding(x, 4096)

View File

@ -1084,7 +1084,7 @@ def forward(self, x_1, y_1):
test_inputs.append([(6, 8)])
gm = self._test_dynamic(f, [(3, 4)], test_inputs)
self.assertTrue(eval_guards(gm, torch.randn(4, 5)))
self.assertEqual(repr(bind_symbols(gm, torch.randn(4, 5))), "{s0: 4, s1: 5}")
self.assertEqual(repr(bind_symbols(gm, torch.randn(4, 5))), "{s75: 4, s96: 5}")
self.assertFalse(eval_guards(gm, torch.randn(25, 5)))
self.assertExpectedInline(show_guards(gm), """L['x'].size()[0] <= 19""")
@ -1717,7 +1717,7 @@ def forward(self, a_1):
gm = self._test_dynamic(f, [(1, 6), (8, 1)], test_inputs)
self.assertTrue(eval_guards(gm, torch.randn(1, 10), torch.randn(6, 1)))
self.assertFalse(eval_guards(gm, torch.randn(1, 2), torch.randn(4, 1)))
self.assertExpectedInline(show_guards(gm), """2*L['a'].size()[1]*L['b'].size()[0] > 20""")
self.assertExpectedInline(show_guards(gm), """2*L['b'].size()[0]*L['a'].size()[1] > 20""")
def test_new_empty(self):
def f(a, b):

View File

@ -869,6 +869,12 @@ def _optimized_add(
if sortkey(lhs._args[0]) > sortkey(rhs._args[-1]):
return make_optimized(rhs._args + lhs._args)
# (a1+a3) + (a0+a2) => (a0+a1+a2+a3)
new_args = list(lhs._args)
for a in rhs._args:
new_args = _binary_search_insert_arg(new_args, a)
return make_optimized(new_args)
# (a0+a2) + a1 => (a0+a1+a2)
if lhs_is_optimized_summation and rhs.is_symbol:
new_args = _binary_search_insert_arg(list(lhs._args), rhs)

View File

@ -14,6 +14,7 @@ import atexit
import collections
import dis
import functools
import hashlib
import inspect
import itertools
import logging
@ -3289,6 +3290,11 @@ class ShapeEnv:
self.guards: list[ShapeGuard] = []
self.axioms: dict[sympy.Expr, sympy.Expr] = {}
# A set of ids that have already been allocated. This is used
# for when we allocate symbol ids using the hash of the source
# names to ensure we don't have collisions via linear probing
self.unique_ids: set[int] = set()
# Maps symbolic ints to their original concrete values
# Currently populated from tensors
self.var_to_val: dict[sympy.Symbol, sympy.Integer] = {}
@ -4540,13 +4546,14 @@ class ShapeEnv:
# If we're not duck shaping, we always create a new symbol
# Even if we're duck shaping, if we haven't seen this particular
# value before, we also create a new symbol
symbol_id = self._generate_unique_id(source.name())
if type(val) is int or is_nested_int(val):
sympy_expr = make_symbol(
SymT.SIZE, len(self.var_to_val), positive=positive, integer=True
SymT.SIZE, symbol_id, positive=positive, integer=True
)
else:
sympy_expr = make_symbol(
SymT.FLOAT, len(self.var_to_val), positive=positive, real=True
SymT.FLOAT, symbol_id, positive=positive, real=True
)
self.source_to_var[source_name] = sympy_expr
# We always associate vars to vals
@ -6558,6 +6565,13 @@ class ShapeEnv:
sloc, _ = self._get_stack_summary(framework_loc=framework_loc)
return sloc
def _generate_unique_id(self, source_name: str) -> int:
attempt = int(hashlib.sha256(source_name.encode()).hexdigest(), 16) % 100
while attempt in self.unique_ids:
attempt += 1
self.unique_ids.add(attempt)
return attempt
def _find_frame_locals(self) -> _FrameLocalResult:
"""
Given the current user code frame, finds the relevant lines of code,