mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Use OrderedSet in _functorch/partitioners (#146102)
In an attempt to make partitioning more deterministic, change all sets in partitioners.py to OrderedSets. Note that this change does not fix the non-determinism we're seeing in the internal model. But let's at least eliminate this potential source of non-determinism before investigating any changes to the mincut approach? Pull Request resolved: https://github.com/pytorch/pytorch/pull/146102 Approved by: https://github.com/oulgen
This commit is contained in:
parent
53759ccca8
commit
23fffb54d5
|
|
@ -1703,6 +1703,7 @@ command = [
|
|||
]
|
||||
include_patterns = [
|
||||
"torch/_inductor/**/*.py",
|
||||
"torch/_functorch/partitioners.py",
|
||||
]
|
||||
is_formatter = true
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
|||
is_symbol_binding_fx_node,
|
||||
)
|
||||
from torch.fx.passes import graph_drawer
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils.checkpoint import CheckpointPolicy
|
||||
|
||||
from . import config
|
||||
|
|
@ -55,11 +56,11 @@ prims = torch.ops.prims
|
|||
class OpTypes:
|
||||
"""Class for keeping track of different operator categories"""
|
||||
|
||||
fusible_ops: set[Callable]
|
||||
compute_intensive_ops: set[Callable]
|
||||
random_ops: set[Callable]
|
||||
view_ops: set[Callable]
|
||||
recomputable_ops: set[Callable]
|
||||
fusible_ops: OrderedSet[Callable]
|
||||
compute_intensive_ops: OrderedSet[Callable]
|
||||
random_ops: OrderedSet[Callable]
|
||||
view_ops: OrderedSet[Callable]
|
||||
recomputable_ops: OrderedSet[Callable]
|
||||
|
||||
def is_fusible(self, node: fx.Node):
|
||||
return get_aten_target(node) in self.fusible_ops
|
||||
|
|
@ -82,9 +83,9 @@ class NodeInfo:
|
|||
# Be careful about iterating over these explicitly, as their order may not
|
||||
# be deterministic
|
||||
inputs: list[fx.Node]
|
||||
_required_fw_nodes: set[fx.Node]
|
||||
required_bw_nodes: set[fx.Node]
|
||||
unclaimed_nodes: set[fx.Node]
|
||||
_required_fw_nodes: OrderedSet[fx.Node]
|
||||
required_bw_nodes: OrderedSet[fx.Node]
|
||||
unclaimed_nodes: OrderedSet[fx.Node]
|
||||
fw_order: dict[fx.Node, int]
|
||||
|
||||
@functools.cached_property
|
||||
|
|
@ -326,7 +327,7 @@ def _extract_fwd_bwd_modules(
|
|||
# we propagate all symbols which are referenced by backwards inputs.
|
||||
# These are not directly used in the graph but are required for downstream
|
||||
# sizevar assignment
|
||||
saved_symbols: set[sympy.Symbol] = set()
|
||||
saved_symbols: OrderedSet[sympy.Symbol] = OrderedSet()
|
||||
saved_sym_nodes_binding = []
|
||||
saved_sym_nodes_derived = []
|
||||
|
||||
|
|
@ -426,9 +427,9 @@ def default_partition(
|
|||
forward_only_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph, inputs, fwd_outputs, "forward"
|
||||
)
|
||||
forward_node_names = {
|
||||
forward_node_names = OrderedSet(
|
||||
node.name for node in forward_only_graph.nodes if node.op != "output"
|
||||
}
|
||||
)
|
||||
saved_values = []
|
||||
saved_sym_nodes = []
|
||||
|
||||
|
|
@ -580,7 +581,7 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule:
|
|||
|
||||
def insert_node_in_graph(node):
|
||||
cur_nodes = [node]
|
||||
insertable_nodes = set()
|
||||
insertable_nodes: OrderedSet[fx.Node] = OrderedSet()
|
||||
while len(cur_nodes) > 0:
|
||||
node = cur_nodes.pop()
|
||||
if node in insertable_nodes or node in env:
|
||||
|
|
@ -817,19 +818,21 @@ def solve_min_cut(
|
|||
joint_graph: fx.Graph,
|
||||
node_info: NodeInfo,
|
||||
min_cut_options: MinCutOptions,
|
||||
dont_ban=None,
|
||||
dont_ban: Optional[OrderedSet[fx.Node]] = None,
|
||||
):
|
||||
if dont_ban is None:
|
||||
dont_ban = set()
|
||||
dont_ban = OrderedSet()
|
||||
op_types = get_default_op_list()
|
||||
|
||||
if AOT_PARTITIONER_DEBUG:
|
||||
joint_module_ops = {
|
||||
joint_module_ops = OrderedSet(
|
||||
str(node.target._overloadpacket)
|
||||
for node in joint_graph.nodes
|
||||
if node.op == "call_function" and hasattr(node.target, "_overloadpacket")
|
||||
}
|
||||
ops_ignored = joint_module_ops - {str(i) for i in op_types.recomputable_ops}
|
||||
)
|
||||
ops_ignored = joint_module_ops - OrderedSet(
|
||||
str(i) for i in op_types.recomputable_ops
|
||||
)
|
||||
log.info("Ops banned from re-materialization: %s", ops_ignored)
|
||||
|
||||
def can_fuse_into_auto_functionalized(a, b):
|
||||
|
|
@ -888,7 +891,7 @@ def solve_min_cut(
|
|||
def is_materialized_backwards(node):
|
||||
if op_types.is_view(node):
|
||||
return False
|
||||
cur_nodes = {node}
|
||||
cur_nodes = OrderedSet([node])
|
||||
while len(cur_nodes) > 0:
|
||||
cur = cur_nodes.pop()
|
||||
for user in cur.users:
|
||||
|
|
@ -981,7 +984,7 @@ def solve_min_cut(
|
|||
return mem_sz * 2
|
||||
|
||||
nx_graph = nx.DiGraph()
|
||||
banned_nodes = set()
|
||||
banned_nodes: OrderedSet[fx.Node] = OrderedSet()
|
||||
|
||||
def ban_recomputation_if_allowed(node):
|
||||
if op_types.is_view(node):
|
||||
|
|
@ -1091,12 +1094,13 @@ def solve_min_cut(
|
|||
if node_info.is_required_fw(user):
|
||||
if node_info.get_fw_order(user) > max_range:
|
||||
continue
|
||||
val = (node_info.get_fw_order(user), user, is_fusible(node, user))
|
||||
val: tuple[int, fx.Node, bool] = (
|
||||
node_info.get_fw_order(user),
|
||||
user,
|
||||
is_fusible(node, user),
|
||||
)
|
||||
if val not in sorted_nodes:
|
||||
heapq.heappush(
|
||||
sorted_nodes,
|
||||
val,
|
||||
)
|
||||
heapq.heappush(sorted_nodes, val)
|
||||
return max_range
|
||||
|
||||
if min_cut_options.ban_if_used_far_apart:
|
||||
|
|
@ -1141,11 +1145,13 @@ def solve_min_cut(
|
|||
# Some models it improves perf on are cait_m36_384, mixer_b16_224, poolformer_m36
|
||||
|
||||
if min_cut_options.ban_if_long_fusible_chains:
|
||||
visited = set()
|
||||
visited: OrderedSet[fx.Node] = OrderedSet()
|
||||
for start_node in joint_graph.nodes:
|
||||
if not node_info.is_required_fw(start_node):
|
||||
continue
|
||||
fusible = [(node_info.get_fw_order(start_node), start_node)]
|
||||
fusible: list[tuple[int, fx.Node]] = [
|
||||
(node_info.get_fw_order(start_node), start_node)
|
||||
]
|
||||
start_order = node_info.get_fw_order(start_node)
|
||||
while len(fusible) > 0:
|
||||
_, cur = heapq.heappop(fusible)
|
||||
|
|
@ -1184,11 +1190,11 @@ def solve_min_cut(
|
|||
raise
|
||||
|
||||
reachable, non_reachable = partition
|
||||
cutset: set[tuple[str, str]] = set()
|
||||
cutset: OrderedSet[tuple[str, str]] = OrderedSet()
|
||||
for u, nbrs in ((n, nx_graph[n]) for n in reachable):
|
||||
cutset.update((u, v) for v in nbrs if v in non_reachable)
|
||||
|
||||
cut_nodes = set()
|
||||
cut_nodes: OrderedSet[str] = OrderedSet()
|
||||
for node_in, node_out in cutset:
|
||||
assert node_in[:-3] == node_out[:-4]
|
||||
node_name = node_in[:-3]
|
||||
|
|
@ -1358,9 +1364,9 @@ def get_default_op_list() -> OpTypes:
|
|||
]
|
||||
|
||||
default_recomputable_ops += [method_to_operator(m) for m in magic_methods]
|
||||
recomputable_ops = set(default_recomputable_ops)
|
||||
recomputable_ops = OrderedSet(default_recomputable_ops)
|
||||
|
||||
random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like]
|
||||
random_ops = OrderedSet([aten.native_dropout, aten.rand_like, aten.randn_like])
|
||||
compute_intensive_ops = [
|
||||
aten.mm,
|
||||
aten.convolution,
|
||||
|
|
@ -1375,13 +1381,13 @@ def get_default_op_list() -> OpTypes:
|
|||
aten._scaled_mm,
|
||||
] # noqa: E501,B950
|
||||
|
||||
fusible_ops = recomputable_ops | set(random_ops)
|
||||
fusible_ops = recomputable_ops | random_ops
|
||||
return OpTypes(
|
||||
set(fusible_ops),
|
||||
set(compute_intensive_ops),
|
||||
set(random_ops),
|
||||
set(view_ops),
|
||||
set(recomputable_ops),
|
||||
fusible_ops,
|
||||
OrderedSet(compute_intensive_ops),
|
||||
random_ops,
|
||||
OrderedSet(view_ops),
|
||||
recomputable_ops,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1567,9 +1573,11 @@ def choose_saved_values_set(
|
|||
|
||||
from torch._inductor.fx_utils import get_node_storage
|
||||
|
||||
input_storages = {get_node_storage(node) for node in node_info.inputs}
|
||||
input_storages = OrderedSet(get_node_storage(node) for node in node_info.inputs)
|
||||
|
||||
def get_recomputable_banned_nodes(banned_nodes: set[fx.Node]) -> list[fx.Node]:
|
||||
def get_recomputable_banned_nodes(
|
||||
banned_nodes: OrderedSet[fx.Node],
|
||||
) -> list[fx.Node]:
|
||||
return [
|
||||
i
|
||||
for i in banned_nodes
|
||||
|
|
@ -1653,7 +1661,7 @@ Activation Checkpointing - Knapsack Problem Summary:
|
|||
payload_fn=lambda: knapsack_summary,
|
||||
)
|
||||
log.info(knapsack_summary)
|
||||
dont_ban = set()
|
||||
dont_ban: OrderedSet[fx.Node] = OrderedSet()
|
||||
for idx in recomputable_node_idxs:
|
||||
# if idx in all_recomputable_banned_nodes:
|
||||
try:
|
||||
|
|
@ -1776,7 +1784,7 @@ def min_cut_rematerialization_partition(
|
|||
|
||||
def classify_nodes(joint_module):
|
||||
name_to_node = get_name_to_node(joint_module.graph)
|
||||
required_bw_nodes = set()
|
||||
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
|
||||
for node in joint_module.graph.nodes:
|
||||
if node.op == "placeholder" and "tangents" in node.target:
|
||||
required_bw_nodes.add(node)
|
||||
|
|
@ -1800,16 +1808,16 @@ def min_cut_rematerialization_partition(
|
|||
forward_only_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph, inputs, fwd_outputs, "forward"
|
||||
)
|
||||
required_fw_nodes: set[fx.Node] = {
|
||||
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
name_to_node[node.name]
|
||||
for node in forward_only_graph.nodes
|
||||
if node.op != "output"
|
||||
}
|
||||
unclaimed_nodes = {
|
||||
)
|
||||
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
node
|
||||
for node in joint_module.graph.nodes
|
||||
if node not in required_fw_nodes and node not in required_bw_nodes
|
||||
}
|
||||
)
|
||||
fw_cnt = 0
|
||||
fw_order = {}
|
||||
for node in joint_module.graph.nodes:
|
||||
|
|
@ -1879,12 +1887,12 @@ def min_cut_rematerialization_partition(
|
|||
|
||||
# Log theoretical per activation storage sizes
|
||||
log.info("Theoretical Per Activation Storage Sizes: %s", sorted_sizes)
|
||||
fw_module_nodes = {
|
||||
fw_module_nodes = OrderedSet(
|
||||
node.name for node in fw_module.graph.nodes if node.op == "call_function"
|
||||
}
|
||||
bw_module_nodes = {
|
||||
)
|
||||
bw_module_nodes = OrderedSet(
|
||||
node.name for node in bw_module.graph.nodes if node.op == "call_function"
|
||||
}
|
||||
)
|
||||
remat_nodes = fw_module_nodes & bw_module_nodes
|
||||
|
||||
counts: dict[str, int] = defaultdict(int)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user