[ez] Make relaxed constraint error message more user friendly (#151407)

Fixes #151356

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151407
Approved by: https://github.com/Skylion007
This commit is contained in:
bobrenjc93 2025-04-16 03:12:35 -07:00 committed by PyTorch MergeBot
parent cedcdda0ed
commit bc934f57d7
3 changed files with 9 additions and 10 deletions

View File

@ -2516,7 +2516,7 @@ def forward(self, x):
dynamic_shapes = {"x": (dim0,)}
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Not all values.*valid.*inferred to be a constant",
"You marked.*but your code specialized it to be a constant.*less strict API such as maybe_mark_dynamic or Dim.AUTO.",
):
torch.export.export(bar, (t,), dynamic_shapes=dynamic_shapes, strict=True)

View File

@ -2371,7 +2371,7 @@ graph():
}
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
r"Not all values of dy .* in the specified range are valid because dy was inferred to be a constant",
r"You marked.*but your code specialized it to be a constant.*less strict API such as maybe_mark_dynamic or Dim.AUTO.",
):
export(Foo(), inputs, dynamic_shapes=shapes)
@ -3981,7 +3981,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
# 4->5, 4->5, 3->4
bad_args=(torch.randn(5), [torch.randn(5)], {"k": torch.randn(4)}),
run_time_msg="Expected input.*to be equal to 3, but got 4",
compile_time_msg=r"Constraints violated.*\n.*was inferred to be a constant \(3\)",
compile_time_msg=r"You marked.*but your code specialized it to be a constant.*less strict API such as maybe_mark_dynamic or Dim.AUTO.",
)
def test_mismatched_dynamic_shapes(self):
@ -5296,8 +5296,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
(
"Constraints violated \\(batch\\)!(.*\n)*.*"
"batch was inferred to be a constant(.*\n)*.*"
"You marked.*but your code specialized it to be a constant.*less strict API such as maybe_mark_dynamic or Dim.AUTO.(.*\n)*.*"
"Suggested fixes:(.*\n)*.*"
"batch = 10"
),
@ -5463,8 +5462,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
(
"Constraints violated \\(K1\\)!(.*\n)*.*"
"K1 was inferred to be a constant(.*\n)*.*"
"You marked.*but your code specialized it to be a constant.*less strict API such as maybe_mark_dynamic or Dim.AUTO(.*\n)*"
"Suggested fixes:(.*\n)*.*"
"K1 = 3"
),
@ -13194,7 +13192,7 @@ def forward(self, x):
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
r"Not all values of RelaxedUnspecConstraint.* are valid because .* was inferred to be a constant",
r"You marked.*but your code specialized it to be a constant.*less strict API such as maybe_mark_dynamic or Dim.AUTO.",
):
ep = export(
Specialize(),

View File

@ -5129,8 +5129,9 @@ class ShapeEnv:
source, constraint
)
msg = (
f"Not all values of {var_with_range} are valid because "
f"{self._debug_name(source)} was inferred to be a constant ({val})."
f"You marked {self._debug_name(source)} as dynamic but your code "
f"specialized it to be a constant ({val}). Either remove the mark_dynamic "
f"or use a less strict API such as maybe_mark_dynamic or Dim.AUTO."
)
record_constraint_violation(
constraint.warn_only, self._debug_name(source), msg