pytorch/torch/fx/passes/annotate_getitem_nodes.py
Angel Yang d7f3986314 Fix S367052 to unblock ICVR MC3 (#109853)
Summary: Somehow "getitem" started to get Tensor starting from ads_ranking:996 and broke SDD pipelining FX-transformer. We need to skip the Tensor node in annotation.

Test Plan:
N4326037

# Before
 {F1099052907}

# With this diff

 {F1099052270}

Differential Revision: D49528046

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109853
Approved by: https://github.com/jackiexu1992, https://github.com/lanza, https://github.com/xush6528
2023-09-23 00:23:42 +00:00

45 lines
1.9 KiB
Python

import operator
import torch
def annotate_getitem_nodes(graph: torch.fx.Graph) -> None:
"""
Annotate the type of getitem nodes, inferred from the type of sequence node.
If sequence node is not annotated with a type, do nothing.
Currently support getitem nodes from Tuple, List, and NamedTuple sequence node.
This is helpful since annotations on local names within function are lost during FX transforms.
Adding back known type annotation for getitem nodes to improve jit scriptability.
Args:
graph (Graph): The graph to be annotated
"""
for node in graph.nodes:
if node.target == operator.getitem:
sequence_node, index_node = node.args
if not sequence_node.type:
continue
# container types
if hasattr(sequence_node.type, "_name"):
parameterized_types = sequence_node.type.__args__
if sequence_node.type._name == "Tensor":
continue
elif sequence_node.type._name == "Tuple":
if len(parameterized_types) == 2 and isinstance(
parameterized_types[1], type(...)
):
node.type = parameterized_types[0]
else:
assert len(parameterized_types) > index_node
node_type = parameterized_types[index_node]
node.type = node_type
elif sequence_node.type._name == "List":
assert len(parameterized_types) == 1
node.type = parameterized_types[0]
# NamedTuple type
elif hasattr(sequence_node.type, "__annotations__"):
sequence_node_field_types = sequence_node.type.__annotations__
field_name = sequence_node.type._fields[index_node]
node.type = sequence_node_field_types[field_name]