pytorch/torch/csrc/utils/device_lazy_init.cpp
rzou 889e3eeed3 Avoid cuda init to FakeTensorMode (#124413)
Also partially fixes #122109

This PR:
- We add a C++ flag (only_lift_cpu_tensors) to toggle the
  torch.tensor(1, device='cuda') ctor strategy.
  When false (default), it does the current PyTorch behavior
  of unconditionally constructing a concrete CUDA tensor then calling
  lift_fresh on it. When true, we instead construct a concrete CPU
  tensor, call lift_fresh, and then call Tensor.to(device) (under any ambient
  modes).
- FakeTensorMode flips this flag depending on if CUDA is available or
  not. We don't unconditionally set the flag to True because that is
  likely BC-breaking.

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124413
Approved by: https://github.com/eellison
2024-04-19 02:39:35 +00:00

63 lines
1.8 KiB
C++

#include <c10/core/impl/TorchDispatchModeTLS.h>
#include <torch/csrc/utils/device_lazy_init.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/object_ptr.h>
#include <iostream>
namespace torch::utils {
namespace {
std::array<bool, at::COMPILE_TIME_MAX_DEVICE_TYPES> is_initialized{};
} // anonymous namespace
bool is_device_initialized(at::DeviceType device_type) {
pybind11::gil_scoped_acquire g;
return is_initialized[static_cast<int>(device_type)];
}
void device_lazy_init(at::DeviceType device_type) {
pybind11::gil_scoped_acquire g;
// Protected by the GIL. We don't use call_once because under ASAN it
// has a buggy implementation that deadlocks if an instance throws an
// exception. In any case, call_once isn't necessary, because we
// have taken a lock.
if (is_device_initialized(device_type)) {
return;
}
auto maybe_mode = c10::impl::TorchDispatchModeTLS::get_mode(
c10::impl::TorchDispatchModeKey::FAKE);
if (maybe_mode) {
return;
}
std::string module_name = "torch." + at::DeviceTypeName(device_type, true);
auto module = THPObjectPtr(PyImport_ImportModule(module_name.c_str()));
if (!module) {
throw python_error();
}
if (device_type == at::DeviceType::PrivateUse1) {
auto has_lazy_init_method =
PyObject_HasAttrString(module.get(), "_lazy_init") == 1;
if (!has_lazy_init_method) {
return;
}
}
auto res = THPObjectPtr(PyObject_CallMethod(module.get(), "_lazy_init", ""));
if (!res) {
throw python_error();
}
is_initialized[static_cast<int>(device_type)] = true;
}
void set_requires_device_init(at::DeviceType device_type, bool value) {
is_initialized[static_cast<int>(device_type)] = !value;
}
} // namespace torch::utils