remove redundant dynamic_dim (#107815)

Differential Revision: D48618472

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107815
Approved by: https://github.com/tugsbayasgalan, https://github.com/gmagogsfm
This commit is contained in:
Avik Chaudhuri 2023-08-24 10:46:21 +00:00 committed by PyTorch MergeBot
parent 8354d32f6b
commit cf76938f70
4 changed files with 43 additions and 11 deletions

View File

@ -2312,6 +2312,23 @@ def forward(self, x):
constraints.append(dynamic_dim(z, 0) == dynamic_dim(x, 0))
torch._dynamo.export(my_dyn_fn, constraints=constraints)(x, y, z)
def test_remove_redundant_dynamic_dim_in_error_message(self):
def foo(x, y):
if x.shape[0] == y["k"].shape[0]:
return x + 1
else:
return x - 1
a = torch.randn(3)
b = torch.randn(3)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"\\[\n.*\n.*dynamic_dim.*==.*dynamic_dim.*\n.*\\]",
):
torch._export.export(
foo, (a, {"k": b}), constraints=[dynamic_dim(a, 0), dynamic_dim(b, 0)]
)
@config.patch(
capture_dynamic_output_shape_ops=True,
specialize_int=True,

View File

@ -1829,6 +1829,7 @@ class TestDimConstraints(TestCase):
dim_constraints.add(s5 >= 2)
dim_constraints.solve()
dim_constraints.remove_redundant_dynamic_results()
self.assertEqual(dim_constraints._static_results, {
"L['c'].size()[0] == 8",
"L['d'].size()[0] == 8",
@ -1843,7 +1844,6 @@ class TestDimConstraints(TestCase):
})
self.assertEqual(dim_constraints._dynamic_results, {
"dynamic_dim(L['e'], 1) == dynamic_dim(L['c'], 1)",
"2 <= dynamic_dim(L['c'], 1)",
"dynamic_dim(L['d'], 1) == dynamic_dim(L['c'], 1)",
})
@ -1877,9 +1877,6 @@ def specializations(a, b, c, d, e, f):
expected_dynamic = '''
def specify_constraints(a, b, c, d, e, f):
return [
# c:
dynamic_dim(c, 1),
# d:
dynamic_dim(d, 1) == dynamic_dim(c, 1),

View File

@ -1135,6 +1135,7 @@ def export(
and not skipfiles.check(inspect.getsourcefile(call_to_inspect))
):
dim_constraints.solve()
dim_constraints.remove_redundant_dynamic_results()
msg = dim_constraints.prettify_results(original_signature)
forced_specializations = dim_constraints.forced_specializations()
if forced_specializations:

View File

@ -1891,6 +1891,28 @@ class DimConstraints:
if s in self._marked_dynamic
])
def remove_redundant_dynamic_results(self):
candidates_for_removal = []
dynamic_results = set()
for dc in self._dynamic_results:
# Instead of 2 <= dynamic_dim(...) simply suggest dynamic_dim(...).
# There is no change in behavior since 2 is the default lower bound.
dc_ = re.sub(r"2 <= dynamic_dim(.+)", r"dynamic_dim\1", dc)
if dc != dc_:
candidates_for_removal.append(dc_)
else:
dynamic_results.add(dc_)
for dc in candidates_for_removal:
# remove dynamic_dim(t, 0) as a constraint when dynamic_dim(t, 0) also
# appears as part of another constraint
found = False
for other_dc in dynamic_results:
if dc in other_dc:
found = True
if not found:
dynamic_results.add(dc)
self._dynamic_results = dynamic_results
def prettify_results(self, original_signature: inspect.Signature):
# Note: Model inputs are wrapped as LocalSource in dynamo.
# LocalSource.name() wraps the name with L[""]. We use regular
@ -1927,11 +1949,6 @@ class DimConstraints:
sorted_groups.append((arg, sorted(dcs)))
return sorted_groups
# Instead of 2 <= dynamic_dim(...) simply suggest dynamic_dim(...).
# There is no change in behavior since 2 is the default lower bound.
def remove_default_lower_bound(dc):
return re.sub(r"2 <= dynamic_dim(.+)", r"dynamic_dim\1", dc)
signature = original_signature.replace(return_annotation=inspect.Signature.empty)
args_index = {}
for i, arg in enumerate(signature.parameters.keys()):
@ -1965,13 +1982,13 @@ class DimConstraints:
if self._dynamic_results:
grouped_dynamic_results = group(self._dynamic_results, args_index)
buf += "\nThe following dimensions CAN be dynamic."
buf += "\nYou can use the following code to specify the constraints they must satisfy:"
buf += "\nPlease use the following code to specify the constraints they must satisfy:"
buf += f"\n```\ndef specify_constraints{str(signature)}:"
buf += f"\n{indent}return ["
print_results(
grouped_dynamic_results,
indent * 2,
lambda result: f"{remove_default_lower_bound(result)},",
lambda result: f"{result},",
)
buf += f"\n{indent}]\n```\n"
return buf