mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[user-streams] Add current stream source (#165211)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165211 Approved by: https://github.com/anijain2305 ghstack dependencies: #164304, #164522, #164819
This commit is contained in:
parent
f5cb9a4c68
commit
79aee77381
|
|
@ -181,6 +181,7 @@ from .utils import (
|
|||
common_constant_types,
|
||||
dataclass_fields,
|
||||
dict_keys,
|
||||
get_current_stream,
|
||||
get_custom_getattr,
|
||||
get_torch_function_mode_stack,
|
||||
get_torch_function_mode_stack_at,
|
||||
|
|
@ -761,6 +762,7 @@ def _get_closure_vars() -> dict[str, object]:
|
|||
"___dataclass_fields": dataclass_fields,
|
||||
"___namedtuple_fields": lambda x: x._fields,
|
||||
"___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at,
|
||||
"___get_current_stream": get_current_stream,
|
||||
"__math_isnan": math.isnan,
|
||||
"__numpy_isnan": None if np is None else np.isnan,
|
||||
"inf": float("inf"),
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ import functools
|
|||
from collections.abc import Callable
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from torch import device as device_type
|
||||
from torch._guards import ChainedSource, Guard, GuardSource, Source
|
||||
|
||||
from . import utils
|
||||
|
|
@ -1079,6 +1080,30 @@ class ShapeEnvSource(Source):
|
|||
return GuardSource.SHAPE_ENV
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CurrentStreamSource(Source):
|
||||
device: device_type
|
||||
|
||||
def name(self) -> str:
|
||||
return f"___get_current_stream({self.device})"
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(utils.__name__, "get_current_stream")
|
||||
)
|
||||
num_args = 1
|
||||
codegen.extend_output([codegen.create_load_const(self.device.type)])
|
||||
if self.device.index is not None:
|
||||
num_args += 1
|
||||
codegen.extend_output([codegen.create_load_const(self.device.index)])
|
||||
codegen.add_push_null(lambda: codegen.load_import_from("torch", "device"))
|
||||
codegen.extend_output(create_call_function(num_args, False))
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
|
||||
def guard_source(self) -> GuardSource:
|
||||
return GuardSource.GLOBAL
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BackwardStateSource(Source):
|
||||
def name(self) -> str:
|
||||
|
|
|
|||
|
|
@ -4695,6 +4695,12 @@ def clear_torch_function_mode_stack() -> None:
|
|||
_pop_torch_function_stack()
|
||||
|
||||
|
||||
def get_current_stream(device: torch.device) -> torch.Stream:
|
||||
from .device_interface import get_interface_for_device
|
||||
|
||||
return get_interface_for_device(device).current_stream()
|
||||
|
||||
|
||||
# call from C dynamo in order to inspect values in pdb
|
||||
def _breakpoint_for_c_dynamo(*args: Any) -> None:
|
||||
breakpoint()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user