pytorch/torch/csrc/jit/runtime/static/ops.h
Yuanyuan Chen 36871622f1 [2/N] Mark unused parameters in C++ code (#165121)
This is follow-up of #164912 to mark unused C++ parameters to improve code readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165121
Approved by: https://github.com/Skylion007
2025-10-15 03:04:39 +00:00

188 lines
5.4 KiB
C++

#pragma once
#include <ATen/Utils.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/static/impl.h>
namespace at::native {
at::Tensor& reshape_copy_out(
at::Tensor& out,
const at::Tensor& self,
const at::DimVector& proposed_shape,
bool infer_size = true);
at::Tensor& to_copy_out(
Tensor& out,
const Tensor& self,
bool non_blocking,
bool copy_strides,
std::optional<MemoryFormat> memory_format);
} // namespace at::native
namespace torch::jit {
using SROpFunctor = SROperator (*)(Node* n);
struct SROperatorFunctor {
virtual SROperator Generate(Node* /*unused*/) {
SROperator out;
return out;
}
virtual ~SROperatorFunctor() = default;
};
TORCH_DECLARE_REGISTRY(SROperatorRegistry, SROperatorFunctor);
#define REGISTER_OPERATOR_FUNCTOR(name, id, ...) \
struct SROperatorFunctor_##id : public SROperatorFunctor { \
SROpFunctor fn = __VA_ARGS__; \
SROperator Generate(Node* n) override { \
return fn(n); \
} \
}; \
C10_REGISTER_CLASS(SROperatorRegistry, name, SROperatorFunctor_##id)
TORCH_DECLARE_REGISTRY(SRNativeOperatorRegistry, SROperatorFunctor);
#define REGISTER_NATIVE_OPERATOR_FUNCTOR(name, id, ...) \
struct SRNativeOperatorFunctor_##id : public SROperatorFunctor { \
SROpFunctor fn = __VA_ARGS__; \
SROperator Generate(Node* n) override { \
return fn(n); \
} \
}; \
C10_REGISTER_CLASS( \
SRNativeOperatorRegistry, name, SRNativeOperatorFunctor_##id)
inline at::Tensor create_empty_from(const at::Tensor& t) {
return at::detail::empty_cpu(
{0},
c10::typeMetaToScalarType(t.dtype()),
t.layout(),
t.device(),
std::nullopt,
std::nullopt);
}
inline at::Tensor create_empty_from(
at::IntArrayRef sizes,
const at::Tensor& t) {
return at::detail::empty_cpu(
sizes,
c10::typeMetaToScalarType(t.dtype()),
t.layout(),
t.device(),
std::nullopt,
std::nullopt);
}
inline at::Tensor create_empty(c10::ScalarType dtype) {
return at::detail::empty_cpu(
{0}, dtype, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
}
inline at::Tensor create_empty_from(
const at::Tensor& t,
c10::ScalarType dtype) {
return at::detail::empty_cpu(
{0}, dtype, t.layout(), t.device(), std::nullopt, std::nullopt);
}
inline at::Tensor create_empty_from(const at::Tensor& t, c10::Layout layout) {
return at::detail::empty_cpu(
{0},
c10::typeMetaToScalarType(t.dtype()),
layout,
t.device(),
std::nullopt,
std::nullopt);
}
inline at::Tensor create_empty_from(const at::Tensor& t, c10::Device device) {
return at::detail::empty_cpu(
{0},
c10::typeMetaToScalarType(t.dtype()),
t.layout(),
device,
std::nullopt,
std::nullopt);
}
inline at::Tensor create_empty_from(
const at::Tensor& t,
c10::MemoryFormat memory_format) {
return at::detail::empty_cpu(
{0},
c10::typeMetaToScalarType(t.dtype()),
t.layout(),
t.device(),
std::nullopt,
memory_format);
}
inline at::Tensor create_empty_from(
const at::Tensor& t,
c10::ScalarType dtype,
c10::MemoryFormat memory_format) {
return at::detail::empty_cpu(
{0}, dtype, t.layout(), t.device(), std::nullopt, memory_format);
}
inline bool checkResizedDataPtr(at::Tensor& t) {
auto const prev_data_ptr = t.data_ptr();
t.resize_({0});
return prev_data_ptr == t.data_ptr();
}
inline void fastResizeToZero(at::Tensor& t) {
t.unsafeGetTensorImpl()->set_sizes_contiguous({0});
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(checkResizedDataPtr(t));
}
// check if an op has an out variant registered in Static Runtime
bool opIsRegistered(const c10::Symbol& op_name);
// check if Static Runtime can run an op natively.
// prim ops that are implemented directly in the jit interpreter are implemented
// as native ops in Static Runtime
bool nativeOpIsRegistered(const c10::Symbol& op_name);
bool canReuseInputsOutputs(
Node* n,
const c10::FastMap<Node*, bool>& node_has_out_variant);
bool isOptimizableContainerType(
Node* n,
const c10::FastMap<Node*, bool>& node_has_out_variant);
SROperator getOutOfPlaceOperation(Node* n);
SROperator getNativeOperation(Node* n);
bool hasVarArgs(Node* n);
inline std::string PrintNode(const Node* node) {
std::ostringstream ss;
node->print(ss, 0, nullptr, false);
return ss.str();
}
inline void LogAndDumpSchema(const Node* node) {
VLOG(1) << "Found schema mismatch for: " << node->schema();
}
inline bool sr_schema_check(torch::jit::Node* /*unused*/) {
return true;
}
template <typename Schema, typename... Schemas>
bool sr_schema_check(
torch::jit::Node* node,
Schema&& first,
Schemas&&... rest) {
auto is_match = node->matches(first) || sr_schema_check(node, rest...);
if (!is_match) {
torch::jit::LogAndDumpSchema(node);
}
return is_match;
}
bool sr_schema_check_kind(torch::jit::Node* node, c10::Symbol node_kind);
} // namespace torch::jit
C10_DECLARE_bool(static_runtime_enable_fast_math);