mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[JIT] Optimize DCE by storing a MemoryLocations for an entire set<Value*> (#153645)
Summary:
**TL;DR**: make DCE faster by replacing a Set<Value*> with a MemoryLocations sparse bitset (representing all the memory locations stored by the collection of all values in the set).
**Details**
The goal of this PR is to optimize this function from AliasDb:
```
bool AliasDb::writesToAlias(Node* n, const ValueSet& vs) const {
const auto writtenTo = getWrites(n);
if (writtenTo.empty()) {
return false;
}
MemoryLocations locs;
for (const auto v : vs) {
auto it = elementMap_.find(v);
if (it != elementMap_.end()) {
const auto& vlocs = memoryDAG_->getMemoryLocations(it->second);
if (writtenTo.intersects(vlocs)) {
return true;
}
}
}
return false;
}
```
In the DCE use case, we have a ValueSet of live values, into which we insert `Value*`s; and sometimes need to check whether a node mutates any of the live values using `writesToAlias`.
Looping through all the values in the ValueSet and indexing into the elementMap_ is slow; so if we can pre-compute the MemoryLocations set, this speeds up the function. In some large model examples, I see ~15-25x speedups from this change.
**Implementation**: To avoid exposing too many details of AliasDb, I introduce a friend class `ValueAndMemoryLocationSet`, which is an insert-only set of Values, which also maintains the corresponding MemoryLocations.
Then in AliasDb, I use `ValueAndMemoryLocationSet` if we're using AliasDb for analysis, and otherwise use a `Set<Value*>` if we don't have AliasDb.
Test Plan: Rely on unit tests.
Differential Revision: D74827086
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153645
Approved by: https://github.com/eellison
This commit is contained in:
parent
cc48550e6f
commit
a237831bc2
|
|
@ -44,3 +44,32 @@ class TestDCE(JitTestCase):
|
|||
# freezing inlines t1.__init__(), after which DCE can occur.
|
||||
t2 = torch.jit.freeze(t2)
|
||||
FileCheck().check_not("prim::SetAttr").run(t2.graph)
|
||||
|
||||
def test_mutated_simple(self):
|
||||
def fn(x: torch.Tensor):
|
||||
y = x.sin()
|
||||
y_slice = y[::2]
|
||||
y_slice.add_(x[::2])
|
||||
z = y.cos()
|
||||
return z
|
||||
|
||||
fn_s = torch.jit.script(fn)
|
||||
torch._C._jit_pass_dce_graph(fn_s.graph)
|
||||
|
||||
FileCheck().check("aten::add_").run(fn_s.graph)
|
||||
|
||||
def test_mutated_loop(self):
|
||||
def fn(x: torch.Tensor):
|
||||
y = x.sin()
|
||||
y_slice = y[::2]
|
||||
y_slice.add_(x[::2])
|
||||
for _ in range(2):
|
||||
y_slice = y[::2]
|
||||
y = y.repeat(2)
|
||||
z = y.cos()
|
||||
return z
|
||||
|
||||
fn_s = torch.jit.script(fn)
|
||||
torch._C._jit_pass_dce_graph(fn_s.graph)
|
||||
|
||||
FileCheck().check("aten::add_").run(fn_s.graph)
|
||||
|
|
|
|||
|
|
@ -1602,6 +1602,7 @@ def _jit_set_logging_option(option: str) -> None: ...
|
|||
def _jit_set_logging_stream(stream_name: str) -> None: ...
|
||||
def _jit_pass_cse(Graph) -> _bool: ...
|
||||
def _jit_pass_dce(Graph) -> None: ...
|
||||
def _jit_pass_dce_graph(Graph) -> None: ...
|
||||
def _jit_pass_lint(Graph) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/jit/python/python_custom_class.cpp
|
||||
|
|
|
|||
|
|
@ -396,6 +396,14 @@ MemoryLocations AliasDb::getReads(Node* n) const {
|
|||
return reads;
|
||||
}
|
||||
|
||||
MemoryLocations AliasDb::getMemoryLocations(Value* v) const {
|
||||
auto it = elementMap_.find(v);
|
||||
if (it != elementMap_.end()) {
|
||||
return memoryDAG_->getMemoryLocations(it->second);
|
||||
}
|
||||
return MemoryLocations();
|
||||
}
|
||||
|
||||
std::string AliasDb::getElementName(const Element* e) const {
|
||||
if (e->values.empty()) {
|
||||
// Not the most efficient way, but given the fact there are
|
||||
|
|
@ -2014,4 +2022,27 @@ void Lint(const AliasDb* db) {
|
|||
// - All container values have contained elements
|
||||
}
|
||||
|
||||
ValueAndMemoryLocationSet AliasDb::getValueAndMemoryLocationSet() const {
|
||||
return ValueAndMemoryLocationSet(this);
|
||||
}
|
||||
|
||||
bool AliasDb::writesToAlias(Node* n, const ValueAndMemoryLocationSet& vls)
|
||||
const {
|
||||
const auto writtenTo = getWrites(n);
|
||||
if (writtenTo.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return writtenTo.intersects(vls.memoryLocations_);
|
||||
}
|
||||
|
||||
void ValueAndMemoryLocationSet::insert(Value* v) {
|
||||
valueSet_.insert(v);
|
||||
memoryLocations_ |= aliasDb_->getMemoryLocations(v);
|
||||
}
|
||||
|
||||
ValueSet& ValueAndMemoryLocationSet::getValueSet() {
|
||||
return valueSet_;
|
||||
}
|
||||
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@
|
|||
|
||||
namespace torch::jit {
|
||||
|
||||
class ValueAndMemoryLocationSet;
|
||||
|
||||
/**
|
||||
* Alias analysis pass.
|
||||
*
|
||||
|
|
@ -70,6 +72,12 @@ class AliasDb {
|
|||
// if `recurseBlocks` is true, consider writes on the nodes in `n`s sub-blocks
|
||||
TORCH_API bool writesToAlias(Node* n, const ValueSet& vs) const;
|
||||
|
||||
// Does `n` write to any of the values in `vls`?
|
||||
TORCH_API bool writesToAlias(Node* n, const ValueAndMemoryLocationSet& vls)
|
||||
const;
|
||||
|
||||
TORCH_API ValueAndMemoryLocationSet getValueAndMemoryLocationSet() const;
|
||||
|
||||
// Does `a` and `b` potentially share a memory location or do either
|
||||
// hold in memory any element that exists in the other
|
||||
TORCH_API bool mayContainAlias(Value* a, Value* b) const;
|
||||
|
|
@ -170,6 +178,7 @@ class AliasDb {
|
|||
void enablePreciseTupleContainerAnalysis();
|
||||
|
||||
friend struct MutationRemover;
|
||||
friend class ValueAndMemoryLocationSet;
|
||||
|
||||
private:
|
||||
// Helper for topologically-safe node moves.
|
||||
|
|
@ -197,6 +206,7 @@ class AliasDb {
|
|||
// if `recurseBlocks` is true, gather reads on the nodes in `n`s sub-blocks
|
||||
MemoryLocations getReads(Node* n) const;
|
||||
void getReadsImpl(Node* n, MemoryLocations& ret) const;
|
||||
MemoryLocations getMemoryLocations(Value* v) const;
|
||||
|
||||
/**
|
||||
* Wildcard methods
|
||||
|
|
@ -317,4 +327,37 @@ class AliasDb {
|
|||
// the right thing.
|
||||
TORCH_API void Lint(const AliasDb* db);
|
||||
|
||||
/**
|
||||
* ValueAndMemoryLocationSet
|
||||
*
|
||||
* A insert-only set of values which also maintains a MemoryLocations bitset
|
||||
* of the memory locations that the values alias. It is insert-only. It
|
||||
* should be constructed by calling aliasDb.getValueAndMemoryLocationSet().
|
||||
*
|
||||
* WARNING:
|
||||
* * The AliasDb must not be mutated after construction of a
|
||||
* ValueAndMemoryLocationsSet, or else the MemoryLocations stored in the
|
||||
* ValueAndMemoryLocationSet will no longer be accurate.
|
||||
* * A ValueAndMemoryLocationsSet is tied to an instsance of AliasDb but
|
||||
* does not own the AliasDb. It is the user's responsibility to ensure
|
||||
* that the AliasDb outlives the ValuesAndMemoryLocationsSet.
|
||||
*
|
||||
* The use case for this is to be able to implement writesToAlias
|
||||
* more efficiently for a set of values.
|
||||
*/
|
||||
class ValueAndMemoryLocationSet {
|
||||
public:
|
||||
TORCH_API void insert(Value* v);
|
||||
TORCH_API ValueSet& getValueSet();
|
||||
|
||||
friend class AliasDb;
|
||||
|
||||
private:
|
||||
ValueAndMemoryLocationSet(const AliasDb* db) : aliasDb_(db){};
|
||||
|
||||
const AliasDb* aliasDb_;
|
||||
ValueSet valueSet_;
|
||||
MemoryLocations memoryLocations_;
|
||||
};
|
||||
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ class DeadCodeEliminator {
|
|||
|
||||
mark(block);
|
||||
|
||||
deleteCallback_(liveValues_);
|
||||
deleteCallback_(getLiveValues());
|
||||
|
||||
sweep(block, recurse);
|
||||
}
|
||||
|
|
@ -120,27 +120,27 @@ class DeadCodeEliminator {
|
|||
// Special handling for onnx loop.
|
||||
// The number of body carried inputs and outputs are different.
|
||||
// They cannot be mapped to each other easily by the same index.
|
||||
liveValues_.insert(loop.bodyCarriedOutputs().at(i));
|
||||
insertLiveValue(loop.bodyCarriedOutputs().at(i));
|
||||
continue;
|
||||
}
|
||||
auto innerInput = loop.bodyCarriedInputs().at(i);
|
||||
auto innerOutput = loop.bodyCarriedOutputs().at(i);
|
||||
auto outerOutput = loop.carriedOutputs().at(i);
|
||||
if (liveValues_.count(outerOutput) || innerInput->hasUses()) {
|
||||
liveValues_.insert(innerOutput);
|
||||
if (liveValuesContains(outerOutput) || innerInput->hasUses()) {
|
||||
insertLiveValue(innerOutput);
|
||||
}
|
||||
}
|
||||
|
||||
// Also mark the loop next condition as live, since it will be used inside
|
||||
// the loop body.
|
||||
liveValues_.insert(loop.nextCond());
|
||||
insertLiveValue(loop.nextCond());
|
||||
} else {
|
||||
AT_ASSERT(outerNode->outputs().size() == node->inputs().size());
|
||||
for (const auto i : c10::irange(outerNode->outputs().size())) {
|
||||
auto innerOutput = node->inputs()[i];
|
||||
auto outerOutput = outerNode->outputs()[i];
|
||||
if (liveValues_.count(outerOutput)) {
|
||||
liveValues_.insert(innerOutput);
|
||||
if (liveValuesContains(outerOutput)) {
|
||||
insertLiveValue(innerOutput);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -215,13 +215,14 @@ class DeadCodeEliminator {
|
|||
// Returns true iff this marked something we haven't marked before.
|
||||
bool markIfLive(Node* node) {
|
||||
for (const auto output : node->outputs()) {
|
||||
if (liveValues_.count(output)) {
|
||||
if (liveValuesContains(output)) {
|
||||
return mark(node);
|
||||
}
|
||||
}
|
||||
|
||||
if (useAliasDb_) {
|
||||
if (getOrCreateAliasDb()->writesToAlias(node, liveValues_)) {
|
||||
if (getOrCreateAliasDb()->writesToAlias(
|
||||
node, getLiveValuesAndMemoryLocations())) {
|
||||
return mark(node);
|
||||
}
|
||||
}
|
||||
|
|
@ -252,10 +253,10 @@ class DeadCodeEliminator {
|
|||
}
|
||||
|
||||
for (const auto input : node->inputs()) {
|
||||
if (liveValues_.count(input)) {
|
||||
if (liveValuesContains(input)) {
|
||||
continue;
|
||||
}
|
||||
liveValues_.insert(input);
|
||||
insertLiveValue(input);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
@ -419,6 +420,46 @@ class DeadCodeEliminator {
|
|||
return aliasDb_.get();
|
||||
}
|
||||
|
||||
ValueAndMemoryLocationSet& getLiveValuesAndMemoryLocations() {
|
||||
if (!liveValuesAndMemoryLocations_) {
|
||||
liveValuesAndMemoryLocations_ =
|
||||
std::make_unique<ValueAndMemoryLocationSet>(
|
||||
getOrCreateAliasDb()->getValueAndMemoryLocationSet());
|
||||
}
|
||||
return *liveValuesAndMemoryLocations_;
|
||||
}
|
||||
|
||||
ValueSet& getLiveValuesSet() {
|
||||
if (!liveValuesSet_) {
|
||||
liveValuesSet_ = std::make_unique<ValueSet>();
|
||||
}
|
||||
return *liveValuesSet_;
|
||||
}
|
||||
|
||||
ValueSet& getLiveValues() {
|
||||
if (useAliasDb_) {
|
||||
return getLiveValuesAndMemoryLocations().getValueSet();
|
||||
} else {
|
||||
return getLiveValuesSet();
|
||||
}
|
||||
}
|
||||
|
||||
void insertLiveValue(Value* v) {
|
||||
if (useAliasDb_) {
|
||||
getLiveValuesAndMemoryLocations().insert(v);
|
||||
} else {
|
||||
getLiveValuesSet().insert(v);
|
||||
}
|
||||
}
|
||||
|
||||
bool liveValuesContains(Value* v) {
|
||||
if (useAliasDb_) {
|
||||
return getLiveValuesAndMemoryLocations().getValueSet().count(v);
|
||||
} else {
|
||||
return getLiveValuesSet().count(v);
|
||||
}
|
||||
}
|
||||
|
||||
DCESideEffectPolicy sideEffectPolicy_;
|
||||
|
||||
std::shared_ptr<Graph> graph_;
|
||||
|
|
@ -427,7 +468,15 @@ class DeadCodeEliminator {
|
|||
std::unique_ptr<AliasDb> aliasDb_ = nullptr;
|
||||
std::unordered_map<Node*, bool> memo_;
|
||||
std::unordered_set<Node*> marked_;
|
||||
std::unordered_set<const Value*> liveValues_;
|
||||
|
||||
// we should have at most 1 of these as a non-nullptr; they are lazily
|
||||
// initialized. liveValuesAndMemoryLocations_ is used if we are using AliasDb
|
||||
// (in order to store aliasing info),
|
||||
// otherwise liveValuesSet_ is used.
|
||||
std::unique_ptr<ValueAndMemoryLocationSet> liveValuesAndMemoryLocations_ =
|
||||
nullptr;
|
||||
std::unique_ptr<ValueSet> liveValuesSet_ = nullptr;
|
||||
|
||||
std::function<void(const std::unordered_set<const Value*>&)> deleteCallback_ =
|
||||
[](const std::unordered_set<const Value*>&) {};
|
||||
};
|
||||
|
|
|
|||
|
|
@ -293,6 +293,9 @@ void initJITBindings(PyObject* module) {
|
|||
[](std::shared_ptr<Graph>& g) {
|
||||
return EliminateDeadCode(g->block()); // overload resolution
|
||||
})
|
||||
.def(
|
||||
"_jit_pass_dce_graph",
|
||||
[](std::shared_ptr<Graph>& g) { return EliminateDeadCode(g); })
|
||||
.def(
|
||||
"_jit_pass_dce_allow_deleting_nodes_with_side_effects",
|
||||
[](std::shared_ptr<Graph>& g) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user