Disable fusion of grad_sum_to_size (#23372)

Summary:
Fixes: https://github.com/pytorch/pytorch/issues/22833

grad_sum_to_size does not commute with AutogradAdd after all because it turns the broadcasting AutogradAdd into a broadcasting add.

Chillee did actually do most of the tracking down to the fusion of grad_sum_to_size and pinging me when he had found the cause. Thank you!

About the choice of removing the fusion completely instead of being more precise:
- We do have grad_sum_to_size elimination which works for cases where broadcasting does not actually happen in the forward, so the cases where the fusing of grad_sum_to_size is actually beneficial is much smaller than when initially proposed.
- There will be less fusion, in terms of the tests, IOU stops being fully fused. I vaguely think that it is a case we could handle with refined logic.
- Keeping it would add complexity in checking when to merge fusion groups to the complexities that this PR removes.
- The future of fusion probably lies more in more complete solutions including reductions (TVM or KeOps or our own or ...).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23372

Differential Revision: D16489930

Pulled By: soumith

fbshipit-source-id: bc0431b0d3eda264c401b634675872c4ce46f0f4
This commit is contained in:
Thomas Viehmann 2019-07-25 08:51:47 -07:00 committed by Facebook Github Bot
parent 82545ecc71
commit cf50249bde
6 changed files with 5 additions and 233 deletions

View File

@ -543,6 +543,7 @@ class TestFuser(JitTestCase):
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle")
@enable_cpu_fuser
@unittest.skip("temporarily disabled because fusion was restricted in fixing #22833")
def test_fuser_iou(self):
# This checks if most of Intersection over Union is fused.
# In particular, the backward contains many _grad_sum_to_size.

View File

@ -11,7 +11,6 @@
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/passes/canonicalize.h>
#include <torch/csrc/jit/passes/graph_fuser.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <atomic>
@ -157,71 +156,6 @@ static void setInputBroadcastGroups(KernelSpec& spec) {
std::back_inserter(spec.inputBroadcastGroups()));
}
// This function moves _grad_sum_to_size nodes along the computation graph
// of the fusion group to the outputs and then records the shape inputs
// in order for summation to be applied after the kernel.
// Note that the correctness relies on the invariant that
// _grad_sum_to_size is only applied to gradient nodes created by autodiff.
// This is important because it ensures that in the mul and div nodes only
// one argument (in the case of div the numerator) has a summed value.
// If two arguments to mul had one, we would be in trouble, but thanks
// to the chain rule, we're OK.
// Note that this means that one kernel output may lead to several fusion
// group outputs when several outputs had the same calculation except
// for the final _grad_sum_to_size. This is also the reason why
// we need to deduplicate kernel outputs at the end of this function.
void processGradSumToSize(KernelSpec& spec) {
auto graph = spec.graph();
std::vector<int64_t> outputGradSumToSizes(graph->outputs().size(), -1);
// these are expressions that might occur during autotdiff operating
// on the gradient (matmul would likely be, too but we don't fuse it)
// note that for mul, we know (from the chain rule) that only one
// factor will be stemming from a calculation involving gradients so
// we know that we can move _grad_sum_to_size across it
// Scan the graph. We will delete nodes. We want later (in the graph)
// _grad_sum_to_size nodes to have priority over earlier ones. Thus
// we scan backwards.
for (auto it = graph->nodes().rbegin(); it != graph->nodes().rend(); it++) {
auto* node = *it;
if (node->kind() != aten::_grad_sum_to_size) {
continue;
}
bool success = trackSingleGradSumToSizeToOutputs(
node->output(), &outputGradSumToSizes);
AT_ASSERT(success); // check that we didn't hit anything unknown
// remove the GradSumToSize node, a new node outside the fusion graph
// will be inserted below
node->output()->replaceAllUsesWith(node->inputs()[0]);
it.destroyCurrent();
}
// By removing the _grad_sum_to_size notes, we might end up with
// duplicate outputs, e.g. when having the autodiff backwards of
// x + y + z of something with x, y, z, those will have different
// _grad_sum_to_sizes but of the same kernel output.
// for each fusion group output, record the corresponding kernel
// output and possibly a _grad_sum_to_size for that output
auto& outputMapAndSizes = spec.outputMapAndSizes();
AT_ASSERT(outputMapAndSizes.empty());
std::unordered_map<const Value*, int64_t> reduced_output_indices;
int64_t newo = 0;
for (auto osize : outputGradSumToSizes) {
auto it = reduced_output_indices.find(graph->outputs()[newo]);
if (it == reduced_output_indices.end()) {
reduced_output_indices.emplace(graph->outputs()[newo], newo);
outputMapAndSizes.emplace_back(newo, osize);
newo++;
} else {
graph->eraseOutput(newo);
outputMapAndSizes.emplace_back(it->second, osize);
}
}
}
// Performs "upfront" compilation where storage is known but shapes are not.
// Currently identifies how to expand all tensors so that all intermediate
// tensors are the same shape, simplifying code generation.
@ -234,7 +168,6 @@ void processGradSumToSize(KernelSpec& spec) {
static void upfrontCompilation(KernelSpec& spec) {
setInputBroadcastGroups(spec);
setInputChunkDescriptors(spec);
processGradSumToSize(spec);
}
int64_t registerFusion(const Node* fusion_group) {

View File

@ -329,7 +329,6 @@ bool runFusion(const int64_t key, Stack& stack, std::string* code_out) {
auto maybe_spec = retrieve(key);
AT_ASSERT(maybe_spec);
auto& spec = *(*maybe_spec);
// Acquires inputs from stack
auto all_inputs = last(stack, spec.nInputs());
std::vector<at::Tensor> inputs;
@ -381,18 +380,8 @@ bool runFusion(const int64_t key, Stack& stack, std::string* code_out) {
}
// Launches fusion
std::vector<at::Tensor> raw_outputs;
launchFusion(*(*maybe_kernel), device, inputs, all_inputs, raw_outputs);
auto outputs = fmap(spec.outputMapAndSizes(), [&](const OutputMapAndSize& omap) {
if (omap.needsSumToSize()) {
return at::sum_to(
raw_outputs[omap.offset()],
all_inputs[omap.sizeInput()].toIntListRef());
} else {
return raw_outputs[omap.offset()];
}
});
std::vector<at::Tensor> outputs;
launchFusion(*(*maybe_kernel), device, inputs, all_inputs, outputs);
// Updates stack
drop(stack, spec.nInputs());

View File

@ -41,32 +41,6 @@ struct TORCH_API PartitionInfo {
int64_t dim_;
};
// This is a helper struct to record the following:
// for each fusion group output, it records the corresponding
// kernel output offset (in offset) and the fusion group input
// to that is to be applied with sumtosize on the output (if any).
// This mapping is necessar as a single kernel output might be
// summed to different sizes.
// These mappings are created during compilation in processGradSumToSize.
struct TORCH_API OutputMapAndSize {
OutputMapAndSize(const int64_t _offset, const int64_t _sizeInput)
: offset_{_offset}, sizeInput_{_sizeInput} {};
int64_t offset() const {
return offset_;
}
int64_t sizeInput() const {
return sizeInput_;
}
bool needsSumToSize() const {
return sizeInput_ != -1;
}
private:
int64_t offset_;
int64_t sizeInput_;
};
// "Kernel Specification." - Contains device-independent fusion information.
// Each kernel specification contains a map of instantiated generated functions
// that implement some or most of its functionality. Multiple generated
@ -90,7 +64,6 @@ struct TORCH_API KernelSpec {
nTensorInputs_{},
inputBroadcastGroups_{},
inputChunks_{},
outputMapAndSizes_{},
has_random_{false},
kernels_{} {
for (const auto& n : graph_->nodes()) {
@ -136,10 +109,6 @@ struct TORCH_API KernelSpec {
return inputChunks_;
}
std::vector<OutputMapAndSize>& outputMapAndSizes() {
return outputMapAndSizes_;
}
bool hasRandom() const {
return has_random_;
}
@ -167,11 +136,6 @@ struct TORCH_API KernelSpec {
uint64_t nTensorInputs_;
std::vector<std::vector<int64_t>> inputBroadcastGroups_;
std::vector<PartitionInfo> inputChunks_;
// This will initially be an empty vector. During kernel compilation
// in processGradSumToSize it will be filled and will contain one
// element per fusion group output (which may be larger than the
// number of kernel outputs).
std::vector<OutputMapAndSize> outputMapAndSizes_;
bool has_random_;
mutable std::mutex mutex_;
mutable std::

View File

@ -170,13 +170,6 @@ struct GraphFuser {
});
}
bool containsGradSumToSize(Node* fusion_group) {
auto nodes = getSubgraph(fusion_group).nodes();
return std::any_of(nodes.begin(), nodes.end(), [](Node* n) {
return n->kind() == aten::_grad_sum_to_size;
});
}
bool isFusable(Node* node) {
return callback_(node);
}
@ -212,14 +205,6 @@ struct GraphFuser {
// are not necessarily correct.
if (node->owningBlock() != block_)
return false;
if (node->kind() == aten::_grad_sum_to_size) {
// We only fuse _grad_sum_to_size if
// - we will fuse its input next (checked here)
// - we can commute the _grad_sum_to_size with everything
// along the computation graph until we reach the outputs,
// but this is checked later
return isFusable(node->inputs()[0]->node());
}
return node->kind() == prim::FusionGroup || isSimpleMap(node);
}
@ -444,23 +429,6 @@ struct GraphFuser {
return at::nullopt;
}
if (producer->node()->kind() == aten::_grad_sum_to_size &&
consumer->kind() == kind_) {
// check that we will be able to move the _grad_sum_to_size to be fused
// to the end of the fusion group in the fusion compiler
// the difficulty here is that the producer is not part of the fusion
// group yet
for (auto& u : producer->uses()) {
if (u.user == consumer) {
auto subgraph = &getSubgraph(consumer);
if (!trackSingleGradSumToSizeToOutputs(
subgraph->inputs().at(u.offset), nullptr)) {
return at::nullopt;
}
}
}
}
auto group = consumer;
if (consumer->kind() != kind_) {
group = createSingletonFusionGroup(consumer);
@ -1039,13 +1007,11 @@ struct GraphFuser {
}
// Fusion groups can be merged with concat's group if and only if
// - the value they produce isn't already coming from a concat and
// - the fusion group does not contain GradSumToSize
// the value they produce isn't already coming from a concat
if (producer->node()->kind() == prim::FusionGroup) {
auto subgraph = producer->node()->g(attr::Subgraph);
auto* node = subgraph->outputs().at(producer->offset())->node();
return node->kind() != prim::FusedConcat &&
!containsGradSumToSize(producer->node());
return node->kind() != prim::FusedConcat;
}
return true;
}
@ -1227,83 +1193,6 @@ void PeepholeOptimizeShapeExpressions(Block* block) {
} // anonymous namespace
// This takes a _grad_sum_to_size output and tracks it to the return
// statements that depend on it, checking that it only hits nodes
// that commute with _grad_sum_to_size on its path.
// If a non-nullptr vector pointer outputGradSumToSizes is passed, the sizes
// will be recorded as target sizes for the outputs as applicable.
// In the graph_fuser pass we only need to check that we can go to the
// outputs while in the fuser's compiler we want to record the sizes.
// Note: This will only record a new sum_to_size if there is not one
// already. As we want the last grad_sum_to_size, you need to call
// it in reverse order when recording and removing outputs.
bool trackSingleGradSumToSizeToOutputs(
Value* gradSumToSizeOutput,
std::vector<int64_t>* outputGradSumToSizes) {
static OperatorSet commutes_with_SumToSize{{
"aten::mul(Tensor self, Tensor other) -> Tensor",
"aten::div(Tensor self, Tensor other) -> Tensor",
// for div we might check whether we're the first argument
"aten::mul(Tensor self, Scalar other) -> Tensor",
"aten::div(Tensor self, Scalar other) -> Tensor",
"aten::neg(Tensor self) -> Tensor",
"aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
"aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
// add this used to be prim::AutogradAdd
}};
std::queue<Use> uses_to_process{};
auto add_to_uses = [&](const use_list& uses) {
for (auto u : uses) {
uses_to_process.push(u);
}
};
add_to_uses(gradSumToSizeOutput->uses());
while (!uses_to_process.empty()) {
auto user = uses_to_process.front().user;
auto offset = uses_to_process.front().offset;
uses_to_process.pop();
if (user->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) {
// sometimes, a mask or similar is cast to the same type as the gradient,
// i.e. we see other. Then we don't need to do anything, as the shape is
// not used, only the type..
// But we might also see it as self, when the gradient is cast, then we
// want to track it.
if (offset == 0) {
add_to_uses(user->output()->uses());
}
} else if (commutes_with_SumToSize.find(user)) {
add_to_uses(user->output()->uses());
} else if (user->kind() == prim::Return) {
// During compilation and only if we don't already have a
// _grad_sum_to_size for this output we record the size to sum the output
// to. We only do this if we didn't see anything yet because we want later
// (in the graph) nodes to take precedence over earlier ones and we
// iterate backwards. The implicit assumption is that if we have several
// _grad_sumtosizes "in parallel" (from auto-diff added AutogradAdd as the
// backward of using an input in multiple places) they are the same. This
// is because AutogradAdd does not broadcast.
if (outputGradSumToSizes && (*outputGradSumToSizes)[offset] == -1) {
// note: we make the assumption that the sizes are inputs to the
// fusion group (rather than something calculated).
(*outputGradSumToSizes)[offset] =
gradSumToSizeOutput->node()->inputs()[1]->offset();
}
} else if (user->kind() == aten::_grad_sum_to_size) {
// do nothing
// this case only happens in the graph_fuser step because in the
// compile step because we iterate backwards and delete
// all _grad_sum_to_size nodes we see
} else {
// we find something we do not support. Note that this notably includes
// prim::FusedConcat, which we do not know how to deal with in conjunction
// with _grad_sum_to_size
return false;
}
}
return true;
}
void FuseGraph(std::shared_ptr<Graph>& graph) {
GraphFuser(graph->block(), graph).run();
// After FuseGraph some common subexpressions may come back

View File

@ -28,9 +28,5 @@ TORCH_API void CustomFuseGraph(
Symbol kind,
size_t arg_limit=std::numeric_limits<size_t>::max());
TORCH_API bool trackSingleGradSumToSizeToOutputs(
Value* gradSumToSizeOutput,
std::vector<int64_t>* outputGradSumToSizes);
} // namespace jit
} // namespace torch