#include #include #include #include #include #include #include #include #include #include using namespace torch; static bool in_bad_fork = false; // True for children forked after xpu init // Called in the forked child if xpu has already been initialized static void forked_child() { in_bad_fork = true; torch::utils::set_requires_device_init(at::kXPU, true); } // Should be called before the first xpu call. It is mainly called in lazy_init. // Note: This is distinct from initExtension because a stub xpu implementation // has some working functions (e.g. device_count) but cannot fully initialize. static void poison_fork() { static c10::once_flag flag; c10::call_once(flag, [] { pthread_atfork(nullptr, nullptr, forked_child); }); } // XPU management methods static PyObject* THXPModule_isInBadFork_wrap(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS return PyBool_FromLong(in_bad_fork); END_HANDLE_TH_ERRORS } PyObject* THXPModule_setDevice_wrap(PyObject* self, PyObject* arg) { HANDLE_TH_ERRORS TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to set_device"); int device = THPUtils_unpackInt(arg); c10::xpu::set_device(static_cast(device)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } PyObject* THXPModule_exchangeDevice_wrap(PyObject* self, PyObject* arg) { HANDLE_TH_ERRORS TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchange_device"); int device = THPUtils_unpackInt(arg); if (device < 0) { return THPUtils_packInt32(-1); } torch::utils::device_lazy_init(at::kXPU); int current_device = c10::xpu::exchange_device(device); return THPUtils_packInt32(current_device); END_HANDLE_TH_ERRORS } PyObject* THXPModule_maybeExchangeDevice_wrap(PyObject* self, PyObject* arg) { HANDLE_TH_ERRORS TORCH_CHECK( THPUtils_checkLong(arg), "invalid argument to maybe_exchange_device"); int device = THPUtils_unpackInt(arg); if (device < 0) { return THPUtils_packInt32(-1); } torch::utils::device_lazy_init(at::kXPU); int current_device = c10::xpu::maybe_exchange_device(device); return THPUtils_packInt32(current_device); END_HANDLE_TH_ERRORS } PyObject* THXPModule_getDevice_wrap(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS // NOLINTNEXTLINE(bugprone-signed-char-misuse) auto device = static_cast(c10::xpu::current_device()); return THPUtils_packInt32(device); END_HANDLE_TH_ERRORS } PyObject* THXPModule_getDeviceCount_wrap(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS poison_fork(); return THPUtils_packUInt64(at::xpu::device_count()); END_HANDLE_TH_ERRORS } // XPU module initialization static void registerXpuDeviceProperties(PyObject* module) { // Add _xpuDevicePropertires class to torch._C using namespace c10::xpu; auto get_device_type = [](const DeviceProp& prop) { std::ostringstream stream; using namespace sycl::info; switch (prop.device_type) { case device_type::cpu: stream << "cpu"; break; case device_type::gpu: stream << "gpu"; break; case device_type::accelerator: stream << "accelerator"; break; case device_type::host: stream << "host"; break; default: stream << "unknown device type:" << static_cast::type>( prop.device_type); break; } return stream.str(); }; auto gpu_subslice_count = [](const DeviceProp& prop) { return (prop.gpu_eu_count / prop.gpu_eu_count_per_subslice); }; auto m = py::handle(module).cast(); py::class_(m, "_XpuDeviceProperties") .def_readonly("name", &DeviceProp::name) .def_readonly("platform_name", &DeviceProp::platform_name) .def_readonly("total_memory", &DeviceProp::global_mem_size) .def_readonly("max_compute_units", &DeviceProp::max_compute_units) .def_readonly("gpu_eu_count", &DeviceProp::gpu_eu_count) .def_property_readonly("gpu_subslice_count", gpu_subslice_count) .def_readonly("max_work_group_size", &DeviceProp::max_work_group_size) .def_readonly("max_num_sub_groups", &DeviceProp::max_num_sub_groups) .def_readonly("sub_group_sizes", &DeviceProp::sub_group_sizes) .def_property_readonly("type", get_device_type) .def( "__repr__", [&get_device_type, &gpu_subslice_count](const DeviceProp& prop) { std::ostringstream stream; stream << "_XpuDeviceProperties(name='" << prop.name << "', platform_name='" << prop.platform_name << "', type='" << get_device_type(prop) << ", total_memory=" << prop.global_mem_size / (1024 * 1024) << "MB, max_compute_units=" << prop.max_compute_units << ", gpu_eu_count=" << prop.gpu_eu_count << ", gpu_subslice_count=" << gpu_subslice_count(prop) << ", max_work_group_size=" << prop.max_work_group_size << ", max_num_sub_groups=" << prop.max_num_sub_groups << ", sub_group_sizes=[" << prop.sub_group_sizes << "])"; return stream.str(); }); } static void bindGetDeviceProperties(PyObject* module) { // Add method to torch.xpu auto m = py::handle(module).cast(); m.def( "_get_device_properties", [](int device) -> c10::xpu::DeviceProp* { return at::xpu::getDeviceProperties(device); }, py::return_value_policy::reference); } // Callback for python part. Used for additional initialization of python // classes static PyObject* THXPModule_initExtension(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level poison_fork(); auto m = THPObjectPtr(PyImport_ImportModule("torch.xpu")); if (!m) throw python_error(); bindGetDeviceProperties(m); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } // NOLINTNEXTLINE(modernize-avoid-c-arrays, // cppcoreguidelines-avoid-non-const-global-variables, // cppcoreguidelines-avoid-c-arrays) static struct PyMethodDef _THXPModule_methods[] = { {"_xpu_init", THXPModule_initExtension, METH_NOARGS, nullptr}, {"_xpu_setDevice", THXPModule_setDevice_wrap, METH_O, nullptr}, {"_xpu_exchangeDevice", THXPModule_exchangeDevice_wrap, METH_O, nullptr}, {"_xpu_maybeExchangeDevice", THXPModule_maybeExchangeDevice_wrap, METH_O, nullptr}, {"_xpu_getDevice", THXPModule_getDevice_wrap, METH_NOARGS, nullptr}, {"_xpu_getDeviceCount", THXPModule_getDeviceCount_wrap, METH_NOARGS, nullptr}, {"_xpu_isInBadFork", THXPModule_isInBadFork_wrap, METH_NOARGS, nullptr}, {nullptr}}; PyMethodDef* THXPModule_methods() { return _THXPModule_methods; } namespace torch::xpu { void initModule(PyObject* module) { registerXpuDeviceProperties(module); } } // namespace torch::xpu