mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[19/N] Fix clang-tidy warnings in jit (#133067)
Follows #132963 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133067 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
2e7d67e6af
commit
2f30473fba
|
|
@ -13,7 +13,6 @@ using DebugHandle = int64_t;
|
|||
|
||||
class Function;
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
struct Code {
|
||||
std::vector<Instruction> instructions_;
|
||||
std::vector<DebugHandle> debug_handles_;
|
||||
|
|
|
|||
|
|
@ -7,8 +7,7 @@
|
|||
|
||||
#include <string>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
using caffe2::serialize::IStreamAdapter;
|
||||
using caffe2::serialize::PyTorchStreamWriter;
|
||||
|
|
@ -87,5 +86,4 @@ bool _backport_for_mobile_impl(
|
|||
return backportManager.backport(oss, writer, from_version, to_version);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -2,17 +2,8 @@
|
|||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <istream>
|
||||
#include <memory>
|
||||
|
||||
namespace caffe2 {
|
||||
namespace serialize {
|
||||
class ReadAdapterInterface;
|
||||
class PyTorchStreamWriter;
|
||||
} // namespace serialize
|
||||
} // namespace caffe2
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
TORCH_API bool _backport_for_mobile(
|
||||
std::istream& in,
|
||||
|
|
@ -34,5 +25,4 @@ TORCH_API bool _backport_for_mobile(
|
|||
const std::string& output_filename,
|
||||
const int64_t to_version);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -12,10 +12,8 @@
|
|||
#include <cstddef>
|
||||
#include <sstream>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
using caffe2::serialize::IStreamAdapter;
|
||||
using caffe2::serialize::PyTorchStreamReader;
|
||||
using caffe2::serialize::PyTorchStreamWriter;
|
||||
|
||||
|
|
@ -117,7 +115,6 @@ void write_archive_current(
|
|||
data_pickle.stop();
|
||||
// write out tensor data
|
||||
size_t i = 0;
|
||||
std::string prefix = archive_name + "/";
|
||||
|
||||
TORCH_INTERNAL_ASSERT(tensor_names.size() == data_pickle.tensorData().size());
|
||||
const std::unordered_set<std::string>& pre_serialized_files =
|
||||
|
|
@ -388,8 +385,8 @@ Thus, the backport is necessary such that the bytecode operator table contains
|
|||
number of specified arguments.
|
||||
*/
|
||||
std::stringstream backport_v6_to_v5(std::stringstream& input_model_stream) {
|
||||
std::shared_ptr<IStreamAdapter> rai =
|
||||
std::make_shared<IStreamAdapter>(&input_model_stream);
|
||||
auto rai =
|
||||
std::make_shared<caffe2::serialize::IStreamAdapter>(&input_model_stream);
|
||||
auto reader = std::make_shared<PyTorchStreamReader>(rai);
|
||||
|
||||
// If there are debug info files in the original model file, it should also
|
||||
|
|
@ -453,11 +450,11 @@ push in the stack. Thus, the backport is necessary such that the bytecode
|
|||
contains all the arguments as before.
|
||||
*/
|
||||
std::stringstream backport_v7_to_v6(std::stringstream& input_model_stream) {
|
||||
std::shared_ptr<IStreamAdapter> rai =
|
||||
std::make_shared<IStreamAdapter>(&input_model_stream);
|
||||
auto rai =
|
||||
std::make_shared<caffe2::serialize::IStreamAdapter>(&input_model_stream);
|
||||
auto reader = std::make_shared<PyTorchStreamReader>(rai);
|
||||
auto constants_values =
|
||||
std::move(*readArchive(kArchiveNameConstants, *reader.get()).toTuple())
|
||||
std::move(*readArchive(kArchiveNameConstants, *reader).toTuple())
|
||||
.elements();
|
||||
|
||||
// If there are debug info files in the original model file, it should also
|
||||
|
|
@ -526,8 +523,8 @@ std::stringstream backport_v9_to_v8(std::stringstream& input_model_stream) {
|
|||
}
|
||||
|
||||
std::stringstream backport_v8_to_v7(std::stringstream& input_model_stream) {
|
||||
std::shared_ptr<IStreamAdapter> rai =
|
||||
std::make_shared<IStreamAdapter>(&input_model_stream);
|
||||
auto rai =
|
||||
std::make_shared<caffe2::serialize::IStreamAdapter>(&input_model_stream);
|
||||
auto reader = std::make_shared<PyTorchStreamReader>(rai);
|
||||
// extra_files are kept
|
||||
auto records = reader->getAllRecords();
|
||||
|
|
@ -696,5 +693,4 @@ bool BackportManager::backport(
|
|||
return backport_success;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -8,17 +8,11 @@ namespace c10 {
|
|||
struct IValue;
|
||||
}
|
||||
|
||||
namespace caffe2 {
|
||||
namespace serialize {
|
||||
class IStreamAdapter;
|
||||
class ReadAdapterInterface;
|
||||
namespace caffe2::serialize {
|
||||
class PyTorchStreamWriter;
|
||||
class PyTorchStreamReader;
|
||||
} // namespace serialize
|
||||
} // namespace caffe2
|
||||
} // namespace caffe2::serialize
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
/*
|
||||
BackportManager manages a list of backport from n to n-1 function, and provides
|
||||
|
|
@ -51,5 +45,4 @@ class BackportManager final {
|
|||
backport_function);
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -20,8 +20,7 @@ namespace c10 {
|
|||
TypePtr parseType(const std::string& pythonStr);
|
||||
} // namespace c10
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
using caffe2::serialize::FileAdapter;
|
||||
using caffe2::serialize::IStreamAdapter;
|
||||
|
|
@ -43,7 +42,7 @@ c10::IValue readArchive(
|
|||
|
||||
std::shared_ptr<mobile::CompilationUnit> mobile_compilation_unit =
|
||||
std::make_shared<mobile::CompilationUnit>();
|
||||
auto obj_loader = [&](const at::StrongTypePtr& type, IValue input) {
|
||||
auto obj_loader = [&](const at::StrongTypePtr& type, const IValue& input) {
|
||||
return objLoaderMobile(type, input, *mobile_compilation_unit);
|
||||
};
|
||||
bool bytecode_tensor_in_constants_archive =
|
||||
|
|
@ -86,7 +85,7 @@ uint64_t _get_model_bytecode_version(const std::string& filename) {
|
|||
}
|
||||
|
||||
uint64_t _get_model_bytecode_version(
|
||||
std::shared_ptr<ReadAdapterInterface> rai) {
|
||||
const std::shared_ptr<ReadAdapterInterface>& rai) {
|
||||
auto [data, size] = get_rai_content(rai.get());
|
||||
return _get_model_bytecode_version_from_bytes(data.get(), size);
|
||||
}
|
||||
|
|
@ -345,7 +344,7 @@ ModelCompatibilityInfo ModelCompatibilityInfo::get(
|
|||
|
||||
ModelCompatCheckResult is_compatible(
|
||||
RuntimeCompatibilityInfo runtime_info,
|
||||
ModelCompatibilityInfo model_info) {
|
||||
const ModelCompatibilityInfo& model_info) {
|
||||
ModelCompatCheckResult result = {ModelCompatibilityStatus::OK, {}};
|
||||
// Check that the models bytecode version is less than or equal to
|
||||
// kMaxSupportedBytecodeVersion from the runtime
|
||||
|
|
@ -444,5 +443,4 @@ ModelCompatCheckResult is_compatible(
|
|||
return result;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -1,21 +1,20 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <c10/macros/Export.h>
|
||||
#include <torch/csrc/jit/mobile/compatibility/runtime_compatibility.h>
|
||||
|
||||
#include <istream>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace caffe2 {
|
||||
namespace serialize {
|
||||
namespace caffe2::serialize {
|
||||
class PyTorchStreamReader;
|
||||
class ReadAdapterInterface;
|
||||
} // namespace serialize
|
||||
} // namespace caffe2
|
||||
} // namespace caffe2::serialize
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
// The family of methods below to get bytecode version from a model
|
||||
// Throws if not passed in a well formed model
|
||||
|
|
@ -24,7 +23,7 @@ TORCH_API uint64_t _get_model_bytecode_version(std::istream& in);
|
|||
TORCH_API uint64_t _get_model_bytecode_version(const std::string& filename);
|
||||
|
||||
TORCH_API uint64_t _get_model_bytecode_version(
|
||||
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai);
|
||||
const std::shared_ptr<caffe2::serialize::ReadAdapterInterface>& rai);
|
||||
|
||||
uint64_t _get_model_bytecode_version(
|
||||
const std::vector<c10::IValue>& bytecode_ivalues);
|
||||
|
|
@ -94,13 +93,12 @@ enum ModelCompatibilityStatus {
|
|||
|
||||
struct ModelCompatCheckResult {
|
||||
ModelCompatibilityStatus status;
|
||||
std::vector<std::string> errors;
|
||||
std::vector<std::string> errors{};
|
||||
};
|
||||
// Takes in information about a runtime and a model and returns if the two are
|
||||
// compatible with one another.
|
||||
TORCH_API ModelCompatCheckResult is_compatible(
|
||||
RuntimeCompatibilityInfo runtime_info,
|
||||
ModelCompatibilityInfo model_info);
|
||||
const ModelCompatibilityInfo& model_info);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -11,8 +11,7 @@ namespace c10 {
|
|||
TypePtr parseType(const std::string& pythonStr);
|
||||
} // namespace c10
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
uint64_t _get_runtime_bytecode_version() {
|
||||
return caffe2::serialize::kMaxSupportedBytecodeVersion;
|
||||
|
|
@ -40,7 +39,7 @@ std::unordered_map<std::string, OperatorInfo> _get_runtime_ops_and_info() {
|
|||
auto nonDispatcherOperators = torch::jit::getAllOperators();
|
||||
for (const auto& full_op : nonDispatcherOperators) {
|
||||
auto op = full_op->schema();
|
||||
int num_schema_args = op.arguments().size();
|
||||
auto num_schema_args = op.arguments().size();
|
||||
auto op_name = op.name();
|
||||
if (!op.overload_name().empty()) {
|
||||
op_name += ("." + op.overload_name());
|
||||
|
|
@ -94,5 +93,4 @@ TORCH_API std::unordered_set<std::string> _get_loaded_custom_classes() {
|
|||
return torch::getAllCustomClassesNames();
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -7,8 +7,7 @@
|
|||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
// Struct storing metadata of an operator that can be useful for versioning
|
||||
struct OperatorInfo {
|
||||
|
|
@ -40,5 +39,4 @@ TORCH_API std::unordered_set<std::string> _get_mobile_supported_types();
|
|||
|
||||
TORCH_API std::unordered_set<std::string> _get_loaded_custom_classes();
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
#include <c10/util/irange.h>
|
||||
#include <caffe2/serialize/in_memory_adapter.h>
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
#include <caffe2/serialize/read_adapter_interface.h>
|
||||
#include <caffe2/serialize/istream_adapter.h>
|
||||
#include <caffe2/serialize/versions.h>
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
#include <torch/csrc/jit/mobile/file_format.h>
|
||||
|
|
@ -318,9 +318,7 @@ void BytecodeDeserializer::parseMethods(
|
|||
method_i_start = 1;
|
||||
}
|
||||
TORCH_CHECK(
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
caffe2::serialize::kMinSupportedBytecodeVersion <= bytecode_version_ &&
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
bytecode_version_ <= caffe2::serialize::kMaxSupportedBytecodeVersion,
|
||||
"Lite Interpreter version number does not match. ",
|
||||
"The model version must be between ",
|
||||
|
|
@ -630,7 +628,7 @@ mobile::Module _load_for_mobile(
|
|||
return _load_mobile_from_bytes(
|
||||
data, size, device, extra_files, module_load_options);
|
||||
}
|
||||
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
|
||||
auto rai = std::make_unique<caffe2::serialize::IStreamAdapter>(&in);
|
||||
auto module = _load_for_mobile_impl(
|
||||
std::move(rai), device, extra_files, module_load_options);
|
||||
return module;
|
||||
|
|
@ -661,7 +659,7 @@ mobile::Module _load_for_mobile(
|
|||
data, size, device, extra_files, module_load_options);
|
||||
}
|
||||
|
||||
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
|
||||
auto rai = std::make_unique<caffe2::serialize::FileAdapter>(filename);
|
||||
return _load_for_mobile_impl(
|
||||
std::move(rai), device, extra_files, module_load_options);
|
||||
}
|
||||
|
|
@ -691,8 +689,7 @@ void _load_extra_only_for_mobile(
|
|||
auto format = getFileFormat(filename);
|
||||
switch (format) {
|
||||
case FileFormat::ZipFileFormat: {
|
||||
std::unique_ptr<FileAdapter> rai =
|
||||
std::make_unique<FileAdapter>(filename);
|
||||
auto rai = std::make_unique<caffe2::serialize::FileAdapter>(filename);
|
||||
auto reader = std::make_unique<PyTorchStreamReader>(std::move(rai));
|
||||
BytecodeDeserializer deserializer(std::move(reader));
|
||||
deserializer.deserialize_only_extra(device, extra_files);
|
||||
|
|
|
|||
|
|
@ -8,8 +8,6 @@
|
|||
#include <caffe2/serialize/file_adapter.h>
|
||||
|
||||
namespace torch::jit {
|
||||
using caffe2::serialize::FileAdapter;
|
||||
using caffe2::serialize::IStreamAdapter;
|
||||
using caffe2::serialize::ReadAdapterInterface;
|
||||
using ExtraFilesMap = std::unordered_map<std::string, std::string>;
|
||||
|
||||
|
|
|
|||
|
|
@ -3,8 +3,7 @@
|
|||
#include <ATen/Functions.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <caffe2/serialize/file_adapter.h>
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
#include <torch/csrc/jit/mobile/file_format.h>
|
||||
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
|
||||
|
|
@ -22,9 +21,7 @@
|
|||
#include <vector>
|
||||
|
||||
namespace torch::jit {
|
||||
using caffe2::serialize::MemoryReadAdapter;
|
||||
using caffe2::serialize::PyTorchStreamReader;
|
||||
using caffe2::serialize::ReadAdapterInterface;
|
||||
|
||||
namespace {
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
#include <torch/csrc/jit/mobile/model_tracer/BuildFeatureTracer.h>
|
||||
#include <mutex>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
namespace torch::jit::mobile {
|
||||
BuildFeatureTracer::BuildFeatureTracer() {
|
||||
auto recorder_cb =
|
||||
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||
|
|
@ -26,6 +24,4 @@ c10::Synchronized<BuildFeatureTracer::build_feature_type>& BuildFeatureTracer::
|
|||
return build_features;
|
||||
}
|
||||
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::mobile
|
||||
|
|
|
|||
|
|
@ -6,9 +6,7 @@
|
|||
#include <set>
|
||||
#include <string>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
namespace torch::jit::mobile {
|
||||
|
||||
/* The BuildFeatureTracer class handles the attachment and removal of a
|
||||
* recording callback that traces the invocation of code that handles executing
|
||||
|
|
@ -36,6 +34,4 @@ struct BuildFeatureTracer final {
|
|||
}
|
||||
};
|
||||
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::mobile
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
#include <torch/csrc/jit/mobile/model_tracer/CustomClassTracer.h>
|
||||
#include <mutex>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
namespace torch::jit::mobile {
|
||||
CustomClassTracer::CustomClassTracer() {
|
||||
auto recorder_cb =
|
||||
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||
|
|
@ -25,6 +23,4 @@ c10::Synchronized<CustomClassTracer::custom_classes_type>& CustomClassTracer::
|
|||
return loaded_classes;
|
||||
}
|
||||
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::mobile
|
||||
|
|
|
|||
|
|
@ -6,9 +6,7 @@
|
|||
#include <set>
|
||||
#include <string>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
namespace torch::jit::mobile {
|
||||
|
||||
/* The CustomClassTracer class handles the attachment and removal of a recording
|
||||
* callback that traces the invocation of code that handles loading custom
|
||||
|
|
@ -36,6 +34,4 @@ struct CustomClassTracer final {
|
|||
}
|
||||
};
|
||||
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::mobile
|
||||
|
|
|
|||
|
|
@ -4,9 +4,7 @@
|
|||
#include <set>
|
||||
#include <string>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
namespace torch::jit::mobile {
|
||||
KernelDTypeTracer::KernelDTypeTracer() {
|
||||
auto recorder_cb =
|
||||
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||
|
|
@ -32,6 +30,4 @@ c10::Synchronized<KernelDTypeTracer::kernel_tags_type>& KernelDTypeTracer::
|
|||
return called_kernel_tags;
|
||||
}
|
||||
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::mobile
|
||||
|
|
|
|||
|
|
@ -6,9 +6,7 @@
|
|||
#include <set>
|
||||
#include <string>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
namespace torch::jit::mobile {
|
||||
/* The KernelDTypeTracer class handles the attachment and removal of a recording
|
||||
* callback that traces the invocation of code that handles specific dtypes in
|
||||
* kernel function implementations that are tagged with specific tags.
|
||||
|
|
@ -36,6 +34,4 @@ struct KernelDTypeTracer final {
|
|||
at::removeCallback(handle_);
|
||||
}
|
||||
};
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::mobile
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
#include <torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h>
|
||||
#include <torch/csrc/jit/mobile/model_tracer/TensorUtils.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
namespace torch::jit::mobile {
|
||||
|
||||
std::vector<std::vector<at::IValue>> MobileModelRunner::
|
||||
ivalue_to_bundled_inputs(const c10::IValue& bundled_inputs) {
|
||||
|
|
@ -53,8 +51,8 @@ std::unordered_map<std::string, std::string> MobileModelRunner::
|
|||
|
||||
std::unordered_map<std::string, std::string> ret;
|
||||
for (auto& input : all_inputs) {
|
||||
at::IValue function_name = input.key();
|
||||
at::IValue nested_dict = input.value();
|
||||
const at::IValue& function_name = input.key();
|
||||
const at::IValue& nested_dict = input.value();
|
||||
CAFFE_ENFORCE(
|
||||
function_name.isString(),
|
||||
"Expected function with inputs to be a string ",
|
||||
|
|
@ -74,8 +72,8 @@ std::unordered_map<std::string, std::string> MobileModelRunner::
|
|||
std::unordered_map<std::string, std::vector<std::string>>
|
||||
function_and_info_dict;
|
||||
for (auto& entry : function_and_info_ival_dict) {
|
||||
at::IValue key = entry.key();
|
||||
at::IValue value = entry.value();
|
||||
const at::IValue& key = entry.key();
|
||||
const at::IValue& value = entry.value();
|
||||
CAFFE_ENFORCE(
|
||||
key.isString(),
|
||||
"Expected extra information key to be a string ",
|
||||
|
|
@ -232,6 +230,4 @@ void MobileModelRunner::for_each_tensor_in_bundled_inputs(
|
|||
for_each_tensor_in_ivalue(iv, func);
|
||||
}
|
||||
}
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::mobile
|
||||
|
|
|
|||
|
|
@ -9,9 +9,7 @@
|
|||
#include <torch/csrc/jit/serialization/export.h>
|
||||
#include <torch/script.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
namespace torch::jit::mobile {
|
||||
|
||||
class MobileModelRunner {
|
||||
std::shared_ptr<torch::jit::mobile::Module> module_;
|
||||
|
|
@ -145,6 +143,4 @@ class MobileModelRunner {
|
|||
void run_argless_functions(const std::vector<std::string>& functions);
|
||||
};
|
||||
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::mobile
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
#include <torch/csrc/jit/mobile/model_tracer/OperatorCallTracer.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
namespace torch::jit::mobile {
|
||||
OperatorCallTracer::OperatorCallTracer() {
|
||||
getCalledOperators().withLock([](std::set<std::string>& called_operators) {
|
||||
called_operators.clear();
|
||||
|
|
@ -24,6 +22,4 @@ OperatorCallTracer::OperatorCallTracer() {
|
|||
.scopes({at::RecordScope::FUNCTION}));
|
||||
}
|
||||
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::mobile
|
||||
|
|
|
|||
|
|
@ -3,9 +3,7 @@
|
|||
#include <ATen/record_function.h>
|
||||
#include <c10/util/Synchronized.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
namespace torch::jit::mobile {
|
||||
/* The OperatorCallTracer class handles the attachment and removal of a
|
||||
* recording callback that traces invocation of ATen (and other) PyTorch
|
||||
* operators that get called via the Dispatcher.
|
||||
|
|
@ -31,6 +29,4 @@ struct OperatorCallTracer final {
|
|||
at::removeCallback(handle_);
|
||||
}
|
||||
};
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::mobile
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/jit/mobile/model_tracer/TensorUtils.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
namespace torch::jit::mobile {
|
||||
void for_each_tensor_in_ivalue(
|
||||
const c10::IValue& iv,
|
||||
std::function<void(const ::at::Tensor&)> const& func) {
|
||||
|
|
@ -37,6 +35,4 @@ void for_each_tensor_in_ivalue(
|
|||
AT_ERROR("Unhandled type of IValue. Got ", iv.tagKind());
|
||||
}
|
||||
}
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::mobile
|
||||
|
|
|
|||
|
|
@ -2,9 +2,7 @@
|
|||
|
||||
#include <ATen/core/ivalue.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
namespace torch::jit::mobile {
|
||||
/**
|
||||
* Recursively scan the IValue object, traversing lists, tuples, dicts, and stop
|
||||
* and call the user provided callback function 'func' when a Tensor is found.
|
||||
|
|
@ -12,6 +10,4 @@ namespace mobile {
|
|||
void for_each_tensor_in_ivalue(
|
||||
const ::c10::IValue& iv,
|
||||
std::function<void(const ::at::Tensor&)> const& func);
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::mobile
|
||||
|
|
|
|||
|
|
@ -14,9 +14,7 @@
|
|||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
#include <torch/script.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
namespace torch::jit::mobile {
|
||||
|
||||
// Fetched from caffe2/aten/src/ATen/native/metal/MetalAten.mm
|
||||
// Diffusion Link: https://fburl.com/diffusion/atwwmax2
|
||||
|
|
@ -49,7 +47,7 @@ const std::vector<std::string> gpu_metal_operators = {
|
|||
* If/When this list becomes too long, we can consider making it a
|
||||
* per-model list.
|
||||
*/
|
||||
void call_setup_methods() {
|
||||
static void call_setup_methods() {
|
||||
at::zeros({2, 2});
|
||||
at::ones({2, 2});
|
||||
at::Tensor t1 = at::empty({7, 7});
|
||||
|
|
@ -98,7 +96,7 @@ void call_setup_methods() {
|
|||
* under certain conditions but may avoid getting called in the trace due to the
|
||||
* narrow nature of bundled inputs
|
||||
*/
|
||||
void call_dependent_methods(std::set<std::string>& root_ops) {
|
||||
static void call_dependent_methods(std::set<std::string>& root_ops) {
|
||||
bool is_training = false;
|
||||
bool has_batchnorm = false;
|
||||
bool has_dropout = false;
|
||||
|
|
@ -135,12 +133,12 @@ void call_dependent_methods(std::set<std::string>& root_ops) {
|
|||
* Call methods on the Tensor object that we expect to be called
|
||||
* in production on this Tensor.
|
||||
*/
|
||||
void consume_tensor(const at::Tensor& t) {
|
||||
static void consume_tensor(const at::Tensor& t) {
|
||||
const at::Tensor& c = t;
|
||||
c.copy_(t.cpu());
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, c10::FunctionSchema>
|
||||
static std::unordered_map<std::string, c10::FunctionSchema>
|
||||
_get_runtime_ops_and_schema() {
|
||||
std::unordered_map<std::string, c10::FunctionSchema> result;
|
||||
|
||||
|
|
@ -182,7 +180,7 @@ _get_runtime_ops_and_schema() {
|
|||
* Scalar? output_min=None, Scalar? output_max=None) ->
|
||||
* __torch__.torch.classes.xnnpack.LinearOpContext"
|
||||
*/
|
||||
void recordCustomClassesFromOpSchemas(
|
||||
static void recordCustomClassesFromOpSchemas(
|
||||
std::set<std::string>& root_ops,
|
||||
std::set<std::string>& traced_ops,
|
||||
std::set<std::string>& loaded_classes) {
|
||||
|
|
@ -191,7 +189,7 @@ void recordCustomClassesFromOpSchemas(
|
|||
ops.insert(traced_ops.begin(), traced_ops.end());
|
||||
auto ops_and_schemas = _get_runtime_ops_and_schema();
|
||||
|
||||
auto record_if_class = [&](std::string type_name) {
|
||||
auto record_if_class = [&](const std::string& type_name) {
|
||||
// All custom class types start with __torch__ not sure if this is by
|
||||
// chance or guaranteed
|
||||
if (type_name.find("__torch__") != std::string::npos) {
|
||||
|
|
@ -225,7 +223,7 @@ void recordCustomClassesFromOpSchemas(
|
|||
}
|
||||
}
|
||||
|
||||
void run_model(
|
||||
static void run_model(
|
||||
const std::string& input_module_path,
|
||||
std::set<std::string>& root_ops,
|
||||
std::set<std::string>& enabled_backends,
|
||||
|
|
@ -236,11 +234,11 @@ void run_model(
|
|||
// TorchBind objects can be traced by the model tracer.
|
||||
torch::jit::mobile::MobileModelRunner module_runner(input_module_path, 0);
|
||||
root_ops = module_runner.get_root_operators();
|
||||
std::cout << "Got " << root_ops.size() << " Root Operators." << std::endl;
|
||||
std::cout << "Got " << root_ops.size() << " Root Operators." << '\n';
|
||||
|
||||
if (torch::jit::mobile::MobileModelRunner::set_has_metal_gpu_operators(
|
||||
root_ops)) {
|
||||
std::cout << "Inferred Metal GPU Model." << std::endl;
|
||||
std::cout << "Inferred Metal GPU Model." << '\n';
|
||||
root_ops.insert(gpu_metal_operators.begin(), gpu_metal_operators.end());
|
||||
called_kernel_tags["__unused__"] = {"Float"};
|
||||
enabled_backends.insert("Metal GPU");
|
||||
|
|
@ -251,7 +249,7 @@ void run_model(
|
|||
// memory via a call to .metal()).
|
||||
module_runner.for_each_tensor_in_bundled_inputs(consume_tensor);
|
||||
} else {
|
||||
std::cout << "Inferred CPU Model." << std::endl;
|
||||
std::cout << "Inferred CPU Model." << '\n';
|
||||
enabled_backends.insert("CPU");
|
||||
torch::jit::mobile::MobileModelRunner mobile_module_runner(
|
||||
input_module_path);
|
||||
|
|
@ -341,7 +339,7 @@ TracerResult trace_run(const std::vector<std::string>& input_module_paths) {
|
|||
} catch (std::exception& ex) {
|
||||
std::cerr
|
||||
<< "ModelTracer encountered an error while attempting to run the model in FBGEMM mode"
|
||||
<< ex.what() << "\n Skipping FBGEMM execution" << std::endl;
|
||||
<< ex.what() << "\n Skipping FBGEMM execution" << '\n';
|
||||
}
|
||||
try {
|
||||
at::globalContext().setQEngine(at::QEngine::QNNPACK);
|
||||
|
|
@ -351,7 +349,7 @@ TracerResult trace_run(const std::vector<std::string>& input_module_paths) {
|
|||
} catch (std::exception& ex) {
|
||||
std::cerr
|
||||
<< "ModelTracer encountered an error while attempting to run the model under an inference guard"
|
||||
<< ex.what() << "\n Skipping inference guard execution" << std::endl;
|
||||
<< ex.what() << "\n Skipping inference guard execution" << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -393,6 +391,4 @@ TracerResult trace_run(const std::vector<std::string>& input_module_paths) {
|
|||
return tracer_result;
|
||||
}
|
||||
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::mobile
|
||||
|
|
|
|||
|
|
@ -9,9 +9,7 @@
|
|||
#include <torch/csrc/jit/mobile/model_tracer/CustomClassTracer.h>
|
||||
#include <torch/csrc/jit/mobile/model_tracer/KernelDTypeTracer.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
namespace torch::jit::mobile {
|
||||
|
||||
const std::vector<std::string> always_included_traced_ops = {
|
||||
// The following are called from setup sections.
|
||||
|
|
@ -38,6 +36,4 @@ TracerResult trace_run(const std::string& input_module_path);
|
|||
*/
|
||||
TracerResult trace_run(const std::vector<std::string>& input_module_paths);
|
||||
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::mobile
|
||||
|
|
|
|||
|
|
@ -53,26 +53,25 @@ C10_DEFINE_string(
|
|||
return 1; \
|
||||
}
|
||||
|
||||
void printOpYAML(
|
||||
static void printOpYAML(
|
||||
std::ostream& out,
|
||||
int indent,
|
||||
const std::string& op_name,
|
||||
bool is_used_for_training,
|
||||
bool is_root_operator,
|
||||
bool include_all_overloads) {
|
||||
out << std::string(indent, ' ') << op_name << ":" << std::endl;
|
||||
out << std::string(indent, ' ') << op_name << ":" << '\n';
|
||||
out << std::string(indent + 2, ' ')
|
||||
<< "is_used_for_training: " << (is_used_for_training ? "true" : "false")
|
||||
<< std::endl;
|
||||
<< '\n';
|
||||
out << std::string(indent + 2, ' ')
|
||||
<< "is_root_operator: " << (is_root_operator ? "true" : "false")
|
||||
<< std::endl;
|
||||
<< "is_root_operator: " << (is_root_operator ? "true" : "false") << '\n';
|
||||
out << std::string(indent + 2, ' ')
|
||||
<< "include_all_overloads: " << (include_all_overloads ? "true" : "false")
|
||||
<< std::endl;
|
||||
<< '\n';
|
||||
}
|
||||
|
||||
void printOpsYAML(
|
||||
static void printOpsYAML(
|
||||
std::ostream& out,
|
||||
const std::set<std::string>& operator_list,
|
||||
bool is_used_for_training,
|
||||
|
|
@ -83,19 +82,19 @@ void printOpsYAML(
|
|||
}
|
||||
}
|
||||
|
||||
void printDTypeYAML(
|
||||
static void printDTypeYAML(
|
||||
std::ostream& out,
|
||||
int indent,
|
||||
const std::string& kernel_tag_name,
|
||||
const std::set<std::string> dtypes) {
|
||||
const std::set<std::string>& dtypes) {
|
||||
std::string indent_str = std::string(indent, ' ');
|
||||
out << indent_str << kernel_tag_name << ":" << std::endl;
|
||||
out << indent_str << kernel_tag_name << ":" << '\n';
|
||||
for (auto& dtype : dtypes) {
|
||||
out << indent_str << "- " << dtype << std::endl;
|
||||
out << indent_str << "- " << dtype << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
void printDTypesYAML(
|
||||
static void printDTypesYAML(
|
||||
std::ostream& out,
|
||||
const torch::jit::mobile::KernelDTypeTracer::kernel_tags_type&
|
||||
kernel_tags) {
|
||||
|
|
@ -104,12 +103,12 @@ void printDTypesYAML(
|
|||
}
|
||||
}
|
||||
|
||||
void printCustomClassesYAML(
|
||||
static void printCustomClassesYAML(
|
||||
std::ostream& out,
|
||||
const torch::jit::mobile::CustomClassTracer::custom_classes_type&
|
||||
loaded_classes) {
|
||||
for (auto& class_name : loaded_classes) {
|
||||
out << "- " << class_name << std::endl;
|
||||
out << "- " << class_name << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -120,7 +119,7 @@ void printCustomClassesYAML(
|
|||
*/
|
||||
int main(int argc, char* argv[]) {
|
||||
if (!c10::ParseCommandLineFlags(&argc, &argv)) {
|
||||
std::cerr << "Failed to parse command line flags!" << std::endl;
|
||||
std::cerr << "Failed to parse command line flags!" << '\n';
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
|
@ -130,13 +129,13 @@ int main(int argc, char* argv[]) {
|
|||
std::istringstream sin(FLAGS_model_input_path);
|
||||
std::ofstream yaml_out(FLAGS_build_yaml_path);
|
||||
|
||||
std::cout << "Output: " << FLAGS_build_yaml_path << std::endl;
|
||||
std::cout << "Output: " << FLAGS_build_yaml_path << '\n';
|
||||
torch::jit::mobile::TracerResult tracer_result;
|
||||
std::vector<std::string> model_input_paths;
|
||||
|
||||
for (std::string model_input_path;
|
||||
std::getline(sin, model_input_path, ',');) {
|
||||
std::cout << "Processing: " << model_input_path << std::endl;
|
||||
std::cout << "Processing: " << model_input_path << '\n';
|
||||
model_input_paths.push_back(model_input_path);
|
||||
}
|
||||
|
||||
|
|
@ -147,7 +146,7 @@ int main(int argc, char* argv[]) {
|
|||
<< "ModelTracer has not been able to load the module for the following reasons:\n"
|
||||
<< ex.what()
|
||||
<< "\nPlease consider opening an issue at https://github.com/pytorch/pytorch/issues "
|
||||
<< "with the detailed error message." << std::endl;
|
||||
<< "with the detailed error message." << '\n';
|
||||
|
||||
throw ex;
|
||||
}
|
||||
|
|
@ -161,7 +160,7 @@ int main(int argc, char* argv[]) {
|
|||
". Expected the traced operator list to be bigger then the default size ",
|
||||
torch::jit::mobile::always_included_traced_ops.size(),
|
||||
". Please report a bug in PyTorch.")
|
||||
<< std::endl;
|
||||
<< '\n';
|
||||
}
|
||||
|
||||
// If the op exist in both traced_ops and root_ops, leave it in root_ops only
|
||||
|
|
@ -172,9 +171,9 @@ int main(int argc, char* argv[]) {
|
|||
}
|
||||
}
|
||||
|
||||
yaml_out << "include_all_non_op_selectives: false" << std::endl;
|
||||
yaml_out << "build_features: []" << std::endl;
|
||||
yaml_out << "operators:" << std::endl;
|
||||
yaml_out << "include_all_non_op_selectives: false" << '\n';
|
||||
yaml_out << "build_features: []" << '\n';
|
||||
yaml_out << "operators:" << '\n';
|
||||
printOpsYAML(
|
||||
yaml_out,
|
||||
tracer_result.root_ops,
|
||||
|
|
@ -192,14 +191,14 @@ int main(int argc, char* argv[]) {
|
|||
if (tracer_result.called_kernel_tags.empty()) {
|
||||
yaml_out << " []";
|
||||
}
|
||||
yaml_out << std::endl;
|
||||
yaml_out << '\n';
|
||||
printDTypesYAML(yaml_out, tracer_result.called_kernel_tags);
|
||||
|
||||
yaml_out << "custom_classes:";
|
||||
if (tracer_result.loaded_classes.empty()) {
|
||||
yaml_out << " []";
|
||||
}
|
||||
yaml_out << std::endl;
|
||||
yaml_out << '\n';
|
||||
printCustomClassesYAML(yaml_out, tracer_result.loaded_classes);
|
||||
|
||||
return 0;
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
#include <torch/csrc/jit/mobile/observer.h>
|
||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||
#include <torch/csrc/jit/runtime/jit_exception.h>
|
||||
#include <exception>
|
||||
|
||||
#include <ATen/record_function.h>
|
||||
#include <c10/util/ScopeExit.h>
|
||||
|
|
@ -30,7 +29,7 @@ const Function* CompilationUnit::find_function(
|
|||
}
|
||||
|
||||
Function* CompilationUnit::find_function(const c10::QualifiedName& qn) {
|
||||
// NOLINTNEXTLINE
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
return const_cast<Function*>(
|
||||
static_cast<const CompilationUnit*>(this)->find_function(qn));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -73,7 +73,6 @@ void applyUpgrader(mobile::Function* function, uint64_t operator_version) {
|
|||
for (size_t i = 0; i < code.instructions_.size(); i++) {
|
||||
Instruction& inst = code.instructions_[i];
|
||||
if (inst.op == OpCode::OP) {
|
||||
std::string op_name = code.op_names_[inst.X].name;
|
||||
std::string operator_name = code.op_names_[inst.X].name +
|
||||
(code.op_names_[inst.X].overload_name.empty()
|
||||
? ""
|
||||
|
|
|
|||
|
|
@ -7,7 +7,8 @@ namespace torch::jit {
|
|||
void tupleIndex(Stack& stack) {
|
||||
int64_t index = pop(stack).toInt();
|
||||
auto tuple = pop(stack).toTuple();
|
||||
auto norm_index = normalizeIndex(index, tuple->elements().size());
|
||||
auto norm_index =
|
||||
normalizeIndex(index, static_cast<int64_t>(tuple->elements().size()));
|
||||
if (norm_index < 0 ||
|
||||
norm_index >= static_cast<int64_t>(tuple->elements().size())) {
|
||||
throw std::out_of_range("Tuple list index out of range");
|
||||
|
|
|
|||
|
|
@ -15,8 +15,7 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
namespace mobile {
|
||||
|
||||
char const* toString(OpCode op);
|
||||
|
|
@ -150,5 +149,4 @@ void _save_parameters(
|
|||
_save_parameters(map, ifile, use_flatbuffer);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -2,8 +2,7 @@
|
|||
|
||||
#include <torch/csrc/jit/mobile/module.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
/**
|
||||
* Serializes the provided tensor map to the provided stream.
|
||||
|
|
@ -49,5 +48,4 @@ extern void (*_save_mobile_module_to)(
|
|||
const mobile::Module& module,
|
||||
const std::function<size_t(const void*, size_t)>& writer_func);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -5,11 +5,7 @@
|
|||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <functional>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
namespace torch::jit::mobile {
|
||||
|
||||
bool SGDParamGroup::has_options() const {
|
||||
return options_ != nullptr;
|
||||
|
|
@ -17,12 +13,12 @@ bool SGDParamGroup::has_options() const {
|
|||
|
||||
SGDOptions& SGDParamGroup::options() {
|
||||
TORCH_CHECK(has_options());
|
||||
return *options_.get();
|
||||
return *options_;
|
||||
}
|
||||
|
||||
const SGDOptions& SGDParamGroup::options() const {
|
||||
TORCH_CHECK(has_options());
|
||||
return *options_.get();
|
||||
return *options_;
|
||||
}
|
||||
|
||||
void SGDParamGroup::set_options(std::unique_ptr<SGDOptions> options) {
|
||||
|
|
@ -126,6 +122,4 @@ Tensor SGD::step(const LossClosure& closure) {
|
|||
}
|
||||
return loss;
|
||||
}
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::mobile
|
||||
|
|
|
|||
|
|
@ -1,17 +1,12 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/arg.h>
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/serialize/archive.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
namespace torch::jit::mobile {
|
||||
|
||||
class SGDParamState {
|
||||
TORCH_ARG(torch::Tensor, momentum_buffer);
|
||||
|
|
@ -127,6 +122,4 @@ class TORCH_API SGD {
|
|||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
std::unique_ptr<SGDOptions> options_;
|
||||
};
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::mobile
|
||||
|
|
|
|||
|
|
@ -5,9 +5,7 @@
|
|||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
namespace torch::jit::mobile {
|
||||
|
||||
RandomSampler::RandomSampler(int64_t size, Dtype index_dtype)
|
||||
: indices_(torch::randperm(size, index_dtype)) {}
|
||||
|
|
@ -18,7 +16,7 @@ void RandomSampler::reset(std::optional<size_t> new_size) {
|
|||
// This allocates a new chunk of memory every time (just FYI). It should be
|
||||
// amortized over the entire epoch hopefully.
|
||||
const auto size = new_size.value_or(static_cast<size_t>(indices_.numel()));
|
||||
indices_ = torch::randperm(size, indices_.options());
|
||||
indices_ = torch::randperm(static_cast<int64_t>(size), indices_.options());
|
||||
index_ = 0;
|
||||
}
|
||||
|
||||
|
|
@ -38,7 +36,7 @@ std::optional<std::vector<size_t>> RandomSampler::next(size_t batch_size) {
|
|||
slice = slice.to(torch::kInt64);
|
||||
const auto* data = slice.const_data_ptr<int64_t>();
|
||||
std::copy(data, data + index_batch.size(), index_batch.begin());
|
||||
index_ += index_batch.size();
|
||||
index_ += static_cast<int64_t>(index_batch.size());
|
||||
return index_batch;
|
||||
}
|
||||
|
||||
|
|
@ -54,6 +52,4 @@ size_t RandomSampler::index() const noexcept {
|
|||
return index_;
|
||||
}
|
||||
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::mobile
|
||||
|
|
|
|||
|
|
@ -7,16 +7,12 @@
|
|||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace serialize {
|
||||
namespace torch::serialize {
|
||||
class OutputArchive;
|
||||
class InputArchive;
|
||||
} // namespace serialize
|
||||
} // namespace torch
|
||||
} // namespace torch::serialize
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
namespace torch::jit::mobile {
|
||||
|
||||
/// A lighter `Sampler` that returns indices randomly and cannot be
|
||||
/// serialized.
|
||||
|
|
@ -51,6 +47,4 @@ class TORCH_API RandomSampler : public torch::data::samplers::Sampler<> {
|
|||
int64_t index_ = 0;
|
||||
};
|
||||
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::mobile
|
||||
|
|
|
|||
|
|
@ -5,9 +5,7 @@
|
|||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
namespace torch::jit::mobile {
|
||||
SequentialSampler::SequentialSampler(size_t size) : size_(size) {}
|
||||
|
||||
void SequentialSampler::reset(std::optional<size_t> new_size) {
|
||||
|
|
@ -43,6 +41,4 @@ size_t SequentialSampler::index() const noexcept {
|
|||
return index_;
|
||||
}
|
||||
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::mobile
|
||||
|
|
|
|||
|
|
@ -7,16 +7,12 @@
|
|||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace serialize {
|
||||
namespace torch::serialize {
|
||||
class OutputArchive;
|
||||
class InputArchive;
|
||||
} // namespace serialize
|
||||
} // namespace torch
|
||||
} // namespace torch::serialize
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
namespace torch::jit::mobile {
|
||||
|
||||
/// A lighter `Sampler` that returns indices sequentially and cannot be
|
||||
/// serialized.
|
||||
|
|
@ -46,6 +42,4 @@ class TORCH_API SequentialSampler : public torch::data::samplers::Sampler<> {
|
|||
size_t index_{0};
|
||||
};
|
||||
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit::mobile
|
||||
|
|
|
|||
|
|
@ -1,12 +1,10 @@
|
|||
#pragma once
|
||||
|
||||
// #include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/ivalue_inl.h>
|
||||
|
||||
#include <torch/csrc/jit/mobile/code.h>
|
||||
#include <torch/csrc/jit/mobile/function.h>
|
||||
#include <torch/csrc/jit/serialization/import_export_functions.h>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user