mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36345
During compilation, we spend a huge amount of time in alias analyis.
This PR does a few things to speed it up.
1. Separate the analysis into two phases: one where we build up the
necessary data structures, and the other where we service aliasing
queries. This allows us to defer building indices/maintaining index
consistency until after the "buildup" phase is done.
2. Properly memoize/dynamic program the memory locations lookups.
3. Done naively, setting wildcards invalidates the above memoization,
trigger costly recomputation. So I added a cache-aware `setWildcards`.
Sadly that means you need alias analysis to reach into the guts of
memorydag, but the speedup is worth it.
Sadly, these changes are kind of coupled for correctness reasons, so
they're all here at once.
I used this model (thanks IlyaOvodov) as a provisional benchmark. You
can get it here:
https://www.dropbox.com/s/jlyygn6yygj1jkx/yolov3.zip. Unzip at run
`python test_timing.py`.
Baseline: (752.076s) right before 6bc8ffe824
After optimizing before inlining: (699.593s)
After deferring cache construction: (426.180s)
After cache-aware `setWildcards`: (193.678s)
So a nice 75% speedup to overall compilation. There's a lot more to do
in other places of the compilation pipeline though.
Followup to this PR specifically: Everything that fans out from the
`analyze` call is the "buildup" phase of AliasDB construction. This
should be factored into a separate analysis pass to statically
distinguish the two phases (right now we just null out stuff to
accomplish the same thing dynamically).
Test Plan: Imported from OSS
Differential Revision: D20952727
Pulled By: suo
fbshipit-source-id: 099f797222d7e71e5c04991584adc2c7eab5a70f
1353 lines
43 KiB
C++
1353 lines
43 KiB
C++
#include <torch/csrc/autograd/generated/variable_factories.h>
|
|
#include <torch/csrc/jit/ir/irparser.h>
|
|
#include "test/cpp/jit/test_base.h"
|
|
#include "torch/csrc/jit/frontend/ir_emitter.h"
|
|
#include "torch/csrc/jit/ir/alias_analysis.h"
|
|
#include "torch/csrc/jit/runtime/custom_operator.h"
|
|
#include "torch/csrc/utils/memory.h"
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
|
|
return c10::AliasAnalysisKind::FROM_SCHEMA;
|
|
}
|
|
|
|
// Fixture to set up a graph and make assertions clearer
|
|
struct TopoMoveTestFixture {
|
|
TopoMoveTestFixture() {
|
|
createGraph();
|
|
aliasDb = torch::make_unique<AliasDb>(graph);
|
|
}
|
|
|
|
// Nodes are named after their output.
|
|
// e.g. "a" is an alias for "the node that outputs the value `a`"
|
|
void createGraph() {
|
|
graph = std::make_shared<Graph>();
|
|
createNode("a", {});
|
|
createNode("b", {"a"});
|
|
createNode("c", {});
|
|
createNode("d", {"a", "b"});
|
|
createNode("e", {"c", "b"});
|
|
createNode("f", {"e"});
|
|
createNode("g", {"e"});
|
|
createNode("h", {"g"});
|
|
createNode("i", {"g"});
|
|
createNode("j", {"i"});
|
|
createNode("k", {"i"});
|
|
createNode("l", {"a"});
|
|
createNode("m", {}, {"l"}); // block depends on l
|
|
createNode("n", {"m"});
|
|
createNode("o", {"n"});
|
|
createNode("p", {});
|
|
createNode("q", {});
|
|
createNode("r", {"q"});
|
|
createNode("s", {"q"});
|
|
|
|
graph->lint();
|
|
}
|
|
|
|
void createNode(
|
|
const std::string& name,
|
|
const std::vector<std::string>& inputNames,
|
|
const std::vector<std::string>& blockInputNames = {}) {
|
|
std::vector<Value*> inputs;
|
|
for (const auto& name_ : inputNames) {
|
|
inputs.push_back(nodes.at(name_)->output());
|
|
}
|
|
auto node = graph->appendNode(graph->create(prim::AutogradZero, inputs));
|
|
node->output()->setDebugName(name);
|
|
nodes[name] = node;
|
|
|
|
if (blockInputNames.size() != 0) {
|
|
node->addBlock();
|
|
std::vector<Value*> blockDeps;
|
|
for (const auto& name_ : blockInputNames) {
|
|
blockDeps.push_back(nodes.at(name_)->output());
|
|
}
|
|
|
|
auto block = node->blocks().at(0);
|
|
block->appendNode(graph->create(prim::AutogradZero, blockDeps));
|
|
}
|
|
}
|
|
|
|
bool moveBeforeTopologicallyValid(
|
|
const std::string& toInsert,
|
|
const std::string& insertPoint) {
|
|
std::function<bool(Node*, Node*)> func =
|
|
[this](Node* toInsert, Node* insertPoint) {
|
|
return aliasDb->moveBeforeTopologicallyValid(toInsert, insertPoint);
|
|
};
|
|
return moveWithChecks(toInsert, insertPoint, func);
|
|
}
|
|
|
|
bool moveAfterTopologicallyValid(
|
|
const std::string& toInsert,
|
|
const std::string& insertPoint) {
|
|
std::function<bool(Node*, Node*)> func =
|
|
[this](Node* toInsert, Node* insertPoint) {
|
|
return aliasDb->moveAfterTopologicallyValid(toInsert, insertPoint);
|
|
};
|
|
return moveWithChecks(toInsert, insertPoint, func);
|
|
}
|
|
|
|
bool moveWithChecks(
|
|
const std::string& toInsert,
|
|
const std::string& insertPoint,
|
|
std::function<bool(Node*, Node*)> func) {
|
|
auto n = nodes.at(toInsert);
|
|
auto insert = nodes.at(insertPoint);
|
|
bool isAfter = n->isAfter(insert);
|
|
|
|
std::vector<Node*> originalOrdering;
|
|
Node* original = isAfter ? n->next() : n->prev();
|
|
|
|
auto curNode = original;
|
|
while (curNode != n->owningBlock()->return_node()) {
|
|
originalOrdering.push_back(curNode);
|
|
if (isAfter) {
|
|
curNode = curNode->next();
|
|
} else {
|
|
curNode = curNode->prev();
|
|
}
|
|
}
|
|
|
|
const auto couldMove = func(n, insert);
|
|
// Check the graph is okay
|
|
graph->lint();
|
|
|
|
// If this is the picture of nodes
|
|
// <some nodes> ... toInsert ... <some more nodes> ... insertPoint
|
|
// ^----------^ check that these nodes haven't moved
|
|
curNode = original;
|
|
size_t idx = 0;
|
|
while (curNode != n->owningBlock()->return_node()) {
|
|
AT_ASSERT(originalOrdering[idx] == curNode);
|
|
if (isAfter) {
|
|
curNode = curNode->next();
|
|
} else {
|
|
curNode = curNode->prev();
|
|
}
|
|
idx++;
|
|
}
|
|
|
|
return couldMove;
|
|
}
|
|
|
|
void checkPostCondition(
|
|
const std::string& toInsert,
|
|
const std::string& insertPoint,
|
|
bool after) {
|
|
if (after) {
|
|
AT_ASSERT(nodes.at(toInsert)->prev() == nodes.at(insertPoint));
|
|
} else {
|
|
AT_ASSERT(nodes.at(toInsert)->next() == nodes.at(insertPoint));
|
|
}
|
|
}
|
|
|
|
std::shared_ptr<Graph> graph;
|
|
std::unique_ptr<AliasDb> aliasDb;
|
|
std::unordered_map<std::string, Node*> nodes;
|
|
};
|
|
|
|
void testTopologicalMove() {
|
|
{
|
|
// Check that we are removing `this`'s deps properly when we need to split
|
|
// `this` and deps (see code for what the hell that means)
|
|
TopoMoveTestFixture fixture;
|
|
AT_ASSERT(fixture.moveBeforeTopologicallyValid("q", "s"));
|
|
fixture.checkPostCondition("q", "s", false);
|
|
}
|
|
// Move after
|
|
{
|
|
// Simple move backward
|
|
TopoMoveTestFixture fixture;
|
|
AT_ASSERT(fixture.moveAfterTopologicallyValid("c", "a"));
|
|
fixture.checkPostCondition("c", "a", true);
|
|
}
|
|
{
|
|
// simple invalid move backward
|
|
TopoMoveTestFixture fixture;
|
|
AT_ASSERT(!fixture.moveAfterTopologicallyValid("d", "a"));
|
|
}
|
|
{
|
|
// doesn't actually move anything
|
|
TopoMoveTestFixture fixture;
|
|
AT_ASSERT(fixture.moveAfterTopologicallyValid("f", "e"));
|
|
fixture.checkPostCondition("f", "e", true);
|
|
}
|
|
{
|
|
// move backward with multiple dependencies
|
|
TopoMoveTestFixture fixture;
|
|
AT_ASSERT(fixture.moveAfterTopologicallyValid("e", "c"));
|
|
fixture.checkPostCondition("e", "c", true);
|
|
}
|
|
{
|
|
// Move backward with non-zero working set
|
|
TopoMoveTestFixture fixture;
|
|
AT_ASSERT(fixture.moveAfterTopologicallyValid("k", "f"));
|
|
fixture.checkPostCondition("k", "f", true);
|
|
}
|
|
{
|
|
// Simple move forward
|
|
TopoMoveTestFixture fixture;
|
|
AT_ASSERT(fixture.moveAfterTopologicallyValid("c", "d"));
|
|
fixture.checkPostCondition("c", "d", true);
|
|
}
|
|
{
|
|
// Move forward with non-zero working set
|
|
TopoMoveTestFixture fixture;
|
|
AT_ASSERT(fixture.moveAfterTopologicallyValid("f", "l"));
|
|
fixture.checkPostCondition("f", "l", true);
|
|
}
|
|
|
|
// Move before
|
|
{
|
|
// Simple move forward
|
|
TopoMoveTestFixture fixture;
|
|
AT_ASSERT(fixture.moveBeforeTopologicallyValid("b", "d"));
|
|
fixture.checkPostCondition("b", "d", false);
|
|
}
|
|
{
|
|
// Simple move backward
|
|
TopoMoveTestFixture fixture;
|
|
AT_ASSERT(fixture.moveBeforeTopologicallyValid("c", "a"));
|
|
fixture.checkPostCondition("c", "a", false);
|
|
}
|
|
{
|
|
// doesn't actually move anything
|
|
TopoMoveTestFixture fixture;
|
|
AT_ASSERT(fixture.moveBeforeTopologicallyValid("a", "b"));
|
|
fixture.checkPostCondition("a", "b", false);
|
|
}
|
|
{
|
|
// move forward with deps
|
|
TopoMoveTestFixture fixture;
|
|
AT_ASSERT(fixture.moveBeforeTopologicallyValid("f", "m"));
|
|
fixture.checkPostCondition("f", "m", false);
|
|
}
|
|
{
|
|
// move backward with deps
|
|
TopoMoveTestFixture fixture;
|
|
AT_ASSERT(fixture.moveBeforeTopologicallyValid("l", "f"));
|
|
fixture.checkPostCondition("l", "f", false);
|
|
}
|
|
|
|
// check that dependencies in blocks are recognized
|
|
{
|
|
TopoMoveTestFixture fixture;
|
|
AT_ASSERT(!fixture.moveAfterTopologicallyValid("l", "m"));
|
|
AT_ASSERT(!fixture.moveBeforeTopologicallyValid("m", "l"));
|
|
AT_ASSERT(!fixture.moveAfterTopologicallyValid("n", "l"));
|
|
AT_ASSERT(!fixture.moveBeforeTopologicallyValid("l", "n"));
|
|
}
|
|
|
|
// Test that moveAfter(n) and moveBefore(n->next()) are not necessarily
|
|
// equivalent. Here, the dependency ordering is n -> o -> p. So we can't
|
|
// move `n` after `o`, but we can move `n` before `p` (which pushes `o` after
|
|
// `p`)
|
|
{
|
|
TopoMoveTestFixture fixture;
|
|
AT_ASSERT(!fixture.moveAfterTopologicallyValid("n", "o"));
|
|
AT_ASSERT(fixture.moveBeforeTopologicallyValid("o", "p"));
|
|
fixture.checkPostCondition("o", "p", false);
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
Node* insertIf(
|
|
Graph& g,
|
|
Value* condValue,
|
|
std::function<std::vector<Value*>()> trueInst,
|
|
std::function<std::vector<Value*>()> falseInst) {
|
|
auto if_ = g.insertNode(g.create(prim::If, 0));
|
|
if_->addInput(condValue); // condition value
|
|
auto trueBlock = if_->addBlock();
|
|
auto falseBlock = if_->addBlock();
|
|
{
|
|
// Mutate in true block
|
|
WithInsertPoint g(trueBlock);
|
|
auto outputs = trueInst();
|
|
for (auto output : outputs) {
|
|
trueBlock->registerOutput(output);
|
|
}
|
|
}
|
|
{
|
|
WithInsertPoint g(falseBlock);
|
|
auto outputs = falseInst();
|
|
for (auto output : outputs) {
|
|
falseBlock->registerOutput(output);
|
|
}
|
|
}
|
|
|
|
AT_ASSERT(trueBlock->outputs().size() == falseBlock->outputs().size());
|
|
for (auto output : trueBlock->outputs()) {
|
|
if_->addOutput()->setType(output->type());
|
|
}
|
|
return if_;
|
|
}
|
|
|
|
template <class Exception, class Functor>
|
|
inline void expectThrows(Functor&& functor, const char* expectMessageContains) {
|
|
try {
|
|
std::forward<Functor>(functor)();
|
|
} catch (const Exception& e) {
|
|
if (std::string(e.what()).find(expectMessageContains) ==
|
|
std::string::npos) {
|
|
AT_ERROR(
|
|
"Expected error message to contain \"",
|
|
expectMessageContains,
|
|
"\" but error message was: ",
|
|
e.what());
|
|
}
|
|
return;
|
|
}
|
|
AT_ERROR(
|
|
"Expected to throw exception containing \"",
|
|
expectMessageContains,
|
|
"\" but didn't throw");
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void testAliasAnalysis() {
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
auto a = graph->addInput();
|
|
auto b = graph->addInput();
|
|
|
|
// addsB = b + b
|
|
// c = a + b
|
|
// a += b
|
|
// d = c + c
|
|
auto addsB = graph->insert(aten::add, {b, b});
|
|
auto c = graph->insert(aten::add, {a, b});
|
|
auto aMut = graph->insert(aten::add_, {a, b});
|
|
auto d = graph->insert(aten::add, {c, c});
|
|
|
|
graph->lint();
|
|
|
|
AliasDb aliasDb(graph);
|
|
// Can't move past a mutation of a used value
|
|
AT_ASSERT(!aliasDb.moveAfterTopologicallyValid(c->node(), aMut->node()));
|
|
AT_ASSERT(aliasDb.moveAfterTopologicallyValid(d->node(), c->node()));
|
|
|
|
// b should alias to a (since they are both inputs)
|
|
AT_ASSERT(
|
|
!aliasDb.moveAfterTopologicallyValid(addsB->node(), aMut->node()));
|
|
AT_ASSERT(aliasDb.moveAfterTopologicallyValid(addsB->node(), c->node()));
|
|
|
|
graph->lint();
|
|
}
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
auto a = graph->addInput();
|
|
auto b = graph->addInput();
|
|
|
|
auto constant = graph->insertConstant(1);
|
|
auto fresh = graph->insert(aten::rand, {constant});
|
|
auto usesB = graph->insert(aten::add, {b, fresh});
|
|
auto aliasesB = graph->insert(aten::select, {a, constant, constant});
|
|
auto mutatesAliasOfB = graph->insert(aten::add_, {aliasesB, fresh});
|
|
graph->insert(aten::add, {fresh, aliasesB});
|
|
graph->lint();
|
|
|
|
AliasDb aliasDb(graph);
|
|
AT_ASSERT(!aliasDb.moveAfterTopologicallyValid(
|
|
aliasesB->node(), mutatesAliasOfB->node()));
|
|
AT_ASSERT(!aliasDb.moveAfterTopologicallyValid(
|
|
usesB->node(), mutatesAliasOfB->node()));
|
|
}
|
|
{
|
|
// Test moves across inner blocks
|
|
|
|
// a = rand(1)
|
|
// b = rand(1)
|
|
// if True:
|
|
// a.add_(b)
|
|
// c = a + b
|
|
auto graph = std::make_shared<Graph>();
|
|
auto constant = graph->insertConstant(1);
|
|
auto a = graph->insert(aten::rand, {constant});
|
|
auto b = graph->insert(aten::rand, {constant});
|
|
|
|
auto if_ = insertIf(
|
|
*graph,
|
|
constant,
|
|
[&]() -> std::vector<Value*> {
|
|
auto aMut = graph->insert(aten::add_, {a, b});
|
|
return {aMut};
|
|
},
|
|
[&]() -> std::vector<Value*> { return {a}; });
|
|
|
|
auto c = graph->insert(aten::add, {a, b});
|
|
|
|
graph->lint();
|
|
|
|
// we should not be able to move `c` before the if statement, since it
|
|
// may write to `a`.
|
|
AliasDb aliasDb(graph);
|
|
ASSERT_FALSE(aliasDb.moveBeforeTopologicallyValid(c->node(), if_));
|
|
}
|
|
|
|
// test none value does not have writers
|
|
{{auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(
|
|
R"IR(
|
|
graph():
|
|
%opt : Tensor? = prim::Constant()
|
|
%out : Tensor = prim::unchecked_unwrap_optional(%opt)
|
|
%ret.2 : Tensor = aten::div(%out, %out, %out)
|
|
return (%opt, %out, %ret.2)
|
|
)IR",
|
|
&*graph,
|
|
vmap);
|
|
|
|
AliasDb aliasDb(graph);
|
|
AT_ASSERT(!aliasDb.hasWriters(vmap["opt"]->node()));
|
|
}
|
|
} // namespace jit
|
|
|
|
// test safeToIntroduceAliasingRelationship
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%x : Tensor):
|
|
%3 : int = prim::Constant[value=1]()
|
|
%2 : int = prim::Constant[value=0]()
|
|
%b : Tensor = aten::add(%x, %2, %3)
|
|
%c : Tensor = aten::add(%x, %2, %3)
|
|
%d : Tensor = aten::add(%x, %2, %3)
|
|
%e : Tensor = aten::add(%x, %2, %3)
|
|
%f : Tensor[] = prim::ListConstruct(%e)
|
|
%14 : (Tensor, Tensor) = prim::TupleConstruct(%b, %c)
|
|
return (%14)
|
|
)IR",
|
|
&*graph,
|
|
vmap);
|
|
|
|
AliasDb aliasDb(graph);
|
|
// x, b, c escape scope, so we can't introduce an aliasing relationship
|
|
TORCH_INTERNAL_ASSERT(
|
|
!aliasDb.safeToChangeAliasingRelationship(vmap["x"], vmap["b"]));
|
|
TORCH_INTERNAL_ASSERT(
|
|
!aliasDb.safeToChangeAliasingRelationship(vmap["b"], vmap["x"]));
|
|
TORCH_INTERNAL_ASSERT(
|
|
!aliasDb.safeToChangeAliasingRelationship(vmap["b"], vmap["c"]));
|
|
TORCH_INTERNAL_ASSERT(
|
|
!aliasDb.safeToChangeAliasingRelationship(vmap["c"], vmap["b"]));
|
|
|
|
// e aliases the wildcard set because it's contained in a list
|
|
TORCH_INTERNAL_ASSERT(
|
|
!aliasDb.safeToChangeAliasingRelationship(vmap["e"], vmap["x"]));
|
|
TORCH_INTERNAL_ASSERT(
|
|
!aliasDb.safeToChangeAliasingRelationship(vmap["x"], vmap["e"]));
|
|
|
|
// d is a temporary with no writers, safe to change aliasing relationship here
|
|
TORCH_INTERNAL_ASSERT(
|
|
aliasDb.safeToChangeAliasingRelationship(vmap["c"], vmap["d"]));
|
|
TORCH_INTERNAL_ASSERT(
|
|
aliasDb.safeToChangeAliasingRelationship(vmap["d"], vmap["c"]));
|
|
}
|
|
} // namespace torch
|
|
|
|
void testWriteTracking() {
|
|
RegisterOperators reg({Operator(
|
|
"prim::creates_alias(Tensor(a) x) -> Tensor(a)",
|
|
[](Stack& s) { return 0; },
|
|
aliasAnalysisFromSchema())});
|
|
const auto creates_alias = Symbol::fromQualString("prim::creates_alias");
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
auto a = graph->addInput();
|
|
auto b = graph->addInput();
|
|
|
|
// aten::add(%b, %b)
|
|
// aten::add_(%a, %b)
|
|
// foo::creates_alias(%a)
|
|
auto pureNode = graph->insert(aten::add, {b, b})->node();
|
|
auto writingNode = graph->insert(aten::add_, {a, b})->node();
|
|
auto node3 = graph->insert(creates_alias, {a})->node();
|
|
auto aAlias = node3->output();
|
|
|
|
graph->lint();
|
|
|
|
AliasDb aliasDb(graph);
|
|
ASSERT_TRUE(aliasDb.mayAlias(aAlias, a));
|
|
ASSERT_TRUE(aliasDb.mayAlias(a, b));
|
|
ASSERT_FALSE(
|
|
aliasDb.writesToAlias(pureNode, std::unordered_set<const Value*>{a}));
|
|
ASSERT_FALSE(
|
|
aliasDb.writesToAlias(pureNode, std::unordered_set<const Value*>{b}));
|
|
ASSERT_TRUE(aliasDb.writesToAlias(
|
|
writingNode, std::unordered_set<const Value*>{a}));
|
|
ASSERT_TRUE(aliasDb.writesToAlias(
|
|
writingNode, std::unordered_set<const Value*>{a, b}));
|
|
ASSERT_TRUE(aliasDb.writesToAlias(
|
|
writingNode, std::unordered_set<const Value*>{aAlias}));
|
|
}
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
parseIR(
|
|
R"IR(
|
|
graph(%x: Tensor):
|
|
%b : Tensor = aten::relu_(%x)
|
|
return (%b)
|
|
)IR",
|
|
&*graph);
|
|
auto node_iter = graph->block()->nodes().begin();
|
|
auto relu = *node_iter;
|
|
AliasDb aliasDb(graph);
|
|
AT_ASSERT(aliasDb.isMutable(relu));
|
|
}
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
parseIR(
|
|
R"IR(
|
|
graph(%x: Tensor, %y : Tensor):
|
|
%b : Tensor = aten::mul(%x, %y)
|
|
return (%b)
|
|
)IR",
|
|
&*graph);
|
|
auto node_iter = graph->block()->nodes().begin();
|
|
auto mul = *node_iter;
|
|
AliasDb aliasDb(graph);
|
|
AT_ASSERT(!aliasDb.isMutable(mul));
|
|
}
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%x: Tensor, %y : Tensor):
|
|
%c1 : int = prim::Constant[value=1]()
|
|
%b : Tensor = aten::add_(%x, %y, %c1)
|
|
return (%b)
|
|
)IR",
|
|
&*graph,
|
|
vmap);
|
|
auto add = vmap["b"]->node();
|
|
AliasDb aliasDb(graph);
|
|
AT_ASSERT(aliasDb.hasWriters(add));
|
|
AT_ASSERT(aliasDb.isMutable(add));
|
|
}
|
|
}
|
|
|
|
void testContainerAliasing() {
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%inp: Tensor[]):
|
|
%x : str = prim::Constant[value="a"]()
|
|
%y : Tensor = prim::Constant()
|
|
%z : Tensor = prim::Constant()
|
|
%a : (Tensor) = prim::TupleConstruct(%y)
|
|
%b : Dict(str, Tensor) = prim::DictConstruct(%x, %y)
|
|
%c : Tensor[] = prim::ListConstruct(%y)
|
|
return (%a, %b, %c)
|
|
)IR",
|
|
&*graph,
|
|
vmap);
|
|
|
|
auto str_output = vmap["x"];
|
|
auto ten_output = vmap["y"];
|
|
auto local_var = vmap["z"];
|
|
AliasDb aliasDb(graph);
|
|
|
|
AT_ASSERT(graph->outputs().size() == 3);
|
|
for (auto out : graph->outputs()) {
|
|
AT_ASSERT(aliasDb.mayContainAlias(ten_output, out));
|
|
AT_ASSERT(!aliasDb.mayContainAlias(local_var, out));
|
|
}
|
|
|
|
AT_ASSERT(aliasDb.mayContainAlias(ten_output, graph->inputs()));
|
|
AT_ASSERT(!aliasDb.mayContainAlias(local_var, graph->inputs()));
|
|
|
|
AT_ASSERT(aliasDb.mayContainAlias({ten_output}, graph->outputs()));
|
|
AT_ASSERT(!aliasDb.mayContainAlias(str_output, graph->outputs()));
|
|
}
|
|
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
parseIR(
|
|
R"IR(
|
|
graph():
|
|
%x : str = prim::Constant[value="a"]()
|
|
%y : int = prim::Constant[value=1]()
|
|
%a : (int) = prim::TupleConstruct(%y)
|
|
%b : Dict(str, int) = prim::DictConstruct(%x, %y)
|
|
%c : int[] = prim::ListConstruct(%y)
|
|
return (%a, %b, %c)
|
|
)IR",
|
|
&*graph);
|
|
|
|
auto node_iter = graph->block()->nodes().begin();
|
|
node_iter++; // string
|
|
Node* int_node = *node_iter++;
|
|
AliasDb aliasDb(graph);
|
|
|
|
AT_ASSERT(graph->outputs().size() == 3);
|
|
// primitive values don't need to alias container
|
|
for (auto out : graph->outputs()) {
|
|
AT_ASSERT(!aliasDb.mayContainAlias(int_node->output(), out));
|
|
}
|
|
}
|
|
|
|
// Test input aliasing
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
parseIR(
|
|
R"IR(
|
|
graph(%x: Tensor, %y: Tensor):
|
|
%a : (Tensor) = prim::TupleConstruct(%x)
|
|
return (%a)
|
|
)IR",
|
|
&*graph);
|
|
|
|
auto node_iter = graph->block()->nodes().begin();
|
|
auto tuple_node = *node_iter;
|
|
AliasDb aliasDb(graph);
|
|
|
|
for (auto input : graph->inputs()) {
|
|
AT_ASSERT(aliasDb.mayContainAlias(input, tuple_node->output()));
|
|
}
|
|
AT_ASSERT(aliasDb.mayContainAlias(graph->inputs(), graph->outputs()));
|
|
}
|
|
|
|
// Test tuple that doesn't come from construct
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
parseIR(
|
|
R"IR(
|
|
graph(%x : int,
|
|
%y : Tensor,
|
|
%z : Tensor):
|
|
%3 : int = prim::Constant[value=1]()
|
|
%4 : bool = aten::eq(%x, %3)
|
|
%a : (Tensor) = prim::If(%4)
|
|
block0():
|
|
%a.1 : (Tensor) = prim::TupleConstruct(%y)
|
|
-> (%a.1)
|
|
block1():
|
|
%a.2 : (Tensor) = prim::TupleConstruct(%z)
|
|
-> (%a.2)
|
|
return (%a)
|
|
)IR",
|
|
&*graph);
|
|
|
|
AliasDb aliasDb(graph);
|
|
|
|
for (auto input : graph->inputs()) {
|
|
if (input->type() == IntType::get()) {
|
|
continue;
|
|
}
|
|
|
|
AT_ASSERT(aliasDb.mayContainAlias(input, graph->outputs().at(0)));
|
|
}
|
|
}
|
|
|
|
// test nested types
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
parseIR(
|
|
R"IR(
|
|
graph():
|
|
%a : Tensor = prim::MakeTestTensor()
|
|
%a_list : Tensor[] = prim::ListConstruct(%a)
|
|
%b : Tensor = prim::MakeTestTensor()
|
|
%b_list : Tensor[] = prim::ListConstruct(%b)
|
|
%13 : (Tensor[], Tensor[]) = prim::TupleConstruct(%a_list, %b_list)
|
|
return (%13)
|
|
)IR",
|
|
&*graph);
|
|
AliasDb aliasDb(graph);
|
|
auto g_output = graph->outputs().at(0);
|
|
auto list_2 = g_output->node()->inputs().at(0);
|
|
auto list_1 = g_output->node()->inputs().at(1);
|
|
|
|
// TODO FIX assume conservatively for now
|
|
AT_ASSERT(aliasDb.mayContainAlias(list_1, list_2));
|
|
AT_ASSERT(aliasDb.mayContainAlias(list_2, list_1));
|
|
|
|
AT_ASSERT(aliasDb.mayContainAlias(list_1, g_output));
|
|
AT_ASSERT(aliasDb.mayContainAlias(list_2, g_output));
|
|
}
|
|
|
|
// simple example
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
parseIR(
|
|
R"IR(
|
|
graph():
|
|
%0 : Tensor = prim::Constant()
|
|
%1 : Tensor = prim::Constant()
|
|
%13 : (Tensor) = prim::TupleConstruct(%0)
|
|
return (%13)
|
|
)IR",
|
|
&*graph);
|
|
AliasDb aliasDb(graph);
|
|
|
|
auto node_iter = graph->block()->nodes().begin();
|
|
auto first_ten = *node_iter++;
|
|
auto second_ten = *node_iter++;
|
|
auto tup_node = *node_iter;
|
|
|
|
AT_ASSERT(aliasDb.mayContainAlias(first_ten->output(), tup_node->output()));
|
|
AT_ASSERT(
|
|
!aliasDb.mayContainAlias(second_ten->output(), tup_node->output()));
|
|
|
|
std::vector<Value*> first_st = {first_ten->output()};
|
|
std::vector<Value*> second_st = {second_ten->output()};
|
|
std::vector<Value*> tup_st = {tup_node->output()};
|
|
AT_ASSERT(aliasDb.mayContainAlias(first_st, tup_st));
|
|
AT_ASSERT(!aliasDb.mayContainAlias(first_st, second_st));
|
|
AT_ASSERT(!aliasDb.mayContainAlias(second_st, tup_st));
|
|
}
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(
|
|
R"IR(
|
|
graph():
|
|
%x : str = prim::Constant[value="a"]()
|
|
%y : Tensor = prim::Constant()
|
|
%c : Tensor[] = prim::ListConstruct(%y)
|
|
%d : Tensor[] = prim::ListConstruct(%y)
|
|
return (%c, %d)
|
|
)IR",
|
|
&*graph,
|
|
vmap);
|
|
|
|
AliasDb aliasDb(graph);
|
|
auto x = vmap["x"];
|
|
auto c = vmap["c"];
|
|
AT_ASSERT(!aliasDb.mayContainAlias(x, c));
|
|
AT_ASSERT(!aliasDb.mayContainAlias(c, x));
|
|
|
|
auto d = vmap["d"];
|
|
|
|
AT_ASSERT(aliasDb.mayContainAlias(d, c));
|
|
AT_ASSERT(aliasDb.mayContainAlias(c, d));
|
|
}
|
|
{
|
|
// Test list container aliasing
|
|
auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(
|
|
R"IR(
|
|
graph():
|
|
%0 : int = prim::Constant[value=2]()
|
|
%1 : int = prim::Constant[value=3]()
|
|
%2 : int[] = prim::ListConstruct(%0, %1)
|
|
%x : Tensor = prim::MakeTestTensor()
|
|
%12 : int[] = prim::ListConstruct(%0, %1)
|
|
%y : Tensor = prim::MakeTestTensor()
|
|
%22 : int[] = prim::ListConstruct(%0, %1)
|
|
%z : Tensor = prim::MakeTestTensor()
|
|
%32 : int[] = prim::ListConstruct(%0, %1)
|
|
%fresh : Tensor = prim::MakeTestTensor()
|
|
%foo : Tensor[] = prim::ListConstruct(%x, %y)
|
|
%43 : Tensor[] = aten::append(%foo, %z)
|
|
return ()
|
|
)IR",
|
|
graph.get(),
|
|
vmap);
|
|
AliasDb aliasDb(graph);
|
|
auto x = vmap["x"];
|
|
auto y = vmap["y"];
|
|
auto z = vmap["z"];
|
|
// Tensors x, y, and z went into a list, so they all may alias each other.
|
|
ASSERT_TRUE(aliasDb.mayAlias(x, y));
|
|
ASSERT_TRUE(aliasDb.mayAlias(y, z));
|
|
ASSERT_TRUE(aliasDb.mayAlias(x, z));
|
|
|
|
// But we know `fresh` didn't go into a list, so x, y, and z should not
|
|
// alias it.
|
|
auto fresh = vmap["fresh"];
|
|
ASSERT_FALSE(aliasDb.mayAlias(x, fresh));
|
|
ASSERT_FALSE(aliasDb.mayAlias(y, fresh));
|
|
ASSERT_FALSE(aliasDb.mayAlias(z, fresh));
|
|
}
|
|
{
|
|
// test "conservative" analysis writes to the inside of a container.
|
|
auto ops = torch::RegisterOperators(
|
|
"custom::conservative", [](torch::List<at::Tensor> in) { return in; });
|
|
|
|
auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(
|
|
R"IR(
|
|
graph():
|
|
%0 : int = prim::Constant[value=2]()
|
|
%1 : int = prim::Constant[value=3]()
|
|
%2 : int[] = prim::ListConstruct(%0, %1)
|
|
%11 : Tensor = prim::MakeTestTensor()
|
|
%12 : Tensor[] = prim::ListConstruct(%11)
|
|
%out : Tensor[] = custom::conservative(%12)
|
|
%ret.2 : Tensor = aten::div(%11, %11)
|
|
return ()
|
|
)IR",
|
|
graph.get(),
|
|
vmap);
|
|
AliasDb aliasDb(graph);
|
|
auto conservativeOp = vmap["out"]->node();
|
|
auto tensor = vmap["11"];
|
|
ASSERT_TRUE(aliasDb.writesToAlias(conservativeOp, ValueSet{tensor}));
|
|
}
|
|
{
|
|
auto ops = torch::RegisterOperators().op(
|
|
"uses::list",
|
|
torch::RegisterOperators::options()
|
|
.catchAllKernel([](torch::List<at::Tensor> in) {
|
|
return torch::rand({2, 3});
|
|
})
|
|
.aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION));
|
|
// Write to the inside of a list. Check that we can't reorder a
|
|
// print across it.
|
|
auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(
|
|
R"IR(
|
|
graph():
|
|
%35 : int = prim::Constant[value=1]()
|
|
%0 : int = prim::Constant[value=2]()
|
|
%1 : int = prim::Constant[value=3]()
|
|
%23 : int = prim::Constant[value=0]()
|
|
%2 : int[] = prim::ListConstruct(%0, %1)
|
|
%11 : Tensor = prim::MakeTestTensor()
|
|
%12 : int[] = prim::ListConstruct(%0, %1)
|
|
%21 : Tensor = prim::MakeTestTensor()
|
|
%l : Tensor[] = prim::ListConstruct(%11, %21)
|
|
%24 : Tensor = aten::select(%l, %23)
|
|
%25 : int[] = prim::ListConstruct(%0, %1)
|
|
%34 : Tensor = prim::MakeTestTensor()
|
|
%36 : Tensor = aten::add_(%24, %34, %35)
|
|
%37 : Tensor = uses::list(%l)
|
|
return (%37)
|
|
)IR",
|
|
graph.get(),
|
|
vmap);
|
|
AliasDb aliasDb(graph);
|
|
auto listUse = vmap["37"]->node();
|
|
auto internalWrite = vmap["36"]->node();
|
|
ASSERT_FALSE(aliasDb.moveBeforeTopologicallyValid(listUse, internalWrite));
|
|
}
|
|
{
|
|
// The same as above, but with a nested list
|
|
auto ops = torch::RegisterOperators().op(
|
|
"uses::list",
|
|
torch::RegisterOperators::options()
|
|
.catchAllKernel([](torch::List<at::Tensor> in) {
|
|
return torch::rand({2, 3});
|
|
})
|
|
.aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION));
|
|
// Write to the inside of a list. Check that we can't reorder a
|
|
// print across it.
|
|
auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(
|
|
R"IR(
|
|
graph():
|
|
%38 : int = prim::Constant[value=1]()
|
|
%0 : int = prim::Constant[value=2]()
|
|
%1 : int = prim::Constant[value=3]()
|
|
%24 : int = prim::Constant[value=0]()
|
|
%2 : int[] = prim::ListConstruct(%0, %1)
|
|
%11 : Tensor = prim::MakeTestTensor()
|
|
%12 : int[] = prim::ListConstruct(%0, %1)
|
|
%21 : Tensor = prim::MakeTestTensor()
|
|
%l : Tensor[] = prim::ListConstruct(%11, %21)
|
|
%25 : Tensor = aten::select(%l, %24)
|
|
%27 : Tensor = aten::select(%25, %24, %24)
|
|
%28 : int[] = prim::ListConstruct(%0, %1)
|
|
%37 : Tensor = prim::MakeTestTensor()
|
|
%39 : Tensor = aten::add_(%27, %37, %38)
|
|
%40 : Tensor = uses::list(%l)
|
|
return (%40)
|
|
)IR",
|
|
graph.get(),
|
|
vmap);
|
|
AliasDb aliasDb(graph);
|
|
auto listUse = vmap["40"]->node();
|
|
auto internalWrite = vmap["39"]->node();
|
|
ASSERT_FALSE(aliasDb.moveBeforeTopologicallyValid(listUse, internalWrite));
|
|
}
|
|
}
|
|
|
|
void testWildcards() {
|
|
RegisterOperators reg({Operator(
|
|
"prim::returns_wildcard(Tensor a) -> Tensor(*)",
|
|
[](Stack& stack) { return 0; },
|
|
aliasAnalysisFromSchema()),
|
|
Operator(
|
|
"prim::writes(Tensor(z!) a) -> Tensor(a)",
|
|
[](Stack& stack) { return 0; },
|
|
aliasAnalysisFromSchema())});
|
|
const auto returns_wildcard =
|
|
Symbol::fromQualString("prim::returns_wildcard");
|
|
const auto writes = Symbol::fromQualString("prim::writes");
|
|
|
|
auto graph = std::make_shared<Graph>();
|
|
const auto a = graph->addInput();
|
|
|
|
const auto constant = graph->insertConstant(1);
|
|
const auto fresh = graph->insert(aten::rand, {constant});
|
|
const auto fresh2 = graph->insert(aten::rand, {constant});
|
|
const auto wildcard = graph->insert(returns_wildcard, {fresh});
|
|
|
|
{
|
|
graph->lint();
|
|
AliasDb aliasDb(graph);
|
|
|
|
ASSERT_FALSE(aliasDb.mayAlias(a, fresh));
|
|
ASSERT_FALSE(aliasDb.mayAlias(wildcard, fresh));
|
|
ASSERT_TRUE(aliasDb.mayAlias(wildcard, a));
|
|
ASSERT_FALSE(aliasDb.mayAlias(ValueSet{wildcard}, ValueSet{}));
|
|
ASSERT_FALSE(aliasDb.hasWriters(wildcard->node()));
|
|
}
|
|
|
|
graph->insert(writes, {fresh2})->node();
|
|
{
|
|
graph->lint();
|
|
AliasDb aliasDb(graph);
|
|
ASSERT_FALSE(aliasDb.hasWriters(wildcard->node()));
|
|
}
|
|
|
|
const auto wildcardWrite = graph->insert(writes, {wildcard})->node();
|
|
{
|
|
graph->lint();
|
|
AliasDb aliasDb(graph);
|
|
// Test writes to wildcards
|
|
ASSERT_FALSE(aliasDb.writesToAlias(
|
|
wildcardWrite, std::unordered_set<const Value*>{fresh}));
|
|
ASSERT_FALSE(aliasDb.writesToAlias(
|
|
wildcardWrite, std::unordered_set<const Value*>{fresh2}));
|
|
ASSERT_TRUE(aliasDb.writesToAlias(
|
|
wildcardWrite, std::unordered_set<const Value*>{a}));
|
|
ASSERT_TRUE(aliasDb.hasWriters(wildcard->node()));
|
|
}
|
|
|
|
// test that wildcards are correctly divided by type
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%ten_list : Tensor[], %int_list : int[], %opt_ten_list : Tensor[]?):
|
|
%ten : Tensor = prim::Constant()
|
|
%4 : Tensor[] = aten::append(%ten_list, %ten)
|
|
%ten_ten_list : Tensor[][] = prim::Constant()
|
|
%int_int_list : int[][] = prim::Constant()
|
|
return ()
|
|
)IR",
|
|
&*graph,
|
|
vmap);
|
|
AliasDb aliasDb(graph);
|
|
auto opt_ten_list = vmap["opt_ten_list"];
|
|
auto ten_list = vmap["ten_list"];
|
|
auto int_list = vmap["int_list"];
|
|
AT_ASSERT(!aliasDb.hasWriters(int_list));
|
|
AT_ASSERT(aliasDb.hasWriters(opt_ten_list));
|
|
AT_ASSERT(aliasDb.hasWriters(ten_list));
|
|
AT_ASSERT(!aliasDb.mayContainAlias(int_list, opt_ten_list));
|
|
AT_ASSERT(aliasDb.mayContainAlias(ten_list, opt_ten_list));
|
|
AT_ASSERT(aliasDb.mayAlias(ten_list, opt_ten_list));
|
|
|
|
auto list_of_tensor_lists = vmap["ten_ten_list"];
|
|
AT_ASSERT(aliasDb.mayContainAlias(ten_list, list_of_tensor_lists));
|
|
AT_ASSERT(aliasDb.mayContainAlias(ten_list, vmap["ten"]));
|
|
|
|
AT_ASSERT(
|
|
!aliasDb.mayContainAlias(vmap["int_int_list"], list_of_tensor_lists));
|
|
}
|
|
|
|
// test invariant container aliasing
|
|
// the containers of different type cannot alias each other,
|
|
// however they may contain elements which alias each other
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%ten_list : Tensor[], %ten_opt_list : Tensor?[]):
|
|
%ten : Tensor = prim::Constant()
|
|
%4 : Tensor[] = aten::append(%ten_list, %ten)
|
|
return ()
|
|
)IR",
|
|
&*graph,
|
|
vmap);
|
|
AliasDb aliasDb(graph);
|
|
auto ten_opt_list = vmap["ten_opt_list"];
|
|
auto ten_list = vmap["ten_list"];
|
|
AT_ASSERT(!aliasDb.hasWriters(ten_opt_list));
|
|
AT_ASSERT(aliasDb.hasWriters(ten_list));
|
|
AT_ASSERT(aliasDb.mayContainAlias(ten_list, ten_opt_list));
|
|
AT_ASSERT(!aliasDb.mayAlias(ten_list, ten_opt_list));
|
|
}
|
|
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%float_3D : Float(*, *, *), %float_2D : Float(*, *)):
|
|
return ()
|
|
)IR",
|
|
&*graph,
|
|
vmap);
|
|
AliasDb aliasDb(graph);
|
|
AT_ASSERT(aliasDb.mayAlias(vmap["float_3D"], vmap["float_2D"]));
|
|
}
|
|
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%float_3D_list : Float(*, *, *)[], %float_2D_list : Float(*, *)[], %ten: Tensor):
|
|
return ()
|
|
)IR",
|
|
&*graph,
|
|
vmap);
|
|
AliasDb aliasDb(graph);
|
|
AT_ASSERT(aliasDb.mayAlias(vmap["float_3D_list"], vmap["float_2D_list"]));
|
|
AT_ASSERT(aliasDb.mayContainAlias(vmap["float_3D_list"], vmap["ten"]));
|
|
AT_ASSERT(aliasDb.mayContainAlias(vmap["float_2D_list"], vmap["ten"]));
|
|
}
|
|
}
|
|
|
|
void testMemoryDAG() {
|
|
auto graph = std::make_shared<Graph>();
|
|
const Value* aValue = graph->addInput();
|
|
const Value* bValue = graph->addInput();
|
|
const Value* cValue = graph->addInput();
|
|
const Value* dValue = graph->addInput();
|
|
const Value* eValue = graph->addInput();
|
|
const Value* fValue = graph->addInput();
|
|
const Value* gValue = graph->addInput();
|
|
|
|
{
|
|
// a <- b <- c
|
|
// b <- d
|
|
// a <- e
|
|
// f <- e
|
|
// g is by itself
|
|
auto t = std::make_unique<MemoryDAGBuilder>();
|
|
auto a = t->makeFreshValue(aValue);
|
|
auto b = t->makeFreshValue(bValue);
|
|
auto c = t->makeFreshValue(cValue);
|
|
auto d = t->makeFreshValue(dValue);
|
|
auto e = t->makeFreshValue(eValue);
|
|
auto f = t->makeFreshValue(fValue);
|
|
auto g = t->makeFreshValue(gValue);
|
|
t->makePointerTo(b, a);
|
|
t->makePointerTo(c, b);
|
|
t->makePointerTo(d, b);
|
|
t->makePointerTo(e, a);
|
|
t->makePointerTo(e, f);
|
|
|
|
auto dag = std::make_unique<MemoryDAG>(std::move(t));
|
|
|
|
/**
|
|
* Test mayAlias()
|
|
*/
|
|
// Values should alias themselves
|
|
ASSERT_TRUE(dag->mayAlias(a, a));
|
|
ASSERT_TRUE(dag->mayAlias(g, g));
|
|
|
|
// Values that point to the same location should alias
|
|
ASSERT_TRUE(dag->mayAlias(a, b));
|
|
ASSERT_TRUE(dag->mayAlias(a, c));
|
|
ASSERT_TRUE(dag->mayAlias(c, d));
|
|
|
|
// e may point to a OR f
|
|
ASSERT_TRUE(dag->mayAlias(e, a));
|
|
ASSERT_TRUE(dag->mayAlias(e, f));
|
|
// But a and f don't alias
|
|
ASSERT_FALSE(dag->mayAlias(a, f));
|
|
}
|
|
{
|
|
// x(y) -> x contains y
|
|
|
|
// b(a)
|
|
// c(a)
|
|
auto t = std::make_unique<MemoryDAGBuilder>();
|
|
auto a = t->makeFreshValue(aValue);
|
|
auto b = t->makeFreshValue(bValue);
|
|
t->addToContainedElements(a, b);
|
|
|
|
auto c = t->makeFreshValue(cValue);
|
|
t->addToContainedElements(a, c);
|
|
|
|
auto dag = std::make_unique<MemoryDAG>(std::move(t));
|
|
AT_ASSERT(dag->mayContainAlias(a, b));
|
|
AT_ASSERT(dag->mayContainAlias(b, a));
|
|
|
|
AT_ASSERT(dag->mayContainAlias(a, c));
|
|
AT_ASSERT(dag->mayContainAlias(c, a));
|
|
|
|
AT_ASSERT(dag->mayContainAlias(b, c));
|
|
AT_ASSERT(dag->mayContainAlias(c, b));
|
|
|
|
// containers contain an element in themselves
|
|
AT_ASSERT(dag->mayContainAlias(b, b));
|
|
AT_ASSERT(dag->mayContainAlias(c, c));
|
|
AT_ASSERT(dag->mayContainAlias(a, a));
|
|
}
|
|
{
|
|
// b(a)
|
|
// c(a)
|
|
// d(b(a))
|
|
auto t = std::make_unique<MemoryDAGBuilder>();
|
|
auto a = t->makeFreshValue(aValue);
|
|
auto b = t->makeFreshValue(bValue);
|
|
t->addToContainedElements(a, b);
|
|
|
|
auto c = t->makeFreshValue(cValue);
|
|
t->addToContainedElements(a, c);
|
|
|
|
auto d = t->makeFreshValue(dValue);
|
|
t->addToContainedElements(b, d);
|
|
|
|
auto dag = std::make_unique<MemoryDAG>(std::move(t));
|
|
AT_ASSERT(dag->mayContainAlias(b, d));
|
|
AT_ASSERT(dag->mayContainAlias(d, b));
|
|
|
|
AT_ASSERT(dag->mayContainAlias(c, d));
|
|
AT_ASSERT(dag->mayContainAlias(d, c));
|
|
|
|
AT_ASSERT(dag->mayContainAlias(a, d));
|
|
}
|
|
{
|
|
// f(e)
|
|
auto t = std::make_unique<MemoryDAGBuilder>();
|
|
auto a = t->makeFreshValue(aValue);
|
|
auto b = t->makeFreshValue(bValue);
|
|
t->addToContainedElements(a, b);
|
|
|
|
auto c = t->makeFreshValue(cValue);
|
|
t->addToContainedElements(a, c);
|
|
|
|
auto d = t->makeFreshValue(dValue);
|
|
t->addToContainedElements(b, d);
|
|
|
|
auto f = t->makeFreshValue(aValue);
|
|
auto e = t->makeFreshValue(bValue);
|
|
|
|
t->addToContainedElements(f, e);
|
|
|
|
auto dag = std::make_unique<MemoryDAG>(std::move(t));
|
|
for (auto elem : {a, b, c, d}) {
|
|
AT_ASSERT(!dag->mayContainAlias(f, elem));
|
|
AT_ASSERT(!dag->mayContainAlias(e, elem));
|
|
}
|
|
}
|
|
}
|
|
|
|
void testAliasRegistration() {
|
|
{
|
|
auto registry = torch::RegisterOperators().op(
|
|
"foo::rand1",
|
|
torch::RegisterOperators::options()
|
|
.catchAllKernel([](at::Tensor) -> at::Tensor {
|
|
return at::rand({2, 2});
|
|
})
|
|
.aliasAnalysis(AliasAnalysisKind::CONSERVATIVE));
|
|
const auto rand_op = Symbol::fromQualString("foo::rand1");
|
|
auto graph = std::make_shared<Graph>();
|
|
auto a = graph->addInput();
|
|
auto b = graph->insert(rand_op, {a});
|
|
AliasDb aliasDb(graph);
|
|
// Conservatively we assume there is a reference
|
|
ASSERT_TRUE(aliasDb.mayAlias(a, b));
|
|
}
|
|
{
|
|
auto registry = torch::RegisterOperators().op(
|
|
"foo::rand2(Tensor arg1) -> Tensor",
|
|
torch::RegisterOperators::options()
|
|
.catchAllKernel([](at::Tensor) -> at::Tensor {
|
|
return at::rand({2, 2});
|
|
})
|
|
.aliasAnalysis(AliasAnalysisKind::CONSERVATIVE));
|
|
const auto rand_op = Symbol::fromQualString("foo::rand2");
|
|
auto graph = std::make_shared<Graph>();
|
|
auto a = graph->addInput();
|
|
auto b = graph->insert(rand_op, {a});
|
|
AliasDb aliasDb(graph);
|
|
// Conservatively we assume there is a reference
|
|
ASSERT_TRUE(aliasDb.mayAlias(a, b));
|
|
}
|
|
{
|
|
auto registry = torch::RegisterOperators().op(
|
|
"foo::rand3(Tensor(a) arg1) -> Tensor(b)",
|
|
torch::RegisterOperators::options()
|
|
.catchAllKernel([](at::Tensor) -> at::Tensor {
|
|
return at::rand({2, 2});
|
|
})
|
|
.aliasAnalysis(AliasAnalysisKind::CONSERVATIVE));
|
|
|
|
const auto rand_op = Symbol::fromQualString("foo::rand3");
|
|
auto graph = std::make_shared<Graph>();
|
|
auto a = graph->addInput();
|
|
graph->insert(rand_op, {a});
|
|
|
|
// Registration time is okay, but throw exception when fetch from
|
|
// registration.
|
|
expectThrows<c10::Error>(
|
|
[&graph] { AliasDb aliasDb(graph); },
|
|
"Tried to register operator foo::rand3(Tensor(a) arg1) -> (Tensor(b)) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
|
|
}
|
|
{
|
|
auto registry = torch::RegisterOperators().op(
|
|
"foo::rand4(Tensor(a) arg1) -> Tensor(a)",
|
|
torch::RegisterOperators::options()
|
|
.catchAllKernel([](at::Tensor) -> at::Tensor {
|
|
return at::rand({2, 2});
|
|
})
|
|
.aliasAnalysis(AliasAnalysisKind::CONSERVATIVE));
|
|
const auto rand_op = Symbol::fromQualString("foo::rand4");
|
|
auto graph = std::make_shared<Graph>();
|
|
auto a = graph->addInput();
|
|
graph->insert(rand_op, {a});
|
|
|
|
// Registration time is okay, but throw exception when fetch from
|
|
// registration.
|
|
expectThrows<c10::Error>(
|
|
[&graph] { AliasDb aliasDb(graph); },
|
|
"Tried to register operator foo::rand4(Tensor(a) arg1) -> (Tensor(a)) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
|
|
}
|
|
{
|
|
expectThrows<c10::Error>(
|
|
[] {
|
|
torch::RegisterOperators().op(
|
|
"foo::rand5",
|
|
torch::RegisterOperators::options()
|
|
.catchAllKernel([](at::Tensor) -> at::Tensor {
|
|
return at::rand({2, 2});
|
|
})
|
|
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA));
|
|
},
|
|
"Tried to register operator foo::rand5(Tensor _0) -> (Tensor _0) with AliasAnalysisKind::FROM_SCHEMA, but the schema is inferred");
|
|
}
|
|
{
|
|
auto registry = torch::RegisterOperators().op(
|
|
"foo::rand6(Tensor arg1) -> Tensor",
|
|
torch::RegisterOperators::options()
|
|
.catchAllKernel([](at::Tensor) -> at::Tensor {
|
|
return at::rand({2, 2});
|
|
})
|
|
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA));
|
|
const auto rand_op = Symbol::fromQualString("foo::rand6");
|
|
auto graph = std::make_shared<Graph>();
|
|
auto a = graph->addInput();
|
|
auto b = graph->insert(rand_op, {a});
|
|
AliasDb aliasDb(graph);
|
|
// The schema doesn't contain alias information, which means it's pure
|
|
// (meh!)
|
|
ASSERT_FALSE(aliasDb.mayAlias(a, b));
|
|
}
|
|
{
|
|
auto registry = torch::RegisterOperators().op(
|
|
"foo::rand7(Tensor(a) arg1) -> Tensor(a)",
|
|
torch::RegisterOperators::options()
|
|
.catchAllKernel([](at::Tensor t) -> at::Tensor { return t * 2; })
|
|
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA));
|
|
const auto rand_op = Symbol::fromQualString("foo::rand7");
|
|
|
|
auto graph = std::make_shared<Graph>();
|
|
auto a = graph->addInput();
|
|
auto b = graph->insert(rand_op, {a});
|
|
AliasDb aliasDb(graph);
|
|
// The schema has an alias reference
|
|
ASSERT_TRUE(aliasDb.mayAlias(a, b));
|
|
}
|
|
{
|
|
auto registry = torch::RegisterOperators().op(
|
|
"foo::rand8(Tensor(a) arg1) -> Tensor(b)",
|
|
torch::RegisterOperators::options()
|
|
.catchAllKernel([](at::Tensor t) -> at::Tensor { return t * 2; })
|
|
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA));
|
|
const auto rand_op = Symbol::fromQualString("foo::rand8");
|
|
auto graph = std::make_shared<Graph>();
|
|
auto a = graph->addInput();
|
|
auto b = graph->insert(rand_op, {a});
|
|
AliasDb aliasDb(graph);
|
|
// The schema does not have an alias reference
|
|
ASSERT_FALSE(aliasDb.mayAlias(a, b));
|
|
}
|
|
{
|
|
auto registry = torch::RegisterOperators().op(
|
|
"foo::rand9",
|
|
torch::RegisterOperators::options()
|
|
.catchAllKernel([](at::Tensor) -> at::Tensor {
|
|
return at::rand({2, 2});
|
|
})
|
|
.aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION));
|
|
const auto rand_op = Symbol::fromQualString("foo::rand9");
|
|
auto graph = std::make_shared<Graph>();
|
|
auto a = graph->addInput();
|
|
auto b = graph->insert(rand_op, {a});
|
|
AliasDb aliasDb(graph);
|
|
// The schema is pure, there cannot be any alias
|
|
ASSERT_FALSE(aliasDb.mayAlias(a, b));
|
|
}
|
|
{
|
|
auto registry = torch::RegisterOperators().op(
|
|
"foo::rand10(Tensor arg1) -> Tensor",
|
|
torch::RegisterOperators::options()
|
|
.catchAllKernel([](at::Tensor) -> at::Tensor {
|
|
return at::rand({2, 2});
|
|
})
|
|
.aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION));
|
|
const auto rand_op = Symbol::fromQualString("foo::rand10");
|
|
auto graph = std::make_shared<Graph>();
|
|
auto a = graph->addInput();
|
|
auto b = graph->insert(rand_op, {a});
|
|
AliasDb aliasDb(graph);
|
|
// The schema is pure, there cannot be any alias
|
|
ASSERT_FALSE(aliasDb.mayAlias(a, b));
|
|
}
|
|
{
|
|
auto registry = torch::RegisterOperators().op(
|
|
"foo::rand11(Tensor(a) arg1) -> Tensor(a)",
|
|
torch::RegisterOperators::options()
|
|
.catchAllKernel([](at::Tensor t) -> at::Tensor { return t * 2; })
|
|
.aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION));
|
|
const auto rand_op = Symbol::fromQualString("foo::rand11");
|
|
auto graph = std::make_shared<Graph>();
|
|
auto a = graph->addInput();
|
|
graph->insert(rand_op, {a});
|
|
|
|
// Registration time is okay, but throw exception when fetch from
|
|
// registration.
|
|
expectThrows<c10::Error>(
|
|
[&graph] { AliasDb aliasDb(graph); },
|
|
"Tried to register operator foo::rand11(Tensor(a) arg1) -> (Tensor(a)) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
|
|
}
|
|
{
|
|
auto registry = torch::RegisterOperators().op(
|
|
"foo::rand12(Tensor(a) arg1) -> Tensor(b)",
|
|
torch::RegisterOperators::options()
|
|
.catchAllKernel([](at::Tensor t) -> at::Tensor { return t * 2; })
|
|
.aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION));
|
|
const auto rand_op = Symbol::fromQualString("foo::rand12");
|
|
auto graph = std::make_shared<Graph>();
|
|
auto a = graph->addInput();
|
|
graph->insert(rand_op, {a});
|
|
|
|
// Registration time is okay, but throw exception when fetch from
|
|
// registration.
|
|
expectThrows<c10::Error>(
|
|
[&graph] { AliasDb aliasDb(graph); },
|
|
"Tried to register operator foo::rand12(Tensor(a) arg1) -> (Tensor(b)) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
|
|
}
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|