[2/N] Simplify "in" operation for containers of a single item (#164323)

These issues are detected by ruff [FURB171](https://docs.astral.sh/ruff/rules/single-item-membership-test/#single-item-membership-test-furb171).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164323
Approved by: https://github.com/justinchuby, https://github.com/Skylion007
This commit is contained in:
Yuanyuan Chen 2025-10-01 05:39:11 +00:00 committed by PyTorch MergeBot
parent 96c3b9e275
commit cc8b14d09a
20 changed files with 50 additions and 42 deletions

View File

@ -264,7 +264,7 @@ def _cuda_system_info_comment() -> str:
try: try:
cuda_version_out = subprocess.check_output(["nvcc", "--version"]) cuda_version_out = subprocess.check_output(["nvcc", "--version"])
cuda_version_lines = cuda_version_out.decode().split("\n") cuda_version_lines = cuda_version_out.decode().split("\n")
comment = "".join([f"# {s} \n" for s in cuda_version_lines if s not in [""]]) comment = "".join([f"# {s} \n" for s in cuda_version_lines if s != ""])
model_str += f"{comment}\n" model_str += f"{comment}\n"
except (FileNotFoundError, subprocess.CalledProcessError): except (FileNotFoundError, subprocess.CalledProcessError):
model_str += "# nvcc not found\n" model_str += "# nvcc not found\n"

View File

@ -3623,7 +3623,7 @@ class CheckFunctionManager:
# Leave out the builtins dict key, as we will special handle # Leave out the builtins dict key, as we will special handle
# it later because the guarded code rarely use the entire # it later because the guarded code rarely use the entire
# builtin dict in the common case. # builtin dict in the common case.
if name not in (builtins_dict_name,): if name != builtins_dict_name:
used_global_vars.add(name) used_global_vars.add(name)
elif name := get_local_source_name(source): elif name := get_local_source_name(source):
assert isinstance(name, str) assert isinstance(name, str)

View File

@ -4228,9 +4228,10 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
# _origin marks this as coming from an internal dynamo known function that is safe to # _origin marks this as coming from an internal dynamo known function that is safe to
# trace through. # trace through.
if hasattr(getattr(func, "fn", None), "_origin") and func.fn._origin in [ if (
produce_trampoline_autograd_apply, hasattr(getattr(func, "fn", None), "_origin")
]: and func.fn._origin is produce_trampoline_autograd_apply
):
# Known sound # Known sound
return trace_rules.SkipResult( return trace_rules.SkipResult(
False, "allowlist in dynamo known function" False, "allowlist in dynamo known function"

View File

@ -3017,7 +3017,7 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
] ]
or ( or (
# TODO: this is a little sus, because we didn't check what the self is # TODO: this is a little sus, because we didn't check what the self is
proxy.node.op == "call_method" and proxy.node.target in ["bit_length"] proxy.node.op == "call_method" and proxy.node.target == "bit_length"
) )
): ):
set_example_value(proxy.node, example_value) set_example_value(proxy.node, example_value)

View File

@ -154,7 +154,7 @@ class PlacementClassVariable(DistributedVariable):
kwargs: "dict[str, VariableTracker]", kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker": ) -> "VariableTracker":
if ( if (
inspect.getattr_static(self.value, "__new__", None) in (object.__new__,) inspect.getattr_static(self.value, "__new__", None) == object.__new__
and self.source and self.source
): ):
# NOTE: we don't need to track mutations to the placement class as they # NOTE: we don't need to track mutations to the placement class as they

View File

