mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[user-streams] Add current stream source (#165211)"
This reverts commit79aee77381. Reverted https://github.com/pytorch/pytorch/pull/165211 on behalf of https://github.com/atalman due to failure: test/test_python_dispatch.py::TestPythonDispatch::test_return_stream [GH job link](https://github.com/pytorch/pytorch/actions/runs/18942517662/job/54086481693) [HUD commit link](7563f61cc8) ([comment](https://github.com/pytorch/pytorch/pull/165211#issuecomment-3468332362))
This commit is contained in:
parent
7563f61cc8
commit
ad02bd13df
|
|
@ -181,7 +181,6 @@ from .utils import (
|
||||||
common_constant_types,
|
common_constant_types,
|
||||||
dataclass_fields,
|
dataclass_fields,
|
||||||
dict_keys,
|
dict_keys,
|
||||||
get_current_stream,
|
|
||||||
get_custom_getattr,
|
get_custom_getattr,
|
||||||
get_torch_function_mode_stack,
|
get_torch_function_mode_stack,
|
||||||
get_torch_function_mode_stack_at,
|
get_torch_function_mode_stack_at,
|
||||||
|
|
@ -760,7 +759,6 @@ def _get_closure_vars() -> dict[str, object]:
|
||||||
"___dataclass_fields": dataclass_fields,
|
"___dataclass_fields": dataclass_fields,
|
||||||
"___namedtuple_fields": lambda x: x._fields,
|
"___namedtuple_fields": lambda x: x._fields,
|
||||||
"___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at,
|
"___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at,
|
||||||
"___get_current_stream": get_current_stream,
|
|
||||||
"__math_isnan": math.isnan,
|
"__math_isnan": math.isnan,
|
||||||
"__numpy_isnan": None if np is None else np.isnan,
|
"__numpy_isnan": None if np is None else np.isnan,
|
||||||
"inf": float("inf"),
|
"inf": float("inf"),
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,6 @@ import functools
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
from torch import device as device_type
|
|
||||||
from torch._guards import ChainedSource, Guard, GuardSource, Source
|
from torch._guards import ChainedSource, Guard, GuardSource, Source
|
||||||
|
|
||||||
from . import utils
|
from . import utils
|
||||||
|
|
@ -1080,30 +1079,6 @@ class ShapeEnvSource(Source):
|
||||||
return GuardSource.SHAPE_ENV
|
return GuardSource.SHAPE_ENV
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
|
||||||
class CurrentStreamSource(Source):
|
|
||||||
device: device_type
|
|
||||||
|
|
||||||
def name(self) -> str:
|
|
||||||
return f"___get_current_stream({self.device})"
|
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
|
||||||
codegen.add_push_null(
|
|
||||||
lambda: codegen.load_import_from(utils.__name__, "get_current_stream")
|
|
||||||
)
|
|
||||||
num_args = 1
|
|
||||||
codegen.extend_output([codegen.create_load_const(self.device.type)])
|
|
||||||
if self.device.index is not None:
|
|
||||||
num_args += 1
|
|
||||||
codegen.extend_output([codegen.create_load_const(self.device.index)])
|
|
||||||
codegen.add_push_null(lambda: codegen.load_import_from("torch", "device"))
|
|
||||||
codegen.extend_output(create_call_function(num_args, False))
|
|
||||||
codegen.extend_output(create_call_function(1, False))
|
|
||||||
|
|
||||||
def guard_source(self) -> GuardSource:
|
|
||||||
return GuardSource.GLOBAL
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class BackwardStateSource(Source):
|
class BackwardStateSource(Source):
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
|
|
|
||||||
|
|
@ -4695,12 +4695,6 @@ def clear_torch_function_mode_stack() -> None:
|
||||||
_pop_torch_function_stack()
|
_pop_torch_function_stack()
|
||||||
|
|
||||||
|
|
||||||
def get_current_stream(device: torch.device) -> torch.Stream:
|
|
||||||
from .device_interface import get_interface_for_device
|
|
||||||
|
|
||||||
return get_interface_for_device(device).current_stream()
|
|
||||||
|
|
||||||
|
|
||||||
# call from C dynamo in order to inspect values in pdb
|
# call from C dynamo in order to inspect values in pdb
|
||||||
def _breakpoint_for_c_dynamo(*args: Any) -> None:
|
def _breakpoint_for_c_dynamo(*args: Any) -> None:
|
||||||
breakpoint()
|
breakpoint()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user