mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37106 Recomputing the aliasdb on every fusion iteration + in every subblock is hugely expensive. Instead, update it in-place when doing fusion. The graph fuser pass operates by pushing nodes into a fusion group. So we start with ``` x, y = f(a, b, c) ``` and end with: ``` x_out, y_out = prim::fusionGroup(a, b, c) x_in, y_in = f(a_in, b_in, c_in) -> x_in, y_in ``` We destroy the `x` and `y` `Value*`s in the process. This operation is easy to express as an update to the aliasDb--`x_out` just takes on all the aliasing information `x` used to have. In particular, since we know `f` and `prim::fusionGroup` are purely functional, we don't have to mess with any write information. This PR is the bare minimum to get this working, in the interest of unscrewing the compilation times ASAP. Followups I want to do: - We don't have a way of expressing deletion of values in AliasDb. In `graph_fuser.cpp` we sometimes construct nodes that we end up throwing away, and we are littering `MemoryDAG` with references to dangling pointers. Because of the way the pass works, it's fine, but this is fragile so I want to fix it. - We should decouple alias analysis from write tracking, to simplify the job of keeping the write caches consistent as we mutate the aliasing information. - the tensorexpr fuser doesn't do this and thus is incorrect today, we need to update it to work. Test Plan: Imported from OSS Differential Revision: D21219179 Pulled By: suo fbshipit-source-id: 8ae5397b3a0ad90edec2fbc555647091f1ad5284
254 lines
8.1 KiB
C++
254 lines
8.1 KiB
C++
#include "test/cpp/jit/test_base.h"
|
|
|
|
#include <torch/csrc/jit/passes/canonicalize.h>
|
|
#include "ATen/core/interned_strings.h"
|
|
#include "torch/csrc/autograd/generated/variable_factories.h"
|
|
#include "torch/csrc/autograd/variable.h"
|
|
#include "torch/csrc/jit/codegen/fuser/interface.h"
|
|
#include "torch/csrc/jit/frontend/code_template.h"
|
|
#include "torch/csrc/jit/frontend/tracer.h"
|
|
#include "torch/csrc/jit/ir/alias_analysis.h"
|
|
#include "torch/csrc/jit/ir/attributes.h"
|
|
#include "torch/csrc/jit/ir/irparser.h"
|
|
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
|
|
#include "torch/csrc/jit/passes/constant_propagation.h"
|
|
#include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"
|
|
#include "torch/csrc/jit/passes/dead_code_elimination.h"
|
|
#include "torch/csrc/jit/passes/graph_fuser.h"
|
|
#include "torch/csrc/jit/passes/lower_grad_of.h"
|
|
#include "torch/csrc/jit/passes/lower_tuples.h"
|
|
#include "torch/csrc/jit/passes/requires_grad_analysis.h"
|
|
#include "torch/csrc/jit/passes/shape_analysis.h"
|
|
#include "torch/csrc/jit/passes/utils/subgraph_utils.h"
|
|
#include "torch/csrc/jit/runtime/argument_spec.h"
|
|
#include "torch/csrc/jit/runtime/autodiff.h"
|
|
#include "torch/csrc/jit/runtime/custom_operator.h"
|
|
#include "torch/csrc/jit/runtime/interpreter.h"
|
|
#include "torch/csrc/jit/runtime/symbolic_script.h"
|
|
#include "torch/csrc/jit/serialization/import.h"
|
|
|
|
#include "torch/csrc/autograd/engine.h"
|
|
#include "torch/csrc/autograd/variable.h"
|
|
|
|
#include <torch/csrc/jit/testing/file_check.h>
|
|
#include "ATen/core/ivalue.h"
|
|
#include "torch/csrc/jit/api/module.h"
|
|
#include "torch/csrc/jit/frontend/ir_emitter.h"
|
|
#include "torch/csrc/jit/runtime/graph_executor.h"
|
|
|
|
#include "onnx/onnx_pb.h"
|
|
|
|
#include <ATen/ATen.h>
|
|
|
|
#include <c10/util/Exception.h>
|
|
|
|
#include <algorithm>
|
|
#include <cstddef>
|
|
#include <functional>
|
|
#include <iostream>
|
|
#include <memory>
|
|
#include <stdexcept>
|
|
#include <string>
|
|
#include <tuple>
|
|
#include <unordered_set>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
void testFusion() {
|
|
auto testSimple = [&] {
|
|
const auto graph_string = R"IR(
|
|
graph(%0 : Tensor,
|
|
%1 : Tensor):
|
|
%2 : Tensor = aten::mul(%0, %1)
|
|
return (%2))IR";
|
|
Graph graph;
|
|
torch::jit::parseIR(graph_string, &graph);
|
|
|
|
auto a = at::rand({3, 4}, at::kCUDA);
|
|
auto b = at::rand({4, 3}, at::kCUDA).transpose(0, 1);
|
|
auto o = at::zeros({3, 4}, at::kCUDA);
|
|
auto outputs = debugLaunchGraph(graph, {a, b});
|
|
ASSERT_EQ(outputs.size(), 1);
|
|
auto o2 = a * b;
|
|
float max_diff = (o2 - outputs[0]).abs().max().item<double>();
|
|
// std::cout << "max diff: " << max_diff << "\n";
|
|
ASSERT_EQ(max_diff, 0);
|
|
};
|
|
testSimple();
|
|
|
|
auto testOne = [&](int ti, int tj) {
|
|
const auto graph_string = R"IR(
|
|
graph(%0 : Tensor,
|
|
%1 : Tensor,
|
|
%2 : Tensor,
|
|
%3 : Tensor,
|
|
%4 : Tensor):
|
|
%5 : Tensor = aten::sigmoid(%4)
|
|
%6 : Tensor = aten::sigmoid(%3)
|
|
%7 : Tensor = aten::tanh(%2)
|
|
%8 : Tensor = aten::sigmoid(%1)
|
|
%9 : Tensor = aten::mul(%6, %0)
|
|
%10 : Tensor = aten::mul(%5, %7)
|
|
%11 : int = prim::Constant[value=1]()
|
|
%12 : Tensor = aten::add(%9, %10, %11)
|
|
%13 : Tensor = aten::tanh(%12)
|
|
%14 : Tensor = aten::mul(%8, %13)
|
|
return (%14, %12))IR";
|
|
Graph graph;
|
|
torch::jit::parseIR(graph_string, &graph);
|
|
|
|
graph.lint();
|
|
|
|
std::vector<at::Tensor> inputs;
|
|
// We want to generate input/output tensors with dimension 128x128x32, but
|
|
// with different internal strides. To do this, we generate a tensor
|
|
// with the "wrong" dimensions, and then use transpose to get an
|
|
// appropriately sized view.
|
|
for (size_t i = 0; i < graph.inputs().size(); i++) {
|
|
std::vector<int64_t> dims = {128, 128, 32};
|
|
std::swap(dims[ti], dims[tj]);
|
|
inputs.push_back(at::rand(dims, at::kCUDA).transpose(ti, tj));
|
|
}
|
|
|
|
auto t22 = inputs[4].sigmoid();
|
|
auto t20 = inputs[3].sigmoid();
|
|
auto t18 = inputs[2].tanh();
|
|
auto t16 = inputs[1].sigmoid();
|
|
auto t14 = t20 * inputs[0];
|
|
auto t11 = t22 * t18;
|
|
auto out1 = t14 + t11;
|
|
auto t5 = out1.tanh();
|
|
auto out0 = t16 * t5;
|
|
|
|
auto outputs = debugLaunchGraph(graph, inputs);
|
|
ASSERT_EQ(outputs.size(), graph.outputs().size());
|
|
ASSERT_TRUE(out0.is_same_size(outputs.front()));
|
|
float max_diff = (outputs.front() - out0).abs().max().item<double>();
|
|
ASSERT_TRUE(max_diff < 1e-6);
|
|
};
|
|
testOne(0, 0);
|
|
testOne(0, 1);
|
|
testOne(1, 2);
|
|
testOne(0, 2);
|
|
|
|
const auto graph_string0 = R"IR(
|
|
graph(%0 : Tensor,
|
|
%1 : Tensor):
|
|
%2 : Tensor = aten::mul(%0, %1)
|
|
%3 : Tensor = prim::FusedConcat[dim=0](%0, %2)
|
|
return (%2, %3))IR";
|
|
const auto graph_string1 = R"IR(
|
|
graph(%0 : Tensor,
|
|
%1 : Tensor):
|
|
%2 : Tensor = aten::mul(%0, %1)
|
|
%3 : Tensor = prim::FusedConcat[dim=1](%0, %2)
|
|
return (%2, %3))IR";
|
|
const auto graph_string2 = R"IR(
|
|
graph(%0 : Tensor,
|
|
%1 : Tensor):
|
|
%2 : Tensor = aten::mul(%0, %1)
|
|
%3 : Tensor = prim::FusedConcat[dim=2](%0, %2)
|
|
return (%2, %3))IR";
|
|
|
|
auto a = at::rand({3, 4, 5}, at::kCUDA);
|
|
auto b = at::rand({4, 3, 5}, at::kCUDA).transpose(0, 1);
|
|
const auto o_r = a * b;
|
|
|
|
std::vector<std::string> graph_strings{
|
|
graph_string0, graph_string1, graph_string2};
|
|
for (auto i = decltype(graph_strings.size()){0}; i < graph_strings.size();
|
|
++i) {
|
|
Graph g;
|
|
torch::jit::parseIR(graph_strings[i], &g);
|
|
|
|
auto outputs = debugLaunchGraph(g, {a, b});
|
|
ASSERT_EQ(outputs.size(), 2);
|
|
|
|
float max_diff = (o_r - outputs[0]).abs().max().item<double>();
|
|
ASSERT_EQ(max_diff, 0);
|
|
|
|
const auto o2_r = at::cat({a, o_r}, i);
|
|
float max_diff2 = (o2_r - outputs[1]).abs().max().item<double>();
|
|
ASSERT_EQ(max_diff2, 0);
|
|
};
|
|
}
|
|
|
|
void testFusionAliasing() {
|
|
const auto graph_string = R"IR(
|
|
graph(%0 : Tensor,
|
|
%1 : Tensor):
|
|
%12 : int = prim::Constant[value=1]()
|
|
%2.1 : Tensor = aten::mul(%0, %1)
|
|
%2 : Tensor = aten::mul(%2.1, %1)
|
|
%3 : Tensor = aten::add_(%2, %1, %12)
|
|
%4 : Tensor = aten::mul(%2, %1)
|
|
%5 : Tensor = aten::add(%2, %4, %12)
|
|
return (%5))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
|
|
g->lint();
|
|
FuseGraph(g);
|
|
|
|
// We should not be able to fuse across the in-place operation here.
|
|
testing::FileCheck()
|
|
.check("prim::FusionGroup_0")
|
|
->check("aten::add_")
|
|
->check("prim::FusionGroup_1")
|
|
->run(*g);
|
|
}
|
|
|
|
void testRegisterFusionCachesKernel() {
|
|
// Constructs two functionally equivalent graphs
|
|
const auto graph0_string = R"IR(
|
|
graph(%0 : Float(2, 3, 4),
|
|
%1 : Float(2, 3, 4)):
|
|
%c0 : Float(2, 3, 4) = aten::mul(%0, %1)
|
|
%d0 : Float(2, 3, 4) = aten::mul(%c0, %0)
|
|
return (%d0))IR";
|
|
auto g0 = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph0_string, g0.get());
|
|
|
|
const auto graph1_string = R"IR(
|
|
graph(%0 : Float(2, 3, 4),
|
|
%1 : Float(2, 3, 4)):
|
|
%c1 : Float(2, 3, 4) = aten::mul(%0, %1)
|
|
%d1 : Float(2, 3, 4) = aten::mul(%c1, %0)
|
|
return (%d1))IR";
|
|
auto g1 = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph1_string, g1.get());
|
|
|
|
auto getFusionGroup = [](const std::shared_ptr<Graph>& graph) {
|
|
const auto& nodes = graph->nodes();
|
|
auto maybe_fusion_group =
|
|
std::find_if(nodes.begin(), nodes.end(), [](const Node* node) {
|
|
return node->kind() == prim::FusionGroup;
|
|
});
|
|
TORCH_CHECK(
|
|
maybe_fusion_group != nodes.end(),
|
|
"testRegisterFusionCachesKernel: could not create FusionGroup");
|
|
return *maybe_fusion_group;
|
|
};
|
|
|
|
// Creates two alpha-equivalent fusion groups
|
|
torch::jit::overrideCanFuseOnCPU(true);
|
|
FuseGraph(g0);
|
|
FuseGraph(g1);
|
|
torch::jit::overrideCanFuseOnCPU(false);
|
|
auto fg0 = getFusionGroup(g0);
|
|
auto fg1 = getFusionGroup(g1);
|
|
|
|
// Registers both with the fusion compiler.
|
|
auto expected_key = registerFusion(fg0);
|
|
auto second_key = registerFusion(fg1);
|
|
|
|
// Because the graphs are alpha-equivalent, they should return the same key
|
|
// and therefore share a KernelSpec to share kernels for specializations
|
|
ASSERT_EQ(second_key, expected_key);
|
|
}
|
|
} // namespace jit
|
|
} // namespace torch
|