pytorch/torch/csrc/jit/passes/batch_mm.cpp
Peter Goldsborough 18a975210d Add explicit to conversions (#9336)
Summary:
Another code-mod for clang-tidy: Conversion operators should be marked explicit so that they don't cause unwanted implicit conversions. This is especially important for `operator bool()`, see https://stackoverflow.com/questions/39995573/when-can-i-use-explicit-operator-bool-without-a-cast

ezyang apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9336

Reviewed By: apaszke

Differential Revision: D8807065

Pulled By: goldsborough

fbshipit-source-id: 0e9f4ebd0048a2a510c0d05fa410695d7e977eb1
2018-07-11 12:10:30 -07:00

213 lines
7.8 KiB
C++

#include "torch/csrc/jit/passes/batch_mm.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/interned_strings.h"
#include "torch/csrc/utils/functional.h"
#include <ATen/ATen.h>
#include <algorithm>
#include <unordered_map>
namespace torch { namespace jit {
// This pass looks for trees in the graph, where leaves are mm ops, and the inner
// vertices are add nodes. Once we have such a tree they can be reduced to two
// concats and a single mm (basically into a single multiply of a wide matrix, with
// a tall matrix).
// Such patterns show up mostly in backward of RNNs, since the derivative of many
// uses of matrix multiplies with same weights forms exactly such a tree
// (note that it's usually also highly imbalanced i.e. has O(n) depth).
//
// This (or any tree of adds of MMs):
//
// +------+ +------+ +------+ +------+ +------+
// | | | | | | | | | |
// | L1 | | R1 | + | L2 | | R2 | = | O |
// | | | | | | | | | |
// +------+ +------+ +------+ +------+ +------+
//
// can be basically transformed into a single MM which looks like this
// (we concat all lhs operands, concat rhs operands, do mm):
//
// +------+
// | |
// | R1 |
// | |
// +------+
// | |
// | R2 |
// | |
// +------+
// +------+------+ +------+
// | | | | |
// | L1 | L2 | | O |
// | | | | |
// +------+------+ +------+
// Note [Further optimizations]
// It would be straightforward to extend the TreeToken class to also detect if all
// MMs had the same lhs/rhs. In such case it's more efficient to expand the lhs
// and use bmm + sum instead of repeating it in memory via concat.
// Note [Overlapping trees]
// Additionally it wouldn't be too hard to add support for partially overlapping
// trees. Right now the it's forbidden in the algorithm (only a single tree will
// be allowed), so theoretically we might miss some optimization options, especially
// that the rejected tree could be much larger. I didn't implement that because it's
// not necessary for the simple RNN cases I saw, so I decided to keep stuff simple.
// If we ever get around implementing this, the right solution is probably to fuse
// MMs for the common part, and assume it's an input leaf for the outer two parts
// (I don't think it's beneficial to recompute, unless the subtree is super small,
// but let's not get into such details).
// The algorithm we're using is simple. We're iterating through the graph in the
// topological order and labeling nodes with TreeTokens. Then, we look for roots of
// the trees we formed and fuse them.
// Tunable parameter. Set to something larger if it turns out to be better.
static constexpr size_t min_fusion_size = 2;
static std::array<int64_t, 2> as_array(at::IntList sizes) {
JIT_ASSERT(sizes.size() == 2);
std::array<int64_t, 2> arr;
arr[0] = sizes[0];
arr[1] = sizes[1];
return arr;
}
// TreeTokens will be used to label nodes of the graph, if the nodes will fit
// our mm/add tree pattern. Basically we do dynamic programming on DAGs, where
// when we reach node N with inputs A and B, then A and B have already been
// procesed, and we can try to unify their TreeTokens (if they have them)
// and build a larger tree.
struct TreeToken {
uint64_t tree_size = 0; // NOTE: measured in number of leaves i.e. mm ops
std::array<int64_t, 2> lhs_sizes;
std::array<int64_t, 2> rhs_sizes;
Node *node = nullptr;
bool is_root = false;
static TreeToken fromMM(Node *mm) {
TreeToken token;
token.tree_size = 1;
Value *lhs = mm->inputs()[0];
Value *rhs = mm->inputs()[1];
token.lhs_sizes = as_array(lhs->type()->expect<TensorType>()->sizes());
token.rhs_sizes = as_array(rhs->type()->expect<TensorType>()->sizes());
token.node = mm;
token.is_root = true;
return token;
}
static TreeToken unify(Node *add, TreeToken& l, TreeToken& r) {
TreeToken token;
// See Note [Overlapping trees]
if (&l == &r || !l.is_root || !r.is_root)
return token;
// We can batch the tree only if all sizes match, because we need to
// cat inputs for both operands
if (l.lhs_sizes != r.lhs_sizes)
return token;
if (l.rhs_sizes != r.rhs_sizes)
return token;
token.tree_size = l.tree_size + r.tree_size;
token.lhs_sizes = l.lhs_sizes;
token.rhs_sizes = l.rhs_sizes;
token.node = add;
token.is_root = true;
l.is_root = r.is_root = false; // Reserve the subtrees, so they can't be used again.
return token;
}
explicit operator bool() {
return is_root;
}
std::vector<Node*> gatherMatMuls() {
std::vector<Node*> matmuls;
std::vector<Node*> queue {node};
while (!queue.empty()) {
auto n = queue.back(); queue.pop_back();
if (n->kind() == aten::mm) {
matmuls.push_back(n);
} else {
queue.push_back(n->inputs()[0]->node());
queue.push_back(n->inputs()[1]->node());
}
}
return matmuls;
}
};
void BatchMMBlock(Block* block) {
enum class Side { LHS, RHS };
auto graph = block->owningGraph();
// Look for trees in the block
std::unordered_map<Node*, TreeToken> tokens;
for (auto node : block->nodes()) {
if (node->kind() == aten::mm) {
tokens[node] = TreeToken::fromMM(node);
} else if (node->kind() == aten::add) {
// NOTE: x + 2 is add[other={2}](%x)
if (node->inputs().size() != 2) continue;
Node *lhs = node->inputs()[0]->node();
Node *rhs = node->inputs()[1]->node();
auto lhs_it = tokens.find(lhs);
auto rhs_it = tokens.find(rhs);
// See Note [Overlapping trees] (regarding the uses().size() == 1 check)
// We could treat a subtree with multiple uses as if it was overlapping.
// XXX: uses().size() == 1 is also something that guarantees that this
// transform is valid, because we know for sure that the none of these
// operands depend on the result of the other. If we were to remove this,
// we need to compute a transitive closure and actually check the dependencies.
if (lhs_it != tokens.end() && rhs_it != tokens.end() &&
lhs->output()->uses().size() == 1 && rhs->output()->uses().size() == 1) {
if (auto token = TreeToken::unify(node, lhs_it->second, rhs_it->second))
tokens[node] = token;
}
} else {
for (auto block : node->blocks()) {
BatchMMBlock(block);
}
}
}
// Merge trees we've found
for (auto & item : tokens) {
auto & root = item.second;
if (!root || root.tree_size < min_fusion_size)
continue;
auto matmuls = root.gatherMatMuls();
auto type = root.node->output()->type()->expect<TensorType>();
auto batch_inputs = [&](Side s, std::array<int64_t, 2> cat_sizes) -> Value* {
int inputs_off = s == Side::LHS ? 0 : 1;
int cat_dim = s == Side::LHS ? 1 : 0;
cat_sizes[cat_dim] *= matmuls.size(); // make them really cat_sizes
auto inputs = fmap(matmuls, [=](Node *mm) { return mm->inputs()[inputs_off]; });
Node *cat = graph->create(aten::cat, inputs)
->i_(attr::dim, cat_dim);
cat->insertBefore(root.node);
cat->output()->setType(type->withSizes(cat_sizes));
return cat->output();
};
auto lhs_batch = batch_inputs(Side::LHS, root.lhs_sizes);
auto rhs_batch = batch_inputs(Side::RHS, root.rhs_sizes);
Node *batch_mm = graph->create(aten::mm, {lhs_batch, rhs_batch});
batch_mm->output()->setType(type->asShared());
batch_mm->insertBefore(root.node);
root.node->output()->replaceAllUsesWith(batch_mm->output());
// NB: don't bother with cleaning up after yourself. We'll use DCE for that.
}
EliminateDeadCode(block);
}
void BatchMM(std::shared_ptr<Graph>& graph) {
BatchMMBlock(graph->block());
}
}}