Use 'is' in callable comparisons (#166624)

Just like we use `is/is not` for class comparisons, it is generally advised to use `is/is not` for comparisons against torch functions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166624
Approved by: https://github.com/Lucaskabela, https://github.com/Skylion007
This commit is contained in:
Yuanyuan Chen 2025-10-30 19:00:04 +00:00 committed by PyTorch MergeBot
parent 639a0b1239
commit 694db5f549
45 changed files with 105 additions and 105 deletions

View File

@ -369,7 +369,7 @@ def relu_compile_error_TESTING_ONLY(
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
if node.target == torch.relu:
if node.target is torch.relu:
raise ReluCompileError
return gm
@ -379,7 +379,7 @@ def relu_runtime_error_TESTING_ONLY(
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
if node.target == torch.relu:
if node.target is torch.relu:
node.target = torch._assert
node.args = (False, "ReluRuntimeError")
gm.recompile()
@ -391,7 +391,7 @@ def relu_accuracy_error_TESTING_ONLY(
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
if node.target == torch.relu:
if node.target is torch.relu:
node.target = torch.add
node.args = (node.args[0], 1)
gm.recompile()

View File

@ -573,7 +573,7 @@ class AutogradCompilerInstance:
result.name = make_unique(node.name)
value_remap[node] = result
elif node.op == "call_function":
if node.target == torch.ops.aten.view.default:
if node.target is torch.ops.aten.view.default:
# this aot bwd graph is being lazily compiled
# we must manually apply the view_to_reshape post grad pass
# since it was already applied to the aot fwd, and baked into the gradients
@ -1241,7 +1241,7 @@ class AutogradCompilerInstance:
to_append = []
hook_block = [node] # contain the hook and hook args getitem
for n in input_nodes:
if n.op == "call_function" and n.target == operator.getitem:
if n.op == "call_function" and n.target is operator.getitem:
to_append.append(n.args[0])
to_remove.append(n)
hook_block.append(n)
@ -1314,7 +1314,7 @@ class AutogradCompilerInstance:
# find the corresponding acc_grad node
acc_grad_node = None
for n in list(param_node.users.keys()):
if n.op == "call_function" and n.target == call_accumulate_grad:
if n.op == "call_function" and n.target is call_accumulate_grad:
acc_grad_node = n
break
@ -1369,7 +1369,7 @@ class AutogradCompilerInstance:
for n in list(param_node.users.keys()):
if (
n.op == "call_function"
and n.target == call_hook
and n.target is call_hook
and n.kwargs.get("hook_type", None) == "post_acc_grad_hook"
):
post_acc_grad_hook_node = n

View File

@ -455,7 +455,7 @@ def cast_dtype_args_to_fp64(model: torch.fx.GraphModule) -> torch.fx.GraphModule
for node in model.graph.nodes:
if (
node.op == "call_function"
and node.target == torch.ops.prims.convert_element_type.default
and node.target is torch.ops.prims.convert_element_type.default
):
assert len(node.args) == 2
if is_float_dtype(node.args[1]) and node.args[1] != torch.float64:

View File

@ -2930,12 +2930,12 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
hasattr(proxy.node.target, "__name__")
and proxy.node.target.__name__ == "set_state"
and isinstance(proxy.node.target.__self__, torch._C.Generator)
or proxy.node.target == torch.random.set_rng_state
or proxy.node.target is torch.random.set_rng_state
):
return TorchInGraphFunctionVariable(proxy.node.target)
elif (
proxy.node.target == torch._C._DisableFuncTorch
or proxy.node.target == torch.cuda._is_in_bad_fork
proxy.node.target is torch._C._DisableFuncTorch
or proxy.node.target is torch.cuda._is_in_bad_fork
):
return UserDefinedObjectVariable(example_value)
elif istype(example_value, torch.Size) and all(

View File

@ -2595,7 +2595,7 @@ class CheckpointHigherOrderVariable(WrapHigherOrderVariable):
from torch.utils.checkpoint import noop_context_fn
context_fn = None
if "context_fn" in kwargs and kwargs["context_fn"] != noop_context_fn:
if "context_fn" in kwargs and kwargs["context_fn"] is not noop_context_fn:
ctx = kwargs.pop("context_fn")
if isinstance(ctx, torch._dynamo.variables.UserFunctionVariable):
context_fn = ctx.fn

View File

@ -48,7 +48,7 @@ class CollectTracepointsPass(PassBase):
for node in module.graph.nodes:
if node.op != "call_function":
continue
if node.target == torch.ops.higher_order._export_tracepoint:
if node.target is torch.ops.higher_order._export_tracepoint:
kind = node.kwargs["kind"]
if kind == "module_call_outputs":
nn_module_stack = node.meta["nn_module_stack"]
@ -64,7 +64,7 @@ class CollectTracepointsPass(PassBase):
for node in reversed(module.graph.nodes):
if node.op != "call_function":
continue
if node.target == torch.ops.higher_order._export_tracepoint:
if node.target is torch.ops.higher_order._export_tracepoint:
kind = node.kwargs["kind"]
if kind == "module_call_inputs":
nn_module_stack = node.meta["nn_module_stack"]
@ -94,7 +94,7 @@ class CollectTracepointsPass(PassBase):
for node in module.graph.nodes:
if node.op != "call_function":
continue
if node.target == torch.ops.higher_order._export_tracepoint:
if node.target is torch.ops.higher_order._export_tracepoint:
# There's some subtlety worth noting. Here fqn corresponds to
# the call name, whereas path corresponds to the module name.
# They are not necessarily the same! When a submodule is shared

View File

@ -65,7 +65,7 @@ class ConstantFolder(torch.fx.Interpreter):
def is_impure(self, node: torch.fx.Node) -> bool:
if (
node.target == torch.ops.prims.convert_element_type.default
node.target is torch.ops.prims.convert_element_type.default
and node.args[0].op == "get_attr" # type: ignore[union-attr]
and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr]
and node.args[1] == torch.bfloat16

View File

@ -34,7 +34,7 @@ def _is_enter_autocast_node(node: torch.fx.Node) -> Union[torch.fx.Node, bool]:
return (
node
and node.op == "call_function"
and node.target == torch.amp.autocast_mode._enter_autocast
and node.target is torch.amp.autocast_mode._enter_autocast
)
@ -42,7 +42,7 @@ def _is_exit_autocast_node(node: torch.fx.Node) -> Union[torch.fx.Node, bool]:
return (
node
and node.op == "call_function"
and node.target == torch.amp.autocast_mode._exit_autocast
and node.target is torch.amp.autocast_mode._exit_autocast
)

View File

@ -22,7 +22,7 @@ def _is_set_grad_enabled_node(node: torch.fx.Node) -> Union[torch.fx.Node, bool]
return (
node
and node.op == "call_function"
and node.target == torch._C._set_grad_enabled
and node.target is torch._C._set_grad_enabled
)

View File

@ -88,7 +88,7 @@ def _staged_schema():
f"std::optional<{cpp_type}>",
f"optional {thrift_type}",
)
elif o == Annotated:
elif o is Annotated:
return dump_type(t.__origin__, level)
else:
raise AssertionError(f"Type {t} is not supported in export schema.")

View File

@ -249,7 +249,7 @@ def maybe_to_fresh_input(idx, t, meta):
def is_with_effects(node):
return (
node.op == "call_function"
and node.target == torch.ops.higher_order.with_effects
and node.target is torch.ops.higher_order.with_effects
)

View File

@ -1256,7 +1256,7 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule:
# Build the graph op-by-op by starting from the node all the way to the end
# copy_ can be not using tangents at all, we must copy it.
for node in list(gm.graph.nodes)[: order[first_node_in_bwd]]:
if node.op == "call_function" and node.target == torch.ops.aten.copy_.default:
if node.op == "call_function" and node.target is torch.ops.aten.copy_.default:
insert_node_in_graph(node)
for node in list(gm.graph.nodes)[order[first_node_in_bwd] :]:
@ -1596,7 +1596,7 @@ def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None:
if node.op == "output":
continue
is_copy_ = node.target == torch.ops.aten.copy_.default
is_copy_ = node.target is torch.ops.aten.copy_.default
if is_copy_:
if _has_tag_must_be_in_backward(node):
has_mutation_in_bw.add(node.args[0])

View File

@ -1393,7 +1393,7 @@ def remove_fsdp2_unsharded_param_graph_input_usage(graph: torch.fx.Graph):
for idx, node in enumerate(node_list):
if (
node.op == "call_function"
and node.target == torch.ops.inductor.resize_storage_bytes_.default
and node.target is torch.ops.inductor.resize_storage_bytes_.default
):
assert node.args[0].op == "placeholder", f"""\
Resize can only operate on graph inputs, but got {node} which is resizing non-graph-input {node.args[0]}
@ -1444,7 +1444,7 @@ Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass for that
# Find all eligible unsharded params and their corresponding graph intermediates.
unsharded_param_to_fsdp_copy_node_idxes = defaultdict(list)
for idx, node in enumerate(node_list):
if node.op == "call_function" and node.target == torch.ops.fsdp.copy_.default:
if node.op == "call_function" and node.target is torch.ops.fsdp.copy_.default:
fsdp_copy_node = node
unsharded_param = node.args[0]
assert unsharded_param.op == "placeholder", f"""
@ -1456,8 +1456,8 @@ Offending node: {unsharded_param}. Graph: {graph}
def is_allowed_mutation(node):
return (
node.target == torch.ops.fsdp.copy_.default
or node.target == torch.ops.inductor.resize_storage_bytes_.default
node.target is torch.ops.fsdp.copy_.default
or node.target is torch.ops.inductor.resize_storage_bytes_.default
)
def is_node_mutating_unsharded_param_or_its_alias(node, unsharded_params):
@ -1558,7 +1558,7 @@ Graph: {graph}
for node in node_list:
if (
node.op == "call_function"
and node.target == torch.ops.inductor.resize_storage_bytes_.default
and node.target is torch.ops.inductor.resize_storage_bytes_.default
and node.args[0] in unsharded_param_to_fsdp_copy_node_idxes
):
graph.erase_node(node)

View File

@ -122,7 +122,7 @@ class ConstantFolder(torch.fx.Interpreter):
def is_impure(self, node: torch.fx.node.Node) -> bool:
def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool:
return (
node.target == torch.ops.prims.convert_element_type.default # type: ignore[return-value]
node.target is torch.ops.prims.convert_element_type.default # type: ignore[return-value]
and isinstance(node.args[0], torch.fx.Node)
and "val" in node.args[0].meta
and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr]
@ -132,7 +132,7 @@ class ConstantFolder(torch.fx.Interpreter):
if (
is_woq_int8_pattern(node)
or (
node.target == torch.ops.aten.permute.default
node.target is torch.ops.aten.permute.default
and len(node.users) == 1
and is_woq_int8_pattern(next(iter(node.users)))
)

View File

@ -591,7 +591,7 @@ def b2b_gemm_handler(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node) ->
)
def is_mm(node: torch.fx.Node) -> bool:
return node.target == torch.ops.aten.mm.default
return node.target is torch.ops.aten.mm.default
# the inner MM
inner_mm = match.nodes[-1]

View File

@ -122,28 +122,28 @@ def bucket_reduce_scatter(
def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: # type: ignore[arg-type]
return (
node.op == "call_function"
and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default
and node.target is torch.ops._c10d_functional.all_gather_into_tensor.default
)
def is_reduce_scatter_tensor(node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and node.target == torch.ops._c10d_functional.reduce_scatter_tensor.default
and node.target is torch.ops._c10d_functional.reduce_scatter_tensor.default
)
def is_wait_tensor(node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and node.target == torch.ops._c10d_functional.wait_tensor.default
and node.target is torch.ops._c10d_functional.wait_tensor.default
)
def is_all_reduce_tensor(node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and node.target == torch.ops._c10d_functional.all_reduce.default
and node.target is torch.ops._c10d_functional.all_reduce.default
)

View File

@ -315,7 +315,7 @@ class UniformValueConstantFolder(ConstantFolder):
# single-elem attrs
if node.op == "get_attr" or (
node.op == "call_function"
and node.target == torch.ops.aten.lift_fresh_copy.default
and node.target is torch.ops.aten.lift_fresh_copy.default
):
out = super(ConstantFolder, self).run_node(node)
if isinstance(out, torch.Tensor) and out.numel() == 1:

View File

@ -737,7 +737,7 @@ def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
input_node = node.kwargs["input"]
if (
input_node.op == "call_function"
and input_node.target == torch.nn.functional.linear
and input_node.target is torch.nn.functional.linear
):
normalized = NormalizedLinearNode(input_node)
input = normalized.get_input()

View File

@ -679,7 +679,7 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
if copy_node is not None:
replace_dict[copy_node] = copy_node.args[0]
node.target = inplaceable_op.inplace_op
elif node.target == torch.ops.higher_order.auto_functionalized_v2:
elif node.target is torch.ops.higher_order.auto_functionalized_v2:
_mutable_op = node.args[0]
kwargs = node.kwargs
@ -696,7 +696,7 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
# auto_functionalized into clones + a mutable op; this metadata
# tells the decomp to only clone the following inputs
node.meta["only_clone_these_tensors"] = new_bases_to_clone
elif node.target == torch.ops.higher_order.auto_functionalized:
elif node.target is torch.ops.higher_order.auto_functionalized:
_mutable_op = node.args[0]
from torch._higher_order_ops.auto_functionalize import get_mutable_args

View File

@ -127,12 +127,12 @@ def _get_dim(node: Any):
if "dim" in node.kwargs:
assert isinstance(node.kwargs["dim"], int)
return node.kwargs["dim"]
if node.target == torch.unbind:
if node.target is torch.unbind:
if len(node.args) == 2:
assert isinstance(node.args[-1], int)
return node.args[-1]
return 0 # defaults to dim=0
if node.target == torch.split:
if node.target is torch.split:
if len(node.args) == 3:
assert isinstance(node.args[-1], int)
return node.args[-1]
@ -351,7 +351,7 @@ def normalize_cat_default(match: Match, *args, **kwargs):
cat_node.args == new_args
and cat_node.kwargs == new_kwargs
and cat_node.op == "call_function"
and cat_node.target == torch.cat
and cat_node.target is torch.cat
):
return
@ -866,7 +866,7 @@ class SplitCatSimplifier:
cat_dim = get_arg_value(user_node, 1, "dim")
transform_params: list[_TransformParam] = []
for user_input in user_inputs:
if split_dim == cat_dim and user_node.target == torch.cat:
if split_dim == cat_dim and user_node.target is torch.cat:
# No transform needed
transform_params.append((None, None, None, None))
elif isinstance(user_input, tuple): # Split being simplified
@ -888,7 +888,7 @@ class SplitCatSimplifier:
(unflatten_params, movedim_params, None, None)
)
elif (
user_node.target == torch.stack or split_dim != cat_dim
user_node.target is torch.stack or split_dim != cat_dim
): # We need to unsqueeze inputs not coming through split
transform_params.append((None, None, (cat_dim,), None))
else: # Non-split inputs
@ -1107,9 +1107,9 @@ class SplitCatSimplifier:
)
if (
user_node.target == torch.cat
user_node.target is torch.cat
and split_dim != cat_dim
and split_node.target == torch.split
and split_node.target is torch.split
):
with graph.inserting_after(new_cat_node):
new_cat_node_meta = new_cat_node.meta["example_value"]
@ -1225,13 +1225,13 @@ class UnbindCatRemover(SplitCatSimplifier):
(split_dim, cat_dim) if split_dim != cat_dim else None
)
flatten_params = None
if user_node.target == torch.cat:
if user_node.target is torch.cat:
flatten_params = (cat_dim, cat_dim + 1)
transform_params.append(
(None, movedim_params, None, flatten_params)
)
elif (
user_node.target == torch.stack
user_node.target is torch.stack
): # We need to unsqueeze inputs not coming through unbind into cat
transform_params.append((None, None, (cat_dim,), None))
else: # Non-unbind inputs
@ -1298,13 +1298,13 @@ def merge_split_squeeze(
match: Match, split_input: torch.fx.Node, split_sizes: list[int], dim: int
):
graph = match.graph
split = next(node for node in match.nodes if node.target == torch.split)
split = next(node for node in match.nodes if node.target is torch.split)
if not all(s == 1 for s in split_sizes):
return
if isinstance(dim, Sequence):
return
next_users = find_next_users(split)
if not all(node.target == torch.squeeze for node in next_users):
if not all(node.target is torch.squeeze for node in next_users):
return
with graph.inserting_before(match.output_node()):
unbind = graph.call_function(
@ -1364,7 +1364,7 @@ getitem_unbind = ListOf(
pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"),
)
def merge_unbind_stack(match: Match, unbind_input: torch.fx.Node, dim: int):
unbind_node = next(node for node in match.nodes if node.target == torch.unbind)
unbind_node = next(node for node in match.nodes if node.target is torch.unbind)
UnbindCatRemover().remove_unbind(match.graph, unbind_node)
@ -1431,7 +1431,7 @@ reshape_getitem_split = ListOf(
def simplify_split_cat(match: Match, split_sections: list[int], dim: int):
if not isinstance(split_sections, (list, tuple)): # Unnormalized split
return
split_node = next(node for node in match.nodes if node.target == torch.split)
split_node = next(node for node in match.nodes if node.target is torch.split)
# pyrefly: ignore [bad-argument-type]
SplitCatSimplifier().simplify(match.graph, split_node, split_sections)
@ -1518,7 +1518,7 @@ def merge_getitem_cat(match: Match, split_sections: list[int], dim: int):
if not isinstance(split_sections, (list, tuple)): # Unnormalized split
return
graph = match.graph
split_node = next(node for node in match.nodes if node.target == torch.split)
split_node = next(node for node in match.nodes if node.target is torch.split)
split_input, _split_size, split_dim = _get_split_args_default(split_node)
# if the cat and split have different dims, return
# Find the next users (i.e. users after the getitem)
@ -1625,7 +1625,7 @@ def mutate_cat_node(match: Match, split_sections: list[int], dim: int):
if not isinstance(split_sections, (list, tuple)): # Unnormalized split
return
graph = match.graph
split_node = next(node for node in match.nodes if node.target == torch.split)
split_node = next(node for node in match.nodes if node.target is torch.split)
_split_input, _split_size, split_dim = _get_split_args_default(split_node)
# if the cat and split have different dims, return
# Find the next users (i.e. users after the getitem)
@ -1904,7 +1904,7 @@ def merge_select_cat_aten(match: Match, *args, **kwargs):
# get the select nodes from the node
select_nodes = list(node_input.users.keys())
for cat_node in list(node.users.keys()):
if cat_node.target == torch.ops.aten.cat.default:
if cat_node.target is torch.ops.aten.cat.default:
cat_dim = get_arg_value(cat_node, 1, "dim")
cat_inputs = get_arg_value(cat_node, 0, "tensors")
# check all select nodes has same slice dim
@ -2020,7 +2020,7 @@ def merge_unbind_stack_aten(match: Match, *args, **kwargs):
parent_of_select_node = get_arg_value(select_nodes[0], 0, "input")
# check the target of select_nodes are the same
if not all(
select_node.target == torch.ops.aten.select for select_node in select_nodes
select_node.target is torch.ops.aten.select for select_node in select_nodes
):
return
# check the select nodes come from the same parent node
@ -2357,7 +2357,7 @@ def remove_split_unbind_children(graph: torch.fx.Graph, inputs: list[torch.fx.No
def split_cat_to_slices(match: Match, split_sections: list[int], dim: int):
if not isinstance(split_sections, (list, tuple)): # Unnormalized split
return
split_nodes = [node for node in match.nodes if node.target == torch.split]
split_nodes = [node for node in match.nodes if node.target is torch.split]
if split_nodes:
split_node = next(node for node in split_nodes)
else:
@ -2438,7 +2438,7 @@ def split_cat_to_slices(match: Match, split_sections: list[int], dim: int):
pass_dict=construct_pattern_matcher_pass("unbind_cat_to_view_pass"),
)
def unbind_cat_to_view(match: Match, unbind_input: torch.fx.Node, dim: int):
unbind_node = next(node for node in match.nodes if node.target == torch.unbind)
unbind_node = next(node for node in match.nodes if node.target is torch.unbind)
graph = match.graph
# get the cat_node and check its inputs and meta data
next_users = find_next_users(unbind_node)
@ -2614,7 +2614,7 @@ def convert_reshape_cat_arg_to_stack(
def split_stack_to_cats(match: Match, split_sections: list[int], dim: int):
if not isinstance(split_sections, (list, tuple)): # Unnormalized split
return
split_node = next(node for node in match.nodes if node.target == torch.split)
split_node = next(node for node in match.nodes if node.target is torch.split)
split_dim = get_arg_value(split_node, 2, "dim") or 0
graph = match.graph
threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[
@ -2685,7 +2685,7 @@ def split_stack_to_cats(match: Match, split_sections: list[int], dim: int):
pass_dict=construct_pattern_matcher_pass("unbind_stack_to_slices_pass"),
)
def unbind_stack_to_slices(match: Match, unbind_input: torch.fx.Node, dim: int):
unbind_node = next(node for node in match.nodes if node.target == torch.unbind)
unbind_node = next(node for node in match.nodes if node.target is torch.unbind)
graph = match.graph
# get the cat_node and check its inputs and meta data
next_users = find_next_users(unbind_node)
@ -2785,10 +2785,10 @@ def get_view_shape_list(cat_arg: torch.fx.Node, stack_dim: int) -> list[int]:
pass_dict=construct_pattern_matcher_pass("move_reshape_out_of_split_stack_pass"),
)
def move_reshape_out_of_split_stack(match: Match, *args, **kwargs):
split_node = next(node for node in match.nodes if node.target == torch.split)
split_node = next(node for node in match.nodes if node.target is torch.split)
split_dim = _get_dim(split_node)
split_users = list(split_node.users.keys())
stack_nodes = [node for node in match.nodes if node.target == torch.stack]
stack_nodes = [node for node in match.nodes if node.target is torch.stack]
graph = match.graph
threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[
"move_reshape_out_of_split_stack_pass"
@ -2926,7 +2926,7 @@ def move_view_after_cat(match: Match, *args, **kwargs):
split_node = next(
node
for node in match.nodes
if node.target == torch.ops.aten.split_with_sizes.default
if node.target is torch.ops.aten.split_with_sizes.default
)
split_input, split_section, split_dim = _get_split_args_default(split_node)
split_users = list(split_node.users.keys())
@ -2936,7 +2936,7 @@ def move_view_after_cat(match: Match, *args, **kwargs):
if not is_sorted_and_consecutive(getitem_indices): # type: ignore[arg-type]
return
cat_nodes = [
node for node in match.nodes if node.target == torch.ops.aten.cat.default
node for node in match.nodes if node.target is torch.ops.aten.cat.default
]
graph = match.graph
for cat_node in cat_nodes:
@ -2950,7 +2950,7 @@ def move_view_after_cat(match: Match, *args, **kwargs):
continue
# check if the cat inputs are all the view nodes
if not all(
view_node.target == torch.ops.aten.reshape.default
view_node.target is torch.ops.aten.reshape.default
for view_node in cat_inputs
):
continue

View File

@ -5849,7 +5849,7 @@ class ExternKernel(InputsKernel):
if shape_env := V.fake_mode.shape_env:
node_meta_val = V.current_node.meta.get("val")
ctx: AbstractContextManager[None] = nullcontext()
if V.current_node.target == torch._higher_order_ops.effects.with_effects:
if V.current_node.target is torch._higher_order_ops.effects.with_effects:
# remove the first effect token in meta["val"] and meta["unbacked_bindings"]
node_meta_val = node_meta_val[1]
ctx = _remove_effect_token_unbacked_bindings(V.current_node)

View File

@ -677,7 +677,7 @@ def _convolution(
def constrain_conv_to_fx_strides(fx_node, *args, **kwargs):
assert fx_node.target == torch.ops.aten.convolution.default
assert fx_node.target is torch.ops.aten.convolution.default
if V.graph.layout_opt:
return args, kwargs
else:

View File

@ -478,7 +478,7 @@ def validate_joint_graph(joint_graph: torch.fx.Graph):
for node in joint_graph.nodes:
if (
node.op == "call_function"
and node.target == torch.ops.flex_lib.zeros_and_scatter.default
and node.target is torch.ops.flex_lib.zeros_and_scatter.default
):
for user in node.users:
if user.op != "output":

View File

@ -3922,7 +3922,7 @@ def index_put_impl_(self, indices, values, accumulate, check, may_realize=False)
isinstance(indice, ir.StorageBox)
and isinstance(indice.data, ir.ExternKernel)
and getattr(indice.data, "fx_node", None)
and indice.data.fx_node.target == torch.ops.aten.randperm.default
and indice.data.fx_node.target is torch.ops.aten.randperm.default
)
return False

View File

@ -437,7 +437,7 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
subclass_type = type(arg)
if (
subclass_type.__torch_dispatch__
== torch._C._disabled_torch_dispatch_impl
is torch._C._disabled_torch_dispatch_impl
):
continue

View File

@ -1050,7 +1050,7 @@ def _get_weight_info_from_shadow_wrapper(shadow_wrapper: torch.nn.Module):
raise AssertionError(f"Expected exactly 1, got {len(shadow_n.users)}")
quant_node = next(iter(shadow_n.users.keys()))
new_args: Any = None
if quant_node.target == torch.quantize_per_channel:
if quant_node.target is torch.quantize_per_channel:
_weight, scale_node, zp_node, axis, dtype = quant_node.args
scale_val = getattr_from_fqn(shadow_wrapper, scale_node.target)
zp_val = getattr_from_fqn(shadow_wrapper, zp_node.target)

View File

@ -204,7 +204,7 @@ def get_node_input_qparams(
if prev_node.op == "call_function":
# quantize - read the args directly
if prev_node.target == torch.quantize_per_tensor:
if prev_node.target is torch.quantize_per_tensor:
return _get_scale_zp_from_function_args(prev_node, gm, 1, 2)
elif prev_node.target in (toq.add, toq.add_relu, toq.mul, toq.mul_relu):
return _get_scale_zp_from_function_args(prev_node, gm, 2, 3)

View File

@ -1181,7 +1181,7 @@ def special_pattern_replacement(model: GraphModule):
modules = dict(model.named_modules(remove_duplicate=False))
for n in model.graph.nodes:
q_node = n
is_quantize = q_node.target == torch.quantize_per_tensor
is_quantize = q_node.target is torch.quantize_per_tensor
is_to_fp16 = (
q_node.op == "call_method"
and q_node.target == "to"

View File

@ -139,10 +139,10 @@ def _get_lstm_with_individually_observed_parts(
mul_count = 0
for node in cell.graph.nodes:
op_index: Optional[tuple[Callable, int]] = None # e.g. (torch.add, 1)
if node.target == torch.add:
if node.target is torch.add:
op_index = (torch.add, add_count)
add_count += 1
elif node.target == torch.mul:
elif node.target is torch.mul:
op_index = (torch.mul, mul_count)
mul_count += 1
else:
@ -205,7 +205,7 @@ def _get_reference_quantized_lstm_module(
# on custom module input/output dtypes, and (2) expand support for complex
# input/output structures.
for node in cell.graph.nodes:
if node.target == torch.quantize_per_tensor:
if node.target is torch.quantize_per_tensor:
arg = node.args[0]
# Remove quantize(x), quantize(hidden[0]), and quantize(hidden[1])
if arg.target == "x" or (

View File

@ -455,9 +455,9 @@ def _maybe_insert_input_observers_for_node(
# gelu has a has an approximate kwarg that persist in exported graph.
# This is just a work around for these.
if not (
node.target == torch.ops.aten.clone.default
or node.target == torch.ops.aten.zeros_like.default
or node.target == torch.ops.aten.gelu.default
node.target is torch.ops.aten.clone.default
or node.target is torch.ops.aten.zeros_like.default
or node.target is torch.ops.aten.gelu.default
or len(node.kwargs) == 0
):
raise AssertionError(" expecting kwargs for aten op IR to be empty")

View File

@ -939,7 +939,7 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
# remove in place add from batchnorm tracking training stats
for node in m.graph.nodes:
if (
node.target == torch.ops.aten.add_.Tensor
node.target is torch.ops.aten.add_.Tensor
and node.args[0].op == "get_attr"
and node.args[1] == 1
and (

View File

@ -97,10 +97,10 @@ def _find_q_dq_node_for_user(
def _is_sym_size_node(node: Node):
return (
node.op == "call_function"
and node.target == torch.ops.aten.sym_size.default
or node.target == torch.ops.aten.sym_numel.default
or node.target == torch.ops.aten.sym_numel
or node.target == torch.ops.aten.sym_size
and node.target is torch.ops.aten.sym_size.default
or node.target is torch.ops.aten.sym_numel.default
or node.target is torch.ops.aten.sym_numel
or node.target is torch.ops.aten.sym_size
)
@ -228,7 +228,7 @@ def fold_bn_weights_into_conv_node(
bn_b = _get_tensor_constant_from_node(bn_args[2], m)
bn_rm = _get_tensor_constant_from_node(bn_args[3], m)
bn_rv = _get_tensor_constant_from_node(bn_args[4], m)
if bn_node.target == torch.ops.aten._native_batch_norm_legit_no_training.default:
if bn_node.target is torch.ops.aten._native_batch_norm_legit_no_training.default:
eps_arg_index = 6
elif _is_supported_batch_norm_for_training(bn_node):
eps_arg_index = 7
@ -268,7 +268,7 @@ def fold_bn_weights_into_conv_node(
# native_batch_norm has 3 outputs, we expect getitem calls on the output
# and we want to replace the uses of getitem 0 with the output of conv
#
if bn_node.target == torch.ops.aten.batch_norm.default:
if bn_node.target is torch.ops.aten.batch_norm.default:
# With the new training ir, instead of batch_norm + getitem,
# we only have the batch_norm node.
#
@ -377,7 +377,7 @@ def _get_aten_graph_module_for_pattern(
for node in aten_pattern.graph.nodes: # type: ignore[union-attr]
if (
node.op == "call_function"
and node.target == torch.ops.aten.copy_.default
and node.target is torch.ops.aten.copy_.default
and len(node.users) == 0
):
aten_pattern.graph.erase_node(node) # type: ignore[operator, union-attr]

View File

@ -77,7 +77,7 @@ class EmbeddingQuantizer(Quantizer):
# just as an example of alternate ways of annotating
if (
node.op == "call_function"
and node.target == torch.ops.aten.embedding.default
and node.target is torch.ops.aten.embedding.default
):
if embedding_config.config.weight is None:
raise ValueError(

View File

@ -607,7 +607,7 @@ class X86InductorQuantizer(Quantizer):
_annotate_nodes_not_quantize(linear_node)
return
input_qspec_map = {}
assert linear_node.target == torch.ops.aten.linear.default
assert linear_node.target is torch.ops.aten.linear.default
has_bias = len(linear_node.args) == 3
input_index = 0
weight_index = 1
@ -1396,7 +1396,7 @@ class X86InductorQuantizer(Quantizer):
""" # noqa: B950
edge_or_node: tuple[Node, Node]
if (node.target in int8_in_int8_out_ops) and (_is_any_annotated([node])):
if node.target == torch.ops.aten.max_pool2d.default:
if node.target is torch.ops.aten.max_pool2d.default:
maxpool_node = node
if not _is_all_annotated(
[

View File

@ -112,7 +112,7 @@ class XPUInductorQuantizer(X86InductorQuantizer):
node: Node,
) -> None:
if (node.target in int8_in_int8_out_ops) and (_is_any_annotated([node])):
if node.target == torch.ops.aten.max_pool2d.default:
if node.target is torch.ops.aten.max_pool2d.default:
return
else:
input_node = node.all_input_nodes[0]

View File

@ -111,7 +111,7 @@ def auto_quantize(func, qtype, quant_loss=None):
async_op = kwargs.get("async_op", False)
if async_op is True:
raise RuntimeError("The async_op=True mode is not supported yet.")
if func == dist.all_gather:
if func is dist.all_gather:
tensors = args[0]
input_tensors = _quantize_tensor(args[1], qtype)
out_tensors = _quantize_tensor_list(tensors, qtype)
@ -121,7 +121,7 @@ def auto_quantize(func, qtype, quant_loss=None):
):
tensors[i] = t
elif func == dist.all_to_all:
elif func is dist.all_to_all:
tensors = args[0]
input_tensors = _quantize_tensor_list(args[1], qtype)
out_tensors = _quantize_tensor_list(tensors, qtype)

View File

@ -17,7 +17,7 @@ from torch.fx import Graph
def remove_self_clone(graph: Graph) -> None:
for node in graph.nodes:
if node.target == torch.ops.aten.copy_.default and node.args[0] == node.args[1]:
if node.target is torch.ops.aten.copy_.default and node.args[0] == node.args[1]:
node.replace_all_uses_with(node.args[0])
graph.erase_node(node)

View File

@ -69,7 +69,7 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None:
flatten_node = curr_module_users[0]
assert (
flatten_node.op == "call_function"
and flatten_node.target == fx_pytree.tree_flatten_spec
and flatten_node.target is fx_pytree.tree_flatten_spec
)
flatten_getitem_users = _get_getitem_users(flatten_node)
@ -85,7 +85,7 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None:
unflatten_node = next(iter(flatten_getitem_users))
if not (
unflatten_node.op == "call_function"
and unflatten_node.target == pytree.tree_unflatten
and unflatten_node.target is pytree.tree_unflatten
):
log.debug(
"Flatten node %s's user is not a pytree.tree_unflatten. "

View File

@ -222,7 +222,7 @@ def _rewrite_tracepoint_node(gm: torch.fx.GraphModule):
that has the same target and args, but with the _export_root stripped from path.
"""
for node in gm.graph.nodes:
if node.target == torch.ops.higher_order._export_tracepoint:
if node.target is torch.ops.higher_order._export_tracepoint:
if "path" in node.kwargs:
path = _strip_root(node.kwargs["path"])
with gm.graph.inserting_before(node):

View File

@ -50,7 +50,7 @@ def _remove_detach_pass(
if node.op != "call_function":
continue
if (
node.target == torch.ops.aten.detach.default
node.target is torch.ops.aten.detach.default
and len(node.users) == 1
and next(iter(node.users)).target == torch.ops.aten.detach.default
):

View File

@ -78,7 +78,7 @@ def move_to_device_pass(
if (
node.op == "call_function"
and node.target == torch.ops.aten.to.device
and node.target is torch.ops.aten.to.device
):
args = list(node.args)
# pyrefly: ignore [unsupported-operation]

View File

@ -814,7 +814,7 @@ def bisect(shape_env):
# Bisection happens on the assertion nodes of the recorded FX graph for
# dynamic shapes.
assert_nodes = [
node for node in shape_env.graph.nodes if node.target == torch._assert
node for node in shape_env.graph.nodes if node.target is torch._assert
]
# Preparing the indices for binary search.

View File

@ -15,7 +15,7 @@ class CudaGraphsSupport(OperatorSupport):
if node.op not in CALLABLE_NODE_OPS:
return False
if node.target == torch.ops.aten.embedding_dense_backward.default:
if node.target is torch.ops.aten.embedding_dense_backward.default:
return False
if node.target == operator.getitem:

View File

@ -326,18 +326,18 @@ def split_module(
instantiate_node_partition_mapping(node)
if node.op == "call_function" and node.target in GLOBAL_STATE_NODES:
if node.target == torch._C._set_grad_enabled:
if node.target is torch._C._set_grad_enabled:
assert len(node.args) == 1
assert isinstance(node.args[0], bool)
active_grad = node
grad_regions[active_grad] = set({split_callback(node)})
elif node.target == torch.amp._enter_autocast:
elif node.target is torch.amp._enter_autocast:
# Should all be python constants
assert all(not isinstance(arg, Node) for arg in node.args)
active_autocasts.add(node)
autocast_regions[node] = set({split_callback(node)})
autocast_exits[node] = None
elif node.target == torch.amp._exit_autocast:
elif node.target is torch.amp._exit_autocast:
assert len(node.args) == 1
autocast_regions[node.args[0]].add(split_callback(node))
active_autocasts.remove(node.args[0])

View File

@ -1501,7 +1501,7 @@ def _checkpoint_without_reentrant_generator(
unpack_error_cb = None
if _checkpoint_debug_enabled if _checkpoint_debug_enabled is not None else debug:
if context_fn != noop_context_fn:
if context_fn is not noop_context_fn:
raise ValueError(
"debug=True is incompatible with non-default context_fn"
)