mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Building guards should be under metrics_context (#163967)
Differential Revision: [D83354042](https://our.internmc.facebook.com/intern/diff/D83354042) Pull Request resolved: https://github.com/pytorch/pytorch/pull/163967 Approved by: https://github.com/avikchaudhuri
This commit is contained in:
parent
38ed608956
commit
1fdd99de71
|
|
@ -2,14 +2,14 @@ import inspect
|
|||
import logging
|
||||
import traceback
|
||||
from collections import namedtuple
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dynamo.convert_frame import fullgraph_capture, get_traced_fn
|
||||
from torch._dynamo.convert_frame import CaptureOutput, fullgraph_capture, get_traced_fn
|
||||
from torch._dynamo.eval_frame import argument_names
|
||||
from torch._dynamo.utils import dynamo_timed, get_metrics_context
|
||||
from torch._export.utils import _compiling_state_context
|
||||
|
|
@ -23,6 +23,10 @@ from torch.fx.experimental.symbolic_shapes import (
|
|||
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -332,6 +336,68 @@ class DynamoGraphTransformer(torch.fx.Transformer):
|
|||
return result_gm
|
||||
|
||||
|
||||
def _suggest_or_raise_constraint_violation(
|
||||
module_to_trace: torch.nn.Module,
|
||||
orig_callable: Callable, # type: ignore[type-arg]
|
||||
fake_mode: Optional["FakeTensorMode"],
|
||||
graph_capture_output: CaptureOutput,
|
||||
args: Any,
|
||||
kwargs: Any,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]],
|
||||
):
|
||||
constraint_violation_error = None
|
||||
try:
|
||||
# Check if we have any constraint violations
|
||||
fn, _ = get_traced_fn(module_to_trace)
|
||||
graph_capture_output.graph_capture_output.build_guards(fn.__code__)
|
||||
except ConstraintViolationError as e:
|
||||
constraint_violation_error = e
|
||||
|
||||
if (
|
||||
(shape_env := getattr(fake_mode, "shape_env", None)) is not None
|
||||
and (dim_constraints := shape_env.dim_constraints) is not None
|
||||
and not isinstance(
|
||||
module_to_trace.forward,
|
||||
torch._ops.OpOverloadPacket | torch._ops.OpOverload,
|
||||
)
|
||||
):
|
||||
dim_constraints.solve()
|
||||
forced_specializations = dim_constraints.forced_specializations()
|
||||
msg = dim_constraints.prettify_results(
|
||||
inspect.signature(orig_callable), # type: ignore[attr-defined]
|
||||
dynamic_shapes,
|
||||
constraint_violation_error,
|
||||
forced_specializations,
|
||||
)
|
||||
if constraint_violation_error:
|
||||
constraint_violation_error.args = (
|
||||
constraint_violation_error.args[0] + msg,
|
||||
)
|
||||
else:
|
||||
if forced_specializations:
|
||||
constraint_violation_error = ConstraintViolationError(msg)
|
||||
else:
|
||||
log.info(
|
||||
"Summary of dimension constraints:%s",
|
||||
msg,
|
||||
)
|
||||
|
||||
# Error if we have any constraints on static values
|
||||
for k in shape_env.var_to_range.keys():
|
||||
if isinstance(k, sympy.Integer):
|
||||
constraint_violation_error = ConstraintViolationError(
|
||||
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
|
||||
"It appears that you're trying to set a constraint on a "
|
||||
f"value which we evaluated to have a static value of {k}. "
|
||||
'Set TORCH_LOGS="+export" for more information.'
|
||||
)
|
||||
if constraint_violation_error:
|
||||
constraint_violation_error = post_process_error_msg(
|
||||
constraint_violation_error, orig_callable, args, kwargs
|
||||
)
|
||||
raise constraint_violation_error
|
||||
|
||||
|
||||
def _dynamo_graph_capture_for_export(
|
||||
mod: Callable[..., Any],
|
||||
*,
|
||||
|
|
@ -367,6 +433,7 @@ def _dynamo_graph_capture_for_export(
|
|||
with _compiling_state_context():
|
||||
flat_inputs, in_spec = pytree.tree_flatten((args, kwargs))
|
||||
module_to_trace = ModuleToTrace(mod, in_spec)
|
||||
orig_callable = mod.forward if isinstance(mod, torch.nn.Module) else mod
|
||||
|
||||
constraints: Optional[list[Constraint]] = _constraints
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = (
|
||||
|
|
@ -401,6 +468,27 @@ def _dynamo_graph_capture_for_export(
|
|||
|
||||
assert out.graph_capture_output.output_graph is not None
|
||||
|
||||
example_inputs: list[Any] = []
|
||||
if out.backend_input is not None:
|
||||
graph = out.backend_input.graph_module
|
||||
fake_mode = out.backend_input.fake_mode
|
||||
example_inputs = out.backend_input.example_inputs
|
||||
else:
|
||||
graph = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
|
||||
graph.graph.output(None)
|
||||
graph.recompile()
|
||||
fake_mode = None
|
||||
|
||||
_suggest_or_raise_constraint_violation(
|
||||
module_to_trace,
|
||||
orig_callable,
|
||||
fake_mode,
|
||||
out,
|
||||
args,
|
||||
kwargs,
|
||||
dynamic_shapes,
|
||||
)
|
||||
|
||||
# Extract export metadata from the new location
|
||||
export_metadata = out.graph_capture_output.output_graph.export_metadata
|
||||
graph_inputs = export_metadata.graph_input_idx_to_local_source
|
||||
|
|
@ -408,17 +496,6 @@ def _dynamo_graph_capture_for_export(
|
|||
out_spec = export_metadata.out_spec
|
||||
module_call_spec = export_metadata.module_call_spec
|
||||
|
||||
example_inputs: list[Any] = []
|
||||
if out.backend_input is not None:
|
||||
graph = out.backend_input.graph_module
|
||||
fake_mode = out.backend_input.fake_mode
|
||||
example_inputs = out.backend_input.example_inputs
|
||||
else:
|
||||
graph = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
|
||||
graph.graph.output(None)
|
||||
graph.recompile()
|
||||
fake_mode = None
|
||||
|
||||
# Compute dynamic dimensions for each input based on constraints
|
||||
flat_args_dynamic_dims = [
|
||||
{
|
||||
|
|
@ -453,8 +530,6 @@ def _dynamo_graph_capture_for_export(
|
|||
fake_mode,
|
||||
).transform()
|
||||
|
||||
orig_callable = mod.forward if isinstance(mod, torch.nn.Module) else mod
|
||||
|
||||
# Set up PyTree codegen for proper input/output handling
|
||||
transformed_graph.graph._codegen = _PyTreeCodeGen(
|
||||
_PyTreeInfo(
|
||||
|
|
@ -472,58 +547,6 @@ def _dynamo_graph_capture_for_export(
|
|||
|
||||
transformed_graph.meta["module_call_specs"] = module_call_spec
|
||||
|
||||
constraint_violation_error = None
|
||||
try:
|
||||
# Check if we have any constraint violations
|
||||
fn, _ = get_traced_fn(module_to_trace)
|
||||
out.graph_capture_output.build_guards(fn.__code__)
|
||||
except ConstraintViolationError as e:
|
||||
constraint_violation_error = e
|
||||
|
||||
if (
|
||||
(shape_env := getattr(fake_mode, "shape_env", None)) is not None
|
||||
and (dim_constraints := shape_env.dim_constraints) is not None
|
||||
and not isinstance(
|
||||
module_to_trace.forward,
|
||||
(torch._ops.OpOverloadPacket, torch._ops.OpOverload),
|
||||
)
|
||||
):
|
||||
dim_constraints.solve()
|
||||
forced_specializations = dim_constraints.forced_specializations()
|
||||
msg = dim_constraints.prettify_results(
|
||||
inspect.signature(orig_callable), # type: ignore[attr-defined]
|
||||
dynamic_shapes,
|
||||
constraint_violation_error,
|
||||
forced_specializations,
|
||||
)
|
||||
if constraint_violation_error:
|
||||
constraint_violation_error.args = (
|
||||
constraint_violation_error.args[0] + msg,
|
||||
)
|
||||
else:
|
||||
if forced_specializations:
|
||||
constraint_violation_error = ConstraintViolationError(msg)
|
||||
else:
|
||||
log.info(
|
||||
"Summary of dimension constraints:%s",
|
||||
msg,
|
||||
)
|
||||
|
||||
# Error if we have any constraints on static values
|
||||
for k in shape_env.var_to_range.keys():
|
||||
if isinstance(k, sympy.Integer):
|
||||
constraint_violation_error = ConstraintViolationError(
|
||||
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
|
||||
"It appears that you're trying to set a constraint on a "
|
||||
f"value which we evaluated to have a static value of {k}. "
|
||||
'Set TORCH_LOGS="+export" for more information.'
|
||||
)
|
||||
if constraint_violation_error:
|
||||
constraint_violation_error = post_process_error_msg(
|
||||
constraint_violation_error, orig_callable, args, kwargs
|
||||
)
|
||||
raise constraint_violation_error
|
||||
|
||||
return transformed_graph
|
||||
|
||||
return inner
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user