pytorch/torch/csrc/jit/ir/alias_analysis.cpp
Meghan Lele cb8d9f99aa [JIT] Implement Tensor.tolist() (#33472)
Summary:
**Summary**
This commit adds an implementation of `Tensor.tolist()` to the JIT interpreter.

**Testing**
This commit adds several unit tests that test that this function works correctly for
0D, 1D, 2D and 3D tensors of type `float`, `int` and `bool`.

```
(base) meghanl-mbp:pytorch meghanl$ python test/test_jit.py TestList.test_to_list -v
Fail to import hypothesis in common_utils, tests are not derandomized
test_to_list (jit.test_list_dict.TestList)
Unit tests for Tensor.tolist() function. ... ok

----------------------------------------------------------------------
Ran 1 test in 0.329s

OK
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33472

Differential Revision: D20109738

Pulled By: SplitInfinity

fbshipit-source-id: a6e3fee5e3201d5e1f0c4ca45048488ae2bf5e33
2020-02-27 21:45:46 -08:00

1332 lines
41 KiB
C++

#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/utils/memory.h>
namespace torch {
namespace jit {
// For any mutable type, map it to a type such that all other types which it can
// alias will be mapped to the same type. This function follows a similar logic
// to `unifyTypes` because any two mutable types which can be unified
// can alias each other.
// getMutableTypePtr(Optional[List[int]]) == getMutableTypePtr([List[int]])
// If a type is not mutable, return nullopt
c10::optional<TypePtr> getMutableTypePtr(const TypePtr& type) {
switch (type->kind()) {
case TypeKind::ListType:
case TypeKind::DictType:
case TypeKind::ClassType:
case TypeKind::TensorType:
return unshapedType(type);
case TypeKind::OptionalType:
return getMutableTypePtr(type->cast<OptionalType>()->getElementType());
case TypeKind::FutureType: {
if (auto elem = getMutableTypePtr(type->cast<FutureType>()->getElementType())) {
return FutureType::create(*elem);
}
return c10::nullopt;
}
case TypeKind::TupleType: {
std::vector<TypePtr> mutable_types;
for (const auto& elem : type->expect<TupleType>()->elements()) {
if (auto mut_elem = getMutableTypePtr(elem)) {
mutable_types.push_back(*mut_elem);
}
}
if (mutable_types.size() == 0) {
return c10::nullopt;
} else {
return TupleType::create(mutable_types);
}
}
default:
return c10::nullopt;
}
}
bool AliasDb::mutableType(const TypePtr& type) {
return getMutableTypePtr(type) != c10::nullopt;
}
// We only need to annotate values that either are mutable or could contain
// mutable types.
bool AliasDb::mutableType(const Value* v) {
return mutableType(v->type());
}
bool AliasDb::isContainerType(const TypePtr& type) {
auto mut_type = getMutableTypePtr(type);
return mut_type && (*mut_type)->containedTypes().size() > 0;
}
AliasDb::~AliasDb() = default;
AliasDb::AliasDb(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {
memoryDAG_ = torch::make_unique<MemoryDAG>();
analyze(graph_);
GRAPH_DEBUG(toString());
}
bool AliasDb::isMutable(Node* n) const {
ValueSet vs;
for (const auto input : n->inputs()) {
vs.insert(input);
}
return writesToAlias(n, vs);
}
bool AliasDb::hasInputWriters(const Node* n) const {
for (const auto input : n->inputs()) {
if (hasWriters(input)) {
return true;
}
}
return false;
}
bool AliasDb::hasOutputWriters(const Node* n) const {
for (const auto output : n->outputs()) {
if (hasWriters(output)) {
return true;
}
}
return false;
}
bool AliasDb::hasWriters(const Node* n) const {
return hasInputWriters(n) || hasOutputWriters(n);
}
bool AliasDb::hasWriters(const Value* v) const {
if (v->mustBeNone()) {
return false;
}
auto it = elementMap_.find(v);
if (it == elementMap_.end()) {
return false;
}
if (isWriteCacheStale_) {
rebuildWriteCache();
}
const auto& el = it->second;
return writeCache_.intersects(el->getMemoryLocations());
}
void AliasDb::getWritesImpl(Node* n, MemoryLocations& ret) const {
if (writeIndex_.count(n)) {
const auto& writes = writeIndex_.at(n);
ret |= writes;
}
for (auto block : n->blocks()) {
for (auto node : block->nodes()) {
getWritesImpl(node, ret);
}
}
}
// Does `n` write to an alias of one of the values in `vs`?
bool AliasDb::writesToAlias(Node* n, const ValueSet& vs) const {
const auto writtenTo = getWrites(n);
if (writtenTo.empty()) {
return false;
}
MemoryLocations locs;
for (const auto v : vs) {
auto it = elementMap_.find(v);
if (it != elementMap_.end()) {
const auto& vlocs = it->second->getMemoryLocations();
if (writtenTo.intersects(vlocs)) {
return true;
}
}
}
return false;
}
MemoryLocations AliasDb::getWrites(Node* n) const {
MemoryLocations writes;
getWritesImpl(n, writes);
return writes;
}
void AliasDb::getReadsImpl(Node* n, MemoryLocations& ret) const {
for (const auto input : n->inputs()) {
auto it = elementMap_.find(input);
if (it != elementMap_.end()) {
auto el = it->second;
// Add all memory locations this element may alias.
ret |= el->getMemoryLocations();
// We also consider memory locations of contained values to be "read".
for (const auto& type : input->type()->containedTypes()) {
if (auto wildcard = getWildcard(type)) {
ret |= wildcard->getMemoryLocations();
}
}
}
}
for (auto block : n->blocks()) {
for (auto node : block->nodes()) {
getReadsImpl(node, ret);
}
}
}
MemoryLocations AliasDb::getReads(Node* n) const {
MemoryLocations reads;
getReadsImpl(n, reads);
return reads;
}
std::string AliasDb::getElementName(const Element* e) const {
if (e->value == nullptr) {
// not the most efficient way, but given the fact there are
// not too many types and even fewer of them will end up in
// wildcardIndex_, we should be fine with a linear search
// each time we hit a wildcard leaf
for (const auto& ent : wildcardIndex_) {
if (ent.second == e) {
return std::string("WILDCARD for type ") + ent.first->str();
}
}
return "WILDCARD";
} else {
return e->value->debugName();
}
}
void AliasDb::dump() const {
std::cout << toString();
}
std::string AliasDb::toString() const {
std::stringstream ss{};
ss << "\n===1. GRAPH===\n";
ss << graph_->toString();
ss << "\n===2. ALIAS DB===\n";
for (const auto& ptrPair : elementMap_) {
const auto element = ptrPair.second;
if (!element->pointsTo.empty()) {
ss << getElementName(element) << " points to: ";
for (const auto pointedTo : element->pointsTo) {
ss << getElementName(memoryDAG_->fromIndex(pointedTo)) << ", ";
}
ss << "\n";
}
if (!element->containedElements.empty()) {
ss << getElementName(element) << " contains: ";
for (const auto contained : element->containedElements) {
ss << getElementName(memoryDAG_->fromIndex(contained)) << ", ";
}
ss << "\n";
}
}
ss << "\n===3. Writes===\n";
for (const auto& pr : writeIndex_) {
const auto node = pr.first;
const auto& values = pr.second;
ss << *node;
ss << " ";
for (const auto value : values) {
ss << getElementName(memoryDAG_->fromIndex(value)) << ", ";
}
ss << "\n";
}
ss << "\n";
return ss.str();
}
void AliasDb::analyze(const std::shared_ptr<Graph>& graph) {
for (auto input : graph->inputs()) {
setWildcard(input);
}
analyze(graph->block());
}
void AliasDb::analyze(Block* block) {
for (auto node : block->nodes()) {
analyze(node);
}
}
void AliasDb::analyze(Node* node) {
analyzeImpl(node);
}
// Returns true if analysis was run using
// the registered analyzer.
bool AliasDb::tryRegisteredAnalysis(Node* node) {
const Operator& op = node->getOperator();
auto analysis = op.aliasAnalysisKind();
if (AliasAnalysisKind::PURE_FUNCTION == analysis) {
analyzeCreator(node);
return true;
}
return false;
}
// The basic strategy is:
// 1. Retrieve alias information for every input.
// 2. Use the node's schema's alias annotations to propgagate alias/write
// information to the outputs. For unschematized nodes, a special analyzer
// will have to be handwritten.
void AliasDb::analyzeImpl(Node* node) {
auto op = node->maybeOperator();
const bool hasSpecialCase = aliasAnalysisHasSpecialCaseFor(node->kind());
if (op) {
const auto analysis = op->aliasAnalysisKind();
const bool registeredAsSpecialCase =
analysis == AliasAnalysisKind::INTERNAL_SPECIAL_CASE;
if (C10_UNLIKELY(registeredAsSpecialCase && !hasSpecialCase)) {
TORCH_INTERNAL_ASSERT(
false,
"Op ",
node->kind().toDisplayString(),
" is registered with AliasAnalysisKind::INTERNAL_SPECIAL_CASE but doesn't have a special case.");
} else if (C10_UNLIKELY(!registeredAsSpecialCase && hasSpecialCase)) {
TORCH_INTERNAL_ASSERT(
false,
"Op ",
node->kind().toDisplayString(),
" has a special case and should be registered with AliasAnalysisKind::INTERNAL_SPECIAL_CASE but is registered with ",
c10::toString(analysis));
}
} else {
TORCH_INTERNAL_ASSERT(
hasSpecialCase,
"We don't have an op for ",
node->kind().toDisplayString(),
" but it isn't a special case.");
}
// These nodes are not schematized, so we need to handle them specially
switch (node->kind()) {
case prim::If:
return analyzeIf(node);
case prim::Loop:
return analyzeLoop(node);
case prim::FusionGroup:
case prim::DifferentiableGraph:
return analyzeSubgraph(node);
case prim::fork:
return analyzeFork(node);
case aten::wait:
return analyzeWait(node);
case prim::GradOf:
return analyzeGradOf(node);
case prim::Constant:
case prim::AutogradZero:
case prim::AutogradAdd:
case prim::FusedConcat:
case prim::MMTreeReduce:
case prim::MMBatchSide:
case prim::BroadcastSizes:
case prim::ChunkSizes:
case prim::Function:
case prim::CreateObject:
case prim::tolist:
return analyzeCreator(node);
case prim::TupleConstruct:
case prim::DictConstruct:
case prim::ListConstruct:
return analyzeContainerConstruct(node);
case prim::TupleUnpack:
case prim::TupleIndex:
case prim::TupleSlice:
case prim::ListUnpack:
case prim::PythonOp:
case prim::GetAttr:
return analyzeExtractor(node);
case prim::unchecked_cast:
return makePointerTo(node->output(), node->input());
case prim::ConstantChunk:
return analyzeChunk(node);
case prim::BroadcastingChunk:
return analyzeBroadcastingChunk(node);
case prim::SetAttr:
return analyzeSetAttr(node);
case prim::profile:
if (node->inputs().size() > 0) {
makePointerTo(node->output(), node->inputs().at(0));
}
return;
case prim::BailOut:
TORCH_INTERNAL_ASSERT(node->inputs().at(0)->node()->kind() ==
prim::BailoutTemplate);
makePointerTo(node->output(), node->inputs().at(1));
return;
case prim::Guard:
makePointerTo(node->output(), node->inputs().at(0));
return;
case prim::CallFunction:
case prim::CallMethod:
// TODO: this can be improved with summarizes of what the function does
// for now we assume the worst
// NB: update safeToChangeAliasingRelationship if changed
return analyzeConservative(node);
case prim::Uninitialized:
giveFreshAlias(node->output());
return;
case prim::Print:
case prim::isinstance:
// These ops do nothing
return;
default:
if (tryRegisteredAnalysis(node)) {
return;
}
}
TORCH_INTERNAL_ASSERT(op, "We should have an op schema if we get to here");
const AliasAnalysisKind analysis = op->aliasAnalysisKind();
TORCH_INTERNAL_ASSERT(
analysis != AliasAnalysisKind::INTERNAL_SPECIAL_CASE &&
!aliasAnalysisHasSpecialCaseFor(node->kind()),
"Special cases should be handled already if we're here.");
if (node->kind().is_aten() || node->kind().is_prim()) {
// TODO There is nothing in the system that relies on aten:: and prim::
// ops using AliasAnalysisKind::FROM_SCHEMA or AliasAnalysisKind::INTERNAL_SPECIAL_CASE,
// but this is the intended behavior for all current ops and a good error check.
// We can consider lifting this constraint later if we have a use case for it.
TORCH_INTERNAL_ASSERT(
analysis == AliasAnalysisKind::FROM_SCHEMA ||
analysis == AliasAnalysisKind::CONSERVATIVE,
"aten:: and prim:: operators should use AliasAnalysisKind::FROM_SCHEMA or "
"AliasAnalysisKind::CONSERVATIVE(if really necessary), but ",
node->kind().toDisplayString(),
" doesn't. Note: Ideally, prim:: operators actually shouldn't have a schema ",
"and then use AliasAnalysisKind::INTERNAL_SPECIAL_CASE instead.");
}
if (analysis == AliasAnalysisKind::CONSERVATIVE) {
// TODO A previous implementation of alias analysis always accessed
// node->schema , which cause the schema caches in the Node class to be
// filled for the full graph. Unfortunately, our JIT passes started relying
// on that, so we need to keep doing this. Details: in
// caffe2/torch/onnx/utils.py, _jit_pass_onnx is called on an invalid JIT
// graph because we called _jit_pass_erase_number_types right before and
// ints are now Tensors instead. So if _jit_pass_onnx tries to look up
// operator schemas, it will crash. However, _jit_pass_constant_propagation,
// which is called before it, runs alias analysis and prefills the schema
// cache in the all Node instances so that _jit_pass_onnx doesn't look up
// operators to get the schemas anymore. We should fix this.
node->schema(); // fill the schema cache in the Node class
return analyzeConservative(node);
}
TORCH_INTERNAL_ASSERT(
analysis == AliasAnalysisKind::FROM_SCHEMA,
"AliasAnalysisKind::CONSERVATIVE/PURE_FUNCTION/INTERNAL_SPECIAL_CASE should already have been handled above");
const auto& schema = node->schema();
// Bind the schema's "formal" alias annotation to the actual values those
// schema arguments represent
std::unordered_map<Symbol, Value*> formalToActual;
for (size_t i = 0; i < schema.arguments().size(); i++) {
const auto& formal = schema.arguments()[i].alias_info();
const auto& actualValue = node->inputs().at(i);
// Skip if there's no alias annotation
if (!formal) {
continue;
}
// If this type cannot alias, continue. Can occur with a VarType schema
if (!mutableType(actualValue)) {
continue;
}
// Do sanity checks on the alias annotation
TORCH_INTERNAL_ASSERT(
formal->containedTypes().size() == 0,
"Composite types for alias analysis not yet supported");
TORCH_INTERNAL_ASSERT(
!formal->isWildcardBefore(),
"Doesn't make sense for a input value to begin as a wildcard");
const auto& formalAlias = formal->beforeSet();
// skip if we've already bound this alias
if (formalToActual.count(formalAlias) != 0) {
continue;
}
// Bind the formal to the actual
formalToActual[formalAlias] = actualValue;
// Record writes
if (formal->isWrite()) {
registerWrite(actualValue, node);
}
// Now deal with sets after the '->'
if (formal->isWildcardAfter()) {
TORCH_INTERNAL_ASSERT(
formal->afterSets().size() == 1,
"If the after set contains a wildcard, "
"there should be no other alias sets specified.");
setWildcard(actualValue);
} else {
// We don't understand anything else in the after yet, so assert there's
// been no change.
TORCH_INTERNAL_ASSERT(formal->beforeSets() == formal->afterSets());
}
}
// Use the formal-actual mapping to give aliases to the outputs
for (size_t i = 0; i < schema.returns().size(); i++) {
const auto actual = node->outputs().at(i);
const auto& formal = schema.returns()[i].alias_info();
if (!formal) {
// This is a fresh tensor
giveFreshAlias(actual);
continue;
}
// If this type cannot alias, continue. Can occur with a VarType schema
if (!mutableType(actual)) {
continue;
}
TORCH_INTERNAL_ASSERT(
formal->containedTypes().size() == 0,
"Composite types for alias analysis not yet supported");
TORCH_INTERNAL_ASSERT(formal->beforeSets() == formal->afterSets());
if (formal->isWildcardBefore()) {
TORCH_INTERNAL_ASSERT(
formal->beforeSets().size() == 1,
"If an output is a wildcard, "
"there should be no other alias sets specified.");
setWildcard(actual);
continue;
}
for (const auto& formalAlias : formal->beforeSets()) {
// If we encounter an alias annotation that wasn't in the inputs:
if (!formalToActual.count(formalAlias)) {
// If this alias is not seen elsewhere and is the only annotation on
// the output, it's equivalent to being fresh:
// e.g. foo(Tensor(a) self) -> Tensor(b)
if (formal->beforeSets().size() == 1) {
giveFreshAlias(actual);
}
// Or it is the form of a|fresh, which we can ignore, taking the
// conservative assumption that the output must alias `a`, e.g
// aten::cuda(Tensor(a) self) -> Tensor(a|fresh)
// Don't assign an alias set in that case.
continue;
}
auto toAlias = formalToActual.at(formalAlias);
makePointerTo(actual, toAlias);
}
// Record writes
if (formal->isWrite()) {
registerWrite(actual, node);
}
}
}
// Register the fact that `n` writes to `v`.
void AliasDb::registerWrite(const Value* v, Node* n) {
if (!mutableType(v)) {
// don't need to register a write if the value isn't mutable
return;
}
auto it = elementMap_.find(v);
TORCH_INTERNAL_ASSERT(
it != elementMap_.end(), "Tried to write to value not in MemoryDAG");
const auto& writtenMemoryLocations = it->second->getMemoryLocations();
writeIndex_[n] |= writtenMemoryLocations;
}
void AliasDb::registerWrite(const Element* e, Node* n) {
TORCH_INTERNAL_ASSERT(
e->pointsTo.empty(),
"Can only register writes to memory location elements");
writeIndex_[n].set(e->index);
}
void AliasDb::analyzeIf(Node* node) {
// For if statements, the alias set of an output is the union of the
// alias sets generated by the if and else block
const auto trueBlock = node->blocks().at(0);
const auto falseBlock = node->blocks().at(1);
analyze(trueBlock);
analyze(falseBlock);
for (size_t i = 0; i < node->outputs().size(); i++) {
const auto nodeOutput = node->outputs()[i];
const auto trueOutput = trueBlock->outputs().at(i);
const auto falseOutput = falseBlock->outputs().at(i);
makePointerTo(nodeOutput, trueOutput);
makePointerTo(nodeOutput, falseOutput);
}
}
void AliasDb::analyzeLoop(Node* node) {
const auto bodyBlock = node->blocks().at(0);
const auto loopCarriedInputs = node->inputs().slice(2); // skip max, cond
const auto blockInputs = bodyBlock->inputs().slice(1); // skip trip
const auto blockOutputs = bodyBlock->outputs().slice(1); // skip trip
TORCH_INTERNAL_ASSERT(loopCarriedInputs.size() == blockInputs.size());
TORCH_INTERNAL_ASSERT(blockOutputs.size() == node->outputs().size());
// Run alias analysis on the loop body, iterating until the block output
// alias info converges.
// Copy node input aliases to block input
mapAliases(blockInputs, loopCarriedInputs);
// Populate block output alias info by analyzing the body
analyze(bodyBlock);
// Copy the alias info from the block output to the node output
mapAliases(node->outputs(), blockOutputs);
}
void AliasDb::analyzeGradOf(Node* node) {
const auto grad_of_block = node->blocks().at(0);
analyze(grad_of_block);
mapAliases(node->outputs(), grad_of_block->outputs());
}
void AliasDb::analyzeSubgraph(Node* node) {
const auto subgraph = node->g(attr::Subgraph).get();
const auto subgraphBlock = subgraph->block();
mapAliases(subgraphBlock->inputs(), node->inputs());
analyze(subgraphBlock);
// Note: the subgraph outputs and node outputs are NOT NECESSARILY the
// same length. Autodifferentiation maybe capture additional outputs in the
// subgraph block.
TORCH_INTERNAL_ASSERT(
subgraphBlock->outputs().size() >= node->outputs().size());
for (size_t i = 0; i < node->outputs().size(); i++) {
makePointerTo(node->outputs()[i], subgraphBlock->outputs()[i]);
}
}
// For nodes that generate a fresh value from nothing
void AliasDb::analyzeCreator(Node* node) {
for (Value* output : node->outputs()) {
giveFreshAlias(output);
}
}
// For nodes that extract values from a composite type. Right now, this just
// gives up and creates wildcards for everything.
void AliasDb::analyzeExtractor(Node* node) {
for (const auto output : node->outputs()) {
setWildcard(output);
}
}
// For torch.chunk(), all returned tensors may alias the input tensor
void AliasDb::analyzeChunk(Node* node) {
for (auto output : node->outputs()) {
makePointerTo(output, node->input());
}
}
void AliasDb::analyzeFork(Node* node) {
for (const auto input : node->inputs()) {
setWildcard(input);
}
// Give the future that the fork emits a fresh value
for (const auto output : node->outputs()) {
giveFreshAlias(output);
}
}
void AliasDb::analyzeWait(Node* node) {
TORCH_INTERNAL_ASSERT(node->kind() == aten::wait);
for (const auto output : node->outputs()) {
setWildcard(output);
}
// the forked subgraph that `wait` is waiting on may write to any of its
// inputs. We don't have a reliable way of recovering the fork inputs, so
// for safety we just register a write to every wildcard.
for (const auto& pr : wildcardIndex_) {
registerWrite(pr.second, node);
}
}
// SetAttr: writes to the `self` field
void AliasDb::analyzeSetAttr(Node* node) {
const auto self = node->inputs().at(0);
TORCH_INTERNAL_ASSERT(self->type()->kind() == TypeKind::ClassType);
registerWrite(self, node);
// Also the value being set must become a wildcard.
const auto newValue = node->inputs().at(1);
setWildcard(newValue);
}
// Used for anything where we do not have accurate alias summaries
// may write to any input and produce wildcards
void AliasDb::analyzeConservative(Node* node) {
for (const auto input : node->inputs()) {
if (!mutableType(input)) {
continue;
}
auto elem = elementMap_.at(input);
registerWrite(input, node);
MemoryLocations mem_locations;
memoryDAG_->collectAllContainedMemoryLocations(elem, mem_locations);
for (unsigned loc : mem_locations) {
auto contained_elem = memoryDAG_->fromIndex(loc);
// we only register writes on memory locations
if (contained_elem->pointsTo.empty()) {
registerWrite(contained_elem, node);
}
}
setWildcard(input);
}
for (const auto output : node->outputs()) {
setWildcard(output);
}
}
// List or dict or tuple: construct: create an aliasing element for the actual
// container, then mark all inputs as wildcards, since they've gone inside the
// container. Then, add the wildcard sets of appropriate type to the contained
// elements of the container.
void AliasDb::analyzeContainerConstruct(Node* node) {
TORCH_INTERNAL_ASSERT(
node->kind() == prim::ListConstruct ||
node->kind() == prim::DictConstruct ||
node->kind() == prim::TupleConstruct);
// tuples which contain immutable types are immutable
if (!mutableType(node->output())) {
return;
}
TORCH_INTERNAL_ASSERT(node->outputs().size() == 1);
auto container = node->output();
giveFreshAlias(container);
auto container_elem = elementMap_.at(container);
for (auto input : node->inputs()) {
auto maybe_wildcard_elem = setWildcard(input);
if (maybe_wildcard_elem) {
memoryDAG_->addToContainedElements(*maybe_wildcard_elem, container_elem);
}
}
}
// BroadcastingChunk: all inputs are broadcasted, and then individually chunked.
// This is an intermediate node used only in the graph fuser.
void AliasDb::analyzeBroadcastingChunk(Node* node) {
auto inputs = node->inputs();
auto outputs = node->outputs();
auto nchunks = node->i(attr::chunks);
for (size_t index = 0; index < inputs.size(); ++index) {
// Each inputs[i] is aliased by exactly `nchunks` distinct output tensors:
// inputs[i] produces chunks outputs[i * nchunks + k] for k in [0..nchunks)
auto output_begin = outputs.begin() + index * nchunks;
for (auto it = output_begin; it != output_begin + nchunks; ++it) {
makePointerTo(*it, inputs.at(index));
}
}
}
bool AliasDb::nonAliasingValue(const Value* elem) const {
// these are values which can point to aliasing types in the graph,
// as with a None value pointing to an optional if node output,
// but will never alias themselves
return elem->mustBeNone() || elem->node()->kind() == prim::Uninitialized;
}
// Register the fact that `from` is a pointer to `to`
void AliasDb::makePointerTo(const Value* from, const Value* to) {
if (nonAliasingValue(from) || nonAliasingValue(to)) {
// if either value is guaranteed to be non-aliasing, we do not need to
// connect the two elements. however, it is invariant that aliasing types
// that are not wildcards have a memory dag element, so we create one if
// needed
giveFreshAlias(from);
giveFreshAlias(to);
return;
}
if (!mutableType(from)) {
TORCH_INTERNAL_ASSERT(!mutableType(to));
return;
}
if (from == to) {
return;
}
// Special case: if `from` is an optional, `to` could be a None. Don't
// create a pointer in that case
if (from->type()->kind() == TypeKind::OptionalType &&
to->type()->kind() == TypeKind::NoneType) {
return;
}
// At this point, we should be dealing with two mutable types.
TORCH_INTERNAL_ASSERT(mutableType(from) && mutableType(to));
auto fromEl = getOrCreateElement(from);
auto toEl = getOrCreateElement(to);
memoryDAG_->makePointerTo(fromEl, toEl);
}
void AliasDb::addToContainedElements(
const Value* elem,
const Value* container) {
if (!mutableType(elem)) {
return;
}
TORCH_INTERNAL_ASSERT(isContainerType(container->type()));
auto elemEl = getOrCreateElement(elem);
auto contEl = getOrCreateElement(container);
memoryDAG_->addToContainedElements(elemEl, contEl);
}
bool AliasDb::mayAlias(const Value* a, const Value* b) const {
if (!mutableType(a) || !mutableType(b)) {
return false;
}
return memoryDAG_->mayAlias(elementMap_.at(a), elementMap_.at(b));
}
bool AliasDb::mayAlias(const ValueSet& a, const ValueSet& b) const {
if (a.empty() || b.empty()) {
return false;
}
// Record all memory locations from group `a`
MemoryLocations aMemLocs;
for (const auto value : a) {
auto it = elementMap_.find(value);
if (it != elementMap_.end()) {
aMemLocs |= it->second->getMemoryLocations();
}
}
// If any of group `b`s memory locations overlap, return true.
for (const auto value : b) {
auto it = elementMap_.find(value);
if (it != elementMap_.end()) {
if (aMemLocs.intersects(it->second->getMemoryLocations())) {
return true;
}
}
}
// No overlap, so group `a` and `b` do not share a memory location
return false;
}
bool AliasDb::mayContainAlias(Value* a, Value* b) const {
const std::vector<Value*> a_vec = {a};
const std::vector<Value*> b_vec = {b};
return mayContainAlias(a_vec, b_vec);
}
std::vector<Element*> AliasDb::getElements(at::ArrayRef<Value*> vs) const {
std::vector<Element*> elements;
for (const auto& val : vs) {
if (mutableType(val)) {
elements.push_back(elementMap_.at(val));
}
}
return elements;
}
bool AliasDb::mayContainAlias(
const at::ArrayRef<Value*> a,
const at::ArrayRef<Value*> b) const {
auto a_elems = getElements(a);
return a_elems.size() == 0 ? false : memoryDAG_->mayContainAlias(a_elems, getElements(b));
}
// Make each value in the `from` list point to its partner in the `to` list
void AliasDb::mapAliases(at::ArrayRef<Value*> from, at::ArrayRef<Value*> to) {
TORCH_INTERNAL_ASSERT(to.size() == from.size());
for (size_t i = 0; i < to.size(); i++) {
makePointerTo(from[i], to[i]);
}
}
void AliasDb::giveFreshAlias(const Value* value) {
auto maybe_mut_type = getMutableTypePtr(value->type());
if (!maybe_mut_type) {
return;
}
if (elementMap_.count(value)) {
// Inside a loop, we may have given a fresh alias to this value already, so
// skip
return;
}
auto new_elem = memoryDAG_->makeFreshValue(value);
elementMap_[value] = new_elem;
addContainedTypesToFreshElement(new_elem, *maybe_mut_type);
}
Element* AliasDb::getOrCreateElement(const Value* value) {
if (!elementMap_.count(value)) {
giveFreshAlias(value);
}
return elementMap_.at(value);
}
bool AliasDb::moveAfterTopologicallyValid(Node* n, Node* movePoint) {
return tryMove(n, movePoint, MoveSide::AFTER, /*dryRun=*/false);
}
bool AliasDb::couldMoveAfterTopologically(Node* n, Node* movePoint) {
return tryMove(n, movePoint, MoveSide::AFTER, /*dryRun=*/true);
}
bool AliasDb::moveBeforeTopologicallyValid(Node* n, Node* movePoint) {
// We have to distinguish the move side (instead of just moving after
// n->prev()). Consider the following example:
// If the dependency graph looks like
// n -> movePoint -> o
// then moveBefore(o) will end up with
// n, o, movePoint
// but moveAfter(n) will return false.
return tryMove(n, movePoint, MoveSide::BEFORE, /*dryRun=*/false);
}
bool AliasDb::couldMoveBeforeTopologically(Node* n, Node* movePoint) {
return tryMove(n, movePoint, MoveSide::BEFORE, /*dryRun=*/true);
}
bool AliasDb::hasWriters(const at::ArrayRef<Value*>& values) const {
return std::any_of(values.begin(), values.end(), [&](Value* value) {
return hasWriters(value);
});
}
bool AliasDb::escapesScope(const at::ArrayRef<Value*>& vs) const {
return mayContainAlias(graph_->inputs(), vs) ||
mayContainAlias(graph_->outputs(), vs) || mayAliasWildcard(vs);
}
// Correctness conditions:
// no values in either set can have writers, and values in both sets
// cannot escape the current graph scope. Values can escape the current scope
// by aliasing a graph output or input, or by aliasing the wildcard set.
bool AliasDb::safeToChangeAliasingRelationship(
const at::ArrayRef<Value*>& a,
const at::ArrayRef<Value*>& b) const {
if (hasWriters(a) || hasWriters(b)) {
return false;
}
return !(escapesScope(a) && escapesScope(b));
}
// Helper for topologically-safe node moves. See `tryMove()` for details.
class AliasDb::WorkingSet {
public:
explicit WorkingSet(Node* mover, const AliasDb& aliasDb) : aliasDb_(aliasDb) {
mover_ = mover;
for (const auto user : getUsersSameBlock(mover_)) {
moverUsers_.insert(user);
}
moverWrites_ |= aliasDb_.getWrites(mover_);
moverReads_ |= aliasDb_.getReads(mover_);
}
// Add `n` to the working set
void add(Node* n) {
nodes_.push_back(n);
for (const auto user : getUsersSameBlock(n)) {
users_.insert(user);
}
writes_ |= aliasDb_.getWrites(n);
reads_ |= aliasDb_.getReads(n);
}
void eraseMover() {
mover_ = nullptr;
moverWrites_.clear();
moverReads_.clear();
moverUsers_.clear();
}
const std::vector<Node*>& dependentNodes() {
return nodes_;
}
// Does the working set depend on `n`?
bool dependsOn(Node* n) const {
if (!mover_ && nodes_.empty()) {
return false;
}
return hasDataDependency(n) || hasMutabilityDependency(n);
}
private:
bool hasDataDependency(Node* n) const {
if (!mover_ && nodes_.empty()) {
return false;
}
const Node* pivot = mover_ ? mover_ : nodes_.front();
if (n->isAfter(pivot)) {
return producesFor(n);
} else {
return consumesFrom(n);
}
}
bool hasMutabilityDependency(Node* n) const {
// Check that `n` does not write to anything used by the working set
const auto& nWrites = aliasDb_.getWrites(n);
if (reads_.intersects(nWrites)) {
return true;
}
if (mover_ && moverReads_.intersects(nWrites)) {
return true;
}
// Check that the working set doesn't write to anything that `n` uses.
const auto& nReads = aliasDb_.getReads(n);
if (writes_.intersects(nReads)) {
return true;
}
if (mover_ && moverWrites_.intersects(nReads)) {
return true;
}
return false;
}
// Does the working set produce any values consumed by `n`?
bool producesFor(Node* n) const {
// This equivalent to asking: does the total use-set of all the nodes in the
// working set include `n`?
if (mover_ && moverUsers_.count(n)) {
return true;
}
return users_.count(n) != 0;
}
// Does the working set consume any values produced by `n`?
bool consumesFrom(Node* n) const {
const auto users = getUsersSameBlock(n);
if (mover_ && users.count(mover_)) {
return true;
}
return std::any_of(nodes_.begin(), nodes_.end(), [&](Node* node) {
return users.count(node) != 0;
});
}
// Get all users of outputs of `n`, in the same block as `n`.
// This means if there is an `if` node that uses an output of `n` in some
// inner sub-block, we will consider the whole `if` node a user of `n`.
std::unordered_set<Node*> getUsersSameBlock(Node* n) const {
std::unordered_set<Node*> users;
for (const auto output : n->outputs()) {
for (const auto& use : output->uses()) {
if (auto sameBlock = findSameBlock(use.user, n)) {
users.insert(sameBlock);
}
}
}
return users;
}
// Traverse `target`'s blockchain upward until we find a node that shares a
// block with `n`.
//
// If one can't be found (say, because `n` is an inner block and target is
// outside), then return nullptr. Since we can only reorder nodes within a
// block, `target` would be irrelevant.
static Node* findSameBlock(Node* target, Node* n) {
TORCH_INTERNAL_ASSERT(target->owningGraph() == n->owningGraph());
if (target->owningBlock() == n->owningBlock()) {
return target;
} else {
// This user is in a sub-block. Traverse the blockchain upward until
// we arrive at a node that shares a block with `this`
auto curNode = target;
while (curNode->owningBlock() != n->owningBlock()) {
curNode = curNode->owningBlock()->owningNode();
if (curNode == nullptr) {
return curNode;
}
}
return curNode;
}
}
const AliasDb& aliasDb_;
std::vector<Node*> nodes_;
// Mover dependencies. We track these separately since we may erase the mover
// from the working set.
Node* mover_;
MemoryLocations moverWrites_;
MemoryLocations moverReads_;
std::unordered_set<Node*> moverUsers_;
// users => # of working set nodes it uses
std::unordered_set<Node*> users_;
// Values written to by the working set => number of nodes writing to value
MemoryLocations writes_;
MemoryLocations reads_;
};
// Try to move `toMove` before/after `movePoint` while preserving value
// dependencies. Returns false iff such a move could not be made.
//
// If `dryRun` is set, don't actually execute the move, just check if the move
// is possible
//
// The basic approach is: have a "working set" that we are moving forward, one
// node at a time. When we can't move past a node (because it depends on the
// working set), then add it to the working set and keep moving until we hit
// `moveAfter`.
bool AliasDb::tryMove(
Node* toMove,
Node* movePoint,
MoveSide moveSide,
bool dryRun) {
TORCH_INTERNAL_ASSERT(toMove->owningBlock() == movePoint->owningBlock());
if (toMove == movePoint) {
return true;
}
// 1. Move from `this` toward movePoint, building up the working set of
// dependencies
WorkingSet workingSet(toMove, *this);
int direction;
if (toMove->isAfter(movePoint)) {
direction = kPrevDirection;
} else {
direction = kNextDirection;
}
auto curNode = toMove->next_in_graph[direction];
// Move forward one node at a time
while (curNode != movePoint) {
if (workingSet.dependsOn(curNode)) {
// If we can't move past this node, add it to the working set
workingSet.add(curNode);
}
curNode = curNode->next_in_graph[direction];
}
// 2. Decide whether we can move it all to `movePoint`.
// Say we are moving directly before movePoint and `toMove` starts before
// movePoint in the graph. The move looks like
//
// `toMove` `toMove` |
// <dependencies> -> `movePoint` | `toMove` and deps are split
// `movePoint` <dependencies> |
//
// Contrast with the case where `toMove` starts AFTER movePoint:
//
// `movePoint` <dependencies> |
// <dependencies> -> `toMove` | `toMove` and deps are together
// `toMove` `movePoint` |
//
// In the first case, we need to split `this` off from its dependencies, so we
// can move the dependencies below `movePoint` and keep `toMove` above.
const bool splitToMoveAndDeps =
(moveSide == MoveSide::BEFORE && toMove->isBefore(movePoint)) ||
(moveSide == MoveSide::AFTER && toMove->isAfter(movePoint));
if (splitToMoveAndDeps) {
// remove `this` from dependencies to be moved past `movePoint`
workingSet.eraseMover();
}
// Check if we can move the working set past the move point
if (workingSet.dependsOn(movePoint)) {
// if we can't, then there are intermediate dependencies between the
// `this` and `movePoint`, so we can't do the move
return false;
}
if (dryRun) {
return true;
}
// 3. Execute the move
TORCH_INTERNAL_ASSERT(curNode == movePoint);
if (splitToMoveAndDeps) {
// Move `toMove`
move(toMove, movePoint, moveSide);
// Then move all of its dependencies on the other side of `movePoint`
const auto reversed =
moveSide == MoveSide::BEFORE ? MoveSide::AFTER : MoveSide::BEFORE;
for (auto n : workingSet.dependentNodes()) {
move(n, curNode, reversed);
curNode = n;
}
} else {
// Just append/prepend everything to `movePoint`
move(toMove, curNode, moveSide);
curNode = toMove;
for (auto n : workingSet.dependentNodes()) {
move(n, curNode, moveSide);
curNode = n;
}
}
return true;
}
// Helper function so we can generalize `tryMove`
void AliasDb::move(Node* toMove, Node* movePoint, MoveSide moveSide) {
switch (moveSide) {
case MoveSide::BEFORE:
toMove->moveBefore(movePoint);
break;
case MoveSide::AFTER:
toMove->moveAfter(movePoint);
break;
}
}
bool AliasDb::writesToWildcard(Node* n) const {
if (!writeIndex_.count(n)) {
return false;
}
const auto& writes = writeIndex_.at(n);
// Are any of these memoryLocs a wildcard element?
for (const auto& pr : wildcardIndex_) {
const auto wildcardElement = pr.second;
if (writes.test(wildcardElement->index)) {
return true;
}
}
return false;
}
bool AliasDb::mayAliasWildcard(const Value* v) const {
if (auto e = getWildcard(v->type())) {
return memoryDAG_->mayAlias(elementMap_.at(v), e);
}
// There were no wildcards of this type, so return false.
return false;
}
bool AliasDb::mayAliasWildcard(const at::ArrayRef<Value*> vs) const {
return std::any_of(
vs.begin(), vs.end(), [&](Value* v) { return mayAliasWildcard(v); });
}
c10::optional<Element*> AliasDb::tryGetOrCreateWildcard(const TypePtr& type) {
auto updated_type = getMutableTypePtr(type);
if (!updated_type) {
return c10::nullopt;
}
auto mapped_type = *updated_type;
auto existing_wildcard = wildcardIndex_.find(mapped_type);
if (existing_wildcard != wildcardIndex_.end()) {
return existing_wildcard->second;
}
auto wildcard_elem = memoryDAG_->makeFreshValue(nullptr);
wildcardIndex_.emplace(mapped_type, wildcard_elem);
addContainedTypesToFreshElement(wildcard_elem, mapped_type);
return wildcard_elem;
}
void AliasDb::addContainedTypesToFreshElement(
Element* container_elem,
const TypePtr& mut_type) {
for (const auto& contained : mut_type->containedTypes()) {
auto maybe_elem = tryGetOrCreateWildcard(contained);
if (maybe_elem) {
memoryDAG_->addToContainedElements(*maybe_elem, container_elem);
}
}
}
// Search the wildcard index for an element that corresponds to the given type.
// Const version returns nullptr
Element* AliasDb::getWildcard(const TypePtr& type) const {
auto maybe_mut_type = getMutableTypePtr(type);
if (!maybe_mut_type) {
return nullptr;
}
TypePtr mut_type = *maybe_mut_type;
auto wildcard = wildcardIndex_.find(mut_type);
if (wildcard != wildcardIndex_.end()) {
return wildcard->second;
}
return nullptr;
}
// Register `v` as a wildcard value.
c10::optional<Element*> AliasDb::setWildcard(const Value* v) {
auto maybe_wildcardElement = tryGetOrCreateWildcard(v->type());
if (!maybe_wildcardElement) {
return c10::nullopt;
}
auto wildcardElement = *maybe_wildcardElement;
// Making a value a wildcard means that all its potential memory locations
// may alias the wildcard set.
const MemoryLocations pointeeSet = getOrCreateElement(v)->getMemoryLocations();
for (const auto& pointee : pointeeSet) {
auto from = memoryDAG_->fromIndex(pointee);
// avoid cycles where the wildcard points to itself
if (from != wildcardElement) {
memoryDAG_->makePointerTo(from, wildcardElement);
}
}
// We also need to update the write index with new writes to the wildcard set.
for (auto& pr : writeIndex_) {
auto& writtenTo = pr.second;
if (writtenTo.intersects(pointeeSet)) {
writtenTo.set(wildcardElement->index);
}
}
isWriteCacheStale_ = true;
return wildcardElement;
}
void AliasDb::rebuildWriteCache() const {
for (const auto& pr : writeIndex_) {
const auto& writtenLocs = pr.second;
writeCache_ |= writtenLocs;
}
isWriteCacheStale_ = false;
}
} // namespace jit
} // namespace torch