Enable all SIM rules except disabled ones (#164645)

`SIM` rules are useful for simplifying boolean expressions and enhances code readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164645
Approved by: https://github.com/ezyang, https://github.com/mlazos
This commit is contained in:
Yuanyuan Chen 2025-10-17 07:27:06 +00:00 committed by PyTorch MergeBot
parent f1d882212a
commit e925dfcc6b
54 changed files with 98 additions and 134 deletions

View File

@ -1092,7 +1092,7 @@ class GitHubPR:
editor = node["editor"] editor = node["editor"]
return GitHubComment( return GitHubComment(
body_text=node["bodyText"], body_text=node["bodyText"],
created_at=node["createdAt"] if "createdAt" in node else "", created_at=node.get("createdAt", ""),
author_login=node["author"]["login"], author_login=node["author"]["login"],
author_url=node["author"].get("url", None), author_url=node["author"].get("url", None),
author_association=node["authorAssociation"], author_association=node["authorAssociation"],

View File

@ -4060,7 +4060,7 @@ def run(runner, args, original_dir=None):
else: else:
optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython) optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython)
experiment = ( experiment = (
speedup_experiment if not args.backend == "torchao" else latency_experiment speedup_experiment if args.backend != "torchao" else latency_experiment
) )
if args.accuracy: if args.accuracy:
output_filename = f"accuracy_{args.backend}.csv" output_filename = f"accuracy_{args.backend}.csv"

View File

@ -271,7 +271,7 @@ def run_single_backend_sdpa(
if config.calculate_bwd_time: if config.calculate_bwd_time:
# TODO: debug backward pass for njt # TODO: debug backward pass for njt
if eager_sdpa and not config.attn_type == "document_mask": if eager_sdpa and config.attn_type != "document_mask":
d_out = torch.randn_like(out_eager.transpose(1, 2)).transpose(1, 2) d_out = torch.randn_like(out_eager.transpose(1, 2)).transpose(1, 2)
backward_eager_time = benchmark_torch_function_in_microseconds( backward_eager_time = benchmark_torch_function_in_microseconds(
out_eager.backward, d_out, retain_graph=True out_eager.backward, d_out, retain_graph=True

View File

@ -180,6 +180,7 @@ ignore = [
"SIM116", # Disable Use a dictionary instead of consecutive `if` statements "SIM116", # Disable Use a dictionary instead of consecutive `if` statements
"SIM117", "SIM117",
"SIM118", "SIM118",
"SIM300", # Yoda condition detected
"UP007", # keep-runtime-typing "UP007", # keep-runtime-typing
"UP045", # keep-runtime-typing "UP045", # keep-runtime-typing
"TC006", "TC006",
@ -195,8 +196,7 @@ select = [
"E", "E",
"EXE", "EXE",
"F", "F",
"SIM1", "SIM",
"SIM911",
"W", "W",
# Not included in flake8 # Not included in flake8
"FURB", "FURB",

View File

@ -55,7 +55,7 @@ class TestActivationSparsifier(TestCase):
for key, config in sparsifier_defaults.items(): for key, config in sparsifier_defaults.items():
# all the keys in combined_defaults should be present in sparsifier defaults # all the keys in combined_defaults should be present in sparsifier defaults
assert config == combined_defaults.get(key, None) assert config == combined_defaults.get(key)
def _check_register_layer( def _check_register_layer(
self, activation_sparsifier, defaults, sparse_config, layer_args_list self, activation_sparsifier, defaults, sparse_config, layer_args_list

View File

@ -3074,7 +3074,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
wrong_dtype_shards, [10, 10], init_rrefs=True wrong_dtype_shards, [10, 10], init_rrefs=True
) )
tensor_requires_grad = True if self.rank == 0 else False tensor_requires_grad = self.rank == 0
wrong_requires_grad_shards = [ wrong_requires_grad_shards = [
sharded_tensor.Shard( sharded_tensor.Shard(
torch.randn( torch.randn(
@ -3121,7 +3121,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
wrong_pin_memory_local_shards, [10, 10], init_rrefs=True wrong_pin_memory_local_shards, [10, 10], init_rrefs=True
) )
tensor_pin_memory = True if self.rank == 0 else False tensor_pin_memory = self.rank == 0
wrong_pin_memory_shards_cross_ranks = [ wrong_pin_memory_shards_cross_ranks = [
sharded_tensor.Shard( sharded_tensor.Shard(
torch.randn(5, 5, pin_memory=tensor_pin_memory), local_shard_metadata torch.randn(5, 5, pin_memory=tensor_pin_memory), local_shard_metadata

View File

@ -152,7 +152,7 @@ class TestStorageBase:
self.rank = 0 if not dist.is_initialized() else dist.get_rank() self.rank = 0 if not dist.is_initialized() else dist.get_rank()
def _get_ranks(self, name): def _get_ranks(self, name):
return self.fail_conf[name] if name in self.fail_conf else None return self.fail_conf.get(name, None)
def _fail_rank(self, name): def _fail_rank(self, name):
ranks = self._get_ranks(name) ranks = self._get_ranks(name)

View File

@ -155,7 +155,7 @@ class TestFreezingWeights(FSDPTest):
ddp_kwargs = { ddp_kwargs = {
"device_ids": [self.rank], "device_ids": [self.rank],
"find_unused_parameters": True if disable_autograd else False, "find_unused_parameters": bool(disable_autograd),
} }
model = self._create_model( model = self._create_model(

View File

@ -66,7 +66,7 @@ class MockPipelineStage(_PipelineStageBase):
self.num_stages = kwargs.get("num_stages", 1) self.num_stages = kwargs.get("num_stages", 1)
self.group_size = kwargs.get("group_size", 1) self.group_size = kwargs.get("group_size", 1)
self.group_rank = kwargs.get("group_rank", 0) self.group_rank = kwargs.get("group_rank", 0)
self.group = kwargs.get("group", None) self.group = kwargs.get("group")
def _create_grad_recv_info(self, *args, **kwargs): def _create_grad_recv_info(self, *args, **kwargs):
return None return None

View File

@ -1066,7 +1066,7 @@ class TestDTensorPlacementTypes(DTensorTestBase):
assert_array_equal(expected_pad_sizes, pad_sizes) assert_array_equal(expected_pad_sizes, pad_sizes)
is_tensor_empty = [ is_tensor_empty = [
False if splitted_tensor.numel() > 0 else True not splitted_tensor.numel() > 0
for splitted_tensor in splitted_tensor_list for splitted_tensor in splitted_tensor_list
] ]
expected_is_tensor_empty = [True] * self.world_size expected_is_tensor_empty = [True] * self.world_size
@ -1089,12 +1089,10 @@ class TestDTensorPlacementTypes(DTensorTestBase):
for i, tensor in enumerate(splitted_tensor_list) for i, tensor in enumerate(splitted_tensor_list)
] ]
expected_is_tensor_empty = [ expected_is_tensor_empty = [
False if idx < size else True not idx < size for idx, _ in enumerate(range(self.world_size))
for idx, _ in enumerate(range(self.world_size))
] ]
is_tensor_empty = [ is_tensor_empty = [
False if unpadded_tensor.numel() > 0 else True not unpadded_tensor.numel() > 0 for unpadded_tensor in unpadded_list
for unpadded_tensor in unpadded_list
] ]
assert_array_equal(expected_is_tensor_empty, is_tensor_empty) assert_array_equal(expected_is_tensor_empty, is_tensor_empty)

View File

@ -2770,11 +2770,7 @@ class WorkHookTest(MultiProcessTestCase):
# from rank0 to other ranks. However, this is DDP's internal implementation, # from rank0 to other ranks. However, this is DDP's internal implementation,
# which is subject to change in future versions. # which is subject to change in future versions.
self.assertTrue(num_hook_fired[OpType.BROADCAST] > 0) self.assertTrue(num_hook_fired[OpType.BROADCAST] > 0)
ctor_allreduce = ( ctor_allreduce = num_hook_fired.get(OpType.ALLREDUCE, 0)
num_hook_fired[OpType.ALLREDUCE]
if OpType.ALLREDUCE in num_hook_fired
else 0
)
x = torch.zeros(2, 1000).cuda(self.rank) x = torch.zeros(2, 1000).cuda(self.rank)
ddp(x).sum().backward() ddp(x).sum().backward()

View File

@ -82,7 +82,7 @@ def grad(L, desired_results: list[Variable]) -> list[Variable]:
# look up dL_dentries. If a variable is never used to compute the loss, # look up dL_dentries. If a variable is never used to compute the loss,
# we consider its gradient None, see the note below about zeros for more information. # we consider its gradient None, see the note below about zeros for more information.
def gather_grad(entries: list[str]): def gather_grad(entries: list[str]):
return [dL_d[entry] if entry in dL_d else None for entry in entries] return [dL_d.get(entry) for entry in entries]
# propagate the gradient information backward # propagate the gradient information backward
for entry in reversed(gradient_tape): for entry in reversed(gradient_tape):

View File

@ -286,7 +286,7 @@ class OptionalScaledTensor(torch.Tensor):
def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride):
return OptionalScaledTensor( return OptionalScaledTensor(
inner_tensors["_data"], inner_tensors["_data"],
inner_tensors["_scale"] if "_scale" in inner_tensors else None, inner_tensors.get("_scale", None),
constant=metadata["_constant"], constant=metadata["_constant"],
) )

View File

@ -358,9 +358,7 @@ def _sequential_split_inline_tests():
for i, node in enumerate(insert_locs): for i, node in enumerate(insert_locs):
with gm.graph.inserting_before(node): with gm.graph.inserting_before(node):
gm.graph.call_function( gm.graph.call_function(torch._C._set_grad_enabled, (i % 2 == 0,), {})
torch._C._set_grad_enabled, (True if i % 2 == 0 else False,), {}
)
return gm return gm
x = torch.randn(2, 2) x = torch.randn(2, 2)

View File

@ -2932,9 +2932,7 @@ class GraphModule(torch.nn.Module):
if autograd: if autograd:
result_flat = pytree.tree_leaves(result) result_flat = pytree.tree_leaves(result)
result_exp_flat = pytree.tree_leaves(result_exp) result_exp_flat = pytree.tree_leaves(result_exp)
exp_grad_mask = [ exp_grad_mask = [bool(r.requires_grad) for r in result_exp_flat]
True if r.requires_grad else False for r in result_exp_flat
]
self.check_autograd( self.check_autograd(
[r for r, m in zip(result_flat, exp_grad_mask) if m], [r for r, m in zip(result_flat, exp_grad_mask) if m],
[r for r, m in zip(result_exp_flat, exp_grad_mask) if m], [r for r, m in zip(result_exp_flat, exp_grad_mask) if m],
@ -3741,9 +3739,7 @@ class AssociativeScanTests(TestCase):
): ):
result_flat = pytree.tree_leaves(result) result_flat = pytree.tree_leaves(result)
result_exp_flat = pytree.tree_leaves(result_exp) result_exp_flat = pytree.tree_leaves(result_exp)
exp_grad_mask = [ exp_grad_mask = [bool(r.requires_grad) for r in result_exp_flat]
True if r.requires_grad else False for r in result_exp_flat
]
self._check_autograd( self._check_autograd(
[r for r, m in zip(result_flat, exp_grad_mask) if m], [r for r, m in zip(result_flat, exp_grad_mask) if m],
@ -5710,10 +5706,9 @@ def forward(self, arg0_1):
) )
def test_while_loop_tracing(self, while_loop_test): def test_while_loop_tracing(self, while_loop_test):
fn, inp = WHILE_LOOP_TESTS[while_loop_test] fn, inp = WHILE_LOOP_TESTS[while_loop_test]
allow_non_fake_inputs = ( allow_non_fake_inputs = while_loop_test in (
False "simple_with_linear",
if while_loop_test not in ("simple_with_linear", "nested_with_linear") "nested_with_linear",
else True
) )
self._check_tracing(fn, inp, allow_non_fake_inputs) self._check_tracing(fn, inp, allow_non_fake_inputs)

View File

@ -177,9 +177,7 @@ class TestFXNodeSource(TestCase):
for node_name_2 in node_name_to_from_node: for node_name_2 in node_name_to_from_node:
if node_name_2 in { if node_name_2 in {
node_name_1, node_name_1,
same_ancestor_nodes[node_name_1] same_ancestor_nodes.get(node_name_1),
if node_name_1 in same_ancestor_nodes
else None,
}: }:
self.assertEqual( self.assertEqual(
node_name_to_from_node[node_name_1], node_name_to_from_node[node_name_1],

View File

@ -164,9 +164,7 @@ class B2BGEMMTest(TestCase):
self.assertTrue("B2B_GEMM_LEFT_TRITON_ENTRANCE" not in code) self.assertTrue("B2B_GEMM_LEFT_TRITON_ENTRANCE" not in code)
self.assertTrue("B2B_GEMM_RIGHT_TRITON_ENTRANCE" not in code) self.assertTrue("B2B_GEMM_RIGHT_TRITON_ENTRANCE" not in code)
@unittest.skipIf( @unittest.skipIf(os.environ.get("DO_PERF_TEST") != "1", "Perf test not enabled")
not (os.environ.get("DO_PERF_TEST") == "1"), "Perf test not enabled"
)
@torch._dynamo.config.patch(recompile_limit=32) @torch._dynamo.config.patch(recompile_limit=32)
def test_plain_b2b_gemm_performance(self): def test_plain_b2b_gemm_performance(self):
"""compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)""" """compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)"""
@ -219,9 +217,7 @@ class B2BGEMMTest(TestCase):
# flaky test assertion: disabled # flaky test assertion: disabled
# self.assertTrue(average_speedup > 1) # self.assertTrue(average_speedup > 1)
@unittest.skipIf( @unittest.skipIf(os.environ.get("DO_PERF_TEST") != "1", "Perf test not enabled")
not (os.environ.get("DO_PERF_TEST") == "1"), "Perf test not enabled"
)
@torch._dynamo.config.patch(recompile_limit=32) @torch._dynamo.config.patch(recompile_limit=32)
def test_gelu_b2b_gemm_performance(self): def test_gelu_b2b_gemm_performance(self):
"""compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)""" """compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)"""
@ -276,9 +272,7 @@ class B2BGEMMTest(TestCase):
# flaky test assertion: disabled # flaky test assertion: disabled
# self.assertTrue(average_speedup > 1) # self.assertTrue(average_speedup > 1)
@unittest.skipIf( @unittest.skipIf(os.environ.get("DO_PERF_TEST") != "1", "Perf test not enabled")
not (os.environ.get("DO_PERF_TEST") == "1"), "Perf test not enabled"
)
@torch._dynamo.config.patch(recompile_limit=32) @torch._dynamo.config.patch(recompile_limit=32)
def test_gelu_mlp_b2b_gemm_performance(self): def test_gelu_mlp_b2b_gemm_performance(self):
"""compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)""" """compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)"""

View File

@ -165,7 +165,7 @@ class BenchmarkFusionTestTemplate:
_, out_code = run_and_get_code(foo_c, m, inp) _, out_code = run_and_get_code(foo_c, m, inp)
# occasionally, CI will make this one kernel. just skip in this case # occasionally, CI will make this one kernel. just skip in this case
if not out_code[0].count("def triton_") == 2: if out_code[0].count("def triton_") != 2:
return return
# should be multiple triton invocations # should be multiple triton invocations

View File

@ -289,7 +289,7 @@ def build_opt_kwarg_db():
has_tensor_lr = False has_tensor_lr = False
for key, val in kwargs.items(): for key, val in kwargs.items():
if (not key == "lr" and not key == "betas") and ( if (key != "lr" and key != "betas") and (
not isinstance(val, bool) or (isinstance(val, bool) and val) not isinstance(val, bool) or (isinstance(val, bool) and val)
): ):
name += "_" + key name += "_" + key
@ -450,7 +450,7 @@ def make_test(
stack.enter_context(config.patch({"triton.cudagraphs": True})) stack.enter_context(config.patch({"triton.cudagraphs": True}))
kwargs_compiled = deepcopy(kwargs) kwargs_compiled = deepcopy(kwargs)
if isinstance(kwargs.get("lr", None), torch.Tensor): if isinstance(kwargs.get("lr"), torch.Tensor):
kwargs["lr"] = kwargs["lr"].to(device) kwargs["lr"] = kwargs["lr"].to(device)
kwargs_compiled["lr"] = kwargs_compiled["lr"].to(device) kwargs_compiled["lr"] = kwargs_compiled["lr"].to(device)

View File

@ -177,7 +177,7 @@ if HAS_CUDA_AND_TRITON:
def get_manager(self, device_index=None): def get_manager(self, device_index=None):
return torch._inductor.cudagraph_trees.get_container( return torch._inductor.cudagraph_trees.get_container(
self.device_idx if not device_index else device_index device_index if device_index else self.device_idx
).tree_manager ).tree_manager
def get_roots(self): def get_roots(self):

View File

@ -585,9 +585,7 @@ class TestFlexAttention(InductorTestCase):
) )
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
sdpa_partial = create_attention( sdpa_partial = create_attention(score_mod, block_mask, enable_gqa=(Q_H != KV_H))
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
)
compiled_sdpa = torch.compile(sdpa_partial) compiled_sdpa = torch.compile(sdpa_partial)
golden_out = sdpa_partial(q_gold, k_gold, v_gold) golden_out = sdpa_partial(q_gold, k_gold, v_gold)
@ -761,7 +759,7 @@ class TestFlexAttention(InductorTestCase):
return_lse=return_lse, return_lse=return_lse,
block_mask=converted_block_mask, block_mask=converted_block_mask,
score_mod=converted_score_mod, score_mod=converted_score_mod,
enable_gqa=(not Q_H == KV_H), enable_gqa=(Q_H != KV_H),
kernel_options=kernel_options, kernel_options=kernel_options,
) )
else: else:
@ -774,7 +772,7 @@ class TestFlexAttention(InductorTestCase):
return_lse=return_lse, return_lse=return_lse,
block_mask=converted_block_mask, block_mask=converted_block_mask,
score_mod=converted_score_mod, score_mod=converted_score_mod,
enable_gqa=(not Q_H == KV_H), enable_gqa=(Q_H != KV_H),
kernel_options=kernel_options, kernel_options=kernel_options,
) )
return compiled_out, compiled_lse return compiled_out, compiled_lse
@ -819,9 +817,7 @@ class TestFlexAttention(InductorTestCase):
if block_mask is None: if block_mask is None:
block_mask = create_block_mask(noop_mask, Q_B, 1, Q_S, KV_S, device=device) block_mask = create_block_mask(noop_mask, Q_B, 1, Q_S, KV_S, device=device)
sdpa_partial = create_attention( sdpa_partial = create_attention(score_mod, block_mask, enable_gqa=(Q_H != KV_H))
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
)
golden_out, golden_lse = sdpa_partial(q_gold, k_gold, v_gold, return_lse=True) golden_out, golden_lse = sdpa_partial(q_gold, k_gold, v_gold, return_lse=True)
ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True) ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True)
@ -1466,7 +1462,7 @@ class TestFlexAttention(InductorTestCase):
block_mask = create_block_mask(mask_mod, Bq, 1, S, S, device=device) block_mask = create_block_mask(mask_mod, Bq, 1, S, S, device=device)
attention = functools.partial( attention = functools.partial(
flex_attention, block_mask=block_mask, enable_gqa=(not Hq == Hkv) flex_attention, block_mask=block_mask, enable_gqa=(Hq != Hkv)
) )
self.run_test_with_call(attention, dtype, device, Bq, Hq, S, D, Bkv, Hkv, S, D) self.run_test_with_call(attention, dtype, device, Bq, Hq, S, D, Bkv, Hkv, S, D)

View File

@ -412,7 +412,7 @@ class TestFlexDecoding(InductorTestCase):
sdpa_partial = create_attention( sdpa_partial = create_attention(
score_mod, score_mod,
block_mask, block_mask,
enable_gqa=(not Q_H == KV_H), enable_gqa=(Q_H != KV_H),
kernel_options=kernel_options, kernel_options=kernel_options,
) )
compiled_sdpa = torch.compile(sdpa_partial) compiled_sdpa = torch.compile(sdpa_partial)
@ -607,7 +607,7 @@ class TestFlexDecoding(InductorTestCase):
return_lse=True, return_lse=True,
block_mask=converted_block_mask, block_mask=converted_block_mask,
score_mod=converted_score_mod, score_mod=converted_score_mod,
enable_gqa=(not Q_H == KV_H), enable_gqa=(Q_H != KV_H),
) )
else: else:
compiled_lse = None compiled_lse = None
@ -618,7 +618,7 @@ class TestFlexDecoding(InductorTestCase):
return_lse=False, return_lse=False,
block_mask=converted_block_mask, block_mask=converted_block_mask,
score_mod=converted_score_mod, score_mod=converted_score_mod,
enable_gqa=(not Q_H == KV_H), enable_gqa=(Q_H != KV_H),
) )
return compiled_out, compiled_lse return compiled_out, compiled_lse
@ -664,9 +664,7 @@ class TestFlexDecoding(InductorTestCase):
if block_mask is None: if block_mask is None:
block_mask = create_block_mask(noop_mask, Q_B, 1, 1, KV_S, device=device) block_mask = create_block_mask(noop_mask, Q_B, 1, 1, KV_S, device=device)
sdpa_partial = create_attention( sdpa_partial = create_attention(score_mod, block_mask, enable_gqa=(Q_H != KV_H))
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
)
golden_out, gold_lse = sdpa_partial(q_gold, k_gold, v_gold, return_lse=True) golden_out, gold_lse = sdpa_partial(q_gold, k_gold, v_gold, return_lse=True)
ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True) ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True)
@ -906,7 +904,7 @@ class TestFlexDecoding(InductorTestCase):
sdpa_partial = create_attention( sdpa_partial = create_attention(
score_mod=score_mod, score_mod=score_mod,
block_mask=None, block_mask=None,
enable_gqa=(not Hq == Hkv), enable_gqa=(Hq != Hkv),
) )
compiled_sdpa = torch.compile(sdpa_partial) compiled_sdpa = torch.compile(sdpa_partial)
ref_out = sdpa_partial(q, k, v) ref_out = sdpa_partial(q, k, v)
@ -1144,7 +1142,7 @@ class TestFlexDecoding(InductorTestCase):
def head_attention_mod(kv_head_num): def head_attention_mod(kv_head_num):
head_type = torch.tensor( head_type = torch.tensor(
[False if i % kv_head_num == 0 else True for i in range(kv_head_num)], [i % kv_head_num != 0 for i in range(kv_head_num)],
dtype=torch.bool, dtype=torch.bool,
device=device, device=device,
) )

View File

@ -22,7 +22,7 @@ except ImportError:
HAS_PYDOT = False HAS_PYDOT = False
HAS_DOT = True if shutil.which("dot") is not None else False HAS_DOT = shutil.which("dot") is not None
class TestGraphTransformObserver(TestCase): class TestGraphTransformObserver(TestCase):

View File

@ -835,9 +835,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
for dtype in dtypes: for dtype in dtypes:
torch._dynamo.reset() torch._dynamo.reset()
autocast_enabled = ( autocast_enabled = dtype in [torch.bfloat16, torch.float16]
True if dtype in [torch.bfloat16, torch.float16] else False
)
with ( with (
torch.no_grad(), torch.no_grad(),
torch.autocast( torch.autocast(
@ -4421,14 +4419,12 @@ class TestPatternMatcher(TestPatternMatcherBase):
out_feature = 64 out_feature = 64
q_min, q_max = -32, 31 q_min, q_max = -32, 31
# we only test for qlinear_binary in this case # we only test for qlinear_binary in this case
test_for_pointwise_binary = ( test_for_pointwise_binary = bool(
True M == 1
if M == 1
and inplace_add and inplace_add
and not expand_a_scale and not expand_a_scale
and not dynamic and not dynamic
and not has_bias and not has_bias
else False
) )
if test_for_pointwise_binary and not IS_X86: if test_for_pointwise_binary and not IS_X86:
self.skipTest("Some UTs are only supported on x86_64 CPUs") self.skipTest("Some UTs are only supported on x86_64 CPUs")

View File

@ -706,7 +706,7 @@ def check_model_gpu(
if check_lowp: if check_lowp:
def downcast_fn(x): def downcast_fn(x):
if not isinstance(x, torch.Tensor) or not x.dtype == torch.float: if not isinstance(x, torch.Tensor) or x.dtype != torch.float:
return x return x
return torch.empty_strided( return torch.empty_strided(
x.size(), x.stride(), device=GPU_TYPE, dtype=torch.half x.size(), x.stride(), device=GPU_TYPE, dtype=torch.half
@ -4694,7 +4694,7 @@ class CommonTemplate:
# Make sure we compute also with fp16 in the reference. Otherwise, # Make sure we compute also with fp16 in the reference. Otherwise,
# the reference will compute with fp32 and cast back to fp16, which # the reference will compute with fp32 and cast back to fp16, which
# causes numeric differences beyond tolerance. # causes numeric differences beyond tolerance.
reference_in_float=False if torch.version.hip else True, reference_in_float=not torch.version.hip,
) )
def test_convolution2(self): def test_convolution2(self):
@ -4728,7 +4728,7 @@ class CommonTemplate:
# Make sure we compute also with fp16 in the reference. Otherwise, # Make sure we compute also with fp16 in the reference. Otherwise,
# the reference will compute with fp32 and cast back to fp16, which # the reference will compute with fp32 and cast back to fp16, which
# causes numeric differences beyond tolerance. # causes numeric differences beyond tolerance.
reference_in_float=False if torch.version.hip else True, reference_in_float=not torch.version.hip,
) )
@skip_if_gpu_halide @skip_if_gpu_halide
@ -4779,7 +4779,7 @@ class CommonTemplate:
# Make sure we compute also with fp16 in the reference. Otherwise, # Make sure we compute also with fp16 in the reference. Otherwise,
# the reference will compute with fp32 and cast back to fp16, which # the reference will compute with fp32 and cast back to fp16, which
# causes numeric differences beyond tolerance. # causes numeric differences beyond tolerance.
reference_in_float=False if torch.version.hip else True, reference_in_float=not torch.version.hip,
) )
def test_conv2d_channels_last(self): def test_conv2d_channels_last(self):
@ -12970,7 +12970,7 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
) )
res = torch.compile(fn)(20) res = torch.compile(fn)(20)
self.assertTrue(torch.all((0 <= res) & (res < 10)).item()) self.assertTrue(torch.all((res >= 0) & (res < 10)).item())
@torch._inductor.config.patch(force_shape_pad=True) @torch._inductor.config.patch(force_shape_pad=True)
@skip_if_gpu_halide # correctness issue @skip_if_gpu_halide # correctness issue

View File

@ -1220,7 +1220,7 @@ class TestInductorOpInfo(TestCase):
# not exercised in test_ops_gradients atm. The problem is not # not exercised in test_ops_gradients atm. The problem is not
# complex32 per-se (which is supported by data movement only ops) # complex32 per-se (which is supported by data movement only ops)
# but that when we do backwards we expect other ops like add to work # but that when we do backwards we expect other ops like add to work
and not dtype == torch.complex32 and dtype != torch.complex32
) )
samples = op.sample_inputs(device, dtype, requires_grad=requires_grad) samples = op.sample_inputs(device, dtype, requires_grad=requires_grad)

