Implement MM fusion (MM with add reduction tree) (#4615)

Implement MM fusion (MM with add reduction tree)

A tree where leaves are matrix multiplies and inner
vertices are adds can be computed as a single mm.
Such subgraph often appear in backward if a single weight
is reused multiple times (e.g. in RNNs).

NOTE: this seems to be slightly slower on the GPU than the
naive implementation, but it's a huge win on the CPU
(think 100x lower overhead)
This commit is contained in:
Adam Paszke 2018-01-17 21:36:21 +01:00 committed by GitHub
parent db7f5dae77
commit 1a02d3ae86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 240 additions and 8 deletions

View File

@ -473,6 +473,7 @@ main_sources = [
"torch/csrc/jit/passes/peephole.cpp",
"torch/csrc/jit/passes/inplace_check.cpp",
"torch/csrc/jit/passes/canonicalize.cpp",
"torch/csrc/jit/passes/batch_mm.cpp",
"torch/csrc/jit/passes/onnx/peephole.cpp",
"torch/csrc/jit/generated/aten_dispatch.cpp",
"torch/csrc/autograd/init.cpp",

View File

@ -290,6 +290,7 @@ class TestJit(TestCase):
torch._C._jit_pass_fuse(trace)
self.assertExpectedTrace(trace)
@unittest.skipIf(IS_WINDOWS, "Mysteriously fails on Windows")
def test_arg_configurations(self):
"""Different arg configurations should trigger different traces"""
x = Variable(torch.FloatTensor(4, 4).uniform_())
@ -730,6 +731,7 @@ class TestJit(TestCase):
del z
check(False, True)
@unittest.skipIf(IS_WINDOWS, "Mysteriously fails on Windows")
def test_multiuse_fn(self):
x = Variable(torch.randn(2, 2), requires_grad=True)
w = Variable(torch.randn(2, 2), requires_grad=True)
@ -746,6 +748,7 @@ class TestJit(TestCase):
torch.jit.verify(cell, (x, w), devices=[])
@unittest.skipIf(IS_WINDOWS, "Mysteriously fails on Windows")
def test_output_unflatten(self):
"""Check that outputs of traced functions retain the original structure and nesting"""
x = Variable(torch.randn(2, 2), requires_grad=True)

View File

@ -0,0 +1,206 @@
#include "torch/csrc/jit/passes/batch_mm.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/interned_strings.h"
#include "torch/csrc/utils/functional.h"
#include <ATen/ATen.h>
#include <algorithm>
#include <unordered_map>
namespace torch { namespace jit {
// This pass looks for trees in the graph, where leaves are mm ops, and the inner
// vertices are add nodes. Once we have such a tree they can be reduced to two
// concats and a single mm (basically into a single multiply of a wide matrix, with
// a tall matrix).
// Such patterns show up mostly in backward of RNNs, since the derivative of many
// uses of matrix multiplies with same weights forms exactly such a tree
// (note that it's usually also highly imbalanced i.e. has O(n) depth).
//
// This (or any tree of adds of MMs):
//
// +------+ +------+ +------+ +------+ +------+
// | | | | | | | | | |
// | L1 | | R1 | + | L2 | | R2 | = | O |
// | | | | | | | | | |
// +------+ +------+ +------+ +------+ +------+
//
// can be basically transformed into a single MM which looks like this
// (we concat all lhs operands, concat rhs operands, do mm):
//
// +------+
// | |
// | R1 |
// | |
// +------+
// | |
// | R2 |
// | |
// +------+
// +------+------+ +------+
// | | | | |
// | L1 | L2 | | O |
// | | | | |
// +------+------+ +------+
// Note [Further optimizations]
// It would be straightforward to extend the TreeToken class to also detect if all
// MMs had the same lhs/rhs. In such case it's more efficient to expand the lhs
// and use bmm + sum instead of repeating it in memory via concat.
// Note [Overlapping trees]
// Additionally it wouldn't be too hard to add support for partially overlapping
// trees. Right now the it's forbidden in the algorithm (only a single tree will
// be allowed), so theoretically we might miss some optimization options, especially
// that the rejected tree could be much larger. I didn't implement that because it's
// not necessary for the simple RNN cases I saw, so I decided to keep stuff simple.
// If we ever get around implementing this, the right solution is probably to fuse
// MMs for the common part, and assume it's an input leaf for the outer two parts
// (I don't think it's beneficial to recompute, unless the subtree is super small,
// but let's not get into such details).
// The algorithm we're using is simple. We're iterating through the graph in the
// topological order and labeling nodes with TreeTokens. Then, we look for roots of
// the trees we formed and fuse them.
// Tunable parameter. Set to something larger if it turns out to be better.
static constexpr std::size_t min_fusion_size = 2;
static std::array<int64_t, 2> as_array(at::IntList sizes) {
JIT_ASSERT(sizes.size() == 2);
std::array<int64_t, 2> arr;
arr[0] = sizes[0];
arr[1] = sizes[1];
return arr;
}
// TreeTokens will be used to label nodes of the graph, if the nodes will fit
// our mm/add tree pattern. Basically we do dynamic programming on DAGs, where
// when we reach node N with inputs A and B, then A and B have already been
// procesed, and we can try to unify their TreeTokens (if they have them)
// and build a larger tree.
struct TreeToken {
uint64_t tree_size = 0; // NOTE: measured in number of leaves i.e. mm ops
std::array<int64_t, 2> lhs_sizes;
std::array<int64_t, 2> rhs_sizes;
Node *node = nullptr;
bool is_root = false;
static TreeToken fromMM(Node *mm) {
TreeToken token;
token.tree_size = 1;
Value *lhs = mm->inputs()[0];
Value *rhs = mm->inputs()[1];
token.lhs_sizes = as_array(lhs->type()->expect<TensorType>()->sizes());
token.rhs_sizes = as_array(rhs->type()->expect<TensorType>()->sizes());
token.node = mm;
token.is_root = true;
return token;
}
static TreeToken unify(Node *add, TreeToken& l, TreeToken& r) {
TreeToken token;
// See Note [Overlapping trees]
if (&l == &r || !l.is_root || !r.is_root)
return token;
// We can batch the tree only if all sizes match, because we need to
// cat inputs for both operands
if (l.lhs_sizes != r.lhs_sizes)
return token;
if (l.rhs_sizes != r.rhs_sizes)
return token;
token.tree_size = l.tree_size + r.tree_size;
token.lhs_sizes = l.lhs_sizes;
token.rhs_sizes = l.rhs_sizes;
token.node = add;
token.is_root = true;
l.is_root = r.is_root = false; // Reserve the subtrees, so they can't be used again.
return token;
}
operator bool() {
return is_root;
}
std::vector<Node*> gatherMatMuls() {
static const Symbol mm_kind = "mm"_sym;
std::vector<Node*> matmuls;
std::vector<Node*> queue {node};
while (!queue.empty()) {
auto n = queue.back(); queue.pop_back();
if (n->kind() == mm_kind) {
matmuls.push_back(n);
} else {
queue.push_back(n->inputs()[0]->node());
queue.push_back(n->inputs()[1]->node());
}
}
return matmuls;
}
};
void BatchMM(std::shared_ptr<Graph>& graph) {
enum class Side { LHS, RHS };
static const Symbol mm_kind = "mm"_sym;
static const Symbol add_kind = "add"_sym;
static const Symbol cat_kind = "cat"_sym;
static const Symbol dim_sym = "dim"_sym;
// Look for trees in the graph
std::unordered_map<Node*, TreeToken> tokens;
for (auto node : graph->nodes()) {
if (node->kind() == mm_kind) {
tokens[node] = TreeToken::fromMM(node);
} else if (node->kind() == add_kind) {
Node *lhs = node->inputs()[0]->node();
Node *rhs = node->inputs()[1]->node();
auto lhs_it = tokens.find(lhs);
auto rhs_it = tokens.find(rhs);
// See Note [Overlapping trees] (regarding the uses().size() == 1 check)
// We could treat a subtree with multiple uses as if it was overlapping.
// XXX: uses().size() == 1 is also something that guarantees that this
// transform is valid, because we know for sure that the none of these
// operands depend on the result of the other. If we were to remove this,
// we need to compute a transitive closure and actually check the dependencies.
if (lhs_it != tokens.end() && rhs_it != tokens.end() &&
lhs->output()->uses().size() == 1 && rhs->output()->uses().size() == 1) {
if (auto token = TreeToken::unify(node, lhs_it->second, rhs_it->second))
tokens[node] = token;
}
}
}
// Merge trees we've found
for (auto & item : tokens) {
auto & root = item.second;
if (!root || root.tree_size < min_fusion_size)
continue;
auto matmuls = root.gatherMatMuls();
auto type = root.node->output()->type()->expect<TensorType>();
auto batch_inputs = [&](Side s, std::array<int64_t, 2> cat_sizes) -> Value* {
int inputs_off = s == Side::LHS ? 0 : 1;
int cat_dim = s == Side::LHS ? 1 : 0;
cat_sizes[cat_dim] *= matmuls.size(); // make them really cat_sizes
auto inputs = fmap(matmuls, [=](Node *mm) { return mm->inputs()[inputs_off]; });
Node *cat = graph->create(cat_kind, inputs)
->i_(dim_sym, cat_dim);
cat->insertBefore(root.node);
cat->output()->setType(type->withSizes(cat_sizes));
return cat->output();
};
auto lhs_batch = batch_inputs(Side::LHS, root.lhs_sizes);
auto rhs_batch = batch_inputs(Side::RHS, root.rhs_sizes);
Node *batch_mm = graph->create(mm_kind, {lhs_batch, rhs_batch});
batch_mm->output()->setType(type->asShared());
batch_mm->insertBefore(root.node);
root.node->output()->replaceAllUsesWith(batch_mm->output());
// NB: don't bother with cleaning up after yourself. We'll use DCE for that.
}
EliminateDeadCode(graph);
}
}}

View File

@ -0,0 +1,9 @@
#pragma once
#include "torch/csrc/jit/ir.h"
namespace torch { namespace jit {
void BatchMM(std::shared_ptr<Graph>& graph);
}}

View File

@ -9,6 +9,7 @@
#include "torch/csrc/jit/passes/peephole.h"
#include "torch/csrc/jit/passes/graph_fuser.h"
#include "torch/csrc/jit/passes/inplace_check.h"
#include "torch/csrc/jit/passes/batch_mm.h"
#include "torch/csrc/jit/python_arg_flatten.h"
#include "torch/csrc/jit/interpreter.h"
#include "torch/csrc/jit/interpreter_autograd_function.h"
@ -80,6 +81,7 @@ struct CompiledFunction {
CheckInplace(complete_trace->graph);
if (fn_.optimize_) {
PeepholeOptimize(complete_trace->graph);
BatchMM(complete_trace->graph);
FuseGraph(complete_trace->graph);
EliminateCommonSubexpression(complete_trace->graph);
}

View File

@ -24,7 +24,7 @@ enum class TypeKind {
struct Type;
using TypePtr = std::shared_ptr<Type>;
struct Type {
struct Type : std::enable_shared_from_this<Type> {
private:
TypeKind kind_;
@ -50,6 +50,9 @@ public:
JIT_ASSERT(T::Kind == kind());
return static_cast<T*>(this);
}
std::shared_ptr<Type> asShared() {
return shared_from_this();
}
};
// This node represents a single Tensor value
@ -61,7 +64,7 @@ struct TensorType : public Type {
, device_(tensor.type().is_cuda() ? tensor.get_device() : -1)
, sizes_(tensor.sizes())
, strides_(tensor.strides()) {}
TensorType(at::ScalarType scalar_type, int device, std::vector<int64_t> sizes, std::vector<int64_t> strides)
TensorType(at::ScalarType scalar_type, int device, at::IntList sizes, at::IntList strides)
: Type(TypeKind::TensorType)
, scalar_type_(scalar_type)
, device_(device)
@ -76,20 +79,28 @@ struct TensorType : public Type {
const std::vector<std::int64_t>& sizes() const { return sizes_; }
const std::vector<std::int64_t>& strides() const { return strides_; }
TypePtr withSizesStrides(const std::vector<std::int64_t>& sizes, const std::vector<std::int64_t>& strides) const {
TypePtr withSizesStrides(at::IntList sizes, at::IntList strides) const {
return std::make_shared<TensorType>(scalar_type_, device_, sizes, strides);
}
TypePtr withSizes(at::IntList sizes) const {
return withSizesStrides(sizes, contiguousStridesOf(sizes));
}
TypePtr contiguous() const {
auto t = std::make_shared<TensorType>(*this);
t->strides_.resize(sizes_.size());
t->strides_.back() = 1;
for(size_t i = t->strides_.size() - 1; i > 0; i--) {
t->strides_[i-1] = t->strides_[i] * t->sizes_[i];
}
t->strides_ = contiguousStridesOf(sizes_);
return t;
}
private:
std::vector<int64_t> contiguousStridesOf(at::IntList sizes) const {
std::vector<int64_t> strides(sizes.size());
strides.back() = 1;
for(std::size_t i = strides.size() - 1; i > 0; i--) {
strides[i-1] = strides[i] * sizes[i];
}
return strides;
}
at::ScalarType scalar_type_;
int device_;
std::vector<int64_t> sizes_;