Revert "Enable all SIM rules except disabled ones (#164645)"

This reverts commit 321e602692.

Reverted https://github.com/pytorch/pytorch/pull/164645 on behalf of https://github.com/izaitsevfb due to causes lint failures ([comment](https://github.com/pytorch/pytorch/pull/164645#issuecomment-3369274351))
This commit is contained in:
PyTorch MergeBot 2025-10-05 19:32:21 +00:00
parent 321e602692
commit 5d7360bb03
97 changed files with 255 additions and 182 deletions

View File

@ -4035,7 +4035,7 @@ def run(runner, args, original_dir=None):
else:
optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython)
experiment = (
speedup_experiment if args.backend != "torchao" else latency_experiment
speedup_experiment if not args.backend == "torchao" else latency_experiment
)
if args.accuracy:
output_filename = f"accuracy_{args.backend}.csv"

View File

@ -270,7 +270,7 @@ def run_single_backend_sdpa(
if config.calculate_bwd_time:
# TODO: debug backward pass for njt
if eager_sdpa and config.attn_type != "document_mask":
if eager_sdpa and not config.attn_type == "document_mask":
d_out = torch.randn_like(out_eager.transpose(1, 2)).transpose(1, 2)
backward_eager_time = benchmark_torch_function_in_microseconds(
out_eager.backward, d_out, retain_graph=True

View File

@ -181,7 +181,6 @@ ignore = [
"SIM116", # Disable Use a dictionary instead of consecutive `if` statements
"SIM117",
"SIM118",
"SIM300", # Yoda condition detected
"UP007", # keep-runtime-typing
"UP045", # keep-runtime-typing
"TC006",
@ -197,7 +196,8 @@ select = [
"E",
"EXE",
"F",
"SIM",
"SIM1",
"SIM911",
"W",
# Not included in flake8
"FURB",

View File

@ -55,7 +55,7 @@ class TestActivationSparsifier(TestCase):
for key, config in sparsifier_defaults.items():
# all the keys in combined_defaults should be present in sparsifier defaults
assert config == combined_defaults.get(key)
assert config == combined_defaults.get(key, None)
def _check_register_layer(
self, activation_sparsifier, defaults, sparse_config, layer_args_list

View File

@ -3074,7 +3074,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
wrong_dtype_shards, [10, 10], init_rrefs=True
)
tensor_requires_grad = self.rank == 0
tensor_requires_grad = True if self.rank == 0 else False
wrong_requires_grad_shards = [
sharded_tensor.Shard(
torch.randn(
@ -3121,7 +3121,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
wrong_pin_memory_local_shards, [10, 10], init_rrefs=True
)
tensor_pin_memory = self.rank == 0
tensor_pin_memory = True if self.rank == 0 else False
wrong_pin_memory_shards_cross_ranks = [
sharded_tensor.Shard(
torch.randn(5, 5, pin_memory=tensor_pin_memory), local_shard_metadata

View File

@ -146,7 +146,7 @@ class TestStorageBase:
self.rank = 0 if not dist.is_initialized() else dist.get_rank()
def _get_ranks(self, name):
return self.fail_conf.get(name, None)
return self.fail_conf[name] if name in self.fail_conf else None
def _fail_rank(self, name):
ranks = self._get_ranks(name)

View File

@ -155,7 +155,7 @@ class TestFreezingWeights(FSDPTest):
ddp_kwargs = {
"device_ids": [self.rank],
"find_unused_parameters": bool(disable_autograd),
"find_unused_parameters": True if disable_autograd else False,
}
model = self._create_model(

View File

@ -65,7 +65,7 @@ class MockPipelineStage(_PipelineStageBase):
self.num_stages = kwargs.get("num_stages", 1)
self.group_size = kwargs.get("group_size", 1)
self.group_rank = kwargs.get("group_rank", 0)
self.group = kwargs.get("group")
self.group = kwargs.get("group", None)
def _create_grad_recv_info(self, *args, **kwargs):
return None

View File

@ -1022,7 +1022,7 @@ class TestDTensorPlacementTypes(DTensorTestBase):
assert_array_equal(expected_pad_sizes, pad_sizes)
is_tensor_empty = [
not splitted_tensor.numel() > 0
False if splitted_tensor.numel() > 0 else True
for splitted_tensor in splitted_tensor_list
]
expected_is_tensor_empty = [True] * self.world_size
@ -1045,10 +1045,12 @@ class TestDTensorPlacementTypes(DTensorTestBase):
for i, tensor in enumerate(splitted_tensor_list)
]
expected_is_tensor_empty = [
not idx < size for idx, _ in enumerate(range(self.world_size))
False if idx < size else True
for idx, _ in enumerate(range(self.world_size))
]
is_tensor_empty = [
not unpadded_tensor.numel() > 0 for unpadded_tensor in unpadded_list
False if unpadded_tensor.numel() > 0 else True
for unpadded_tensor in unpadded_list
]
assert_array_equal(expected_is_tensor_empty, is_tensor_empty)

View File

@ -2770,7 +2770,11 @@ class WorkHookTest(MultiProcessTestCase):
# from rank0 to other ranks. However, this is DDP's internal implementation,
# which is subject to change in future versions.
self.assertTrue(num_hook_fired[OpType.BROADCAST] > 0)
ctor_allreduce = num_hook_fired.get(OpType.ALLREDUCE, 0)
ctor_allreduce = (
num_hook_fired[OpType.ALLREDUCE]
if OpType.ALLREDUCE in num_hook_fired
else 0
)
x = torch.zeros(2, 1000).cuda(self.rank)
ddp(x).sum().backward()

View File

@ -82,7 +82,7 @@ def grad(L, desired_results: list[Variable]) -> list[Variable]:
# look up dL_dentries. If a variable is never used to compute the loss,
# we consider its gradient None, see the note below about zeros for more information.
def gather_grad(entries: list[str]):
return [dL_d.get(entry) for entry in entries]
return [dL_d[entry] if entry in dL_d else None for entry in entries]
# propagate the gradient information backward
for entry in reversed(gradient_tape):

View File

@ -286,7 +286,7 @@ class OptionalScaledTensor(torch.Tensor):
def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride):
return OptionalScaledTensor(
inner_tensors["_data"],
inner_tensors.get("_scale", None),
inner_tensors["_scale"] if "_scale" in inner_tensors else None,
constant=metadata["_constant"],
)

View File

@ -38,7 +38,7 @@ class PreDispatchSchemaCheckMode(SchemaCheckMode):
def _may_alias_or_mutate(self, func, types, args, kwargs):
def unwrap(e):
if isinstance(e, torch.Tensor) and type(e) != torch.Tensor:
if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor:
try:
return e.elem
except AttributeError:

View File

@ -358,7 +358,9 @@ def _sequential_split_inline_tests():
for i, node in enumerate(insert_locs):
with gm.graph.inserting_before(node):
gm.graph.call_function(torch._C._set_grad_enabled, (i % 2 == 0,), {})
gm.graph.call_function(
torch._C._set_grad_enabled, (True if i % 2 == 0 else False,), {}
)
return gm
x = torch.randn(2, 2)

View File

@ -2932,7 +2932,9 @@ class GraphModule(torch.nn.Module):
if autograd:
result_flat = pytree.tree_leaves(result)
result_exp_flat = pytree.tree_leaves(result_exp)
exp_grad_mask = [bool(r.requires_grad) for r in result_exp_flat]
exp_grad_mask = [
True if r.requires_grad else False for r in result_exp_flat
]
self.check_autograd(
[r for r, m in zip(result_flat, exp_grad_mask) if m],
[r for r, m in zip(result_exp_flat, exp_grad_mask) if m],
@ -3739,7 +3741,9 @@ class AssociativeScanTests(TestCase):
):
result_flat = pytree.tree_leaves(result)
result_exp_flat = pytree.tree_leaves(result_exp)
exp_grad_mask = [bool(r.requires_grad) for r in result_exp_flat]
exp_grad_mask = [
True if r.requires_grad else False for r in result_exp_flat
]
self._check_autograd(
[r for r, m in zip(result_flat, exp_grad_mask) if m],
@ -5706,9 +5710,10 @@ def forward(self, arg0_1):
)
def test_while_loop_tracing(self, while_loop_test):
fn, inp = WHILE_LOOP_TESTS[while_loop_test]
allow_non_fake_inputs = while_loop_test in (
"simple_with_linear",
"nested_with_linear",
allow_non_fake_inputs = (
False
if while_loop_test not in ("simple_with_linear", "nested_with_linear")
else True
)
self._check_tracing(fn, inp, allow_non_fake_inputs)

View File

@ -177,7 +177,9 @@ class TestFXNodeSource(TestCase):
for node_name_2 in node_name_to_from_node:
if node_name_2 in {
node_name_1,
same_ancestor_nodes.get(node_name_1),
same_ancestor_nodes[node_name_1]
if node_name_1 in same_ancestor_nodes
else None,
}:
self.assertEqual(
node_name_to_from_node[node_name_1],

View File

@ -164,7 +164,9 @@ class B2BGEMMTest(TestCase):
self.assertTrue("B2B_GEMM_LEFT_TRITON_ENTRANCE" not in code)
self.assertTrue("B2B_GEMM_RIGHT_TRITON_ENTRANCE" not in code)
@unittest.skipIf(os.environ.get("DO_PERF_TEST") != "1", "Perf test not enabled")
@unittest.skipIf(
not (os.environ.get("DO_PERF_TEST") == "1"), "Perf test not enabled"
)
@torch._dynamo.config.patch(recompile_limit=32)
def test_plain_b2b_gemm_performance(self):
"""compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)"""
@ -217,7 +219,9 @@ class B2BGEMMTest(TestCase):
# flaky test assertion: disabled
# self.assertTrue(average_speedup > 1)
@unittest.skipIf(os.environ.get("DO_PERF_TEST") != "1", "Perf test not enabled")
@unittest.skipIf(
not (os.environ.get("DO_PERF_TEST") == "1"), "Perf test not enabled"
)
@torch._dynamo.config.patch(recompile_limit=32)
def test_gelu_b2b_gemm_performance(self):
"""compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)"""
@ -272,7 +276,9 @@ class B2BGEMMTest(TestCase):
# flaky test assertion: disabled
# self.assertTrue(average_speedup > 1)
@unittest.skipIf(os.environ.get("DO_PERF_TEST") != "1", "Perf test not enabled")
@unittest.skipIf(
not (os.environ.get("DO_PERF_TEST") == "1"), "Perf test not enabled"
)
@torch._dynamo.config.patch(recompile_limit=32)
def test_gelu_mlp_b2b_gemm_performance(self):
"""compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)"""

View File

@ -165,7 +165,7 @@ class BenchmarkFusionTestTemplate:
_, out_code = run_and_get_code(foo_c, m, inp)
# occasionally, CI will make this one kernel. just skip in this case
if out_code[0].count("def triton_") != 2:
if not out_code[0].count("def triton_") == 2:
return
# should be multiple triton invocations

View File

@ -289,7 +289,7 @@ def build_opt_kwarg_db():
has_tensor_lr = False
for key, val in kwargs.items():
if (key != "lr" and key != "betas") and (
if (not key == "lr" and not key == "betas") and (
not isinstance(val, bool) or (isinstance(val, bool) and val)
):
name += "_" + key
@ -450,7 +450,7 @@ def make_test(
stack.enter_context(config.patch({"triton.cudagraphs": True}))
kwargs_compiled = deepcopy(kwargs)
if isinstance(kwargs.get("lr"), torch.Tensor):
if isinstance(kwargs.get("lr", None), torch.Tensor):
kwargs["lr"] = kwargs["lr"].to(device)
kwargs_compiled["lr"] = kwargs_compiled["lr"].to(device)

View File

@ -583,7 +583,9 @@ class TestFlexAttention(InductorTestCase):
)
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
sdpa_partial = create_attention(score_mod, block_mask, enable_gqa=(Q_H != KV_H))
sdpa_partial = create_attention(
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
)
compiled_sdpa = torch.compile(sdpa_partial)
golden_out = sdpa_partial(q_gold, k_gold, v_gold)
@ -757,7 +759,7 @@ class TestFlexAttention(InductorTestCase):
return_lse=return_lse,
block_mask=converted_block_mask,
score_mod=converted_score_mod,
enable_gqa=(Q_H != KV_H),
enable_gqa=(not Q_H == KV_H),
kernel_options=kernel_options,
)
else:
@ -770,7 +772,7 @@ class TestFlexAttention(InductorTestCase):
return_lse=return_lse,
block_mask=converted_block_mask,
score_mod=converted_score_mod,
enable_gqa=(Q_H != KV_H),
enable_gqa=(not Q_H == KV_H),
kernel_options=kernel_options,
)
return compiled_out, compiled_lse
@ -815,7 +817,9 @@ class TestFlexAttention(InductorTestCase):
if block_mask is None:
block_mask = create_block_mask(noop_mask, Q_B, 1, Q_S, KV_S, device=device)
sdpa_partial = create_attention(score_mod, block_mask, enable_gqa=(Q_H != KV_H))
sdpa_partial = create_attention(
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
)
golden_out, golden_lse = sdpa_partial(q_gold, k_gold, v_gold, return_lse=True)
ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True)
@ -1460,7 +1464,7 @@ class TestFlexAttention(InductorTestCase):
block_mask = create_block_mask(mask_mod, Bq, 1, S, S, device=device)
attention = functools.partial(
flex_attention, block_mask=block_mask, enable_gqa=(Hq != Hkv)
flex_attention, block_mask=block_mask, enable_gqa=(not Hq == Hkv)
)
self.run_test_with_call(attention, dtype, device, Bq, Hq, S, D, Bkv, Hkv, S, D)

View File

@ -412,7 +412,7 @@ class TestFlexDecoding(InductorTestCase):
sdpa_partial = create_attention(
score_mod,
block_mask,
enable_gqa=(Q_H != KV_H),
enable_gqa=(not Q_H == KV_H),
kernel_options=kernel_options,
)
compiled_sdpa = torch.compile(sdpa_partial)
@ -607,7 +607,7 @@ class TestFlexDecoding(InductorTestCase):
return_lse=True,
block_mask=converted_block_mask,
score_mod=converted_score_mod,
enable_gqa=(Q_H != KV_H),
enable_gqa=(not Q_H == KV_H),
)
else:
compiled_lse = None
@ -618,7 +618,7 @@ class TestFlexDecoding(InductorTestCase):
return_lse=False,
block_mask=converted_block_mask,
score_mod=converted_score_mod,
enable_gqa=(Q_H != KV_H),
enable_gqa=(not Q_H == KV_H),
)
return compiled_out, compiled_lse
@ -664,7 +664,9 @@ class TestFlexDecoding(InductorTestCase):
if block_mask is None:
block_mask = create_block_mask(noop_mask, Q_B, 1, 1, KV_S, device=device)
sdpa_partial = create_attention(score_mod, block_mask, enable_gqa=(Q_H != KV_H))
sdpa_partial = create_attention(
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
)
golden_out, gold_lse = sdpa_partial(q_gold, k_gold, v_gold, return_lse=True)
ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True)
@ -904,7 +906,7 @@ class TestFlexDecoding(InductorTestCase):
sdpa_partial = create_attention(
score_mod=score_mod,
block_mask=None,
enable_gqa=(Hq != Hkv),
enable_gqa=(not Hq == Hkv),
)
compiled_sdpa = torch.compile(sdpa_partial)
ref_out = sdpa_partial(q, k, v)
@ -1142,7 +1144,7 @@ class TestFlexDecoding(InductorTestCase):
def head_attention_mod(kv_head_num):
head_type = torch.tensor(
[i % kv_head_num != 0 for i in range(kv_head_num)],
[False if i % kv_head_num == 0 else True for i in range(kv_head_num)],
dtype=torch.bool,
device=device,
)

View File

@ -22,7 +22,7 @@ except ImportError:
HAS_PYDOT = False
HAS_DOT = shutil.which("dot") is not None
HAS_DOT = True if shutil.which("dot") is not None else False
class TestGraphTransformObserver(TestCase):

View File

@ -834,7 +834,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
for dtype in dtypes:
torch._dynamo.reset()
autocast_enabled = dtype in [torch.bfloat16, torch.float16]
autocast_enabled = (
True if dtype in [torch.bfloat16, torch.float16] else False
)
with (
torch.no_grad(),
torch.autocast(
@ -4418,12 +4420,14 @@ class TestPatternMatcher(TestPatternMatcherBase):
out_feature = 64
q_min, q_max = -32, 31
# we only test for qlinear_binary in this case
test_for_pointwise_binary = bool(
M == 1
test_for_pointwise_binary = (
True
if M == 1
and inplace_add
and not expand_a_scale
and not dynamic
and not has_bias
else False
)
if test_for_pointwise_binary and not IS_X86:
self.skipTest("Some UTs are only supported on x86_64 CPUs")

View File

@ -706,7 +706,7 @@ def check_model_gpu(
if check_lowp:
def downcast_fn(x):
if not isinstance(x, torch.Tensor) or x.dtype != torch.float:
if not isinstance(x, torch.Tensor) or not x.dtype == torch.float:
return x
return torch.empty_strided(
x.size(), x.stride(), device=GPU_TYPE, dtype=torch.half
@ -4693,7 +4693,7 @@ class CommonTemplate:
# Make sure we compute also with fp16 in the reference. Otherwise,
# the reference will compute with fp32 and cast back to fp16, which
# causes numeric differences beyond tolerance.
reference_in_float=not torch.version.hip,
reference_in_float=False if torch.version.hip else True,
)
def test_convolution2(self):
@ -4727,7 +4727,7 @@ class CommonTemplate:
# Make sure we compute also with fp16 in the reference. Otherwise,
# the reference will compute with fp32 and cast back to fp16, which
# causes numeric differences beyond tolerance.
reference_in_float=not torch.version.hip,
reference_in_float=False if torch.version.hip else True,
)
@skip_if_gpu_halide
@ -4778,7 +4778,7 @@ class CommonTemplate:
# Make sure we compute also with fp16 in the reference. Otherwise,
# the reference will compute with fp32 and cast back to fp16, which
# causes numeric differences beyond tolerance.
reference_in_float=not torch.version.hip,
reference_in_float=False if torch.version.hip else True,
)
def test_conv2d_channels_last(self):
@ -12937,7 +12937,7 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
)
res = torch.compile(fn)(20)
self.assertTrue(torch.all((res >= 0) & (res < 10)).item())
self.assertTrue(torch.all((0 <= res) & (res < 10)).item())
@torch._inductor.config.patch(force_shape_pad=True)
@skip_if_gpu_halide # correctness issue

View File

@ -1217,7 +1217,7 @@ class TestInductorOpInfo(TestCase):
# not exercised in test_ops_gradients atm. The problem is not
# complex32 per-se (which is supported by data movement only ops)
# but that when we do backwards we expect other ops like add to work
and dtype != torch.complex32
and not dtype == torch.complex32
)
samples = op.sample_inputs(device, dtype, requires_grad=requires_grad)

View File

@ -17,13 +17,17 @@ with open(sys.argv[1]) as input_yaml_file:
for info in model_infos:
for op in info["root_operators"]:
# aggregate occurance per op
root_operators[op] = 1 + (root_operators.get(op, 0))
root_operators[op] = 1 + (root_operators[op] if op in root_operators else 0)
for op in info["traced_operators"]:
# aggregate occurance per op
traced_operators[op] = 1 + (traced_operators.get(op, 0))
traced_operators[op] = 1 + (
traced_operators[op] if op in traced_operators else 0
)
# merge dtypes for each kernel
for kernal, dtypes in info["kernel_metadata"].items():
new_dtypes = dtypes + (kernel_metadata.get(kernal, []))
new_dtypes = dtypes + (
kernel_metadata[kernal] if kernal in kernel_metadata else []
)
kernel_metadata[kernal] = list(set(new_dtypes))

View File

@ -4879,7 +4879,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
@skipScriptTest()
def test_rnn_no_bias(self):
def make_model(layers, packed_sequence):
batch_first = packed_sequence == 2
batch_first = True if packed_sequence == 2 else False
model = torch.nn.RNN(
RNN_INPUT_SIZE,
RNN_HIDDEN_SIZE,
@ -4900,7 +4900,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
return model
def make_input(batch_size, layers, packed_sequence):
batch_first = packed_sequence == 2
batch_first = True if packed_sequence == 2 else False
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
seq_lengths = sorted(map(int, seq_lengths), reverse=True)
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]

View File

@ -7045,8 +7045,8 @@ class TestQuantizedConv(TestCase):
# ONEDNN only supports symmetric quantization of weight
if W_zero_point is not None:
W_zero_point = len(W_zero_point) * [0]
fp32_output = qconv_output_dtype is torch.float32
bfloat16_output = qconv_output_dtype is torch.bfloat16
fp32_output = True if qconv_output_dtype is torch.float32 else False
bfloat16_output = True if qconv_output_dtype is torch.bfloat16 else False
if fp32_output or bfloat16_output:
Y_scale = 1.0
Y_zero_point = 0
@ -7905,8 +7905,8 @@ class TestQuantizedConv(TestCase):
weight_in_channel_last_format=False,
):
# We assume FP8 quantization is always symmetric
fp32_output = qconv_output_dtype is torch.float32
bfloat16_output = qconv_output_dtype is torch.bfloat16
fp32_output = True if qconv_output_dtype is torch.float32 else False
bfloat16_output = True if qconv_output_dtype is torch.bfloat16 else False
if fp32_output or bfloat16_output:
Y_scale = 1.0
X2_scale = 1.0

View File

@ -11712,7 +11712,7 @@ class TestAutogradDeviceType(TestCase):
def test_nonzero(tensor, value, expected):
tensor[0] = value
self.assertEqual(expected, bool(tensor))
self.assertEqual(expected, bool(tensor))
self.assertEqual(expected, True if tensor else False)
test_nonzero(l, 0, False)
test_nonzero(l, -2, True)

View File

@ -577,7 +577,7 @@ print(t.is_pinned())
src = torch.randn(
1000000,
device="cuda" if dst == "cpu" else "cpu",
pin_memory=dst == "cuda",
pin_memory=True if dst == "cuda" else False,
)
_test_to_non_blocking(src, try_non_blocking, dst)

View File

@ -942,7 +942,7 @@ def forward(self, scores_1, mask_1, value_1):
# not exercised in test_ops_gradients atm. The problem is not
# complex32 per-se (which is supported by data movement only ops)
# but that when we do backwards we expect other ops like add to work
and dtype != torch.complex32
and not dtype == torch.complex32
)
samples = op.sample_inputs(device, dtype, requires_grad=requires_grad)

View File

@ -3583,7 +3583,7 @@ class TestFX(JitTestCase):
class LeafTracerNotB(Tracer):
def is_leaf_module(self, module, name):
return "b" not in name
return False if "b" in name else True
# Recompile calls added "for fun", since they
# chain __call__ wrappers.

View File

@ -2036,7 +2036,7 @@ class TestIndexing(TestCase):
index = torch.tensor([0], device=device)
x.index_fill_(1, index, 0)
self.assertEqual(x, torch.tensor([[0, 2], [0, 5]], dtype=dtype, device=device))
if not x.is_complex() and device != "meta":
if not x.is_complex() and not device == "meta":
with self.assertRaisesRegex(RuntimeError, r"Scalar"):
x.index_fill_(1, index, 1 + 1j)
# Make sure that the result stays 0-dim while applied to

View File

@ -6723,7 +6723,7 @@ a")
@torch.jit.script
def testNoThrows(t):
c1 = 1
if (False and bool(t[1])) or (True or bool(t[1])): # noqa: SIM222,SIM223
if (False and bool(t[1])) or (True or bool(t[1])):
c1 = 0
return c1
@ -15758,7 +15758,7 @@ dedent """
def fn(d):
# type: (Dict[str, int]) -> List[int]
out = [1]
for i in range(d.get("hi", 6)):
for i in range(d["hi"] if "hi" in d else 6):
out.append(i) # noqa: PERF402
return out
@ -16104,7 +16104,7 @@ M = 10
S = 5
def add_nn_module_test(*args, **kwargs):
no_grad = kwargs.get('no_grad', False)
no_grad = False if 'no_grad' not in kwargs else kwargs['no_grad']
if 'desc' in kwargs and 'eval' in kwargs['desc']:
# eval() is not supported, so skip these tests

View File

@ -111,7 +111,7 @@ class TestAutocast(JitTestCase):
def test_runtime_autocast_state_expr(self):
@torch.jit.script
def fn(a, b):
with autocast(enabled=bool((a[0][0] > 0.5).item())):
with autocast(enabled=True if a[0][0] > 0.5 else False):
return torch.mm(a, b)
# runtime values for autocast enable argument are not supported
with self.assertRaises(RuntimeError):

View File

@ -3522,7 +3522,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
nn.RNN(10, 20, batch_first=True)
]
# ROCm RNN does not issue warning about single contig chunk of memory, so don't assert it
first_warn = not torch.version.hip
first_warn = False if torch.version.hip else True
for rnn in rnns:
rnn.cuda()
input = torch.randn(5, 4, 10, requires_grad=True, device="cuda")

View File

@ -205,7 +205,7 @@ class TestNumPyInterop(TestCase):
x = x.conj()
y = x.resolve_conj()
expect_error = (
requires_grad or sparse or conj or device != "cpu"
requires_grad or sparse or conj or not device == "cpu"
)
error_msg = r"Use (t|T)ensor\..*(\.numpy\(\))?"
if not force and expect_error:

View File

@ -18,7 +18,7 @@ class PruningOpTest(TestCase):
def _generate_rowwise_mask(self, embedding_rows):
indicator = torch.from_numpy((np.random.random_sample(embedding_rows)).astype(np.float32))
threshold = float(np.random.random_sample())
mask = torch.BoolTensor([val >= threshold for val in indicator])
mask = torch.BoolTensor([True if val >= threshold else False for val in indicator])
return mask
def _test_rowwise_prune_op(self, embedding_rows, embedding_dims, indices_type, weights_dtype):

View File

@ -1899,7 +1899,7 @@ class TestReductions(TestCase):
# Note [all, any uint8 compatibility]: However for compatibility reason,
# for `uint8`, they return Tensor of same dtype `uint8`.
# Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561
exact_dtype = dtype != torch.uint8
exact_dtype = True if dtype != torch.uint8 else False
def _test_all_any(x):
self.compare_with_numpy(torch.all, np.all, x)

View File

@ -129,7 +129,7 @@ class TestSegmentReductions(TestCase):
for reduction in reductions:
for initial in [0, None]:
check_backward = initial is not None
check_backward = True if initial is not None else False
initial_value = initial
default_value = get_default_value(initial_value, reduction)
if reduction == "max":
@ -186,7 +186,7 @@ class TestSegmentReductions(TestCase):
for reduction in reductions:
for initial in [0, None]:
check_backward = initial is not None
check_backward = True if initial is not None else False
initial_value = initial
default_value = get_default_value(initial_value, reduction)
if reduction == "max":
@ -244,7 +244,7 @@ class TestSegmentReductions(TestCase):
for reduction in reductions:
for initial in [0, None]:
check_backward = initial is not None
check_backward = True if initial is not None else False
initial_value = initial
default_value = get_default_value(initial_value, reduction)
if reduction == "max":

View File

@ -4553,7 +4553,7 @@ class TestSerialization(TestCase, SerializationMixin):
with TemporaryFileName() as f:
torch.save(m, f)
try:
old_value = os.environ.get(env_var, None)
old_value = os.environ[env_var] if env_var in os.environ else None
os.environ[env_var] = "1"
# if weights_only is explicitly set, TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD cannot override it
with self.assertRaisesRegex(pickle.UnpicklingError, "Weights only load failed"):

View File

@ -3483,7 +3483,7 @@ class TestRandomTensorCreation(TestCase):
else:
t.uniform_(from_, to_)
range_ = to_ - from_
if dtype != torch.bfloat16 and not (
if not (dtype == torch.bfloat16) and not (
dtype == torch.half and device == 'cpu') and not torch.isnan(t).all():
delta = alpha * range_
double_t = t.to(torch.double)

View File

@ -100,7 +100,7 @@ class TestBuiltin(TestCase):
# dtypes results in False/True when compared to valid dtypes.
# Here 7 cannot be converted to dtype. No exceptions should be raised
assert np.dtype(np.int32) != 7, "dtype richcompare failed for =="
assert not np.dtype(np.int32) == 7, "dtype richcompare failed for =="
assert np.dtype(np.int32) != 7, "dtype richcompare failed for !="
@parametrize("operation", [operator.le, operator.lt, operator.ge, operator.gt])

View File

@ -401,7 +401,7 @@ allow_rnn = False
# exported FX graph. This flag should become the default eventually
# and be removed, but currently provides a way to fall back to old
# graph breaking behavior.
capture_sparse_compute = not is_fbcode()
capture_sparse_compute = False if is_fbcode() else True
# If true, error if we try to compile a function that has
# been seen before.

View File

@ -718,7 +718,11 @@ def validate_args_and_maybe_create_graph_inputs(
new_proxy = tracer.create_graph_input(
arg_name, a.python_type(), example_value
)
example_value = node.meta.get("example_value", None)
example_value = (
node.meta["example_value"]
if "example_value" in node.meta
else None
)
a = wrap_fx_proxy_cls(
target_cls=type(a),
tx=tx,
@ -756,7 +760,9 @@ def validate_args_and_maybe_create_graph_inputs(
# If `a` can be put into a graph
elif a.maybe_fx_node() is not None:
node = a.maybe_fx_node()
example_value = node.meta.get("example_value", None)
example_value = (
node.meta["example_value"] if "example_value" in node.meta else None
)
arg_name = node.name if sub_args_names is None else sub_args_names[idx]
new_proxy = tracer.create_graph_input(
arg_name, a.python_type(), example_value

View File

@ -270,7 +270,7 @@ backward_pass_autocast = "same_as_forward"
# This controls whether we collect donated buffer. This flag must be set
# False if a user wants to retain_graph=True for backward.
donated_buffer = not is_fbcode()
donated_buffer = False if is_fbcode() else True
# Controls the default graph output format used by draw_graph
# Supported formats are defined here https://graphviz.org/docs/outputs/

View File

@ -608,7 +608,8 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None:
# Use position-based lookup for building output
# only update the return node args, and remain all other users unchanged
output_updated_args = [
position_to_quant.get(i, node) for i, node in enumerate(fwd_outputs)
position_to_quant[i] if i in position_to_quant else node
for i, node in enumerate(fwd_outputs)
]
# add the scale nodes to the output find the first sym_node in the output
idx = find_first_sym_node(output_updated_args)

View File

@ -414,7 +414,10 @@ class JsonProfile:
elif isinstance(dtype, torch.dtype):
self.dtype = dtype
else:
self.dtype = _dtype_map.get(dtype)
if dtype in _dtype_map:
self.dtype = _dtype_map[dtype]
else:
self.dtype = None
self._create_devices()
def convert_dtype(self, event: dict[str, Any]) -> Optional[torch.dtype]:

View File

@ -482,11 +482,15 @@ def get_wrapper_codegen_for_device(
def get_custom_backend_pass_for_device(device: str) -> Optional[CustomGraphModulePass]:
return custom_backend_passes.get(device)
return custom_backend_passes[device] if device in custom_backend_passes else None
def get_custom_backend_config_for_device(device: str) -> Optional[ConfigModule]:
return custom_backend_codegen_configs.get(device)
return (
custom_backend_codegen_configs[device]
if device in custom_backend_codegen_configs
else None
)
@functools.cache

View File

@ -857,7 +857,7 @@ def parallel_compile_enabled_internally() -> bool:
jk_name = "pytorch/inductor:enable_parallel_compile_version"
version = torch._utils_internal.justknobs_getval_int(jk_name)
return version <= ENABLE_PARALLEL_COMPILE_VERSION
return ENABLE_PARALLEL_COMPILE_VERSION >= version
def decide_compile_threads() -> int:
@ -1259,7 +1259,7 @@ class triton:
cudagraph_trees_history_recording = False
# Enable cudagraph support for mutated inputs from prior cudagraph pool
cudagraph_support_input_mutation = not is_fbcode()
cudagraph_support_input_mutation = False if is_fbcode() else True
# Maximal number of allowed cudagraph re-record for a function and
# a cudagraph node due to static input tensor address changes or

View File

@ -476,7 +476,9 @@ def build_subgraph_buffer(
elif node.op == "call_function":
# For call_function we use the default lowerings and pass in the
# already created TensorBoxes as args
args, kwargs = tree_map(lambda x: env.get(x, x), (node.args, node.kwargs))
args, kwargs = tree_map(
lambda x: env[x] if x in env else x, (node.args, node.kwargs)
)
env[node] = lowerings[node.target](*args, **kwargs)
elif node.op == "output":
@ -689,7 +691,9 @@ def b2b_gemm_handler(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node) ->
for node in graph.nodes: # preserve the order of nodes
if node in subgraph_node_set:
subgraph_node_list.append(node)
new_node = new_graph.node_copy(node, lambda x: node_remapping.get(x, x))
new_node = new_graph.node_copy(
node, lambda x: node_remapping[x] if x in node_remapping else x
)
node_remapping[node] = new_node
if node is inner_mm:
new_input_anchor = new_node

View File

@ -531,7 +531,7 @@ def _register_quantized_linear_unary_lowering(
)
# bias
b = kwargs.get("b")
b = kwargs["b"] if "b" in kwargs else None
# Output QParams
o_inv_scale = kwargs["output_scale"]
@ -593,7 +593,7 @@ def _register_quantized_linear_binary_lowering(
kwargs["w_zp"],
)
# bias
b = kwargs.get("b")
b = kwargs["b"] if "b" in kwargs else None
# Output QParams
o_inv_scale = kwargs["output_scale"]
o_zero_point = kwargs["output_zero_point"]
@ -885,10 +885,10 @@ def _register_quantized_maxpool2d_lowering(
def qmaxpool2d(match: Match, *args, **kwargs):
x = kwargs["x"]
kernel_size = kwargs["kernel_size"]
stride = kwargs.get("stride")
padding = kwargs.get("padding", 0)
dilation = kwargs.get("dilation", 1)
ceil_mode = kwargs.get("ceil_mode", False)
stride = kwargs["stride"] if ("stride" in kwargs) else None
padding = kwargs["padding"] if ("padding" in kwargs) else 0
dilation = kwargs["dilation"] if ("dilation" in kwargs) else 1
ceil_mode = kwargs["ceil_mode"] if ("ceil_mode" in kwargs) else False
if padding == 0:
padding = [0, 0]
@ -1976,7 +1976,7 @@ def _register_qlinear_weight_prepack_pass(
)
# Params
bias = kwargs.get("b")
bias = kwargs["b"] if "b" in kwargs else None
x_shape = qx.meta.get("tensor_meta").shape
if has_free_symbols(x_shape):
@ -2451,7 +2451,7 @@ def _register_linear_dynamic_fp16_weight_prepack_pass(
# find params
x = kwargs["x"]
w = kwargs["w"]
bias = kwargs.get("b")
bias = kwargs["b"] if "b" in kwargs else None
# find linear node
nodes_to_find = [aten.addmm.default, aten.mm.default, aten.bmm.default]
@ -2727,7 +2727,7 @@ def _register_smooth_quant_int_mm_pattern():
pass_number=pass_number,
)
def _int_mm_weight_prepack(match: Match, *args, **kwargs):
bias = kwargs.get("bias")
bias = kwargs.get("bias", None)
x = kwargs["a"]
weight = kwargs["b"]
dtype = kwargs["dtype"]
@ -2794,7 +2794,7 @@ def _register_smooth_quant_int_mm_pattern():
else:
# onednn.qlinear does not support per-channel quantization of x
# so in this case, we have to apply x scale and add bias ourselves after qlinear
in_shape = kwargs.get("in_shape")
in_shape = kwargs.get("in_shape", None)
if in_shape is None:
x_reshaped = x
else:
@ -2826,8 +2826,8 @@ def _register_smooth_quant_int_mm_pattern():
# Add bias and reshape
has_outer_reshape = (
kwargs.get("out_shape_with_bias") is not None
or kwargs.get("out_shape_no_bias") is not None
kwargs.get("out_shape_with_bias", None) is not None
or kwargs.get("out_shape_no_bias", None) is not None
)
if has_outer_reshape:
@ -3276,7 +3276,7 @@ def _register_qlinear_post_op_fusion_pass(
)
# bias
b = kwargs.get("b")
b = kwargs["b"] if "b" in kwargs else None
# Output QParams
o_inv_scale = (

View File

@ -1074,13 +1074,13 @@ def _overload_method(func):
_check_overload_body(func)
qual_name = _qualified_name(func)
global _overloaded_methods
class_name_map = _overloaded_methods.get(qual_name)
class_name_map = _overloaded_methods.get(qual_name, None)
if class_name_map is None:
class_name_map = {}
_overloaded_methods[qual_name] = class_name_map
class_name, line_no = get_class_name_lineno(func)
method_overloads = class_name_map.get(class_name)
method_overloads = class_name_map.get(class_name, None)
if method_overloads is None:
method_overloads = []
class_name_map[class_name] = method_overloads
@ -1102,7 +1102,7 @@ def _get_overloaded_methods(method, mod_class):
if not hasattr(method, "__name__"):
return None
qual_name = _qualified_name(method)
class_name_map = _overloaded_methods.get(qual_name)
class_name_map = _overloaded_methods.get(qual_name, None)
if class_name_map is None:
return None
overloads = class_name_map.get(mod_class.__name__, None)

View File

@ -5303,7 +5303,7 @@ def grid_sampler_3d_backward(
@register_meta([aten.full.default])
def full(size, fill_value, *args, **kwargs):
dtype = kwargs.get("dtype")
dtype = kwargs.get("dtype", None)
if not dtype:
dtype = utils.get_dtype(fill_value)
kwargs["dtype"] = dtype

View File

@ -1414,7 +1414,7 @@ class _HigherOrderNamespace(types.ModuleType):
def __getattr__(self, name: str) -> HigherOrderOperator:
# Following _OpNamespace.__getattr__, we cache the op on this object.
op = _higher_order_ops.get(name)
op = _higher_order_ops.get(name, None)
if op is None:
raise AttributeError(
f"'_HigherOrderNamespace' 'torch.ops.higher_order' object has no attribute '{name}'"

View File

@ -84,8 +84,12 @@ class ReferenceQuantizedModule(torch.nn.Module):
# store weight_axis as weight_axis_int due to some constraints of torchdynamo.export
# for capturing `.item` operations
self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment]
self.weight_quant_min: typing.Optional[int] = weight_qparams.get("quant_min")
self.weight_quant_max: typing.Optional[int] = weight_qparams.get("quant_max")
self.weight_quant_min: typing.Optional[int] = weight_qparams.get(
"quant_min", None
)
self.weight_quant_max: typing.Optional[int] = weight_qparams.get(
"quant_max", None
)
def get_weight(self):
"""

View File

@ -240,29 +240,29 @@ scale_min_lower_bound=None, scale_max_upper_bound=None)
"bias_type": torch.dtype
"is_dynamic": bool
"""
input_dtype = dtype_config_dict.get(INPUT_DTYPE_DICT_KEY)
input_dtype = dtype_config_dict.get(INPUT_DTYPE_DICT_KEY, None)
if input_dtype is not None and not isinstance(
input_dtype, (torch.dtype, DTypeWithConstraints)
):
raise ValueError(
"Expected input_dtype to be a torch.dtype or DTypeWithConstraints"
)
output_dtype = dtype_config_dict.get(OUTPUT_DTYPE_DICT_KEY)
output_dtype = dtype_config_dict.get(OUTPUT_DTYPE_DICT_KEY, None)
if output_dtype is not None and not isinstance(
output_dtype, (torch.dtype, DTypeWithConstraints)
):
raise ValueError(
"Expected output_dtype to be a torch.dtype or DTypeWithConstraints"
)
weight_dtype = dtype_config_dict.get(WEIGHT_DTYPE_DICT_KEY)
weight_dtype = dtype_config_dict.get(WEIGHT_DTYPE_DICT_KEY, None)
if weight_dtype is not None and not isinstance(
weight_dtype, (torch.dtype, DTypeWithConstraints)
):
raise ValueError(
"Expected weight_dtype to be a torch.dtype or DTypeWithConstraints"
)
bias_dtype = dtype_config_dict.get(BIAS_DTYPE_DICT_KEY)
is_dynamic = dtype_config_dict.get(IS_DYNAMIC_DICT_KEY)
bias_dtype = dtype_config_dict.get(BIAS_DTYPE_DICT_KEY, None)
is_dynamic = dtype_config_dict.get(IS_DYNAMIC_DICT_KEY, None)
return cls(input_dtype, output_dtype, weight_dtype, bias_dtype, is_dynamic)
def to_dict(self) -> dict[str, Any]:
@ -673,23 +673,23 @@ class BackendPatternConfig:
for d in backend_pattern_config_dict.get(DTYPE_CONFIGS_DICT_KEY, []):
conf.add_dtype_config(_get_dtype_config(d))
conf.set_root_module(
backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY) # type: ignore[arg-type]
backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY, None) # type: ignore[arg-type]
)
conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY)) # type: ignore[arg-type]
conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY, None)) # type: ignore[arg-type]
conf.set_reference_quantized_module(
backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY) # type: ignore[arg-type]
backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY, None) # type: ignore[arg-type]
)
conf.set_fused_module(
backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY) # type: ignore[arg-type]
backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY, None) # type: ignore[arg-type]
)
conf.set_fuser_method(
backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY) # type: ignore[arg-type]
backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY, None) # type: ignore[arg-type]
)
conf._set_root_node_getter(
backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY) # type: ignore[arg-type]
backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY, None) # type: ignore[arg-type]
)
conf._set_extra_inputs_getter(
backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY) # type: ignore[arg-type]
backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY, None) # type: ignore[arg-type]
)
conf._set_num_tensor_args_to_observation_type(
backend_pattern_config_dict.get(

View File

@ -286,7 +286,7 @@ def get_fuser_method_new(
op_patterns = _get_valid_patterns(op_pattern)
fuser_method = None
for op_pattern in op_patterns:
fuser_method = fuser_method_mapping.get(op_pattern)
fuser_method = fuser_method_mapping.get(op_pattern, None)
if fuser_method is not None:
break
assert fuser_method is not None, f"did not find fuser method for: {op_pattern} "

View File

@ -168,7 +168,7 @@ def _find_matches(
for node in reversed(graph.nodes):
if node.name not in match_map and node.name not in all_matched:
for pattern, quantize_handler_cls in patterns.items():
root_node_getter = root_node_getter_mapping.get(pattern)
root_node_getter = root_node_getter_mapping.get(pattern, None)
if _is_match(modules, node, pattern) and node.name not in match_map:
matched_node_pattern: list[Node] = []
record_match(pattern, node, node, matched_node_pattern, match_map)

View File

@ -130,7 +130,7 @@ def _get_qspec_for_arg(
) -> Optional[QuantizationSpecBase]:
while _is_activation_post_process_node(arg, named_modules):
arg = arg.args[0] # type: ignore[assignment]
return input_qspec_map.get(arg)
return input_qspec_map.get(arg, None)
def _create_obs_or_fq_from_qspec(

View File

@ -163,7 +163,7 @@ def get_qconv_prepack_op(conv_op: Callable) -> Callable:
torch.nn.functional.conv_transpose2d: torch.ops.quantized.conv_transpose2d_prepack,
torch.nn.functional.conv_transpose3d: torch.ops.quantized.conv_transpose3d_prepack,
}
prepack_op = prepack_ops.get(conv_op)
prepack_op = prepack_ops.get(conv_op, None)
assert prepack_op, f"Didn't find prepack op for {conv_op}"
return prepack_op

View File

@ -803,7 +803,7 @@ class PerChannelMinMaxObserver(UniformQuantizationObserverBase):
unexpected_keys: list[str],
error_msgs: list[str],
):
version = local_metadata.get("version")
version = local_metadata.get("version", None)
if version is not None and version < 3:
local_state = ["min_vals", "max_vals"]
expected_min_name = "min_vals"

View File

@ -366,7 +366,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
if input_edge_obs_or_fq is None:
return new_arg
arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg)
arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg, None)
# the arg is observed as the output and is using the same instance as the input_edge
# we'll reuse the inserted observer/fake_quant
if arg_as_output_obs_or_fq is not None and id(arg_as_output_obs_or_fq) == id(
@ -497,7 +497,11 @@ def _maybe_insert_input_and_output_observers_for_node(
is_qat: bool,
model_device: Optional[torch.device] = None,
):
this_node_quantization_annotation = node.meta.get("quantization_annotation", None)
this_node_quantization_annotation = (
node.meta["quantization_annotation"]
if "quantization_annotation" in node.meta
else None
)
if this_node_quantization_annotation is None:
return

View File

@ -343,7 +343,7 @@ def get_default_float_to_quantized_operator_mappings() -> dict[
# TODO: merge with get_static_quant_module_class
def get_quantized_operator(float_op: Union[Callable, str]) -> Callable:
"""Get the quantized operator corresponding to the float operator"""
quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op)
quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None)
assert quantized_op is not None, (
f"Operator {str(float_op)} does not have corresponding quantized op"
)

View File

@ -1352,8 +1352,10 @@ class X86InductorQuantizer(Quantizer):
def _annotate_output_share_observer_as_input(
self, input_node: Node, source_node: Node
):
source_node_quantization_annotation = source_node.meta.get(
QUANT_ANNOTATION_KEY, None
source_node_quantization_annotation = (
source_node.meta[QUANT_ANNOTATION_KEY]
if QUANT_ANNOTATION_KEY in source_node.meta
else None
)
if (
source_node_quantization_annotation
@ -1393,8 +1395,10 @@ class X86InductorQuantizer(Quantizer):
return
# Get the quantization_annotation from getitem_node
maxpool_node_quantization_annotation = maxpool_node.meta.get(
QUANT_ANNOTATION_KEY, None
maxpool_node_quantization_annotation = (
maxpool_node.meta[QUANT_ANNOTATION_KEY]
if QUANT_ANNOTATION_KEY in maxpool_node.meta
else None
)
if (
maxpool_node_quantization_annotation

View File

@ -162,7 +162,10 @@ class EventList(list):
if p is not None:
assert p.fwd_thread is not None
t = (p.sequence_nr, p.fwd_thread)
evt.stack = fwd_stacks.get(t, [])
if t in fwd_stacks:
evt.stack = fwd_stacks[t]
else:
evt.stack = []
@property
def self_cpu_time_total(self):

View File

@ -214,7 +214,7 @@ def replicate(
state = replicate.state(module)
module.register_forward_pre_hook(state.forward_pre_hook, with_kwargs=True)
device_mesh = kwargs.get("device_mesh")
device_mesh = kwargs.get("device_mesh", None)
if device_mesh is not None:
from torch.distributed.device_mesh import _mesh_resources

View File

@ -226,7 +226,7 @@ def replicate_impl(
# Place Replicate leftmost for highest priority in the method resolution order
for module in modules:
cls = module.__class__
new_cls = cls_to_replicate_cls.get(cls)
new_cls = cls_to_replicate_cls.get(cls, None)
if not new_cls:
dct = {"__deepcopy__": _unimplemented_deepcopy}
new_cls = type(f"Replicate{cls.__name__}", (ReplicateModule, cls), dct)

View File

@ -131,7 +131,7 @@ def register_tensor_creation_op(op):
takes a ShardedTensor as argument, such as ``torch.zeros_like`` or
``torch.full_like``.
"""
creation_op = tensor_like_creation_op_map.get(op)
creation_op = tensor_like_creation_op_map.get(op, None)
if creation_op is None:
raise RuntimeError(f"Tensor creation {op} not supported!")
if kwargs is None:

View File

@ -676,7 +676,7 @@ class ShardedTensor(ShardedTensorBase):
copy_tensor = kwargs.get("copy", False)
non_blocking = kwargs.get("non_blocking", False)
memory_format = kwargs.get("memory_format", torch.preserve_format)
process_group = kwargs.get("process_group")
process_group = kwargs.get("process_group", None)
if (
not copy_tensor

View File

@ -596,7 +596,7 @@ def _distribute_tensors(
if pg is None:
pg = dist.distributed_c10d._get_default_group()
for key in keys:
_local_state = local_state_dict.get(key)
_local_state = local_state_dict.get(key, None)
if _local_state is None or torch.is_tensor(_local_state):
continue
@ -706,7 +706,7 @@ def _distribute_state_dict(
local_state_dict[key] = value.cpu()
else:
assert isinstance(value, torch.Tensor)
local_state = local_state_dict.get(key)
local_state = local_state_dict.get(key, None)
if local_state is None:
continue
elif isinstance(local_state, DTensor):

View File

@ -127,7 +127,7 @@ def aggregate_stats(
}
for mod in model.modules():
if mod_mem_stat := mod_mem_stats.get(mod):
if mod_mem_stat := mod_mem_stats.get(mod, None):
if tradeoff_stats := mod_sac_tradeoff_stats.get(mod_mem_stat.mod_fqn, None):
sac_runtime = tradeoff_stats.sac_runtime
sac_memory = tradeoff_stats.sac_memory

View File

@ -710,7 +710,7 @@ class SACEstimator(TorchDispatchMode):
str(i in sac_stats.view_like_ops),
str(i in sac_stats.rand_ops),
str(i in sac_stats.saved_autograd_ops),
str(op_parent.get(i)),
str(op_parent.get(i, None)),
]
table_data.append(row)
# Define headers

View File

@ -106,7 +106,7 @@ def auto_quantize(func, qtype, quant_loss=None):
@functools.wraps(func)
def wrapper(*args, **kwargs):
group = kwargs.get("group")
group = kwargs.get("group", None)
async_op = kwargs.get("async_op", False)
if async_op is True:
raise RuntimeError("The async_op=True mode is not supported yet.")
@ -132,8 +132,8 @@ def auto_quantize(func, qtype, quant_loss=None):
elif func == dist.all_to_all_single:
tensors = args[0]
out_splits = kwargs.get("out_splits")
in_splits = kwargs.get("in_splits")
out_splits = kwargs.get("out_splits", None)
in_splits = kwargs.get("in_splits", None)
# Quantizing the input/output tensor
input_tensors = _quantize_tensor(args[1], qtype)
out_tensors = _quantize_tensor(tensors, qtype)

View File

@ -631,7 +631,7 @@ class _FileSystemWriter(StorageWriter):
def set_up_storage_writer(
self, is_coordinator: bool, *args: Any, **kwargs: Any
) -> None:
self.rank = kwargs.get("rank")
self.rank = kwargs.get("rank", None)
self.use_collectives = kwargs.get("use_collectives", True)
def _metadata_exists(self) -> bool:
@ -919,7 +919,7 @@ class FileSystemReader(StorageReader):
# Implementing the abstract function in StorageReader
def read_metadata(self, *args: Any, **kwargs: Any) -> Metadata:
rank = kwargs.get("rank")
rank = kwargs.get("rank", None)
path = self._get_metadata_path(rank)
with self.fs.create_stream(path, "rb") as metadata_file:
metadata = pickle.load(metadata_file)
@ -934,7 +934,7 @@ class FileSystemReader(StorageReader):
self, metadata: Metadata, is_coordinator: bool, *args: Any, **kwargs: Any
) -> None:
self.storage_data = metadata.storage_data
self.rank = kwargs.get("rank")
self.rank = kwargs.get("rank", None)
self.use_collectives = kwargs.get("use_collectives", True)
assert self.storage_data is not None

View File

@ -30,11 +30,11 @@ def _msg_dict_from_dcp_method_args(*args, **kwargs) -> dict[str, Any]:
msg_dict = {}
# checkpoint ID can be passed in through the serializer or through the checkpoint id directly
storage_writer = kwargs.get("storage_writer")
storage_reader = kwargs.get("storage_reader")
planner = kwargs.get("planner")
storage_writer = kwargs.get("storage_writer", None)
storage_reader = kwargs.get("storage_reader", None)
planner = kwargs.get("planner", None)
checkpoint_id = kwargs.get("checkpoint_id")
checkpoint_id = kwargs.get("checkpoint_id", None)
if not checkpoint_id and (serializer := storage_writer or storage_reader):
checkpoint_id = getattr(serializer, "checkpoint_id", None)

View File

@ -305,7 +305,7 @@ def _verify_options(
continue
fqns = _get_fqns(model, name)
fqn = fqn_param_mapping.get(param)
fqn = fqn_param_mapping.get(param, None)
if fqn is not None:
cast(set[str], fqn_param_mapping[param]).update(fqns)
shared_params_mapping[param] = fqn_param_mapping[param]

View File

@ -195,7 +195,7 @@ else:
# A root mesh is not created through slicing.
# We considers the root mesh of a root mesh is itself.
root_mesh = self.child_to_root_mapping.get(device_mesh, None)
return root_mesh if root_mesh else device_mesh
return device_mesh if not root_mesh else root_mesh
def get_root_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]:
"""

View File

@ -5042,7 +5042,7 @@ def _is_safe_to_split() -> bool:
users must be aware that a pg is only splittable after the first collective is
issued.
"""
return _get_default_group().bound_device_id is not None
return False if _get_default_group().bound_device_id is None else True
@_time_logger

View File

@ -88,7 +88,10 @@ def configure(handler: MetricHandler, group: Optional[str] = None):
def getStream(group: str):
handler = _metrics_map.get(group, _default_metrics_handler)
if group in _metrics_map:
handler = _metrics_map[group]
else:
handler = _default_metrics_handler
return MetricStream(group, handler)

View File

@ -239,7 +239,7 @@ def fully_shard(
# Place FSDP leftmost for highest priority in the method resolution order
for module in modules:
cls = module.__class__
new_cls = cls_to_fsdp_cls.get(cls)
new_cls = cls_to_fsdp_cls.get(cls, None)
if not new_cls:
dct = {"__deepcopy__": _unimplemented_deepcopy}
new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct)

View File

@ -1267,7 +1267,7 @@ def _is_named_optimizer(optim_state_dict: dict[str, Any]) -> bool:
(which usually are FQNs) versus integers (which usually refer to param_ids
from a vanilla torch.optim.Optimizer).
"""
state = optim_state_dict.get("state")
state = optim_state_dict.get("state", None)
if not state:
# If we cannot find a state, assume it is not NamedOptimizer as
# NamedOptimizer has eager initialization.
@ -1715,7 +1715,7 @@ def _convert_state_with_orig_params(
# across ranks
for optim_state_key in all_optim_state_keys:
param_key: Union[str, int, None] = optim_state_key_to_param_key.get(
optim_state_key
optim_state_key, None
)
if param_key is None and not optim_state_key.is_fsdp_managed:
@ -1723,7 +1723,7 @@ def _convert_state_with_orig_params(
if optim_state_key.is_fsdp_managed:
fqn = optim_state_key.unflat_param_names[0]
fsdp_param_info = fqn_to_fsdp_param_info.get(fqn)
fsdp_param_info = fqn_to_fsdp_param_info.get(fqn, None)
if fsdp_param_info is None:
# This can happen if the not all FSDP instances have all the
# parameters. This can happen with FSDP + some MPMD style
@ -1801,7 +1801,7 @@ def _convert_state_with_flat_params(
# across ranks
for optim_state_key in all_optim_state_keys:
param_key: Union[str, int, None] = optim_state_key_to_param_key.get(
optim_state_key
optim_state_key, None
)
assert param_key is not None, (

View File

@ -52,7 +52,8 @@ class _ScriptLocalOptimizer(nn.Module):
all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
# apply functional optimizer step with a list of gradients
grads: list[Optional[Tensor]] = [
all_local_grads.get(p, None) for p in self._local_params
all_local_grads[p] if p in all_local_grads else None
for p in self._local_params
]
self.optim.step(grads)

View File

@ -189,7 +189,7 @@ def _insert_stage_symbolic_backward(
output_grads: Union[tuple[Optional[fx.Node], ...], Optional[fx.Node]]
if node in tuples:
stage_output = tuples[node]
output_grads = tuple(val_to_grad.get(n) for n in tuples[node])
output_grads = tuple(val_to_grad.get(n, None) for n in tuples[node])
outputs_with_grads_idxs = [
i for i, n in enumerate(tuples[node]) if n in live_nodes
]

View File

@ -114,7 +114,7 @@ def get_param_groups(
"intermediates": intersected,
}
for input_node in intersected:
existing = param_groups.get(input_node)
existing = param_groups.get(input_node, None)
if existing is not None:
existing["params"] = existing["params"].union(param_group["params"])
existing["intermediates"] = existing["intermediates"].union(

View File

@ -326,7 +326,8 @@ def _insert_copy_for_mutations(
return_nodes_to_copy[return_node] = copy_node
output_args = tuple(
return_nodes_to_copy.get(node, node) for node in user_output_nodes
return_nodes_to_copy[node] if node in return_nodes_to_copy else node
for node in user_output_nodes
)
with gm.graph.inserting_before(output_node):
# Only return user outputs

View File

@ -46,7 +46,7 @@ class NormalizeArgs(Transformer):
def get_type(arg):
if isinstance(arg, fx.Node):
return n.meta.get("type", None)
return n.meta["type"] if "type" in n.meta else None
return type(arg)
arg_types = map_aggregate(n.args, get_type)

View File

@ -4378,7 +4378,7 @@ class ShapeEnv:
size = []
for i, val in enumerate(tensor_size):
sym = self.create_symbol(
hint_overrides.get(i, val),
val if i not in hint_overrides else hint_overrides[i],
TensorPropertySource(source, TensorProperty.SIZE, i),
dynamic_dims[i],
constraint_dims[i],
@ -4579,7 +4579,7 @@ class ShapeEnv:
sym_sizes = [
self.create_symintnode(
sym,
hint=hint_overrides.get(i, hint),
hint=hint if i not in hint_overrides else hint_overrides[i],
source=TensorPropertySource(source, TensorProperty.SIZE, i),
)
for i, (sym, hint) in enumerate(zip(size, ex_size))

View File

@ -923,7 +923,7 @@ class ParameterDict(Module):
key (str): key to get from the ParameterDict
default (Parameter, optional): value to return if key not present
"""
return self[key] if key in self else default # noqa: SIM401
return self[key] if key in self else default
def fromkeys(
self, keys: Iterable[str], default: Optional[Any] = None

View File

@ -218,7 +218,7 @@ def _dump_DDP_relevant_env_vars():
]
formatted_output = ""
for var in relevant_env_vars:
value = os.environ.get(var, "N/A")
value = os.environ[var] if var in os.environ else "N/A"
formatted_output += f"env:{var}={value}\n"
print(formatted_output)

View File

@ -774,8 +774,8 @@ class Optimizer:
assert param_groups is not None
for pg in param_groups:
if param_id in pg["params"]:
fused = pg.get("fused", False)
capturable = pg.get("capturable", False)
fused = pg["fused"] if "fused" in pg else False
capturable = pg["capturable"] if "capturable" in pg else False
break
if key == "step":
if capturable or fused:

View File

@ -390,8 +390,8 @@ class DeviceTypeTestBase(TestCase):
return test.tolerance_overrides.get(dtype, tol(self.precision, self.rel_tol))
def _apply_precision_override_for_test(self, test, param_kwargs):
dtype = param_kwargs.get("dtype", None)
dtype = param_kwargs.get("dtypes", dtype)
dtype = param_kwargs["dtype"] if "dtype" in param_kwargs else None
dtype = param_kwargs["dtypes"] if "dtypes" in param_kwargs else dtype
if dtype:
self.precision = self._get_precision_override(test, dtype)
self.precision, self.rel_tol = self._get_tolerance_override(test, dtype)

View File

@ -1915,7 +1915,7 @@ def sample_inputs_new_full(self, device, dtype, requires_grad, **kwargs):
for sample in sample_inputs_new_fns(self, device, dtype, requires_grad, **kwargs):
# The scalar we are passing to new_full must be the same dtype
# as the one of the resulting tensor
use_dtype = sample.kwargs.get('dtype', dtype)
use_dtype = sample.kwargs['dtype'] if 'dtype' in sample.kwargs else dtype
yield SampleInput(
sample.input, *sample.args, get_val(use_dtype), **sample.kwargs)

View File

@ -725,7 +725,7 @@ class DistributedTest:
lines = out.getvalue().splitlines()
def format_line(var):
return f"env:{var}={os.environ.get(var, 'N/A')}"
return f"env:{var}={os.environ[var] if var in os.environ else 'N/A'}"
# Check relevant env vars
vars = [
@ -6212,7 +6212,7 @@ class DistributedTest:
)
def test_ddp_logging_data_cpu(self):
def parse_env(var):
return os.environ.get(var, "N/A")
return os.environ[var] if var in os.environ else "N/A"
dist.set_debug_level(dist.DebugLevel.INFO)
_, group_id, _ = self._init_global_test()

View File

@ -21,7 +21,7 @@ INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le)
def mirror_rel_op(type: type) -> Optional[type[sympy.Rel]]:
return _MIRROR_REL_OP.get(type)
return _MIRROR_REL_OP.get(type, None)
# Tries to simplify 'expr', so as to leave only 'thing' in the left-hand side.

View File

@ -277,7 +277,7 @@ def create_graph(objects, *, context=None, filter=None):
references = annotated_references(obj)
for referrent in gc.get_referents(obj):
rid = id(referrent)
tidx = id_to_node.get(rid)
tidx = id_to_node.get(rid, None)
if tidx is None:
continue
labels = references.get(rid, ["?"])

View File

@ -150,7 +150,7 @@ def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> str | None:
assert schema.kind() == SchemaKind.inplace
if not is_mutated_arg(schema.arguments.flat_all[0]):
return None
if len([arg for arg in schema.arguments.flat_all if is_mutated_arg(arg)]) != 1:
if not len([arg for arg in schema.arguments.flat_all if is_mutated_arg(arg)]) == 1:
return None
# Only support cases where all returns are Tensors or vector<Tensor>