mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
[StaticRuntime] Fix bug in MemoryPlanner (#51342)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51342 There is a subtle bug with the MemoryPlanner with regard to view ops with out variant. ``` def forward(self, a: Tensor, shape: List[int]): b = a.reshape(shape) return b + b ``` In this case, if we replace reshape with the out variant, b would be managed by the MemoryPlanner and the storage of its output would have been set to nullptr right after inference by the MemoryPlanner if opts.cleanup_activations is true. Because b is a view of a, the storage of a is also set to nullptr, and this violates the API which promises that a is const. To fix this bug, I changed the MemoryPlanner so that it puts b in the unmanaged part. Test Plan: Add unit test to enforce the constness of inputs ``` buck test //caffe2/benchmarks/static_runtime:static_runtime_cpptest ``` Reviewed By: ajyu Differential Revision: D26144203 fbshipit-source-id: 2dbacccf7685d0fe0f0b1195166e0510b2069fe3
This commit is contained in:
parent
09e48dbd33
commit
11cda929fb
|
|
@ -27,7 +27,8 @@ const auto add_script = R"JIT(
|
|||
|
||||
const auto reshape_script_1 = R"JIT(
|
||||
def forward(self, a: Tensor, shape: List[int]):
|
||||
return a.reshape(shape)
|
||||
b = a.reshape(shape)
|
||||
return b + b
|
||||
)JIT";
|
||||
|
||||
const auto reshape_script_2 = R"JIT(
|
||||
|
|
@ -38,7 +39,8 @@ const auto reshape_script_2 = R"JIT(
|
|||
|
||||
const auto flatten_script_1 = R"JIT(
|
||||
def forward(self, a: Tensor, start_dim: int, end_dim: int):
|
||||
return torch.flatten(a, start_dim, end_dim)
|
||||
b = torch.flatten(a, start_dim, end_dim)
|
||||
return b + b
|
||||
)JIT";
|
||||
|
||||
const auto flatten_script_2 = R"JIT(
|
||||
|
|
|
|||
|
|
@ -59,6 +59,15 @@ void testStaticRuntime(
|
|||
script::Module module("module");
|
||||
module.define(jit_script);
|
||||
|
||||
std::vector<IValue> args_tensors, args_copy;
|
||||
for (const auto& ival : args) {
|
||||
if (ival.isTensor()) {
|
||||
args_tensors.emplace_back(ival);
|
||||
const at::Tensor& t = ival.toTensor();
|
||||
args_copy.emplace_back(t.clone());
|
||||
}
|
||||
}
|
||||
|
||||
auto expect = module.forward(args);
|
||||
|
||||
StaticRuntime runtime(module);
|
||||
|
|
@ -72,6 +81,8 @@ void testStaticRuntime(
|
|||
} else {
|
||||
EXPECT_TRUE(expect.toTensor().equal(actual.toTensor()));
|
||||
}
|
||||
// make sure inputs were not modified
|
||||
compareTensorLists(args_tensors, args_copy);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
|
|||
|
|
@ -655,11 +655,28 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
|
|||
MemoryPlanner::MemoryPlanner(
|
||||
StaticRuntime* runtime,
|
||||
std::unordered_map<Value*, std::vector<Value*>> should_share) {
|
||||
// get input Value*
|
||||
at::ArrayRef<Value*> inputs =
|
||||
runtime->get_inference_module()->graph->inputs();
|
||||
std::unordered_set<Value*> graph_input_values(inputs.begin(), inputs.end());
|
||||
|
||||
// collect register indices of outputs of ops with out variant
|
||||
std::unordered_set<Value*> managed_values;
|
||||
std::unordered_set<IValue*> unmanaged_value_set;
|
||||
for (ProcessedNode& pnode : runtime->get_nodes()) {
|
||||
if (pnode.has_out_variant()) {
|
||||
bool should_manage = pnode.has_out_variant();
|
||||
if (should_manage && isViewOp(pnode.get_node())) {
|
||||
// outputs of view ops with inputs as the graph inputs shouldn't be
|
||||
// managed by the MemoryPlanner. It may release the storage of the graph
|
||||
// inputs.
|
||||
for (Value* in : pnode.get_node()->inputs()) {
|
||||
if (graph_input_values.count(in) > 0) {
|
||||
should_manage = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (should_manage) {
|
||||
// Types are stored in the underlying TorchScript IR
|
||||
for (Value* out : pnode.get_node()->outputs()) {
|
||||
if (out->type()->cast<TensorType>()) {
|
||||
|
|
|
|||
|
|
@ -117,6 +117,11 @@ bool canReuseInputsOutputs(Node* n) {
|
|||
return !SRViewOperatorRegistry()->Has(op_name);
|
||||
}
|
||||
|
||||
bool isViewOp(Node* n) {
|
||||
auto op_name = std::string(n->kind().toQualString());
|
||||
return SRViewOperatorRegistry()->Has(op_name);
|
||||
}
|
||||
|
||||
bool canReuseInputs(Node* n) {
|
||||
auto op_name = std::string(n->kind().toQualString());
|
||||
if (SROperatorRegistry()->Has(op_name)) {
|
||||
|
|
|
|||
|
|
@ -74,6 +74,7 @@ bool canRunOutOfPlace(Node* n);
|
|||
bool canReuseInputsOutputs(Node* n);
|
||||
bool canReuseInputs(Node* n);
|
||||
bool canReuseOutputs(Node* n);
|
||||
bool isViewOp(Node* n);
|
||||
|
||||
std::function<void(ProcessedNode*)> getOutOfPlaceOperation(Node* n);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user