mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Use torch.Stream&torch.Event for Dynamo capature (#134850)
# Motivation This PR aims to solve the multiple Inheritance problem. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134850 Approved by: https://github.com/yf225, https://github.com/EikanWang
This commit is contained in:
parent
bf73af4b4e
commit
d29094888b
|
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torch._dynamo import device_interface # noqa: PLC2701 import-private-name
|
||||
|
||||
|
||||
|
|
@ -13,9 +14,7 @@ class DeviceProperties:
|
|||
|
||||
|
||||
class DeviceInterface(device_interface.DeviceInterface):
|
||||
class Event(
|
||||
device_interface._EventBase
|
||||
): # pyright: ignore [reportPrivateImportUsage]
|
||||
class Event(torch.Event):
|
||||
def __init__(
|
||||
self,
|
||||
enable_timing: bool = False,
|
||||
|
|
|
|||
|
|
@ -1,11 +1,9 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import inspect
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from torch._streambase import _EventBase, _StreamBase
|
||||
|
||||
|
||||
get_cuda_stream: Optional[Callable[[int], int]]
|
||||
|
|
@ -21,21 +19,7 @@ caching_worker_device_properties: Dict[str, Any] = {}
|
|||
caching_worker_current_devices: Dict[str, int] = {}
|
||||
|
||||
|
||||
class DeviceInterfaceMeta(type):
|
||||
def __new__(metacls, *args, **kwargs):
|
||||
class_member = args[2]
|
||||
if "Event" in class_member:
|
||||
assert inspect.isclass(class_member["Event"]) and issubclass(
|
||||
class_member["Event"], _EventBase
|
||||
), "DeviceInterface member Event should be inherit from _EventBase"
|
||||
if "Stream" in class_member:
|
||||
assert inspect.isclass(class_member["Stream"]) and issubclass(
|
||||
class_member["Stream"], _StreamBase
|
||||
), "DeviceInterface member Stream should be inherit from _StreamBase"
|
||||
return super().__new__(metacls, *args, **kwargs)
|
||||
|
||||
|
||||
class DeviceInterface(metaclass=DeviceInterfaceMeta):
|
||||
class DeviceInterface:
|
||||
"""
|
||||
This is a simple device runtime interface for Inductor. It enables custom
|
||||
backends to be integrated with Inductor in a device-agnostic semantic.
|
||||
|
|
@ -45,6 +29,18 @@ class DeviceInterface(metaclass=DeviceInterfaceMeta):
|
|||
def __new__(cls, device: _device_t):
|
||||
raise NotImplementedError
|
||||
|
||||
class Event:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"Event should be inherited from torch.Event, otherwise, it couldn't be captured by dynamo."
|
||||
)
|
||||
|
||||
class Stream:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"Stream should be inherited from torch.Stream, otherwise, it couldn't be captured by dynamo."
|
||||
)
|
||||
|
||||
class Worker:
|
||||
"""
|
||||
Worker API to query device properties that will work in multi processing
|
||||
|
|
@ -161,7 +157,7 @@ class CudaInterface(DeviceInterface):
|
|||
device = torch.cuda.device
|
||||
|
||||
# register Event and Stream class into the backend interface
|
||||
# make sure Event and Stream are implemented and inherited from the _EventBase and _StreamBase
|
||||
# make sure Event and Stream are implemented and inherited from the torch.Event and torch.Stream
|
||||
Event = torch.cuda.Event
|
||||
Stream = torch.cuda.Stream
|
||||
|
||||
|
|
@ -303,14 +299,14 @@ class CpuDeviceProperties:
|
|||
|
||||
|
||||
class CpuInterface(DeviceInterface):
|
||||
class Event(_EventBase):
|
||||
class Event(torch.Event):
|
||||
def __init__(self, enable_timing=True):
|
||||
self.time = 0.0
|
||||
|
||||
def elapsed_time(self, end_event) -> float:
|
||||
return (end_event.time - self.time) * 1000
|
||||
|
||||
def record(self):
|
||||
def record(self, stream=None):
|
||||
self.time = time.perf_counter()
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -36,7 +36,6 @@ from torch import SymInt
|
|||
from torch._guards import GuardSource, TracingContext
|
||||
from torch._higher_order_ops.torchbind import call_torchbind
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._streambase import _EventBase, _StreamBase
|
||||
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
|
||||
from torch._subclasses.meta_utils import is_sparse_any, safe_grad
|
||||
from torch._utils_internal import justknobs_check
|
||||
|
|
@ -822,7 +821,7 @@ class VariableBuilder:
|
|||
stream_source = AttrSource(self.source, "stream")
|
||||
stream_var = VariableBuilder(self.tx, stream_source)(value.stream)
|
||||
return StreamContextVariable.create(self.tx, stream_var)
|
||||
elif isinstance(value, _StreamBase):
|
||||
elif isinstance(value, torch.Stream):
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
stream_proxy = self.tx.output.create_proxy(
|
||||
"call_function",
|
||||
|
|
@ -847,7 +846,7 @@ class VariableBuilder:
|
|||
elif isinstance(value, torch._C._SDPBackend):
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
return ConstantVariable(value)
|
||||
elif isinstance(value, _EventBase):
|
||||
elif isinstance(value, torch.Event):
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
torch._dynamo.utils.store_user_object_weakref(value)
|
||||
event_proxy = self.tx.output.create_proxy(
|
||||
|
|
@ -2265,7 +2264,7 @@ def wrap_fx_proxy_cls(
|
|||
return SymNodeVariable(proxy, example_value, **options)
|
||||
elif (
|
||||
inspect.isclass(proxy.node.target)
|
||||
and issubclass(proxy.node.target, _StreamBase)
|
||||
and issubclass(proxy.node.target, torch.Stream)
|
||||
) or proxy.node.target in [
|
||||
device_interface.current_stream
|
||||
for _, device_interface in get_registered_device_interfaces()
|
||||
|
|
@ -2273,7 +2272,8 @@ def wrap_fx_proxy_cls(
|
|||
set_example_value(proxy.node, example_value)
|
||||
return StreamVariable(proxy, example_value, example_value.device, **options)
|
||||
elif (
|
||||
inspect.isclass(proxy.node.target) and issubclass(proxy.node.target, _EventBase)
|
||||
inspect.isclass(proxy.node.target)
|
||||
and issubclass(proxy.node.target, torch.Event)
|
||||
) or proxy.node.target in [
|
||||
device_interface.Event
|
||||
for _, device_interface in get_registered_device_interfaces()
|
||||
|
|
@ -2285,7 +2285,7 @@ def wrap_fx_proxy_cls(
|
|||
return ConstantVariable(example_value, **options)
|
||||
elif (
|
||||
example_value is not None
|
||||
and isinstance(example_value, _EventBase)
|
||||
and isinstance(example_value, torch.Event)
|
||||
and proxy.node.target == "record_event"
|
||||
and proxy.node.op == "call_method"
|
||||
):
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import torch.fx
|
|||
import torch.nn
|
||||
from torch._guards import TracingContext
|
||||
from torch._logging import warning_once
|
||||
from torch._streambase import _StreamBase
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type
|
||||
|
||||
from .. import config, polyfills, variables
|
||||
|
|
@ -267,7 +266,7 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
|
|||
assert len(args) <= 1 and len(kwargs) == 0
|
||||
inf_mode = args[0].as_python_constant() if len(args) == 1 else True
|
||||
return InferenceModeVariable.create(tx, inf_mode)
|
||||
elif inspect.isclass(self.value) and issubclass(self.value, _StreamBase):
|
||||
elif inspect.isclass(self.value) and issubclass(self.value, torch.Stream):
|
||||
from torch._dynamo.variables.builder import wrap_fx_proxy_cls
|
||||
|
||||
return wrap_fx_proxy_cls(
|
||||
|
|
|
|||
|
|
@ -1,46 +1,20 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from abc import ABC, abstractmethod
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class _StreamBase(ABC):
|
||||
r"""Base stream class abstraction for multi backends Stream to herit from"""
|
||||
|
||||
@abstractmethod
|
||||
def wait_event(self, event) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def wait_stream(self, stream) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def record_event(self, event=None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def query(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def synchronize(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def __eq__(self, stream) -> bool:
|
||||
raise NotImplementedError
|
||||
# Preserved only for BC reasons
|
||||
@deprecated(
|
||||
"`torch._streambase._StreamBase` is deprecated. Please use `torch.Stream` instead.",
|
||||
category=FutureWarning,
|
||||
)
|
||||
class _StreamBase(torch.Stream):
|
||||
pass
|
||||
|
||||
|
||||
class _EventBase(ABC):
|
||||
r"""Base Event class abstraction for multi backends Event to herit from"""
|
||||
|
||||
@abstractmethod
|
||||
def wait(self, stream=None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def query(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def synchronize(self) -> None:
|
||||
raise NotImplementedError
|
||||
@deprecated(
|
||||
"`torch._streambase._EventBase` is deprecated. Please use `torch.Event` instead.",
|
||||
category=FutureWarning,
|
||||
)
|
||||
class _EventBase(torch.Event):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
import ctypes
|
||||
|
||||
import torch
|
||||
from torch._streambase import _EventBase, _StreamBase
|
||||
from torch._utils import _dummy_type
|
||||
|
||||
|
||||
|
|
@ -12,7 +11,7 @@ if not hasattr(torch._C, "_CudaStreamBase"):
|
|||
torch._C.__dict__["_CudaEventBase"] = _dummy_type("_CudaEventBase")
|
||||
|
||||
|
||||
class Stream(torch._C._CudaStreamBase, _StreamBase):
|
||||
class Stream(torch._C._CudaStreamBase):
|
||||
r"""Wrapper around a CUDA stream.
|
||||
|
||||
A CUDA stream is a linear sequence of execution that belongs to a specific
|
||||
|
|
@ -138,7 +137,7 @@ class ExternalStream(Stream):
|
|||
return super().__new__(cls, stream_ptr=stream_ptr, **kwargs)
|
||||
|
||||
|
||||
class Event(torch._C._CudaEventBase, _EventBase):
|
||||
class Event(torch._C._CudaEventBase):
|
||||
r"""Wrapper around a CUDA event.
|
||||
|
||||
CUDA events are synchronization markers that can be used to monitor the
|
||||
|
|
|
|||
|
|
@ -2,9 +2,7 @@
|
|||
import ctypes
|
||||
|
||||
import torch
|
||||
from torch._streambase import _EventBase, _StreamBase
|
||||
|
||||
from .._utils import _dummy_type
|
||||
from torch._utils import _dummy_type
|
||||
|
||||
|
||||
if not hasattr(torch._C, "_XpuStreamBase"):
|
||||
|
|
@ -13,7 +11,7 @@ if not hasattr(torch._C, "_XpuStreamBase"):
|
|||
torch._C.__dict__["_XpuEventBase"] = _dummy_type("_XpuEventBase")
|
||||
|
||||
|
||||
class Stream(torch._C._XpuStreamBase, _StreamBase):
|
||||
class Stream(torch._C._XpuStreamBase):
|
||||
r"""Wrapper around a XPU stream.
|
||||
|
||||
A XPU stream is a linear sequence of execution that belongs to a specific
|
||||
|
|
@ -98,7 +96,7 @@ class Stream(torch._C._XpuStreamBase, _StreamBase):
|
|||
return f"torch.xpu.Stream(device={self.device} sycl_queue={self.sycl_queue:#x})"
|
||||
|
||||
|
||||
class Event(torch._C._XpuEventBase, _EventBase):
|
||||
class Event(torch._C._XpuEventBase):
|
||||
r"""Wrapper around a XPU event.
|
||||
|
||||
XPU events are synchronization markers that can be used to monitor the
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user