mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Wrap torch::deploy API functions in safe rethrow macros (#58412)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58412 Second try- avoid ctor/dtor handling this time as it is kind of pointless if the rethrow will still terminate(), and upsets -Werror=terminate Original commit changeset: 1775bed18269 Test Plan: existing unit tests and CI Reviewed By: suo Differential Revision: D28478588 fbshipit-source-id: 84191cecc3ef52e23f11bfea07bbb9773ebc5df4
This commit is contained in:
parent
7b73fdf597
commit
8a3fb2689f
|
|
@ -16,25 +16,33 @@ namespace torch {
|
||||||
namespace deploy {
|
namespace deploy {
|
||||||
|
|
||||||
Package InterpreterManager::load_package(const std::string& uri) {
|
Package InterpreterManager::load_package(const std::string& uri) {
|
||||||
|
TORCH_DEPLOY_TRY
|
||||||
return Package(uri, this);
|
return Package(uri, this);
|
||||||
|
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||||
}
|
}
|
||||||
|
|
||||||
Package InterpreterManager::load_package(
|
Package InterpreterManager::load_package(
|
||||||
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> reader) {
|
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> reader) {
|
||||||
|
TORCH_DEPLOY_TRY
|
||||||
return Package(reader, this);
|
return Package(reader, this);
|
||||||
|
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||||
}
|
}
|
||||||
|
|
||||||
Obj InterpreterSession::from_movable(const ReplicatedObj& obj) {
|
Obj InterpreterSession::from_movable(const ReplicatedObj& obj) {
|
||||||
|
TORCH_DEPLOY_TRY
|
||||||
return impl_->unpickle_or_get(obj.pImpl_->object_id_, obj.pImpl_->data_);
|
return impl_->unpickle_or_get(obj.pImpl_->object_id_, obj.pImpl_->data_);
|
||||||
|
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||||
}
|
}
|
||||||
|
|
||||||
InterpreterSession ReplicatedObj::acquire_session(
|
InterpreterSession ReplicatedObj::acquire_session(
|
||||||
const Interpreter* on_this_interpreter) const {
|
const Interpreter* on_this_interpreter) const {
|
||||||
|
TORCH_DEPLOY_TRY
|
||||||
InterpreterSession I = on_this_interpreter
|
InterpreterSession I = on_this_interpreter
|
||||||
? on_this_interpreter->acquire_session()
|
? on_this_interpreter->acquire_session()
|
||||||
: pImpl_->manager_->acquire_one();
|
: pImpl_->manager_->acquire_one();
|
||||||
I.self = I.from_movable(*this);
|
I.self = I.from_movable(*this);
|
||||||
return I;
|
return I;
|
||||||
|
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||||
}
|
}
|
||||||
|
|
||||||
InterpreterSession::~InterpreterSession() {
|
InterpreterSession::~InterpreterSession() {
|
||||||
|
|
@ -44,6 +52,7 @@ InterpreterSession::~InterpreterSession() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void ReplicatedObjImpl::unload(const Interpreter* on_this_interpreter) {
|
void ReplicatedObjImpl::unload(const Interpreter* on_this_interpreter) {
|
||||||
|
TORCH_DEPLOY_TRY
|
||||||
if (!on_this_interpreter) {
|
if (!on_this_interpreter) {
|
||||||
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
|
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
|
||||||
for (auto& interp : manager_->all_instances()) {
|
for (auto& interp : manager_->all_instances()) {
|
||||||
|
|
@ -54,6 +63,7 @@ void ReplicatedObjImpl::unload(const Interpreter* on_this_interpreter) {
|
||||||
|
|
||||||
InterpreterSession I = on_this_interpreter->acquire_session();
|
InterpreterSession I = on_this_interpreter->acquire_session();
|
||||||
I.impl_->unload(object_id_);
|
I.impl_->unload(object_id_);
|
||||||
|
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||||
}
|
}
|
||||||
|
|
||||||
ReplicatedObjImpl::~ReplicatedObjImpl() {
|
ReplicatedObjImpl::~ReplicatedObjImpl() {
|
||||||
|
|
@ -61,16 +71,20 @@ ReplicatedObjImpl::~ReplicatedObjImpl() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void ReplicatedObj::unload(const Interpreter* on_this_interpreter) {
|
void ReplicatedObj::unload(const Interpreter* on_this_interpreter) {
|
||||||
|
TORCH_DEPLOY_TRY
|
||||||
pImpl_->unload(on_this_interpreter);
|
pImpl_->unload(on_this_interpreter);
|
||||||
|
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||||
}
|
}
|
||||||
|
|
||||||
ReplicatedObj InterpreterSession::create_movable(Obj obj) {
|
ReplicatedObj InterpreterSession::create_movable(Obj obj) {
|
||||||
|
TORCH_DEPLOY_TRY
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
manager_,
|
manager_,
|
||||||
"Can only create a movable object when the session was created from an interpreter that is part of a InterpreterManager");
|
"Can only create a movable object when the session was created from an interpreter that is part of a InterpreterManager");
|
||||||
auto pickled = impl_->pickle(self, obj);
|
auto pickled = impl_->pickle(self, obj);
|
||||||
return ReplicatedObj(std::make_shared<ReplicatedObjImpl>(
|
return ReplicatedObj(std::make_shared<ReplicatedObjImpl>(
|
||||||
manager_->next_object_id_++, std::move(pickled), manager_));
|
manager_->next_object_id_++, std::move(pickled), manager_));
|
||||||
|
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||||
}
|
}
|
||||||
|
|
||||||
Interpreter::Interpreter(InterpreterManager* manager)
|
Interpreter::Interpreter(InterpreterManager* manager)
|
||||||
|
|
@ -114,6 +128,7 @@ Interpreter::~Interpreter() {
|
||||||
}
|
}
|
||||||
|
|
||||||
int LoadBalancer::acquire() {
|
int LoadBalancer::acquire() {
|
||||||
|
TORCH_DEPLOY_TRY
|
||||||
thread_local int last = 0;
|
thread_local int last = 0;
|
||||||
size_t minusers = SIZE_MAX;
|
size_t minusers = SIZE_MAX;
|
||||||
int min_idx = 0;
|
int min_idx = 0;
|
||||||
|
|
@ -147,10 +162,14 @@ int LoadBalancer::acquire() {
|
||||||
// then, so this is only a heuristic).
|
// then, so this is only a heuristic).
|
||||||
__atomic_fetch_add(&uses_[8 * min_idx], 1ULL, __ATOMIC_SEQ_CST);
|
__atomic_fetch_add(&uses_[8 * min_idx], 1ULL, __ATOMIC_SEQ_CST);
|
||||||
return min_idx;
|
return min_idx;
|
||||||
|
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||||
}
|
}
|
||||||
|
|
||||||
void LoadBalancer::free(int where) {
|
void LoadBalancer::free(int where) {
|
||||||
|
TORCH_DEPLOY_TRY
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||||
__atomic_fetch_sub(&uses_[8 * where], 1ULL, __ATOMIC_SEQ_CST);
|
__atomic_fetch_sub(&uses_[8 * where], 1ULL, __ATOMIC_SEQ_CST);
|
||||||
|
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace deploy
|
} // namespace deploy
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,48 @@
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
/* Torch Deploy intentionally embeds multiple copies of c++ libraries
|
||||||
|
providing python bindings necessary for torch::deploy users in the same
|
||||||
|
process space in order to provide a multi-python environment. As a result,
|
||||||
|
any exception types defined by these duplicated libraries can't be safely
|
||||||
|
caught or handled outside of the originating dynamic library (.so).
|
||||||
|
|
||||||
|
In practice this means that you must either
|
||||||
|
catch these exceptions inside the torch::deploy API boundary or risk crashing
|
||||||
|
the client application.
|
||||||
|
|
||||||
|
It is safe to throw exception types that are defined once in
|
||||||
|
the context of the client application, such as c10::Error, which is defined
|
||||||
|
in libtorch, which isn't duplicated in torch::deploy interpreters.
|
||||||
|
|
||||||
|
==> Use TORCH_DEPLOY_TRY, _SAFE_CATCH_RETHROW around _ALL_ torch::deploy APIs
|
||||||
|
|
||||||
|
For more information, see
|
||||||
|
https://gcc.gnu.org/wiki/Visibility (section on c++ exceptions)
|
||||||
|
or https://stackoverflow.com/a/14364055
|
||||||
|
or
|
||||||
|
https://stackoverflow.com/questions/14268736/symbol-visibility-exceptions-runtime-error
|
||||||
|
note- this may be only a serious problem on versions of gcc prior to 4.0,
|
||||||
|
but still seems worth sealing off.
|
||||||
|
|
||||||
|
*/
|
||||||
|
#define TORCH_DEPLOY_TRY try {
|
||||||
|
#define TORCH_DEPLOY_SAFE_CATCH_RETHROW \
|
||||||
|
} \
|
||||||
|
catch (std::exception & err) { \
|
||||||
|
throw c10::Error( \
|
||||||
|
std::string( \
|
||||||
|
"Exception Caught inside torch::deploy embedded library: \n") + \
|
||||||
|
err.what(), \
|
||||||
|
""); \
|
||||||
|
} \
|
||||||
|
catch (...) { \
|
||||||
|
throw c10::Error( \
|
||||||
|
std::string( \
|
||||||
|
"Unknown Exception Caught inside torch::deploy embedded library"), \
|
||||||
|
""); \
|
||||||
|
}
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace deploy {
|
namespace deploy {
|
||||||
|
|
||||||
|
|
@ -26,10 +68,14 @@ struct TORCH_API InterpreterSession {
|
||||||
InterpreterSession(InterpreterSession&&) noexcept = default;
|
InterpreterSession(InterpreterSession&&) noexcept = default;
|
||||||
~InterpreterSession();
|
~InterpreterSession();
|
||||||
Obj global(const char* module, const char* name) {
|
Obj global(const char* module, const char* name) {
|
||||||
|
TORCH_DEPLOY_TRY
|
||||||
return impl_->global(module, name);
|
return impl_->global(module, name);
|
||||||
|
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||||
}
|
}
|
||||||
Obj from_ivalue(at::IValue ivalue) {
|
Obj from_ivalue(at::IValue ivalue) {
|
||||||
|
TORCH_DEPLOY_TRY
|
||||||
return impl_->from_ivalue(std::move(ivalue));
|
return impl_->from_ivalue(std::move(ivalue));
|
||||||
|
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||||
}
|
}
|
||||||
ReplicatedObj create_movable(Obj obj);
|
ReplicatedObj create_movable(Obj obj);
|
||||||
Obj from_movable(const ReplicatedObj& obj);
|
Obj from_movable(const ReplicatedObj& obj);
|
||||||
|
|
@ -55,7 +101,9 @@ class TORCH_API Interpreter {
|
||||||
public:
|
public:
|
||||||
Interpreter(InterpreterManager* manager);
|
Interpreter(InterpreterManager* manager);
|
||||||
InterpreterSession acquire_session() const {
|
InterpreterSession acquire_session() const {
|
||||||
|
TORCH_DEPLOY_TRY
|
||||||
return InterpreterSession(pImpl_->acquire_session(), manager_);
|
return InterpreterSession(pImpl_->acquire_session(), manager_);
|
||||||
|
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||||
}
|
}
|
||||||
~Interpreter();
|
~Interpreter();
|
||||||
Interpreter(Interpreter&& rhs) noexcept
|
Interpreter(Interpreter&& rhs) noexcept
|
||||||
|
|
@ -75,13 +123,18 @@ class TORCH_API Interpreter {
|
||||||
struct Package;
|
struct Package;
|
||||||
|
|
||||||
struct TORCH_API LoadBalancer {
|
struct TORCH_API LoadBalancer {
|
||||||
LoadBalancer(size_t n) : uses_(new uint64_t[8 * n]), allocated_(n), n_(n) {
|
explicit LoadBalancer(size_t n)
|
||||||
|
: uses_(new uint64_t[8 * n]), allocated_(n), n_(n) {
|
||||||
|
TORCH_DEPLOY_TRY
|
||||||
// 8*... to avoid false sharing of atomics on the same cache line
|
// 8*... to avoid false sharing of atomics on the same cache line
|
||||||
memset(uses_.get(), 0, 8 * n_ * sizeof(uint64_t));
|
memset(uses_.get(), 0, 8 * n_ * sizeof(uint64_t));
|
||||||
|
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||||
}
|
}
|
||||||
void setResourceLimit(size_t n) {
|
void setResourceLimit(size_t n) {
|
||||||
|
TORCH_DEPLOY_TRY
|
||||||
TORCH_INTERNAL_ASSERT(n <= allocated_);
|
TORCH_INTERNAL_ASSERT(n <= allocated_);
|
||||||
n_ = n;
|
n_ = n;
|
||||||
|
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||||
}
|
}
|
||||||
int acquire();
|
int acquire();
|
||||||
void free(int where);
|
void free(int where);
|
||||||
|
|
@ -96,6 +149,7 @@ struct TORCH_API LoadBalancer {
|
||||||
|
|
||||||
struct TORCH_API InterpreterManager {
|
struct TORCH_API InterpreterManager {
|
||||||
InterpreterManager(size_t n_interp = 2) : resources_(n_interp) {
|
InterpreterManager(size_t n_interp = 2) : resources_(n_interp) {
|
||||||
|
TORCH_DEPLOY_TRY
|
||||||
for (size_t i = 0; i < n_interp; ++i) {
|
for (size_t i = 0; i < n_interp; ++i) {
|
||||||
instances_.emplace_back(this);
|
instances_.emplace_back(this);
|
||||||
auto I = instances_.back().acquire_session();
|
auto I = instances_.back().acquire_session();
|
||||||
|
|
@ -104,24 +158,31 @@ struct TORCH_API InterpreterManager {
|
||||||
I.global("torch", "version").attr("__setattr__")({"interp", int(i)});
|
I.global("torch", "version").attr("__setattr__")({"interp", int(i)});
|
||||||
// std::cerr << "Interpreter " << i << " initialized\n";
|
// std::cerr << "Interpreter " << i << " initialized\n";
|
||||||
}
|
}
|
||||||
|
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||||
}
|
}
|
||||||
// get a free model, guarenteed that no other user of acquire_one has the same
|
// get a free model, guarenteed that no other user of acquire_one has the same
|
||||||
// model. It _is_ possible that other users will be using the interpreter.
|
// model. It _is_ possible that other users will be using the interpreter.
|
||||||
InterpreterSession acquire_one() {
|
InterpreterSession acquire_one() {
|
||||||
|
TORCH_DEPLOY_TRY
|
||||||
int where = resources_.acquire();
|
int where = resources_.acquire();
|
||||||
InterpreterSession I = instances_[where].acquire_session();
|
InterpreterSession I = instances_[where].acquire_session();
|
||||||
I.notify_idx_ = where;
|
I.notify_idx_ = where;
|
||||||
return I;
|
return I;
|
||||||
|
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||||
}
|
}
|
||||||
|
|
||||||
// use to make sure something gets run on all interpreters, such as loading or
|
// use to make sure something gets run on all interpreters, such as loading or
|
||||||
// unloading a model eagerly
|
// unloading a model eagerly
|
||||||
at::ArrayRef<Interpreter> all_instances() {
|
at::ArrayRef<Interpreter> all_instances() {
|
||||||
|
TORCH_DEPLOY_TRY
|
||||||
return instances_;
|
return instances_;
|
||||||
|
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||||
}
|
}
|
||||||
void debugLimitInterpreters(size_t N) {
|
void debugLimitInterpreters(size_t N) {
|
||||||
|
TORCH_DEPLOY_TRY
|
||||||
AT_ASSERT(N <= instances_.size());
|
AT_ASSERT(N <= instances_.size());
|
||||||
resources_.setResourceLimit(N);
|
resources_.setResourceLimit(N);
|
||||||
|
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||||
}
|
}
|
||||||
Package load_package(const std::string& uri);
|
Package load_package(const std::string& uri);
|
||||||
Package load_package(
|
Package load_package(
|
||||||
|
|
@ -157,21 +218,27 @@ struct TORCH_API ReplicatedObj {
|
||||||
InterpreterSession acquire_session(
|
InterpreterSession acquire_session(
|
||||||
const Interpreter* on_this_interpreter = nullptr) const;
|
const Interpreter* on_this_interpreter = nullptr) const;
|
||||||
at::IValue operator()(at::ArrayRef<at::IValue> args) const {
|
at::IValue operator()(at::ArrayRef<at::IValue> args) const {
|
||||||
|
TORCH_DEPLOY_TRY
|
||||||
auto I = acquire_session();
|
auto I = acquire_session();
|
||||||
return I.self(args).toIValue();
|
return I.self(args).toIValue();
|
||||||
|
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||||
}
|
}
|
||||||
|
|
||||||
at::IValue call_kwargs(
|
[[nodiscard]] at::IValue call_kwargs(
|
||||||
std::vector<at::IValue> args,
|
std::vector<at::IValue> args,
|
||||||
std::unordered_map<std::string, c10::IValue> kwargs) const {
|
std::unordered_map<std::string, c10::IValue> kwargs) const {
|
||||||
|
TORCH_DEPLOY_TRY
|
||||||
auto I = acquire_session();
|
auto I = acquire_session();
|
||||||
return I.self.call_kwargs(std::move(args), std::move(kwargs)).toIValue();
|
return I.self.call_kwargs(std::move(args), std::move(kwargs)).toIValue();
|
||||||
|
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] at::IValue call_kwargs(
|
[[nodiscard]] at::IValue call_kwargs(
|
||||||
std::unordered_map<std::string, c10::IValue> kwargs) const {
|
std::unordered_map<std::string, c10::IValue> kwargs) const {
|
||||||
|
TORCH_DEPLOY_TRY
|
||||||
auto I = acquire_session();
|
auto I = acquire_session();
|
||||||
return I.self.call_kwargs(std::move(kwargs)).toIValue();
|
return I.self.call_kwargs(std::move(kwargs)).toIValue();
|
||||||
|
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||||
}
|
}
|
||||||
|
|
||||||
void unload(const Interpreter* on_this_interpreter = nullptr);
|
void unload(const Interpreter* on_this_interpreter = nullptr);
|
||||||
|
|
@ -190,22 +257,28 @@ struct TORCH_API Package {
|
||||||
ReplicatedObj load_pickle(
|
ReplicatedObj load_pickle(
|
||||||
const std::string& module,
|
const std::string& module,
|
||||||
const std::string& file) {
|
const std::string& file) {
|
||||||
|
TORCH_DEPLOY_TRY
|
||||||
auto I = acquire_session();
|
auto I = acquire_session();
|
||||||
auto loaded = I.self.attr("load_pickle")({module, file});
|
auto loaded = I.self.attr("load_pickle")({module, file});
|
||||||
return I.create_movable(loaded);
|
return I.create_movable(loaded);
|
||||||
|
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string load_text(const std::string& module, const std::string& file) {
|
std::string load_text(const std::string& module, const std::string& file) {
|
||||||
|
TORCH_DEPLOY_TRY
|
||||||
auto I = acquire_session();
|
auto I = acquire_session();
|
||||||
auto loaded = I.self.attr("load_text")({module, file});
|
auto loaded = I.self.attr("load_text")({module, file});
|
||||||
return loaded.toIValue().toStringRef();
|
return loaded.toIValue().toStringRef();
|
||||||
|
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||||
}
|
}
|
||||||
|
|
||||||
InterpreterSession acquire_session() {
|
InterpreterSession acquire_session() {
|
||||||
|
TORCH_DEPLOY_TRY
|
||||||
auto I = manager_->acquire_one();
|
auto I = manager_->acquire_one();
|
||||||
I.self = I.impl_->create_or_get_package_importer_from_container_file(
|
I.self = I.impl_->create_or_get_package_importer_from_container_file(
|
||||||
container_file_);
|
container_file_);
|
||||||
return I;
|
return I;
|
||||||
|
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
||||||
|
|
@ -134,3 +134,15 @@ TEST(TorchpyTest, ThreadedSimpleModel) {
|
||||||
ASSERT_TRUE(ref_output.equal(outputs[i]));
|
ASSERT_TRUE(ref_output.equal(outputs[i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TorchpyTest, ThrowsSafely) {
|
||||||
|
// See explanation in deploy.h
|
||||||
|
torch::deploy::InterpreterManager manager(3);
|
||||||
|
EXPECT_THROW(manager.load_package("some garbage path"), c10::Error);
|
||||||
|
|
||||||
|
torch::deploy::Package p = manager.load_package(path("SIMPLE", simple));
|
||||||
|
EXPECT_THROW(p.load_pickle("some other", "garbage path"), c10::Error);
|
||||||
|
|
||||||
|
auto model = p.load_pickle("model", "model.pkl");
|
||||||
|
EXPECT_THROW(model(at::IValue("unexpected input")), c10::Error);
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user