mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
remove redundant dynamic_dim (#107815)
Differential Revision: D48618472 Pull Request resolved: https://github.com/pytorch/pytorch/pull/107815 Approved by: https://github.com/tugsbayasgalan, https://github.com/gmagogsfm
This commit is contained in:
parent
8354d32f6b
commit
cf76938f70
|
|
@ -2312,6 +2312,23 @@ def forward(self, x):
|
|||
constraints.append(dynamic_dim(z, 0) == dynamic_dim(x, 0))
|
||||
torch._dynamo.export(my_dyn_fn, constraints=constraints)(x, y, z)
|
||||
|
||||
def test_remove_redundant_dynamic_dim_in_error_message(self):
|
||||
def foo(x, y):
|
||||
if x.shape[0] == y["k"].shape[0]:
|
||||
return x + 1
|
||||
else:
|
||||
return x - 1
|
||||
|
||||
a = torch.randn(3)
|
||||
b = torch.randn(3)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
"\\[\n.*\n.*dynamic_dim.*==.*dynamic_dim.*\n.*\\]",
|
||||
):
|
||||
torch._export.export(
|
||||
foo, (a, {"k": b}), constraints=[dynamic_dim(a, 0), dynamic_dim(b, 0)]
|
||||
)
|
||||
|
||||
@config.patch(
|
||||
capture_dynamic_output_shape_ops=True,
|
||||
specialize_int=True,
|
||||
|
|
|
|||
|
|
@ -1829,6 +1829,7 @@ class TestDimConstraints(TestCase):
|
|||
dim_constraints.add(s5 >= 2)
|
||||
|
||||
dim_constraints.solve()
|
||||
dim_constraints.remove_redundant_dynamic_results()
|
||||
self.assertEqual(dim_constraints._static_results, {
|
||||
"L['c'].size()[0] == 8",
|
||||
"L['d'].size()[0] == 8",
|
||||
|
|
@ -1843,7 +1844,6 @@ class TestDimConstraints(TestCase):
|
|||
})
|
||||
self.assertEqual(dim_constraints._dynamic_results, {
|
||||
"dynamic_dim(L['e'], 1) == dynamic_dim(L['c'], 1)",
|
||||
"2 <= dynamic_dim(L['c'], 1)",
|
||||
"dynamic_dim(L['d'], 1) == dynamic_dim(L['c'], 1)",
|
||||
})
|
||||
|
||||
|
|
@ -1877,9 +1877,6 @@ def specializations(a, b, c, d, e, f):
|
|||
expected_dynamic = '''
|
||||
def specify_constraints(a, b, c, d, e, f):
|
||||
return [
|
||||
# c:
|
||||
dynamic_dim(c, 1),
|
||||
|
||||
# d:
|
||||
dynamic_dim(d, 1) == dynamic_dim(c, 1),
|
||||
|
||||
|
|
|
|||
|
|
@ -1135,6 +1135,7 @@ def export(
|
|||
and not skipfiles.check(inspect.getsourcefile(call_to_inspect))
|
||||
):
|
||||
dim_constraints.solve()
|
||||
dim_constraints.remove_redundant_dynamic_results()
|
||||
msg = dim_constraints.prettify_results(original_signature)
|
||||
forced_specializations = dim_constraints.forced_specializations()
|
||||
if forced_specializations:
|
||||
|
|
|
|||
|
|
@ -1891,6 +1891,28 @@ class DimConstraints:
|
|||
if s in self._marked_dynamic
|
||||
])
|
||||
|
||||
def remove_redundant_dynamic_results(self):
|
||||
candidates_for_removal = []
|
||||
dynamic_results = set()
|
||||
for dc in self._dynamic_results:
|
||||
# Instead of 2 <= dynamic_dim(...) simply suggest dynamic_dim(...).
|
||||
# There is no change in behavior since 2 is the default lower bound.
|
||||
dc_ = re.sub(r"2 <= dynamic_dim(.+)", r"dynamic_dim\1", dc)
|
||||
if dc != dc_:
|
||||
candidates_for_removal.append(dc_)
|
||||
else:
|
||||
dynamic_results.add(dc_)
|
||||
for dc in candidates_for_removal:
|
||||
# remove dynamic_dim(t, 0) as a constraint when dynamic_dim(t, 0) also
|
||||
# appears as part of another constraint
|
||||
found = False
|
||||
for other_dc in dynamic_results:
|
||||
if dc in other_dc:
|
||||
found = True
|
||||
if not found:
|
||||
dynamic_results.add(dc)
|
||||
self._dynamic_results = dynamic_results
|
||||
|
||||
def prettify_results(self, original_signature: inspect.Signature):
|
||||
# Note: Model inputs are wrapped as LocalSource in dynamo.
|
||||
# LocalSource.name() wraps the name with L[""]. We use regular
|
||||
|
|
@ -1927,11 +1949,6 @@ class DimConstraints:
|
|||
sorted_groups.append((arg, sorted(dcs)))
|
||||
return sorted_groups
|
||||
|
||||
# Instead of 2 <= dynamic_dim(...) simply suggest dynamic_dim(...).
|
||||
# There is no change in behavior since 2 is the default lower bound.
|
||||
def remove_default_lower_bound(dc):
|
||||
return re.sub(r"2 <= dynamic_dim(.+)", r"dynamic_dim\1", dc)
|
||||
|
||||
signature = original_signature.replace(return_annotation=inspect.Signature.empty)
|
||||
args_index = {}
|
||||
for i, arg in enumerate(signature.parameters.keys()):
|
||||
|
|
@ -1965,13 +1982,13 @@ class DimConstraints:
|
|||
if self._dynamic_results:
|
||||
grouped_dynamic_results = group(self._dynamic_results, args_index)
|
||||
buf += "\nThe following dimensions CAN be dynamic."
|
||||
buf += "\nYou can use the following code to specify the constraints they must satisfy:"
|
||||
buf += "\nPlease use the following code to specify the constraints they must satisfy:"
|
||||
buf += f"\n```\ndef specify_constraints{str(signature)}:"
|
||||
buf += f"\n{indent}return ["
|
||||
print_results(
|
||||
grouped_dynamic_results,
|
||||
indent * 2,
|
||||
lambda result: f"{remove_default_lower_bound(result)},",
|
||||
lambda result: f"{result},",
|
||||
)
|
||||
buf += f"\n{indent}]\n```\n"
|
||||
return buf
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user