diff --git a/docs/source/torch.compiler_api.rst b/docs/source/torch.compiler_api.rst index bcf9772351a..88a373067f1 100644 --- a/docs/source/torch.compiler_api.rst +++ b/docs/source/torch.compiler_api.rst @@ -24,3 +24,4 @@ For a quick overview of ``torch.compiler``, see :ref:`torch.compiler_overview`. cudagraph_mark_step_begin is_compiling is_dynamo_compiling + is_exporting diff --git a/docs/source/torch.compiler_fine_grain_apis.rst b/docs/source/torch.compiler_fine_grain_apis.rst index 9c0ebf29187..7f61d88a269 100644 --- a/docs/source/torch.compiler_fine_grain_apis.rst +++ b/docs/source/torch.compiler_fine_grain_apis.rst @@ -28,6 +28,7 @@ disable compilation are listed in the following table: "``torch._dynamo.graph_break``", "Adds a graph break. The code before and after the graph break goes through TorchDynamo.", "**Rarely useful for deployment** - If you think you need this, most probably you need either ``disable`` or ``disallow_in_graph``." "``torch.compiler.is_compiling``", "Indicates whether a graph is executed/traced as part of torch.compile() or torch.export()." "``torch.compiler.is_dynamo_compiling``", "Indicates whether a graph is traced via TorchDynamo. It's stricter than torch.compiler.is_compiling() flag, as it would only be set to True when TorchDynamo is used." + "``torch.compiler.is_exporting``", "Indicates whether a graph is traced via export. It's stricter than torch.compiler.is_compiling() flag, as it would only be set to True when torch.export is used." ``torch.compiler.disable`` ~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/test/export/test_export.py b/test/export/test_export.py index 78edb3e0f60..292d3ec686f 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -1881,6 +1881,52 @@ def forward(self, p_linear_weight, p_linear_bias, x): ] self.assertEqual(actual_torch_fns, exp_torch_fns) + def test_is_exporting(self): + class Mod(torch.nn.Module): + def forward(self, pred, x): + def f(x): + return x.sin() if torch.compiler.is_exporting() else x.cos() + + y = f(x) + + def true_fn(x): + return f(x) - 1 if torch.compiler.is_exporting() else f(x) + 1 + + def false_fn(x): + return f(x) + 1 if torch.compiler.is_exporting() else f(x) - 1 + + return torch.cond(pred, true_fn, false_fn, (x,)) * y + + ep = export( + Mod(), + ( + torch.tensor(False), + torch.randn(3, 4), + ), + ) + FileCheck().check_count("torch.ops.aten.sin", 1, exactly=True).run( + ep.graph_module.code + ) + FileCheck().check_count("torch.ops.higher_order.cond", 1, exactly=True).run( + ep.graph_module.code + ) + + # True graph should contain sin and sub + FileCheck().check_count("torch.ops.aten.sub", 1, exactly=True).run( + ep.graph_module.true_graph_0.code + ) + FileCheck().check_count("torch.ops.aten.sin", 1, exactly=True).run( + ep.graph_module.true_graph_0.code + ) + + # False graph should contain sin and add + FileCheck().check_count("torch.ops.aten.add", 1, exactly=True).run( + ep.graph_module.false_graph_0.code + ) + FileCheck().check_count("torch.ops.aten.sin", 1, exactly=True).run( + ep.graph_module.false_graph_0.code + ) + def test_duplicate_modules_with_non_persistent_buffers(self): class FooWithBuf(torch.nn.Module): def __init__(self): diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 9e29d308ea7..7fa91fb308b 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -154,6 +154,7 @@ manual_torch_name_rule_map = { "torch._dynamo.external_utils.is_compiling": TorchInGraphFunctionVariable, "torch.compiler.is_compiling": TorchInGraphFunctionVariable, "torch.compiler.is_dynamo_compiling": TorchInGraphFunctionVariable, + "torch.compiler.is_exporting": TorchInGraphFunctionVariable, "torch.autograd._profiler_enabled": SkipFunctionVariable, "torch._C._to_dlpack": SkipFunctionVariable, "torch.to_dlpack": SkipFunctionVariable, diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index bb4198e3869..8021a76ea8a 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -147,6 +147,7 @@ tracing_state_functions = { torch._utils.is_compiling: True, torch.compiler.is_compiling: True, torch.compiler.is_dynamo_compiling: True, + torch.compiler.is_exporting: True, torch.nn.modules.activation._is_make_fx_tracing: False, } @@ -410,6 +411,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): torch._dynamo.external_utils.is_compiling, torch.compiler.is_compiling, torch.compiler.is_dynamo_compiling, + torch.compiler.is_exporting, ): tx.mark_inconsistent_side_effects() return ConstantVariable.create(tracing_state_functions[self.value]) diff --git a/torch/_export/utils.py b/torch/_export/utils.py index 73aded96a11..98af35c3e6d 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -1137,3 +1137,16 @@ def _get_decomp_for_cia(op: "OperatorBase"): ) return functools.partial(_special_op_to_decompose_cia, kernel=op) + + +@contextmanager +def _compiling_state_context(): + old_compiling_flag = torch.compiler._is_compiling_flag + old_exporting_flag = torch.compiler._is_exporting_flag + try: + torch.compiler._is_compiling_flag = True + torch.compiler._is_exporting_flag = True + yield + finally: + torch.compiler._is_compiling_flag = old_compiling_flag + torch.compiler._is_exporting_flag = old_exporting_flag diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py index ca5123314a8..038f36d20e2 100644 --- a/torch/compiler/__init__.py +++ b/torch/compiler/__init__.py @@ -17,6 +17,7 @@ __all__ = [ "wrap_numpy", "is_compiling", "is_dynamo_compiling", + "is_exporting", ] @@ -362,6 +363,7 @@ def wrap_numpy(fn): _is_compiling_flag: bool = False +_is_exporting_flag: bool = False def is_compiling() -> bool: @@ -402,3 +404,21 @@ def is_dynamo_compiling() -> bool: >>> # ...rest of the function... """ return False + + +def is_exporting() -> bool: + """ + Indicated whether we're under exporting. + + It's stricter than is_compiling() flag, as it would only be set to True when + torch.export is used. + + Example:: + + >>> def forward(self, x): + >>> if not torch.compiler.is_exporting(): + >>> pass # ...logic that is not needed in export... + >>> + >>> # ...rest of the function... + """ + return _is_exporting_flag diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 1186bd8e634..cc644f5bf32 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -36,6 +36,7 @@ from torch._export.passes.lift_constants_pass import ( ) from torch._export.utils import ( _collect_param_buffer_metadata, + _compiling_state_context, _populate_param_buffer_metadata_to_new_gm, _update_gm_meta_if_possible, apply_runtime_assertion_pass, @@ -719,15 +720,6 @@ def _export_to_aten_ir( if not pre_dispatch and is_grad_enabled: grad_safe_guard = AutogradStateOpsFailSafeguard() # type: ignore[assignment] - @contextmanager - def _compiling_state_context(): - old_value = torch.compiler._is_compiling_flag - try: - torch.compiler._is_compiling_flag = True - yield - finally: - torch.compiler._is_compiling_flag = old_value - # This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode, # otherwise aot_export_module will error out because it sees a mix of fake_modes. # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about. @@ -1427,15 +1419,6 @@ def _export_to_aten_ir_make_fx( produce_guards_callback=None, transform=lambda x: x, ) -> ATenExportArtifact: - @contextmanager - def _compiling_state_context(): - old_value = torch.compiler._is_compiling_flag - try: - torch.compiler._is_compiling_flag = True - yield - finally: - torch.compiler._is_compiling_flag = old_value - def _make_fx_helper(mod, args, kwargs, **flags): from torch._functorch._aot_autograd.schemas import GraphSignature