pytorch/benchmarks/static_runtime/test_utils.h
Mike Iovine 1c43b1602c [SR] Scope exit guard for memory planner deallocation (#68795)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68795

This change improves static runtime exception safety. Added a scope exit guard that invokes `MemoryPlanner::deallocate` in its destructor.

Caveat: we have to be really careful with the exception behavior of `MemoryPlanner::deallocate` and `MemoryPlanner`'s constructor, because they're now both potentially called in the destructor of the scope exit guard. Letting exceptions potentially escape destructors is playing with fire since 1) the destructor of `Deallocator` is (implicitly) `noexcept`, 2) even if it wasn't, `std::terminate` will be called if an exception escapes and the stack is already unwinding. To get around this, we wrap the deallocation stuff in a try/catch. If deallocation throws, then we simply reset all of the memory planner stuff and carry on.
There's a catch: the code path that we take when handling the deallocation exception can't throw. However, this code path is much simpler than memory planner construction/deallocation, so it's much easier to manually audit the correctness here.

Test Plan:
**New unit tests**

`buck test caffe2/benchmarks/static_runtime:static_runtime_cpptest`

Reviewed By: hlu1

Differential Revision: D32609915

fbshipit-source-id: 71fbe6994fd573ca6b7dd859b2e6fbd7eeabcd9e
2021-12-08 16:41:52 -08:00

55 lines
1.3 KiB
C++

// (c) Facebook, Inc. and its affiliates. Confidential and proprietary.
#pragma once
#include <string>
#include <vector>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/static/impl.h>
namespace c10 {
struct IValue;
}
namespace torch {
namespace jit {
struct Node;
class StaticModule;
namespace test {
// Given a model/function in jit or IR script, run the model/function
// with the jit interpreter and static runtime, and compare the results
void testStaticRuntime(
const std::string& source,
const std::vector<c10::IValue>& args,
const std::vector<c10::IValue>& args2 = {},
const bool use_allclose = false,
const bool use_equalnan = false,
const bool check_resize = true);
std::shared_ptr<Graph> getGraphFromIR(const std::string& ir);
bool hasProcessedNodeWithName(
torch::jit::StaticModule& smodule,
const char* name);
at::Tensor getTensor(const at::IValue& ival);
Node* getNodeWithKind(const StaticModule& smodule, const std::string& kind);
bool hasNodeWithKind(const StaticModule& smodule, const std::string& kind);
void compareResultsWithJIT(
StaticRuntime& runtime,
const std::shared_ptr<Graph>& graph,
const std::vector<c10::IValue>& args,
const bool use_allclose = false,
const bool use_equalnan = false);
} // namespace test
} // namespace jit
} // namespace torch