Revert "[Inductor] Support native Inductor as backend for MTIA (#158526)"

This reverts commit cd68559d04.

Reverted https://github.com/pytorch/pytorch/pull/158526 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/158526#issuecomment-3122186057))
This commit is contained in:
PyTorch MergeBot 2025-07-26 17:58:00 +00:00
parent 7dafab6a93
commit fe0ff12dab
18 changed files with 5 additions and 228 deletions

View File

@ -16,5 +16,4 @@ The MTIA backend is implemented out of the tree, only interfaces are be defined
:nosignatures:
memory_stats
memory_allocated
```

View File

@ -1948,10 +1948,8 @@ def _mtia_isBuilt() -> _bool: ...
def _mtia_isInBadFork() -> _bool: ...
def _mtia_deviceSynchronize() -> None: ...
def _mtia_getCurrentStream(device: _int) -> Stream: ...
def _mtia_getCurrentRawStream(device: _int) -> _int: ...
def _mtia_setCurrentStream(stream: Stream) -> None: ...
def _mtia_getDefaultStream(device: _int) -> Stream: ...
def _mtia_setStream(stream_id: _int, device_index: _int, device_type: _int) -> None: ...
def _mtia_memoryStats(device: _int) -> dict[str, Any]: ...
def _mtia_getDeviceCapability(device: _int) -> tuple[_int, _int]: ...
def _mtia_getDeviceProperties(device: _int) -> dict[str, Any]: ...

View File

@ -2,10 +2,10 @@
Device abstraction layer for TorchDynamo and Inductor backends.
This module provides a unified interface for different hardware backends (CUDA, XPU,
CPU, MPS, MTIA) through a common device interface. Key components include:
CPU, MPS) through a common device interface. Key components include:
- DeviceInterface: Base class defining the common API for all device types
- Device-specific implementations: CudaInterface, XpuInterface, CpuInterface, MpsInterface, MtiaInterface
- Device-specific implementations: CudaInterface, XpuInterface, CpuInterface, MpsInterface
- Device registration system for managing available backends
- Worker APIs for multi-processing scenarios
- Stream and event management across different devices
@ -287,87 +287,6 @@ class CudaInterface(DeviceInterface):
raise RuntimeError("triton not built with the 'nvidia' backend")
get_mtia_stream: Optional[Callable[[int], int]]
if torch.mtia._is_compiled():
from torch._C import _mtia_getCurrentRawStream as get_mtia_stream
else:
get_mtia_stream = None
class MtiaInterface(DeviceInterface):
device = torch.mtia.device # type: ignore[assignment]
Event = torch.mtia.Event # type: ignore[assignment]
Stream = torch.mtia.Stream # type: ignore[assignment]
class Worker:
@staticmethod
def set_device(device: int) -> None:
caching_worker_current_devices["mtia"] = device
@staticmethod
def current_device() -> int:
if "mtia" in caching_worker_current_devices:
return caching_worker_current_devices["mtia"]
return torch.mtia.current_device()
@staticmethod
def get_device_properties(device: torch.types.Device = None) -> Any:
if device is not None:
if isinstance(device, str):
device = torch.device(device)
assert device.type == "mtia"
if isinstance(device, torch.device):
device = device.index
if device is None:
device = MtiaInterface.Worker.current_device()
if "mtia" not in caching_worker_device_properties:
device_prop = [
torch.mtia.get_device_properties(i)
for i in range(torch.mtia.device_count())
]
caching_worker_device_properties["mtia"] = device_prop
return caching_worker_device_properties["mtia"][device]
current_device = staticmethod(torch.mtia.current_device)
set_device = staticmethod(torch.mtia.set_device) # type: ignore[assignment]
device_count = staticmethod(torch.mtia.device_count)
stream = staticmethod(torch.mtia.stream) # type: ignore[assignment]
current_stream = staticmethod(torch.mtia.current_stream)
set_stream = staticmethod(torch.mtia.set_stream) # type: ignore[assignment]
_set_stream_by_id = staticmethod(torch.mtia._set_stream_by_id) # type: ignore[assignment]
synchronize = staticmethod(torch.mtia.synchronize)
get_device_properties = staticmethod(torch.mtia.get_device_properties) # type: ignore[assignment]
get_raw_stream = staticmethod(get_mtia_stream) # type: ignore[assignment, arg-type]
exchange_device = staticmethod(torch.mtia._exchange_device) # type: ignore[arg-type]
maybe_exchange_device = staticmethod(torch.mtia._maybe_exchange_device) # type: ignore[arg-type]
memory_allocated = staticmethod(torch.mtia.memory_allocated) # type: ignore[assignment]
is_bf16_supported = staticmethod(torch.mtia.is_bf16_supported) # type: ignore[arg-type]
# Can be mock patched by @patch decorator.
@staticmethod
def is_available() -> bool:
ret = torch.mtia.is_available()
return ret
@staticmethod
def get_compute_capability(device: torch.types.Device = None) -> Any:
cc = torch.mtia.get_device_capability(device)
return cc
@staticmethod
def is_triton_capable(device: torch.types.Device = None) -> bool:
return True
@staticmethod
def raise_if_triton_unavailable(evice: torch.types.Device = None) -> None:
import triton.backends
if "mtia" not in triton.backends.backends:
raise RuntimeError("triton not built with the 'mtia' backend")
get_xpu_stream: Optional[Callable[[int], int]]
if torch.xpu._is_compiled():
from torch._C import _xpu_getCurrentRawStream as get_xpu_stream
@ -590,10 +509,6 @@ def init_device_reg() -> None:
for i in range(torch.xpu.device_count()):
register_interface_for_device(f"xpu:{i}", XpuInterface)
register_interface_for_device("mtia", MtiaInterface)
for i in range(torch.mtia.device_count()):
register_interface_for_device(f"mtia:{i}", MtiaInterface)
register_interface_for_device("cpu", CpuInterface)
register_interface_for_device("mps", MpsInterface)

View File

@ -3976,7 +3976,7 @@ def is_compile_supported(device_type):
compile_supported = is_dynamo_supported()
if type == "cpu":
pass
elif type in ["cuda", "xpu", "mtia"] and compile_supported:
elif type in ["cuda", "xpu"] and compile_supported:
compile_supported = has_triton()
else:
compile_supported = False

View File

@ -16,7 +16,6 @@ from .template_heuristics import (
BaseConfigHeuristic,
CPUConfigHeuristic,
CUDAConfigHeuristic,
MTIAConfigHeuristic,
ROCmConfigHeuristic,
XPUConfigHeuristic,
)
@ -66,8 +65,6 @@ class InductorChoices:
return XPUConfigHeuristic()
elif device_type == "cpu":
return CPUConfigHeuristic()
elif device_type == "mtia":
return MTIAConfigHeuristic()
else:
return BaseConfigHeuristic()

View File

@ -492,7 +492,6 @@ def init_backend_registration() -> None:
from .cuda_combined_scheduling import CUDACombinedScheduling
from .halide import HalideScheduling
from .mps import MetalScheduling
from .python_wrapper_mtia import PythonWrapperMtia
from .triton import TritonScheduling
from .wrapper import PythonWrapperCodegen
@ -540,14 +539,6 @@ def init_backend_registration() -> None:
CppWrapperMps,
)
if get_scheduling_for_device("mtia") is None:
register_backend_for_device(
"mtia",
TritonScheduling,
PythonWrapperMtia,
CppWrapperGpu,
)
private_backend = torch._C._get_privateuse1_backend_name()
if (
private_backend != "privateuseone"
@ -593,7 +584,6 @@ def get_device_op_overrides(device: str) -> DeviceOpOverrides:
if not device_op_overrides_dict:
from . import cpu_device_op_overrides, mps_device_op_overrides # noqa: F401
from .cuda import device_op_overrides # noqa: F401
from .mtia import device_op_overrides as mtia_op_overrides # noqa: F401
from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401
return device_op_overrides_dict[device]

View File

@ -1,20 +0,0 @@
from __future__ import annotations
from ..common import DeviceOpOverrides, register_device_op_overrides
class MTIADeviceOpOverrides(DeviceOpOverrides):
def import_get_raw_stream_as(self, name: str) -> str:
return f"from torch._C import _mtia_getCurrentRawStream as {name}"
def set_device(self, device_idx: int) -> str:
return f"torch.mtia.set_device({device_idx})"
def synchronize(self) -> str:
return "torch.mtia.synchronize()"
def device_guard(self, device_idx: int) -> str:
return f"torch.mtia.device({device_idx})"
register_device_op_overrides("mtia", MTIADeviceOpOverrides())

View File

@ -1,34 +0,0 @@
from typing import Optional
from typing_extensions import override
from torch._inductor import ir
from .wrapper import PythonWrapperCodegen
class PythonWrapperMtia(PythonWrapperCodegen):
"""
A thin wrapper of PythonWrapperCodegen with MTIA specific logic
"""
@override
def write_header(self) -> None:
super().write_header()
# MITA specific imports
self.imports.splice("import mtia.host_runtime.torch_mtia.dynamic_library")
@override
@staticmethod
def create(
is_subgraph: bool,
subgraph_name: Optional[str],
parent_wrapper: Optional[PythonWrapperCodegen],
partition_signatures: Optional[ir.GraphPartitionSignature] = None,
) -> PythonWrapperCodegen:
if is_subgraph:
# Delegate to the parent class to handle the case of subgraph
return PythonWrapperCodegen.create(
is_subgraph, subgraph_name, parent_wrapper, partition_signatures
)
return PythonWrapperMtia()

View File

@ -997,7 +997,6 @@ class PythonWrapperCodegen(CodeGen):
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
@ -2778,7 +2777,7 @@ class PythonWrapperCodegen(CodeGen):
allocation_shape
)
codegen_stride_tuple = self.codegen_python_shape_tuple(stride)
if device.type in ("cpu", "cuda", "xpu", "mtia"):
if device.type in ("cpu", "cuda", "xpu"):
# optimized path for faster allocations, saving ~2us versus the stuff below
out = (
f"{name} = empty_strided_{device.type}("

View File

@ -156,8 +156,6 @@ class DeviceProperties(typing.NamedTuple):
elif device_type == "mps":
# TODO: Fetch the actual value from ioreg
multi_processor_count = 8
elif device_type == "mtia":
multi_processor_count = 64
else:
raise
return cls(

View File

@ -1201,9 +1201,3 @@ class XPUConfigHeuristic(BaseConfigHeuristic):
"""
Placeholder child class for XPU specific overrides.
"""
class MTIAConfigHeuristic(BaseConfigHeuristic):
"""
Placeholder child class for MTIA specific overrides.
"""

View File

@ -94,7 +94,7 @@ if TYPE_CHECKING:
from .scheduler import BaseSchedulerNode, SchedulerBuffer
GPU_TYPES = ["cuda", "mps", "xpu", "mtia"]
GPU_TYPES = ["cuda", "mps", "xpu"]
T = TypeVar("T")

View File

@ -31,10 +31,6 @@
#include <ATen/xpu/EmptyTensor.h>
#endif
#ifdef USE_MTIA
#include <ATen/native/mtia/EmptyTensor.h>
#endif
#include <chrono>
#include <sstream>
#include <tuple>
@ -1063,12 +1059,6 @@ static PyObject* _empty_strided_device(
return THPVariable_Wrap(at::detail::empty_strided_xpu(
sizes, strides, dtype, c10::DeviceType::XPU));
}
#endif
#ifdef USE_MTIA
else if (device_type == c10::DeviceType::MTIA) {
return THPVariable_Wrap(at::detail::empty_strided_mtia(
sizes, strides, dtype, c10::DeviceType::MTIA));
}
#endif
else {
TORCH_CHECK(
@ -1094,10 +1084,6 @@ static PyObject* _empty_strided_xpu(PyObject* dummy, PyObject* args) {
return _empty_strided_device(dummy, args, c10::DeviceType::XPU);
}
static PyObject* _empty_strided_mtia(PyObject* dummy, PyObject* args) {
return _empty_strided_device(dummy, args, c10::DeviceType::MTIA);
}
static PyObject* _reinterpret_tensor(PyObject* dummy, PyObject* args) {
HANDLE_TH_ERRORS;
static PythonArgParser parser(
@ -1129,7 +1115,6 @@ static PyMethodDef _methods[] = {
{"_empty_strided_cpu", _empty_strided_cpu, METH_VARARGS, nullptr},
{"_empty_strided_cuda", _empty_strided_cuda, METH_VARARGS, nullptr},
{"_empty_strided_xpu", _empty_strided_xpu, METH_VARARGS, nullptr},
{"_empty_strided_mtia", _empty_strided_mtia, METH_VARARGS, nullptr},
{"_reinterpret_tensor", _reinterpret_tensor, METH_VARARGS, nullptr},
{nullptr, nullptr, 0, nullptr}};

View File

@ -63,18 +63,6 @@ void initModule(PyObject* module) {
return at::detail::getMTIAHooks().getDefaultStream(device_index);
});
m.def(
"_mtia_setStream",
[](int64_t stream_id,
c10::DeviceIndex device_index,
int64_t device_type) {
torch::utils::device_lazy_init(at::kMTIA);
at::detail::getMTIAHooks().setCurrentStream(c10::Stream::unpack3(
stream_id,
device_index,
static_cast<c10::DeviceType>(device_type)));
});
m.def("_mtia_setCurrentStream", [](const c10::Stream& stream) {
torch::utils::device_lazy_init(at::kMTIA);
auto device = at::detail::getMTIAHooks().getCurrentDevice();

View File

@ -204,10 +204,6 @@ def attach_out_of_memory_observer(
torch._C._mtia_attachOutOfMemoryObserver(observer)
def is_bf16_supported(including_emulation: bool = True):
return True
def get_device_capability(device: Optional[_device_t] = None) -> tuple[int, int]:
r"""Return capability of a given device as a tuple of (major version, minor version).
@ -339,17 +335,6 @@ class StreamContext:
torch.mtia.set_stream(self.src_prev_stream) # type: ignore[arg-type]
def _set_stream_by_id(stream_id, device_index, device_type):
r"""set stream specified by the stream id, device index and
device type
Args: stream_id (int): stream id in stream pool
device_index (int): device index in topo
device_type (int): enum device type
"""
torch._C._mtia_setStream(stream_id, device_index, device_type)
def stream(stream: Optional["torch.mtia.Stream"]) -> StreamContext:
r"""Wrap around the Context-manager StreamContext that selects a given stream.
@ -407,7 +392,6 @@ __all__ = [
"default_stream",
"memory_stats",
"max_memory_allocated",
"memory_allocated",
"reset_peak_memory_stats",
"get_device_capability",
"get_device_properties",
@ -421,5 +405,4 @@ __all__ = [
"device",
"set_rng_state",
"get_rng_state",
"is_bf16_supported",
]

View File

@ -36,19 +36,6 @@ def max_memory_allocated(device: Optional[_device_t] = None) -> int:
return memory_stats(device).get("dram", 0).get("peak_bytes", 0)
def memory_allocated(device: Optional[_device_t] = None) -> int:
r"""Return the current MTIA memory occupied by tensors in bytes for a given device.
Args:
device (torch.device or int or str, optional): selected device. Returns
statistic for the current device, given by :func:`~torch.mtia.current_device`,
if :attr:`device` is ``None`` (default).
"""
if not is_initialized():
return 0
return memory_stats(device).get("dram", 0).get("allocated_bytes", 0)
def reset_peak_memory_stats(device: Optional[_device_t] = None) -> None:
r"""Reset the peak memory stats for a given device.
@ -66,6 +53,5 @@ def reset_peak_memory_stats(device: Optional[_device_t] = None) -> None:
__all__ = [
"memory_stats",
"max_memory_allocated",
"memory_allocated",
"reset_peak_memory_stats",
]

View File

@ -135,7 +135,6 @@ def has_triton() -> bool:
"cuda": cuda_extra_check,
"xpu": _return_true,
"cpu": cpu_extra_check,
"mtia": _return_true,
}
def is_device_compatible_with_triton() -> bool: