mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/138976 Approved by: https://github.com/Skylion007
798 lines
23 KiB
C++
798 lines
23 KiB
C++
#include <torch/csrc/jit/passes/quantization/helper.h>
|
|
|
|
#include <torch/csrc/jit/api/function_impl.h>
|
|
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
|
|
|
|
#include <utility>
|
|
|
|
namespace torch::jit {
|
|
|
|
using graph_rewrite_helper::getFuncName;
|
|
|
|
struct FuncArg {
|
|
std::string func_name;
|
|
int arg_index;
|
|
};
|
|
|
|
using AtenFuncArgs = std::vector<FuncArg>;
|
|
using CallFuncArgs = std::vector<FuncArg>;
|
|
|
|
// Lists of allowed quantizable operators
|
|
std::vector<std::string> _static_quantizable_call_funcs = {
|
|
"conv2d",
|
|
"linear",
|
|
"batch_norm",
|
|
"hardswish",
|
|
"elu",
|
|
"celu",
|
|
"layer_norm",
|
|
"group_norm",
|
|
"instance_norm",
|
|
"embedding_bag",
|
|
};
|
|
|
|
std::vector<std::string> _static_quantizable_aten_funcs = {
|
|
"conv1d",
|
|
"conv2d",
|
|
"conv3d",
|
|
"conv_transpose1d",
|
|
"conv_transpose2d",
|
|
"linear",
|
|
"hardswish",
|
|
"hardswish_",
|
|
"elu",
|
|
"elu_",
|
|
"celu",
|
|
"celu_",
|
|
"batch_norm",
|
|
"layer_norm",
|
|
"group_norm",
|
|
"instance_norm",
|
|
"embedding_bag",
|
|
};
|
|
|
|
std::vector<std::string> _dynamic_quantizable_call_funcs = {
|
|
"linear",
|
|
};
|
|
|
|
std::vector<std::string> _dynamic_quantizable_aten_funcs = {
|
|
"linear",
|
|
};
|
|
|
|
std::vector<std::string> _static_weight_only_quant_aten_funcs = {
|
|
"embedding_bag",
|
|
};
|
|
std::vector<std::string> _static_weight_only_quant_call_funcs = {
|
|
"embedding_bag",
|
|
};
|
|
|
|
// These are the prim::CallFunctions that doesn't require observation and
|
|
// have a single input Tensor
|
|
// example: `prim::CallFunction(%dropout, %input_tensor, ...)
|
|
// so we propagate observed property from %input_tensor to the
|
|
// output of the `prim::CallFunction`
|
|
// Also these ops doesn't do computation on the value of Tensor, the
|
|
// operation only depends on the shape of the Tensor
|
|
std::vector<std::string> _single_input_general_shape_call_funcs = {
|
|
"_max_pool1d",
|
|
"_max_pool2d",
|
|
"_max_pool3d",
|
|
"dropout",
|
|
"relu",
|
|
};
|
|
|
|
// Similar to prim::CallFunctions, there are aten ops that doesn't
|
|
// require observation and have a single input Tensor
|
|
// Also these ops doesn't do computation on the value of Tensor, the
|
|
// operation only depends on the shape of the Tensor
|
|
// e.g. `aten::flatten(%input_tensor, ...)`
|
|
std::vector<std::string> _single_input_general_shape_aten_funcs = {
|
|
"max_pool1d",
|
|
"max_pool2d",
|
|
"max_pool3d",
|
|
"flatten",
|
|
"max",
|
|
"min",
|
|
"dropout",
|
|
"reshape",
|
|
// Non-inplace resize is deprecated
|
|
"resize_",
|
|
"chunk",
|
|
"view",
|
|
"transpose",
|
|
"contiguous",
|
|
"permute",
|
|
"repeat",
|
|
"repeat_interleave",
|
|
"relu",
|
|
"relu_",
|
|
"squeeze",
|
|
"squeeze_",
|
|
"unsqueeze",
|
|
"unsqueeze_",
|
|
"detach",
|
|
"detach_",
|
|
"stack",
|
|
"__getitem__",
|
|
};
|
|
|
|
// Theses are prim::CallFunctions for ops that doesn't require observation and
|
|
// have a single input Tensor
|
|
// Also these ops do computation on the value of Tensor
|
|
// TODO: [Need verify] looks like we can quantize simple functionals that just
|
|
// call into aten functions
|
|
std::vector<std::string> _single_input_general_value_call_funcs = {
|
|
"avg_pool1d",
|
|
"avg_pool2d",
|
|
"avg_pool3d",
|
|
"adaptive_avg_pool1d",
|
|
"adaptive_avg_pool2d",
|
|
"adaptive_avg_pool3d",
|
|
"interpolate",
|
|
"upsample",
|
|
"upsample_bilinear",
|
|
"upsample_nearest",
|
|
"hardtanh",
|
|
"leaky_relu",
|
|
};
|
|
|
|
// Theses are aten functions for ops that doesn't require observation and
|
|
// have a single input Tensor
|
|
// Also these ops do computation on the value of Tensor
|
|
// e.g. `aten::avg_pool2d(%input_tensor, ...)`
|
|
std::vector<std::string> _single_input_general_value_aten_funcs = {
|
|
"avg_pool1d",
|
|
"avg_pool2d",
|
|
"avg_pool3d",
|
|
"adaptive_avg_pool1d",
|
|
"adaptive_avg_pool2d",
|
|
"adaptive_avg_pool3d",
|
|
"mean",
|
|
"upsample_nearest1d",
|
|
"upsample_nearest2d",
|
|
"upsample_nearest3d",
|
|
"upsample_linear1d",
|
|
"upsample_bilinear2d",
|
|
"upsample_trilinear3d",
|
|
"upsample_bicubic2d",
|
|
"clamp",
|
|
// "clamp_", // Enable when quantized `clamp_` is ready
|
|
"hardtanh",
|
|
"hardtanh_",
|
|
"leaky_relu",
|
|
"leaky_relu_",
|
|
};
|
|
|
|
std::vector<std::string> _clamp_funcs = {
|
|
"hardtanh",
|
|
"hardtanh_",
|
|
"clamp",
|
|
// "clamp_", // Enable when quantized `clamp_` is ready
|
|
};
|
|
|
|
const float _asym_scale = 1.0f / 256.0f;
|
|
const int _asym_zero_point = 0;
|
|
const float _sym_scale = 2.0f / 256.0f;
|
|
const int _sym_zero_point = 128;
|
|
// quantization parameters for ops with range 0 to 1
|
|
// for example: aten/src/ATen/native/quantized/cpu/qsigmoid.cpp
|
|
std::tuple<c10::QScheme, QParamVector> _per_tensor_asym_qparam =
|
|
std::make_tuple(
|
|
c10::kPerTensorAffine,
|
|
QParamVector(
|
|
{std::make_pair(".scale", IValue(_asym_scale)),
|
|
std::make_pair(".zero_point", IValue(_asym_zero_point)),
|
|
std::make_pair(".scalar_type", IValue(c10::kQUInt8))}));
|
|
|
|
// quantization parameters for ops with range -1 to 1
|
|
// for example: aten/src/ATen/native/quantized/cpu/qtanh.cpp
|
|
std::tuple<c10::QScheme, QParamVector> _per_tensor_sym_qparam = std::make_tuple(
|
|
c10::kPerTensorAffine,
|
|
QParamVector(
|
|
{std::make_pair(".scale", IValue(_sym_scale)),
|
|
std::make_pair(".zero_point", IValue(_sym_zero_point)),
|
|
std::make_pair(".scalar_type", IValue(c10::kQUInt8))}));
|
|
|
|
// Map from aten op symbol to the quantization parameters
|
|
// for the ops with fixed quantization parameters
|
|
std::unordered_map<NodeKind, std::tuple<c10::QScheme, QParamVector>>
|
|
_fixed_qparams_map = {
|
|
{Symbol::aten("hardsigmoid"), _per_tensor_asym_qparam},
|
|
{Symbol::aten("hardsigmoid_"), _per_tensor_asym_qparam},
|
|
{Symbol::aten("sigmoid"), _per_tensor_asym_qparam},
|
|
{Symbol::aten("sigmoid_"), _per_tensor_asym_qparam},
|
|
{Symbol::aten("tanh"), _per_tensor_sym_qparam},
|
|
{Symbol::aten("tanh_"), _per_tensor_sym_qparam},
|
|
};
|
|
|
|
// Special checks for ops that do not require observers for all input tensors.
|
|
// For each operator in this list observers are inserted for the input based
|
|
// on the index specified.
|
|
AtenFuncArgs _observe_inputs_aten_func = {};
|
|
CallFuncArgs _observe_inputs_call_func = {{"batch_norm", 1}};
|
|
|
|
// Aten functions for getting tensor information
|
|
std::vector<std::string> _tensor_info_funcs = {"size", "len", "dim", "numel"};
|
|
|
|
// Aten functions whose output will be quantized or not quantized depending
|
|
// on input tensor
|
|
std::vector<std::string> _propagate_quant_single_input_ops = {"cat"};
|
|
|
|
// Rules are slightly different for binary ops like `aten::add`, for these ops,
|
|
// if both of the inputs are Tensor, we'll quantize the output only if both of
|
|
// the inputs are quantized
|
|
// if the second input is a Scalar, we'll only look at the first input to decide
|
|
// if we need to quantize the output
|
|
std::vector<std::string> _propagate_quant_binary_ops = {
|
|
"add",
|
|
"add_",
|
|
"mul",
|
|
"mul_"};
|
|
|
|
// Check if `use` is an aten function of name `func_name` and if value
|
|
// `v` is the nth argument (if provided) of the function.
|
|
bool matchAtenFuncToUse(
|
|
const Use& use,
|
|
const std::string& func_name,
|
|
std::optional<int> n) {
|
|
Node* node = use.user;
|
|
return node->kind() == Symbol::aten(func_name) &&
|
|
(!n.has_value() || static_cast<size_t>(n.value()) == use.offset);
|
|
}
|
|
|
|
bool matchCallFuncToUse(
|
|
const Use& use,
|
|
const std::string& func_name,
|
|
std::optional<int> n) {
|
|
Node* node = use.user;
|
|
return node->kind() == prim::CallFunction &&
|
|
getFuncName(node->inputs()[0]) == func_name &&
|
|
(!n.has_value() || static_cast<size_t>(n.value()) == use.offset);
|
|
}
|
|
|
|
// Check any use of `v` matches the aten function call
|
|
// or CallFunction patterns
|
|
static bool matchArgPattern(
|
|
Value* v,
|
|
const AtenFuncArgs& aten_func_args,
|
|
const CallFuncArgs& call_func_args) {
|
|
for (const Use& u : v->uses()) {
|
|
for (const auto& func_arg : aten_func_args) {
|
|
if (matchAtenFuncToUse(u, func_arg.func_name, func_arg.arg_index)) {
|
|
return true;
|
|
}
|
|
}
|
|
|
|
for (const auto& func_arg : call_func_args) {
|
|
if (matchCallFuncToUse(u, func_arg.func_name, func_arg.arg_index)) {
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
// TODO add other op signatures.
|
|
bool isWeight(Value* v) {
|
|
bool result = matchArgPattern(
|
|
v,
|
|
// ate::embedding_bag(%weight, %input, %offsets, %scale_grad_by_freq,
|
|
// %mode_enum, %sparse, %per_sample_weights, %include_last_offset)
|
|
AtenFuncArgs(
|
|
{{"conv1d", 1},
|
|
{"conv2d", 1},
|
|
{"conv3d", 1},
|
|
{"conv_transpose1d", 1},
|
|
{"conv_transpose2d", 1},
|
|
{"linear", 1},
|
|
{"embedding_bag", 0}}),
|
|
// embedding_bag - prim::CallFunction(%func, %input.1, %weight,
|
|
// %offsets.1, %max_norm, %norm_type, %scale_grad_by_freq, %mode, %sparse,
|
|
// %per_sample_weights.1, %include_last_offset)
|
|
CallFuncArgs({{"linear", 2}, {"embedding_bag", 2}}));
|
|
return result;
|
|
}
|
|
|
|
bool isBiasOfConvOrLinear(Value* v) {
|
|
bool result = matchArgPattern(
|
|
v,
|
|
AtenFuncArgs(
|
|
{{"conv1d", 2},
|
|
{"conv2d", 2},
|
|
{"conv3d", 2},
|
|
{"conv_transpose1d", 2},
|
|
{"conv_transpose2d", 2},
|
|
{"linear", 2}}),
|
|
CallFuncArgs({{"linear", 3}}));
|
|
return result;
|
|
}
|
|
|
|
bool isEmbeddingBagNonInput(Value* v) {
|
|
bool result = matchArgPattern(
|
|
v,
|
|
AtenFuncArgs({{"embedding_bag", 2}, {"embedding_bag", 6}}),
|
|
CallFuncArgs({}));
|
|
return result;
|
|
}
|
|
|
|
std::optional<Use> getClampScalarInputUse(Value* v) {
|
|
for (const auto& use : v->uses()) {
|
|
for (const auto& aten_func : _clamp_funcs) {
|
|
if (matchAtenFuncToUse(use, aten_func, 1) ||
|
|
matchAtenFuncToUse(use, aten_func, 2)) {
|
|
return use;
|
|
}
|
|
}
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
void cloneMethod(
|
|
Module& module,
|
|
const std::string& orig_method_name,
|
|
const std::string& new_method_name) {
|
|
const Function& method = module.get_method(orig_method_name).function();
|
|
auto graph = toGraphFunction(method).graph()->copy();
|
|
const auto& schema = method.getSchema();
|
|
const auto this_method_name =
|
|
c10::QualifiedName(*module.type()->name(), new_method_name);
|
|
auto copied = module._ivalue()->compilation_unit()->create_function(
|
|
this_method_name, std::move(graph));
|
|
module.type()->addMethod(copied);
|
|
copied->setSchema(schema);
|
|
}
|
|
|
|
std::vector<Value*> getPassThroughInputs(Value* v) {
|
|
Node* n = v->node();
|
|
if (isSingleInputGeneralCallFunction(n)) {
|
|
return {n->input(1)};
|
|
} else if (
|
|
isSingleInputGeneralAtenFunction(n) ||
|
|
(n->kind() == Symbol::aten("sort") && v->offset() == 0)) {
|
|
return {n->input(0)};
|
|
} else if (n->kind() == prim::If && n->outputs().size() == 1) {
|
|
std::vector<Value*> inputs;
|
|
for (Block* subblock : n->blocks()) {
|
|
if (alwaysRaisesException(subblock)) {
|
|
continue;
|
|
}
|
|
auto* output = subblock->outputs()[0];
|
|
inputs.push_back(output);
|
|
}
|
|
return inputs;
|
|
} else if (n->kind() == prim::ListUnpack || n->kind() == prim::TupleUnpack) {
|
|
// only propagate dequantize for Tensor
|
|
if (v->type()->isSubtypeOf(*TensorType::get())) {
|
|
return {n->input(0)};
|
|
} else {
|
|
return {};
|
|
}
|
|
} else if (
|
|
n->kind() == prim::ListConstruct &&
|
|
v->type()->isSubtypeOf(*ListType::ofTensors())) {
|
|
std::vector<Value*> inputs;
|
|
for (auto* v : n->inputs()) {
|
|
inputs.push_back(v);
|
|
}
|
|
return inputs;
|
|
} else if (n->kind() == prim::TupleConstruct) {
|
|
std::vector<Value*> inputs;
|
|
for (auto* input : n->inputs()) {
|
|
if (input->type()->isSubtypeOf(*TensorType::get())) {
|
|
inputs.push_back(input);
|
|
}
|
|
}
|
|
return inputs;
|
|
} else if (n->kind() == Symbol::aten("append")) {
|
|
std::vector<Value*> inputs;
|
|
for (auto* input : n->inputs()) {
|
|
inputs.push_back(input);
|
|
}
|
|
return inputs;
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
static std::vector<NodeKind> toAtenSymbol(
|
|
const std::vector<std::string>& func_names) {
|
|
std::vector<NodeKind> symbols;
|
|
std::transform(
|
|
func_names.begin(),
|
|
func_names.end(),
|
|
std::back_inserter(symbols),
|
|
Symbol::aten);
|
|
return symbols;
|
|
}
|
|
|
|
static bool isAtenFunc(Node* n, const std::vector<NodeKind>& aten_funcs) {
|
|
return std::find(aten_funcs.begin(), aten_funcs.end(), n->kind()) !=
|
|
aten_funcs.end();
|
|
}
|
|
|
|
static bool isAtenFunc(Node* n, const std::vector<std::string>& aten_funcs) {
|
|
const auto& symbols = toAtenSymbol(aten_funcs);
|
|
return isAtenFunc(n, symbols);
|
|
}
|
|
|
|
// TODO: factor out isCallFunc
|
|
static bool isFunctionNode(
|
|
Node* n,
|
|
const std::vector<std::string>& call_funcs,
|
|
const std::vector<std::string>& aten_funcs) {
|
|
bool is_func_node = isAtenFunc(n, aten_funcs);
|
|
if (n->kind() == prim::CallFunction) {
|
|
auto func_name = getFuncName(n->inputs()[0]);
|
|
is_func_node |=
|
|
std::find(call_funcs.begin(), call_funcs.end(), func_name) !=
|
|
call_funcs.end();
|
|
}
|
|
return is_func_node;
|
|
}
|
|
|
|
bool isSingleInputGeneralShapeAtenFunction(Node* n) {
|
|
return isAtenFunc(n, _single_input_general_shape_aten_funcs);
|
|
}
|
|
|
|
bool isSingleInputGeneralValueAtenFunction(Node* n) {
|
|
return isAtenFunc(n, _single_input_general_value_aten_funcs) ||
|
|
isBinaryOpWithScalarInput(n);
|
|
}
|
|
|
|
bool isSingleInputGeneralCallFunction(Node* n) {
|
|
static std::vector<std::string> single_input_general_call_funcs;
|
|
std::copy(
|
|
_single_input_general_shape_call_funcs.begin(),
|
|
_single_input_general_shape_call_funcs.end(),
|
|
std::back_inserter(single_input_general_call_funcs));
|
|
std::copy(
|
|
_single_input_general_value_call_funcs.begin(),
|
|
_single_input_general_value_call_funcs.end(),
|
|
std::back_inserter(single_input_general_call_funcs));
|
|
return isFunctionNode(
|
|
n,
|
|
/* call_funcs = */ single_input_general_call_funcs,
|
|
/* aten_funcs = */ {});
|
|
}
|
|
|
|
bool isSingleInputGeneralAtenFunction(Node* n) {
|
|
static std::vector<NodeKind> fixed_qparams_aten_funcs;
|
|
std::transform(
|
|
_fixed_qparams_map.begin(),
|
|
_fixed_qparams_map.end(),
|
|
std::back_inserter(fixed_qparams_aten_funcs),
|
|
[](auto pair) { return pair.first; });
|
|
|
|
return isSingleInputGeneralValueAtenFunction(n) ||
|
|
isSingleInputGeneralShapeAtenFunction(n) ||
|
|
isAtenFunc(n, fixed_qparams_aten_funcs);
|
|
}
|
|
|
|
bool isClamp(Node* n) {
|
|
return isAtenFunc(n, _clamp_funcs);
|
|
}
|
|
|
|
bool isTensorInfoNode(Node* n) {
|
|
return isAtenFunc(n, _tensor_info_funcs);
|
|
}
|
|
|
|
bool isPropagateQuantSingleInputOp(Node* n) {
|
|
return isAtenFunc(n, _propagate_quant_single_input_ops);
|
|
}
|
|
|
|
bool isPropagateQuantBinaryOp(Node* n) {
|
|
return isAtenFunc(n, _propagate_quant_binary_ops);
|
|
}
|
|
|
|
bool isPropagateQuantOp(Node* n) {
|
|
return isPropagateQuantSingleInputOp(n) || isPropagateQuantBinaryOp(n);
|
|
}
|
|
|
|
bool isBinaryOpWithScalarInput(Node* n) {
|
|
return isPropagateQuantBinaryOp(n) && isScalar(n->input(1));
|
|
}
|
|
|
|
std::optional<std::tuple<c10::QScheme, QParamVector>> getFixedQParams(Node* n) {
|
|
static std::vector<NodeKind> fixed_qparam_funcs;
|
|
std::transform(
|
|
_fixed_qparams_map.begin(),
|
|
_fixed_qparams_map.end(),
|
|
std::back_inserter(fixed_qparam_funcs),
|
|
[](const auto& pair) { return pair.first; });
|
|
if (isAtenFunc(n, fixed_qparam_funcs)) {
|
|
return _fixed_qparams_map.at(n->kind());
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
bool userDefinedCallFunction(Node* n) {
|
|
return n->kind() == prim::CallFunction &&
|
|
!isSingleInputGeneralCallFunction(n) &&
|
|
!isFunctionNode(n, _static_quantizable_call_funcs, {});
|
|
}
|
|
|
|
bool isWeightOnlyStaticQuantOp(Node* n) {
|
|
return isFunctionNode(
|
|
n,
|
|
_static_weight_only_quant_call_funcs,
|
|
_static_weight_only_quant_aten_funcs);
|
|
}
|
|
|
|
bool nodeQuantizable(Node* n, QuantType quant_type) {
|
|
bool is_dynamic = quant_type == QuantType::DYNAMIC;
|
|
return isFunctionNode(
|
|
n,
|
|
/* call_funcs = */
|
|
is_dynamic ? _dynamic_quantizable_call_funcs
|
|
: _static_quantizable_call_funcs,
|
|
/* aten_funcs = */
|
|
is_dynamic ? _dynamic_quantizable_aten_funcs
|
|
: _static_quantizable_aten_funcs);
|
|
}
|
|
|
|
bool useQuantizable(const Use& use, QuantType quant_type) {
|
|
if (quant_type == QuantType::STATIC) {
|
|
for (const auto& func_input : _observe_inputs_aten_func) {
|
|
if (matchAtenFuncToUse(use, func_input.func_name, std::nullopt)) {
|
|
return use.offset == static_cast<size_t>(func_input.arg_index);
|
|
}
|
|
}
|
|
|
|
for (const auto& func_input : _observe_inputs_call_func) {
|
|
if (matchCallFuncToUse(use, func_input.func_name, std::nullopt)) {
|
|
return use.offset == static_cast<size_t>(func_input.arg_index);
|
|
}
|
|
}
|
|
}
|
|
|
|
return nodeQuantizable(use.user, quant_type);
|
|
}
|
|
|
|
std::shared_ptr<Graph> getCallFunctionGraph(Node* n) {
|
|
auto* func_node = n->input(0)->node();
|
|
auto func = func_node->output()->type()->expectRef<FunctionType>().function();
|
|
auto graphFunc = tryToGraphFunction(*func);
|
|
TORCH_CHECK(graphFunc, "Quantization only works for graph function");
|
|
return graphFunc->graph();
|
|
}
|
|
|
|
// Block helper functions
|
|
bool alwaysRaisesException(Block* block) {
|
|
for (Node* n : block->nodes()) {
|
|
if (n->kind() == prim::RaiseException) {
|
|
return true;
|
|
}
|
|
if (n->kind() == prim::If) {
|
|
bool exception = true;
|
|
for (Block* b : n->blocks()) {
|
|
exception &= alwaysRaisesException(b);
|
|
}
|
|
if (exception) {
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
// Check if a value in the graph is a Scalar value
|
|
bool isScalar(Value* v) {
|
|
auto iv = toIValue(v);
|
|
return v->type()->isSubtypeOf(*NumberType::get()) ||
|
|
(v->type()->isSubtypeOf(*TensorType::get()) && iv && iv->isTensor() &&
|
|
iv->toTensor().dim() == 0);
|
|
}
|
|
|
|
// =================== Graph/Module analysis helper functions ============
|
|
// Check if value is the input of the graph
|
|
bool hitGraphInput(Value* value) {
|
|
Graph* graph = value->owningGraph();
|
|
const auto& inputs = graph->inputs();
|
|
return std::find(inputs.begin(), inputs.end(), value) != inputs.end();
|
|
}
|
|
|
|
// Get the module access path for a Value representing a module instance
|
|
// by tracing back the GetAttr nodes and recording all the attribute
|
|
// names along the way.
|
|
// Assuming 'self.sub.basic_block.conv1',
|
|
// Input1: Value instance of conv1
|
|
// Input2: Value instance of self
|
|
// Output: ['sub', 'basic_block', 'conv1']
|
|
std::vector<std::string> getModuleAccessPath(Value* instance, Value* self) {
|
|
std::vector<std::string> path;
|
|
// Iterator to traverse back the GetAttr calls
|
|
Value* iter = instance;
|
|
// trace back the instance to recover the path of the submodule
|
|
while (!hitGraphInput(iter) && iter->node()->kind() == prim::GetAttr) {
|
|
Node* get_attr = iter->node();
|
|
// record the name of GetAttr
|
|
path.push_back(get_attr->s(attr::name));
|
|
// trace back the chain of GetAttr
|
|
iter = get_attr->inputs()[0];
|
|
}
|
|
TORCH_CHECK(
|
|
iter == self,
|
|
"Can't handle the access pattern of GetAttr "
|
|
" in getModuleAccessPath, traced back to:",
|
|
iter->debugName(),
|
|
" which is not self:",
|
|
self->debugName());
|
|
std::reverse(path.begin(), path.end());
|
|
return path;
|
|
}
|
|
|
|
// Assuming self.foo.bar.conv1,
|
|
// Input1: Module instance of self
|
|
// Input2: ['foo', 'bar', 'conv1']
|
|
// Output: Module instance of conv1
|
|
Module findChildModule(
|
|
const Module& module,
|
|
const std::vector<std::string>& path) {
|
|
Module m = module;
|
|
for (const auto& p : path) {
|
|
m = m.attr(p).toModule();
|
|
}
|
|
return m;
|
|
}
|
|
|
|
Module getInvokedModule(Module& module, Node* n, Value* self) {
|
|
auto* instance = n->inputs()[0];
|
|
auto path = getModuleAccessPath(instance, self);
|
|
return findChildModule(module, path);
|
|
}
|
|
|
|
std::optional<Module> getInvokedModuleOpt(
|
|
const Module& module,
|
|
Node* n,
|
|
Value* self) {
|
|
auto* instance = n->inputs()[0];
|
|
auto path = getModuleAccessPath(instance, self);
|
|
Module m = module;
|
|
for (const auto& p : path) {
|
|
if (m.attr(p).isModule()) {
|
|
m = m.attr(p).toModule();
|
|
} else {
|
|
return std::nullopt;
|
|
}
|
|
}
|
|
return m;
|
|
}
|
|
|
|
// ==================== filter functions for matches ==============
|
|
bool is_int_constant(
|
|
const Match& match,
|
|
const std::unordered_map<std::string, Value*>& vmap,
|
|
const std::string& vname,
|
|
int value) {
|
|
const auto& match_vmap = match.values_map;
|
|
auto v = toIValue(match_vmap.at(vmap.at(vname)));
|
|
return v && v->isInt() && v->toInt() == value;
|
|
}
|
|
|
|
static bool is_functional(
|
|
const Match& match,
|
|
const std::unordered_map<std::string, Value*>& vmap,
|
|
const std::string& vname,
|
|
const std::string& functional) {
|
|
const auto& match_vmap = match.values_map;
|
|
Value* v = match_vmap.at(vmap.at(vname));
|
|
return v->type()->cast<FunctionType>() && getFuncName(v) == functional;
|
|
}
|
|
|
|
std::string removeTorchMangle(const std::string& orig_name) {
|
|
static std::regex mangle_re("\\.___torch_mangle_\\d+");
|
|
auto qualified_name = std::regex_replace(orig_name, mangle_re, "");
|
|
return qualified_name;
|
|
}
|
|
|
|
std::optional<std::string> getModuleName(Value* value) {
|
|
auto type = value->type()->cast<ClassType>();
|
|
if (type && type->name()) {
|
|
return removeTorchMangle(type->name()->qualifiedName());
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
static bool is_module(
|
|
const Match& match,
|
|
const std::unordered_map<std::string, Value*>& vmap,
|
|
const std::string& vname,
|
|
const std::string& module_qualified_name) {
|
|
const auto& match_vmap = match.values_map;
|
|
Value* v = match_vmap.at(vmap.at(vname));
|
|
auto module_name = getModuleName(v);
|
|
if (module_name.has_value()) {
|
|
return module_name.value() == module_qualified_name;
|
|
}
|
|
return false;
|
|
};
|
|
|
|
bool aten_add_alpha_is_one(
|
|
const Match& match,
|
|
const std::unordered_map<std::string, Value*>& vmap) {
|
|
return is_int_constant(match, vmap, "alpha", 1);
|
|
}
|
|
|
|
bool is_functional_relu(
|
|
const Match& match,
|
|
const std::unordered_map<std::string, Value*>& vmap) {
|
|
return is_functional(match, vmap, "relu", "relu");
|
|
}
|
|
|
|
bool is_relu_module(
|
|
const Match& match,
|
|
const std::unordered_map<std::string, Value*>& vmap) {
|
|
return is_module(
|
|
match, vmap, "relu", "__torch__.torch.nn.modules.activation.ReLU");
|
|
}
|
|
|
|
bool is_linear_module(
|
|
const Match& match,
|
|
const std::unordered_map<std::string, Value*>& vmap) {
|
|
return is_module(
|
|
match, vmap, "linear", "__torch__.torch.nn.modules.linear.Linear");
|
|
}
|
|
|
|
bool is_conv1d_module(
|
|
const Match& match,
|
|
const std::unordered_map<std::string, Value*>& vmap) {
|
|
return is_module(
|
|
match, vmap, "conv", "__torch__.torch.nn.modules.conv.Conv1d");
|
|
}
|
|
|
|
bool is_conv2d_module(
|
|
const Match& match,
|
|
const std::unordered_map<std::string, Value*>& vmap) {
|
|
return is_module(
|
|
match, vmap, "conv", "__torch__.torch.nn.modules.conv.Conv2d");
|
|
}
|
|
|
|
bool is_conv3d_module(
|
|
const Match& match,
|
|
const std::unordered_map<std::string, Value*>& vmap) {
|
|
return is_module(
|
|
match, vmap, "conv", "__torch__.torch.nn.modules.conv.Conv3d");
|
|
}
|
|
|
|
bool is_conv_transpose1d_module(
|
|
const Match& match,
|
|
const std::unordered_map<std::string, Value*>& vmap) {
|
|
return is_module(
|
|
match, vmap, "conv", "__torch__.torch.nn.modules.conv.ConvTranspose1d");
|
|
}
|
|
|
|
bool is_conv_transpose2d_module(
|
|
const Match& match,
|
|
const std::unordered_map<std::string, Value*>& vmap) {
|
|
return is_module(
|
|
match, vmap, "conv", "__torch__.torch.nn.modules.conv.ConvTranspose2d");
|
|
}
|
|
|
|
bool is_batchnorm2d_module(
|
|
const Match& match,
|
|
const std::unordered_map<std::string, Value*>& vmap) {
|
|
bool regnorm = is_module(
|
|
match,
|
|
vmap,
|
|
"batchnorm",
|
|
"__torch__.torch.nn.modules.batchnorm.BatchNorm2d");
|
|
bool naivenorm = is_module(
|
|
match,
|
|
vmap,
|
|
"batchnorm",
|
|
"__torch__.mobile_cv.arch.layers.batch_norm.NaiveSyncBatchNorm");
|
|
return (regnorm || naivenorm);
|
|
}
|
|
|
|
bool is_batchnorm3d_module(
|
|
const Match& match,
|
|
const std::unordered_map<std::string, Value*>& vmap) {
|
|
return is_module(
|
|
match,
|
|
vmap,
|
|
"batchnorm",
|
|
"__torch__.torch.nn.modules.batchnorm.BatchNorm3d");
|
|
}
|
|
|
|
} // namespace torch::jit
|