mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/39459 Update to this PR: this code isn't going to fully solve https://github.com/pytorch/pytorch/issues/37010. The changes required for 37010 is more than this PR initially planned. Instead, this PR switches op registration of rng related tests to use the new API (similar to what was done in #36925) Test Plan: 1) unit tests Imported from OSS Reviewed By: ezyang Differential Revision: D22264889 fbshipit-source-id: 82488ac6e3b762a756818434e22c2a0f9cb9dd47
67 lines
2.3 KiB
C++
67 lines
2.3 KiB
C++
#include <torch/extension.h>
|
|
#include <torch/library.h>
|
|
#include <ATen/Generator.h>
|
|
#include <ATen/Tensor.h>
|
|
#include <ATen/native/DistributionTemplates.h>
|
|
#include <ATen/native/cpu/DistributionTemplates.h>
|
|
#include <memory>
|
|
|
|
using namespace at;
|
|
|
|
static size_t instance_count = 0;
|
|
|
|
struct TestCPUGenerator : public c10::GeneratorImpl {
|
|
TestCPUGenerator(uint64_t value) : c10::GeneratorImpl{Device(DeviceType::CPU), DispatchKeySet(DispatchKey::CustomRNGKeyId)}, value_(value) {
|
|
++instance_count;
|
|
}
|
|
~TestCPUGenerator() {
|
|
--instance_count;
|
|
}
|
|
uint32_t random() { return static_cast<uint32_t>(value_); }
|
|
uint64_t random64() { return value_; }
|
|
void set_current_seed(uint64_t seed) override { throw std::runtime_error("not implemented"); }
|
|
uint64_t current_seed() const override { throw std::runtime_error("not implemented"); }
|
|
uint64_t seed() override { throw std::runtime_error("not implemented"); }
|
|
TestCPUGenerator* clone_impl() const override { throw std::runtime_error("not implemented"); }
|
|
|
|
static DeviceType device_type() { return DeviceType::CPU; }
|
|
|
|
uint64_t value_;
|
|
};
|
|
|
|
Tensor& random_(Tensor& self, c10::optional<Generator> generator) {
|
|
return at::native::templates::random_impl<native::templates::cpu::RandomKernel, TestCPUGenerator>(self, generator);
|
|
}
|
|
|
|
Tensor& random_from_to(Tensor& self, int64_t from, optional<int64_t> to, c10::optional<Generator> generator) {
|
|
return at::native::templates::random_from_to_impl<native::templates::cpu::RandomFromToKernel, TestCPUGenerator>(self, from, to, generator);
|
|
}
|
|
|
|
Tensor& random_to(Tensor& self, int64_t to, c10::optional<Generator> generator) {
|
|
return random_from_to(self, 0, to, generator);
|
|
}
|
|
|
|
Generator createTestCPUGenerator(uint64_t value) {
|
|
return at::make_generator<TestCPUGenerator>(value);
|
|
}
|
|
|
|
Generator identity(Generator g) {
|
|
return g;
|
|
}
|
|
|
|
size_t getInstanceCount() {
|
|
return instance_count;
|
|
}
|
|
|
|
TORCH_LIBRARY_IMPL(aten, CustomRNGKeyId, m) {
|
|
m.impl_UNBOXED("aten::random_.from", random_from_to);
|
|
m.impl_UNBOXED("aten::random_.to", random_to);
|
|
m.impl_UNBOXED("aten::random_", random_);
|
|
}
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
m.def("createTestCPUGenerator", &createTestCPUGenerator);
|
|
m.def("getInstanceCount", &getInstanceCount);
|
|
m.def("identity", &identity);
|
|
}
|