@ -1536,7 +1536,7 @@ class NumpyNdarrayVariable(TensorVariable):
explanation=f"Dynamo currently does not support tracing `ndarray.{name}`.", explanation=f"Dynamo currently does not support tracing `ndarray.{name}`.",
hints=[], hints=[],
) )
elif name in ["__version__"]: elif name == "__version__":
unimplemented_v2( unimplemented_v2(
gb_type="Unsupported ndarray.__version__ access", gb_type="Unsupported ndarray.__version__ access",
context=f"var_getattr {self} {name}", context=f"var_getattr {self} {name}",

View File

@ -497,7 +497,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
self, tx: "InstructionTranslator", *args, **kwargs self, tx: "InstructionTranslator", *args, **kwargs
): ):
assert not kwargs assert not kwargs
if self.value in (torch._C._dispatch_keys,): if self.value is torch._C._dispatch_keys:
assert len(args) == 1 assert len(args) == 1
assert isinstance(args[0], variables.TensorVariable) assert isinstance(args[0], variables.TensorVariable)
example_value = args[0].proxy.node.meta["example_value"] example_value = args[0].proxy.node.meta["example_value"]
@ -1862,7 +1862,7 @@ class DispatchKeySetVariable(BaseTorchVariable):
return cls(value, source=source) return cls(value, source=source)
def is_constant_fold_method(self, name): def is_constant_fold_method(self, name):
return name in ["has"] return name == "has"
def call_method( def call_method(
self, self,

View File

@ -151,7 +151,7 @@ def _should_lower_as_one_shot_all_reduce(
config._collective.auto_select config._collective.auto_select
and is_symm_mem_enabled_for_group(group_name) and is_symm_mem_enabled_for_group(group_name)
and can_realize_as_comm_buffer(inp, ir.CommBufferType.SYMM_MEM) and can_realize_as_comm_buffer(inp, ir.CommBufferType.SYMM_MEM)
and reduce_op in ("sum",) and reduce_op == "sum"
and inp_size <= config._collective.one_shot_all_reduce_threshold_bytes and inp_size <= config._collective.one_shot_all_reduce_threshold_bytes
) )

View File

@ -1037,7 +1037,7 @@ def _get_unexposed_collectives(graph: torch.fx.Graph) -> list[torch.fx.Node]:
""" """
def _is_compute_intensive(node: torch.fx.Node) -> bool: def _is_compute_intensive(node: torch.fx.Node) -> bool:
return node.target in [torch.ops.aten.mm.default] return node.target is torch.ops.aten.mm.default
collective_to_overlapping_candidates = defaultdict(list) collective_to_overlapping_candidates = defaultdict(list)
available_nodes = OrderedSet[torch.fx.Node]() available_nodes = OrderedSet[torch.fx.Node]()

View File

@ -614,7 +614,7 @@ def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False) -> torch.fx.GraphModul
class NormalizedLinearNode: class NormalizedLinearNode:
def __init__(self, node: torch.fx.Node) -> None: def __init__(self, node: torch.fx.Node) -> None:
assert node.op == "call_function" assert node.op == "call_function"
assert node.target in [torch.nn.functional.linear] assert node.target is torch.nn.functional.linear
self.node: torch.fx.Node = node self.node: torch.fx.Node = node
def get_input(self) -> torch.fx.Node: def get_input(self) -> torch.fx.Node:

View File

@ -3846,7 +3846,7 @@ def meta__dyn_quant_matmul_4bit(
): ):
torch._check(inp.dim() == 2, lambda: "input must be a 2D tensor") torch._check(inp.dim() == 2, lambda: "input must be a 2D tensor")
torch._check( torch._check(
inp.dtype in [torch.float32], inp.dtype == torch.float32,
lambda: f"expected input to be f32, got {inp.dtype}", lambda: f"expected input to be f32, got {inp.dtype}",
) )
M = inp.size(0) M = inp.size(0)

View File