View File

@ -17,17 +17,13 @@ with open(sys.argv[1]) as input_yaml_file:
for info in model_infos: for info in model_infos:
for op in info["root_operators"]: for op in info["root_operators"]:
# aggregate occurance per op # aggregate occurance per op
root_operators[op] = 1 + (root_operators[op] if op in root_operators else 0) root_operators[op] = 1 + (root_operators.get(op, 0))
for op in info["traced_operators"]: for op in info["traced_operators"]:
# aggregate occurance per op # aggregate occurance per op
traced_operators[op] = 1 + ( traced_operators[op] = 1 + (traced_operators.get(op, 0))
traced_operators[op] if op in traced_operators else 0
)
# merge dtypes for each kernel # merge dtypes for each kernel
for kernal, dtypes in info["kernel_metadata"].items(): for kernal, dtypes in info["kernel_metadata"].items():
new_dtypes = dtypes + ( new_dtypes = dtypes + (kernel_metadata.get(kernal, []))
kernel_metadata[kernal] if kernal in kernel_metadata else []
)
kernel_metadata[kernal] = list(set(new_dtypes)) kernel_metadata[kernal] = list(set(new_dtypes))

View File

@ -4879,7 +4879,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
@skipScriptTest() @skipScriptTest()
def test_rnn_no_bias(self): def test_rnn_no_bias(self):
def make_model(layers, packed_sequence): def make_model(layers, packed_sequence):
batch_first = True if packed_sequence == 2 else False batch_first = packed_sequence == 2
model = torch.nn.RNN( model = torch.nn.RNN(
RNN_INPUT_SIZE, RNN_INPUT_SIZE,
RNN_HIDDEN_SIZE, RNN_HIDDEN_SIZE,
@ -4900,7 +4900,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
return model return model
def make_input(batch_size, layers, packed_sequence): def make_input(batch_size, layers, packed_sequence):
batch_first = True if packed_sequence == 2 else False batch_first = packed_sequence == 2
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size) seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
seq_lengths = sorted(map(int, seq_lengths), reverse=True) seq_lengths = sorted(map(int, seq_lengths), reverse=True)
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths] inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]

