[JIT] Add a phase to perform inplace<->functional conversion for activation operators (#57477)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57477

Currently the conversion only deals with activation operators. The legality check is somewhat strict for now.

Test Plan:
```
python test/test_jit.py -k test_functional_to_inplace_activation
python test/test_jit.py -k test_inplace_to_functional_activation
```

Reviewed By: mrshenli

Differential Revision: D28155153

Pulled By: desertfire

fbshipit-source-id: df092830c4dff3ce9578ff76285eb7a566b7d81b
This commit is contained in:
Bin Bao 2021-06-03 06:41:43 -07:00 committed by Facebook GitHub Bot
parent 91b7bcf4c0
commit add291cf66
9 changed files with 400 additions and 7 deletions

View File

@ -38,6 +38,7 @@
#include <torch/csrc/jit/passes/lower_tuples.h>
#include <torch/csrc/jit/passes/pass_manager.h>
#include <torch/csrc/jit/passes/requires_grad_analysis.h>
#include <torch/csrc/jit/passes/restore_mutation.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/runtime/argument_spec.h>
@ -2593,5 +2594,39 @@ graph(%x.1 : Tensor):
testing::FileCheck().check_not("aten::add_")->run(*graph);
}
TEST(TestInplaceToFunctionalActivation, Basic) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=1]()
%x.3 : Tensor = aten::add(%x.1, %2, %2)
%y : Tensor = aten::relu_(%x.3)
return (%y))IR",
&*graph,
vmap);
InplaceToFunctionalActivation(graph);
testing::FileCheck().check("aten::relu")->run(*graph);
testing::FileCheck().check_not("aten::relu_")->run(*graph);
}
TEST(TestFunctionalToInplaceActivation, Basic) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=1]()
%x.3 : Tensor = aten::add(%x.1, %2, %2)
%y : Tensor = aten::relu(%x.3)
return (%y))IR",
&*graph,
vmap);
FunctionalToInplaceActivation(graph);
testing::FileCheck().check("aten::relu_")->run(*graph);
testing::FileCheck().check_not("aten::relu(")->run(*graph);
}
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,175 @@
import os
import sys
from itertools import product
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.testing import FileCheck
import unittest
try:
import torchvision
HAS_TORCHVISION = True
except ImportError:
HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
activations = [
F.celu,
F.elu,
F.hardsigmoid,
F.hardswish,
F.hardtanh,
F.leaky_relu,
F.relu,
F.relu6,
F.rrelu,
F.selu,
F.silu,
]
class TestFunctionalToInplaceActivation(JitTestCase):
def test_check_no_type_promotion(self):
dtypes = [
torch.bool,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.float32,
torch.float64,
]
# restore_mutation.h contains a mapping from activation operators
# to whether they allow type conversion. Use this checking to
# guard the mapping, and if any later change breaks the assumption
# we need to update the mapping correspondingly.
for activation, dtype in product(activations, dtypes):
inp = torch.normal(0, 5, size=(4, 4)).to(dtype)
try:
out = activation(inp)
self.assertEqual(dtype, out.dtype)
except RuntimeError:
# Skip the not implemented error
pass
def test_functional_to_inplace_activation(self):
for activation in activations:
def test_basic(x):
y = x + 1
z = activation(y)
return z
fn = torch.jit.script(test_basic)
self.run_pass("inline", fn.graph)
self.run_pass("constant_propagation", fn.graph)
FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph)
self.run_pass('functional_to_inplace_activation', fn.graph)
FileCheck().check_not(f"aten::{activation.__name__}(").run(fn.graph)
FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph)
inp = torch.rand([2, 2])
self.assertEqual(fn(inp), test_basic(inp))
def test_no_functional_to_inplace(self):
# inplace conversion should not happen because sigmoid may
# perform type conversion
def test1():
y = torch.ones([2, 2])
z = torch.sigmoid(y)
return z
fn = torch.jit.script(test1)
self.run_pass('functional_to_inplace_activation', fn.graph)
FileCheck().check_not("aten::sigmoid_").run(fn.graph)
# inplace conversion should not happen because y is alias
# the input x
def test2(x):
y = x[0]
z = torch.relu(y)
return z
fn = torch.jit.script(test2)
self.run_pass('functional_to_inplace_activation', fn.graph)
FileCheck().check_not("aten::relu_").run(fn.graph)
# inplace conversion should not happen because self.x is
# at the global scope
class Test3(nn.Module):
def __init__(self, x):
super(Test3, self).__init__()
self.x = x
def forward(self):
y = torch.relu(self.x)
return y
fn = torch.jit.script(Test3(torch.rand([2, 2])).eval())
self.run_pass('functional_to_inplace_activation', fn.graph)
FileCheck().check_not("aten::relu_").run(fn.graph)
@skipIfNoTorchVision
def test_resnet18_correctness(self):
model = torchvision.models.resnet18()
frozen_model = torch.jit.freeze(torch.jit.script(model.eval()))
N, C, H, W, = 10, 3, 224, 224
inp = torch.randn(N, C, H, W)
self.run_pass('functional_to_inplace_activation', frozen_model.graph)
self.assertEqual(model(inp), frozen_model(inp))
class TestInplaceToFunctionalActivation(JitTestCase):
def test_inplace_to_functional_activation(self):
for activation in activations:
def test_basic(x):
y = x + 1
activation(y, inplace=True)
return y
fn = torch.jit.script(test_basic)
self.run_pass("inline", fn.graph)
self.run_pass("constant_propagation", fn.graph)
FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph)
self.run_pass('inplace_to_functional_activation', fn.graph)
FileCheck().check_not(f"aten::{activation.__name__}_").run(fn.graph)
FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph)
for activation in [
torch.relu_,
torch.sigmoid_,
torch.tanh_,
]:
def test_basic(x):
y = x + 1
activation(y)
return y
fn = torch.jit.script(test_basic)
self.run_pass("inline", fn.graph)
self.run_pass("constant_propagation", fn.graph)
FileCheck().check(f"aten::{activation.__name__}").run(fn.graph)
self.run_pass('inplace_to_functional_activation', fn.graph)
FileCheck().check_not(f"aten::{activation.__name__}").run(fn.graph)
FileCheck().check(f"aten::{activation.__name__[:-1]}(").run(fn.graph)
inp = torch.rand([2, 2])
self.assertEqual(fn(inp), test_basic(inp))
@skipIfNoTorchVision
def test_resnet18_correctness(self):
model = torchvision.models.resnet18()
frozen_model = torch.jit.freeze(torch.jit.script(model.eval()))
N, C, H, W, = 10, 3, 224, 224
inp = torch.randn(N, C, H, W)
self.run_pass('inplace_to_functional_activation', frozen_model.graph)
self.assertEqual(model(inp), frozen_model(inp))

