Building guards should be under metrics_context (#163967)

Differential Revision: [D83354042](https://our.internmc.facebook.com/intern/diff/D83354042)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163967
Approved by: https://github.com/avikchaudhuri
This commit is contained in:
Tugsbayasgalan Manlaibaatar 2025-09-26 13:47:59 -07:00 committed by PyTorch MergeBot
parent 38ed608956
commit 1fdd99de71

View File

@ -2,14 +2,14 @@ import inspect
import logging
import traceback
from collections import namedtuple
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
import sympy
import torch
import torch.fx
import torch.utils._pytree as pytree
from torch._dynamo.convert_frame import fullgraph_capture, get_traced_fn
from torch._dynamo.convert_frame import CaptureOutput, fullgraph_capture, get_traced_fn
from torch._dynamo.eval_frame import argument_names
from torch._dynamo.utils import dynamo_timed, get_metrics_context
from torch._export.utils import _compiling_state_context
@ -23,6 +23,10 @@ from torch.fx.experimental.symbolic_shapes import (
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
if TYPE_CHECKING:
from torch._subclasses.fake_tensor import FakeTensorMode
log = logging.getLogger(__name__)
@ -332,6 +336,68 @@ class DynamoGraphTransformer(torch.fx.Transformer):
return result_gm
def _suggest_or_raise_constraint_violation(
module_to_trace: torch.nn.Module,
orig_callable: Callable, # type: ignore[type-arg]
fake_mode: Optional["FakeTensorMode"],
graph_capture_output: CaptureOutput,
args: Any,
kwargs: Any,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]],
):
constraint_violation_error = None
try:
# Check if we have any constraint violations
fn, _ = get_traced_fn(module_to_trace)
graph_capture_output.graph_capture_output.build_guards(fn.__code__)
except ConstraintViolationError as e:
constraint_violation_error = e
if (
(shape_env := getattr(fake_mode, "shape_env", None)) is not None
and (dim_constraints := shape_env.dim_constraints) is not None
and not isinstance(
module_to_trace.forward,
torch._ops.OpOverloadPacket | torch._ops.OpOverload,
)
):
dim_constraints.solve()
forced_specializations = dim_constraints.forced_specializations()
msg = dim_constraints.prettify_results(
inspect.signature(orig_callable), # type: ignore[attr-defined]
dynamic_shapes,
constraint_violation_error,
forced_specializations,
)
if constraint_violation_error:
constraint_violation_error.args = (
constraint_violation_error.args[0] + msg,
)
else:
if forced_specializations:
constraint_violation_error = ConstraintViolationError(msg)
else:
log.info(
"Summary of dimension constraints:%s",
msg,
)
# Error if we have any constraints on static values
for k in shape_env.var_to_range.keys():
if isinstance(k, sympy.Integer):
constraint_violation_error = ConstraintViolationError(
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
"It appears that you're trying to set a constraint on a "
f"value which we evaluated to have a static value of {k}. "
'Set TORCH_LOGS="+export" for more information.'
)
if constraint_violation_error:
constraint_violation_error = post_process_error_msg(
constraint_violation_error, orig_callable, args, kwargs
)
raise constraint_violation_error
def _dynamo_graph_capture_for_export(
mod: Callable[..., Any],
*,
@ -367,6 +433,7 @@ def _dynamo_graph_capture_for_export(
with _compiling_state_context():
flat_inputs, in_spec = pytree.tree_flatten((args, kwargs))
module_to_trace = ModuleToTrace(mod, in_spec)
orig_callable = mod.forward if isinstance(mod, torch.nn.Module) else mod
constraints: Optional[list[Constraint]] = _constraints
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = (
@ -401,6 +468,27 @@ def _dynamo_graph_capture_for_export(
assert out.graph_capture_output.output_graph is not None
example_inputs: list[Any] = []
if out.backend_input is not None:
graph = out.backend_input.graph_module
fake_mode = out.backend_input.fake_mode
example_inputs = out.backend_input.example_inputs
else:
graph = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
graph.graph.output(None)
graph.recompile()
fake_mode = None
_suggest_or_raise_constraint_violation(
module_to_trace,
orig_callable,
fake_mode,
out,
args,
kwargs,
dynamic_shapes,
)
# Extract export metadata from the new location
export_metadata = out.graph_capture_output.output_graph.export_metadata
graph_inputs = export_metadata.graph_input_idx_to_local_source
@ -408,17 +496,6 @@ def _dynamo_graph_capture_for_export(
out_spec = export_metadata.out_spec
module_call_spec = export_metadata.module_call_spec
example_inputs: list[Any] = []
if out.backend_input is not None:
graph = out.backend_input.graph_module
fake_mode = out.backend_input.fake_mode
example_inputs = out.backend_input.example_inputs
else:
graph = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
graph.graph.output(None)
graph.recompile()
fake_mode = None
# Compute dynamic dimensions for each input based on constraints
flat_args_dynamic_dims = [
{
@ -453,8 +530,6 @@ def _dynamo_graph_capture_for_export(
fake_mode,
).transform()
orig_callable = mod.forward if isinstance(mod, torch.nn.Module) else mod
# Set up PyTree codegen for proper input/output handling
transformed_graph.graph._codegen = _PyTreeCodeGen(
_PyTreeInfo(
@ -472,58 +547,6 @@ def _dynamo_graph_capture_for_export(
transformed_graph.meta["module_call_specs"] = module_call_spec
constraint_violation_error = None
try:
# Check if we have any constraint violations
fn, _ = get_traced_fn(module_to_trace)
out.graph_capture_output.build_guards(fn.__code__)
except ConstraintViolationError as e:
constraint_violation_error = e
if (
(shape_env := getattr(fake_mode, "shape_env", None)) is not None
and (dim_constraints := shape_env.dim_constraints) is not None
and not isinstance(
module_to_trace.forward,
(torch._ops.OpOverloadPacket, torch._ops.OpOverload),
)
):
dim_constraints.solve()
forced_specializations = dim_constraints.forced_specializations()
msg = dim_constraints.prettify_results(
inspect.signature(orig_callable), # type: ignore[attr-defined]
dynamic_shapes,
constraint_violation_error,
forced_specializations,
)
if constraint_violation_error:
constraint_violation_error.args = (
constraint_violation_error.args[0] + msg,
)
else:
if forced_specializations:
constraint_violation_error = ConstraintViolationError(msg)
else:
log.info(
"Summary of dimension constraints:%s",
msg,
)
# Error if we have any constraints on static values
for k in shape_env.var_to_range.keys():
if isinstance(k, sympy.Integer):
constraint_violation_error = ConstraintViolationError(
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
"It appears that you're trying to set a constraint on a "
f"value which we evaluated to have a static value of {k}. "
'Set TORCH_LOGS="+export" for more information.'
)
if constraint_violation_error:
constraint_violation_error = post_process_error_msg(
constraint_violation_error, orig_callable, args, kwargs
)
raise constraint_violation_error
return transformed_graph
return inner