[Inductor] Implement a deduplist data structure for name to user tracking (#115609)

Summary:
An internal MRS model was taking over a day's worth of time to compile due to many duplicates in dependency tracking. This PR replaces the list with a custom dedup list.
Normally one could use a set/dict for this purpose however the list in question gets elements appended as it is being iterated over which means that we need to keep the list semantics.

Test Plan: ad hoc testing

Differential Revision: D52060659

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115609
Approved by: https://github.com/jansel
This commit is contained in:
Oguz Ulgen 2023-12-12 22:28:26 +00:00 committed by PyTorch MergeBot
parent ffb2a28a67
commit af09fe256a

View File

@ -12,11 +12,13 @@ from typing import (
Counter,
DefaultDict,
Dict,
Generic,
List,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
Union,
)
@ -1168,6 +1170,16 @@ class NodeUser:
# use the result
is_weak: bool = False
def __hash__(self):
return hash((self.node.get_name(), self.can_inplace, self.is_weak))
def __eq__(self, other):
return (
self.get_name() == other.get_name()
and self.can_inplace == other.can_inplace
and self.is_weak == other.is_weak
)
def get_name(self):
return self.node.get_name()
@ -1330,7 +1342,39 @@ class Scheduler:
Create dependency edges between nodes, handling aliasing and
mutation properly.
"""
name_to_users: DefaultDict[str, List[NodeUser]] = collections.defaultdict(list)
T = TypeVar("T")
class DedupList(Generic[T]):
"""
This data structure behaves like a list except it makes sure the
elements remain unique.
Normally one could use a set/dict for this purpose however
the list in question gets elements appended as it is being
iterated over which means that we need to keep the list
semantics.
"""
def __init__(self, items=None, membership=None):
self.items = items or list()
self.membership = membership or set()
def append(self, node_user: T) -> None:
if node_user in self.membership:
return
self.items.append(node_user)
self.membership.add(node_user)
def __add__(self, other: "DedupList[T]") -> "DedupList[T]":
new_membership = set.union(self.membership, other.membership)
new_items = self.items + [
x for x in other.items if x not in self.membership
]
return DedupList(new_items, new_membership)
name_to_users: DefaultDict[str, DedupList[NodeUser]] = collections.defaultdict(
DedupList
)
# handle aliasing by using python aliasing in name_to_users
# if foo aliases bar then we will make name_to_users["foo"] point
@ -1405,7 +1449,7 @@ class Scheduler:
# this node must run after the prior writer
add_user(alt_name, node)
node.add_mutation_dep(StarDep(alt_name))
for other_node in name_to_users[alt_name]:
for other_node in name_to_users[alt_name].items:
# this node must run after all prior readers
other_name = rename(other_node.get_name())
known_dep_node_names = dep_closure(node.get_name())
@ -1463,7 +1507,7 @@ class Scheduler:
# copy users information onto the nodes
for node in self.nodes:
node.set_users(name_to_users[node.get_name()])
node.set_users(name_to_users[node.get_name()].items)
# populate inverse_users
for node in self.nodes: