mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[SR] Reverse iteration order in resetMemory (#71705)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71705
This fixes a crash `resetMemory` caused by trying to access a `TensorImpl` via a borrowed `IValue` after it had already been destroyed. We need to clean up all borrows *before* we destroy the owning `IValue`, not after.
ghstack-source-id: 147688982
Test Plan:
New unit test covers this case
ICE w/ inline_cvr v0 [finishes successfully](https://www.internalfb.com/intern/unidash/dashboard/ads_infra_cost_estimation/a_metrics/?e[select_ESTIMATION_RUN_ID]=ICE_mikeiovine_16431103211c65), didn't see any nnpi errors
Reviewed By: ajyu
Differential Revision: D33725435
fbshipit-source-id: f8dd109382b5cf54df6f194f8dcb5c0812b174bb
(cherry picked from commit 31339d9d38)
This commit is contained in:
parent
e04ade92ae
commit
7e6312a5df
|
|
@ -2357,6 +2357,40 @@ TEST(StaticRuntime, ModelCrashOnSecondRun) {
|
||||||
compareResultsWithJIT(runtime, graph, args_no_crash);
|
compareResultsWithJIT(runtime, graph, args_no_crash);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(StaticRuntime, ModelCrashOnFirstRunWithBorrows) {
|
||||||
|
const auto src = R"JIT(
|
||||||
|
graph(%0: Tensor):
|
||||||
|
%1: Tensor = aten::mul(%0, %0)
|
||||||
|
%2: Tensor = aten::mul(%1, %1)
|
||||||
|
%3: bool = prim::Constant[value=1]()
|
||||||
|
%4: Tensor = static_runtime::select_tensor(%1, %2, %3)
|
||||||
|
static_runtime_tests::maybe_throw(%3)
|
||||||
|
return (%4)
|
||||||
|
)JIT";
|
||||||
|
auto graph = getGraphFromIR(src);
|
||||||
|
auto static_module = StaticModule(graph);
|
||||||
|
auto& runtime = static_module.runtime();
|
||||||
|
|
||||||
|
std::vector<IValue> args{at::randn({1})};
|
||||||
|
EXPECT_THROW(runtime(args), std::runtime_error);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StaticRuntime, ModelCrashOnFirstRunWithBorrowedInputs) {
|
||||||
|
const auto src = R"JIT(
|
||||||
|
graph(%0: Tensor, %1: Tensor):
|
||||||
|
%2: bool = prim::Constant[value=1]()
|
||||||
|
%3: Tensor = static_runtime::select_tensor(%0, %1, %2)
|
||||||
|
static_runtime_tests::maybe_throw(%2)
|
||||||
|
return (%3)
|
||||||
|
)JIT";
|
||||||
|
auto graph = getGraphFromIR(src);
|
||||||
|
auto static_module = StaticModule(graph);
|
||||||
|
auto& runtime = static_module.runtime();
|
||||||
|
|
||||||
|
std::vector<IValue> args{at::randn({1}), at::randn({1})};
|
||||||
|
EXPECT_THROW(runtime(std::move(args)), std::runtime_error);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(StaticRuntime, ReplaceWithMaybeCopy) {
|
TEST(StaticRuntime, ReplaceWithMaybeCopy) {
|
||||||
const std::string to = R"IR(
|
const std::string to = R"IR(
|
||||||
graph(%0 : Tensor):
|
graph(%0 : Tensor):
|
||||||
|
|
|
||||||
|
|
@ -925,15 +925,21 @@ void destroyNodeOutputs(ProcessedNode& p_node) {
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void StaticRuntime::clean_up_intermediate_ivalues() noexcept {
|
void StaticRuntime::clean_up_intermediate_ivalues() noexcept {
|
||||||
for (auto& p_node : nodes_) {
|
// We have to iterate in reverse order here due to borrowed
|
||||||
destroyNodeOutputs(p_node);
|
// IValues - we don't want to destroy a value until all of its
|
||||||
|
// borrows are cleaned up!
|
||||||
|
for (auto it = nodes_.rbegin(); it != nodes_.rend(); ++it) {
|
||||||
|
destroyNodeOutputs(*it);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void StaticRuntime::resetMemory() noexcept {
|
void StaticRuntime::resetMemory() noexcept {
|
||||||
planner_.reset();
|
planner_.reset();
|
||||||
clean_up_input_ivalues();
|
// We must clean up intermediate values before inputs in case
|
||||||
|
// there are borrowed inputs and static runtime owns the only
|
||||||
|
// reference (e.g. the inputs were std::move'd into the runtime)
|
||||||
clean_up_intermediate_ivalues();
|
clean_up_intermediate_ivalues();
|
||||||
|
clean_up_input_ivalues();
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::IValue StaticRuntime::move_outputs_to_tuple(uint32_t num_outputs) {
|
c10::IValue StaticRuntime::move_outputs_to_tuple(uint32_t num_outputs) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user