mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68795 This change improves static runtime exception safety. Added a scope exit guard that invokes `MemoryPlanner::deallocate` in its destructor. Caveat: we have to be really careful with the exception behavior of `MemoryPlanner::deallocate` and `MemoryPlanner`'s constructor, because they're now both potentially called in the destructor of the scope exit guard. Letting exceptions potentially escape destructors is playing with fire since 1) the destructor of `Deallocator` is (implicitly) `noexcept`, 2) even if it wasn't, `std::terminate` will be called if an exception escapes and the stack is already unwinding. To get around this, we wrap the deallocation stuff in a try/catch. If deallocation throws, then we simply reset all of the memory planner stuff and carry on. There's a catch: the code path that we take when handling the deallocation exception can't throw. However, this code path is much simpler than memory planner construction/deallocation, so it's much easier to manually audit the correctness here. Test Plan: **New unit tests** `buck test caffe2/benchmarks/static_runtime:static_runtime_cpptest` Reviewed By: hlu1 Differential Revision: D32609915 fbshipit-source-id: 71fbe6994fd573ca6b7dd859b2e6fbd7eeabcd9e
55 lines
1.3 KiB
C++
55 lines
1.3 KiB
C++
// (c) Facebook, Inc. and its affiliates. Confidential and proprietary.
|
|
|
|
#pragma once
|
|
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/runtime/static/impl.h>
|
|
|
|
namespace c10 {
|
|
struct IValue;
|
|
}
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
struct Node;
|
|
class StaticModule;
|
|
|
|
namespace test {
|
|
|
|
// Given a model/function in jit or IR script, run the model/function
|
|
// with the jit interpreter and static runtime, and compare the results
|
|
void testStaticRuntime(
|
|
const std::string& source,
|
|
const std::vector<c10::IValue>& args,
|
|
const std::vector<c10::IValue>& args2 = {},
|
|
const bool use_allclose = false,
|
|
const bool use_equalnan = false,
|
|
const bool check_resize = true);
|
|
|
|
std::shared_ptr<Graph> getGraphFromIR(const std::string& ir);
|
|
|
|
bool hasProcessedNodeWithName(
|
|
torch::jit::StaticModule& smodule,
|
|
const char* name);
|
|
|
|
at::Tensor getTensor(const at::IValue& ival);
|
|
|
|
Node* getNodeWithKind(const StaticModule& smodule, const std::string& kind);
|
|
|
|
bool hasNodeWithKind(const StaticModule& smodule, const std::string& kind);
|
|
|
|
void compareResultsWithJIT(
|
|
StaticRuntime& runtime,
|
|
const std::shared_ptr<Graph>& graph,
|
|
const std::vector<c10::IValue>& args,
|
|
const bool use_allclose = false,
|
|
const bool use_equalnan = false);
|
|
|
|
} // namespace test
|
|
} // namespace jit
|
|
} // namespace torch
|