mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156318 Approved by: https://github.com/albanD
494 lines
18 KiB
C++
494 lines
18 KiB
C++
#include <torch/csrc/jit/passes/batch_mm.h>
|
|
|
|
#include <ATen/core/functional.h>
|
|
#include <ATen/core/symbol.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <c10/util/irange.h>
|
|
#include <torch/csrc/jit/ir/alias_analysis.h>
|
|
#include <torch/csrc/jit/ir/constants.h>
|
|
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
|
#include <torch/csrc/jit/passes/peephole.h>
|
|
#include <torch/csrc/jit/runtime/custom_operator.h>
|
|
#include <torch/csrc/jit/runtime/graph_iterator.h>
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <algorithm>
|
|
#include <unordered_map>
|
|
#include <utility>
|
|
|
|
namespace torch::jit {
|
|
|
|
namespace {
|
|
c10::AliasAnalysisKind aliasAnalysisIsSpecialCase() {
|
|
return AliasAnalysisKind::INTERNAL_SPECIAL_CASE;
|
|
}
|
|
} // namespace
|
|
|
|
// This pass looks for trees in the graph, where leaves are mm ops, and the
|
|
// inner vertices are add nodes. Once we have such a tree they can be reduced to
|
|
// two concats and a single mm (basically into a single multiply of a wide
|
|
// matrix, with a tall matrix). Such patterns show up mostly in backward of
|
|
// RNNs, since the derivative of many uses of matrix multiplies with same
|
|
// weights forms exactly such a tree (note that it's usually also highly
|
|
// imbalanced i.e. has O(n) depth).
|
|
//
|
|
// This (or any tree of adds of MMs):
|
|
//
|
|
// +------+ +------+ +------+ +------+ +------+
|
|
// | | | | | | | | | |
|
|
// | L1 | | R1 | + | L2 | | R2 | = | O |
|
|
// | | | | | | | | | |
|
|
// +------+ +------+ +------+ +------+ +------+
|
|
//
|
|
// can be basically transformed into a single MM which looks like this
|
|
// (we concat all lhs operands, concat rhs operands, do mm):
|
|
//
|
|
// +------+
|
|
// | |
|
|
// | R1 |
|
|
// | |
|
|
// +------+
|
|
// | |
|
|
// | R2 |
|
|
// | |
|
|
// +------+
|
|
// +------+------+ +------+
|
|
// | | | | |
|
|
// | L1 | L2 | | O |
|
|
// | | | | |
|
|
// +------+------+ +------+
|
|
|
|
// Note [Further optimizations]
|
|
// It would be straightforward to extend the TreeToken class to also detect if
|
|
// all MMs had the same lhs/rhs. In such case it's more efficient to expand the
|
|
// lhs and use bmm + sum instead of repeating it in memory via concat.
|
|
|
|
// Note [Overlapping trees]
|
|
// Additionally it wouldn't be too hard to add support for partially overlapping
|
|
// trees. Right now the it's forbidden in the algorithm (only a single tree will
|
|
// be allowed), so theoretically we might miss some optimization options,
|
|
// especially that the rejected tree could be much larger. I didn't implement
|
|
// that because it's not necessary for the simple RNN cases I saw, so I decided
|
|
// to keep stuff simple. If we ever get around implementing this, the right
|
|
// solution is probably to fuse MMs for the common part, and assume it's an
|
|
// input leaf for the outer two parts (I don't think it's beneficial to
|
|
// recompute, unless the subtree is super small, but let's not get into such
|
|
// details).
|
|
|
|
// The algorithm we're using is simple. We're iterating through the graph in the
|
|
// topological order and labeling nodes with TreeTokens. Then, we look for roots
|
|
// of the trees we formed and fuse them.
|
|
|
|
// Tunable parameter. Set to something larger if it turns out to be better.
|
|
static constexpr size_t min_fusion_size = 4;
|
|
|
|
static bool have_same_shape(at::TensorList inputs) {
|
|
auto expected_sizes = inputs[0].sizes();
|
|
return (std::all_of(
|
|
inputs.begin(), inputs.end(), [expected_sizes](const at::Tensor& t) {
|
|
return t.sizes() == expected_sizes;
|
|
}));
|
|
}
|
|
|
|
static bool should_be_transposed(at::TensorList inputs) {
|
|
return (std::all_of(inputs.begin(), inputs.end(), [](const at::Tensor& t) {
|
|
return t.stride(0) == 1 && t.stride(1) == t.size(0);
|
|
}));
|
|
}
|
|
|
|
static std::vector<at::Tensor> transpose_inputs(at::TensorList inputs) {
|
|
return fmap(inputs, [](const at::Tensor& i) { return i.t(); });
|
|
}
|
|
|
|
static bool shape_is_fast_for_reduce(
|
|
const at::Tensor& lhs,
|
|
const at::Tensor& rhs) {
|
|
size_t l = lhs.size(0);
|
|
size_t m = lhs.size(1);
|
|
size_t r = rhs.size(1);
|
|
// Numbers obtained by some simple benchmarks of fp32 gemms on a TITAN V
|
|
return m < 512 || ((l < 256 && r < 256) || (l > 256 && r > 256));
|
|
}
|
|
|
|
static RegisterOperators mm_tree_reduction_reg({Operator(
|
|
"prim::MMTreeReduce(...) -> Tensor",
|
|
[](Stack& stack) {
|
|
auto num_inputs = pop(stack).toInt();
|
|
std::vector<at::Tensor> inputs;
|
|
inputs.reserve(num_inputs);
|
|
for (auto it = stack.end() - num_inputs; it != stack.end(); ++it) {
|
|
inputs.push_back(std::move(*it).toTensor());
|
|
}
|
|
drop(stack, num_inputs);
|
|
|
|
AT_ASSERT(!inputs.empty());
|
|
AT_ASSERT(inputs.size() % 2 == 0);
|
|
size_t side_num_elems = inputs.size() / 2;
|
|
auto lhs_inputs = at::TensorList(inputs).slice(0, side_num_elems);
|
|
auto rhs_inputs = at::TensorList(inputs).slice(side_num_elems);
|
|
// TODO: checking this is not free, so we should stop if this keeps
|
|
// failing
|
|
if (have_same_shape(lhs_inputs) && have_same_shape(rhs_inputs) &&
|
|
shape_is_fast_for_reduce(lhs_inputs[0], rhs_inputs[0])) {
|
|
// sometimes lhs_inputs or rhs_inputs are not contiguous, and that
|
|
// causes at::cat to go through slow path view them as contiguous if
|
|
// possible by transposing
|
|
bool lhs_input_transposed = should_be_transposed(lhs_inputs);
|
|
bool rhs_input_transposed = should_be_transposed(rhs_inputs);
|
|
at::Tensor lhs, rhs;
|
|
if (lhs_input_transposed) {
|
|
std::vector<at::Tensor> lhs_contig_inputs =
|
|
transpose_inputs(lhs_inputs);
|
|
lhs = at::cat(lhs_contig_inputs, /*dim*/ 0);
|
|
lhs = lhs.t();
|
|
} else {
|
|
lhs = at::cat(lhs_inputs, /*dim=*/1);
|
|
}
|
|
if (rhs_input_transposed) {
|
|
std::vector<at::Tensor> rhs_contig_inputs =
|
|
transpose_inputs(rhs_inputs);
|
|
rhs = at::cat(rhs_contig_inputs, /*dim*/ 1);
|
|
rhs = rhs.t();
|
|
} else {
|
|
rhs = at::cat(rhs_inputs, /*dim=*/0);
|
|
}
|
|
push(stack, at::mm(lhs, rhs));
|
|
} else {
|
|
auto acc = at::mm(inputs[0], inputs[side_num_elems]);
|
|
for (const auto i : c10::irange(1, side_num_elems)) {
|
|
acc.add_(at::mm(inputs[i], inputs[side_num_elems + i]));
|
|
}
|
|
push(stack, std::move(acc));
|
|
}
|
|
},
|
|
aliasAnalysisIsSpecialCase())});
|
|
|
|
// TreeTokens will be used to label nodes of the graph, if the nodes will fit
|
|
// our mm/add tree pattern. Basically we do dynamic programming on DAGs, where
|
|
// when we reach node N with inputs A and B, then A and B have already been
|
|
// processed, and we can try to unify their TreeTokens (if they have them)
|
|
// and build a larger tree.
|
|
struct TreeToken {
|
|
uint64_t tree_size = 0; // NOTE: measured in number of leaves i.e. mm ops
|
|
Node* node = nullptr;
|
|
bool is_root = false;
|
|
|
|
static TreeToken mm(Node* mm) {
|
|
TreeToken token;
|
|
token.tree_size = 1;
|
|
token.node = mm;
|
|
token.is_root = true;
|
|
return token;
|
|
}
|
|
|
|
// NB: the returned token might be invalid, so make sure to check its boolean
|
|
// value!
|
|
static TreeToken transpose(Node* t, TreeToken& inp_token) {
|
|
TreeToken token;
|
|
if (!inp_token.node->matches(
|
|
"aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
|
|
return token;
|
|
}
|
|
token.tree_size = 1;
|
|
token.node = t;
|
|
token.is_root = true;
|
|
inp_token.is_root = false;
|
|
return token;
|
|
}
|
|
|
|
// NB: the returned token might be invalid, so make sure to check its boolean
|
|
// value!
|
|
static TreeToken add(Node* add, TreeToken& l, TreeToken& r) {
|
|
TreeToken token;
|
|
// See Note [Overlapping trees]
|
|
if (&l == &r || !l.is_root || !r.is_root)
|
|
return token;
|
|
token.tree_size = l.tree_size + r.tree_size;
|
|
token.node = add;
|
|
token.is_root = true;
|
|
l.is_root = r.is_root =
|
|
false; // Reserve the subtrees, so they can't be used again.
|
|
return token;
|
|
}
|
|
|
|
explicit operator bool() {
|
|
return is_root;
|
|
}
|
|
|
|
std::vector<Node*> removeTransposesAndGatherMatmuls() {
|
|
std::vector<Node*> matmuls;
|
|
std::vector<Node*> queue{node};
|
|
Graph* graph = node->owningGraph();
|
|
while (!queue.empty()) {
|
|
auto n = queue.back();
|
|
queue.pop_back();
|
|
if (n->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
|
|
matmuls.push_back(n);
|
|
} else if (n->matches("aten::t(Tensor self) -> Tensor")) {
|
|
Node* input_node = n->input()->node();
|
|
AT_ASSERT(input_node->matches(
|
|
"aten::mm(Tensor self, Tensor mat2) -> Tensor"));
|
|
// (AB)^T == B^TA^T
|
|
WithInsertPoint insert_guard{input_node};
|
|
Value* A = input_node->inputs()[0];
|
|
Value* B = input_node->inputs()[1];
|
|
Value* AT = graph->insert(aten::t, {A});
|
|
Value* BT = graph->insert(aten::t, {B});
|
|
Value* BTAT = graph->insert(aten::mm, {BT, AT});
|
|
n->output()->replaceAllUsesWith(BTAT);
|
|
matmuls.push_back(BTAT->node());
|
|
} else if (
|
|
n->matches(
|
|
"aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
|
|
queue.push_back(n->inputs()[0]->node());
|
|
queue.push_back(n->inputs()[1]->node());
|
|
} else {
|
|
TORCH_INTERNAL_ASSERT(
|
|
false, "Unsupported node found in a BatchMM tree!");
|
|
}
|
|
}
|
|
return matmuls;
|
|
}
|
|
};
|
|
|
|
enum class Side { LHS, RHS };
|
|
|
|
static void BatchMMTreeReduce(Block* block, AliasDb& alias_db) {
|
|
auto graph = block->owningGraph();
|
|
|
|
// Look for trees in the block
|
|
std::unordered_map<Node*, TreeToken> tokens;
|
|
for (auto node : block->nodes()) {
|
|
if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor") &&
|
|
!alias_db.hasWriters(node)) {
|
|
tokens[node] = TreeToken::mm(node);
|
|
} else if (
|
|
node->matches("aten::t(Tensor self) -> Tensor") &&
|
|
!alias_db.hasWriters(node)) {
|
|
auto input_it = tokens.find(node->input()->node());
|
|
if (input_it != tokens.end()) {
|
|
tokens[node] = TreeToken::transpose(node, input_it->second);
|
|
}
|
|
} else if (
|
|
node->matches(
|
|
"aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") &&
|
|
!alias_db.hasWriters(node)) {
|
|
Node* lhs = node->inputs()[0]->node();
|
|
Node* rhs = node->inputs()[1]->node();
|
|
auto lhs_it = tokens.find(lhs);
|
|
auto rhs_it = tokens.find(rhs);
|
|
// See Note [Overlapping trees] (regarding the uses().size() == 1 check)
|
|
// We could treat a subtree with multiple uses as if it was overlapping.
|
|
// XXX: uses().size() == 1 is also something that guarantees that this
|
|
// transform is valid, because we know for sure that the none of these
|
|
// operands depend on the result of the other. If we were to remove this,
|
|
// we need to compute a transitive closure and actually check the
|
|
// dependencies.
|
|
if (lhs_it != tokens.end() && rhs_it != tokens.end() &&
|
|
lhs->output()->uses().size() == 1 &&
|
|
rhs->output()->uses().size() == 1) {
|
|
if (auto token = TreeToken::add(node, lhs_it->second, rhs_it->second)) {
|
|
tokens[node] = token;
|
|
}
|
|
}
|
|
} else {
|
|
for (auto block : node->blocks()) {
|
|
BatchMMTreeReduce(block, alias_db);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Merge trees we've found
|
|
for (auto& item : tokens) {
|
|
auto& root = item.second;
|
|
if (!root || root.tree_size < min_fusion_size)
|
|
continue;
|
|
auto matmuls = root.removeTransposesAndGatherMatmuls();
|
|
WithInsertPoint insert_guard{root.node};
|
|
Node* tree_reduce =
|
|
graph->insertNode(graph->create(Symbol::prim("MMTreeReduce")));
|
|
for (Node* matmul : matmuls) {
|
|
tree_reduce->addInput(matmul->inputs().at(0));
|
|
}
|
|
for (Node* matmul : matmuls) {
|
|
tree_reduce->addInput(matmul->inputs().at(1));
|
|
}
|
|
root.node->output()->replaceAllUsesWith(tree_reduce->output());
|
|
// NB: don't bother with cleaning up after yourself. We'll use DCE for that.
|
|
}
|
|
}
|
|
|
|
static bool shape_is_fast_for_side(const at::Tensor& other_side_input) {
|
|
// Cutoff chose by benchmarking on a TITAN V
|
|
return other_side_input.numel() <= 1024 * 2048;
|
|
}
|
|
|
|
static RegisterOperators mm_batch_side_reg({Operator(
|
|
prim::MMBatchSide,
|
|
[](const Node* node) -> Operation {
|
|
size_t num_other_side_inputs = node->inputs().size() - 1;
|
|
Side single_side = static_cast<Side>(node->i(Symbol::attr("side")));
|
|
return [num_other_side_inputs, single_side](Stack& stack) {
|
|
at::Tensor side_input;
|
|
std::vector<at::Tensor> other_side_inputs;
|
|
other_side_inputs.reserve(num_other_side_inputs);
|
|
for (auto it = stack.end() - num_other_side_inputs; it != stack.end();
|
|
++it) {
|
|
other_side_inputs.push_back(std::move(*it).toTensor());
|
|
}
|
|
drop(stack, num_other_side_inputs);
|
|
pop(stack, side_input);
|
|
|
|
auto any_other_input = other_side_inputs[0];
|
|
if (have_same_shape(other_side_inputs) &&
|
|
shape_is_fast_for_side(other_side_inputs[0])) {
|
|
auto other_side_input =
|
|
at::cat(other_side_inputs, single_side == Side::LHS ? 1 : 0);
|
|
auto mm_out = single_side == Side::LHS
|
|
? side_input.mm(other_side_input)
|
|
: other_side_input.mm(side_input);
|
|
auto outputs = at::chunk(
|
|
mm_out,
|
|
num_other_side_inputs,
|
|
/*dim=*/single_side == Side::LHS ? 1 : 0);
|
|
stack.insert(
|
|
stack.end(),
|
|
std::make_move_iterator(outputs.begin()),
|
|
std::make_move_iterator(outputs.end()));
|
|
} else {
|
|
if (single_side == Side::LHS) {
|
|
for (at::Tensor& other : other_side_inputs) {
|
|
stack.emplace_back(side_input.mm(other));
|
|
}
|
|
} else {
|
|
for (at::Tensor& other : other_side_inputs) {
|
|
stack.emplace_back(other.mm(side_input));
|
|
}
|
|
}
|
|
}
|
|
};
|
|
},
|
|
aliasAnalysisIsSpecialCase())});
|
|
|
|
static std::pair<std::vector<Node*>, std::vector<Node*>> gatherIndependentMMUses(
|
|
Value* value,
|
|
AliasDb& alias_db) {
|
|
const auto postprocess = [&](std::vector<Node*> mms) {
|
|
if (mms.empty()) {
|
|
return mms;
|
|
}
|
|
std::sort(mms.begin(), mms.end(), [](Node* n, Node* m) {
|
|
return n->isBefore(m);
|
|
});
|
|
// Filter out dependent MMs. This algorithm might do very badly if e.g. you
|
|
// have a lot of independent MMs, that depend on the first one, but I doubt
|
|
// this will be a common scenario.
|
|
for (const auto i : c10::irange(mms.size())) {
|
|
if (mms[i] == nullptr)
|
|
continue;
|
|
for (size_t j = i + 1; j < mms.size(); ++j) {
|
|
if (mms[j] == nullptr)
|
|
continue;
|
|
if (!alias_db.couldMoveBeforeTopologically(mms[j], mms[i])) {
|
|
mms[j] = nullptr;
|
|
}
|
|
}
|
|
}
|
|
return c10::filter(mms, [](Node* n) { return n != nullptr; });
|
|
};
|
|
|
|
Block* block = value->node()->owningBlock();
|
|
std::vector<Node*> lhses; // Will contain nodes where value is used as an lhs
|
|
std::vector<Node*> rhses; // Like above, but rhs
|
|
for (Use u : value->uses()) {
|
|
if (u.user->owningBlock() == block &&
|
|
u.user->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor") &&
|
|
!alias_db.hasWriters(u.user)) {
|
|
if (u.offset == 0 && u.user->inputs()[1] != value) {
|
|
lhses.push_back(u.user);
|
|
} else if (u.offset == 1 && u.user->inputs()[0] != value) {
|
|
rhses.push_back(u.user);
|
|
}
|
|
}
|
|
}
|
|
return std::make_pair(
|
|
postprocess(std::move(lhses)), postprocess(std::move(rhses)));
|
|
}
|
|
|
|
static void BatchMMSide(Block* block, AliasDb& alias_db) {
|
|
// NB: 8 is the current loop unrolling factor
|
|
static constexpr size_t how_many_is_many = 8;
|
|
const auto batch_side = [&](std::vector<Node*>& mms, Side side) {
|
|
AT_ASSERT(!mms.empty());
|
|
for (int64_t i = static_cast<int64_t>(mms.size()) - 2; i >= 0; --i) {
|
|
bool move_ok = alias_db.moveBeforeTopologicallyValid(mms[i], mms[i + 1]);
|
|
AT_ASSERT(move_ok);
|
|
}
|
|
WithInsertPoint insert_guard{mms[0]};
|
|
Graph* graph = mms[0]->owningGraph();
|
|
Node* batch_mm = graph->create(
|
|
prim::MMBatchSide,
|
|
/*inputs=*/{},
|
|
/*num_outputs=*/mms.size());
|
|
graph->insertNode(batch_mm);
|
|
batch_mm->i_(Symbol::attr("side"), static_cast<int>(side));
|
|
Value* const_side = mms[0]->inputs().at(side == Side::LHS ? 0 : 1);
|
|
batch_mm->addInput(const_side);
|
|
for (const auto i : c10::irange(mms.size())) {
|
|
batch_mm->addInput(mms[i]->inputs().at(side == Side::LHS ? 1 : 0));
|
|
mms[i]->output()->replaceAllUsesWith(batch_mm->outputs().at(i));
|
|
}
|
|
};
|
|
|
|
std::unordered_set<Value*> considered_values;
|
|
for (Node* node : block->nodes()) {
|
|
if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor") &&
|
|
!alias_db.hasWriters(node)) {
|
|
for (Value* input : node->inputs()) {
|
|
if (/*bool not_inserted = */ !considered_values.emplace(input).second) {
|
|
continue;
|
|
}
|
|
auto uses_with_many = gatherIndependentMMUses(input, alias_db);
|
|
if (uses_with_many.first.size() >= how_many_is_many) {
|
|
batch_side(uses_with_many.first, Side::LHS);
|
|
}
|
|
if (uses_with_many.second.size() >= how_many_is_many) {
|
|
batch_side(uses_with_many.second, Side::RHS);
|
|
}
|
|
}
|
|
} else {
|
|
for (Block* subblock : node->blocks()) {
|
|
BatchMMSide(subblock, alias_db);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
static bool hasMMOperators(std::shared_ptr<Graph>& graph) {
|
|
DepthFirstGraphNodeIterator it(graph);
|
|
Node* n = nullptr;
|
|
while ((n = it.next()) != nullptr) {
|
|
if (n->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
void BatchMM(std::shared_ptr<Graph>& graph) {
|
|
if (!hasMMOperators(graph)) {
|
|
return;
|
|
}
|
|
AliasDb alias_db(graph);
|
|
BatchMMTreeReduce(graph->block(), alias_db);
|
|
BatchMMSide(graph->block(), alias_db);
|
|
EliminateDeadCode(graph);
|
|
// It's possible that transpose rearrangements have created sequences of
|
|
// consecutive transposes that didn't exist before.
|
|
|
|
// tensor type properties are not guaranteed to be correct
|
|
PeepholeOptimize(graph, /*disable_shape_peepholes*/ true);
|
|
}
|
|
|
|
} // namespace torch::jit
|