@ -352,12 +352,14 @@ def _make_prim(
from torch._subclasses.fake_tensor import contains_tensor_types from torch._subclasses.fake_tensor import contains_tensor_types
if not any(contains_tensor_types(a.type) for a in _prim._schema.arguments) or str( if (
_prim not any(contains_tensor_types(a.type) for a in _prim._schema.arguments)
) in [ or str(
# See https://github.com/pytorch/pytorch/issues/103532 _prim
"prims.device_put.default" # See https://github.com/pytorch/pytorch/issues/103532
]: )
== "prims.device_put.default"
):
prim_backend_select_impl.impl(name, _backend_select_impl) prim_backend_select_impl.impl(name, _backend_select_impl)
for p in (_prim_packet, _prim): for p in (_prim_packet, _prim):

View File

@ -890,7 +890,7 @@ class Tracer(TracerBase):
new_tracer = Tracer.__new__(Tracer) new_tracer = Tracer.__new__(Tracer)
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if k in {"_autowrap_search"}: if k == "_autowrap_search":
new_obj = copy.copy(v) new_obj = copy.copy(v)
else: else:
new_obj = copy.deepcopy(v, memo) new_obj = copy.deepcopy(v, memo)

View File

@ -882,7 +882,7 @@ def proxy_call(
def can_handle_tensor(x: Tensor) -> bool: def can_handle_tensor(x: Tensor) -> bool:
r = type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer) r = type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer)
if proxy_mode._allow_fake_constant: if proxy_mode._allow_fake_constant:
r = r or type(x) in (torch._subclasses.FakeTensor,) r = r or type(x) is torch._subclasses.FakeTensor
if not r: if not r:
unrecognized_types.append(type(x)) unrecognized_types.append(type(x))
return r return r
@ -1534,7 +1534,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
with set_original_aten_op(func): with set_original_aten_op(func):
kwargs = kwargs or {} kwargs = kwargs or {}
if func in (prim.device.default,): if func == prim.device.default:
return func(*args, **kwargs) return func(*args, **kwargs)
return proxy_call(self, func, self.pre_dispatch, args, kwargs) return proxy_call(self, func, self.pre_dispatch, args, kwargs)

View File

@ -15,10 +15,10 @@ class CudaGraphsSupport(OperatorSupport):
if node.op not in CALLABLE_NODE_OPS: if node.op not in CALLABLE_NODE_OPS:
return False return False
if node.target in [torch.ops.aten.embedding_dense_backward.default]: if node.target == torch.ops.aten.embedding_dense_backward.default:
return False return False
if node.target in [operator.getitem]: if node.target == operator.getitem:
return True return True
found_not_cuda = False found_not_cuda = False

View File

@ -815,9 +815,10 @@ def _is_fp(value) -> bool:
def _is_bool(value) -> bool: def _is_bool(value) -> bool:
return _type_utils.JitScalarType.from_value( return (
value, _type_utils.JitScalarType.UNDEFINED _type_utils.JitScalarType.from_value(value, _type_utils.JitScalarType.UNDEFINED)
) in {_type_utils.JitScalarType.BOOL} == _type_utils.JitScalarType.BOOL
)
def _generate_wrapped_number(g: jit_utils.GraphContext, scalar): def _generate_wrapped_number(g: jit_utils.GraphContext, scalar):

View File

@ -3011,7 +3011,7 @@ def error_inputs_aminmax_amax_amin(op_info, device, is_ref=False, **kwargs):
err_msg_aminmax = "cannot compute aminmax over an empty dimension as the operation has no identity" err_msg_aminmax = "cannot compute aminmax over an empty dimension as the operation has no identity"
if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']: if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_amax_amin) yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_amax_amin)
elif op_info.name in ['aminmax']: elif op_info.name == 'aminmax':
yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_aminmax) yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_aminmax)
# Error Inputs for tensors with more than 64 dimension # Error Inputs for tensors with more than 64 dimension
@ -3050,7 +3050,7 @@ def error_inputs_aminmax_amax_amin(op_info, device, is_ref=False, **kwargs):
if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']: if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': illegal_values}), yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': illegal_values}),
error_regex=err_msg_amax_amin2) error_regex=err_msg_amax_amin2)
elif op_info.name in ['aminmax']: elif op_info.name == 'aminmax':
yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': (max_values, min_values)}), yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': (max_values, min_values)}),
error_regex=err_msg_aminmax2) error_regex=err_msg_aminmax2)
@ -4883,7 +4883,7 @@ def sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs):
def reference_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs): def reference_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs):
yield from sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs) yield from sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs)
if mode in ('bilinear', ): if mode == 'bilinear':
make_arg = partial( make_arg = partial(
make_tensor, make_tensor,
device=device, device=device,
@ -9468,7 +9468,7 @@ class foreach_inputs_sample_func:
# unary # unary
if opinfo.ref in (torch.abs, torch.neg): if opinfo.ref in (torch.abs, torch.neg):
return False return False
if opinfo.ref_inplace in (torch.Tensor.zero_,): if opinfo.ref_inplace == torch.Tensor.zero_:
return False return False
return dtype in integral_types_and(torch.bool) return dtype in integral_types_and(torch.bool)
if self.arity < 2 or rightmost_arg_type == ForeachRightmostArgType.Tensor: if self.arity < 2 or rightmost_arg_type == ForeachRightmostArgType.Tensor:
@ -9698,7 +9698,7 @@ class foreach_pointwise_sample_func(foreach_inputs_sample_func):
super().__init__(arity, rightmost_supports_scalar, rightmost_supports_scalarlist) super().__init__(arity, rightmost_supports_scalar, rightmost_supports_scalarlist)
def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype): def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype):
return dtype in integral_types_and(torch.bool) and opinfo.ref in (torch.addcmul,) return dtype in integral_types_and(torch.bool) and opinfo.ref == torch.addcmul
def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs): def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs):
assert "num_input_tensors" not in kwargs assert "num_input_tensors" not in kwargs

