[JIT] Optimize DCE by storing a MemoryLocations for an entire set<Value*> (#153645)

Summary:
**TL;DR**: make DCE faster by replacing a Set<Value*> with a MemoryLocations sparse bitset (representing all the memory locations stored by the collection of all values in the set).

**Details**
The goal of this PR is to optimize this function from AliasDb:

```
bool AliasDb::writesToAlias(Node* n, const ValueSet& vs) const {
  const auto writtenTo = getWrites(n);
  if (writtenTo.empty()) {
    return false;
  }

  MemoryLocations locs;
  for (const auto v : vs) {
    auto it = elementMap_.find(v);
    if (it != elementMap_.end()) {
      const auto& vlocs = memoryDAG_->getMemoryLocations(it->second);
      if (writtenTo.intersects(vlocs)) {
        return true;
      }
    }
  }

  return false;
}
```

In the DCE use case, we have a ValueSet of live values, into which we insert `Value*`s; and sometimes need to check whether a node mutates any of the live values using `writesToAlias`.

Looping through all the values in the ValueSet and indexing into the elementMap_ is slow; so if we can pre-compute the MemoryLocations set, this speeds up the function. In some large model examples, I see ~15-25x speedups from this change.

**Implementation**: To avoid exposing too many details of AliasDb, I introduce a friend class `ValueAndMemoryLocationSet`, which is an insert-only set of Values, which also maintains the corresponding MemoryLocations.

Then in AliasDb, I use `ValueAndMemoryLocationSet` if we're using AliasDb for analysis, and otherwise use a `Set<Value*>` if we don't have AliasDb.

Test Plan: Rely on unit tests.

Differential Revision: D74827086

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153645
Approved by: https://github.com/eellison
This commit is contained in:
David Berard 2025-05-19 21:04:59 +00:00 committed by PyTorch MergeBot
parent cc48550e6f
commit a237831bc2
6 changed files with 168 additions and 12 deletions

View File

@ -44,3 +44,32 @@ class TestDCE(JitTestCase):
# freezing inlines t1.__init__(), after which DCE can occur.
t2 = torch.jit.freeze(t2)
FileCheck().check_not("prim::SetAttr").run(t2.graph)
def test_mutated_simple(self):
def fn(x: torch.Tensor):
y = x.sin()
y_slice = y[::2]
y_slice.add_(x[::2])
z = y.cos()
return z
fn_s = torch.jit.script(fn)
torch._C._jit_pass_dce_graph(fn_s.graph)
FileCheck().check("aten::add_").run(fn_s.graph)
def test_mutated_loop(self):
def fn(x: torch.Tensor):
y = x.sin()
y_slice = y[::2]
y_slice.add_(x[::2])
for _ in range(2):
y_slice = y[::2]
y = y.repeat(2)
z = y.cos()
return z
fn_s = torch.jit.script(fn)
torch._C._jit_pass_dce_graph(fn_s.graph)
FileCheck().check("aten::add_").run(fn_s.graph)

View File

@ -1602,6 +1602,7 @@ def _jit_set_logging_option(option: str) -> None: ...
def _jit_set_logging_stream(stream_name: str) -> None: ...
def _jit_pass_cse(Graph) -> _bool: ...
def _jit_pass_dce(Graph) -> None: ...
def _jit_pass_dce_graph(Graph) -> None: ...
def _jit_pass_lint(Graph) -> None: ...
# Defined in torch/csrc/jit/python/python_custom_class.cpp

View File

@ -396,6 +396,14 @@ MemoryLocations AliasDb::getReads(Node* n) const {
return reads;
}
MemoryLocations AliasDb::getMemoryLocations(Value* v) const {
auto it = elementMap_.find(v);
if (it != elementMap_.end()) {
return memoryDAG_->getMemoryLocations(it->second);
}
return MemoryLocations();
}
std::string AliasDb::getElementName(const Element* e) const {
if (e->values.empty()) {
// Not the most efficient way, but given the fact there are
@ -2014,4 +2022,27 @@ void Lint(const AliasDb* db) {
// - All container values have contained elements
}
ValueAndMemoryLocationSet AliasDb::getValueAndMemoryLocationSet() const {
return ValueAndMemoryLocationSet(this);
}
bool AliasDb::writesToAlias(Node* n, const ValueAndMemoryLocationSet& vls)
const {
const auto writtenTo = getWrites(n);
if (writtenTo.empty()) {
return false;
}
return writtenTo.intersects(vls.memoryLocations_);
}
void ValueAndMemoryLocationSet::insert(Value* v) {
valueSet_.insert(v);
memoryLocations_ |= aliasDb_->getMemoryLocations(v);
}
ValueSet& ValueAndMemoryLocationSet::getValueSet() {
return valueSet_;
}
} // namespace torch::jit

View File

@ -9,6 +9,8 @@
namespace torch::jit {
class ValueAndMemoryLocationSet;
/**
* Alias analysis pass.
*
@ -70,6 +72,12 @@ class AliasDb {
// if `recurseBlocks` is true, consider writes on the nodes in `n`s sub-blocks
TORCH_API bool writesToAlias(Node* n, const ValueSet& vs) const;
// Does `n` write to any of the values in `vls`?
TORCH_API bool writesToAlias(Node* n, const ValueAndMemoryLocationSet& vls)
const;
TORCH_API ValueAndMemoryLocationSet getValueAndMemoryLocationSet() const;
// Does `a` and `b` potentially share a memory location or do either
// hold in memory any element that exists in the other
TORCH_API bool mayContainAlias(Value* a, Value* b) const;
@ -170,6 +178,7 @@ class AliasDb {
void enablePreciseTupleContainerAnalysis();
friend struct MutationRemover;
friend class ValueAndMemoryLocationSet;
private:
// Helper for topologically-safe node moves.
@ -197,6 +206,7 @@ class AliasDb {
// if `recurseBlocks` is true, gather reads on the nodes in `n`s sub-blocks
MemoryLocations getReads(Node* n) const;
void getReadsImpl(Node* n, MemoryLocations& ret) const;
MemoryLocations getMemoryLocations(Value* v) const;
/**
* Wildcard methods
@ -317,4 +327,37 @@ class AliasDb {
// the right thing.
TORCH_API void Lint(const AliasDb* db);
/**
* ValueAndMemoryLocationSet
*
* A insert-only set of values which also maintains a MemoryLocations bitset
* of the memory locations that the values alias. It is insert-only. It
* should be constructed by calling aliasDb.getValueAndMemoryLocationSet().
*
* WARNING:
* * The AliasDb must not be mutated after construction of a
* ValueAndMemoryLocationsSet, or else the MemoryLocations stored in the
* ValueAndMemoryLocationSet will no longer be accurate.
* * A ValueAndMemoryLocationsSet is tied to an instsance of AliasDb but
* does not own the AliasDb. It is the user's responsibility to ensure
* that the AliasDb outlives the ValuesAndMemoryLocationsSet.
*
* The use case for this is to be able to implement writesToAlias
* more efficiently for a set of values.
*/
class ValueAndMemoryLocationSet {
public:
TORCH_API void insert(Value* v);
TORCH_API ValueSet& getValueSet();
friend class AliasDb;
private:
ValueAndMemoryLocationSet(const AliasDb* db) : aliasDb_(db){};
const AliasDb* aliasDb_;
ValueSet valueSet_;
MemoryLocations memoryLocations_;
};
} // namespace torch::jit

View File

@ -36,7 +36,7 @@ class DeadCodeEliminator {
mark(block);
deleteCallback_(liveValues_);
deleteCallback_(getLiveValues());
sweep(block, recurse);
}
@ -120,27 +120,27 @@ class DeadCodeEliminator {
// Special handling for onnx loop.
// The number of body carried inputs and outputs are different.
// They cannot be mapped to each other easily by the same index.
liveValues_.insert(loop.bodyCarriedOutputs().at(i));
insertLiveValue(loop.bodyCarriedOutputs().at(i));
continue;
}
auto innerInput = loop.bodyCarriedInputs().at(i);
auto innerOutput = loop.bodyCarriedOutputs().at(i);
auto outerOutput = loop.carriedOutputs().at(i);
if (liveValues_.count(outerOutput) || innerInput->hasUses()) {
liveValues_.insert(innerOutput);
if (liveValuesContains(outerOutput) || innerInput->hasUses()) {
insertLiveValue(innerOutput);
}
}
// Also mark the loop next condition as live, since it will be used inside
// the loop body.
liveValues_.insert(loop.nextCond());
insertLiveValue(loop.nextCond());
} else {
AT_ASSERT(outerNode->outputs().size() == node->inputs().size());
for (const auto i : c10::irange(outerNode->outputs().size())) {
auto innerOutput = node->inputs()[i];
auto outerOutput = outerNode->outputs()[i];
if (liveValues_.count(outerOutput)) {
liveValues_.insert(innerOutput);
if (liveValuesContains(outerOutput)) {
insertLiveValue(innerOutput);
}
}
}
@ -215,13 +215,14 @@ class DeadCodeEliminator {
// Returns true iff this marked something we haven't marked before.
bool markIfLive(Node* node) {
for (const auto output : node->outputs()) {
if (liveValues_.count(output)) {
if (liveValuesContains(output)) {
return mark(node);
}
}
if (useAliasDb_) {
if (getOrCreateAliasDb()->writesToAlias(node, liveValues_)) {
if (getOrCreateAliasDb()->writesToAlias(
node, getLiveValuesAndMemoryLocations())) {
return mark(node);
}
}
@ -252,10 +253,10 @@ class DeadCodeEliminator {
}
for (const auto input : node->inputs()) {
if (liveValues_.count(input)) {
if (liveValuesContains(input)) {
continue;
}
liveValues_.insert(input);
insertLiveValue(input);
}
return true;
}
@ -419,6 +420,46 @@ class DeadCodeEliminator {
return aliasDb_.get();
}
ValueAndMemoryLocationSet& getLiveValuesAndMemoryLocations() {
if (!liveValuesAndMemoryLocations_) {
liveValuesAndMemoryLocations_ =
std::make_unique<ValueAndMemoryLocationSet>(
getOrCreateAliasDb()->getValueAndMemoryLocationSet());
}
return *liveValuesAndMemoryLocations_;
}
ValueSet& getLiveValuesSet() {
if (!liveValuesSet_) {
liveValuesSet_ = std::make_unique<ValueSet>();
}
return *liveValuesSet_;
}
ValueSet& getLiveValues() {
if (useAliasDb_) {
return getLiveValuesAndMemoryLocations().getValueSet();
} else {
return getLiveValuesSet();
}
}
void insertLiveValue(Value* v) {
if (useAliasDb_) {
getLiveValuesAndMemoryLocations().insert(v);
} else {
getLiveValuesSet().insert(v);
}
}
bool liveValuesContains(Value* v) {
if (useAliasDb_) {
return getLiveValuesAndMemoryLocations().getValueSet().count(v);
} else {
return getLiveValuesSet().count(v);
}
}
DCESideEffectPolicy sideEffectPolicy_;
std::shared_ptr<Graph> graph_;
@ -427,7 +468,15 @@ class DeadCodeEliminator {
std::unique_ptr<AliasDb> aliasDb_ = nullptr;
std::unordered_map<Node*, bool> memo_;
std::unordered_set<Node*> marked_;
std::unordered_set<const Value*> liveValues_;
// we should have at most 1 of these as a non-nullptr; they are lazily
// initialized. liveValuesAndMemoryLocations_ is used if we are using AliasDb
// (in order to store aliasing info),
// otherwise liveValuesSet_ is used.
std::unique_ptr<ValueAndMemoryLocationSet> liveValuesAndMemoryLocations_ =
nullptr;
std::unique_ptr<ValueSet> liveValuesSet_ = nullptr;
std::function<void(const std::unordered_set<const Value*>&)> deleteCallback_ =
[](const std::unordered_set<const Value*>&) {};
};

View File

@ -293,6 +293,9 @@ void initJITBindings(PyObject* module) {
[](std::shared_ptr<Graph>& g) {
return EliminateDeadCode(g->block()); // overload resolution
})
.def(
"_jit_pass_dce_graph",
[](std::shared_ptr<Graph>& g) { return EliminateDeadCode(g); })
.def(
"_jit_pass_dce_allow_deleting_nodes_with_side_effects",
[](std::shared_ptr<Graph>& g) {