From 694db5f54927697c9e914d35029f7e5bd9b85b96 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Thu, 30 Oct 2025 19:00:04 +0000 Subject: [PATCH] Use 'is' in callable comparisons (#166624) Just like we use `is/is not` for class comparisons, it is generally advised to use `is/is not` for comparisons against torch functions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166624 Approved by: https://github.com/Lucaskabela, https://github.com/Skylion007 --- torch/_dynamo/backends/debugging.py | 6 +-- torch/_dynamo/compiled_autograd.py | 8 +-- torch/_dynamo/debug_utils.py | 2 +- torch/_dynamo/variables/builder.py | 6 +-- torch/_dynamo/variables/higher_order_ops.py | 2 +- .../passes/collect_tracepoints_pass.py | 6 +-- torch/_export/passes/constant_folding.py | 2 +- .../passes/replace_autocast_with_hop_pass.py | 4 +- .../passes/replace_set_grad_with_hop_pass.py | 2 +- torch/_export/serde/schema_check.py | 2 +- torch/_functorch/_aot_autograd/utils.py | 2 +- torch/_functorch/partitioners.py | 4 +- torch/_inductor/comms.py | 10 ++-- torch/_inductor/constant_folding.py | 4 +- torch/_inductor/fx_passes/b2b_gemm.py | 2 +- torch/_inductor/fx_passes/bucketing.py | 8 +-- torch/_inductor/fx_passes/joint_graph.py | 2 +- torch/_inductor/fx_passes/pre_grad.py | 2 +- torch/_inductor/fx_passes/reinplace.py | 4 +- torch/_inductor/fx_passes/split_cat.py | 52 +++++++++---------- torch/_inductor/ir.py | 2 +- torch/_inductor/kernel/conv.py | 2 +- torch/_inductor/kernel/flex/flex_attention.py | 2 +- torch/_inductor/lowering.py | 2 +- torch/_ops.py | 2 +- torch/ao/ns/fx/n_shadows_utils.py | 2 +- torch/ao/ns/fx/utils.py | 2 +- .../fx/_lower_to_native_backend.py | 2 +- torch/ao/quantization/fx/lstm_utils.py | 6 +-- torch/ao/quantization/pt2e/prepare.py | 6 +-- torch/ao/quantization/pt2e/qat_utils.py | 2 +- torch/ao/quantization/pt2e/utils.py | 14 ++--- .../quantizer/embedding_quantizer.py | 2 +- .../quantizer/x86_inductor_quantizer.py | 4 +- .../quantizer/xpu_inductor_quantizer.py | 2 +- .../algorithms/_quantization/quantization.py | 4 +- .../_remove_auto_functionalized_pass.py | 2 +- torch/export/_swap.py | 4 +- torch/export/_trace.py | 2 +- torch/export/experimental/__init__.py | 2 +- torch/export/passes/__init__.py | 2 +- torch/fx/experimental/validator.py | 2 +- torch/fx/passes/backends/cudagraphs.py | 2 +- torch/fx/passes/split_module.py | 6 +-- torch/utils/checkpoint.py | 2 +- 45 files changed, 105 insertions(+), 105 deletions(-) diff --git a/torch/_dynamo/backends/debugging.py b/torch/_dynamo/backends/debugging.py index 360a3d73353..0e62e08cf1f 100644 --- a/torch/_dynamo/backends/debugging.py +++ b/torch/_dynamo/backends/debugging.py @@ -369,7 +369,7 @@ def relu_compile_error_TESTING_ONLY( gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] ) -> torch.fx.GraphModule: for node in gm.graph.nodes: - if node.target == torch.relu: + if node.target is torch.relu: raise ReluCompileError return gm @@ -379,7 +379,7 @@ def relu_runtime_error_TESTING_ONLY( gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] ) -> torch.fx.GraphModule: for node in gm.graph.nodes: - if node.target == torch.relu: + if node.target is torch.relu: node.target = torch._assert node.args = (False, "ReluRuntimeError") gm.recompile() @@ -391,7 +391,7 @@ def relu_accuracy_error_TESTING_ONLY( gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] ) -> torch.fx.GraphModule: for node in gm.graph.nodes: - if node.target == torch.relu: + if node.target is torch.relu: node.target = torch.add node.args = (node.args[0], 1) gm.recompile() diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index 5af72310b3a..6ce53c77ae7 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -573,7 +573,7 @@ class AutogradCompilerInstance: result.name = make_unique(node.name) value_remap[node] = result elif node.op == "call_function": - if node.target == torch.ops.aten.view.default: + if node.target is torch.ops.aten.view.default: # this aot bwd graph is being lazily compiled # we must manually apply the view_to_reshape post grad pass # since it was already applied to the aot fwd, and baked into the gradients @@ -1241,7 +1241,7 @@ class AutogradCompilerInstance: to_append = [] hook_block = [node] # contain the hook and hook args getitem for n in input_nodes: - if n.op == "call_function" and n.target == operator.getitem: + if n.op == "call_function" and n.target is operator.getitem: to_append.append(n.args[0]) to_remove.append(n) hook_block.append(n) @@ -1314,7 +1314,7 @@ class AutogradCompilerInstance: # find the corresponding acc_grad node acc_grad_node = None for n in list(param_node.users.keys()): - if n.op == "call_function" and n.target == call_accumulate_grad: + if n.op == "call_function" and n.target is call_accumulate_grad: acc_grad_node = n break @@ -1369,7 +1369,7 @@ class AutogradCompilerInstance: for n in list(param_node.users.keys()): if ( n.op == "call_function" - and n.target == call_hook + and n.target is call_hook and n.kwargs.get("hook_type", None) == "post_acc_grad_hook" ): post_acc_grad_hook_node = n diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index 47e5fdb12df..e16fa11ed08 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -455,7 +455,7 @@ def cast_dtype_args_to_fp64(model: torch.fx.GraphModule) -> torch.fx.GraphModule for node in model.graph.nodes: if ( node.op == "call_function" - and node.target == torch.ops.prims.convert_element_type.default + and node.target is torch.ops.prims.convert_element_type.default ): assert len(node.args) == 2 if is_float_dtype(node.args[1]) and node.args[1] != torch.float64: diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 13eccf68e3a..bc80f580ee3 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -2930,12 +2930,12 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe hasattr(proxy.node.target, "__name__") and proxy.node.target.__name__ == "set_state" and isinstance(proxy.node.target.__self__, torch._C.Generator) - or proxy.node.target == torch.random.set_rng_state + or proxy.node.target is torch.random.set_rng_state ): return TorchInGraphFunctionVariable(proxy.node.target) elif ( - proxy.node.target == torch._C._DisableFuncTorch - or proxy.node.target == torch.cuda._is_in_bad_fork + proxy.node.target is torch._C._DisableFuncTorch + or proxy.node.target is torch.cuda._is_in_bad_fork ): return UserDefinedObjectVariable(example_value) elif istype(example_value, torch.Size) and all( diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 11aa77decdf..1b26bed0524 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -2595,7 +2595,7 @@ class CheckpointHigherOrderVariable(WrapHigherOrderVariable): from torch.utils.checkpoint import noop_context_fn context_fn = None - if "context_fn" in kwargs and kwargs["context_fn"] != noop_context_fn: + if "context_fn" in kwargs and kwargs["context_fn"] is not noop_context_fn: ctx = kwargs.pop("context_fn") if isinstance(ctx, torch._dynamo.variables.UserFunctionVariable): context_fn = ctx.fn diff --git a/torch/_export/passes/collect_tracepoints_pass.py b/torch/_export/passes/collect_tracepoints_pass.py index 8162342e50c..aa4765d3d62 100644 --- a/torch/_export/passes/collect_tracepoints_pass.py +++ b/torch/_export/passes/collect_tracepoints_pass.py @@ -48,7 +48,7 @@ class CollectTracepointsPass(PassBase): for node in module.graph.nodes: if node.op != "call_function": continue - if node.target == torch.ops.higher_order._export_tracepoint: + if node.target is torch.ops.higher_order._export_tracepoint: kind = node.kwargs["kind"] if kind == "module_call_outputs": nn_module_stack = node.meta["nn_module_stack"] @@ -64,7 +64,7 @@ class CollectTracepointsPass(PassBase): for node in reversed(module.graph.nodes): if node.op != "call_function": continue - if node.target == torch.ops.higher_order._export_tracepoint: + if node.target is torch.ops.higher_order._export_tracepoint: kind = node.kwargs["kind"] if kind == "module_call_inputs": nn_module_stack = node.meta["nn_module_stack"] @@ -94,7 +94,7 @@ class CollectTracepointsPass(PassBase): for node in module.graph.nodes: if node.op != "call_function": continue - if node.target == torch.ops.higher_order._export_tracepoint: + if node.target is torch.ops.higher_order._export_tracepoint: # There's some subtlety worth noting. Here fqn corresponds to # the call name, whereas path corresponds to the module name. # They are not necessarily the same! When a submodule is shared diff --git a/torch/_export/passes/constant_folding.py b/torch/_export/passes/constant_folding.py index 5fdc92702a1..b51b05f78ea 100644 --- a/torch/_export/passes/constant_folding.py +++ b/torch/_export/passes/constant_folding.py @@ -65,7 +65,7 @@ class ConstantFolder(torch.fx.Interpreter): def is_impure(self, node: torch.fx.Node) -> bool: if ( - node.target == torch.ops.prims.convert_element_type.default + node.target is torch.ops.prims.convert_element_type.default and node.args[0].op == "get_attr" # type: ignore[union-attr] and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr] and node.args[1] == torch.bfloat16 diff --git a/torch/_export/passes/replace_autocast_with_hop_pass.py b/torch/_export/passes/replace_autocast_with_hop_pass.py index 71b90a3ff1b..bd2fe5c26d5 100644 --- a/torch/_export/passes/replace_autocast_with_hop_pass.py +++ b/torch/_export/passes/replace_autocast_with_hop_pass.py @@ -34,7 +34,7 @@ def _is_enter_autocast_node(node: torch.fx.Node) -> Union[torch.fx.Node, bool]: return ( node and node.op == "call_function" - and node.target == torch.amp.autocast_mode._enter_autocast + and node.target is torch.amp.autocast_mode._enter_autocast ) @@ -42,7 +42,7 @@ def _is_exit_autocast_node(node: torch.fx.Node) -> Union[torch.fx.Node, bool]: return ( node and node.op == "call_function" - and node.target == torch.amp.autocast_mode._exit_autocast + and node.target is torch.amp.autocast_mode._exit_autocast ) diff --git a/torch/_export/passes/replace_set_grad_with_hop_pass.py b/torch/_export/passes/replace_set_grad_with_hop_pass.py index 4c3a9c48d75..cf75ead34fd 100644 --- a/torch/_export/passes/replace_set_grad_with_hop_pass.py +++ b/torch/_export/passes/replace_set_grad_with_hop_pass.py @@ -22,7 +22,7 @@ def _is_set_grad_enabled_node(node: torch.fx.Node) -> Union[torch.fx.Node, bool] return ( node and node.op == "call_function" - and node.target == torch._C._set_grad_enabled + and node.target is torch._C._set_grad_enabled ) diff --git a/torch/_export/serde/schema_check.py b/torch/_export/serde/schema_check.py index 0890b4b2dd8..d59245ac95f 100644 --- a/torch/_export/serde/schema_check.py +++ b/torch/_export/serde/schema_check.py @@ -88,7 +88,7 @@ def _staged_schema(): f"std::optional<{cpp_type}>", f"optional {thrift_type}", ) - elif o == Annotated: + elif o is Annotated: return dump_type(t.__origin__, level) else: raise AssertionError(f"Type {t} is not supported in export schema.") diff --git a/torch/_functorch/_aot_autograd/utils.py b/torch/_functorch/_aot_autograd/utils.py index 4a912d9dfd9..858c0e9e539 100644 --- a/torch/_functorch/_aot_autograd/utils.py +++ b/torch/_functorch/_aot_autograd/utils.py @@ -249,7 +249,7 @@ def maybe_to_fresh_input(idx, t, meta): def is_with_effects(node): return ( node.op == "call_function" - and node.target == torch.ops.higher_order.with_effects + and node.target is torch.ops.higher_order.with_effects ) diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index d7b59ad6075..b56e86e50ed 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -1256,7 +1256,7 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule: # Build the graph op-by-op by starting from the node all the way to the end # copy_ can be not using tangents at all, we must copy it. for node in list(gm.graph.nodes)[: order[first_node_in_bwd]]: - if node.op == "call_function" and node.target == torch.ops.aten.copy_.default: + if node.op == "call_function" and node.target is torch.ops.aten.copy_.default: insert_node_in_graph(node) for node in list(gm.graph.nodes)[order[first_node_in_bwd] :]: @@ -1596,7 +1596,7 @@ def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None: if node.op == "output": continue - is_copy_ = node.target == torch.ops.aten.copy_.default + is_copy_ = node.target is torch.ops.aten.copy_.default if is_copy_: if _has_tag_must_be_in_backward(node): has_mutation_in_bw.add(node.args[0]) diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index 5a1e39bf710..00f023be083 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -1393,7 +1393,7 @@ def remove_fsdp2_unsharded_param_graph_input_usage(graph: torch.fx.Graph): for idx, node in enumerate(node_list): if ( node.op == "call_function" - and node.target == torch.ops.inductor.resize_storage_bytes_.default + and node.target is torch.ops.inductor.resize_storage_bytes_.default ): assert node.args[0].op == "placeholder", f"""\ Resize can only operate on graph inputs, but got {node} which is resizing non-graph-input {node.args[0]} @@ -1444,7 +1444,7 @@ Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass for that # Find all eligible unsharded params and their corresponding graph intermediates. unsharded_param_to_fsdp_copy_node_idxes = defaultdict(list) for idx, node in enumerate(node_list): - if node.op == "call_function" and node.target == torch.ops.fsdp.copy_.default: + if node.op == "call_function" and node.target is torch.ops.fsdp.copy_.default: fsdp_copy_node = node unsharded_param = node.args[0] assert unsharded_param.op == "placeholder", f""" @@ -1456,8 +1456,8 @@ Offending node: {unsharded_param}. Graph: {graph} def is_allowed_mutation(node): return ( - node.target == torch.ops.fsdp.copy_.default - or node.target == torch.ops.inductor.resize_storage_bytes_.default + node.target is torch.ops.fsdp.copy_.default + or node.target is torch.ops.inductor.resize_storage_bytes_.default ) def is_node_mutating_unsharded_param_or_its_alias(node, unsharded_params): @@ -1558,7 +1558,7 @@ Graph: {graph} for node in node_list: if ( node.op == "call_function" - and node.target == torch.ops.inductor.resize_storage_bytes_.default + and node.target is torch.ops.inductor.resize_storage_bytes_.default and node.args[0] in unsharded_param_to_fsdp_copy_node_idxes ): graph.erase_node(node) diff --git a/torch/_inductor/constant_folding.py b/torch/_inductor/constant_folding.py index 85033e2b3e8..4e6f1302937 100644 --- a/torch/_inductor/constant_folding.py +++ b/torch/_inductor/constant_folding.py @@ -122,7 +122,7 @@ class ConstantFolder(torch.fx.Interpreter): def is_impure(self, node: torch.fx.node.Node) -> bool: def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool: return ( - node.target == torch.ops.prims.convert_element_type.default # type: ignore[return-value] + node.target is torch.ops.prims.convert_element_type.default # type: ignore[return-value] and isinstance(node.args[0], torch.fx.Node) and "val" in node.args[0].meta and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr] @@ -132,7 +132,7 @@ class ConstantFolder(torch.fx.Interpreter): if ( is_woq_int8_pattern(node) or ( - node.target == torch.ops.aten.permute.default + node.target is torch.ops.aten.permute.default and len(node.users) == 1 and is_woq_int8_pattern(next(iter(node.users))) ) diff --git a/torch/_inductor/fx_passes/b2b_gemm.py b/torch/_inductor/fx_passes/b2b_gemm.py index 303d9bfd59a..9faec788e9e 100644 --- a/torch/_inductor/fx_passes/b2b_gemm.py +++ b/torch/_inductor/fx_passes/b2b_gemm.py @@ -591,7 +591,7 @@ def b2b_gemm_handler(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node) -> ) def is_mm(node: torch.fx.Node) -> bool: - return node.target == torch.ops.aten.mm.default + return node.target is torch.ops.aten.mm.default # the inner MM inner_mm = match.nodes[-1] diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index 555c6e8d2ba..5dd383b2bbe 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -122,28 +122,28 @@ def bucket_reduce_scatter( def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: # type: ignore[arg-type] return ( node.op == "call_function" - and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default + and node.target is torch.ops._c10d_functional.all_gather_into_tensor.default ) def is_reduce_scatter_tensor(node: torch.fx.Node) -> bool: return ( node.op == "call_function" - and node.target == torch.ops._c10d_functional.reduce_scatter_tensor.default + and node.target is torch.ops._c10d_functional.reduce_scatter_tensor.default ) def is_wait_tensor(node: torch.fx.Node) -> bool: return ( node.op == "call_function" - and node.target == torch.ops._c10d_functional.wait_tensor.default + and node.target is torch.ops._c10d_functional.wait_tensor.default ) def is_all_reduce_tensor(node: torch.fx.Node) -> bool: return ( node.op == "call_function" - and node.target == torch.ops._c10d_functional.all_reduce.default + and node.target is torch.ops._c10d_functional.all_reduce.default ) diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py index 87075efc202..8f4568bd89f 100644 --- a/torch/_inductor/fx_passes/joint_graph.py +++ b/torch/_inductor/fx_passes/joint_graph.py @@ -315,7 +315,7 @@ class UniformValueConstantFolder(ConstantFolder): # single-elem attrs if node.op == "get_attr" or ( node.op == "call_function" - and node.target == torch.ops.aten.lift_fresh_copy.default + and node.target is torch.ops.aten.lift_fresh_copy.default ): out = super(ConstantFolder, self).run_node(node) if isinstance(out, torch.Tensor) and out.numel() == 1: diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py index b953a7ad01a..051c75b2c2a 100644 --- a/torch/_inductor/fx_passes/pre_grad.py +++ b/torch/_inductor/fx_passes/pre_grad.py @@ -737,7 +737,7 @@ def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: input_node = node.kwargs["input"] if ( input_node.op == "call_function" - and input_node.target == torch.nn.functional.linear + and input_node.target is torch.nn.functional.linear ): normalized = NormalizedLinearNode(input_node) input = normalized.get_input() diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index 8b9deac6ba5..4ad15a5ee18 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -679,7 +679,7 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None: if copy_node is not None: replace_dict[copy_node] = copy_node.args[0] node.target = inplaceable_op.inplace_op - elif node.target == torch.ops.higher_order.auto_functionalized_v2: + elif node.target is torch.ops.higher_order.auto_functionalized_v2: _mutable_op = node.args[0] kwargs = node.kwargs @@ -696,7 +696,7 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None: # auto_functionalized into clones + a mutable op; this metadata # tells the decomp to only clone the following inputs node.meta["only_clone_these_tensors"] = new_bases_to_clone - elif node.target == torch.ops.higher_order.auto_functionalized: + elif node.target is torch.ops.higher_order.auto_functionalized: _mutable_op = node.args[0] from torch._higher_order_ops.auto_functionalize import get_mutable_args diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 15ea6867dba..773b1a3f207 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -127,12 +127,12 @@ def _get_dim(node: Any): if "dim" in node.kwargs: assert isinstance(node.kwargs["dim"], int) return node.kwargs["dim"] - if node.target == torch.unbind: + if node.target is torch.unbind: if len(node.args) == 2: assert isinstance(node.args[-1], int) return node.args[-1] return 0 # defaults to dim=0 - if node.target == torch.split: + if node.target is torch.split: if len(node.args) == 3: assert isinstance(node.args[-1], int) return node.args[-1] @@ -351,7 +351,7 @@ def normalize_cat_default(match: Match, *args, **kwargs): cat_node.args == new_args and cat_node.kwargs == new_kwargs and cat_node.op == "call_function" - and cat_node.target == torch.cat + and cat_node.target is torch.cat ): return @@ -866,7 +866,7 @@ class SplitCatSimplifier: cat_dim = get_arg_value(user_node, 1, "dim") transform_params: list[_TransformParam] = [] for user_input in user_inputs: - if split_dim == cat_dim and user_node.target == torch.cat: + if split_dim == cat_dim and user_node.target is torch.cat: # No transform needed transform_params.append((None, None, None, None)) elif isinstance(user_input, tuple): # Split being simplified @@ -888,7 +888,7 @@ class SplitCatSimplifier: (unflatten_params, movedim_params, None, None) ) elif ( - user_node.target == torch.stack or split_dim != cat_dim + user_node.target is torch.stack or split_dim != cat_dim ): # We need to unsqueeze inputs not coming through split transform_params.append((None, None, (cat_dim,), None)) else: # Non-split inputs @@ -1107,9 +1107,9 @@ class SplitCatSimplifier: ) if ( - user_node.target == torch.cat + user_node.target is torch.cat and split_dim != cat_dim - and split_node.target == torch.split + and split_node.target is torch.split ): with graph.inserting_after(new_cat_node): new_cat_node_meta = new_cat_node.meta["example_value"] @@ -1225,13 +1225,13 @@ class UnbindCatRemover(SplitCatSimplifier): (split_dim, cat_dim) if split_dim != cat_dim else None ) flatten_params = None - if user_node.target == torch.cat: + if user_node.target is torch.cat: flatten_params = (cat_dim, cat_dim + 1) transform_params.append( (None, movedim_params, None, flatten_params) ) elif ( - user_node.target == torch.stack + user_node.target is torch.stack ): # We need to unsqueeze inputs not coming through unbind into cat transform_params.append((None, None, (cat_dim,), None)) else: # Non-unbind inputs @@ -1298,13 +1298,13 @@ def merge_split_squeeze( match: Match, split_input: torch.fx.Node, split_sizes: list[int], dim: int ): graph = match.graph - split = next(node for node in match.nodes if node.target == torch.split) + split = next(node for node in match.nodes if node.target is torch.split) if not all(s == 1 for s in split_sizes): return if isinstance(dim, Sequence): return next_users = find_next_users(split) - if not all(node.target == torch.squeeze for node in next_users): + if not all(node.target is torch.squeeze for node in next_users): return with graph.inserting_before(match.output_node()): unbind = graph.call_function( @@ -1364,7 +1364,7 @@ getitem_unbind = ListOf( pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"), ) def merge_unbind_stack(match: Match, unbind_input: torch.fx.Node, dim: int): - unbind_node = next(node for node in match.nodes if node.target == torch.unbind) + unbind_node = next(node for node in match.nodes if node.target is torch.unbind) UnbindCatRemover().remove_unbind(match.graph, unbind_node) @@ -1431,7 +1431,7 @@ reshape_getitem_split = ListOf( def simplify_split_cat(match: Match, split_sections: list[int], dim: int): if not isinstance(split_sections, (list, tuple)): # Unnormalized split return - split_node = next(node for node in match.nodes if node.target == torch.split) + split_node = next(node for node in match.nodes if node.target is torch.split) # pyrefly: ignore [bad-argument-type] SplitCatSimplifier().simplify(match.graph, split_node, split_sections) @@ -1518,7 +1518,7 @@ def merge_getitem_cat(match: Match, split_sections: list[int], dim: int): if not isinstance(split_sections, (list, tuple)): # Unnormalized split return graph = match.graph - split_node = next(node for node in match.nodes if node.target == torch.split) + split_node = next(node for node in match.nodes if node.target is torch.split) split_input, _split_size, split_dim = _get_split_args_default(split_node) # if the cat and split have different dims, return # Find the next users (i.e. users after the getitem) @@ -1625,7 +1625,7 @@ def mutate_cat_node(match: Match, split_sections: list[int], dim: int): if not isinstance(split_sections, (list, tuple)): # Unnormalized split return graph = match.graph - split_node = next(node for node in match.nodes if node.target == torch.split) + split_node = next(node for node in match.nodes if node.target is torch.split) _split_input, _split_size, split_dim = _get_split_args_default(split_node) # if the cat and split have different dims, return # Find the next users (i.e. users after the getitem) @@ -1904,7 +1904,7 @@ def merge_select_cat_aten(match: Match, *args, **kwargs): # get the select nodes from the node select_nodes = list(node_input.users.keys()) for cat_node in list(node.users.keys()): - if cat_node.target == torch.ops.aten.cat.default: + if cat_node.target is torch.ops.aten.cat.default: cat_dim = get_arg_value(cat_node, 1, "dim") cat_inputs = get_arg_value(cat_node, 0, "tensors") # check all select nodes has same slice dim @@ -2020,7 +2020,7 @@ def merge_unbind_stack_aten(match: Match, *args, **kwargs): parent_of_select_node = get_arg_value(select_nodes[0], 0, "input") # check the target of select_nodes are the same if not all( - select_node.target == torch.ops.aten.select for select_node in select_nodes + select_node.target is torch.ops.aten.select for select_node in select_nodes ): return # check the select nodes come from the same parent node @@ -2357,7 +2357,7 @@ def remove_split_unbind_children(graph: torch.fx.Graph, inputs: list[torch.fx.No def split_cat_to_slices(match: Match, split_sections: list[int], dim: int): if not isinstance(split_sections, (list, tuple)): # Unnormalized split return - split_nodes = [node for node in match.nodes if node.target == torch.split] + split_nodes = [node for node in match.nodes if node.target is torch.split] if split_nodes: split_node = next(node for node in split_nodes) else: @@ -2438,7 +2438,7 @@ def split_cat_to_slices(match: Match, split_sections: list[int], dim: int): pass_dict=construct_pattern_matcher_pass("unbind_cat_to_view_pass"), ) def unbind_cat_to_view(match: Match, unbind_input: torch.fx.Node, dim: int): - unbind_node = next(node for node in match.nodes if node.target == torch.unbind) + unbind_node = next(node for node in match.nodes if node.target is torch.unbind) graph = match.graph # get the cat_node and check its inputs and meta data next_users = find_next_users(unbind_node) @@ -2614,7 +2614,7 @@ def convert_reshape_cat_arg_to_stack( def split_stack_to_cats(match: Match, split_sections: list[int], dim: int): if not isinstance(split_sections, (list, tuple)): # Unnormalized split return - split_node = next(node for node in match.nodes if node.target == torch.split) + split_node = next(node for node in match.nodes if node.target is torch.split) split_dim = get_arg_value(split_node, 2, "dim") or 0 graph = match.graph threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ @@ -2685,7 +2685,7 @@ def split_stack_to_cats(match: Match, split_sections: list[int], dim: int): pass_dict=construct_pattern_matcher_pass("unbind_stack_to_slices_pass"), ) def unbind_stack_to_slices(match: Match, unbind_input: torch.fx.Node, dim: int): - unbind_node = next(node for node in match.nodes if node.target == torch.unbind) + unbind_node = next(node for node in match.nodes if node.target is torch.unbind) graph = match.graph # get the cat_node and check its inputs and meta data next_users = find_next_users(unbind_node) @@ -2785,10 +2785,10 @@ def get_view_shape_list(cat_arg: torch.fx.Node, stack_dim: int) -> list[int]: pass_dict=construct_pattern_matcher_pass("move_reshape_out_of_split_stack_pass"), ) def move_reshape_out_of_split_stack(match: Match, *args, **kwargs): - split_node = next(node for node in match.nodes if node.target == torch.split) + split_node = next(node for node in match.nodes if node.target is torch.split) split_dim = _get_dim(split_node) split_users = list(split_node.users.keys()) - stack_nodes = [node for node in match.nodes if node.target == torch.stack] + stack_nodes = [node for node in match.nodes if node.target is torch.stack] graph = match.graph threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ "move_reshape_out_of_split_stack_pass" @@ -2926,7 +2926,7 @@ def move_view_after_cat(match: Match, *args, **kwargs): split_node = next( node for node in match.nodes - if node.target == torch.ops.aten.split_with_sizes.default + if node.target is torch.ops.aten.split_with_sizes.default ) split_input, split_section, split_dim = _get_split_args_default(split_node) split_users = list(split_node.users.keys()) @@ -2936,7 +2936,7 @@ def move_view_after_cat(match: Match, *args, **kwargs): if not is_sorted_and_consecutive(getitem_indices): # type: ignore[arg-type] return cat_nodes = [ - node for node in match.nodes if node.target == torch.ops.aten.cat.default + node for node in match.nodes if node.target is torch.ops.aten.cat.default ] graph = match.graph for cat_node in cat_nodes: @@ -2950,7 +2950,7 @@ def move_view_after_cat(match: Match, *args, **kwargs): continue # check if the cat inputs are all the view nodes if not all( - view_node.target == torch.ops.aten.reshape.default + view_node.target is torch.ops.aten.reshape.default for view_node in cat_inputs ): continue diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 55f582c5d78..39d32e41b4e 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -5849,7 +5849,7 @@ class ExternKernel(InputsKernel): if shape_env := V.fake_mode.shape_env: node_meta_val = V.current_node.meta.get("val") ctx: AbstractContextManager[None] = nullcontext() - if V.current_node.target == torch._higher_order_ops.effects.with_effects: + if V.current_node.target is torch._higher_order_ops.effects.with_effects: # remove the first effect token in meta["val"] and meta["unbacked_bindings"] node_meta_val = node_meta_val[1] ctx = _remove_effect_token_unbacked_bindings(V.current_node) diff --git a/torch/_inductor/kernel/conv.py b/torch/_inductor/kernel/conv.py index 2977932c084..8e5a2aa09d4 100644 --- a/torch/_inductor/kernel/conv.py +++ b/torch/_inductor/kernel/conv.py @@ -677,7 +677,7 @@ def _convolution( def constrain_conv_to_fx_strides(fx_node, *args, **kwargs): - assert fx_node.target == torch.ops.aten.convolution.default + assert fx_node.target is torch.ops.aten.convolution.default if V.graph.layout_opt: return args, kwargs else: diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py index 9e7a217829d..bc148ebc207 100644 --- a/torch/_inductor/kernel/flex/flex_attention.py +++ b/torch/_inductor/kernel/flex/flex_attention.py @@ -478,7 +478,7 @@ def validate_joint_graph(joint_graph: torch.fx.Graph): for node in joint_graph.nodes: if ( node.op == "call_function" - and node.target == torch.ops.flex_lib.zeros_and_scatter.default + and node.target is torch.ops.flex_lib.zeros_and_scatter.default ): for user in node.users: if user.op != "output": diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 52521285dfe..a700489bf46 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -3922,7 +3922,7 @@ def index_put_impl_(self, indices, values, accumulate, check, may_realize=False) isinstance(indice, ir.StorageBox) and isinstance(indice.data, ir.ExternKernel) and getattr(indice.data, "fx_node", None) - and indice.data.fx_node.target == torch.ops.aten.randperm.default + and indice.data.fx_node.target is torch.ops.aten.randperm.default ) return False diff --git a/torch/_ops.py b/torch/_ops.py index cc7b3ffe2f0..9cdf735532d 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -437,7 +437,7 @@ class HigherOrderOperator(OperatorBase, abc.ABC): subclass_type = type(arg) if ( subclass_type.__torch_dispatch__ - == torch._C._disabled_torch_dispatch_impl + is torch._C._disabled_torch_dispatch_impl ): continue diff --git a/torch/ao/ns/fx/n_shadows_utils.py b/torch/ao/ns/fx/n_shadows_utils.py index 9adca1a7751..f16700994d0 100644 --- a/torch/ao/ns/fx/n_shadows_utils.py +++ b/torch/ao/ns/fx/n_shadows_utils.py @@ -1050,7 +1050,7 @@ def _get_weight_info_from_shadow_wrapper(shadow_wrapper: torch.nn.Module): raise AssertionError(f"Expected exactly 1, got {len(shadow_n.users)}") quant_node = next(iter(shadow_n.users.keys())) new_args: Any = None - if quant_node.target == torch.quantize_per_channel: + if quant_node.target is torch.quantize_per_channel: _weight, scale_node, zp_node, axis, dtype = quant_node.args scale_val = getattr_from_fqn(shadow_wrapper, scale_node.target) zp_val = getattr_from_fqn(shadow_wrapper, zp_node.target) diff --git a/torch/ao/ns/fx/utils.py b/torch/ao/ns/fx/utils.py index a7541e8a50c..3423d853320 100644 --- a/torch/ao/ns/fx/utils.py +++ b/torch/ao/ns/fx/utils.py @@ -204,7 +204,7 @@ def get_node_input_qparams( if prev_node.op == "call_function": # quantize - read the args directly - if prev_node.target == torch.quantize_per_tensor: + if prev_node.target is torch.quantize_per_tensor: return _get_scale_zp_from_function_args(prev_node, gm, 1, 2) elif prev_node.target in (toq.add, toq.add_relu, toq.mul, toq.mul_relu): return _get_scale_zp_from_function_args(prev_node, gm, 2, 3) diff --git a/torch/ao/quantization/fx/_lower_to_native_backend.py b/torch/ao/quantization/fx/_lower_to_native_backend.py index ab6e489d53e..6c0374bdd96 100644 --- a/torch/ao/quantization/fx/_lower_to_native_backend.py +++ b/torch/ao/quantization/fx/_lower_to_native_backend.py @@ -1181,7 +1181,7 @@ def special_pattern_replacement(model: GraphModule): modules = dict(model.named_modules(remove_duplicate=False)) for n in model.graph.nodes: q_node = n - is_quantize = q_node.target == torch.quantize_per_tensor + is_quantize = q_node.target is torch.quantize_per_tensor is_to_fp16 = ( q_node.op == "call_method" and q_node.target == "to" diff --git a/torch/ao/quantization/fx/lstm_utils.py b/torch/ao/quantization/fx/lstm_utils.py index b49f462640f..3bc39424d9c 100644 --- a/torch/ao/quantization/fx/lstm_utils.py +++ b/torch/ao/quantization/fx/lstm_utils.py @@ -139,10 +139,10 @@ def _get_lstm_with_individually_observed_parts( mul_count = 0 for node in cell.graph.nodes: op_index: Optional[tuple[Callable, int]] = None # e.g. (torch.add, 1) - if node.target == torch.add: + if node.target is torch.add: op_index = (torch.add, add_count) add_count += 1 - elif node.target == torch.mul: + elif node.target is torch.mul: op_index = (torch.mul, mul_count) mul_count += 1 else: @@ -205,7 +205,7 @@ def _get_reference_quantized_lstm_module( # on custom module input/output dtypes, and (2) expand support for complex # input/output structures. for node in cell.graph.nodes: - if node.target == torch.quantize_per_tensor: + if node.target is torch.quantize_per_tensor: arg = node.args[0] # Remove quantize(x), quantize(hidden[0]), and quantize(hidden[1]) if arg.target == "x" or ( diff --git a/torch/ao/quantization/pt2e/prepare.py b/torch/ao/quantization/pt2e/prepare.py index aedad07cc8a..6eac69a96ba 100644 --- a/torch/ao/quantization/pt2e/prepare.py +++ b/torch/ao/quantization/pt2e/prepare.py @@ -455,9 +455,9 @@ def _maybe_insert_input_observers_for_node( # gelu has a has an approximate kwarg that persist in exported graph. # This is just a work around for these. if not ( - node.target == torch.ops.aten.clone.default - or node.target == torch.ops.aten.zeros_like.default - or node.target == torch.ops.aten.gelu.default + node.target is torch.ops.aten.clone.default + or node.target is torch.ops.aten.zeros_like.default + or node.target is torch.ops.aten.gelu.default or len(node.kwargs) == 0 ): raise AssertionError(" expecting kwargs for aten op IR to be empty") diff --git a/torch/ao/quantization/pt2e/qat_utils.py b/torch/ao/quantization/pt2e/qat_utils.py index b7daca97b18..1f5588376fb 100644 --- a/torch/ao/quantization/pt2e/qat_utils.py +++ b/torch/ao/quantization/pt2e/qat_utils.py @@ -939,7 +939,7 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule: # remove in place add from batchnorm tracking training stats for node in m.graph.nodes: if ( - node.target == torch.ops.aten.add_.Tensor + node.target is torch.ops.aten.add_.Tensor and node.args[0].op == "get_attr" and node.args[1] == 1 and ( diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index 320429a5677..bb698260539 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -97,10 +97,10 @@ def _find_q_dq_node_for_user( def _is_sym_size_node(node: Node): return ( node.op == "call_function" - and node.target == torch.ops.aten.sym_size.default - or node.target == torch.ops.aten.sym_numel.default - or node.target == torch.ops.aten.sym_numel - or node.target == torch.ops.aten.sym_size + and node.target is torch.ops.aten.sym_size.default + or node.target is torch.ops.aten.sym_numel.default + or node.target is torch.ops.aten.sym_numel + or node.target is torch.ops.aten.sym_size ) @@ -228,7 +228,7 @@ def fold_bn_weights_into_conv_node( bn_b = _get_tensor_constant_from_node(bn_args[2], m) bn_rm = _get_tensor_constant_from_node(bn_args[3], m) bn_rv = _get_tensor_constant_from_node(bn_args[4], m) - if bn_node.target == torch.ops.aten._native_batch_norm_legit_no_training.default: + if bn_node.target is torch.ops.aten._native_batch_norm_legit_no_training.default: eps_arg_index = 6 elif _is_supported_batch_norm_for_training(bn_node): eps_arg_index = 7 @@ -268,7 +268,7 @@ def fold_bn_weights_into_conv_node( # native_batch_norm has 3 outputs, we expect getitem calls on the output # and we want to replace the uses of getitem 0 with the output of conv # - if bn_node.target == torch.ops.aten.batch_norm.default: + if bn_node.target is torch.ops.aten.batch_norm.default: # With the new training ir, instead of batch_norm + getitem, # we only have the batch_norm node. # @@ -377,7 +377,7 @@ def _get_aten_graph_module_for_pattern( for node in aten_pattern.graph.nodes: # type: ignore[union-attr] if ( node.op == "call_function" - and node.target == torch.ops.aten.copy_.default + and node.target is torch.ops.aten.copy_.default and len(node.users) == 0 ): aten_pattern.graph.erase_node(node) # type: ignore[operator, union-attr] diff --git a/torch/ao/quantization/quantizer/embedding_quantizer.py b/torch/ao/quantization/quantizer/embedding_quantizer.py index 88bc6f3c8c9..b0f1b823b7f 100644 --- a/torch/ao/quantization/quantizer/embedding_quantizer.py +++ b/torch/ao/quantization/quantizer/embedding_quantizer.py @@ -77,7 +77,7 @@ class EmbeddingQuantizer(Quantizer): # just as an example of alternate ways of annotating if ( node.op == "call_function" - and node.target == torch.ops.aten.embedding.default + and node.target is torch.ops.aten.embedding.default ): if embedding_config.config.weight is None: raise ValueError( diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index e2482077b73..6ffaf72b12c 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -607,7 +607,7 @@ class X86InductorQuantizer(Quantizer): _annotate_nodes_not_quantize(linear_node) return input_qspec_map = {} - assert linear_node.target == torch.ops.aten.linear.default + assert linear_node.target is torch.ops.aten.linear.default has_bias = len(linear_node.args) == 3 input_index = 0 weight_index = 1 @@ -1396,7 +1396,7 @@ class X86InductorQuantizer(Quantizer): """ # noqa: B950 edge_or_node: tuple[Node, Node] if (node.target in int8_in_int8_out_ops) and (_is_any_annotated([node])): - if node.target == torch.ops.aten.max_pool2d.default: + if node.target is torch.ops.aten.max_pool2d.default: maxpool_node = node if not _is_all_annotated( [ diff --git a/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py b/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py index eff97dbcf27..1888eb57396 100644 --- a/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py @@ -112,7 +112,7 @@ class XPUInductorQuantizer(X86InductorQuantizer): node: Node, ) -> None: if (node.target in int8_in_int8_out_ops) and (_is_any_annotated([node])): - if node.target == torch.ops.aten.max_pool2d.default: + if node.target is torch.ops.aten.max_pool2d.default: return else: input_node = node.all_input_nodes[0] diff --git a/torch/distributed/algorithms/_quantization/quantization.py b/torch/distributed/algorithms/_quantization/quantization.py index a1fa1fd64c0..37cc3e0deb4 100644 --- a/torch/distributed/algorithms/_quantization/quantization.py +++ b/torch/distributed/algorithms/_quantization/quantization.py @@ -111,7 +111,7 @@ def auto_quantize(func, qtype, quant_loss=None): async_op = kwargs.get("async_op", False) if async_op is True: raise RuntimeError("The async_op=True mode is not supported yet.") - if func == dist.all_gather: + if func is dist.all_gather: tensors = args[0] input_tensors = _quantize_tensor(args[1], qtype) out_tensors = _quantize_tensor_list(tensors, qtype) @@ -121,7 +121,7 @@ def auto_quantize(func, qtype, quant_loss=None): ): tensors[i] = t - elif func == dist.all_to_all: + elif func is dist.all_to_all: tensors = args[0] input_tensors = _quantize_tensor_list(args[1], qtype) out_tensors = _quantize_tensor_list(tensors, qtype) diff --git a/torch/export/_remove_auto_functionalized_pass.py b/torch/export/_remove_auto_functionalized_pass.py index 67f84e49af6..1b483392765 100644 --- a/torch/export/_remove_auto_functionalized_pass.py +++ b/torch/export/_remove_auto_functionalized_pass.py @@ -17,7 +17,7 @@ from torch.fx import Graph def remove_self_clone(graph: Graph) -> None: for node in graph.nodes: - if node.target == torch.ops.aten.copy_.default and node.args[0] == node.args[1]: + if node.target is torch.ops.aten.copy_.default and node.args[0] == node.args[1]: node.replace_all_uses_with(node.args[0]) graph.erase_node(node) diff --git a/torch/export/_swap.py b/torch/export/_swap.py index f167842fae9..0a30f2154ce 100644 --- a/torch/export/_swap.py +++ b/torch/export/_swap.py @@ -69,7 +69,7 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None: flatten_node = curr_module_users[0] assert ( flatten_node.op == "call_function" - and flatten_node.target == fx_pytree.tree_flatten_spec + and flatten_node.target is fx_pytree.tree_flatten_spec ) flatten_getitem_users = _get_getitem_users(flatten_node) @@ -85,7 +85,7 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None: unflatten_node = next(iter(flatten_getitem_users)) if not ( unflatten_node.op == "call_function" - and unflatten_node.target == pytree.tree_unflatten + and unflatten_node.target is pytree.tree_unflatten ): log.debug( "Flatten node %s's user is not a pytree.tree_unflatten. " diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 3e7e9d8d991..685fe149714 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -222,7 +222,7 @@ def _rewrite_tracepoint_node(gm: torch.fx.GraphModule): that has the same target and args, but with the _export_root stripped from path. """ for node in gm.graph.nodes: - if node.target == torch.ops.higher_order._export_tracepoint: + if node.target is torch.ops.higher_order._export_tracepoint: if "path" in node.kwargs: path = _strip_root(node.kwargs["path"]) with gm.graph.inserting_before(node): diff --git a/torch/export/experimental/__init__.py b/torch/export/experimental/__init__.py index 2705b59a907..17e16b8f218 100644 --- a/torch/export/experimental/__init__.py +++ b/torch/export/experimental/__init__.py @@ -50,7 +50,7 @@ def _remove_detach_pass( if node.op != "call_function": continue if ( - node.target == torch.ops.aten.detach.default + node.target is torch.ops.aten.detach.default and len(node.users) == 1 and next(iter(node.users)).target == torch.ops.aten.detach.default ): diff --git a/torch/export/passes/__init__.py b/torch/export/passes/__init__.py index d36985180f5..90430608cab 100644 --- a/torch/export/passes/__init__.py +++ b/torch/export/passes/__init__.py @@ -78,7 +78,7 @@ def move_to_device_pass( if ( node.op == "call_function" - and node.target == torch.ops.aten.to.device + and node.target is torch.ops.aten.to.device ): args = list(node.args) # pyrefly: ignore [unsupported-operation] diff --git a/torch/fx/experimental/validator.py b/torch/fx/experimental/validator.py index eb55b6c2050..b3eb7bcde49 100644 --- a/torch/fx/experimental/validator.py +++ b/torch/fx/experimental/validator.py @@ -814,7 +814,7 @@ def bisect(shape_env): # Bisection happens on the assertion nodes of the recorded FX graph for # dynamic shapes. assert_nodes = [ - node for node in shape_env.graph.nodes if node.target == torch._assert + node for node in shape_env.graph.nodes if node.target is torch._assert ] # Preparing the indices for binary search. diff --git a/torch/fx/passes/backends/cudagraphs.py b/torch/fx/passes/backends/cudagraphs.py index 657c7578f5f..d6c41841af3 100644 --- a/torch/fx/passes/backends/cudagraphs.py +++ b/torch/fx/passes/backends/cudagraphs.py @@ -15,7 +15,7 @@ class CudaGraphsSupport(OperatorSupport): if node.op not in CALLABLE_NODE_OPS: return False - if node.target == torch.ops.aten.embedding_dense_backward.default: + if node.target is torch.ops.aten.embedding_dense_backward.default: return False if node.target == operator.getitem: diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py index e23045caccb..a4b244750f3 100644 --- a/torch/fx/passes/split_module.py +++ b/torch/fx/passes/split_module.py @@ -326,18 +326,18 @@ def split_module( instantiate_node_partition_mapping(node) if node.op == "call_function" and node.target in GLOBAL_STATE_NODES: - if node.target == torch._C._set_grad_enabled: + if node.target is torch._C._set_grad_enabled: assert len(node.args) == 1 assert isinstance(node.args[0], bool) active_grad = node grad_regions[active_grad] = set({split_callback(node)}) - elif node.target == torch.amp._enter_autocast: + elif node.target is torch.amp._enter_autocast: # Should all be python constants assert all(not isinstance(arg, Node) for arg in node.args) active_autocasts.add(node) autocast_regions[node] = set({split_callback(node)}) autocast_exits[node] = None - elif node.target == torch.amp._exit_autocast: + elif node.target is torch.amp._exit_autocast: assert len(node.args) == 1 autocast_regions[node.args[0]].add(split_callback(node)) active_autocasts.remove(node.args[0]) diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index 1b16da9f242..565fcec6f70 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -1501,7 +1501,7 @@ def _checkpoint_without_reentrant_generator( unpack_error_cb = None if _checkpoint_debug_enabled if _checkpoint_debug_enabled is not None else debug: - if context_fn != noop_context_fn: + if context_fn is not noop_context_fn: raise ValueError( "debug=True is incompatible with non-default context_fn" )