remove dead code for suggesting legacy dynamic shapes fixes (#133700)

Summary: `dynamic_dim` based dynamic shapes are long gone, so pretty-printing suggested fixes for them is dead code.

Test Plan: existing tests

Differential Revision: D61398303

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133700
Approved by: https://github.com/zhxchen17
This commit is contained in:
Avik Chaudhuri 2024-08-17 01:59:34 +00:00 committed by PyTorch MergeBot
parent 455f6bda56
commit 695d7db2d6
2 changed files with 109 additions and 236 deletions

View File

@ -2,11 +2,9 @@
import contextlib
import copy
import inspect
import itertools
import math
import operator
import re
import unittest
import numpy as np
@ -2525,47 +2523,6 @@ class TestDimConstraints(TestCase):
},
)
def dummy_fn(a, b, c, d, e, f):
pass
action_code = dim_constraints.prettify_results(inspect.signature(dummy_fn), {})
static_code, dynamic_code = re.findall(r"```(.*?)```", action_code, re.DOTALL)
expected_static = """
def specializations(a, b, c, d, e, f):
# a:
assert a.size()[0] == 8
assert a.size()[1] == 22
assert a.size()[2] == 96
assert a.size()[3] == 96
# b:
assert b.size()[0] == 8
assert b.size()[1] == 22
assert b.size()[2] == 3
# c:
assert c.size()[0] == 8
# d:
assert d.size()[0] == 8
# f:
assert f.size()[1] == 1
"""
expected_dynamic = """
def specify_constraints(a, b, c, d, e, f):
return [
# d:
dynamic_dim(d, 1) == dynamic_dim(c, 1),
# e:
dynamic_dim(e, 1) == dynamic_dim(c, 1),
]
"""
self.assertEqual(static_code, expected_static)
self.assertEqual(dynamic_code, expected_dynamic)
class TestGuardsExpressions(TestCase):
"""

View File

@ -2170,7 +2170,9 @@ class DimConstraints:
):
"""Format a message for constraint violation erros"""
from torch.export.dynamic_shapes import _get_dim_name_mapping
if self._dcp.source_name_to_debug_name:
if not self._dcp.source_name_to_debug_name:
# nothing to do
return ""
def transform(s, inverse=False):
for k, v in self._dcp.source_name_to_debug_name.items():
@ -2293,92 +2295,6 @@ class DimConstraints:
return buf
# Note: Model inputs are wrapped as LocalSource in dynamo.
# LocalSource.name() wraps the name with L[""]. We use regular
# expression to do the replacement to avoid traversing up
# the source hierarchy manually.
def extract_and_rewrite_local(dc):
match = re.search(r"L\['(.+?)'\]", dc)
if match is None:
return
arg = match.expand(r'\1')
dc = re.sub(r"L\['(.+?)'\]", r'\1', dc)
return arg, dc
def group(results, args_index):
groups = defaultdict(list)
for dc in results:
local = extract_and_rewrite_local(dc)
if local is None:
# This can happen, e.g., with `assume_constant_result`.
# In that case, we drop the constraint.
# TODO(avik) Maybe we should generate an assertion here?
continue
arg, dc = local
if arg in args_index:
groups[args_index[arg]].append(dc)
else:
# This can happen, e.g., with decorators that change the signature.
# In that case, we drop the constraint. Seems hard to do better. :/
# TODO(avik) Maybe warn that `arg` in not in `signature`?
continue
sorted_groups = []
for idx, dcs in sorted(groups.items()):
_, arg = idx
sorted_groups.append((arg, sorted(dcs)))
return sorted_groups
signature = original_signature.replace(return_annotation=inspect.Signature.empty)
args_index = {}
for i, arg in enumerate(signature.parameters.keys()):
args_index[arg] = (i, arg)
def print_results(grouped, indent, result_fn):
nonlocal buf
space = False
for arg, results in grouped:
if space:
buf += "\n"
else:
space = True
buf += f"\n{indent}# {arg}:"
for result in results:
buf += f"\n{indent}{result_fn(result)}"
buf = ""
if forced_specializations:
buf += (
"Some dynamic dimensions need to be specialized because "
"the constraints inferred for them are too complex to specify.\n"
)
for s, val in forced_specializations.items():
buf += f" - {s}, which was marked dynamic, must be specialized to {val}.\n"
indent = 4 * " "
if self._static_results:
grouped_static_results = group(self._static_results, args_index)
buf += "\nThe following dimensions have been specialized and CANNOT be dynamic."
buf += f"\n```\ndef specializations{str(signature)}:"
print_results(
grouped_static_results,
indent,
lambda result: f"assert {result}",
)
buf += "\n```\n"
if self._dynamic_results:
grouped_dynamic_results = group(self._dynamic_results, args_index)
buf += "\nThe following dimensions CAN be dynamic."
buf += "\nPlease use the following code to specify the constraints they must satisfy:"
buf += f"\n```\ndef specify_constraints{str(signature)}:"
buf += f"\n{indent}return ["
print_results(
grouped_dynamic_results,
indent * 2,
lambda result: f"{result},",
)
buf += f"\n{indent}]\n```\n"
return buf
TLS = threading.local()