mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[SR] Reverse iteration order in resetMemory (#71705)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71705
This fixes a crash `resetMemory` caused by trying to access a `TensorImpl` via a borrowed `IValue` after it had already been destroyed. We need to clean up all borrows *before* we destroy the owning `IValue`, not after.
ghstack-source-id: 147688982
Test Plan:
New unit test covers this case
ICE w/ inline_cvr v0 [finishes successfully](https://www.internalfb.com/intern/unidash/dashboard/ads_infra_cost_estimation/a_metrics/?e[select_ESTIMATION_RUN_ID]=ICE_mikeiovine_16431103211c65), didn't see any nnpi errors
Reviewed By: ajyu
Differential Revision: D33725435
fbshipit-source-id: f8dd109382b5cf54df6f194f8dcb5c0812b174bb
(cherry picked from commit 31339d9d38)
This commit is contained in:
parent
e04ade92ae
commit
7e6312a5df
|
|
@ -2357,6 +2357,40 @@ TEST(StaticRuntime, ModelCrashOnSecondRun) {
|
|||
compareResultsWithJIT(runtime, graph, args_no_crash);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, ModelCrashOnFirstRunWithBorrows) {
|
||||
const auto src = R"JIT(
|
||||
graph(%0: Tensor):
|
||||
%1: Tensor = aten::mul(%0, %0)
|
||||
%2: Tensor = aten::mul(%1, %1)
|
||||
%3: bool = prim::Constant[value=1]()
|
||||
%4: Tensor = static_runtime::select_tensor(%1, %2, %3)
|
||||
static_runtime_tests::maybe_throw(%3)
|
||||
return (%4)
|
||||
)JIT";
|
||||
auto graph = getGraphFromIR(src);
|
||||
auto static_module = StaticModule(graph);
|
||||
auto& runtime = static_module.runtime();
|
||||
|
||||
std::vector<IValue> args{at::randn({1})};
|
||||
EXPECT_THROW(runtime(args), std::runtime_error);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, ModelCrashOnFirstRunWithBorrowedInputs) {
|
||||
const auto src = R"JIT(
|
||||
graph(%0: Tensor, %1: Tensor):
|
||||
%2: bool = prim::Constant[value=1]()
|
||||
%3: Tensor = static_runtime::select_tensor(%0, %1, %2)
|
||||
static_runtime_tests::maybe_throw(%2)
|
||||
return (%3)
|
||||
)JIT";
|
||||
auto graph = getGraphFromIR(src);
|
||||
auto static_module = StaticModule(graph);
|
||||
auto& runtime = static_module.runtime();
|
||||
|
||||
std::vector<IValue> args{at::randn({1}), at::randn({1})};
|
||||
EXPECT_THROW(runtime(std::move(args)), std::runtime_error);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, ReplaceWithMaybeCopy) {
|
||||
const std::string to = R"IR(
|
||||
graph(%0 : Tensor):
|
||||
|
|
|
|||
|
|
@ -925,15 +925,21 @@ void destroyNodeOutputs(ProcessedNode& p_node) {
|
|||
} // namespace
|
||||
|
||||
void StaticRuntime::clean_up_intermediate_ivalues() noexcept {
|
||||
for (auto& p_node : nodes_) {
|
||||
destroyNodeOutputs(p_node);
|
||||
// We have to iterate in reverse order here due to borrowed
|
||||
// IValues - we don't want to destroy a value until all of its
|
||||
// borrows are cleaned up!
|
||||
for (auto it = nodes_.rbegin(); it != nodes_.rend(); ++it) {
|
||||
destroyNodeOutputs(*it);
|
||||
}
|
||||
}
|
||||
|
||||
void StaticRuntime::resetMemory() noexcept {
|
||||
planner_.reset();
|
||||
clean_up_input_ivalues();
|
||||
// We must clean up intermediate values before inputs in case
|
||||
// there are borrowed inputs and static runtime owns the only
|
||||
// reference (e.g. the inputs were std::move'd into the runtime)
|
||||
clean_up_intermediate_ivalues();
|
||||
clean_up_input_ivalues();
|
||||
}
|
||||
|
||||
c10::IValue StaticRuntime::move_outputs_to_tuple(uint32_t num_outputs) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user