View File

@ -572,7 +572,7 @@ def optim_inputs_func_adam(device, dtype=None):
+ (cuda_supported_configs if _get_device_type(device) == "cuda" else []) + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
+ (mps_supported_configs if _get_device_type(device) == "mps" else []) + (mps_supported_configs if _get_device_type(device) == "mps" else [])
) )
if dtype in (torch.float16,): if dtype == torch.float16:
for input in total: for input in total:
""" """
Too small eps will make denom to be zero for low precision dtype Too small eps will make denom to be zero for low precision dtype

View File

@ -322,7 +322,7 @@ def sample_inputs_linalg_matrix_norm(op_info, device, dtype, requires_grad, **kw
def sample_inputs_linalg_norm( def sample_inputs_linalg_norm(
op_info, device, dtype, requires_grad, *, variant=None, **kwargs op_info, device, dtype, requires_grad, *, variant=None, **kwargs
): ):
if variant is not None and variant not in ("subgradient_at_zero",): if variant is not None and variant != "subgradient_at_zero":
raise ValueError( raise ValueError(
f"Unsupported variant, expected variant to be 'subgradient_at_zero' but got: {variant}" f"Unsupported variant, expected variant to be 'subgradient_at_zero' but got: {variant}"
) )

View File

@ -204,7 +204,7 @@ def _validate_sample_input_sparse_reduction(op_info, sample, check_validate=Fals
if op_info.name == "sum": if op_info.name == "sum":
sample = _validate_sample_input_sparse_reduction_sum(sample) sample = _validate_sample_input_sparse_reduction_sum(sample)
if op_info.name in {"masked.sum"}: if op_info.name == "masked.sum":
mask = sample.kwargs.get("mask", UNSPECIFIED) mask = sample.kwargs.get("mask", UNSPECIFIED)
if ( if (
mask not in {None, UNSPECIFIED} mask not in {None, UNSPECIFIED}
@ -792,12 +792,16 @@ def _sample_inputs_sparse_like_fns(
def _validate_sample_input_sparse_like_fns(op_info, sample, check_validate=False): def _validate_sample_input_sparse_like_fns(op_info, sample, check_validate=False):
if sample.input.layout in { if (
torch.sparse_csr, sample.input.layout
torch.sparse_csc, in {
torch.sparse_bsr, torch.sparse_csr,
torch.sparse_bsc, torch.sparse_csc,
} and op_info.name not in {"zeros_like"}: torch.sparse_bsr,
torch.sparse_bsc,
}
and op_info.name != "zeros_like"
):
if sample.kwargs.get("layout", sample.input.layout) != sample.input.layout: if sample.kwargs.get("layout", sample.input.layout) != sample.input.layout:
return ErrorInput( return ErrorInput(
sample, sample,