[2/N] Fix ruff warnings (#164460)

Apply ruff `SIM` rules.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164460
Approved by: https://github.com/ezyang
This commit is contained in:
Yuanyuan Chen 2025-10-04 03:40:29 +00:00 committed by PyTorch MergeBot
parent 34042a9145
commit 35c4130fd1
58 changed files with 82 additions and 98 deletions

View File

@ -220,7 +220,7 @@ def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info):
f"hold a list of gradients but got object of type " f"hold a list of gradients but got object of type "
f"{type(grad)}." f"{type(grad)}."
) )
if not len(grad) == len(arg_info): if len(grad) != len(arg_info):
error( error(
f"for input '{name}' expected the grad_input dict to " f"for input '{name}' expected the grad_input dict to "
f"hold a list of {len(arg_info)} gradients but got " f"hold a list of {len(arg_info)} gradients but got "

View File

@ -604,7 +604,7 @@ class OutputGraph(OutputGraphCommon):
fake_mode = torch._subclasses.FakeTensorMode( fake_mode = torch._subclasses.FakeTensorMode(
shape_env=shape_env, shape_env=shape_env,
# TODO (tmanlaibaatar) Remove this once we always lift params and buffers # TODO (tmanlaibaatar) Remove this once we always lift params and buffers
allow_non_fake_inputs=True if self.export else False, allow_non_fake_inputs=bool(self.export),
export=self.export, export=self.export,
) )
self.tracing_context: TracingContext = TracingContext(fake_mode) self.tracing_context: TracingContext = TracingContext(fake_mode)

View File

@ -453,9 +453,7 @@ def _call_while_loop(
cond_r_meta = _extract_tensor_metadata( cond_r_meta = _extract_tensor_metadata(
cond_r.proxy.node.meta["example_value"], include_contiguity=False cond_r.proxy.node.meta["example_value"], include_contiguity=False
) )
if not cond_r_meta.dtype == torch.bool or not cond_r_meta.shape == torch.Size( if cond_r_meta.dtype != torch.bool or cond_r_meta.shape != torch.Size([]):
[]
):
unimplemented( unimplemented(
f"Expected cond_fn to return a scalar tensor or a bool but got {cond_r_meta.shape}" f"Expected cond_fn to return a scalar tensor or a bool but got {cond_r_meta.shape}"
) )

View File

@ -325,8 +325,7 @@ class CondAutogradOp(torch.autograd.Function):
true_outputs = fn(*args) true_outputs = fn(*args)
grads_tensor_masks = [ grads_tensor_masks = [
True if isinstance(out, torch.Tensor) else False bool(isinstance(out, torch.Tensor)) for out in true_outputs
for out in true_outputs
] ]
return filter_with_masks(true_outputs, grads_tensor_masks) return filter_with_masks(true_outputs, grads_tensor_masks)

View File

@ -134,9 +134,7 @@ def create_hop_fw_bw(
"Dynamo traced submodule should return tuple" "Dynamo traced submodule should return tuple"
) )
return fw_out, [ return fw_out, [
True bool(isinstance(ret, torch.Tensor) and ret.requires_grad)
if isinstance(ret, torch.Tensor) and ret.requires_grad
else False
for ret in fw_out for ret in fw_out
] ]

View File

@ -675,8 +675,7 @@ class ScanAutogradImpl:
grad_carry, grad_ys = grad_fw_outputs[:n_carry], grad_fw_outputs[n_carry:] grad_carry, grad_ys = grad_fw_outputs[:n_carry], grad_fw_outputs[n_carry:]
additional_inputs_tensor_masks = [ additional_inputs_tensor_masks = [
True if isinstance(t, torch.Tensor) else False bool(isinstance(t, torch.Tensor)) for t in self.additional_inputs
for t in self.additional_inputs
] ]
grad_additional_inputs = [ grad_additional_inputs = [
torch.zeros_like(t) torch.zeros_like(t)

View File

@ -500,8 +500,7 @@ def prepare_fw_with_masks(fn):
def fw_with_masks(*args): def fw_with_masks(*args):
fw_out = fn(*args) fw_out = fn(*args)
return fw_out, [ return fw_out, [
True if isinstance(ret, torch.Tensor) and ret.requires_grad else False bool(isinstance(ret, torch.Tensor) and ret.requires_grad) for ret in fw_out
for ret in fw_out
] ]
return fw_with_masks return fw_with_masks
@ -857,7 +856,7 @@ def first_slice_copy(t: torch.Tensor, dim: int = 0) -> torch.Tensor:
# Returns a mask whether a list element is a tensor or not # Returns a mask whether a list element is a tensor or not
def get_tensor_mask(tensor_list: Iterable[Any]) -> list[bool]: def get_tensor_mask(tensor_list: Iterable[Any]) -> list[bool]:
return [True if isinstance(v, torch.Tensor) else False for v in tensor_list] return [bool(isinstance(v, torch.Tensor)) for v in tensor_list]
def mask_list( def mask_list(

View File

@ -767,11 +767,11 @@ class WhileLoopAutogradOp(torch.autograd.Function):
# inductor codegen, where we need to do a non-unform treatment for None and tensors. # inductor codegen, where we need to do a non-unform treatment for None and tensors.
# So we set up masks and filter the None gradients so that only tensors are returned from each step. # So we set up masks and filter the None gradients so that only tensors are returned from each step.
carries_tensor_masks = [ carries_tensor_masks = [
True if isinstance(t, torch.Tensor) and t.dtype.is_floating_point else False bool(isinstance(t, torch.Tensor) and t.dtype.is_floating_point)
for t in ctx.carries for t in ctx.carries
] ]
additional_inputs_tensor_masks = [ additional_inputs_tensor_masks = [
True if isinstance(t, torch.Tensor) and t.dtype.is_floating_point else False bool(isinstance(t, torch.Tensor) and t.dtype.is_floating_point)
for t in ctx.additional_inputs for t in ctx.additional_inputs
] ]

View File

@ -1413,7 +1413,7 @@ Resize can only operate on graph inputs, but got {node} which is resizing non-gr
) )
resized_to_0_idxes = graph_input_to_resized_to_0_node_idxes.get(graph_input, []) resized_to_0_idxes = graph_input_to_resized_to_0_node_idxes.get(graph_input, [])
if not len(resized_to_full_idxes) == len(resized_to_0_idxes): if len(resized_to_full_idxes) != len(resized_to_0_idxes):
log.warning( log.warning(
f""" f"""
Unequal number of resize-to-full and resize-to-0 nodes for graph input {graph_input}: Unequal number of resize-to-full and resize-to-0 nodes for graph input {graph_input}:

View File

@ -383,7 +383,7 @@ class SamplingMethod(Enum):
elif TypeExemplars.contains(type_hint): elif TypeExemplars.contains(type_hint):
return TypeExemplars.example(type_hint) return TypeExemplars.example(type_hint)
elif type_hint == Any: elif type_hint == Any:
return 1 if not default == 1 else 2 return 1 if default != 1 else 2
else: else:
raise ValueError(f"Unable to process type {type_hint}. PRs welcome :)") raise ValueError(f"Unable to process type {type_hint}. PRs welcome :)")

View File

@ -19,18 +19,18 @@ def mark_mixed_dtype(computation_node):
if computation_node_dtype not in (torch.float16, torch.bfloat16): if computation_node_dtype not in (torch.float16, torch.bfloat16):
return return
if not len(computation_node.users) == 1: if len(computation_node.users) != 1:
return return
computation_node_user = next(iter(computation_node.users.keys())) computation_node_user = next(iter(computation_node.users.keys()))
if not isinstance(computation_node_user.meta["val"], torch.Tensor): if not isinstance(computation_node_user.meta["val"], torch.Tensor):
return return
if not computation_node_user.meta["val"].dtype == torch.float32: if computation_node_user.meta["val"].dtype != torch.float32:
return return
while computation_node_user.target in _binary_ops: while computation_node_user.target in _binary_ops:
if not len(computation_node_user.users) == 1: if len(computation_node_user.users) != 1:
return return
computation_node_user = next(iter(computation_node_user.users.keys())) computation_node_user = next(iter(computation_node_user.users.keys()))
@ -188,7 +188,7 @@ def binary_folding_init():
): ):
return False return False
if not len(conv_node.args[1].users) == 1: if len(conv_node.args[1].users) != 1:
return False return False
weight_meta_value = conv_node.args[1].meta.get("val") weight_meta_value = conv_node.args[1].meta.get("val")
@ -242,7 +242,7 @@ def binary_folding_init():
): ):
return False return False
if not len(weight_node.users) == 1: if len(weight_node.users) != 1:
return False return False
weight_meta_value = weight_node.meta.get("val") weight_meta_value = weight_node.meta.get("val")

View File

@ -594,7 +594,7 @@ class OverlapScheduler:
if is_wait_tensor(node): if is_wait_tensor(node):
info = self.collective_info[self.wait_to_start[node]] info = self.collective_info[self.wait_to_start[node]]
assert not info.hiding_node == curr_compute_node assert info.hiding_node != curr_compute_node
self._handle_wait(node) self._handle_wait(node)
continue continue

View File

@ -705,7 +705,7 @@ def reorder_for_locality(graph: torch.fx.Graph):
iter(graph.find_nodes(op="call_function", target=torch.ops.aten.copy_.default)), iter(graph.find_nodes(op="call_function", target=torch.ops.aten.copy_.default)),
None, None,
) )
past_mutating_epilogue = True if first_copy is None else False past_mutating_epilogue = first_copy is None
for node in reversed(graph.nodes): for node in reversed(graph.nodes):
seen_nodes.add(node) seen_nodes.add(node)
@ -1761,7 +1761,7 @@ class ConstructorMoverPass:
if not torch._subclasses.fake_tensor._is_tensor_constructor(node.target): if not torch._subclasses.fake_tensor._is_tensor_constructor(node.target):
continue continue
if not node.kwargs.get("device") == torch.device("cpu"): if node.kwargs.get("device") != torch.device("cpu"):
continue continue
constructors.append(node) constructors.append(node)
@ -1922,13 +1922,9 @@ def move_constructors_to_gpu(graph: fx.Graph) -> None:
# by explicitly moving cpu scalar tensors to gpu when profitable, relying on # by explicitly moving cpu scalar tensors to gpu when profitable, relying on
# graph partition to split off this data copy, and cudagraphifying # graph partition to split off this data copy, and cudagraphifying
# the remaining gpu ops. # the remaining gpu ops.
allow_inputs_outputs = ( allow_inputs_outputs = bool(
True torch._inductor.config.triton.cudagraphs
if ( and torch._inductor.config.graph_partition
torch._inductor.config.triton.cudagraphs
and torch._inductor.config.graph_partition
)
else False
) )
ConstructorMoverPass( ConstructorMoverPass(
get_gpu_type(), get_gpu_type(),

View File

@ -615,7 +615,7 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
copy_node = copy_args_to_copy_nodes.get((mutated_arg, node)) copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))
if copy_node is not None: if copy_node is not None:
replace_dict[copy_node] = copy_node.args[0] replace_dict[copy_node] = copy_node.args[0]
if not trigger == ReInplaceTrigger.AUTO_FUNC_V2: if trigger != ReInplaceTrigger.AUTO_FUNC_V2:
for user in node.users: for user in node.users:
# For auto_functionalize_v2, arg is the index of the base, where base at index i corresponds to # For auto_functionalize_v2, arg is the index of the base, where base at index i corresponds to
# output atindex size(out)+i. # output atindex size(out)+i.

