mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Disable fusion of grad_sum_to_size (#23372)
Summary: Fixes: https://github.com/pytorch/pytorch/issues/22833 grad_sum_to_size does not commute with AutogradAdd after all because it turns the broadcasting AutogradAdd into a broadcasting add. Chillee did actually do most of the tracking down to the fusion of grad_sum_to_size and pinging me when he had found the cause. Thank you! About the choice of removing the fusion completely instead of being more precise: - We do have grad_sum_to_size elimination which works for cases where broadcasting does not actually happen in the forward, so the cases where the fusing of grad_sum_to_size is actually beneficial is much smaller than when initially proposed. - There will be less fusion, in terms of the tests, IOU stops being fully fused. I vaguely think that it is a case we could handle with refined logic. - Keeping it would add complexity in checking when to merge fusion groups to the complexities that this PR removes. - The future of fusion probably lies more in more complete solutions including reductions (TVM or KeOps or our own or ...). Pull Request resolved: https://github.com/pytorch/pytorch/pull/23372 Differential Revision: D16489930 Pulled By: soumith fbshipit-source-id: bc0431b0d3eda264c401b634675872c4ce46f0f4
This commit is contained in:
parent
82545ecc71
commit
cf50249bde
|
|
@ -543,6 +543,7 @@ class TestFuser(JitTestCase):
|
|||
|
||||
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle")
|
||||
@enable_cpu_fuser
|
||||
@unittest.skip("temporarily disabled because fusion was restricted in fixing #22833")
|
||||
def test_fuser_iou(self):
|
||||
# This checks if most of Intersection over Union is fused.
|
||||
# In particular, the backward contains many _grad_sum_to_size.
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@
|
|||
#include <torch/csrc/jit/ir.h>
|
||||
#include <torch/csrc/jit/operator.h>
|
||||
#include <torch/csrc/jit/passes/canonicalize.h>
|
||||
#include <torch/csrc/jit/passes/graph_fuser.h>
|
||||
#include <torch/csrc/jit/passes/shape_analysis.h>
|
||||
|
||||
#include <atomic>
|
||||
|
|
@ -157,71 +156,6 @@ static void setInputBroadcastGroups(KernelSpec& spec) {
|
|||
std::back_inserter(spec.inputBroadcastGroups()));
|
||||
}
|
||||
|
||||
// This function moves _grad_sum_to_size nodes along the computation graph
|
||||
// of the fusion group to the outputs and then records the shape inputs
|
||||
// in order for summation to be applied after the kernel.
|
||||
// Note that the correctness relies on the invariant that
|
||||
// _grad_sum_to_size is only applied to gradient nodes created by autodiff.
|
||||
// This is important because it ensures that in the mul and div nodes only
|
||||
// one argument (in the case of div the numerator) has a summed value.
|
||||
// If two arguments to mul had one, we would be in trouble, but thanks
|
||||
// to the chain rule, we're OK.
|
||||
// Note that this means that one kernel output may lead to several fusion
|
||||
// group outputs when several outputs had the same calculation except
|
||||
// for the final _grad_sum_to_size. This is also the reason why
|
||||
// we need to deduplicate kernel outputs at the end of this function.
|
||||
void processGradSumToSize(KernelSpec& spec) {
|
||||
auto graph = spec.graph();
|
||||
|
||||
std::vector<int64_t> outputGradSumToSizes(graph->outputs().size(), -1);
|
||||
|
||||
// these are expressions that might occur during autotdiff operating
|
||||
// on the gradient (matmul would likely be, too but we don't fuse it)
|
||||
// note that for mul, we know (from the chain rule) that only one
|
||||
// factor will be stemming from a calculation involving gradients so
|
||||
// we know that we can move _grad_sum_to_size across it
|
||||
// Scan the graph. We will delete nodes. We want later (in the graph)
|
||||
// _grad_sum_to_size nodes to have priority over earlier ones. Thus
|
||||
// we scan backwards.
|
||||
for (auto it = graph->nodes().rbegin(); it != graph->nodes().rend(); it++) {
|
||||
auto* node = *it;
|
||||
if (node->kind() != aten::_grad_sum_to_size) {
|
||||
continue;
|
||||
}
|
||||
bool success = trackSingleGradSumToSizeToOutputs(
|
||||
node->output(), &outputGradSumToSizes);
|
||||
AT_ASSERT(success); // check that we didn't hit anything unknown
|
||||
|
||||
// remove the GradSumToSize node, a new node outside the fusion graph
|
||||
// will be inserted below
|
||||
node->output()->replaceAllUsesWith(node->inputs()[0]);
|
||||
it.destroyCurrent();
|
||||
}
|
||||
|
||||
// By removing the _grad_sum_to_size notes, we might end up with
|
||||
// duplicate outputs, e.g. when having the autodiff backwards of
|
||||
// x + y + z of something with x, y, z, those will have different
|
||||
// _grad_sum_to_sizes but of the same kernel output.
|
||||
|
||||
// for each fusion group output, record the corresponding kernel
|
||||
// output and possibly a _grad_sum_to_size for that output
|
||||
auto& outputMapAndSizes = spec.outputMapAndSizes();
|
||||
AT_ASSERT(outputMapAndSizes.empty());
|
||||
std::unordered_map<const Value*, int64_t> reduced_output_indices;
|
||||
int64_t newo = 0;
|
||||
for (auto osize : outputGradSumToSizes) {
|
||||
auto it = reduced_output_indices.find(graph->outputs()[newo]);
|
||||
if (it == reduced_output_indices.end()) {
|
||||
reduced_output_indices.emplace(graph->outputs()[newo], newo);
|
||||
outputMapAndSizes.emplace_back(newo, osize);
|
||||
newo++;
|
||||
} else {
|
||||
graph->eraseOutput(newo);
|
||||
outputMapAndSizes.emplace_back(it->second, osize);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Performs "upfront" compilation where storage is known but shapes are not.
|
||||
// Currently identifies how to expand all tensors so that all intermediate
|
||||
// tensors are the same shape, simplifying code generation.
|
||||
|
|
@ -234,7 +168,6 @@ void processGradSumToSize(KernelSpec& spec) {
|
|||
static void upfrontCompilation(KernelSpec& spec) {
|
||||
setInputBroadcastGroups(spec);
|
||||
setInputChunkDescriptors(spec);
|
||||
processGradSumToSize(spec);
|
||||
}
|
||||
|
||||
int64_t registerFusion(const Node* fusion_group) {
|
||||
|
|
|
|||
|
|
@ -329,7 +329,6 @@ bool runFusion(const int64_t key, Stack& stack, std::string* code_out) {
|
|||
auto maybe_spec = retrieve(key);
|
||||
AT_ASSERT(maybe_spec);
|
||||
auto& spec = *(*maybe_spec);
|
||||
|
||||
// Acquires inputs from stack
|
||||
auto all_inputs = last(stack, spec.nInputs());
|
||||
std::vector<at::Tensor> inputs;
|
||||
|
|
@ -381,18 +380,8 @@ bool runFusion(const int64_t key, Stack& stack, std::string* code_out) {
|
|||
}
|
||||
|
||||
// Launches fusion
|
||||
std::vector<at::Tensor> raw_outputs;
|
||||
launchFusion(*(*maybe_kernel), device, inputs, all_inputs, raw_outputs);
|
||||
|
||||
auto outputs = fmap(spec.outputMapAndSizes(), [&](const OutputMapAndSize& omap) {
|
||||
if (omap.needsSumToSize()) {
|
||||
return at::sum_to(
|
||||
raw_outputs[omap.offset()],
|
||||
all_inputs[omap.sizeInput()].toIntListRef());
|
||||
} else {
|
||||
return raw_outputs[omap.offset()];
|
||||
}
|
||||
});
|
||||
std::vector<at::Tensor> outputs;
|
||||
launchFusion(*(*maybe_kernel), device, inputs, all_inputs, outputs);
|
||||
|
||||
// Updates stack
|
||||
drop(stack, spec.nInputs());
|
||||
|
|
|
|||
|
|
@ -41,32 +41,6 @@ struct TORCH_API PartitionInfo {
|
|||
int64_t dim_;
|
||||
};
|
||||
|
||||
// This is a helper struct to record the following:
|
||||
// for each fusion group output, it records the corresponding
|
||||
// kernel output offset (in offset) and the fusion group input
|
||||
// to that is to be applied with sumtosize on the output (if any).
|
||||
// This mapping is necessar as a single kernel output might be
|
||||
// summed to different sizes.
|
||||
// These mappings are created during compilation in processGradSumToSize.
|
||||
struct TORCH_API OutputMapAndSize {
|
||||
OutputMapAndSize(const int64_t _offset, const int64_t _sizeInput)
|
||||
: offset_{_offset}, sizeInput_{_sizeInput} {};
|
||||
|
||||
int64_t offset() const {
|
||||
return offset_;
|
||||
}
|
||||
int64_t sizeInput() const {
|
||||
return sizeInput_;
|
||||
}
|
||||
bool needsSumToSize() const {
|
||||
return sizeInput_ != -1;
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t offset_;
|
||||
int64_t sizeInput_;
|
||||
};
|
||||
|
||||
// "Kernel Specification." - Contains device-independent fusion information.
|
||||
// Each kernel specification contains a map of instantiated generated functions
|
||||
// that implement some or most of its functionality. Multiple generated
|
||||
|
|
@ -90,7 +64,6 @@ struct TORCH_API KernelSpec {
|
|||
nTensorInputs_{},
|
||||
inputBroadcastGroups_{},
|
||||
inputChunks_{},
|
||||
outputMapAndSizes_{},
|
||||
has_random_{false},
|
||||
kernels_{} {
|
||||
for (const auto& n : graph_->nodes()) {
|
||||
|
|
@ -136,10 +109,6 @@ struct TORCH_API KernelSpec {
|
|||
return inputChunks_;
|
||||
}
|
||||
|
||||
std::vector<OutputMapAndSize>& outputMapAndSizes() {
|
||||
return outputMapAndSizes_;
|
||||
}
|
||||
|
||||
bool hasRandom() const {
|
||||
return has_random_;
|
||||
}
|
||||
|
|
@ -167,11 +136,6 @@ struct TORCH_API KernelSpec {
|
|||
uint64_t nTensorInputs_;
|
||||
std::vector<std::vector<int64_t>> inputBroadcastGroups_;
|
||||
std::vector<PartitionInfo> inputChunks_;
|
||||
// This will initially be an empty vector. During kernel compilation
|
||||
// in processGradSumToSize it will be filled and will contain one
|
||||
// element per fusion group output (which may be larger than the
|
||||
// number of kernel outputs).
|
||||
std::vector<OutputMapAndSize> outputMapAndSizes_;
|
||||
bool has_random_;
|
||||
mutable std::mutex mutex_;
|
||||
mutable std::
|
||||
|
|
|
|||
|
|
@ -170,13 +170,6 @@ struct GraphFuser {
|
|||
});
|
||||
}
|
||||
|
||||
bool containsGradSumToSize(Node* fusion_group) {
|
||||
auto nodes = getSubgraph(fusion_group).nodes();
|
||||
return std::any_of(nodes.begin(), nodes.end(), [](Node* n) {
|
||||
return n->kind() == aten::_grad_sum_to_size;
|
||||
});
|
||||
}
|
||||
|
||||
bool isFusable(Node* node) {
|
||||
return callback_(node);
|
||||
}
|
||||
|
|
@ -212,14 +205,6 @@ struct GraphFuser {
|
|||
// are not necessarily correct.
|
||||
if (node->owningBlock() != block_)
|
||||
return false;
|
||||
if (node->kind() == aten::_grad_sum_to_size) {
|
||||
// We only fuse _grad_sum_to_size if
|
||||
// - we will fuse its input next (checked here)
|
||||
// - we can commute the _grad_sum_to_size with everything
|
||||
// along the computation graph until we reach the outputs,
|
||||
// but this is checked later
|
||||
return isFusable(node->inputs()[0]->node());
|
||||
}
|
||||
return node->kind() == prim::FusionGroup || isSimpleMap(node);
|
||||
}
|
||||
|
||||
|
|
@ -444,23 +429,6 @@ struct GraphFuser {
|
|||
return at::nullopt;
|
||||
}
|
||||
|
||||
if (producer->node()->kind() == aten::_grad_sum_to_size &&
|
||||
consumer->kind() == kind_) {
|
||||
// check that we will be able to move the _grad_sum_to_size to be fused
|
||||
// to the end of the fusion group in the fusion compiler
|
||||
// the difficulty here is that the producer is not part of the fusion
|
||||
// group yet
|
||||
for (auto& u : producer->uses()) {
|
||||
if (u.user == consumer) {
|
||||
auto subgraph = &getSubgraph(consumer);
|
||||
if (!trackSingleGradSumToSizeToOutputs(
|
||||
subgraph->inputs().at(u.offset), nullptr)) {
|
||||
return at::nullopt;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto group = consumer;
|
||||
if (consumer->kind() != kind_) {
|
||||
group = createSingletonFusionGroup(consumer);
|
||||
|
|
@ -1039,13 +1007,11 @@ struct GraphFuser {
|
|||
}
|
||||
|
||||
// Fusion groups can be merged with concat's group if and only if
|
||||
// - the value they produce isn't already coming from a concat and
|
||||
// - the fusion group does not contain GradSumToSize
|
||||
// the value they produce isn't already coming from a concat
|
||||
if (producer->node()->kind() == prim::FusionGroup) {
|
||||
auto subgraph = producer->node()->g(attr::Subgraph);
|
||||
auto* node = subgraph->outputs().at(producer->offset())->node();
|
||||
return node->kind() != prim::FusedConcat &&
|
||||
!containsGradSumToSize(producer->node());
|
||||
return node->kind() != prim::FusedConcat;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
@ -1227,83 +1193,6 @@ void PeepholeOptimizeShapeExpressions(Block* block) {
|
|||
|
||||
} // anonymous namespace
|
||||
|
||||
// This takes a _grad_sum_to_size output and tracks it to the return
|
||||
// statements that depend on it, checking that it only hits nodes
|
||||
// that commute with _grad_sum_to_size on its path.
|
||||
// If a non-nullptr vector pointer outputGradSumToSizes is passed, the sizes
|
||||
// will be recorded as target sizes for the outputs as applicable.
|
||||
// In the graph_fuser pass we only need to check that we can go to the
|
||||
// outputs while in the fuser's compiler we want to record the sizes.
|
||||
// Note: This will only record a new sum_to_size if there is not one
|
||||
// already. As we want the last grad_sum_to_size, you need to call
|
||||
// it in reverse order when recording and removing outputs.
|
||||
bool trackSingleGradSumToSizeToOutputs(
|
||||
Value* gradSumToSizeOutput,
|
||||
std::vector<int64_t>* outputGradSumToSizes) {
|
||||
static OperatorSet commutes_with_SumToSize{{
|
||||
"aten::mul(Tensor self, Tensor other) -> Tensor",
|
||||
"aten::div(Tensor self, Tensor other) -> Tensor",
|
||||
// for div we might check whether we're the first argument
|
||||
"aten::mul(Tensor self, Scalar other) -> Tensor",
|
||||
"aten::div(Tensor self, Scalar other) -> Tensor",
|
||||
"aten::neg(Tensor self) -> Tensor",
|
||||
"aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
|
||||
"aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
|
||||
// add this used to be prim::AutogradAdd
|
||||
}};
|
||||
|
||||
std::queue<Use> uses_to_process{};
|
||||
auto add_to_uses = [&](const use_list& uses) {
|
||||
for (auto u : uses) {
|
||||
uses_to_process.push(u);
|
||||
}
|
||||
};
|
||||
add_to_uses(gradSumToSizeOutput->uses());
|
||||
while (!uses_to_process.empty()) {
|
||||
auto user = uses_to_process.front().user;
|
||||
auto offset = uses_to_process.front().offset;
|
||||
uses_to_process.pop();
|
||||
if (user->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) {
|
||||
// sometimes, a mask or similar is cast to the same type as the gradient,
|
||||
// i.e. we see other. Then we don't need to do anything, as the shape is
|
||||
// not used, only the type..
|
||||
// But we might also see it as self, when the gradient is cast, then we
|
||||
// want to track it.
|
||||
if (offset == 0) {
|
||||
add_to_uses(user->output()->uses());
|
||||
}
|
||||
} else if (commutes_with_SumToSize.find(user)) {
|
||||
add_to_uses(user->output()->uses());
|
||||
} else if (user->kind() == prim::Return) {
|
||||
// During compilation and only if we don't already have a
|
||||
// _grad_sum_to_size for this output we record the size to sum the output
|
||||
// to. We only do this if we didn't see anything yet because we want later
|
||||
// (in the graph) nodes to take precedence over earlier ones and we
|
||||
// iterate backwards. The implicit assumption is that if we have several
|
||||
// _grad_sumtosizes "in parallel" (from auto-diff added AutogradAdd as the
|
||||
// backward of using an input in multiple places) they are the same. This
|
||||
// is because AutogradAdd does not broadcast.
|
||||
if (outputGradSumToSizes && (*outputGradSumToSizes)[offset] == -1) {
|
||||
// note: we make the assumption that the sizes are inputs to the
|
||||
// fusion group (rather than something calculated).
|
||||
(*outputGradSumToSizes)[offset] =
|
||||
gradSumToSizeOutput->node()->inputs()[1]->offset();
|
||||
}
|
||||
} else if (user->kind() == aten::_grad_sum_to_size) {
|
||||
// do nothing
|
||||
// this case only happens in the graph_fuser step because in the
|
||||
// compile step because we iterate backwards and delete
|
||||
// all _grad_sum_to_size nodes we see
|
||||
} else {
|
||||
// we find something we do not support. Note that this notably includes
|
||||
// prim::FusedConcat, which we do not know how to deal with in conjunction
|
||||
// with _grad_sum_to_size
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void FuseGraph(std::shared_ptr<Graph>& graph) {
|
||||
GraphFuser(graph->block(), graph).run();
|
||||
// After FuseGraph some common subexpressions may come back
|
||||
|
|
|
|||
|
|
@ -28,9 +28,5 @@ TORCH_API void CustomFuseGraph(
|
|||
Symbol kind,
|
||||
size_t arg_limit=std::numeric_limits<size_t>::max());
|
||||
|
||||
TORCH_API bool trackSingleGradSumToSizeToOutputs(
|
||||
Value* gradSumToSizeOutput,
|
||||
std::vector<int64_t>* outputGradSumToSizes);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user