View File

@ -7045,8 +7045,8 @@ class TestQuantizedConv(TestCase):
# ONEDNN only supports symmetric quantization of weight # ONEDNN only supports symmetric quantization of weight
if W_zero_point is not None: if W_zero_point is not None:
W_zero_point = len(W_zero_point) * [0] W_zero_point = len(W_zero_point) * [0]
fp32_output = True if qconv_output_dtype is torch.float32 else False fp32_output = qconv_output_dtype is torch.float32
bfloat16_output = True if qconv_output_dtype is torch.bfloat16 else False bfloat16_output = qconv_output_dtype is torch.bfloat16
if fp32_output or bfloat16_output: if fp32_output or bfloat16_output:
Y_scale = 1.0 Y_scale = 1.0
Y_zero_point = 0 Y_zero_point = 0
@ -7905,8 +7905,8 @@ class TestQuantizedConv(TestCase):
weight_in_channel_last_format=False, weight_in_channel_last_format=False,
): ):
# We assume FP8 quantization is always symmetric # We assume FP8 quantization is always symmetric
fp32_output = True if qconv_output_dtype is torch.float32 else False fp32_output = qconv_output_dtype is torch.float32
bfloat16_output = True if qconv_output_dtype is torch.bfloat16 else False bfloat16_output = qconv_output_dtype is torch.bfloat16
if fp32_output or bfloat16_output: if fp32_output or bfloat16_output:
Y_scale = 1.0 Y_scale = 1.0
X2_scale = 1.0 X2_scale = 1.0

