Use torch.Stream&torch.Event for Dynamo capature (#134850)

# Motivation
This PR aims to solve the multiple Inheritance problem.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134850
Approved by: https://github.com/yf225, https://github.com/EikanWang
This commit is contained in:
Yu, Guangye 2024-10-02 09:13:19 +00:00 committed by PyTorch MergeBot
parent bf73af4b4e
commit d29094888b
7 changed files with 46 additions and 81 deletions

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import time import time
import torch
from torch._dynamo import device_interface # noqa: PLC2701 import-private-name from torch._dynamo import device_interface # noqa: PLC2701 import-private-name
@ -13,9 +14,7 @@ class DeviceProperties:
class DeviceInterface(device_interface.DeviceInterface): class DeviceInterface(device_interface.DeviceInterface):
class Event( class Event(torch.Event):
device_interface._EventBase
): # pyright: ignore [reportPrivateImportUsage]
def __init__( def __init__(
self, self,
enable_timing: bool = False, enable_timing: bool = False,

View File

@ -1,11 +1,9 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import inspect
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union
import torch import torch
from torch._streambase import _EventBase, _StreamBase
get_cuda_stream: Optional[Callable[[int], int]] get_cuda_stream: Optional[Callable[[int], int]]
@ -21,21 +19,7 @@ caching_worker_device_properties: Dict[str, Any] = {}
caching_worker_current_devices: Dict[str, int] = {} caching_worker_current_devices: Dict[str, int] = {}
class DeviceInterfaceMeta(type): class DeviceInterface:
def __new__(metacls, *args, **kwargs):
class_member = args[2]
if "Event" in class_member:
assert inspect.isclass(class_member["Event"]) and issubclass(
class_member["Event"], _EventBase
), "DeviceInterface member Event should be inherit from _EventBase"
if "Stream" in class_member:
assert inspect.isclass(class_member["Stream"]) and issubclass(
class_member["Stream"], _StreamBase
), "DeviceInterface member Stream should be inherit from _StreamBase"
return super().__new__(metacls, *args, **kwargs)
class DeviceInterface(metaclass=DeviceInterfaceMeta):
""" """
This is a simple device runtime interface for Inductor. It enables custom This is a simple device runtime interface for Inductor. It enables custom
backends to be integrated with Inductor in a device-agnostic semantic. backends to be integrated with Inductor in a device-agnostic semantic.
@ -45,6 +29,18 @@ class DeviceInterface(metaclass=DeviceInterfaceMeta):
def __new__(cls, device: _device_t): def __new__(cls, device: _device_t):
raise NotImplementedError raise NotImplementedError
class Event:
def __new__(cls, *args, **kwargs):
raise NotImplementedError(
"Event should be inherited from torch.Event, otherwise, it couldn't be captured by dynamo."
)
class Stream:
def __new__(cls, *args, **kwargs):
raise NotImplementedError(
"Stream should be inherited from torch.Stream, otherwise, it couldn't be captured by dynamo."
)
class Worker: class Worker:
""" """
Worker API to query device properties that will work in multi processing Worker API to query device properties that will work in multi processing
@ -161,7 +157,7 @@ class CudaInterface(DeviceInterface):
device = torch.cuda.device device = torch.cuda.device
# register Event and Stream class into the backend interface # register Event and Stream class into the backend interface
# make sure Event and Stream are implemented and inherited from the _EventBase and _StreamBase # make sure Event and Stream are implemented and inherited from the torch.Event and torch.Stream
Event = torch.cuda.Event Event = torch.cuda.Event
Stream = torch.cuda.Stream Stream = torch.cuda.Stream
@ -303,14 +299,14 @@ class CpuDeviceProperties:
class CpuInterface(DeviceInterface): class CpuInterface(DeviceInterface):
class Event(_EventBase): class Event(torch.Event):
def __init__(self, enable_timing=True): def __init__(self, enable_timing=True):
self.time = 0.0 self.time = 0.0
def elapsed_time(self, end_event) -> float: def elapsed_time(self, end_event) -> float:
return (end_event.time - self.time) * 1000 return (end_event.time - self.time) * 1000
def record(self): def record(self, stream=None):
self.time = time.perf_counter() self.time = time.perf_counter()
@staticmethod @staticmethod

View File

@ -36,7 +36,6 @@ from torch import SymInt
from torch._guards import GuardSource, TracingContext from torch._guards import GuardSource, TracingContext
from torch._higher_order_ops.torchbind import call_torchbind from torch._higher_order_ops.torchbind import call_torchbind
from torch._ops import HigherOrderOperator from torch._ops import HigherOrderOperator
from torch._streambase import _EventBase, _StreamBase
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
from torch._subclasses.meta_utils import is_sparse_any, safe_grad from torch._subclasses.meta_utils import is_sparse_any, safe_grad
from torch._utils_internal import justknobs_check from torch._utils_internal import justknobs_check
@ -822,7 +821,7 @@ class VariableBuilder:
stream_source = AttrSource(self.source, "stream") stream_source = AttrSource(self.source, "stream")
stream_var = VariableBuilder(self.tx, stream_source)(value.stream) stream_var = VariableBuilder(self.tx, stream_source)(value.stream)
return StreamContextVariable.create(self.tx, stream_var) return StreamContextVariable.create(self.tx, stream_var)
elif isinstance(value, _StreamBase): elif isinstance(value, torch.Stream):
self.install_guards(GuardBuilder.ID_MATCH) self.install_guards(GuardBuilder.ID_MATCH)
stream_proxy = self.tx.output.create_proxy( stream_proxy = self.tx.output.create_proxy(
"call_function", "call_function",
@ -847,7 +846,7 @@ class VariableBuilder:
elif isinstance(value, torch._C._SDPBackend): elif isinstance(value, torch._C._SDPBackend):
self.install_guards(GuardBuilder.ID_MATCH) self.install_guards(GuardBuilder.ID_MATCH)
return ConstantVariable(value) return ConstantVariable(value)
elif isinstance(value, _EventBase): elif isinstance(value, torch.Event):
self.install_guards(GuardBuilder.ID_MATCH) self.install_guards(GuardBuilder.ID_MATCH)
torch._dynamo.utils.store_user_object_weakref(value) torch._dynamo.utils.store_user_object_weakref(value)
event_proxy = self.tx.output.create_proxy( event_proxy = self.tx.output.create_proxy(
@ -2265,7 +2264,7 @@ def wrap_fx_proxy_cls(
return SymNodeVariable(proxy, example_value, **options) return SymNodeVariable(proxy, example_value, **options)
elif ( elif (
inspect.isclass(proxy.node.target) inspect.isclass(proxy.node.target)
and issubclass(proxy.node.target, _StreamBase) and issubclass(proxy.node.target, torch.Stream)
) or proxy.node.target in [ ) or proxy.node.target in [
device_interface.current_stream device_interface.current_stream
for _, device_interface in get_registered_device_interfaces() for _, device_interface in get_registered_device_interfaces()
@ -2273,7 +2272,8 @@ def wrap_fx_proxy_cls(
set_example_value(proxy.node, example_value) set_example_value(proxy.node, example_value)
return StreamVariable(proxy, example_value, example_value.device, **options) return StreamVariable(proxy, example_value, example_value.device, **options)
elif ( elif (
inspect.isclass(proxy.node.target) and issubclass(proxy.node.target, _EventBase) inspect.isclass(proxy.node.target)
and issubclass(proxy.node.target, torch.Event)
) or proxy.node.target in [ ) or proxy.node.target in [
device_interface.Event device_interface.Event
for _, device_interface in get_registered_device_interfaces() for _, device_interface in get_registered_device_interfaces()
@ -2285,7 +2285,7 @@ def wrap_fx_proxy_cls(
return ConstantVariable(example_value, **options) return ConstantVariable(example_value, **options)
elif ( elif (
example_value is not None example_value is not None
and isinstance(example_value, _EventBase) and isinstance(example_value, torch.Event)
and proxy.node.target == "record_event" and proxy.node.target == "record_event"
and proxy.node.op == "call_method" and proxy.node.op == "call_method"
): ):

View File

@ -13,7 +13,6 @@ import torch.fx
import torch.nn import torch.nn
from torch._guards import TracingContext from torch._guards import TracingContext
from torch._logging import warning_once from torch._logging import warning_once
from torch._streambase import _StreamBase
from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type
from .. import config, polyfills, variables from .. import config, polyfills, variables
@ -267,7 +266,7 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
assert len(args) <= 1 and len(kwargs) == 0 assert len(args) <= 1 and len(kwargs) == 0
inf_mode = args[0].as_python_constant() if len(args) == 1 else True inf_mode = args[0].as_python_constant() if len(args) == 1 else True
return InferenceModeVariable.create(tx, inf_mode) return InferenceModeVariable.create(tx, inf_mode)
elif inspect.isclass(self.value) and issubclass(self.value, _StreamBase): elif inspect.isclass(self.value) and issubclass(self.value, torch.Stream):
from torch._dynamo.variables.builder import wrap_fx_proxy_cls from torch._dynamo.variables.builder import wrap_fx_proxy_cls
return wrap_fx_proxy_cls( return wrap_fx_proxy_cls(

View File

@ -1,46 +1,20 @@
# mypy: allow-untyped-defs from typing_extensions import deprecated
from abc import ABC, abstractmethod
import torch
class _StreamBase(ABC): # Preserved only for BC reasons
r"""Base stream class abstraction for multi backends Stream to herit from""" @deprecated(
"`torch._streambase._StreamBase` is deprecated. Please use `torch.Stream` instead.",
@abstractmethod category=FutureWarning,
def wait_event(self, event) -> None: )
raise NotImplementedError class _StreamBase(torch.Stream):
pass
@abstractmethod
def wait_stream(self, stream) -> None:
raise NotImplementedError
@abstractmethod
def record_event(self, event=None) -> None:
raise NotImplementedError
@abstractmethod
def query(self) -> bool:
raise NotImplementedError
@abstractmethod
def synchronize(self) -> None:
raise NotImplementedError
@abstractmethod
def __eq__(self, stream) -> bool:
raise NotImplementedError
class _EventBase(ABC): @deprecated(
r"""Base Event class abstraction for multi backends Event to herit from""" "`torch._streambase._EventBase` is deprecated. Please use `torch.Event` instead.",
category=FutureWarning,
@abstractmethod )
def wait(self, stream=None) -> None: class _EventBase(torch.Event):
raise NotImplementedError pass
@abstractmethod
def query(self) -> bool:
raise NotImplementedError
@abstractmethod
def synchronize(self) -> None:
raise NotImplementedError

View File

@ -2,7 +2,6 @@
import ctypes import ctypes
import torch import torch
from torch._streambase import _EventBase, _StreamBase
from torch._utils import _dummy_type from torch._utils import _dummy_type
@ -12,7 +11,7 @@ if not hasattr(torch._C, "_CudaStreamBase"):
torch._C.__dict__["_CudaEventBase"] = _dummy_type("_CudaEventBase") torch._C.__dict__["_CudaEventBase"] = _dummy_type("_CudaEventBase")
class Stream(torch._C._CudaStreamBase, _StreamBase): class Stream(torch._C._CudaStreamBase):
r"""Wrapper around a CUDA stream. r"""Wrapper around a CUDA stream.
A CUDA stream is a linear sequence of execution that belongs to a specific A CUDA stream is a linear sequence of execution that belongs to a specific
@ -138,7 +137,7 @@ class ExternalStream(Stream):
return super().__new__(cls, stream_ptr=stream_ptr, **kwargs) return super().__new__(cls, stream_ptr=stream_ptr, **kwargs)
class Event(torch._C._CudaEventBase, _EventBase): class Event(torch._C._CudaEventBase):
r"""Wrapper around a CUDA event. r"""Wrapper around a CUDA event.
CUDA events are synchronization markers that can be used to monitor the CUDA events are synchronization markers that can be used to monitor the

View File

@ -2,9 +2,7 @@
import ctypes import ctypes
import torch import torch
from torch._streambase import _EventBase, _StreamBase from torch._utils import _dummy_type
from .._utils import _dummy_type
if not hasattr(torch._C, "_XpuStreamBase"): if not hasattr(torch._C, "_XpuStreamBase"):
@ -13,7 +11,7 @@ if not hasattr(torch._C, "_XpuStreamBase"):
torch._C.__dict__["_XpuEventBase"] = _dummy_type("_XpuEventBase") torch._C.__dict__["_XpuEventBase"] = _dummy_type("_XpuEventBase")
class Stream(torch._C._XpuStreamBase, _StreamBase): class Stream(torch._C._XpuStreamBase):
r"""Wrapper around a XPU stream. r"""Wrapper around a XPU stream.
A XPU stream is a linear sequence of execution that belongs to a specific A XPU stream is a linear sequence of execution that belongs to a specific
@ -98,7 +96,7 @@ class Stream(torch._C._XpuStreamBase, _StreamBase):
return f"torch.xpu.Stream(device={self.device} sycl_queue={self.sycl_queue:#x})" return f"torch.xpu.Stream(device={self.device} sycl_queue={self.sycl_queue:#x})"
class Event(torch._C._XpuEventBase, _EventBase): class Event(torch._C._XpuEventBase):
r"""Wrapper around a XPU event. r"""Wrapper around a XPU event.
XPU events are synchronization markers that can be used to monitor the XPU events are synchronization markers that can be used to monitor the