remove dead code for suggesting legacy dynamic shapes fixes (#133700)

Summary: `dynamic_dim` based dynamic shapes are long gone, so pretty-printing suggested fixes for them is dead code.

Test Plan: existing tests

Differential Revision: D61398303

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133700
Approved by: https://github.com/zhxchen17
This commit is contained in:
Avik Chaudhuri 2024-08-17 01:59:34 +00:00 committed by PyTorch MergeBot
parent 455f6bda56
commit 695d7db2d6
2 changed files with 109 additions and 236 deletions

View File

@ -2,11 +2,9 @@
import contextlib
import copy
import inspect
import itertools
import math
import operator
import re
import unittest
import numpy as np
@ -2525,47 +2523,6 @@ class TestDimConstraints(TestCase):
},
)
def dummy_fn(a, b, c, d, e, f):
pass
action_code = dim_constraints.prettify_results(inspect.signature(dummy_fn), {})
static_code, dynamic_code = re.findall(r"```(.*?)```", action_code, re.DOTALL)
expected_static = """
def specializations(a, b, c, d, e, f):
# a:
assert a.size()[0] == 8
assert a.size()[1] == 22
assert a.size()[2] == 96
assert a.size()[3] == 96
# b:
assert b.size()[0] == 8
assert b.size()[1] == 22
assert b.size()[2] == 3
# c:
assert c.size()[0] == 8
# d:
assert d.size()[0] == 8
# f:
assert f.size()[1] == 1
"""
expected_dynamic = """
def specify_constraints(a, b, c, d, e, f):
return [
# d:
dynamic_dim(d, 1) == dynamic_dim(c, 1),
# e:
dynamic_dim(e, 1) == dynamic_dim(c, 1),
]
"""
self.assertEqual(static_code, expected_static)
self.assertEqual(dynamic_code, expected_dynamic)
class TestGuardsExpressions(TestCase):
"""

View File

@ -2170,213 +2170,129 @@ class DimConstraints:
):
"""Format a message for constraint violation erros"""
from torch.export.dynamic_shapes import _get_dim_name_mapping
if self._dcp.source_name_to_debug_name:
if not self._dcp.source_name_to_debug_name:
# nothing to do
return ""
def transform(s, inverse=False):
for k, v in self._dcp.source_name_to_debug_name.items():
s = s.replace(k, v) if not inverse else s.replace(v, k)
return s
def transform(s, inverse=False):
for k, v in self._dcp.source_name_to_debug_name.items():
s = s.replace(k, v) if not inverse else s.replace(v, k)
return s
results = defaultdict(dict)
if dynamic_shapes is None:
dynamic_shapes = {}
results = defaultdict(dict)
if dynamic_shapes is None:
dynamic_shapes = {}
def flip(op):
if op == "<=":
return ">="
if op == ">=":
return "<="
if op == "<":
return ">"
if op == ">":
return "<"
def flip(op):
if op == "<=":
return ">="
if op == ">=":
return "<="
if op == "<":
return ">"
if op == ">":
return "<"
assert op == "=="
return op
def relation_with_digit(expr, op, digit):
if op == "<=":
results[expr]["max"] = digit
elif op == "<":
results[expr]["max"] = digit - 1
elif op == ">=":
results[expr]["min"] = digit
elif op == ">":
results[expr]["min"] = digit + 1
else:
assert op == "=="
return op
results[expr]["eq"] = digit
def relation_with_digit(expr, op, digit):
if op == "<=":
results[expr]["max"] = digit
elif op == "<":
results[expr]["max"] = digit - 1
elif op == ">=":
results[expr]["min"] = digit
elif op == ">":
results[expr]["min"] = digit + 1
else:
assert op == "=="
results[expr]["eq"] = digit
# retrieve dynamic shapes
name_to_dim = _get_dim_name_mapping(dynamic_shapes)
# retrieve dynamic shapes
name_to_dim = _get_dim_name_mapping(dynamic_shapes)
for s in self._static_results.union(self._dynamic_results):
t = transform(s)
if t == s:
continue
left, op, right = re.split(r"( == | <= | >= | < | > )", t)
op = op.strip()
if op == "==" and left == right:
continue
if right.isdigit():
relation_with_digit(left, op, int(right))
elif left.isdigit():
relation_with_digit(right, flip(op), int(left))
else:
assert op == "==", t
results[left]["eq"] = sympy.sympify(right)
for s in self._static_results.union(self._dynamic_results):
t = transform(s)
if t == s:
continue
left, op, right = re.split(r"( == | <= | >= | < | > )", t)
op = op.strip()
if op == "==" and left == right:
continue
if right.isdigit():
relation_with_digit(left, op, int(right))
elif left.isdigit():
relation_with_digit(right, flip(op), int(left))
else:
assert op == "==", t
results[left]["eq"] = sympy.sympify(right)
# order forced specializations based on name
forced_specializations = {
k: forced_specializations[k]
for k in sorted(
forced_specializations.keys(),
key=lambda x: x.split(" = ")[1],
)
}
buf = ""
if forced_specializations:
debug_names = set()
for k in forced_specializations:
dim = name_to_dim[k.split(" = ")[0]]
if self._is_derived_dim(dim):
debug_names.add(dim.root.__name__)
else:
debug_names.add(dim.__name__)
buf += (
f"Specializations unexpectedly required ({', '.join(sorted(debug_names))})! "
'For more information, run with TORCH_LOGS="+dynamic".\n'
)
for s, val in forced_specializations.items():
buf += f" - solving the guards generated for {s} resulted in a specialized value of {val}.\n"
self._process_derived_dim_roots(results, name_to_dim)
dims = []
others = []
# order results by source name
results = {
k: results[k] for k in sorted(
results.keys(),
key=lambda x: transform(x, inverse=True),
)
}
for k, c in results.items():
if "eq" in c:
other = c["eq"]
if isinstance(other, int):
others.append(f"{k} = {other}")
elif _is_supported_equivalence(other):
others.append(f"{k} = {other}")
else:
min_ = c.get("min", None)
if min_ == 2:
min_ = None
max_ = c.get("max", None)
if min_ is not None and max_ is not None:
dims.append(f"{k} = Dim('{k}', min={min_}, max={max_})")
elif min_ is not None:
dims.append(f"{k} = Dim('{k}', min={min_})")
elif max_ is not None:
dims.append(f"{k} = Dim('{k}', max={max_})")
else:
dims.append(f"{k} = Dim('{k}')")
# results will get filtered out if no new suggestions,
# this can happen if guards are too complex.
# in that case don't suggest fix
if dims or others:
buf += "\nSuggested fixes:\n "
buf += "\n ".join(dims + others)
return buf
# Note: Model inputs are wrapped as LocalSource in dynamo.
# LocalSource.name() wraps the name with L[""]. We use regular
# expression to do the replacement to avoid traversing up
# the source hierarchy manually.
def extract_and_rewrite_local(dc):
match = re.search(r"L\['(.+?)'\]", dc)
if match is None:
return
arg = match.expand(r'\1')
dc = re.sub(r"L\['(.+?)'\]", r'\1', dc)
return arg, dc
def group(results, args_index):
groups = defaultdict(list)
for dc in results:
local = extract_and_rewrite_local(dc)
if local is None:
# This can happen, e.g., with `assume_constant_result`.
# In that case, we drop the constraint.
# TODO(avik) Maybe we should generate an assertion here?
continue
arg, dc = local
if arg in args_index:
groups[args_index[arg]].append(dc)
else:
# This can happen, e.g., with decorators that change the signature.
# In that case, we drop the constraint. Seems hard to do better. :/
# TODO(avik) Maybe warn that `arg` in not in `signature`?
continue
sorted_groups = []
for idx, dcs in sorted(groups.items()):
_, arg = idx
sorted_groups.append((arg, sorted(dcs)))
return sorted_groups
signature = original_signature.replace(return_annotation=inspect.Signature.empty)
args_index = {}
for i, arg in enumerate(signature.parameters.keys()):
args_index[arg] = (i, arg)
def print_results(grouped, indent, result_fn):
nonlocal buf
space = False
for arg, results in grouped:
if space:
buf += "\n"
else:
space = True
buf += f"\n{indent}# {arg}:"
for result in results:
buf += f"\n{indent}{result_fn(result)}"
# order forced specializations based on name
forced_specializations = {
k: forced_specializations[k]
for k in sorted(
forced_specializations.keys(),
key=lambda x: x.split(" = ")[1],
)
}
buf = ""
if forced_specializations:
debug_names = set()
for k in forced_specializations:
dim = name_to_dim[k.split(" = ")[0]]
if self._is_derived_dim(dim):
debug_names.add(dim.root.__name__)
else:
debug_names.add(dim.__name__)
buf += (
"Some dynamic dimensions need to be specialized because "
"the constraints inferred for them are too complex to specify.\n"
f"Specializations unexpectedly required ({', '.join(sorted(debug_names))})! "
'For more information, run with TORCH_LOGS="+dynamic".\n'
)
for s, val in forced_specializations.items():
buf += f" - {s}, which was marked dynamic, must be specialized to {val}.\n"
indent = 4 * " "
if self._static_results:
grouped_static_results = group(self._static_results, args_index)
buf += "\nThe following dimensions have been specialized and CANNOT be dynamic."
buf += f"\n```\ndef specializations{str(signature)}:"
print_results(
grouped_static_results,
indent,
lambda result: f"assert {result}",
buf += f" - solving the guards generated for {s} resulted in a specialized value of {val}.\n"
self._process_derived_dim_roots(results, name_to_dim)
dims = []
others = []
# order results by source name
results = {
k: results[k] for k in sorted(
results.keys(),
key=lambda x: transform(x, inverse=True),
)
buf += "\n```\n"
if self._dynamic_results:
grouped_dynamic_results = group(self._dynamic_results, args_index)
buf += "\nThe following dimensions CAN be dynamic."
buf += "\nPlease use the following code to specify the constraints they must satisfy:"
buf += f"\n```\ndef specify_constraints{str(signature)}:"
buf += f"\n{indent}return ["
print_results(
grouped_dynamic_results,
indent * 2,
lambda result: f"{result},",
)
buf += f"\n{indent}]\n```\n"
}
for k, c in results.items():
if "eq" in c:
other = c["eq"]
if isinstance(other, int):
others.append(f"{k} = {other}")
elif _is_supported_equivalence(other):
others.append(f"{k} = {other}")
else:
min_ = c.get("min", None)
if min_ == 2:
min_ = None
max_ = c.get("max", None)
if min_ is not None and max_ is not None:
dims.append(f"{k} = Dim('{k}', min={min_}, max={max_})")
elif min_ is not None:
dims.append(f"{k} = Dim('{k}', min={min_})")
elif max_ is not None:
dims.append(f"{k} = Dim('{k}', max={max_})")
else:
dims.append(f"{k} = Dim('{k}')")
# results will get filtered out if no new suggestions,
# this can happen if guards are too complex.
# in that case don't suggest fix
if dims or others:
buf += "\nSuggested fixes:\n "
buf += "\n ".join(dims + others)
return buf