View File

@ -11861,7 +11861,7 @@ class TestAutogradDeviceType(TestCase):
def test_nonzero(tensor, value, expected): def test_nonzero(tensor, value, expected):
tensor[0] = value tensor[0] = value
self.assertEqual(expected, bool(tensor)) self.assertEqual(expected, bool(tensor))
self.assertEqual(expected, True if tensor else False) self.assertEqual(expected, bool(tensor))
test_nonzero(l, 0, False) test_nonzero(l, 0, False)
test_nonzero(l, -2, True) test_nonzero(l, -2, True)

View File

@ -577,7 +577,7 @@ print(t.is_pinned())
src = torch.randn( src = torch.randn(
1000000, 1000000,
device="cuda" if dst == "cpu" else "cpu", device="cuda" if dst == "cpu" else "cpu",
pin_memory=True if dst == "cuda" else False, pin_memory=dst == "cuda",
) )
_test_to_non_blocking(src, try_non_blocking, dst) _test_to_non_blocking(src, try_non_blocking, dst)

View File

@ -945,7 +945,7 @@ def forward(self, scores_1, mask_1, value_1):
# not exercised in test_ops_gradients atm. The problem is not # not exercised in test_ops_gradients atm. The problem is not
# complex32 per-se (which is supported by data movement only ops) # complex32 per-se (which is supported by data movement only ops)
# but that when we do backwards we expect other ops like add to work # but that when we do backwards we expect other ops like add to work
and not dtype == torch.complex32 and dtype != torch.complex32
) )
samples = op.sample_inputs(device, dtype, requires_grad=requires_grad) samples = op.sample_inputs(device, dtype, requires_grad=requires_grad)