View File

@ -469,9 +469,7 @@ def significant_strides_equal(
if not V.graph.sizevars.statically_known_equals( if not V.graph.sizevars.statically_known_equals(
s1, s2 s1, s2
) and not V.graph.sizevars.symbolic_hint(s1) == V.graph.sizevars.symbolic_hint( ) and V.graph.sizevars.symbolic_hint(s1) != V.graph.sizevars.symbolic_hint(s2):
s2
):
return False return False
return True return True
@ -5443,7 +5441,7 @@ class ConcatKernel(NopKernel):
return True return True
# otherwise, check equality of layouts # otherwise, check equality of layouts
if not len(src.get_stride()) == len(dst.get_stride()): if len(src.get_stride()) != len(dst.get_stride()):
return False return False
return all( return all(

View File

@ -35,7 +35,7 @@ def check_cpu_supported():
supported = ( supported = (
requires_avx2_on_cpu requires_avx2_on_cpu
and not torch.xpu.is_available() and not torch.xpu.is_available()
and not sys.platform == "darwin" and sys.platform != "darwin"
) )
return supported return supported

View File

@ -1398,7 +1398,7 @@ def check_and_add_duplicate_pattern(
new_graph_str = str(graph) new_graph_str = str(graph)
for graph_str in equiv_pattern_reprs: for graph_str in equiv_pattern_reprs:
if not new_graph_str == graph_str: if new_graph_str != graph_str:
continue continue
if skip_duplicates: if skip_duplicates:
return True return True

View File

@ -237,7 +237,7 @@ def check_autotune_cache(
not disabled not disabled
and filename is not None and filename is not None
and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning")) and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning"))
and not os.environ.get("TRITON_INTERPRET", "0") == "1" and os.environ.get("TRITON_INTERPRET", "0") != "1"
): ):
configs_hash = hash_configs(configs) configs_hash = hash_configs(configs)

View File

@ -4976,7 +4976,7 @@ class Scheduler:
if name in name_to_node if name in name_to_node
} }
input_deallocation = { input_deallocation = {
name: True if name in buffer_names_to_free else False name: name in buffer_names_to_free
for name in partition_input_names for name in partition_input_names
if name in name_to_node if name in name_to_node
} }

