mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Disable fusion of grad_sum_to_size (#23372)
Summary: Fixes: https://github.com/pytorch/pytorch/issues/22833 grad_sum_to_size does not commute with AutogradAdd after all because it turns the broadcasting AutogradAdd into a broadcasting add. Chillee did actually do most of the tracking down to the fusion of grad_sum_to_size and pinging me when he had found the cause. Thank you! About the choice of removing the fusion completely instead of being more precise: - We do have grad_sum_to_size elimination which works for cases where broadcasting does not actually happen in the forward, so the cases where the fusing of grad_sum_to_size is actually beneficial is much smaller than when initially proposed. - There will be less fusion, in terms of the tests, IOU stops being fully fused. I vaguely think that it is a case we could handle with refined logic. - Keeping it would add complexity in checking when to merge fusion groups to the complexities that this PR removes. - The future of fusion probably lies more in more complete solutions including reductions (TVM or KeOps or our own or ...). Pull Request resolved: https://github.com/pytorch/pytorch/pull/23372 Differential Revision: D16489930 Pulled By: soumith fbshipit-source-id: bc0431b0d3eda264c401b634675872c4ce46f0f4
This commit is contained in:
parent
82545ecc71
commit
cf50249bde
|
|
@ -543,6 +543,7 @@ class TestFuser(JitTestCase):
|
||||||
|
|
||||||
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle")
|
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle")
|
||||||
@enable_cpu_fuser
|
@enable_cpu_fuser
|
||||||
|
@unittest.skip("temporarily disabled because fusion was restricted in fixing #22833")
|
||||||
def test_fuser_iou(self):
|
def test_fuser_iou(self):
|
||||||
# This checks if most of Intersection over Union is fused.
|
# This checks if most of Intersection over Union is fused.
|
||||||
# In particular, the backward contains many _grad_sum_to_size.
|
# In particular, the backward contains many _grad_sum_to_size.
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@
|
||||||
#include <torch/csrc/jit/ir.h>
|
#include <torch/csrc/jit/ir.h>
|
||||||
#include <torch/csrc/jit/operator.h>
|
#include <torch/csrc/jit/operator.h>
|
||||||
#include <torch/csrc/jit/passes/canonicalize.h>
|
#include <torch/csrc/jit/passes/canonicalize.h>
|
||||||
#include <torch/csrc/jit/passes/graph_fuser.h>
|
|
||||||
#include <torch/csrc/jit/passes/shape_analysis.h>
|
#include <torch/csrc/jit/passes/shape_analysis.h>
|
||||||
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
|
|
@ -157,71 +156,6 @@ static void setInputBroadcastGroups(KernelSpec& spec) {
|
||||||
std::back_inserter(spec.inputBroadcastGroups()));
|
std::back_inserter(spec.inputBroadcastGroups()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// This function moves _grad_sum_to_size nodes along the computation graph
|
|
||||||
// of the fusion group to the outputs and then records the shape inputs
|
|
||||||
// in order for summation to be applied after the kernel.
|
|
||||||
// Note that the correctness relies on the invariant that
|
|
||||||
// _grad_sum_to_size is only applied to gradient nodes created by autodiff.
|
|
||||||
// This is important because it ensures that in the mul and div nodes only
|
|
||||||
// one argument (in the case of div the numerator) has a summed value.
|
|
||||||
// If two arguments to mul had one, we would be in trouble, but thanks
|
|
||||||
// to the chain rule, we're OK.
|
|
||||||
// Note that this means that one kernel output may lead to several fusion
|
|
||||||
// group outputs when several outputs had the same calculation except
|
|
||||||
// for the final _grad_sum_to_size. This is also the reason why
|
|
||||||
// we need to deduplicate kernel outputs at the end of this function.
|
|
||||||
void processGradSumToSize(KernelSpec& spec) {
|
|
||||||
auto graph = spec.graph();
|
|
||||||
|
|
||||||
std::vector<int64_t> outputGradSumToSizes(graph->outputs().size(), -1);
|
|
||||||
|
|
||||||
// these are expressions that might occur during autotdiff operating
|
|
||||||
// on the gradient (matmul would likely be, too but we don't fuse it)
|
|
||||||
// note that for mul, we know (from the chain rule) that only one
|
|
||||||
// factor will be stemming from a calculation involving gradients so
|
|
||||||
// we know that we can move _grad_sum_to_size across it
|
|
||||||
// Scan the graph. We will delete nodes. We want later (in the graph)
|
|
||||||
// _grad_sum_to_size nodes to have priority over earlier ones. Thus
|
|
||||||
// we scan backwards.
|
|
||||||
for (auto it = graph->nodes().rbegin(); it != graph->nodes().rend(); it++) {
|
|
||||||
auto* node = *it;
|
|
||||||
if (node->kind() != aten::_grad_sum_to_size) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
bool success = trackSingleGradSumToSizeToOutputs(
|
|
||||||
node->output(), &outputGradSumToSizes);
|
|
||||||
AT_ASSERT(success); // check that we didn't hit anything unknown
|
|
||||||
|
|
||||||
// remove the GradSumToSize node, a new node outside the fusion graph
|
|
||||||
// will be inserted below
|
|
||||||
node->output()->replaceAllUsesWith(node->inputs()[0]);
|
|
||||||
it.destroyCurrent();
|
|
||||||
}
|
|
||||||
|
|
||||||
// By removing the _grad_sum_to_size notes, we might end up with
|
|
||||||
// duplicate outputs, e.g. when having the autodiff backwards of
|
|
||||||
// x + y + z of something with x, y, z, those will have different
|
|
||||||
// _grad_sum_to_sizes but of the same kernel output.
|
|
||||||
|
|
||||||
// for each fusion group output, record the corresponding kernel
|
|
||||||
// output and possibly a _grad_sum_to_size for that output
|
|
||||||
auto& outputMapAndSizes = spec.outputMapAndSizes();
|
|
||||||
AT_ASSERT(outputMapAndSizes.empty());
|
|
||||||
std::unordered_map<const Value*, int64_t> reduced_output_indices;
|
|
||||||
int64_t newo = 0;
|
|
||||||
for (auto osize : outputGradSumToSizes) {
|
|
||||||
auto it = reduced_output_indices.find(graph->outputs()[newo]);
|
|
||||||
if (it == reduced_output_indices.end()) {
|
|
||||||
reduced_output_indices.emplace(graph->outputs()[newo], newo);
|
|
||||||
outputMapAndSizes.emplace_back(newo, osize);
|
|
||||||
newo++;
|
|
||||||
} else {
|
|
||||||
graph->eraseOutput(newo);
|
|
||||||
outputMapAndSizes.emplace_back(it->second, osize);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Performs "upfront" compilation where storage is known but shapes are not.
|
// Performs "upfront" compilation where storage is known but shapes are not.
|
||||||
// Currently identifies how to expand all tensors so that all intermediate
|
// Currently identifies how to expand all tensors so that all intermediate
|
||||||
// tensors are the same shape, simplifying code generation.
|
// tensors are the same shape, simplifying code generation.
|
||||||
|
|
@ -234,7 +168,6 @@ void processGradSumToSize(KernelSpec& spec) {
|
||||||
static void upfrontCompilation(KernelSpec& spec) {
|
static void upfrontCompilation(KernelSpec& spec) {
|
||||||
setInputBroadcastGroups(spec);
|
setInputBroadcastGroups(spec);
|
||||||
setInputChunkDescriptors(spec);
|
setInputChunkDescriptors(spec);
|
||||||
processGradSumToSize(spec);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t registerFusion(const Node* fusion_group) {
|
int64_t registerFusion(const Node* fusion_group) {
|
||||||
|
|
|
||||||
|
|
@ -329,7 +329,6 @@ bool runFusion(const int64_t key, Stack& stack, std::string* code_out) {
|
||||||
auto maybe_spec = retrieve(key);
|
auto maybe_spec = retrieve(key);
|
||||||
AT_ASSERT(maybe_spec);
|
AT_ASSERT(maybe_spec);
|
||||||
auto& spec = *(*maybe_spec);
|
auto& spec = *(*maybe_spec);
|
||||||
|
|
||||||
// Acquires inputs from stack
|
// Acquires inputs from stack
|
||||||
auto all_inputs = last(stack, spec.nInputs());
|
auto all_inputs = last(stack, spec.nInputs());
|
||||||
std::vector<at::Tensor> inputs;
|
std::vector<at::Tensor> inputs;
|
||||||
|
|
@ -381,18 +380,8 @@ bool runFusion(const int64_t key, Stack& stack, std::string* code_out) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launches fusion
|
// Launches fusion
|
||||||
std::vector<at::Tensor> raw_outputs;
|
std::vector<at::Tensor> outputs;
|
||||||
launchFusion(*(*maybe_kernel), device, inputs, all_inputs, raw_outputs);
|
launchFusion(*(*maybe_kernel), device, inputs, all_inputs, outputs);
|
||||||
|
|
||||||
auto outputs = fmap(spec.outputMapAndSizes(), [&](const OutputMapAndSize& omap) {
|
|
||||||
if (omap.needsSumToSize()) {
|
|
||||||
return at::sum_to(
|
|
||||||
raw_outputs[omap.offset()],
|
|
||||||
all_inputs[omap.sizeInput()].toIntListRef());
|
|
||||||
} else {
|
|
||||||
return raw_outputs[omap.offset()];
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Updates stack
|
// Updates stack
|
||||||
drop(stack, spec.nInputs());
|
drop(stack, spec.nInputs());
|
||||||
|
|
|
||||||
|
|
@ -41,32 +41,6 @@ struct TORCH_API PartitionInfo {
|
||||||
int64_t dim_;
|
int64_t dim_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// This is a helper struct to record the following:
|
|
||||||
// for each fusion group output, it records the corresponding
|
|
||||||
// kernel output offset (in offset) and the fusion group input
|
|
||||||
// to that is to be applied with sumtosize on the output (if any).
|
|
||||||
// This mapping is necessar as a single kernel output might be
|
|
||||||
// summed to different sizes.
|
|
||||||
// These mappings are created during compilation in processGradSumToSize.
|
|
||||||
struct TORCH_API OutputMapAndSize {
|
|
||||||
OutputMapAndSize(const int64_t _offset, const int64_t _sizeInput)
|
|
||||||
: offset_{_offset}, sizeInput_{_sizeInput} {};
|
|
||||||
|
|
||||||
int64_t offset() const {
|
|
||||||
return offset_;
|
|
||||||
}
|
|
||||||
int64_t sizeInput() const {
|
|
||||||
return sizeInput_;
|
|
||||||
}
|
|
||||||
bool needsSumToSize() const {
|
|
||||||
return sizeInput_ != -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
int64_t offset_;
|
|
||||||
int64_t sizeInput_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// "Kernel Specification." - Contains device-independent fusion information.
|
// "Kernel Specification." - Contains device-independent fusion information.
|
||||||
// Each kernel specification contains a map of instantiated generated functions
|
// Each kernel specification contains a map of instantiated generated functions
|
||||||
// that implement some or most of its functionality. Multiple generated
|
// that implement some or most of its functionality. Multiple generated
|
||||||
|
|
@ -90,7 +64,6 @@ struct TORCH_API KernelSpec {
|
||||||
nTensorInputs_{},
|
nTensorInputs_{},
|
||||||
inputBroadcastGroups_{},
|
inputBroadcastGroups_{},
|
||||||
inputChunks_{},
|
inputChunks_{},
|
||||||
outputMapAndSizes_{},
|
|
||||||
has_random_{false},
|
has_random_{false},
|
||||||
kernels_{} {
|
kernels_{} {
|
||||||
for (const auto& n : graph_->nodes()) {
|
for (const auto& n : graph_->nodes()) {
|
||||||
|
|
@ -136,10 +109,6 @@ struct TORCH_API KernelSpec {
|
||||||
return inputChunks_;
|
return inputChunks_;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<OutputMapAndSize>& outputMapAndSizes() {
|
|
||||||
return outputMapAndSizes_;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool hasRandom() const {
|
bool hasRandom() const {
|
||||||
return has_random_;
|
return has_random_;
|
||||||
}
|
}
|
||||||
|
|
@ -167,11 +136,6 @@ struct TORCH_API KernelSpec {
|
||||||
uint64_t nTensorInputs_;
|
uint64_t nTensorInputs_;
|
||||||
std::vector<std::vector<int64_t>> inputBroadcastGroups_;
|
std::vector<std::vector<int64_t>> inputBroadcastGroups_;
|
||||||
std::vector<PartitionInfo> inputChunks_;
|
std::vector<PartitionInfo> inputChunks_;
|
||||||
// This will initially be an empty vector. During kernel compilation
|
|
||||||
// in processGradSumToSize it will be filled and will contain one
|
|
||||||
// element per fusion group output (which may be larger than the
|
|
||||||
// number of kernel outputs).
|
|
||||||
std::vector<OutputMapAndSize> outputMapAndSizes_;
|
|
||||||
bool has_random_;
|
bool has_random_;
|
||||||
mutable std::mutex mutex_;
|
mutable std::mutex mutex_;
|
||||||
mutable std::
|
mutable std::
|
||||||
|
|
|
||||||
|
|
@ -170,13 +170,6 @@ struct GraphFuser {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
bool containsGradSumToSize(Node* fusion_group) {
|
|
||||||
auto nodes = getSubgraph(fusion_group).nodes();
|
|
||||||
return std::any_of(nodes.begin(), nodes.end(), [](Node* n) {
|
|
||||||
return n->kind() == aten::_grad_sum_to_size;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isFusable(Node* node) {
|
bool isFusable(Node* node) {
|
||||||
return callback_(node);
|
return callback_(node);
|
||||||
}
|
}
|
||||||
|
|
@ -212,14 +205,6 @@ struct GraphFuser {
|
||||||
// are not necessarily correct.
|
// are not necessarily correct.
|
||||||
if (node->owningBlock() != block_)
|
if (node->owningBlock() != block_)
|
||||||
return false;
|
return false;
|
||||||
if (node->kind() == aten::_grad_sum_to_size) {
|
|
||||||
// We only fuse _grad_sum_to_size if
|
|
||||||
// - we will fuse its input next (checked here)
|
|
||||||
// - we can commute the _grad_sum_to_size with everything
|
|
||||||
// along the computation graph until we reach the outputs,
|
|
||||||
// but this is checked later
|
|
||||||
return isFusable(node->inputs()[0]->node());
|
|
||||||
}
|
|
||||||
return node->kind() == prim::FusionGroup || isSimpleMap(node);
|
return node->kind() == prim::FusionGroup || isSimpleMap(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -444,23 +429,6 @@ struct GraphFuser {
|
||||||
return at::nullopt;
|
return at::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (producer->node()->kind() == aten::_grad_sum_to_size &&
|
|
||||||
consumer->kind() == kind_) {
|
|
||||||
// check that we will be able to move the _grad_sum_to_size to be fused
|
|
||||||
// to the end of the fusion group in the fusion compiler
|
|
||||||
// the difficulty here is that the producer is not part of the fusion
|
|
||||||
// group yet
|
|
||||||
for (auto& u : producer->uses()) {
|
|
||||||
if (u.user == consumer) {
|
|
||||||
auto subgraph = &getSubgraph(consumer);
|
|
||||||
if (!trackSingleGradSumToSizeToOutputs(
|
|
||||||
subgraph->inputs().at(u.offset), nullptr)) {
|
|
||||||
return at::nullopt;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto group = consumer;
|
auto group = consumer;
|
||||||
if (consumer->kind() != kind_) {
|
if (consumer->kind() != kind_) {
|
||||||
group = createSingletonFusionGroup(consumer);
|
group = createSingletonFusionGroup(consumer);
|
||||||
|
|
@ -1039,13 +1007,11 @@ struct GraphFuser {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fusion groups can be merged with concat's group if and only if
|
// Fusion groups can be merged with concat's group if and only if
|
||||||
// - the value they produce isn't already coming from a concat and
|
// the value they produce isn't already coming from a concat
|
||||||
// - the fusion group does not contain GradSumToSize
|
|
||||||
if (producer->node()->kind() == prim::FusionGroup) {
|
if (producer->node()->kind() == prim::FusionGroup) {
|
||||||
auto subgraph = producer->node()->g(attr::Subgraph);
|
auto subgraph = producer->node()->g(attr::Subgraph);
|
||||||
auto* node = subgraph->outputs().at(producer->offset())->node();
|
auto* node = subgraph->outputs().at(producer->offset())->node();
|
||||||
return node->kind() != prim::FusedConcat &&
|
return node->kind() != prim::FusedConcat;
|
||||||
!containsGradSumToSize(producer->node());
|
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
@ -1227,83 +1193,6 @@ void PeepholeOptimizeShapeExpressions(Block* block) {
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
// This takes a _grad_sum_to_size output and tracks it to the return
|
|
||||||
// statements that depend on it, checking that it only hits nodes
|
|
||||||
// that commute with _grad_sum_to_size on its path.
|
|
||||||
// If a non-nullptr vector pointer outputGradSumToSizes is passed, the sizes
|
|
||||||
// will be recorded as target sizes for the outputs as applicable.
|
|
||||||
// In the graph_fuser pass we only need to check that we can go to the
|
|
||||||
// outputs while in the fuser's compiler we want to record the sizes.
|
|
||||||
// Note: This will only record a new sum_to_size if there is not one
|
|
||||||
// already. As we want the last grad_sum_to_size, you need to call
|
|
||||||
// it in reverse order when recording and removing outputs.
|
|
||||||
bool trackSingleGradSumToSizeToOutputs(
|
|
||||||
Value* gradSumToSizeOutput,
|
|
||||||
std::vector<int64_t>* outputGradSumToSizes) {
|
|
||||||
static OperatorSet commutes_with_SumToSize{{
|
|
||||||
"aten::mul(Tensor self, Tensor other) -> Tensor",
|
|
||||||
"aten::div(Tensor self, Tensor other) -> Tensor",
|
|
||||||
// for div we might check whether we're the first argument
|
|
||||||
"aten::mul(Tensor self, Scalar other) -> Tensor",
|
|
||||||
"aten::div(Tensor self, Scalar other) -> Tensor",
|
|
||||||
"aten::neg(Tensor self) -> Tensor",
|
|
||||||
"aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
|
|
||||||
"aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
|
|
||||||
// add this used to be prim::AutogradAdd
|
|
||||||
}};
|
|
||||||
|
|
||||||
std::queue<Use> uses_to_process{};
|
|
||||||
auto add_to_uses = [&](const use_list& uses) {
|
|
||||||
for (auto u : uses) {
|
|
||||||
uses_to_process.push(u);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
add_to_uses(gradSumToSizeOutput->uses());
|
|
||||||
while (!uses_to_process.empty()) {
|
|
||||||
auto user = uses_to_process.front().user;
|
|
||||||
auto offset = uses_to_process.front().offset;
|
|
||||||
uses_to_process.pop();
|
|
||||||
if (user->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) {
|
|
||||||
// sometimes, a mask or similar is cast to the same type as the gradient,
|
|
||||||
// i.e. we see other. Then we don't need to do anything, as the shape is
|
|
||||||
// not used, only the type..
|
|
||||||
// But we might also see it as self, when the gradient is cast, then we
|
|
||||||
// want to track it.
|
|
||||||
if (offset == 0) {
|
|
||||||
add_to_uses(user->output()->uses());
|
|
||||||
}
|
|
||||||
} else if (commutes_with_SumToSize.find(user)) {
|
|
||||||
add_to_uses(user->output()->uses());
|
|
||||||
} else if (user->kind() == prim::Return) {
|
|
||||||
// During compilation and only if we don't already have a
|
|
||||||
// _grad_sum_to_size for this output we record the size to sum the output
|
|
||||||
// to. We only do this if we didn't see anything yet because we want later
|
|
||||||
// (in the graph) nodes to take precedence over earlier ones and we
|
|
||||||
// iterate backwards. The implicit assumption is that if we have several
|
|
||||||
// _grad_sumtosizes "in parallel" (from auto-diff added AutogradAdd as the
|
|
||||||
// backward of using an input in multiple places) they are the same. This
|
|
||||||
// is because AutogradAdd does not broadcast.
|
|
||||||
if (outputGradSumToSizes && (*outputGradSumToSizes)[offset] == -1) {
|
|
||||||
// note: we make the assumption that the sizes are inputs to the
|
|
||||||
// fusion group (rather than something calculated).
|
|
||||||
(*outputGradSumToSizes)[offset] =
|
|
||||||
gradSumToSizeOutput->node()->inputs()[1]->offset();
|
|
||||||
}
|
|
||||||
} else if (user->kind() == aten::_grad_sum_to_size) {
|
|
||||||
// do nothing
|
|
||||||
// this case only happens in the graph_fuser step because in the
|
|
||||||
// compile step because we iterate backwards and delete
|
|
||||||
// all _grad_sum_to_size nodes we see
|
|
||||||
} else {
|
|
||||||
// we find something we do not support. Note that this notably includes
|
|
||||||
// prim::FusedConcat, which we do not know how to deal with in conjunction
|
|
||||||
// with _grad_sum_to_size
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void FuseGraph(std::shared_ptr<Graph>& graph) {
|
void FuseGraph(std::shared_ptr<Graph>& graph) {
|
||||||
GraphFuser(graph->block(), graph).run();
|
GraphFuser(graph->block(), graph).run();
|
||||||
// After FuseGraph some common subexpressions may come back
|
// After FuseGraph some common subexpressions may come back
|
||||||
|
|
|
||||||
|
|
@ -28,9 +28,5 @@ TORCH_API void CustomFuseGraph(
|
||||||
Symbol kind,
|
Symbol kind,
|
||||||
size_t arg_limit=std::numeric_limits<size_t>::max());
|
size_t arg_limit=std::numeric_limits<size_t>::max());
|
||||||
|
|
||||||
TORCH_API bool trackSingleGradSumToSizeToOutputs(
|
|
||||||
Value* gradSumToSizeOutput,
|
|
||||||
std::vector<int64_t>* outputGradSumToSizes);
|
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user