From cbee9c1fd2b08b04af5b0bdf472254ffd3bbe231 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 7 Aug 2024 00:05:20 +0000 Subject: [PATCH] Revert "Deprecate `torch._utils.is_compiling()` and `torch._dynamo.external_utils.is_compiling()` (#127690)" This reverts commit 0e7e61f7cec82a43f2de52b83eff152d703be7a3. Reverted https://github.com/pytorch/pytorch/pull/127690 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/127690#issuecomment-2272370386)) --- test/dynamo/test_skip_non_tensor.py | 10 +++++----- test/export/test_torchbind.py | 2 +- test/functorch/test_memory_efficient_fusion.py | 2 +- test/inductor/test_distributed_patterns.py | 4 ++-- test/test_optim.py | 2 +- torch/_dynamo/decorators.py | 3 ++- torch/_dynamo/external_utils.py | 5 ----- torch/_functorch/apis.py | 6 +++--- torch/_functorch/eager_transforms.py | 4 ++-- torch/_higher_order_ops/associative_scan.py | 2 +- torch/_utils.py | 6 +----- .../_composable/fsdp/_fsdp_common.py | 17 +++++++++-------- .../algorithms/ddp_comm_hooks/default_hooks.py | 4 ++-- torch/distributed/tensor/parallel/_utils.py | 14 ++++++++------ torch/nn/parallel/distributed.py | 4 ++-- torch/optim/_adafactor.py | 2 +- torch/optim/adadelta.py | 8 ++++---- torch/optim/adagrad.py | 2 +- torch/optim/adam.py | 8 ++++---- torch/optim/adamax.py | 8 ++++---- torch/optim/adamw.py | 8 ++++---- torch/optim/asgd.py | 6 +++--- torch/optim/nadam.py | 6 +++--- torch/optim/optimizer.py | 11 ++++++----- torch/optim/radam.py | 6 +++--- torch/optim/rmsprop.py | 8 ++++---- torch/optim/rprop.py | 8 ++++---- torch/optim/sgd.py | 2 +- .../testing/_internal/optests/generate_tests.py | 2 +- 29 files changed, 83 insertions(+), 87 deletions(-) diff --git a/test/dynamo/test_skip_non_tensor.py b/test/dynamo/test_skip_non_tensor.py index 48c4022ef28..72153d26a1f 100644 --- a/test/dynamo/test_skip_non_tensor.py +++ b/test/dynamo/test_skip_non_tensor.py @@ -12,12 +12,12 @@ _variable_2 = 0 def user_function(): - return torch.compiler.is_compiling() + return torch._utils.is_compiling() def user_generator(): for _ in range(1): - yield torch.compiler.is_compiling() + yield torch._utils.is_compiling() return @@ -38,7 +38,7 @@ class MyModule(torch.nn.Module): global _variable, _variable_2 if self.mode == 1: - if torch.compiler.is_compiling(): + if torch._utils.is_compiling(): _variable += 1 else: _variable_2 += 1 @@ -46,7 +46,7 @@ class MyModule(torch.nn.Module): if user_function(): _variable += 1 elif self.mode == 3: - lambda_f = lambda: torch.compiler.is_compiling() # noqa: E731 + lambda_f = lambda: torch._utils.is_compiling() # noqa: E731 if lambda_f(): _variable += 1 elif self.mode == 4: @@ -163,7 +163,7 @@ class SkipNonTensorTests(torch._dynamo.test_case.TestCase): def test_do_not_skip_side_effects(self): # https://github.com/pytorch/pytorch/issues/110765 - # By invoking torch.compiler.is_compiling(), + # By invoking torch._utils.is_compiling(), # there may be side-effects inconsistent with eager when # compiling. Thus we force dynamo to commit the graph, # even if it does not perform any tensor operation diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py index 98aa36c0a3b..119064e1dd5 100644 --- a/test/export/test_torchbind.py +++ b/test/export/test_torchbind.py @@ -1281,7 +1281,7 @@ class TestCompileTorchbind(TestCase): f(_empty_tensor_queue(), x), torch.compile(f, backend=backend)(_empty_tensor_queue(), x), ) - if not torch.compiler.is_compiling() and backend == "eager": + if not torch._dynamo.is_compiling() and backend == "eager": self.assertExpectedInline( backend.graphs[0].code.strip(), """\ diff --git a/test/functorch/test_memory_efficient_fusion.py b/test/functorch/test_memory_efficient_fusion.py index d07fb136f5e..bfca66d333b 100644 --- a/test/functorch/test_memory_efficient_fusion.py +++ b/test/functorch/test_memory_efficient_fusion.py @@ -278,7 +278,7 @@ class NoChangeTestCase(TestCase): # Test to repro issue with fx_graph_cse when # hash((primals_2, 1.0)) == hash((primals_2, 1)) - if torch.compiler.is_compiling(): + if torch._dynamo.is_compiling(): self.skipTest("Unsupported if test run is compiled") def f(inpt, osize): diff --git a/test/inductor/test_distributed_patterns.py b/test/inductor/test_distributed_patterns.py index 6a78ea45d58..a647a16c421 100644 --- a/test/inductor/test_distributed_patterns.py +++ b/test/inductor/test_distributed_patterns.py @@ -110,13 +110,13 @@ def init_fake_distributed(device="cpu"): def init_module_bw_hooks(allow_eager): def bw_pre_hook(mod, gO): - assert allow_eager or torch.compiler.is_compiling() + assert allow_eager or torch._dynamo.is_compiling() assert mod.weight.size() == (10, 10) mod.hook_count_pre.add_(1) return (torch.sin(gO[0] + 1.2),) def bw_post_hook(mod, gI, gO): - assert allow_eager or torch.compiler.is_compiling() + assert allow_eager or torch._dynamo.is_compiling() assert mod.weight.size() == (10, 10) mod.hook_count_post.add_(1) return (torch.sin(gI[0] + 3.4),) diff --git a/test/test_optim.py b/test/test_optim.py index e5e6512a05f..d7a8cc92e63 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -288,7 +288,7 @@ class TestOptimRenewed(TestCase): inpt = torch.randn(5, device=device, dtype=dtype) # avoid endless recompiles by wrapping LR in a tensor if we're compiling - lr = torch.tensor(0.01) if torch.compiler.is_compiling() else 0.01 + lr = torch.tensor(0.01) if torch._utils.is_compiling() else 0.01 optimizer = optim_cls([{"params": [weight]}, {"params": [bias], "lr": lr}]) schedulers = [scheduler_c(optimizer) for scheduler_c in schedulers_c] diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index c7cb8027530..f541ecae985 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -10,6 +10,7 @@ from . import trace_rules, variables from .comptime import comptime from .eval_frame import DisableContext, innermost_fn, RunOnlyContext from .exc import IncorrectUsage +from .external_utils import is_compiling if TYPE_CHECKING: @@ -291,7 +292,7 @@ def mark_static(t, index=None): Unlike mark_dynamic, this can be done inside a graph, in which case it induces specialization on the tensor. """ - if torch.compiler.is_compiling(): + if is_compiling(): if index is None: for s in t.size(): comptime.force_static(s) diff --git a/torch/_dynamo/external_utils.py b/torch/_dynamo/external_utils.py index 4f6a7c924d7..c2f192e28b2 100644 --- a/torch/_dynamo/external_utils.py +++ b/torch/_dynamo/external_utils.py @@ -3,7 +3,6 @@ import functools from typing import List -from typing_extensions import deprecated import torch import torch.utils._pytree as pytree @@ -15,10 +14,6 @@ except ModuleNotFoundError: np = None # type: ignore[assignment] -@deprecated( - "`torch._dynamo.external_utils.is_compiling` is deprecated. Use `torch.compiler.is_compiling` instead.", - category=FutureWarning, -) def is_compiling() -> bool: """ Indicates whether we are tracing/compiling with torch.compile() or torch.export(). diff --git a/torch/_functorch/apis.py b/torch/_functorch/apis.py index db252d8ca6d..d906f3c906c 100644 --- a/torch/_functorch/apis.py +++ b/torch/_functorch/apis.py @@ -191,7 +191,7 @@ def vmap( vmap does not provide general autobatching or handle variable-length sequences out of the box. """ - from torch.compiler import is_compiling + from torch._dynamo import is_compiling _check_randomness_arg(randomness) if not (chunk_size is None or chunk_size > 0): @@ -393,7 +393,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla """ # To avoid cyclical dependency. import torch._functorch.eager_transforms as eager_transforms - from torch.compiler import is_compiling + from torch._dynamo import is_compiling def wrapper(*args, **kwargs): return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs) @@ -435,8 +435,8 @@ def grad_and_value( See :func:`grad` for examples """ + from torch._dynamo import is_compiling from torch._functorch import eager_transforms - from torch.compiler import is_compiling def wrapper(*args, **kwargs): return eager_transforms.grad_and_value_impl( diff --git a/torch/_functorch/eager_transforms.py b/torch/_functorch/eager_transforms.py index ad8d4881858..e18c77cf087 100644 --- a/torch/_functorch/eager_transforms.py +++ b/torch/_functorch/eager_transforms.py @@ -764,7 +764,7 @@ def jacrev( # Dynamo does not support HOP composition if their inner function is # annotated with @functools.wraps(...). We circumvent this issue by applying # wraps only if we're not tracing with dynamo. - if not torch.compiler.is_compiling(): + if not torch._dynamo.is_compiling(): wrapper_fn = wraps(func)(wrapper_fn) return wrapper_fn @@ -1344,7 +1344,7 @@ def jacfwd( # Dynamo does not support HOP composition if their inner function is # annotated with @functools.wraps(...). We circumvent this issue by applying # wraps only if we're not tracing with dynamo. - if not torch.compiler.is_compiling(): + if not torch._dynamo.is_compiling(): wrapper_fn = wraps(func)(wrapper_fn) return wrapper_fn diff --git a/torch/_higher_order_ops/associative_scan.py b/torch/_higher_order_ops/associative_scan.py index e3c3383ece1..8839cd4fbfc 100644 --- a/torch/_higher_order_ops/associative_scan.py +++ b/torch/_higher_order_ops/associative_scan.py @@ -74,7 +74,7 @@ def associative_scan( assert callable(combine_fn), "combine_fn must be a callable, but got {combine_fn}" assert isinstance(dim, int), "dim must be an int, but got {type(dim)}" - if not torch.compiler.is_compiling(): + if not torch._dynamo.is_compiling(): with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): return torch.compile(associative_scan, fullgraph=True)( combine_fn, input, dim diff --git a/torch/_utils.py b/torch/_utils.py index eacf6fd93c7..938392fa971 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -8,7 +8,7 @@ import traceback import warnings from collections import defaultdict from typing import Any, Callable, DefaultDict, Generic, List, Optional -from typing_extensions import deprecated, ParamSpec +from typing_extensions import ParamSpec import torch @@ -868,10 +868,6 @@ def classproperty(func): return _ClassPropertyDescriptor(func) -@deprecated( - "`torch._utils.is_compiling` is deprecated. Use `torch.compiler.is_compiling` instead.", - category=FutureWarning, -) def is_compiling() -> bool: """ Indicates whether we are tracing/compiling with torch.compile() or torch.export(). diff --git a/torch/distributed/_composable/fsdp/_fsdp_common.py b/torch/distributed/_composable/fsdp/_fsdp_common.py index 0cbf0b4b261..36b181250f2 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_common.py +++ b/torch/distributed/_composable/fsdp/_fsdp_common.py @@ -6,6 +6,7 @@ from enum import auto, Enum from typing import Any, cast, List, Optional import torch +import torch._dynamo.compiled_autograd as ca import torch.distributed as dist import torch.nn as nn from torch.distributed._composable.contract import _get_registry @@ -114,7 +115,6 @@ def _from_local_no_grad( This method is similar to ``DTensor.from_local()`` except that in eager mode it avoids some CPU overhead by avoiding default args and not being differentiable. """ - import torch._dynamo.compiled_autograd as ca if not ca.compiled_autograd_enabled: return DTensor( @@ -124,13 +124,14 @@ def _from_local_no_grad( sharding_spec, requires_grad=local_tensor.requires_grad, ) - return DTensor.from_local( - local_tensor, - sharding_spec.mesh, - sharding_spec.placements, - shape=sharding_spec.shape, - stride=sharding_spec.stride, - ) + else: + return DTensor.from_local( + local_tensor, + sharding_spec.mesh, + sharding_spec.placements, + shape=sharding_spec.shape, + stride=sharding_spec.stride, + ) def _to_dtype_if_needed( diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py index 16ce811ea33..b1296ae712f 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py @@ -87,7 +87,7 @@ def fp16_compress_hook( decompressed_tensor.copy_(value) return decompressed_tensor - if torch.compiler.is_compiling(): + if torch._utils.is_compiling(): grad = dist._functional_collectives.all_reduce( compressed_tensor, "sum", group_to_use ) @@ -136,7 +136,7 @@ def bf16_compress_hook( decompressed_tensor.copy_(value) return decompressed_tensor - if torch.compiler.is_compiling(): + if torch._utils.is_compiling(): grad = dist._functional_collectives.all_reduce( compressed_tensor, "sum", group_to_use ) diff --git a/torch/distributed/tensor/parallel/_utils.py b/torch/distributed/tensor/parallel/_utils.py index 3208aaa8ea0..3f47ec6f1ef 100644 --- a/torch/distributed/tensor/parallel/_utils.py +++ b/torch/distributed/tensor/parallel/_utils.py @@ -2,20 +2,22 @@ import warnings from typing import Tuple, Union -import torch from torch.distributed._tensor import DeviceMesh from torch.distributed._tensor.placement_types import Placement from torch.distributed.device_mesh import _mesh_resources +try: + from torch._dynamo.external_utils import is_compiling as is_torchdynamo_compiling +except Exception: + + def is_torchdynamo_compiling(): # type: ignore[misc] + return False + + LayoutsType = Union[Placement, Tuple[Placement, ...]] -def is_torchdynamo_compiling() -> bool: - # Use local function to avoid circular imports - return torch.compiler.is_compiling() - - def _deprecate_warnings(func_name: str, extra_msg: str) -> None: """ Inject common validation logics for `_prepare_input` funcs via this decorator. diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 049e0c5c83a..a2257d73676 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -1484,7 +1484,7 @@ class DistributedDataParallel(Module, Joinable): def _should_disable_cpp_reducer(self) -> bool: return self._use_python_reducer and ( - torch.compiler.is_compiling() or self._force_to_disable_cpp_reducer + torch._utils.is_compiling() or self._force_to_disable_cpp_reducer ) def _pre_forward(self, *inputs, **kwargs): @@ -1497,7 +1497,7 @@ class DistributedDataParallel(Module, Joinable): h.remove() self._accum_grad_hooks.clear() - if not self._lazy_init_ran and not torch.compiler.is_compiling(): + if not self._lazy_init_ran and not torch._utils.is_compiling(): self._lazy_init() if self._delay_all_reduce_all_params: diff --git a/torch/optim/_adafactor.py b/torch/optim/_adafactor.py index 8c0fc1f19cd..73ba8658750 100644 --- a/torch/optim/_adafactor.py +++ b/torch/optim/_adafactor.py @@ -423,7 +423,7 @@ def adafactor( See :class:`~torch.optim.Adafactor` for details. """ - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py index 0d75d7547dd..d1a05d6df70 100644 --- a/torch/optim/adadelta.py +++ b/torch/optim/adadelta.py @@ -259,7 +259,7 @@ def _single_tensor_adadelta( has_complex: bool, ): # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -315,7 +315,7 @@ def _multi_tensor_adadelta( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -347,7 +347,7 @@ def _multi_tensor_adadelta( # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just # wrapped it once now. The alpha is required to assure we go to the right overload. - if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu: + if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: torch._foreach_add_( device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 ) @@ -418,7 +418,7 @@ def adadelta( # this check is slow during compilation, so we skip it # if it's strictly needed we can add this check back in dynamo - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/adagrad.py b/torch/optim/adagrad.py index 1ea56f6a28f..ba8a1c895a3 100644 --- a/torch/optim/adagrad.py +++ b/torch/optim/adagrad.py @@ -451,7 +451,7 @@ def _multi_tensor_adagrad( # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just # wrapped it once now. The alpha is required to assure we go to the right overload. - if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu: + if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: torch._foreach_add_( device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 ) diff --git a/torch/optim/adam.py b/torch/optim/adam.py index 58afea691a4..97648b86bec 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -361,7 +361,7 @@ def _single_tensor_adam( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step_t.device.type @@ -474,7 +474,7 @@ def _multi_tensor_adam( ) # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -521,7 +521,7 @@ def _multi_tensor_adam( # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just # wrapped it once now. The alpha is required to assure we go to the right overload. - if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu: + if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: torch._foreach_add_( device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 ) @@ -754,7 +754,7 @@ def adam( # this check is slow during compilation, so we skip it # if it's strictly needed we can add this check back in dynamo - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py index 2482bb70f79..7cb5e464f5a 100644 --- a/torch/optim/adamax.py +++ b/torch/optim/adamax.py @@ -248,7 +248,7 @@ def _single_tensor_adamax( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step_t.device.type @@ -320,7 +320,7 @@ def _multi_tensor_adamax( return # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -352,7 +352,7 @@ def _multi_tensor_adamax( # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just # wrapped it once now. The alpha is required to assure we go to the right overload. - if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu: + if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu: torch._foreach_add_( grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 ) @@ -429,7 +429,7 @@ def adamax( See :class:`~torch.optim.Adamax` for details. """ - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index 5c21e6b46ea..345b4369050 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -362,7 +362,7 @@ def _single_tensor_adamw( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step_t.device.type @@ -475,7 +475,7 @@ def _multi_tensor_adamw( ) # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -521,7 +521,7 @@ def _multi_tensor_adamw( # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just # wrapped it once now. The alpha is required to assure we go to the right overload. - if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu: + if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: torch._foreach_add_( device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 ) @@ -739,7 +739,7 @@ def adamw( See :class:`~torch.optim.AdamW` for details. """ - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py index d0cd25ac7da..1d8402edc48 100644 --- a/torch/optim/asgd.py +++ b/torch/optim/asgd.py @@ -219,7 +219,7 @@ def _single_tensor_asgd( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type @@ -292,7 +292,7 @@ def _multi_tensor_asgd( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -326,7 +326,7 @@ def _multi_tensor_asgd( # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just # wrapped it once now. The alpha is required to assure we go to the right overload. - if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu: + if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu: torch._foreach_add_( grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 ) diff --git a/torch/optim/nadam.py b/torch/optim/nadam.py index 96b385779aa..54cc8df5a9b 100644 --- a/torch/optim/nadam.py +++ b/torch/optim/nadam.py @@ -310,7 +310,7 @@ def _single_tensor_nadam( exp_avg_sq = torch.view_as_real(exp_avg_sq) # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == mu_product.device.type == step_t.device.type @@ -396,7 +396,7 @@ def _multi_tensor_nadam( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -430,7 +430,7 @@ def _multi_tensor_nadam( # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just # wrapped it once now. The alpha is required to assure we go to the right overload. - if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu: + if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu: torch._foreach_add_( grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 ) diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index 3c05fa6b0b2..5a28f98c96c 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -26,6 +26,7 @@ from typing_extensions import ParamSpec, Self, TypeAlias import torch import torch.utils.hooks as hooks +from torch._utils import is_compiling from torch.utils._foreach_utils import ( _get_foreach_kernels_supported_devices, _get_fused_kernels_supported_devices, @@ -99,14 +100,14 @@ def _use_grad_for_differentiable(func): def _get_value(x): # item is significantly faster than a cpu tensor in eager mode - if not torch.jit.is_scripting() and torch.compiler.is_compiling(): + if not torch.jit.is_scripting() and is_compiling(): return x else: return x.item() if isinstance(x, torch.Tensor) else x def _stack_if_compiling(x): - if not torch.jit.is_scripting() and torch.compiler.is_compiling(): + if not torch.jit.is_scripting() and is_compiling(): return torch.stack(x) else: return x @@ -138,7 +139,7 @@ def _disable_dynamo_if_unsupported(single_tensor_fn=None): # the capturable flag. If capturable=True, this is not a problem. @functools.wraps(func) def maybe_fallback(*args, **kwargs): - if torch.compiler.is_compiling() and ( + if is_compiling() and ( not kwargs.get("capturable", False) and has_state_steps and (args[state_steps_ind] and args[state_steps_ind][0].is_cuda) @@ -413,7 +414,7 @@ class Optimizer: # Thus, when compiling, inductor will determine if cudagraphs # can be enabled based on whether there is input mutation or CPU tensors. if ( - not torch.compiler.is_compiling() + not is_compiling() and torch.backends.cuda.is_built() and torch.cuda.is_available() ): @@ -501,7 +502,7 @@ class Optimizer: Skips this step if we are compiling since this will occur during inductor lowering. """ - if torch.compiler.is_compiling(): + if is_compiling(): return {(None, None): (tensorlistlist, list(range(len(tensorlistlist[0]))))} else: return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices) # type: ignore[return-value, arg-type] diff --git a/torch/optim/radam.py b/torch/optim/radam.py index 1683fb9a41d..24949ea4e05 100644 --- a/torch/optim/radam.py +++ b/torch/optim/radam.py @@ -276,7 +276,7 @@ def _single_tensor_radam( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step_t.device.type @@ -374,7 +374,7 @@ def _multi_tensor_radam( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -398,7 +398,7 @@ def _multi_tensor_radam( # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just # wrapped it once now. The alpha is required to assure we go to the right overload. - if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu: + if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu: torch._foreach_add_( grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 ) diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py index 86bd64ff727..c9b33684f48 100644 --- a/torch/optim/rmsprop.py +++ b/torch/optim/rmsprop.py @@ -283,7 +283,7 @@ def _single_tensor_rmsprop( step = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step.device.type @@ -356,7 +356,7 @@ def _multi_tensor_rmsprop( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert all( p.device.type == step.device.type @@ -392,7 +392,7 @@ def _multi_tensor_rmsprop( # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just # wrapped it once now. The alpha is required to assure we go to the right overload. - if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu: + if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu: torch._foreach_add_( grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 ) @@ -475,7 +475,7 @@ def rmsprop( """ # this check is slow during compilation, so we skip it # if it's strictly needed we can add this check back in dynamo - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/rprop.py b/torch/optim/rprop.py index 25aaec211d6..ba0be649a8f 100644 --- a/torch/optim/rprop.py +++ b/torch/optim/rprop.py @@ -243,7 +243,7 @@ def _single_tensor_rprop( step = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step.device.type @@ -309,7 +309,7 @@ def _multi_tensor_rprop( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert all( p.device.type == step.device.type @@ -331,7 +331,7 @@ def _multi_tensor_rprop( # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just # wrapped it once now. The alpha is required to assure we go to the right overload. - if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu: + if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu: torch._foreach_add_( grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 ) @@ -421,7 +421,7 @@ def rprop( """ # this check is slow during compilation, so we skip it # if it's strictly needed we can add this check back in dynamo - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index 2bb37f3ff7f..c9b2b169b1a 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -434,7 +434,7 @@ def _multi_tensor_sgd( if not device_has_sparse_grad: # handle internal item() call if lr is a tensor - if isinstance(lr, torch.Tensor) and torch.compiler.is_compiling(): + if isinstance(lr, torch.Tensor) and torch._utils.is_compiling(): grads_x_lr = torch._foreach_mul(device_grads, -lr) torch._foreach_add_(device_params, grads_x_lr) else: diff --git a/torch/testing/_internal/optests/generate_tests.py b/torch/testing/_internal/optests/generate_tests.py index 0a3774afc3e..7fac1e57c6a 100644 --- a/torch/testing/_internal/optests/generate_tests.py +++ b/torch/testing/_internal/optests/generate_tests.py @@ -567,7 +567,7 @@ class OpCheckMode(TorchFunctionMode): if ( torch.jit.is_tracing() or torch.jit.is_scripting() - or torch.compiler.is_compiling() + or torch._dynamo.is_compiling() ): return func(*args, **kwargs) # Pre-existing code may not use the .default overload. If we see an