[2/N] Simplify "in" operation for containers of a single item (#164323)

These issues are detected by ruff [FURB171](https://docs.astral.sh/ruff/rules/single-item-membership-test/#single-item-membership-test-furb171).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164323
Approved by: https://github.com/justinchuby, https://github.com/Skylion007
This commit is contained in:
Yuanyuan Chen 2025-10-01 05:39:11 +00:00 committed by PyTorch MergeBot
parent 96c3b9e275
commit cc8b14d09a
20 changed files with 50 additions and 42 deletions

View File

@ -264,7 +264,7 @@ def _cuda_system_info_comment() -> str:
try:
cuda_version_out = subprocess.check_output(["nvcc", "--version"])
cuda_version_lines = cuda_version_out.decode().split("\n")
comment = "".join([f"# {s} \n" for s in cuda_version_lines if s not in [""]])
comment = "".join([f"# {s} \n" for s in cuda_version_lines if s != ""])
model_str += f"{comment}\n"
except (FileNotFoundError, subprocess.CalledProcessError):
model_str += "# nvcc not found\n"

View File

@ -3623,7 +3623,7 @@ class CheckFunctionManager:
# Leave out the builtins dict key, as we will special handle
# it later because the guarded code rarely use the entire
# builtin dict in the common case.
if name not in (builtins_dict_name,):
if name != builtins_dict_name:
used_global_vars.add(name)
elif name := get_local_source_name(source):
assert isinstance(name, str)

View File

@ -4228,9 +4228,10 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
# _origin marks this as coming from an internal dynamo known function that is safe to
# trace through.
if hasattr(getattr(func, "fn", None), "_origin") and func.fn._origin in [
produce_trampoline_autograd_apply,
]:
if (
hasattr(getattr(func, "fn", None), "_origin")
and func.fn._origin is produce_trampoline_autograd_apply
):
# Known sound
return trace_rules.SkipResult(
False, "allowlist in dynamo known function"

View File

@ -3017,7 +3017,7 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
]
or (
# TODO: this is a little sus, because we didn't check what the self is
proxy.node.op == "call_method" and proxy.node.target in ["bit_length"]
proxy.node.op == "call_method" and proxy.node.target == "bit_length"
)
):
set_example_value(proxy.node, example_value)

View File

@ -154,7 +154,7 @@ class PlacementClassVariable(DistributedVariable):
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
if (
inspect.getattr_static(self.value, "__new__", None) in (object.__new__,)
inspect.getattr_static(self.value, "__new__", None) == object.__new__
and self.source
):
# NOTE: we don't need to track mutations to the placement class as they

View File

@ -1536,7 +1536,7 @@ class NumpyNdarrayVariable(TensorVariable):
explanation=f"Dynamo currently does not support tracing `ndarray.{name}`.",
hints=[],
)
elif name in ["__version__"]:
elif name == "__version__":
unimplemented_v2(
gb_type="Unsupported ndarray.__version__ access",
context=f"var_getattr {self} {name}",

View File

@ -497,7 +497,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
self, tx: "InstructionTranslator", *args, **kwargs
):
assert not kwargs
if self.value in (torch._C._dispatch_keys,):
if self.value is torch._C._dispatch_keys:
assert len(args) == 1
assert isinstance(args[0], variables.TensorVariable)
example_value = args[0].proxy.node.meta["example_value"]
@ -1862,7 +1862,7 @@ class DispatchKeySetVariable(BaseTorchVariable):
return cls(value, source=source)
def is_constant_fold_method(self, name):
return name in ["has"]
return name == "has"
def call_method(
self,

View File

@ -151,7 +151,7 @@ def _should_lower_as_one_shot_all_reduce(
config._collective.auto_select
and is_symm_mem_enabled_for_group(group_name)
and can_realize_as_comm_buffer(inp, ir.CommBufferType.SYMM_MEM)
and reduce_op in ("sum",)
and reduce_op == "sum"
and inp_size <= config._collective.one_shot_all_reduce_threshold_bytes
)

View File

@ -1037,7 +1037,7 @@ def _get_unexposed_collectives(graph: torch.fx.Graph) -> list[torch.fx.Node]:
"""
def _is_compute_intensive(node: torch.fx.Node) -> bool:
return node.target in [torch.ops.aten.mm.default]
return node.target is torch.ops.aten.mm.default
collective_to_overlapping_candidates = defaultdict(list)
available_nodes = OrderedSet[torch.fx.Node]()

View File

@ -614,7 +614,7 @@ def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False) -> torch.fx.GraphModul
class NormalizedLinearNode:
def __init__(self, node: torch.fx.Node) -> None:
assert node.op == "call_function"
assert node.target in [torch.nn.functional.linear]
assert node.target is torch.nn.functional.linear
self.node: torch.fx.Node = node
def get_input(self) -> torch.fx.Node:

View File

@ -3846,7 +3846,7 @@ def meta__dyn_quant_matmul_4bit(
):
torch._check(inp.dim() == 2, lambda: "input must be a 2D tensor")
torch._check(
inp.dtype in [torch.float32],
inp.dtype == torch.float32,
lambda: f"expected input to be f32, got {inp.dtype}",
)
M = inp.size(0)

View File

@ -352,12 +352,14 @@ def _make_prim(
from torch._subclasses.fake_tensor import contains_tensor_types
if not any(contains_tensor_types(a.type) for a in _prim._schema.arguments) or str(
_prim
) in [
# See https://github.com/pytorch/pytorch/issues/103532
"prims.device_put.default"
]:
if (
not any(contains_tensor_types(a.type) for a in _prim._schema.arguments)
or str(
_prim
# See https://github.com/pytorch/pytorch/issues/103532
)
== "prims.device_put.default"
):
prim_backend_select_impl.impl(name, _backend_select_impl)
for p in (_prim_packet, _prim):

View File

@ -890,7 +890,7 @@ class Tracer(TracerBase):
new_tracer = Tracer.__new__(Tracer)
for k, v in self.__dict__.items():
if k in {"_autowrap_search"}:
if k == "_autowrap_search":
new_obj = copy.copy(v)
else:
new_obj = copy.deepcopy(v, memo)

View File

@ -882,7 +882,7 @@ def proxy_call(
def can_handle_tensor(x: Tensor) -> bool:
r = type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer)
if proxy_mode._allow_fake_constant:
r = r or type(x) in (torch._subclasses.FakeTensor,)
r = r or type(x) is torch._subclasses.FakeTensor
if not r:
unrecognized_types.append(type(x))
return r
@ -1534,7 +1534,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
with set_original_aten_op(func):
kwargs = kwargs or {}
if func in (prim.device.default,):
if func == prim.device.default:
return func(*args, **kwargs)
return proxy_call(self, func, self.pre_dispatch, args, kwargs)

View File

@ -15,10 +15,10 @@ class CudaGraphsSupport(OperatorSupport):
if node.op not in CALLABLE_NODE_OPS:
return False
if node.target in [torch.ops.aten.embedding_dense_backward.default]:
if node.target == torch.ops.aten.embedding_dense_backward.default:
return False
if node.target in [operator.getitem]:
if node.target == operator.getitem:
return True
found_not_cuda = False

View File

@ -815,9 +815,10 @@ def _is_fp(value) -> bool:
def _is_bool(value) -> bool:
return _type_utils.JitScalarType.from_value(
value, _type_utils.JitScalarType.UNDEFINED
) in {_type_utils.JitScalarType.BOOL}
return (
_type_utils.JitScalarType.from_value(value, _type_utils.JitScalarType.UNDEFINED)
== _type_utils.JitScalarType.BOOL
)
def _generate_wrapped_number(g: jit_utils.GraphContext, scalar):

View File

@ -3011,7 +3011,7 @@ def error_inputs_aminmax_amax_amin(op_info, device, is_ref=False, **kwargs):
err_msg_aminmax = "cannot compute aminmax over an empty dimension as the operation has no identity"
if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_amax_amin)
elif op_info.name in ['aminmax']:
elif op_info.name == 'aminmax':
yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_aminmax)
# Error Inputs for tensors with more than 64 dimension
@ -3050,7 +3050,7 @@ def error_inputs_aminmax_amax_amin(op_info, device, is_ref=False, **kwargs):
if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': illegal_values}),
error_regex=err_msg_amax_amin2)
elif op_info.name in ['aminmax']:
elif op_info.name == 'aminmax':
yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': (max_values, min_values)}),
error_regex=err_msg_aminmax2)
@ -4883,7 +4883,7 @@ def sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs):
def reference_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs):
yield from sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs)
if mode in ('bilinear', ):
if mode == 'bilinear':
make_arg = partial(
make_tensor,
device=device,
@ -9468,7 +9468,7 @@ class foreach_inputs_sample_func:
# unary
if opinfo.ref in (torch.abs, torch.neg):
return False
if opinfo.ref_inplace in (torch.Tensor.zero_,):
if opinfo.ref_inplace == torch.Tensor.zero_:
return False
return dtype in integral_types_and(torch.bool)
if self.arity < 2 or rightmost_arg_type == ForeachRightmostArgType.Tensor:
@ -9698,7 +9698,7 @@ class foreach_pointwise_sample_func(foreach_inputs_sample_func):
super().__init__(arity, rightmost_supports_scalar, rightmost_supports_scalarlist)
def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype):
return dtype in integral_types_and(torch.bool) and opinfo.ref in (torch.addcmul,)
return dtype in integral_types_and(torch.bool) and opinfo.ref == torch.addcmul
def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs):
assert "num_input_tensors" not in kwargs

View File

@ -572,7 +572,7 @@ def optim_inputs_func_adam(device, dtype=None):
+ (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
+ (mps_supported_configs if _get_device_type(device) == "mps" else [])
)
if dtype in (torch.float16,):
if dtype == torch.float16:
for input in total:
"""
Too small eps will make denom to be zero for low precision dtype

View File

@ -322,7 +322,7 @@ def sample_inputs_linalg_matrix_norm(op_info, device, dtype, requires_grad, **kw
def sample_inputs_linalg_norm(
op_info, device, dtype, requires_grad, *, variant=None, **kwargs
):
if variant is not None and variant not in ("subgradient_at_zero",):
if variant is not None and variant != "subgradient_at_zero":
raise ValueError(
f"Unsupported variant, expected variant to be 'subgradient_at_zero' but got: {variant}"
)

View File

@ -204,7 +204,7 @@ def _validate_sample_input_sparse_reduction(op_info, sample, check_validate=Fals
if op_info.name == "sum":
sample = _validate_sample_input_sparse_reduction_sum(sample)
if op_info.name in {"masked.sum"}:
if op_info.name == "masked.sum":
mask = sample.kwargs.get("mask", UNSPECIFIED)
if (
mask not in {None, UNSPECIFIED}
@ -792,12 +792,16 @@ def _sample_inputs_sparse_like_fns(
def _validate_sample_input_sparse_like_fns(op_info, sample, check_validate=False):
if sample.input.layout in {
torch.sparse_csr,
torch.sparse_csc,
torch.sparse_bsr,
torch.sparse_bsc,
} and op_info.name not in {"zeros_like"}:
if (
sample.input.layout
in {
torch.sparse_csr,
torch.sparse_csc,
torch.sparse_bsr,
torch.sparse_bsc,
}
and op_info.name != "zeros_like"
):
if sample.kwargs.get("layout", sample.input.layout) != sample.input.layout:
return ErrorInput(
sample,