pytorch/torch/csrc/jit/runtime/static/ops.cpp
Brian Hirsh 779cae9e42 port at::pow to structured (#53669)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53669

This PR does two things:
* Ports `pow` to be structured
* Fixes a bug with how pow handles mixed cpu and cuda tensors

**bug fix**
Pow is a binary op, and all binary ops that use TensorIterator are currently written to handle the case when one of the inputs is a CUDA tensor, and the other is a zero-dimensional cpu tensor.

`pow` incidentally only handles one of the two cases: it fails when the CUDA tensor is passed as the exponent, e.g. `at::pow(torch.tensor(2.0, device='cpu'), torch.tensor([2, 2], device='cuda'))`. Porting `pow` to structured happened to change the error that was outputted from a `TORCH_CHECK` in TensorIterator to an `INTERNAL_ASSERT` in loop.cuh, so I ended up trying to fix the error and update the tests. I added more details in a comment on the PR.

**notes on the structured port**
Pow is a little weird, so I wrote down a couple of issues I noticed during the port:
* Multiple independent overloads. `pow` has two overloads that have their own cpu/cuda kernels, meaning one doesn't call the other. I have to update the names of the kernel overloads to make the compiler happy, since the codegen would otherwise try to generate two classes with the same name. `pow` actually has 3 overloads that all have `out` variants, so I ported all 3 to structured- one of them just happens to redispatch one of the others in most cases.
* Name propagation. Is name propagation implemented per operator? Or is expected to work for most/all ops by default. Right now it looks like it happens for TensorIterator ops by default. For ops that don't use TensorIterator, we need to explicitly pass the names through to the `set_output()` call in the meta function. This happened to matter for `pow` because it has 3 overloads, but only two of them directly use TensorIterator. I had to pass names directly to `set_output` in the 3rd overload to make tests happy.
*  Lack of `const Tensor &` in the C++ API. It's a goal to slowly make all `Tensor &` arguments const as part of the structured port, but in this case I needed to explicitly cast constness away because one structured kernel called back into the C++ API, which still has ordinary `Tensor &` arguments. This probably isn't something we'll fix soon, since we have boxing logic that actually relies on the `Tensor &` / `const Tensor &` distinction in some places.

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D27029821

Pulled By: bdhirsh

fbshipit-source-id: c1786e770de6e6c2474b9a48210b88057ab1018e
2021-03-19 14:30:48 -07:00

1096 lines
38 KiB
C++

#include <torch/csrc/jit/runtime/static/ops.h>
#include <ATen/CPUFunctions.h>
#include <ATen/InferSize.h>
#include <ATen/NativeFunctions.h>
#include <ATen/TensorUtils.h>
#include <ATen/native/EmbeddingBag.h>
#include <ATen/native/IndexingUtils.h>
#include <ATen/native/Resize.h>
#include <ATen/native/TensorAdvancedIndexing.h>
#include <ATen/native/quantized/cpu/qembeddingbag.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/vararg_functions.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
#include <torch/csrc/jit/tensorexpr/loopnest.h>
namespace at {
namespace native {
// The out variants of view ops can't be moved to aten because they don't
// exactly follow the semantics of the aten ops. aten::reshape/flatten create
// views, t, that are tracked by autograd and t.is_view() returns true. Here
// t.is_view() would return false instead.
at::Tensor& reshape_out(
at::Tensor& out,
const at::Tensor& self,
const std::vector<int64_t>& proposed_shape,
bool infer_size = true) {
auto shape = infer_size ? at::infer_size(proposed_shape, self.numel())
: proposed_shape;
auto stride = at::detail::computeStride(self.sizes(), self.strides(), shape);
if (stride.has_value()) {
// create view
if (!out.defined() || !out.storage().is_alias_of(self.storage())) {
auto impl = c10::make_intrusive<c10::TensorImpl>(
c10::Storage(self.storage()), self.key_set(), self.dtype());
out = at::Tensor(std::move(impl));
}
c10::TensorImpl* impl = out.unsafeGetTensorImpl();
impl->set_storage_offset(self.storage_offset());
impl->set_sizes_and_strides(shape, *stride);
} else {
// copy over tensor
if (!out.defined()) {
out = at::native::empty_like(
self, self.options(), at::MemoryFormat::Contiguous);
}
// copy first and set shape/strides later. It doesn't work the other way
// around.
at::native::copy_(out, self);
stride = at::detail::computeStride(out.sizes(), out.strides(), shape);
c10::TensorImpl* impl = out.unsafeGetTensorImpl();
impl->set_sizes_and_strides(shape, *stride);
}
// namedinference::propagate_names(output, self);
return out;
}
at::Tensor& flatten_out(
at::Tensor& out,
const at::Tensor& self,
int64_t start_dim,
int64_t end_dim) {
start_dim =
start_dim < 0 ? c10::maybe_wrap_dim(start_dim, self.dim()) : start_dim;
end_dim = end_dim < 0 ? c10::maybe_wrap_dim(end_dim, self.dim()) : end_dim;
TORCH_CHECK(
start_dim <= end_dim,
"flatten() has invalid args: start_dim cannot come after end_dim");
if (self.dim() == 0) {
return reshape_out(out, self, {1}, false);
}
if (start_dim == end_dim) {
out = self;
return out;
}
// We don't want to infer_size on the entire shape, because that can give us
// an extra degree of freedom we don't want; for example, consider shape [0,
// 1, 3, 0], with start_dim=1, end_dim=2. It's clear we want result shape [0,
// 3, 0] but passing [0, -1, 0] to infer_size means the -1 can take on any
// value and satisfy the constraints.
auto iter = self.sizes().data();
auto slice_numel = std::accumulate(
iter + start_dim,
iter + end_dim + 1,
static_cast<int64_t>(1),
std::multiplies<int64_t>());
std::vector<int64_t> shape;
shape.reserve(self.dim() - end_dim + start_dim);
for (int64_t i = 0; i < start_dim; i++) {
shape.push_back(self.sizes()[i]);
}
shape.push_back(slice_numel);
for (int64_t i = end_dim + 1; i < self.dim(); i++) {
shape.push_back(self.sizes()[i]);
}
return reshape_out(out, self, shape, false);
}
at::Tensor& to_copy_out(Tensor& out, const Tensor& self, bool non_blocking) {
if (!out.options().memory_format_opt().has_value()) {
at::native::resize_impl_cpu_(
out.unsafeGetTensorImpl(), self.sizes(), self.strides());
at::native::copy_(out, self, non_blocking);
return out;
}
at::native::resize_(out, self.sizes(), c10::nullopt);
at::native::copy_(out, self, non_blocking);
return out;
}
} // namespace native
} // namespace at
namespace torch {
namespace jit {
C10_DEFINE_REGISTRY(SROperatorRegistry, SROperatorFunctor);
bool canRunOutOfPlace(Node* n) {
auto op_name = std::string(n->kind().toQualString());
return SROperatorRegistry()->Has(op_name);
}
// Keep function canReuseInputsOutputs because the name canReuseInputsOutputs is
// more informative where it's used
bool canReuseInputsOutputs(Node* n) {
return canRunOutOfPlace(n);
}
// TODO: expand to include all view producing ops, mostly in
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
bool canRunNatively(Node* n) {
// In alphabetical order
const static std::unordered_set<std::string> native_nodes{
"aten::flatten",
"aten::reshape",
"aten::slice",
"aten::transpose",
"aten::to",
"prim::ListConstruct",
"prim::ListUnpack",
"prim::TupleConstruct"};
auto str = std::string(n->kind().toQualString());
if (!native_nodes.count(str)) {
return false;
}
if (str == "aten::to") {
return n->inputs().size() == 5;
}
return true;
}
// returns true if the producers of the inputs
// to this operations are out of place.
// This means the IValues will not change run to run
bool inputsCanRunOutOfPlace(Node* n) {
for (auto* input : n->inputs()) {
if (!canRunOutOfPlace(input->node())) {
return false;
}
}
return true;
}
bool canOptimizeConstruct(Node* n) {
const auto& type = n->output()->type();
if (type->kind() == TypeKind::ListType) {
const auto& list_type = type->expectRef<ListType>();
bool is_tensor_list =
list_type.getElementType()->kind() == TypeKind::TensorType;
return is_tensor_list && inputsCanRunOutOfPlace(n);
} else if (type->kind() == TypeKind::TupleType) {
const auto& tuple_type = type->expectRef<TupleType>();
auto types = tuple_type.containedTypes();
const auto& iter =
std::find_if(types.begin(), types.end(), [](const TypePtr& elem) {
return elem->kind() == TypeKind::TensorType;
});
bool is_tensor_tuple = iter != types.end();
return is_tensor_tuple && inputsCanRunOutOfPlace(n);
}
return false;
}
REGISTER_OPERATOR_FUNCTOR(
prim::ListConstruct,
prim_ListConstruct,
[](Node* n) -> SROperator {
const auto& type = n->output()->type()->expectRef<ListType>();
bool can_optimize = canOptimizeConstruct(n);
return [can_optimize, &type](ProcessedNode* p_node) {
const auto& out_l = p_node->Output(0);
if (!out_l.isNone() && can_optimize) {
return;
}
const size_t size = p_node->inputs().size();
c10::List<IValue> vals(type.getElementType());
vals.reserve(size);
for (size_t i = 0; i < size; i++) {
vals.push_back(p_node->Input(i));
}
p_node->Output(0) = std::move(vals);
};
});
REGISTER_OPERATOR_FUNCTOR(
prim::TupleConstruct,
prim_TupleConstruct,
[](Node* n) -> SROperator {
bool can_optimize = canOptimizeConstruct(n);
return [can_optimize](ProcessedNode* p_node) {
const auto& out_l = p_node->Output(0);
if (!out_l.isNone() && can_optimize) {
return;
}
// prepare inputs
const size_t size = p_node->inputs().size();
std::vector<IValue> vals;
vals.reserve(size);
for (size_t i = 0; i < size; i++) {
vals.push_back(p_node->Input(i));
}
p_node->Output(0) = c10::ivalue::Tuple::create(std::move(vals));
};
});
REGISTER_OPERATOR_FUNCTOR(aten::mul, aten_mul, [](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
const auto& in1_t = p_node->Input(1).toTensor();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(in0_t);
}
auto& out_t = p_node->Output(0).toTensor();
fastResizeToZero(out_t);
at::cpu::mul_out(out_t, in0_t, in1_t);
};
});
REGISTER_OPERATOR_FUNCTOR(aten::addmm, aten_addmm, [](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
const auto& in1_t = p_node->Input(1).toTensor();
const auto& in2_t = p_node->Input(2).toTensor();
const auto in3_s = p_node->Input(3).toScalar();
const auto in4_s = p_node->Input(4).toScalar();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(in0_t);
}
auto& out_t = p_node->Output(0).toTensor();
fastResizeToZero(out_t);
at::native::addmm_cpu_out(out_t, in0_t, in1_t, in2_t, in3_s, in4_s);
};
});
REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
const auto in1_s = p_node->Input(1).toScalar();
const auto in2_s = p_node->Input(2).toScalar();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(in0_t);
}
auto& out_t = p_node->Output(0).toTensor();
fastResizeToZero(out_t);
at::native::clamp_out(in0_t, in1_s, in2_s, out_t);
};
});
REGISTER_OPERATOR_FUNCTOR(aten::bmm, aten_bmm, [](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
const auto& in1_t = p_node->Input(1).toTensor();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(in0_t);
}
auto& out_t = p_node->Output(0).toTensor();
fastResizeToZero(out_t);
at::native::bmm_out_cpu(in0_t, in1_t, out_t);
};
});
REGISTER_OPERATOR_FUNCTOR(
aten::nan_to_num,
aten_nan_to_num,
[](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
auto input_size = p_node->inputs().size();
const auto& in0_t = p_node->Input(0).toTensor();
const double in1_d = input_size > 1 ? p_node->Input(1).toDouble() : 0;
const double in2_d = input_size > 2
? p_node->Input(2).toDouble()
: std::numeric_limits<double>::infinity();
const double in3_d = input_size > 3
? p_node->Input(3).toDouble()
: -std::numeric_limits<double>::infinity();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(in0_t);
}
auto& out_t = p_node->Output(0).toTensor();
fastResizeToZero(out_t);
at::native::nan_to_num_out(in0_t, in1_d, in2_d, in3_d, out_t);
};
});
REGISTER_OPERATOR_FUNCTOR(aten::cat, aten_cat, [](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
const auto in0_tl = p_node->Input(0).toTensorVector();
const auto in1_i = p_node->Input(1).toInt();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(in0_tl[0]);
}
auto& out_t = p_node->Output(0).toTensor();
fastResizeToZero(out_t);
at::native::_cat_out_cpu(out_t, in0_tl, in1_i);
};
});
// Split out into a function to appease MSVC's pre-processor
SROperator aten_stack(Node* n) {
return [](ProcessedNode* p_node) {
const auto inputs = p_node->Input(0).toTensorVector();
const auto dim = p_node->Input(1).toInt();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(inputs[0]);
}
auto& out_t = p_node->Output(0).toTensor();
fastResizeToZero(out_t);
at::native::_stack_out_cpu(inputs, dim, out_t);
};
}
REGISTER_OPERATOR_FUNCTOR(aten::stack, aten_stack, aten_stack);
REGISTER_OPERATOR_FUNCTOR(
aten::leaky_relu,
aten_leaky_relu,
[](Node* n) -> SROperator {
const auto in1 = toIValue(n->inputs()[1]);
if (in1) {
const auto in1_s = in1->toScalar();
return [=](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(in0_t);
}
auto& out_t = p_node->Output(0).toTensor();
at::native::leaky_relu_out(out_t, in0_t, in1_s);
};
} else {
return [](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
const auto in1_s = p_node->Input(1).toScalar();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(in0_t);
}
auto& out_t = p_node->Output(0).toTensor();
at::native::leaky_relu_out(out_t, in0_t, in1_s);
};
}
});
namespace {
// Use the width of an AVX-512 vector by default; this happens to work OK for
// AVX2 as well. Some ops benefit from using multiple AVX ports, in which case
// they are vectorized by twice this constant. An exception is logit, since it
// contains FP divide, which is single-ported.
static constexpr int kVectorWidth = 16;
#ifdef TORCH_ENABLE_LLVM
struct TEWrapper {
tensorexpr::KernelArena ka;
tensorexpr::KernelScope ks;
std::unique_ptr<tensorexpr::LLVMCodeGen> cg;
TEWrapper() = default;
void update(std::unique_ptr<tensorexpr::LLVMCodeGen>&& cg_) {
cg = std::move(cg_);
}
template <typename... Ts>
void operator()(const Ts&... ts) {
std::vector<tensorexpr::CodeGen::CallArg> args(
{tensorexpr::CodeGen::CallArg(ts)...});
cg->call(args);
}
inline bool supports(const at::Tensor& t) {
return t.is_contiguous() && t.dtype().Match<float>();
}
};
void optimizePointwise(
tensorexpr::LoopNest* ln,
tensorexpr::Tensor* target,
int width) {
using namespace torch::jit::tensorexpr;
std::vector<For*> loops = ln->getLoopStmtsFor(target);
For *outer, *inner, *tail;
TORCH_CHECK(loops.size() > 0, "No loops created for pointwise op");
ln->splitWithTail(loops[0], width, &outer, &inner, &tail);
ln->vectorize(inner);
}
std::shared_ptr<TEWrapper> wrapTECompute(
std::shared_ptr<TEWrapper> wrap,
tensorexpr::Placeholder& in,
tensorexpr::Tensor* out,
tensorexpr::VarHandle& dim,
int width = kVectorWidth) {
using namespace torch::jit::tensorexpr;
LoopNest ln({out});
optimizePointwise(&ln, out, width);
ln.prepareForCodegen();
Stmt* s = ln.root_stmt();
s = tensorexpr::IRSimplifier::simplify(s);
std::vector<CodeGen::BufferArg> args;
args.emplace_back(out);
args.emplace_back(in);
args.emplace_back(dim);
auto cg = std::make_unique<LLVMCodeGen>(s, args);
wrap->update(std::move(cg));
return wrap;
};
#else
struct TEWrapper {
TEWrapper() = default;
template <typename... Ts>
void operator()(const Ts&... ts) {
DCHECK(0 && "Invalid call");
}
inline bool supports(const at::Tensor& t) {
return false;
}
};
std::shared_ptr<TEWrapper> wrapTECompute(
std::shared_ptr<TEWrapper> wrap,
tensorexpr::Placeholder& in,
tensorexpr::Tensor* out,
tensorexpr::VarHandle& dim,
int width = kVectorWidth) {
return wrap;
};
#endif
} // namespace
std::shared_ptr<TEWrapper> createLogit(c10::optional<float> clamp) {
using namespace torch::jit::tensorexpr;
auto wrap = std::make_shared<TEWrapper>();
auto N = VarHandle("N", kInt);
Placeholder A("A", kFloat, {N});
tensorexpr::Tensor* B = Compute("B", {N}, [&](const VarHandle& i) {
auto A_elem = [&]() {
if (!clamp) {
return A.load(i);
} else {
auto elem = A.load(i);
auto min = FloatImm::make(*clamp);
auto max = FloatImm::make(1.0f - *clamp);
elem = CompareSelect::make(elem, min, min, elem, kLT);
return CompareSelect::make(elem, max, max, elem, kGT);
}
}();
return log_vml(A_elem / (FloatImm::make(1.0f) - A_elem));
});
return wrapTECompute(wrap, A, B, N);
}
std::shared_ptr<TEWrapper> createRelu() {
using namespace torch::jit::tensorexpr;
auto wrap = std::make_shared<TEWrapper>();
auto N = VarHandle("N", kInt);
Placeholder A("A", kFloat, {N});
tensorexpr::Tensor* B = Compute("B", {N}, [&](const VarHandle& i) {
auto zero = FloatImm::make(0.f);
auto a = A.load(i);
return ifThenElse(a < zero, zero, a);
});
return wrapTECompute(wrap, A, B, N);
}
std::shared_ptr<TEWrapper> createTanh() {
using namespace torch::jit::tensorexpr;
auto wrap = std::make_shared<TEWrapper>();
auto N = VarHandle("N", kInt);
Placeholder A("A", kFloat, {N});
tensorexpr::Tensor* B = Compute("B", {N}, [&](const VarHandle& i) {
auto a = A.load(i);
return fast_tanh(a);
});
return wrapTECompute(wrap, A, B, N);
}
std::shared_ptr<TEWrapper> createSigmoid() {
using namespace torch::jit::tensorexpr;
auto wrap = std::make_shared<TEWrapper>();
auto N = VarHandle("N", kInt);
Placeholder A("A", kFloat, {N});
Tensor* B =
Compute("B", {N}, [&](const VarHandle& i) { return sigmoid(A.load(i)); });
// NNC uses sleef for vectorizing sigmoid, which comes in an 8-wide flavor
// (Sleef_expf8).
constexpr int kSleefWidth = 8;
return wrapTECompute(wrap, A, B, N, kSleefWidth);
}
REGISTER_OPERATOR_FUNCTOR(aten::relu, aten_relu, [](Node* n) -> SROperator {
auto te = createRelu();
return [te](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(in0_t);
}
auto& out_t = p_node->Output(0).toTensor();
if (!te->supports(in0_t)) {
fastResizeToZero(out_t);
at::native::threshold_out(out_t, in0_t, 0, 0);
} else {
at::native::resize_as_(out_t, in0_t, c10::nullopt);
(*te)(out_t.data_ptr<float>(), in0_t.data_ptr<float>(), in0_t.numel());
}
};
});
REGISTER_OPERATOR_FUNCTOR(aten::tanh, aten_tanh, [](Node* n) -> SROperator {
auto te = createTanh();
return [te](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(in0_t);
}
auto& out_t = p_node->Output(0).toTensor();
if (!te->supports(in0_t)) {
fastResizeToZero(out_t);
at::native::tanh_out(out_t, in0_t);
} else {
at::native::resize_as_(out_t, in0_t, c10::nullopt);
(*te)(out_t.data_ptr<float>(), in0_t.data_ptr<float>(), in0_t.numel());
}
};
});
REGISTER_OPERATOR_FUNCTOR(
aten::sigmoid,
aten_sigmoid,
[](Node* n) -> SROperator {
auto te = createSigmoid();
return [te](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(in0_t);
}
auto& out_t = p_node->Output(0).toTensor();
if (!te->supports(in0_t)) {
fastResizeToZero(out_t);
at::native::sigmoid_out(out_t, in0_t);
} else {
at::native::resize_as_(out_t, in0_t, c10::nullopt);
(*te)(
out_t.data_ptr<float>(), in0_t.data_ptr<float>(), in0_t.numel());
}
};
});
REGISTER_OPERATOR_FUNCTOR(aten::logit, aten_logit, [](Node* n) -> SROperator {
c10::optional<float> clamp;
if (n->inputs().size() > 1) {
TORCH_CHECK(n->inputs().at(1)->node()->kind() == prim::Constant);
clamp = toIValue(n->inputs().at(1))->toDouble();
}
auto te = createLogit(clamp);
return [te](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(in0_t);
}
auto& out_t = p_node->Output(0).toTensor();
if (!te->supports(in0_t)) {
const auto in0_t = p_node->Input(0).toTensor();
const double in1_d =
p_node->inputs().size() > 1 ? p_node->Input(1).toDouble() : -1.0;
fastResizeToZero(out_t);
at::native::logit_out(out_t, in0_t, in1_d);
} else {
at::native::resize_as_(out_t, in0_t, c10::nullopt);
(*te)(out_t.data_ptr<float>(), in0_t.data_ptr<float>(), in0_t.numel());
}
};
});
REGISTER_OPERATOR_FUNCTOR(aten::clone, aten_clone, [](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(in0_t);
}
auto& out_t = p_node->Output(0).toTensor();
at::native::resize_as_(out_t, in0_t, c10::nullopt);
at::native::copy_(out_t, in0_t, false);
};
});
REGISTER_OPERATOR_FUNCTOR(
quantized::embedding_bag_byte_rowwise_offsets,
quantized_embedding_bag_byte_rowwise_offsets,
[](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
const auto& weight = p_node->Input(0).toTensor();
const auto& indices = p_node->Input(1).toTensor();
const auto offsets = p_node->Input(2).toOptional<at::Tensor>();
const auto pruned_weights = p_node->Input(5).toBool();
const auto per_sample_weights =
p_node->Input(6).toOptional<at::Tensor>();
const auto compressed_indices_mapping =
p_node->Input(7).toOptional<at::Tensor>();
const auto include_last_offset = p_node->Input(8).toBool();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(weight, at::kFloat);
}
auto& out_t = p_node->Output(0).toTensor();
fastResizeToZero(out_t);
return at::native::embedding_bag_byte_rowwise_offsets_out(
out_t,
weight,
indices,
offsets,
false, // unused scale_grad_by_freq
0, // unused mode
pruned_weights,
per_sample_weights,
compressed_indices_mapping,
include_last_offset);
};
});
REGISTER_OPERATOR_FUNCTOR(
quantized::embedding_bag_4bit_rowwise_offsets,
embedding_bag_4bit_rowwise_offsets,
[](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
const auto& weight = p_node->Input(0).toTensor();
const auto& indices = p_node->Input(1).toTensor();
const auto offsets = p_node->Input(2).toOptional<at::Tensor>();
const auto pruned_weights = p_node->Input(5).toBool();
const auto per_sample_weights =
p_node->Input(6).toOptional<at::Tensor>();
const auto compressed_indices_mapping =
p_node->Input(7).toOptional<at::Tensor>();
const auto include_last_offset = p_node->Input(8).toBool();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(weight, at::kFloat);
}
auto& out_t = p_node->Output(0).toTensor();
fastResizeToZero(out_t);
return at::native::embedding_bag_byte_rowwise_offsets_out(
out_t,
weight,
indices,
offsets,
false, // unused scale_grad_by_freq
0, // unused mode
pruned_weights,
per_sample_weights,
compressed_indices_mapping,
include_last_offset);
};
});
// The out variant takes precedence over native
REGISTER_OPERATOR_FUNCTOR(
aten::narrow_copy,
aten_narrow_copy,
[](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
const auto& self = p_node->Input(0).toTensor(); // self
const auto dim = p_node->Input(1).toInt(); // dim
int64_t start = 0;
if (p_node->Input(2).isScalar()) {
start = p_node->Input(2).toInt();
} else {
auto& t = p_node->Input(2).toTensor();
start = t.item<int64_t>();
}
auto length = p_node->Input(3).toInt(); // length
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(self);
}
auto& output = p_node->Output(0).toTensor();
fastResizeToZero(output);
at::native::narrow_copy_dense_cpu_out(self, dim, start, length, output);
};
});
REGISTER_OPERATOR_FUNCTOR(aten::index, aten_index, [](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
const auto in1_l =
at::native::toListOfOptionalTensors(p_node->Input(1).toListRef());
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(in0_t);
}
auto& out_t = p_node->Output(0).toTensor();
fastResizeToZero(out_t);
at::native::index_out(out_t, in0_t, in1_l);
};
});
REGISTER_OPERATOR_FUNCTOR(aten::pow, aten_pow, [](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
if (p_node->Output(0).isNone()) {
c10::ScalarType dtype;
if (p_node->Input(0).isTensor()) {
const auto& in0_t = p_node->Input(0).toTensor();
if (p_node->Input(1).isTensor()) {
dtype = at::native::result_type(in0_t, p_node->Input(1).toTensor());
p_node->Output(0) = create_empty_from(in0_t, dtype);
} else {
dtype = at::native::result_type(in0_t, p_node->Input(1).toScalar());
p_node->Output(0) = at::native::empty_like(
in0_t, in0_t.options().dtype(dtype), at::MemoryFormat::Preserve);
}
} else {
const auto& in1_t = p_node->Input(1).toTensor();
dtype = at::native::result_type(p_node->Input(0).toScalar(), in1_t);
p_node->Output(0) = at::native::empty_like(
in1_t, in1_t.options().dtype(dtype), at::MemoryFormat::Preserve);
}
}
auto& out_t = p_node->Output(0).toTensor();
fastResizeToZero(out_t);
if (p_node->Input(0).isTensor()) {
if (p_node->Input(1).isTensor()) {
at::cpu::pow_out(
out_t, p_node->Input(0).toTensor(), p_node->Input(1).toTensor());
} else {
at::cpu::pow_out(
out_t, p_node->Input(0).toTensor(), p_node->Input(1).toScalar());
}
} else {
at::cpu::pow_out(
out_t, p_node->Input(0).toScalar(), p_node->Input(1).toTensor());
}
};
});
// out variant takes precedence over native
REGISTER_OPERATOR_FUNCTOR(
static_runtime::to_copy,
aten_to_copy,
[](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
// support 4- or 5-arg for adindexer/adfinder models
DCHECK(p_node->inputs().size() >= 4);
const auto& in0_t = p_node->Input(0).toTensor();
auto in2_i = p_node->Input(2).toBool(); // non_blocking
// ignore input 3 (copy)
if (p_node->Output(0).isNone()) {
auto in1_i = p_node->Input(1).toScalarType();
c10::optional<c10::MemoryFormat> in4_o = c10::nullopt;
if (p_node->inputs().size() > 4 && p_node->Input(4).isInt()) {
in4_o = p_node->Input(4).toOptional<c10::MemoryFormat>();
}
if (in4_o.value_or(c10::MemoryFormat::Preserve) ==
c10::MemoryFormat::Preserve) {
if (in0_t.is_non_overlapping_and_dense()) {
in4_o = c10::nullopt;
} else {
in4_o = in0_t.suggest_memory_format();
}
}
// See Note [Explicit nullopt MemoryFormat argument]
p_node->Output(0) = at::detail::empty_cpu(
{0}, in1_i, in0_t.layout(), in0_t.device(), c10::nullopt, in4_o);
}
auto& out_t = p_node->Output(0).toTensor();
fastResizeToZero(out_t);
at::native::to_copy_out(out_t, in0_t, in2_i);
};
});
// Out variants for view ops are registered to a separate registry because
// their outputs (views) can't participate in memory reuse.
REGISTER_OPERATOR_FUNCTOR(
aten::reshape,
aten_reshape,
[](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
const auto& self = p_node->Input(0).toTensor(); // self
const auto proposed_shape = p_node->Input(1).toIntVector(); // shape
if (p_node->Output(0).isNone()) {
p_node->Output(0) = at::Tensor();
}
auto& out = p_node->Output(0).toTensor();
at::native::reshape_out(out, self, proposed_shape, true);
};
});
REGISTER_OPERATOR_FUNCTOR(
aten::flatten,
aten_flatten,
[](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
DCHECK(p_node->inputs().size() == 3);
const auto& self = p_node->Input(0).toTensor();
const auto start_dim = p_node->Input(1).toInt();
const auto end_dim = p_node->Input(2).toInt();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = at::Tensor();
}
auto& out = p_node->Output(0).toTensor();
at::native::flatten_out(out, self, start_dim, end_dim);
};
});
REGISTER_OPERATOR_FUNCTOR(aten::sum, aten_sum, [](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
const at::Tensor& self = p_node->Input(0).toTensor();
std::vector<int64_t> dim = {};
if ((p_node->inputs().size() > 1) && (!p_node->Input(1).isNone())) {
dim = p_node->Input(1).toIntList().vec();
}
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(self);
}
auto& output = p_node->Output(0).toTensor();
fastResizeToZero(output);
if (p_node->inputs().size() > 2) {
at::native::sum_out(
output,
self,
dim,
p_node->Input(2).toBool(),
p_node->Input(3).toOptional<at::ScalarType>());
return;
}
at::native::sum_out(output, self, dim);
};
});
std::function<void(ProcessedNode*)> getOutOfPlaceOperation(Node* n) {
auto op_name = n->kind().toQualString();
if (SROperatorRegistry()->Has(op_name)) {
return SROperatorRegistry()->Create(op_name)->Generate(n);
}
return [](ProcessedNode*) { TORCH_CHECK(0); };
}
std::function<void(ProcessedNode*)> getNativeOperation(Node* n) {
if (n->kind() == c10::Symbol::fromQualString("aten::transpose")) {
return [](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
const auto in1_i = p_node->Input(1).toInt();
const auto in2_i = p_node->Input(2).toInt();
p_node->Output(0) = at::native::transpose(in0_t, in1_i, in2_i);
};
} else if (n->kind() == c10::Symbol::fromQualString("aten::flatten")) {
return [](ProcessedNode* p_node) {
DCHECK(p_node->inputs().size() == 3);
const auto& in0_t = p_node->Input(0).toTensor();
const auto in1_i = p_node->Input(1).toInt();
const auto in2_i = p_node->Input(2).toInt();
p_node->Output(0) = at::native::flatten(in0_t, in1_i, in2_i);
};
} else if (n->kind() == prim::TupleConstruct) {
return [](ProcessedNode* p_node) {
// prepare inputs
std::vector<IValue> stack;
const size_t size = p_node->inputs().size();
stack.reserve(size);
for (size_t i = 0; i < size; i++) {
stack.emplace_back(p_node->Input(i));
}
// run op
auto* node = p_node->node();
const auto& type = node->output()->type()->expect<TupleType>();
if (type->name().has_value()) {
namedTupleConstruct(stack, type, node->inputs().size());
} else {
tupleConstruct(stack, node->inputs().size());
}
// put output back
p_node->Output(0) = std::move(stack[0]);
};
} else if (n->kind() == prim::ListConstruct) {
return [](ProcessedNode* p_node) {
// prepare inputs
std::vector<IValue> stack;
const size_t size = p_node->inputs().size();
stack.reserve(size);
for (size_t i = 0; i < size; i++) {
stack.emplace_back(p_node->Input(i));
}
// run op
listConstruct(
stack,
p_node->node()->output()->type()->expectRef<ListType>(),
p_node->inputs().size());
// put output back
p_node->Output(0) = std::move(stack[0]);
};
} else if (n->kind() == prim::ListUnpack) {
return [](ProcessedNode* p_node) {
// prepare inputs
std::vector<IValue> stack;
const size_t size = p_node->inputs().size();
stack.reserve(size);
for (size_t i = 0; i < size; i++) {
stack.emplace_back(p_node->Input(i));
}
// run op
size_t num_outputs = p_node->outputs().size();
listUnpack(stack, num_outputs);
// put output back
DCHECK_EQ(stack.size(), num_outputs);
for (auto i = 0; i < num_outputs; i++) {
p_node->Output(i) = std::move(stack[i]);
}
};
} else if (n->kind() == c10::Symbol::fromQualString("aten::permute")) {
return [](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
const auto in1_iv = p_node->Input(1).toIntVector();
p_node->Output(0) = at::native::permute(in0_t, in1_iv);
};
} else if (n->kind() == c10::Symbol::fromQualString("aten::reshape")) {
return [](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
const auto in1_iv = p_node->Input(1).toIntVector();
p_node->Output(0) = at::native::reshape(in0_t, in1_iv);
};
} else if (n->kind() == c10::Symbol::fromQualString("aten::slice")) {
return [](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
const auto in1_i = p_node->Input(1).toInt();
const auto in2_i = p_node->Input(2).toInt();
const auto in3_i = p_node->Input(3).toInt();
const auto in4_i = p_node->Input(4).toInt();
p_node->Output(0) = at::native::slice(in0_t, in1_i, in2_i, in3_i, in4_i);
};
} else if (n->kind() == c10::Symbol::fromQualString("aten::narrow")) {
return [](ProcessedNode* p_node) {
const auto& self = p_node->Input(0).toTensor(); // self
const auto dim = p_node->Input(1).toInt(); // dim
int64_t start = 0;
if (p_node->Input(2).isScalar()) {
start = p_node->Input(2).toInt();
} else {
auto& t = p_node->Input(2).toTensor();
start = t.item<int64_t>();
}
const auto length = p_node->Input(3).toInt(); // length
TORCH_CHECK(
self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
auto cur_size = self.sizes()[dim];
if (start != cur_size && start < 0) { // start being the end is valid, but
// not a valid dim specification.
start = at::maybe_wrap_dim(start, cur_size);
}
TORCH_CHECK(
length >= 0 && start <= cur_size - length,
"start (",
start,
") + length (",
length,
") exceeds dimension size (",
cur_size,
").");
p_node->Output(0) =
at::native::slice(self, dim, start, start + length, 1);
};
} else if (n->kind() == c10::Symbol::fromQualString("aten::to")) {
return [](ProcessedNode* p_node) {
DCHECK(p_node->inputs().size() == 5);
const auto& in0_t = p_node->Input(0).toTensor();
const auto in1_i = p_node->Input(1).toScalarType();
const auto in2_i = p_node->Input(2).toBool();
const auto in3_i = p_node->Input(3).toBool();
if (p_node->Input(4).isNone()) {
p_node->Output(0) =
at::native::to(in0_t, in1_i, in2_i, in3_i, c10::nullopt);
} else {
const auto in4_o = p_node->Input(4).toMemoryFormat();
p_node->Output(0) = at::native::to(in0_t, in1_i, in2_i, in3_i, in4_o);
}
};
}
return [](ProcessedNode*) { TORCH_CHECK(0); };
}
REGISTER_OPERATOR_FUNCTOR(
aten::embedding_bag,
aten_embedding_bag,
[](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
TORCH_CHECK(
p_node->inputs().size() == 8,
"Expected number of inputs are 8, but got " +
std::to_string(p_node->inputs().size()));
const auto& weight = p_node->Input(0).toTensor();
const auto& indices = p_node->Input(1).toTensor();
const auto& offsets = p_node->Input(2).toTensor();
auto scale_grad_by_freq = p_node->Input(3).toBool();
auto mode = p_node->Input(4).to<int64_t>();
auto sparse = p_node->Input(5).toBool();
auto per_sample_weights = p_node->Input(6).toOptional<at::Tensor>();
auto include_last_offset = p_node->Input(7).toBool();
at::native::check_arguments(
weight,
indices,
offsets,
mode,
per_sample_weights,
include_last_offset);
std::ignore = scale_grad_by_freq;
std::ignore = sparse;
if (p_node->Output(0).isNone()) {
p_node->Output(0) = at::empty(
{include_last_offset ? offsets.sizes()[0] - 1
: offsets.sizes()[0],
weight.sizes()[1]},
weight.options());
} else {
at::native::resize_(
p_node->Output(0).toTensor(),
{include_last_offset ? offsets.sizes()[0] - 1
: offsets.sizes()[0],
weight.sizes()[1]},
c10::nullopt);
}
at::Tensor& output = p_node->Output(0).toTensor();
if (p_node->Output(1).isNone()) {
p_node->Output(1) = at::empty({0}, offsets.options());
}
at::Tensor& offset2bag = p_node->Output(1).toTensor();
at::native::make_offset2bag_out(
offset2bag,
output,
weight,
indices,
offsets,
mode,
per_sample_weights);
if (p_node->Output(2).isNone()) {
p_node->Output(2) = at::empty(offsets.sizes(), offsets.options());
}
at::Tensor& bag_size = p_node->Output(2).toTensor();
at::native::make_bag_size_out(
bag_size, offsets, indices, mode, include_last_offset, false);
if (p_node->Output(3).isNone()) {
p_node->Output(3) = at::empty(bag_size.sizes(), offsets.options());
}
at::Tensor& max_indices = p_node->Output(3).toTensor();
at::native::make_max_indices_out(
max_indices,
weight,
indices,
offsets,
bag_size,
mode,
include_last_offset);
at::native::_embedding_bag_cpu_impl_out(
output,
offset2bag,
bag_size,
max_indices,
weight,
indices,
offsets,
mode,
per_sample_weights,
include_last_offset);
};
});
} // namespace jit
} // namespace torch