[2/N] Fix ruff warnings (#164460)

Apply ruff `SIM` rules.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164460
Approved by: https://github.com/ezyang
This commit is contained in:
Yuanyuan Chen 2025-10-04 03:40:29 +00:00 committed by PyTorch MergeBot
parent 34042a9145
commit 35c4130fd1
58 changed files with 82 additions and 98 deletions

View File

@ -220,7 +220,7 @@ def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info):
f"hold a list of gradients but got object of type "
f"{type(grad)}."
)
if not len(grad) == len(arg_info):
if len(grad) != len(arg_info):
error(
f"for input '{name}' expected the grad_input dict to "
f"hold a list of {len(arg_info)} gradients but got "

View File

@ -604,7 +604,7 @@ class OutputGraph(OutputGraphCommon):
fake_mode = torch._subclasses.FakeTensorMode(
shape_env=shape_env,
# TODO (tmanlaibaatar) Remove this once we always lift params and buffers
allow_non_fake_inputs=True if self.export else False,
allow_non_fake_inputs=bool(self.export),
export=self.export,
)
self.tracing_context: TracingContext = TracingContext(fake_mode)

View File

@ -453,9 +453,7 @@ def _call_while_loop(
cond_r_meta = _extract_tensor_metadata(
cond_r.proxy.node.meta["example_value"], include_contiguity=False
)
if not cond_r_meta.dtype == torch.bool or not cond_r_meta.shape == torch.Size(
[]
):
if cond_r_meta.dtype != torch.bool or cond_r_meta.shape != torch.Size([]):
unimplemented(
f"Expected cond_fn to return a scalar tensor or a bool but got {cond_r_meta.shape}"
)

View File

@ -325,8 +325,7 @@ class CondAutogradOp(torch.autograd.Function):
true_outputs = fn(*args)
grads_tensor_masks = [
True if isinstance(out, torch.Tensor) else False
for out in true_outputs
bool(isinstance(out, torch.Tensor)) for out in true_outputs
]
return filter_with_masks(true_outputs, grads_tensor_masks)

View File

@ -134,9 +134,7 @@ def create_hop_fw_bw(
"Dynamo traced submodule should return tuple"
)
return fw_out, [
True
if isinstance(ret, torch.Tensor) and ret.requires_grad
else False
bool(isinstance(ret, torch.Tensor) and ret.requires_grad)
for ret in fw_out
]

View File

@ -675,8 +675,7 @@ class ScanAutogradImpl:
grad_carry, grad_ys = grad_fw_outputs[:n_carry], grad_fw_outputs[n_carry:]
additional_inputs_tensor_masks = [
True if isinstance(t, torch.Tensor) else False
for t in self.additional_inputs
bool(isinstance(t, torch.Tensor)) for t in self.additional_inputs
]
grad_additional_inputs = [
torch.zeros_like(t)

View File

@ -500,8 +500,7 @@ def prepare_fw_with_masks(fn):
def fw_with_masks(*args):
fw_out = fn(*args)
return fw_out, [
True if isinstance(ret, torch.Tensor) and ret.requires_grad else False
for ret in fw_out
bool(isinstance(ret, torch.Tensor) and ret.requires_grad) for ret in fw_out
]
return fw_with_masks
@ -857,7 +856,7 @@ def first_slice_copy(t: torch.Tensor, dim: int = 0) -> torch.Tensor:
# Returns a mask whether a list element is a tensor or not
def get_tensor_mask(tensor_list: Iterable[Any]) -> list[bool]:
return [True if isinstance(v, torch.Tensor) else False for v in tensor_list]
return [bool(isinstance(v, torch.Tensor)) for v in tensor_list]
def mask_list(

View File

@ -767,11 +767,11 @@ class WhileLoopAutogradOp(torch.autograd.Function):
# inductor codegen, where we need to do a non-unform treatment for None and tensors.
# So we set up masks and filter the None gradients so that only tensors are returned from each step.
carries_tensor_masks = [
True if isinstance(t, torch.Tensor) and t.dtype.is_floating_point else False
bool(isinstance(t, torch.Tensor) and t.dtype.is_floating_point)
for t in ctx.carries
]
additional_inputs_tensor_masks = [
True if isinstance(t, torch.Tensor) and t.dtype.is_floating_point else False
bool(isinstance(t, torch.Tensor) and t.dtype.is_floating_point)
for t in ctx.additional_inputs
]

View File

@ -1413,7 +1413,7 @@ Resize can only operate on graph inputs, but got {node} which is resizing non-gr
)
resized_to_0_idxes = graph_input_to_resized_to_0_node_idxes.get(graph_input, [])
if not len(resized_to_full_idxes) == len(resized_to_0_idxes):
if len(resized_to_full_idxes) != len(resized_to_0_idxes):
log.warning(
f"""
Unequal number of resize-to-full and resize-to-0 nodes for graph input {graph_input}:

View File

@ -383,7 +383,7 @@ class SamplingMethod(Enum):
elif TypeExemplars.contains(type_hint):
return TypeExemplars.example(type_hint)
elif type_hint == Any:
return 1 if not default == 1 else 2
return 1 if default != 1 else 2
else:
raise ValueError(f"Unable to process type {type_hint}. PRs welcome :)")

View File

@ -19,18 +19,18 @@ def mark_mixed_dtype(computation_node):
if computation_node_dtype not in (torch.float16, torch.bfloat16):
return
if not len(computation_node.users) == 1:
if len(computation_node.users) != 1:
return
computation_node_user = next(iter(computation_node.users.keys()))
if not isinstance(computation_node_user.meta["val"], torch.Tensor):
return
if not computation_node_user.meta["val"].dtype == torch.float32:
if computation_node_user.meta["val"].dtype != torch.float32:
return
while computation_node_user.target in _binary_ops:
if not len(computation_node_user.users) == 1:
if len(computation_node_user.users) != 1:
return
computation_node_user = next(iter(computation_node_user.users.keys()))
@ -188,7 +188,7 @@ def binary_folding_init():
):
return False
if not len(conv_node.args[1].users) == 1:
if len(conv_node.args[1].users) != 1:
return False
weight_meta_value = conv_node.args[1].meta.get("val")
@ -242,7 +242,7 @@ def binary_folding_init():
):
return False
if not len(weight_node.users) == 1:
if len(weight_node.users) != 1:
return False
weight_meta_value = weight_node.meta.get("val")

View File

@ -594,7 +594,7 @@ class OverlapScheduler:
if is_wait_tensor(node):
info = self.collective_info[self.wait_to_start[node]]
assert not info.hiding_node == curr_compute_node
assert info.hiding_node != curr_compute_node
self._handle_wait(node)
continue

View File

@ -705,7 +705,7 @@ def reorder_for_locality(graph: torch.fx.Graph):
iter(graph.find_nodes(op="call_function", target=torch.ops.aten.copy_.default)),
None,
)
past_mutating_epilogue = True if first_copy is None else False
past_mutating_epilogue = first_copy is None
for node in reversed(graph.nodes):
seen_nodes.add(node)
@ -1761,7 +1761,7 @@ class ConstructorMoverPass:
if not torch._subclasses.fake_tensor._is_tensor_constructor(node.target):
continue
if not node.kwargs.get("device") == torch.device("cpu"):
if node.kwargs.get("device") != torch.device("cpu"):
continue
constructors.append(node)
@ -1922,14 +1922,10 @@ def move_constructors_to_gpu(graph: fx.Graph) -> None:
# by explicitly moving cpu scalar tensors to gpu when profitable, relying on
# graph partition to split off this data copy, and cudagraphifying
# the remaining gpu ops.
allow_inputs_outputs = (
True
if (
allow_inputs_outputs = bool(
torch._inductor.config.triton.cudagraphs
and torch._inductor.config.graph_partition
)
else False
)
ConstructorMoverPass(
get_gpu_type(),
allow_inputs=allow_inputs_outputs,

View File

@ -615,7 +615,7 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))
if copy_node is not None:
replace_dict[copy_node] = copy_node.args[0]
if not trigger == ReInplaceTrigger.AUTO_FUNC_V2:
if trigger != ReInplaceTrigger.AUTO_FUNC_V2:
for user in node.users:
# For auto_functionalize_v2, arg is the index of the base, where base at index i corresponds to
# output atindex size(out)+i.

View File

@ -469,9 +469,7 @@ def significant_strides_equal(
if not V.graph.sizevars.statically_known_equals(
s1, s2
) and not V.graph.sizevars.symbolic_hint(s1) == V.graph.sizevars.symbolic_hint(
s2
):
) and V.graph.sizevars.symbolic_hint(s1) != V.graph.sizevars.symbolic_hint(s2):
return False
return True
@ -5443,7 +5441,7 @@ class ConcatKernel(NopKernel):
return True
# otherwise, check equality of layouts
if not len(src.get_stride()) == len(dst.get_stride()):
if len(src.get_stride()) != len(dst.get_stride()):
return False
return all(

View File

@ -35,7 +35,7 @@ def check_cpu_supported():
supported = (
requires_avx2_on_cpu
and not torch.xpu.is_available()
and not sys.platform == "darwin"
and sys.platform != "darwin"
)
return supported

View File

@ -1398,7 +1398,7 @@ def check_and_add_duplicate_pattern(
new_graph_str = str(graph)
for graph_str in equiv_pattern_reprs:
if not new_graph_str == graph_str:
if new_graph_str != graph_str:
continue
if skip_duplicates:
return True

View File

@ -237,7 +237,7 @@ def check_autotune_cache(
not disabled
and filename is not None
and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning"))
and not os.environ.get("TRITON_INTERPRET", "0") == "1"
and os.environ.get("TRITON_INTERPRET", "0") != "1"
):
configs_hash = hash_configs(configs)

View File

@ -4976,7 +4976,7 @@ class Scheduler:
if name in name_to_node
}
input_deallocation = {
name: True if name in buffer_names_to_free else False
name: name in buffer_names_to_free
for name in partition_input_names
if name in name_to_node
}

View File

@ -132,7 +132,7 @@ def solve_for_tiling(expr: sympy.Expr) -> Optional[sympy.Expr]:
out = _solve_simple_expr(eq_1_expr_simplified)
# since we approximated FloorDiv/ModularIndexing, double check here
if not out or not (sympy_subs(eq_1_expr, {free_symbol: out})) == 1:
if not out or sympy_subs(eq_1_expr, {free_symbol: out}) != 1:
return None
required_values.append(out)

View File

@ -504,7 +504,7 @@ def is_pointwise_use(
Uses in views ops will follow the views uses
"""
if not use.op == "call_function":
if use.op != "call_function":
return False
if not (
isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem
@ -2020,7 +2020,7 @@ def use_ck_template(layout: Layout) -> bool:
if not torch.version.hip:
return False
# tensors must be on GPU
if not layout.device.type == "cuda":
if layout.device.type != "cuda":
return False
# hardware check
# if config arch list is not specified, get the native arch from the device properties

View File

@ -443,7 +443,7 @@ def get_callable_argument_names(fn) -> list[str]:
for name, param in callable_signature.parameters.items():
# All four other types of arguments do not map to individual values
# with a keyword as name.
if not param.kind == param.POSITIONAL_OR_KEYWORD:
if param.kind != param.POSITIONAL_OR_KEYWORD:
continue
argument_names.append(name)

View File

@ -135,7 +135,7 @@ def mutates_and_returns_first_arg(op: OpOverload):
if op.namespace != "aten":
return False
schema = op._schema
if not len(schema.returns) == 1:
if len(schema.returns) != 1:
return False
if schema.returns[0].alias_info is None:
return False

View File

@ -1207,7 +1207,7 @@ def safe_grad_filter(message, category, filename, lineno, file=None, line=None)
def user_warning_filter(
message, category, filename, lineno, file=None, line=None
) -> bool:
return not category == UserWarning
return category != UserWarning
@contextlib.contextmanager

View File

@ -317,7 +317,7 @@ def print_assert_equal(test_string, actual, desired):
__tracebackhide__ = True # Hide traceback for py.test
import pprint
if not (actual == desired):
if actual != desired:
msg = StringIO()
msg.write(test_string)
msg.write(" failed\nACTUAL: \n")
@ -1505,7 +1505,7 @@ def _integer_repr(x, vdt, comp):
# See also
# https://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/
rx = x.view(vdt)
if not (rx.size == 1):
if rx.size != 1:
rx[rx < 0] = comp - rx[rx < 0]
else:
if rx < 0:

View File

@ -4316,7 +4316,7 @@ def tensor_split(
# If indices_or_sections is a tensor, it must be a CPU Long tensor
if isinstance(indices_or_sections, TensorLike):
if not indices_or_sections.device.type == "cpu":
if indices_or_sections.device.type != "cpu":
msg = (
f"tensor_split: if indices_or_sections is a tensor it must be on the CPU, "
f"but received one on {indices_or_sections.device}"

View File

@ -1335,7 +1335,7 @@ def make_fast_binary_impl(
# Use elementwise_dtypes for the tricky case
has_different_input_dtypes = True
continue
if common_device == cpu and not op.device.type == "cpu":
if common_device == cpu and op.device.type != "cpu":
common_device = op.device
# Slightly simplified here as target_dtype cannot vary
if common_dtype is None:

View File

@ -3012,7 +3012,7 @@ class FakeTensorMode(TorchDispatchMode):
t.numel() <= CONSTANT_NUMEL_LIMIT
and not is_sparse_any(t)
and not self.is_our_fake(t)
and not t.device.type == "meta"
and t.device.type != "meta"
)
def invalidate_written_to_constants(

View File

@ -127,7 +127,7 @@ def try_convert_fake_to_real(
key_to_real_storage = {v: k for k, v in desc.lookup_storage.items()}
out = []
for t in ten_list:
if not isinstance(t, FakeTensor) or not t.layout == torch.strided:
if not isinstance(t, FakeTensor) or t.layout != torch.strided:
out.append(t)
continue

View File

@ -113,7 +113,7 @@ class SchemaCheckMode(TorchDispatchMode):
return name if name != "self" else "input"
def unwrap(e):
if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor:
if isinstance(e, torch.Tensor) and type(e) != torch.Tensor:
try:
return e.elem
except AttributeError:
@ -122,7 +122,7 @@ class SchemaCheckMode(TorchDispatchMode):
def parse_metadata(e):
if isinstance(e, torch.Tensor):
if not type(e) == torch.Tensor:
if type(e) != torch.Tensor:
try:
current = e.elem
return (

View File

@ -518,7 +518,7 @@ class ModelReportVisualizer:
# the index of the feature will the 0 + num non feature columns
tensor_feature_index = feature_column_offset
row_value = row[tensor_feature_index]
if not type(row_value) == str:
if type(row_value) != str:
x_data.append(x_val_to_append)
y_data.append(row_value)
elif is_valid_per_channel_plot:
@ -541,7 +541,7 @@ class ModelReportVisualizer:
# the index of the feature will the 0 + num non feature columns
tensor_feature_index = feature_column_offset
row_value = row[tensor_feature_index]
if not type(row_value) == str:
if type(row_value) != str:
# only append if new index we are appending
if len(x_data) == 0 or x_data[-1] != x_val_to_append:
x_data.append(x_val_to_append)

View File

@ -51,7 +51,7 @@ def _is_match(modules, node, pattern, max_uses=sys.maxsize):
if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
if node.op != "call_module":
return False
if not type_before_parametrizations(modules[node.target]) == self_match:
if type_before_parametrizations(modules[node.target]) != self_match:
return False
elif callable(self_match):
if node.op != "call_function" or node.target is not self_match:

View File

@ -1708,7 +1708,7 @@ def insert_observers_for_model(
skip_inserting_observers = (
(qconfig is None) or not output_is_a_tensor
) and (not node.op == "output")
) and (node.op != "output")
# TODO: take a closer look to see if we can remove this check
# right now it is here because of `observed_node_names`, we are using

View File

@ -733,9 +733,7 @@ def _stack_and_check_tensors(
if tensor is None:
out_jacobian[:, j].zero_()
else:
dense = (
tensor.to_dense() if not tensor.layout == torch.strided else tensor
)
dense = tensor.to_dense() if tensor.layout != torch.strided else tensor
assert out_jacobian[:, j].numel() == dense.numel()
out_jacobian[:, j] = dense.reshape(-1)
return out_jacobians, correct_grad_sizes, correct_grad_types

View File

@ -116,5 +116,5 @@ class OptEinsumModule(PropModule):
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
sys.modules[__name__] = OptEinsumModule(sys.modules[__name__], __name__)
enabled = True if is_available() else False
enabled = bool(is_available())
strategy = "auto" if is_available() else None

View File

@ -200,7 +200,7 @@ class _TensorsAccessed:
del self.accesses[data_ptr]
def were_there_reads_since_last_write(self, data_ptr: DataPtr) -> bool:
return True if self.accesses[data_ptr].reads else False
return bool(self.accesses[data_ptr].reads)
def get_allocation_stack_trace(
self, data_ptr: DataPtr

View File

@ -548,7 +548,7 @@ def create_default_global_save_plan(
for plan in all_plans:
new_items = []
for item in plan.items:
if not item.type == WriteItemType.SHARD:
if item.type != WriteItemType.SHARD:
assert item.index.fqn not in md
if item.type == WriteItemType.BYTE_IO:

View File

@ -1508,8 +1508,7 @@ def _allgather_orig_param_states(
return output_states
has_state_params: list[bool] = [
True if fqn in output_states else False
for fqn, idx in fsdp_param_info.param_indices.items()
fqn in output_states for fqn, idx in fsdp_param_info.param_indices.items()
]
# Loop through the ``state_buffers`` and construct the flattened, concatenated,

View File

@ -1191,7 +1191,7 @@ def _add_send_recv(
"""
if action is None:
return True
elif action.computation_type == F and not action.stage_index == 0:
elif action.computation_type == F and action.stage_index != 0:
if (
_Action(action.stage_index, RECV_F, action.microbatch_index)
in prev_actions
@ -1205,7 +1205,7 @@ def _add_send_recv(
return False
elif (
action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD)
and not action.stage_index == num_stages - 1
and action.stage_index != num_stages - 1
):
if (
_Action(action.stage_index, RECV_B, action.microbatch_index)

View File

@ -350,7 +350,7 @@ def _tensorpipe_init_backend_handler(
device_count = torch.cuda.device_count()
is_static_group = True if world_size else False
is_static_group = bool(world_size)
# world_size is specified so this is a static group (ranks cannot join and leave)
if is_static_group:
# The agent's join method is required to behave like a barrier and perform

View File

@ -1067,12 +1067,12 @@ def grouped_mm_strategy(op_schema: OpSchema) -> OpStrategy:
if meta.stride[end_dim - 1] == 1 and meta.stride[end_dim] >= max(
1, meta.shape[end_dim - 1]
):
if not meta.stride[end_dim] % alignment == 0:
if meta.stride[end_dim] % alignment != 0:
return False
elif meta.stride[end_dim] == 1 and meta.stride[end_dim - 1] >= max(
1, meta.shape[end_dim]
):
if not meta.stride[end_dim - 1] % alignment == 0:
if meta.stride[end_dim - 1] % alignment != 0:
return False
else:
return False

View File

@ -187,7 +187,7 @@ class LocalShardsWrapper(torch.Tensor):
aten.equal.default(x, y) for x, y in zip(a.local_shards(), b.local_shards())
):
return False
if not a.storage_metadata() == b.storage_metadata():
if a.storage_metadata() != b.storage_metadata():
return False
return True

View File

@ -464,7 +464,7 @@ def _insert_reshard_gm(
if reshard_node.op not in ["placeholder", "output"]:
reshard_node.meta["nn_module_stack"] = (
copy.copy(input_arg.meta["nn_module_stack"])
if not input_arg.op == "placeholder"
if input_arg.op != "placeholder"
else copy.copy(node.meta["nn_module_stack"])
)
output_node = gm.graph.graph_copy(

View File

@ -142,7 +142,7 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None:
return
next_module_node = next(iter(unflatten_getitem_getitem_users))
if not (next_module_node.op == "call_module"):
if next_module_node.op != "call_module":
log.debug(
"Unflatten node %s's user is not a call_module. "
"Instead it is: %s. Passing...",

View File

@ -590,7 +590,7 @@ def _decompose_and_get_gm_with_new_signature_constants(
ep.graph_module,
fake_args,
decompositions=python_decomp_table,
trace_joint=True if joint_loss_index is not None else False,
trace_joint=joint_loss_index is not None,
output_loss_index=(
joint_loss_index if joint_loss_index is not None else None
),

View File

@ -1942,7 +1942,7 @@ def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
torch._check_type(
not indices.is_complex()
and not indices.is_floating_point()
and not indices.dtype == torch.bool,
and indices.dtype != torch.bool,
lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}",
)

View File

@ -7056,7 +7056,7 @@ class ShapeEnv:
expr,
concrete_val,
# only print stack trace when debug mode is on (e.g. TORCH_LOGS="dynamic")
stack_info=True if log.getEffectiveLevel() < logging.WARNING else False,
stack_info=log.getEffectiveLevel() < logging.WARNING,
)
def _get_user_frame(self) -> Optional[types.FrameType]:

View File

@ -668,7 +668,7 @@ class _MinimizerBase:
final_start_idx: Optional[int] = start_idx
final_end_idx: Optional[int] = end_idx
run_both = True if find_last_node is None else False
run_both = find_last_node is None
# step 1: find (0, end_idx) of culprit block
if run_both or find_last_node:

View File

@ -3649,7 +3649,7 @@ def smooth_l1_loss(
reduction=reduction,
beta=beta,
)
if not (target.size() == input.size()):
if target.size() != input.size():
warnings.warn(
f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). "
"This will likely lead to incorrect results due to broadcasting. "
@ -3712,7 +3712,7 @@ def huber_loss(
weight=weight,
)
if not (target.size() == input.size()):
if target.size() != input.size():
warnings.warn(
f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). "
"This will likely lead to incorrect results due to broadcasting. "
@ -3789,7 +3789,7 @@ def l1_loss(
reduce=reduce,
reduction=reduction,
)
if not (target.size() == input.size()):
if target.size() != input.size():
warnings.warn(
f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). "
"This will likely lead to incorrect results due to broadcasting. "
@ -3862,7 +3862,7 @@ def mse_loss(
weight=weight,
)
if not (target.size() == input.size()):
if target.size() != input.size():
warnings.warn(
f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). "
"This will likely lead to incorrect results due to broadcasting. "

View File

@ -242,7 +242,7 @@ class RNNBase(Module):
for fw in self._flat_weights:
if (
not isinstance(fw, Tensor)
or not (fw.dtype == dtype)
or fw.dtype != dtype
or not fw.is_cuda
or not torch.backends.cudnn.is_acceptable(fw)
):

View File

@ -382,7 +382,7 @@ class TransformerEncoder(Module):
why_not_sparsity_fast_path = (
f"{enc_layer}.activation_relu_or_gelu was not True"
)
elif not (encoder_layer.norm1.eps == encoder_layer.norm2.eps):
elif encoder_layer.norm1.eps != encoder_layer.norm2.eps:
why_not_sparsity_fast_path = (
f"{enc_layer}.norm1.eps was not equal to {enc_layer}.norm2.eps"
)
@ -458,7 +458,7 @@ class TransformerEncoder(Module):
)
elif first_layer.training:
why_not_sparsity_fast_path = f"{str_first_layer} was in training mode"
elif not src.dim() == 3:
elif src.dim() != 3:
why_not_sparsity_fast_path = (
f"input not batched; expected src.dim() of 3 but got {src.dim()}"
)
@ -832,7 +832,7 @@ class TransformerEncoderLayer(Module):
why_not_sparsity_fast_path = (
"torch.backends.mha.get_fastpath_enabled() was not True"
)
elif not src.dim() == 3:
elif src.dim() != 3:
why_not_sparsity_fast_path = (
f"input not batched; expected src.dim() of 3 but got {src.dim()}"
)
@ -846,7 +846,7 @@ class TransformerEncoderLayer(Module):
why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"
elif not self.activation_relu_or_gelu:
why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"
elif not (self.norm1.eps == self.norm2.eps):
elif self.norm1.eps != self.norm2.eps:
why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"
elif src.is_nested and (
src_key_padding_mask is not None or src_mask is not None

View File

@ -136,7 +136,7 @@ def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode):
try:
yield
finally:
if hasattr(model, "training") and not mode == _C_onnx.TrainingMode.PRESERVE:
if hasattr(model, "training") and mode != _C_onnx.TrainingMode.PRESERVE:
model.train(originally_training)

View File

@ -9346,7 +9346,7 @@ def sample_inputs_multi_head_attention_forward(opinfo, device, dtype, requires_g
"k_proj_weight" : k_proj_weight,
"v_proj_weight" : v_proj_weight,
"attn_mask" : attn_mask,
"training" : True if dropout_p > 0.0 else False,
"training" : dropout_p > 0.0,
"use_separate_proj_weight" : use_separate_proj_weight
}

View File

@ -1287,7 +1287,7 @@ class QuantizationTestCase(TestCase):
prepare_custom_config=prepare_custom_config,
backend_config=backend_config,
)
if not quant_type == QuantType.DYNAMIC:
if quant_type != QuantType.DYNAMIC:
prepared(*inputs)
if print_debug_info:

View File

@ -1470,7 +1470,7 @@ TEST_ACL = torch.backends.mkldnn.is_available() and torch.ops.mkldnn._is_mkldnn_
TEST_MPS = torch.backends.mps.is_available()
MACOS_VERSION = float('.'.join(platform.mac_ver()[0].split('.')[:2]) or -1)
TEST_XPU = torch.xpu.is_available()
TEST_HPU = True if (hasattr(torch, "hpu") and torch.hpu.is_available()) else False
TEST_HPU = bool(hasattr(torch, "hpu") and torch.hpu.is_available())
TEST_CUDA = torch.cuda.is_available()
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name(), None)
TEST_PRIVATEUSE1 = is_privateuse1_backend_available()
@ -3195,7 +3195,7 @@ class TestCase(expecttest.TestCase):
def remove_empty_lines(self, input_string):
lines = input_string.split('\n')
filtered_lines = [line for line in lines if not line.strip() == '']
filtered_lines = [line for line in lines if line.strip() != '']
return '\n'.join(filtered_lines)
# ignore comments will ignore lines that starts with # after being stripped

View File

@ -132,7 +132,7 @@ def _infer_device_type(*args):
def add_device_types(arg):
nonlocal device_types
if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu":
if isinstance(arg, torch.Tensor) and arg.device.type != "cpu":
device_types.append(arg.device.type)
tree_map(add_device_types, args)

View File

@ -119,7 +119,7 @@ PIP_PATTERNS = [
def run(command):
"""Return (return-code, stdout, stderr)."""
shell = True if type(command) is str else False
shell = type(command) is str
p = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell
)

View File

@ -228,7 +228,7 @@ CUDA_NOT_FOUND_MESSAGE = (
)
ROCM_HOME = _find_rocm_home() if (torch.cuda._is_compiled() and torch.version.hip) else None
HIP_HOME = _join_rocm_home('hip') if ROCM_HOME else None
IS_HIP_EXTENSION = True if ((ROCM_HOME is not None) and (torch.version.hip is not None)) else False
IS_HIP_EXTENSION = bool(ROCM_HOME is not None and torch.version.hip is not None)
ROCM_VERSION = None
if torch.version.hip is not None:
ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2])