mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[2/N] Simplify "in" operation for containers of a single item (#164323)
These issues are detected by ruff [FURB171](https://docs.astral.sh/ruff/rules/single-item-membership-test/#single-item-membership-test-furb171). Pull Request resolved: https://github.com/pytorch/pytorch/pull/164323 Approved by: https://github.com/justinchuby, https://github.com/Skylion007
This commit is contained in:
parent
96c3b9e275
commit
cc8b14d09a
|
|
@ -264,7 +264,7 @@ def _cuda_system_info_comment() -> str:
|
||||||
try:
|
try:
|
||||||
cuda_version_out = subprocess.check_output(["nvcc", "--version"])
|
cuda_version_out = subprocess.check_output(["nvcc", "--version"])
|
||||||
cuda_version_lines = cuda_version_out.decode().split("\n")
|
cuda_version_lines = cuda_version_out.decode().split("\n")
|
||||||
comment = "".join([f"# {s} \n" for s in cuda_version_lines if s not in [""]])
|
comment = "".join([f"# {s} \n" for s in cuda_version_lines if s != ""])
|
||||||
model_str += f"{comment}\n"
|
model_str += f"{comment}\n"
|
||||||
except (FileNotFoundError, subprocess.CalledProcessError):
|
except (FileNotFoundError, subprocess.CalledProcessError):
|
||||||
model_str += "# nvcc not found\n"
|
model_str += "# nvcc not found\n"
|
||||||
|
|
|
||||||
|
|
@ -3623,7 +3623,7 @@ class CheckFunctionManager:
|
||||||
# Leave out the builtins dict key, as we will special handle
|
# Leave out the builtins dict key, as we will special handle
|
||||||
# it later because the guarded code rarely use the entire
|
# it later because the guarded code rarely use the entire
|
||||||
# builtin dict in the common case.
|
# builtin dict in the common case.
|
||||||
if name not in (builtins_dict_name,):
|
if name != builtins_dict_name:
|
||||||
used_global_vars.add(name)
|
used_global_vars.add(name)
|
||||||
elif name := get_local_source_name(source):
|
elif name := get_local_source_name(source):
|
||||||
assert isinstance(name, str)
|
assert isinstance(name, str)
|
||||||
|
|
|
||||||
|
|
@ -4228,9 +4228,10 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||||
|
|
||||||
# _origin marks this as coming from an internal dynamo known function that is safe to
|
# _origin marks this as coming from an internal dynamo known function that is safe to
|
||||||
# trace through.
|
# trace through.
|
||||||
if hasattr(getattr(func, "fn", None), "_origin") and func.fn._origin in [
|
if (
|
||||||
produce_trampoline_autograd_apply,
|
hasattr(getattr(func, "fn", None), "_origin")
|
||||||
]:
|
and func.fn._origin is produce_trampoline_autograd_apply
|
||||||
|
):
|
||||||
# Known sound
|
# Known sound
|
||||||
return trace_rules.SkipResult(
|
return trace_rules.SkipResult(
|
||||||
False, "allowlist in dynamo known function"
|
False, "allowlist in dynamo known function"
|
||||||
|
|
|
||||||
|
|
@ -3017,7 +3017,7 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
|
||||||
]
|
]
|
||||||
or (
|
or (
|
||||||
# TODO: this is a little sus, because we didn't check what the self is
|
# TODO: this is a little sus, because we didn't check what the self is
|
||||||
proxy.node.op == "call_method" and proxy.node.target in ["bit_length"]
|
proxy.node.op == "call_method" and proxy.node.target == "bit_length"
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
set_example_value(proxy.node, example_value)
|
set_example_value(proxy.node, example_value)
|
||||||
|
|
|
||||||
|
|
@ -154,7 +154,7 @@ class PlacementClassVariable(DistributedVariable):
|
||||||
kwargs: "dict[str, VariableTracker]",
|
kwargs: "dict[str, VariableTracker]",
|
||||||
) -> "VariableTracker":
|
) -> "VariableTracker":
|
||||||
if (
|
if (
|
||||||
inspect.getattr_static(self.value, "__new__", None) in (object.__new__,)
|
inspect.getattr_static(self.value, "__new__", None) == object.__new__
|
||||||
and self.source
|
and self.source
|
||||||
):
|
):
|
||||||
# NOTE: we don't need to track mutations to the placement class as they
|
# NOTE: we don't need to track mutations to the placement class as they
|
||||||
|
|
|
||||||
|
|
@ -1536,7 +1536,7 @@ class NumpyNdarrayVariable(TensorVariable):
|
||||||
explanation=f"Dynamo currently does not support tracing `ndarray.{name}`.",
|
explanation=f"Dynamo currently does not support tracing `ndarray.{name}`.",
|
||||||
hints=[],
|
hints=[],
|
||||||
)
|
)
|
||||||
elif name in ["__version__"]:
|
elif name == "__version__":
|
||||||
unimplemented_v2(
|
unimplemented_v2(
|
||||||
gb_type="Unsupported ndarray.__version__ access",
|
gb_type="Unsupported ndarray.__version__ access",
|
||||||
context=f"var_getattr {self} {name}",
|
context=f"var_getattr {self} {name}",
|
||||||
|
|
|
||||||
|
|
@ -497,7 +497,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||||
self, tx: "InstructionTranslator", *args, **kwargs
|
self, tx: "InstructionTranslator", *args, **kwargs
|
||||||
):
|
):
|
||||||
assert not kwargs
|
assert not kwargs
|
||||||
if self.value in (torch._C._dispatch_keys,):
|
if self.value is torch._C._dispatch_keys:
|
||||||
assert len(args) == 1
|
assert len(args) == 1
|
||||||
assert isinstance(args[0], variables.TensorVariable)
|
assert isinstance(args[0], variables.TensorVariable)
|
||||||
example_value = args[0].proxy.node.meta["example_value"]
|
example_value = args[0].proxy.node.meta["example_value"]
|
||||||
|
|
@ -1862,7 +1862,7 @@ class DispatchKeySetVariable(BaseTorchVariable):
|
||||||
return cls(value, source=source)
|
return cls(value, source=source)
|
||||||
|
|
||||||
def is_constant_fold_method(self, name):
|
def is_constant_fold_method(self, name):
|
||||||
return name in ["has"]
|
return name == "has"
|
||||||
|
|
||||||
def call_method(
|
def call_method(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -151,7 +151,7 @@ def _should_lower_as_one_shot_all_reduce(
|
||||||
config._collective.auto_select
|
config._collective.auto_select
|
||||||
and is_symm_mem_enabled_for_group(group_name)
|
and is_symm_mem_enabled_for_group(group_name)
|
||||||
and can_realize_as_comm_buffer(inp, ir.CommBufferType.SYMM_MEM)
|
and can_realize_as_comm_buffer(inp, ir.CommBufferType.SYMM_MEM)
|
||||||
and reduce_op in ("sum",)
|
and reduce_op == "sum"
|
||||||
and inp_size <= config._collective.one_shot_all_reduce_threshold_bytes
|
and inp_size <= config._collective.one_shot_all_reduce_threshold_bytes
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1037,7 +1037,7 @@ def _get_unexposed_collectives(graph: torch.fx.Graph) -> list[torch.fx.Node]:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _is_compute_intensive(node: torch.fx.Node) -> bool:
|
def _is_compute_intensive(node: torch.fx.Node) -> bool:
|
||||||
return node.target in [torch.ops.aten.mm.default]
|
return node.target is torch.ops.aten.mm.default
|
||||||
|
|
||||||
collective_to_overlapping_candidates = defaultdict(list)
|
collective_to_overlapping_candidates = defaultdict(list)
|
||||||
available_nodes = OrderedSet[torch.fx.Node]()
|
available_nodes = OrderedSet[torch.fx.Node]()
|
||||||
|
|
|
||||||
|
|
@ -614,7 +614,7 @@ def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False) -> torch.fx.GraphModul
|
||||||
class NormalizedLinearNode:
|
class NormalizedLinearNode:
|
||||||
def __init__(self, node: torch.fx.Node) -> None:
|
def __init__(self, node: torch.fx.Node) -> None:
|
||||||
assert node.op == "call_function"
|
assert node.op == "call_function"
|
||||||
assert node.target in [torch.nn.functional.linear]
|
assert node.target is torch.nn.functional.linear
|
||||||
self.node: torch.fx.Node = node
|
self.node: torch.fx.Node = node
|
||||||
|
|
||||||
def get_input(self) -> torch.fx.Node:
|
def get_input(self) -> torch.fx.Node:
|
||||||
|
|
|
||||||
|
|
@ -3846,7 +3846,7 @@ def meta__dyn_quant_matmul_4bit(
|
||||||
):
|
):
|
||||||
torch._check(inp.dim() == 2, lambda: "input must be a 2D tensor")
|
torch._check(inp.dim() == 2, lambda: "input must be a 2D tensor")
|
||||||
torch._check(
|
torch._check(
|
||||||
inp.dtype in [torch.float32],
|
inp.dtype == torch.float32,
|
||||||
lambda: f"expected input to be f32, got {inp.dtype}",
|
lambda: f"expected input to be f32, got {inp.dtype}",
|
||||||
)
|
)
|
||||||
M = inp.size(0)
|
M = inp.size(0)
|
||||||
|
|
|
||||||
|
|
@ -352,12 +352,14 @@ def _make_prim(
|
||||||
|
|
||||||
from torch._subclasses.fake_tensor import contains_tensor_types
|
from torch._subclasses.fake_tensor import contains_tensor_types
|
||||||
|
|
||||||
if not any(contains_tensor_types(a.type) for a in _prim._schema.arguments) or str(
|
if (
|
||||||
|
not any(contains_tensor_types(a.type) for a in _prim._schema.arguments)
|
||||||
|
or str(
|
||||||
_prim
|
_prim
|
||||||
) in [
|
|
||||||
# See https://github.com/pytorch/pytorch/issues/103532
|
# See https://github.com/pytorch/pytorch/issues/103532
|
||||||
"prims.device_put.default"
|
)
|
||||||
]:
|
== "prims.device_put.default"
|
||||||
|
):
|
||||||
prim_backend_select_impl.impl(name, _backend_select_impl)
|
prim_backend_select_impl.impl(name, _backend_select_impl)
|
||||||
|
|
||||||
for p in (_prim_packet, _prim):
|
for p in (_prim_packet, _prim):
|
||||||
|
|
|
||||||
|
|
@ -890,7 +890,7 @@ class Tracer(TracerBase):
|
||||||
new_tracer = Tracer.__new__(Tracer)
|
new_tracer = Tracer.__new__(Tracer)
|
||||||
|
|
||||||
for k, v in self.__dict__.items():
|
for k, v in self.__dict__.items():
|
||||||
if k in {"_autowrap_search"}:
|
if k == "_autowrap_search":
|
||||||
new_obj = copy.copy(v)
|
new_obj = copy.copy(v)
|
||||||
else:
|
else:
|
||||||
new_obj = copy.deepcopy(v, memo)
|
new_obj = copy.deepcopy(v, memo)
|
||||||
|
|
|
||||||
|
|
@ -882,7 +882,7 @@ def proxy_call(
|
||||||
def can_handle_tensor(x: Tensor) -> bool:
|
def can_handle_tensor(x: Tensor) -> bool:
|
||||||
r = type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer)
|
r = type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer)
|
||||||
if proxy_mode._allow_fake_constant:
|
if proxy_mode._allow_fake_constant:
|
||||||
r = r or type(x) in (torch._subclasses.FakeTensor,)
|
r = r or type(x) is torch._subclasses.FakeTensor
|
||||||
if not r:
|
if not r:
|
||||||
unrecognized_types.append(type(x))
|
unrecognized_types.append(type(x))
|
||||||
return r
|
return r
|
||||||
|
|
@ -1534,7 +1534,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
||||||
with set_original_aten_op(func):
|
with set_original_aten_op(func):
|
||||||
kwargs = kwargs or {}
|
kwargs = kwargs or {}
|
||||||
|
|
||||||
if func in (prim.device.default,):
|
if func == prim.device.default:
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
|
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
|
||||||
|
|
|
||||||
|
|
@ -15,10 +15,10 @@ class CudaGraphsSupport(OperatorSupport):
|
||||||
if node.op not in CALLABLE_NODE_OPS:
|
if node.op not in CALLABLE_NODE_OPS:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if node.target in [torch.ops.aten.embedding_dense_backward.default]:
|
if node.target == torch.ops.aten.embedding_dense_backward.default:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if node.target in [operator.getitem]:
|
if node.target == operator.getitem:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
found_not_cuda = False
|
found_not_cuda = False
|
||||||
|
|
|
||||||
|
|
@ -815,9 +815,10 @@ def _is_fp(value) -> bool:
|
||||||
|
|
||||||
|
|
||||||
def _is_bool(value) -> bool:
|
def _is_bool(value) -> bool:
|
||||||
return _type_utils.JitScalarType.from_value(
|
return (
|
||||||
value, _type_utils.JitScalarType.UNDEFINED
|
_type_utils.JitScalarType.from_value(value, _type_utils.JitScalarType.UNDEFINED)
|
||||||
) in {_type_utils.JitScalarType.BOOL}
|
== _type_utils.JitScalarType.BOOL
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _generate_wrapped_number(g: jit_utils.GraphContext, scalar):
|
def _generate_wrapped_number(g: jit_utils.GraphContext, scalar):
|
||||||
|
|
|
||||||
|
|
@ -3011,7 +3011,7 @@ def error_inputs_aminmax_amax_amin(op_info, device, is_ref=False, **kwargs):
|
||||||
err_msg_aminmax = "cannot compute aminmax over an empty dimension as the operation has no identity"
|
err_msg_aminmax = "cannot compute aminmax over an empty dimension as the operation has no identity"
|
||||||
if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
|
if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
|
||||||
yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_amax_amin)
|
yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_amax_amin)
|
||||||
elif op_info.name in ['aminmax']:
|
elif op_info.name == 'aminmax':
|
||||||
yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_aminmax)
|
yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_aminmax)
|
||||||
|
|
||||||
# Error Inputs for tensors with more than 64 dimension
|
# Error Inputs for tensors with more than 64 dimension
|
||||||
|
|
@ -3050,7 +3050,7 @@ def error_inputs_aminmax_amax_amin(op_info, device, is_ref=False, **kwargs):
|
||||||
if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
|
if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
|
||||||
yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': illegal_values}),
|
yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': illegal_values}),
|
||||||
error_regex=err_msg_amax_amin2)
|
error_regex=err_msg_amax_amin2)
|
||||||
elif op_info.name in ['aminmax']:
|
elif op_info.name == 'aminmax':
|
||||||
yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': (max_values, min_values)}),
|
yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': (max_values, min_values)}),
|
||||||
error_regex=err_msg_aminmax2)
|
error_regex=err_msg_aminmax2)
|
||||||
|
|
||||||
|
|
@ -4883,7 +4883,7 @@ def sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs):
|
||||||
def reference_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs):
|
def reference_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs):
|
||||||
yield from sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs)
|
yield from sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs)
|
||||||
|
|
||||||
if mode in ('bilinear', ):
|
if mode == 'bilinear':
|
||||||
make_arg = partial(
|
make_arg = partial(
|
||||||
make_tensor,
|
make_tensor,
|
||||||
device=device,
|
device=device,
|
||||||
|
|
@ -9468,7 +9468,7 @@ class foreach_inputs_sample_func:
|
||||||
# unary
|
# unary
|
||||||
if opinfo.ref in (torch.abs, torch.neg):
|
if opinfo.ref in (torch.abs, torch.neg):
|
||||||
return False
|
return False
|
||||||
if opinfo.ref_inplace in (torch.Tensor.zero_,):
|
if opinfo.ref_inplace == torch.Tensor.zero_:
|
||||||
return False
|
return False
|
||||||
return dtype in integral_types_and(torch.bool)
|
return dtype in integral_types_and(torch.bool)
|
||||||
if self.arity < 2 or rightmost_arg_type == ForeachRightmostArgType.Tensor:
|
if self.arity < 2 or rightmost_arg_type == ForeachRightmostArgType.Tensor:
|
||||||
|
|
@ -9698,7 +9698,7 @@ class foreach_pointwise_sample_func(foreach_inputs_sample_func):
|
||||||
super().__init__(arity, rightmost_supports_scalar, rightmost_supports_scalarlist)
|
super().__init__(arity, rightmost_supports_scalar, rightmost_supports_scalarlist)
|
||||||
|
|
||||||
def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype):
|
def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype):
|
||||||
return dtype in integral_types_and(torch.bool) and opinfo.ref in (torch.addcmul,)
|
return dtype in integral_types_and(torch.bool) and opinfo.ref == torch.addcmul
|
||||||
|
|
||||||
def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs):
|
def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs):
|
||||||
assert "num_input_tensors" not in kwargs
|
assert "num_input_tensors" not in kwargs
|
||||||
|
|
|
||||||
|
|
@ -572,7 +572,7 @@ def optim_inputs_func_adam(device, dtype=None):
|
||||||
+ (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
|
+ (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
|
||||||
+ (mps_supported_configs if _get_device_type(device) == "mps" else [])
|
+ (mps_supported_configs if _get_device_type(device) == "mps" else [])
|
||||||
)
|
)
|
||||||
if dtype in (torch.float16,):
|
if dtype == torch.float16:
|
||||||
for input in total:
|
for input in total:
|
||||||
"""
|
"""
|
||||||
Too small eps will make denom to be zero for low precision dtype
|
Too small eps will make denom to be zero for low precision dtype
|
||||||
|
|
|
||||||
|
|
@ -322,7 +322,7 @@ def sample_inputs_linalg_matrix_norm(op_info, device, dtype, requires_grad, **kw
|
||||||
def sample_inputs_linalg_norm(
|
def sample_inputs_linalg_norm(
|
||||||
op_info, device, dtype, requires_grad, *, variant=None, **kwargs
|
op_info, device, dtype, requires_grad, *, variant=None, **kwargs
|
||||||
):
|
):
|
||||||
if variant is not None and variant not in ("subgradient_at_zero",):
|
if variant is not None and variant != "subgradient_at_zero":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported variant, expected variant to be 'subgradient_at_zero' but got: {variant}"
|
f"Unsupported variant, expected variant to be 'subgradient_at_zero' but got: {variant}"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -204,7 +204,7 @@ def _validate_sample_input_sparse_reduction(op_info, sample, check_validate=Fals
|
||||||
if op_info.name == "sum":
|
if op_info.name == "sum":
|
||||||
sample = _validate_sample_input_sparse_reduction_sum(sample)
|
sample = _validate_sample_input_sparse_reduction_sum(sample)
|
||||||
|
|
||||||
if op_info.name in {"masked.sum"}:
|
if op_info.name == "masked.sum":
|
||||||
mask = sample.kwargs.get("mask", UNSPECIFIED)
|
mask = sample.kwargs.get("mask", UNSPECIFIED)
|
||||||
if (
|
if (
|
||||||
mask not in {None, UNSPECIFIED}
|
mask not in {None, UNSPECIFIED}
|
||||||
|
|
@ -792,12 +792,16 @@ def _sample_inputs_sparse_like_fns(
|
||||||
|
|
||||||
|
|
||||||
def _validate_sample_input_sparse_like_fns(op_info, sample, check_validate=False):
|
def _validate_sample_input_sparse_like_fns(op_info, sample, check_validate=False):
|
||||||
if sample.input.layout in {
|
if (
|
||||||
|
sample.input.layout
|
||||||
|
in {
|
||||||
torch.sparse_csr,
|
torch.sparse_csr,
|
||||||
torch.sparse_csc,
|
torch.sparse_csc,
|
||||||
torch.sparse_bsr,
|
torch.sparse_bsr,
|
||||||
torch.sparse_bsc,
|
torch.sparse_bsc,
|
||||||
} and op_info.name not in {"zeros_like"}:
|
}
|
||||||
|
and op_info.name != "zeros_like"
|
||||||
|
):
|
||||||
if sample.kwargs.get("layout", sample.input.layout) != sample.input.layout:
|
if sample.kwargs.get("layout", sample.input.layout) != sample.input.layout:
|
||||||
return ErrorInput(
|
return ErrorInput(
|
||||||
sample,
|
sample,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user