JIT Layernorm fusion (#18266)

Summary:
Partially fuse layer_norm by decomposing layer_norm into the batchnorm kernel that computes the stats, and then fusing the affine operations after the reduce operations, this is similar to the batchnorm fusion that apaszke did, it also only works in inference mode now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18266

Differential Revision: D14879877

Pulled By: wanchaol

fbshipit-source-id: 0197d8f2a17ec438d3e53f4c411d759c1ae81efe
This commit is contained in:
Wanchao Liang 2019-04-12 14:24:37 -07:00 committed by Facebook Github Bot
parent 0e435afc3c
commit a3d3008e73
2 changed files with 175 additions and 87 deletions

View File

@ -428,43 +428,56 @@ class TestFuser(JitTestCase):
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@skipIfRocm
def test_fuse_batch_norm(self):
def test_fuse_decompose_normalization(self):
class ResLike(torch.jit.ScriptModule):
def __init__(self, optimize=True):
def __init__(self, norm_module, optimize=True):
super(ResLike, self).__init__(optimize)
self.bn = nn.BatchNorm2d(16)
self.nm = norm_module
@torch.jit.script_method
def forward(self, x, y):
return y + torch.relu(self.bn(x))
return y + torch.relu(self.nm(x))
model = ResLike().cuda()
model_noopt = ResLike(optimize=False).cuda()
model_noopt.load_state_dict(model.state_dict())
x = torch.randn(2, 16, 8, 8, device='cuda')
y = torch.randn(2, 16, 8, 8, device='cuda')
# FIXME: We need differentiation for CNNs for this optimization to trigger
with torch.no_grad():
out = model(x, y)
graph = model.graph_for(x, y)
rep = str(graph)
def test_norm_decompose(nm, in_opt_graph, not_in_opt_graph, in_fusegraph):
model = ResLike(nm).cuda()
model_noopt = ResLike(nm, optimize=False).cuda()
model_noopt.load_state_dict(model.state_dict())
x = torch.randn(2, 16, 8, 8, device='cuda')
y = torch.randn(2, 16, 8, 8, device='cuda')
out_noopt = model_noopt(x, y)
rep_noopt = str(model_noopt.graph_for(x, y))
self.assertEqual(out, out_noopt, prec=3e-5)
# FIXME: We need differentiation for CNNs for this optimization to trigger
with torch.no_grad():
out = model(x, y)
graph = model.graph_for(x, y)
rep = str(graph)
# Check that batch_norm has really been decomposed
self.assertIn('aten::batch_norm_update_stats', rep)
self.assertNotIn('aten::batch_norm(', rep)
self.assertIn('aten::batch_norm(', rep_noopt)
out_noopt = model_noopt(x, y)
rep_noopt = str(model_noopt.graph_for(x, y))
self.assertEqual(out, out_noopt, prec=3e-5)
# Make sure the fusion group is big, and contains aten::sqrt, which could
# originate only from decomposing batch_norm in this case
fusion_groups = [node for node in graph.nodes() if node.kind() == 'prim::FusionGroup']
self.assertEqual(len(fusion_groups), 1)
fused_graph = fusion_groups[0].g('Subgraph')
self.assertTrue(any(node.kind() == 'aten::sqrt' for node in fused_graph.nodes()))
# Check that normalization op has really been decomposed
for node_in_graph in in_opt_graph:
self.assertIn(node_in_graph, rep)
for node_not_in_graph in not_in_opt_graph:
self.assertNotIn(node_not_in_graph, rep)
self.assertIn(node_not_in_graph, rep_noopt)
fusion_groups = [node for node in graph.nodes() if node.kind() == 'prim::FusionGroup']
self.assertEqual(len(fusion_groups), 1)
fused_graph = str(fusion_groups[0].g('Subgraph'))
for node_in_fusegraph in in_fusegraph:
self.assertIn(node_in_fusegraph, fused_graph)
# test for batchnorm decompose
bm = nn.BatchNorm2d(16)
test_norm_decompose(bm, ['aten::batch_norm_update_stats'],
['aten::batch_norm('], ['aten::sqrt'])
# test for layernorm decompose
lm = nn.LayerNorm(8)
test_norm_decompose(lm, ['aten::batch_norm_stats'],
['aten::layer_norm('], ['aten::sub', 'aten::mul', 'aten::addcmul'])
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")

View File

@ -109,7 +109,8 @@ bool isSimpleMap(Node* node) {
return false;
}
for (Value* input : node->inputs()) {
if (input->type()->isSubtypeOf(TensorType::get()) || input->type()->isSubtypeOf(FloatType::get())) {
if (input->type()->isSubtypeOf(TensorType::get()) ||
input->type()->isSubtypeOf(FloatType::get())) {
continue;
}
if (input->node()->kind() != prim::Constant) {
@ -133,6 +134,23 @@ RegisterOperators reg_bn_unsqueeze({Operator(
};
})});
RegisterOperators reg_ln_view({Operator(
"aten::_ncf_view(Tensor self, int[] input_shape, int normalized_ndim) -> Tensor",
[](const Node* node) {
return [](Stack& stack) {
const int64_t normalized_ndim = pop(stack).toInt();
auto input_shape = pop(stack).toIntListRef();
auto self = pop(stack).toTensor();
const int64_t input_ndim = input_shape.size();
c10::SmallVector<int64_t, 8> sizes(input_ndim, 1);
for (int i = 0; i < input_ndim - normalized_ndim; ++i) {
sizes.at(i) = input_shape[i];
}
push(stack, self.reshape(sizes));
return 0;
};
})});
// Yes, no, or no value if we can't tell
c10::optional<bool> isDefined(Value* tensor) {
if (tensor->type()->isSubtypeOf(TensorType::get())) {
@ -144,16 +162,20 @@ c10::optional<bool> isDefined(Value* tensor) {
return {};
}
bool isFusableBatchNorm(Node* batch_norm) {
if (!batch_norm->matches(
"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) {
return false;
bool isFusableNorm(Node* normalize_op) {
static const OperatorSet decomposable_normalization_ops = {
"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
"aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, bool cudnn_enable) -> Tensor",
};
if (decomposable_normalization_ops.find(normalize_op)) {
// If we can't determine if weight and bias is defined statically there's
// really no point in decomposing normalization into simpler ops, since it
// won't get fused into a single kernel.
return isDefined(normalize_op->namedInput(attr::weight)).has_value() &&
isDefined(normalize_op->namedInput(attr::bias)).has_value();
}
// If we can't determine if weight and bias is defined statically there's
// really no point in decomposing batch norm into simpler ops, since it won't
// get fused into a single kernel.
return isDefined(batch_norm->namedInput(attr::weight)).has_value() &&
isDefined(batch_norm->namedInput(attr::bias)).has_value();
return false;
}
Value* broadcastSizes(at::ArrayRef<Value*> sizes) {
@ -187,7 +209,7 @@ struct GraphFuser {
}
bool isFusable(Node* node) {
return isFusableMap(node) || isFusableBatchNorm(node);
return isFusableMap(node) || isFusableNorm(node);
}
bool isFusableMap(Node* node) {
@ -249,13 +271,35 @@ struct GraphFuser {
return *n->g(attr::Subgraph);
}
void decomposeBatchNorm(Node* batch_norm) {
static std::shared_ptr<Graph> bn_graph;
static std::once_flag flag;
Value* decomposeCommonNormalization(
Node* normalization_op,
const char* source,
const std::string& method_name,
const std::vector<Value*>& inputs) {
std::shared_ptr<Graph> nm_graph;
std::once_flag flag;
std::call_once(
flag,
[](std::shared_ptr<Graph>* graph_ptr) {
static const char* source = R"SCRIPT(
[](std::shared_ptr<Graph>* graph_ptr,
const char* source,
const std::string& method_name) {
script::CompilationUnit cu;
cu.define(source, script::nativeResolver, nullptr);
*graph_ptr = cu.get_function(method_name).graph();
},
&nm_graph,
source,
method_name);
AT_ASSERT(isFusableNorm(normalization_op));
WithInsertPoint insert_guard{normalization_op};
Value* new_output =
SubgraphUtils::inlineGraph(nm_graph, inputs, normalization_op).at(0);
return new_output;
}
void decomposeNormalizationOps(Node* normalization_op) {
static const char* bm_source = R"SCRIPT(
def batch_norm(input : Tensor, running_mean : Optional[Tensor], running_var : Optional[Tensor], training : bool, momentum : float, eps : float) -> Tensor:
if training:
norm_mean, norm_var = torch.batch_norm_update_stats(input, running_mean, running_var, momentum)
@ -264,41 +308,74 @@ struct GraphFuser {
norm_var = torch._unwrap_optional(running_var)
norm_mean = torch._ncf_unsqueeze(norm_mean, input.dim())
norm_var = torch._ncf_unsqueeze(norm_var, input.dim())
norm_invstd = 1 / (eps + torch.sqrt(norm_var))
norm_invstd = 1 / (torch.sqrt(norm_var + eps))
return ((input - norm_mean) * norm_invstd)
)SCRIPT";
script::CompilationUnit cu;
cu.define(source, script::nativeResolver, nullptr);
*graph_ptr = cu.get_function("batch_norm").graph();
},
&bn_graph);
static const char* lm_source = R"SCRIPT(
def layer_norm(input : Tensor, normalized_shape : List[int], eps : float, cudnn_enable : bool) -> Tensor:
input_ndim = input.dim()
normalized_ndim = len(normalized_shape)
n = 1
for i in range(input_ndim - normalized_ndim):
n *= input.size(i)
input_reshape = input.contiguous().view(1, n, -1)
mean, invstd = torch.batch_norm_stats(input_reshape, eps)
input_shape = input.size()
mean = torch._ncf_view(mean, input_shape, normalized_ndim)
invstd = torch._ncf_view(invstd, input_shape, normalized_ndim)
AT_ASSERT(isFusableBatchNorm(batch_norm));
WithInsertPoint insert_guard{batch_norm};
Value* input = batch_norm->namedInput(attr::input);
Value* input_dim = graph_->insert(aten::dim, {input});
std::vector<Value*> inputs{input,
batch_norm->namedInput(attr::running_mean),
batch_norm->namedInput(attr::running_var),
batch_norm->namedInput(attr::training),
batch_norm->namedInput(attr::momentum),
batch_norm->namedInput(attr::eps)};
Value* new_output =
SubgraphUtils::inlineGraph(bn_graph, inputs, batch_norm).at(0);
auto weight = batch_norm->namedInput(attr::weight);
auto bias = batch_norm->namedInput(attr::bias);
if (isDefined(weight).value()) {
Value* expanded_weight =
graph_->insert(aten::_ncf_unsqueeze, {weight, input_dim});
new_output = graph_->insert(aten::mul, {new_output, expanded_weight});
return (input - mean) * invstd
)SCRIPT";
Value* input = normalization_op->namedInput(attr::input);
if (normalization_op->kind() == aten::batch_norm) {
Value* input_dim = graph_->insert(aten::dim, {input});
std::vector<Value*> inputs{
input,
normalization_op->namedInput(attr::running_mean),
normalization_op->namedInput(attr::running_var),
normalization_op->namedInput(attr::training),
normalization_op->namedInput(attr::momentum),
normalization_op->namedInput(attr::eps)};
Value* new_output = decomposeCommonNormalization(
normalization_op, bm_source, "batch_norm", inputs);
auto weight = normalization_op->namedInput(attr::weight);
auto bias = normalization_op->namedInput(attr::bias);
if (isDefined(weight).value()) {
Value* expanded_weight =
graph_->insert(aten::_ncf_unsqueeze, {weight, input_dim});
new_output = graph_->insert(aten::mul, {new_output, expanded_weight});
}
if (isDefined(bias).value()) {
Value* expanded_bias =
graph_->insert(aten::_ncf_unsqueeze, {bias, input_dim});
new_output = graph_->insert(aten::add, {new_output, expanded_bias});
}
normalization_op->output()->replaceAllUsesWith(new_output);
normalization_op->destroy();
} else if (normalization_op->kind() == aten::layer_norm) {
std::vector<Value*> inputs{
input,
normalization_op->namedInput(attr::normalized_shape),
normalization_op->namedInput(attr::eps),
normalization_op->namedInput(attr::cudnn_enable)};
Value* new_output = decomposeCommonNormalization(
normalization_op, lm_source, "layer_norm", inputs);
auto weight = normalization_op->namedInput(attr::weight);
auto bias = normalization_op->namedInput(attr::bias);
auto weight_defined = isDefined(weight).value();
auto bias_defined = isDefined(bias).value();
if (weight_defined && bias_defined) {
new_output = graph_->insert(aten::addcmul, {bias, new_output, weight});
} else if (weight_defined) {
new_output = graph_->insert(aten::mul, {new_output, weight});
} else if (bias_defined) {
new_output = graph_->insert(aten::add, {new_output, bias});
}
normalization_op->output()->replaceAllUsesWith(new_output);
normalization_op->destroy();
}
if (isDefined(bias).value()) {
Value* expanded_bias =
graph_->insert(aten::_ncf_unsqueeze, {bias, input_dim});
new_output = graph_->insert(aten::add, {new_output, expanded_bias});
}
batch_norm->output()->replaceAllUsesWith(new_output);
batch_norm->destroy();
}
void mergeFusionGroups(Node* consumer_group, Node* producer_group) {
@ -390,9 +467,10 @@ struct GraphFuser {
group->insertInput(tensor_insert_idx, input);
tensor_insert_idx++;
} else if (
(input->type()->isSubtypeOf(FloatType::get()) && input->node()->kind() != prim::Constant) ||
(n->kind() == aten::_grad_sum_to_size &&
input->type()->isSubtypeOf(ListType::ofInts()))) {
(input->type()->isSubtypeOf(FloatType::get()) &&
input->node()->kind() != prim::Constant) ||
(n->kind() == aten::_grad_sum_to_size &&
input->type()->isSubtypeOf(ListType::ofInts()))) {
auto in_group = subgraph.addInput();
in_group->setType(input->type());
inputs_map[input] = in_group;
@ -453,12 +531,6 @@ struct GraphFuser {
return group;
}
// TODO: remove this and use WithInsertPoint instead
void insertAt(Node** insertion_point, Node* n) {
n->insertAfter(*insertion_point);
*insertion_point = n;
}
at::optional<Node*> tryFuse(Node* consumer, Value* producer) {
// this handles cases where producer can be moved _into_ the fusion group of
// consumer.
@ -506,13 +578,16 @@ struct GraphFuser {
group = createSingletonFusionGroup(consumer);
}
if (producer->node()->matches(
"aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, bool cudnn_enable) -> Tensor") ||
producer->node()->matches(
"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) {
// We don't do any fusions in here, but simply decompose the batch norm
// into a kernel that computes the stats + pointwise ops which will be
// We don't do any fusions in here, but simply decompose the normalization
// ops into a kernel that computes the stats + pointwise ops which will be
// considered in this fusion next.
decomposeBatchNorm(producer->node());
decomposeNormalizationOps(producer->node());
return group;
}
if (producer->node()->kind() == prim::FusionGroup) {
mergeFusionGroups(group, producer->node());
return group;