[Graph Partition] remove PRECOMPUTED_SIZE from partition symbol inputs (#152864)

PRECOMPUTED_SIZE is computed during runtime and should not be included in graph_partition_inputs. See the following example for a PRECOMPUTED_SIZE `ps0`.

![image](https://github.com/user-attachments/assets/5aa949a9-b8e0-4b77-8702-95b96b58694e)

full output code: [P1803820480](https://www.internalfb.com/phabricator/paste/view/P1803820480)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152864
Approved by: https://github.com/eellison
This commit is contained in:
Boyuan Feng 2025-05-06 17:35:29 +00:00 committed by PyTorch MergeBot
parent 5d36485b4a
commit 7dd9d514d2
2 changed files with 24 additions and 1 deletions

View File

@ -13255,6 +13255,30 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
assert len(inps) == 0
@torch._inductor.config.patch("graph_partition", True)
def test_graph_partition_pad_dynamic(self):
def get_same_padding(x: int, k: int, s: int, d: int):
return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
def pad_same(x, k, s, d=(1, 1), value=0):
ih, iw = x.size()[-2:]
pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(
iw, k[1], s[1], d[1]
)
if pad_h > 0 or pad_w > 0:
x = torch.nn.functional.pad(
x,
[pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2],
value=value,
)
return x
x = torch.randn(2, 24, 110, 110, device=self.device)
opt = torch.compile(pad_same, dynamic=True)
res = opt(x, (5, 5), (2, 2))
ref = pad_same(x, (5, 5), (2, 2))
self.assertEqual(res, ref, atol=0, rtol=0)
def test_remove_noop_view_default(self):
def f(x):
batch_size = x.shape[0]

View File

@ -4183,7 +4183,6 @@ class Scheduler:
SymT.FLOAT,
SymT.UNBACKED_INT,
SymT.UNBACKED_FLOAT,
SymT.PRECOMPUTED_SIZE,
),
)
)