[nnc] Added NNC lowerings for t/transpose/permute/expand + other cleaning (#57426)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57426

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D28293191

Pulled By: Chillee

fbshipit-source-id: b8fc44299acf2569c11e87e1991a2b724434b15d
This commit is contained in:
Horace He 2021-05-07 15:37:22 -07:00 committed by Facebook GitHub Bot
parent c88167d2ed
commit b38f153d91
3 changed files with 183 additions and 10 deletions

View File

@ -206,5 +206,104 @@ graph(%a : Tensor, %b : Tensor):
correct = TestModule().forward(x, y)
np.testing.assert_allclose(res.numpy(), correct.numpy(), atol=1e-5)
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_kernel_with_t(self):
def f(a):
return a.t()
device, size = 'cpu', (3, 4)
x = torch.rand(size, device=device)
graph_str = """
graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
%3 : Float(4, 3, strides=[4, 1], requires_grad=0, device=cpu) = aten::t(%a.1)
return (%3)
"""
graph = torch._C.parse_ir(graph_str)
kernel = torch._C._te.TensorExprKernel(graph)
res1 = kernel.run((x,))
res2 = kernel.fallback((x,))
correct = f(x)
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_kernel_with_transpose(self):
def f(a):
return a.transpose(-1, -2)
device, size = 'cpu', (3, 4)
x = torch.rand(size, device=device)
graph_str = """
graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
%2 : int = prim::Constant[value=-1]()
%3 : int = prim::Constant[value=-2]()
%4 : Float(4, 3, strides=[4, 1], requires_grad=0, device=cpu) = aten::transpose(%a.1, %2, %3)
return (%4)
"""
graph = torch._C.parse_ir(graph_str)
kernel = torch._C._te.TensorExprKernel(graph)
res1 = kernel.run((x,))
res2 = kernel.fallback((x,))
correct = f(x)
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_kernel_with_permute(self):
def f(a):
return a.permute([2,1,0])
device, size = 'cpu', (3, 4, 5)
x = torch.rand(size, device=device)
graph_str = """
graph(%a.1 : Float(3, 4, 5, strides=[20, 5, 1], requires_grad=0, device=cpu)):
%1 : int = prim::Constant[value=2]()
%2 : int = prim::Constant[value=1]()
%3 : int = prim::Constant[value=0]()
%4 : int[] = prim::ListConstruct(%1, %2, %3)
%5 : Float(5, 4, 3, strides=[12, 3, 1], requires_grad=0, device=cpu) = aten::permute(%a.1, %4)
return (%5)
"""
graph = torch._C.parse_ir(graph_str)
kernel = torch._C._te.TensorExprKernel(graph)
res1 = kernel.run((x,))
res2 = kernel.fallback((x,))
correct = f(x)
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_kernel_with_expand(self):
def f(a):
return a.expand((2,3,4))
device = 'cpu'
x = torch.rand((1,3,1), device=device)
graph_str = """
graph(%a : Float(1, 3, 1, strides=[3, 1, 1], requires_grad=0, device=cpu)):
%1 : int = prim::Constant[value=2]()
%2 : int = prim::Constant[value=3]()
%3 : int = prim::Constant[value=4]()
%4 : int[] = prim::ListConstruct(%1, %2, %3)
%5 : bool = prim::Constant[value=0]()
%6 : Float(2, 3, 4, strides=[12, 4, 0], requires_grad=0, device=cpu) = aten::expand(%a, %4, %5)
return (%6)
"""
graph = torch._C.parse_ir(graph_str)
kernel = torch._C._te.TensorExprKernel(graph)
res1 = kernel.run((x,))
res2 = kernel.fallback((x,))
correct = f(x)
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
if __name__ == '__main__':
run_tests()

View File

@ -225,6 +225,24 @@ void nnc_aten_addmm(
}
}
void nnc_aten_digamma(
int64_t bufs_num,
void** buf_data,
int64_t* buf_ranks,
int64_t* buf_dims,
int8_t* buf_dtypes,
int64_t args_num,
int64_t* extra_args) {
std::vector<at::Tensor> tensors =
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
at::Tensor& r = tensors[0];
const at::Tensor& x = tensors[1];
try {
at::digamma_out(r, x);
} catch (...) {
}
}
#ifndef C10_MOBILE
const static RegisterNNCExternalFunction nnc_conv2d(
@ -244,6 +262,9 @@ const static RegisterNNCExternalFunction nnc_mean(
const static RegisterNNCExternalFunction nnc_addmm(
"nnc_aten_addmm",
nnc_aten_addmm);
const static RegisterNNCExternalFunction nnc_digamma(
"nnc_aten_digamma",
nnc_aten_digamma);
#endif

View File

@ -1457,11 +1457,11 @@ Tensor* computeMatmul(
dtype = Dtype(*outputType);
}
BufHandle ResultBuf("matmul", outputShape, dtype);
const Buf* a = c10::get<BufHandle>(inputs[0]).node();
const Buf* b = c10::get<BufHandle>(inputs[1]).node();
const BufHandle a = c10::get<BufHandle>(inputs[0]);
const BufHandle b = c10::get<BufHandle>(inputs[1]);
auto size_a = ExprVectorToExprHandleVector(a->dims());
auto size_b = ExprVectorToExprHandleVector(b->dims());
auto size_a = a.dims();
auto size_b = b.dims();
const IntImm* total_size = dynamic_cast<const IntImm*>(
IRSimplifier::simplify((size_a[0] * size_a[1] * size_b[1])).node());
@ -1479,16 +1479,13 @@ Tensor* computeMatmul(
{{size_a[0], "M"}, {size_b[1], "N"}},
Sum(),
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
BufHandle ah(a);
BufHandle bh(b);
return Load::make(ah, {m, k}) * Load::make(bh, {k, n});
return Load::make(a, {m, k}) * Load::make(b, {k, n});
},
{{size_a[1], "K"}});
} else {
return new Tensor(
ResultBuf.node(),
ExternalCall::make(
ResultBuf, "nnc_aten_matmul", {BufHandle(a), BufHandle(b)}, {}));
ExternalCall::make(ResultBuf, "nnc_aten_matmul", {a, b}, {}));
}
}
@ -2426,9 +2423,60 @@ Tensor* tensorexpr::computeOperandValue(
}
}
return tensorOrConstant(inputs[0], indices);
return broadcast(c10::get<BufHandle>(inputs[0]), indices);
});
}
case aten::t: {
auto shape = valueShape(inputs[0]);
if (shape.size() == 1) {
return new Tensor(c10::get<BufHandle>(inputs[0]).node(), nullptr);
}
return computeOperandValue(
aten::transpose,
{inputs[0], (int64_t)1, (int64_t)0},
outputShape,
outputType);
}
case aten::transpose: {
auto A = c10::get<BufHandle>(inputs[0]);
auto start_dim =
at::maybe_wrap_dim(c10::get<int64_t>(inputs[1]), A.ndim());
auto to_dim = at::maybe_wrap_dim(c10::get<int64_t>(inputs[2]), A.ndim());
return Compute(
"aten_transpose",
c10::fmap<DimArg>(outputShape),
[&](std::vector<VarHandle> axes) {
std::swap(axes[start_dim], axes[to_dim]);
return A.load(axes);
});
}
case aten::permute: {
auto A = c10::get<BufHandle>(inputs[0]);
auto permute_dims = c10::get<IntList>(inputs[1]);
return Compute(
"aten_permute",
c10::fmap<DimArg>(outputShape),
[&](const std::vector<VarHandle>& axes) {
std::vector<VarHandle> new_axes;
assert(permute_dims.size() == axes.size());
for (auto i : permute_dims) {
new_axes.push_back(axes[i]);
}
return A.load(new_axes);
});
}
case aten::expand: {
auto A = c10::get<BufHandle>(inputs[0]);
return Compute(
"aten_expand",
c10::fmap<DimArg>(outputShape),
[&](const std::vector<VarHandle>& axes) {
std::vector<ExprHandle> indices(axes.begin(), axes.end());
return broadcast(A, indices);
});
}
case aten::mm: // aten::mm is a subset of aten::matmul where both inputs are
// rank 2
case aten::matmul: {
return computeMatmul(inputs, outputShape, outputType);
}
@ -2538,6 +2586,11 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
case aten::lgamma:
case aten::slice:
case aten::unsqueeze:
case aten::t:
case aten::transpose:
case aten::expand:
case aten::permute:
case aten::mm:
case aten::matmul:
case aten::cat:
case aten::sum: