Use OrderedSet in _functorch/partitioners (#146102)

In an attempt to make partitioning more deterministic, change all sets in partitioners.py to OrderedSets. Note that this change does not fix the non-determinism we're seeing in the internal model. But let's at least eliminate this potential source of non-determinism before investigating any changes to the mincut approach?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146102
Approved by: https://github.com/oulgen
This commit is contained in:
Sam Larsen 2025-02-03 16:02:20 -08:00 committed by PyTorch MergeBot
parent 53759ccca8
commit 23fffb54d5
2 changed files with 57 additions and 48 deletions

View File

@ -1703,6 +1703,7 @@ command = [
] ]
include_patterns = [ include_patterns = [
"torch/_inductor/**/*.py", "torch/_inductor/**/*.py",
"torch/_functorch/partitioners.py",
] ]
is_formatter = true is_formatter = true

View File

@ -25,6 +25,7 @@ from torch.fx.experimental.symbolic_shapes import (
is_symbol_binding_fx_node, is_symbol_binding_fx_node,
) )
from torch.fx.passes import graph_drawer from torch.fx.passes import graph_drawer
from torch.utils._ordered_set import OrderedSet
from torch.utils.checkpoint import CheckpointPolicy from torch.utils.checkpoint import CheckpointPolicy
from . import config from . import config
@ -55,11 +56,11 @@ prims = torch.ops.prims
class OpTypes: class OpTypes:
"""Class for keeping track of different operator categories""" """Class for keeping track of different operator categories"""
fusible_ops: set[Callable] fusible_ops: OrderedSet[Callable]
compute_intensive_ops: set[Callable] compute_intensive_ops: OrderedSet[Callable]
random_ops: set[Callable] random_ops: OrderedSet[Callable]
view_ops: set[Callable] view_ops: OrderedSet[Callable]
recomputable_ops: set[Callable] recomputable_ops: OrderedSet[Callable]
def is_fusible(self, node: fx.Node): def is_fusible(self, node: fx.Node):
return get_aten_target(node) in self.fusible_ops return get_aten_target(node) in self.fusible_ops
@ -82,9 +83,9 @@ class NodeInfo:
# Be careful about iterating over these explicitly, as their order may not # Be careful about iterating over these explicitly, as their order may not
# be deterministic # be deterministic
inputs: list[fx.Node] inputs: list[fx.Node]
_required_fw_nodes: set[fx.Node] _required_fw_nodes: OrderedSet[fx.Node]
required_bw_nodes: set[fx.Node] required_bw_nodes: OrderedSet[fx.Node]
unclaimed_nodes: set[fx.Node] unclaimed_nodes: OrderedSet[fx.Node]
fw_order: dict[fx.Node, int] fw_order: dict[fx.Node, int]
@functools.cached_property @functools.cached_property
@ -326,7 +327,7 @@ def _extract_fwd_bwd_modules(
# we propagate all symbols which are referenced by backwards inputs. # we propagate all symbols which are referenced by backwards inputs.
# These are not directly used in the graph but are required for downstream # These are not directly used in the graph but are required for downstream
# sizevar assignment # sizevar assignment
saved_symbols: set[sympy.Symbol] = set() saved_symbols: OrderedSet[sympy.Symbol] = OrderedSet()
saved_sym_nodes_binding = [] saved_sym_nodes_binding = []
saved_sym_nodes_derived = [] saved_sym_nodes_derived = []
@ -426,9 +427,9 @@ def default_partition(
forward_only_graph = _extract_graph_with_inputs_outputs( forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, "forward" joint_module.graph, inputs, fwd_outputs, "forward"
) )
forward_node_names = { forward_node_names = OrderedSet(
node.name for node in forward_only_graph.nodes if node.op != "output" node.name for node in forward_only_graph.nodes if node.op != "output"
} )
saved_values = [] saved_values = []
saved_sym_nodes = [] saved_sym_nodes = []
@ -580,7 +581,7 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule:
def insert_node_in_graph(node): def insert_node_in_graph(node):
cur_nodes = [node] cur_nodes = [node]
insertable_nodes = set() insertable_nodes: OrderedSet[fx.Node] = OrderedSet()
while len(cur_nodes) > 0: while len(cur_nodes) > 0:
node = cur_nodes.pop() node = cur_nodes.pop()
if node in insertable_nodes or node in env: if node in insertable_nodes or node in env:
@ -817,19 +818,21 @@ def solve_min_cut(
joint_graph: fx.Graph, joint_graph: fx.Graph,
node_info: NodeInfo, node_info: NodeInfo,
min_cut_options: MinCutOptions, min_cut_options: MinCutOptions,
dont_ban=None, dont_ban: Optional[OrderedSet[fx.Node]] = None,
): ):
if dont_ban is None: if dont_ban is None:
dont_ban = set() dont_ban = OrderedSet()
op_types = get_default_op_list() op_types = get_default_op_list()
if AOT_PARTITIONER_DEBUG: if AOT_PARTITIONER_DEBUG:
joint_module_ops = { joint_module_ops = OrderedSet(
str(node.target._overloadpacket) str(node.target._overloadpacket)
for node in joint_graph.nodes for node in joint_graph.nodes
if node.op == "call_function" and hasattr(node.target, "_overloadpacket") if node.op == "call_function" and hasattr(node.target, "_overloadpacket")
} )
ops_ignored = joint_module_ops - {str(i) for i in op_types.recomputable_ops} ops_ignored = joint_module_ops - OrderedSet(
str(i) for i in op_types.recomputable_ops
)
log.info("Ops banned from re-materialization: %s", ops_ignored) log.info("Ops banned from re-materialization: %s", ops_ignored)
def can_fuse_into_auto_functionalized(a, b): def can_fuse_into_auto_functionalized(a, b):
@ -888,7 +891,7 @@ def solve_min_cut(
def is_materialized_backwards(node): def is_materialized_backwards(node):
if op_types.is_view(node): if op_types.is_view(node):
return False return False
cur_nodes = {node} cur_nodes = OrderedSet([node])
while len(cur_nodes) > 0: while len(cur_nodes) > 0:
cur = cur_nodes.pop() cur = cur_nodes.pop()
for user in cur.users: for user in cur.users:
@ -981,7 +984,7 @@ def solve_min_cut(
return mem_sz * 2 return mem_sz * 2
nx_graph = nx.DiGraph() nx_graph = nx.DiGraph()
banned_nodes = set() banned_nodes: OrderedSet[fx.Node] = OrderedSet()
def ban_recomputation_if_allowed(node): def ban_recomputation_if_allowed(node):
if op_types.is_view(node): if op_types.is_view(node):
@ -1091,12 +1094,13 @@ def solve_min_cut(
if node_info.is_required_fw(user): if node_info.is_required_fw(user):
if node_info.get_fw_order(user) > max_range: if node_info.get_fw_order(user) > max_range:
continue continue
val = (node_info.get_fw_order(user), user, is_fusible(node, user)) val: tuple[int, fx.Node, bool] = (
node_info.get_fw_order(user),
user,
is_fusible(node, user),
)
if val not in sorted_nodes: if val not in sorted_nodes:
heapq.heappush( heapq.heappush(sorted_nodes, val)
sorted_nodes,
val,
)
return max_range return max_range
if min_cut_options.ban_if_used_far_apart: if min_cut_options.ban_if_used_far_apart:
@ -1141,11 +1145,13 @@ def solve_min_cut(
# Some models it improves perf on are cait_m36_384, mixer_b16_224, poolformer_m36 # Some models it improves perf on are cait_m36_384, mixer_b16_224, poolformer_m36
if min_cut_options.ban_if_long_fusible_chains: if min_cut_options.ban_if_long_fusible_chains:
visited = set() visited: OrderedSet[fx.Node] = OrderedSet()
for start_node in joint_graph.nodes: for start_node in joint_graph.nodes:
if not node_info.is_required_fw(start_node): if not node_info.is_required_fw(start_node):
continue continue
fusible = [(node_info.get_fw_order(start_node), start_node)] fusible: list[tuple[int, fx.Node]] = [
(node_info.get_fw_order(start_node), start_node)
]
start_order = node_info.get_fw_order(start_node) start_order = node_info.get_fw_order(start_node)
while len(fusible) > 0: while len(fusible) > 0:
_, cur = heapq.heappop(fusible) _, cur = heapq.heappop(fusible)
@ -1184,11 +1190,11 @@ def solve_min_cut(
raise raise
reachable, non_reachable = partition reachable, non_reachable = partition
cutset: set[tuple[str, str]] = set() cutset: OrderedSet[tuple[str, str]] = OrderedSet()
for u, nbrs in ((n, nx_graph[n]) for n in reachable): for u, nbrs in ((n, nx_graph[n]) for n in reachable):
cutset.update((u, v) for v in nbrs if v in non_reachable) cutset.update((u, v) for v in nbrs if v in non_reachable)
cut_nodes = set() cut_nodes: OrderedSet[str] = OrderedSet()
for node_in, node_out in cutset: for node_in, node_out in cutset:
assert node_in[:-3] == node_out[:-4] assert node_in[:-3] == node_out[:-4]
node_name = node_in[:-3] node_name = node_in[:-3]
@ -1358,9 +1364,9 @@ def get_default_op_list() -> OpTypes:
] ]
default_recomputable_ops += [method_to_operator(m) for m in magic_methods] default_recomputable_ops += [method_to_operator(m) for m in magic_methods]
recomputable_ops = set(default_recomputable_ops) recomputable_ops = OrderedSet(default_recomputable_ops)
random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like] random_ops = OrderedSet([aten.native_dropout, aten.rand_like, aten.randn_like])
compute_intensive_ops = [ compute_intensive_ops = [
aten.mm, aten.mm,
aten.convolution, aten.convolution,
@ -1375,13 +1381,13 @@ def get_default_op_list() -> OpTypes:
aten._scaled_mm, aten._scaled_mm,
] # noqa: E501,B950 ] # noqa: E501,B950
fusible_ops = recomputable_ops | set(random_ops) fusible_ops = recomputable_ops | random_ops
return OpTypes( return OpTypes(
set(fusible_ops), fusible_ops,
set(compute_intensive_ops), OrderedSet(compute_intensive_ops),
set(random_ops), random_ops,
set(view_ops), OrderedSet(view_ops),
set(recomputable_ops), recomputable_ops,
) )
@ -1567,9 +1573,11 @@ def choose_saved_values_set(
from torch._inductor.fx_utils import get_node_storage from torch._inductor.fx_utils import get_node_storage
input_storages = {get_node_storage(node) for node in node_info.inputs} input_storages = OrderedSet(get_node_storage(node) for node in node_info.inputs)
def get_recomputable_banned_nodes(banned_nodes: set[fx.Node]) -> list[fx.Node]: def get_recomputable_banned_nodes(
banned_nodes: OrderedSet[fx.Node],
) -> list[fx.Node]:
return [ return [
i i
for i in banned_nodes for i in banned_nodes
@ -1653,7 +1661,7 @@ Activation Checkpointing - Knapsack Problem Summary:
payload_fn=lambda: knapsack_summary, payload_fn=lambda: knapsack_summary,
) )
log.info(knapsack_summary) log.info(knapsack_summary)
dont_ban = set() dont_ban: OrderedSet[fx.Node] = OrderedSet()
for idx in recomputable_node_idxs: for idx in recomputable_node_idxs:
# if idx in all_recomputable_banned_nodes: # if idx in all_recomputable_banned_nodes:
try: try:
@ -1776,7 +1784,7 @@ def min_cut_rematerialization_partition(
def classify_nodes(joint_module): def classify_nodes(joint_module):
name_to_node = get_name_to_node(joint_module.graph) name_to_node = get_name_to_node(joint_module.graph)
required_bw_nodes = set() required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
for node in joint_module.graph.nodes: for node in joint_module.graph.nodes:
if node.op == "placeholder" and "tangents" in node.target: if node.op == "placeholder" and "tangents" in node.target:
required_bw_nodes.add(node) required_bw_nodes.add(node)
@ -1800,16 +1808,16 @@ def min_cut_rematerialization_partition(
forward_only_graph = _extract_graph_with_inputs_outputs( forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, "forward" joint_module.graph, inputs, fwd_outputs, "forward"
) )
required_fw_nodes: set[fx.Node] = { required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
name_to_node[node.name] name_to_node[node.name]
for node in forward_only_graph.nodes for node in forward_only_graph.nodes
if node.op != "output" if node.op != "output"
} )
unclaimed_nodes = { unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
node node
for node in joint_module.graph.nodes for node in joint_module.graph.nodes
if node not in required_fw_nodes and node not in required_bw_nodes if node not in required_fw_nodes and node not in required_bw_nodes
} )
fw_cnt = 0 fw_cnt = 0
fw_order = {} fw_order = {}
for node in joint_module.graph.nodes: for node in joint_module.graph.nodes:
@ -1879,12 +1887,12 @@ def min_cut_rematerialization_partition(
# Log theoretical per activation storage sizes # Log theoretical per activation storage sizes
log.info("Theoretical Per Activation Storage Sizes: %s", sorted_sizes) log.info("Theoretical Per Activation Storage Sizes: %s", sorted_sizes)
fw_module_nodes = { fw_module_nodes = OrderedSet(
node.name for node in fw_module.graph.nodes if node.op == "call_function" node.name for node in fw_module.graph.nodes if node.op == "call_function"
} )
bw_module_nodes = { bw_module_nodes = OrderedSet(
node.name for node in bw_module.graph.nodes if node.op == "call_function" node.name for node in bw_module.graph.nodes if node.op == "call_function"
} )
remat_nodes = fw_module_nodes & bw_module_nodes remat_nodes = fw_module_nodes & bw_module_nodes
counts: dict[str, int] = defaultdict(int) counts: dict[str, int] = defaultdict(int)