View File

@ -132,7 +132,7 @@ def solve_for_tiling(expr: sympy.Expr) -> Optional[sympy.Expr]:
out = _solve_simple_expr(eq_1_expr_simplified) out = _solve_simple_expr(eq_1_expr_simplified)
# since we approximated FloorDiv/ModularIndexing, double check here # since we approximated FloorDiv/ModularIndexing, double check here
if not out or not (sympy_subs(eq_1_expr, {free_symbol: out})) == 1: if not out or sympy_subs(eq_1_expr, {free_symbol: out}) != 1:
return None return None
required_values.append(out) required_values.append(out)

View File

@ -504,7 +504,7 @@ def is_pointwise_use(
Uses in views ops will follow the views uses Uses in views ops will follow the views uses
""" """
if not use.op == "call_function": if use.op != "call_function":
return False return False
if not ( if not (
isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem
@ -2020,7 +2020,7 @@ def use_ck_template(layout: Layout) -> bool:
if not torch.version.hip: if not torch.version.hip:
return False return False
# tensors must be on GPU # tensors must be on GPU
if not layout.device.type == "cuda": if layout.device.type != "cuda":
return False return False
# hardware check # hardware check
# if config arch list is not specified, get the native arch from the device properties # if config arch list is not specified, get the native arch from the device properties

View File

@ -443,7 +443,7 @@ def get_callable_argument_names(fn) -> list[str]:
for name, param in callable_signature.parameters.items(): for name, param in callable_signature.parameters.items():
# All four other types of arguments do not map to individual values # All four other types of arguments do not map to individual values
# with a keyword as name. # with a keyword as name.
if not param.kind == param.POSITIONAL_OR_KEYWORD: if param.kind != param.POSITIONAL_OR_KEYWORD:
continue continue
argument_names.append(name) argument_names.append(name)

View File

@ -135,7 +135,7 @@ def mutates_and_returns_first_arg(op: OpOverload):
if op.namespace != "aten": if op.namespace != "aten":
return False return False
schema = op._schema schema = op._schema
if not len(schema.returns) == 1: if len(schema.returns) != 1:
return False return False
if schema.returns[0].alias_info is None: if schema.returns[0].alias_info is None:
return False return False

View File

@ -1207,7 +1207,7 @@ def safe_grad_filter(message, category, filename, lineno, file=None, line=None)
def user_warning_filter( def user_warning_filter(
message, category, filename, lineno, file=None, line=None message, category, filename, lineno, file=None, line=None
) -> bool: ) -> bool:
return not category == UserWarning return category != UserWarning
@contextlib.contextmanager @contextlib.contextmanager

View File

@ -317,7 +317,7 @@ def print_assert_equal(test_string, actual, desired):
__tracebackhide__ = True # Hide traceback for py.test __tracebackhide__ = True # Hide traceback for py.test
import pprint import pprint
if not (actual == desired): if actual != desired:
msg = StringIO() msg = StringIO()
msg.write(test_string) msg.write(test_string)
msg.write(" failed\nACTUAL: \n") msg.write(" failed\nACTUAL: \n")
@ -1505,7 +1505,7 @@ def _integer_repr(x, vdt, comp):
# See also # See also
# https://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/ # https://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/
rx = x.view(vdt) rx = x.view(vdt)
if not (rx.size == 1): if rx.size != 1:
rx[rx < 0] = comp - rx[rx < 0] rx[rx < 0] = comp - rx[rx < 0]
else: else:
if rx < 0: if rx < 0:

View File

@ -4316,7 +4316,7 @@ def tensor_split(
# If indices_or_sections is a tensor, it must be a CPU Long tensor # If indices_or_sections is a tensor, it must be a CPU Long tensor
if isinstance(indices_or_sections, TensorLike): if isinstance(indices_or_sections, TensorLike):
if not indices_or_sections.device.type == "cpu": if indices_or_sections.device.type != "cpu":
msg = ( msg = (
f"tensor_split: if indices_or_sections is a tensor it must be on the CPU, " f"tensor_split: if indices_or_sections is a tensor it must be on the CPU, "
f"but received one on {indices_or_sections.device}" f"but received one on {indices_or_sections.device}"

View File

@ -1335,7 +1335,7 @@ def make_fast_binary_impl(
# Use elementwise_dtypes for the tricky case # Use elementwise_dtypes for the tricky case
has_different_input_dtypes = True has_different_input_dtypes = True
continue continue
if common_device == cpu and not op.device.type == "cpu": if common_device == cpu and op.device.type != "cpu":
common_device = op.device common_device = op.device
# Slightly simplified here as target_dtype cannot vary # Slightly simplified here as target_dtype cannot vary
if common_dtype is None: if common_dtype is None:

View File

@ -3012,7 +3012,7 @@ class FakeTensorMode(TorchDispatchMode):
t.numel() <= CONSTANT_NUMEL_LIMIT t.numel() <= CONSTANT_NUMEL_LIMIT
and not is_sparse_any(t) and not is_sparse_any(t)
and not self.is_our_fake(t) and not self.is_our_fake(t)
and not t.device.type == "meta" and t.device.type != "meta"
) )
def invalidate_written_to_constants( def invalidate_written_to_constants(

View File

@ -127,7 +127,7 @@ def try_convert_fake_to_real(
key_to_real_storage = {v: k for k, v in desc.lookup_storage.items()} key_to_real_storage = {v: k for k, v in desc.lookup_storage.items()}
out = [] out = []
for t in ten_list: for t in ten_list:
if not isinstance(t, FakeTensor) or not t.layout == torch.strided: if not isinstance(t, FakeTensor) or t.layout != torch.strided:
out.append(t) out.append(t)
continue continue

View File

@ -113,7 +113,7 @@ class SchemaCheckMode(TorchDispatchMode):
return name if name != "self" else "input" return name if name != "self" else "input"
def unwrap(e): def unwrap(e):
if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor: if isinstance(e, torch.Tensor) and type(e) != torch.Tensor:
try: try:
return e.elem return e.elem
except AttributeError: except AttributeError:
@ -122,7 +122,7 @@ class SchemaCheckMode(TorchDispatchMode):
def parse_metadata(e): def parse_metadata(e):
if isinstance(e, torch.Tensor): if isinstance(e, torch.Tensor):
if not type(e) == torch.Tensor: if type(e) != torch.Tensor:
try: try:
current = e.elem current = e.elem
return ( return (

View File

@ -518,7 +518,7 @@ class ModelReportVisualizer:
# the index of the feature will the 0 + num non feature columns # the index of the feature will the 0 + num non feature columns
tensor_feature_index = feature_column_offset tensor_feature_index = feature_column_offset
row_value = row[tensor_feature_index] row_value = row[tensor_feature_index]
if not type(row_value) == str: if type(row_value) != str:
x_data.append(x_val_to_append) x_data.append(x_val_to_append)
y_data.append(row_value) y_data.append(row_value)
elif is_valid_per_channel_plot: elif is_valid_per_channel_plot:
@ -541,7 +541,7 @@ class ModelReportVisualizer:
# the index of the feature will the 0 + num non feature columns # the index of the feature will the 0 + num non feature columns
tensor_feature_index = feature_column_offset tensor_feature_index = feature_column_offset
row_value = row[tensor_feature_index] row_value = row[tensor_feature_index]
if not type(row_value) == str: if type(row_value) != str:
# only append if new index we are appending # only append if new index we are appending
if len(x_data) == 0 or x_data[-1] != x_val_to_append: if len(x_data) == 0 or x_data[-1] != x_val_to_append:
x_data.append(x_val_to_append) x_data.append(x_val_to_append)

View File

@ -51,7 +51,7 @@ def _is_match(modules, node, pattern, max_uses=sys.maxsize):
if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module): if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
if node.op != "call_module": if node.op != "call_module":
return False return False
if not type_before_parametrizations(modules[node.target]) == self_match: if type_before_parametrizations(modules[node.target]) != self_match:
return False return False
elif callable(self_match): elif callable(self_match):
if node.op != "call_function" or node.target is not self_match: if node.op != "call_function" or node.target is not self_match:

View File

@ -1708,7 +1708,7 @@ def insert_observers_for_model(
skip_inserting_observers = ( skip_inserting_observers = (
(qconfig is None) or not output_is_a_tensor (qconfig is None) or not output_is_a_tensor
) and (not node.op == "output") ) and (node.op != "output")
# TODO: take a closer look to see if we can remove this check # TODO: take a closer look to see if we can remove this check
# right now it is here because of `observed_node_names`, we are using # right now it is here because of `observed_node_names`, we are using

View File

@ -733,9 +733,7 @@ def _stack_and_check_tensors(
if tensor is None: if tensor is None:
out_jacobian[:, j].zero_() out_jacobian[:, j].zero_()
else: else:
dense = ( dense = tensor.to_dense() if tensor.layout != torch.strided else tensor
tensor.to_dense() if not tensor.layout == torch.strided else tensor
)
assert out_jacobian[:, j].numel() == dense.numel() assert out_jacobian[:, j].numel() == dense.numel()
out_jacobian[:, j] = dense.reshape(-1) out_jacobian[:, j] = dense.reshape(-1)
return out_jacobians, correct_grad_sizes, correct_grad_types return out_jacobians, correct_grad_sizes, correct_grad_types

View File

@ -116,5 +116,5 @@ class OptEinsumModule(PropModule):
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 # https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
sys.modules[__name__] = OptEinsumModule(sys.modules[__name__], __name__) sys.modules[__name__] = OptEinsumModule(sys.modules[__name__], __name__)
enabled = True if is_available() else False enabled = bool(is_available())
strategy = "auto" if is_available() else None strategy = "auto" if is_available() else None

View File

@ -200,7 +200,7 @@ class _TensorsAccessed:
del self.accesses[data_ptr] del self.accesses[data_ptr]
def were_there_reads_since_last_write(self, data_ptr: DataPtr) -> bool: def were_there_reads_since_last_write(self, data_ptr: DataPtr) -> bool:
return True if self.accesses[data_ptr].reads else False return bool(self.accesses[data_ptr].reads)
def get_allocation_stack_trace( def get_allocation_stack_trace(
self, data_ptr: DataPtr self, data_ptr: DataPtr

View File

@ -548,7 +548,7 @@ def create_default_global_save_plan(
for plan in all_plans: for plan in all_plans:
new_items = [] new_items = []
for item in plan.items: for item in plan.items:
if not item.type == WriteItemType.SHARD: if item.type != WriteItemType.SHARD:
assert item.index.fqn not in md assert item.index.fqn not in md
if item.type == WriteItemType.BYTE_IO: if item.type == WriteItemType.BYTE_IO:

View File

@ -1508,8 +1508,7 @@ def _allgather_orig_param_states(
return output_states return output_states
has_state_params: list[bool] = [ has_state_params: list[bool] = [
True if fqn in output_states else False fqn in output_states for fqn, idx in fsdp_param_info.param_indices.items()
for fqn, idx in fsdp_param_info.param_indices.items()
] ]
# Loop through the ``state_buffers`` and construct the flattened, concatenated, # Loop through the ``state_buffers`` and construct the flattened, concatenated,

View File

@ -1191,7 +1191,7 @@ def _add_send_recv(
""" """
if action is None: if action is None:
return True return True
elif action.computation_type == F and not action.stage_index == 0: elif action.computation_type == F and action.stage_index != 0:
if ( if (
_Action(action.stage_index, RECV_F, action.microbatch_index) _Action(action.stage_index, RECV_F, action.microbatch_index)
in prev_actions in prev_actions
@ -1205,7 +1205,7 @@ def _add_send_recv(
return False return False
elif ( elif (
action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD) action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD)
and not action.stage_index == num_stages - 1 and action.stage_index != num_stages - 1
): ):
if ( if (
_Action(action.stage_index, RECV_B, action.microbatch_index) _Action(action.stage_index, RECV_B, action.microbatch_index)

View File

@ -350,7 +350,7 @@ def _tensorpipe_init_backend_handler(
device_count = torch.cuda.device_count() device_count = torch.cuda.device_count()
is_static_group = True if world_size else False is_static_group = bool(world_size)
# world_size is specified so this is a static group (ranks cannot join and leave) # world_size is specified so this is a static group (ranks cannot join and leave)
if is_static_group: if is_static_group:
# The agent's join method is required to behave like a barrier and perform # The agent's join method is required to behave like a barrier and perform

View File

@ -1067,12 +1067,12 @@ def grouped_mm_strategy(op_schema: OpSchema) -> OpStrategy:
if meta.stride[end_dim - 1] == 1 and meta.stride[end_dim] >= max( if meta.stride[end_dim - 1] == 1 and meta.stride[end_dim] >= max(
1, meta.shape[end_dim - 1] 1, meta.shape[end_dim - 1]
): ):
if not meta.stride[end_dim] % alignment == 0: if meta.stride[end_dim] % alignment != 0:
return False return False
elif meta.stride[end_dim] == 1 and meta.stride[end_dim - 1] >= max( elif meta.stride[end_dim] == 1 and meta.stride[end_dim - 1] >= max(
1, meta.shape[end_dim] 1, meta.shape[end_dim]
): ):
if not meta.stride[end_dim - 1] % alignment == 0: if meta.stride[end_dim - 1] % alignment != 0:
return False return False
else: else:
return False return False

View File

@ -187,7 +187,7 @@ class LocalShardsWrapper(torch.Tensor):
aten.equal.default(x, y) for x, y in zip(a.local_shards(), b.local_shards()) aten.equal.default(x, y) for x, y in zip(a.local_shards(), b.local_shards())
): ):
return False return False
if not a.storage_metadata() == b.storage_metadata(): if a.storage_metadata() != b.storage_metadata():
return False return False
return True return True

View File

@ -464,7 +464,7 @@ def _insert_reshard_gm(
if reshard_node.op not in ["placeholder", "output"]: if reshard_node.op not in ["placeholder", "output"]:
reshard_node.meta["nn_module_stack"] = ( reshard_node.meta["nn_module_stack"] = (
copy.copy(input_arg.meta["nn_module_stack"]) copy.copy(input_arg.meta["nn_module_stack"])
if not input_arg.op == "placeholder" if input_arg.op != "placeholder"
else copy.copy(node.meta["nn_module_stack"]) else copy.copy(node.meta["nn_module_stack"])
) )
output_node = gm.graph.graph_copy( output_node = gm.graph.graph_copy(

View File

@ -142,7 +142,7 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None:
return return
next_module_node = next(iter(unflatten_getitem_getitem_users)) next_module_node = next(iter(unflatten_getitem_getitem_users))
if not (next_module_node.op == "call_module"): if next_module_node.op != "call_module":
log.debug( log.debug(
"Unflatten node %s's user is not a call_module. " "Unflatten node %s's user is not a call_module. "
"Instead it is: %s. Passing...", "Instead it is: %s. Passing...",

View File

@ -590,7 +590,7 @@ def _decompose_and_get_gm_with_new_signature_constants(
ep.graph_module, ep.graph_module,
fake_args, fake_args,
decompositions=python_decomp_table, decompositions=python_decomp_table,
trace_joint=True if joint_loss_index is not None else False, trace_joint=joint_loss_index is not None,
output_loss_index=( output_loss_index=(
joint_loss_index if joint_loss_index is not None else None joint_loss_index if joint_loss_index is not None else None
), ),

View File

@ -1942,7 +1942,7 @@ def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
torch._check_type( torch._check_type(
not indices.is_complex() not indices.is_complex()
and not indices.is_floating_point() and not indices.is_floating_point()
and not indices.dtype == torch.bool, and indices.dtype != torch.bool,
lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}", lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}",
) )

View File

@ -7056,7 +7056,7 @@ class ShapeEnv:
expr, expr,
concrete_val, concrete_val,
# only print stack trace when debug mode is on (e.g. TORCH_LOGS="dynamic") # only print stack trace when debug mode is on (e.g. TORCH_LOGS="dynamic")
stack_info=True if log.getEffectiveLevel() < logging.WARNING else False, stack_info=log.getEffectiveLevel() < logging.WARNING,
) )
def _get_user_frame(self) -> Optional[types.FrameType]: def _get_user_frame(self) -> Optional[types.FrameType]:

View File

@ -668,7 +668,7 @@ class _MinimizerBase:
final_start_idx: Optional[int] = start_idx final_start_idx: Optional[int] = start_idx
final_end_idx: Optional[int] = end_idx final_end_idx: Optional[int] = end_idx
run_both = True if find_last_node is None else False run_both = find_last_node is None
# step 1: find (0, end_idx) of culprit block # step 1: find (0, end_idx) of culprit block
if run_both or find_last_node: if run_both or find_last_node:

View File

@ -3649,7 +3649,7 @@ def smooth_l1_loss(
reduction=reduction, reduction=reduction,
beta=beta, beta=beta,
) )
if not (target.size() == input.size()): if target.size() != input.size():
warnings.warn( warnings.warn(
f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). "
"This will likely lead to incorrect results due to broadcasting. " "This will likely lead to incorrect results due to broadcasting. "
@ -3712,7 +3712,7 @@ def huber_loss(
weight=weight, weight=weight,
) )
if not (target.size() == input.size()): if target.size() != input.size():
warnings.warn( warnings.warn(
f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). "
"This will likely lead to incorrect results due to broadcasting. " "This will likely lead to incorrect results due to broadcasting. "
@ -3789,7 +3789,7 @@ def l1_loss(
reduce=reduce, reduce=reduce,
reduction=reduction, reduction=reduction,
) )
if not (target.size() == input.size()): if target.size() != input.size():
warnings.warn( warnings.warn(
f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). "
"This will likely lead to incorrect results due to broadcasting. " "This will likely lead to incorrect results due to broadcasting. "
@ -3862,7 +3862,7 @@ def mse_loss(
weight=weight, weight=weight,
) )
if not (target.size() == input.size()): if target.size() != input.size():
warnings.warn( warnings.warn(
f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). "
"This will likely lead to incorrect results due to broadcasting. " "This will likely lead to incorrect results due to broadcasting. "

View File

@ -242,7 +242,7 @@ class RNNBase(Module):
for fw in self._flat_weights: for fw in self._flat_weights:
if ( if (
not isinstance(fw, Tensor) not isinstance(fw, Tensor)
or not (fw.dtype == dtype) or fw.dtype != dtype
or not fw.is_cuda or not fw.is_cuda
or not torch.backends.cudnn.is_acceptable(fw) or not torch.backends.cudnn.is_acceptable(fw)
): ):

View File

@ -382,7 +382,7 @@ class TransformerEncoder(Module):
why_not_sparsity_fast_path = ( why_not_sparsity_fast_path = (
f"{enc_layer}.activation_relu_or_gelu was not True" f"{enc_layer}.activation_relu_or_gelu was not True"
) )
elif not (encoder_layer.norm1.eps == encoder_layer.norm2.eps): elif encoder_layer.norm1.eps != encoder_layer.norm2.eps:
why_not_sparsity_fast_path = ( why_not_sparsity_fast_path = (
f"{enc_layer}.norm1.eps was not equal to {enc_layer}.norm2.eps" f"{enc_layer}.norm1.eps was not equal to {enc_layer}.norm2.eps"
) )
@ -458,7 +458,7 @@ class TransformerEncoder(Module):
) )
elif first_layer.training: elif first_layer.training:
why_not_sparsity_fast_path = f"{str_first_layer} was in training mode" why_not_sparsity_fast_path = f"{str_first_layer} was in training mode"
elif not src.dim() == 3: elif src.dim() != 3:
why_not_sparsity_fast_path = ( why_not_sparsity_fast_path = (
f"input not batched; expected src.dim() of 3 but got {src.dim()}" f"input not batched; expected src.dim() of 3 but got {src.dim()}"
) )
@ -832,7 +832,7 @@ class TransformerEncoderLayer(Module):
why_not_sparsity_fast_path = ( why_not_sparsity_fast_path = (
"torch.backends.mha.get_fastpath_enabled() was not True" "torch.backends.mha.get_fastpath_enabled() was not True"
) )
elif not src.dim() == 3: elif src.dim() != 3:
why_not_sparsity_fast_path = ( why_not_sparsity_fast_path = (
f"input not batched; expected src.dim() of 3 but got {src.dim()}" f"input not batched; expected src.dim() of 3 but got {src.dim()}"
) )
@ -846,7 +846,7 @@ class TransformerEncoderLayer(Module):
why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True" why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"
elif not self.activation_relu_or_gelu: elif not self.activation_relu_or_gelu:
why_not_sparsity_fast_path = "activation_relu_or_gelu was not True" why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"
elif not (self.norm1.eps == self.norm2.eps): elif self.norm1.eps != self.norm2.eps:
why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps" why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"
elif src.is_nested and ( elif src.is_nested and (
src_key_padding_mask is not None or src_mask is not None src_key_padding_mask is not None or src_mask is not None

View File

@ -136,7 +136,7 @@ def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode):
try: try:
yield yield
finally: finally:
if hasattr(model, "training") and not mode == _C_onnx.TrainingMode.PRESERVE: if hasattr(model, "training") and mode != _C_onnx.TrainingMode.PRESERVE:
model.train(originally_training) model.train(originally_training)

View File

@ -9346,7 +9346,7 @@ def sample_inputs_multi_head_attention_forward(opinfo, device, dtype, requires_g
"k_proj_weight" : k_proj_weight, "k_proj_weight" : k_proj_weight,
"v_proj_weight" : v_proj_weight, "v_proj_weight" : v_proj_weight,
"attn_mask" : attn_mask, "attn_mask" : attn_mask,
"training" : True if dropout_p > 0.0 else False, "training" : dropout_p > 0.0,
"use_separate_proj_weight" : use_separate_proj_weight "use_separate_proj_weight" : use_separate_proj_weight
} }

View File

@ -1287,7 +1287,7 @@ class QuantizationTestCase(TestCase):
prepare_custom_config=prepare_custom_config, prepare_custom_config=prepare_custom_config,
backend_config=backend_config, backend_config=backend_config,
) )
if not quant_type == QuantType.DYNAMIC: if quant_type != QuantType.DYNAMIC:
prepared(*inputs) prepared(*inputs)
if print_debug_info: if print_debug_info:

View File

@ -1470,7 +1470,7 @@ TEST_ACL = torch.backends.mkldnn.is_available() and torch.ops.mkldnn._is_mkldnn_
TEST_MPS = torch.backends.mps.is_available() TEST_MPS = torch.backends.mps.is_available()
MACOS_VERSION = float('.'.join(platform.mac_ver()[0].split('.')[:2]) or -1) MACOS_VERSION = float('.'.join(platform.mac_ver()[0].split('.')[:2]) or -1)
TEST_XPU = torch.xpu.is_available() TEST_XPU = torch.xpu.is_available()
TEST_HPU = True if (hasattr(torch, "hpu") and torch.hpu.is_available()) else False TEST_HPU = bool(hasattr(torch, "hpu") and torch.hpu.is_available())
TEST_CUDA = torch.cuda.is_available() TEST_CUDA = torch.cuda.is_available()
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name(), None) custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name(), None)
TEST_PRIVATEUSE1 = is_privateuse1_backend_available() TEST_PRIVATEUSE1 = is_privateuse1_backend_available()
@ -3195,7 +3195,7 @@ class TestCase(expecttest.TestCase):
def remove_empty_lines(self, input_string): def remove_empty_lines(self, input_string):
lines = input_string.split('\n') lines = input_string.split('\n')
filtered_lines = [line for line in lines if not line.strip() == ''] filtered_lines = [line for line in lines if line.strip() != '']
return '\n'.join(filtered_lines) return '\n'.join(filtered_lines)
# ignore comments will ignore lines that starts with # after being stripped # ignore comments will ignore lines that starts with # after being stripped

View File

@ -132,7 +132,7 @@ def _infer_device_type(*args):
def add_device_types(arg): def add_device_types(arg):
nonlocal device_types nonlocal device_types
if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu": if isinstance(arg, torch.Tensor) and arg.device.type != "cpu":
device_types.append(arg.device.type) device_types.append(arg.device.type)
tree_map(add_device_types, args) tree_map(add_device_types, args)

View File

@ -119,7 +119,7 @@ PIP_PATTERNS = [
def run(command): def run(command):
"""Return (return-code, stdout, stderr).""" """Return (return-code, stdout, stderr)."""
shell = True if type(command) is str else False shell = type(command) is str
p = subprocess.Popen( p = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell
) )

View File

@ -228,7 +228,7 @@ CUDA_NOT_FOUND_MESSAGE = (
) )
ROCM_HOME = _find_rocm_home() if (torch.cuda._is_compiled() and torch.version.hip) else None ROCM_HOME = _find_rocm_home() if (torch.cuda._is_compiled() and torch.version.hip) else None
HIP_HOME = _join_rocm_home('hip') if ROCM_HOME else None HIP_HOME = _join_rocm_home('hip') if ROCM_HOME else None
IS_HIP_EXTENSION = True if ((ROCM_HOME is not None) and (torch.version.hip is not None)) else False IS_HIP_EXTENSION = bool(ROCM_HOME is not None and torch.version.hip is not None)
ROCM_VERSION = None ROCM_VERSION = None
if torch.version.hip is not None: if torch.version.hip is not None:
ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2]) ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2])