diff --git a/test/test_tensorexpr_pybind.py b/test/test_tensorexpr_pybind.py index d5914d90afd..2d17ad882ce 100644 --- a/test/test_tensorexpr_pybind.py +++ b/test/test_tensorexpr_pybind.py @@ -206,5 +206,104 @@ graph(%a : Tensor, %b : Tensor): correct = TestModule().forward(x, y) np.testing.assert_allclose(res.numpy(), correct.numpy(), atol=1e-5) + @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") + def test_kernel_with_t(self): + def f(a): + return a.t() + + device, size = 'cpu', (3, 4) + x = torch.rand(size, device=device) + + graph_str = """ +graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)): + %3 : Float(4, 3, strides=[4, 1], requires_grad=0, device=cpu) = aten::t(%a.1) + return (%3) + """ + graph = torch._C.parse_ir(graph_str) + + kernel = torch._C._te.TensorExprKernel(graph) + res1 = kernel.run((x,)) + res2 = kernel.fallback((x,)) + correct = f(x) + np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) + np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) + + @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") + def test_kernel_with_transpose(self): + def f(a): + return a.transpose(-1, -2) + + device, size = 'cpu', (3, 4) + x = torch.rand(size, device=device) + + graph_str = """ +graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)): + %2 : int = prim::Constant[value=-1]() + %3 : int = prim::Constant[value=-2]() + %4 : Float(4, 3, strides=[4, 1], requires_grad=0, device=cpu) = aten::transpose(%a.1, %2, %3) + return (%4) + """ + graph = torch._C.parse_ir(graph_str) + + kernel = torch._C._te.TensorExprKernel(graph) + res1 = kernel.run((x,)) + res2 = kernel.fallback((x,)) + correct = f(x) + np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) + np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) + + @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") + def test_kernel_with_permute(self): + def f(a): + return a.permute([2,1,0]) + + device, size = 'cpu', (3, 4, 5) + x = torch.rand(size, device=device) + + graph_str = """ +graph(%a.1 : Float(3, 4, 5, strides=[20, 5, 1], requires_grad=0, device=cpu)): + %1 : int = prim::Constant[value=2]() + %2 : int = prim::Constant[value=1]() + %3 : int = prim::Constant[value=0]() + %4 : int[] = prim::ListConstruct(%1, %2, %3) + %5 : Float(5, 4, 3, strides=[12, 3, 1], requires_grad=0, device=cpu) = aten::permute(%a.1, %4) + return (%5) + """ + graph = torch._C.parse_ir(graph_str) + + kernel = torch._C._te.TensorExprKernel(graph) + res1 = kernel.run((x,)) + res2 = kernel.fallback((x,)) + correct = f(x) + np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) + np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) + + @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") + def test_kernel_with_expand(self): + def f(a): + return a.expand((2,3,4)) + + device = 'cpu' + x = torch.rand((1,3,1), device=device) + graph_str = """ +graph(%a : Float(1, 3, 1, strides=[3, 1, 1], requires_grad=0, device=cpu)): + %1 : int = prim::Constant[value=2]() + %2 : int = prim::Constant[value=3]() + %3 : int = prim::Constant[value=4]() + %4 : int[] = prim::ListConstruct(%1, %2, %3) + %5 : bool = prim::Constant[value=0]() + %6 : Float(2, 3, 4, strides=[12, 4, 0], requires_grad=0, device=cpu) = aten::expand(%a, %4, %5) + return (%6) + """ + graph = torch._C.parse_ir(graph_str) + + kernel = torch._C._te.TensorExprKernel(graph) + res1 = kernel.run((x,)) + res2 = kernel.fallback((x,)) + correct = f(x) + np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) + np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) + + if __name__ == '__main__': run_tests() diff --git a/torch/csrc/jit/tensorexpr/external_functions.cpp b/torch/csrc/jit/tensorexpr/external_functions.cpp index e6af453a86e..e47eeb7692f 100644 --- a/torch/csrc/jit/tensorexpr/external_functions.cpp +++ b/torch/csrc/jit/tensorexpr/external_functions.cpp @@ -225,6 +225,24 @@ void nnc_aten_addmm( } } +void nnc_aten_digamma( + int64_t bufs_num, + void** buf_data, + int64_t* buf_ranks, + int64_t* buf_dims, + int8_t* buf_dtypes, + int64_t args_num, + int64_t* extra_args) { + std::vector tensors = + constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes); + at::Tensor& r = tensors[0]; + const at::Tensor& x = tensors[1]; + try { + at::digamma_out(r, x); + } catch (...) { + } +} + #ifndef C10_MOBILE const static RegisterNNCExternalFunction nnc_conv2d( @@ -244,6 +262,9 @@ const static RegisterNNCExternalFunction nnc_mean( const static RegisterNNCExternalFunction nnc_addmm( "nnc_aten_addmm", nnc_aten_addmm); +const static RegisterNNCExternalFunction nnc_digamma( + "nnc_aten_digamma", + nnc_aten_digamma); #endif diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 1ad55dcbb97..a91cb19e59d 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -1457,11 +1457,11 @@ Tensor* computeMatmul( dtype = Dtype(*outputType); } BufHandle ResultBuf("matmul", outputShape, dtype); - const Buf* a = c10::get(inputs[0]).node(); - const Buf* b = c10::get(inputs[1]).node(); + const BufHandle a = c10::get(inputs[0]); + const BufHandle b = c10::get(inputs[1]); - auto size_a = ExprVectorToExprHandleVector(a->dims()); - auto size_b = ExprVectorToExprHandleVector(b->dims()); + auto size_a = a.dims(); + auto size_b = b.dims(); const IntImm* total_size = dynamic_cast( IRSimplifier::simplify((size_a[0] * size_a[1] * size_b[1])).node()); @@ -1479,16 +1479,13 @@ Tensor* computeMatmul( {{size_a[0], "M"}, {size_b[1], "N"}}, Sum(), [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { - BufHandle ah(a); - BufHandle bh(b); - return Load::make(ah, {m, k}) * Load::make(bh, {k, n}); + return Load::make(a, {m, k}) * Load::make(b, {k, n}); }, {{size_a[1], "K"}}); } else { return new Tensor( ResultBuf.node(), - ExternalCall::make( - ResultBuf, "nnc_aten_matmul", {BufHandle(a), BufHandle(b)}, {})); + ExternalCall::make(ResultBuf, "nnc_aten_matmul", {a, b}, {})); } } @@ -2426,9 +2423,60 @@ Tensor* tensorexpr::computeOperandValue( } } - return tensorOrConstant(inputs[0], indices); + return broadcast(c10::get(inputs[0]), indices); }); } + case aten::t: { + auto shape = valueShape(inputs[0]); + if (shape.size() == 1) { + return new Tensor(c10::get(inputs[0]).node(), nullptr); + } + return computeOperandValue( + aten::transpose, + {inputs[0], (int64_t)1, (int64_t)0}, + outputShape, + outputType); + } + case aten::transpose: { + auto A = c10::get(inputs[0]); + auto start_dim = + at::maybe_wrap_dim(c10::get(inputs[1]), A.ndim()); + auto to_dim = at::maybe_wrap_dim(c10::get(inputs[2]), A.ndim()); + return Compute( + "aten_transpose", + c10::fmap(outputShape), + [&](std::vector axes) { + std::swap(axes[start_dim], axes[to_dim]); + return A.load(axes); + }); + } + case aten::permute: { + auto A = c10::get(inputs[0]); + auto permute_dims = c10::get(inputs[1]); + return Compute( + "aten_permute", + c10::fmap(outputShape), + [&](const std::vector& axes) { + std::vector new_axes; + assert(permute_dims.size() == axes.size()); + for (auto i : permute_dims) { + new_axes.push_back(axes[i]); + } + return A.load(new_axes); + }); + } + case aten::expand: { + auto A = c10::get(inputs[0]); + return Compute( + "aten_expand", + c10::fmap(outputShape), + [&](const std::vector& axes) { + std::vector indices(axes.begin(), axes.end()); + return broadcast(A, indices); + }); + } + case aten::mm: // aten::mm is a subset of aten::matmul where both inputs are + // rank 2 case aten::matmul: { return computeMatmul(inputs, outputShape, outputType); } @@ -2538,6 +2586,11 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { case aten::lgamma: case aten::slice: case aten::unsqueeze: + case aten::t: + case aten::transpose: + case aten::expand: + case aten::permute: + case aten::mm: case aten::matmul: case aten::cat: case aten::sum: