mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[export] add is_exporting flag (#142425)
We added an is_export flag under torch.compiler.is_exporting. This comes handy when we try to do some special logic in user-level and system-level (e.g. in upper of the stack). In increasing-scope: - `_is_fx_tracing` is set to True when we use under symbolic_trace or make_fx. - `is_exporting` is set to True when we're doing strict or non-strict export, which internally has a step that calls make_fx and set _is_fx_tracing to be True. - `is_compiling` is set to True when we're either doing strict, non-strict export or torch.compile. Pull Request resolved: https://github.com/pytorch/pytorch/pull/142425 Approved by: https://github.com/avikchaudhuri
This commit is contained in:
parent
894d47b91b
commit
1e201422ed
|
|
@ -24,3 +24,4 @@ For a quick overview of ``torch.compiler``, see :ref:`torch.compiler_overview`.
|
|||
cudagraph_mark_step_begin
|
||||
is_compiling
|
||||
is_dynamo_compiling
|
||||
is_exporting
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ disable compilation are listed in the following table:
|
|||
"``torch._dynamo.graph_break``", "Adds a graph break. The code before and after the graph break goes through TorchDynamo.", "**Rarely useful for deployment** - If you think you need this, most probably you need either ``disable`` or ``disallow_in_graph``."
|
||||
"``torch.compiler.is_compiling``", "Indicates whether a graph is executed/traced as part of torch.compile() or torch.export()."
|
||||
"``torch.compiler.is_dynamo_compiling``", "Indicates whether a graph is traced via TorchDynamo. It's stricter than torch.compiler.is_compiling() flag, as it would only be set to True when TorchDynamo is used."
|
||||
"``torch.compiler.is_exporting``", "Indicates whether a graph is traced via export. It's stricter than torch.compiler.is_compiling() flag, as it would only be set to True when torch.export is used."
|
||||
|
||||
``torch.compiler.disable``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
|
|
|||
|
|
@ -1881,6 +1881,52 @@ def forward(self, p_linear_weight, p_linear_bias, x):
|
|||
]
|
||||
self.assertEqual(actual_torch_fns, exp_torch_fns)
|
||||
|
||||
def test_is_exporting(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, pred, x):
|
||||
def f(x):
|
||||
return x.sin() if torch.compiler.is_exporting() else x.cos()
|
||||
|
||||
y = f(x)
|
||||
|
||||
def true_fn(x):
|
||||
return f(x) - 1 if torch.compiler.is_exporting() else f(x) + 1
|
||||
|
||||
def false_fn(x):
|
||||
return f(x) + 1 if torch.compiler.is_exporting() else f(x) - 1
|
||||
|
||||
return torch.cond(pred, true_fn, false_fn, (x,)) * y
|
||||
|
||||
ep = export(
|
||||
Mod(),
|
||||
(
|
||||
torch.tensor(False),
|
||||
torch.randn(3, 4),
|
||||
),
|
||||
)
|
||||
FileCheck().check_count("torch.ops.aten.sin", 1, exactly=True).run(
|
||||
ep.graph_module.code
|
||||
)
|
||||
FileCheck().check_count("torch.ops.higher_order.cond", 1, exactly=True).run(
|
||||
ep.graph_module.code
|
||||
)
|
||||
|
||||
# True graph should contain sin and sub
|
||||
FileCheck().check_count("torch.ops.aten.sub", 1, exactly=True).run(
|
||||
ep.graph_module.true_graph_0.code
|
||||
)
|
||||
FileCheck().check_count("torch.ops.aten.sin", 1, exactly=True).run(
|
||||
ep.graph_module.true_graph_0.code
|
||||
)
|
||||
|
||||
# False graph should contain sin and add
|
||||
FileCheck().check_count("torch.ops.aten.add", 1, exactly=True).run(
|
||||
ep.graph_module.false_graph_0.code
|
||||
)
|
||||
FileCheck().check_count("torch.ops.aten.sin", 1, exactly=True).run(
|
||||
ep.graph_module.false_graph_0.code
|
||||
)
|
||||
|
||||
def test_duplicate_modules_with_non_persistent_buffers(self):
|
||||
class FooWithBuf(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -154,6 +154,7 @@ manual_torch_name_rule_map = {
|
|||
"torch._dynamo.external_utils.is_compiling": TorchInGraphFunctionVariable,
|
||||
"torch.compiler.is_compiling": TorchInGraphFunctionVariable,
|
||||
"torch.compiler.is_dynamo_compiling": TorchInGraphFunctionVariable,
|
||||
"torch.compiler.is_exporting": TorchInGraphFunctionVariable,
|
||||
"torch.autograd._profiler_enabled": SkipFunctionVariable,
|
||||
"torch._C._to_dlpack": SkipFunctionVariable,
|
||||
"torch.to_dlpack": SkipFunctionVariable,
|
||||
|
|
|
|||
|
|
@ -147,6 +147,7 @@ tracing_state_functions = {
|
|||
torch._utils.is_compiling: True,
|
||||
torch.compiler.is_compiling: True,
|
||||
torch.compiler.is_dynamo_compiling: True,
|
||||
torch.compiler.is_exporting: True,
|
||||
torch.nn.modules.activation._is_make_fx_tracing: False,
|
||||
}
|
||||
|
||||
|
|
@ -410,6 +411,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||
torch._dynamo.external_utils.is_compiling,
|
||||
torch.compiler.is_compiling,
|
||||
torch.compiler.is_dynamo_compiling,
|
||||
torch.compiler.is_exporting,
|
||||
):
|
||||
tx.mark_inconsistent_side_effects()
|
||||
return ConstantVariable.create(tracing_state_functions[self.value])
|
||||
|
|
|
|||
|
|
@ -1137,3 +1137,16 @@ def _get_decomp_for_cia(op: "OperatorBase"):
|
|||
)
|
||||
|
||||
return functools.partial(_special_op_to_decompose_cia, kernel=op)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _compiling_state_context():
|
||||
old_compiling_flag = torch.compiler._is_compiling_flag
|
||||
old_exporting_flag = torch.compiler._is_exporting_flag
|
||||
try:
|
||||
torch.compiler._is_compiling_flag = True
|
||||
torch.compiler._is_exporting_flag = True
|
||||
yield
|
||||
finally:
|
||||
torch.compiler._is_compiling_flag = old_compiling_flag
|
||||
torch.compiler._is_exporting_flag = old_exporting_flag
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ __all__ = [
|
|||
"wrap_numpy",
|
||||
"is_compiling",
|
||||
"is_dynamo_compiling",
|
||||
"is_exporting",
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -362,6 +363,7 @@ def wrap_numpy(fn):
|
|||
|
||||
|
||||
_is_compiling_flag: bool = False
|
||||
_is_exporting_flag: bool = False
|
||||
|
||||
|
||||
def is_compiling() -> bool:
|
||||
|
|
@ -402,3 +404,21 @@ def is_dynamo_compiling() -> bool:
|
|||
>>> # ...rest of the function...
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
def is_exporting() -> bool:
|
||||
"""
|
||||
Indicated whether we're under exporting.
|
||||
|
||||
It's stricter than is_compiling() flag, as it would only be set to True when
|
||||
torch.export is used.
|
||||
|
||||
Example::
|
||||
|
||||
>>> def forward(self, x):
|
||||
>>> if not torch.compiler.is_exporting():
|
||||
>>> pass # ...logic that is not needed in export...
|
||||
>>>
|
||||
>>> # ...rest of the function...
|
||||
"""
|
||||
return _is_exporting_flag
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ from torch._export.passes.lift_constants_pass import (
|
|||
)
|
||||
from torch._export.utils import (
|
||||
_collect_param_buffer_metadata,
|
||||
_compiling_state_context,
|
||||
_populate_param_buffer_metadata_to_new_gm,
|
||||
_update_gm_meta_if_possible,
|
||||
apply_runtime_assertion_pass,
|
||||
|
|
@ -719,15 +720,6 @@ def _export_to_aten_ir(
|
|||
if not pre_dispatch and is_grad_enabled:
|
||||
grad_safe_guard = AutogradStateOpsFailSafeguard() # type: ignore[assignment]
|
||||
|
||||
@contextmanager
|
||||
def _compiling_state_context():
|
||||
old_value = torch.compiler._is_compiling_flag
|
||||
try:
|
||||
torch.compiler._is_compiling_flag = True
|
||||
yield
|
||||
finally:
|
||||
torch.compiler._is_compiling_flag = old_value
|
||||
|
||||
# This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode,
|
||||
# otherwise aot_export_module will error out because it sees a mix of fake_modes.
|
||||
# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
|
||||
|
|
@ -1427,15 +1419,6 @@ def _export_to_aten_ir_make_fx(
|
|||
produce_guards_callback=None,
|
||||
transform=lambda x: x,
|
||||
) -> ATenExportArtifact:
|
||||
@contextmanager
|
||||
def _compiling_state_context():
|
||||
old_value = torch.compiler._is_compiling_flag
|
||||
try:
|
||||
torch.compiler._is_compiling_flag = True
|
||||
yield
|
||||
finally:
|
||||
torch.compiler._is_compiling_flag = old_value
|
||||
|
||||
def _make_fx_helper(mod, args, kwargs, **flags):
|
||||
from torch._functorch._aot_autograd.schemas import GraphSignature
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user