[export] add is_exporting flag (#142425)

We added an is_export flag under torch.compiler.is_exporting. This comes handy when we try to do some special logic in user-level and system-level (e.g. in upper of the stack).

In increasing-scope:
- `_is_fx_tracing` is set to True when we use under symbolic_trace or make_fx.
- `is_exporting` is set to True when we're doing strict or non-strict export, which internally has a step that calls make_fx and set _is_fx_tracing to be True.
- `is_compiling` is set to True when we're either doing strict, non-strict export or torch.compile.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142425
Approved by: https://github.com/avikchaudhuri
This commit is contained in:
Yidi Wu 2024-12-17 14:33:16 -08:00 committed by PyTorch MergeBot
parent 894d47b91b
commit 1e201422ed
8 changed files with 85 additions and 18 deletions

View File

@ -24,3 +24,4 @@ For a quick overview of ``torch.compiler``, see :ref:`torch.compiler_overview`.
cudagraph_mark_step_begin
is_compiling
is_dynamo_compiling
is_exporting

View File

@ -28,6 +28,7 @@ disable compilation are listed in the following table:
"``torch._dynamo.graph_break``", "Adds a graph break. The code before and after the graph break goes through TorchDynamo.", "**Rarely useful for deployment** - If you think you need this, most probably you need either ``disable`` or ``disallow_in_graph``."
"``torch.compiler.is_compiling``", "Indicates whether a graph is executed/traced as part of torch.compile() or torch.export()."
"``torch.compiler.is_dynamo_compiling``", "Indicates whether a graph is traced via TorchDynamo. It's stricter than torch.compiler.is_compiling() flag, as it would only be set to True when TorchDynamo is used."
"``torch.compiler.is_exporting``", "Indicates whether a graph is traced via export. It's stricter than torch.compiler.is_compiling() flag, as it would only be set to True when torch.export is used."
``torch.compiler.disable``
~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -1881,6 +1881,52 @@ def forward(self, p_linear_weight, p_linear_bias, x):
]
self.assertEqual(actual_torch_fns, exp_torch_fns)
def test_is_exporting(self):
class Mod(torch.nn.Module):
def forward(self, pred, x):
def f(x):
return x.sin() if torch.compiler.is_exporting() else x.cos()
y = f(x)
def true_fn(x):
return f(x) - 1 if torch.compiler.is_exporting() else f(x) + 1
def false_fn(x):
return f(x) + 1 if torch.compiler.is_exporting() else f(x) - 1
return torch.cond(pred, true_fn, false_fn, (x,)) * y
ep = export(
Mod(),
(
torch.tensor(False),
torch.randn(3, 4),
),
)
FileCheck().check_count("torch.ops.aten.sin", 1, exactly=True).run(
ep.graph_module.code
)
FileCheck().check_count("torch.ops.higher_order.cond", 1, exactly=True).run(
ep.graph_module.code
)
# True graph should contain sin and sub
FileCheck().check_count("torch.ops.aten.sub", 1, exactly=True).run(
ep.graph_module.true_graph_0.code
)
FileCheck().check_count("torch.ops.aten.sin", 1, exactly=True).run(
ep.graph_module.true_graph_0.code
)
# False graph should contain sin and add
FileCheck().check_count("torch.ops.aten.add", 1, exactly=True).run(
ep.graph_module.false_graph_0.code
)
FileCheck().check_count("torch.ops.aten.sin", 1, exactly=True).run(
ep.graph_module.false_graph_0.code
)
def test_duplicate_modules_with_non_persistent_buffers(self):
class FooWithBuf(torch.nn.Module):
def __init__(self):

View File

@ -154,6 +154,7 @@ manual_torch_name_rule_map = {
"torch._dynamo.external_utils.is_compiling": TorchInGraphFunctionVariable,
"torch.compiler.is_compiling": TorchInGraphFunctionVariable,
"torch.compiler.is_dynamo_compiling": TorchInGraphFunctionVariable,
"torch.compiler.is_exporting": TorchInGraphFunctionVariable,
"torch.autograd._profiler_enabled": SkipFunctionVariable,
"torch._C._to_dlpack": SkipFunctionVariable,
"torch.to_dlpack": SkipFunctionVariable,

View File

@ -147,6 +147,7 @@ tracing_state_functions = {
torch._utils.is_compiling: True,
torch.compiler.is_compiling: True,
torch.compiler.is_dynamo_compiling: True,
torch.compiler.is_exporting: True,
torch.nn.modules.activation._is_make_fx_tracing: False,
}
@ -410,6 +411,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
torch._dynamo.external_utils.is_compiling,
torch.compiler.is_compiling,
torch.compiler.is_dynamo_compiling,
torch.compiler.is_exporting,
):
tx.mark_inconsistent_side_effects()
return ConstantVariable.create(tracing_state_functions[self.value])

View File

@ -1137,3 +1137,16 @@ def _get_decomp_for_cia(op: "OperatorBase"):
)
return functools.partial(_special_op_to_decompose_cia, kernel=op)
@contextmanager
def _compiling_state_context():
old_compiling_flag = torch.compiler._is_compiling_flag
old_exporting_flag = torch.compiler._is_exporting_flag
try:
torch.compiler._is_compiling_flag = True
torch.compiler._is_exporting_flag = True
yield
finally:
torch.compiler._is_compiling_flag = old_compiling_flag
torch.compiler._is_exporting_flag = old_exporting_flag

View File

@ -17,6 +17,7 @@ __all__ = [
"wrap_numpy",
"is_compiling",
"is_dynamo_compiling",
"is_exporting",
]
@ -362,6 +363,7 @@ def wrap_numpy(fn):
_is_compiling_flag: bool = False
_is_exporting_flag: bool = False
def is_compiling() -> bool:
@ -402,3 +404,21 @@ def is_dynamo_compiling() -> bool:
>>> # ...rest of the function...
"""
return False
def is_exporting() -> bool:
"""
Indicated whether we're under exporting.
It's stricter than is_compiling() flag, as it would only be set to True when
torch.export is used.
Example::
>>> def forward(self, x):
>>> if not torch.compiler.is_exporting():
>>> pass # ...logic that is not needed in export...
>>>
>>> # ...rest of the function...
"""
return _is_exporting_flag

View File

@ -36,6 +36,7 @@ from torch._export.passes.lift_constants_pass import (
)
from torch._export.utils import (
_collect_param_buffer_metadata,
_compiling_state_context,
_populate_param_buffer_metadata_to_new_gm,
_update_gm_meta_if_possible,
apply_runtime_assertion_pass,
@ -719,15 +720,6 @@ def _export_to_aten_ir(
if not pre_dispatch and is_grad_enabled:
grad_safe_guard = AutogradStateOpsFailSafeguard() # type: ignore[assignment]
@contextmanager
def _compiling_state_context():
old_value = torch.compiler._is_compiling_flag
try:
torch.compiler._is_compiling_flag = True
yield
finally:
torch.compiler._is_compiling_flag = old_value
# This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode,
# otherwise aot_export_module will error out because it sees a mix of fake_modes.
# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
@ -1427,15 +1419,6 @@ def _export_to_aten_ir_make_fx(
produce_guards_callback=None,
transform=lambda x: x,
) -> ATenExportArtifact:
@contextmanager
def _compiling_state_context():
old_value = torch.compiler._is_compiling_flag
try:
torch.compiler._is_compiling_flag = True
yield
finally:
torch.compiler._is_compiling_flag = old_value
def _make_fx_helper(mod, args, kwargs, **flags):
from torch._functorch._aot_autograd.schemas import GraphSignature