[StaticRuntime] Fix bug in MemoryPlanner (#51342)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51342

There is a subtle bug with the MemoryPlanner with regard to view ops with out variant.

```
  def forward(self, a: Tensor, shape: List[int]):
      b = a.reshape(shape)
      return b + b
```
In this case, if we replace reshape with the out variant, b would be managed by the MemoryPlanner and the storage of its output would have been set to nullptr right after inference by the MemoryPlanner if opts.cleanup_activations is true. Because b is a view of a, the storage of a is also set to nullptr, and this violates the API which promises that a is const.

To fix this bug, I changed the MemoryPlanner so that it puts b in the unmanaged part.

Test Plan:
Add unit test to enforce the constness of inputs

```
buck test //caffe2/benchmarks/static_runtime:static_runtime_cpptest
```

Reviewed By: ajyu

Differential Revision: D26144203

fbshipit-source-id: 2dbacccf7685d0fe0f0b1195166e0510b2069fe3
This commit is contained in:
Hao Lu 2021-01-29 21:13:35 -08:00 committed by Facebook GitHub Bot
parent 09e48dbd33
commit 11cda929fb
5 changed files with 39 additions and 3 deletions

View File

@ -27,7 +27,8 @@ const auto add_script = R"JIT(
const auto reshape_script_1 = R"JIT(
def forward(self, a: Tensor, shape: List[int]):
return a.reshape(shape)
b = a.reshape(shape)
return b + b
)JIT";
const auto reshape_script_2 = R"JIT(
@ -38,7 +39,8 @@ const auto reshape_script_2 = R"JIT(
const auto flatten_script_1 = R"JIT(
def forward(self, a: Tensor, start_dim: int, end_dim: int):
return torch.flatten(a, start_dim, end_dim)
b = torch.flatten(a, start_dim, end_dim)
return b + b
)JIT";
const auto flatten_script_2 = R"JIT(

View File

@ -59,6 +59,15 @@ void testStaticRuntime(
script::Module module("module");
module.define(jit_script);
std::vector<IValue> args_tensors, args_copy;
for (const auto& ival : args) {
if (ival.isTensor()) {
args_tensors.emplace_back(ival);
const at::Tensor& t = ival.toTensor();
args_copy.emplace_back(t.clone());
}
}
auto expect = module.forward(args);
StaticRuntime runtime(module);
@ -72,6 +81,8 @@ void testStaticRuntime(
} else {
EXPECT_TRUE(expect.toTensor().equal(actual.toTensor()));
}
// make sure inputs were not modified
compareTensorLists(args_tensors, args_copy);
}
} // namespace

View File

@ -655,11 +655,28 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
MemoryPlanner::MemoryPlanner(
StaticRuntime* runtime,
std::unordered_map<Value*, std::vector<Value*>> should_share) {
// get input Value*
at::ArrayRef<Value*> inputs =
runtime->get_inference_module()->graph->inputs();
std::unordered_set<Value*> graph_input_values(inputs.begin(), inputs.end());
// collect register indices of outputs of ops with out variant
std::unordered_set<Value*> managed_values;
std::unordered_set<IValue*> unmanaged_value_set;
for (ProcessedNode& pnode : runtime->get_nodes()) {
if (pnode.has_out_variant()) {
bool should_manage = pnode.has_out_variant();
if (should_manage && isViewOp(pnode.get_node())) {
// outputs of view ops with inputs as the graph inputs shouldn't be
// managed by the MemoryPlanner. It may release the storage of the graph
// inputs.
for (Value* in : pnode.get_node()->inputs()) {
if (graph_input_values.count(in) > 0) {
should_manage = false;
break;
}
}
}
if (should_manage) {
// Types are stored in the underlying TorchScript IR
for (Value* out : pnode.get_node()->outputs()) {
if (out->type()->cast<TensorType>()) {

View File

@ -117,6 +117,11 @@ bool canReuseInputsOutputs(Node* n) {
return !SRViewOperatorRegistry()->Has(op_name);
}
bool isViewOp(Node* n) {
auto op_name = std::string(n->kind().toQualString());
return SRViewOperatorRegistry()->Has(op_name);
}
bool canReuseInputs(Node* n) {
auto op_name = std::string(n->kind().toQualString());
if (SROperatorRegistry()->Has(op_name)) {

View File

@ -74,6 +74,7 @@ bool canRunOutOfPlace(Node* n);
bool canReuseInputsOutputs(Node* n);
bool canReuseInputs(Node* n);
bool canReuseOutputs(Node* n);
bool isViewOp(Node* n);
std::function<void(ProcessedNode*)> getOutOfPlaceOperation(Node* n);