#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include static uint64_t add_counter = 0; static uint64_t last_saved_value = 0; static uint64_t storageImpl_counter = 0; static uint64_t last_storageImpl_saved_value = 0; // A dummy storageImpl for our custom device, that secretly uses the CPU c10::intrusive_ptr make_custom_storage_impl(c10::StorageImpl::use_byte_size_t, c10::SymInt size_bytes, c10::DataPtr data_ptr, c10::Allocator* allocator, bool resizable) { c10::intrusive_ptr custom_storage_impl; if (data_ptr == nullptr){ custom_storage_impl = c10::make_intrusive( c10::StorageImpl::use_byte_size_t(), size_bytes, allocator, resizable); } else { custom_storage_impl = c10::make_intrusive( c10::StorageImpl::use_byte_size_t(), size_bytes, std::move(data_ptr), allocator, resizable); } storageImpl_counter += 1; return custom_storage_impl; } // Register our dummy storageImpl create method. void custom_storage_registry() { c10::SetStorageImplCreate(c10::DeviceType::PrivateUse1, &make_custom_storage_impl); } bool custom_storageImpl_called() { if (storageImpl_counter > last_storageImpl_saved_value) { last_storageImpl_saved_value = storageImpl_counter; return true; } return false; } // basic dummy add function at::Tensor custom_add_Tensor(const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha) { add_counter += 1; // Since this custom device is just for testing, not bothering to implement kernels. return at::empty(self.sizes(), self.options()); } at::Tensor custom__copy_from_and_resize(const at::Tensor& self, const at::Tensor& dst) { return dst.copy_(self, false); } // Some set operations for the basic use case at::Tensor& custom_set_source_Storage(at::Tensor& result, c10::Storage src) { int64_t new_size = static_cast(src.nbytes() / result.dtype().itemsize()); c10::IntArrayRef stride = {}; result.unsafeGetTensorImpl()->set_storage_offset(0); at::OptionalIntArrayRef stride_opt = stride.data() != nullptr ? at::OptionalIntArrayRef(stride) : std::nullopt; at::native::resize_impl_cpu_(result.unsafeGetTensorImpl(), new_size, stride_opt, /*resize_storage=*/!result.is_meta()); return result; } // This macro does the heavy lifting. // With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend. // For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key. // Later in this file, we map a custom device to the PrivateUse1 device type, // which allows user code that puts a tensor on your custom_device to eventually get plumbed // into the kernels registered here. // // This macro registers your kernels to the PyTorch Dispatcher. // More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/. TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { m.impl("add.Tensor", &custom_add_Tensor); m.impl("_copy_from_and_resize", &custom__copy_from_and_resize); m.impl("set_.source_Storage", &custom_set_source_Storage); } void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { at::native::cpu_fallback(op, stack); } TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { m.impl("_foreach_add.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); m.impl("_fused_adamw_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); m.impl("triu_indices", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); } // This basic implementation doesn't bother dealing with different device indices // (e.g. custom_device:0 vs. custom_device:1). // We could do that by letting the user pass in a device index in our exposed device function. // Note that if you do that, you'll also need to register a device guard to core. // See `c10/core/impl/DeviceGuardImplInterface.h:C10_REGISTER_GUARD_IMPL`. c10::Device get_custom_device() { return c10::Device(c10::DeviceType::PrivateUse1, 0); } bool custom_add_called() { bool called = false; if (add_counter > last_saved_value) { called = true; last_saved_value = add_counter; } return called; } void fallback_with_undefined_tensor() { at::Tensor first = at::empty({2, 3}).to(at::DeviceType::PrivateUse1); at::Tensor second = at::Tensor(); at::Tensor step = at::empty({}).fill_(2).to(at::DeviceType::PrivateUse1); at::Tensor grad_scale = at::empty({}).fill_(0.00001).to(at::DeviceType::PrivateUse1); at::Tensor found_inf = at::empty({}).fill_(1).to(at::DeviceType::PrivateUse1); at::TensorList tensors = {first, first}; at::TensorList undefined_tensors = {first, second}; at::TensorList steps = {step, step}; return at::_fused_adamw_(tensors, tensors, tensors, tensors, undefined_tensors, steps, 0.001, 0.9, 0.999, 1e-2, 1e-8, false, false, grad_scale, found_inf); } // Here, we're exposing a custom device object that corresponds to our custom backend. // We do this using pybind: exposing an "extension_name.custom_device()" function in python, // that's implemented in C++. // The implementation in this file maps directly to the `PrivateUse1` device type. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("custom_device", &get_custom_device, "get custom device object"); m.def("custom_add_called", &custom_add_called, "check if our custom add function was called"); m.def("custom_storage_registry", &custom_storage_registry, "set custom storageImpl creat method"); m.def("custom_storageImpl_called", &custom_storageImpl_called, "check if our custom abs function was called"); m.def("fallback_with_undefined_tensor", &fallback_with_undefined_tensor, "fallback_with_undefined_tensor for privateuse1"); }