[export] handle new roots & root swapping in derived dims suggested fixes (#125543)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125543

This PR address 2 issues with derived dim suggested fixes, 1) newly introduced roots, and 2) root swapping.

1 | Newly introduced roots appear with modulo guards, e.g. Mod(dx, 2) = 0 suggests dx is a derived dim equal to 2 * _dx, introducing a new root _dx. Currently the final suggested fixes handle this correctly, but we can get intermediate results where related derived dims don't rely on a unified root, and are a mixture of min/max range and derived suggestions.

For example:
```
"dx": {"eq": 3*_dx-1, "max": 36}
"dy": {"eq": dx+1}
This should lead to suggested fixes
  _dx = Dim('_dx', max=12)
  dx = 3 * _dx - 1
  dy = 3 * _dx
```

This PR prettifies the suggested fixes routine by unifying to a single root, and making each intermediate suggestion either a derived dim or min/max range, not both.

2 | The current suggested fixes for derived dims can lead to root dims/derived dims being swapped, e.g. `dy - 1, dy` -> `dx, dx + 1`. This leads to problematic suggested fixes that look like `dy - 1 = Dim("dy - 1")` since we don't have access to the original variable name.

This PR only adds a suggested fix for the root dim, and removes all other derived suggestions.

For example, with the export test case test_derived_dim_out_of_order_simplified:
```
_dimz = torch.export.Dim("_dimz", min=6, max=8)
dimy = _dimz - 1
dimx = dimy - 1
dimz = torch.export.Dim("dimz", min=6, max=8)  # doesn't work, should be = _dimz

class Foo(torch.nn.Module):
    def forward(self, x, y, z):
        return x + y[1:] + z[2:]

foo = Foo()
u, v, w = torch.randn(5), torch.randn(6), torch.randn(7)
export(
    foo,
    (u, v, w),
    dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}),
)
```

Before:
```
Suggested fixes:
  _dimz = Dim('_dimz', min=3, max=9223372036854775807)  # 2 <= _dimz - 1 <= 9223372036854775806
  _dimz - 2 = Dim('_dimz - 2', min=4, max=6)
  _dimz = Dim('_dimz', min=2, max=9223372036854775806)  # 2 <= _dimz <= 9223372036854775806
  _dimz - 1 = _dimz - 1
  dimz = _dimz
```

New suggested fixes:
```
Suggested fixes:
  dimz = _dimz
```

Note: This assumes the specified derived relations between dims are correct. This should be valid because: 1) if the relation is plain wrong (e.g. (dx, dx - 1) provided with inputs (6, 4)), this gets caught in beforehand in produce_guards. 2) if the relation is correct but does not match the emitted guard, for example:
```
def forward(self, x, y):
    return x.reshape([-1]) + y  # guard: s0 * 2 = s1
dx = Dim("dx")
export(
    model,
    (torch.randn(6, 2), torch.randn(12)),
    dynamic_shapes={"x": (dx, 2), "y": (dx + 6, )}
)
```
This produces two linear equations, leading to specialization since a) produce_guards is able to solve for a concrete value, and b) the export constraint solver will anyways force specializations due to range constraints.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125543
Approved by: https://github.com/avikchaudhuri
This commit is contained in:
Pian Pawakapan 2024-05-28 20:41:43 +00:00 committed by PyTorch MergeBot
parent 0a9d73a814
commit f206c5c628
6 changed files with 317 additions and 38 deletions

View File

