mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[nnc] Strides to Tensor (#72962)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72962 Test Plan: Imported from OSS Reviewed By: ZolotukhinM, cpuhrsch Differential Revision: D34589306 Pulled By: IvanKobzarev fbshipit-source-id: ecee5249760ecc0c8b2edb1842b90218899bc944 (cherry picked from commit 9e310c4c67389da30da89126d838ffe3864aba6f)
This commit is contained in:
parent
1a7e43be14
commit
939060925f
|
|
@ -783,15 +783,36 @@ struct TORCH_API TensorType : public SharedType {
|
|||
|
||||
static const TypeKind Kind = TypeKind::TensorType;
|
||||
|
||||
static std::vector<int64_t> contiguousStridesOf(at::IntArrayRef sizes) {
|
||||
std::vector<int64_t> strides(sizes.size());
|
||||
if (sizes.empty()) // zero-dim case
|
||||
static std::vector<int64_t> contiguousStridesOf(
|
||||
at::IntArrayRef sizes,
|
||||
at::MemoryFormat memory_format = MemoryFormat::Contiguous) {
|
||||
auto contiguous_fn = [](const at::IntArrayRef& sizes,
|
||||
const std::vector<int64_t>& dim_order) {
|
||||
std::vector<int64_t> strides(sizes.size());
|
||||
if (sizes.empty()) // zero-dim case
|
||||
return strides;
|
||||
|
||||
strides[dim_order[0]] = 1;
|
||||
for (size_t i = 1; i < dim_order.size(); i++) {
|
||||
auto cur_dim = dim_order[i];
|
||||
auto pre_dim = dim_order[i - 1];
|
||||
strides[cur_dim] = strides[pre_dim] * sizes[pre_dim];
|
||||
}
|
||||
return strides;
|
||||
strides.back() = 1;
|
||||
for (size_t i = strides.size() - 1; i > 0; i--) {
|
||||
strides[i - 1] = strides[i] * sizes[i];
|
||||
};
|
||||
|
||||
std::vector<int64_t> dim_order(sizes.size());
|
||||
if (memory_format == MemoryFormat::ChannelsLast) {
|
||||
dim_order = {1, 3, 2, 0};
|
||||
} else if (memory_format == MemoryFormat::ChannelsLast3d) {
|
||||
dim_order = {1, 4, 3, 2, 0};
|
||||
} else {
|
||||
auto ndims = sizes.size();
|
||||
for (size_t i = 0; i < ndims; i++) {
|
||||
dim_order[i] = ndims - i - 1; // Reverse
|
||||
}
|
||||
}
|
||||
return strides;
|
||||
return contiguous_fn(sizes, dim_order);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -392,8 +392,8 @@ BENCHMARK_DEFINE_F(Reduce1D, Op)(benchmark::State& state) {
|
|||
const int kChunkSize = 8;
|
||||
|
||||
te::BufHandle a("A", {M}, te::kFloat);
|
||||
te::Tensor b =
|
||||
te::computeSum({a, te::IntList({0}), false}, {}, at::kFloat, at::kCPU);
|
||||
te::Tensor b = te::computeSum(
|
||||
{a, te::IntList({0}), false}, {}, {}, at::kFloat, at::kCPU);
|
||||
te::LoopNest nest({b});
|
||||
|
||||
auto loops = nest.getLoopStmtsFor(b);
|
||||
|
|
@ -456,8 +456,8 @@ BENCHMARK_REGISTER_F(Reduce2DCol, Torch)
|
|||
BENCHMARK_DEFINE_F(Reduce2DCol, OpSchedule)(benchmark::State& state) {
|
||||
constexpr int kCacheSize = 1 << 12;
|
||||
te::BufHandle a("A", {M, N}, te::kFloat);
|
||||
te::Tensor b =
|
||||
te::computeSum({a, te::IntList({0}), false}, {N}, at::kFloat, at::kCPU);
|
||||
te::Tensor b = te::computeSum(
|
||||
{a, te::IntList({0}), false}, {N}, {1}, at::kFloat, at::kCPU);
|
||||
te::LoopNest nest({b});
|
||||
|
||||
auto sch = state.range(2);
|
||||
|
|
@ -565,8 +565,8 @@ BENCHMARK_REGISTER_F(Reduce2DRow, Hand)->Args({1 << 18, 1 << 6});
|
|||
BENCHMARK_DEFINE_F(Reduce2DRow, OpSchedule)(benchmark::State& state) {
|
||||
constexpr int kChunkSize = 8;
|
||||
te::BufHandle a("A", {M, N}, te::kFloat);
|
||||
te::Tensor b =
|
||||
te::computeSum({a, te::IntList({1}), false}, {M}, at::kFloat, at::kCPU);
|
||||
te::Tensor b = te::computeSum(
|
||||
{a, te::IntList({1}), false}, {M}, {1}, at::kFloat, at::kCPU);
|
||||
te::LoopNest nest({b});
|
||||
|
||||
auto sch = state.range(2);
|
||||
|
|
|
|||
|
|
@ -944,6 +944,7 @@ TEST(ExternalCall, JitCustomFusionOp) {
|
|||
[external_func_name](
|
||||
const std::vector<torch::jit::tensorexpr::ArgValue>& inputs,
|
||||
const std::vector<torch::jit::tensorexpr::ExprHandle>& output_shape,
|
||||
const std::vector<torch::jit::tensorexpr::ExprHandle>& output_strides,
|
||||
const c10::optional<torch::jit::tensorexpr::ScalarType>& output_type,
|
||||
at::Device device) {
|
||||
auto output_dtype = Dtype(*output_type);
|
||||
|
|
|
|||
|
|
@ -1598,12 +1598,14 @@ TEST_F(Kernel, CodegenInspection) {
|
|||
Tensor lowerNanToNum(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
auto input_buf = c10::get<BufHandle>(inputs[0]);
|
||||
auto e = Compute(
|
||||
"custom_nan_to_num",
|
||||
outputShape,
|
||||
outputStrides,
|
||||
[&](const std::vector<VarHandle>& axes) {
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
auto load = input_buf.load(indices);
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#include <gtest/gtest.h>
|
||||
#include <torch/csrc/jit/tensorexpr/eval.h>
|
||||
#include <torch/csrc/jit/tensorexpr/expr.h>
|
||||
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
||||
#include <torch/csrc/jit/tensorexpr/operators/operators.h>
|
||||
#include <torch/torch.h>
|
||||
|
|
@ -29,7 +30,10 @@ TEST(Ops, Sum) {
|
|||
const auto& outShape = outputShapes[idx];
|
||||
|
||||
BufHandle a("a", {M, N}, kFloat);
|
||||
Tensor b = computeSum({a, dims, false}, outShape, c10::kFloat, at::kCPU);
|
||||
std::vector<ExprHandle> outStrides =
|
||||
c10::fmap<ExprHandle>(make_contiguous_strides(outShape));
|
||||
Tensor b = computeSum(
|
||||
{a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU);
|
||||
auto cg = compile({a}, {b});
|
||||
|
||||
auto at = at::arange(M * N, at::kFloat).view({M, N});
|
||||
|
|
@ -41,3 +45,34 @@ TEST(Ops, Sum) {
|
|||
ASSERT_TRUE(at::allclose(bt, ref));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Ops, ChannelsLastSum) {
|
||||
constexpr int A = 2;
|
||||
constexpr int B = 3;
|
||||
constexpr int C = 4;
|
||||
constexpr int D = 5;
|
||||
constexpr int E = 6;
|
||||
std::vector<IntList> testDims = {{0}, {1}, {0, 1}};
|
||||
|
||||
std::vector<std::vector<ExprHandle>> outputShapes = {
|
||||
{B, C, D, E}, {A, C, D, E}, {C, D, E}};
|
||||
for (unsigned idx = 0; idx < testDims.size(); idx++) {
|
||||
const auto& dims = testDims[idx];
|
||||
const auto& outShape = outputShapes[idx];
|
||||
|
||||
BufHandle a("a", {A, B, C, D, E}, kFloat);
|
||||
std::vector<ExprHandle> outStrides =
|
||||
c10::fmap<ExprHandle>(make_channels_last_strides(outShape));
|
||||
Tensor b = computeSum(
|
||||
{a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU);
|
||||
auto cg = compile({a}, {b});
|
||||
|
||||
auto at = at::arange(A * B * C * D * E, at::kFloat).view({A, B, C, D, E});
|
||||
auto ref = at::sum(at, dims);
|
||||
auto bt = at::empty_like(ref);
|
||||
|
||||
cg->call({at.data_ptr<float>(), bt.data_ptr<float>()});
|
||||
|
||||
ASSERT_TRUE(at::allclose(bt, ref));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -348,7 +348,7 @@ graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)):
|
|||
"""
|
||||
graph = torch._C.parse_ir(graph_str)
|
||||
|
||||
def my_custom_lowering(inputs, out_shape, out_type, device):
|
||||
def my_custom_lowering(inputs, out_shape, out_strides, out_type, device):
|
||||
def compute(idxs):
|
||||
load = inputs[0].as_buf().load(idxs)
|
||||
return te.ifThenElse(
|
||||
|
|
|
|||
|
|
@ -164,8 +164,7 @@ StrideInput summarizeOutputStrides(const TensorType& tt) {
|
|||
// otherwise we defer to contiguous
|
||||
// TODO: channels last 3d
|
||||
// NNC Channels last permutation for outputs causes slowdown, disable
|
||||
if (c10::is_channels_last_strides_2d(sizes, strides) &&
|
||||
!tt.device()->is_cpu()) {
|
||||
if (c10::is_channels_last_strides_2d(sizes, strides)) {
|
||||
return StrideInput::TENSOR_CONT_CHANNELS_LAST;
|
||||
}
|
||||
return StrideInput::TENSOR_CONT;
|
||||
|
|
|
|||
|
|
@ -477,6 +477,11 @@ bool Buf::is_contiguous(at::MemoryFormat memory_format) const {
|
|||
return false;
|
||||
dim_order = {1, 4, 3, 2, 0};
|
||||
} else {
|
||||
if (dims_.empty()) {
|
||||
// Scalar tensor
|
||||
TORCH_CHECK(strides_.empty());
|
||||
return true; // Align with the isContiguous logic in the kernel.cpp
|
||||
}
|
||||
for (size_t i = 0; i < ndims; i++) {
|
||||
dim_order[i] = ndims - i - 1; // Reverse
|
||||
}
|
||||
|
|
|
|||
|
|
@ -185,9 +185,9 @@ class TORCH_API Var : public ExprNode<Var> {
|
|||
std::string name_hint_;
|
||||
};
|
||||
|
||||
std::vector<ExprPtr> make_contiguous_strides(
|
||||
TORCH_API std::vector<ExprPtr> make_contiguous_strides(
|
||||
const std::vector<ExprHandle>& dims);
|
||||
std::vector<ExprPtr> make_channels_last_strides(
|
||||
TORCH_API std::vector<ExprPtr> make_channels_last_strides(
|
||||
const std::vector<ExprHandle>& dims);
|
||||
|
||||
class TORCH_API Buf : public ExprNode<Buf> {
|
||||
|
|
@ -323,7 +323,6 @@ class TORCH_API Buf : public ExprNode<Buf> {
|
|||
bool is_cont_with(int cur_dim, int adjacent_dim) const;
|
||||
bool is_stride_one(int cur_dim) const;
|
||||
|
||||
private:
|
||||
VarPtr base_handle_;
|
||||
std::vector<ExprPtr> dims_;
|
||||
std::vector<ExprPtr> strides_;
|
||||
|
|
|
|||
|
|
@ -208,7 +208,9 @@ std::vector<int64_t> _pair_int(IValue v) {
|
|||
}
|
||||
}
|
||||
|
||||
static bool isContiguous(const torch::jit::Value* v) {
|
||||
static bool isContiguous(
|
||||
const torch::jit::Value* v,
|
||||
at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) {
|
||||
auto const& tt = v->type()->cast<TensorType>();
|
||||
if (!tt) {
|
||||
return false;
|
||||
|
|
@ -221,6 +223,14 @@ static bool isContiguous(const torch::jit::Value* v) {
|
|||
if (!sizes || !strides) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check dimension size first
|
||||
int ndims = (*sizes).size();
|
||||
if ((memory_format == at::MemoryFormat::ChannelsLast && ndims != 4) ||
|
||||
(memory_format == at::MemoryFormat::ChannelsLast3d && ndims != 5)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return *strides == TensorType::contiguousStridesOf(*sizes);
|
||||
}
|
||||
|
||||
|
|
@ -475,8 +485,50 @@ Tensor TensorExprKernel::computeValue(const torch::jit::Value* v) {
|
|||
hasRandom_ = true;
|
||||
}
|
||||
|
||||
// Check whether the node is to create an new tensor
|
||||
// TODO: It is hard to deduce the layout of the generated tensor.
|
||||
bool is_tensor_creation = true;
|
||||
|
||||
// Check if the tensor is a contiguous tensor
|
||||
bool is_contiguous = false;
|
||||
// Check if the tensor is a channels-last contiguous tensor
|
||||
bool is_channels_last_contiguous = false;
|
||||
for (auto input : inputs) {
|
||||
if (input->type()->kind() != TypeKind::TensorType)
|
||||
continue;
|
||||
|
||||
is_tensor_creation = false;
|
||||
TORCH_CHECK(bufs_.count(input) > 0);
|
||||
auto buf_ = bufs_.at(input);
|
||||
|
||||
auto _is_contiguous = buf_->is_contiguous();
|
||||
if (_is_contiguous) {
|
||||
is_contiguous |= _is_contiguous;
|
||||
} else {
|
||||
is_channels_last_contiguous |=
|
||||
(buf_->is_contiguous(at::MemoryFormat::ChannelsLast) ||
|
||||
buf_->is_contiguous(at::MemoryFormat::ChannelsLast3d) ||
|
||||
buf_->is_channels_last_1d_contiguous());
|
||||
}
|
||||
|
||||
// Does not support mixing the contiguous tensor and channels-last contigous
|
||||
// tensor
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
is_tensor_creation ||
|
||||
((is_contiguous ^ is_channels_last_contiguous) &&
|
||||
(is_contiguous || is_channels_last_contiguous)));
|
||||
}
|
||||
|
||||
auto outputType = findDtypeForValue(v);
|
||||
std::vector<ExprHandle> outputShape = sizesForValue(v);
|
||||
std::vector<ExprHandle> outputStrides;
|
||||
if (is_channels_last_contiguous) {
|
||||
outputStrides =
|
||||
c10::fmap<ExprHandle>(make_channels_last_strides(outputShape));
|
||||
} else {
|
||||
// Default
|
||||
outputStrides = c10::fmap<ExprHandle>(make_contiguous_strides(outputShape));
|
||||
}
|
||||
|
||||
std::vector<ArgValue> argInputs;
|
||||
if (op == prim::ConstantChunk) {
|
||||
|
|
@ -521,12 +573,14 @@ Tensor TensorExprKernel::computeValue(const torch::jit::Value* v) {
|
|||
}
|
||||
|
||||
if (NNCLoweringFunction custom_lowering = getCustomLoweringFor(op)) {
|
||||
return custom_lowering(argInputs, outputShape, outputType, device_);
|
||||
return custom_lowering(
|
||||
argInputs, outputShape, outputStrides, outputType, device_);
|
||||
}
|
||||
if (v->node()->maybeSchema()) {
|
||||
if (NNCLoweringFunction lowering =
|
||||
getStandardLoweringFor(c10::toString(v->node()->schema()))) {
|
||||
return lowering(argInputs, outputShape, outputType, device_);
|
||||
return lowering(
|
||||
argInputs, outputShape, outputStrides, outputType, device_);
|
||||
}
|
||||
}
|
||||
std::string msg = std::string("Unhandled node kind (in computeValue): ") +
|
||||
|
|
@ -995,28 +1049,55 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) {
|
|||
auto const& outputs = input->owningGraph()->outputs();
|
||||
std::unordered_set<const Value*> outputs_set(outputs.begin(), outputs.end());
|
||||
|
||||
auto is_concrete_cont = [](const torch::jit::Value* input) {
|
||||
if (input->isCompleteTensor()) {
|
||||
return isContiguous(input) ||
|
||||
isContiguous(input, at::MemoryFormat::ChannelsLast);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
auto is_symbolic_cont = [](std::vector<torch::jit::StrideInput> desc) {
|
||||
if (desc.size() == 1) {
|
||||
return desc[0] == torch::jit::StrideInput::TENSOR_CONT ||
|
||||
desc[0] == torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
Tensor result(nullptr, nullptr);
|
||||
switch (t->kind()) {
|
||||
case TypeKind::TensorType: {
|
||||
auto tt = input->type()->cast<TensorType>();
|
||||
bool contiguous_concrete_tensor =
|
||||
(input->isCompleteTensor() && isContiguous(input));
|
||||
bool contiguous_sym_tensor = false;
|
||||
bool contiguous_concrete_tensor = is_concrete_cont(input);
|
||||
bool contiguous_symbolic_tensor = false;
|
||||
if (has_symbolic_shapes_) {
|
||||
auto desc = getSymbolicInputStrideDesc(input);
|
||||
contiguous_sym_tensor =
|
||||
desc.size() == 1 && desc[0] == torch::jit::StrideInput::TENSOR_CONT;
|
||||
contiguous_symbolic_tensor = is_symbolic_cont(desc);
|
||||
}
|
||||
|
||||
// Get input size and strides
|
||||
auto size_handles = sizesFromSymbolicShape(tt->symbolic_sizes());
|
||||
auto inputTensorStrides = getInputStrides(input, size_handles);
|
||||
|
||||
// We don't need to copy the input if:
|
||||
// 1) it is not an output AND
|
||||
// 2) it is contiguous
|
||||
bool contiguous = contiguous_concrete_tensor || contiguous_sym_tensor;
|
||||
bool contiguous =
|
||||
contiguous_concrete_tensor || contiguous_symbolic_tensor;
|
||||
if (!outputs_set.count(input) && contiguous) {
|
||||
BufHandle inBuffer(
|
||||
"t" + input_name_map_[input],
|
||||
sizesFromSymbolicShape(tt->symbolic_sizes()),
|
||||
inputTensorStrides,
|
||||
ToDtype(static_cast<ScalarType>(*tt->scalarType())));
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
inBuffer.node()->is_contiguous() ||
|
||||
inBuffer.node()->is_channels_last_1d_contiguous() ||
|
||||
inBuffer.node()->is_contiguous(at::MemoryFormat::ChannelsLast) ||
|
||||
inBuffer.node()->is_contiguous(at::MemoryFormat::ChannelsLast3d));
|
||||
bufs_.emplace(input, inBuffer.node());
|
||||
bufferArgs_.emplace_back(inBuffer);
|
||||
break;
|
||||
|
|
@ -1025,8 +1106,6 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) {
|
|||
// if the input isn't contiguous or is an output,
|
||||
// write strided input into contiguous buffer that is
|
||||
// then used in all further compute
|
||||
auto size_handles = sizesFromSymbolicShape(tt->symbolic_sizes());
|
||||
auto inputTensorStrides = getInputStrides(input, size_handles);
|
||||
ExprHandle flat_size = 1;
|
||||
for (size_t i = 0; i < size_handles.size(); ++i) {
|
||||
auto size = size_handles[i];
|
||||
|
|
@ -1168,11 +1247,12 @@ Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides(
|
|||
"Ouput tensor has no corresponding bufs in the fuser."));
|
||||
BufPtr buf = bufs_.at(v);
|
||||
// output is contiguous, no work to do
|
||||
if (tensorOutputStrideDesc_[v->offset()] ==
|
||||
torch::jit::StrideInput::TENSOR_CONT) {
|
||||
auto stride_desc = tensorOutputStrideDesc_[v->offset()];
|
||||
if (stride_desc == torch::jit::StrideInput::TENSOR_CONT ||
|
||||
stride_desc == torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST) {
|
||||
return Tensor(buf, nullptr);
|
||||
;
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
tensorOutputStrideDesc_[v->offset()] ==
|
||||
torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST);
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -32,6 +32,7 @@ using ArgValue = c10::variant<
|
|||
using NNCLoweringFunction = std::function<Tensor(
|
||||
const std::vector<ArgValue>&,
|
||||
const std::vector<ExprHandle>&,
|
||||
const std::vector<ExprHandle>&,
|
||||
const c10::optional<ScalarType>&,
|
||||
at::Device)>;
|
||||
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ Tensor conv2d_depthwise_static(
|
|||
Tensor conv = Reduce(
|
||||
"conv2d_depthwise",
|
||||
{N, K, OH, OW},
|
||||
c10::nullopt, // TODO
|
||||
Sum(),
|
||||
[&](const std::vector<VarHandle>& v) { return init_func(v); },
|
||||
[&](const std::vector<VarHandle>& v) {
|
||||
|
|
@ -121,6 +122,7 @@ Tensor conv2d_depthwise_dynamic(
|
|||
return Reduce(
|
||||
"conv2d_depthwise",
|
||||
{N, K, OH, OW},
|
||||
c10::nullopt, // TODO
|
||||
Sum(),
|
||||
[&](const std::vector<VarHandle>& v) { return init_func(v); },
|
||||
[&](const std::vector<VarHandle>& v) {
|
||||
|
|
@ -308,6 +310,7 @@ bool conv2dIsSupported(
|
|||
Tensor computeConv2d(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
Dtype dtype = kFloat;
|
||||
|
|
@ -355,6 +358,7 @@ Tensor computeConv2d(
|
|||
Tensor computeConv1d(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
Dtype dtype = kFloat;
|
||||
|
|
@ -388,6 +392,7 @@ Tensor computeConv1d(
|
|||
Tensor computePrepackedConv2dClampRun(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
Dtype dtype = kFloat;
|
||||
|
|
@ -406,6 +411,7 @@ Tensor computePrepackedConv2dClampRun(
|
|||
Tensor computePrepackedLinearClampRun(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
Dtype dtype = kFloat;
|
||||
|
|
|
|||
|
|
@ -66,21 +66,25 @@ bool conv2dIsSupported(
|
|||
Tensor computeConv2d(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
Tensor computeConv1d(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
Tensor computePrepackedConv2dClampRun(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
Tensor computePrepackedLinearClampRun(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ namespace tensorexpr {
|
|||
Tensor computeMatmul(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
Dtype dtype = kFloat;
|
||||
|
|
@ -54,6 +55,7 @@ Tensor computeMatmul(
|
|||
Tensor computeAddMM(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
Dtype dtype = kFloat;
|
||||
|
|
|
|||
|
|
@ -9,11 +9,13 @@ namespace tensorexpr {
|
|||
Tensor computeMatmul(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
Tensor computeAddMM(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
|
|
|
|||
|
|
@ -320,6 +320,7 @@ std::vector<ExprHandle> computeIndicesToBroadcast(
|
|||
Tensor computeChunk(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
return Compute(
|
||||
|
|
@ -353,6 +354,7 @@ Tensor computeChunk(
|
|||
Tensor computeTranspose(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
auto A = c10::get<BufHandle>(inputs[0]);
|
||||
|
|
@ -379,6 +381,7 @@ Tensor computeTranspose(
|
|||
Tensor computeExpand(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
auto A = c10::get<BufHandle>(inputs[0]);
|
||||
|
|
@ -392,6 +395,7 @@ Tensor computeExpand(
|
|||
Tensor computeReshape(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
auto A = c10::get<BufHandle>(inputs[0]);
|
||||
|
|
@ -459,6 +463,7 @@ Tensor computeReshape(
|
|||
Tensor computeFlatten(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
std::vector<int64_t> outputShapeVec;
|
||||
|
|
@ -468,7 +473,8 @@ Tensor computeFlatten(
|
|||
std::vector<ArgValue> reshapeInputs;
|
||||
reshapeInputs.push_back(inputs[0]);
|
||||
reshapeInputs.emplace_back(outputShapeVec);
|
||||
return computeReshape(reshapeInputs, outputShape, outputType, device);
|
||||
return computeReshape(
|
||||
reshapeInputs, outputShape, outputStrides, outputType, device);
|
||||
}
|
||||
|
||||
static std::pair<ScalarType, std::vector<BufHandle>> processCatList(
|
||||
|
|
@ -586,6 +592,7 @@ Tensor computeCatWoConditionals(
|
|||
Tensor computeCat(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
if (device == at::kCPU && getCatWoConditionals()) {
|
||||
|
|
@ -598,7 +605,10 @@ Tensor computeCat(
|
|||
ScalarType highType = catInfo.first;
|
||||
std::vector<BufHandle> nonEmptyInputs = catInfo.second;
|
||||
return Compute(
|
||||
"aten_cat", outputShape, [&](const std::vector<VarHandle>& axes) {
|
||||
"aten_cat",
|
||||
outputShape,
|
||||
outputStrides,
|
||||
[&](const std::vector<VarHandle>& axes) {
|
||||
if (nonEmptyInputs.size() == 0) {
|
||||
return ExprHandle(0);
|
||||
}
|
||||
|
|
@ -645,6 +655,7 @@ Tensor computeCat(
|
|||
Tensor computeEmbedding(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
Dtype dtype = kFloat;
|
||||
|
|
|
|||
|
|
@ -50,26 +50,31 @@ ExprHandle clamp(
|
|||
Tensor computeChunk(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
Tensor computeTranspose(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
Tensor computeExpand(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
Tensor computeReshape(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
Tensor computeFlatten(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
Tensor computeCatWoConditionals(
|
||||
|
|
@ -78,11 +83,13 @@ Tensor computeCatWoConditionals(
|
|||
Tensor computeCat(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
Tensor computeEmbedding(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ namespace tensorexpr {
|
|||
Tensor computeBatchNorm(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
bool hasWeight = true;
|
||||
|
|
@ -22,7 +23,10 @@ Tensor computeBatchNorm(
|
|||
}
|
||||
|
||||
return Compute(
|
||||
"aten_batch_norm", outputShape, [&](const std::vector<VarHandle>& axes) {
|
||||
"aten_batch_norm",
|
||||
outputShape,
|
||||
outputStrides,
|
||||
[&](const std::vector<VarHandle>& axes) {
|
||||
TORCH_INTERNAL_ASSERT(axes.size() >= 2);
|
||||
// axes: N, C, H, W
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ namespace tensorexpr {
|
|||
Tensor computeBatchNorm(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
|
|
|
|||
|
|
@ -9,28 +9,32 @@ using namespace torch::jit::tensorexpr;
|
|||
|
||||
Tensor computeSign(
|
||||
const std::vector<ArgValue>& inputValues,
|
||||
const std::vector<ExprHandle>& outputShape) {
|
||||
return Compute("aten_sign", outputShape, [&](ParameterList& axes) {
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
std::vector<ExprHandle> inputs = {
|
||||
tensorOrConstant(inputValues[0], indices)};
|
||||
auto inp = inputs[0];
|
||||
auto zero = ExprHandle(immLike(inp, 0.0f));
|
||||
auto res = (zero < inp) - (inp < zero);
|
||||
return promoteToDtype(res, inp.dtype().scalar_type());
|
||||
});
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
c10::optional<std::vector<ExprHandle>> outputStrides) {
|
||||
return Compute(
|
||||
"aten_sign", outputShape, outputStrides, [&](ParameterList& axes) {
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
std::vector<ExprHandle> inputs = {
|
||||
tensorOrConstant(inputValues[0], indices)};
|
||||
auto inp = inputs[0];
|
||||
auto zero = ExprHandle(immLike(inp, 0.0f));
|
||||
auto res = (zero < inp) - (inp < zero);
|
||||
return promoteToDtype(res, inp.dtype().scalar_type());
|
||||
});
|
||||
}
|
||||
|
||||
Tensor computeOneOperand(
|
||||
const std::string& name,
|
||||
const std::vector<ArgValue>& inputValues,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
const std::function<ExprHandle(const ExprHandle&)>& innerExpr,
|
||||
const int checkParamTypes) {
|
||||
return Compute(
|
||||
name,
|
||||
outputShape,
|
||||
outputStrides,
|
||||
[inputValues, outputType, innerExpr, checkParamTypes](
|
||||
const std::vector<VarHandle>& axes) {
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
|
|
@ -46,12 +50,14 @@ Tensor computeTwoOperand(
|
|||
const std::string& name,
|
||||
const std::vector<ArgValue>& inputValues,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
|
||||
innerExpr) {
|
||||
return Compute(
|
||||
name,
|
||||
outputShape,
|
||||
outputStrides,
|
||||
[inputValues, outputType, innerExpr](const std::vector<VarHandle>& axes) {
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
std::vector<ExprHandle> inputs = {
|
||||
|
|
@ -69,12 +75,14 @@ Tensor computeTwoOperandWithAlpha(
|
|||
const std::string& name,
|
||||
const std::vector<ArgValue>& inputValues,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
|
||||
innerExpr) {
|
||||
return Compute(
|
||||
name,
|
||||
outputShape,
|
||||
outputStrides,
|
||||
[inputValues, outputType, innerExpr](const std::vector<VarHandle>& axes) {
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
std::vector<ExprHandle> inputs = {
|
||||
|
|
@ -93,6 +101,7 @@ Tensor computeConditionWithTwoOperand(
|
|||
const std::string& name,
|
||||
const std::vector<ArgValue>& inputValues,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
const std::function<
|
||||
ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>&
|
||||
|
|
@ -100,6 +109,7 @@ Tensor computeConditionWithTwoOperand(
|
|||
return Compute(
|
||||
name,
|
||||
outputShape,
|
||||
outputStrides,
|
||||
[inputValues, outputType, innerExpr](const std::vector<VarHandle>& axes) {
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
std::vector<ExprHandle> inputs = {
|
||||
|
|
@ -120,6 +130,7 @@ Tensor computeThreeOperand(
|
|||
const std::string& name,
|
||||
const std::vector<ArgValue>& inputValues,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
const std::function<
|
||||
ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>&
|
||||
|
|
@ -128,6 +139,7 @@ Tensor computeThreeOperand(
|
|||
return Compute(
|
||||
name,
|
||||
outputShape,
|
||||
outputStrides,
|
||||
[inputValues, outputType, innerExpr, promote_inputs](
|
||||
const std::vector<VarHandle>& axes) {
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
|
|
@ -148,6 +160,7 @@ Tensor computeFourOperand(
|
|||
const std::string& name,
|
||||
const std::vector<ArgValue>& inputValues,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
const std::function<ExprHandle(
|
||||
const ExprHandle&,
|
||||
|
|
@ -157,6 +170,7 @@ Tensor computeFourOperand(
|
|||
return Compute(
|
||||
name,
|
||||
outputShape,
|
||||
outputStrides,
|
||||
[inputValues, outputType, innerExpr](const std::vector<VarHandle>& axes) {
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
std::vector<ExprHandle> inputs = {
|
||||
|
|
@ -176,18 +190,23 @@ Tensor computeFourOperand(
|
|||
Tensor computeNoop(
|
||||
const std::vector<ArgValue>& inputValues,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
return computeOneOperand(
|
||||
"copy", inputValues, outputShape, outputType, [](const ExprHandle& a) {
|
||||
return a;
|
||||
});
|
||||
"copy",
|
||||
inputValues,
|
||||
outputShape,
|
||||
outputStrides,
|
||||
outputType,
|
||||
[](const ExprHandle& a) { return a; });
|
||||
}
|
||||
|
||||
Tensor computeScalar(
|
||||
const std::string& name,
|
||||
const std::vector<ArgValue>& inputValues,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
|
||||
innerExpr) {
|
||||
|
|
|
|||
|
|
@ -8,12 +8,14 @@ namespace tensorexpr {
|
|||
|
||||
TORCH_API Tensor computeSign(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape);
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
c10::optional<std::vector<ExprHandle>> outputStrides = c10::nullopt);
|
||||
|
||||
Tensor computeOneOperand(
|
||||
const std::string& name,
|
||||
const std::vector<ArgValue>& inputValues,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
const std::function<ExprHandle(const ExprHandle&)>& innerExpr,
|
||||
const int checkParamTypes = kAllTypes);
|
||||
|
|
@ -21,6 +23,7 @@ Tensor computeTwoOperand(
|
|||
const std::string& name,
|
||||
const std::vector<ArgValue>& inputValues,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
|
||||
innerExpr);
|
||||
|
|
@ -28,6 +31,7 @@ Tensor computeTwoOperandWithAlpha(
|
|||
const std::string& name,
|
||||
const std::vector<ArgValue>& inputValues,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
|
||||
innerExpr);
|
||||
|
|
@ -35,6 +39,7 @@ Tensor computeConditionWithTwoOperand(
|
|||
const std::string& name,
|
||||
const std::vector<ArgValue>& inputValues,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
const std::function<
|
||||
ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>&
|
||||
|
|
@ -43,6 +48,7 @@ Tensor computeThreeOperand(
|
|||
const std::string& name,
|
||||
const std::vector<ArgValue>& inputValues,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
const std::function<
|
||||
ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>&
|
||||
|
|
@ -52,6 +58,7 @@ Tensor computeFourOperand(
|
|||
const std::string& name,
|
||||
const std::vector<ArgValue>& inputValues,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
const std::function<ExprHandle(
|
||||
const ExprHandle&,
|
||||
|
|
@ -61,6 +68,7 @@ Tensor computeFourOperand(
|
|||
Tensor computeNoop(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
|
|
@ -68,6 +76,7 @@ Tensor computeScalar(
|
|||
const std::string& name,
|
||||
const std::vector<ArgValue>& inputValues,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
|
||||
innerExpr);
|
||||
|
|
|
|||
|
|
@ -138,6 +138,7 @@ ExprHandle dequant(
|
|||
Tensor computeQuantizePerTensor(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>&,
|
||||
at::Device) {
|
||||
std::vector<VarPtr> vars;
|
||||
|
|
@ -177,6 +178,7 @@ Tensor computeQuantizePerTensor(
|
|||
Tensor computeQuantizedAdd(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device) {
|
||||
const BufHandle& QA = c10::get<BufHandle>(inputs[0]);
|
||||
|
|
@ -219,6 +221,7 @@ Tensor computeQuantizedAdd(
|
|||
Tensor computeQuantizePerTensorExternalCall(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
// NOLINTNEXTLINE
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device) {
|
||||
|
|
@ -251,6 +254,7 @@ Tensor computeQuantizePerTensorExternalCall(
|
|||
Tensor computeDequantizeExternalCall(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device) {
|
||||
Dtype dtype = kFloat;
|
||||
|
|
@ -275,6 +279,7 @@ Tensor computeDequantizeExternalCall(
|
|||
Tensor computeQuantizedConv2dPrepack(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device) {
|
||||
Dtype dtype = kFloat;
|
||||
|
|
@ -323,6 +328,7 @@ Tensor computeQuantizedConv2dPrepack(
|
|||
Tensor computeQuantizedConv1d(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
// NOLINTNEXTLINE
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
// NOLINTNEXTLINE
|
||||
|
|
@ -354,6 +360,7 @@ Tensor computeQuantizedConv1d(
|
|||
Tensor computeQuantizedConv2d(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
// NOLINTNEXTLINE
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
// NOLINTNEXTLINE
|
||||
|
|
@ -385,6 +392,7 @@ Tensor computeQuantizedConv2d(
|
|||
Tensor computeQuantizedConv2dRelu(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
// NOLINTNEXTLINE
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
// NOLINTNEXTLINE
|
||||
|
|
@ -416,6 +424,7 @@ Tensor computeQuantizedConv2dRelu(
|
|||
Tensor computeQuantizedLinear(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
// NOLINTNEXTLINE
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
// NOLINTNEXTLINE
|
||||
|
|
@ -447,6 +456,7 @@ Tensor computeQuantizedLinear(
|
|||
Tensor computeQuantizedLinearRelu(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
// NOLINTNEXTLINE
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
// NOLINTNEXTLINE
|
||||
|
|
@ -478,6 +488,7 @@ Tensor computeQuantizedLinearRelu(
|
|||
Tensor computeQuantizedAddExternalCall(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
// NOLINTNEXTLINE
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
// NOLINTNEXTLINE
|
||||
|
|
@ -521,6 +532,7 @@ Tensor computeQuantizedAddExternalCall(
|
|||
Tensor computeQuantizedMul(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
// NOLINTNEXTLINE
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
// NOLINTNEXTLINE
|
||||
|
|
@ -551,6 +563,7 @@ Tensor computeQuantizedMul(
|
|||
Tensor computeQuantizedMulScalar(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
// NOLINTNEXTLINE
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
// NOLINTNEXTLINE
|
||||
|
|
@ -577,6 +590,7 @@ Tensor computeQuantizedMulScalar(
|
|||
Tensor computeQuantizedRelu(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
// NOLINTNEXTLINE
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
// NOLINTNEXTLINE
|
||||
|
|
@ -607,6 +621,7 @@ Tensor computeQuantizedRelu(
|
|||
Tensor computeQuantizedCat(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
// NOLINTNEXTLINE
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
// NOLINTNEXTLINE
|
||||
|
|
@ -645,6 +660,7 @@ Tensor computeQuantizedCat(
|
|||
Tensor computeDequantize(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device) {
|
||||
Dtype dtype = kFloat;
|
||||
|
|
@ -676,6 +692,7 @@ Tensor computeDequantize(
|
|||
Tensor computeUpsampleNearest2d(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device) {
|
||||
auto A = c10::get<BufHandle>(inputs[0]);
|
||||
|
|
@ -721,6 +738,7 @@ Tensor computeUpsampleNearest2d(
|
|||
Tensor computeUpsampleNearest2dExternalCall(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device) {
|
||||
Dtype dtype = kFloat;
|
||||
|
|
@ -779,6 +797,7 @@ Tensor computeUpsampleNearest2dExternalCall(
|
|||
Tensor computeQuantizedSigmoidExternalCall(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
// NOLINTNEXTLINE
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device) {
|
||||
|
|
|
|||
|
|
@ -19,120 +19,140 @@ TORCH_API bool isQuantized(const BufHandle& qx);
|
|||
TORCH_API Tensor computeQuantizePerTensor(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeQuantizePerTensorExternalCall(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeQuantizedConv1d(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeQuantizedConv2dPrepack(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeQuantizedConv1d(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeQuantizedConv2d(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeQuantizedConv2dRelu(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeQuantizedLinear(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeQuantizedLinearRelu(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeQuantizedAdd(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
Tensor computeQuantizedAddExternalCall(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeQuantizedMul(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeQuantizedMulScalar(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeQuantizedCat(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeQuantizedRelu(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeDequantize(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeDequantizeExternalCall(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeUpsampleNearest2d(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeUpsampleNearest2dExternalCall(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeQuantizedSigmoidExternalCall(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device);
|
||||
} // namespace tensorexpr
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ namespace tensorexpr {
|
|||
Tensor computeSum(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
std::vector<size_t> axes;
|
||||
|
|
@ -72,6 +73,7 @@ Tensor computeSum(
|
|||
return Reduce(
|
||||
"sum",
|
||||
outputDims,
|
||||
outputStrides,
|
||||
Sum(),
|
||||
[&](ParameterList& indices) {
|
||||
// "Squeeze" out indices inserted when keepdim is set.
|
||||
|
|
@ -105,6 +107,7 @@ Tensor computeSum(
|
|||
Tensor computeMean(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
Dtype dtype = kFloat;
|
||||
|
|
@ -137,6 +140,7 @@ Tensor computeMean(
|
|||
Tensor computeMax(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
Dtype dtype = kFloat;
|
||||
|
|
@ -160,6 +164,7 @@ Tensor computeMax(
|
|||
Tensor computeAdaptiveAvgPool2d(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
Dtype dtype = kFloat;
|
||||
|
|
|
|||
|
|
@ -9,21 +9,25 @@ namespace tensorexpr {
|
|||
TORCH_API Tensor computeSum(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
TORCH_API Tensor computeMean(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
TORCH_API Tensor computeAdaptiveAvgPool2d(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
Tensor computeMax(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ using namespace torch::jit::tensorexpr;
|
|||
Tensor computeSoftmax(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
bool log_softmax) {
|
||||
// Softmax is computed as follows:
|
||||
// softmax(vi) = exp(vi) / sum(exp(vi))
|
||||
|
|
@ -102,14 +103,18 @@ Tensor computeSoftmax(
|
|||
auto max = Reduce(
|
||||
"aten_softmax_max",
|
||||
non_softmax_dims,
|
||||
c10::nullopt,
|
||||
Maximum(dtype),
|
||||
[&](ParameterList& indices) {
|
||||
return tensorOrConstant(
|
||||
inputs[0], move_softmax_dim_index_to_pos(indices));
|
||||
},
|
||||
{outputShape[softmax_dim]});
|
||||
auto e =
|
||||
Compute("aten_softmax_exp", outputShape, [&](ParameterList& indices) {
|
||||
auto e = Compute(
|
||||
"aten_softmax_exp",
|
||||
outputShape,
|
||||
c10::nullopt,
|
||||
[&](ParameterList& indices) {
|
||||
auto inp = tensorOrConstant(
|
||||
inputs[0], convert_indices_to_expr_handle(indices));
|
||||
return exp(inp - max.load(remove_softmax_dim_index(indices)));
|
||||
|
|
@ -117,14 +122,15 @@ Tensor computeSoftmax(
|
|||
auto sum = Reduce(
|
||||
"aten_softmax_sum",
|
||||
non_softmax_dims,
|
||||
c10::nullopt,
|
||||
Sum(),
|
||||
[&](ParameterList& indices) {
|
||||
return e.load(move_softmax_dim_index_to_pos(indices));
|
||||
},
|
||||
{outputShape[softmax_dim]});
|
||||
if (!log_softmax) {
|
||||
auto result =
|
||||
Compute("aten_softmax", outputShape, [&](ParameterList& indices) {
|
||||
auto result = Compute(
|
||||
"aten_softmax", outputShape, c10::nullopt, [&](ParameterList& indices) {
|
||||
return e.load(indices) / sum.load(remove_softmax_dim_index(indices));
|
||||
});
|
||||
return Tensor(
|
||||
|
|
@ -134,11 +140,15 @@ Tensor computeSoftmax(
|
|||
}
|
||||
|
||||
auto log_sum = Compute(
|
||||
"aten_softmax_log_sum", non_softmax_dims, [&](ParameterList& indices) {
|
||||
return log(sum.load(indices));
|
||||
});
|
||||
auto result =
|
||||
Compute("aten_log_softmax", outputShape, [&](ParameterList& indices) {
|
||||
"aten_softmax_log_sum",
|
||||
non_softmax_dims,
|
||||
c10::nullopt,
|
||||
[&](ParameterList& indices) { return log(sum.load(indices)); });
|
||||
auto result = Compute(
|
||||
"aten_log_softmax",
|
||||
outputShape,
|
||||
c10::nullopt,
|
||||
[&](ParameterList& indices) {
|
||||
auto inp = tensorOrConstant(
|
||||
inputs[0], convert_indices_to_expr_handle(indices));
|
||||
auto non_softmax_indices = remove_softmax_dim_index(indices);
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ namespace tensorexpr {
|
|||
Tensor computeSoftmax(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
bool log_softmax);
|
||||
|
||||
} // namespace tensorexpr
|
||||
|
|
|
|||
|
|
@ -39,9 +39,31 @@ StmtPtr Tensor::constructStmt(
|
|||
}
|
||||
}
|
||||
|
||||
for (const auto i : c10::irange(ndim)) {
|
||||
// Going in reverse order: from innermost loop to the outermost
|
||||
size_t dim_index = ndim - i - 1;
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
buf_->is_contiguous() ||
|
||||
buf_->is_contiguous(at::MemoryFormat::ChannelsLast) ||
|
||||
buf_->is_contiguous(at::MemoryFormat::ChannelsLast3d) ||
|
||||
buf_->is_channels_last_1d_contiguous());
|
||||
|
||||
auto loop_order_fn = [&]() {
|
||||
std::vector<int32_t> loop_order;
|
||||
if (buf_->is_contiguous()) {
|
||||
for (int32_t i = args.size() - 1; i >= 0; i--) {
|
||||
loop_order.push_back(i);
|
||||
}
|
||||
} else if (buf_->is_contiguous(c10::MemoryFormat::ChannelsLast)) {
|
||||
loop_order = {1, 3, 2, 0};
|
||||
} else if (buf_->is_contiguous(c10::MemoryFormat::ChannelsLast3d)) {
|
||||
loop_order = {1, 4, 3, 2, 0};
|
||||
} else {
|
||||
loop_order = {1, 2, 0};
|
||||
}
|
||||
|
||||
return loop_order;
|
||||
};
|
||||
|
||||
auto loop_order = loop_order_fn();
|
||||
for (auto dim_index : loop_order) {
|
||||
auto const& dim = buf()->dim(dim_index);
|
||||
s = alloc<For>(args[dim_index], immLike(dim, 0), dim, s);
|
||||
}
|
||||
|
|
@ -51,16 +73,24 @@ StmtPtr Tensor::constructStmt(
|
|||
Tensor Compute(
|
||||
const std::string& name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
c10::optional<std::vector<ExprHandle>> strides,
|
||||
const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func) {
|
||||
std::vector<VarHandle> args = create_index_vars(dims);
|
||||
ExprHandle body = body_func(args);
|
||||
BufHandle buf = Buf::make(name, dims, body.dtype());
|
||||
BufHandle buf = Buf::make(name, dims, body.dtype(), c10::nullopt, strides);
|
||||
return Tensor(buf, args, body);
|
||||
}
|
||||
Tensor Compute(
|
||||
const std::string& name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func) {
|
||||
return Compute(name, dims, c10::nullopt, body_func);
|
||||
}
|
||||
|
||||
Tensor Compute(
|
||||
const std::string& name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
c10::optional<std::vector<ExprHandle>> strides,
|
||||
const std::function<ExprHandle(const VarHandle&)>& body_func) {
|
||||
if (dims.size() != 1) {
|
||||
throw malformed_input("mismatch between body and arg size (1)");
|
||||
|
|
@ -68,13 +98,20 @@ Tensor Compute(
|
|||
|
||||
std::vector<VarHandle> args = create_index_vars(dims);
|
||||
ExprHandle body = body_func(args[0]);
|
||||
BufHandle buf = Buf::make(name, dims, body.dtype());
|
||||
BufHandle buf = Buf::make(name, dims, body.dtype(), c10::nullopt, strides);
|
||||
return Tensor(buf, args, body);
|
||||
}
|
||||
Tensor Compute(
|
||||
const std::string& name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const std::function<ExprHandle(const VarHandle&)>& body_func) {
|
||||
return Compute(name, dims, c10::nullopt, body_func);
|
||||
}
|
||||
|
||||
Tensor Compute(
|
||||
const std::string& name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
c10::optional<std::vector<ExprHandle>> strides,
|
||||
const std::function<ExprHandle(const VarHandle&, const VarHandle&)>&
|
||||
body_func) {
|
||||
if (dims.size() != 2) {
|
||||
|
|
@ -82,13 +119,21 @@ Tensor Compute(
|
|||
}
|
||||
std::vector<VarHandle> args = create_index_vars(dims);
|
||||
ExprHandle body = body_func(args[0], args[1]);
|
||||
BufHandle buf = Buf::make(name, dims, body.dtype());
|
||||
BufHandle buf = Buf::make(name, dims, body.dtype(), c10::nullopt, strides);
|
||||
return Tensor(buf, args, body);
|
||||
}
|
||||
Tensor Compute(
|
||||
const std::string& name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const std::function<ExprHandle(const VarHandle&, const VarHandle&)>&
|
||||
body_func) {
|
||||
return Compute(name, dims, c10::nullopt, body_func);
|
||||
}
|
||||
|
||||
Tensor Compute(
|
||||
const std::string& name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
c10::optional<std::vector<ExprHandle>> strides,
|
||||
const std::function<
|
||||
ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>&
|
||||
body_func) {
|
||||
|
|
@ -97,13 +142,22 @@ Tensor Compute(
|
|||
}
|
||||
std::vector<VarHandle> args = create_index_vars(dims);
|
||||
ExprHandle body = body_func(args[0], args[1], args[2]);
|
||||
BufHandle buf = Buf::make(name, dims, body.dtype());
|
||||
BufHandle buf = Buf::make(name, dims, body.dtype(), c10::nullopt, strides);
|
||||
return Tensor(buf, args, body);
|
||||
}
|
||||
Tensor Compute(
|
||||
const std::string& name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const std::function<
|
||||
ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>&
|
||||
body_func) {
|
||||
return Compute(name, dims, c10::nullopt, body_func);
|
||||
}
|
||||
|
||||
Tensor Compute(
|
||||
const std::string& name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
c10::optional<std::vector<ExprHandle>> strides,
|
||||
const std::function<ExprHandle(
|
||||
const VarHandle&,
|
||||
const VarHandle&,
|
||||
|
|
@ -114,37 +168,67 @@ Tensor Compute(
|
|||
}
|
||||
std::vector<VarHandle> args = create_index_vars(dims);
|
||||
ExprHandle body = body_func(args[0], args[1], args[2], args[3]);
|
||||
BufHandle buf = Buf::make(name, dims, body.dtype());
|
||||
BufHandle buf = Buf::make(name, dims, body.dtype(), c10::nullopt, strides);
|
||||
return Tensor(buf, args, body);
|
||||
}
|
||||
Tensor Compute(
|
||||
const std::string& name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const std::function<ExprHandle(
|
||||
const VarHandle&,
|
||||
const VarHandle&,
|
||||
const VarHandle&,
|
||||
const VarHandle&)>& body_func) {
|
||||
return Compute(name, dims, c10::nullopt, body_func);
|
||||
}
|
||||
|
||||
Tensor Reduce(
|
||||
const std::string& name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
c10::optional<std::vector<ExprHandle>> strides,
|
||||
const Reducer& reducer,
|
||||
const BufHandle& buffer,
|
||||
const std::vector<ExprHandle>& reduce_dims) {
|
||||
return Reduce(
|
||||
name,
|
||||
dims,
|
||||
strides,
|
||||
reducer,
|
||||
[&](ParameterList& p) { return buffer.load(p); },
|
||||
reduce_dims);
|
||||
}
|
||||
Tensor Reduce(
|
||||
const std::string& name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const Reducer& reducer,
|
||||
const BufHandle& buffer,
|
||||
const std::vector<ExprHandle>& reduce_dims) {
|
||||
return Reduce(name, dims, c10::nullopt, reducer, buffer, reduce_dims);
|
||||
}
|
||||
|
||||
Tensor Reduce(
|
||||
const std::string& name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
c10::optional<std::vector<ExprHandle>> strides,
|
||||
const Reducer& reducer,
|
||||
Tensor tensor,
|
||||
const std::vector<ExprHandle>& reduce_dims) {
|
||||
return Reduce(
|
||||
name,
|
||||
dims,
|
||||
strides,
|
||||
reducer,
|
||||
[&](ParameterList& p) { return tensor.load(p); },
|
||||
reduce_dims);
|
||||
}
|
||||
Tensor Reduce(
|
||||
const std::string& name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const Reducer& reducer,
|
||||
Tensor tensor,
|
||||
const std::vector<ExprHandle>& reduce_dims) {
|
||||
return Reduce(name, dims, c10::nullopt, reducer, tensor, reduce_dims);
|
||||
}
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
|
|
|
|||
|
|
@ -73,12 +73,30 @@ class TORCH_API Tensor {
|
|||
TORCH_API Tensor Compute(
|
||||
const std::string& func_name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
c10::optional<std::vector<ExprHandle>> strides,
|
||||
const std::function<ExprHandle(const VarHandle&)>& body_func);
|
||||
TORCH_API Tensor Compute(
|
||||
const std::string& func_name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const std::function<ExprHandle(const VarHandle&)>& body_func);
|
||||
TORCH_API Tensor Compute(
|
||||
const std::string& func_name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
c10::optional<std::vector<ExprHandle>> strides,
|
||||
const std::function<ExprHandle(const VarHandle&, const VarHandle&)>&
|
||||
body_func);
|
||||
TORCH_API Tensor Compute(
|
||||
const std::string& func_name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const std::function<ExprHandle(const VarHandle&, const VarHandle&)>&
|
||||
body_func);
|
||||
TORCH_API Tensor Compute(
|
||||
const std::string& func_name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
c10::optional<std::vector<ExprHandle>> strides,
|
||||
const std::function<
|
||||
ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>&
|
||||
body_func);
|
||||
TORCH_API Tensor Compute(
|
||||
const std::string& func_name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
|
|
@ -88,11 +106,25 @@ TORCH_API Tensor Compute(
|
|||
TORCH_API Tensor Compute(
|
||||
const std::string& func_name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
c10::optional<std::vector<ExprHandle>> strides,
|
||||
const std::function<ExprHandle(
|
||||
const VarHandle&,
|
||||
const VarHandle&,
|
||||
const VarHandle&,
|
||||
const VarHandle&)>& body_func);
|
||||
TORCH_API Tensor Compute(
|
||||
const std::string& func_name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const std::function<ExprHandle(
|
||||
const VarHandle&,
|
||||
const VarHandle&,
|
||||
const VarHandle&,
|
||||
const VarHandle&)>& body_func);
|
||||
TORCH_API Tensor Compute(
|
||||
const std::string& func_name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
c10::optional<std::vector<ExprHandle>> strides,
|
||||
const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func);
|
||||
TORCH_API Tensor Compute(
|
||||
const std::string& func_name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
|
|
@ -114,6 +146,7 @@ template <typename InitFunc, typename BodyFunc>
|
|||
Tensor Reduce(
|
||||
const std::string& func_name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
c10::optional<std::vector<ExprHandle>> strides,
|
||||
const Reducer& reducer,
|
||||
const InitFunc& init_func,
|
||||
const BodyFunc& body_func,
|
||||
|
|
@ -125,7 +158,8 @@ Tensor Reduce(
|
|||
// copy
|
||||
if (reduce_vars.empty()) {
|
||||
ExprHandle body = Reducer::getReduceBody(body_func, vars);
|
||||
BufHandle func_result = Buf::make(func_name, dims, body.dtype());
|
||||
BufHandle func_result =
|
||||
Buf::make(func_name, dims, body.dtype(), c10::nullopt, strides);
|
||||
return Tensor(func_result, vars, body);
|
||||
}
|
||||
|
||||
|
|
@ -141,7 +175,41 @@ Tensor Reduce(
|
|||
Tensor t = Tensor(func_result, vars, reduce_dims, reduce_vars, reduce_op);
|
||||
return t;
|
||||
}
|
||||
template <typename InitFunc, typename BodyFunc>
|
||||
Tensor Reduce(
|
||||
const std::string& func_name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const Reducer& reducer,
|
||||
const InitFunc& init_func,
|
||||
const BodyFunc& body_func,
|
||||
const std::vector<ExprHandle>& reduce_dims) {
|
||||
return Reduce<InitFunc, BodyFunc>(
|
||||
func_name,
|
||||
dims,
|
||||
c10::nullopt,
|
||||
reducer,
|
||||
init_func,
|
||||
body_func,
|
||||
reduce_dims);
|
||||
}
|
||||
|
||||
template <typename BodyFunc>
|
||||
Tensor Reduce(
|
||||
const std::string& func_name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
c10::optional<std::vector<ExprHandle>> strides,
|
||||
const Reducer& reducer,
|
||||
const BodyFunc& body_func,
|
||||
const std::vector<ExprHandle>& reduce_dims) {
|
||||
return Reduce(
|
||||
func_name,
|
||||
dims,
|
||||
strides,
|
||||
reducer,
|
||||
[&](ParameterList p) { return ExprHandle(reducer.initializer()); },
|
||||
body_func,
|
||||
reduce_dims);
|
||||
}
|
||||
template <typename BodyFunc>
|
||||
Tensor Reduce(
|
||||
const std::string& func_name,
|
||||
|
|
@ -149,13 +217,8 @@ Tensor Reduce(
|
|||
const Reducer& reducer,
|
||||
const BodyFunc& body_func,
|
||||
const std::vector<ExprHandle>& reduce_dims) {
|
||||
return Reduce(
|
||||
func_name,
|
||||
dims,
|
||||
reducer,
|
||||
[&](ParameterList p) { return ExprHandle(reducer.initializer()); },
|
||||
body_func,
|
||||
reduce_dims);
|
||||
return Reduce<BodyFunc>(
|
||||
func_name, dims, c10::nullopt, reducer, body_func, reduce_dims);
|
||||
}
|
||||
|
||||
// Overload which allows inline lambda functions for the body_func.
|
||||
|
|
@ -163,12 +226,29 @@ template <typename BodyFunc>
|
|||
Tensor Reduce(
|
||||
const std::string& func_name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
c10::optional<std::vector<ExprHandle>> strides,
|
||||
const Reducer& reducer,
|
||||
const BodyFunc&& body_func,
|
||||
const std::vector<ExprHandle>& reduce_dims) {
|
||||
return Reduce(func_name, dims, reducer, body_func, reduce_dims);
|
||||
return Reduce(func_name, dims, strides, reducer, body_func, reduce_dims);
|
||||
}
|
||||
template <typename BodyFunc>
|
||||
Tensor Reduce(
|
||||
const std::string& func_name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const Reducer& reducer,
|
||||
const BodyFunc&& body_func,
|
||||
const std::vector<ExprHandle>& reduce_dims) {
|
||||
return Reduce(func_name, dims, c10::nullopt, reducer, body_func, reduce_dims);
|
||||
}
|
||||
|
||||
TORCH_API Tensor Reduce(
|
||||
const std::string& name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
c10::optional<std::vector<ExprHandle>> strides,
|
||||
const Reducer& reducer,
|
||||
const BufHandle& buffer,
|
||||
const std::vector<ExprHandle>& reduce_dims);
|
||||
TORCH_API Tensor Reduce(
|
||||
const std::string& name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
|
|
@ -178,6 +258,13 @@ TORCH_API Tensor Reduce(
|
|||
|
||||
// Overload for the common case of all dimensions of a prevously Computed
|
||||
// Tensor.
|
||||
TORCH_API Tensor Reduce(
|
||||
const std::string& func_name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
c10::optional<std::vector<ExprHandle>> strides,
|
||||
const Reducer& reducer,
|
||||
Tensor tensor,
|
||||
const std::vector<ExprHandle>& reduce_dims);
|
||||
TORCH_API Tensor Reduce(
|
||||
const std::string& func_name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
|
|
|
|||
|
|
@ -707,8 +707,14 @@ void initTensorExprBindings(PyObject* module) {
|
|||
}
|
||||
if (NNCLoweringFunction lowering =
|
||||
getStandardLoweringFor(op.toQualString())) {
|
||||
std::vector<ExprHandle> outputStrides =
|
||||
c10::fmap<ExprHandle>(make_channels_last_strides(outputShape));
|
||||
return lowering(
|
||||
argInputs, outputShape, outputType.scalar_type(), at::kCPU);
|
||||
argInputs,
|
||||
outputShape,
|
||||
outputStrides,
|
||||
outputType.scalar_type(),
|
||||
at::kCPU);
|
||||
}
|
||||
std::string msg = std::string("Unhandled node kind (in te.lower): ") +
|
||||
op.toQualString();
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user