mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ez] Make relaxed constraint error message more user friendly (#151407)
Fixes #151356 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151407 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
cedcdda0ed
commit
bc934f57d7
|
|
@ -2516,7 +2516,7 @@ def forward(self, x):
|
|||
dynamic_shapes = {"x": (dim0,)}
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
"Not all values.*valid.*inferred to be a constant",
|
||||
"You marked.*but your code specialized it to be a constant.*less strict API such as maybe_mark_dynamic or Dim.AUTO.",
|
||||
):
|
||||
torch.export.export(bar, (t,), dynamic_shapes=dynamic_shapes, strict=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -2371,7 +2371,7 @@ graph():
|
|||
}
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
r"Not all values of dy .* in the specified range are valid because dy was inferred to be a constant",
|
||||
r"You marked.*but your code specialized it to be a constant.*less strict API such as maybe_mark_dynamic or Dim.AUTO.",
|
||||
):
|
||||
export(Foo(), inputs, dynamic_shapes=shapes)
|
||||
|
||||
|
|
@ -3981,7 +3981,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
# 4->5, 4->5, 3->4
|
||||
bad_args=(torch.randn(5), [torch.randn(5)], {"k": torch.randn(4)}),
|
||||
run_time_msg="Expected input.*to be equal to 3, but got 4",
|
||||
compile_time_msg=r"Constraints violated.*\n.*was inferred to be a constant \(3\)",
|
||||
compile_time_msg=r"You marked.*but your code specialized it to be a constant.*less strict API such as maybe_mark_dynamic or Dim.AUTO.",
|
||||
)
|
||||
|
||||
def test_mismatched_dynamic_shapes(self):
|
||||
|
|
@ -5296,8 +5296,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
(
|
||||
"Constraints violated \\(batch\\)!(.*\n)*.*"
|
||||
"batch was inferred to be a constant(.*\n)*.*"
|
||||
"You marked.*but your code specialized it to be a constant.*less strict API such as maybe_mark_dynamic or Dim.AUTO.(.*\n)*.*"
|
||||
"Suggested fixes:(.*\n)*.*"
|
||||
"batch = 10"
|
||||
),
|
||||
|
|
@ -5463,8 +5462,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
(
|
||||
"Constraints violated \\(K1\\)!(.*\n)*.*"
|
||||
"K1 was inferred to be a constant(.*\n)*.*"
|
||||
"You marked.*but your code specialized it to be a constant.*less strict API such as maybe_mark_dynamic or Dim.AUTO(.*\n)*"
|
||||
"Suggested fixes:(.*\n)*.*"
|
||||
"K1 = 3"
|
||||
),
|
||||
|
|
@ -13194,7 +13192,7 @@ def forward(self, x):
|
|||
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
r"Not all values of RelaxedUnspecConstraint.* are valid because .* was inferred to be a constant",
|
||||
r"You marked.*but your code specialized it to be a constant.*less strict API such as maybe_mark_dynamic or Dim.AUTO.",
|
||||
):
|
||||
ep = export(
|
||||
Specialize(),
|
||||
|
|
|
|||
|
|
@ -5129,8 +5129,9 @@ class ShapeEnv:
|
|||
source, constraint
|
||||
)
|
||||
msg = (
|
||||
f"Not all values of {var_with_range} are valid because "
|
||||
f"{self._debug_name(source)} was inferred to be a constant ({val})."
|
||||
f"You marked {self._debug_name(source)} as dynamic but your code "
|
||||
f"specialized it to be a constant ({val}). Either remove the mark_dynamic "
|
||||
f"or use a less strict API such as maybe_mark_dynamic or Dim.AUTO."
|
||||
)
|
||||
record_constraint_violation(
|
||||
constraint.warn_only, self._debug_name(source), msg
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user