@ -968,6 +968,35 @@ class TestExport(TestCase):
6,
)
def test_specialize_derived_dim_roots(self):
# dim & derived dim both specialize
class Foo(torch.nn.Module):
def forward(self, x, y):
return x.reshape([-1]) + y
dy = Dim("dy", min=6)
x, y = torch.randn(6, 2), torch.randn(12)
dynamic_shapes = {
"x": (dy - 6, 2),
"y": (dy,),
}
try:
export(Foo(), (x, y), dynamic_shapes=dynamic_shapes)
raise Exception(
"export() call should have failed with dynamic shapes error."
)
except torch._dynamo.exc.UserError as exc:
expected_error_msg = (
"Specializations unexpectedly required \(dy\)!(.*\n)*.*"
".*dy - 6.*must be specialized to 6 because the guards generated for it are too complex(.*\n)*.*"
"Suggested fixes(.*\n)*.*"
".*dy = 12(.*\n)*.*"
)
self.assertTrue(re.search(expected_error_msg, exc.args[0]) is not None)
self.assertTrue(
"dy - 6 = 6" not in exc.args[0]
) # don't suggest fix for non-root dim
def test_derived_dim_out_of_order_simplified(self):
_dimz = torch.export.Dim("_dimz", min=6, max=8)
dimy = _dimz - 1
@ -979,22 +1008,25 @@ class TestExport(TestCase):
return x + y[1:] + z[2:]
foo = Foo()
u, v, w = torch.randn(5), torch.randn(6), torch.randn(7)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
(
"Constraints violated \\(dimz\\)!(.*\n)*.*"
"The values of dimz.*must always be related to the values of _dimz - 2.*by.*(.*\n)*.*"
"Suggested fixes:(.*\n)*.*"
"dimz = _dimz"
),
):
try:
export(
foo,
(u, v, w),
dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}),
)
except torch._dynamo.exc.UserError as exc:
expected_error_msg = (
"Constraints violated \(dimz\)!(.*\n)*.*"
"The values of dimz.*must always be related to the values of _dimz - 2.*by.*(.*\n)*.*"
"Suggested fixes:(.*\n)*.*"
"dimz = _dimz"
)
self.assertTrue(re.search(expected_error_msg, exc.args[0]) is not None)
# don't suggest fix for non-root dims, and no need to update root here
self.assertTrue("_dimz - 2 = Dim(" not in exc.args[0])
self.assertTrue("_dimz - 1 = _dimz - 1" not in exc.args[0])
self.assertTrue("_dimz = Dim(" not in exc.args[0])
dimz = dimx + 2 # works, effectively = _dimz
ep = export(
@ -1888,6 +1920,55 @@ class TestExport(TestCase):
dynamic_shapes={"x": (batch, M, K), "y": (batch, K, N)},
)
def test_suggested_fixes_new_roots(self):
from torch.export import dims
# suggested fixes should introduce new root dim for modulo guard
class Foo(torch.nn.Module):
def forward(self, x, y, z):
# dy = 3 * _dx
# dx = 3 * _dx - 1
# dz = 3 * _dx + 2
# suggested fixes results will look something like
# {"dx": {"eq": 3*_dx-1, "min": 5, "max": 36}, "dy": {"eq": dx+1}, ...}
if x.shape[0] >= 5 and x.shape[0] <= 36 and y.shape[0] % 3 == 0:
return x + y[1:] + z[3:]
foo = Foo()
inputs = (
torch.randn(
11,
),
torch.randn(
12,
),
torch.randn(
14,
),
)
dx, dy, dz = dims("dx", "dy", "dz")
dynamic_shapes = {
"x": (dx,),
"y": (dy,),
"z": (dz,),
}
with self.assertRaisesRegex( # figure out regex later
torch._dynamo.exc.UserError,
(
"Constraints violated.*!(.*\n)*.*"
"Suggested fixes(.*\n)*.*"
"_dx = Dim\(\\'_dx\\', max=12\)(.*\n)*.*"
"dx = 3\*_dx - 1(.*\n)*.*"
"dy = 3\*_dx(.*\n)*.*"
"dz = 3\*_dx \+ 2"
),
):
export(Foo(), inputs, dynamic_shapes=dynamic_shapes)
# retry export
_dx = Dim("_dx", min=2, max=12)
dynamic_shapes = {"x": (3 * _dx - 1,), "y": (3 * _dx,), "z": (3 * _dx + 2,)}
export(Foo(), inputs, dynamic_shapes=dynamic_shapes)
def test_dynamic_shapes_spec_with_pytree(self):
from torch.export import Dim, export
from torch.utils._pytree import tree_map
@ -5065,7 +5146,6 @@ def forward(self, x, y):
torch._dynamo.exc.UserError,
r".*Specializations unexpectedly required(.*\n)*"
r"Suggested fixes:(.*\n)*"
r".*dy = Dim.*(.*\n)*"
r".*dw0 = 3(.*\n)*"
r".*dw1 = 4(.*\n)*"
r".*dx0 = 12(.*\n)*"
@ -5082,9 +5162,6 @@ def forward(self, x, y):
torch._dynamo.exc.UserError,
r".*Constraints violated(.*\n)*"
r"Suggested fixes:(.*\n)*"
r".*dw0 = Dim.*(.*\n)*"
r".*dw1 = Dim.*(.*\n)*"
r".*dy = Dim.*(.*\n)*"
r".*dz = dy(.*\n)*",
) as msg:
torch.export._trace._export(

View File

@ -2435,7 +2435,7 @@ class TestDimConstraints(TestCase):
def dummy_fn(a, b, c, d, e, f):
pass
action_code = dim_constraints.prettify_results(inspect.signature(dummy_fn))
action_code = dim_constraints.prettify_results(inspect.signature(dummy_fn), {})
static_code, dynamic_code = re.findall(r"```(.*?)```", action_code, re.DOTALL)
expected_static = """
def specializations(a, b, c, d, e, f):

View File

@ -1333,7 +1333,10 @@ def export(
dim_constraints.remove_redundant_dynamic_results()
forced_specializations = dim_constraints.forced_specializations()
msg = dim_constraints.prettify_results(
original_signature, constraint_violation_error, forced_specializations
original_signature,
dynamic_shapes,
constraint_violation_error,
forced_specializations,
)
if constraint_violation_error:
constraint_violation_error.args = (

View File

@ -223,6 +223,7 @@ def _flatten_dynamic_shapes(
def produce_guards_and_solve_constraints(
fake_mode: FakeTensorMode,
gm: torch.fx.GraphModule,
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
equalities_inputs: EqualityConstraint,
original_signature: inspect.Signature,
_disable_forced_specializations: Optional[bool] = False,
@ -273,7 +274,10 @@ def produce_guards_and_solve_constraints(
forced_specializations = dim_constraints.forced_specializations()
if not _is_torch_jit_trace:
msg = dim_constraints.prettify_results(
original_signature, constraint_violation_error, forced_specializations
original_signature,
dynamic_shapes,
constraint_violation_error,
forced_specializations,
)
else:
# FIXME(ycao): This is a hack to get around missing signature from ScriptMethod

View File

@ -1316,6 +1316,7 @@ def _non_strict_export(
produce_guards_and_solve_constraints(
fake_mode,
aten_export_artifact.gm,
dynamic_shapes,
equalities_inputs,
original_signature,
_disable_forced_specializations=_disable_forced_specializations,

View File

@ -1914,9 +1914,192 @@ class DimConstraints:
dynamic_results.add(dc)
self._dynamic_results = dynamic_results
def _is_derived_dim(self, dim):
return isinstance(dim, torch.export.dynamic_shapes._DerivedDim)
def _is_dim(self, dim):
return (
isinstance(dim, torch.export.dynamic_shapes._Dim)
and not isinstance(dim, torch.export.dynamic_shapes._DerivedDim)
)
def _process_derived_dim_roots(
self,
results: Dict[str, Dict[str, Any]],
name_to_dim: Dict[str, Any],
) -> None:
'''
Here we resolve 2 concerns with derived dims suggested fixes: 1) newly introduced roots,
and 2) root swapping.
1) Newly introduced roots appear with modulo guards, e.g. Mod(dx, 2) = 0 suggests
dx is a derived dim equal to 2 * _dx, introducing a new root _dx. Currently the final
suggested fixes handle this correctly, but we can get intermediate results that look like
{"dy": {"eq": "dx + 1"}, "dx": {"eq": "2 * _dx + 1, "min": 3, "max": 15}}
and this routine prettifies this by unifying to a single root, and making each suggestion
either a derived dim or min/max range, not both.
2) With suggested fixes for derived dims, roots can be swapped,
e.g. dx, dx - 1 -> dy + 1, dy. Here we don't want to print out the attached name,
since this leads to messages like "dx - 1 = Dim("dx - 1", ...)".
Instead we evaluate the new root value, and remove results for its derivations.
First we find all the original roots (specified in dynamic_shapes), that are found in the
values of results (i.e. used for computing suggesting fix values). These original roots
(suppose `dx`) are either specialized, unchanged, refined, or swapped
(expressed as a derived dim). If any of the first 3 cases happen, we suggest `dx`'s value
in results, and remove suggestions for derivations of `dx`, assuming the derived relation
is valid. If swapped, we find the new root, and use the fix to evaluate `dx`'s new value,
and then do the same with `dx`'s derivations.
Assuming the originally specified derived relations are correct is valid, because:
1) if the relations are plain wrong (e.g. input shape = (6, 4) with spec (dx, dx - 1))
produce_guards() will catch this and crash before hand.
2) if the relations are numerically correct but do not match the emitted guard,
for example:
def forward(self, x, y):
return x.reshape([-1]) + y # guard: s0 * 2 = s1
inputs = (torch.randn(6, 2), torch.randn(12))
dx = Dim("dx", min=2, max=32)
dynamic_shapes={"x": (dx, 2), "y": (dx + 6, )} # this matches values but not op
then this leads to 2 linear equations, and a) produce_guards() is able to solve for
the unique solution of dx = 6 and specialize, and b) the export constraint solver will
raise an issue due to range constraints (a unique solution means not all values in a
range satisfy a guard) and also force specializations.
'''
from torch.export.dynamic_shapes import Dim
def _check_same_range(c, dim):
# returns True if c & dim are both min/max ranges with same values
return (
self._is_dim(dim)
and ("min" in c or "max" in c)
and (dim.min < 2 or dim.min == c.get("min", 2)) # let pass if min < 2
and dim.max == c.get("max", sys.maxsize - 1)
)
# 1) newly introduced roots
# this part we handle adding newly introduced roots
# these arise from guards like "x.shape[0] % 3 == 0"
# leading to suggested fixes like "dx = 3*_dx"
# extract _dx, and find appropriate min/max values
#
# before, we have something like:
# {"dx": {"eq": 3*_dx+1, "min": 4, "max": 10}, "dy": dx+1, "dz": dx+2}
# we want instead:
# {"_dx": {"min": 1, "max": 4}, "dx": 3*_dx+1, "dy": 3*_dx+2, "dz": 3*_dx+3}
introduced_roots: Dict[str, str] = {} # map new root -> old root
for k, c in list(results.items()):
if "eq" in c and isinstance(c["eq"], sympy.Expr): # derived dim
root = next(iter(c["eq"].free_symbols))
if str(root) not in name_to_dim:
introduced_roots[str(root)] = k
# calculate necessary min & max
modulus, remainder = sympy.polys.polytools.div(c["eq"], root)
c_min = c.get("min", 2)
min_ = math.ceil((c_min - remainder) / modulus)
c_max = c.get("max", sys.maxsize - 1)
max_ = math.floor((c_max - remainder) / modulus)
# create result & dim
results[str(root)] = {"min": min_, "max": max_}
name_to_dim[str(root)] = Dim(str(root), min=min_, max=max_)
# remove old root min/max bounds
c.pop("min", None)
c.pop("max", None)
# alter derivations that depend on old root, to unify to new root
# e.g. dx=3*_dx+1, dy=dx+1 -> dy=3*_dx+2
for old_root in introduced_roots.values():
for k, c in list(results.items()):
if (
"eq" in c
and isinstance(c["eq"], sympy.Expr)
and str(symbol := next(iter(c["eq"].free_symbols))) == old_root
): # derived dim with root = old_root
new_root_expr = results[str(old_root)]["eq"] # dx=3*_dx+1
new_expr = c["eq"].subs({symbol: new_root_expr}) # dy=(3*_dx+1)+1
c["eq"] = new_expr
# 2) root swapping
# collect all the original roots that are used for calculating values of suggested fixes
# this consists of:
# 1) {"dx": {"min": ..., "max": ...}} -> dx: refined root dim
# 2) {"dy": "dx + 1"} -> dx: root for suggested fix
modified_roots: Set[str] = set()
for k, c in results.items():
if k not in name_to_dim: # _dynamo.export() may handle source directly
continue
if self._is_dim(name_to_dim[k]) and ("min" in c or "max" in c): # case 1)
modified_roots.add(k)
elif "eq" in c and isinstance(c["eq"], sympy.Expr): # case 2)
root = next(iter(c["eq"].free_symbols))
assert root is not None
modified_roots.add(str(root))
# exclude newly introduced roots, we've already processed these
modified_roots = modified_roots.difference(introduced_roots)
# evaluate the new value for each root
# this is now either 1) unchanged, 2) refined with a new range,
# or 3) specialized to a concrete value
modified_root_values: Dict[str, Dict[str, Any]] = {}
for root in modified_roots:
swapped_root = True
if root in results:
c = results[root]
if (
("min" in c or "max" in c) # range
or isinstance(c["eq"], int) # specialized
):
# here, the original root is a root Dim or concrete value in results.
# if it is a derived dim, it is swapped, and we handle that below.
if not _check_same_range(c, name_to_dim[root]): # ignore if unchanged
modified_root_values[root] = c
swapped_root = False
if swapped_root:
# if the original root has been swapped in results, that means the new root
# is a range (if it had specialized, the original root would have too).
# find this new root, and solve for the original root's range.
for k, c in results.items():
if k not in name_to_dim:
continue
dim = name_to_dim[k]
if dim.__class__.__name__ == "_DerivedDim" and dim.root.__name__ == root:
# only look for min/max root, otherwise root would have specialized
if "min" in c or "max" in c:
expr = sympy.sympify(k)
s = next(iter(expr.free_symbols))
result = {
"min": try_solve(sympy.Eq(expr, c["min"]), s)[1], # type: ignore[arg-type]
"max": try_solve(sympy.Eq(expr, c["max"]), s)[1], # type: ignore[arg-type]
}
if not _check_same_range(result, name_to_dim[root]): # ignore if unchanged
modified_root_values[root] = result
break
# filter out results where the key is a derived dim (e.g. {"dx - 1" : 4})
# we only want to suggest fixes for the root, to avoid derived names.
# also, remove anything in modified_roots, since we either add new modified values after this,
# or have decided they are unchanged.
for k in list(results.keys()):
if k not in name_to_dim:
continue
if self._is_derived_dim(name_to_dim[k]) or k in modified_roots:
del results[k]
# update results with modified root values
# now results has the following properties:
# - only contains original roots as keys
# - each root is now either specialized, refined, or derived from another original root
results.update(modified_root_values)
def prettify_results(
self,
original_signature: inspect.Signature,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
constraint_violation_error=None,
forced_specializations=None,
):
@ -1929,6 +2112,8 @@ class DimConstraints:
return s
results = defaultdict(dict)
if dynamic_shapes is None:
dynamic_shapes = {}
def flip(op):
if op == "<=":
@ -1955,6 +2140,18 @@ class DimConstraints:
assert op == "=="
results[expr]["eq"] = digit
# retrieve dynamic shapes
name_to_dim = {}
for dim in pytree.tree_flatten(
dynamic_shapes,
is_leaf=lambda x: self._is_derived_dim(x) or self._is_dim(x),
)[0]:
if dim is None or isinstance(dim, int):
continue
name_to_dim[dim.__name__] = dim
if self._is_derived_dim(dim):
name_to_dim[dim.root.__name__] = dim.root
for s in self._static_results.union(self._dynamic_results):
t = transform(s)
if t == s:
@ -1981,23 +2178,26 @@ class DimConstraints:
}
buf = ""
debug_names = set()
if forced_specializations:
debug_names.update(k.split(" = ")[0] for k in forced_specializations.keys())
debug_names = set()
for k in forced_specializations:
dim = name_to_dim[k.split(" = ")[0]]
if self._is_derived_dim(dim):
debug_names.add(dim.root.__name__)
else:
debug_names.add(dim.__name__)
buf += (
f"Specializations unexpectedly required ({', '.join(debug_names)})! "
f"Specializations unexpectedly required ({', '.join(sorted(debug_names))})! "
"For more information, run with TORCH_LOGS=\"+dynamic\".\n"
)
for s, val in forced_specializations.items():
buf += f" - {s} must be specialized to {val} because the guards generated for it are too complex.\n"
self._process_derived_dim_roots(results, name_to_dim)
dims = []
others = []
match = None
if constraint_violation_error:
match = re.search(r"Constraints violated \((.*)\)", constraint_violation_error.args[0])
if match is not None:
debug_names.update(match.expand(r'\1').split(', '))
# order results by source name
results = {
@ -2007,21 +2207,11 @@ class DimConstraints:
)
}
for k, c in results.items():
# if k not in debug_names:
# continue
if "eq" in c:
other = c["eq"]
if isinstance(other, int):
others.append(f"{k} = {other}")
elif self._is_supported_equivalence(other):
s = next(iter(other.free_symbols))
if str(s) not in results:
modulus, remainder = sympy.polys.polytools.div(other, s)
c_min = c.get("min", 2)
min_ = math.ceil((c_min - remainder) / modulus)
c_max = c.get("max", sys.maxsize - 1)
max_ = math.floor((c_max - remainder) / modulus)
dims.append(f"{s} = Dim('{s}', min={min_}, max={max_}) # {c_min} <= {other} <= {c_max}")
others.append(f"{k} = {other}")
else:
min_ = c.get("min", None)
@ -2037,6 +2227,10 @@ class DimConstraints:
else:
dims.append(f"{k} = Dim('{k}')")
# results will get filtered out if no new suggestions,
# this can happen if guards are too complex.
# in that case don't suggest fix
if dims or others:
buf += "\nSuggested fixes:\n "
buf += "\n ".join(dims + others)
@ -3937,7 +4131,7 @@ class ShapeEnv:
error_msgs.append(msg)
debug_names.add(debug_name)
if len(error_msgs) > 0:
debug_names = ', '.join(debug_names)
debug_names = ', '.join(sorted(debug_names))
err = '\n'.join(error_msgs)
raise ConstraintViolationError(
f"Constraints violated ({debug_names})! "