mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
remove dead code for suggesting legacy dynamic shapes fixes (#133700)
Summary: `dynamic_dim` based dynamic shapes are long gone, so pretty-printing suggested fixes for them is dead code. Test Plan: existing tests Differential Revision: D61398303 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133700 Approved by: https://github.com/zhxchen17
This commit is contained in:
parent
455f6bda56
commit
695d7db2d6
|
|
@ -2,11 +2,9 @@
|
|||
|
||||
import contextlib
|
||||
import copy
|
||||
import inspect
|
||||
import itertools
|
||||
import math
|
||||
import operator
|
||||
import re
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -2525,47 +2523,6 @@ class TestDimConstraints(TestCase):
|
|||
},
|
||||
)
|
||||
|
||||
def dummy_fn(a, b, c, d, e, f):
|
||||
pass
|
||||
|
||||
action_code = dim_constraints.prettify_results(inspect.signature(dummy_fn), {})
|
||||
static_code, dynamic_code = re.findall(r"```(.*?)```", action_code, re.DOTALL)
|
||||
expected_static = """
|
||||
def specializations(a, b, c, d, e, f):
|
||||
# a:
|
||||
assert a.size()[0] == 8
|
||||
assert a.size()[1] == 22
|
||||
assert a.size()[2] == 96
|
||||
assert a.size()[3] == 96
|
||||
|
||||
# b:
|
||||
assert b.size()[0] == 8
|
||||
assert b.size()[1] == 22
|
||||
assert b.size()[2] == 3
|
||||
|
||||
# c:
|
||||
assert c.size()[0] == 8
|
||||
|
||||
# d:
|
||||
assert d.size()[0] == 8
|
||||
|
||||
# f:
|
||||
assert f.size()[1] == 1
|
||||
"""
|
||||
expected_dynamic = """
|
||||
def specify_constraints(a, b, c, d, e, f):
|
||||
return [
|
||||
# d:
|
||||
dynamic_dim(d, 1) == dynamic_dim(c, 1),
|
||||
|
||||
# e:
|
||||
dynamic_dim(e, 1) == dynamic_dim(c, 1),
|
||||
]
|
||||
"""
|
||||
|
||||
self.assertEqual(static_code, expected_static)
|
||||
self.assertEqual(dynamic_code, expected_dynamic)
|
||||
|
||||
|
||||
class TestGuardsExpressions(TestCase):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -2170,213 +2170,129 @@ class DimConstraints:
|
|||
):
|
||||
"""Format a message for constraint violation erros"""
|
||||
from torch.export.dynamic_shapes import _get_dim_name_mapping
|
||||
if self._dcp.source_name_to_debug_name:
|
||||
if not self._dcp.source_name_to_debug_name:
|
||||
# nothing to do
|
||||
return ""
|
||||
|
||||
def transform(s, inverse=False):
|
||||
for k, v in self._dcp.source_name_to_debug_name.items():
|
||||
s = s.replace(k, v) if not inverse else s.replace(v, k)
|
||||
return s
|
||||
def transform(s, inverse=False):
|
||||
for k, v in self._dcp.source_name_to_debug_name.items():
|
||||
s = s.replace(k, v) if not inverse else s.replace(v, k)
|
||||
return s
|
||||
|
||||
results = defaultdict(dict)
|
||||
if dynamic_shapes is None:
|
||||
dynamic_shapes = {}
|
||||
results = defaultdict(dict)
|
||||
if dynamic_shapes is None:
|
||||
dynamic_shapes = {}
|
||||
|
||||
def flip(op):
|
||||
if op == "<=":
|
||||
return ">="
|
||||
if op == ">=":
|
||||
return "<="
|
||||
if op == "<":
|
||||
return ">"
|
||||
if op == ">":
|
||||
return "<"
|
||||
def flip(op):
|
||||
if op == "<=":
|
||||
return ">="
|
||||
if op == ">=":
|
||||
return "<="
|
||||
if op == "<":
|
||||
return ">"
|
||||
if op == ">":
|
||||
return "<"
|
||||
assert op == "=="
|
||||
return op
|
||||
|
||||
def relation_with_digit(expr, op, digit):
|
||||
if op == "<=":
|
||||
results[expr]["max"] = digit
|
||||
elif op == "<":
|
||||
results[expr]["max"] = digit - 1
|
||||
elif op == ">=":
|
||||
results[expr]["min"] = digit
|
||||
elif op == ">":
|
||||
results[expr]["min"] = digit + 1
|
||||
else:
|
||||
assert op == "=="
|
||||
return op
|
||||
results[expr]["eq"] = digit
|
||||
|
||||
def relation_with_digit(expr, op, digit):
|
||||
if op == "<=":
|
||||
results[expr]["max"] = digit
|
||||
elif op == "<":
|
||||
results[expr]["max"] = digit - 1
|
||||
elif op == ">=":
|
||||
results[expr]["min"] = digit
|
||||
elif op == ">":
|
||||
results[expr]["min"] = digit + 1
|
||||
else:
|
||||
assert op == "=="
|
||||
results[expr]["eq"] = digit
|
||||
# retrieve dynamic shapes
|
||||
name_to_dim = _get_dim_name_mapping(dynamic_shapes)
|
||||
|
||||
# retrieve dynamic shapes
|
||||
name_to_dim = _get_dim_name_mapping(dynamic_shapes)
|
||||
for s in self._static_results.union(self._dynamic_results):
|
||||
t = transform(s)
|
||||
if t == s:
|
||||
continue
|
||||
left, op, right = re.split(r"( == | <= | >= | < | > )", t)
|
||||
op = op.strip()
|
||||
if op == "==" and left == right:
|
||||
continue
|
||||
if right.isdigit():
|
||||
relation_with_digit(left, op, int(right))
|
||||
elif left.isdigit():
|
||||
relation_with_digit(right, flip(op), int(left))
|
||||
else:
|
||||
assert op == "==", t
|
||||
results[left]["eq"] = sympy.sympify(right)
|
||||
|
||||
for s in self._static_results.union(self._dynamic_results):
|
||||
t = transform(s)
|
||||
if t == s:
|
||||
continue
|
||||
left, op, right = re.split(r"( == | <= | >= | < | > )", t)
|
||||
op = op.strip()
|
||||
if op == "==" and left == right:
|
||||
continue
|
||||
if right.isdigit():
|
||||
relation_with_digit(left, op, int(right))
|
||||
elif left.isdigit():
|
||||
relation_with_digit(right, flip(op), int(left))
|
||||
else:
|
||||
assert op == "==", t
|
||||
results[left]["eq"] = sympy.sympify(right)
|
||||
|
||||
# order forced specializations based on name
|
||||
forced_specializations = {
|
||||
k: forced_specializations[k]
|
||||
for k in sorted(
|
||||
forced_specializations.keys(),
|
||||
key=lambda x: x.split(" = ")[1],
|
||||
)
|
||||
}
|
||||
|
||||
buf = ""
|
||||
if forced_specializations:
|
||||
debug_names = set()
|
||||
for k in forced_specializations:
|
||||
dim = name_to_dim[k.split(" = ")[0]]
|
||||
if self._is_derived_dim(dim):
|
||||
debug_names.add(dim.root.__name__)
|
||||
else:
|
||||
debug_names.add(dim.__name__)
|
||||
|
||||
buf += (
|
||||
f"Specializations unexpectedly required ({', '.join(sorted(debug_names))})! "
|
||||
'For more information, run with TORCH_LOGS="+dynamic".\n'
|
||||
)
|
||||
for s, val in forced_specializations.items():
|
||||
buf += f" - solving the guards generated for {s} resulted in a specialized value of {val}.\n"
|
||||
|
||||
self._process_derived_dim_roots(results, name_to_dim)
|
||||
|
||||
dims = []
|
||||
others = []
|
||||
|
||||
# order results by source name
|
||||
results = {
|
||||
k: results[k] for k in sorted(
|
||||
results.keys(),
|
||||
key=lambda x: transform(x, inverse=True),
|
||||
)
|
||||
}
|
||||
for k, c in results.items():
|
||||
if "eq" in c:
|
||||
other = c["eq"]
|
||||
if isinstance(other, int):
|
||||
others.append(f"{k} = {other}")
|
||||
elif _is_supported_equivalence(other):
|
||||
others.append(f"{k} = {other}")
|
||||
else:
|
||||
min_ = c.get("min", None)
|
||||
if min_ == 2:
|
||||
min_ = None
|
||||
max_ = c.get("max", None)
|
||||
if min_ is not None and max_ is not None:
|
||||
dims.append(f"{k} = Dim('{k}', min={min_}, max={max_})")
|
||||
elif min_ is not None:
|
||||
dims.append(f"{k} = Dim('{k}', min={min_})")
|
||||
elif max_ is not None:
|
||||
dims.append(f"{k} = Dim('{k}', max={max_})")
|
||||
else:
|
||||
dims.append(f"{k} = Dim('{k}')")
|
||||
|
||||
# results will get filtered out if no new suggestions,
|
||||
# this can happen if guards are too complex.
|
||||
# in that case don't suggest fix
|
||||
if dims or others:
|
||||
buf += "\nSuggested fixes:\n "
|
||||
buf += "\n ".join(dims + others)
|
||||
|
||||
return buf
|
||||
|
||||
# Note: Model inputs are wrapped as LocalSource in dynamo.
|
||||
# LocalSource.name() wraps the name with L[""]. We use regular
|
||||
# expression to do the replacement to avoid traversing up
|
||||
# the source hierarchy manually.
|
||||
def extract_and_rewrite_local(dc):
|
||||
match = re.search(r"L\['(.+?)'\]", dc)
|
||||
if match is None:
|
||||
return
|
||||
arg = match.expand(r'\1')
|
||||
dc = re.sub(r"L\['(.+?)'\]", r'\1', dc)
|
||||
return arg, dc
|
||||
|
||||
def group(results, args_index):
|
||||
groups = defaultdict(list)
|
||||
for dc in results:
|
||||
local = extract_and_rewrite_local(dc)
|
||||
if local is None:
|
||||
# This can happen, e.g., with `assume_constant_result`.
|
||||
# In that case, we drop the constraint.
|
||||
# TODO(avik) Maybe we should generate an assertion here?
|
||||
continue
|
||||
arg, dc = local
|
||||
if arg in args_index:
|
||||
groups[args_index[arg]].append(dc)
|
||||
else:
|
||||
# This can happen, e.g., with decorators that change the signature.
|
||||
# In that case, we drop the constraint. Seems hard to do better. :/
|
||||
# TODO(avik) Maybe warn that `arg` in not in `signature`?
|
||||
continue
|
||||
sorted_groups = []
|
||||
for idx, dcs in sorted(groups.items()):
|
||||
_, arg = idx
|
||||
sorted_groups.append((arg, sorted(dcs)))
|
||||
return sorted_groups
|
||||
|
||||
signature = original_signature.replace(return_annotation=inspect.Signature.empty)
|
||||
args_index = {}
|
||||
for i, arg in enumerate(signature.parameters.keys()):
|
||||
args_index[arg] = (i, arg)
|
||||
|
||||
def print_results(grouped, indent, result_fn):
|
||||
nonlocal buf
|
||||
|
||||
space = False
|
||||
for arg, results in grouped:
|
||||
if space:
|
||||
buf += "\n"
|
||||
else:
|
||||
space = True
|
||||
buf += f"\n{indent}# {arg}:"
|
||||
for result in results:
|
||||
buf += f"\n{indent}{result_fn(result)}"
|
||||
# order forced specializations based on name
|
||||
forced_specializations = {
|
||||
k: forced_specializations[k]
|
||||
for k in sorted(
|
||||
forced_specializations.keys(),
|
||||
key=lambda x: x.split(" = ")[1],
|
||||
)
|
||||
}
|
||||
|
||||
buf = ""
|
||||
if forced_specializations:
|
||||
debug_names = set()
|
||||
for k in forced_specializations:
|
||||
dim = name_to_dim[k.split(" = ")[0]]
|
||||
if self._is_derived_dim(dim):
|
||||
debug_names.add(dim.root.__name__)
|
||||
else:
|
||||
debug_names.add(dim.__name__)
|
||||
|
||||
buf += (
|
||||
"Some dynamic dimensions need to be specialized because "
|
||||
"the constraints inferred for them are too complex to specify.\n"
|
||||
f"Specializations unexpectedly required ({', '.join(sorted(debug_names))})! "
|
||||
'For more information, run with TORCH_LOGS="+dynamic".\n'
|
||||
)
|
||||
for s, val in forced_specializations.items():
|
||||
buf += f" - {s}, which was marked dynamic, must be specialized to {val}.\n"
|
||||
indent = 4 * " "
|
||||
if self._static_results:
|
||||
grouped_static_results = group(self._static_results, args_index)
|
||||
buf += "\nThe following dimensions have been specialized and CANNOT be dynamic."
|
||||
buf += f"\n```\ndef specializations{str(signature)}:"
|
||||
print_results(
|
||||
grouped_static_results,
|
||||
indent,
|
||||
lambda result: f"assert {result}",
|
||||
buf += f" - solving the guards generated for {s} resulted in a specialized value of {val}.\n"
|
||||
|
||||
self._process_derived_dim_roots(results, name_to_dim)
|
||||
|
||||
dims = []
|
||||
others = []
|
||||
|
||||
# order results by source name
|
||||
results = {
|
||||
k: results[k] for k in sorted(
|
||||
results.keys(),
|
||||
key=lambda x: transform(x, inverse=True),
|
||||
)
|
||||
buf += "\n```\n"
|
||||
if self._dynamic_results:
|
||||
grouped_dynamic_results = group(self._dynamic_results, args_index)
|
||||
buf += "\nThe following dimensions CAN be dynamic."
|
||||
buf += "\nPlease use the following code to specify the constraints they must satisfy:"
|
||||
buf += f"\n```\ndef specify_constraints{str(signature)}:"
|
||||
buf += f"\n{indent}return ["
|
||||
print_results(
|
||||
grouped_dynamic_results,
|
||||
indent * 2,
|
||||
lambda result: f"{result},",
|
||||
)
|
||||
buf += f"\n{indent}]\n```\n"
|
||||
}
|
||||
for k, c in results.items():
|
||||
if "eq" in c:
|
||||
other = c["eq"]
|
||||
if isinstance(other, int):
|
||||
others.append(f"{k} = {other}")
|
||||
elif _is_supported_equivalence(other):
|
||||
others.append(f"{k} = {other}")
|
||||
else:
|
||||
min_ = c.get("min", None)
|
||||
if min_ == 2:
|
||||
min_ = None
|
||||
max_ = c.get("max", None)
|
||||
if min_ is not None and max_ is not None:
|
||||
dims.append(f"{k} = Dim('{k}', min={min_}, max={max_})")
|
||||
elif min_ is not None:
|
||||
dims.append(f"{k} = Dim('{k}', min={min_})")
|
||||
elif max_ is not None:
|
||||
dims.append(f"{k} = Dim('{k}', max={max_})")
|
||||
else:
|
||||
dims.append(f"{k} = Dim('{k}')")
|
||||
|
||||
# results will get filtered out if no new suggestions,
|
||||
# this can happen if guards are too complex.
|
||||
# in that case don't suggest fix
|
||||
if dims or others:
|
||||
buf += "\nSuggested fixes:\n "
|
||||
buf += "\n ".join(dims + others)
|
||||
|
||||
return buf
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user