mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Follows #132209 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132411 Approved by: https://github.com/Skylion007
46 lines
1.3 KiB
C++
46 lines
1.3 KiB
C++
// This file defines classes for registering standard lowerings from JIT to TE
|
|
// IR.
|
|
#pragma once
|
|
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/runtime/interpreter.h>
|
|
#include <torch/csrc/jit/tensorexpr/analysis.h>
|
|
#include <torch/csrc/jit/tensorexpr/codegen.h>
|
|
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
|
|
|
namespace torch::jit::tensorexpr {
|
|
|
|
using ArgNone = std::monostate;
|
|
using BufList = std::vector<tensorexpr::BufHandle>;
|
|
using DoubleList = std::vector<double>;
|
|
using IntList = std::vector<int64_t>;
|
|
using ArgValue = std::variant<
|
|
tensorexpr::BufHandle,
|
|
tensorexpr::VarHandle,
|
|
double,
|
|
int64_t,
|
|
bool,
|
|
BufList,
|
|
DoubleList,
|
|
IntList,
|
|
std::string,
|
|
ArgNone>;
|
|
|
|
using NNCLoweringFunction = std::function<Tensor(
|
|
const std::vector<ArgValue>&,
|
|
const std::vector<ExprHandle>&,
|
|
const std::vector<ExprHandle>&,
|
|
const std::optional<ScalarType>&,
|
|
at::Device)>;
|
|
|
|
TORCH_API FunctionSchemaMap<NNCLoweringFunction>& getNNCLoweringRegistry();
|
|
TORCH_API NNCLoweringFunction getStandardLoweringFor(const std::string& op);
|
|
|
|
struct RegisterNNCLoweringsFunction {
|
|
RegisterNNCLoweringsFunction(
|
|
const std::vector<std::string>& schemas,
|
|
const NNCLoweringFunction& fn);
|
|
};
|
|
|
|
} // namespace torch::jit::tensorexpr
|