mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Use source hashing to generate consistent symbolic ids (#149665)"
This reverts commit1f92348dc6. Reverted https://github.com/pytorch/pytorch/pull/149665 on behalf of https://github.com/malfet due to Broke trunk, see6eb3c2e282/1([comment](https://github.com/pytorch/pytorch/pull/149665#issuecomment-2758578187))
This commit is contained in:
parent
6eb3c2e282
commit
af7719a2fa
|
|
@ -196,46 +196,6 @@ class AOTAutogradCacheTests(InductorTestCase):
|
|||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
|
||||
|
||||
@inductor_config.patch("fx_graph_remote_cache", False)
|
||||
@inductor_config.patch("fx_graph_cache", True)
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
def test_symbol_specialization(self):
|
||||
"""
|
||||
Verify the symbol specializations don't cause cache miss.
|
||||
"""
|
||||
|
||||
def fn(x, y, z):
|
||||
return (torch.randn(5) + x + y, z * torch.randn(1))
|
||||
|
||||
a = torch.rand(5)
|
||||
torch._dynamo.maybe_mark_dynamic(a, 0)
|
||||
b = torch.rand(5)
|
||||
c = torch.randn(6)
|
||||
torch._dynamo.maybe_mark_dynamic(c, 0)
|
||||
|
||||
compiled_fn = torch.compile(fn, backend="inductor")
|
||||
|
||||
# A first call should miss in the cache.
|
||||
compiled_fn(a, b, c)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
|
||||
|
||||
# A second call should hit even if a new dimension is marked as dynamic
|
||||
# that is later specialized as part of tracing.
|
||||
a = torch.rand(5)
|
||||
torch._dynamo.maybe_mark_dynamic(a, 0)
|
||||
b = torch.rand(5)
|
||||
torch._dynamo.maybe_mark_dynamic(b, 0)
|
||||
c = torch.randn(6)
|
||||
torch._dynamo.maybe_mark_dynamic(c, 0)
|
||||
self._clear_dynamo_and_codecache()
|
||||
|
||||
compiled_fn(a, b, c)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
|
||||
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
def test_aot_runtime_trace_joint(self):
|
||||
@torch.compile(backend="inductor")
|
||||
|
|
|
|||
|
|
@ -245,7 +245,7 @@ class GraphModule(torch.nn.Module):
|
|||
actual,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(2)", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s45)"):
|
||||
def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(2)", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s7)"):
|
||||
l_inputs_ = L_inputs_
|
||||
l_sizes_0_ = L_sizes_0_
|
||||
l_hooks_1_keywords_fn_keywords_obj_counter = L_hooks_1_keywords_fn_keywords_obj_counter
|
||||
|
|
@ -264,7 +264,7 @@ class GraphModule(torch.nn.Module):
|
|||
|
||||
copy_: "f32[2]" = new_grad_strided.copy_(aot0_tangents_1); copy_ = None
|
||||
|
||||
add: "Sym(s45 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None
|
||||
add: "Sym(s7 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None
|
||||
|
||||
result: "f32[2]" = aot0_tangents_1 * aot0_tangents_1; aot0_tangents_1 = None
|
||||
|
||||
|
|
|
|||
|
|
@ -57,18 +57,18 @@ class ComptimeTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertExpectedInline(
|
||||
FILE.getvalue().strip(),
|
||||
"""\
|
||||
FakeTensor(..., size=(s77,))
|
||||
FakeTensor(..., size=(s0,))
|
||||
2
|
||||
[FakeTensor(..., size=(s77,)), 2]
|
||||
(FakeTensor(..., size=(s77,)), 2)
|
||||
{'foo': FakeTensor(..., size=(s77,))}
|
||||
[FakeTensor(..., size=(s0,)), 2]
|
||||
(FakeTensor(..., size=(s0,)), 2)
|
||||
{'foo': FakeTensor(..., size=(s0,))}
|
||||
range(1, 3, 1)
|
||||
Employee(name='foo', id=2)
|
||||
UserDefinedListVariable(mylist)
|
||||
defaultdict(NestedUserFunctionVariable(), {})
|
||||
set()
|
||||
{'a','b'}
|
||||
s77""",
|
||||
s0""",
|
||||
)
|
||||
|
||||
def test_print_graph(self):
|
||||
|
|
|
|||
|
|
@ -256,34 +256,34 @@ Model:
|
|||
==> L['x'].size()[0]: 3
|
||||
==> L['x'].storage_offset(): 0
|
||||
==> L['x'].stride()[0]: 1
|
||||
==> s0: 3
|
||||
==> s1: 0
|
||||
==> s2: 1
|
||||
==> s3: 1
|
||||
==> s52: 1
|
||||
==> s77: 3
|
||||
==> s86: 0
|
||||
|
||||
Assertions:
|
||||
==> (== 0 L['x'].storage_offset())
|
||||
==> (== 1 L['x'].stride()[0])
|
||||
==> (== L['shape'][0] s86)
|
||||
==> (== L['shape'][1] s52)
|
||||
==> (== L['shape'][0] s1)
|
||||
==> (== L['shape'][1] s2)
|
||||
==> (== L['shape'][2] s3)
|
||||
==> (== L['x'].size()[0] s77)
|
||||
==> (> s77 1)
|
||||
==> (== L['x'].size()[0] s0)
|
||||
==> (> s0 1)
|
||||
|
||||
Target Expressions:
|
||||
==> (!= (+ s3 s52 s86) s77)
|
||||
==> (!= (+ s1 s2 s3) s0)
|
||||
==> (<= 0 s1)
|
||||
==> (<= 0 s2)
|
||||
==> (<= 0 s3)
|
||||
==> (<= 0 s52)
|
||||
==> (<= 0 s86)
|
||||
==> (<= 2 s77)
|
||||
==> (<= 2 s0)
|
||||
==> (== 0 L['x'].storage_offset())
|
||||
==> (== 1 L['x'].stride()[0])
|
||||
==> (== L['shape'][0] s86)
|
||||
==> (== L['shape'][1] s52)
|
||||
==> (== L['shape'][0] s1)
|
||||
==> (== L['shape'][1] s2)
|
||||
==> (== L['shape'][2] s3)
|
||||
==> (== L['x'].size()[0] s77)
|
||||
==> (> s77 0)
|
||||
==> (>= 0 s86)
|
||||
==> (== L['x'].size()[0] s0)
|
||||
==> (> s0 0)
|
||||
==> (>= 0 s1)
|
||||
|
||||
Failed Source Expressions:
|
||||
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
|
||||
|
|
@ -309,7 +309,7 @@ Failed Source Expressions:
|
|||
BisectValidationException,
|
||||
lambda: fn(torch.randn(20), (5, 10, 5)),
|
||||
"""\
|
||||
translation validation failed when evaluating: Eq(s3 + s52 + s86, s77)
|
||||
translation validation failed when evaluating: Eq(s1 + s2 + s3, s0)
|
||||
|
||||
Failure occurred while running node:
|
||||
%split : [num_users=3] = call_method[target=split](args = (%l_x_, (%l_shape_0_, %l_shape_1_, %l_shape_2_)), kwargs = {})
|
||||
|
|
@ -321,33 +321,33 @@ Model:
|
|||
==> L['x'].size()[0]: 3
|
||||
==> L['x'].storage_offset(): 0
|
||||
==> L['x'].stride()[0]: 1
|
||||
==> s0: 3
|
||||
==> s1: 1
|
||||
==> s2: 1
|
||||
==> s3: 0
|
||||
==> s52: 1
|
||||
==> s77: 3
|
||||
==> s86: 1
|
||||
|
||||
Assertions:
|
||||
==> (== 0 L['x'].storage_offset())
|
||||
==> (== 1 L['x'].stride()[0])
|
||||
==> (== L['shape'][0] s86)
|
||||
==> (== L['shape'][1] s52)
|
||||
==> (== L['shape'][0] s1)
|
||||
==> (== L['shape'][1] s2)
|
||||
==> (== L['shape'][2] s3)
|
||||
==> (== L['x'].size()[0] s77)
|
||||
==> (> s77 1)
|
||||
==> (== L['x'].size()[0] s0)
|
||||
==> (> s0 1)
|
||||
|
||||
Target Expressions:
|
||||
==> (!= (+ s3 s52 s86) s77)
|
||||
==> (!= (+ s1 s2 s3) s0)
|
||||
==> (<= 0 s1)
|
||||
==> (<= 0 s2)
|
||||
==> (<= 0 s3)
|
||||
==> (<= 0 s52)
|
||||
==> (<= 0 s86)
|
||||
==> (<= 2 s77)
|
||||
==> (<= 2 s0)
|
||||
==> (== 0 L['x'].storage_offset())
|
||||
==> (== 1 L['x'].stride()[0])
|
||||
==> (== L['shape'][0] s86)
|
||||
==> (== L['shape'][1] s52)
|
||||
==> (== L['shape'][0] s1)
|
||||
==> (== L['shape'][1] s2)
|
||||
==> (== L['shape'][2] s3)
|
||||
==> (== L['x'].size()[0] s77)
|
||||
==> (> s77 0)
|
||||
==> (== L['x'].size()[0] s0)
|
||||
==> (> s0 0)
|
||||
|
||||
Failed Source Expressions:
|
||||
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
|
||||
|
|
|
|||
|
|
@ -2703,7 +2703,7 @@ def forward(self, x):
|
|||
for node in ebar.graph_module.graph.nodes
|
||||
if node.op == "placeholder"
|
||||
],
|
||||
["torch.Size([s17, s27, s27])", "torch.Size([s17, s27, s27])"],
|
||||
["torch.Size([s0, s1, s1])", "torch.Size([s0, s1, s1])"],
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch(
|
||||
|
|
@ -3480,23 +3480,23 @@ def forward(self, x):
|
|||
true_graph = """\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, pred, x):
|
||||
arg1: "f32[s77, s27]";
|
||||
arg1: "f32[s1, s2]";
|
||||
|
||||
arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
|
||||
l_x_ = arg1
|
||||
|
||||
sin: "f32[s77, s27]" = l_x_.sin(); l_x_ = None
|
||||
sin: "f32[s1, s2]" = l_x_.sin(); l_x_ = None
|
||||
return pytree.tree_unflatten([sin], self._out_spec)
|
||||
"""
|
||||
false_graph = """\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, pred, x):
|
||||
arg1: "f32[s77, s27]";
|
||||
arg1: "f32[s1, s2]";
|
||||
|
||||
arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
|
||||
l_x_ = arg1
|
||||
|
||||
cos: "f32[s77, s27]" = l_x_.cos(); l_x_ = None
|
||||
cos: "f32[s1, s2]" = l_x_.cos(); l_x_ = None
|
||||
return pytree.tree_unflatten([cos], self._out_spec)
|
||||
"""
|
||||
true_guard_code = [
|
||||
|
|
|
|||
|
|
@ -2655,7 +2655,7 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77]"):
|
||||
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0]"):
|
||||
l_x_ = L_x_
|
||||
|
||||
sum_1: "f32[]" = l_x_.sum(); l_x_ = None
|
||||
|
|
@ -2885,13 +2885,13 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, s9: "Sym(s9)", L_lambda0_keywords_y_: "f32[s9, s9]"):
|
||||
def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
|
||||
l_lambda0_keywords_y_ = L_lambda0_keywords_y_
|
||||
|
||||
mul: "f32[s9, s9]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
|
||||
mul_1: "f32[s9, s9]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
|
||||
mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
|
||||
mul_1: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
|
||||
|
||||
mul_2: "f32[s9, s9]" = torch.mul(mul, mul_1); mul = mul_1 = None
|
||||
mul_2: "f32[s0, s0]" = torch.mul(mul, mul_1); mul = mul_1 = None
|
||||
return (mul_2,)
|
||||
""",
|
||||
)
|
||||
|
|
@ -2932,14 +2932,14 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, s9: "Sym(s9)", L_lambda0_keywords_y_: "f32[s9, s9]"):
|
||||
def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
|
||||
l_lambda0_keywords_y_ = L_lambda0_keywords_y_
|
||||
|
||||
mul: "f32[s9, s9]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
|
||||
mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
|
||||
|
||||
add: "f32[s9, s9]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
|
||||
add: "f32[s0, s0]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
|
||||
|
||||
mul_1: "f32[s9, s9]" = torch.mul(mul, add); mul = add = None
|
||||
mul_1: "f32[s0, s0]" = torch.mul(mul, add); mul = add = None
|
||||
return (mul_1,)
|
||||
""",
|
||||
)
|
||||
|
|
@ -2982,14 +2982,14 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, s9: "Sym(s9)", L_lambda0_keywords_y_: "f32[s9, s9]"):
|
||||
def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
|
||||
l_lambda0_keywords_y_ = L_lambda0_keywords_y_
|
||||
|
||||
mul: "f32[s9, s9]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
|
||||
mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
|
||||
|
||||
add: "f32[s9, s9]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
|
||||
add: "f32[s0, s0]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
|
||||
|
||||
mul_1: "f32[s9, s9]" = torch.mul(mul, add); mul = add = None
|
||||
mul_1: "f32[s0, s0]" = torch.mul(mul, add); mul = add = None
|
||||
return (mul_1,)
|
||||
""",
|
||||
)
|
||||
|
|
@ -3029,14 +3029,14 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77, s77]"):
|
||||
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, s0]"):
|
||||
l_x_ = L_x_
|
||||
|
||||
mul: "f32[s77, s77]" = l_x_ * 4
|
||||
mul_1: "f32[s77, s77]" = mul * l_x_; mul = None
|
||||
mul_2: "f32[s77, s77]" = 20 * l_x_; l_x_ = None
|
||||
mul: "f32[s0, s0]" = l_x_ * 4
|
||||
mul_1: "f32[s0, s0]" = mul * l_x_; mul = None
|
||||
mul_2: "f32[s0, s0]" = 20 * l_x_; l_x_ = None
|
||||
|
||||
mul_3: "f32[s77, s77]" = torch.mul(mul_1, mul_2); mul_1 = mul_2 = None
|
||||
mul_3: "f32[s0, s0]" = torch.mul(mul_1, mul_2); mul_1 = mul_2 = None
|
||||
return (mul_3,)
|
||||
""",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -413,18 +413,18 @@ class GraphModule(torch.nn.Module):
|
|||
actual_graph,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77, 1]"):
|
||||
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, 1]"):
|
||||
l_x_ = L_x_
|
||||
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_); wrap_body_0 = s77 = l_x_ = None
|
||||
getitem: "f32[s77]" = wrap[0]; wrap = None
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_); wrap_body_0 = s0 = l_x_ = None
|
||||
getitem: "f32[s0]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77, 1]"):
|
||||
view: "f32[s77]" = l_x_.view(s77); l_x_ = s77 = None
|
||||
add: "f32[s77]" = view + 0.5; view = None
|
||||
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0, 1]"):
|
||||
view: "f32[s0]" = l_x_.view(s0); l_x_ = s0 = None
|
||||
add: "f32[s0]" = view + 0.5; view = None
|
||||
return (add,)
|
||||
""",
|
||||
)
|
||||
|
|
@ -606,27 +606,27 @@ class GraphModule(torch.nn.Module):
|
|||
out_graph,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77]"):
|
||||
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0]"):
|
||||
l_x_ = L_x_
|
||||
|
||||
sum_1: "f32[]" = l_x_.sum()
|
||||
item: "Sym(zuf0)" = sum_1.item(); sum_1 = None
|
||||
|
||||
wrap_body_1 = self.wrap_body_1
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_1, s77, l_x_, item); wrap_body_1 = s77 = l_x_ = item = None
|
||||
getitem: "f32[s77]" = wrap[0]; wrap = None
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, l_x_, item); wrap_body_1 = s0 = l_x_ = item = None
|
||||
getitem: "f32[s0]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_1(torch.nn.Module):
|
||||
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", item: "Sym(zuf0)"):
|
||||
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", item: "Sym(zuf0)"):
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_, item); wrap_body_0 = s77 = l_x_ = item = None
|
||||
getitem: "f32[s77]" = wrap[0]; wrap = None
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_, item); wrap_body_0 = s0 = l_x_ = item = None
|
||||
getitem: "f32[s0]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", item: "Sym(zuf0)"):
|
||||
add: "f32[s77]" = l_x_ + item; l_x_ = item = None
|
||||
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", item: "Sym(zuf0)"):
|
||||
add: "f32[s0]" = l_x_ + item; l_x_ = item = None
|
||||
return (add,)
|
||||
""",
|
||||
)
|
||||
|
|
@ -692,7 +692,7 @@ class GraphModule(torch.nn.Module):
|
|||
out_graph,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77]"):
|
||||
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0]"):
|
||||
l_x_ = L_x_
|
||||
|
||||
c: "i64[u0, 1]" = l_x_.nonzero()
|
||||
|
|
@ -704,22 +704,22 @@ class GraphModule(torch.nn.Module):
|
|||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
||||
|
||||
wrap_body_1 = self.wrap_body_1
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_1, s77, l_x_, sym_size_int_1, c); wrap_body_1 = s77 = l_x_ = sym_size_int_1 = c = None
|
||||
getitem: "f32[s77]" = wrap[0]
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, l_x_, sym_size_int_1, c); wrap_body_1 = s0 = l_x_ = sym_size_int_1 = c = None
|
||||
getitem: "f32[s0]" = wrap[0]
|
||||
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (getitem, getitem_1)
|
||||
|
||||
class wrap_body_1(torch.nn.Module):
|
||||
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_, u0, c); wrap_body_0 = s77 = l_x_ = u0 = c = None
|
||||
child: "f32[s77]" = wrap[0]
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_, u0, c); wrap_body_0 = s0 = l_x_ = u0 = c = None
|
||||
child: "f32[s0]" = wrap[0]
|
||||
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (child, child_1)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
child: "f32[s77]" = l_x_.sin(); l_x_ = None
|
||||
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
child: "f32[s0]" = l_x_.sin(); l_x_ = None
|
||||
child_1: "f32[u0, 1]" = c.sin(); c = None
|
||||
return (child, child_1)
|
||||
""",
|
||||
|
|
@ -994,25 +994,25 @@ class GraphModule(torch.nn.Module):
|
|||
out_graph,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]", s94: "Sym(s94)", L_y_: "f32[s27, s94]"):
|
||||
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_x_: "f32[s0, s1]", s2: "Sym(s2)", L_y_: "f32[s1, s2]"):
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
|
||||
wrap_body_1 = self.wrap_body_1
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_1, s77, s27, l_x_, s94, l_y_); wrap_body_1 = s77 = s27 = l_x_ = s94 = l_y_ = None
|
||||
getitem: "f32[s77, s94]" = wrap[0]; wrap = None
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, s1, l_x_, s2, l_y_); wrap_body_1 = s0 = s1 = l_x_ = s2 = l_y_ = None
|
||||
getitem: "f32[s0, s2]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_1(torch.nn.Module):
|
||||
def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", l_x_: "f32[s77, s27]", s94: "Sym(s94)", l_y_: "f32[s27, s94]"):
|
||||
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", l_x_: "f32[s0, s1]", s2: "Sym(s2)", l_y_: "f32[s1, s2]"):
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, s27, l_x_, s94, l_y_); wrap_body_0 = s77 = s27 = l_x_ = s94 = l_y_ = None
|
||||
getitem: "f32[s77, s94]" = wrap[0]; wrap = None
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, s1, l_x_, s2, l_y_); wrap_body_0 = s0 = s1 = l_x_ = s2 = l_y_ = None
|
||||
getitem: "f32[s0, s2]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", l_x_: "f32[s77, s27]", s94: "Sym(s94)", l_y_: "f32[s27, s94]"):
|
||||
matmul: "f32[s77, s94]" = l_x_ @ l_y_; l_x_ = l_y_ = None
|
||||
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", l_x_: "f32[s0, s1]", s2: "Sym(s2)", l_y_: "f32[s1, s2]"):
|
||||
matmul: "f32[s0, s2]" = l_x_ @ l_y_; l_x_ = l_y_ = None
|
||||
return (matmul,)
|
||||
""",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10382,22 +10382,19 @@ ShapeEnv not equal: field values don't match:
|
|||
> Left: {x.size()[0]: x.size()[0], x.size()[1]: x.size()[1], x.storage_offset(): x.storage_offset(), x.stride()[0]: x.stride()[0], x.stride()[1]: x.stride()[1]}
|
||||
> Right: {}
|
||||
==> source_to_var: values don't match.
|
||||
> Left: {x.size()[0]: s93, x.size()[1]: s44}
|
||||
> Right: {}
|
||||
==> unique_ids: values don't match.
|
||||
> Left: {44, 93}
|
||||
> Left: {x.size()[0]: s0, x.size()[1]: s1}
|
||||
> Right: {}
|
||||
==> val_to_var: values don't match.
|
||||
> Left: {0: 0, 1: 1, 2: s44, 3: s93}
|
||||
> Left: {0: 0, 1: 1, 2: s1, 3: s0}
|
||||
> Right: {0: 0, 1: 1}
|
||||
==> var_to_range: values don't match.
|
||||
> Left: {s44: VR[2, int_oo], s93: VR[2, int_oo]}
|
||||
> Left: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
|
||||
> Right: {}
|
||||
==> var_to_sources: values don't match.
|
||||
> Left: {s44: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)], s93: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)]}
|
||||
> Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)]}
|
||||
> Right: {}
|
||||
==> var_to_val: values don't match.
|
||||
> Left: {s44: 2, s93: 3}
|
||||
> Left: {s0: 3, s1: 2}
|
||||
> Right: {}
|
||||
""",
|
||||
)
|
||||
|
|
@ -10456,13 +10453,13 @@ ShapeEnv not equal: field values don't match:
|
|||
ShapeEnv not equal: field values don't match:
|
||||
|
||||
==> axioms: values don't match.
|
||||
> Left: {(Mod(s93, 3)) < 0: False, (Mod(s93, 3)) <= 0: True, 0 < (Mod(s93, 3)): False, 0 <= (Mod(s93, 3)): True, Eq(0, Mod(s93, 3)): True, Eq(Mod(s93, 3), 0): True, Ne(0, Mod(s93, 3)): False, Ne(Mod(s93, 3), 0): False}
|
||||
> Left: {(Mod(s0, 3)) < 0: False, (Mod(s0, 3)) <= 0: True, 0 < (Mod(s0, 3)): False, 0 <= (Mod(s0, 3)): True, Eq(0, Mod(s0, 3)): True, Eq(Mod(s0, 3), 0): True, Ne(0, Mod(s0, 3)): False, Ne(Mod(s0, 3), 0): False}
|
||||
> Right: {}
|
||||
==> divisible: values don't match.
|
||||
> Left: {Mod(s93, 3)}
|
||||
> Left: {Mod(s0, 3)}
|
||||
> Right: {}
|
||||
==> guards: values don't match.
|
||||
> Left: [Eq(Mod(s93, 3), 0)]
|
||||
> Left: [Eq(Mod(s0, 3), 0)]
|
||||
> Right: []
|
||||
==> name_to_node: values don't match.
|
||||
> Left: {_assert, eq, mod, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
|
||||
|
|
@ -10499,17 +10496,17 @@ ShapeEnv not equal: field values don't match:
|
|||
> Left: {False: False, True: True}
|
||||
> Right: {}
|
||||
==> guards: values don't match.
|
||||
> Left: [Eq(s93, 3)]
|
||||
> Left: [Eq(s0, 3)]
|
||||
> Right: []
|
||||
==> name_to_node: values don't match.
|
||||
> Left: {_assert, eq, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
|
||||
> Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
|
||||
==> replacements: values don't match.
|
||||
> Left: {s93: 3}
|
||||
> Left: {s0: 3}
|
||||
> Right: {}
|
||||
==> var_to_range: values don't match.
|
||||
> Left: {s44: VR[2, int_oo], s93: VR[3, 3]}
|
||||
> Right: {s44: VR[2, int_oo], s93: VR[2, int_oo]}
|
||||
> Left: {s0: VR[3, 3], s1: VR[2, int_oo]}
|
||||
> Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
|
||||
""",
|
||||
)
|
||||
self._replay_and_check(main)
|
||||
|
|
@ -10540,17 +10537,17 @@ ShapeEnv not equal: field values don't match:
|
|||
ShapeEnv not equal: field values don't match:
|
||||
|
||||
==> axioms: values don't match.
|
||||
> Left: {3 <= s93: True, s93 < 3: False}
|
||||
> Left: {3 <= s0: True, s0 < 3: False}
|
||||
> Right: {}
|
||||
==> guards: values don't match.
|
||||
> Left: [s93 >= 3]
|
||||
> Left: [s0 >= 3]
|
||||
> Right: []
|
||||
==> name_to_node: values don't match.
|
||||
> Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
|
||||
> Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
|
||||
==> var_to_range: values don't match.
|
||||
> Left: {s44: VR[2, int_oo], s93: VR[3, int_oo]}
|
||||
> Right: {s44: VR[2, int_oo], s93: VR[2, int_oo]}
|
||||
> Left: {s0: VR[3, int_oo], s1: VR[2, int_oo]}
|
||||
> Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
|
||||
""",
|
||||
)
|
||||
self._replay_and_check(main)
|
||||
|
|
|
|||
|
|
@ -4768,7 +4768,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertExpectedInline(
|
||||
str(graph.code).strip(),
|
||||
"""\
|
||||
def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
l_x_ = L_x_
|
||||
getitem_2 = l_x_[0]
|
||||
sum_1 = getitem_2.sum(); getitem_2 = None
|
||||
|
|
|
|||
|
|
@ -339,7 +339,7 @@ class StructuredTraceTest(TestCase):
|
|||
{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"create_symbol": {"symbol": "s48", "val": "1", "vr": "[-int_oo, int_oo]", "source": "L['y']", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"create_symbol": {"symbol": "s0", "val": "1", "vr": "[-int_oo, int_oo]", "source": "L['y']", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"dynamo_output_graph": {"sizes": {"l_x_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
|
||||
{"inductor_pre_grad_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "before_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
|
||||
|
|
@ -767,15 +767,15 @@ class StructuredTraceTest(TestCase):
|
|||
{"describe_storage": {"id": 0, "describer_id": "ID", "size": 200}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [5, 10], "is_leaf": true, "stride": [10, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"create_symbol": {"symbol": "s97", "val": "5", "vr": "[2, int_oo]", "source": "L['a'].size()[0]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"create_symbol": {"symbol": "s98", "val": "10", "vr": "[2, int_oo]", "source": "L['a'].size()[1]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"create_symbol": {"symbol": "s0", "val": "5", "vr": "[2, int_oo]", "source": "L['a'].size()[0]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"create_symbol": {"symbol": "s1", "val": "10", "vr": "[2, int_oo]", "source": "L['a'].size()[1]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"describe_storage": {"id": 1, "describer_id": "ID", "size": 600}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [10, 15], "is_leaf": true, "stride": [15, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['b']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"create_symbol": {"symbol": "s52", "val": "10", "vr": "[2, int_oo]", "source": "L['b'].size()[0]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"create_symbol": {"symbol": "s20", "val": "15", "vr": "[2, int_oo]", "source": "L['b'].size()[1]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"guard_added_fast": {"expr": "Eq(s98, s52)", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"dynamo_output_graph": {"sizes": {"l_a_": ["s97", "s52"], "l_b_": ["s52", "s20"], "matmul": ["s97", "s20"]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
|
||||
{"create_symbol": {"symbol": "s2", "val": "10", "vr": "[2, int_oo]", "source": "L['b'].size()[0]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"create_symbol": {"symbol": "s3", "val": "15", "vr": "[2, int_oo]", "source": "L['b'].size()[1]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"guard_added_fast": {"expr": "Eq(s1, s2)", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"dynamo_output_graph": {"sizes": {"l_a_": ["s0", "s1"], "l_b_": ["s1", "s3"], "matmul": ["s0", "s3"]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
|
||||
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
|
||||
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
""", # noqa: B950
|
||||
|
|
|
|||
|
|
@ -1373,21 +1373,21 @@ class GraphModule(torch.nn.Module):
|
|||
# During fakeifying, we end up allocating a separate symint
|
||||
# for the outer and inner tensor (in this test, s0 is unused).
|
||||
expected_var_to_val = {
|
||||
"s50": 4,
|
||||
"s77": 8,
|
||||
"s0": 8,
|
||||
"s1": 4,
|
||||
}
|
||||
expected_var_to_sources = {
|
||||
"s50": "L['x'].inner_elem.size()[0]",
|
||||
"s77": "L['x'].size()[0]",
|
||||
"s0": "L['x'].size()[0]",
|
||||
"s1": "L['x'].inner_elem.size()[0]",
|
||||
}
|
||||
self.assertEqual(curr_var_to_val, expected_var_to_val)
|
||||
self.assertEqual(curr_var_to_sources, expected_var_to_sources)
|
||||
self.assertExpectedInline(
|
||||
"\n".join(guards),
|
||||
"""\
|
||||
Eq(2*s50, s77)
|
||||
2*s50 < 13
|
||||
s50 > 3""",
|
||||
Eq(2*s1, s0)
|
||||
2*s1 < 13
|
||||
s1 > 3""",
|
||||
)
|
||||
|
||||
def test_wrapper_subclass_with_same_sized_inner_tensor(self):
|
||||
|
|
@ -1976,9 +1976,9 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
|
||||
mul: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None
|
||||
mul_3: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None
|
||||
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
|
||||
mul: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None
|
||||
mul_3: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None
|
||||
return (mul, mul_3, primals_5, primals_7, primals_7, primals_1, primals_5, primals_7)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -1987,9 +1987,9 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s47)", primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s47, s16]", tangents_2: "f32[s47, s16]"):
|
||||
mul_8: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = None
|
||||
mul_9: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(tangents_2, primals_1); tangents_2 = primals_1 = None
|
||||
def forward(self, primals_1: "Sym(s0)", primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0, s1]", tangents_2: "f32[s0, s1]"):
|
||||
mul_8: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = None
|
||||
mul_9: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_2, primals_1); tangents_2 = primals_1 = None
|
||||
return (None, None, mul_8, mul_9, primals_5, primals_7, primals_7)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -2009,12 +2009,12 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
|
||||
clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
|
||||
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
|
||||
view: "f32[s16, s47]" = torch.ops.aten.view.default(clone, [primals_2, primals_1]); clone = None
|
||||
view_1: "f32[s16, s47]" = torch.ops.aten.view.default(clone_1, [primals_2, primals_1]); clone_1 = primals_1 = None
|
||||
view: "f32[s1, s0]" = torch.ops.aten.view.default(clone, [primals_2, primals_1]); clone = None
|
||||
view_1: "f32[s1, s0]" = torch.ops.aten.view.default(clone_1, [primals_2, primals_1]); clone_1 = primals_1 = None
|
||||
return (view, view_1, primals_2, primals_5, primals_5, primals_5, primals_7)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -2023,9 +2023,9 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s16, s47]", tangents_2: "f32[s16, s47]"):
|
||||
view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
|
||||
view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
|
||||
def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s1, s0]", tangents_2: "f32[s1, s0]"):
|
||||
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
|
||||
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
|
||||
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -2047,15 +2047,15 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s97)", primals_2: "Sym(s98)", primals_3: "f32[s97, s98]", primals_4: "f32[s97, s98]", primals_5: "Sym(s97)", primals_6: "Sym(s98)", primals_7: "Sym(s98)"):
|
||||
mul: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None
|
||||
mul_3: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None
|
||||
mul_8: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul, primals_2); mul = None
|
||||
mul_11: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_3, primals_2); mul_3 = None
|
||||
mul_16: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_8, primals_1); mul_8 = None
|
||||
mul_19: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_11, primals_1); mul_11 = None
|
||||
mul_24: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_16, primals_2); mul_16 = None
|
||||
mul_27: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_19, primals_2); mul_19 = None
|
||||
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
|
||||
mul: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None
|
||||
mul_3: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None
|
||||
mul_8: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul, primals_2); mul = None
|
||||
mul_11: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_3, primals_2); mul_3 = None
|
||||
mul_16: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_8, primals_1); mul_8 = None
|
||||
mul_19: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_11, primals_1); mul_11 = None
|
||||
mul_24: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_16, primals_2); mul_16 = None
|
||||
mul_27: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_19, primals_2); mul_19 = None
|
||||
return (mul_24, mul_27, primals_5, primals_7, primals_7, primals_1, primals_2, primals_5, primals_7)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -2064,15 +2064,15 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s97)", primals_2: "Sym(s98)", primals_5: "Sym(s97)", primals_7: "Sym(s98)", tangents_1: "f32[s97, s98]", tangents_2: "f32[s97, s98]"):
|
||||
mul_32: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(tangents_1, primals_2); tangents_1 = None
|
||||
mul_33: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(tangents_2, primals_2); tangents_2 = None
|
||||
mul_34: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_32, primals_1); mul_32 = None
|
||||
mul_35: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_33, primals_1); mul_33 = None
|
||||
mul_36: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_34, primals_2); mul_34 = None
|
||||
mul_37: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_35, primals_2); mul_35 = primals_2 = None
|
||||
mul_38: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_36, primals_1); mul_36 = None
|
||||
mul_39: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_37, primals_1); mul_37 = primals_1 = None
|
||||
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0, s1]", tangents_2: "f32[s0, s1]"):
|
||||
mul_32: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_1, primals_2); tangents_1 = None
|
||||
mul_33: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_2, primals_2); tangents_2 = None
|
||||
mul_34: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_32, primals_1); mul_32 = None
|
||||
mul_35: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_33, primals_1); mul_33 = None
|
||||
mul_36: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_34, primals_2); mul_34 = None
|
||||
mul_37: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_35, primals_2); mul_35 = primals_2 = None
|
||||
mul_38: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_36, primals_1); mul_36 = None
|
||||
mul_39: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_37, primals_1); mul_37 = primals_1 = None
|
||||
return (None, None, mul_38, mul_39, primals_5, primals_7, primals_7)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -2092,12 +2092,12 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
|
||||
clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
|
||||
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
|
||||
view: "f32[s47, s16]" = torch.ops.aten.view.default(clone, [primals_1, primals_2]); clone = None
|
||||
view_1: "f32[s47, s16]" = torch.ops.aten.view.default(clone_1, [primals_1, primals_2]); clone_1 = primals_1 = primals_2 = None
|
||||
view: "f32[s0, s1]" = torch.ops.aten.view.default(clone, [primals_1, primals_2]); clone = None
|
||||
view_1: "f32[s0, s1]" = torch.ops.aten.view.default(clone_1, [primals_1, primals_2]); clone_1 = primals_1 = primals_2 = None
|
||||
return (view, view_1, primals_5, primals_7, primals_7, primals_5, primals_7)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -2106,9 +2106,9 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s47, s16]", tangents_2: "f32[s47, s16]"):
|
||||
view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
|
||||
view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
|
||||
def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0, s1]", tangents_2: "f32[s0, s1]"):
|
||||
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
|
||||
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
|
||||
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -2128,13 +2128,13 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
|
||||
clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
|
||||
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
|
||||
mul_6: "Sym(s16*s47)" = primals_1 * primals_2; primals_1 = primals_2 = None
|
||||
view: "f32[s16*s47]" = torch.ops.aten.view.default(clone, [mul_6]); clone = None
|
||||
view_1: "f32[s16*s47]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None
|
||||
mul_6: "Sym(s0*s1)" = primals_1 * primals_2; primals_1 = primals_2 = None
|
||||
view: "f32[s0*s1]" = torch.ops.aten.view.default(clone, [mul_6]); clone = None
|
||||
view_1: "f32[s0*s1]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None
|
||||
return (view, view_1, mul_6, primals_5, primals_7)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -2143,9 +2143,9 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s16*s47]", tangents_2: "f32[s16*s47]"):
|
||||
view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
|
||||
view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
|
||||
def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0*s1]", tangents_2: "f32[s0*s1]"):
|
||||
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
|
||||
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
|
||||
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -2165,13 +2165,13 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
|
||||
clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
|
||||
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
|
||||
mul_6: "Sym(s16*s47)" = primals_1 * primals_2; primals_1 = primals_2 = None
|
||||
view: "f32[s16*s47]" = torch.ops.aten.view.default(clone, [mul_6])
|
||||
view_1: "f32[s16*s47]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None
|
||||
mul_6: "Sym(s0*s1)" = primals_1 * primals_2; primals_1 = primals_2 = None
|
||||
view: "f32[s0*s1]" = torch.ops.aten.view.default(clone, [mul_6])
|
||||
view_1: "f32[s0*s1]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None
|
||||
return (clone, view, view_1, mul_6, primals_5, primals_7)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -2180,9 +2180,9 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s16*s47]", tangents_2: "f32[s16*s47]"):
|
||||
view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
|
||||
view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
|
||||
def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0*s1]", tangents_2: "f32[s0*s1]"):
|
||||
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
|
||||
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
|
||||
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -2261,13 +2261,13 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(fw[1].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s16)", primals_2: "f32[3, s16]", primals_3: "f32[3, s16]", primals_4: "Sym(s16)", primals_5: "Sym(s16)"):
|
||||
clone: "f32[3, s16]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
|
||||
clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
def forward(self, primals_1: "Sym(s0)", primals_2: "f32[3, s0]", primals_3: "f32[3, s0]", primals_4: "Sym(s0)", primals_5: "Sym(s0)"):
|
||||
clone: "f32[3, s0]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
|
||||
clone_1: "f32[3, s0]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
|
||||
view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1])
|
||||
sym_numel_default: "Sym(3*s16)" = torch.ops.aten.sym_numel.default(clone)
|
||||
view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1])
|
||||
view: "f32[3*s0]" = torch.ops.aten.view.default(clone, [-1])
|
||||
sym_numel_default: "Sym(3*s0)" = torch.ops.aten.sym_numel.default(clone)
|
||||
view_1: "f32[3*s0]" = torch.ops.aten.view.default(clone_1, [-1])
|
||||
return (clone, view, view_1, sym_numel_default, clone_1, primals_5)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -2287,9 +2287,9 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(bw[1].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_5: "Sym(s16)", tangents_1: "f32[3*s16]", tangents_2: "f32[3*s16]"):
|
||||
view_2: "f32[3, s16]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None
|
||||
view_3: "f32[3, s16]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None
|
||||
def forward(self, primals_5: "Sym(s0)", tangents_1: "f32[3*s0]", tangents_2: "f32[3*s0]"):
|
||||
view_2: "f32[3, s0]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None
|
||||
view_3: "f32[3, s0]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None
|
||||
return (None, view_2, view_3, primals_5, primals_5)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -2317,13 +2317,13 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s16)", primals_2: "f32[3, s16]", primals_3: "f32[3, s16]", primals_4: "Sym(s16)", primals_5: "Sym(s16)"):
|
||||
clone: "f32[3, s16]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
|
||||
clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
def forward(self, primals_1: "Sym(s0)", primals_2: "f32[3, s0]", primals_3: "f32[3, s0]", primals_4: "Sym(s0)", primals_5: "Sym(s0)"):
|
||||
clone: "f32[3, s0]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
|
||||
clone_1: "f32[3, s0]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
|
||||
view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1])
|
||||
sym_numel_default: "Sym(3*s16)" = torch.ops.aten.sym_numel.default(clone)
|
||||
view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1])
|
||||
view: "f32[3*s0]" = torch.ops.aten.view.default(clone, [-1])
|
||||
sym_numel_default: "Sym(3*s0)" = torch.ops.aten.sym_numel.default(clone)
|
||||
view_1: "f32[3*s0]" = torch.ops.aten.view.default(clone_1, [-1])
|
||||
return (clone, view, view_1, sym_numel_default, clone_1, primals_5)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -2332,9 +2332,9 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_5: "Sym(s16)", tangents_1: "f32[3*s16]", tangents_2: "f32[3*s16]"):
|
||||
view_2: "f32[3, s16]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None
|
||||
view_3: "f32[3, s16]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None
|
||||
def forward(self, primals_5: "Sym(s0)", tangents_1: "f32[3*s0]", tangents_2: "f32[3*s0]"):
|
||||
view_2: "f32[3, s0]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None
|
||||
view_3: "f32[3, s0]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None
|
||||
return (None, view_2, view_3, primals_5, primals_5)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -2501,10 +2501,10 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s51)", primals_2: "Sym(s71)", primals_3: "Sym(s55)", primals_4: "f64[s64, s55]", primals_5: "i64[s51 + 1]", primals_6: "f32[s0, 0]", primals_7: "f32[s83, 0]", primals_8: "Sym(s51)", primals_9: "Sym(s55)", primals_10: "Sym(s55)"):
|
||||
clone: "f64[s64, s55]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
def forward(self, primals_1: "Sym(s2)", primals_2: "Sym(s3)", primals_3: "Sym(s1)", primals_4: "f64[s0, s1]", primals_5: "i64[s2 + 1]", primals_6: "f32[s6, 0]", primals_7: "f32[s7, 0]", primals_8: "Sym(s2)", primals_9: "Sym(s1)", primals_10: "Sym(s1)"):
|
||||
clone: "f64[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
|
||||
mul: "f64[s64, s55]" = torch.ops.aten.mul.Tensor(clone, primals_1); clone = None
|
||||
mul: "f64[s0, s1]" = torch.ops.aten.mul.Tensor(clone, primals_1); clone = None
|
||||
return (mul, primals_5, primals_6, primals_7, primals_8, primals_10, primals_10, primals_1, primals_8, primals_10)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -2513,8 +2513,8 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s51)", primals_8: "Sym(s51)", primals_10: "Sym(s55)", tangents_1: "f64[s64, s55]", tangents_2: "i64[s51 + 1]", tangents_3: "f32[s0, 0]", tangents_4: "f32[s83, 0]"):
|
||||
mul_1: "f64[s64, s55]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = primals_1 = None
|
||||
def forward(self, primals_1: "Sym(s2)", primals_8: "Sym(s2)", primals_10: "Sym(s1)", tangents_1: "f64[s0, s1]", tangents_2: "i64[s2 + 1]", tangents_3: "f32[s6, 0]", tangents_4: "f32[s7, 0]"):
|
||||
mul_1: "f64[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = primals_1 = None
|
||||
return (None, None, None, mul_1, tangents_2, tangents_3, tangents_4, primals_8, primals_10, primals_10)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -2534,11 +2534,11 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s51)", primals_2: "Sym(s71)", primals_3: "Sym(s55)", primals_4: "f64[s64, s55]", primals_5: "i64[s51 + 1]", primals_6: "f32[s0, 0]", primals_7: "f32[s83, 0]", primals_8: "Sym(s51)", primals_9: "Sym(s55)", primals_10: "Sym(s55)"):
|
||||
clone: "f64[s64, s55]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
def forward(self, primals_1: "Sym(s2)", primals_2: "Sym(s3)", primals_3: "Sym(s1)", primals_4: "f64[s0, s1]", primals_5: "i64[s2 + 1]", primals_6: "f32[s6, 0]", primals_7: "f32[s7, 0]", primals_8: "Sym(s2)", primals_9: "Sym(s1)", primals_10: "Sym(s1)"):
|
||||
clone: "f64[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
|
||||
cat: "f64[s64, 2*s55]" = torch.ops.aten.cat.default([clone, clone], 1); clone = None
|
||||
add_2: "Sym(2*s55)" = primals_10 + primals_10
|
||||
cat: "f64[s0, 2*s1]" = torch.ops.aten.cat.default([clone, clone], 1); clone = None
|
||||
add_2: "Sym(2*s1)" = primals_10 + primals_10
|
||||
return (cat, primals_5, primals_6, primals_7, primals_8, add_2, add_2, primals_8, primals_10, add_2)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -2547,11 +2547,11 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_8: "Sym(s51)", primals_10: "Sym(s55)", add_2: "Sym(2*s55)", tangents_1: "f64[s64, 2*s55]", tangents_2: "i64[s51 + 1]", tangents_3: "f32[s0, 0]", tangents_4: "f32[s83, 0]"):
|
||||
slice_1: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, 0, primals_10)
|
||||
slice_2: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, primals_10, add_2); tangents_1 = add_2 = None
|
||||
def forward(self, primals_8: "Sym(s2)", primals_10: "Sym(s1)", add_2: "Sym(2*s1)", tangents_1: "f64[s0, 2*s1]", tangents_2: "i64[s2 + 1]", tangents_3: "f32[s6, 0]", tangents_4: "f32[s7, 0]"):
|
||||
slice_1: "f64[s0, s1]" = torch.ops.aten.slice.Tensor(tangents_1, 1, 0, primals_10)
|
||||
slice_2: "f64[s0, s1]" = torch.ops.aten.slice.Tensor(tangents_1, 1, primals_10, add_2); tangents_1 = add_2 = None
|
||||
|
||||
add_4: "f64[s64, s55]" = torch.ops.aten.add.Tensor(slice_1, slice_2); slice_1 = slice_2 = None
|
||||
add_4: "f64[s0, s1]" = torch.ops.aten.add.Tensor(slice_1, slice_2); slice_1 = slice_2 = None
|
||||
return (None, None, None, add_4, tangents_2, tangents_3, tangents_4, primals_8, primals_10, primals_10)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -2580,7 +2580,7 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "Sym(s51)", arg1_1: "Sym(s71)", arg2_1: "Sym(s55)", arg3_1: "f64[9, s55]", arg4_1: "i64[s51 + 1]", arg5_1: "f32[s0, 0]", arg6_1: "f32[s83, 0]", arg7_1: "Sym(s51)", arg8_1: "Sym(s55)", arg9_1: "Sym(s55)"):
|
||||
def forward(self, arg0_1: "Sym(s3)", arg1_1: "Sym(s4)", arg2_1: "Sym(s2)", arg3_1: "f64[9, s2]", arg4_1: "i64[s3 + 1]", arg5_1: "f32[s7, 0]", arg6_1: "f32[s8, 0]", arg7_1: "Sym(s3)", arg8_1: "Sym(s2)", arg9_1: "Sym(s2)"):
|
||||
randn: "f64[2, 5]" = torch.ops.aten.randn.default([2, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
|
||||
randn_1: "f64[3, 5]" = torch.ops.aten.randn.default([3, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
|
||||
randn_2: "f64[4, 5]" = torch.ops.aten.randn.default([4, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
|
||||
|
|
@ -2594,13 +2594,13 @@ class <lambda>(torch.nn.Module):
|
|||
zeros_1: "f32[2, 0]" = torch.ops.aten.zeros.default([2, 0], device = device(type='cpu'), pin_memory = False)
|
||||
zeros_2: "f32[4, 0]" = torch.ops.aten.zeros.default([4, 0], device = device(type='cpu'), pin_memory = False)
|
||||
|
||||
cat_2: "f64[9, s55 + 5]" = torch.ops.aten.cat.default([cat, arg3_1], 1); cat = arg3_1 = None
|
||||
cat_2: "f64[9, s2 + 5]" = torch.ops.aten.cat.default([cat, arg3_1], 1); cat = arg3_1 = None
|
||||
|
||||
sin: "f64[9, s55 + 5]" = torch.ops.aten.sin.default(cat_2)
|
||||
mul: "f64[9, s55 + 5]" = torch.ops.aten.mul.Tensor(sin, 3); sin = None
|
||||
sin: "f64[9, s2 + 5]" = torch.ops.aten.sin.default(cat_2)
|
||||
mul: "f64[9, s2 + 5]" = torch.ops.aten.mul.Tensor(sin, 3); sin = None
|
||||
|
||||
sym_size_int: "Sym(s55 + 5)" = torch.ops.aten.sym_size.int(cat_2, 1); cat_2 = None
|
||||
sym_stride_int: "Sym(s55 + 5)" = torch.ops.aten.sym_stride.int(mul, 0)
|
||||
sym_size_int: "Sym(s2 + 5)" = torch.ops.aten.sym_size.int(cat_2, 1); cat_2 = None
|
||||
sym_stride_int: "Sym(s2 + 5)" = torch.ops.aten.sym_stride.int(mul, 0)
|
||||
return (mul, cat_1, zeros_1, zeros_2, sym_size_int, sym_stride_int)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -2757,10 +2757,10 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase):
|
|||
norm_graph,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, s71: "Sym(s71)", L_nt_: "f64[3, s71, 5]"):
|
||||
def forward(self, s1: "Sym(s1)", L_nt_: "f64[3, s1, 5]"):
|
||||
l_nt_ = L_nt_
|
||||
|
||||
add: "f64[3, s71, 5]" = l_nt_ + 2; l_nt_ = None
|
||||
add: "f64[3, s1, 5]" = l_nt_ + 2; l_nt_ = None
|
||||
return (add,)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -3254,27 +3254,27 @@ class GraphModule(torch.nn.Module):
|
|||
# varies based on the type of view
|
||||
guard_str = "\n".join(guards)
|
||||
if nt_view_name == "subclass_dense":
|
||||
self.assertExpectedInline(guard_str, """Eq(s85 - 1, s77)""")
|
||||
self.assertExpectedInline(guard_str, """Eq(s3 - 1, s0)""")
|
||||
elif nt_view_name == "dense_subclass_dense_subclass":
|
||||
self.assertExpectedInline(
|
||||
guard_str,
|
||||
"""\
|
||||
Eq(s85 - 1, s77)
|
||||
Eq(s80 - 1, s78)
|
||||
Eq(s72, s71)""",
|
||||
Eq(s5 - 1, s2)
|
||||
Eq(s12 - 1, s7)
|
||||
Eq(s11, s9)""",
|
||||
)
|
||||
elif nt_view_name.startswith("base_is_nt_True"):
|
||||
self.assertExpectedInline(
|
||||
guard_str,
|
||||
"""Eq(s17 - 1, s83)""",
|
||||
"""Eq(s3 - 1, s0)""",
|
||||
)
|
||||
else:
|
||||
self.assertExpectedInline(
|
||||
guard_str,
|
||||
"""\
|
||||
Eq(s85 - 1, s64)
|
||||
Eq(s80 - 1, s77)
|
||||
Eq(s72, s71)""",
|
||||
Eq(s4 - 1, s1)
|
||||
Eq(s13 - 1, s8)
|
||||
Eq(s12, s10)""",
|
||||
)
|
||||
return gm
|
||||
|
||||
|
|
|
|||
|
|
@ -1560,7 +1560,7 @@ graph():
|
|||
)
|
||||
with self.assertRaisesRegex(
|
||||
error_type,
|
||||
r"Real tensor propagation found an output size mismatch between fake shape s\d+ and real shape 4, "
|
||||
r"Real tensor propagation found an output size mismatch between fake shape s1 and real shape 4, "
|
||||
r"at output\.size\(0\), for func: mylib.foo.default",
|
||||
):
|
||||
export(
|
||||
|
|
@ -2848,7 +2848,7 @@ def forward(self, p_linear_weight, p_linear_bias, x):
|
|||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected input.*shape.*= 9 to be "
|
||||
"of the form 2\\*s92, where s92 is an integer",
|
||||
"of the form 2\\*s1, where s1 is an integer",
|
||||
):
|
||||
ep.module()(torch.randn(9))
|
||||
|
||||
|
|
@ -3506,11 +3506,8 @@ def forward(self, x):
|
|||
dynamic_shapes=({0: Dim("x")},),
|
||||
)
|
||||
|
||||
# Since symbol names are based on hash of source names, and these differ across inference and
|
||||
# training, we do range comparisons instead.
|
||||
self.assertEqual(
|
||||
str(ep_for_training.range_constraints.values()),
|
||||
str(ep_for_real.range_constraints.values()),
|
||||
str(ep_for_training.range_constraints), str(ep_for_real.range_constraints)
|
||||
)
|
||||
|
||||
def test_export_for_training_with_container_type(self):
|
||||
|
|
@ -4388,7 +4385,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
em.module()(torch.randn(4, 3))
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Runtime assertion failed for expression Eq\(Mod\(s27\*s77, s77 \- 1\), 0\)",
|
||||
r"Runtime assertion failed for expression Eq\(Mod\(s0\*s1, s0 \- 1\), 0\)",
|
||||
):
|
||||
em.module()(torch.randn(4, 5))
|
||||
|
||||
|
|
@ -4399,7 +4396,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
x = torch.randn(3, 5)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected.*shape\\[1\\] = 5 to be of the form 2\\*s33, where s33 is an integer",
|
||||
"Expected.*shape\\[1\\] = 5 to be of the form 2\\*s1, where s1 is an integer",
|
||||
):
|
||||
em.module()(x)
|
||||
|
||||
|
|
@ -4958,14 +4955,11 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
)
|
||||
self.assertEqual(
|
||||
[
|
||||
# First dimension varies across strict and non-strict
|
||||
# since the source names are different, resulting in
|
||||
# different symbol names.
|
||||
str(node.meta["val"].shape[1:])
|
||||
str(node.meta["val"].shape)
|
||||
for node in efoo.graph_module.graph.nodes
|
||||
if node.op == "placeholder"
|
||||
],
|
||||
["torch.Size([2, 3])", "torch.Size([3, 4])"],
|
||||
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
|
||||
)
|
||||
|
||||
@testing.expectedFailureCppSerDes
|
||||
|
|
@ -5103,10 +5097,14 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
"y": (batch, size, size),
|
||||
},
|
||||
)
|
||||
|
||||
for node in efoo.graph_module.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
self.assertEqual(node.meta["val"].shape[1], node.meta["val"].shape[2])
|
||||
self.assertEqual(
|
||||
[
|
||||
str(node.meta["val"].shape)
|
||||
for node in efoo.graph_module.graph.nodes
|
||||
if node.op == "placeholder"
|
||||
],
|
||||
["torch.Size([s0, s1, s1])", "torch.Size([s0, s1, s1])"],
|
||||
)
|
||||
self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
|
||||
|
||||
# pass dynamic shapes of inputs [multiple, mostly distinct]
|
||||
|
|
@ -5117,14 +5115,13 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
inputs,
|
||||
dynamic_shapes={"x": (batch, M, K), "y": (batch, K, N)},
|
||||
)
|
||||
placeholders = [
|
||||
node.meta["val"].shape
|
||||
for node in efoo.graph_module.graph.nodes
|
||||
if node.op == "placeholder"
|
||||
]
|
||||
self.assertEqual(
|
||||
placeholders[0][2],
|
||||
placeholders[1][1],
|
||||
[
|
||||
str(node.meta["val"].shape)
|
||||
for node in efoo.graph_module.graph.nodes
|
||||
if node.op == "placeholder"
|
||||
],
|
||||
["torch.Size([s0, s1, s2])", "torch.Size([s0, s2, s5])"],
|
||||
)
|
||||
self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
|
||||
|
||||
|
|
@ -5141,14 +5138,11 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
)
|
||||
self.assertEqual(
|
||||
[
|
||||
# First dimension varies across strict and non-strict
|
||||
# since the source names are different, resulting in
|
||||
# different symbol names.
|
||||
str(node.meta["val"].shape[1:])
|
||||
str(node.meta["val"].shape)
|
||||
for node in efoo.graph_module.graph.nodes
|
||||
if node.op == "placeholder"
|
||||
],
|
||||
["torch.Size([2, 3])", "torch.Size([3, 4])"],
|
||||
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
|
||||
)
|
||||
self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
|
||||
|
||||
|
|
@ -5165,14 +5159,11 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
)
|
||||
self.assertEqual(
|
||||
[
|
||||
# First dimension varies across strict and non-strict
|
||||
# since the source names are different, resulting in
|
||||
# different symbol names.
|
||||
str(node.meta["val"].shape[1:])
|
||||
str(node.meta["val"].shape)
|
||||
for node in efoo.graph_module.graph.nodes
|
||||
if node.op == "placeholder"
|
||||
],
|
||||
["torch.Size([2, 3])", "torch.Size([3, 4])"],
|
||||
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
|
||||
)
|
||||
self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
|
||||
|
||||
|
|
@ -5482,7 +5473,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
if node.op == "placeholder"
|
||||
]
|
||||
self.assertEqual(len(input_shapes), 9)
|
||||
self.assertTrue(all(shape == "torch.Size([s3])" for shape in input_shapes))
|
||||
self.assertTrue(all(shape == "torch.Size([s0])" for shape in input_shapes))
|
||||
|
||||
def test_error_does_not_reference_eager_fallback(self):
|
||||
class Module(torch.nn.Module):
|
||||
|
|
@ -11146,7 +11137,7 @@ def forward(self, x, y):
|
|||
self.assertEqual(out2.shape, torch.ones(11, 4, 3).shape)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Runtime assertion failed for expression Eq\(Mod\(s27\*s77, 4\*s77 \- 4\), 0\) on node 'eq.*'",
|
||||
r"Runtime assertion failed for expression Eq\(Mod\(s0\*s1, 4\*s0 \- 4\), 0\) on node 'eq.*'",
|
||||
):
|
||||
ep.module()(torch.randn(8, 8)) # fail
|
||||
|
||||
|
|
@ -11178,7 +11169,7 @@ def forward(self, x, y):
|
|||
self.assertEqual(out2.shape, torch.ones(40).shape)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Runtime assertion failed for expression Eq\((.*)\) on node '.*'",
|
||||
r"Runtime assertion failed for expression Eq\(s0\*s1, s2\*s3\) on node 'eq.*'",
|
||||
): # fail only at runtime
|
||||
ep.module()(torch.randn(5, 8), torch.randn(4, 5), torch.randn(30)) # fail
|
||||
|
||||
|
|
@ -11205,7 +11196,7 @@ def forward(self, x, y):
|
|||
self.assertEqual(out1.shape, torch.ones(126).shape)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Runtime assertion failed for expression Eq\((.*)\) on node '.*'",
|
||||
r"Runtime assertion failed for expression Eq\(s0\*s1\*s2, s3\) on node 'eq.*'",
|
||||
): # fail only at runtime
|
||||
ep.module()(torch.randn(4, 3, 2), torch.randn(10)) # fail
|
||||
|
||||
|
|
@ -11286,12 +11277,12 @@ def forward(self, x, y):
|
|||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Runtime assertion failed for expression Ne\(s77, 20\)",
|
||||
r"Runtime assertion failed for expression Ne\(s0, 20\)",
|
||||
):
|
||||
ep.module()(torch.randn(20, 20, 16))
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Runtime assertion failed for expression Ne\(Mod\(s77, 20\), 0\)",
|
||||
r"Runtime assertion failed for expression Ne\(Mod\(s0, 20\), 0\)",
|
||||
):
|
||||
ep.module()(torch.randn(400, 20, 16))
|
||||
ep.module()(torch.randn(42, 20, 16))
|
||||
|
|
@ -11329,17 +11320,17 @@ def forward(self, x, y):
|
|||
self.assertEqual(out1.shape, torch.ones(27).shape)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Runtime assertion failed for expression Ne\(s77, s17\)",
|
||||
r"Runtime assertion failed for expression Ne\(s0, s1\)",
|
||||
): # fail only at runtime
|
||||
ep.module()(torch.randn(4), torch.randn(4)) # fail
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Runtime assertion failed for expression Ne\(s77, s17\**3\)",
|
||||
r"Runtime assertion failed for expression Ne\(s0, s1\**3\)",
|
||||
):
|
||||
ep.module()(torch.randn(64), torch.randn(4)) # fail
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Runtime assertion failed for expression Eq\(s77\**2, 3\*s17\)",
|
||||
r"Runtime assertion failed for expression Eq\(s0\**2, 3\*s1\)",
|
||||
):
|
||||
ep.module()(torch.randn(10), torch.randn(9)) # fail
|
||||
|
||||
|
|
|
|||
|
|
@ -539,12 +539,8 @@ def forward(self, x):
|
|||
ep.range_constraints[symint] = ValueRanges(lower=lower_range, upper=upper_range)
|
||||
|
||||
serialized = ExportedProgramSerializer().serialize(ep)
|
||||
self.assertEqual(
|
||||
serialized.exported_program.range_constraints["s77"].min_val, 2
|
||||
)
|
||||
self.assertEqual(
|
||||
serialized.exported_program.range_constraints["s77"].max_val, 3
|
||||
)
|
||||
self.assertEqual(serialized.exported_program.range_constraints["s0"].min_val, 2)
|
||||
self.assertEqual(serialized.exported_program.range_constraints["s0"].max_val, 3)
|
||||
|
||||
def test_kwargs_default(self) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -5941,8 +5941,8 @@ class TestAOTModuleSimplified(AOTTestCase):
|
|||
self.assertExpectedInline(
|
||||
shape_env.format_guards(),
|
||||
"""\
|
||||
- Eq(s49, 20)
|
||||
- Eq(s70, 30)""",
|
||||
- Eq(s1, 20)
|
||||
- Eq(s2, 30)""",
|
||||
)
|
||||
|
||||
assert torch.allclose(ref[0], res[0])
|
||||
|
|
|
|||
|
|
@ -4553,10 +4553,10 @@ def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_bo
|
|||
gm.code.strip("\n"),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(arg3_1, 1)
|
||||
sym_size_int = torch.ops.aten.sym_size.int(arg2_1, 0)
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(arg2_1, 1)
|
||||
sym_size_int_2 = torch.ops.aten.sym_size.int(arg2_1, 0)
|
||||
sym_size_int_3 = torch.ops.aten.sym_size.int(arg3_1, 0)
|
||||
sym_size_int_2 = torch.ops.aten.sym_size.int(arg3_1, 0)
|
||||
sym_size_int_3 = torch.ops.aten.sym_size.int(arg3_1, 1)
|
||||
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
||||
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
||||
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), (sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3)); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None
|
||||
|
|
@ -4719,10 +4719,10 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1
|
|||
def forward(self, a_1, b_1):
|
||||
sum_1 = torch.ops.aten.sum.default(a_1)
|
||||
gt = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
|
||||
sym_size_int = torch.ops.aten.sym_size.int(a_1, 1)
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(b_1, 0)
|
||||
sym_size_int_2 = torch.ops.aten.sym_size.int(b_1, 1)
|
||||
sym_size_int_3 = torch.ops.aten.sym_size.int(a_1, 0)
|
||||
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(a_1, 1)
|
||||
sym_size_int_2 = torch.ops.aten.sym_size.int(b_1, 0)
|
||||
sym_size_int_3 = torch.ops.aten.sym_size.int(b_1, 1)
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (a_1, b_1, sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3)); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None
|
||||
|
|
@ -5856,7 +5856,7 @@ def forward(self, x_1):
|
|||
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int_1, sym_size_int)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = sym_size_int = None
|
||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int, sym_size_int_1)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int = sym_size_int_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
return getitem""", # noqa: B950
|
||||
)
|
||||
|
|
@ -5969,7 +5969,7 @@ def forward(self, x_1):
|
|||
false_graph_0 = self.false_graph_0
|
||||
_tensor_constant0 = self._tensor_constant0
|
||||
_tensor_constant1 = self._tensor_constant1
|
||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, _tensor_constant0, sym_size_int_1, sym_size_int, _tensor_constant1)); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = sym_size_int_1 = sym_size_int = _tensor_constant1 = None
|
||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, _tensor_constant0, sym_size_int, sym_size_int_1, _tensor_constant1)); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = sym_size_int = sym_size_int_1 = _tensor_constant1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
return getitem""", # noqa: B950
|
||||
)
|
||||
|
|
@ -6209,7 +6209,7 @@ def forward(self, x_1):
|
|||
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int_1, sym_size_int)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = sym_size_int = None
|
||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int, sym_size_int_1)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int = sym_size_int_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
return getitem""", # noqa: B950
|
||||
)
|
||||
|
|
@ -6558,14 +6558,14 @@ def forward(self, l_inp_, l_tmp_):
|
|||
self.assertExpectedInline(
|
||||
backend.graphs[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor, L_self_num : torch.SymInt):
|
||||
def forward(self, s0 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor, L_self_num : torch.SymInt):
|
||||
l_a_ = L_a_
|
||||
l_b_ = L_b_
|
||||
l_self_num = L_self_num
|
||||
tensor = torch.tensor([True])
|
||||
cond_true_0 = self.cond_true_0
|
||||
cond_false_0 = self.cond_false_0
|
||||
cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, l_self_num, s97)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = l_self_num = s97 = None
|
||||
cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, l_self_num, s0)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = l_self_num = s0 = None
|
||||
getitem = cond[0]; cond = None
|
||||
return (getitem,)""", # noqa: B950
|
||||
)
|
||||
|
|
@ -6753,10 +6753,10 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_
|
|||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x: "f32[s35, 3]";
|
||||
x: "f32[s0, 3]";
|
||||
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
sym_size_int_1: "Sym(s35)" = torch.ops.aten.sym_size.int(x, 0)
|
||||
sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
|
||||
|
||||
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
||||
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
||||
|
|
@ -6770,27 +6770,27 @@ class GraphModule(torch.nn.Module):
|
|||
gt_1: "Sym(u1 > 0)" = getitem_2 > 0
|
||||
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 on node 'gt_1'"); gt_1 = _assert_scalar_default_1 = None
|
||||
|
||||
getitem_1: "f32[s35, 3]" = while_loop[1]; while_loop = None
|
||||
getitem_1: "f32[s0, 3]" = while_loop[1]; while_loop = None
|
||||
|
||||
add: "Sym(u1 + 1)" = getitem_2 + 1
|
||||
|
||||
add_1: "f32[s35, 3]" = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = None
|
||||
add_1: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = None
|
||||
|
||||
lt: "Sym(u1 < s35)" = getitem_2 < sym_size_int_1; sym_size_int_1 = None
|
||||
lt: "Sym(u1 < s0)" = getitem_2 < sym_size_int_1; sym_size_int_1 = None
|
||||
|
||||
mul: "Sym(2*u1)" = getitem_2 * 2; getitem_2 = None
|
||||
ones: "f32[2*u1]" = torch.ops.aten.ones.default([mul], device = device(type='cpu'), pin_memory = False); mul = None
|
||||
return pytree.tree_unflatten((add, add_1, lt, ones), self._out_spec)
|
||||
|
||||
class while_loop_cond_graph_0(torch.nn.Module):
|
||||
def forward(self, it_1: "Sym(u0)", x_1: "f32[s35, 3]"):
|
||||
sym_size_int: "Sym(s35)" = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
|
||||
lt: "Sym(u0 < s35)" = it_1 < sym_size_int; it_1 = sym_size_int = None
|
||||
def forward(self, it_1: "Sym(u0)", x_1: "f32[s0, 3]"):
|
||||
sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
|
||||
lt: "Sym(u0 < s0)" = it_1 < sym_size_int; it_1 = sym_size_int = None
|
||||
return lt
|
||||
|
||||
class while_loop_body_graph_0(torch.nn.Module):
|
||||
def forward(self, it_1: "Sym(u0)", x_1: "f32[s35, 3]"):
|
||||
clone: "f32[s35, 3]" = torch.ops.aten.clone.default(x_1); x_1 = None
|
||||
def forward(self, it_1: "Sym(u0)", x_1: "f32[s0, 3]"):
|
||||
clone: "f32[s0, 3]" = torch.ops.aten.clone.default(x_1); x_1 = None
|
||||
select: "f32[3]" = torch.ops.aten.select.int(clone, 0, it_1)
|
||||
select_1: "f32[3]" = torch.ops.aten.select.int(clone, 0, it_1)
|
||||
add: "f32[3]" = torch.ops.aten.add.Tensor(select_1, it_1); select_1 = None
|
||||
|
|
@ -6820,12 +6820,12 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"):
|
||||
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_x_: "f32[s0, s1]"):
|
||||
l_x_ = L_x_
|
||||
|
||||
cond_fn_0 = self.cond_fn_0
|
||||
body_fn_0 = self.body_fn_0
|
||||
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (0, l_x_), (s27, s77)); cond_fn_0 = body_fn_0 = l_x_ = s27 = None
|
||||
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (0, l_x_), (s0, s1)); cond_fn_0 = body_fn_0 = l_x_ = s1 = None
|
||||
|
||||
getitem_4: "Sym(u1)" = while_loop[0]
|
||||
|
||||
|
|
@ -6835,49 +6835,49 @@ class GraphModule(torch.nn.Module):
|
|||
gt_1: "Sym(u1 > 0)" = getitem_4 > 0
|
||||
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 on node 'gt_1'"); gt_1 = _assert_scalar_default_1 = None
|
||||
|
||||
out_x: "f32[s77, s27]" = while_loop[1]; while_loop = None
|
||||
out_x: "f32[s0, s1]" = while_loop[1]; while_loop = None
|
||||
|
||||
add: "Sym(u1 + 1)" = getitem_4 + 1
|
||||
|
||||
add_1: "f32[s77, s27]" = getitem_4 + out_x; out_x = None
|
||||
add_1: "f32[s0, s1]" = getitem_4 + out_x; out_x = None
|
||||
|
||||
lt: "Sym(u1 < s77)" = getitem_4 < s77; s77 = None
|
||||
lt: "Sym(u1 < s0)" = getitem_4 < s0; s0 = None
|
||||
|
||||
mul: "Sym(2*u1)" = getitem_4 * 2; getitem_4 = None
|
||||
ones: "f32[2*u1]" = torch.ones(mul); mul = None
|
||||
return (add, add_1, lt, ones)
|
||||
|
||||
class cond_fn_0(torch.nn.Module):
|
||||
def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s77, s27]", s27, s77):
|
||||
s27_1 = s27
|
||||
s77_1 = s77
|
||||
def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s0, s1]", s0, s1):
|
||||
s0_1 = s0
|
||||
s1_1 = s1
|
||||
|
||||
size = l_x_.size(); l_x_ = None
|
||||
getitem: "Sym(s77)" = size[0]
|
||||
getitem_1: "Sym(s27)" = size[1]; size = getitem_1 = None
|
||||
lt: "Sym(u0 < s77)" = unbacked_symint < getitem; unbacked_symint = getitem = None
|
||||
getitem: "Sym(s0)" = size[0]
|
||||
getitem_1: "Sym(s1)" = size[1]; size = getitem_1 = None
|
||||
lt: "Sym(u0 < s0)" = unbacked_symint < getitem; unbacked_symint = getitem = None
|
||||
return lt
|
||||
|
||||
class body_fn_0(torch.nn.Module):
|
||||
def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s77, s27]", s27, s77):
|
||||
s27_1 = s27
|
||||
s77_1 = s77
|
||||
def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s0, s1]", s0, s1):
|
||||
s0_1 = s0
|
||||
s1_1 = s1
|
||||
|
||||
x_clone: "f32[s77, s27]" = l_x_.clone()
|
||||
x_clone: "f32[s0, s1]" = l_x_.clone()
|
||||
|
||||
ge: "Sym(u0 >= 0)" = unbacked_symint >= 0
|
||||
_check = torch._check(ge); ge = _check = None
|
||||
|
||||
size = l_x_.size(); l_x_ = None
|
||||
getitem: "Sym(s77)" = size[0]
|
||||
getitem_1: "Sym(s27)" = size[1]; size = getitem_1 = None
|
||||
lt: "Sym(u0 < s77)" = unbacked_symint < getitem; getitem = None
|
||||
getitem: "Sym(s0)" = size[0]
|
||||
getitem_1: "Sym(s1)" = size[1]; size = getitem_1 = None
|
||||
lt: "Sym(u0 < s0)" = unbacked_symint < getitem; getitem = None
|
||||
_check_1 = torch._check(lt); lt = _check_1 = None
|
||||
|
||||
select: "f32[s27]" = x_clone.select(0, unbacked_symint)
|
||||
select_1: "f32[s27]" = x_clone.select(0, unbacked_symint)
|
||||
add: "f32[s27]" = select_1 + unbacked_symint; select_1 = None
|
||||
copy_: "f32[s27]" = select.copy_(add); select = add = copy_ = None
|
||||
select: "f32[s1]" = x_clone.select(0, unbacked_symint)
|
||||
select_1: "f32[s1]" = x_clone.select(0, unbacked_symint)
|
||||
add: "f32[s1]" = select_1 + unbacked_symint; select_1 = None
|
||||
copy_: "f32[s1]" = select.copy_(add); select = add = copy_ = None
|
||||
|
||||
add_1: "Sym(u0 + 1)" = unbacked_symint + 1; unbacked_symint = None
|
||||
return (add_1, x_clone)
|
||||
|
|
@ -7048,12 +7048,12 @@ class GraphModule(torch.nn.Module):
|
|||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x: "f32[s77, 3]";
|
||||
x: "f32[s0, 3]";
|
||||
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0)
|
||||
sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
|
||||
|
||||
sin: "f32[s77, 3]" = torch.ops.aten.sin.default(x); x = None
|
||||
sin: "f32[s0, 3]" = torch.ops.aten.sin.default(x); x = None
|
||||
|
||||
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
||||
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
||||
|
|
@ -7065,19 +7065,19 @@ class GraphModule(torch.nn.Module):
|
|||
getitem_9: "Sym(u8)" = while_loop[3]
|
||||
getitem_10: "Sym(u9)" = while_loop[4]
|
||||
|
||||
getitem_5: "f32[s77, 3]" = while_loop[5]; while_loop = None
|
||||
getitem_5: "f32[s0, 3]" = while_loop[5]; while_loop = None
|
||||
|
||||
add: "Sym(u7 + 1)" = getitem_8 + 1
|
||||
add_1: "Sym(u8 + 1)" = getitem_9 + 1
|
||||
add_2: "Sym(u9 + 1)" = getitem_10 + 1
|
||||
|
||||
add_3: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_8); getitem_8 = None
|
||||
add_4: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_9); getitem_9 = None
|
||||
add_5: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_10); getitem_10 = None
|
||||
add_3: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_8); getitem_8 = None
|
||||
add_4: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_9); getitem_9 = None
|
||||
add_5: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_10); getitem_10 = None
|
||||
return pytree.tree_unflatten((getitem_6, getitem_7, add, add_1, add_2, add_3, add_4, add_5, getitem_5), self._out_spec)
|
||||
|
||||
class while_loop_cond_graph_0(torch.nn.Module):
|
||||
def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s77, 3]"):
|
||||
def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s0, 3]"):
|
||||
mul: "Sym(u17*u18)" = arg2_1 * arg3_1; arg2_1 = arg3_1 = None
|
||||
mul_1: "Sym(u17*u18*u19)" = mul * arg4_1; mul = arg4_1 = None
|
||||
mul_2: "Sym(u15*u16)" = arg0_1 * arg1_1; arg0_1 = arg1_1 = None
|
||||
|
|
@ -7085,7 +7085,7 @@ class GraphModule(torch.nn.Module):
|
|||
return lt
|
||||
|
||||
class while_loop_body_graph_0(torch.nn.Module):
|
||||
def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s77, 3]"):
|
||||
def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s0, 3]"):
|
||||
add: "Sym(u15 + 1)" = arg0_1 + 1; arg0_1 = None
|
||||
add_1: "Sym(u16 + 1)" = arg1_1 + 1; arg1_1 = None
|
||||
|
||||
|
|
@ -7093,7 +7093,7 @@ class GraphModule(torch.nn.Module):
|
|||
add_3: "Sym(u18 + 1)" = arg3_1 + 1; arg3_1 = None
|
||||
add_4: "Sym(u19 + 1)" = arg4_1 + 1; arg4_1 = None
|
||||
|
||||
add_5: "f32[s77, 3]" = torch.ops.aten.add.Tensor(arg5_1, 1); arg5_1 = None
|
||||
add_5: "f32[s0, 3]" = torch.ops.aten.add.Tensor(arg5_1, 1); arg5_1 = None
|
||||
return (add, add_1, add_2, add_3, add_4, add_5)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -7119,14 +7119,14 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"):
|
||||
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_x_: "f32[s0, s1]"):
|
||||
l_x_ = L_x_
|
||||
|
||||
child: "f32[s77, s27]" = l_x_.sin(); l_x_ = None
|
||||
child: "f32[s0, s1]" = l_x_.sin(); l_x_ = None
|
||||
|
||||
cond_fn_0 = self.cond_fn_0
|
||||
body_fn_0 = self.body_fn_0
|
||||
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (s77, s27, 2, 2, 3, child), (s27, s77)); cond_fn_0 = body_fn_0 = s77 = s27 = child = None
|
||||
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (s0, s1, 2, 2, 3, child), (s0, s1)); cond_fn_0 = body_fn_0 = s0 = s1 = child = None
|
||||
|
||||
getitem_10: "Sym(u5)" = while_loop[0]
|
||||
getitem_11: "Sym(u6)" = while_loop[1]
|
||||
|
|
@ -7134,21 +7134,21 @@ class GraphModule(torch.nn.Module):
|
|||
getitem_13: "Sym(u8)" = while_loop[3]
|
||||
getitem_14: "Sym(u9)" = while_loop[4]
|
||||
|
||||
out_x: "f32[s77, s27]" = while_loop[5]; while_loop = None
|
||||
out_x: "f32[s0, s1]" = while_loop[5]; while_loop = None
|
||||
|
||||
add: "Sym(u7 + 1)" = getitem_12 + 1
|
||||
add_1: "Sym(u8 + 1)" = getitem_13 + 1
|
||||
add_2: "Sym(u9 + 1)" = getitem_14 + 1
|
||||
|
||||
add_3: "f32[s77, s27]" = getitem_12 + out_x; getitem_12 = None
|
||||
add_4: "f32[s77, s27]" = getitem_13 + out_x; getitem_13 = None
|
||||
add_5: "f32[s77, s27]" = getitem_14 + out_x; getitem_14 = None
|
||||
add_3: "f32[s0, s1]" = getitem_12 + out_x; getitem_12 = None
|
||||
add_4: "f32[s0, s1]" = getitem_13 + out_x; getitem_13 = None
|
||||
add_5: "f32[s0, s1]" = getitem_14 + out_x; getitem_14 = None
|
||||
return (getitem_10, getitem_11, add, add_1, add_2, add_3, add_4, add_5, out_x)
|
||||
|
||||
class cond_fn_0(torch.nn.Module):
|
||||
def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s77, s27]", s27, s77):
|
||||
s27_1 = s27
|
||||
s77_1 = s77
|
||||
def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s0, s1]", s0, s1):
|
||||
s0_1 = s0
|
||||
s1_1 = s1
|
||||
|
||||
mul: "Sym(u2*u3)" = unbacked_symint_1 * unbacked_symint_2; unbacked_symint_1 = unbacked_symint_2 = None
|
||||
mul_1: "Sym(u2*u3*u4)" = mul * unbacked_symint_3; mul = unbacked_symint_3 = None
|
||||
|
|
@ -7157,9 +7157,9 @@ class GraphModule(torch.nn.Module):
|
|||
return lt
|
||||
|
||||
class body_fn_0(torch.nn.Module):
|
||||
def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s77, s27]", s27, s77):
|
||||
s27_1 = s27
|
||||
s77_1 = s77
|
||||
def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s0, s1]", s0, s1):
|
||||
s0_1 = s0
|
||||
s1_1 = s1
|
||||
|
||||
add: "Sym(u0 + 1)" = unbacked_symint + 1; unbacked_symint = None
|
||||
add_1: "Sym(u1 + 1)" = unbacked_symint_0 + 1; unbacked_symint_0 = None
|
||||
|
|
@ -7168,7 +7168,7 @@ class GraphModule(torch.nn.Module):
|
|||
add_3: "Sym(u3 + 1)" = unbacked_symint_2 + 1; unbacked_symint_2 = None
|
||||
add_4: "Sym(u4 + 1)" = unbacked_symint_3 + 1; unbacked_symint_3 = None
|
||||
|
||||
child_1: "f32[s77, s27]" = child + 1; child = None
|
||||
child_1: "f32[s0, s1]" = child + 1; child = None
|
||||
return (add, add_1, add_2, add_3, add_4, child_1)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -7336,30 +7336,30 @@ class GraphModule(torch.nn.Module):
|
|||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, x, y, z):
|
||||
x: "f32[s35, 3]"; y: "f32[s58]"; z: "f32[s35, 3]";
|
||||
x: "f32[s0, 3]"; y: "f32[s1]"; z: "f32[s0, 3]";
|
||||
|
||||
x, y, z, = fx_pytree.tree_flatten_spec(([x, y, z], {}), self._in_spec)
|
||||
sym_size_int_3: "Sym(s35)" = torch.ops.aten.sym_size.int(x, 0)
|
||||
sym_size_int_4: "Sym(s58)" = torch.ops.aten.sym_size.int(y, 0); y = None
|
||||
sym_size_int_3: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
|
||||
sym_size_int_4: "Sym(s1)" = torch.ops.aten.sym_size.int(y, 0); y = None
|
||||
|
||||
gt: "Sym(s35 > 5)" = sym_size_int_3 > 5
|
||||
gt: "Sym(s0 > 5)" = sym_size_int_3 > 5
|
||||
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x, sym_size_int_4, sym_size_int_3, z)); gt = true_graph_0 = false_graph_0 = x = sym_size_int_4 = sym_size_int_3 = z = None
|
||||
getitem: "f32[s35, 3]" = cond[0]; cond = None
|
||||
getitem: "f32[s0, 3]" = cond[0]; cond = None
|
||||
return pytree.tree_unflatten((getitem,), self._out_spec)
|
||||
|
||||
class true_graph_0(torch.nn.Module):
|
||||
def forward(self, x: "f32[s35, 3]", sym_size_int_4: "Sym(s58)", sym_size_int_3: "Sym(s35)", z: "f32[s35, 3]"):
|
||||
add: "f32[s35, 3]" = torch.ops.aten.add.Tensor(x, sym_size_int_4); x = sym_size_int_4 = None
|
||||
def forward(self, x: "f32[s0, 3]", sym_size_int_4: "Sym(s1)", sym_size_int_3: "Sym(s0)", z: "f32[s0, 3]"):
|
||||
add: "f32[s0, 3]" = torch.ops.aten.add.Tensor(x, sym_size_int_4); x = sym_size_int_4 = None
|
||||
return (add,)
|
||||
|
||||
class false_graph_0(torch.nn.Module):
|
||||
def forward(self, x: "f32[s35, 3]", sym_size_int_4: "Sym(s58)", sym_size_int_3: "Sym(s35)", z: "f32[s35, 3]"):
|
||||
mul: "f32[s35, 3]" = torch.ops.aten.mul.Tensor(z, sym_size_int_3); z = sym_size_int_3 = None
|
||||
def forward(self, x: "f32[s0, 3]", sym_size_int_4: "Sym(s1)", sym_size_int_3: "Sym(s0)", z: "f32[s0, 3]"):
|
||||
mul: "f32[s0, 3]" = torch.ops.aten.mul.Tensor(z, sym_size_int_3); z = sym_size_int_3 = None
|
||||
|
||||
add: "f32[s35, 3]" = torch.ops.aten.add.Tensor(x, mul); x = mul = None
|
||||
add: "f32[s0, 3]" = torch.ops.aten.add.Tensor(x, mul); x = mul = None
|
||||
return (add,)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
@ -7522,7 +7522,7 @@ class GraphModule(torch.nn.Module):
|
|||
normalize_gm(bk.graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, s17: "Sym(s17)", s94: "Sym(s94)", L_y_: "f32[s17, s94]", L_z_: "f32[s17, s94]", L_x_: "f32[s17, s94]"):
|
||||
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_y_: "f32[s0, s1]", L_z_: "f32[s0, s1]", L_x_: "f32[s0, s1]"):
|
||||
l_y_ = L_y_
|
||||
l_z_ = L_z_
|
||||
l_x_ = L_x_
|
||||
|
|
@ -7532,39 +7532,39 @@ class GraphModule(torch.nn.Module):
|
|||
|
||||
cond_true_0 = self.cond_true_0
|
||||
cond_false_0 = self.cond_false_0
|
||||
cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, (l_x_, s94, s17, s17, l_z_)); gt = cond_true_0 = cond_false_0 = l_x_ = s94 = s17 = l_z_ = None
|
||||
cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, (l_x_, s1, s0, s0, l_z_)); gt = cond_true_0 = cond_false_0 = l_x_ = s1 = s0 = l_z_ = None
|
||||
|
||||
getitem_5: "f32[u0, s94]" = cond[0]
|
||||
getitem_5: "f32[u0, s1]" = cond[0]
|
||||
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(getitem_5, 0); getitem_5 = None
|
||||
_check_is_size = torch._check_is_size(sym_size_int); _check_is_size = None
|
||||
|
||||
ge: "Sym(u0 >= 0)" = sym_size_int >= 0; sym_size_int = None
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
||||
ret: "f32[u0, s94]" = cond[0]; cond = None
|
||||
ret: "f32[u0, s1]" = cond[0]; cond = None
|
||||
|
||||
sum_2: "f32[]" = l_y_.sum(); l_y_ = None
|
||||
sub: "f32[u0, s94]" = sum_2 - ret; sum_2 = ret = None
|
||||
sub: "f32[u0, s1]" = sum_2 - ret; sum_2 = ret = None
|
||||
return (sub,)
|
||||
|
||||
class cond_true_0(torch.nn.Module):
|
||||
def forward(self, l_x_, s94, s17_true_branch, getitem_2_false_branch, l_z__false_branch):
|
||||
def forward(self, l_x_, s1, s0_true_branch, getitem_2_false_branch, l_z__false_branch):
|
||||
l_x__1 = l_x_
|
||||
s94_1 = s94
|
||||
s1_1 = s1
|
||||
|
||||
add: "f32[s17, s94]" = l_x__1 + s17_true_branch; l_x__1 = s17_true_branch = None
|
||||
getitem: "f32[s17 - 2, s94]" = add[slice(2, None, None)]; add = None
|
||||
clone: "f32[s17 - 2, s94]" = getitem.clone(); getitem = None
|
||||
add: "f32[s0, s1]" = l_x__1 + s0_true_branch; l_x__1 = s0_true_branch = None
|
||||
getitem: "f32[s0 - 2, s1]" = add[slice(2, None, None)]; add = None
|
||||
clone: "f32[s0 - 2, s1]" = getitem.clone(); getitem = None
|
||||
return (clone,)
|
||||
|
||||
class cond_false_0(torch.nn.Module):
|
||||
def forward(self, l_x_, s94, s17_true_branch, getitem_2_false_branch, l_z__false_branch):
|
||||
def forward(self, l_x_, s1, s0_true_branch, getitem_2_false_branch, l_z__false_branch):
|
||||
l_x__1 = l_x_
|
||||
s94_1 = s94
|
||||
s1_1 = s1
|
||||
|
||||
mul: "f32[s17, s94]" = getitem_2_false_branch * l_z__false_branch; getitem_2_false_branch = l_z__false_branch = None
|
||||
add: "f32[s17, s94]" = l_x__1 + mul; l_x__1 = mul = None
|
||||
getitem: "f32[2, s94]" = add[slice(None, 2, None)]; add = None
|
||||
clone: "f32[2, s94]" = getitem.clone(); getitem = None
|
||||
mul: "f32[s0, s1]" = getitem_2_false_branch * l_z__false_branch; getitem_2_false_branch = l_z__false_branch = None
|
||||
add: "f32[s0, s1]" = l_x__1 + mul; l_x__1 = mul = None
|
||||
getitem: "f32[2, s1]" = add[slice(None, 2, None)]; add = None
|
||||
clone: "f32[2, s1]" = getitem.clone(); getitem = None
|
||||
return (clone,)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
|
|
|||
|
|
@ -403,10 +403,10 @@ arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None
|
|||
self.assertExpectedInline(
|
||||
post_grad_graphs,
|
||||
"""\
|
||||
def forward(self, arg0_1: "Sym(s72)", arg1_1: "f32[s72][1]cpu", arg2_1: "f32[s72][1]cpu", arg3_1: "f32[s72][1]cpu", arg4_1: "f32[s72][1]cpu", arg5_1: "f32[s72][1]cpu"):
|
||||
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu", arg3_1: "f32[s0][1]cpu", arg4_1: "f32[s0][1]cpu", arg5_1: "f32[s0][1]cpu"):
|
||||
foo_default = torch.ops.mylib.foo.default(arg3_1, [arg4_1, arg5_1], arg2_1, 2, arg1_1); arg4_1 = arg5_1 = arg1_1 = foo_default = None
|
||||
copy_: "f32[s72][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None
|
||||
copy__1: "f32[s72][1]cpu" = torch.ops.aten.copy_.default(arg3_1, arg3_1); arg3_1 = copy__1 = None
|
||||
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None
|
||||
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg3_1, arg3_1); arg3_1 = copy__1 = None
|
||||
return ()""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
|
|
@ -563,13 +563,13 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
|
|||
self.assertExpectedInline(
|
||||
graph_aot,
|
||||
"""\
|
||||
def forward(self, arg0_1: "Sym(s17)", arg1_1: "f32[s17][1]cpu", arg2_1: "f32[s17][1]cpu"):
|
||||
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"):
|
||||
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg2_1, arg1_1])
|
||||
getitem_1: "f32[s17][1]cpu" = auto_functionalized_v2[1]
|
||||
getitem_2: "f32[s17][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
|
||||
add: "f32[s17][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2)
|
||||
copy_: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy_ = None
|
||||
copy__1: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_1); arg2_1 = getitem_1 = copy__1 = None
|
||||
getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]
|
||||
getitem_2: "f32[s0][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
|
||||
add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2)
|
||||
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy_ = None
|
||||
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_1); arg2_1 = getitem_1 = copy__1 = None
|
||||
return (add,)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
|
|
@ -595,11 +595,11 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
|
|||
self.assertExpectedInline(
|
||||
graph_inductor,
|
||||
"""\
|
||||
def forward(self, arg0_1: "Sym(s17)", arg1_1: "f32[s17][1]cpu", arg2_1: "f32[s17][1]cpu"):
|
||||
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"):
|
||||
foo_default = torch.ops.mylib.foo.default(arg2_1, arg1_1); foo_default = None
|
||||
add: "f32[s17][1]cpu" = torch.ops.aten.add.Tensor(arg2_1, arg1_1)
|
||||
copy_: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
|
||||
copy__1: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
|
||||
add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(arg2_1, arg1_1)
|
||||
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
|
||||
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
|
||||
return (add,)""",
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
|
|
@ -663,10 +663,10 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
|
|||
self.assertExpectedInline(
|
||||
graph_aot,
|
||||
"""\
|
||||
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
|
||||
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
|
||||
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg1_1])
|
||||
getitem_1: "f32[s77][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
|
||||
copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy_ = None
|
||||
getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
|
||||
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy_ = None
|
||||
return ()""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
|
|
@ -691,11 +691,11 @@ def forward(self, arg0_1: "f32[2][1]cpu"):
|
|||
self.assertExpectedInline(
|
||||
graph_inductor,
|
||||
"""\
|
||||
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
|
||||
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
|
||||
as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 0)
|
||||
as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 1)
|
||||
foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None
|
||||
copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
|
||||
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
|
||||
return ()""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
|
|
@ -1291,14 +1291,14 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"):
|
|||
self.assertExpectedInline(
|
||||
graph_aot,
|
||||
"""\
|
||||
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77, s77][s77, 1]cpu"):
|
||||
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0, s0][s0, 1]cpu"):
|
||||
floordiv: "Sym(0)" = 0 // arg0_1; arg0_1 = None
|
||||
add_6: "Sym(2)" = floordiv + 2
|
||||
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_slice_dim = 0, _x_slice_start = floordiv, _x_slice_end = add_6, _y_base_index = 0, _y_slice_dim = 1, _y_slice_start = 3, _y_slice_end = 4, _all_bases = [arg1_1]); floordiv = add_6 = None
|
||||
getitem_1: "f32[s77, s77][s77, 1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
|
||||
copy_: "f32[s77, s77][s77, 1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None
|
||||
slice_3: "f32[2, s77][s77, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 0, 0, 2)
|
||||
slice_4: "f32[s77, 1][s77, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 1, 3, 4); getitem_1 = None
|
||||
getitem_1: "f32[s0, s0][s0, 1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
|
||||
copy_: "f32[s0, s0][s0, 1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None
|
||||
slice_3: "f32[2, s0][s0, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 0, 0, 2)
|
||||
slice_4: "f32[s0, 1][s0, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 1, 3, 4); getitem_1 = None
|
||||
return (slice_3, slice_4)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
|
|
@ -1324,13 +1324,13 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"):
|
|||
self.assertExpectedInline(
|
||||
graph_inductor,
|
||||
"""\
|
||||
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77, s77][s77, 1]cpu"):
|
||||
slice_tensor: "f32[2, s77][s77, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, 0, 2)
|
||||
slice_tensor_1: "f32[s77, 1][s77, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 1, 3, 4)
|
||||
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0, s0][s0, 1]cpu"):
|
||||
slice_tensor: "f32[2, s0][s0, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, 0, 2)
|
||||
slice_tensor_1: "f32[s0, 1][s0, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 1, 3, 4)
|
||||
foo_default = torch.ops.mylib.foo.default(slice_tensor, slice_tensor_1); slice_tensor = slice_tensor_1 = foo_default = None
|
||||
copy_: "f32[s77, s77][s77, 1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None
|
||||
slice_3: "f32[2, s77][s77, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, 0, 2)
|
||||
slice_4: "f32[s77, 1][s77, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 1, 3, 4); arg1_1 = None
|
||||
copy_: "f32[s0, s0][s0, 1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None
|
||||
slice_3: "f32[2, s0][s0, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, 0, 2)
|
||||
slice_4: "f32[s0, 1][s0, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 1, 3, 4); arg1_1 = None
|
||||
return (slice_3, slice_4)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
|
|
@ -1470,18 +1470,18 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"):
|
|||
self.assertExpectedInline(
|
||||
graph_aot,
|
||||
"""\
|
||||
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
|
||||
clone: "f32[s77][1]cpu" = torch.ops.aten.clone.default(arg1_1)
|
||||
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
|
||||
clone: "f32[s0][1]cpu" = torch.ops.aten.clone.default(arg1_1)
|
||||
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
|
||||
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
|
||||
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
_to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
|
||||
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg1_1, _to_copy]); _to_copy = None
|
||||
getitem_1: "f32[s77][1]cpu" = auto_functionalized_v2[1]
|
||||
getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]
|
||||
getitem_2: "f32[u0, 1][1, u0]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
|
||||
copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None
|
||||
alias_1: "f32[s77][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None
|
||||
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None
|
||||
alias_1: "f32[s0][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None
|
||||
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None
|
||||
return (alias_1, slice_2)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
|
|
@ -1517,16 +1517,16 @@ def forward(self, arg0_1: "f32[2][1]cpu"):
|
|||
self.assertExpectedInline(
|
||||
graph_inductor,
|
||||
"""\
|
||||
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
|
||||
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
|
||||
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg1_1)
|
||||
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
|
||||
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
|
||||
alias_default: "f32[s77][1]cpu" = torch.ops.aten.alias.default(arg1_1)
|
||||
alias_default: "f32[s0][1]cpu" = torch.ops.aten.alias.default(arg1_1)
|
||||
alias_default_1: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.alias.default(convert_element_type)
|
||||
foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); alias_default = alias_default_1 = foo_default = None
|
||||
copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None
|
||||
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None
|
||||
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None
|
||||
return (arg1_1, slice_2)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
|
|
|
|||
|
|
@ -11822,7 +11822,7 @@ class CommonTemplate:
|
|||
# 'i1 + 3 * i0' is cached.
|
||||
self.assertTrue(
|
||||
"i0 + 2 * i1" in mul_buf.data.inner_fn_str()
|
||||
or "i0 + i1 * s64" in mul_buf.data.inner_fn_str()
|
||||
or "i0 + i1 * s1" in mul_buf.data.inner_fn_str()
|
||||
)
|
||||
|
||||
with add_scheduler_init_hook(hook_fn):
|
||||
|
|
@ -12551,7 +12551,7 @@ class CommonTemplate:
|
|||
torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3)
|
||||
|
||||
if is_dynamic_shape_enabled():
|
||||
size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, s12, s80, s80., .3\*s12\*s80\*s80, s12\*s80\*s80, 1, s12\*s80, s1.." # noqa: B950
|
||||
size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, s1, s2, s2., .3\*s1\*s2\*s2, s1\*s2\*s2, 1, s1\*s2, s1.." # noqa: B950
|
||||
else:
|
||||
size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, 16, 32, 32., .49152, 16384, 1, 512, 16.."
|
||||
FileCheck().check_regex(size_assert_pattern).run(code)
|
||||
|
|
@ -12570,8 +12570,8 @@ class CommonTemplate:
|
|||
code = run_and_get_triton_code(f, x)
|
||||
|
||||
if is_dynamic_shape_enabled():
|
||||
FileCheck().check("assert_size_stride(buf1, (s77, s27), (s27, 1))").check(
|
||||
"assert_size_stride(buf2, (s77, s27), (s27, 1))"
|
||||
FileCheck().check("assert_size_stride(buf1, (s0, s1), (s1, 1))").check(
|
||||
"assert_size_stride(buf2, (s0, s1), (s1, 1))"
|
||||
).run(code)
|
||||
else:
|
||||
FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1))").check(
|
||||
|
|
|
|||
|
|
@ -393,7 +393,7 @@ class TestPySymInt(TestCase):
|
|||
self.assertEqual(res_and, 0b1000)
|
||||
self.assertIsInstance(res_and, torch.SymInt, msg=type(res_and))
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_and(s97, s26), 8)"""
|
||||
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_and(s0, s1), 8)"""
|
||||
)
|
||||
|
||||
a1 = create_symint(shape_env, 3)
|
||||
|
|
@ -415,7 +415,7 @@ class TestPySymInt(TestCase):
|
|||
self.assertEqual(res_or, 0b1110)
|
||||
self.assertIsInstance(res_or, torch.SymInt, msg=type(res_or))
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_or(s97, s26), 14)"""
|
||||
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_or(s0, s1), 14)"""
|
||||
)
|
||||
|
||||
def test_stride(self):
|
||||
|
|
@ -497,7 +497,7 @@ class TestPySymInt(TestCase):
|
|||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 2)
|
||||
self.assertEqual(guard_int(a0), 2)
|
||||
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""")
|
||||
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
|
||||
|
||||
def test_sym_sum(self):
|
||||
shape_env = ShapeEnv()
|
||||
|
|
@ -512,7 +512,7 @@ class TestPySymInt(TestCase):
|
|||
shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
|
||||
s0 = create_symint(shape_env, 2)
|
||||
self.assertEqual(guard_int(s0), 2)
|
||||
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""")
|
||||
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
|
||||
|
||||
shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
|
||||
s0 = create_symint(shape_env, 2)
|
||||
|
|
@ -520,7 +520,7 @@ class TestPySymInt(TestCase):
|
|||
self.assertEqual(len(shape_env.guards), 0)
|
||||
self.assertExpectedInline(
|
||||
str([ra.expr for ra in shape_env.deferred_runtime_asserts[None]]),
|
||||
"""[Eq(s97, 2)]""",
|
||||
"""[Eq(s0, 2)]""",
|
||||
)
|
||||
|
||||
def test_sym_int(self):
|
||||
|
|
@ -529,14 +529,14 @@ class TestPySymInt(TestCase):
|
|||
r = sym_int(a0)
|
||||
self.assertEqual(r, 5)
|
||||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 5)""")
|
||||
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 5)""")
|
||||
|
||||
a1 = create_symint(shape_env, 7)
|
||||
r = sym_int(a1 / 2)
|
||||
self.assertEqual(guard_int(r), 3)
|
||||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s26, 2)), 3)"""
|
||||
str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s1, 2)), 3)"""
|
||||
)
|
||||
|
||||
a3 = create_symint(shape_env, 3)
|
||||
|
|
@ -544,7 +544,7 @@ class TestPySymInt(TestCase):
|
|||
self.assertEqual(guard_int(r), 6)
|
||||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s57)), 6)"""
|
||||
str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s2)), 6)"""
|
||||
)
|
||||
|
||||
def test_sym_log2(self):
|
||||
|
|
@ -554,7 +554,7 @@ class TestPySymInt(TestCase):
|
|||
self.assertEqual(r, 2.0)
|
||||
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_log2(ToFloat(s97)), 2.0)"""
|
||||
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_log2(ToFloat(s0)), 2.0)"""
|
||||
)
|
||||
|
||||
def test_sym_sqrt(self):
|
||||
|
|
@ -564,7 +564,7 @@ class TestPySymInt(TestCase):
|
|||
self.assertEqual(r, 2)
|
||||
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(ToFloat(s97)), 2.0)"""
|
||||
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(ToFloat(s0)), 2.0)"""
|
||||
)
|
||||
|
||||
def test_sym_floor(self):
|
||||
|
|
@ -575,14 +575,14 @@ class TestPySymInt(TestCase):
|
|||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[0][0]),
|
||||
"""Eq(FloorToInt(IntTrueDiv(s97, 2)), 2)""",
|
||||
"""Eq(FloorToInt(IntTrueDiv(s0, 2)), 2)""",
|
||||
)
|
||||
r = math.floor(3.0 * a0)
|
||||
self.assertEqual(r, 15)
|
||||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[1][0]),
|
||||
"""Eq(FloorToInt(3.0*ToFloat(s97)), 15)""",
|
||||
"""Eq(FloorToInt(3.0*ToFloat(s0)), 15)""",
|
||||
)
|
||||
|
||||
def test_sym_trunc(self):
|
||||
|
|
@ -592,14 +592,14 @@ class TestPySymInt(TestCase):
|
|||
self.assertEqual(r, 2)
|
||||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s97, 2)), 2)"""
|
||||
str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s0, 2)), 2)"""
|
||||
)
|
||||
r = torch.sym_int(torch.sym_sqrt(a0))
|
||||
self.assertEqual(r, 2)
|
||||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[1][0]),
|
||||
"""Eq(TruncToInt(OpaqueUnaryFn_sqrt(ToFloat(s97))), 2)""",
|
||||
"""Eq(TruncToInt(OpaqueUnaryFn_sqrt(ToFloat(s0))), 2)""",
|
||||
)
|
||||
|
||||
def test_sym_ceil(self):
|
||||
|
|
@ -610,7 +610,7 @@ class TestPySymInt(TestCase):
|
|||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[0][0]),
|
||||
"""Eq(CeilToInt(IntTrueDiv(s97, 2)), 3)""",
|
||||
"""Eq(CeilToInt(IntTrueDiv(s0, 2)), 3)""",
|
||||
)
|
||||
r1 = 3.0 * a0
|
||||
r = math.floor(r1)
|
||||
|
|
@ -618,7 +618,7 @@ class TestPySymInt(TestCase):
|
|||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[1][0]),
|
||||
"""Eq(FloorToInt(3.0*ToFloat(s97)), 15)""",
|
||||
"""Eq(FloorToInt(3.0*ToFloat(s0)), 15)""",
|
||||
)
|
||||
|
||||
def test_sym_ite(self):
|
||||
|
|
@ -638,7 +638,7 @@ class TestPySymInt(TestCase):
|
|||
self.assertEqual(type(t), type(r3))
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[0][0]),
|
||||
"""Eq(Piecewise((s97, Eq(s97, 5)), (s26, True)), 5)""",
|
||||
"""Eq(Piecewise((s0, Eq(s0, 5)), (s1, True)), 5)""",
|
||||
)
|
||||
b4 = f == 5
|
||||
r4 = torch.sym_ite(b4, t, f)
|
||||
|
|
@ -647,7 +647,7 @@ class TestPySymInt(TestCase):
|
|||
self.assertEqual(type(f), type(r4))
|
||||
self.assertExpectedInline(
|
||||
str(shape_env.guards[1][0]),
|
||||
"""Eq(Piecewise((s97, Eq(s26, 5)), (s26, True)), 4)""",
|
||||
"""Eq(Piecewise((s0, Eq(s1, 5)), (s1, True)), 4)""",
|
||||
)
|
||||
|
||||
def test_tracing_sym_ite(self):
|
||||
|
|
@ -679,7 +679,7 @@ def forward(self, x_1):
|
|||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 2)
|
||||
int(a0)
|
||||
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""")
|
||||
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
|
||||
|
||||
def test_data_dependent_guard(self):
|
||||
shape_env = ShapeEnv()
|
||||
|
|
@ -710,7 +710,7 @@ def forward(self, x_1):
|
|||
self.assertTrue(expect_true(i0 < s0))
|
||||
self.assertExpectedInline(
|
||||
str([ra.expr for ra in shape_env.deferred_runtime_asserts[i0.node.expr]]),
|
||||
"""[u0 < s97]""",
|
||||
"""[u0 < s0]""",
|
||||
)
|
||||
self.assertTrue(i0 < s0)
|
||||
self.assertTrue(i0 != s0)
|
||||
|
|
@ -1173,18 +1173,18 @@ def forward(self, x_1):
|
|||
out.strip(),
|
||||
"""\
|
||||
class f(torch.nn.Module):
|
||||
def forward(self, a_1: "f32[s75, s96]", b_1: "f32[s57, s96]"):
|
||||
def forward(self, a_1: "f32[s0, s1]", b_1: "f32[s2, s1]"):
|
||||
# No stacktrace found for following nodes
|
||||
sym_size_int: "Sym(s75)" = torch.ops.aten.sym_size.int(a_1, 0)
|
||||
sym_size_int_1: "Sym(s57)" = torch.ops.aten.sym_size.int(b_1, 0)
|
||||
add: "Sym(s57 + s75)" = sym_size_int + sym_size_int_1; sym_size_int = sym_size_int_1 = None
|
||||
sym_size_int_2: "Sym(s96)" = torch.ops.aten.sym_size.int(a_1, 1)
|
||||
sym_size_int_3: "Sym(s96)" = torch.ops.aten.sym_size.int(b_1, 1); b_1 = None
|
||||
add_1: "Sym(2*s96)" = sym_size_int_2 + sym_size_int_3; sym_size_int_2 = sym_size_int_3 = None
|
||||
new_empty: "f32[s57 + s75, 2*s96]" = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None
|
||||
sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(a_1, 0)
|
||||
sym_size_int_1: "Sym(s2)" = torch.ops.aten.sym_size.int(b_1, 0)
|
||||
add: "Sym(s0 + s2)" = sym_size_int + sym_size_int_1; sym_size_int = sym_size_int_1 = None
|
||||
sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(a_1, 1)
|
||||
sym_size_int_3: "Sym(s1)" = torch.ops.aten.sym_size.int(b_1, 1); b_1 = None
|
||||
add_1: "Sym(2*s1)" = sym_size_int_2 + sym_size_int_3; sym_size_int_2 = sym_size_int_3 = None
|
||||
new_empty: "f32[s0 + s2, 2*s1]" = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None
|
||||
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
|
||||
getitem: "f32[s57 + s75, 2*s96]" = native_dropout[0]
|
||||
getitem_1: "b8[s57 + s75, 2*s96]" = native_dropout[1]; native_dropout = None
|
||||
getitem: "f32[s0 + s2, 2*s1]" = native_dropout[0]
|
||||
getitem_1: "b8[s0 + s2, 2*s1]" = native_dropout[1]; native_dropout = None
|
||||
return (getitem, getitem_1)""", # noqa: B950
|
||||
)
|
||||
|
||||
|
|
@ -2846,8 +2846,8 @@ class TestGuardsExpressions(TestCase):
|
|||
],
|
||||
)
|
||||
|
||||
self.assertEqual(f"{x.stride()}", "(s49, 1)")
|
||||
self.assertEqual(f"{x.shape}", "torch.Size([s26, s49])")
|
||||
self.assertEqual(f"{x.stride()}", "(s1, 1)")
|
||||
self.assertEqual(f"{x.shape}", "torch.Size([s0, s1])")
|
||||
|
||||
x_clean = _remove_symbols_without_guarding(x, 4096)
|
||||
|
||||
|
|
|
|||
|
|
@ -1084,7 +1084,7 @@ def forward(self, x_1, y_1):
|
|||
test_inputs.append([(6, 8)])
|
||||
gm = self._test_dynamic(f, [(3, 4)], test_inputs)
|
||||
self.assertTrue(eval_guards(gm, torch.randn(4, 5)))
|
||||
self.assertEqual(repr(bind_symbols(gm, torch.randn(4, 5))), "{s75: 4, s96: 5}")
|
||||
self.assertEqual(repr(bind_symbols(gm, torch.randn(4, 5))), "{s0: 4, s1: 5}")
|
||||
self.assertFalse(eval_guards(gm, torch.randn(25, 5)))
|
||||
self.assertExpectedInline(show_guards(gm), """L['x'].size()[0] <= 19""")
|
||||
|
||||
|
|
@ -1717,7 +1717,7 @@ def forward(self, a_1):
|
|||
gm = self._test_dynamic(f, [(1, 6), (8, 1)], test_inputs)
|
||||
self.assertTrue(eval_guards(gm, torch.randn(1, 10), torch.randn(6, 1)))
|
||||
self.assertFalse(eval_guards(gm, torch.randn(1, 2), torch.randn(4, 1)))
|
||||
self.assertExpectedInline(show_guards(gm), """2*L['b'].size()[0]*L['a'].size()[1] > 20""")
|
||||
self.assertExpectedInline(show_guards(gm), """2*L['a'].size()[1]*L['b'].size()[0] > 20""")
|
||||
|
||||
def test_new_empty(self):
|
||||
def f(a, b):
|
||||
|
|
|
|||
|
|
@ -869,12 +869,6 @@ def _optimized_add(
|
|||
if sortkey(lhs._args[0]) > sortkey(rhs._args[-1]):
|
||||
return make_optimized(rhs._args + lhs._args)
|
||||
|
||||
# (a1+a3) + (a0+a2) => (a0+a1+a2+a3)
|
||||
new_args = list(lhs._args)
|
||||
for a in rhs._args:
|
||||
new_args = _binary_search_insert_arg(new_args, a)
|
||||
return make_optimized(new_args)
|
||||
|
||||
# (a0+a2) + a1 => (a0+a1+a2)
|
||||
if lhs_is_optimized_summation and rhs.is_symbol:
|
||||
new_args = _binary_search_insert_arg(list(lhs._args), rhs)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ import atexit
|
|||
import collections
|
||||
import dis
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import itertools
|
||||
import logging
|
||||
|
|
@ -3290,11 +3289,6 @@ class ShapeEnv:
|
|||
|
||||
self.guards: list[ShapeGuard] = []
|
||||
self.axioms: dict[sympy.Expr, sympy.Expr] = {}
|
||||
|
||||
# A set of ids that have already been allocated. This is used
|
||||
# for when we allocate symbol ids using the hash of the source
|
||||
# names to ensure we don't have collisions via linear probing
|
||||
self.unique_ids: set[int] = set()
|
||||
# Maps symbolic ints to their original concrete values
|
||||
# Currently populated from tensors
|
||||
self.var_to_val: dict[sympy.Symbol, sympy.Integer] = {}
|
||||
|
|
@ -4546,14 +4540,13 @@ class ShapeEnv:
|
|||
# If we're not duck shaping, we always create a new symbol
|
||||
# Even if we're duck shaping, if we haven't seen this particular
|
||||
# value before, we also create a new symbol
|
||||
symbol_id = self._generate_unique_id(source.name())
|
||||
if type(val) is int or is_nested_int(val):
|
||||
sympy_expr = make_symbol(
|
||||
SymT.SIZE, symbol_id, positive=positive, integer=True
|
||||
SymT.SIZE, len(self.var_to_val), positive=positive, integer=True
|
||||
)
|
||||
else:
|
||||
sympy_expr = make_symbol(
|
||||
SymT.FLOAT, symbol_id, positive=positive, real=True
|
||||
SymT.FLOAT, len(self.var_to_val), positive=positive, real=True
|
||||
)
|
||||
self.source_to_var[source_name] = sympy_expr
|
||||
# We always associate vars to vals
|
||||
|
|
@ -6565,13 +6558,6 @@ class ShapeEnv:
|
|||
sloc, _ = self._get_stack_summary(framework_loc=framework_loc)
|
||||
return sloc
|
||||
|
||||
def _generate_unique_id(self, source_name: str) -> int:
|
||||
attempt = int(hashlib.sha256(source_name.encode()).hexdigest(), 16) % 100
|
||||
while attempt in self.unique_ids:
|
||||
attempt += 1
|
||||
self.unique_ids.add(attempt)
|
||||
return attempt
|
||||
|
||||
def _find_frame_locals(self) -> _FrameLocalResult:
|
||||
"""
|
||||
Given the current user code frame, finds the relevant lines of code,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user