mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Revert "[user-streams] Add current stream source (#165211)"
This reverts commit79aee77381. Reverted https://github.com/pytorch/pytorch/pull/165211 on behalf of https://github.com/atalman due to failure: test/test_python_dispatch.py::TestPythonDispatch::test_return_stream [GH job link](https://github.com/pytorch/pytorch/actions/runs/18942517662/job/54086481693) [HUD commit link](7563f61cc8) ([comment](https://github.com/pytorch/pytorch/pull/165211#issuecomment-3468332362))
This commit is contained in:
parent
7563f61cc8
commit
ad02bd13df
|
|
@ -181,7 +181,6 @@ from .utils import (
|
|||
common_constant_types,
|
||||
dataclass_fields,
|
||||
dict_keys,
|
||||
get_current_stream,
|
||||
get_custom_getattr,
|
||||
get_torch_function_mode_stack,
|
||||
get_torch_function_mode_stack_at,
|
||||
|
|
@ -760,7 +759,6 @@ def _get_closure_vars() -> dict[str, object]:
|
|||
"___dataclass_fields": dataclass_fields,
|
||||
"___namedtuple_fields": lambda x: x._fields,
|
||||
"___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at,
|
||||
"___get_current_stream": get_current_stream,
|
||||
"__math_isnan": math.isnan,
|
||||
"__numpy_isnan": None if np is None else np.isnan,
|
||||
"inf": float("inf"),
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ import functools
|
|||
from collections.abc import Callable
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from torch import device as device_type
|
||||
from torch._guards import ChainedSource, Guard, GuardSource, Source
|
||||
|
||||
from . import utils
|
||||
|
|
@ -1080,30 +1079,6 @@ class ShapeEnvSource(Source):
|
|||
return GuardSource.SHAPE_ENV
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CurrentStreamSource(Source):
|
||||
device: device_type
|
||||
|
||||
def name(self) -> str:
|
||||
return f"___get_current_stream({self.device})"
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(utils.__name__, "get_current_stream")
|
||||
)
|
||||
num_args = 1
|
||||
codegen.extend_output([codegen.create_load_const(self.device.type)])
|
||||
if self.device.index is not None:
|
||||
num_args += 1
|
||||
codegen.extend_output([codegen.create_load_const(self.device.index)])
|
||||
codegen.add_push_null(lambda: codegen.load_import_from("torch", "device"))
|
||||
codegen.extend_output(create_call_function(num_args, False))
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
|
||||
def guard_source(self) -> GuardSource:
|
||||
return GuardSource.GLOBAL
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BackwardStateSource(Source):
|
||||
def name(self) -> str:
|
||||
|
|
|
|||
|
|
@ -4695,12 +4695,6 @@ def clear_torch_function_mode_stack() -> None:
|
|||
_pop_torch_function_stack()
|
||||
|
||||
|
||||
def get_current_stream(device: torch.device) -> torch.Stream:
|
||||
from .device_interface import get_interface_for_device
|
||||
|
||||
return get_interface_for_device(device).current_stream()
|
||||
|
||||
|
||||
# call from C dynamo in order to inspect values in pdb
|
||||
def _breakpoint_for_c_dynamo(*args: Any) -> None:
|
||||
breakpoint()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user