View File

@ -55,6 +55,7 @@ from jit.test_pdt import TestPDT # noqa: F401
from jit.test_tensor_creation_ops import TestTensorCreationOps # noqa: F401
from jit.test_module_apis import TestModuleAPIs # noqa: F401
from jit.test_script_profile import TestScriptProfile # noqa: F401
from jit.test_convert_activation import TestFunctionalToInplaceActivation, TestInplaceToFunctionalActivation # noqa: F401
# Torch
from torch import Tensor

View File

@ -182,6 +182,7 @@ core_sources_full_mobile = [
"torch/csrc/jit/passes/concat_opt.cpp",
"torch/csrc/jit/passes/constant_pooling.cpp",
"torch/csrc/jit/passes/constant_propagation.cpp",
"torch/csrc/jit/passes/restore_mutation.cpp",
"torch/csrc/jit/passes/create_autodiff_subgraphs.cpp",
"torch/csrc/jit/passes/dead_code_elimination.cpp",
"torch/csrc/jit/passes/remove_redundant_profiles.cpp",

View File

@ -1,4 +1,5 @@
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/restore_mutation.h>
namespace torch {
namespace jit {
@ -11,7 +12,7 @@ bool MutationRemover::removeTensorMutation() {
return RemoveTensorMutation(graph_->block());
}
bool MutationRemover::newMemoryLocation(Value* v) {
bool MutationRemover::hasSideEffectOrAlias(Value* v, AliasDb* aliasDb) {
// bail on nodes with side effects, blocks, or graph / graph inputs
Node* n = v->node();
bool unhandled_node = n->blocks().size() != 0 ||
@ -20,9 +21,8 @@ bool MutationRemover::newMemoryLocation(Value* v) {
// if the output isn't contained or alias by the inputs to its node, it's
// unique
return !unhandled_node &&
!getOrCreateAliasDb()->mayContainAlias(v->node()->inputs(), v) &&
!(v->node()->kind() == prim::Param);
return unhandled_node || aliasDb->mayContainAlias(v->node()->inputs(), v) ||
(v->node()->kind() == prim::Param);
}
Node* MutationRemover::createSpecialMappedOp(Node* n) {
@ -81,7 +81,7 @@ bool MutationRemover::tryMakeCreationAndMutationAtomic(
// We can only remove mutation to values that are unique aliases in the
// graph. if x = y[0] or y = self.y, then removing the mutation could
// change observable semantics
if (!newMemoryLocation(mutated_value)) {
if (hasSideEffectOrAlias(mutated_value, getOrCreateAliasDb())) {
return false;
}
@ -116,7 +116,8 @@ bool MutationRemover::tryMakeUnaliasedIfOutputAndMutationAtomic(
return false;
}
if (!newMemoryLocation(true_value) || !newMemoryLocation(false_value)) {
if (hasSideEffectOrAlias(true_value, getOrCreateAliasDb()) ||
hasSideEffectOrAlias(false_value, getOrCreateAliasDb())) {
return false;
}
@ -329,5 +330,20 @@ bool RemoveTensorMutation(
return mr.removeTensorMutation();
}
static const std::unordered_set<Symbol> activation_ops = []() {
std::unordered_set<Symbol> target_ops;
for (const auto& iter : activation_type_promotion_mapping) {
std::string name = std::string(iter.first.toQualString()) + "_";
target_ops.insert(Symbol::fromQualString(name));
}
return target_ops;
}();
bool InplaceToFunctionalActivation(const std::shared_ptr<Graph>& graph) {
return RemoveTensorMutation(graph, [](Node* node) {
return activation_ops.count(node->kind()) != 0;
});
}
} // namespace jit
} // namespace torch

View File

@ -33,8 +33,9 @@ struct TORCH_API MutationRemover {
bool inplaceOpVariant(Node* n);
static bool hasSideEffectOrAlias(Value* v, AliasDb* aliasDb);
private:
bool newMemoryLocation(Value* v);
Node* createSpecialMappedOp(Node* n);
bool listMutationFollowingListConstruct(Node* n);
bool tryMakeCreationAndMutationAtomic(
@ -73,5 +74,9 @@ TORCH_API bool RemoveTensorMutation(
const std::shared_ptr<Graph>& graph,
c10::optional<std::function<bool(Node*)>> mutation_filter = c10::nullopt);
// Replaces in-place aten activation ops with their functional equivalence
TORCH_API bool InplaceToFunctionalActivation(
const std::shared_ptr<Graph>& graph);
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,85 @@
#include <ATen/core/interned_strings.h>
#include <ATen/core/jit_type.h>
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/restore_mutation.h>
namespace torch {
namespace jit {
FunctionalToInplaceRewriter::FunctionalToInplaceRewriter(
std::shared_ptr<Graph> graph)
: aliasDb_(nullptr), graph_(std::move(graph)) {}
bool FunctionalToInplaceRewriter::CanBeInplace(Node* node) {
if (activation_type_promotion_mapping.find(node->kind()) ==
activation_type_promotion_mapping.end()) {
return false;
}
Symbol inplace_op =
Symbol::fromQualString(std::string(node->kind().toQualString()) + "_");
if (!inplace_op) {
return false;
}
// If type promotion is allowed, then perform dtype check
bool check_dtype = activation_type_promotion_mapping.at(node->kind());
Value* input = node->inputs().at(0);
Value* output = node->outputs().at(0);
auto inputDtype = input->type()->expect<TensorType>()->scalarType();
auto outputDtype = output->type()->expect<TensorType>()->scalarType();
// In general, we don't need to check shape for activation ops as they
// element-wise. But for those where type promotion could happen, we need to
// make sure the dtype of input and output are the same. For now the dtype
// checking will always fail until the type inference is ready.
if (check_dtype &&
(!inputDtype || !outputDtype ||
inputDtype.value() != outputDtype.value())) {
return false;
}
// Skip if input's def node has side effect or input has alias
if (MutationRemover::hasSideEffectOrAlias(input, getOrCreateAliasDb())) {
return false;
}
// If x has more than one use, skip the converson.
// TODO: Use liveness analysis to catch more general scenario
return (input->uses().size() == 1);
}
bool FunctionalToInplaceRewriter::FunctionalToInplace(Block* block) {
bool changed = false;
for (auto it = block->nodes().begin(); it != block->nodes().end();) {
auto* node = *it;
it++;
for (Block* sub_block : node->blocks()) {
changed |= FunctionalToInplace(sub_block);
}
if (!CanBeInplace(node)) {
continue;
}
changed = true;
Node* inplace_node = node->replaceWithNewSymbol(
Symbol::fromQualString(node->schema().name() + "_"));
inplace_node->output()->replaceAllUsesWith(node->inputs().at(0));
getOrCreateAliasDb()->replaceWithNewValue(
node->output(), inplace_node->output());
node->destroy();
}
return changed;
}
bool FunctionalToInplaceActivation(const std::shared_ptr<Graph>& graph) {
FunctionalToInplaceRewriter rewriter(graph);
return rewriter.FunctionalToInplace(graph->block());
}
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,64 @@
#pragma once
#include <ATen/core/interned_strings.h>
#include <c10/util/Exception.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/utils/memory.h>
namespace torch {
namespace jit {
// A map which stores if an activation operator can perform type promotion
const std::unordered_map<Symbol, bool> activation_type_promotion_mapping = {
{aten::sigmoid, true},
{aten::tanh, true},
{aten::celu, false},
{aten::elu, false},
{aten::gelu, false},
{aten::glu, false},
{aten::hardshrink, false},
{aten::hardsigmoid, false},
{aten::hardswish, false},
{aten::hardtanh, false},
{aten::leaky_relu, false},
{aten::prelu, false},
{aten::relu6, false},
{aten::relu, false},
{aten::rrelu, false},
{aten::selu, false},
{aten::silu, false}};
class FunctionalToInplaceRewriter {
public:
FunctionalToInplaceRewriter(std::shared_ptr<Graph> graph);
bool FunctionalToInplace(Block* block);
private:
AliasDb* getOrCreateAliasDb() {
if (!aliasDb_) {
aliasDb_ = std::make_unique<AliasDb>(graph_);
}
return aliasDb_.get();
}
bool CanBeInplace(Node* node);
std::unique_ptr<AliasDb> aliasDb_ = nullptr;
std::shared_ptr<Graph> graph_;
};
// A common application scenario is to apply InplaceToFunctionalActivation
// before some JIT optimization passes, so that those passes are less
// constrained by in-place ops. After those passes are done, we can call
// FunctionalToInplaceActivation to recover in-place activation ops,
// so that we won't lose the performance benefit coming from memory reduction.
// Replaces functional aten activation ops with their in-place equivalents
TORCH_API bool FunctionalToInplaceActivation(
const std::shared_ptr<Graph>& graph);
} // namespace jit
} // namespace torch

View File

@ -67,6 +67,7 @@
#include <torch/csrc/jit/passes/remove_expands.h>
#include <torch/csrc/jit/passes/remove_inplace_ops.h>
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/restore_mutation.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/passes/specialize_autogradzero.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
@ -419,6 +420,16 @@ void initJITBindings(PyObject* module) {
RemoveListMutation(g);
return RemoveTensorMutation(g);
})
.def(
"_jit_pass_functional_to_inplace_activation",
[](std::shared_ptr<Graph>& g) {
return FunctionalToInplaceActivation(g);
})
.def(
"_jit_pass_inplace_to_functional_activation",
[](std::shared_ptr<Graph>& g) {
return InplaceToFunctionalActivation(g);
})
.def(
"_jit_pass_inline_functional_graphs",
[](std::shared_ptr<Graph>& g) { return InlineFunctionalGraphs(g); })