View File

@ -3584,7 +3584,7 @@ class TestFX(JitTestCase):
class LeafTracerNotB(Tracer): class LeafTracerNotB(Tracer):
def is_leaf_module(self, module, name): def is_leaf_module(self, module, name):
return False if "b" in name else True return "b" not in name
# Recompile calls added "for fun", since they # Recompile calls added "for fun", since they
# chain __call__ wrappers. # chain __call__ wrappers.

View File

@ -2036,7 +2036,7 @@ class TestIndexing(TestCase):
index = torch.tensor([0], device=device) index = torch.tensor([0], device=device)
x.index_fill_(1, index, 0) x.index_fill_(1, index, 0)
self.assertEqual(x, torch.tensor([[0, 2], [0, 5]], dtype=dtype, device=device)) self.assertEqual(x, torch.tensor([[0, 2], [0, 5]], dtype=dtype, device=device))
if not x.is_complex() and not device == "meta": if not x.is_complex() and device != "meta":
with self.assertRaisesRegex(RuntimeError, r"Scalar"): with self.assertRaisesRegex(RuntimeError, r"Scalar"):
x.index_fill_(1, index, 1 + 1j) x.index_fill_(1, index, 1 + 1j)
# Make sure that the result stays 0-dim while applied to # Make sure that the result stays 0-dim while applied to

