[22/N] Fix clang-tidy warnings in jit (#135319)

Follows #134537
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135319
Approved by: https://github.com/titaiwangms
This commit is contained in:
cyy 2024-09-08 17:18:29 +00:00 committed by PyTorch MergeBot
parent cfc227ad43
commit 2196f32475
41 changed files with 159 additions and 260 deletions

View File

@ -164,7 +164,6 @@ void PreprocessCaffe2Ops(std::shared_ptr<Graph>& graph) {
std::shared_ptr<Graph> ToONNX(
std::shared_ptr<Graph>& graph,
::torch::onnx::OperatorExportTypes operator_export_type) {
auto constant_value_map = ConstantValueMap::getInstance();
ConstantValueMap::ClearMaps();
auto new_graph = std::make_shared<Graph>(graph->current_scope());
py::dict env;

View File

@ -1,8 +1,7 @@
#include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h>
#include <torch/csrc/jit/passes/onnx/helper.h>
namespace torch {
namespace jit {
namespace torch::jit {
namespace onnx {
using namespace ::c10::onnx;
}
@ -70,5 +69,4 @@ void CastAllConstantToFloating(Block* block) {
void CastAllConstantToFloating(const std::shared_ptr<Graph>& graph) {
CastAllConstantToFloating(graph->block());
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -4,9 +4,7 @@
#include <memory>
namespace torch {
namespace jit {
namespace torch::jit {
// see .cpp for docs
TORCH_API void CastAllConstantToFloating(const std::shared_ptr<Graph>& graph);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -9,8 +9,7 @@
#include <algorithm>
#include <optional>
namespace torch {
namespace jit {
namespace torch::jit {
namespace onnx {
using namespace ::c10::onnx;
@ -707,5 +706,4 @@ void ConstantFoldONNX(
GRAPH_DUMP("After ConstantFoldONNX:", g);
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -5,8 +5,7 @@
#include <torch/csrc/jit/ir/ir.h>
#include <optional>
namespace torch {
namespace jit {
namespace torch::jit {
const int ONNX_OPSET_9 = 9;
const int ONNX_OPSET_10 = 10;
@ -30,6 +29,4 @@ void ConstantFoldONNX(
std::map<std::string, IValue>& paramDict,
int opset_version);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -7,12 +7,7 @@
#include <string>
#include <unordered_map>
namespace torch {
namespace jit {
namespace onnx {
using namespace ::c10::onnx;
}
namespace torch::jit {
// Meyers Singleton for C++ 14
ConstantValueMap& ConstantValueMap::getInstance() {
@ -290,7 +285,7 @@ void ConstantValueMap::ClearMaps() {
// For debug only.
void ConstantValueMap::PrintMaps() {
std::cout << "Rank/Shape Map:" << std::endl;
std::cout << "Rank/Shape Map:" << '\n';
for (const auto& x : ConstantValueMap::getInstance().rankMap) {
std::stringstream ss;
if (ConstantValueMap::getInstance().shapeMap.find(x.first) !=
@ -308,45 +303,45 @@ void ConstantValueMap::PrintMaps() {
}
}
ss << " (rank = " << x.second << ")";
std::cout << "node " << x.first << ": " << ss.str() << std::endl;
std::cout << "node " << x.first << ": " << ss.str() << '\n';
}
std::cout << std::endl;
std::cout << "Value Map:" << std::endl;
std::cout << '\n';
std::cout << "Value Map:" << '\n';
for (const auto& x : ConstantValueMap::getInstance().tensorValueMap) {
std::cout << "node " << x.first << ": " << x.second << std::endl;
std::cout << "node " << x.first << ": " << x.second << '\n';
}
std::cout << std::endl;
std::cout << "TypeReliable Map:" << std::endl;
std::cout << '\n';
std::cout << "TypeReliable Map:" << '\n';
size_t count = 0;
for (const auto& x : ConstantValueMap::getInstance().typeReliableMap) {
std::cout << "(node " << x.first << ": " << x.second << "), ";
count++;
if (count % 10 == 0) {
std::cout << std::endl;
std::cout << '\n';
}
}
std::cout << std::endl;
std::cout << "UseInferredType Map:" << std::endl;
std::cout << '\n';
std::cout << "UseInferredType Map:" << '\n';
count = 0;
for (const auto& x : ConstantValueMap::getInstance().useInferredTypeMap) {
std::cout << "(node " << x.first << ": " << x.second << "), ";
count++;
if (count % 10 == 0) {
std::cout << std::endl;
std::cout << '\n';
}
}
std::cout << std::endl;
std::cout << "ShapeValue Map:" << std::endl;
std::cout << '\n';
std::cout << "ShapeValue Map:" << '\n';
count = 0;
for (const auto& x : ConstantValueMap::getInstance().shapeValueMap) {
std::cout << "(node " << x.first << ": " << x.second << "), ";
count++;
if (count % 10 == 0) {
std::cout << std::endl;
std::cout << '\n';
}
}
std::cout << std::endl;
std::cout << "InferredShape Map:" << std::endl;
std::cout << '\n';
std::cout << "InferredShape Map:" << '\n';
count = 0;
for (const auto& x : ConstantValueMap::getInstance().inferredShapeData) {
std::cout << "(node " << x.first << ": ";
@ -360,29 +355,28 @@ void ConstantValueMap::PrintMaps() {
std::cout << "), ";
count++;
if (count % 10 == 0) {
std::cout << std::endl;
std::cout << '\n';
}
}
std::cout << std::endl;
std::cout << "SymbolDim Map:" << std::endl;
std::cout << '\n';
std::cout << "SymbolDim Map:" << '\n';
count = 0;
for (const auto& x : ConstantValueMap::getInstance().symbolDimMap) {
std::cout << "(" << x.first << ": " << x.second << "), ";
count++;
if (count % 10 == 0) {
std::cout << std::endl;
std::cout << '\n';
}
}
std::cout << "DimSymbol Map:" << std::endl;
std::cout << "DimSymbol Map:" << '\n';
count = 0;
for (const auto& x : ConstantValueMap::getInstance().dimSymbolMap) {
std::cout << "(" << x.first << ": " << x.second << "), ";
count++;
if (count % 10 == 0) {
std::cout << std::endl;
std::cout << '\n';
}
}
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -13,8 +13,7 @@ C10_DIAGNOSTIC_POP()
#include <mutex>
#include <unordered_map>
namespace torch {
namespace jit {
namespace torch::jit {
using ShapeDataMap =
std::unordered_map<std::string, ::ONNX_NAMESPACE::TensorShapeProto>;
@ -112,8 +111,7 @@ class ConstantValueMap {
// Stores if all graph-level inputs have static shape
std::optional<bool> allGraphInputsStatic;
// True if reliable has been computed for all graph inputs
bool allGraphInputsReliableComputed;
bool allGraphInputsReliableComputed{};
};
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -4,8 +4,7 @@
#include <c10/util/irange.h>
namespace torch {
namespace jit {
namespace torch::jit {
namespace onnx {
using namespace ::c10::onnx;
@ -99,5 +98,4 @@ void DeduplicateInitializers(
buildParamsMapFromValueToParamsMap(valsToParamsMap, paramsDict);
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -4,14 +4,11 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
void DeduplicateInitializers(
std::shared_ptr<Graph>& g,
std::map<std::string, IValue>& paramsDict,
bool is_train);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -1,8 +1,7 @@
#include <torch/csrc/jit/passes/onnx/eliminate_unused_items.h>
#include <torch/csrc/jit/passes/onnx/helper.h>
namespace torch {
namespace jit {
namespace torch::jit {
namespace onnx {
using namespace ::c10::onnx;
@ -16,5 +15,4 @@ void EliminateUnusedItemsONNX(Block* b, ParamMap& paramsDict) {
return;
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,8 +2,7 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
// EliminateUnusedItemsONNX pass is removing unused
// initializers and inputs, this is needed because
@ -12,6 +11,4 @@ void EliminateUnusedItemsONNX(
Block* b,
std::map<std::string, IValue>& paramDict);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -6,8 +6,7 @@
#include <c10/util/irange.h>
#include <algorithm>
namespace torch {
namespace jit {
namespace torch::jit {
namespace onnx {
using namespace ::c10::onnx;
@ -152,5 +151,4 @@ void EvalPeepholeONNX(std::shared_ptr<Graph>& g, ParamMap& paramsDict) {
GRAPH_DUMP("After EvalPeepholeONNX:", g);
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -4,13 +4,10 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
void EvalPeepholeONNX(
std::shared_ptr<Graph>& g,
std::map<std::string, IValue>& paramDict);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -8,19 +8,14 @@
#include <torch/csrc/jit/passes/onnx/peephole.h>
#include <torch/csrc/jit/passes/onnx/shape_type_inference.h>
namespace torch {
namespace jit {
namespace onnx {
using namespace ::c10::onnx;
}
namespace torch::jit {
namespace {
const int ONNX_OPSET_13 = 13;
const int ONNX_TYPE_BOOL = 9;
Node* CreateCastToBoolNode(Value* val, Graph* graph) {
Node* cast_node = graph->create(onnx::Cast);
Node* cast_node = graph->create(c10::onnx::Cast);
cast_node->addInput(val);
cast_node->i_(attr::to, ONNX_TYPE_BOOL);
cast_node->output()->setType(BoolType::get());
@ -149,7 +144,7 @@ std::vector<Value*> ConvertSequenceDependencies(Node* node, int opset_version) {
// Split the added scan_output back to expected tensor sequence.
auto loop_output = loop_node->output(i - 2);
Node* split_node =
loop_node->owningGraph()->create(onnx::SplitToSequence);
loop_node->owningGraph()->create(c10::onnx::SplitToSequence);
loop_output->replaceAllUsesWith(split_node->output());
split_node->i_(attr::keepdims, 0);
split_node->addInput(loop_output);
@ -191,7 +186,7 @@ std::vector<Value*> ConvertSequenceDependencies(Node* node, int opset_version) {
return new_outputs;
}
Node* ONNXOptionalNode(OptionalTypePtr opt_type, Graph* g) {
Node* ONNXOptionalNode(const OptionalTypePtr& opt_type, Graph* g) {
TORCH_INTERNAL_ASSERT(opt_type);
TypePtr elem_type = opt_type->getElementType();
Node* opt_node = g->create(::c10::onnx::Optional, 1);
@ -208,7 +203,7 @@ Node* ONNXOptionalNode(OptionalTypePtr opt_type, Graph* g) {
// 2. Loop Op: insert Optional node before output, if input is Optional type
// or output type is None.
void ReplaceBlockOutputWithOptional(
OptionalTypePtr opt_type,
const OptionalTypePtr& opt_type,
Block* block,
size_t i) {
Node* opt_node = ONNXOptionalNode(opt_type, block->owningGraph());
@ -235,9 +230,9 @@ void FixupONNXSubblockOutputs(Node* n) {
// Identity(None). Also enables shape inference later on, since
// ONNX shape inference doesn't handle None.
if (output->type()->cast<NoneType>()) {
id_node = block->owningGraph()->create(onnx::Optional);
id_node = block->owningGraph()->create(c10::onnx::Optional);
} else {
id_node = block->owningGraph()->create(onnx::Identity);
id_node = block->owningGraph()->create(c10::onnx::Identity);
id_node->addInput(output);
}
id_node->insertBefore(block->return_node());
@ -741,5 +736,4 @@ void FixupONNXControlflowNodeOutputs(Node* n) {
}
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,11 +2,9 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
std::vector<Value*> FixupONNXControlflowNode(Node* n, int opset_version);
void FixupONNXControlflowNodeOutputs(Node* n);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,9 +2,7 @@
#include <torch/csrc/jit/passes/onnx/function_extraction.h>
#include <torch/csrc/jit/passes/onnx/naming.h>
namespace torch {
namespace jit {
namespace onnx {
namespace torch::jit::onnx {
namespace {
@ -75,9 +73,9 @@ struct FunctionExtractor {
using FunctionCtxPtr = FunctionContext*;
using func_ctx_map = std::unordered_map<ScopePtr, FunctionCtxPtr>;
static bool IsValidScope(ScopePtr s);
static bool IsValidScope(const ScopePtr& s);
static std::optional<ScopePtr> InferScope(Node* n);
static bool IsAncestor(ScopePtr parent, ScopePtr child);
static bool IsAncestor(const ScopePtr& parent, ScopePtr child);
static std::optional<ScopePtr> FindCommonAncestor(ScopePtr a, ScopePtr b);
static std::optional<ScopePtr> FindCommonAncestor(const scope_list& scopes);
std::shared_ptr<Graph> ConstructFuncGraph(FunctionContext& ctx);
@ -88,7 +86,9 @@ struct FunctionExtractor {
scope_ctx_map& scope_ctxs,
const std::shared_ptr<Graph>& graph);
static void HandleNoScopeNodes(scope_ctx_map&, node_list no_scope_nlist);
static void HandleNoScopeNodes(
scope_ctx_map&,
const node_list& no_scope_nlist);
std::tuple<scope_ctx_map, node_list> PartitionNodesByScope(Block* b);
scope_ctx_map PartitionNodesByScope(const std::shared_ptr<Graph>& graph);
static std::unordered_map<ScopePtr, scope_list> PartitionIdenticalScopes(
@ -279,11 +279,11 @@ void FunctionExtractor::DebugPrintGraphWithFunction(
GRAPH_UPDATE("Main graph: ", g->toString());
}
bool FunctionExtractor::IsValidScope(ScopePtr s) {
bool FunctionExtractor::IsValidScope(const ScopePtr& s) {
return !s->isRoot() && !s->isBlank();
}
bool FunctionExtractor::IsAncestor(ScopePtr parent, ScopePtr child) {
bool FunctionExtractor::IsAncestor(const ScopePtr& parent, ScopePtr child) {
if (!IsValidScope(parent) || !IsValidScope(child) ||
parent->getDepth() >= child->getDepth()) {
return false;
@ -376,7 +376,7 @@ std::optional<ScopePtr> FunctionExtractor::InferScope(Node* n) {
std::all_of(
output_scopes.begin(),
output_scopes.end(),
[&output_scopes](ScopePtr scope) -> bool {
[&output_scopes](const ScopePtr& scope) -> bool {
return IsValidScope(scope) && scope == output_scopes.at(0);
})) {
return output_scopes.at(0);
@ -385,7 +385,7 @@ std::optional<ScopePtr> FunctionExtractor::InferScope(Node* n) {
std::all_of(
input_scopes.begin(),
input_scopes.end(),
[&input_scopes](ScopePtr scope) -> bool {
[&input_scopes](const ScopePtr& scope) -> bool {
return IsValidScope(scope) && scope == input_scopes.at(0);
})) {
return input_scopes.at(0);
@ -822,7 +822,7 @@ void FunctionExtractor::ScopeContext::PopulateInputsOutputs(
void FunctionExtractor::HandleNoScopeNodes(
scope_ctx_map& scope_ctxs,
node_list no_scope_nlist) {
const node_list& no_scope_nlist) {
GRAPH_UPDATE("No scope node count: ", no_scope_nlist.size());
for (auto n : no_scope_nlist) {
TORCH_WARN(
@ -1181,6 +1181,4 @@ void ONNXTrackScopeAttributes(
}
}
} // namespace onnx
} // namespace jit
} // namespace torch
} // namespace torch::jit::onnx

View File

@ -2,9 +2,6 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
// This api will be used by serialization/export.cpp to extract function
// information. It should do conversion on graph to
// 1. Extract subgraph pattern of functions and define as local function
@ -15,7 +12,7 @@ namespace jit {
// represent these info inside Graph object.
// export.cpp will serialize the ONNX model with function_proto with
// above information.
namespace onnx {
namespace torch::jit::onnx {
// The following return types are used to track information regarding function
// attributes, that are unable to be traced through Torch IR.
@ -64,7 +61,4 @@ TORCH_API void ONNXTrackScopeAttributes(
std::shared_ptr<Graph>& graph,
std::map<std::string, IValue>& attributes);
} // namespace onnx
} // namespace jit
} // namespace torch
} // namespace torch::jit::onnx

View File

@ -4,8 +4,7 @@
#include <torch/csrc/jit/passes/onnx/helper.h>
#include <torch/csrc/jit/passes/onnx/naming.h>
namespace torch {
namespace jit {
namespace torch::jit {
namespace {
@ -193,5 +192,4 @@ void ONNXFunctionCallSubstitution(Graph& graph) {
GRAPH_DUMP("After function call substitution calls: ", &graph);
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,10 +2,8 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
TORCH_API void ONNXFunctionCallSubstitution(Graph& graph);
}
} // namespace torch

View File

@ -12,8 +12,7 @@
#include <onnx/onnx_pb.h>
namespace torch {
namespace jit {
namespace torch::jit {
namespace onnx {
using namespace ::c10::onnx;
@ -296,5 +295,4 @@ void ONNXLintGraph(const std::shared_ptr<Graph>& graph) {
" constants.");
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -3,8 +3,7 @@
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
// Utility functions for PyTorch to ONNX conversion.
@ -73,5 +72,4 @@ class ScalarTypeHashFunction {
}
};
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -4,8 +4,7 @@
#include <torch/csrc/jit/passes/onnx/helper.h>
#include <torch/csrc/jit/passes/onnx/list_model_parameters.h>
namespace torch {
namespace jit {
namespace torch::jit {
namespace onnx {
using namespace ::c10::onnx;
@ -191,5 +190,4 @@ std::pair<Module, std::vector<IValue>> list_module_parameters(
return std::make_pair(moduleClone, parameterIValues);
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -3,11 +3,9 @@
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
TORCH_API std::pair<Module, std::vector<IValue>> list_module_parameters(
const Module& module);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -3,9 +3,7 @@
#include <utility>
namespace torch {
namespace jit {
namespace onnx {
namespace torch::jit::onnx {
namespace ONNXScopeName {
@ -16,7 +14,7 @@ const std::string name_separator = "::";
namespace {
std::string nameFromRoot(
torch::jit::ScopePtr scope,
const torch::jit::ScopePtr& scope,
const std::string& layer_separator,
NameFunc name_func) {
std::string out = (*name_func)(scope);
@ -32,7 +30,7 @@ std::string nameFromRoot(
}
std::pair<std::string, std::string> parseNameFromScope(
torch::jit::ScopePtr scope) {
const torch::jit::ScopePtr& scope) {
std::string full_name = scope->name().toUnqualString();
auto pos = full_name.find(name_separator);
TORCH_CHECK(
@ -55,7 +53,7 @@ std::string variableName(torch::jit::ScopePtr scope) {
}
std::string variableNameFromRoot(
torch::jit::ScopePtr scope,
const torch::jit::ScopePtr& scope,
const std::string& layer_separator) {
return nameFromRoot(scope, layer_separator, &variableName);
}
@ -65,12 +63,12 @@ std::string className(torch::jit::ScopePtr scope) {
}
std::string classNameFromRoot(
torch::jit::ScopePtr scope,
const torch::jit::ScopePtr& scope,
const std::string& layer_separator) {
return nameFromRoot(scope, layer_separator, &className);
}
bool isCompatibleScope(torch::jit::ScopePtr scope) {
bool isCompatibleScope(const torch::jit::ScopePtr& scope) {
return !scope->isRoot() && !scope->isBlank() &&
(std::string(scope->name().toUnqualString()).find(name_separator) !=
std::string::npos);
@ -89,7 +87,7 @@ class NodeNameGenerator {
virtual void CreateNodeName(Node* n) = 0;
void PopulateNodeNames(Block*);
void UpdateOutputsNames(Node* n);
bool IsGraphOutput(const Value* v, const std::shared_ptr<Graph> graph) const;
bool IsGraphOutput(const Value* v, const std::shared_ptr<Graph>& graph) const;
protected:
std::string CreateUniqueName(
@ -99,19 +97,21 @@ class NodeNameGenerator {
std::unordered_map<const Node*, std::string> node_names_;
std::unordered_map<std::string, size_t> base_node_name_counts_;
std::shared_ptr<Graph> graph_;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const std::string layer_separator_ = "/";
};
NodeNameGenerator::~NodeNameGenerator() = default;
class ScopedNodeNameGenerator : public NodeNameGenerator {
public:
ScopedNodeNameGenerator(std::shared_ptr<Graph> g) : NodeNameGenerator(g){};
ScopedNodeNameGenerator(std::shared_ptr<Graph> g)
: NodeNameGenerator(std::move(g)){};
protected:
void CreateNodeName(Node* n) override;
private:
std::string GetFullScopeName(ScopePtr scope);
std::string GetFullScopeName(const ScopePtr& scope);
std::unordered_map<ScopePtr, std::string> full_scope_names_;
std::unordered_map<std::string, size_t> base_scope_name_counts_;
};
@ -131,7 +131,7 @@ std::string NodeNameGenerator::CreateUniqueName(
bool NodeNameGenerator::IsGraphOutput(
const Value* v,
const std::shared_ptr<Graph> graph) const {
const std::shared_ptr<Graph>& graph) const {
for (const auto* graph_output : graph->outputs()) {
if (v == graph_output) {
return true;
@ -185,7 +185,7 @@ void ScopedNodeNameGenerator::CreateNodeName(Node* n) {
n->s_(Symbol::attr(::torch::onnx::kOnnxNodeNameAttribute), node_names_[n]);
}
std::string ScopedNodeNameGenerator::GetFullScopeName(ScopePtr scope) {
std::string ScopedNodeNameGenerator::GetFullScopeName(const ScopePtr& scope) {
if (full_scope_names_.find(scope) == full_scope_names_.end()) {
auto full_scope_name =
ONNXScopeName::variableNameFromRoot(scope, layer_separator_);
@ -202,6 +202,4 @@ void AssignScopedNamesForNodeAndValue(std::shared_ptr<Graph>& graph) {
node_name_generator->PopulateNodeNames();
}
} // namespace onnx
} // namespace jit
} // namespace torch
} // namespace torch::jit::onnx

View File

@ -2,9 +2,7 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace onnx {
namespace torch::jit::onnx {
namespace ONNXScopeName {
@ -13,18 +11,16 @@ std::string createFullScopeName(
const std::string& variable_name);
std::string variableName(torch::jit::ScopePtr scope);
std::string variableNameFromRoot(
torch::jit::ScopePtr scope,
const torch::jit::ScopePtr& scope,
const std::string& layer_separator);
std::string className(torch::jit::ScopePtr scope);
std::string classNameFromRoot(
torch::jit::ScopePtr scope,
const torch::jit::ScopePtr& scope,
const std::string& layer_separator);
bool isCompatibleScope(torch::jit::ScopePtr scope);
bool isCompatibleScope(const torch::jit::ScopePtr& scope);
} // namespace ONNXScopeName
TORCH_API void AssignScopedNamesForNodeAndValue(std::shared_ptr<Graph>& graph);
} // namespace onnx
} // namespace jit
} // namespace torch
} // namespace torch::jit::onnx

View File

@ -1,9 +1,7 @@
#include <torch/csrc/jit/passes/onnx/onnx_log.h>
#include <iostream>
namespace torch {
namespace jit {
namespace onnx {
namespace torch::jit::onnx {
namespace {
bool log_enabled = false;
@ -26,6 +24,4 @@ std::ostream& _get_log_output_stream() {
return out ? *out : std::cout;
}
} // namespace onnx
} // namespace jit
} // namespace torch
} // namespace torch::jit::onnx

View File

@ -4,9 +4,7 @@
#include <ostream>
#include <string>
namespace torch {
namespace jit {
namespace onnx {
namespace torch::jit::onnx {
TORCH_API bool is_log_enabled();
@ -22,6 +20,4 @@ TORCH_API std::ostream& _get_log_output_stream();
<< ::c10::str(__VA_ARGS__) << std::endl; \
}
} // namespace onnx
} // namespace jit
} // namespace torch
} // namespace torch::jit::onnx

View File

@ -23,8 +23,7 @@
typedef SSIZE_T ssize_t;
#endif
namespace torch {
namespace jit {
namespace torch::jit {
namespace onnx {
using namespace ::c10::onnx;
@ -1068,5 +1067,4 @@ void PeepholeOptimizeONNX(
GRAPH_DUMP("After PeepholeOptimizeONNX", graph);
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,13 +2,11 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
void PeepholeOptimizeONNX(
std::shared_ptr<Graph>& graph,
int opset_version,
bool fixed_batch_size);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -3,8 +3,7 @@
#include <torch/csrc/jit/ir/constants.h>
#include <torch/csrc/jit/jit_log.h>
namespace torch {
namespace jit {
namespace torch::jit {
// onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0,
// so before converting the ints to tensors we need to cast them to floats.
@ -43,5 +42,4 @@ void PrepareDivisionForONNX(const std::shared_ptr<Graph>& graph) {
GRAPH_DUMP("After PrepareDivisionForONNX: ", graph);
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,8 +2,7 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
// Prepare division ops for ONNX export. This is necessary for and only used
// by ONNX export.
@ -15,5 +14,4 @@ namespace jit {
//
TORCH_API void PrepareDivisionForONNX(const std::shared_ptr<Graph>& graph);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -6,8 +6,7 @@
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/onnx/helper.h>
namespace torch {
namespace jit {
namespace torch::jit {
namespace onnx {
using namespace ::c10::onnx;
@ -229,5 +228,4 @@ void PreprocessForONNX(std::shared_ptr<Graph>& graph) {
GRAPH_DUMP("After fuseListAndListUnpack: ", graph);
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,10 +2,8 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
void PreprocessForONNX(std::shared_ptr<Graph>& graph);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -13,8 +13,7 @@
#include <limits>
namespace torch {
namespace jit {
namespace torch::jit {
namespace {
@ -368,7 +367,7 @@ static void PrepareForRemoveMutations(MutationRemover& mr, Block* b) {
}
}
static void PrepareForRemoveMutations(std::shared_ptr<Graph> graph) {
static void PrepareForRemoveMutations(const std::shared_ptr<Graph>& graph) {
MutationRemover mr(graph);
PrepareForRemoveMutations(mr, graph->block());
GRAPH_DUMP("After PrepareForRemoveMutations: ", graph);
@ -438,23 +437,23 @@ std::string InplaceConverter::ValueTracker::toString() const {
// ss << "Current graph: " << graph_->toString() << std::endl;
ss << "Tracking " << value_to_sorted_aliases_.size() << " individual values."
<< std::endl;
ss << "value_to_sorted_aliases_: " << std::endl;
<< '\n';
ss << "value_to_sorted_aliases_: " << '\n';
size_t idx = 0;
for (const auto& it : value_to_sorted_aliases_) {
ss << "Value[" << idx << "]: " << it.first->debugName() << std::endl;
ss << "Value[" << idx << "]: " << it.first->debugName() << '\n';
ss << " Mapping to ";
for (auto v : it.second) {
ss << v->debugName() << " ";
}
ss << std::endl;
ss << '\n';
idx++;
}
ss << "alias_to_value_: " << std::endl;
ss << "alias_to_value_: " << '\n';
for (auto it : alias_to_value_) {
ss << " Alias " << it.first->debugName();
ss << " map to " << it.second->debugName() << std::endl;
ss << " map to " << it.second->debugName() << '\n';
}
return ss.str();
@ -890,5 +889,4 @@ void RemoveInplaceOpsForONNX(
ic.convertMutationForONNX();
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,12 +2,10 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
TORCH_API void RemoveInplaceOpsForONNX(
const std::shared_ptr<Graph>& graph,
Module* model);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -4,8 +4,7 @@
#include <torch/csrc/jit/passes/onnx/helper.h>
#include <torch/csrc/jit/passes/onnx/scalar_type_analysis.h>
namespace torch {
namespace jit {
namespace torch::jit {
namespace onnx {
using namespace ::c10::onnx;
@ -479,5 +478,4 @@ void ScalarTypeAnalysisNodeForONNX(Node* n) {
ImplicitCastNodeForONNX(n);
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,8 +2,7 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
TORCH_API void ScalarTypeAnalysisForONNX(
const std::shared_ptr<Graph>& graph,
@ -11,5 +10,4 @@ TORCH_API void ScalarTypeAnalysisForONNX(
int opset_version);
void ScalarTypeAnalysisNodeForONNX(Node* n);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -22,16 +22,15 @@
#include <unordered_set>
#include <utility>
namespace torch {
namespace jit {
namespace torch::jit {
inline bool PyNone_Check(PyObject* o) {
return o == Py_None;
}
std::pair<TypePtr, bool> MergeInferredType(
TypePtr existing_type,
TypePtr inferred_type) {
const TypePtr& existing_type,
const TypePtr& inferred_type) {
auto new_list_type = inferred_type->cast<ListType>();
auto use_inferred_type = false;
if (new_list_type) {
@ -75,8 +74,8 @@ std::pair<TypePtr, bool> MergeInferredType(
void MergeInferredTypeAndSetMap(
Value* dest_v,
TypePtr existing_type,
TypePtr inferred_type) {
const TypePtr& existing_type,
const TypePtr& inferred_type) {
auto [mergedType, inferred] = MergeInferredType(existing_type, inferred_type);
dest_v->setType(mergedType);
ConstantValueMap::SetUseInferredType(dest_v->debugName(), inferred);
@ -256,7 +255,7 @@ bool CustomSettype(Node* node) {
Value* CloneValueFromListConstruct(
Value* v,
std::shared_ptr<Graph> n_graph,
const std::shared_ptr<Graph>& n_graph,
int opset_version) {
auto lc_node = v->node();
TORCH_INTERNAL_ASSERT(lc_node->kind() == ::c10::prim::ListConstruct);
@ -355,7 +354,7 @@ Node* CloneNodeToGraph(
return clone_node;
}
bool HasValidType(TypePtr type, std::string name) {
bool HasValidType(const TypePtr& type, const std::string& name) {
if (auto t_type = type->cast<TensorType>()) {
if (!t_type->scalarType().has_value()) {
GRAPH_UPDATE("Input ", name, " is missing tensor datatype.");
@ -371,7 +370,7 @@ bool HasValidType(TypePtr type, std::string name) {
return true;
}
bool IsGraphValidForInference(std::shared_ptr<Graph> graph) {
bool IsGraphValidForInference(const std::shared_ptr<Graph>& graph) {
// Verify if every input has type (either Tensor, Sequence or Optional) and
// scalar type. This is a requirement for ONNX graph inputs.
for (auto in : graph->inputs()) {
@ -381,7 +380,7 @@ bool IsGraphValidForInference(std::shared_ptr<Graph> graph) {
}
void ConvertGraphToONNXProto(
std::shared_ptr<Graph> graph,
const std::shared_ptr<Graph>& graph,
std::shared_ptr<onnx::ModelProto>& model_proto,
SymbolDimMap& symbol_dim_map,
DimSymbolMap& dim_symbol_map,
@ -1652,7 +1651,8 @@ void SpecialPostProcess(Node* n) {
auto seq_node = n->input(0)->node();
auto t_type = n->input(1)->type()->cast<TensorType>();
auto update_sequence_empty_dtype = [](Node* n, TensorTypePtr t_type) {
auto update_sequence_empty_dtype = [](Node* n,
const TensorTypePtr& t_type) {
TORCH_INTERNAL_ASSERT(n && n->kind() == ::c10::onnx::SequenceEmpty);
TORCH_INTERNAL_ASSERT(t_type && t_type->scalarType().has_value());
auto scalar_type = t_type->scalarType().value();
@ -1711,7 +1711,7 @@ void SpecialPostProcess(Node* n) {
return nullptr;
};
return find_sequence_empty_impl(
input, t_type, find_sequence_empty_impl);
input, std::move(t_type), find_sequence_empty_impl);
};
if (seq_node && t_type && t_type->scalarType()) {
@ -2122,9 +2122,9 @@ void ONNXShapeTypeInference(
case ::c10::onnx::Gather: {
auto* schema_registry = onnx::OpSchemaRegistry::Instance();
onnx::ShapeInferenceOptions options{
/*check_type=*/false,
/*error_mode=*/false,
/*enable_data_propagation=*/true};
/*check_type_val=*/false,
/*strict_mode_val=*/0,
/*data_prop_val=*/true};
onnx::shape_inference::InferShapes(
*model_proto, schema_registry, options, &inferred_shape_data);
break;
@ -2509,5 +2509,4 @@ void UpdateShapeConstantIfReliable(torch::jit::Value* node_output) {
}
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -6,8 +6,7 @@
#include <utility>
namespace torch {
namespace jit {
namespace torch::jit {
// Merges existing_type and inferred_type.
// Returns {merged type, whether or not inferred_type was used}.
@ -28,13 +27,13 @@ namespace jit {
// ONNX represents list of scalars by 1-d Tensor. Return inferred type since
// it is more compatible with ONNX.
std::pair<TypePtr, bool> MergeInferredType(
TypePtr existing_type,
TypePtr inferred_type);
const TypePtr& existing_type,
const TypePtr& inferred_type);
void MergeInferredTypeAndSetMap(
Value* dest_v,
TypePtr existing_type,
TypePtr inferred_type);
const TypePtr& existing_type,
const TypePtr& inferred_type);
// Update graph input types with dynamic axes info.
// Axes that are marked as dynamic will be assigned as dynamic ShapeSymbol.
@ -96,5 +95,4 @@ void UpdateReliable(
void UpdateReliable(torch::jit::Node* n);
void UpdateShapeConstantIfReliable(torch::jit::Value* output);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -14,8 +14,8 @@
#include <ATen/Functions.h>
using ::c10::Dispatcher;
namespace torch {
namespace jit {
namespace torch::jit {
namespace onnx {
using namespace ::c10::onnx;
@ -765,5 +765,4 @@ void insertPermutes(
GRAPH_DUMP("After insertPermutes: ", graph);
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -6,8 +6,7 @@
#include <memory>
namespace torch {
namespace jit {
namespace torch::jit {
TORCH_API void UnpackQuantizedWeights(
std::shared_ptr<Graph>& graph,
@ -15,5 +14,4 @@ TORCH_API void UnpackQuantizedWeights(
TORCH_API void insertPermutes(
std::shared_ptr<Graph>& graph,
std::map<std::string, IValue>& paramsDict);
} // namespace jit
} // namespace torch
} // namespace torch::jit