[19/N] Fix clang-tidy warnings in jit (#133067)

Follows  #132963
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133067
Approved by: https://github.com/Skylion007
This commit is contained in:
cyy 2024-08-13 15:59:43 +00:00 committed by PyTorch MergeBot
parent 2e7d67e6af
commit 2f30473fba
39 changed files with 135 additions and 272 deletions

View File

@ -13,7 +13,6 @@ using DebugHandle = int64_t;
class Function;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct Code {
std::vector<Instruction> instructions_;
std::vector<DebugHandle> debug_handles_;

View File

@ -7,8 +7,7 @@
#include <string>
namespace torch {
namespace jit {
namespace torch::jit {
using caffe2::serialize::IStreamAdapter;
using caffe2::serialize::PyTorchStreamWriter;
@ -87,5 +86,4 @@ bool _backport_for_mobile_impl(
return backportManager.backport(oss, writer, from_version, to_version);
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,17 +2,8 @@
#include <c10/macros/Export.h>
#include <istream>
#include <memory>
namespace caffe2 {
namespace serialize {
class ReadAdapterInterface;
class PyTorchStreamWriter;
} // namespace serialize
} // namespace caffe2
namespace torch {
namespace jit {
namespace torch::jit {
TORCH_API bool _backport_for_mobile(
std::istream& in,
@ -34,5 +25,4 @@ TORCH_API bool _backport_for_mobile(
const std::string& output_filename,
const int64_t to_version);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -12,10 +12,8 @@
#include <cstddef>
#include <sstream>
namespace torch {
namespace jit {
namespace torch::jit {
using caffe2::serialize::IStreamAdapter;
using caffe2::serialize::PyTorchStreamReader;
using caffe2::serialize::PyTorchStreamWriter;
@ -117,7 +115,6 @@ void write_archive_current(
data_pickle.stop();
// write out tensor data
size_t i = 0;
std::string prefix = archive_name + "/";
TORCH_INTERNAL_ASSERT(tensor_names.size() == data_pickle.tensorData().size());
const std::unordered_set<std::string>& pre_serialized_files =
@ -388,8 +385,8 @@ Thus, the backport is necessary such that the bytecode operator table contains
number of specified arguments.
*/
std::stringstream backport_v6_to_v5(std::stringstream& input_model_stream) {
std::shared_ptr<IStreamAdapter> rai =
std::make_shared<IStreamAdapter>(&input_model_stream);
auto rai =
std::make_shared<caffe2::serialize::IStreamAdapter>(&input_model_stream);
auto reader = std::make_shared<PyTorchStreamReader>(rai);
// If there are debug info files in the original model file, it should also
@ -453,11 +450,11 @@ push in the stack. Thus, the backport is necessary such that the bytecode
contains all the arguments as before.
*/
std::stringstream backport_v7_to_v6(std::stringstream& input_model_stream) {
std::shared_ptr<IStreamAdapter> rai =
std::make_shared<IStreamAdapter>(&input_model_stream);
auto rai =
std::make_shared<caffe2::serialize::IStreamAdapter>(&input_model_stream);
auto reader = std::make_shared<PyTorchStreamReader>(rai);
auto constants_values =
std::move(*readArchive(kArchiveNameConstants, *reader.get()).toTuple())
std::move(*readArchive(kArchiveNameConstants, *reader).toTuple())
.elements();
// If there are debug info files in the original model file, it should also
@ -526,8 +523,8 @@ std::stringstream backport_v9_to_v8(std::stringstream& input_model_stream) {
}
std::stringstream backport_v8_to_v7(std::stringstream& input_model_stream) {
std::shared_ptr<IStreamAdapter> rai =
std::make_shared<IStreamAdapter>(&input_model_stream);
auto rai =
std::make_shared<caffe2::serialize::IStreamAdapter>(&input_model_stream);
auto reader = std::make_shared<PyTorchStreamReader>(rai);
// extra_files are kept
auto records = reader->getAllRecords();
@ -696,5 +693,4 @@ bool BackportManager::backport(
return backport_success;
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -8,17 +8,11 @@ namespace c10 {
struct IValue;
}
namespace caffe2 {
namespace serialize {
class IStreamAdapter;
class ReadAdapterInterface;
namespace caffe2::serialize {
class PyTorchStreamWriter;
class PyTorchStreamReader;
} // namespace serialize
} // namespace caffe2
} // namespace caffe2::serialize
namespace torch {
namespace jit {
namespace torch::jit {
/*
BackportManager manages a list of backport from n to n-1 function, and provides
@ -51,5 +45,4 @@ class BackportManager final {
backport_function);
};
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -20,8 +20,7 @@ namespace c10 {
TypePtr parseType(const std::string& pythonStr);
} // namespace c10
namespace torch {
namespace jit {
namespace torch::jit {
using caffe2::serialize::FileAdapter;
using caffe2::serialize::IStreamAdapter;
@ -43,7 +42,7 @@ c10::IValue readArchive(
std::shared_ptr<mobile::CompilationUnit> mobile_compilation_unit =
std::make_shared<mobile::CompilationUnit>();
auto obj_loader = [&](const at::StrongTypePtr& type, IValue input) {
auto obj_loader = [&](const at::StrongTypePtr& type, const IValue& input) {
return objLoaderMobile(type, input, *mobile_compilation_unit);
};
bool bytecode_tensor_in_constants_archive =
@ -86,7 +85,7 @@ uint64_t _get_model_bytecode_version(const std::string& filename) {
}
uint64_t _get_model_bytecode_version(
std::shared_ptr<ReadAdapterInterface> rai) {
const std::shared_ptr<ReadAdapterInterface>& rai) {
auto [data, size] = get_rai_content(rai.get());
return _get_model_bytecode_version_from_bytes(data.get(), size);
}
@ -345,7 +344,7 @@ ModelCompatibilityInfo ModelCompatibilityInfo::get(
ModelCompatCheckResult is_compatible(
RuntimeCompatibilityInfo runtime_info,
ModelCompatibilityInfo model_info) {
const ModelCompatibilityInfo& model_info) {
ModelCompatCheckResult result = {ModelCompatibilityStatus::OK, {}};
// Check that the models bytecode version is less than or equal to
// kMaxSupportedBytecodeVersion from the runtime
@ -444,5 +443,4 @@ ModelCompatCheckResult is_compatible(
return result;
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -1,21 +1,20 @@
#pragma once
#include <ATen/core/ivalue.h>
#include <c10/macros/Export.h>
#include <torch/csrc/jit/mobile/compatibility/runtime_compatibility.h>
#include <istream>
#include <memory>
#include <unordered_map>
#include <vector>
namespace caffe2 {
namespace serialize {
namespace caffe2::serialize {
class PyTorchStreamReader;
class ReadAdapterInterface;
} // namespace serialize
} // namespace caffe2
} // namespace caffe2::serialize
namespace torch {
namespace jit {
namespace torch::jit {
// The family of methods below to get bytecode version from a model
// Throws if not passed in a well formed model
@ -24,7 +23,7 @@ TORCH_API uint64_t _get_model_bytecode_version(std::istream& in);
TORCH_API uint64_t _get_model_bytecode_version(const std::string& filename);
TORCH_API uint64_t _get_model_bytecode_version(
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai);
const std::shared_ptr<caffe2::serialize::ReadAdapterInterface>& rai);
uint64_t _get_model_bytecode_version(
const std::vector<c10::IValue>& bytecode_ivalues);
@ -94,13 +93,12 @@ enum ModelCompatibilityStatus {
struct ModelCompatCheckResult {
ModelCompatibilityStatus status;
std::vector<std::string> errors;
std::vector<std::string> errors{};
};
// Takes in information about a runtime and a model and returns if the two are
// compatible with one another.
TORCH_API ModelCompatCheckResult is_compatible(
RuntimeCompatibilityInfo runtime_info,
ModelCompatibilityInfo model_info);
const ModelCompatibilityInfo& model_info);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -11,8 +11,7 @@ namespace c10 {
TypePtr parseType(const std::string& pythonStr);
} // namespace c10
namespace torch {
namespace jit {
namespace torch::jit {
uint64_t _get_runtime_bytecode_version() {
return caffe2::serialize::kMaxSupportedBytecodeVersion;
@ -40,7 +39,7 @@ std::unordered_map<std::string, OperatorInfo> _get_runtime_ops_and_info() {
auto nonDispatcherOperators = torch::jit::getAllOperators();
for (const auto& full_op : nonDispatcherOperators) {
auto op = full_op->schema();
int num_schema_args = op.arguments().size();
auto num_schema_args = op.arguments().size();
auto op_name = op.name();
if (!op.overload_name().empty()) {
op_name += ("." + op.overload_name());
@ -94,5 +93,4 @@ TORCH_API std::unordered_set<std::string> _get_loaded_custom_classes() {
return torch::getAllCustomClassesNames();
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -7,8 +7,7 @@
#include <unordered_map>
#include <unordered_set>
namespace torch {
namespace jit {
namespace torch::jit {
// Struct storing metadata of an operator that can be useful for versioning
struct OperatorInfo {
@ -40,5 +39,4 @@ TORCH_API std::unordered_set<std::string> _get_mobile_supported_types();
TORCH_API std::unordered_set<std::string> _get_loaded_custom_classes();
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -9,7 +9,7 @@
#include <c10/util/irange.h>
#include <caffe2/serialize/in_memory_adapter.h>
#include <caffe2/serialize/inline_container.h>
#include <caffe2/serialize/read_adapter_interface.h>
#include <caffe2/serialize/istream_adapter.h>
#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/mobile/file_format.h>
@ -318,9 +318,7 @@ void BytecodeDeserializer::parseMethods(
method_i_start = 1;
}
TORCH_CHECK(
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
caffe2::serialize::kMinSupportedBytecodeVersion <= bytecode_version_ &&
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
bytecode_version_ <= caffe2::serialize::kMaxSupportedBytecodeVersion,
"Lite Interpreter version number does not match. ",
"The model version must be between ",
@ -630,7 +628,7 @@ mobile::Module _load_for_mobile(
return _load_mobile_from_bytes(
data, size, device, extra_files, module_load_options);
}
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
auto rai = std::make_unique<caffe2::serialize::IStreamAdapter>(&in);
auto module = _load_for_mobile_impl(
std::move(rai), device, extra_files, module_load_options);
return module;
@ -661,7 +659,7 @@ mobile::Module _load_for_mobile(
data, size, device, extra_files, module_load_options);
}
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
auto rai = std::make_unique<caffe2::serialize::FileAdapter>(filename);
return _load_for_mobile_impl(
std::move(rai), device, extra_files, module_load_options);
}
@ -691,8 +689,7 @@ void _load_extra_only_for_mobile(
auto format = getFileFormat(filename);
switch (format) {
case FileFormat::ZipFileFormat: {
std::unique_ptr<FileAdapter> rai =
std::make_unique<FileAdapter>(filename);
auto rai = std::make_unique<caffe2::serialize::FileAdapter>(filename);
auto reader = std::make_unique<PyTorchStreamReader>(std::move(rai));
BytecodeDeserializer deserializer(std::move(reader));
deserializer.deserialize_only_extra(device, extra_files);

View File

@ -8,8 +8,6 @@
#include <caffe2/serialize/file_adapter.h>
namespace torch::jit {
using caffe2::serialize::FileAdapter;
using caffe2::serialize::IStreamAdapter;
using caffe2::serialize::ReadAdapterInterface;
using ExtraFilesMap = std::unordered_map<std::string, std::string>;

View File

@ -3,8 +3,7 @@
#include <ATen/Functions.h>
#include <ATen/core/ivalue.h>
#include <c10/util/irange.h>
#include <caffe2/serialize/file_adapter.h>
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/mobile/file_format.h>
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
@ -22,9 +21,7 @@
#include <vector>
namespace torch::jit {
using caffe2::serialize::MemoryReadAdapter;
using caffe2::serialize::PyTorchStreamReader;
using caffe2::serialize::ReadAdapterInterface;
namespace {

View File

@ -1,9 +1,7 @@
#include <torch/csrc/jit/mobile/model_tracer/BuildFeatureTracer.h>
#include <mutex>
namespace torch {
namespace jit {
namespace mobile {
namespace torch::jit::mobile {
BuildFeatureTracer::BuildFeatureTracer() {
auto recorder_cb =
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
@ -26,6 +24,4 @@ c10::Synchronized<BuildFeatureTracer::build_feature_type>& BuildFeatureTracer::
return build_features;
}
} // namespace mobile
} // namespace jit
} // namespace torch
} // namespace torch::jit::mobile

View File

@ -6,9 +6,7 @@
#include <set>
#include <string>
namespace torch {
namespace jit {
namespace mobile {
namespace torch::jit::mobile {
/* The BuildFeatureTracer class handles the attachment and removal of a
* recording callback that traces the invocation of code that handles executing
@ -36,6 +34,4 @@ struct BuildFeatureTracer final {
}
};
} // namespace mobile
} // namespace jit
} // namespace torch
} // namespace torch::jit::mobile

View File

@ -1,9 +1,7 @@
#include <torch/csrc/jit/mobile/model_tracer/CustomClassTracer.h>
#include <mutex>
namespace torch {
namespace jit {
namespace mobile {
namespace torch::jit::mobile {
CustomClassTracer::CustomClassTracer() {
auto recorder_cb =
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
@ -25,6 +23,4 @@ c10::Synchronized<CustomClassTracer::custom_classes_type>& CustomClassTracer::
return loaded_classes;
}
} // namespace mobile
} // namespace jit
} // namespace torch
} // namespace torch::jit::mobile

View File

@ -6,9 +6,7 @@
#include <set>
#include <string>
namespace torch {
namespace jit {
namespace mobile {
namespace torch::jit::mobile {
/* The CustomClassTracer class handles the attachment and removal of a recording
* callback that traces the invocation of code that handles loading custom
@ -36,6 +34,4 @@ struct CustomClassTracer final {
}
};
} // namespace mobile
} // namespace jit
} // namespace torch
} // namespace torch::jit::mobile

View File

@ -4,9 +4,7 @@
#include <set>
#include <string>
namespace torch {
namespace jit {
namespace mobile {
namespace torch::jit::mobile {
KernelDTypeTracer::KernelDTypeTracer() {
auto recorder_cb =
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
@ -32,6 +30,4 @@ c10::Synchronized<KernelDTypeTracer::kernel_tags_type>& KernelDTypeTracer::
return called_kernel_tags;
}
} // namespace mobile
} // namespace jit
} // namespace torch
} // namespace torch::jit::mobile

View File

@ -6,9 +6,7 @@
#include <set>
#include <string>
namespace torch {
namespace jit {
namespace mobile {
namespace torch::jit::mobile {
/* The KernelDTypeTracer class handles the attachment and removal of a recording
* callback that traces the invocation of code that handles specific dtypes in
* kernel function implementations that are tagged with specific tags.
@ -36,6 +34,4 @@ struct KernelDTypeTracer final {
at::removeCallback(handle_);
}
};
} // namespace mobile
} // namespace jit
} // namespace torch
} // namespace torch::jit::mobile

View File

@ -1,9 +1,7 @@
#include <torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h>
#include <torch/csrc/jit/mobile/model_tracer/TensorUtils.h>
namespace torch {
namespace jit {
namespace mobile {
namespace torch::jit::mobile {
std::vector<std::vector<at::IValue>> MobileModelRunner::
ivalue_to_bundled_inputs(const c10::IValue& bundled_inputs) {
@ -53,8 +51,8 @@ std::unordered_map<std::string, std::string> MobileModelRunner::
std::unordered_map<std::string, std::string> ret;
for (auto& input : all_inputs) {
at::IValue function_name = input.key();
at::IValue nested_dict = input.value();
const at::IValue& function_name = input.key();
const at::IValue& nested_dict = input.value();
CAFFE_ENFORCE(
function_name.isString(),
"Expected function with inputs to be a string ",
@ -74,8 +72,8 @@ std::unordered_map<std::string, std::string> MobileModelRunner::
std::unordered_map<std::string, std::vector<std::string>>
function_and_info_dict;
for (auto& entry : function_and_info_ival_dict) {
at::IValue key = entry.key();
at::IValue value = entry.value();
const at::IValue& key = entry.key();
const at::IValue& value = entry.value();
CAFFE_ENFORCE(
key.isString(),
"Expected extra information key to be a string ",
@ -232,6 +230,4 @@ void MobileModelRunner::for_each_tensor_in_bundled_inputs(
for_each_tensor_in_ivalue(iv, func);
}
}
} // namespace mobile
} // namespace jit
} // namespace torch
} // namespace torch::jit::mobile

View File

@ -9,9 +9,7 @@
#include <torch/csrc/jit/serialization/export.h>
#include <torch/script.h>
namespace torch {
namespace jit {
namespace mobile {
namespace torch::jit::mobile {
class MobileModelRunner {
std::shared_ptr<torch::jit::mobile::Module> module_;
@ -145,6 +143,4 @@ class MobileModelRunner {
void run_argless_functions(const std::vector<std::string>& functions);
};
} // namespace mobile
} // namespace jit
} // namespace torch
} // namespace torch::jit::mobile

View File

@ -1,8 +1,6 @@
#include <torch/csrc/jit/mobile/model_tracer/OperatorCallTracer.h>
namespace torch {
namespace jit {
namespace mobile {
namespace torch::jit::mobile {
OperatorCallTracer::OperatorCallTracer() {
getCalledOperators().withLock([](std::set<std::string>& called_operators) {
called_operators.clear();
@ -24,6 +22,4 @@ OperatorCallTracer::OperatorCallTracer() {
.scopes({at::RecordScope::FUNCTION}));
}
} // namespace mobile
} // namespace jit
} // namespace torch
} // namespace torch::jit::mobile

View File

@ -3,9 +3,7 @@
#include <ATen/record_function.h>
#include <c10/util/Synchronized.h>
namespace torch {
namespace jit {
namespace mobile {
namespace torch::jit::mobile {
/* The OperatorCallTracer class handles the attachment and removal of a
* recording callback that traces invocation of ATen (and other) PyTorch
* operators that get called via the Dispatcher.
@ -31,6 +29,4 @@ struct OperatorCallTracer final {
at::removeCallback(handle_);
}
};
} // namespace mobile
} // namespace jit
} // namespace torch
} // namespace torch::jit::mobile

View File

@ -1,9 +1,7 @@
#include <c10/util/Exception.h>
#include <torch/csrc/jit/mobile/model_tracer/TensorUtils.h>
namespace torch {
namespace jit {
namespace mobile {
namespace torch::jit::mobile {
void for_each_tensor_in_ivalue(
const c10::IValue& iv,
std::function<void(const ::at::Tensor&)> const& func) {
@ -37,6 +35,4 @@ void for_each_tensor_in_ivalue(
AT_ERROR("Unhandled type of IValue. Got ", iv.tagKind());
}
}
} // namespace mobile
} // namespace jit
} // namespace torch
} // namespace torch::jit::mobile

View File

@ -2,9 +2,7 @@
#include <ATen/core/ivalue.h>
namespace torch {
namespace jit {
namespace mobile {
namespace torch::jit::mobile {
/**
* Recursively scan the IValue object, traversing lists, tuples, dicts, and stop
* and call the user provided callback function 'func' when a Tensor is found.
@ -12,6 +10,4 @@ namespace mobile {
void for_each_tensor_in_ivalue(
const ::c10::IValue& iv,
std::function<void(const ::at::Tensor&)> const& func);
} // namespace mobile
} // namespace jit
} // namespace torch
} // namespace torch::jit::mobile

View File

@ -14,9 +14,7 @@
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/script.h>
namespace torch {
namespace jit {
namespace mobile {
namespace torch::jit::mobile {
// Fetched from caffe2/aten/src/ATen/native/metal/MetalAten.mm
// Diffusion Link: https://fburl.com/diffusion/atwwmax2
@ -49,7 +47,7 @@ const std::vector<std::string> gpu_metal_operators = {
* If/When this list becomes too long, we can consider making it a
* per-model list.
*/
void call_setup_methods() {
static void call_setup_methods() {
at::zeros({2, 2});
at::ones({2, 2});
at::Tensor t1 = at::empty({7, 7});
@ -98,7 +96,7 @@ void call_setup_methods() {
* under certain conditions but may avoid getting called in the trace due to the
* narrow nature of bundled inputs
*/
void call_dependent_methods(std::set<std::string>& root_ops) {
static void call_dependent_methods(std::set<std::string>& root_ops) {
bool is_training = false;
bool has_batchnorm = false;
bool has_dropout = false;
@ -135,12 +133,12 @@ void call_dependent_methods(std::set<std::string>& root_ops) {
* Call methods on the Tensor object that we expect to be called
* in production on this Tensor.
*/
void consume_tensor(const at::Tensor& t) {
static void consume_tensor(const at::Tensor& t) {
const at::Tensor& c = t;
c.copy_(t.cpu());
}
std::unordered_map<std::string, c10::FunctionSchema>
static std::unordered_map<std::string, c10::FunctionSchema>
_get_runtime_ops_and_schema() {
std::unordered_map<std::string, c10::FunctionSchema> result;
@ -182,7 +180,7 @@ _get_runtime_ops_and_schema() {
* Scalar? output_min=None, Scalar? output_max=None) ->
* __torch__.torch.classes.xnnpack.LinearOpContext"
*/
void recordCustomClassesFromOpSchemas(
static void recordCustomClassesFromOpSchemas(
std::set<std::string>& root_ops,
std::set<std::string>& traced_ops,
std::set<std::string>& loaded_classes) {
@ -191,7 +189,7 @@ void recordCustomClassesFromOpSchemas(
ops.insert(traced_ops.begin(), traced_ops.end());
auto ops_and_schemas = _get_runtime_ops_and_schema();
auto record_if_class = [&](std::string type_name) {
auto record_if_class = [&](const std::string& type_name) {
// All custom class types start with __torch__ not sure if this is by
// chance or guaranteed
if (type_name.find("__torch__") != std::string::npos) {
@ -225,7 +223,7 @@ void recordCustomClassesFromOpSchemas(
}
}
void run_model(
static void run_model(
const std::string& input_module_path,
std::set<std::string>& root_ops,
std::set<std::string>& enabled_backends,
@ -236,11 +234,11 @@ void run_model(
// TorchBind objects can be traced by the model tracer.
torch::jit::mobile::MobileModelRunner module_runner(input_module_path, 0);
root_ops = module_runner.get_root_operators();
std::cout << "Got " << root_ops.size() << " Root Operators." << std::endl;
std::cout << "Got " << root_ops.size() << " Root Operators." << '\n';
if (torch::jit::mobile::MobileModelRunner::set_has_metal_gpu_operators(
root_ops)) {
std::cout << "Inferred Metal GPU Model." << std::endl;
std::cout << "Inferred Metal GPU Model." << '\n';
root_ops.insert(gpu_metal_operators.begin(), gpu_metal_operators.end());
called_kernel_tags["__unused__"] = {"Float"};
enabled_backends.insert("Metal GPU");
@ -251,7 +249,7 @@ void run_model(
// memory via a call to .metal()).
module_runner.for_each_tensor_in_bundled_inputs(consume_tensor);
} else {
std::cout << "Inferred CPU Model." << std::endl;
std::cout << "Inferred CPU Model." << '\n';
enabled_backends.insert("CPU");
torch::jit::mobile::MobileModelRunner mobile_module_runner(
input_module_path);
@ -341,7 +339,7 @@ TracerResult trace_run(const std::vector<std::string>& input_module_paths) {
} catch (std::exception& ex) {
std::cerr
<< "ModelTracer encountered an error while attempting to run the model in FBGEMM mode"
<< ex.what() << "\n Skipping FBGEMM execution" << std::endl;
<< ex.what() << "\n Skipping FBGEMM execution" << '\n';
}
try {
at::globalContext().setQEngine(at::QEngine::QNNPACK);
@ -351,7 +349,7 @@ TracerResult trace_run(const std::vector<std::string>& input_module_paths) {
} catch (std::exception& ex) {
std::cerr
<< "ModelTracer encountered an error while attempting to run the model under an inference guard"
<< ex.what() << "\n Skipping inference guard execution" << std::endl;
<< ex.what() << "\n Skipping inference guard execution" << '\n';
}
}
@ -393,6 +391,4 @@ TracerResult trace_run(const std::vector<std::string>& input_module_paths) {
return tracer_result;
}
} // namespace mobile
} // namespace jit
} // namespace torch
} // namespace torch::jit::mobile

View File

@ -9,9 +9,7 @@
#include <torch/csrc/jit/mobile/model_tracer/CustomClassTracer.h>
#include <torch/csrc/jit/mobile/model_tracer/KernelDTypeTracer.h>
namespace torch {
namespace jit {
namespace mobile {
namespace torch::jit::mobile {
const std::vector<std::string> always_included_traced_ops = {
// The following are called from setup sections.
@ -38,6 +36,4 @@ TracerResult trace_run(const std::string& input_module_path);
*/
TracerResult trace_run(const std::vector<std::string>& input_module_paths);
} // namespace mobile
} // namespace jit
} // namespace torch
} // namespace torch::jit::mobile

View File

@ -53,26 +53,25 @@ C10_DEFINE_string(
return 1; \
}
void printOpYAML(
static void printOpYAML(
std::ostream& out,
int indent,
const std::string& op_name,
bool is_used_for_training,
bool is_root_operator,
bool include_all_overloads) {
out << std::string(indent, ' ') << op_name << ":" << std::endl;
out << std::string(indent, ' ') << op_name << ":" << '\n';
out << std::string(indent + 2, ' ')
<< "is_used_for_training: " << (is_used_for_training ? "true" : "false")
<< std::endl;
<< '\n';
out << std::string(indent + 2, ' ')
<< "is_root_operator: " << (is_root_operator ? "true" : "false")
<< std::endl;
<< "is_root_operator: " << (is_root_operator ? "true" : "false") << '\n';
out << std::string(indent + 2, ' ')
<< "include_all_overloads: " << (include_all_overloads ? "true" : "false")
<< std::endl;
<< '\n';
}
void printOpsYAML(
static void printOpsYAML(
std::ostream& out,
const std::set<std::string>& operator_list,
bool is_used_for_training,
@ -83,19 +82,19 @@ void printOpsYAML(
}
}
void printDTypeYAML(
static void printDTypeYAML(
std::ostream& out,
int indent,
const std::string& kernel_tag_name,
const std::set<std::string> dtypes) {
const std::set<std::string>& dtypes) {
std::string indent_str = std::string(indent, ' ');
out << indent_str << kernel_tag_name << ":" << std::endl;
out << indent_str << kernel_tag_name << ":" << '\n';
for (auto& dtype : dtypes) {
out << indent_str << "- " << dtype << std::endl;
out << indent_str << "- " << dtype << '\n';
}
}
void printDTypesYAML(
static void printDTypesYAML(
std::ostream& out,
const torch::jit::mobile::KernelDTypeTracer::kernel_tags_type&
kernel_tags) {
@ -104,12 +103,12 @@ void printDTypesYAML(
}
}
void printCustomClassesYAML(
static void printCustomClassesYAML(
std::ostream& out,
const torch::jit::mobile::CustomClassTracer::custom_classes_type&
loaded_classes) {
for (auto& class_name : loaded_classes) {
out << "- " << class_name << std::endl;
out << "- " << class_name << '\n';
}
}
@ -120,7 +119,7 @@ void printCustomClassesYAML(
*/
int main(int argc, char* argv[]) {
if (!c10::ParseCommandLineFlags(&argc, &argv)) {
std::cerr << "Failed to parse command line flags!" << std::endl;
std::cerr << "Failed to parse command line flags!" << '\n';
return 1;
}
@ -130,13 +129,13 @@ int main(int argc, char* argv[]) {
std::istringstream sin(FLAGS_model_input_path);
std::ofstream yaml_out(FLAGS_build_yaml_path);
std::cout << "Output: " << FLAGS_build_yaml_path << std::endl;
std::cout << "Output: " << FLAGS_build_yaml_path << '\n';
torch::jit::mobile::TracerResult tracer_result;
std::vector<std::string> model_input_paths;
for (std::string model_input_path;
std::getline(sin, model_input_path, ',');) {
std::cout << "Processing: " << model_input_path << std::endl;
std::cout << "Processing: " << model_input_path << '\n';
model_input_paths.push_back(model_input_path);
}
@ -147,7 +146,7 @@ int main(int argc, char* argv[]) {
<< "ModelTracer has not been able to load the module for the following reasons:\n"
<< ex.what()
<< "\nPlease consider opening an issue at https://github.com/pytorch/pytorch/issues "
<< "with the detailed error message." << std::endl;
<< "with the detailed error message." << '\n';
throw ex;
}
@ -161,7 +160,7 @@ int main(int argc, char* argv[]) {
". Expected the traced operator list to be bigger then the default size ",
torch::jit::mobile::always_included_traced_ops.size(),
". Please report a bug in PyTorch.")
<< std::endl;
<< '\n';
}
// If the op exist in both traced_ops and root_ops, leave it in root_ops only
@ -172,9 +171,9 @@ int main(int argc, char* argv[]) {
}
}
yaml_out << "include_all_non_op_selectives: false" << std::endl;
yaml_out << "build_features: []" << std::endl;
yaml_out << "operators:" << std::endl;
yaml_out << "include_all_non_op_selectives: false" << '\n';
yaml_out << "build_features: []" << '\n';
yaml_out << "operators:" << '\n';
printOpsYAML(
yaml_out,
tracer_result.root_ops,
@ -192,14 +191,14 @@ int main(int argc, char* argv[]) {
if (tracer_result.called_kernel_tags.empty()) {
yaml_out << " []";
}
yaml_out << std::endl;
yaml_out << '\n';
printDTypesYAML(yaml_out, tracer_result.called_kernel_tags);
yaml_out << "custom_classes:";
if (tracer_result.loaded_classes.empty()) {
yaml_out << " []";
}
yaml_out << std::endl;
yaml_out << '\n';
printCustomClassesYAML(yaml_out, tracer_result.loaded_classes);
return 0;

View File

@ -5,7 +5,6 @@
#include <torch/csrc/jit/mobile/observer.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/runtime/jit_exception.h>
#include <exception>
#include <ATen/record_function.h>
#include <c10/util/ScopeExit.h>
@ -30,7 +29,7 @@ const Function* CompilationUnit::find_function(
}
Function* CompilationUnit::find_function(const c10::QualifiedName& qn) {
// NOLINTNEXTLINE
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
return const_cast<Function*>(
static_cast<const CompilationUnit*>(this)->find_function(qn));
}

View File

@ -73,7 +73,6 @@ void applyUpgrader(mobile::Function* function, uint64_t operator_version) {
for (size_t i = 0; i < code.instructions_.size(); i++) {
Instruction& inst = code.instructions_[i];
if (inst.op == OpCode::OP) {
std::string op_name = code.op_names_[inst.X].name;
std::string operator_name = code.op_names_[inst.X].name +
(code.op_names_[inst.X].overload_name.empty()
? ""

View File

@ -7,7 +7,8 @@ namespace torch::jit {
void tupleIndex(Stack& stack) {
int64_t index = pop(stack).toInt();
auto tuple = pop(stack).toTuple();
auto norm_index = normalizeIndex(index, tuple->elements().size());
auto norm_index =
normalizeIndex(index, static_cast<int64_t>(tuple->elements().size()));
if (norm_index < 0 ||
norm_index >= static_cast<int64_t>(tuple->elements().size())) {
throw std::out_of_range("Tuple list index out of range");

View File

@ -15,8 +15,7 @@
#include <string>
#include <vector>
namespace torch {
namespace jit {
namespace torch::jit {
namespace mobile {
char const* toString(OpCode op);
@ -150,5 +149,4 @@ void _save_parameters(
_save_parameters(map, ifile, use_flatbuffer);
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,8 +2,7 @@
#include <torch/csrc/jit/mobile/module.h>
namespace torch {
namespace jit {
namespace torch::jit {
/**
* Serializes the provided tensor map to the provided stream.
@ -49,5 +48,4 @@ extern void (*_save_mobile_module_to)(
const mobile::Module& module,
const std::function<size_t(const void*, size_t)>& writer_func);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -5,11 +5,7 @@
#include <ATen/ATen.h>
#include <functional>
namespace torch {
namespace jit {
namespace mobile {
namespace torch::jit::mobile {
bool SGDParamGroup::has_options() const {
return options_ != nullptr;
@ -17,12 +13,12 @@ bool SGDParamGroup::has_options() const {
SGDOptions& SGDParamGroup::options() {
TORCH_CHECK(has_options());
return *options_.get();
return *options_;
}
const SGDOptions& SGDParamGroup::options() const {
TORCH_CHECK(has_options());
return *options_.get();
return *options_;
}
void SGDParamGroup::set_options(std::unique_ptr<SGDOptions> options) {
@ -126,6 +122,4 @@ Tensor SGD::step(const LossClosure& closure) {
}
return loss;
}
} // namespace mobile
} // namespace jit
} // namespace torch
} // namespace torch::jit::mobile

View File

@ -1,17 +1,12 @@
#pragma once
#include <torch/arg.h>
#include <torch/nn/module.h>
#include <torch/serialize/archive.h>
#include <torch/types.h>
#include <cstddef>
#include <utility>
#include <vector>
namespace torch {
namespace jit {
namespace mobile {
namespace torch::jit::mobile {
class SGDParamState {
TORCH_ARG(torch::Tensor, momentum_buffer);
@ -127,6 +122,4 @@ class TORCH_API SGD {
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::unique_ptr<SGDOptions> options_;
};
} // namespace mobile
} // namespace jit
} // namespace torch
} // namespace torch::jit::mobile

View File

@ -5,9 +5,7 @@
#include <cstddef>
#include <vector>
namespace torch {
namespace jit {
namespace mobile {
namespace torch::jit::mobile {
RandomSampler::RandomSampler(int64_t size, Dtype index_dtype)
: indices_(torch::randperm(size, index_dtype)) {}
@ -18,7 +16,7 @@ void RandomSampler::reset(std::optional<size_t> new_size) {
// This allocates a new chunk of memory every time (just FYI). It should be
// amortized over the entire epoch hopefully.
const auto size = new_size.value_or(static_cast<size_t>(indices_.numel()));
indices_ = torch::randperm(size, indices_.options());
indices_ = torch::randperm(static_cast<int64_t>(size), indices_.options());
index_ = 0;
}
@ -38,7 +36,7 @@ std::optional<std::vector<size_t>> RandomSampler::next(size_t batch_size) {
slice = slice.to(torch::kInt64);
const auto* data = slice.const_data_ptr<int64_t>();
std::copy(data, data + index_batch.size(), index_batch.begin());
index_ += index_batch.size();
index_ += static_cast<int64_t>(index_batch.size());
return index_batch;
}
@ -54,6 +52,4 @@ size_t RandomSampler::index() const noexcept {
return index_;
}
} // namespace mobile
} // namespace jit
} // namespace torch
} // namespace torch::jit::mobile

View File

@ -7,16 +7,12 @@
#include <cstddef>
#include <vector>
namespace torch {
namespace serialize {
namespace torch::serialize {
class OutputArchive;
class InputArchive;
} // namespace serialize
} // namespace torch
} // namespace torch::serialize
namespace torch {
namespace jit {
namespace mobile {
namespace torch::jit::mobile {
/// A lighter `Sampler` that returns indices randomly and cannot be
/// serialized.
@ -51,6 +47,4 @@ class TORCH_API RandomSampler : public torch::data::samplers::Sampler<> {
int64_t index_ = 0;
};
} // namespace mobile
} // namespace jit
} // namespace torch
} // namespace torch::jit::mobile

View File

@ -5,9 +5,7 @@
#include <cstddef>
#include <vector>
namespace torch {
namespace jit {
namespace mobile {
namespace torch::jit::mobile {
SequentialSampler::SequentialSampler(size_t size) : size_(size) {}
void SequentialSampler::reset(std::optional<size_t> new_size) {
@ -43,6 +41,4 @@ size_t SequentialSampler::index() const noexcept {
return index_;
}
} // namespace mobile
} // namespace jit
} // namespace torch
} // namespace torch::jit::mobile

View File

@ -7,16 +7,12 @@
#include <cstddef>
#include <vector>
namespace torch {
namespace serialize {
namespace torch::serialize {
class OutputArchive;
class InputArchive;
} // namespace serialize
} // namespace torch
} // namespace torch::serialize
namespace torch {
namespace jit {
namespace mobile {
namespace torch::jit::mobile {
/// A lighter `Sampler` that returns indices sequentially and cannot be
/// serialized.
@ -46,6 +42,4 @@ class TORCH_API SequentialSampler : public torch::data::samplers::Sampler<> {
size_t index_{0};
};
} // namespace mobile
} // namespace jit
} // namespace torch
} // namespace torch::jit::mobile

View File

@ -1,12 +1,10 @@
#pragma once
// #include <ATen/core/ivalue.h>
#include <ATen/core/ivalue_inl.h>
#include <torch/csrc/jit/mobile/code.h>
#include <torch/csrc/jit/mobile/function.h>
#include <torch/csrc/jit/serialization/import_export_functions.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>