[nnc] Strides to Tensor (#72962)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72962

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM, cpuhrsch

Differential Revision: D34589306

Pulled By: IvanKobzarev

fbshipit-source-id: ecee5249760ecc0c8b2edb1842b90218899bc944
(cherry picked from commit 9e310c4c67389da30da89126d838ffe3864aba6f)
This commit is contained in:
Ivan Kobzarev 2022-04-23 12:30:25 -07:00 committed by PyTorch MergeBot
parent 1a7e43be14
commit 939060925f
31 changed files with 710 additions and 92 deletions

View File

@ -783,15 +783,36 @@ struct TORCH_API TensorType : public SharedType {
static const TypeKind Kind = TypeKind::TensorType;
static std::vector<int64_t> contiguousStridesOf(at::IntArrayRef sizes) {
std::vector<int64_t> strides(sizes.size());
if (sizes.empty()) // zero-dim case
static std::vector<int64_t> contiguousStridesOf(
at::IntArrayRef sizes,
at::MemoryFormat memory_format = MemoryFormat::Contiguous) {
auto contiguous_fn = [](const at::IntArrayRef& sizes,
const std::vector<int64_t>& dim_order) {
std::vector<int64_t> strides(sizes.size());
if (sizes.empty()) // zero-dim case
return strides;
strides[dim_order[0]] = 1;
for (size_t i = 1; i < dim_order.size(); i++) {
auto cur_dim = dim_order[i];
auto pre_dim = dim_order[i - 1];
strides[cur_dim] = strides[pre_dim] * sizes[pre_dim];
}
return strides;
strides.back() = 1;
for (size_t i = strides.size() - 1; i > 0; i--) {
strides[i - 1] = strides[i] * sizes[i];
};
std::vector<int64_t> dim_order(sizes.size());
if (memory_format == MemoryFormat::ChannelsLast) {
dim_order = {1, 3, 2, 0};
} else if (memory_format == MemoryFormat::ChannelsLast3d) {
dim_order = {1, 4, 3, 2, 0};
} else {
auto ndims = sizes.size();
for (size_t i = 0; i < ndims; i++) {
dim_order[i] = ndims - i - 1; // Reverse
}
}
return strides;
return contiguous_fn(sizes, dim_order);
}
private:

View File

@ -392,8 +392,8 @@ BENCHMARK_DEFINE_F(Reduce1D, Op)(benchmark::State& state) {
const int kChunkSize = 8;
te::BufHandle a("A", {M}, te::kFloat);
te::Tensor b =
te::computeSum({a, te::IntList({0}), false}, {}, at::kFloat, at::kCPU);
te::Tensor b = te::computeSum(
{a, te::IntList({0}), false}, {}, {}, at::kFloat, at::kCPU);
te::LoopNest nest({b});
auto loops = nest.getLoopStmtsFor(b);
@ -456,8 +456,8 @@ BENCHMARK_REGISTER_F(Reduce2DCol, Torch)
BENCHMARK_DEFINE_F(Reduce2DCol, OpSchedule)(benchmark::State& state) {
constexpr int kCacheSize = 1 << 12;
te::BufHandle a("A", {M, N}, te::kFloat);
te::Tensor b =
te::computeSum({a, te::IntList({0}), false}, {N}, at::kFloat, at::kCPU);
te::Tensor b = te::computeSum(
{a, te::IntList({0}), false}, {N}, {1}, at::kFloat, at::kCPU);
te::LoopNest nest({b});
auto sch = state.range(2);
@ -565,8 +565,8 @@ BENCHMARK_REGISTER_F(Reduce2DRow, Hand)->Args({1 << 18, 1 << 6});
BENCHMARK_DEFINE_F(Reduce2DRow, OpSchedule)(benchmark::State& state) {
constexpr int kChunkSize = 8;
te::BufHandle a("A", {M, N}, te::kFloat);
te::Tensor b =
te::computeSum({a, te::IntList({1}), false}, {M}, at::kFloat, at::kCPU);
te::Tensor b = te::computeSum(
{a, te::IntList({1}), false}, {M}, {1}, at::kFloat, at::kCPU);
te::LoopNest nest({b});
auto sch = state.range(2);

View File

@ -944,6 +944,7 @@ TEST(ExternalCall, JitCustomFusionOp) {
[external_func_name](
const std::vector<torch::jit::tensorexpr::ArgValue>& inputs,
const std::vector<torch::jit::tensorexpr::ExprHandle>& output_shape,
const std::vector<torch::jit::tensorexpr::ExprHandle>& output_strides,
const c10::optional<torch::jit::tensorexpr::ScalarType>& output_type,
at::Device device) {
auto output_dtype = Dtype(*output_type);

View File

@ -1598,12 +1598,14 @@ TEST_F(Kernel, CodegenInspection) {
Tensor lowerNanToNum(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
auto input_buf = c10::get<BufHandle>(inputs[0]);
auto e = Compute(
"custom_nan_to_num",
outputShape,
outputStrides,
[&](const std::vector<VarHandle>& axes) {
std::vector<ExprHandle> indices(axes.begin(), axes.end());
auto load = input_buf.load(indices);

View File

@ -1,5 +1,6 @@
#include <gtest/gtest.h>
#include <torch/csrc/jit/tensorexpr/eval.h>
#include <torch/csrc/jit/tensorexpr/expr.h>
#include <torch/csrc/jit/tensorexpr/loopnest.h>
#include <torch/csrc/jit/tensorexpr/operators/operators.h>
#include <torch/torch.h>
@ -29,7 +30,10 @@ TEST(Ops, Sum) {
const auto& outShape = outputShapes[idx];
BufHandle a("a", {M, N}, kFloat);
Tensor b = computeSum({a, dims, false}, outShape, c10::kFloat, at::kCPU);
std::vector<ExprHandle> outStrides =
c10::fmap<ExprHandle>(make_contiguous_strides(outShape));
Tensor b = computeSum(
{a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU);
auto cg = compile({a}, {b});
auto at = at::arange(M * N, at::kFloat).view({M, N});
@ -41,3 +45,34 @@ TEST(Ops, Sum) {
ASSERT_TRUE(at::allclose(bt, ref));
}
}
TEST(Ops, ChannelsLastSum) {
constexpr int A = 2;
constexpr int B = 3;
constexpr int C = 4;
constexpr int D = 5;
constexpr int E = 6;
std::vector<IntList> testDims = {{0}, {1}, {0, 1}};
std::vector<std::vector<ExprHandle>> outputShapes = {
{B, C, D, E}, {A, C, D, E}, {C, D, E}};
for (unsigned idx = 0; idx < testDims.size(); idx++) {
const auto& dims = testDims[idx];
const auto& outShape = outputShapes[idx];
BufHandle a("a", {A, B, C, D, E}, kFloat);
std::vector<ExprHandle> outStrides =
c10::fmap<ExprHandle>(make_channels_last_strides(outShape));
Tensor b = computeSum(
{a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU);
auto cg = compile({a}, {b});
auto at = at::arange(A * B * C * D * E, at::kFloat).view({A, B, C, D, E});
auto ref = at::sum(at, dims);
auto bt = at::empty_like(ref);
cg->call({at.data_ptr<float>(), bt.data_ptr<float>()});
ASSERT_TRUE(at::allclose(bt, ref));
}
}

View File

@ -348,7 +348,7 @@ graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)):
"""
graph = torch._C.parse_ir(graph_str)
def my_custom_lowering(inputs, out_shape, out_type, device):
def my_custom_lowering(inputs, out_shape, out_strides, out_type, device):
def compute(idxs):
load = inputs[0].as_buf().load(idxs)
return te.ifThenElse(

View File

@ -164,8 +164,7 @@ StrideInput summarizeOutputStrides(const TensorType& tt) {
// otherwise we defer to contiguous
// TODO: channels last 3d
// NNC Channels last permutation for outputs causes slowdown, disable
if (c10::is_channels_last_strides_2d(sizes, strides) &&
!tt.device()->is_cpu()) {
if (c10::is_channels_last_strides_2d(sizes, strides)) {
return StrideInput::TENSOR_CONT_CHANNELS_LAST;
}
return StrideInput::TENSOR_CONT;

View File

@ -477,6 +477,11 @@ bool Buf::is_contiguous(at::MemoryFormat memory_format) const {
return false;
dim_order = {1, 4, 3, 2, 0};
} else {
if (dims_.empty()) {
// Scalar tensor
TORCH_CHECK(strides_.empty());
return true; // Align with the isContiguous logic in the kernel.cpp
}
for (size_t i = 0; i < ndims; i++) {
dim_order[i] = ndims - i - 1; // Reverse
}

View File

@ -185,9 +185,9 @@ class TORCH_API Var : public ExprNode<Var> {
std::string name_hint_;
};
std::vector<ExprPtr> make_contiguous_strides(
TORCH_API std::vector<ExprPtr> make_contiguous_strides(
const std::vector<ExprHandle>& dims);
std::vector<ExprPtr> make_channels_last_strides(
TORCH_API std::vector<ExprPtr> make_channels_last_strides(
const std::vector<ExprHandle>& dims);
class TORCH_API Buf : public ExprNode<Buf> {
@ -323,7 +323,6 @@ class TORCH_API Buf : public ExprNode<Buf> {
bool is_cont_with(int cur_dim, int adjacent_dim) const;
bool is_stride_one(int cur_dim) const;
private:
VarPtr base_handle_;
std::vector<ExprPtr> dims_;
std::vector<ExprPtr> strides_;

View File

@ -208,7 +208,9 @@ std::vector<int64_t> _pair_int(IValue v) {
}
}
static bool isContiguous(const torch::jit::Value* v) {
static bool isContiguous(
const torch::jit::Value* v,
at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) {
auto const& tt = v->type()->cast<TensorType>();
if (!tt) {
return false;
@ -221,6 +223,14 @@ static bool isContiguous(const torch::jit::Value* v) {
if (!sizes || !strides) {
return false;
}
// Check dimension size first
int ndims = (*sizes).size();
if ((memory_format == at::MemoryFormat::ChannelsLast && ndims != 4) ||
(memory_format == at::MemoryFormat::ChannelsLast3d && ndims != 5)) {
return false;
}
return *strides == TensorType::contiguousStridesOf(*sizes);
}
@ -475,8 +485,50 @@ Tensor TensorExprKernel::computeValue(const torch::jit::Value* v) {
hasRandom_ = true;
}
// Check whether the node is to create an new tensor
// TODO: It is hard to deduce the layout of the generated tensor.
bool is_tensor_creation = true;
// Check if the tensor is a contiguous tensor
bool is_contiguous = false;
// Check if the tensor is a channels-last contiguous tensor
bool is_channels_last_contiguous = false;
for (auto input : inputs) {
if (input->type()->kind() != TypeKind::TensorType)
continue;
is_tensor_creation = false;
TORCH_CHECK(bufs_.count(input) > 0);
auto buf_ = bufs_.at(input);
auto _is_contiguous = buf_->is_contiguous();
if (_is_contiguous) {
is_contiguous |= _is_contiguous;
} else {
is_channels_last_contiguous |=
(buf_->is_contiguous(at::MemoryFormat::ChannelsLast) ||
buf_->is_contiguous(at::MemoryFormat::ChannelsLast3d) ||
buf_->is_channels_last_1d_contiguous());
}
// Does not support mixing the contiguous tensor and channels-last contigous
// tensor
TORCH_INTERNAL_ASSERT(
is_tensor_creation ||
((is_contiguous ^ is_channels_last_contiguous) &&
(is_contiguous || is_channels_last_contiguous)));
}
auto outputType = findDtypeForValue(v);
std::vector<ExprHandle> outputShape = sizesForValue(v);
std::vector<ExprHandle> outputStrides;
if (is_channels_last_contiguous) {
outputStrides =
c10::fmap<ExprHandle>(make_channels_last_strides(outputShape));
} else {
// Default
outputStrides = c10::fmap<ExprHandle>(make_contiguous_strides(outputShape));
}
std::vector<ArgValue> argInputs;
if (op == prim::ConstantChunk) {
@ -521,12 +573,14 @@ Tensor TensorExprKernel::computeValue(const torch::jit::Value* v) {
}
if (NNCLoweringFunction custom_lowering = getCustomLoweringFor(op)) {
return custom_lowering(argInputs, outputShape, outputType, device_);
return custom_lowering(
argInputs, outputShape, outputStrides, outputType, device_);
}
if (v->node()->maybeSchema()) {
if (NNCLoweringFunction lowering =
getStandardLoweringFor(c10::toString(v->node()->schema()))) {
return lowering(argInputs, outputShape, outputType, device_);
return lowering(
argInputs, outputShape, outputStrides, outputType, device_);
}
}
std::string msg = std::string("Unhandled node kind (in computeValue): ") +
@ -995,28 +1049,55 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) {
auto const& outputs = input->owningGraph()->outputs();
std::unordered_set<const Value*> outputs_set(outputs.begin(), outputs.end());
auto is_concrete_cont = [](const torch::jit::Value* input) {
if (input->isCompleteTensor()) {
return isContiguous(input) ||
isContiguous(input, at::MemoryFormat::ChannelsLast);
} else {
return false;
}
};
auto is_symbolic_cont = [](std::vector<torch::jit::StrideInput> desc) {
if (desc.size() == 1) {
return desc[0] == torch::jit::StrideInput::TENSOR_CONT ||
desc[0] == torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST;
} else {
return false;
}
};
Tensor result(nullptr, nullptr);
switch (t->kind()) {
case TypeKind::TensorType: {
auto tt = input->type()->cast<TensorType>();
bool contiguous_concrete_tensor =
(input->isCompleteTensor() && isContiguous(input));
bool contiguous_sym_tensor = false;
bool contiguous_concrete_tensor = is_concrete_cont(input);
bool contiguous_symbolic_tensor = false;
if (has_symbolic_shapes_) {
auto desc = getSymbolicInputStrideDesc(input);
contiguous_sym_tensor =
desc.size() == 1 && desc[0] == torch::jit::StrideInput::TENSOR_CONT;
contiguous_symbolic_tensor = is_symbolic_cont(desc);
}
// Get input size and strides
auto size_handles = sizesFromSymbolicShape(tt->symbolic_sizes());
auto inputTensorStrides = getInputStrides(input, size_handles);
// We don't need to copy the input if:
// 1) it is not an output AND
// 2) it is contiguous
bool contiguous = contiguous_concrete_tensor || contiguous_sym_tensor;
bool contiguous =
contiguous_concrete_tensor || contiguous_symbolic_tensor;
if (!outputs_set.count(input) && contiguous) {
BufHandle inBuffer(
"t" + input_name_map_[input],
sizesFromSymbolicShape(tt->symbolic_sizes()),
inputTensorStrides,
ToDtype(static_cast<ScalarType>(*tt->scalarType())));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
inBuffer.node()->is_contiguous() ||
inBuffer.node()->is_channels_last_1d_contiguous() ||
inBuffer.node()->is_contiguous(at::MemoryFormat::ChannelsLast) ||
inBuffer.node()->is_contiguous(at::MemoryFormat::ChannelsLast3d));
bufs_.emplace(input, inBuffer.node());
bufferArgs_.emplace_back(inBuffer);
break;
@ -1025,8 +1106,6 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) {
// if the input isn't contiguous or is an output,
// write strided input into contiguous buffer that is
// then used in all further compute
auto size_handles = sizesFromSymbolicShape(tt->symbolic_sizes());
auto inputTensorStrides = getInputStrides(input, size_handles);
ExprHandle flat_size = 1;
for (size_t i = 0; i < size_handles.size(); ++i) {
auto size = size_handles[i];
@ -1168,11 +1247,12 @@ Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides(
"Ouput tensor has no corresponding bufs in the fuser."));
BufPtr buf = bufs_.at(v);
// output is contiguous, no work to do
if (tensorOutputStrideDesc_[v->offset()] ==
torch::jit::StrideInput::TENSOR_CONT) {
auto stride_desc = tensorOutputStrideDesc_[v->offset()];
if (stride_desc == torch::jit::StrideInput::TENSOR_CONT ||
stride_desc == torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST) {
return Tensor(buf, nullptr);
;
}
TORCH_INTERNAL_ASSERT(
tensorOutputStrideDesc_[v->offset()] ==
torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST);

File diff suppressed because it is too large Load Diff

View File

@ -32,6 +32,7 @@ using ArgValue = c10::variant<
using NNCLoweringFunction = std::function<Tensor(
const std::vector<ArgValue>&,
const std::vector<ExprHandle>&,
const std::vector<ExprHandle>&,
const c10::optional<ScalarType>&,
at::Device)>;

View File

@ -50,6 +50,7 @@ Tensor conv2d_depthwise_static(
Tensor conv = Reduce(
"conv2d_depthwise",
{N, K, OH, OW},
c10::nullopt, // TODO
Sum(),
[&](const std::vector<VarHandle>& v) { return init_func(v); },
[&](const std::vector<VarHandle>& v) {
@ -121,6 +122,7 @@ Tensor conv2d_depthwise_dynamic(
return Reduce(
"conv2d_depthwise",
{N, K, OH, OW},
c10::nullopt, // TODO
Sum(),
[&](const std::vector<VarHandle>& v) { return init_func(v); },
[&](const std::vector<VarHandle>& v) {
@ -308,6 +310,7 @@ bool conv2dIsSupported(
Tensor computeConv2d(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
Dtype dtype = kFloat;
@ -355,6 +358,7 @@ Tensor computeConv2d(
Tensor computeConv1d(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
Dtype dtype = kFloat;
@ -388,6 +392,7 @@ Tensor computeConv1d(
Tensor computePrepackedConv2dClampRun(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
Dtype dtype = kFloat;
@ -406,6 +411,7 @@ Tensor computePrepackedConv2dClampRun(
Tensor computePrepackedLinearClampRun(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
Dtype dtype = kFloat;

View File

@ -66,21 +66,25 @@ bool conv2dIsSupported(
Tensor computeConv2d(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
Tensor computeConv1d(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
Tensor computePrepackedConv2dClampRun(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
Tensor computePrepackedLinearClampRun(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);

View File

@ -8,6 +8,7 @@ namespace tensorexpr {
Tensor computeMatmul(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
Dtype dtype = kFloat;
@ -54,6 +55,7 @@ Tensor computeMatmul(
Tensor computeAddMM(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
Dtype dtype = kFloat;

View File

@ -9,11 +9,13 @@ namespace tensorexpr {
Tensor computeMatmul(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
Tensor computeAddMM(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);

View File

@ -320,6 +320,7 @@ std::vector<ExprHandle> computeIndicesToBroadcast(
Tensor computeChunk(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
return Compute(
@ -353,6 +354,7 @@ Tensor computeChunk(
Tensor computeTranspose(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
auto A = c10::get<BufHandle>(inputs[0]);
@ -379,6 +381,7 @@ Tensor computeTranspose(
Tensor computeExpand(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
auto A = c10::get<BufHandle>(inputs[0]);
@ -392,6 +395,7 @@ Tensor computeExpand(
Tensor computeReshape(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
auto A = c10::get<BufHandle>(inputs[0]);
@ -459,6 +463,7 @@ Tensor computeReshape(
Tensor computeFlatten(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
std::vector<int64_t> outputShapeVec;
@ -468,7 +473,8 @@ Tensor computeFlatten(
std::vector<ArgValue> reshapeInputs;
reshapeInputs.push_back(inputs[0]);
reshapeInputs.emplace_back(outputShapeVec);
return computeReshape(reshapeInputs, outputShape, outputType, device);
return computeReshape(
reshapeInputs, outputShape, outputStrides, outputType, device);
}
static std::pair<ScalarType, std::vector<BufHandle>> processCatList(
@ -586,6 +592,7 @@ Tensor computeCatWoConditionals(
Tensor computeCat(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
if (device == at::kCPU && getCatWoConditionals()) {
@ -598,7 +605,10 @@ Tensor computeCat(
ScalarType highType = catInfo.first;
std::vector<BufHandle> nonEmptyInputs = catInfo.second;
return Compute(
"aten_cat", outputShape, [&](const std::vector<VarHandle>& axes) {
"aten_cat",
outputShape,
outputStrides,
[&](const std::vector<VarHandle>& axes) {
if (nonEmptyInputs.size() == 0) {
return ExprHandle(0);
}
@ -645,6 +655,7 @@ Tensor computeCat(
Tensor computeEmbedding(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
Dtype dtype = kFloat;

View File

@ -50,26 +50,31 @@ ExprHandle clamp(
Tensor computeChunk(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
Tensor computeTranspose(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
Tensor computeExpand(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
Tensor computeReshape(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
Tensor computeFlatten(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
Tensor computeCatWoConditionals(
@ -78,11 +83,13 @@ Tensor computeCatWoConditionals(
Tensor computeCat(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
Tensor computeEmbedding(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);

View File

@ -8,6 +8,7 @@ namespace tensorexpr {
Tensor computeBatchNorm(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
bool hasWeight = true;
@ -22,7 +23,10 @@ Tensor computeBatchNorm(
}
return Compute(
"aten_batch_norm", outputShape, [&](const std::vector<VarHandle>& axes) {
"aten_batch_norm",
outputShape,
outputStrides,
[&](const std::vector<VarHandle>& axes) {
TORCH_INTERNAL_ASSERT(axes.size() >= 2);
// axes: N, C, H, W
std::vector<ExprHandle> indices(axes.begin(), axes.end());

View File

@ -9,6 +9,7 @@ namespace tensorexpr {
Tensor computeBatchNorm(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);

View File

@ -9,28 +9,32 @@ using namespace torch::jit::tensorexpr;
Tensor computeSign(
const std::vector<ArgValue>& inputValues,
const std::vector<ExprHandle>& outputShape) {
return Compute("aten_sign", outputShape, [&](ParameterList& axes) {
std::vector<ExprHandle> indices(axes.begin(), axes.end());
std::vector<ExprHandle> inputs = {
tensorOrConstant(inputValues[0], indices)};
auto inp = inputs[0];
auto zero = ExprHandle(immLike(inp, 0.0f));
auto res = (zero < inp) - (inp < zero);
return promoteToDtype(res, inp.dtype().scalar_type());
});
const std::vector<ExprHandle>& outputShape,
c10::optional<std::vector<ExprHandle>> outputStrides) {
return Compute(
"aten_sign", outputShape, outputStrides, [&](ParameterList& axes) {
std::vector<ExprHandle> indices(axes.begin(), axes.end());
std::vector<ExprHandle> inputs = {
tensorOrConstant(inputValues[0], indices)};
auto inp = inputs[0];
auto zero = ExprHandle(immLike(inp, 0.0f));
auto res = (zero < inp) - (inp < zero);
return promoteToDtype(res, inp.dtype().scalar_type());
});
}
Tensor computeOneOperand(
const std::string& name,
const std::vector<ArgValue>& inputValues,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
const std::function<ExprHandle(const ExprHandle&)>& innerExpr,
const int checkParamTypes) {
return Compute(
name,
outputShape,
outputStrides,
[inputValues, outputType, innerExpr, checkParamTypes](
const std::vector<VarHandle>& axes) {
std::vector<ExprHandle> indices(axes.begin(), axes.end());
@ -46,12 +50,14 @@ Tensor computeTwoOperand(
const std::string& name,
const std::vector<ArgValue>& inputValues,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
innerExpr) {
return Compute(
name,
outputShape,
outputStrides,
[inputValues, outputType, innerExpr](const std::vector<VarHandle>& axes) {
std::vector<ExprHandle> indices(axes.begin(), axes.end());
std::vector<ExprHandle> inputs = {
@ -69,12 +75,14 @@ Tensor computeTwoOperandWithAlpha(
const std::string& name,
const std::vector<ArgValue>& inputValues,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
innerExpr) {
return Compute(
name,
outputShape,
outputStrides,
[inputValues, outputType, innerExpr](const std::vector<VarHandle>& axes) {
std::vector<ExprHandle> indices(axes.begin(), axes.end());
std::vector<ExprHandle> inputs = {
@ -93,6 +101,7 @@ Tensor computeConditionWithTwoOperand(
const std::string& name,
const std::vector<ArgValue>& inputValues,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
const std::function<
ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>&
@ -100,6 +109,7 @@ Tensor computeConditionWithTwoOperand(
return Compute(
name,
outputShape,
outputStrides,
[inputValues, outputType, innerExpr](const std::vector<VarHandle>& axes) {
std::vector<ExprHandle> indices(axes.begin(), axes.end());
std::vector<ExprHandle> inputs = {
@ -120,6 +130,7 @@ Tensor computeThreeOperand(
const std::string& name,
const std::vector<ArgValue>& inputValues,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
const std::function<
ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>&
@ -128,6 +139,7 @@ Tensor computeThreeOperand(
return Compute(
name,
outputShape,
outputStrides,
[inputValues, outputType, innerExpr, promote_inputs](
const std::vector<VarHandle>& axes) {
std::vector<ExprHandle> indices(axes.begin(), axes.end());
@ -148,6 +160,7 @@ Tensor computeFourOperand(
const std::string& name,
const std::vector<ArgValue>& inputValues,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
const std::function<ExprHandle(
const ExprHandle&,
@ -157,6 +170,7 @@ Tensor computeFourOperand(
return Compute(
name,
outputShape,
outputStrides,
[inputValues, outputType, innerExpr](const std::vector<VarHandle>& axes) {
std::vector<ExprHandle> indices(axes.begin(), axes.end());
std::vector<ExprHandle> inputs = {
@ -176,18 +190,23 @@ Tensor computeFourOperand(
Tensor computeNoop(
const std::vector<ArgValue>& inputValues,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
return computeOneOperand(
"copy", inputValues, outputShape, outputType, [](const ExprHandle& a) {
return a;
});
"copy",
inputValues,
outputShape,
outputStrides,
outputType,
[](const ExprHandle& a) { return a; });
}
Tensor computeScalar(
const std::string& name,
const std::vector<ArgValue>& inputValues,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
innerExpr) {

View File

@ -8,12 +8,14 @@ namespace tensorexpr {
TORCH_API Tensor computeSign(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape);
const std::vector<ExprHandle>& outputShape,
c10::optional<std::vector<ExprHandle>> outputStrides = c10::nullopt);
Tensor computeOneOperand(
const std::string& name,
const std::vector<ArgValue>& inputValues,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
const std::function<ExprHandle(const ExprHandle&)>& innerExpr,
const int checkParamTypes = kAllTypes);
@ -21,6 +23,7 @@ Tensor computeTwoOperand(
const std::string& name,
const std::vector<ArgValue>& inputValues,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
innerExpr);
@ -28,6 +31,7 @@ Tensor computeTwoOperandWithAlpha(
const std::string& name,
const std::vector<ArgValue>& inputValues,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
innerExpr);
@ -35,6 +39,7 @@ Tensor computeConditionWithTwoOperand(
const std::string& name,
const std::vector<ArgValue>& inputValues,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
const std::function<
ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>&
@ -43,6 +48,7 @@ Tensor computeThreeOperand(
const std::string& name,
const std::vector<ArgValue>& inputValues,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
const std::function<
ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>&
@ -52,6 +58,7 @@ Tensor computeFourOperand(
const std::string& name,
const std::vector<ArgValue>& inputValues,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
const std::function<ExprHandle(
const ExprHandle&,
@ -61,6 +68,7 @@ Tensor computeFourOperand(
Tensor computeNoop(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
@ -68,6 +76,7 @@ Tensor computeScalar(
const std::string& name,
const std::vector<ArgValue>& inputValues,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
innerExpr);

View File

@ -138,6 +138,7 @@ ExprHandle dequant(
Tensor computeQuantizePerTensor(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>&,
at::Device) {
std::vector<VarPtr> vars;
@ -177,6 +178,7 @@ Tensor computeQuantizePerTensor(
Tensor computeQuantizedAdd(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device) {
const BufHandle& QA = c10::get<BufHandle>(inputs[0]);
@ -219,6 +221,7 @@ Tensor computeQuantizedAdd(
Tensor computeQuantizePerTensorExternalCall(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
// NOLINTNEXTLINE
const c10::optional<ScalarType>& outputType,
at::Device) {
@ -251,6 +254,7 @@ Tensor computeQuantizePerTensorExternalCall(
Tensor computeDequantizeExternalCall(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device) {
Dtype dtype = kFloat;
@ -275,6 +279,7 @@ Tensor computeDequantizeExternalCall(
Tensor computeQuantizedConv2dPrepack(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device) {
Dtype dtype = kFloat;
@ -323,6 +328,7 @@ Tensor computeQuantizedConv2dPrepack(
Tensor computeQuantizedConv1d(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
// NOLINTNEXTLINE
const c10::optional<ScalarType>& outputType,
// NOLINTNEXTLINE
@ -354,6 +360,7 @@ Tensor computeQuantizedConv1d(
Tensor computeQuantizedConv2d(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
// NOLINTNEXTLINE
const c10::optional<ScalarType>& outputType,
// NOLINTNEXTLINE
@ -385,6 +392,7 @@ Tensor computeQuantizedConv2d(
Tensor computeQuantizedConv2dRelu(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
// NOLINTNEXTLINE
const c10::optional<ScalarType>& outputType,
// NOLINTNEXTLINE
@ -416,6 +424,7 @@ Tensor computeQuantizedConv2dRelu(
Tensor computeQuantizedLinear(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
// NOLINTNEXTLINE
const c10::optional<ScalarType>& outputType,
// NOLINTNEXTLINE
@ -447,6 +456,7 @@ Tensor computeQuantizedLinear(
Tensor computeQuantizedLinearRelu(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
// NOLINTNEXTLINE
const c10::optional<ScalarType>& outputType,
// NOLINTNEXTLINE
@ -478,6 +488,7 @@ Tensor computeQuantizedLinearRelu(
Tensor computeQuantizedAddExternalCall(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
// NOLINTNEXTLINE
const c10::optional<ScalarType>& outputType,
// NOLINTNEXTLINE
@ -521,6 +532,7 @@ Tensor computeQuantizedAddExternalCall(
Tensor computeQuantizedMul(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
// NOLINTNEXTLINE
const c10::optional<ScalarType>& outputType,
// NOLINTNEXTLINE
@ -551,6 +563,7 @@ Tensor computeQuantizedMul(
Tensor computeQuantizedMulScalar(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
// NOLINTNEXTLINE
const c10::optional<ScalarType>& outputType,
// NOLINTNEXTLINE
@ -577,6 +590,7 @@ Tensor computeQuantizedMulScalar(
Tensor computeQuantizedRelu(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
// NOLINTNEXTLINE
const c10::optional<ScalarType>& outputType,
// NOLINTNEXTLINE
@ -607,6 +621,7 @@ Tensor computeQuantizedRelu(
Tensor computeQuantizedCat(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
// NOLINTNEXTLINE
const c10::optional<ScalarType>& outputType,
// NOLINTNEXTLINE
@ -645,6 +660,7 @@ Tensor computeQuantizedCat(
Tensor computeDequantize(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device) {
Dtype dtype = kFloat;
@ -676,6 +692,7 @@ Tensor computeDequantize(
Tensor computeUpsampleNearest2d(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device) {
auto A = c10::get<BufHandle>(inputs[0]);
@ -721,6 +738,7 @@ Tensor computeUpsampleNearest2d(
Tensor computeUpsampleNearest2dExternalCall(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device) {
Dtype dtype = kFloat;
@ -779,6 +797,7 @@ Tensor computeUpsampleNearest2dExternalCall(
Tensor computeQuantizedSigmoidExternalCall(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
// NOLINTNEXTLINE
const c10::optional<ScalarType>& outputType,
at::Device) {

View File

@ -19,120 +19,140 @@ TORCH_API bool isQuantized(const BufHandle& qx);
TORCH_API Tensor computeQuantizePerTensor(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
TORCH_API Tensor computeQuantizePerTensorExternalCall(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
TORCH_API Tensor computeQuantizedConv1d(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
TORCH_API Tensor computeQuantizedConv2dPrepack(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
TORCH_API Tensor computeQuantizedConv1d(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
TORCH_API Tensor computeQuantizedConv2d(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
TORCH_API Tensor computeQuantizedConv2dRelu(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
TORCH_API Tensor computeQuantizedLinear(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
TORCH_API Tensor computeQuantizedLinearRelu(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
TORCH_API Tensor computeQuantizedAdd(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
Tensor computeQuantizedAddExternalCall(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
TORCH_API Tensor computeQuantizedMul(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
TORCH_API Tensor computeQuantizedMulScalar(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
TORCH_API Tensor computeQuantizedCat(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
TORCH_API Tensor computeQuantizedRelu(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
TORCH_API Tensor computeDequantize(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
TORCH_API Tensor computeDequantizeExternalCall(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
TORCH_API Tensor computeUpsampleNearest2d(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
TORCH_API Tensor computeUpsampleNearest2dExternalCall(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
TORCH_API Tensor computeQuantizedSigmoidExternalCall(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device);
} // namespace tensorexpr

View File

@ -22,6 +22,7 @@ namespace tensorexpr {
Tensor computeSum(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
std::vector<size_t> axes;
@ -72,6 +73,7 @@ Tensor computeSum(
return Reduce(
"sum",
outputDims,
outputStrides,
Sum(),
[&](ParameterList& indices) {
// "Squeeze" out indices inserted when keepdim is set.
@ -105,6 +107,7 @@ Tensor computeSum(
Tensor computeMean(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
Dtype dtype = kFloat;
@ -137,6 +140,7 @@ Tensor computeMean(
Tensor computeMax(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
Dtype dtype = kFloat;
@ -160,6 +164,7 @@ Tensor computeMax(
Tensor computeAdaptiveAvgPool2d(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
Dtype dtype = kFloat;

View File

@ -9,21 +9,25 @@ namespace tensorexpr {
TORCH_API Tensor computeSum(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
TORCH_API Tensor computeMean(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
TORCH_API Tensor computeAdaptiveAvgPool2d(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);
Tensor computeMax(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device);

View File

@ -9,6 +9,7 @@ using namespace torch::jit::tensorexpr;
Tensor computeSoftmax(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
bool log_softmax) {
// Softmax is computed as follows:
// softmax(vi) = exp(vi) / sum(exp(vi))
@ -102,14 +103,18 @@ Tensor computeSoftmax(
auto max = Reduce(
"aten_softmax_max",
non_softmax_dims,
c10::nullopt,
Maximum(dtype),
[&](ParameterList& indices) {
return tensorOrConstant(
inputs[0], move_softmax_dim_index_to_pos(indices));
},
{outputShape[softmax_dim]});
auto e =
Compute("aten_softmax_exp", outputShape, [&](ParameterList& indices) {
auto e = Compute(
"aten_softmax_exp",
outputShape,
c10::nullopt,
[&](ParameterList& indices) {
auto inp = tensorOrConstant(
inputs[0], convert_indices_to_expr_handle(indices));
return exp(inp - max.load(remove_softmax_dim_index(indices)));
@ -117,14 +122,15 @@ Tensor computeSoftmax(
auto sum = Reduce(
"aten_softmax_sum",
non_softmax_dims,
c10::nullopt,
Sum(),
[&](ParameterList& indices) {
return e.load(move_softmax_dim_index_to_pos(indices));
},
{outputShape[softmax_dim]});
if (!log_softmax) {
auto result =
Compute("aten_softmax", outputShape, [&](ParameterList& indices) {
auto result = Compute(
"aten_softmax", outputShape, c10::nullopt, [&](ParameterList& indices) {
return e.load(indices) / sum.load(remove_softmax_dim_index(indices));
});
return Tensor(
@ -134,11 +140,15 @@ Tensor computeSoftmax(
}
auto log_sum = Compute(
"aten_softmax_log_sum", non_softmax_dims, [&](ParameterList& indices) {
return log(sum.load(indices));
});
auto result =
Compute("aten_log_softmax", outputShape, [&](ParameterList& indices) {
"aten_softmax_log_sum",
non_softmax_dims,
c10::nullopt,
[&](ParameterList& indices) { return log(sum.load(indices)); });
auto result = Compute(
"aten_log_softmax",
outputShape,
c10::nullopt,
[&](ParameterList& indices) {
auto inp = tensorOrConstant(
inputs[0], convert_indices_to_expr_handle(indices));
auto non_softmax_indices = remove_softmax_dim_index(indices);

View File

@ -9,6 +9,7 @@ namespace tensorexpr {
Tensor computeSoftmax(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
bool log_softmax);
} // namespace tensorexpr

View File

@ -39,9 +39,31 @@ StmtPtr Tensor::constructStmt(
}
}
for (const auto i : c10::irange(ndim)) {
// Going in reverse order: from innermost loop to the outermost
size_t dim_index = ndim - i - 1;
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
buf_->is_contiguous() ||
buf_->is_contiguous(at::MemoryFormat::ChannelsLast) ||
buf_->is_contiguous(at::MemoryFormat::ChannelsLast3d) ||
buf_->is_channels_last_1d_contiguous());
auto loop_order_fn = [&]() {
std::vector<int32_t> loop_order;
if (buf_->is_contiguous()) {
for (int32_t i = args.size() - 1; i >= 0; i--) {
loop_order.push_back(i);
}
} else if (buf_->is_contiguous(c10::MemoryFormat::ChannelsLast)) {
loop_order = {1, 3, 2, 0};
} else if (buf_->is_contiguous(c10::MemoryFormat::ChannelsLast3d)) {
loop_order = {1, 4, 3, 2, 0};
} else {
loop_order = {1, 2, 0};
}
return loop_order;
};
auto loop_order = loop_order_fn();
for (auto dim_index : loop_order) {
auto const& dim = buf()->dim(dim_index);
s = alloc<For>(args[dim_index], immLike(dim, 0), dim, s);
}
@ -51,16 +73,24 @@ StmtPtr Tensor::constructStmt(
Tensor Compute(
const std::string& name,
const std::vector<ExprHandle>& dims,
c10::optional<std::vector<ExprHandle>> strides,
const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func) {
std::vector<VarHandle> args = create_index_vars(dims);
ExprHandle body = body_func(args);
BufHandle buf = Buf::make(name, dims, body.dtype());
BufHandle buf = Buf::make(name, dims, body.dtype(), c10::nullopt, strides);
return Tensor(buf, args, body);
}
Tensor Compute(
const std::string& name,
const std::vector<ExprHandle>& dims,
const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func) {
return Compute(name, dims, c10::nullopt, body_func);
}
Tensor Compute(
const std::string& name,
const std::vector<ExprHandle>& dims,
c10::optional<std::vector<ExprHandle>> strides,
const std::function<ExprHandle(const VarHandle&)>& body_func) {
if (dims.size() != 1) {
throw malformed_input("mismatch between body and arg size (1)");
@ -68,13 +98,20 @@ Tensor Compute(
std::vector<VarHandle> args = create_index_vars(dims);
ExprHandle body = body_func(args[0]);
BufHandle buf = Buf::make(name, dims, body.dtype());
BufHandle buf = Buf::make(name, dims, body.dtype(), c10::nullopt, strides);
return Tensor(buf, args, body);
}
Tensor Compute(
const std::string& name,
const std::vector<ExprHandle>& dims,
const std::function<ExprHandle(const VarHandle&)>& body_func) {
return Compute(name, dims, c10::nullopt, body_func);
}
Tensor Compute(
const std::string& name,
const std::vector<ExprHandle>& dims,
c10::optional<std::vector<ExprHandle>> strides,
const std::function<ExprHandle(const VarHandle&, const VarHandle&)>&
body_func) {
if (dims.size() != 2) {
@ -82,13 +119,21 @@ Tensor Compute(
}
std::vector<VarHandle> args = create_index_vars(dims);
ExprHandle body = body_func(args[0], args[1]);
BufHandle buf = Buf::make(name, dims, body.dtype());
BufHandle buf = Buf::make(name, dims, body.dtype(), c10::nullopt, strides);
return Tensor(buf, args, body);
}
Tensor Compute(
const std::string& name,
const std::vector<ExprHandle>& dims,
const std::function<ExprHandle(const VarHandle&, const VarHandle&)>&
body_func) {
return Compute(name, dims, c10::nullopt, body_func);
}
Tensor Compute(
const std::string& name,
const std::vector<ExprHandle>& dims,
c10::optional<std::vector<ExprHandle>> strides,
const std::function<
ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>&
body_func) {
@ -97,13 +142,22 @@ Tensor Compute(
}
std::vector<VarHandle> args = create_index_vars(dims);
ExprHandle body = body_func(args[0], args[1], args[2]);
BufHandle buf = Buf::make(name, dims, body.dtype());
BufHandle buf = Buf::make(name, dims, body.dtype(), c10::nullopt, strides);
return Tensor(buf, args, body);
}
Tensor Compute(
const std::string& name,
const std::vector<ExprHandle>& dims,
const std::function<
ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>&
body_func) {
return Compute(name, dims, c10::nullopt, body_func);
}
Tensor Compute(
const std::string& name,
const std::vector<ExprHandle>& dims,
c10::optional<std::vector<ExprHandle>> strides,
const std::function<ExprHandle(
const VarHandle&,
const VarHandle&,
@ -114,37 +168,67 @@ Tensor Compute(
}
std::vector<VarHandle> args = create_index_vars(dims);
ExprHandle body = body_func(args[0], args[1], args[2], args[3]);
BufHandle buf = Buf::make(name, dims, body.dtype());
BufHandle buf = Buf::make(name, dims, body.dtype(), c10::nullopt, strides);
return Tensor(buf, args, body);
}
Tensor Compute(
const std::string& name,
const std::vector<ExprHandle>& dims,
const std::function<ExprHandle(
const VarHandle&,
const VarHandle&,
const VarHandle&,
const VarHandle&)>& body_func) {
return Compute(name, dims, c10::nullopt, body_func);
}
Tensor Reduce(
const std::string& name,
const std::vector<ExprHandle>& dims,
c10::optional<std::vector<ExprHandle>> strides,
const Reducer& reducer,
const BufHandle& buffer,
const std::vector<ExprHandle>& reduce_dims) {
return Reduce(
name,
dims,
strides,
reducer,
[&](ParameterList& p) { return buffer.load(p); },
reduce_dims);
}
Tensor Reduce(
const std::string& name,
const std::vector<ExprHandle>& dims,
const Reducer& reducer,
const BufHandle& buffer,
const std::vector<ExprHandle>& reduce_dims) {
return Reduce(name, dims, c10::nullopt, reducer, buffer, reduce_dims);
}
Tensor Reduce(
const std::string& name,
const std::vector<ExprHandle>& dims,
c10::optional<std::vector<ExprHandle>> strides,
const Reducer& reducer,
Tensor tensor,
const std::vector<ExprHandle>& reduce_dims) {
return Reduce(
name,
dims,
strides,
reducer,
[&](ParameterList& p) { return tensor.load(p); },
reduce_dims);
}
Tensor Reduce(
const std::string& name,
const std::vector<ExprHandle>& dims,
const Reducer& reducer,
Tensor tensor,
const std::vector<ExprHandle>& reduce_dims) {
return Reduce(name, dims, c10::nullopt, reducer, tensor, reduce_dims);
}
} // namespace tensorexpr
} // namespace jit

View File

@ -73,12 +73,30 @@ class TORCH_API Tensor {
TORCH_API Tensor Compute(
const std::string& func_name,
const std::vector<ExprHandle>& dims,
c10::optional<std::vector<ExprHandle>> strides,
const std::function<ExprHandle(const VarHandle&)>& body_func);
TORCH_API Tensor Compute(
const std::string& func_name,
const std::vector<ExprHandle>& dims,
const std::function<ExprHandle(const VarHandle&)>& body_func);
TORCH_API Tensor Compute(
const std::string& func_name,
const std::vector<ExprHandle>& dims,
c10::optional<std::vector<ExprHandle>> strides,
const std::function<ExprHandle(const VarHandle&, const VarHandle&)>&
body_func);
TORCH_API Tensor Compute(
const std::string& func_name,
const std::vector<ExprHandle>& dims,
const std::function<ExprHandle(const VarHandle&, const VarHandle&)>&
body_func);
TORCH_API Tensor Compute(
const std::string& func_name,
const std::vector<ExprHandle>& dims,
c10::optional<std::vector<ExprHandle>> strides,
const std::function<
ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>&
body_func);
TORCH_API Tensor Compute(
const std::string& func_name,
const std::vector<ExprHandle>& dims,
@ -88,11 +106,25 @@ TORCH_API Tensor Compute(
TORCH_API Tensor Compute(
const std::string& func_name,
const std::vector<ExprHandle>& dims,
c10::optional<std::vector<ExprHandle>> strides,
const std::function<ExprHandle(
const VarHandle&,
const VarHandle&,
const VarHandle&,
const VarHandle&)>& body_func);
TORCH_API Tensor Compute(
const std::string& func_name,
const std::vector<ExprHandle>& dims,
const std::function<ExprHandle(
const VarHandle&,
const VarHandle&,
const VarHandle&,
const VarHandle&)>& body_func);
TORCH_API Tensor Compute(
const std::string& func_name,
const std::vector<ExprHandle>& dims,
c10::optional<std::vector<ExprHandle>> strides,
const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func);
TORCH_API Tensor Compute(
const std::string& func_name,
const std::vector<ExprHandle>& dims,
@ -114,6 +146,7 @@ template <typename InitFunc, typename BodyFunc>
Tensor Reduce(
const std::string& func_name,
const std::vector<ExprHandle>& dims,
c10::optional<std::vector<ExprHandle>> strides,
const Reducer& reducer,
const InitFunc& init_func,
const BodyFunc& body_func,
@ -125,7 +158,8 @@ Tensor Reduce(
// copy
if (reduce_vars.empty()) {
ExprHandle body = Reducer::getReduceBody(body_func, vars);
BufHandle func_result = Buf::make(func_name, dims, body.dtype());
BufHandle func_result =
Buf::make(func_name, dims, body.dtype(), c10::nullopt, strides);
return Tensor(func_result, vars, body);
}
@ -141,7 +175,41 @@ Tensor Reduce(
Tensor t = Tensor(func_result, vars, reduce_dims, reduce_vars, reduce_op);
return t;
}
template <typename InitFunc, typename BodyFunc>
Tensor Reduce(
const std::string& func_name,
const std::vector<ExprHandle>& dims,
const Reducer& reducer,
const InitFunc& init_func,
const BodyFunc& body_func,
const std::vector<ExprHandle>& reduce_dims) {
return Reduce<InitFunc, BodyFunc>(
func_name,
dims,
c10::nullopt,
reducer,
init_func,
body_func,
reduce_dims);
}
template <typename BodyFunc>
Tensor Reduce(
const std::string& func_name,
const std::vector<ExprHandle>& dims,
c10::optional<std::vector<ExprHandle>> strides,
const Reducer& reducer,
const BodyFunc& body_func,
const std::vector<ExprHandle>& reduce_dims) {
return Reduce(
func_name,
dims,
strides,
reducer,
[&](ParameterList p) { return ExprHandle(reducer.initializer()); },
body_func,
reduce_dims);
}
template <typename BodyFunc>
Tensor Reduce(
const std::string& func_name,
@ -149,13 +217,8 @@ Tensor Reduce(
const Reducer& reducer,
const BodyFunc& body_func,
const std::vector<ExprHandle>& reduce_dims) {
return Reduce(
func_name,
dims,
reducer,
[&](ParameterList p) { return ExprHandle(reducer.initializer()); },
body_func,
reduce_dims);
return Reduce<BodyFunc>(
func_name, dims, c10::nullopt, reducer, body_func, reduce_dims);
}
// Overload which allows inline lambda functions for the body_func.
@ -163,12 +226,29 @@ template <typename BodyFunc>
Tensor Reduce(
const std::string& func_name,
const std::vector<ExprHandle>& dims,
c10::optional<std::vector<ExprHandle>> strides,
const Reducer& reducer,
const BodyFunc&& body_func,
const std::vector<ExprHandle>& reduce_dims) {
return Reduce(func_name, dims, reducer, body_func, reduce_dims);
return Reduce(func_name, dims, strides, reducer, body_func, reduce_dims);
}
template <typename BodyFunc>
Tensor Reduce(
const std::string& func_name,
const std::vector<ExprHandle>& dims,
const Reducer& reducer,
const BodyFunc&& body_func,
const std::vector<ExprHandle>& reduce_dims) {
return Reduce(func_name, dims, c10::nullopt, reducer, body_func, reduce_dims);
}
TORCH_API Tensor Reduce(
const std::string& name,
const std::vector<ExprHandle>& dims,
c10::optional<std::vector<ExprHandle>> strides,
const Reducer& reducer,
const BufHandle& buffer,
const std::vector<ExprHandle>& reduce_dims);
TORCH_API Tensor Reduce(
const std::string& name,
const std::vector<ExprHandle>& dims,
@ -178,6 +258,13 @@ TORCH_API Tensor Reduce(
// Overload for the common case of all dimensions of a prevously Computed
// Tensor.
TORCH_API Tensor Reduce(
const std::string& func_name,
const std::vector<ExprHandle>& dims,
c10::optional<std::vector<ExprHandle>> strides,
const Reducer& reducer,
Tensor tensor,
const std::vector<ExprHandle>& reduce_dims);
TORCH_API Tensor Reduce(
const std::string& func_name,
const std::vector<ExprHandle>& dims,

View File

@ -707,8 +707,14 @@ void initTensorExprBindings(PyObject* module) {
}
if (NNCLoweringFunction lowering =
getStandardLoweringFor(op.toQualString())) {
std::vector<ExprHandle> outputStrides =
c10::fmap<ExprHandle>(make_channels_last_strides(outputShape));
return lowering(
argInputs, outputShape, outputType.scalar_type(), at::kCPU);
argInputs,
outputShape,
outputStrides,
outputType.scalar_type(),
at::kCPU);
}
std::string msg = std::string("Unhandled node kind (in te.lower): ") +
op.toQualString();