View File

@ -6723,7 +6723,7 @@ a")
@torch.jit.script @torch.jit.script
def testNoThrows(t): def testNoThrows(t):
c1 = 1 c1 = 1
if (False and bool(t[1])) or (True or bool(t[1])): if (False and bool(t[1])) or (True or bool(t[1])): # noqa: SIM222,SIM223
c1 = 0 c1 = 0
return c1 return c1
@ -15758,7 +15758,7 @@ dedent """
def fn(d): def fn(d):
# type: (Dict[str, int]) -> List[int] # type: (Dict[str, int]) -> List[int]
out = [1] out = [1]
for i in range(d["hi"] if "hi" in d else 6): for i in range(d.get("hi", 6)):
out.append(i) # noqa: PERF402 out.append(i) # noqa: PERF402
return out return out
@ -16104,7 +16104,7 @@ M = 10
S = 5 S = 5
def add_nn_module_test(*args, **kwargs): def add_nn_module_test(*args, **kwargs):
no_grad = False if 'no_grad' not in kwargs else kwargs['no_grad'] no_grad = kwargs.get('no_grad', False)
if 'desc' in kwargs and 'eval' in kwargs['desc']: if 'desc' in kwargs and 'eval' in kwargs['desc']:
# eval() is not supported, so skip these tests # eval() is not supported, so skip these tests

View File

@ -111,7 +111,7 @@ class TestAutocast(JitTestCase):
def test_runtime_autocast_state_expr(self): def test_runtime_autocast_state_expr(self):
@torch.jit.script @torch.jit.script
def fn(a, b): def fn(a, b):
with autocast(enabled=True if a[0][0] > 0.5 else False): with autocast(enabled=bool((a[0][0] > 0.5).item())):
return torch.mm(a, b) return torch.mm(a, b)
# runtime values for autocast enable argument are not supported # runtime values for autocast enable argument are not supported
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):

View File

@ -3522,7 +3522,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
nn.RNN(10, 20, batch_first=True) nn.RNN(10, 20, batch_first=True)
] ]
# ROCm RNN does not issue warning about single contig chunk of memory, so don't assert it # ROCm RNN does not issue warning about single contig chunk of memory, so don't assert it
first_warn = False if torch.version.hip else True first_warn = not torch.version.hip
for rnn in rnns: for rnn in rnns:
rnn.cuda() rnn.cuda()
input = torch.randn(5, 4, 10, requires_grad=True, device="cuda") input = torch.randn(5, 4, 10, requires_grad=True, device="cuda")

View File

@ -205,7 +205,7 @@ class TestNumPyInterop(TestCase):
x = x.conj() x = x.conj()
y = x.resolve_conj() y = x.resolve_conj()
expect_error = ( expect_error = (
requires_grad or sparse or conj or not device == "cpu" requires_grad or sparse or conj or device != "cpu"
) )
error_msg = r"Use (t|T)ensor\..*(\.numpy\(\))?" error_msg = r"Use (t|T)ensor\..*(\.numpy\(\))?"
if not force and expect_error: if not force and expect_error:

View File

@ -18,7 +18,7 @@ class PruningOpTest(TestCase):
def _generate_rowwise_mask(self, embedding_rows): def _generate_rowwise_mask(self, embedding_rows):
indicator = torch.from_numpy((np.random.random_sample(embedding_rows)).astype(np.float32)) indicator = torch.from_numpy((np.random.random_sample(embedding_rows)).astype(np.float32))
threshold = float(np.random.random_sample()) threshold = float(np.random.random_sample())
mask = torch.BoolTensor([True if val >= threshold else False for val in indicator]) mask = torch.BoolTensor([val >= threshold for val in indicator])
return mask return mask
def _test_rowwise_prune_op(self, embedding_rows, embedding_dims, indices_type, weights_dtype): def _test_rowwise_prune_op(self, embedding_rows, embedding_dims, indices_type, weights_dtype):

View File

@ -1899,7 +1899,7 @@ class TestReductions(TestCase):
# Note [all, any uint8 compatibility]: However for compatibility reason, # Note [all, any uint8 compatibility]: However for compatibility reason,
# for `uint8`, they return Tensor of same dtype `uint8`. # for `uint8`, they return Tensor of same dtype `uint8`.
# Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561 # Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561
exact_dtype = True if dtype != torch.uint8 else False exact_dtype = dtype != torch.uint8
def _test_all_any(x): def _test_all_any(x):
self.compare_with_numpy(torch.all, np.all, x) self.compare_with_numpy(torch.all, np.all, x)

View File

@ -1204,7 +1204,7 @@ class TestFP8Matmul(TestCase):
events = sorted(events, key=lambda x: x['ts']) events = sorted(events, key=lambda x: x['ts'])
# ROCm carveout is invisible except for kernels running slower on fewer CUs # ROCm carveout is invisible except for kernels running slower on fewer CUs
no_carveout, carveout_0, carveout, no_carveout_again = [float(evt.get("dur", "0.0")) for evt in events] no_carveout, carveout_0, carveout, no_carveout_again = [float(evt.get("dur", "0.0")) for evt in events]
if True or not (no_carveout < carveout and carveout_0 < carveout and no_carveout_again < carveout): if True or not (no_carveout < carveout and carveout_0 < carveout and no_carveout_again < carveout): # noqa: SIM222
# something went wrong, print more info to help debug flaky test # something went wrong, print more info to help debug flaky test
print("ROCm debug info for test_honor_sm_carveout") print("ROCm debug info for test_honor_sm_carveout")
print("cu_count", cu_count) print("cu_count", cu_count)

View File

@ -129,7 +129,7 @@ class TestSegmentReductions(TestCase):
for reduction in reductions: for reduction in reductions:
for initial in [0, None]: for initial in [0, None]:
check_backward = True if initial is not None else False check_backward = initial is not None
initial_value = initial initial_value = initial
default_value = get_default_value(initial_value, reduction) default_value = get_default_value(initial_value, reduction)
if reduction == "max": if reduction == "max":
@ -186,7 +186,7 @@ class TestSegmentReductions(TestCase):
for reduction in reductions: for reduction in reductions:
for initial in [0, None]: for initial in [0, None]:
check_backward = True if initial is not None else False check_backward = initial is not None
initial_value = initial initial_value = initial
default_value = get_default_value(initial_value, reduction) default_value = get_default_value(initial_value, reduction)
if reduction == "max": if reduction == "max":
@ -244,7 +244,7 @@ class TestSegmentReductions(TestCase):
for reduction in reductions: for reduction in reductions:
for initial in [0, None]: for initial in [0, None]:
check_backward = True if initial is not None else False check_backward = initial is not None
initial_value = initial initial_value = initial
default_value = get_default_value(initial_value, reduction) default_value = get_default_value(initial_value, reduction)
if reduction == "max": if reduction == "max":

View File

@ -4553,7 +4553,7 @@ class TestSerialization(TestCase, SerializationMixin):
with TemporaryFileName() as f: with TemporaryFileName() as f:
torch.save(m, f) torch.save(m, f)
try: try:
old_value = os.environ[env_var] if env_var in os.environ else None old_value = os.environ.get(env_var, None)
os.environ[env_var] = "1" os.environ[env_var] = "1"
# if weights_only is explicitly set, TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD cannot override it # if weights_only is explicitly set, TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD cannot override it
with self.assertRaisesRegex(pickle.UnpicklingError, "Weights only load failed"): with self.assertRaisesRegex(pickle.UnpicklingError, "Weights only load failed"):

View File

@ -4099,7 +4099,7 @@ class TestSparseCompressedTritonKernels(TestCase):
left_alpha = make_tensor(M, dtype=dtype, device=device, low=0.5, high=high) if has_left_alpha else None left_alpha = make_tensor(M, dtype=dtype, device=device, low=0.5, high=high) if has_left_alpha else None
right_alpha = make_tensor(N, dtype=dtype, device=device, low=0.5, high=high) if has_right_alpha else None right_alpha = make_tensor(N, dtype=dtype, device=device, low=0.5, high=high) if has_right_alpha else None
if 0 and op == "bsr_dense_addmm": if 0 and op == "bsr_dense_addmm": # noqa: SIM223
# Find optimal kernel parameters, the speed-up is # Find optimal kernel parameters, the speed-up is
# about 10x for running this test. # about 10x for running this test.
# #

View File

@ -3498,7 +3498,7 @@ class TestRandomTensorCreation(TestCase):
else: else:
t.uniform_(from_, to_) t.uniform_(from_, to_)
range_ = to_ - from_ range_ = to_ - from_
if not (dtype == torch.bfloat16) and not ( if dtype != torch.bfloat16 and not (
dtype == torch.half and device == 'cpu') and not torch.isnan(t).all(): dtype == torch.half and device == 'cpu') and not torch.isnan(t).all():
delta = alpha * range_ delta = alpha * range_
double_t = t.to(torch.double) double_t = t.to(torch.double)

View File

@ -359,7 +359,9 @@ class TestFuzzerCompileIssues(TestCase):
t3 = arg1 # size=(1,), stride=(1,), dtype=int64, device=cuda t3 = arg1 # size=(1,), stride=(1,), dtype=int64, device=cuda
t4 = arg2 # size=(1,), stride=(1,), dtype=int64, device=cuda t4 = arg2 # size=(1,), stride=(1,), dtype=int64, device=cuda
t5 = t3 + t3 + t4 # size=(1,), stride=(1,), dtype=int64, device=cuda t5 = t3 + t3 + t4 # size=(1,), stride=(1,), dtype=int64, device=cuda
t6 = torch.exp(t5) # size=(1,), stride=(1,), dtype=int64, device=cuda t6 = torch.exp( # noqa: F841
t5
) # size=(1,), stride=(1,), dtype=int64, device=cuda # noqa: F841
t7 = torch.nn.functional.layer_norm( t7 = torch.nn.functional.layer_norm(
t2, (111,) t2, (111,)
) # size=(49, 112, 111), stride=(12432, 111, 1), dtype=float32, device=cuda ) # size=(49, 112, 111), stride=(12432, 111, 1), dtype=float32, device=cuda
@ -436,7 +438,7 @@ class TestFuzzerCompileIssues(TestCase):
torch.manual_seed(9) torch.manual_seed(9)
def foo(arg0): def foo(arg0):
var_node_1 = arg0 # size=(1, 2), stride=(2, 1), dtype=int64, device=cuda var_node_1 = arg0 # size=(1, 2), stride=(2, 1), dtype=int64, device=cuda # noqa: F841
var_node_5 = torch.full( var_node_5 = torch.full(
(1, 2), -66, dtype=torch.int32 (1, 2), -66, dtype=torch.int32
) # size=(1, 2), stride=(2, 1), dtype=int32, device=cuda ) # size=(1, 2), stride=(2, 1), dtype=int32, device=cuda

View File

@ -100,7 +100,7 @@ class TestBuiltin(TestCase):
# dtypes results in False/True when compared to valid dtypes. # dtypes results in False/True when compared to valid dtypes.
# Here 7 cannot be converted to dtype. No exceptions should be raised # Here 7 cannot be converted to dtype. No exceptions should be raised
assert not np.dtype(np.int32) == 7, "dtype richcompare failed for ==" assert np.dtype(np.int32) != 7, "dtype richcompare failed for =="
assert np.dtype(np.int32) != 7, "dtype richcompare failed for !=" assert np.dtype(np.int32) != 7, "dtype richcompare failed for !="
@parametrize("operation", [operator.le, operator.lt, operator.ge, operator.gt]) @parametrize("operation", [operator.le, operator.lt, operator.ge, operator.gt])

View File

@ -416,11 +416,8 @@ class JsonProfile:
# pyrefly: ignore # bad-assignment # pyrefly: ignore # bad-assignment
self.dtype = dtype self.dtype = dtype
else: else:
if dtype in _dtype_map: # pyrefly: ignore # bad-assignment
# pyrefly: ignore # bad-assignment self.dtype = _dtype_map.get(dtype)
self.dtype = _dtype_map[dtype]
else:
self.dtype = None
self._create_devices() self._create_devices()
def convert_dtype(self, event: dict[str, Any]) -> Optional[torch.dtype]: def convert_dtype(self, event: dict[str, Any]) -> Optional[torch.dtype]:

View File

@ -1363,7 +1363,7 @@ class TritonOverrides(OpOverrides):
value = triton_reshape(value, initial_shape, shape_2d) value = triton_reshape(value, initial_shape, shape_2d)
# broadcast if needed # broadcast if needed
broadcast_needed = not (shape_2d == [YBLOCK, RBLOCK]) broadcast_needed = shape_2d != [YBLOCK, RBLOCK]
if broadcast_needed: if broadcast_needed:
value = f"tl.broadcast_to({value}, ({YBLOCK}, {RBLOCK}))" value = f"tl.broadcast_to({value}, ({YBLOCK}, {RBLOCK}))"
@ -1385,7 +1385,7 @@ class TritonOverrides(OpOverrides):
value = f"tl.trans({value})" value = f"tl.trans({value})"
# broadcast if needed # broadcast if needed
broadcast_needed = not (shape_2d == [XBLOCK, RBLOCK]) broadcast_needed = shape_2d != [XBLOCK, RBLOCK]
if broadcast_needed: if broadcast_needed:
value = f"tl.broadcast_to({value}, ({RBLOCK}, {XBLOCK}))" value = f"tl.broadcast_to({value}, ({RBLOCK}, {XBLOCK}))"
else: else:

View File

@ -1570,7 +1570,7 @@ class Reduction(Loops):
and V.graph.sizevars.size_hint_or_throw(reduction_numel) and V.graph.sizevars.size_hint_or_throw(reduction_numel)
< config.unroll_reductions_threshold < config.unroll_reductions_threshold
and (sympy_product(ranges) != 1 or is_gpu(device.type)) and (sympy_product(ranges) != 1 or is_gpu(device.type))
and not (reduction_type == "dot") and reduction_type != "dot"
): ):
# When native matmul, don't unroll the dot reduction. # When native matmul, don't unroll the dot reduction.

View File

@ -834,7 +834,7 @@ class SizeVarAllocator:
any_unbacked_lhs = has_free_unbacked_symbols(lhs) any_unbacked_lhs = has_free_unbacked_symbols(lhs)
any_unbacked_rhs = has_free_unbacked_symbols(rhs) any_unbacked_rhs = has_free_unbacked_symbols(rhs)
if any_unbacked_lhs != any_unbacked_rhs: if any_unbacked_lhs != any_unbacked_rhs:
return True if any_unbacked_rhs else False return bool(any_unbacked_rhs)
# Handles cases where LHS contains the RHS. In other words, # Handles cases where LHS contains the RHS. In other words,
# RHS is a sub-expression of LHS. For example: # RHS is a sub-expression of LHS. For example:
@ -848,12 +848,12 @@ class SizeVarAllocator:
degrees_lhs = len(self.eq_graph[lhs]) degrees_lhs = len(self.eq_graph[lhs])
degrees_rhs = len(self.eq_graph[rhs]) degrees_rhs = len(self.eq_graph[rhs])
if degrees_lhs != degrees_rhs: if degrees_lhs != degrees_rhs:
return True if degrees_lhs > degrees_rhs else False return degrees_lhs > degrees_rhs
# Try to apply union-by-rank optimization to flatten the # Try to apply union-by-rank optimization to flatten the
# leader trees. # leader trees.
if self.rank[x] != self.rank[y]: if self.rank[x] != self.rank[y]:
return True if self.rank[x] > self.rank[y] else False return self.rank[x] > self.rank[y]
# Fallback to sympy.Basic.compare for a deterministic ordering. # Fallback to sympy.Basic.compare for a deterministic ordering.
return lhs.compare(rhs) == -1 return lhs.compare(rhs) == -1

View File

@ -708,7 +708,7 @@ def _distribute_state_dict(
local_state_dict[key] = value.cpu() local_state_dict[key] = value.cpu()
else: else:
assert isinstance(value, torch.Tensor) assert isinstance(value, torch.Tensor)
local_state = local_state_dict.get(key, None) local_state = local_state_dict.get(key)
if local_state is None: if local_state is None:
continue continue
elif isinstance(local_state, DTensor): elif isinstance(local_state, DTensor):

View File

@ -6686,7 +6686,7 @@ def scaled_mm(
# So, we need to convert None arguments for lists in python # So, we need to convert None arguments for lists in python
# explicitly into empty lists. # explicitly into empty lists.
def list_or_empty(l: list[_Any] | None) -> list[_Any]: def list_or_empty(l: list[_Any] | None) -> list[_Any]:
return [] if not l else l return l if l else []
def enum_list_as_int_list(l: _Any | list[_Any]) -> list[_Any]: def enum_list_as_int_list(l: _Any | list[_Any]) -> list[_Any]:
if not isinstance(l, list): if not isinstance(l, list):
@ -6772,7 +6772,7 @@ def scaled_grouped_mm(
# So, we need to convert None arguments for lists in python # So, we need to convert None arguments for lists in python
# explicitly into empty lists. # explicitly into empty lists.
def list_or_empty(l: list[_Any] | None) -> list[_Any]: def list_or_empty(l: list[_Any] | None) -> list[_Any]:
return [] if not l else l return l if l else []
def enum_list_as_int_list(l: _Any | list[_Any]) -> list[_Any]: def enum_list_as_int_list(l: _Any | list[_Any]) -> list[_Any]:
if not isinstance(l, list): if not isinstance(l, list):

View File

@ -150,7 +150,7 @@ def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> str | None:
assert schema.kind() == SchemaKind.inplace assert schema.kind() == SchemaKind.inplace
if not is_mutated_arg(schema.arguments.flat_all[0]): if not is_mutated_arg(schema.arguments.flat_all[0]):
return None return None
if not len([arg for arg in schema.arguments.flat_all if is_mutated_arg(arg)]) == 1: if len([arg for arg in schema.arguments.flat_all if is_mutated_arg(arg)]) != 1:
return None return None
# Only support cases where all returns are Tensors or vector<Tensor> # Only support cases where all returns are Tensors or vector<Tensor>