mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This PR implements 2 things:
1. support the device agnostic stream and runtime APIs captured by the dynamo.
2. support the stream methods(include the event) captured by the dynamo.
Here are details for 1st.
Previously the stream captured in dynamo was tightly bind to CUDA. Here we implement a global singleton container named `StreamMethodContainer` for different backends to register their associated stream methods to dynamo. When import the backend’s product, the stream operations can be registered directly by calling
```
device_stream_method = {'current_stream': method_1,
'create_stream_context': method_2,
'set_stream': method_3,
'set_stream_by_id': method_4}
torch._dynamo.stream.register_stream_method(device_name, device_stream_method)
```
Stream methods need to be passed in this API according to the precise semantics represented by the dict key in `device_stream_method`. After register, these methods can be used by dynamo to capture the stream operations in users’ script, for example, get the current stream or set the specific stream. Additionally, the wrapped stream variable and the stream context variable are changed to be the device-agnostic, the proxy functions of these variables are assigned by the associated methods in the container. All of this are illustrated in the below. Below is a illustration.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108312
Approved by: https://github.com/jansel, https://github.com/jgong5
46 lines
1.0 KiB
Python
46 lines
1.0 KiB
Python
from abc import ABC, abstractmethod
|
|
|
|
|
|
class _StreamBase(ABC):
|
|
r"""Base stream class abstraction for multi backends Stream to herit from"""
|
|
|
|
@abstractmethod
|
|
def wait_event(self, event):
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def wait_stream(self, stream):
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def record_event(self, event=None):
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def query(self):
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def synchronize(self):
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def __eq__(self, stream):
|
|
raise NotImplementedError()
|
|
|
|
|
|
class _EventBase(ABC):
|
|
r"""Base Event class abstraction for multi backends Event to herit from"""
|
|
|
|
@abstractmethod
|
|
def wait(self, stream=None):
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def query(self):
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def synchronize(self):
|
|
raise NotImplementedError()
|