mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/36258 Previous we had a && chaining style API. There are some downsides to this API: - It's easy to forget the 'static' qualifier in front, leading to subtle ODR bugs. - It is not compatible with torchbind class_ definitions, as these need multiple levels of chaining. So in practice people end up having to define multiple static initializers, one per class. - It's not like pybind11. - There's no way to conveniently get the file and line number of the registration, as there is no macro point in the API. - The old API doesn't really encourage people to put all of their definitions for a library in one place, and to give a custom namespace for it. Similarly, the old API wasn't very DRY, because you had to keep repeating the namespace/dispatch key you were writing implementations for. The new API is modeled exactly off of the PYBIND11_MODULE macro: you write: ``` TORCH_LIBRARY(aten, m) { m.def("aten::add(Tensor self, Tensor other) -> Tensor"); ... } ``` in a non-chaining fashion, and under the hood the macro expands to define a function, and define a static initializer that allocates c10::Library (previously called c10::Module, but we renamed it to avoid confusion with the existing NN module concept), passes it to your function, and then retains it for the rest of the lifetime of the program. Specification of the namespace is mandatory, and in later commit I plan to make it a hard error to TORCH_LIBRARY the same library name twice. If you are specifying an implementation for an existing operator (e.g., you're the XLA backend, or even if you're just putting registrations for implementations at the implementation site), you should use TORCH_LIBRARY_IMPL, which instead takes a backend argument (instead of namespace) and can be used to specify an implementation for a backend. Unlike TORCH_LIBRARY, you can do as many of these as you want for a backend. This needs updates to the mobile code analyzer. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Differential Revision: D20929257 Pulled By: ezyang fbshipit-source-id: ba04d78492e8c93ae7190165fb936f6872896ada
175 lines
6.9 KiB
Python
175 lines
6.9 KiB
Python
import os
|
|
import unittest
|
|
|
|
import torch.testing._internal.common_utils as common
|
|
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfRocm
|
|
from torch.testing._internal.common_cuda import TEST_CUDA
|
|
import torch
|
|
import torch.backends.cudnn
|
|
import torch.utils.cpp_extension
|
|
|
|
try:
|
|
import torch_test_cpp_extension.cpp as cpp_extension
|
|
import torch_test_cpp_extension.msnpu as msnpu_extension
|
|
import torch_test_cpp_extension.rng as rng_extension
|
|
except ImportError:
|
|
raise RuntimeError(
|
|
"test_cpp_extensions_aot.py cannot be invoked directly. Run "
|
|
"`python run_test.py -i test_cpp_extensions_aot_ninja` instead."
|
|
)
|
|
|
|
|
|
class TestCppExtensionAOT(common.TestCase):
|
|
"""Tests ahead-of-time cpp extensions
|
|
|
|
NOTE: run_test.py's test_cpp_extensions_aot_ninja target
|
|
also runs this test case, but with ninja enabled. If you are debugging
|
|
a test failure here from the CI, check the logs for which target
|
|
(test_cpp_extensions_aot_no_ninja vs test_cpp_extensions_aot_ninja)
|
|
failed.
|
|
"""
|
|
|
|
def test_extension_function(self):
|
|
x = torch.randn(4, 4)
|
|
y = torch.randn(4, 4)
|
|
z = cpp_extension.sigmoid_add(x, y)
|
|
self.assertEqual(z, x.sigmoid() + y.sigmoid())
|
|
|
|
def test_extension_module(self):
|
|
mm = cpp_extension.MatrixMultiplier(4, 8)
|
|
weights = torch.rand(8, 4, dtype=torch.double)
|
|
expected = mm.get().mm(weights)
|
|
result = mm.forward(weights)
|
|
self.assertEqual(expected, result)
|
|
|
|
def test_backward(self):
|
|
mm = cpp_extension.MatrixMultiplier(4, 8)
|
|
weights = torch.rand(8, 4, dtype=torch.double, requires_grad=True)
|
|
result = mm.forward(weights)
|
|
result.sum().backward()
|
|
tensor = mm.get()
|
|
|
|
expected_weights_grad = tensor.t().mm(torch.ones([4, 4], dtype=torch.double))
|
|
self.assertEqual(weights.grad, expected_weights_grad)
|
|
|
|
expected_tensor_grad = torch.ones([4, 4], dtype=torch.double).mm(weights.t())
|
|
self.assertEqual(tensor.grad, expected_tensor_grad)
|
|
|
|
@skipIfRocm
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
|
def test_cuda_extension(self):
|
|
import torch_test_cpp_extension.cuda as cuda_extension
|
|
|
|
x = torch.zeros(100, device="cuda", dtype=torch.float32)
|
|
y = torch.zeros(100, device="cuda", dtype=torch.float32)
|
|
|
|
z = cuda_extension.sigmoid_add(x, y).cpu()
|
|
|
|
# 2 * sigmoid(0) = 2 * 0.5 = 1
|
|
self.assertEqual(z, torch.ones_like(z))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Not available on Windows")
|
|
def test_no_python_abi_suffix_sets_the_correct_library_name(self):
|
|
# For this test, run_test.py will call `python setup.py install` in the
|
|
# cpp_extensions/no_python_abi_suffix_test folder, where the
|
|
# `BuildExtension` class has a `no_python_abi_suffix` option set to
|
|
# `True`. This *should* mean that on Python 3, the produced shared
|
|
# library does not have an ABI suffix like
|
|
# "cpython-37m-x86_64-linux-gnu" before the library suffix, e.g. "so".
|
|
# On Python 2 there is no ABI suffix anyway.
|
|
root = os.path.join("cpp_extensions", "no_python_abi_suffix_test", "build")
|
|
matches = [f for _, _, fs in os.walk(root) for f in fs if f.endswith("so")]
|
|
self.assertEqual(len(matches), 1, str(matches))
|
|
self.assertEqual(matches[0], "no_python_abi_suffix_test.so", str(matches))
|
|
|
|
def test_optional(self):
|
|
has_value = cpp_extension.function_taking_optional(torch.ones(5))
|
|
self.assertTrue(has_value)
|
|
has_value = cpp_extension.function_taking_optional(None)
|
|
self.assertFalse(has_value)
|
|
|
|
|
|
class TestMSNPUTensor(common.TestCase):
|
|
def test_unregistered(self):
|
|
a = torch.arange(0, 10, device='cpu')
|
|
with self.assertRaisesRegex(RuntimeError, "Could not run"):
|
|
b = torch.arange(0, 10, device='msnpu')
|
|
|
|
def test_zeros(self):
|
|
a = torch.empty(5, 5, device='cpu')
|
|
self.assertEqual(a.device, torch.device('cpu'))
|
|
|
|
b = torch.empty(5, 5, device='msnpu')
|
|
self.assertEqual(b.device, torch.device('msnpu', 0))
|
|
self.assertEqual(msnpu_extension.get_test_int(), 0)
|
|
self.assertEqual(torch.get_default_dtype(), b.dtype)
|
|
|
|
c = torch.empty((5, 5), dtype=torch.int64, device='msnpu')
|
|
self.assertEqual(msnpu_extension.get_test_int(), 0)
|
|
self.assertEqual(torch.int64, c.dtype)
|
|
|
|
def test_add(self):
|
|
a = torch.empty(5, 5, device='msnpu', requires_grad=True)
|
|
self.assertEqual(msnpu_extension.get_test_int(), 0)
|
|
|
|
b = torch.empty(5, 5, device='msnpu')
|
|
self.assertEqual(msnpu_extension.get_test_int(), 0)
|
|
|
|
c = a + b
|
|
self.assertEqual(msnpu_extension.get_test_int(), 1)
|
|
|
|
def test_conv_backend_override(self):
|
|
# To simplify tests, we use 4d input here to avoid doing view4d( which
|
|
# needs more overrides) in _convolution.
|
|
input = torch.empty(2, 4, 10, 2, device='msnpu', requires_grad=True)
|
|
weight = torch.empty(6, 4, 2, 2, device='msnpu', requires_grad=True)
|
|
bias = torch.empty(6, device='msnpu')
|
|
|
|
# Make sure forward is overriden
|
|
out = torch.nn.functional.conv1d(input, weight, bias, 2, 0, 1, 1)
|
|
self.assertEqual(msnpu_extension.get_test_int(), 2)
|
|
self.assertEqual(out.shape[0], input.shape[0])
|
|
self.assertEqual(out.shape[1], weight.shape[0])
|
|
|
|
# Make sure backward is overriden
|
|
# Double backward is dispatched to _convolution_double_backward.
|
|
# It is not tested here as it involves more computation/overrides.
|
|
grad = torch.autograd.grad(out, input, out, create_graph=True)
|
|
self.assertEqual(msnpu_extension.get_test_int(), 3)
|
|
self.assertEqual(grad[0].shape, input.shape)
|
|
|
|
|
|
class TestRNGExtension(common.TestCase):
|
|
|
|
def test_rng(self):
|
|
fourty_two = torch.full((10,), 42, dtype=torch.int64)
|
|
|
|
t = torch.empty(10, dtype=torch.int64).random_()
|
|
self.assertNotEqual(t, fourty_two)
|
|
|
|
gen = torch.Generator(device='cpu')
|
|
t = torch.empty(10, dtype=torch.int64).random_(generator=gen)
|
|
self.assertNotEqual(t, fourty_two)
|
|
|
|
self.assertEqual(rng_extension.getInstanceCount(), 0)
|
|
gen = rng_extension.createTestCPUGenerator(42)
|
|
self.assertEqual(rng_extension.getInstanceCount(), 1)
|
|
copy = gen
|
|
self.assertEqual(rng_extension.getInstanceCount(), 1)
|
|
self.assertEqual(gen, copy)
|
|
copy2 = rng_extension.identity(copy)
|
|
self.assertEqual(rng_extension.getInstanceCount(), 1)
|
|
self.assertEqual(gen, copy2)
|
|
t = torch.empty(10, dtype=torch.int64).random_(generator=gen)
|
|
self.assertEqual(rng_extension.getInstanceCount(), 1)
|
|
self.assertEqual(t, fourty_two)
|
|
del gen
|
|
self.assertEqual(rng_extension.getInstanceCount(), 1)
|
|
del copy
|
|
self.assertEqual(rng_extension.getInstanceCount(), 1)
|
|
del copy2
|
|
self.assertEqual(rng_extension.getInstanceCount(), 0)
|
|
|
|
if __name__ == "__main__":
|
|
common.run_tests()
|