Use torch.Stream&torch.Event for Dynamo capature (#134850)

# Motivation
This PR aims to solve the multiple Inheritance problem.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134850
Approved by: https://github.com/yf225, https://github.com/EikanWang
This commit is contained in:
Yu, Guangye 2024-10-02 09:13:19 +00:00 committed by PyTorch MergeBot
parent bf73af4b4e
commit d29094888b
7 changed files with 46 additions and 81 deletions

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import time
import torch
from torch._dynamo import device_interface # noqa: PLC2701 import-private-name
@ -13,9 +14,7 @@ class DeviceProperties:
class DeviceInterface(device_interface.DeviceInterface):
class Event(
device_interface._EventBase
): # pyright: ignore [reportPrivateImportUsage]
class Event(torch.Event):
def __init__(
self,
enable_timing: bool = False,

View File

@ -1,11 +1,9 @@
# mypy: allow-untyped-defs
import inspect
import time
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union
import torch
from torch._streambase import _EventBase, _StreamBase
get_cuda_stream: Optional[Callable[[int], int]]
@ -21,21 +19,7 @@ caching_worker_device_properties: Dict[str, Any] = {}
caching_worker_current_devices: Dict[str, int] = {}
class DeviceInterfaceMeta(type):
def __new__(metacls, *args, **kwargs):
class_member = args[2]
if "Event" in class_member:
assert inspect.isclass(class_member["Event"]) and issubclass(
class_member["Event"], _EventBase
), "DeviceInterface member Event should be inherit from _EventBase"
if "Stream" in class_member:
assert inspect.isclass(class_member["Stream"]) and issubclass(
class_member["Stream"], _StreamBase
), "DeviceInterface member Stream should be inherit from _StreamBase"
return super().__new__(metacls, *args, **kwargs)
class DeviceInterface(metaclass=DeviceInterfaceMeta):
class DeviceInterface:
"""
This is a simple device runtime interface for Inductor. It enables custom
backends to be integrated with Inductor in a device-agnostic semantic.
@ -45,6 +29,18 @@ class DeviceInterface(metaclass=DeviceInterfaceMeta):
def __new__(cls, device: _device_t):
raise NotImplementedError
class Event:
def __new__(cls, *args, **kwargs):
raise NotImplementedError(
"Event should be inherited from torch.Event, otherwise, it couldn't be captured by dynamo."
)
class Stream:
def __new__(cls, *args, **kwargs):
raise NotImplementedError(
"Stream should be inherited from torch.Stream, otherwise, it couldn't be captured by dynamo."
)
class Worker:
"""
Worker API to query device properties that will work in multi processing
@ -161,7 +157,7 @@ class CudaInterface(DeviceInterface):
device = torch.cuda.device
# register Event and Stream class into the backend interface
# make sure Event and Stream are implemented and inherited from the _EventBase and _StreamBase
# make sure Event and Stream are implemented and inherited from the torch.Event and torch.Stream
Event = torch.cuda.Event
Stream = torch.cuda.Stream
@ -303,14 +299,14 @@ class CpuDeviceProperties:
class CpuInterface(DeviceInterface):
class Event(_EventBase):
class Event(torch.Event):
def __init__(self, enable_timing=True):
self.time = 0.0
def elapsed_time(self, end_event) -> float:
return (end_event.time - self.time) * 1000
def record(self):
def record(self, stream=None):
self.time = time.perf_counter()
@staticmethod

View File

@ -36,7 +36,6 @@ from torch import SymInt
from torch._guards import GuardSource, TracingContext
from torch._higher_order_ops.torchbind import call_torchbind
from torch._ops import HigherOrderOperator
from torch._streambase import _EventBase, _StreamBase
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
from torch._subclasses.meta_utils import is_sparse_any, safe_grad
from torch._utils_internal import justknobs_check
@ -822,7 +821,7 @@ class VariableBuilder:
stream_source = AttrSource(self.source, "stream")
stream_var = VariableBuilder(self.tx, stream_source)(value.stream)
return StreamContextVariable.create(self.tx, stream_var)
elif isinstance(value, _StreamBase):
elif isinstance(value, torch.Stream):
self.install_guards(GuardBuilder.ID_MATCH)
stream_proxy = self.tx.output.create_proxy(
"call_function",
@ -847,7 +846,7 @@ class VariableBuilder:
elif isinstance(value, torch._C._SDPBackend):
self.install_guards(GuardBuilder.ID_MATCH)
return ConstantVariable(value)
elif isinstance(value, _EventBase):
elif isinstance(value, torch.Event):
self.install_guards(GuardBuilder.ID_MATCH)
torch._dynamo.utils.store_user_object_weakref(value)
event_proxy = self.tx.output.create_proxy(
@ -2265,7 +2264,7 @@ def wrap_fx_proxy_cls(
return SymNodeVariable(proxy, example_value, **options)
elif (
inspect.isclass(proxy.node.target)
and issubclass(proxy.node.target, _StreamBase)
and issubclass(proxy.node.target, torch.Stream)
) or proxy.node.target in [
device_interface.current_stream
for _, device_interface in get_registered_device_interfaces()
@ -2273,7 +2272,8 @@ def wrap_fx_proxy_cls(
set_example_value(proxy.node, example_value)
return StreamVariable(proxy, example_value, example_value.device, **options)
elif (
inspect.isclass(proxy.node.target) and issubclass(proxy.node.target, _EventBase)
inspect.isclass(proxy.node.target)
and issubclass(proxy.node.target, torch.Event)
) or proxy.node.target in [
device_interface.Event
for _, device_interface in get_registered_device_interfaces()
@ -2285,7 +2285,7 @@ def wrap_fx_proxy_cls(
return ConstantVariable(example_value, **options)
elif (
example_value is not None
and isinstance(example_value, _EventBase)
and isinstance(example_value, torch.Event)
and proxy.node.target == "record_event"
and proxy.node.op == "call_method"
):

View File

@ -13,7 +13,6 @@ import torch.fx
import torch.nn
from torch._guards import TracingContext
from torch._logging import warning_once
from torch._streambase import _StreamBase
from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type
from .. import config, polyfills, variables
@ -267,7 +266,7 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
assert len(args) <= 1 and len(kwargs) == 0
inf_mode = args[0].as_python_constant() if len(args) == 1 else True
return InferenceModeVariable.create(tx, inf_mode)
elif inspect.isclass(self.value) and issubclass(self.value, _StreamBase):
elif inspect.isclass(self.value) and issubclass(self.value, torch.Stream):
from torch._dynamo.variables.builder import wrap_fx_proxy_cls
return wrap_fx_proxy_cls(

View File

@ -1,46 +1,20 @@
# mypy: allow-untyped-defs
from abc import ABC, abstractmethod
from typing_extensions import deprecated
import torch
class _StreamBase(ABC):
r"""Base stream class abstraction for multi backends Stream to herit from"""
@abstractmethod
def wait_event(self, event) -> None:
raise NotImplementedError
@abstractmethod
def wait_stream(self, stream) -> None:
raise NotImplementedError
@abstractmethod
def record_event(self, event=None) -> None:
raise NotImplementedError
@abstractmethod
def query(self) -> bool:
raise NotImplementedError
@abstractmethod
def synchronize(self) -> None:
raise NotImplementedError
@abstractmethod
def __eq__(self, stream) -> bool:
raise NotImplementedError
# Preserved only for BC reasons
@deprecated(
"`torch._streambase._StreamBase` is deprecated. Please use `torch.Stream` instead.",
category=FutureWarning,
)
class _StreamBase(torch.Stream):
pass
class _EventBase(ABC):
r"""Base Event class abstraction for multi backends Event to herit from"""
@abstractmethod
def wait(self, stream=None) -> None:
raise NotImplementedError
@abstractmethod
def query(self) -> bool:
raise NotImplementedError
@abstractmethod
def synchronize(self) -> None:
raise NotImplementedError
@deprecated(
"`torch._streambase._EventBase` is deprecated. Please use `torch.Event` instead.",
category=FutureWarning,
)
class _EventBase(torch.Event):
pass

View File

@ -2,7 +2,6 @@
import ctypes
import torch
from torch._streambase import _EventBase, _StreamBase
from torch._utils import _dummy_type
@ -12,7 +11,7 @@ if not hasattr(torch._C, "_CudaStreamBase"):
torch._C.__dict__["_CudaEventBase"] = _dummy_type("_CudaEventBase")
class Stream(torch._C._CudaStreamBase, _StreamBase):
class Stream(torch._C._CudaStreamBase):
r"""Wrapper around a CUDA stream.
A CUDA stream is a linear sequence of execution that belongs to a specific
@ -138,7 +137,7 @@ class ExternalStream(Stream):
return super().__new__(cls, stream_ptr=stream_ptr, **kwargs)
class Event(torch._C._CudaEventBase, _EventBase):
class Event(torch._C._CudaEventBase):
r"""Wrapper around a CUDA event.
CUDA events are synchronization markers that can be used to monitor the

View File

@ -2,9 +2,7 @@
import ctypes
import torch
from torch._streambase import _EventBase, _StreamBase
from .._utils import _dummy_type
from torch._utils import _dummy_type
if not hasattr(torch._C, "_XpuStreamBase"):
@ -13,7 +11,7 @@ if not hasattr(torch._C, "_XpuStreamBase"):
torch._C.__dict__["_XpuEventBase"] = _dummy_type("_XpuEventBase")
class Stream(torch._C._XpuStreamBase, _StreamBase):
class Stream(torch._C._XpuStreamBase):
r"""Wrapper around a XPU stream.
A XPU stream is a linear sequence of execution that belongs to a specific
@ -98,7 +96,7 @@ class Stream(torch._C._XpuStreamBase, _StreamBase):
return f"torch.xpu.Stream(device={self.device} sycl_queue={self.sycl_queue:#x})"
class Event(torch._C._XpuEventBase, _EventBase):
class Event(torch._C._XpuEventBase):
r"""Wrapper around a XPU event.
XPU events are synchronization markers that can be used to monitor the