Use OrderedSet in _functorch/partitioners (#146102)

In an attempt to make partitioning more deterministic, change all sets in partitioners.py to OrderedSets. Note that this change does not fix the non-determinism we're seeing in the internal model. But let's at least eliminate this potential source of non-determinism before investigating any changes to the mincut approach?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146102
Approved by: https://github.com/oulgen
This commit is contained in:
Sam Larsen 2025-02-03 16:02:20 -08:00 committed by PyTorch MergeBot
parent 53759ccca8
commit 23fffb54d5
2 changed files with 57 additions and 48 deletions

View File

@ -1703,6 +1703,7 @@ command = [
]
include_patterns = [
"torch/_inductor/**/*.py",
"torch/_functorch/partitioners.py",
]
is_formatter = true

View File

@ -25,6 +25,7 @@ from torch.fx.experimental.symbolic_shapes import (
is_symbol_binding_fx_node,
)
from torch.fx.passes import graph_drawer
from torch.utils._ordered_set import OrderedSet
from torch.utils.checkpoint import CheckpointPolicy
from . import config
@ -55,11 +56,11 @@ prims = torch.ops.prims
class OpTypes:
"""Class for keeping track of different operator categories"""
fusible_ops: set[Callable]
compute_intensive_ops: set[Callable]
random_ops: set[Callable]
view_ops: set[Callable]
recomputable_ops: set[Callable]
fusible_ops: OrderedSet[Callable]
compute_intensive_ops: OrderedSet[Callable]
random_ops: OrderedSet[Callable]
view_ops: OrderedSet[Callable]
recomputable_ops: OrderedSet[Callable]
def is_fusible(self, node: fx.Node):
return get_aten_target(node) in self.fusible_ops
@ -82,9 +83,9 @@ class NodeInfo:
# Be careful about iterating over these explicitly, as their order may not
# be deterministic
inputs: list[fx.Node]
_required_fw_nodes: set[fx.Node]
required_bw_nodes: set[fx.Node]
unclaimed_nodes: set[fx.Node]
_required_fw_nodes: OrderedSet[fx.Node]
required_bw_nodes: OrderedSet[fx.Node]
unclaimed_nodes: OrderedSet[fx.Node]
fw_order: dict[fx.Node, int]
@functools.cached_property
@ -326,7 +327,7 @@ def _extract_fwd_bwd_modules(
# we propagate all symbols which are referenced by backwards inputs.
# These are not directly used in the graph but are required for downstream
# sizevar assignment
saved_symbols: set[sympy.Symbol] = set()
saved_symbols: OrderedSet[sympy.Symbol] = OrderedSet()
saved_sym_nodes_binding = []
saved_sym_nodes_derived = []
@ -426,9 +427,9 @@ def default_partition(
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, "forward"
)
forward_node_names = {
forward_node_names = OrderedSet(
node.name for node in forward_only_graph.nodes if node.op != "output"
}
)
saved_values = []
saved_sym_nodes = []
@ -580,7 +581,7 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule:
def insert_node_in_graph(node):
cur_nodes = [node]
insertable_nodes = set()
insertable_nodes: OrderedSet[fx.Node] = OrderedSet()
while len(cur_nodes) > 0:
node = cur_nodes.pop()
if node in insertable_nodes or node in env:
@ -817,19 +818,21 @@ def solve_min_cut(
joint_graph: fx.Graph,
node_info: NodeInfo,
min_cut_options: MinCutOptions,
dont_ban=None,
dont_ban: Optional[OrderedSet[fx.Node]] = None,
):
if dont_ban is None:
dont_ban = set()
dont_ban = OrderedSet()
op_types = get_default_op_list()
if AOT_PARTITIONER_DEBUG:
joint_module_ops = {
joint_module_ops = OrderedSet(
str(node.target._overloadpacket)
for node in joint_graph.nodes
if node.op == "call_function" and hasattr(node.target, "_overloadpacket")
}
ops_ignored = joint_module_ops - {str(i) for i in op_types.recomputable_ops}
)
ops_ignored = joint_module_ops - OrderedSet(
str(i) for i in op_types.recomputable_ops
)
log.info("Ops banned from re-materialization: %s", ops_ignored)
def can_fuse_into_auto_functionalized(a, b):
@ -888,7 +891,7 @@ def solve_min_cut(
def is_materialized_backwards(node):
if op_types.is_view(node):
return False
cur_nodes = {node}
cur_nodes = OrderedSet([node])
while len(cur_nodes) > 0:
cur = cur_nodes.pop()
for user in cur.users:
@ -981,7 +984,7 @@ def solve_min_cut(
return mem_sz * 2
nx_graph = nx.DiGraph()
banned_nodes = set()
banned_nodes: OrderedSet[fx.Node] = OrderedSet()
def ban_recomputation_if_allowed(node):
if op_types.is_view(node):
@ -1091,12 +1094,13 @@ def solve_min_cut(
if node_info.is_required_fw(user):
if node_info.get_fw_order(user) > max_range:
continue
val = (node_info.get_fw_order(user), user, is_fusible(node, user))
if val not in sorted_nodes:
heapq.heappush(
sorted_nodes,
val,
val: tuple[int, fx.Node, bool] = (
node_info.get_fw_order(user),
user,
is_fusible(node, user),
)
if val not in sorted_nodes:
heapq.heappush(sorted_nodes, val)
return max_range
if min_cut_options.ban_if_used_far_apart:
@ -1141,11 +1145,13 @@ def solve_min_cut(
# Some models it improves perf on are cait_m36_384, mixer_b16_224, poolformer_m36
if min_cut_options.ban_if_long_fusible_chains:
visited = set()
visited: OrderedSet[fx.Node] = OrderedSet()
for start_node in joint_graph.nodes:
if not node_info.is_required_fw(start_node):
continue
fusible = [(node_info.get_fw_order(start_node), start_node)]
fusible: list[tuple[int, fx.Node]] = [
(node_info.get_fw_order(start_node), start_node)
]
start_order = node_info.get_fw_order(start_node)
while len(fusible) > 0:
_, cur = heapq.heappop(fusible)
@ -1184,11 +1190,11 @@ def solve_min_cut(
raise
reachable, non_reachable = partition
cutset: set[tuple[str, str]] = set()
cutset: OrderedSet[tuple[str, str]] = OrderedSet()
for u, nbrs in ((n, nx_graph[n]) for n in reachable):
cutset.update((u, v) for v in nbrs if v in non_reachable)
cut_nodes = set()
cut_nodes: OrderedSet[str] = OrderedSet()
for node_in, node_out in cutset:
assert node_in[:-3] == node_out[:-4]
node_name = node_in[:-3]
@ -1358,9 +1364,9 @@ def get_default_op_list() -> OpTypes:
]
default_recomputable_ops += [method_to_operator(m) for m in magic_methods]
recomputable_ops = set(default_recomputable_ops)
recomputable_ops = OrderedSet(default_recomputable_ops)
random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like]
random_ops = OrderedSet([aten.native_dropout, aten.rand_like, aten.randn_like])
compute_intensive_ops = [
aten.mm,
aten.convolution,
@ -1375,13 +1381,13 @@ def get_default_op_list() -> OpTypes:
aten._scaled_mm,
] # noqa: E501,B950
fusible_ops = recomputable_ops | set(random_ops)
fusible_ops = recomputable_ops | random_ops
return OpTypes(
set(fusible_ops),
set(compute_intensive_ops),
set(random_ops),
set(view_ops),
set(recomputable_ops),
fusible_ops,
OrderedSet(compute_intensive_ops),
random_ops,
OrderedSet(view_ops),
recomputable_ops,
)
@ -1567,9 +1573,11 @@ def choose_saved_values_set(
from torch._inductor.fx_utils import get_node_storage
input_storages = {get_node_storage(node) for node in node_info.inputs}
input_storages = OrderedSet(get_node_storage(node) for node in node_info.inputs)
def get_recomputable_banned_nodes(banned_nodes: set[fx.Node]) -> list[fx.Node]:
def get_recomputable_banned_nodes(
banned_nodes: OrderedSet[fx.Node],
) -> list[fx.Node]:
return [
i
for i in banned_nodes
@ -1653,7 +1661,7 @@ Activation Checkpointing - Knapsack Problem Summary:
payload_fn=lambda: knapsack_summary,
)
log.info(knapsack_summary)
dont_ban = set()
dont_ban: OrderedSet[fx.Node] = OrderedSet()
for idx in recomputable_node_idxs:
# if idx in all_recomputable_banned_nodes:
try:
@ -1776,7 +1784,7 @@ def min_cut_rematerialization_partition(
def classify_nodes(joint_module):
name_to_node = get_name_to_node(joint_module.graph)
required_bw_nodes = set()
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
for node in joint_module.graph.nodes:
if node.op == "placeholder" and "tangents" in node.target:
required_bw_nodes.add(node)
@ -1800,16 +1808,16 @@ def min_cut_rematerialization_partition(
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, "forward"
)
required_fw_nodes: set[fx.Node] = {
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
name_to_node[node.name]
for node in forward_only_graph.nodes
if node.op != "output"
}
unclaimed_nodes = {
)
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
node
for node in joint_module.graph.nodes
if node not in required_fw_nodes and node not in required_bw_nodes
}
)
fw_cnt = 0
fw_order = {}
for node in joint_module.graph.nodes:
@ -1879,12 +1887,12 @@ def min_cut_rematerialization_partition(
# Log theoretical per activation storage sizes
log.info("Theoretical Per Activation Storage Sizes: %s", sorted_sizes)
fw_module_nodes = {
fw_module_nodes = OrderedSet(
node.name for node in fw_module.graph.nodes if node.op == "call_function"
}
bw_module_nodes = {
)
bw_module_nodes = OrderedSet(
node.name for node in bw_module.graph.nodes if node.op == "call_function"
}
)
remat_nodes = fw_module_nodes & bw_module_nodes
counts: dict[str, int] = defaultdict(int)