mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[export] handle new roots & root swapping in derived dims suggested fixes (#125543)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125543 This PR address 2 issues with derived dim suggested fixes, 1) newly introduced roots, and 2) root swapping. 1 | Newly introduced roots appear with modulo guards, e.g. Mod(dx, 2) = 0 suggests dx is a derived dim equal to 2 * _dx, introducing a new root _dx. Currently the final suggested fixes handle this correctly, but we can get intermediate results where related derived dims don't rely on a unified root, and are a mixture of min/max range and derived suggestions. For example: ``` "dx": {"eq": 3*_dx-1, "max": 36} "dy": {"eq": dx+1} This should lead to suggested fixes _dx = Dim('_dx', max=12) dx = 3 * _dx - 1 dy = 3 * _dx ``` This PR prettifies the suggested fixes routine by unifying to a single root, and making each intermediate suggestion either a derived dim or min/max range, not both. 2 | The current suggested fixes for derived dims can lead to root dims/derived dims being swapped, e.g. `dy - 1, dy` -> `dx, dx + 1`. This leads to problematic suggested fixes that look like `dy - 1 = Dim("dy - 1")` since we don't have access to the original variable name. This PR only adds a suggested fix for the root dim, and removes all other derived suggestions. For example, with the export test case test_derived_dim_out_of_order_simplified: ``` _dimz = torch.export.Dim("_dimz", min=6, max=8) dimy = _dimz - 1 dimx = dimy - 1 dimz = torch.export.Dim("dimz", min=6, max=8) # doesn't work, should be = _dimz class Foo(torch.nn.Module): def forward(self, x, y, z): return x + y[1:] + z[2:] foo = Foo() u, v, w = torch.randn(5), torch.randn(6), torch.randn(7) export( foo, (u, v, w), dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}), ) ``` Before: ``` Suggested fixes: _dimz = Dim('_dimz', min=3, max=9223372036854775807) # 2 <= _dimz - 1 <= 9223372036854775806 _dimz - 2 = Dim('_dimz - 2', min=4, max=6) _dimz = Dim('_dimz', min=2, max=9223372036854775806) # 2 <= _dimz <= 9223372036854775806 _dimz - 1 = _dimz - 1 dimz = _dimz ``` New suggested fixes: ``` Suggested fixes: dimz = _dimz ``` Note: This assumes the specified derived relations between dims are correct. This should be valid because: 1) if the relation is plain wrong (e.g. (dx, dx - 1) provided with inputs (6, 4)), this gets caught in beforehand in produce_guards. 2) if the relation is correct but does not match the emitted guard, for example: ``` def forward(self, x, y): return x.reshape([-1]) + y # guard: s0 * 2 = s1 dx = Dim("dx") export( model, (torch.randn(6, 2), torch.randn(12)), dynamic_shapes={"x": (dx, 2), "y": (dx + 6, )} ) ``` This produces two linear equations, leading to specialization since a) produce_guards is able to solve for a concrete value, and b) the export constraint solver will anyways force specializations due to range constraints. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125543 Approved by: https://github.com/avikchaudhuri
This commit is contained in:
parent
0a9d73a814
commit
f206c5c628
|
|
@ -968,6 +968,35 @@ class TestExport(TestCase):
|
|||
6,
|
||||
)
|
||||
|
||||
def test_specialize_derived_dim_roots(self):
|
||||
# dim & derived dim both specialize
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
return x.reshape([-1]) + y
|
||||
|
||||
dy = Dim("dy", min=6)
|
||||
x, y = torch.randn(6, 2), torch.randn(12)
|
||||
dynamic_shapes = {
|
||||
"x": (dy - 6, 2),
|
||||
"y": (dy,),
|
||||
}
|
||||
try:
|
||||
export(Foo(), (x, y), dynamic_shapes=dynamic_shapes)
|
||||
raise Exception(
|
||||
"export() call should have failed with dynamic shapes error."
|
||||
)
|
||||
except torch._dynamo.exc.UserError as exc:
|
||||
expected_error_msg = (
|
||||
"Specializations unexpectedly required \(dy\)!(.*\n)*.*"
|
||||
".*dy - 6.*must be specialized to 6 because the guards generated for it are too complex(.*\n)*.*"
|
||||
"Suggested fixes(.*\n)*.*"
|
||||
".*dy = 12(.*\n)*.*"
|
||||
)
|
||||
self.assertTrue(re.search(expected_error_msg, exc.args[0]) is not None)
|
||||
self.assertTrue(
|
||||
"dy - 6 = 6" not in exc.args[0]
|
||||
) # don't suggest fix for non-root dim
|
||||
|
||||
def test_derived_dim_out_of_order_simplified(self):
|
||||
_dimz = torch.export.Dim("_dimz", min=6, max=8)
|
||||
dimy = _dimz - 1
|
||||
|
|
@ -979,22 +1008,25 @@ class TestExport(TestCase):
|
|||
return x + y[1:] + z[2:]
|
||||
|
||||
foo = Foo()
|
||||
|
||||
u, v, w = torch.randn(5), torch.randn(6), torch.randn(7)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
(
|
||||
"Constraints violated \\(dimz\\)!(.*\n)*.*"
|
||||
"The values of dimz.*must always be related to the values of _dimz - 2.*by.*(.*\n)*.*"
|
||||
"Suggested fixes:(.*\n)*.*"
|
||||
"dimz = _dimz"
|
||||
),
|
||||
):
|
||||
try:
|
||||
export(
|
||||
foo,
|
||||
(u, v, w),
|
||||
dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}),
|
||||
)
|
||||
except torch._dynamo.exc.UserError as exc:
|
||||
expected_error_msg = (
|
||||
"Constraints violated \(dimz\)!(.*\n)*.*"
|
||||
"The values of dimz.*must always be related to the values of _dimz - 2.*by.*(.*\n)*.*"
|
||||
"Suggested fixes:(.*\n)*.*"
|
||||
"dimz = _dimz"
|
||||
)
|
||||
self.assertTrue(re.search(expected_error_msg, exc.args[0]) is not None)
|
||||
# don't suggest fix for non-root dims, and no need to update root here
|
||||
self.assertTrue("_dimz - 2 = Dim(" not in exc.args[0])
|
||||
self.assertTrue("_dimz - 1 = _dimz - 1" not in exc.args[0])
|
||||
self.assertTrue("_dimz = Dim(" not in exc.args[0])
|
||||
|
||||
dimz = dimx + 2 # works, effectively = _dimz
|
||||
ep = export(
|
||||
|
|
@ -1888,6 +1920,55 @@ class TestExport(TestCase):
|
|||
dynamic_shapes={"x": (batch, M, K), "y": (batch, K, N)},
|
||||
)
|
||||
|
||||
def test_suggested_fixes_new_roots(self):
|
||||
from torch.export import dims
|
||||
|
||||
# suggested fixes should introduce new root dim for modulo guard
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x, y, z):
|
||||
# dy = 3 * _dx
|
||||
# dx = 3 * _dx - 1
|
||||
# dz = 3 * _dx + 2
|
||||
# suggested fixes results will look something like
|
||||
# {"dx": {"eq": 3*_dx-1, "min": 5, "max": 36}, "dy": {"eq": dx+1}, ...}
|
||||
if x.shape[0] >= 5 and x.shape[0] <= 36 and y.shape[0] % 3 == 0:
|
||||
return x + y[1:] + z[3:]
|
||||
|
||||
foo = Foo()
|
||||
inputs = (
|
||||
torch.randn(
|
||||
11,
|
||||
),
|
||||
torch.randn(
|
||||
12,
|
||||
),
|
||||
torch.randn(
|
||||
14,
|
||||
),
|
||||
)
|
||||
dx, dy, dz = dims("dx", "dy", "dz")
|
||||
dynamic_shapes = {
|
||||
"x": (dx,),
|
||||
"y": (dy,),
|
||||
"z": (dz,),
|
||||
}
|
||||
with self.assertRaisesRegex( # figure out regex later
|
||||
torch._dynamo.exc.UserError,
|
||||
(
|
||||
"Constraints violated.*!(.*\n)*.*"
|
||||
"Suggested fixes(.*\n)*.*"
|
||||
"_dx = Dim\(\\'_dx\\', max=12\)(.*\n)*.*"
|
||||
"dx = 3\*_dx - 1(.*\n)*.*"
|
||||
"dy = 3\*_dx(.*\n)*.*"
|
||||
"dz = 3\*_dx \+ 2"
|
||||
),
|
||||
):
|
||||
export(Foo(), inputs, dynamic_shapes=dynamic_shapes)
|
||||
# retry export
|
||||
_dx = Dim("_dx", min=2, max=12)
|
||||
dynamic_shapes = {"x": (3 * _dx - 1,), "y": (3 * _dx,), "z": (3 * _dx + 2,)}
|
||||
export(Foo(), inputs, dynamic_shapes=dynamic_shapes)
|
||||
|
||||
def test_dynamic_shapes_spec_with_pytree(self):
|
||||
from torch.export import Dim, export
|
||||
from torch.utils._pytree import tree_map
|
||||
|
|
@ -5065,7 +5146,6 @@ def forward(self, x, y):
|
|||
torch._dynamo.exc.UserError,
|
||||
r".*Specializations unexpectedly required(.*\n)*"
|
||||
r"Suggested fixes:(.*\n)*"
|
||||
r".*dy = Dim.*(.*\n)*"
|
||||
r".*dw0 = 3(.*\n)*"
|
||||
r".*dw1 = 4(.*\n)*"
|
||||
r".*dx0 = 12(.*\n)*"
|
||||
|
|
@ -5082,9 +5162,6 @@ def forward(self, x, y):
|
|||
torch._dynamo.exc.UserError,
|
||||
r".*Constraints violated(.*\n)*"
|
||||
r"Suggested fixes:(.*\n)*"
|
||||
r".*dw0 = Dim.*(.*\n)*"
|
||||
r".*dw1 = Dim.*(.*\n)*"
|
||||
r".*dy = Dim.*(.*\n)*"
|
||||
r".*dz = dy(.*\n)*",
|
||||
) as msg:
|
||||
torch.export._trace._export(
|
||||
|
|
|
|||
|
|
@ -2435,7 +2435,7 @@ class TestDimConstraints(TestCase):
|
|||
def dummy_fn(a, b, c, d, e, f):
|
||||
pass
|
||||
|
||||
action_code = dim_constraints.prettify_results(inspect.signature(dummy_fn))
|
||||
action_code = dim_constraints.prettify_results(inspect.signature(dummy_fn), {})
|
||||
static_code, dynamic_code = re.findall(r"```(.*?)```", action_code, re.DOTALL)
|
||||
expected_static = """
|
||||
def specializations(a, b, c, d, e, f):
|
||||
|
|
|
|||
|
|
@ -1333,7 +1333,10 @@ def export(
|
|||
dim_constraints.remove_redundant_dynamic_results()
|
||||
forced_specializations = dim_constraints.forced_specializations()
|
||||
msg = dim_constraints.prettify_results(
|
||||
original_signature, constraint_violation_error, forced_specializations
|
||||
original_signature,
|
||||
dynamic_shapes,
|
||||
constraint_violation_error,
|
||||
forced_specializations,
|
||||
)
|
||||
if constraint_violation_error:
|
||||
constraint_violation_error.args = (
|
||||
|
|
|
|||
|
|
@ -223,6 +223,7 @@ def _flatten_dynamic_shapes(
|
|||
def produce_guards_and_solve_constraints(
|
||||
fake_mode: FakeTensorMode,
|
||||
gm: torch.fx.GraphModule,
|
||||
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
|
||||
equalities_inputs: EqualityConstraint,
|
||||
original_signature: inspect.Signature,
|
||||
_disable_forced_specializations: Optional[bool] = False,
|
||||
|
|
@ -273,7 +274,10 @@ def produce_guards_and_solve_constraints(
|
|||
forced_specializations = dim_constraints.forced_specializations()
|
||||
if not _is_torch_jit_trace:
|
||||
msg = dim_constraints.prettify_results(
|
||||
original_signature, constraint_violation_error, forced_specializations
|
||||
original_signature,
|
||||
dynamic_shapes,
|
||||
constraint_violation_error,
|
||||
forced_specializations,
|
||||
)
|
||||
else:
|
||||
# FIXME(ycao): This is a hack to get around missing signature from ScriptMethod
|
||||
|
|
|
|||
|
|
@ -1316,6 +1316,7 @@ def _non_strict_export(
|
|||
produce_guards_and_solve_constraints(
|
||||
fake_mode,
|
||||
aten_export_artifact.gm,
|
||||
dynamic_shapes,
|
||||
equalities_inputs,
|
||||
original_signature,
|
||||
_disable_forced_specializations=_disable_forced_specializations,
|
||||
|
|
|
|||
|
|
@ -1914,9 +1914,192 @@ class DimConstraints:
|
|||
dynamic_results.add(dc)
|
||||
self._dynamic_results = dynamic_results
|
||||
|
||||
def _is_derived_dim(self, dim):
|
||||
return isinstance(dim, torch.export.dynamic_shapes._DerivedDim)
|
||||
|
||||
def _is_dim(self, dim):
|
||||
return (
|
||||
isinstance(dim, torch.export.dynamic_shapes._Dim)
|
||||
and not isinstance(dim, torch.export.dynamic_shapes._DerivedDim)
|
||||
)
|
||||
|
||||
def _process_derived_dim_roots(
|
||||
self,
|
||||
results: Dict[str, Dict[str, Any]],
|
||||
name_to_dim: Dict[str, Any],
|
||||
) -> None:
|
||||
'''
|
||||
Here we resolve 2 concerns with derived dims suggested fixes: 1) newly introduced roots,
|
||||
and 2) root swapping.
|
||||
|
||||
1) Newly introduced roots appear with modulo guards, e.g. Mod(dx, 2) = 0 suggests
|
||||
dx is a derived dim equal to 2 * _dx, introducing a new root _dx. Currently the final
|
||||
suggested fixes handle this correctly, but we can get intermediate results that look like
|
||||
{"dy": {"eq": "dx + 1"}, "dx": {"eq": "2 * _dx + 1, "min": 3, "max": 15}}
|
||||
and this routine prettifies this by unifying to a single root, and making each suggestion
|
||||
either a derived dim or min/max range, not both.
|
||||
|
||||
2) With suggested fixes for derived dims, roots can be swapped,
|
||||
e.g. dx, dx - 1 -> dy + 1, dy. Here we don't want to print out the attached name,
|
||||
since this leads to messages like "dx - 1 = Dim("dx - 1", ...)".
|
||||
Instead we evaluate the new root value, and remove results for its derivations.
|
||||
|
||||
First we find all the original roots (specified in dynamic_shapes), that are found in the
|
||||
values of results (i.e. used for computing suggesting fix values). These original roots
|
||||
(suppose `dx`) are either specialized, unchanged, refined, or swapped
|
||||
(expressed as a derived dim). If any of the first 3 cases happen, we suggest `dx`'s value
|
||||
in results, and remove suggestions for derivations of `dx`, assuming the derived relation
|
||||
is valid. If swapped, we find the new root, and use the fix to evaluate `dx`'s new value,
|
||||
and then do the same with `dx`'s derivations.
|
||||
|
||||
Assuming the originally specified derived relations are correct is valid, because:
|
||||
1) if the relations are plain wrong (e.g. input shape = (6, 4) with spec (dx, dx - 1))
|
||||
produce_guards() will catch this and crash before hand.
|
||||
2) if the relations are numerically correct but do not match the emitted guard,
|
||||
for example:
|
||||
|
||||
def forward(self, x, y):
|
||||
return x.reshape([-1]) + y # guard: s0 * 2 = s1
|
||||
inputs = (torch.randn(6, 2), torch.randn(12))
|
||||
dx = Dim("dx", min=2, max=32)
|
||||
dynamic_shapes={"x": (dx, 2), "y": (dx + 6, )} # this matches values but not op
|
||||
|
||||
then this leads to 2 linear equations, and a) produce_guards() is able to solve for
|
||||
the unique solution of dx = 6 and specialize, and b) the export constraint solver will
|
||||
raise an issue due to range constraints (a unique solution means not all values in a
|
||||
range satisfy a guard) and also force specializations.
|
||||
'''
|
||||
from torch.export.dynamic_shapes import Dim
|
||||
|
||||
def _check_same_range(c, dim):
|
||||
# returns True if c & dim are both min/max ranges with same values
|
||||
return (
|
||||
self._is_dim(dim)
|
||||
and ("min" in c or "max" in c)
|
||||
and (dim.min < 2 or dim.min == c.get("min", 2)) # let pass if min < 2
|
||||
and dim.max == c.get("max", sys.maxsize - 1)
|
||||
)
|
||||
|
||||
# 1) newly introduced roots
|
||||
# this part we handle adding newly introduced roots
|
||||
# these arise from guards like "x.shape[0] % 3 == 0"
|
||||
# leading to suggested fixes like "dx = 3*_dx"
|
||||
# extract _dx, and find appropriate min/max values
|
||||
#
|
||||
# before, we have something like:
|
||||
# {"dx": {"eq": 3*_dx+1, "min": 4, "max": 10}, "dy": dx+1, "dz": dx+2}
|
||||
# we want instead:
|
||||
# {"_dx": {"min": 1, "max": 4}, "dx": 3*_dx+1, "dy": 3*_dx+2, "dz": 3*_dx+3}
|
||||
introduced_roots: Dict[str, str] = {} # map new root -> old root
|
||||
for k, c in list(results.items()):
|
||||
if "eq" in c and isinstance(c["eq"], sympy.Expr): # derived dim
|
||||
root = next(iter(c["eq"].free_symbols))
|
||||
if str(root) not in name_to_dim:
|
||||
introduced_roots[str(root)] = k
|
||||
# calculate necessary min & max
|
||||
modulus, remainder = sympy.polys.polytools.div(c["eq"], root)
|
||||
c_min = c.get("min", 2)
|
||||
min_ = math.ceil((c_min - remainder) / modulus)
|
||||
c_max = c.get("max", sys.maxsize - 1)
|
||||
max_ = math.floor((c_max - remainder) / modulus)
|
||||
# create result & dim
|
||||
results[str(root)] = {"min": min_, "max": max_}
|
||||
name_to_dim[str(root)] = Dim(str(root), min=min_, max=max_)
|
||||
# remove old root min/max bounds
|
||||
c.pop("min", None)
|
||||
c.pop("max", None)
|
||||
|
||||
# alter derivations that depend on old root, to unify to new root
|
||||
# e.g. dx=3*_dx+1, dy=dx+1 -> dy=3*_dx+2
|
||||
for old_root in introduced_roots.values():
|
||||
for k, c in list(results.items()):
|
||||
if (
|
||||
"eq" in c
|
||||
and isinstance(c["eq"], sympy.Expr)
|
||||
and str(symbol := next(iter(c["eq"].free_symbols))) == old_root
|
||||
): # derived dim with root = old_root
|
||||
new_root_expr = results[str(old_root)]["eq"] # dx=3*_dx+1
|
||||
new_expr = c["eq"].subs({symbol: new_root_expr}) # dy=(3*_dx+1)+1
|
||||
c["eq"] = new_expr
|
||||
|
||||
# 2) root swapping
|
||||
# collect all the original roots that are used for calculating values of suggested fixes
|
||||
# this consists of:
|
||||
# 1) {"dx": {"min": ..., "max": ...}} -> dx: refined root dim
|
||||
# 2) {"dy": "dx + 1"} -> dx: root for suggested fix
|
||||
modified_roots: Set[str] = set()
|
||||
for k, c in results.items():
|
||||
if k not in name_to_dim: # _dynamo.export() may handle source directly
|
||||
continue
|
||||
if self._is_dim(name_to_dim[k]) and ("min" in c or "max" in c): # case 1)
|
||||
modified_roots.add(k)
|
||||
elif "eq" in c and isinstance(c["eq"], sympy.Expr): # case 2)
|
||||
root = next(iter(c["eq"].free_symbols))
|
||||
assert root is not None
|
||||
modified_roots.add(str(root))
|
||||
|
||||
# exclude newly introduced roots, we've already processed these
|
||||
modified_roots = modified_roots.difference(introduced_roots)
|
||||
|
||||
# evaluate the new value for each root
|
||||
# this is now either 1) unchanged, 2) refined with a new range,
|
||||
# or 3) specialized to a concrete value
|
||||
modified_root_values: Dict[str, Dict[str, Any]] = {}
|
||||
for root in modified_roots:
|
||||
swapped_root = True
|
||||
if root in results:
|
||||
c = results[root]
|
||||
if (
|
||||
("min" in c or "max" in c) # range
|
||||
or isinstance(c["eq"], int) # specialized
|
||||
):
|
||||
# here, the original root is a root Dim or concrete value in results.
|
||||
# if it is a derived dim, it is swapped, and we handle that below.
|
||||
if not _check_same_range(c, name_to_dim[root]): # ignore if unchanged
|
||||
modified_root_values[root] = c
|
||||
swapped_root = False
|
||||
|
||||
if swapped_root:
|
||||
# if the original root has been swapped in results, that means the new root
|
||||
# is a range (if it had specialized, the original root would have too).
|
||||
# find this new root, and solve for the original root's range.
|
||||
for k, c in results.items():
|
||||
if k not in name_to_dim:
|
||||
continue
|
||||
dim = name_to_dim[k]
|
||||
if dim.__class__.__name__ == "_DerivedDim" and dim.root.__name__ == root:
|
||||
# only look for min/max root, otherwise root would have specialized
|
||||
if "min" in c or "max" in c:
|
||||
expr = sympy.sympify(k)
|
||||
s = next(iter(expr.free_symbols))
|
||||
result = {
|
||||
"min": try_solve(sympy.Eq(expr, c["min"]), s)[1], # type: ignore[arg-type]
|
||||
"max": try_solve(sympy.Eq(expr, c["max"]), s)[1], # type: ignore[arg-type]
|
||||
}
|
||||
if not _check_same_range(result, name_to_dim[root]): # ignore if unchanged
|
||||
modified_root_values[root] = result
|
||||
break
|
||||
|
||||
# filter out results where the key is a derived dim (e.g. {"dx - 1" : 4})
|
||||
# we only want to suggest fixes for the root, to avoid derived names.
|
||||
# also, remove anything in modified_roots, since we either add new modified values after this,
|
||||
# or have decided they are unchanged.
|
||||
for k in list(results.keys()):
|
||||
if k not in name_to_dim:
|
||||
continue
|
||||
if self._is_derived_dim(name_to_dim[k]) or k in modified_roots:
|
||||
del results[k]
|
||||
|
||||
# update results with modified root values
|
||||
# now results has the following properties:
|
||||
# - only contains original roots as keys
|
||||
# - each root is now either specialized, refined, or derived from another original root
|
||||
results.update(modified_root_values)
|
||||
|
||||
def prettify_results(
|
||||
self,
|
||||
original_signature: inspect.Signature,
|
||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
|
||||
constraint_violation_error=None,
|
||||
forced_specializations=None,
|
||||
):
|
||||
|
|
@ -1929,6 +2112,8 @@ class DimConstraints:
|
|||
return s
|
||||
|
||||
results = defaultdict(dict)
|
||||
if dynamic_shapes is None:
|
||||
dynamic_shapes = {}
|
||||
|
||||
def flip(op):
|
||||
if op == "<=":
|
||||
|
|
@ -1955,6 +2140,18 @@ class DimConstraints:
|
|||
assert op == "=="
|
||||
results[expr]["eq"] = digit
|
||||
|
||||
# retrieve dynamic shapes
|
||||
name_to_dim = {}
|
||||
for dim in pytree.tree_flatten(
|
||||
dynamic_shapes,
|
||||
is_leaf=lambda x: self._is_derived_dim(x) or self._is_dim(x),
|
||||
)[0]:
|
||||
if dim is None or isinstance(dim, int):
|
||||
continue
|
||||
name_to_dim[dim.__name__] = dim
|
||||
if self._is_derived_dim(dim):
|
||||
name_to_dim[dim.root.__name__] = dim.root
|
||||
|
||||
for s in self._static_results.union(self._dynamic_results):
|
||||
t = transform(s)
|
||||
if t == s:
|
||||
|
|
@ -1981,23 +2178,26 @@ class DimConstraints:
|
|||
}
|
||||
|
||||
buf = ""
|
||||
debug_names = set()
|
||||
if forced_specializations:
|
||||
debug_names.update(k.split(" = ")[0] for k in forced_specializations.keys())
|
||||
debug_names = set()
|
||||
for k in forced_specializations:
|
||||
dim = name_to_dim[k.split(" = ")[0]]
|
||||
if self._is_derived_dim(dim):
|
||||
debug_names.add(dim.root.__name__)
|
||||
else:
|
||||
debug_names.add(dim.__name__)
|
||||
|
||||
buf += (
|
||||
f"Specializations unexpectedly required ({', '.join(debug_names)})! "
|
||||
f"Specializations unexpectedly required ({', '.join(sorted(debug_names))})! "
|
||||
"For more information, run with TORCH_LOGS=\"+dynamic\".\n"
|
||||
)
|
||||
for s, val in forced_specializations.items():
|
||||
buf += f" - {s} must be specialized to {val} because the guards generated for it are too complex.\n"
|
||||
|
||||
self._process_derived_dim_roots(results, name_to_dim)
|
||||
|
||||
dims = []
|
||||
others = []
|
||||
match = None
|
||||
if constraint_violation_error:
|
||||
match = re.search(r"Constraints violated \((.*)\)", constraint_violation_error.args[0])
|
||||
if match is not None:
|
||||
debug_names.update(match.expand(r'\1').split(', '))
|
||||
|
||||
# order results by source name
|
||||
results = {
|
||||
|
|
@ -2007,21 +2207,11 @@ class DimConstraints:
|
|||
)
|
||||
}
|
||||
for k, c in results.items():
|
||||
# if k not in debug_names:
|
||||
# continue
|
||||
if "eq" in c:
|
||||
other = c["eq"]
|
||||
if isinstance(other, int):
|
||||
others.append(f"{k} = {other}")
|
||||
elif self._is_supported_equivalence(other):
|
||||
s = next(iter(other.free_symbols))
|
||||
if str(s) not in results:
|
||||
modulus, remainder = sympy.polys.polytools.div(other, s)
|
||||
c_min = c.get("min", 2)
|
||||
min_ = math.ceil((c_min - remainder) / modulus)
|
||||
c_max = c.get("max", sys.maxsize - 1)
|
||||
max_ = math.floor((c_max - remainder) / modulus)
|
||||
dims.append(f"{s} = Dim('{s}', min={min_}, max={max_}) # {c_min} <= {other} <= {c_max}")
|
||||
others.append(f"{k} = {other}")
|
||||
else:
|
||||
min_ = c.get("min", None)
|
||||
|
|
@ -2037,8 +2227,12 @@ class DimConstraints:
|
|||
else:
|
||||
dims.append(f"{k} = Dim('{k}')")
|
||||
|
||||
buf += "\nSuggested fixes:\n "
|
||||
buf += "\n ".join(dims + others)
|
||||
# results will get filtered out if no new suggestions,
|
||||
# this can happen if guards are too complex.
|
||||
# in that case don't suggest fix
|
||||
if dims or others:
|
||||
buf += "\nSuggested fixes:\n "
|
||||
buf += "\n ".join(dims + others)
|
||||
|
||||
return buf
|
||||
|
||||
|
|
@ -3937,7 +4131,7 @@ class ShapeEnv:
|
|||
error_msgs.append(msg)
|
||||
debug_names.add(debug_name)
|
||||
if len(error_msgs) > 0:
|
||||
debug_names = ', '.join(debug_names)
|
||||
debug_names = ', '.join(sorted(debug_names))
|
||||
err = '\n'.join(error_msgs)
|
||||
raise ConstraintViolationError(
|
||||
f"Constraints violated ({debug_names})! "
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user