From ad02bd13dfa017f69def846b265a566c4ec5cb3f Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 30 Oct 2025 14:34:43 +0000 Subject: [PATCH] Revert "[user-streams] Add current stream source (#165211)" This reverts commit 79aee77381b21d41c77148e5ff84c4b351aaf144. Reverted https://github.com/pytorch/pytorch/pull/165211 on behalf of https://github.com/atalman due to failure: test/test_python_dispatch.py::TestPythonDispatch::test_return_stream [GH job link](https://github.com/pytorch/pytorch/actions/runs/18942517662/job/54086481693) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/7563f61cc8a40a5ba21a498a2d98895b4eec3f39) ([comment](https://github.com/pytorch/pytorch/pull/165211#issuecomment-3468332362)) --- torch/_dynamo/guards.py | 2 -- torch/_dynamo/source.py | 25 ------------------------- torch/_dynamo/utils.py | 6 ------ 3 files changed, 33 deletions(-) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 7d50e48f7a7..0b429335279 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -181,7 +181,6 @@ from .utils import ( common_constant_types, dataclass_fields, dict_keys, - get_current_stream, get_custom_getattr, get_torch_function_mode_stack, get_torch_function_mode_stack_at, @@ -760,7 +759,6 @@ def _get_closure_vars() -> dict[str, object]: "___dataclass_fields": dataclass_fields, "___namedtuple_fields": lambda x: x._fields, "___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at, - "___get_current_stream": get_current_stream, "__math_isnan": math.isnan, "__numpy_isnan": None if np is None else np.isnan, "inf": float("inf"), diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 88614309cb9..9fb4f32d68a 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -23,7 +23,6 @@ import functools from collections.abc import Callable from typing import Any, Optional, TYPE_CHECKING, Union -from torch import device as device_type from torch._guards import ChainedSource, Guard, GuardSource, Source from . import utils @@ -1080,30 +1079,6 @@ class ShapeEnvSource(Source): return GuardSource.SHAPE_ENV -@dataclasses.dataclass(frozen=True) -class CurrentStreamSource(Source): - device: device_type - - def name(self) -> str: - return f"___get_current_stream({self.device})" - - def reconstruct(self, codegen: "PyCodegen") -> None: - codegen.add_push_null( - lambda: codegen.load_import_from(utils.__name__, "get_current_stream") - ) - num_args = 1 - codegen.extend_output([codegen.create_load_const(self.device.type)]) - if self.device.index is not None: - num_args += 1 - codegen.extend_output([codegen.create_load_const(self.device.index)]) - codegen.add_push_null(lambda: codegen.load_import_from("torch", "device")) - codegen.extend_output(create_call_function(num_args, False)) - codegen.extend_output(create_call_function(1, False)) - - def guard_source(self) -> GuardSource: - return GuardSource.GLOBAL - - @dataclasses.dataclass(frozen=True) class BackwardStateSource(Source): def name(self) -> str: diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 7700b00a132..d07ad52ab32 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -4695,12 +4695,6 @@ def clear_torch_function_mode_stack() -> None: _pop_torch_function_stack() -def get_current_stream(device: torch.device) -> torch.Stream: - from .device_interface import get_interface_for_device - - return get_interface_for_device(device).current_stream() - - # call from C dynamo in order to inspect values in pdb def _breakpoint_for_c_dynamo(*args: Any) -> None: breakpoint()