mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary:
**TL;DR**: make DCE faster by replacing a Set<Value*> with a MemoryLocations sparse bitset (representing all the memory locations stored by the collection of all values in the set).
**Details**
The goal of this PR is to optimize this function from AliasDb:
```
bool AliasDb::writesToAlias(Node* n, const ValueSet& vs) const {
const auto writtenTo = getWrites(n);
if (writtenTo.empty()) {
return false;
}
MemoryLocations locs;
for (const auto v : vs) {
auto it = elementMap_.find(v);
if (it != elementMap_.end()) {
const auto& vlocs = memoryDAG_->getMemoryLocations(it->second);
if (writtenTo.intersects(vlocs)) {
return true;
}
}
}
return false;
}
```
In the DCE use case, we have a ValueSet of live values, into which we insert `Value*`s; and sometimes need to check whether a node mutates any of the live values using `writesToAlias`.
Looping through all the values in the ValueSet and indexing into the elementMap_ is slow; so if we can pre-compute the MemoryLocations set, this speeds up the function. In some large model examples, I see ~15-25x speedups from this change.
**Implementation**: To avoid exposing too many details of AliasDb, I introduce a friend class `ValueAndMemoryLocationSet`, which is an insert-only set of Values, which also maintains the corresponding MemoryLocations.
Then in AliasDb, I use `ValueAndMemoryLocationSet` if we're using AliasDb for analysis, and otherwise use a `Set<Value*>` if we don't have AliasDb.
Test Plan: Rely on unit tests.
Differential Revision: D74827086
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153645
Approved by: https://github.com/eellison
76 lines
2.0 KiB
Python
76 lines
2.0 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import torch
|
|
from torch.testing import FileCheck
|
|
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
|
|
|
|
|
class TestDCE(JitTestCase):
|
|
def test_setattr_no_aliasdb(self):
|
|
class Net(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.x = torch.empty([2, 2])
|
|
|
|
def forward(self):
|
|
x = torch.rand([3, 3])
|
|
self.x = x
|
|
|
|
net = torch.jit.script(Net())
|
|
|
|
FileCheck().check("prim::SetAttr").run(net.graph)
|
|
|
|
def test_setattr_removed(self):
|
|
@torch.jit.script
|
|
class Thing1:
|
|
def __init__(self) -> None:
|
|
self.x = torch.zeros([2, 2])
|
|
|
|
make_global(Thing1)
|
|
|
|
class Thing2(torch.nn.Module):
|
|
def forward(self):
|
|
x = torch.rand([2, 2])
|
|
y = torch.rand([2, 2])
|
|
t1 = Thing1()
|
|
t1.x = x
|
|
return y
|
|
|
|
unscripted = Thing2()
|
|
|
|
t2 = torch.jit.script(unscripted)
|
|
t2.eval()
|
|
|
|
# freezing inlines t1.__init__(), after which DCE can occur.
|
|
t2 = torch.jit.freeze(t2)
|
|
FileCheck().check_not("prim::SetAttr").run(t2.graph)
|
|
|
|
def test_mutated_simple(self):
|
|
def fn(x: torch.Tensor):
|
|
y = x.sin()
|
|
y_slice = y[::2]
|
|
y_slice.add_(x[::2])
|
|
z = y.cos()
|
|
return z
|
|
|
|
fn_s = torch.jit.script(fn)
|
|
torch._C._jit_pass_dce_graph(fn_s.graph)
|
|
|
|
FileCheck().check("aten::add_").run(fn_s.graph)
|
|
|
|
def test_mutated_loop(self):
|
|
def fn(x: torch.Tensor):
|
|
y = x.sin()
|
|
y_slice = y[::2]
|
|
y_slice.add_(x[::2])
|
|
for _ in range(2):
|
|
y_slice = y[::2]
|
|
y = y.repeat(2)
|
|
z = y.cos()
|
|
return z
|
|
|
|
fn_s = torch.jit.script(fn)
|
|
torch._C._jit_pass_dce_graph(fn_s.graph)
|
|
|
|
FileCheck().check("aten::add_").run(fn_s.graph)
|