mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[Inductor] Support native Inductor as backend for MTIA (#158526)"
This reverts commit cd68559d04.
Reverted https://github.com/pytorch/pytorch/pull/158526 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/158526#issuecomment-3122186057))
This commit is contained in:
parent
7dafab6a93
commit
fe0ff12dab
|
|
@ -16,5 +16,4 @@ The MTIA backend is implemented out of the tree, only interfaces are be defined
|
|||
:nosignatures:
|
||||
|
||||
memory_stats
|
||||
memory_allocated
|
||||
```
|
||||
|
|
|
|||
|
|
@ -1948,10 +1948,8 @@ def _mtia_isBuilt() -> _bool: ...
|
|||
def _mtia_isInBadFork() -> _bool: ...
|
||||
def _mtia_deviceSynchronize() -> None: ...
|
||||
def _mtia_getCurrentStream(device: _int) -> Stream: ...
|
||||
def _mtia_getCurrentRawStream(device: _int) -> _int: ...
|
||||
def _mtia_setCurrentStream(stream: Stream) -> None: ...
|
||||
def _mtia_getDefaultStream(device: _int) -> Stream: ...
|
||||
def _mtia_setStream(stream_id: _int, device_index: _int, device_type: _int) -> None: ...
|
||||
def _mtia_memoryStats(device: _int) -> dict[str, Any]: ...
|
||||
def _mtia_getDeviceCapability(device: _int) -> tuple[_int, _int]: ...
|
||||
def _mtia_getDeviceProperties(device: _int) -> dict[str, Any]: ...
|
||||
|
|
|
|||
|
|
@ -2,10 +2,10 @@
|
|||
Device abstraction layer for TorchDynamo and Inductor backends.
|
||||
|
||||
This module provides a unified interface for different hardware backends (CUDA, XPU,
|
||||
CPU, MPS, MTIA) through a common device interface. Key components include:
|
||||
CPU, MPS) through a common device interface. Key components include:
|
||||
|
||||
- DeviceInterface: Base class defining the common API for all device types
|
||||
- Device-specific implementations: CudaInterface, XpuInterface, CpuInterface, MpsInterface, MtiaInterface
|
||||
- Device-specific implementations: CudaInterface, XpuInterface, CpuInterface, MpsInterface
|
||||
- Device registration system for managing available backends
|
||||
- Worker APIs for multi-processing scenarios
|
||||
- Stream and event management across different devices
|
||||
|
|
@ -287,87 +287,6 @@ class CudaInterface(DeviceInterface):
|
|||
raise RuntimeError("triton not built with the 'nvidia' backend")
|
||||
|
||||
|
||||
get_mtia_stream: Optional[Callable[[int], int]]
|
||||
if torch.mtia._is_compiled():
|
||||
from torch._C import _mtia_getCurrentRawStream as get_mtia_stream
|
||||
else:
|
||||
get_mtia_stream = None
|
||||
|
||||
|
||||
class MtiaInterface(DeviceInterface):
|
||||
device = torch.mtia.device # type: ignore[assignment]
|
||||
Event = torch.mtia.Event # type: ignore[assignment]
|
||||
Stream = torch.mtia.Stream # type: ignore[assignment]
|
||||
|
||||
class Worker:
|
||||
@staticmethod
|
||||
def set_device(device: int) -> None:
|
||||
caching_worker_current_devices["mtia"] = device
|
||||
|
||||
@staticmethod
|
||||
def current_device() -> int:
|
||||
if "mtia" in caching_worker_current_devices:
|
||||
return caching_worker_current_devices["mtia"]
|
||||
return torch.mtia.current_device()
|
||||
|
||||
@staticmethod
|
||||
def get_device_properties(device: torch.types.Device = None) -> Any:
|
||||
if device is not None:
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
assert device.type == "mtia"
|
||||
if isinstance(device, torch.device):
|
||||
device = device.index
|
||||
if device is None:
|
||||
device = MtiaInterface.Worker.current_device()
|
||||
|
||||
if "mtia" not in caching_worker_device_properties:
|
||||
device_prop = [
|
||||
torch.mtia.get_device_properties(i)
|
||||
for i in range(torch.mtia.device_count())
|
||||
]
|
||||
caching_worker_device_properties["mtia"] = device_prop
|
||||
|
||||
return caching_worker_device_properties["mtia"][device]
|
||||
|
||||
current_device = staticmethod(torch.mtia.current_device)
|
||||
set_device = staticmethod(torch.mtia.set_device) # type: ignore[assignment]
|
||||
device_count = staticmethod(torch.mtia.device_count)
|
||||
stream = staticmethod(torch.mtia.stream) # type: ignore[assignment]
|
||||
current_stream = staticmethod(torch.mtia.current_stream)
|
||||
set_stream = staticmethod(torch.mtia.set_stream) # type: ignore[assignment]
|
||||
_set_stream_by_id = staticmethod(torch.mtia._set_stream_by_id) # type: ignore[assignment]
|
||||
synchronize = staticmethod(torch.mtia.synchronize)
|
||||
get_device_properties = staticmethod(torch.mtia.get_device_properties) # type: ignore[assignment]
|
||||
get_raw_stream = staticmethod(get_mtia_stream) # type: ignore[assignment, arg-type]
|
||||
exchange_device = staticmethod(torch.mtia._exchange_device) # type: ignore[arg-type]
|
||||
maybe_exchange_device = staticmethod(torch.mtia._maybe_exchange_device) # type: ignore[arg-type]
|
||||
memory_allocated = staticmethod(torch.mtia.memory_allocated) # type: ignore[assignment]
|
||||
is_bf16_supported = staticmethod(torch.mtia.is_bf16_supported) # type: ignore[arg-type]
|
||||
|
||||
# Can be mock patched by @patch decorator.
|
||||
@staticmethod
|
||||
def is_available() -> bool:
|
||||
ret = torch.mtia.is_available()
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def get_compute_capability(device: torch.types.Device = None) -> Any:
|
||||
cc = torch.mtia.get_device_capability(device)
|
||||
return cc
|
||||
|
||||
@staticmethod
|
||||
def is_triton_capable(device: torch.types.Device = None) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def raise_if_triton_unavailable(evice: torch.types.Device = None) -> None:
|
||||
import triton.backends
|
||||
|
||||
if "mtia" not in triton.backends.backends:
|
||||
raise RuntimeError("triton not built with the 'mtia' backend")
|
||||
|
||||
|
||||
get_xpu_stream: Optional[Callable[[int], int]]
|
||||
if torch.xpu._is_compiled():
|
||||
from torch._C import _xpu_getCurrentRawStream as get_xpu_stream
|
||||
|
|
@ -590,10 +509,6 @@ def init_device_reg() -> None:
|
|||
for i in range(torch.xpu.device_count()):
|
||||
register_interface_for_device(f"xpu:{i}", XpuInterface)
|
||||
|
||||
register_interface_for_device("mtia", MtiaInterface)
|
||||
for i in range(torch.mtia.device_count()):
|
||||
register_interface_for_device(f"mtia:{i}", MtiaInterface)
|
||||
|
||||
register_interface_for_device("cpu", CpuInterface)
|
||||
register_interface_for_device("mps", MpsInterface)
|
||||
|
||||
|
|
|
|||
|
|
@ -3976,7 +3976,7 @@ def is_compile_supported(device_type):
|
|||
compile_supported = is_dynamo_supported()
|
||||
if type == "cpu":
|
||||
pass
|
||||
elif type in ["cuda", "xpu", "mtia"] and compile_supported:
|
||||
elif type in ["cuda", "xpu"] and compile_supported:
|
||||
compile_supported = has_triton()
|
||||
else:
|
||||
compile_supported = False
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ from .template_heuristics import (
|
|||
BaseConfigHeuristic,
|
||||
CPUConfigHeuristic,
|
||||
CUDAConfigHeuristic,
|
||||
MTIAConfigHeuristic,
|
||||
ROCmConfigHeuristic,
|
||||
XPUConfigHeuristic,
|
||||
)
|
||||
|
|
@ -66,8 +65,6 @@ class InductorChoices:
|
|||
return XPUConfigHeuristic()
|
||||
elif device_type == "cpu":
|
||||
return CPUConfigHeuristic()
|
||||
elif device_type == "mtia":
|
||||
return MTIAConfigHeuristic()
|
||||
else:
|
||||
return BaseConfigHeuristic()
|
||||
|
||||
|
|
|
|||
|
|
@ -492,7 +492,6 @@ def init_backend_registration() -> None:
|
|||
from .cuda_combined_scheduling import CUDACombinedScheduling
|
||||
from .halide import HalideScheduling
|
||||
from .mps import MetalScheduling
|
||||
from .python_wrapper_mtia import PythonWrapperMtia
|
||||
from .triton import TritonScheduling
|
||||
from .wrapper import PythonWrapperCodegen
|
||||
|
||||
|
|
@ -540,14 +539,6 @@ def init_backend_registration() -> None:
|
|||
CppWrapperMps,
|
||||
)
|
||||
|
||||
if get_scheduling_for_device("mtia") is None:
|
||||
register_backend_for_device(
|
||||
"mtia",
|
||||
TritonScheduling,
|
||||
PythonWrapperMtia,
|
||||
CppWrapperGpu,
|
||||
)
|
||||
|
||||
private_backend = torch._C._get_privateuse1_backend_name()
|
||||
if (
|
||||
private_backend != "privateuseone"
|
||||
|
|
@ -593,7 +584,6 @@ def get_device_op_overrides(device: str) -> DeviceOpOverrides:
|
|||
if not device_op_overrides_dict:
|
||||
from . import cpu_device_op_overrides, mps_device_op_overrides # noqa: F401
|
||||
from .cuda import device_op_overrides # noqa: F401
|
||||
from .mtia import device_op_overrides as mtia_op_overrides # noqa: F401
|
||||
from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401
|
||||
|
||||
return device_op_overrides_dict[device]
|
||||
|
|
|
|||
|
|
@ -1,20 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from ..common import DeviceOpOverrides, register_device_op_overrides
|
||||
|
||||
|
||||
class MTIADeviceOpOverrides(DeviceOpOverrides):
|
||||
def import_get_raw_stream_as(self, name: str) -> str:
|
||||
return f"from torch._C import _mtia_getCurrentRawStream as {name}"
|
||||
|
||||
def set_device(self, device_idx: int) -> str:
|
||||
return f"torch.mtia.set_device({device_idx})"
|
||||
|
||||
def synchronize(self) -> str:
|
||||
return "torch.mtia.synchronize()"
|
||||
|
||||
def device_guard(self, device_idx: int) -> str:
|
||||
return f"torch.mtia.device({device_idx})"
|
||||
|
||||
|
||||
register_device_op_overrides("mtia", MTIADeviceOpOverrides())
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
|
||||
from torch._inductor import ir
|
||||
|
||||
from .wrapper import PythonWrapperCodegen
|
||||
|
||||
|
||||
class PythonWrapperMtia(PythonWrapperCodegen):
|
||||
"""
|
||||
A thin wrapper of PythonWrapperCodegen with MTIA specific logic
|
||||
"""
|
||||
|
||||
@override
|
||||
def write_header(self) -> None:
|
||||
super().write_header()
|
||||
|
||||
# MITA specific imports
|
||||
self.imports.splice("import mtia.host_runtime.torch_mtia.dynamic_library")
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def create(
|
||||
is_subgraph: bool,
|
||||
subgraph_name: Optional[str],
|
||||
parent_wrapper: Optional[PythonWrapperCodegen],
|
||||
partition_signatures: Optional[ir.GraphPartitionSignature] = None,
|
||||
) -> PythonWrapperCodegen:
|
||||
if is_subgraph:
|
||||
# Delegate to the parent class to handle the case of subgraph
|
||||
return PythonWrapperCodegen.create(
|
||||
is_subgraph, subgraph_name, parent_wrapper, partition_signatures
|
||||
)
|
||||
return PythonWrapperMtia()
|
||||
|
|
@ -997,7 +997,6 @@ class PythonWrapperCodegen(CodeGen):
|
|||
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
||||
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
||||
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
||||
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
|
||||
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
||||
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
||||
async_compile = AsyncCompile()
|
||||
|
|
@ -2778,7 +2777,7 @@ class PythonWrapperCodegen(CodeGen):
|
|||
allocation_shape
|
||||
)
|
||||
codegen_stride_tuple = self.codegen_python_shape_tuple(stride)
|
||||
if device.type in ("cpu", "cuda", "xpu", "mtia"):
|
||||
if device.type in ("cpu", "cuda", "xpu"):
|
||||
# optimized path for faster allocations, saving ~2us versus the stuff below
|
||||
out = (
|
||||
f"{name} = empty_strided_{device.type}("
|
||||
|
|
|
|||
|
|
@ -156,8 +156,6 @@ class DeviceProperties(typing.NamedTuple):
|
|||
elif device_type == "mps":
|
||||
# TODO: Fetch the actual value from ioreg
|
||||
multi_processor_count = 8
|
||||
elif device_type == "mtia":
|
||||
multi_processor_count = 64
|
||||
else:
|
||||
raise
|
||||
return cls(
|
||||
|
|
|
|||
|
|
@ -1201,9 +1201,3 @@ class XPUConfigHeuristic(BaseConfigHeuristic):
|
|||
"""
|
||||
Placeholder child class for XPU specific overrides.
|
||||
"""
|
||||
|
||||
|
||||
class MTIAConfigHeuristic(BaseConfigHeuristic):
|
||||
"""
|
||||
Placeholder child class for MTIA specific overrides.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ if TYPE_CHECKING:
|
|||
from .scheduler import BaseSchedulerNode, SchedulerBuffer
|
||||
|
||||
|
||||
GPU_TYPES = ["cuda", "mps", "xpu", "mtia"]
|
||||
GPU_TYPES = ["cuda", "mps", "xpu"]
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -31,10 +31,6 @@
|
|||
#include <ATen/xpu/EmptyTensor.h>
|
||||
#endif
|
||||
|
||||
#ifdef USE_MTIA
|
||||
#include <ATen/native/mtia/EmptyTensor.h>
|
||||
#endif
|
||||
|
||||
#include <chrono>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
|
|
@ -1063,12 +1059,6 @@ static PyObject* _empty_strided_device(
|
|||
return THPVariable_Wrap(at::detail::empty_strided_xpu(
|
||||
sizes, strides, dtype, c10::DeviceType::XPU));
|
||||
}
|
||||
#endif
|
||||
#ifdef USE_MTIA
|
||||
else if (device_type == c10::DeviceType::MTIA) {
|
||||
return THPVariable_Wrap(at::detail::empty_strided_mtia(
|
||||
sizes, strides, dtype, c10::DeviceType::MTIA));
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
TORCH_CHECK(
|
||||
|
|
@ -1094,10 +1084,6 @@ static PyObject* _empty_strided_xpu(PyObject* dummy, PyObject* args) {
|
|||
return _empty_strided_device(dummy, args, c10::DeviceType::XPU);
|
||||
}
|
||||
|
||||
static PyObject* _empty_strided_mtia(PyObject* dummy, PyObject* args) {
|
||||
return _empty_strided_device(dummy, args, c10::DeviceType::MTIA);
|
||||
}
|
||||
|
||||
static PyObject* _reinterpret_tensor(PyObject* dummy, PyObject* args) {
|
||||
HANDLE_TH_ERRORS;
|
||||
static PythonArgParser parser(
|
||||
|
|
@ -1129,7 +1115,6 @@ static PyMethodDef _methods[] = {
|
|||
{"_empty_strided_cpu", _empty_strided_cpu, METH_VARARGS, nullptr},
|
||||
{"_empty_strided_cuda", _empty_strided_cuda, METH_VARARGS, nullptr},
|
||||
{"_empty_strided_xpu", _empty_strided_xpu, METH_VARARGS, nullptr},
|
||||
{"_empty_strided_mtia", _empty_strided_mtia, METH_VARARGS, nullptr},
|
||||
{"_reinterpret_tensor", _reinterpret_tensor, METH_VARARGS, nullptr},
|
||||
{nullptr, nullptr, 0, nullptr}};
|
||||
|
||||
|
|
|
|||
|
|
@ -63,18 +63,6 @@ void initModule(PyObject* module) {
|
|||
return at::detail::getMTIAHooks().getDefaultStream(device_index);
|
||||
});
|
||||
|
||||
m.def(
|
||||
"_mtia_setStream",
|
||||
[](int64_t stream_id,
|
||||
c10::DeviceIndex device_index,
|
||||
int64_t device_type) {
|
||||
torch::utils::device_lazy_init(at::kMTIA);
|
||||
at::detail::getMTIAHooks().setCurrentStream(c10::Stream::unpack3(
|
||||
stream_id,
|
||||
device_index,
|
||||
static_cast<c10::DeviceType>(device_type)));
|
||||
});
|
||||
|
||||
m.def("_mtia_setCurrentStream", [](const c10::Stream& stream) {
|
||||
torch::utils::device_lazy_init(at::kMTIA);
|
||||
auto device = at::detail::getMTIAHooks().getCurrentDevice();
|
||||
|
|
|
|||
|
|
@ -204,10 +204,6 @@ def attach_out_of_memory_observer(
|
|||
torch._C._mtia_attachOutOfMemoryObserver(observer)
|
||||
|
||||
|
||||
def is_bf16_supported(including_emulation: bool = True):
|
||||
return True
|
||||
|
||||
|
||||
def get_device_capability(device: Optional[_device_t] = None) -> tuple[int, int]:
|
||||
r"""Return capability of a given device as a tuple of (major version, minor version).
|
||||
|
||||
|
|
@ -339,17 +335,6 @@ class StreamContext:
|
|||
torch.mtia.set_stream(self.src_prev_stream) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _set_stream_by_id(stream_id, device_index, device_type):
|
||||
r"""set stream specified by the stream id, device index and
|
||||
device type
|
||||
|
||||
Args: stream_id (int): stream id in stream pool
|
||||
device_index (int): device index in topo
|
||||
device_type (int): enum device type
|
||||
"""
|
||||
torch._C._mtia_setStream(stream_id, device_index, device_type)
|
||||
|
||||
|
||||
def stream(stream: Optional["torch.mtia.Stream"]) -> StreamContext:
|
||||
r"""Wrap around the Context-manager StreamContext that selects a given stream.
|
||||
|
||||
|
|
@ -407,7 +392,6 @@ __all__ = [
|
|||
"default_stream",
|
||||
"memory_stats",
|
||||
"max_memory_allocated",
|
||||
"memory_allocated",
|
||||
"reset_peak_memory_stats",
|
||||
"get_device_capability",
|
||||
"get_device_properties",
|
||||
|
|
@ -421,5 +405,4 @@ __all__ = [
|
|||
"device",
|
||||
"set_rng_state",
|
||||
"get_rng_state",
|
||||
"is_bf16_supported",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -36,19 +36,6 @@ def max_memory_allocated(device: Optional[_device_t] = None) -> int:
|
|||
return memory_stats(device).get("dram", 0).get("peak_bytes", 0)
|
||||
|
||||
|
||||
def memory_allocated(device: Optional[_device_t] = None) -> int:
|
||||
r"""Return the current MTIA memory occupied by tensors in bytes for a given device.
|
||||
|
||||
Args:
|
||||
device (torch.device or int or str, optional): selected device. Returns
|
||||
statistic for the current device, given by :func:`~torch.mtia.current_device`,
|
||||
if :attr:`device` is ``None`` (default).
|
||||
"""
|
||||
if not is_initialized():
|
||||
return 0
|
||||
return memory_stats(device).get("dram", 0).get("allocated_bytes", 0)
|
||||
|
||||
|
||||
def reset_peak_memory_stats(device: Optional[_device_t] = None) -> None:
|
||||
r"""Reset the peak memory stats for a given device.
|
||||
|
||||
|
|
@ -66,6 +53,5 @@ def reset_peak_memory_stats(device: Optional[_device_t] = None) -> None:
|
|||
__all__ = [
|
||||
"memory_stats",
|
||||
"max_memory_allocated",
|
||||
"memory_allocated",
|
||||
"reset_peak_memory_stats",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -135,7 +135,6 @@ def has_triton() -> bool:
|
|||
"cuda": cuda_extra_check,
|
||||
"xpu": _return_true,
|
||||
"cpu": cpu_extra_check,
|
||||
"mtia": _return_true,
|
||||
}
|
||||
|
||||
def is_device_compatible_with_triton() -> bool:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user