[SR] Reverse iteration order in resetMemory (#71705)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71705

This fixes a crash `resetMemory` caused by trying to access a `TensorImpl` via a borrowed `IValue` after it had already been destroyed. We need to clean up all borrows *before* we destroy the owning `IValue`, not after.
ghstack-source-id: 147688982

Test Plan:
New unit test covers this case

ICE w/ inline_cvr v0 [finishes successfully](https://www.internalfb.com/intern/unidash/dashboard/ads_infra_cost_estimation/a_metrics/?e[select_ESTIMATION_RUN_ID]=ICE_mikeiovine_16431103211c65), didn't see any nnpi errors

Reviewed By: ajyu

Differential Revision: D33725435

fbshipit-source-id: f8dd109382b5cf54df6f194f8dcb5c0812b174bb
(cherry picked from commit 31339d9d38)
This commit is contained in:
Mike Iovine 2022-01-26 09:12:13 -08:00 committed by PyTorch MergeBot
parent e04ade92ae
commit 7e6312a5df
2 changed files with 43 additions and 3 deletions

View File

@ -2357,6 +2357,40 @@ TEST(StaticRuntime, ModelCrashOnSecondRun) {
compareResultsWithJIT(runtime, graph, args_no_crash); compareResultsWithJIT(runtime, graph, args_no_crash);
} }
TEST(StaticRuntime, ModelCrashOnFirstRunWithBorrows) {
const auto src = R"JIT(
graph(%0: Tensor):
%1: Tensor = aten::mul(%0, %0)
%2: Tensor = aten::mul(%1, %1)
%3: bool = prim::Constant[value=1]()
%4: Tensor = static_runtime::select_tensor(%1, %2, %3)
static_runtime_tests::maybe_throw(%3)
return (%4)
)JIT";
auto graph = getGraphFromIR(src);
auto static_module = StaticModule(graph);
auto& runtime = static_module.runtime();
std::vector<IValue> args{at::randn({1})};
EXPECT_THROW(runtime(args), std::runtime_error);
}
TEST(StaticRuntime, ModelCrashOnFirstRunWithBorrowedInputs) {
const auto src = R"JIT(
graph(%0: Tensor, %1: Tensor):
%2: bool = prim::Constant[value=1]()
%3: Tensor = static_runtime::select_tensor(%0, %1, %2)
static_runtime_tests::maybe_throw(%2)
return (%3)
)JIT";
auto graph = getGraphFromIR(src);
auto static_module = StaticModule(graph);
auto& runtime = static_module.runtime();
std::vector<IValue> args{at::randn({1}), at::randn({1})};
EXPECT_THROW(runtime(std::move(args)), std::runtime_error);
}
TEST(StaticRuntime, ReplaceWithMaybeCopy) { TEST(StaticRuntime, ReplaceWithMaybeCopy) {
const std::string to = R"IR( const std::string to = R"IR(
graph(%0 : Tensor): graph(%0 : Tensor):

View File

@ -925,15 +925,21 @@ void destroyNodeOutputs(ProcessedNode& p_node) {
} // namespace } // namespace
void StaticRuntime::clean_up_intermediate_ivalues() noexcept { void StaticRuntime::clean_up_intermediate_ivalues() noexcept {
for (auto& p_node : nodes_) { // We have to iterate in reverse order here due to borrowed
destroyNodeOutputs(p_node); // IValues - we don't want to destroy a value until all of its
// borrows are cleaned up!
for (auto it = nodes_.rbegin(); it != nodes_.rend(); ++it) {
destroyNodeOutputs(*it);
} }
} }
void StaticRuntime::resetMemory() noexcept { void StaticRuntime::resetMemory() noexcept {
planner_.reset(); planner_.reset();
clean_up_input_ivalues(); // We must clean up intermediate values before inputs in case
// there are borrowed inputs and static runtime owns the only
// reference (e.g. the inputs were std::move'd into the runtime)
clean_up_intermediate_ivalues(); clean_up_intermediate_ivalues();
clean_up_input_ivalues();
} }
c10::IValue StaticRuntime::move_outputs_to_tuple(uint32_t num_outputs) { c10::IValue StaticRuntime::move_outputs_to_tuple(uint32_t num_outputs) {