[user-streams] Add current stream source (#165211)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165211
Approved by: https://github.com/anijain2305
ghstack dependencies: #164304, #164522, #164819
This commit is contained in:
Michael Lazos 2025-10-29 17:30:03 -07:00 committed by PyTorch MergeBot
parent f5cb9a4c68
commit 79aee77381
3 changed files with 33 additions and 0 deletions

View File

@ -181,6 +181,7 @@ from .utils import (
common_constant_types,
dataclass_fields,
dict_keys,
get_current_stream,
get_custom_getattr,
get_torch_function_mode_stack,
get_torch_function_mode_stack_at,
@ -761,6 +762,7 @@ def _get_closure_vars() -> dict[str, object]:
"___dataclass_fields": dataclass_fields,
"___namedtuple_fields": lambda x: x._fields,
"___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at,
"___get_current_stream": get_current_stream,
"__math_isnan": math.isnan,
"__numpy_isnan": None if np is None else np.isnan,
"inf": float("inf"),

View File

@ -23,6 +23,7 @@ import functools
from collections.abc import Callable
from typing import Any, Optional, TYPE_CHECKING, Union
from torch import device as device_type
from torch._guards import ChainedSource, Guard, GuardSource, Source
from . import utils
@ -1079,6 +1080,30 @@ class ShapeEnvSource(Source):
return GuardSource.SHAPE_ENV
@dataclasses.dataclass(frozen=True)
class CurrentStreamSource(Source):
device: device_type
def name(self) -> str:
return f"___get_current_stream({self.device})"
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null(
lambda: codegen.load_import_from(utils.__name__, "get_current_stream")
)
num_args = 1
codegen.extend_output([codegen.create_load_const(self.device.type)])
if self.device.index is not None:
num_args += 1
codegen.extend_output([codegen.create_load_const(self.device.index)])
codegen.add_push_null(lambda: codegen.load_import_from("torch", "device"))
codegen.extend_output(create_call_function(num_args, False))
codegen.extend_output(create_call_function(1, False))
def guard_source(self) -> GuardSource:
return GuardSource.GLOBAL
@dataclasses.dataclass(frozen=True)
class BackwardStateSource(Source):
def name(self) -> str:

View File

@ -4695,6 +4695,12 @@ def clear_torch_function_mode_stack() -> None:
_pop_torch_function_stack()
def get_current_stream(device: torch.device) -> torch.Stream:
from .device_interface import get_interface_for_device
return get_interface_for_device(device).current_stream()
# call from C dynamo in order to inspect values in pdb
def _breakpoint_for_c_dynamo(*args: Any) -> None:
breakpoint()