mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Add test_cpp_extensions tests for stream_and_event and mita_backend (#123614)"
This reverts commit b6f0159db0.
Reverted https://github.com/pytorch/pytorch/pull/123614 on behalf of https://github.com/jeffdaily due to This broke ROCm. see test_overrides.py ([comment](https://github.com/pytorch/pytorch/pull/123611#issuecomment-2067363780))
This commit is contained in:
parent
f8f7cfbeee
commit
52da03edeb
|
|
@ -1,219 +0,0 @@
|
|||
#include <ATen/detail/MTIAHooksInterface.h>
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/Stream.h>
|
||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
||||
#include <c10/util/Logging.h>
|
||||
#include <torch/csrc/utils/device_lazy_init.h>
|
||||
#include <thread>
|
||||
namespace torch::mtia {
|
||||
|
||||
constexpr c10::DeviceType kMTIADeviceType = c10::DeviceType::MTIA;
|
||||
constexpr c10::DeviceIndex kMTIADeviceCount = 2;
|
||||
static thread_local c10::DeviceIndex current_device = 0;
|
||||
static thread_local std::array<c10::Stream, kMTIADeviceCount> current_streams =
|
||||
{c10::Stream::unpack3(0, 0, c10::DeviceType::MTIA),
|
||||
c10::Stream::unpack3(0, 1, c10::DeviceType::MTIA)};
|
||||
static int64_t stream_id_gen = 1;
|
||||
static int64_t event_id_gen = 1;
|
||||
static std::array<c10::Stream, kMTIADeviceCount> default_streams = {
|
||||
c10::Stream::unpack3(0, 0, c10::DeviceType::MTIA),
|
||||
c10::Stream::unpack3(0, 1, c10::DeviceType::MTIA)};
|
||||
struct MTIAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
MTIAGuardImpl() = default;
|
||||
explicit MTIAGuardImpl(c10::DeviceType t) {
|
||||
TORCH_INTERNAL_ASSERT(t == kMTIADeviceType);
|
||||
}
|
||||
c10::DeviceType type() const override {
|
||||
return kMTIADeviceType;
|
||||
}
|
||||
c10::Device exchangeDevice(c10::Device d) const override {
|
||||
c10::Device old_device = getDevice();
|
||||
if (old_device.index() != d.index()) {
|
||||
setDevice(d);
|
||||
}
|
||||
return old_device;
|
||||
}
|
||||
c10::Device getDevice() const override {
|
||||
return c10::Device(kMTIADeviceType, current_device);
|
||||
}
|
||||
|
||||
void setDevice(c10::Device d) const override {
|
||||
c10::Device current_device = getDevice();
|
||||
if (current_device.index() != d.index()) {
|
||||
current_device = d;
|
||||
}
|
||||
}
|
||||
void uncheckedSetDevice(c10::Device d) const noexcept override {
|
||||
(void)d;
|
||||
}
|
||||
c10::Stream getStream(c10::Device d) const noexcept override {
|
||||
return current_streams[d.index()];
|
||||
}
|
||||
c10::Stream getNewStream(c10::Device d, int priority = 0) const override {
|
||||
(void)priority;
|
||||
return c10::Stream::unpack3(stream_id_gen++, d.index(), d.type());
|
||||
}
|
||||
c10::Stream getDefaultStream(c10::Device d) const override {
|
||||
return default_streams[d.index()];
|
||||
}
|
||||
c10::Stream getStreamFromGlobalPool(
|
||||
c10::Device d,
|
||||
bool isHighPriority = false) const override {
|
||||
return c10::Stream::unpack3(stream_id_gen++, d.index(), d.type());
|
||||
}
|
||||
// NB: These do NOT set the current device
|
||||
c10::Stream exchangeStream(c10::Stream s) const noexcept override {
|
||||
c10::Stream old_stream = getStream(s.device());
|
||||
return old_stream;
|
||||
}
|
||||
c10::DeviceIndex deviceCount() const noexcept override {
|
||||
return kMTIADeviceCount;
|
||||
}
|
||||
|
||||
void destroyEvent(void* event, const c10::DeviceIndex device_index)
|
||||
const noexcept override {
|
||||
(void)device_index;
|
||||
}
|
||||
|
||||
void record(
|
||||
void** event,
|
||||
const c10::Stream& stream,
|
||||
const c10::DeviceIndex device_index,
|
||||
const c10::EventFlag flag) const override {
|
||||
TORCH_CHECK(
|
||||
device_index == -1 || device_index == stream.device_index(),
|
||||
"Event device index ",
|
||||
device_index,
|
||||
" does not match recording stream's device index ",
|
||||
stream.device_index(),
|
||||
".");
|
||||
|
||||
const auto orig_device = getDevice();
|
||||
|
||||
setDevice(stream.device());
|
||||
|
||||
if (*event == nullptr) {
|
||||
*event = reinterpret_cast<void*>(event_id_gen++);
|
||||
}
|
||||
setDevice(orig_device);
|
||||
}
|
||||
|
||||
void block(void* event, const c10::Stream& stream) const override {
|
||||
(void)event;
|
||||
(void)stream;
|
||||
}
|
||||
|
||||
// May be called from any device
|
||||
bool queryEvent(void* event) const override {
|
||||
(void)event;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Stream-related functions
|
||||
bool queryStream(const c10::Stream& stream) const override {
|
||||
(void)stream;
|
||||
return true;
|
||||
}
|
||||
|
||||
void synchronizeStream(const c10::Stream& stream) const override {
|
||||
(void)stream;
|
||||
}
|
||||
|
||||
void recordDataPtrOnStream(
|
||||
const c10::DataPtr& data_ptr,
|
||||
const c10::Stream& stream) const override {
|
||||
(void)data_ptr;
|
||||
(void)stream;
|
||||
}
|
||||
|
||||
double elapsedTime(void* event1, void* event2) const override {
|
||||
uint64_t elapsed_time = 1e6;
|
||||
return (double)(elapsed_time / 1e6);
|
||||
}
|
||||
|
||||
void synchronizeEvent(void* event) const override {
|
||||
(void)event;
|
||||
}
|
||||
};
|
||||
|
||||
struct MTIAHooks : public at::MTIAHooksInterface {
|
||||
explicit MTIAHooks(at::MTIAHooksArgs) {}
|
||||
void initMTIA() const override {}
|
||||
|
||||
bool hasMTIA() const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
c10::DeviceIndex deviceCount() const override {
|
||||
torch::utils::device_lazy_init(at::kMTIA);
|
||||
return c10::DeviceIndex(2);
|
||||
}
|
||||
|
||||
void deviceSynchronize(c10::DeviceIndex device_index) const override {
|
||||
torch::utils::device_lazy_init(at::kMTIA);
|
||||
(void)device_index;
|
||||
}
|
||||
|
||||
std::string showConfig() const override {
|
||||
return "None config";
|
||||
}
|
||||
|
||||
c10::DeviceIndex exchangeDevice(c10::DeviceIndex device) const override {
|
||||
torch::utils::device_lazy_init(at::kMTIA);
|
||||
auto orig_device = current_device;
|
||||
if (current_device != device) {
|
||||
current_device = device;
|
||||
}
|
||||
return orig_device;
|
||||
}
|
||||
|
||||
c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device) const override {
|
||||
torch::utils::device_lazy_init(at::kMTIA);
|
||||
|
||||
auto orig_device = current_device;
|
||||
if (current_device != device) {
|
||||
current_device = device;
|
||||
}
|
||||
return orig_device;
|
||||
}
|
||||
|
||||
c10::Stream getDefaultStream(c10::DeviceIndex device) const override {
|
||||
torch::utils::device_lazy_init(at::kMTIA);
|
||||
|
||||
return default_streams[device];
|
||||
}
|
||||
|
||||
c10::Stream getCurrentStream(c10::DeviceIndex device) const override {
|
||||
torch::utils::device_lazy_init(at::kMTIA);
|
||||
|
||||
return current_streams[device];
|
||||
}
|
||||
|
||||
void setCurrentStream(const c10::Stream& stream) const override {
|
||||
torch::utils::device_lazy_init(at::kMTIA);
|
||||
|
||||
current_streams[stream.device_index()] = stream;
|
||||
}
|
||||
|
||||
c10::DeviceIndex getCurrentDevice() const override {
|
||||
torch::utils::device_lazy_init(at::kMTIA);
|
||||
|
||||
return current_device;
|
||||
}
|
||||
|
||||
void setCurrentDevice(c10::DeviceIndex device) const override {
|
||||
torch::utils::device_lazy_init(at::kMTIA);
|
||||
|
||||
if (current_device != device) {
|
||||
current_device = device;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using at::MTIAHooksRegistry;
|
||||
using at::RegistererMTIAHooksRegistry;
|
||||
|
||||
REGISTER_MTIA_HOOKS(MTIAHooks);
|
||||
C10_REGISTER_GUARD_IMPL(MTIA, MTIAGuardImpl);
|
||||
|
||||
} // namespace torch::mtia
|
||||
|
|
@ -191,8 +191,6 @@ XPU_TEST = [
|
|||
RUN_PARALLEL_BLOCKLIST = [
|
||||
"test_cpp_extensions_jit",
|
||||
"test_cpp_extensions_open_device_registration",
|
||||
"test_cpp_extensions_stream_and_event",
|
||||
"test_cpp_extensions_mtia_backend",
|
||||
"test_jit_disabled",
|
||||
"test_mobile_optimizer",
|
||||
"test_multiprocessing",
|
||||
|
|
|
|||
|
|
@ -1,154 +0,0 @@
|
|||
# Owner(s): ["module: mtia"]
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.testing._internal.common_utils as common
|
||||
import torch.utils.cpp_extension
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_ARM64,
|
||||
IS_LINUX,
|
||||
skipIfTorchDynamo,
|
||||
TEST_CUDA,
|
||||
TEST_PRIVATEUSE1,
|
||||
)
|
||||
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
|
||||
|
||||
|
||||
TEST_CUDA = TEST_CUDA and CUDA_HOME is not None
|
||||
TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None
|
||||
|
||||
|
||||
def remove_build_path():
|
||||
if sys.platform == "win32":
|
||||
# Not wiping extensions build folder because Windows
|
||||
return
|
||||
default_build_root = torch.utils.cpp_extension.get_default_build_root()
|
||||
if os.path.exists(default_build_root):
|
||||
shutil.rmtree(default_build_root, ignore_errors=True)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
IS_ARM64 or not IS_LINUX or TEST_CUDA or TEST_PRIVATEUSE1,
|
||||
"Only on linux platform and mutual exclusive to other backends",
|
||||
)
|
||||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||
class TestCppExtensionMTIABackend(common.TestCase):
|
||||
"""Tests MTIA backend with C++ extensions."""
|
||||
|
||||
module = None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# cpp extensions use relative paths. Those paths are relative to
|
||||
# this file, so we'll change the working directory temporarily
|
||||
self.old_working_dir = os.getcwd()
|
||||
os.chdir(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
# return the working directory (see setUp)
|
||||
os.chdir(self.old_working_dir)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
remove_build_path()
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
remove_build_path()
|
||||
build_dir = tempfile.mkdtemp()
|
||||
# Load the fake device guard impl.
|
||||
cls.module = torch.utils.cpp_extension.load(
|
||||
name="mtia_extension",
|
||||
sources=["cpp_extensions/mtia_extension.cpp"],
|
||||
build_directory=build_dir,
|
||||
extra_include_paths=[
|
||||
"cpp_extensions",
|
||||
"path / with spaces in it",
|
||||
"path with quote'",
|
||||
],
|
||||
is_python_module=False,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
|
||||
def test_get_device_module(self):
|
||||
device = torch.device("mtia:0")
|
||||
default_stream = torch.get_device_module(device).current_stream()
|
||||
self.assertEqual(
|
||||
default_stream.device_type, int(torch._C._autograd.DeviceType.MTIA)
|
||||
)
|
||||
print(torch._C.Stream.__mro__)
|
||||
print(torch.cuda.Stream.__mro__)
|
||||
|
||||
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
|
||||
def test_stream_basic(self):
|
||||
default_stream = torch.mtia.current_stream()
|
||||
user_stream = torch.mtia.Stream()
|
||||
self.assertEqual(torch.mtia.current_stream(), default_stream)
|
||||
self.assertNotEqual(default_stream, user_stream)
|
||||
# Check mtia_extension.cpp, default stream id starts from 0.
|
||||
self.assertEqual(default_stream.stream_id, 0)
|
||||
self.assertNotEqual(user_stream.stream_id, 0)
|
||||
with torch.mtia.stream(user_stream):
|
||||
self.assertEqual(torch.mtia.current_stream(), user_stream)
|
||||
self.assertTrue(user_stream.query())
|
||||
default_stream.synchronize()
|
||||
self.assertTrue(default_stream.query())
|
||||
|
||||
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
|
||||
def test_stream_context(self):
|
||||
mtia_stream_0 = torch.mtia.Stream(device="mtia:0")
|
||||
mtia_stream_1 = torch.mtia.Stream(device="mtia:0")
|
||||
print(mtia_stream_0)
|
||||
print(mtia_stream_1)
|
||||
with torch.mtia.stream(mtia_stream_0):
|
||||
current_stream = torch.mtia.current_stream()
|
||||
msg = f"current_stream {current_stream} should be {mtia_stream_0}"
|
||||
self.assertTrue(current_stream == mtia_stream_0, msg=msg)
|
||||
|
||||
with torch.mtia.stream(mtia_stream_1):
|
||||
current_stream = torch.mtia.current_stream()
|
||||
msg = f"current_stream {current_stream} should be {mtia_stream_1}"
|
||||
self.assertTrue(current_stream == mtia_stream_1, msg=msg)
|
||||
|
||||
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
|
||||
def test_stream_context_different_device(self):
|
||||
device_0 = torch.device("mtia:0")
|
||||
device_1 = torch.device("mtia:1")
|
||||
mtia_stream_0 = torch.mtia.Stream(device=device_0)
|
||||
mtia_stream_1 = torch.mtia.Stream(device=device_1)
|
||||
print(mtia_stream_0)
|
||||
print(mtia_stream_1)
|
||||
orig_current_device = torch.mtia.current_device()
|
||||
with torch.mtia.stream(mtia_stream_0):
|
||||
current_stream = torch.mtia.current_stream()
|
||||
self.assertTrue(torch.mtia.current_device() == device_0.index)
|
||||
msg = f"current_stream {current_stream} should be {mtia_stream_0}"
|
||||
self.assertTrue(current_stream == mtia_stream_0, msg=msg)
|
||||
self.assertTrue(torch.mtia.current_device() == orig_current_device)
|
||||
with torch.mtia.stream(mtia_stream_1):
|
||||
current_stream = torch.mtia.current_stream()
|
||||
self.assertTrue(torch.mtia.current_device() == device_1.index)
|
||||
msg = f"current_stream {current_stream} should be {mtia_stream_1}"
|
||||
self.assertTrue(current_stream == mtia_stream_1, msg=msg)
|
||||
self.assertTrue(torch.mtia.current_device() == orig_current_device)
|
||||
|
||||
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
|
||||
def test_device_context(self):
|
||||
device_0 = torch.device("mtia:0")
|
||||
device_1 = torch.device("mtia:1")
|
||||
with torch.mtia.device(device_0):
|
||||
self.assertTrue(torch.mtia.current_device() == device_0.index)
|
||||
|
||||
with torch.mtia.device(device_1):
|
||||
self.assertTrue(torch.mtia.current_device() == device_1.index)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common.run_tests()
|
||||
|
|
@ -1,108 +0,0 @@
|
|||
# Owner(s): ["module: mtia"]
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.testing._internal.common_utils as common
|
||||
import torch.utils.cpp_extension
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_ARM64,
|
||||
IS_LINUX,
|
||||
skipIfTorchDynamo,
|
||||
TEST_CUDA,
|
||||
TEST_PRIVATEUSE1,
|
||||
)
|
||||
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
|
||||
|
||||
|
||||
TEST_CUDA = TEST_CUDA and CUDA_HOME is not None
|
||||
TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None
|
||||
|
||||
|
||||
def remove_build_path():
|
||||
if sys.platform == "win32":
|
||||
# Not wiping extensions build folder because Windows
|
||||
return
|
||||
default_build_root = torch.utils.cpp_extension.get_default_build_root()
|
||||
if os.path.exists(default_build_root):
|
||||
shutil.rmtree(default_build_root, ignore_errors=True)
|
||||
|
||||
|
||||
# Since we use a fake MTIA device backend to test generic Stream/Event, device backends are mutual exclusive to each other.
|
||||
# The test will be skipped if any of the following conditions are met:
|
||||
@unittest.skipIf(
|
||||
IS_ARM64 or not IS_LINUX or TEST_CUDA or TEST_PRIVATEUSE1,
|
||||
"Only on linux platform and mutual exclusive to other backends",
|
||||
)
|
||||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||
class TestCppExtensionStreamAndEvent(common.TestCase):
|
||||
"""Tests Stream and Event with C++ extensions."""
|
||||
|
||||
module = None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# cpp extensions use relative paths. Those paths are relative to
|
||||
# this file, so we'll change the working directory temporarily
|
||||
self.old_working_dir = os.getcwd()
|
||||
os.chdir(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
# return the working directory (see setUp)
|
||||
os.chdir(self.old_working_dir)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
remove_build_path()
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
remove_build_path()
|
||||
build_dir = tempfile.mkdtemp()
|
||||
# Load the fake device guard impl.
|
||||
src = f"{os.path.abspath(os.path.dirname(__file__))}/cpp_extensions/mtia_extension.cpp"
|
||||
cls.module = torch.utils.cpp_extension.load(
|
||||
name="mtia_extension",
|
||||
sources=[src],
|
||||
build_directory=build_dir,
|
||||
extra_include_paths=[
|
||||
"cpp_extensions",
|
||||
"path / with spaces in it",
|
||||
"path with quote'",
|
||||
],
|
||||
is_python_module=False,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
|
||||
def test_stream_event(self):
|
||||
s = torch.Stream()
|
||||
self.assertTrue(s.device_type, int(torch._C._autograd.DeviceType.MTIA))
|
||||
e = torch.Event()
|
||||
self.assertTrue(e.device.type, "mtia")
|
||||
# Should be nullptr by default
|
||||
self.assertTrue(e.event_id == 0)
|
||||
s.record_event(event=e)
|
||||
print(f"recorded event 1: {e}")
|
||||
self.assertTrue(e.event_id != 0)
|
||||
e2 = s.record_event()
|
||||
print(f"recorded event 2: {e2}")
|
||||
self.assertTrue(e2.event_id != 0)
|
||||
self.assertTrue(e2.event_id != e.event_id)
|
||||
e.synchronize()
|
||||
e2.synchronize()
|
||||
time_elapsed = e.elapsed_time(e2)
|
||||
print(f"time elapsed between e1 and e2: {time_elapsed}")
|
||||
old_event_id = e.event_id
|
||||
e.record(stream=s)
|
||||
print(f"recorded event 1: {e}")
|
||||
self.assertTrue(e.event_id == old_event_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common.run_tests()
|
||||
|
|
@ -21,8 +21,6 @@ TARGET_DET_LIST = [
|
|||
"test_cpp_extensions_aot_no_ninja",
|
||||
"test_cpp_extensions_jit",
|
||||
"test_cpp_extensions_open_device_registration",
|
||||
"test_cpp_extensions_stream_and_event",
|
||||
"test_cpp_extensions_mtia_backend",
|
||||
"test_cuda",
|
||||
"test_cuda_primary_ctx",
|
||||
"test_dataloader",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user