mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[22/N] Fix clang-tidy warnings in jit (#135319)
Follows #134537 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135319 Approved by: https://github.com/titaiwangms
This commit is contained in:
parent
cfc227ad43
commit
2196f32475
|
|
@ -164,7 +164,6 @@ void PreprocessCaffe2Ops(std::shared_ptr<Graph>& graph) {
|
|||
std::shared_ptr<Graph> ToONNX(
|
||||
std::shared_ptr<Graph>& graph,
|
||||
::torch::onnx::OperatorExportTypes operator_export_type) {
|
||||
auto constant_value_map = ConstantValueMap::getInstance();
|
||||
ConstantValueMap::ClearMaps();
|
||||
auto new_graph = std::make_shared<Graph>(graph->current_scope());
|
||||
py::dict env;
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
#include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h>
|
||||
#include <torch/csrc/jit/passes/onnx/helper.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
namespace onnx {
|
||||
using namespace ::c10::onnx;
|
||||
}
|
||||
|
|
@ -70,5 +69,4 @@ void CastAllConstantToFloating(Block* block) {
|
|||
void CastAllConstantToFloating(const std::shared_ptr<Graph>& graph) {
|
||||
CastAllConstantToFloating(graph->block());
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -4,9 +4,7 @@
|
|||
|
||||
#include <memory>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
// see .cpp for docs
|
||||
TORCH_API void CastAllConstantToFloating(const std::shared_ptr<Graph>& graph);
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -9,8 +9,7 @@
|
|||
#include <algorithm>
|
||||
#include <optional>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
namespace onnx {
|
||||
using namespace ::c10::onnx;
|
||||
|
|
@ -707,5 +706,4 @@ void ConstantFoldONNX(
|
|||
GRAPH_DUMP("After ConstantFoldONNX:", g);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -5,8 +5,7 @@
|
|||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <optional>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
const int ONNX_OPSET_9 = 9;
|
||||
const int ONNX_OPSET_10 = 10;
|
||||
|
|
@ -30,6 +29,4 @@ void ConstantFoldONNX(
|
|||
std::map<std::string, IValue>& paramDict,
|
||||
int opset_version);
|
||||
|
||||
} // namespace jit
|
||||
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -7,12 +7,7 @@
|
|||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace onnx {
|
||||
using namespace ::c10::onnx;
|
||||
}
|
||||
namespace torch::jit {
|
||||
|
||||
// Meyer’s Singleton for C++ 14
|
||||
ConstantValueMap& ConstantValueMap::getInstance() {
|
||||
|
|
@ -290,7 +285,7 @@ void ConstantValueMap::ClearMaps() {
|
|||
|
||||
// For debug only.
|
||||
void ConstantValueMap::PrintMaps() {
|
||||
std::cout << "Rank/Shape Map:" << std::endl;
|
||||
std::cout << "Rank/Shape Map:" << '\n';
|
||||
for (const auto& x : ConstantValueMap::getInstance().rankMap) {
|
||||
std::stringstream ss;
|
||||
if (ConstantValueMap::getInstance().shapeMap.find(x.first) !=
|
||||
|
|
@ -308,45 +303,45 @@ void ConstantValueMap::PrintMaps() {
|
|||
}
|
||||
}
|
||||
ss << " (rank = " << x.second << ")";
|
||||
std::cout << "node " << x.first << ": " << ss.str() << std::endl;
|
||||
std::cout << "node " << x.first << ": " << ss.str() << '\n';
|
||||
}
|
||||
std::cout << std::endl;
|
||||
std::cout << "Value Map:" << std::endl;
|
||||
std::cout << '\n';
|
||||
std::cout << "Value Map:" << '\n';
|
||||
for (const auto& x : ConstantValueMap::getInstance().tensorValueMap) {
|
||||
std::cout << "node " << x.first << ": " << x.second << std::endl;
|
||||
std::cout << "node " << x.first << ": " << x.second << '\n';
|
||||
}
|
||||
std::cout << std::endl;
|
||||
std::cout << "TypeReliable Map:" << std::endl;
|
||||
std::cout << '\n';
|
||||
std::cout << "TypeReliable Map:" << '\n';
|
||||
size_t count = 0;
|
||||
for (const auto& x : ConstantValueMap::getInstance().typeReliableMap) {
|
||||
std::cout << "(node " << x.first << ": " << x.second << "), ";
|
||||
count++;
|
||||
if (count % 10 == 0) {
|
||||
std::cout << std::endl;
|
||||
std::cout << '\n';
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
std::cout << "UseInferredType Map:" << std::endl;
|
||||
std::cout << '\n';
|
||||
std::cout << "UseInferredType Map:" << '\n';
|
||||
count = 0;
|
||||
for (const auto& x : ConstantValueMap::getInstance().useInferredTypeMap) {
|
||||
std::cout << "(node " << x.first << ": " << x.second << "), ";
|
||||
count++;
|
||||
if (count % 10 == 0) {
|
||||
std::cout << std::endl;
|
||||
std::cout << '\n';
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
std::cout << "ShapeValue Map:" << std::endl;
|
||||
std::cout << '\n';
|
||||
std::cout << "ShapeValue Map:" << '\n';
|
||||
count = 0;
|
||||
for (const auto& x : ConstantValueMap::getInstance().shapeValueMap) {
|
||||
std::cout << "(node " << x.first << ": " << x.second << "), ";
|
||||
count++;
|
||||
if (count % 10 == 0) {
|
||||
std::cout << std::endl;
|
||||
std::cout << '\n';
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
std::cout << "InferredShape Map:" << std::endl;
|
||||
std::cout << '\n';
|
||||
std::cout << "InferredShape Map:" << '\n';
|
||||
count = 0;
|
||||
for (const auto& x : ConstantValueMap::getInstance().inferredShapeData) {
|
||||
std::cout << "(node " << x.first << ": ";
|
||||
|
|
@ -360,29 +355,28 @@ void ConstantValueMap::PrintMaps() {
|
|||
std::cout << "), ";
|
||||
count++;
|
||||
if (count % 10 == 0) {
|
||||
std::cout << std::endl;
|
||||
std::cout << '\n';
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
std::cout << "SymbolDim Map:" << std::endl;
|
||||
std::cout << '\n';
|
||||
std::cout << "SymbolDim Map:" << '\n';
|
||||
count = 0;
|
||||
for (const auto& x : ConstantValueMap::getInstance().symbolDimMap) {
|
||||
std::cout << "(" << x.first << ": " << x.second << "), ";
|
||||
count++;
|
||||
if (count % 10 == 0) {
|
||||
std::cout << std::endl;
|
||||
std::cout << '\n';
|
||||
}
|
||||
}
|
||||
std::cout << "DimSymbol Map:" << std::endl;
|
||||
std::cout << "DimSymbol Map:" << '\n';
|
||||
count = 0;
|
||||
for (const auto& x : ConstantValueMap::getInstance().dimSymbolMap) {
|
||||
std::cout << "(" << x.first << ": " << x.second << "), ";
|
||||
count++;
|
||||
if (count % 10 == 0) {
|
||||
std::cout << std::endl;
|
||||
std::cout << '\n';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -13,8 +13,7 @@ C10_DIAGNOSTIC_POP()
|
|||
#include <mutex>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
using ShapeDataMap =
|
||||
std::unordered_map<std::string, ::ONNX_NAMESPACE::TensorShapeProto>;
|
||||
|
|
@ -112,8 +111,7 @@ class ConstantValueMap {
|
|||
// Stores if all graph-level inputs have static shape
|
||||
std::optional<bool> allGraphInputsStatic;
|
||||
// True if reliable has been computed for all graph inputs
|
||||
bool allGraphInputsReliableComputed;
|
||||
bool allGraphInputsReliableComputed{};
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -4,8 +4,7 @@
|
|||
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
namespace onnx {
|
||||
using namespace ::c10::onnx;
|
||||
|
|
@ -99,5 +98,4 @@ void DeduplicateInitializers(
|
|||
buildParamsMapFromValueToParamsMap(valsToParamsMap, paramsDict);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -4,14 +4,11 @@
|
|||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
void DeduplicateInitializers(
|
||||
std::shared_ptr<Graph>& g,
|
||||
std::map<std::string, IValue>& paramsDict,
|
||||
bool is_train);
|
||||
|
||||
} // namespace jit
|
||||
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
#include <torch/csrc/jit/passes/onnx/eliminate_unused_items.h>
|
||||
#include <torch/csrc/jit/passes/onnx/helper.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
namespace onnx {
|
||||
using namespace ::c10::onnx;
|
||||
|
|
@ -16,5 +15,4 @@ void EliminateUnusedItemsONNX(Block* b, ParamMap& paramsDict) {
|
|||
return;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -2,8 +2,7 @@
|
|||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
// EliminateUnusedItemsONNX pass is removing unused
|
||||
// initializers and inputs, this is needed because
|
||||
|
|
@ -12,6 +11,4 @@ void EliminateUnusedItemsONNX(
|
|||
Block* b,
|
||||
std::map<std::string, IValue>& paramDict);
|
||||
|
||||
} // namespace jit
|
||||
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -6,8 +6,7 @@
|
|||
#include <c10/util/irange.h>
|
||||
#include <algorithm>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
namespace onnx {
|
||||
using namespace ::c10::onnx;
|
||||
|
|
@ -152,5 +151,4 @@ void EvalPeepholeONNX(std::shared_ptr<Graph>& g, ParamMap& paramsDict) {
|
|||
GRAPH_DUMP("After EvalPeepholeONNX:", g);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -4,13 +4,10 @@
|
|||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
void EvalPeepholeONNX(
|
||||
std::shared_ptr<Graph>& g,
|
||||
std::map<std::string, IValue>& paramDict);
|
||||
|
||||
} // namespace jit
|
||||
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -8,19 +8,14 @@
|
|||
#include <torch/csrc/jit/passes/onnx/peephole.h>
|
||||
#include <torch/csrc/jit/passes/onnx/shape_type_inference.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace onnx {
|
||||
using namespace ::c10::onnx;
|
||||
}
|
||||
namespace torch::jit {
|
||||
|
||||
namespace {
|
||||
const int ONNX_OPSET_13 = 13;
|
||||
const int ONNX_TYPE_BOOL = 9;
|
||||
|
||||
Node* CreateCastToBoolNode(Value* val, Graph* graph) {
|
||||
Node* cast_node = graph->create(onnx::Cast);
|
||||
Node* cast_node = graph->create(c10::onnx::Cast);
|
||||
cast_node->addInput(val);
|
||||
cast_node->i_(attr::to, ONNX_TYPE_BOOL);
|
||||
cast_node->output()->setType(BoolType::get());
|
||||
|
|
@ -149,7 +144,7 @@ std::vector<Value*> ConvertSequenceDependencies(Node* node, int opset_version) {
|
|||
// Split the added scan_output back to expected tensor sequence.
|
||||
auto loop_output = loop_node->output(i - 2);
|
||||
Node* split_node =
|
||||
loop_node->owningGraph()->create(onnx::SplitToSequence);
|
||||
loop_node->owningGraph()->create(c10::onnx::SplitToSequence);
|
||||
loop_output->replaceAllUsesWith(split_node->output());
|
||||
split_node->i_(attr::keepdims, 0);
|
||||
split_node->addInput(loop_output);
|
||||
|
|
@ -191,7 +186,7 @@ std::vector<Value*> ConvertSequenceDependencies(Node* node, int opset_version) {
|
|||
return new_outputs;
|
||||
}
|
||||
|
||||
Node* ONNXOptionalNode(OptionalTypePtr opt_type, Graph* g) {
|
||||
Node* ONNXOptionalNode(const OptionalTypePtr& opt_type, Graph* g) {
|
||||
TORCH_INTERNAL_ASSERT(opt_type);
|
||||
TypePtr elem_type = opt_type->getElementType();
|
||||
Node* opt_node = g->create(::c10::onnx::Optional, 1);
|
||||
|
|
@ -208,7 +203,7 @@ Node* ONNXOptionalNode(OptionalTypePtr opt_type, Graph* g) {
|
|||
// 2. Loop Op: insert Optional node before output, if input is Optional type
|
||||
// or output type is None.
|
||||
void ReplaceBlockOutputWithOptional(
|
||||
OptionalTypePtr opt_type,
|
||||
const OptionalTypePtr& opt_type,
|
||||
Block* block,
|
||||
size_t i) {
|
||||
Node* opt_node = ONNXOptionalNode(opt_type, block->owningGraph());
|
||||
|
|
@ -235,9 +230,9 @@ void FixupONNXSubblockOutputs(Node* n) {
|
|||
// Identity(None). Also enables shape inference later on, since
|
||||
// ONNX shape inference doesn't handle None.
|
||||
if (output->type()->cast<NoneType>()) {
|
||||
id_node = block->owningGraph()->create(onnx::Optional);
|
||||
id_node = block->owningGraph()->create(c10::onnx::Optional);
|
||||
} else {
|
||||
id_node = block->owningGraph()->create(onnx::Identity);
|
||||
id_node = block->owningGraph()->create(c10::onnx::Identity);
|
||||
id_node->addInput(output);
|
||||
}
|
||||
id_node->insertBefore(block->return_node());
|
||||
|
|
@ -741,5 +736,4 @@ void FixupONNXControlflowNodeOutputs(Node* n) {
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -2,11 +2,9 @@
|
|||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
std::vector<Value*> FixupONNXControlflowNode(Node* n, int opset_version);
|
||||
void FixupONNXControlflowNodeOutputs(Node* n);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -2,9 +2,7 @@
|
|||
#include <torch/csrc/jit/passes/onnx/function_extraction.h>
|
||||
#include <torch/csrc/jit/passes/onnx/naming.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace onnx {
|
||||
namespace torch::jit::onnx {
|
||||
|
||||
namespace {
|
||||
|
||||
|
|
@ -75,9 +73,9 @@ struct FunctionExtractor {
|
|||
using FunctionCtxPtr = FunctionContext*;
|
||||
using func_ctx_map = std::unordered_map<ScopePtr, FunctionCtxPtr>;
|
||||
|
||||
static bool IsValidScope(ScopePtr s);
|
||||
static bool IsValidScope(const ScopePtr& s);
|
||||
static std::optional<ScopePtr> InferScope(Node* n);
|
||||
static bool IsAncestor(ScopePtr parent, ScopePtr child);
|
||||
static bool IsAncestor(const ScopePtr& parent, ScopePtr child);
|
||||
static std::optional<ScopePtr> FindCommonAncestor(ScopePtr a, ScopePtr b);
|
||||
static std::optional<ScopePtr> FindCommonAncestor(const scope_list& scopes);
|
||||
std::shared_ptr<Graph> ConstructFuncGraph(FunctionContext& ctx);
|
||||
|
|
@ -88,7 +86,9 @@ struct FunctionExtractor {
|
|||
scope_ctx_map& scope_ctxs,
|
||||
const std::shared_ptr<Graph>& graph);
|
||||
|
||||
static void HandleNoScopeNodes(scope_ctx_map&, node_list no_scope_nlist);
|
||||
static void HandleNoScopeNodes(
|
||||
scope_ctx_map&,
|
||||
const node_list& no_scope_nlist);
|
||||
std::tuple<scope_ctx_map, node_list> PartitionNodesByScope(Block* b);
|
||||
scope_ctx_map PartitionNodesByScope(const std::shared_ptr<Graph>& graph);
|
||||
static std::unordered_map<ScopePtr, scope_list> PartitionIdenticalScopes(
|
||||
|
|
@ -279,11 +279,11 @@ void FunctionExtractor::DebugPrintGraphWithFunction(
|
|||
GRAPH_UPDATE("Main graph: ", g->toString());
|
||||
}
|
||||
|
||||
bool FunctionExtractor::IsValidScope(ScopePtr s) {
|
||||
bool FunctionExtractor::IsValidScope(const ScopePtr& s) {
|
||||
return !s->isRoot() && !s->isBlank();
|
||||
}
|
||||
|
||||
bool FunctionExtractor::IsAncestor(ScopePtr parent, ScopePtr child) {
|
||||
bool FunctionExtractor::IsAncestor(const ScopePtr& parent, ScopePtr child) {
|
||||
if (!IsValidScope(parent) || !IsValidScope(child) ||
|
||||
parent->getDepth() >= child->getDepth()) {
|
||||
return false;
|
||||
|
|
@ -376,7 +376,7 @@ std::optional<ScopePtr> FunctionExtractor::InferScope(Node* n) {
|
|||
std::all_of(
|
||||
output_scopes.begin(),
|
||||
output_scopes.end(),
|
||||
[&output_scopes](ScopePtr scope) -> bool {
|
||||
[&output_scopes](const ScopePtr& scope) -> bool {
|
||||
return IsValidScope(scope) && scope == output_scopes.at(0);
|
||||
})) {
|
||||
return output_scopes.at(0);
|
||||
|
|
@ -385,7 +385,7 @@ std::optional<ScopePtr> FunctionExtractor::InferScope(Node* n) {
|
|||
std::all_of(
|
||||
input_scopes.begin(),
|
||||
input_scopes.end(),
|
||||
[&input_scopes](ScopePtr scope) -> bool {
|
||||
[&input_scopes](const ScopePtr& scope) -> bool {
|
||||
return IsValidScope(scope) && scope == input_scopes.at(0);
|
||||
})) {
|
||||
return input_scopes.at(0);
|
||||
|
|
@ -822,7 +822,7 @@ void FunctionExtractor::ScopeContext::PopulateInputsOutputs(
|
|||
|
||||
void FunctionExtractor::HandleNoScopeNodes(
|
||||
scope_ctx_map& scope_ctxs,
|
||||
node_list no_scope_nlist) {
|
||||
const node_list& no_scope_nlist) {
|
||||
GRAPH_UPDATE("No scope node count: ", no_scope_nlist.size());
|
||||
for (auto n : no_scope_nlist) {
|
||||
TORCH_WARN(
|
||||
|
|
@ -1181,6 +1181,4 @@ void ONNXTrackScopeAttributes(
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace onnx
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::onnx
|
||||
|
|
|
|||
|
|
@ -2,9 +2,6 @@
|
|||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// This api will be used by serialization/export.cpp to extract function
|
||||
// information. It should do conversion on graph to
|
||||
// 1. Extract subgraph pattern of functions and define as local function
|
||||
|
|
@ -15,7 +12,7 @@ namespace jit {
|
|||
// represent these info inside Graph object.
|
||||
// export.cpp will serialize the ONNX model with function_proto with
|
||||
// above information.
|
||||
namespace onnx {
|
||||
namespace torch::jit::onnx {
|
||||
|
||||
// The following return types are used to track information regarding function
|
||||
// attributes, that are unable to be traced through Torch IR.
|
||||
|
|
@ -64,7 +61,4 @@ TORCH_API void ONNXTrackScopeAttributes(
|
|||
std::shared_ptr<Graph>& graph,
|
||||
std::map<std::string, IValue>& attributes);
|
||||
|
||||
} // namespace onnx
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::onnx
|
||||
|
|
|
|||
|
|
@ -4,8 +4,7 @@
|
|||
#include <torch/csrc/jit/passes/onnx/helper.h>
|
||||
#include <torch/csrc/jit/passes/onnx/naming.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
namespace {
|
||||
|
||||
|
|
@ -193,5 +192,4 @@ void ONNXFunctionCallSubstitution(Graph& graph) {
|
|||
GRAPH_DUMP("After function call substitution calls: ", &graph);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -2,10 +2,8 @@
|
|||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
TORCH_API void ONNXFunctionCallSubstitution(Graph& graph);
|
||||
|
||||
}
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -12,8 +12,7 @@
|
|||
|
||||
#include <onnx/onnx_pb.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
namespace onnx {
|
||||
using namespace ::c10::onnx;
|
||||
|
||||
|
|
@ -296,5 +295,4 @@ void ONNXLintGraph(const std::shared_ptr<Graph>& graph) {
|
|||
" constants.");
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -3,8 +3,7 @@
|
|||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
// Utility functions for PyTorch to ONNX conversion.
|
||||
|
||||
|
|
@ -73,5 +72,4 @@ class ScalarTypeHashFunction {
|
|||
}
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -4,8 +4,7 @@
|
|||
#include <torch/csrc/jit/passes/onnx/helper.h>
|
||||
#include <torch/csrc/jit/passes/onnx/list_model_parameters.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
namespace onnx {
|
||||
using namespace ::c10::onnx;
|
||||
|
|
@ -191,5 +190,4 @@ std::pair<Module, std::vector<IValue>> list_module_parameters(
|
|||
return std::make_pair(moduleClone, parameterIValues);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -3,11 +3,9 @@
|
|||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
TORCH_API std::pair<Module, std::vector<IValue>> list_module_parameters(
|
||||
const Module& module);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -3,9 +3,7 @@
|
|||
|
||||
#include <utility>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace onnx {
|
||||
namespace torch::jit::onnx {
|
||||
|
||||
namespace ONNXScopeName {
|
||||
|
||||
|
|
@ -16,7 +14,7 @@ const std::string name_separator = "::";
|
|||
namespace {
|
||||
|
||||
std::string nameFromRoot(
|
||||
torch::jit::ScopePtr scope,
|
||||
const torch::jit::ScopePtr& scope,
|
||||
const std::string& layer_separator,
|
||||
NameFunc name_func) {
|
||||
std::string out = (*name_func)(scope);
|
||||
|
|
@ -32,7 +30,7 @@ std::string nameFromRoot(
|
|||
}
|
||||
|
||||
std::pair<std::string, std::string> parseNameFromScope(
|
||||
torch::jit::ScopePtr scope) {
|
||||
const torch::jit::ScopePtr& scope) {
|
||||
std::string full_name = scope->name().toUnqualString();
|
||||
auto pos = full_name.find(name_separator);
|
||||
TORCH_CHECK(
|
||||
|
|
@ -55,7 +53,7 @@ std::string variableName(torch::jit::ScopePtr scope) {
|
|||
}
|
||||
|
||||
std::string variableNameFromRoot(
|
||||
torch::jit::ScopePtr scope,
|
||||
const torch::jit::ScopePtr& scope,
|
||||
const std::string& layer_separator) {
|
||||
return nameFromRoot(scope, layer_separator, &variableName);
|
||||
}
|
||||
|
|
@ -65,12 +63,12 @@ std::string className(torch::jit::ScopePtr scope) {
|
|||
}
|
||||
|
||||
std::string classNameFromRoot(
|
||||
torch::jit::ScopePtr scope,
|
||||
const torch::jit::ScopePtr& scope,
|
||||
const std::string& layer_separator) {
|
||||
return nameFromRoot(scope, layer_separator, &className);
|
||||
}
|
||||
|
||||
bool isCompatibleScope(torch::jit::ScopePtr scope) {
|
||||
bool isCompatibleScope(const torch::jit::ScopePtr& scope) {
|
||||
return !scope->isRoot() && !scope->isBlank() &&
|
||||
(std::string(scope->name().toUnqualString()).find(name_separator) !=
|
||||
std::string::npos);
|
||||
|
|
@ -89,7 +87,7 @@ class NodeNameGenerator {
|
|||
virtual void CreateNodeName(Node* n) = 0;
|
||||
void PopulateNodeNames(Block*);
|
||||
void UpdateOutputsNames(Node* n);
|
||||
bool IsGraphOutput(const Value* v, const std::shared_ptr<Graph> graph) const;
|
||||
bool IsGraphOutput(const Value* v, const std::shared_ptr<Graph>& graph) const;
|
||||
|
||||
protected:
|
||||
std::string CreateUniqueName(
|
||||
|
|
@ -99,19 +97,21 @@ class NodeNameGenerator {
|
|||
std::unordered_map<const Node*, std::string> node_names_;
|
||||
std::unordered_map<std::string, size_t> base_node_name_counts_;
|
||||
std::shared_ptr<Graph> graph_;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||
const std::string layer_separator_ = "/";
|
||||
};
|
||||
NodeNameGenerator::~NodeNameGenerator() = default;
|
||||
|
||||
class ScopedNodeNameGenerator : public NodeNameGenerator {
|
||||
public:
|
||||
ScopedNodeNameGenerator(std::shared_ptr<Graph> g) : NodeNameGenerator(g){};
|
||||
ScopedNodeNameGenerator(std::shared_ptr<Graph> g)
|
||||
: NodeNameGenerator(std::move(g)){};
|
||||
|
||||
protected:
|
||||
void CreateNodeName(Node* n) override;
|
||||
|
||||
private:
|
||||
std::string GetFullScopeName(ScopePtr scope);
|
||||
std::string GetFullScopeName(const ScopePtr& scope);
|
||||
std::unordered_map<ScopePtr, std::string> full_scope_names_;
|
||||
std::unordered_map<std::string, size_t> base_scope_name_counts_;
|
||||
};
|
||||
|
|
@ -131,7 +131,7 @@ std::string NodeNameGenerator::CreateUniqueName(
|
|||
|
||||
bool NodeNameGenerator::IsGraphOutput(
|
||||
const Value* v,
|
||||
const std::shared_ptr<Graph> graph) const {
|
||||
const std::shared_ptr<Graph>& graph) const {
|
||||
for (const auto* graph_output : graph->outputs()) {
|
||||
if (v == graph_output) {
|
||||
return true;
|
||||
|
|
@ -185,7 +185,7 @@ void ScopedNodeNameGenerator::CreateNodeName(Node* n) {
|
|||
n->s_(Symbol::attr(::torch::onnx::kOnnxNodeNameAttribute), node_names_[n]);
|
||||
}
|
||||
|
||||
std::string ScopedNodeNameGenerator::GetFullScopeName(ScopePtr scope) {
|
||||
std::string ScopedNodeNameGenerator::GetFullScopeName(const ScopePtr& scope) {
|
||||
if (full_scope_names_.find(scope) == full_scope_names_.end()) {
|
||||
auto full_scope_name =
|
||||
ONNXScopeName::variableNameFromRoot(scope, layer_separator_);
|
||||
|
|
@ -202,6 +202,4 @@ void AssignScopedNamesForNodeAndValue(std::shared_ptr<Graph>& graph) {
|
|||
node_name_generator->PopulateNodeNames();
|
||||
}
|
||||
|
||||
} // namespace onnx
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::onnx
|
||||
|
|
|
|||
|
|
@ -2,9 +2,7 @@
|
|||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace onnx {
|
||||
namespace torch::jit::onnx {
|
||||
|
||||
namespace ONNXScopeName {
|
||||
|
||||
|
|
@ -13,18 +11,16 @@ std::string createFullScopeName(
|
|||
const std::string& variable_name);
|
||||
std::string variableName(torch::jit::ScopePtr scope);
|
||||
std::string variableNameFromRoot(
|
||||
torch::jit::ScopePtr scope,
|
||||
const torch::jit::ScopePtr& scope,
|
||||
const std::string& layer_separator);
|
||||
std::string className(torch::jit::ScopePtr scope);
|
||||
std::string classNameFromRoot(
|
||||
torch::jit::ScopePtr scope,
|
||||
const torch::jit::ScopePtr& scope,
|
||||
const std::string& layer_separator);
|
||||
bool isCompatibleScope(torch::jit::ScopePtr scope);
|
||||
bool isCompatibleScope(const torch::jit::ScopePtr& scope);
|
||||
|
||||
} // namespace ONNXScopeName
|
||||
|
||||
TORCH_API void AssignScopedNamesForNodeAndValue(std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace onnx
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::onnx
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
#include <torch/csrc/jit/passes/onnx/onnx_log.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace onnx {
|
||||
namespace torch::jit::onnx {
|
||||
|
||||
namespace {
|
||||
bool log_enabled = false;
|
||||
|
|
@ -26,6 +24,4 @@ std::ostream& _get_log_output_stream() {
|
|||
return out ? *out : std::cout;
|
||||
}
|
||||
|
||||
} // namespace onnx
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::onnx
|
||||
|
|
|
|||
|
|
@ -4,9 +4,7 @@
|
|||
#include <ostream>
|
||||
#include <string>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace onnx {
|
||||
namespace torch::jit::onnx {
|
||||
|
||||
TORCH_API bool is_log_enabled();
|
||||
|
||||
|
|
@ -22,6 +20,4 @@ TORCH_API std::ostream& _get_log_output_stream();
|
|||
<< ::c10::str(__VA_ARGS__) << std::endl; \
|
||||
}
|
||||
|
||||
} // namespace onnx
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::onnx
|
||||
|
|
|
|||
|
|
@ -23,8 +23,7 @@
|
|||
typedef SSIZE_T ssize_t;
|
||||
#endif
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
namespace onnx {
|
||||
using namespace ::c10::onnx;
|
||||
|
|
@ -1068,5 +1067,4 @@ void PeepholeOptimizeONNX(
|
|||
GRAPH_DUMP("After PeepholeOptimizeONNX", graph);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -2,13 +2,11 @@
|
|||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
void PeepholeOptimizeONNX(
|
||||
std::shared_ptr<Graph>& graph,
|
||||
int opset_version,
|
||||
bool fixed_batch_size);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -3,8 +3,7 @@
|
|||
#include <torch/csrc/jit/ir/constants.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
// onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0,
|
||||
// so before converting the ints to tensors we need to cast them to floats.
|
||||
|
|
@ -43,5 +42,4 @@ void PrepareDivisionForONNX(const std::shared_ptr<Graph>& graph) {
|
|||
GRAPH_DUMP("After PrepareDivisionForONNX: ", graph);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -2,8 +2,7 @@
|
|||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
// Prepare division ops for ONNX export. This is necessary for and only used
|
||||
// by ONNX export.
|
||||
|
|
@ -15,5 +14,4 @@ namespace jit {
|
|||
//
|
||||
TORCH_API void PrepareDivisionForONNX(const std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -6,8 +6,7 @@
|
|||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/onnx/helper.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
namespace onnx {
|
||||
using namespace ::c10::onnx;
|
||||
|
|
@ -229,5 +228,4 @@ void PreprocessForONNX(std::shared_ptr<Graph>& graph) {
|
|||
GRAPH_DUMP("After fuseListAndListUnpack: ", graph);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -2,10 +2,8 @@
|
|||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
void PreprocessForONNX(std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -13,8 +13,7 @@
|
|||
|
||||
#include <limits>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
namespace {
|
||||
|
||||
|
|
@ -368,7 +367,7 @@ static void PrepareForRemoveMutations(MutationRemover& mr, Block* b) {
|
|||
}
|
||||
}
|
||||
|
||||
static void PrepareForRemoveMutations(std::shared_ptr<Graph> graph) {
|
||||
static void PrepareForRemoveMutations(const std::shared_ptr<Graph>& graph) {
|
||||
MutationRemover mr(graph);
|
||||
PrepareForRemoveMutations(mr, graph->block());
|
||||
GRAPH_DUMP("After PrepareForRemoveMutations: ", graph);
|
||||
|
|
@ -438,23 +437,23 @@ std::string InplaceConverter::ValueTracker::toString() const {
|
|||
|
||||
// ss << "Current graph: " << graph_->toString() << std::endl;
|
||||
ss << "Tracking " << value_to_sorted_aliases_.size() << " individual values."
|
||||
<< std::endl;
|
||||
ss << "value_to_sorted_aliases_: " << std::endl;
|
||||
<< '\n';
|
||||
ss << "value_to_sorted_aliases_: " << '\n';
|
||||
size_t idx = 0;
|
||||
for (const auto& it : value_to_sorted_aliases_) {
|
||||
ss << "Value[" << idx << "]: " << it.first->debugName() << std::endl;
|
||||
ss << "Value[" << idx << "]: " << it.first->debugName() << '\n';
|
||||
ss << " Mapping to ";
|
||||
for (auto v : it.second) {
|
||||
ss << v->debugName() << " ";
|
||||
}
|
||||
ss << std::endl;
|
||||
ss << '\n';
|
||||
idx++;
|
||||
}
|
||||
|
||||
ss << "alias_to_value_: " << std::endl;
|
||||
ss << "alias_to_value_: " << '\n';
|
||||
for (auto it : alias_to_value_) {
|
||||
ss << " Alias " << it.first->debugName();
|
||||
ss << " map to " << it.second->debugName() << std::endl;
|
||||
ss << " map to " << it.second->debugName() << '\n';
|
||||
}
|
||||
|
||||
return ss.str();
|
||||
|
|
@ -890,5 +889,4 @@ void RemoveInplaceOpsForONNX(
|
|||
ic.convertMutationForONNX();
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -2,12 +2,10 @@
|
|||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
TORCH_API void RemoveInplaceOpsForONNX(
|
||||
const std::shared_ptr<Graph>& graph,
|
||||
Module* model);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -4,8 +4,7 @@
|
|||
#include <torch/csrc/jit/passes/onnx/helper.h>
|
||||
#include <torch/csrc/jit/passes/onnx/scalar_type_analysis.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
namespace onnx {
|
||||
using namespace ::c10::onnx;
|
||||
|
|
@ -479,5 +478,4 @@ void ScalarTypeAnalysisNodeForONNX(Node* n) {
|
|||
ImplicitCastNodeForONNX(n);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -2,8 +2,7 @@
|
|||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
TORCH_API void ScalarTypeAnalysisForONNX(
|
||||
const std::shared_ptr<Graph>& graph,
|
||||
|
|
@ -11,5 +10,4 @@ TORCH_API void ScalarTypeAnalysisForONNX(
|
|||
int opset_version);
|
||||
void ScalarTypeAnalysisNodeForONNX(Node* n);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -22,16 +22,15 @@
|
|||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
inline bool PyNone_Check(PyObject* o) {
|
||||
return o == Py_None;
|
||||
}
|
||||
|
||||
std::pair<TypePtr, bool> MergeInferredType(
|
||||
TypePtr existing_type,
|
||||
TypePtr inferred_type) {
|
||||
const TypePtr& existing_type,
|
||||
const TypePtr& inferred_type) {
|
||||
auto new_list_type = inferred_type->cast<ListType>();
|
||||
auto use_inferred_type = false;
|
||||
if (new_list_type) {
|
||||
|
|
@ -75,8 +74,8 @@ std::pair<TypePtr, bool> MergeInferredType(
|
|||
|
||||
void MergeInferredTypeAndSetMap(
|
||||
Value* dest_v,
|
||||
TypePtr existing_type,
|
||||
TypePtr inferred_type) {
|
||||
const TypePtr& existing_type,
|
||||
const TypePtr& inferred_type) {
|
||||
auto [mergedType, inferred] = MergeInferredType(existing_type, inferred_type);
|
||||
dest_v->setType(mergedType);
|
||||
ConstantValueMap::SetUseInferredType(dest_v->debugName(), inferred);
|
||||
|
|
@ -256,7 +255,7 @@ bool CustomSettype(Node* node) {
|
|||
|
||||
Value* CloneValueFromListConstruct(
|
||||
Value* v,
|
||||
std::shared_ptr<Graph> n_graph,
|
||||
const std::shared_ptr<Graph>& n_graph,
|
||||
int opset_version) {
|
||||
auto lc_node = v->node();
|
||||
TORCH_INTERNAL_ASSERT(lc_node->kind() == ::c10::prim::ListConstruct);
|
||||
|
|
@ -355,7 +354,7 @@ Node* CloneNodeToGraph(
|
|||
return clone_node;
|
||||
}
|
||||
|
||||
bool HasValidType(TypePtr type, std::string name) {
|
||||
bool HasValidType(const TypePtr& type, const std::string& name) {
|
||||
if (auto t_type = type->cast<TensorType>()) {
|
||||
if (!t_type->scalarType().has_value()) {
|
||||
GRAPH_UPDATE("Input ", name, " is missing tensor datatype.");
|
||||
|
|
@ -371,7 +370,7 @@ bool HasValidType(TypePtr type, std::string name) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool IsGraphValidForInference(std::shared_ptr<Graph> graph) {
|
||||
bool IsGraphValidForInference(const std::shared_ptr<Graph>& graph) {
|
||||
// Verify if every input has type (either Tensor, Sequence or Optional) and
|
||||
// scalar type. This is a requirement for ONNX graph inputs.
|
||||
for (auto in : graph->inputs()) {
|
||||
|
|
@ -381,7 +380,7 @@ bool IsGraphValidForInference(std::shared_ptr<Graph> graph) {
|
|||
}
|
||||
|
||||
void ConvertGraphToONNXProto(
|
||||
std::shared_ptr<Graph> graph,
|
||||
const std::shared_ptr<Graph>& graph,
|
||||
std::shared_ptr<onnx::ModelProto>& model_proto,
|
||||
SymbolDimMap& symbol_dim_map,
|
||||
DimSymbolMap& dim_symbol_map,
|
||||
|
|
@ -1652,7 +1651,8 @@ void SpecialPostProcess(Node* n) {
|
|||
auto seq_node = n->input(0)->node();
|
||||
auto t_type = n->input(1)->type()->cast<TensorType>();
|
||||
|
||||
auto update_sequence_empty_dtype = [](Node* n, TensorTypePtr t_type) {
|
||||
auto update_sequence_empty_dtype = [](Node* n,
|
||||
const TensorTypePtr& t_type) {
|
||||
TORCH_INTERNAL_ASSERT(n && n->kind() == ::c10::onnx::SequenceEmpty);
|
||||
TORCH_INTERNAL_ASSERT(t_type && t_type->scalarType().has_value());
|
||||
auto scalar_type = t_type->scalarType().value();
|
||||
|
|
@ -1711,7 +1711,7 @@ void SpecialPostProcess(Node* n) {
|
|||
return nullptr;
|
||||
};
|
||||
return find_sequence_empty_impl(
|
||||
input, t_type, find_sequence_empty_impl);
|
||||
input, std::move(t_type), find_sequence_empty_impl);
|
||||
};
|
||||
|
||||
if (seq_node && t_type && t_type->scalarType()) {
|
||||
|
|
@ -2122,9 +2122,9 @@ void ONNXShapeTypeInference(
|
|||
case ::c10::onnx::Gather: {
|
||||
auto* schema_registry = onnx::OpSchemaRegistry::Instance();
|
||||
onnx::ShapeInferenceOptions options{
|
||||
/*check_type=*/false,
|
||||
/*error_mode=*/false,
|
||||
/*enable_data_propagation=*/true};
|
||||
/*check_type_val=*/false,
|
||||
/*strict_mode_val=*/0,
|
||||
/*data_prop_val=*/true};
|
||||
onnx::shape_inference::InferShapes(
|
||||
*model_proto, schema_registry, options, &inferred_shape_data);
|
||||
break;
|
||||
|
|
@ -2509,5 +2509,4 @@ void UpdateShapeConstantIfReliable(torch::jit::Value* node_output) {
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -6,8 +6,7 @@
|
|||
|
||||
#include <utility>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
// Merges existing_type and inferred_type.
|
||||
// Returns {merged type, whether or not inferred_type was used}.
|
||||
|
|
@ -28,13 +27,13 @@ namespace jit {
|
|||
// ONNX represents list of scalars by 1-d Tensor. Return inferred type since
|
||||
// it is more compatible with ONNX.
|
||||
std::pair<TypePtr, bool> MergeInferredType(
|
||||
TypePtr existing_type,
|
||||
TypePtr inferred_type);
|
||||
const TypePtr& existing_type,
|
||||
const TypePtr& inferred_type);
|
||||
|
||||
void MergeInferredTypeAndSetMap(
|
||||
Value* dest_v,
|
||||
TypePtr existing_type,
|
||||
TypePtr inferred_type);
|
||||
const TypePtr& existing_type,
|
||||
const TypePtr& inferred_type);
|
||||
|
||||
// Update graph input types with dynamic axes info.
|
||||
// Axes that are marked as dynamic will be assigned as dynamic ShapeSymbol.
|
||||
|
|
@ -96,5 +95,4 @@ void UpdateReliable(
|
|||
void UpdateReliable(torch::jit::Node* n);
|
||||
void UpdateShapeConstantIfReliable(torch::jit::Value* output);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -14,8 +14,8 @@
|
|||
#include <ATen/Functions.h>
|
||||
|
||||
using ::c10::Dispatcher;
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace torch::jit {
|
||||
namespace onnx {
|
||||
using namespace ::c10::onnx;
|
||||
|
||||
|
|
@ -765,5 +765,4 @@ void insertPermutes(
|
|||
GRAPH_DUMP("After insertPermutes: ", graph);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -6,8 +6,7 @@
|
|||
|
||||
#include <memory>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
TORCH_API void UnpackQuantizedWeights(
|
||||
std::shared_ptr<Graph>& graph,
|
||||
|
|
@ -15,5 +14,4 @@ TORCH_API void UnpackQuantizedWeights(
|
|||
TORCH_API void insertPermutes(
|
||||
std::shared_ptr<Graph>& graph,
|
||||
std::map<std::string, IValue>& paramsDict);
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user