mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[2/N] Simplify "in" operation for containers of a single item (#164323)
These issues are detected by ruff [FURB171](https://docs.astral.sh/ruff/rules/single-item-membership-test/#single-item-membership-test-furb171). Pull Request resolved: https://github.com/pytorch/pytorch/pull/164323 Approved by: https://github.com/justinchuby, https://github.com/Skylion007
This commit is contained in:
parent
96c3b9e275
commit
cc8b14d09a
|
|
@ -264,7 +264,7 @@ def _cuda_system_info_comment() -> str:
|
|||
try:
|
||||
cuda_version_out = subprocess.check_output(["nvcc", "--version"])
|
||||
cuda_version_lines = cuda_version_out.decode().split("\n")
|
||||
comment = "".join([f"# {s} \n" for s in cuda_version_lines if s not in [""]])
|
||||
comment = "".join([f"# {s} \n" for s in cuda_version_lines if s != ""])
|
||||
model_str += f"{comment}\n"
|
||||
except (FileNotFoundError, subprocess.CalledProcessError):
|
||||
model_str += "# nvcc not found\n"
|
||||
|
|
|
|||
|
|
@ -3623,7 +3623,7 @@ class CheckFunctionManager:
|
|||
# Leave out the builtins dict key, as we will special handle
|
||||
# it later because the guarded code rarely use the entire
|
||||
# builtin dict in the common case.
|
||||
if name not in (builtins_dict_name,):
|
||||
if name != builtins_dict_name:
|
||||
used_global_vars.add(name)
|
||||
elif name := get_local_source_name(source):
|
||||
assert isinstance(name, str)
|
||||
|
|
|
|||
|
|
@ -4228,9 +4228,10 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||
|
||||
# _origin marks this as coming from an internal dynamo known function that is safe to
|
||||
# trace through.
|
||||
if hasattr(getattr(func, "fn", None), "_origin") and func.fn._origin in [
|
||||
produce_trampoline_autograd_apply,
|
||||
]:
|
||||
if (
|
||||
hasattr(getattr(func, "fn", None), "_origin")
|
||||
and func.fn._origin is produce_trampoline_autograd_apply
|
||||
):
|
||||
# Known sound
|
||||
return trace_rules.SkipResult(
|
||||
False, "allowlist in dynamo known function"
|
||||
|
|
|
|||
|
|
@ -3017,7 +3017,7 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
|
|||
]
|
||||
or (
|
||||
# TODO: this is a little sus, because we didn't check what the self is
|
||||
proxy.node.op == "call_method" and proxy.node.target in ["bit_length"]
|
||||
proxy.node.op == "call_method" and proxy.node.target == "bit_length"
|
||||
)
|
||||
):
|
||||
set_example_value(proxy.node, example_value)
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ class PlacementClassVariable(DistributedVariable):
|
|||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
if (
|
||||
inspect.getattr_static(self.value, "__new__", None) in (object.__new__,)
|
||||
inspect.getattr_static(self.value, "__new__", None) == object.__new__
|
||||
and self.source
|
||||
):
|
||||
# NOTE: we don't need to track mutations to the placement class as they
|
||||
|
|
|
|||
|
|
@ -1536,7 +1536,7 @@ class NumpyNdarrayVariable(TensorVariable):
|
|||
explanation=f"Dynamo currently does not support tracing `ndarray.{name}`.",
|
||||
hints=[],
|
||||
)
|
||||
elif name in ["__version__"]:
|
||||
elif name == "__version__":
|
||||
unimplemented_v2(
|
||||
gb_type="Unsupported ndarray.__version__ access",
|
||||
context=f"var_getattr {self} {name}",
|
||||
|
|
|
|||
|
|
@ -497,7 +497,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||
self, tx: "InstructionTranslator", *args, **kwargs
|
||||
):
|
||||
assert not kwargs
|
||||
if self.value in (torch._C._dispatch_keys,):
|
||||
if self.value is torch._C._dispatch_keys:
|
||||
assert len(args) == 1
|
||||
assert isinstance(args[0], variables.TensorVariable)
|
||||
example_value = args[0].proxy.node.meta["example_value"]
|
||||
|
|
@ -1862,7 +1862,7 @@ class DispatchKeySetVariable(BaseTorchVariable):
|
|||
return cls(value, source=source)
|
||||
|
||||
def is_constant_fold_method(self, name):
|
||||
return name in ["has"]
|
||||
return name == "has"
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -151,7 +151,7 @@ def _should_lower_as_one_shot_all_reduce(
|
|||
config._collective.auto_select
|
||||
and is_symm_mem_enabled_for_group(group_name)
|
||||
and can_realize_as_comm_buffer(inp, ir.CommBufferType.SYMM_MEM)
|
||||
and reduce_op in ("sum",)
|
||||
and reduce_op == "sum"
|
||||
and inp_size <= config._collective.one_shot_all_reduce_threshold_bytes
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1037,7 +1037,7 @@ def _get_unexposed_collectives(graph: torch.fx.Graph) -> list[torch.fx.Node]:
|
|||
"""
|
||||
|
||||
def _is_compute_intensive(node: torch.fx.Node) -> bool:
|
||||
return node.target in [torch.ops.aten.mm.default]
|
||||
return node.target is torch.ops.aten.mm.default
|
||||
|
||||
collective_to_overlapping_candidates = defaultdict(list)
|
||||
available_nodes = OrderedSet[torch.fx.Node]()
|
||||
|
|
|
|||
|
|
@ -614,7 +614,7 @@ def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False) -> torch.fx.GraphModul
|
|||
class NormalizedLinearNode:
|
||||
def __init__(self, node: torch.fx.Node) -> None:
|
||||
assert node.op == "call_function"
|
||||
assert node.target in [torch.nn.functional.linear]
|
||||
assert node.target is torch.nn.functional.linear
|
||||
self.node: torch.fx.Node = node
|
||||
|
||||
def get_input(self) -> torch.fx.Node:
|
||||
|
|
|
|||
|
|
@ -3846,7 +3846,7 @@ def meta__dyn_quant_matmul_4bit(
|
|||
):
|
||||
torch._check(inp.dim() == 2, lambda: "input must be a 2D tensor")
|
||||
torch._check(
|
||||
inp.dtype in [torch.float32],
|
||||
inp.dtype == torch.float32,
|
||||
lambda: f"expected input to be f32, got {inp.dtype}",
|
||||
)
|
||||
M = inp.size(0)
|
||||
|
|
|
|||
|
|
@ -352,12 +352,14 @@ def _make_prim(
|
|||
|
||||
from torch._subclasses.fake_tensor import contains_tensor_types
|
||||
|
||||
if not any(contains_tensor_types(a.type) for a in _prim._schema.arguments) or str(
|
||||
_prim
|
||||
) in [
|
||||
# See https://github.com/pytorch/pytorch/issues/103532
|
||||
"prims.device_put.default"
|
||||
]:
|
||||
if (
|
||||
not any(contains_tensor_types(a.type) for a in _prim._schema.arguments)
|
||||
or str(
|
||||
_prim
|
||||
# See https://github.com/pytorch/pytorch/issues/103532
|
||||
)
|
||||
== "prims.device_put.default"
|
||||
):
|
||||
prim_backend_select_impl.impl(name, _backend_select_impl)
|
||||
|
||||
for p in (_prim_packet, _prim):
|
||||
|
|
|
|||
|
|
@ -890,7 +890,7 @@ class Tracer(TracerBase):
|
|||
new_tracer = Tracer.__new__(Tracer)
|
||||
|
||||
for k, v in self.__dict__.items():
|
||||
if k in {"_autowrap_search"}:
|
||||
if k == "_autowrap_search":
|
||||
new_obj = copy.copy(v)
|
||||
else:
|
||||
new_obj = copy.deepcopy(v, memo)
|
||||
|
|
|
|||
|
|
@ -882,7 +882,7 @@ def proxy_call(
|
|||
def can_handle_tensor(x: Tensor) -> bool:
|
||||
r = type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer)
|
||||
if proxy_mode._allow_fake_constant:
|
||||
r = r or type(x) in (torch._subclasses.FakeTensor,)
|
||||
r = r or type(x) is torch._subclasses.FakeTensor
|
||||
if not r:
|
||||
unrecognized_types.append(type(x))
|
||||
return r
|
||||
|
|
@ -1534,7 +1534,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||
with set_original_aten_op(func):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
if func in (prim.device.default,):
|
||||
if func == prim.device.default:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
|
||||
|
|
|
|||
|
|
@ -15,10 +15,10 @@ class CudaGraphsSupport(OperatorSupport):
|
|||
if node.op not in CALLABLE_NODE_OPS:
|
||||
return False
|
||||
|
||||
if node.target in [torch.ops.aten.embedding_dense_backward.default]:
|
||||
if node.target == torch.ops.aten.embedding_dense_backward.default:
|
||||
return False
|
||||
|
||||
if node.target in [operator.getitem]:
|
||||
if node.target == operator.getitem:
|
||||
return True
|
||||
|
||||
found_not_cuda = False
|
||||
|
|
|
|||
|
|
@ -815,9 +815,10 @@ def _is_fp(value) -> bool:
|
|||
|
||||
|
||||
def _is_bool(value) -> bool:
|
||||
return _type_utils.JitScalarType.from_value(
|
||||
value, _type_utils.JitScalarType.UNDEFINED
|
||||
) in {_type_utils.JitScalarType.BOOL}
|
||||
return (
|
||||
_type_utils.JitScalarType.from_value(value, _type_utils.JitScalarType.UNDEFINED)
|
||||
== _type_utils.JitScalarType.BOOL
|
||||
)
|
||||
|
||||
|
||||
def _generate_wrapped_number(g: jit_utils.GraphContext, scalar):
|
||||
|
|
|
|||
|
|
@ -3011,7 +3011,7 @@ def error_inputs_aminmax_amax_amin(op_info, device, is_ref=False, **kwargs):
|
|||
err_msg_aminmax = "cannot compute aminmax over an empty dimension as the operation has no identity"
|
||||
if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
|
||||
yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_amax_amin)
|
||||
elif op_info.name in ['aminmax']:
|
||||
elif op_info.name == 'aminmax':
|
||||
yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_aminmax)
|
||||
|
||||
# Error Inputs for tensors with more than 64 dimension
|
||||
|
|
@ -3050,7 +3050,7 @@ def error_inputs_aminmax_amax_amin(op_info, device, is_ref=False, **kwargs):
|
|||
if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
|
||||
yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': illegal_values}),
|
||||
error_regex=err_msg_amax_amin2)
|
||||
elif op_info.name in ['aminmax']:
|
||||
elif op_info.name == 'aminmax':
|
||||
yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': (max_values, min_values)}),
|
||||
error_regex=err_msg_aminmax2)
|
||||
|
||||
|
|
@ -4883,7 +4883,7 @@ def sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs):
|
|||
def reference_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs):
|
||||
yield from sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs)
|
||||
|
||||
if mode in ('bilinear', ):
|
||||
if mode == 'bilinear':
|
||||
make_arg = partial(
|
||||
make_tensor,
|
||||
device=device,
|
||||
|
|
@ -9468,7 +9468,7 @@ class foreach_inputs_sample_func:
|
|||
# unary
|
||||
if opinfo.ref in (torch.abs, torch.neg):
|
||||
return False
|
||||
if opinfo.ref_inplace in (torch.Tensor.zero_,):
|
||||
if opinfo.ref_inplace == torch.Tensor.zero_:
|
||||
return False
|
||||
return dtype in integral_types_and(torch.bool)
|
||||
if self.arity < 2 or rightmost_arg_type == ForeachRightmostArgType.Tensor:
|
||||
|
|
@ -9698,7 +9698,7 @@ class foreach_pointwise_sample_func(foreach_inputs_sample_func):
|
|||
super().__init__(arity, rightmost_supports_scalar, rightmost_supports_scalarlist)
|
||||
|
||||
def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype):
|
||||
return dtype in integral_types_and(torch.bool) and opinfo.ref in (torch.addcmul,)
|
||||
return dtype in integral_types_and(torch.bool) and opinfo.ref == torch.addcmul
|
||||
|
||||
def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs):
|
||||
assert "num_input_tensors" not in kwargs
|
||||
|
|
|
|||
|
|
@ -572,7 +572,7 @@ def optim_inputs_func_adam(device, dtype=None):
|
|||
+ (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
|
||||
+ (mps_supported_configs if _get_device_type(device) == "mps" else [])
|
||||
)
|
||||
if dtype in (torch.float16,):
|
||||
if dtype == torch.float16:
|
||||
for input in total:
|
||||
"""
|
||||
Too small eps will make denom to be zero for low precision dtype
|
||||
|
|
|
|||
|
|
@ -322,7 +322,7 @@ def sample_inputs_linalg_matrix_norm(op_info, device, dtype, requires_grad, **kw
|
|||
def sample_inputs_linalg_norm(
|
||||
op_info, device, dtype, requires_grad, *, variant=None, **kwargs
|
||||
):
|
||||
if variant is not None and variant not in ("subgradient_at_zero",):
|
||||
if variant is not None and variant != "subgradient_at_zero":
|
||||
raise ValueError(
|
||||
f"Unsupported variant, expected variant to be 'subgradient_at_zero' but got: {variant}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -204,7 +204,7 @@ def _validate_sample_input_sparse_reduction(op_info, sample, check_validate=Fals
|
|||
if op_info.name == "sum":
|
||||
sample = _validate_sample_input_sparse_reduction_sum(sample)
|
||||
|
||||
if op_info.name in {"masked.sum"}:
|
||||
if op_info.name == "masked.sum":
|
||||
mask = sample.kwargs.get("mask", UNSPECIFIED)
|
||||
if (
|
||||
mask not in {None, UNSPECIFIED}
|
||||
|
|
@ -792,12 +792,16 @@ def _sample_inputs_sparse_like_fns(
|
|||
|
||||
|
||||
def _validate_sample_input_sparse_like_fns(op_info, sample, check_validate=False):
|
||||
if sample.input.layout in {
|
||||
torch.sparse_csr,
|
||||
torch.sparse_csc,
|
||||
torch.sparse_bsr,
|
||||
torch.sparse_bsc,
|
||||
} and op_info.name not in {"zeros_like"}:
|
||||
if (
|
||||
sample.input.layout
|
||||
in {
|
||||
torch.sparse_csr,
|
||||
torch.sparse_csc,
|
||||
torch.sparse_bsr,
|
||||
torch.sparse_bsc,
|
||||
}
|
||||
and op_info.name != "zeros_like"
|
||||
):
|
||||
if sample.kwargs.get("layout", sample.input.layout) != sample.input.layout:
|
||||
return ErrorInput(
|
||||
sample,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user