mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Inductor] Implement a deduplist data structure for name to user tracking (#115609)
Summary: An internal MRS model was taking over a day's worth of time to compile due to many duplicates in dependency tracking. This PR replaces the list with a custom dedup list. Normally one could use a set/dict for this purpose however the list in question gets elements appended as it is being iterated over which means that we need to keep the list semantics. Test Plan: ad hoc testing Differential Revision: D52060659 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115609 Approved by: https://github.com/jansel
This commit is contained in:
parent
ffb2a28a67
commit
af09fe256a
|
|
@ -12,11 +12,13 @@ from typing import (
|
|||
Counter,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
|
|
@ -1168,6 +1170,16 @@ class NodeUser:
|
|||
# use the result
|
||||
is_weak: bool = False
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.node.get_name(), self.can_inplace, self.is_weak))
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
self.get_name() == other.get_name()
|
||||
and self.can_inplace == other.can_inplace
|
||||
and self.is_weak == other.is_weak
|
||||
)
|
||||
|
||||
def get_name(self):
|
||||
return self.node.get_name()
|
||||
|
||||
|
|
@ -1330,7 +1342,39 @@ class Scheduler:
|
|||
Create dependency edges between nodes, handling aliasing and
|
||||
mutation properly.
|
||||
"""
|
||||
name_to_users: DefaultDict[str, List[NodeUser]] = collections.defaultdict(list)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
class DedupList(Generic[T]):
|
||||
"""
|
||||
This data structure behaves like a list except it makes sure the
|
||||
elements remain unique.
|
||||
Normally one could use a set/dict for this purpose however
|
||||
the list in question gets elements appended as it is being
|
||||
iterated over which means that we need to keep the list
|
||||
semantics.
|
||||
"""
|
||||
|
||||
def __init__(self, items=None, membership=None):
|
||||
self.items = items or list()
|
||||
self.membership = membership or set()
|
||||
|
||||
def append(self, node_user: T) -> None:
|
||||
if node_user in self.membership:
|
||||
return
|
||||
self.items.append(node_user)
|
||||
self.membership.add(node_user)
|
||||
|
||||
def __add__(self, other: "DedupList[T]") -> "DedupList[T]":
|
||||
new_membership = set.union(self.membership, other.membership)
|
||||
new_items = self.items + [
|
||||
x for x in other.items if x not in self.membership
|
||||
]
|
||||
return DedupList(new_items, new_membership)
|
||||
|
||||
name_to_users: DefaultDict[str, DedupList[NodeUser]] = collections.defaultdict(
|
||||
DedupList
|
||||
)
|
||||
|
||||
# handle aliasing by using python aliasing in name_to_users
|
||||
# if foo aliases bar then we will make name_to_users["foo"] point
|
||||
|
|
@ -1405,7 +1449,7 @@ class Scheduler:
|
|||
# this node must run after the prior writer
|
||||
add_user(alt_name, node)
|
||||
node.add_mutation_dep(StarDep(alt_name))
|
||||
for other_node in name_to_users[alt_name]:
|
||||
for other_node in name_to_users[alt_name].items:
|
||||
# this node must run after all prior readers
|
||||
other_name = rename(other_node.get_name())
|
||||
known_dep_node_names = dep_closure(node.get_name())
|
||||
|
|
@ -1463,7 +1507,7 @@ class Scheduler:
|
|||
|
||||
# copy users information onto the nodes
|
||||
for node in self.nodes:
|
||||
node.set_users(name_to_users[node.get_name()])
|
||||
node.set_users(name_to_users[node.get_name()].items)
|
||||
|
||||
# populate inverse_users
|
||||
for node in self.nodes:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user