pytorch/test/jit/test_dce.py
David Berard a237831bc2 [JIT] Optimize DCE by storing a MemoryLocations for an entire set<Value*> (#153645)
Summary:
**TL;DR**: make DCE faster by replacing a Set<Value*> with a MemoryLocations sparse bitset (representing all the memory locations stored by the collection of all values in the set).

**Details**
The goal of this PR is to optimize this function from AliasDb:

```
bool AliasDb::writesToAlias(Node* n, const ValueSet& vs) const {
  const auto writtenTo = getWrites(n);
  if (writtenTo.empty()) {
    return false;
  }

  MemoryLocations locs;
  for (const auto v : vs) {
    auto it = elementMap_.find(v);
    if (it != elementMap_.end()) {
      const auto& vlocs = memoryDAG_->getMemoryLocations(it->second);
      if (writtenTo.intersects(vlocs)) {
        return true;
      }
    }
  }

  return false;
}
```

In the DCE use case, we have a ValueSet of live values, into which we insert `Value*`s; and sometimes need to check whether a node mutates any of the live values using `writesToAlias`.

Looping through all the values in the ValueSet and indexing into the elementMap_ is slow; so if we can pre-compute the MemoryLocations set, this speeds up the function. In some large model examples, I see ~15-25x speedups from this change.

**Implementation**: To avoid exposing too many details of AliasDb, I introduce a friend class `ValueAndMemoryLocationSet`, which is an insert-only set of Values, which also maintains the corresponding MemoryLocations.

Then in AliasDb, I use `ValueAndMemoryLocationSet` if we're using AliasDb for analysis, and otherwise use a `Set<Value*>` if we don't have AliasDb.

Test Plan: Rely on unit tests.

Differential Revision: D74827086

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153645
Approved by: https://github.com/eellison
2025-05-19 21:04:59 +00:00

76 lines
2.0 KiB
Python

# Owner(s): ["oncall: jit"]
import torch
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase, make_global
class TestDCE(JitTestCase):
def test_setattr_no_aliasdb(self):
class Net(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.x = torch.empty([2, 2])
def forward(self):
x = torch.rand([3, 3])
self.x = x
net = torch.jit.script(Net())
FileCheck().check("prim::SetAttr").run(net.graph)
def test_setattr_removed(self):
@torch.jit.script
class Thing1:
def __init__(self) -> None:
self.x = torch.zeros([2, 2])
make_global(Thing1)
class Thing2(torch.nn.Module):
def forward(self):
x = torch.rand([2, 2])
y = torch.rand([2, 2])
t1 = Thing1()
t1.x = x
return y
unscripted = Thing2()
t2 = torch.jit.script(unscripted)
t2.eval()
# freezing inlines t1.__init__(), after which DCE can occur.
t2 = torch.jit.freeze(t2)
FileCheck().check_not("prim::SetAttr").run(t2.graph)
def test_mutated_simple(self):
def fn(x: torch.Tensor):
y = x.sin()
y_slice = y[::2]
y_slice.add_(x[::2])
z = y.cos()
return z
fn_s = torch.jit.script(fn)
torch._C._jit_pass_dce_graph(fn_s.graph)
FileCheck().check("aten::add_").run(fn_s.graph)
def test_mutated_loop(self):
def fn(x: torch.Tensor):
y = x.sin()
y_slice = y[::2]
y_slice.add_(x[::2])
for _ in range(2):
y_slice = y[::2]
y = y.repeat(2)
z = y.cos()
return z
fn_s = torch.jit.script(fn)
torch._C._jit_pass_dce_graph(fn_s.graph)
FileCheck().check("aten::add_").run(fn_s.graph)