mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[nnc] Added NNC lowerings for t/transpose/permute/expand + other cleaning (#57426)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57426 Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D28293191 Pulled By: Chillee fbshipit-source-id: b8fc44299acf2569c11e87e1991a2b724434b15d
This commit is contained in:
parent
c88167d2ed
commit
b38f153d91
|
|
@ -206,5 +206,104 @@ graph(%a : Tensor, %b : Tensor):
|
|||
correct = TestModule().forward(x, y)
|
||||
np.testing.assert_allclose(res.numpy(), correct.numpy(), atol=1e-5)
|
||||
|
||||
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
|
||||
def test_kernel_with_t(self):
|
||||
def f(a):
|
||||
return a.t()
|
||||
|
||||
device, size = 'cpu', (3, 4)
|
||||
x = torch.rand(size, device=device)
|
||||
|
||||
graph_str = """
|
||||
graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
|
||||
%3 : Float(4, 3, strides=[4, 1], requires_grad=0, device=cpu) = aten::t(%a.1)
|
||||
return (%3)
|
||||
"""
|
||||
graph = torch._C.parse_ir(graph_str)
|
||||
|
||||
kernel = torch._C._te.TensorExprKernel(graph)
|
||||
res1 = kernel.run((x,))
|
||||
res2 = kernel.fallback((x,))
|
||||
correct = f(x)
|
||||
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
|
||||
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
|
||||
|
||||
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
|
||||
def test_kernel_with_transpose(self):
|
||||
def f(a):
|
||||
return a.transpose(-1, -2)
|
||||
|
||||
device, size = 'cpu', (3, 4)
|
||||
x = torch.rand(size, device=device)
|
||||
|
||||
graph_str = """
|
||||
graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
|
||||
%2 : int = prim::Constant[value=-1]()
|
||||
%3 : int = prim::Constant[value=-2]()
|
||||
%4 : Float(4, 3, strides=[4, 1], requires_grad=0, device=cpu) = aten::transpose(%a.1, %2, %3)
|
||||
return (%4)
|
||||
"""
|
||||
graph = torch._C.parse_ir(graph_str)
|
||||
|
||||
kernel = torch._C._te.TensorExprKernel(graph)
|
||||
res1 = kernel.run((x,))
|
||||
res2 = kernel.fallback((x,))
|
||||
correct = f(x)
|
||||
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
|
||||
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
|
||||
|
||||
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
|
||||
def test_kernel_with_permute(self):
|
||||
def f(a):
|
||||
return a.permute([2,1,0])
|
||||
|
||||
device, size = 'cpu', (3, 4, 5)
|
||||
x = torch.rand(size, device=device)
|
||||
|
||||
graph_str = """
|
||||
graph(%a.1 : Float(3, 4, 5, strides=[20, 5, 1], requires_grad=0, device=cpu)):
|
||||
%1 : int = prim::Constant[value=2]()
|
||||
%2 : int = prim::Constant[value=1]()
|
||||
%3 : int = prim::Constant[value=0]()
|
||||
%4 : int[] = prim::ListConstruct(%1, %2, %3)
|
||||
%5 : Float(5, 4, 3, strides=[12, 3, 1], requires_grad=0, device=cpu) = aten::permute(%a.1, %4)
|
||||
return (%5)
|
||||
"""
|
||||
graph = torch._C.parse_ir(graph_str)
|
||||
|
||||
kernel = torch._C._te.TensorExprKernel(graph)
|
||||
res1 = kernel.run((x,))
|
||||
res2 = kernel.fallback((x,))
|
||||
correct = f(x)
|
||||
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
|
||||
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
|
||||
|
||||
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
|
||||
def test_kernel_with_expand(self):
|
||||
def f(a):
|
||||
return a.expand((2,3,4))
|
||||
|
||||
device = 'cpu'
|
||||
x = torch.rand((1,3,1), device=device)
|
||||
graph_str = """
|
||||
graph(%a : Float(1, 3, 1, strides=[3, 1, 1], requires_grad=0, device=cpu)):
|
||||
%1 : int = prim::Constant[value=2]()
|
||||
%2 : int = prim::Constant[value=3]()
|
||||
%3 : int = prim::Constant[value=4]()
|
||||
%4 : int[] = prim::ListConstruct(%1, %2, %3)
|
||||
%5 : bool = prim::Constant[value=0]()
|
||||
%6 : Float(2, 3, 4, strides=[12, 4, 0], requires_grad=0, device=cpu) = aten::expand(%a, %4, %5)
|
||||
return (%6)
|
||||
"""
|
||||
graph = torch._C.parse_ir(graph_str)
|
||||
|
||||
kernel = torch._C._te.TensorExprKernel(graph)
|
||||
res1 = kernel.run((x,))
|
||||
res2 = kernel.fallback((x,))
|
||||
correct = f(x)
|
||||
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
|
||||
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -225,6 +225,24 @@ void nnc_aten_addmm(
|
|||
}
|
||||
}
|
||||
|
||||
void nnc_aten_digamma(
|
||||
int64_t bufs_num,
|
||||
void** buf_data,
|
||||
int64_t* buf_ranks,
|
||||
int64_t* buf_dims,
|
||||
int8_t* buf_dtypes,
|
||||
int64_t args_num,
|
||||
int64_t* extra_args) {
|
||||
std::vector<at::Tensor> tensors =
|
||||
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
|
||||
at::Tensor& r = tensors[0];
|
||||
const at::Tensor& x = tensors[1];
|
||||
try {
|
||||
at::digamma_out(r, x);
|
||||
} catch (...) {
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef C10_MOBILE
|
||||
|
||||
const static RegisterNNCExternalFunction nnc_conv2d(
|
||||
|
|
@ -244,6 +262,9 @@ const static RegisterNNCExternalFunction nnc_mean(
|
|||
const static RegisterNNCExternalFunction nnc_addmm(
|
||||
"nnc_aten_addmm",
|
||||
nnc_aten_addmm);
|
||||
const static RegisterNNCExternalFunction nnc_digamma(
|
||||
"nnc_aten_digamma",
|
||||
nnc_aten_digamma);
|
||||
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -1457,11 +1457,11 @@ Tensor* computeMatmul(
|
|||
dtype = Dtype(*outputType);
|
||||
}
|
||||
BufHandle ResultBuf("matmul", outputShape, dtype);
|
||||
const Buf* a = c10::get<BufHandle>(inputs[0]).node();
|
||||
const Buf* b = c10::get<BufHandle>(inputs[1]).node();
|
||||
const BufHandle a = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle b = c10::get<BufHandle>(inputs[1]);
|
||||
|
||||
auto size_a = ExprVectorToExprHandleVector(a->dims());
|
||||
auto size_b = ExprVectorToExprHandleVector(b->dims());
|
||||
auto size_a = a.dims();
|
||||
auto size_b = b.dims();
|
||||
const IntImm* total_size = dynamic_cast<const IntImm*>(
|
||||
IRSimplifier::simplify((size_a[0] * size_a[1] * size_b[1])).node());
|
||||
|
||||
|
|
@ -1479,16 +1479,13 @@ Tensor* computeMatmul(
|
|||
{{size_a[0], "M"}, {size_b[1], "N"}},
|
||||
Sum(),
|
||||
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
|
||||
BufHandle ah(a);
|
||||
BufHandle bh(b);
|
||||
return Load::make(ah, {m, k}) * Load::make(bh, {k, n});
|
||||
return Load::make(a, {m, k}) * Load::make(b, {k, n});
|
||||
},
|
||||
{{size_a[1], "K"}});
|
||||
} else {
|
||||
return new Tensor(
|
||||
ResultBuf.node(),
|
||||
ExternalCall::make(
|
||||
ResultBuf, "nnc_aten_matmul", {BufHandle(a), BufHandle(b)}, {}));
|
||||
ExternalCall::make(ResultBuf, "nnc_aten_matmul", {a, b}, {}));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2426,9 +2423,60 @@ Tensor* tensorexpr::computeOperandValue(
|
|||
}
|
||||
}
|
||||
|
||||
return tensorOrConstant(inputs[0], indices);
|
||||
return broadcast(c10::get<BufHandle>(inputs[0]), indices);
|
||||
});
|
||||
}
|
||||
case aten::t: {
|
||||
auto shape = valueShape(inputs[0]);
|
||||
if (shape.size() == 1) {
|
||||
return new Tensor(c10::get<BufHandle>(inputs[0]).node(), nullptr);
|
||||
}
|
||||
return computeOperandValue(
|
||||
aten::transpose,
|
||||
{inputs[0], (int64_t)1, (int64_t)0},
|
||||
outputShape,
|
||||
outputType);
|
||||
}
|
||||
case aten::transpose: {
|
||||
auto A = c10::get<BufHandle>(inputs[0]);
|
||||
auto start_dim =
|
||||
at::maybe_wrap_dim(c10::get<int64_t>(inputs[1]), A.ndim());
|
||||
auto to_dim = at::maybe_wrap_dim(c10::get<int64_t>(inputs[2]), A.ndim());
|
||||
return Compute(
|
||||
"aten_transpose",
|
||||
c10::fmap<DimArg>(outputShape),
|
||||
[&](std::vector<VarHandle> axes) {
|
||||
std::swap(axes[start_dim], axes[to_dim]);
|
||||
return A.load(axes);
|
||||
});
|
||||
}
|
||||
case aten::permute: {
|
||||
auto A = c10::get<BufHandle>(inputs[0]);
|
||||
auto permute_dims = c10::get<IntList>(inputs[1]);
|
||||
return Compute(
|
||||
"aten_permute",
|
||||
c10::fmap<DimArg>(outputShape),
|
||||
[&](const std::vector<VarHandle>& axes) {
|
||||
std::vector<VarHandle> new_axes;
|
||||
assert(permute_dims.size() == axes.size());
|
||||
for (auto i : permute_dims) {
|
||||
new_axes.push_back(axes[i]);
|
||||
}
|
||||
return A.load(new_axes);
|
||||
});
|
||||
}
|
||||
case aten::expand: {
|
||||
auto A = c10::get<BufHandle>(inputs[0]);
|
||||
return Compute(
|
||||
"aten_expand",
|
||||
c10::fmap<DimArg>(outputShape),
|
||||
[&](const std::vector<VarHandle>& axes) {
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
return broadcast(A, indices);
|
||||
});
|
||||
}
|
||||
case aten::mm: // aten::mm is a subset of aten::matmul where both inputs are
|
||||
// rank 2
|
||||
case aten::matmul: {
|
||||
return computeMatmul(inputs, outputShape, outputType);
|
||||
}
|
||||
|
|
@ -2538,6 +2586,11 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
|
|||
case aten::lgamma:
|
||||
case aten::slice:
|
||||
case aten::unsqueeze:
|
||||
case aten::t:
|
||||
case aten::transpose:
|
||||
case aten::expand:
|
||||
case aten::permute:
|
||||
case aten::mm:
|
||||
case aten::matmul:
|
||||
case aten::cat:
|
||||
case aten::sum:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user