mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable all SIM rules except disabled ones (#164645)
`SIM` rules are useful for simplifying boolean expressions and enhances code readability. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164645 Approved by: https://github.com/ezyang, https://github.com/mlazos
This commit is contained in:
parent
f1d882212a
commit
e925dfcc6b
2
.github/scripts/trymerge.py
vendored
2
.github/scripts/trymerge.py
vendored
|
|
@ -1092,7 +1092,7 @@ class GitHubPR:
|
||||||
editor = node["editor"]
|
editor = node["editor"]
|
||||||
return GitHubComment(
|
return GitHubComment(
|
||||||
body_text=node["bodyText"],
|
body_text=node["bodyText"],
|
||||||
created_at=node["createdAt"] if "createdAt" in node else "",
|
created_at=node.get("createdAt", ""),
|
||||||
author_login=node["author"]["login"],
|
author_login=node["author"]["login"],
|
||||||
author_url=node["author"].get("url", None),
|
author_url=node["author"].get("url", None),
|
||||||
author_association=node["authorAssociation"],
|
author_association=node["authorAssociation"],
|
||||||
|
|
|
||||||
|
|
@ -4060,7 +4060,7 @@ def run(runner, args, original_dir=None):
|
||||||
else:
|
else:
|
||||||
optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython)
|
optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython)
|
||||||
experiment = (
|
experiment = (
|
||||||
speedup_experiment if not args.backend == "torchao" else latency_experiment
|
speedup_experiment if args.backend != "torchao" else latency_experiment
|
||||||
)
|
)
|
||||||
if args.accuracy:
|
if args.accuracy:
|
||||||
output_filename = f"accuracy_{args.backend}.csv"
|
output_filename = f"accuracy_{args.backend}.csv"
|
||||||
|
|
|
||||||
|
|
@ -271,7 +271,7 @@ def run_single_backend_sdpa(
|
||||||
|
|
||||||
if config.calculate_bwd_time:
|
if config.calculate_bwd_time:
|
||||||
# TODO: debug backward pass for njt
|
# TODO: debug backward pass for njt
|
||||||
if eager_sdpa and not config.attn_type == "document_mask":
|
if eager_sdpa and config.attn_type != "document_mask":
|
||||||
d_out = torch.randn_like(out_eager.transpose(1, 2)).transpose(1, 2)
|
d_out = torch.randn_like(out_eager.transpose(1, 2)).transpose(1, 2)
|
||||||
backward_eager_time = benchmark_torch_function_in_microseconds(
|
backward_eager_time = benchmark_torch_function_in_microseconds(
|
||||||
out_eager.backward, d_out, retain_graph=True
|
out_eager.backward, d_out, retain_graph=True
|
||||||
|
|
|
||||||
|
|
@ -180,6 +180,7 @@ ignore = [
|
||||||
"SIM116", # Disable Use a dictionary instead of consecutive `if` statements
|
"SIM116", # Disable Use a dictionary instead of consecutive `if` statements
|
||||||
"SIM117",
|
"SIM117",
|
||||||
"SIM118",
|
"SIM118",
|
||||||
|
"SIM300", # Yoda condition detected
|
||||||
"UP007", # keep-runtime-typing
|
"UP007", # keep-runtime-typing
|
||||||
"UP045", # keep-runtime-typing
|
"UP045", # keep-runtime-typing
|
||||||
"TC006",
|
"TC006",
|
||||||
|
|
@ -195,8 +196,7 @@ select = [
|
||||||
"E",
|
"E",
|
||||||
"EXE",
|
"EXE",
|
||||||
"F",
|
"F",
|
||||||
"SIM1",
|
"SIM",
|
||||||
"SIM911",
|
|
||||||
"W",
|
"W",
|
||||||
# Not included in flake8
|
# Not included in flake8
|
||||||
"FURB",
|
"FURB",
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ class TestActivationSparsifier(TestCase):
|
||||||
|
|
||||||
for key, config in sparsifier_defaults.items():
|
for key, config in sparsifier_defaults.items():
|
||||||
# all the keys in combined_defaults should be present in sparsifier defaults
|
# all the keys in combined_defaults should be present in sparsifier defaults
|
||||||
assert config == combined_defaults.get(key, None)
|
assert config == combined_defaults.get(key)
|
||||||
|
|
||||||
def _check_register_layer(
|
def _check_register_layer(
|
||||||
self, activation_sparsifier, defaults, sparse_config, layer_args_list
|
self, activation_sparsifier, defaults, sparse_config, layer_args_list
|
||||||
|
|
|
||||||
|
|
@ -3074,7 +3074,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
|
||||||
wrong_dtype_shards, [10, 10], init_rrefs=True
|
wrong_dtype_shards, [10, 10], init_rrefs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
tensor_requires_grad = True if self.rank == 0 else False
|
tensor_requires_grad = self.rank == 0
|
||||||
wrong_requires_grad_shards = [
|
wrong_requires_grad_shards = [
|
||||||
sharded_tensor.Shard(
|
sharded_tensor.Shard(
|
||||||
torch.randn(
|
torch.randn(
|
||||||
|
|
@ -3121,7 +3121,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
|
||||||
wrong_pin_memory_local_shards, [10, 10], init_rrefs=True
|
wrong_pin_memory_local_shards, [10, 10], init_rrefs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
tensor_pin_memory = True if self.rank == 0 else False
|
tensor_pin_memory = self.rank == 0
|
||||||
wrong_pin_memory_shards_cross_ranks = [
|
wrong_pin_memory_shards_cross_ranks = [
|
||||||
sharded_tensor.Shard(
|
sharded_tensor.Shard(
|
||||||
torch.randn(5, 5, pin_memory=tensor_pin_memory), local_shard_metadata
|
torch.randn(5, 5, pin_memory=tensor_pin_memory), local_shard_metadata
|
||||||
|
|
|
||||||
|
|
@ -152,7 +152,7 @@ class TestStorageBase:
|
||||||
self.rank = 0 if not dist.is_initialized() else dist.get_rank()
|
self.rank = 0 if not dist.is_initialized() else dist.get_rank()
|
||||||
|
|
||||||
def _get_ranks(self, name):
|
def _get_ranks(self, name):
|
||||||
return self.fail_conf[name] if name in self.fail_conf else None
|
return self.fail_conf.get(name, None)
|
||||||
|
|
||||||
def _fail_rank(self, name):
|
def _fail_rank(self, name):
|
||||||
ranks = self._get_ranks(name)
|
ranks = self._get_ranks(name)
|
||||||
|
|
|
||||||
|
|
@ -155,7 +155,7 @@ class TestFreezingWeights(FSDPTest):
|
||||||
|
|
||||||
ddp_kwargs = {
|
ddp_kwargs = {
|
||||||
"device_ids": [self.rank],
|
"device_ids": [self.rank],
|
||||||
"find_unused_parameters": True if disable_autograd else False,
|
"find_unused_parameters": bool(disable_autograd),
|
||||||
}
|
}
|
||||||
|
|
||||||
model = self._create_model(
|
model = self._create_model(
|
||||||
|
|
|
||||||
|
|
@ -66,7 +66,7 @@ class MockPipelineStage(_PipelineStageBase):
|
||||||
self.num_stages = kwargs.get("num_stages", 1)
|
self.num_stages = kwargs.get("num_stages", 1)
|
||||||
self.group_size = kwargs.get("group_size", 1)
|
self.group_size = kwargs.get("group_size", 1)
|
||||||
self.group_rank = kwargs.get("group_rank", 0)
|
self.group_rank = kwargs.get("group_rank", 0)
|
||||||
self.group = kwargs.get("group", None)
|
self.group = kwargs.get("group")
|
||||||
|
|
||||||
def _create_grad_recv_info(self, *args, **kwargs):
|
def _create_grad_recv_info(self, *args, **kwargs):
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
|
|
@ -1066,7 +1066,7 @@ class TestDTensorPlacementTypes(DTensorTestBase):
|
||||||
assert_array_equal(expected_pad_sizes, pad_sizes)
|
assert_array_equal(expected_pad_sizes, pad_sizes)
|
||||||
|
|
||||||
is_tensor_empty = [
|
is_tensor_empty = [
|
||||||
False if splitted_tensor.numel() > 0 else True
|
not splitted_tensor.numel() > 0
|
||||||
for splitted_tensor in splitted_tensor_list
|
for splitted_tensor in splitted_tensor_list
|
||||||
]
|
]
|
||||||
expected_is_tensor_empty = [True] * self.world_size
|
expected_is_tensor_empty = [True] * self.world_size
|
||||||
|
|
@ -1089,12 +1089,10 @@ class TestDTensorPlacementTypes(DTensorTestBase):
|
||||||
for i, tensor in enumerate(splitted_tensor_list)
|
for i, tensor in enumerate(splitted_tensor_list)
|
||||||
]
|
]
|
||||||
expected_is_tensor_empty = [
|
expected_is_tensor_empty = [
|
||||||
False if idx < size else True
|
not idx < size for idx, _ in enumerate(range(self.world_size))
|
||||||
for idx, _ in enumerate(range(self.world_size))
|
|
||||||
]
|
]
|
||||||
is_tensor_empty = [
|
is_tensor_empty = [
|
||||||
False if unpadded_tensor.numel() > 0 else True
|
not unpadded_tensor.numel() > 0 for unpadded_tensor in unpadded_list
|
||||||
for unpadded_tensor in unpadded_list
|
|
||||||
]
|
]
|
||||||
assert_array_equal(expected_is_tensor_empty, is_tensor_empty)
|
assert_array_equal(expected_is_tensor_empty, is_tensor_empty)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2770,11 +2770,7 @@ class WorkHookTest(MultiProcessTestCase):
|
||||||
# from rank0 to other ranks. However, this is DDP's internal implementation,
|
# from rank0 to other ranks. However, this is DDP's internal implementation,
|
||||||
# which is subject to change in future versions.
|
# which is subject to change in future versions.
|
||||||
self.assertTrue(num_hook_fired[OpType.BROADCAST] > 0)
|
self.assertTrue(num_hook_fired[OpType.BROADCAST] > 0)
|
||||||
ctor_allreduce = (
|
ctor_allreduce = num_hook_fired.get(OpType.ALLREDUCE, 0)
|
||||||
num_hook_fired[OpType.ALLREDUCE]
|
|
||||||
if OpType.ALLREDUCE in num_hook_fired
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
x = torch.zeros(2, 1000).cuda(self.rank)
|
x = torch.zeros(2, 1000).cuda(self.rank)
|
||||||
ddp(x).sum().backward()
|
ddp(x).sum().backward()
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,7 @@ def grad(L, desired_results: list[Variable]) -> list[Variable]:
|
||||||
# look up dL_dentries. If a variable is never used to compute the loss,
|
# look up dL_dentries. If a variable is never used to compute the loss,
|
||||||
# we consider its gradient None, see the note below about zeros for more information.
|
# we consider its gradient None, see the note below about zeros for more information.
|
||||||
def gather_grad(entries: list[str]):
|
def gather_grad(entries: list[str]):
|
||||||
return [dL_d[entry] if entry in dL_d else None for entry in entries]
|
return [dL_d.get(entry) for entry in entries]
|
||||||
|
|
||||||
# propagate the gradient information backward
|
# propagate the gradient information backward
|
||||||
for entry in reversed(gradient_tape):
|
for entry in reversed(gradient_tape):
|
||||||
|
|
|
||||||
|
|
@ -286,7 +286,7 @@ class OptionalScaledTensor(torch.Tensor):
|
||||||
def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride):
|
def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride):
|
||||||
return OptionalScaledTensor(
|
return OptionalScaledTensor(
|
||||||
inner_tensors["_data"],
|
inner_tensors["_data"],
|
||||||
inner_tensors["_scale"] if "_scale" in inner_tensors else None,
|
inner_tensors.get("_scale", None),
|
||||||
constant=metadata["_constant"],
|
constant=metadata["_constant"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -358,9 +358,7 @@ def _sequential_split_inline_tests():
|
||||||
|
|
||||||
for i, node in enumerate(insert_locs):
|
for i, node in enumerate(insert_locs):
|
||||||
with gm.graph.inserting_before(node):
|
with gm.graph.inserting_before(node):
|
||||||
gm.graph.call_function(
|
gm.graph.call_function(torch._C._set_grad_enabled, (i % 2 == 0,), {})
|
||||||
torch._C._set_grad_enabled, (True if i % 2 == 0 else False,), {}
|
|
||||||
)
|
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
x = torch.randn(2, 2)
|
x = torch.randn(2, 2)
|
||||||
|
|
|
||||||
|
|
@ -2932,9 +2932,7 @@ class GraphModule(torch.nn.Module):
|
||||||
if autograd:
|
if autograd:
|
||||||
result_flat = pytree.tree_leaves(result)
|
result_flat = pytree.tree_leaves(result)
|
||||||
result_exp_flat = pytree.tree_leaves(result_exp)
|
result_exp_flat = pytree.tree_leaves(result_exp)
|
||||||
exp_grad_mask = [
|
exp_grad_mask = [bool(r.requires_grad) for r in result_exp_flat]
|
||||||
True if r.requires_grad else False for r in result_exp_flat
|
|
||||||
]
|
|
||||||
self.check_autograd(
|
self.check_autograd(
|
||||||
[r for r, m in zip(result_flat, exp_grad_mask) if m],
|
[r for r, m in zip(result_flat, exp_grad_mask) if m],
|
||||||
[r for r, m in zip(result_exp_flat, exp_grad_mask) if m],
|
[r for r, m in zip(result_exp_flat, exp_grad_mask) if m],
|
||||||
|
|
@ -3741,9 +3739,7 @@ class AssociativeScanTests(TestCase):
|
||||||
):
|
):
|
||||||
result_flat = pytree.tree_leaves(result)
|
result_flat = pytree.tree_leaves(result)
|
||||||
result_exp_flat = pytree.tree_leaves(result_exp)
|
result_exp_flat = pytree.tree_leaves(result_exp)
|
||||||
exp_grad_mask = [
|
exp_grad_mask = [bool(r.requires_grad) for r in result_exp_flat]
|
||||||
True if r.requires_grad else False for r in result_exp_flat
|
|
||||||
]
|
|
||||||
|
|
||||||
self._check_autograd(
|
self._check_autograd(
|
||||||
[r for r, m in zip(result_flat, exp_grad_mask) if m],
|
[r for r, m in zip(result_flat, exp_grad_mask) if m],
|
||||||
|
|
@ -5710,10 +5706,9 @@ def forward(self, arg0_1):
|
||||||
)
|
)
|
||||||
def test_while_loop_tracing(self, while_loop_test):
|
def test_while_loop_tracing(self, while_loop_test):
|
||||||
fn, inp = WHILE_LOOP_TESTS[while_loop_test]
|
fn, inp = WHILE_LOOP_TESTS[while_loop_test]
|
||||||
allow_non_fake_inputs = (
|
allow_non_fake_inputs = while_loop_test in (
|
||||||
False
|
"simple_with_linear",
|
||||||
if while_loop_test not in ("simple_with_linear", "nested_with_linear")
|
"nested_with_linear",
|
||||||
else True
|
|
||||||
)
|
)
|
||||||
self._check_tracing(fn, inp, allow_non_fake_inputs)
|
self._check_tracing(fn, inp, allow_non_fake_inputs)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -177,9 +177,7 @@ class TestFXNodeSource(TestCase):
|
||||||
for node_name_2 in node_name_to_from_node:
|
for node_name_2 in node_name_to_from_node:
|
||||||
if node_name_2 in {
|
if node_name_2 in {
|
||||||
node_name_1,
|
node_name_1,
|
||||||
same_ancestor_nodes[node_name_1]
|
same_ancestor_nodes.get(node_name_1),
|
||||||
if node_name_1 in same_ancestor_nodes
|
|
||||||
else None,
|
|
||||||
}:
|
}:
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
node_name_to_from_node[node_name_1],
|
node_name_to_from_node[node_name_1],
|
||||||
|
|
|
||||||
|
|
@ -164,9 +164,7 @@ class B2BGEMMTest(TestCase):
|
||||||
self.assertTrue("B2B_GEMM_LEFT_TRITON_ENTRANCE" not in code)
|
self.assertTrue("B2B_GEMM_LEFT_TRITON_ENTRANCE" not in code)
|
||||||
self.assertTrue("B2B_GEMM_RIGHT_TRITON_ENTRANCE" not in code)
|
self.assertTrue("B2B_GEMM_RIGHT_TRITON_ENTRANCE" not in code)
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(os.environ.get("DO_PERF_TEST") != "1", "Perf test not enabled")
|
||||||
not (os.environ.get("DO_PERF_TEST") == "1"), "Perf test not enabled"
|
|
||||||
)
|
|
||||||
@torch._dynamo.config.patch(recompile_limit=32)
|
@torch._dynamo.config.patch(recompile_limit=32)
|
||||||
def test_plain_b2b_gemm_performance(self):
|
def test_plain_b2b_gemm_performance(self):
|
||||||
"""compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)"""
|
"""compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)"""
|
||||||
|
|
@ -219,9 +217,7 @@ class B2BGEMMTest(TestCase):
|
||||||
# flaky test assertion: disabled
|
# flaky test assertion: disabled
|
||||||
# self.assertTrue(average_speedup > 1)
|
# self.assertTrue(average_speedup > 1)
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(os.environ.get("DO_PERF_TEST") != "1", "Perf test not enabled")
|
||||||
not (os.environ.get("DO_PERF_TEST") == "1"), "Perf test not enabled"
|
|
||||||
)
|
|
||||||
@torch._dynamo.config.patch(recompile_limit=32)
|
@torch._dynamo.config.patch(recompile_limit=32)
|
||||||
def test_gelu_b2b_gemm_performance(self):
|
def test_gelu_b2b_gemm_performance(self):
|
||||||
"""compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)"""
|
"""compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)"""
|
||||||
|
|
@ -276,9 +272,7 @@ class B2BGEMMTest(TestCase):
|
||||||
# flaky test assertion: disabled
|
# flaky test assertion: disabled
|
||||||
# self.assertTrue(average_speedup > 1)
|
# self.assertTrue(average_speedup > 1)
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(os.environ.get("DO_PERF_TEST") != "1", "Perf test not enabled")
|
||||||
not (os.environ.get("DO_PERF_TEST") == "1"), "Perf test not enabled"
|
|
||||||
)
|
|
||||||
@torch._dynamo.config.patch(recompile_limit=32)
|
@torch._dynamo.config.patch(recompile_limit=32)
|
||||||
def test_gelu_mlp_b2b_gemm_performance(self):
|
def test_gelu_mlp_b2b_gemm_performance(self):
|
||||||
"""compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)"""
|
"""compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)"""
|
||||||
|
|
|
||||||
|
|
@ -165,7 +165,7 @@ class BenchmarkFusionTestTemplate:
|
||||||
_, out_code = run_and_get_code(foo_c, m, inp)
|
_, out_code = run_and_get_code(foo_c, m, inp)
|
||||||
|
|
||||||
# occasionally, CI will make this one kernel. just skip in this case
|
# occasionally, CI will make this one kernel. just skip in this case
|
||||||
if not out_code[0].count("def triton_") == 2:
|
if out_code[0].count("def triton_") != 2:
|
||||||
return
|
return
|
||||||
|
|
||||||
# should be multiple triton invocations
|
# should be multiple triton invocations
|
||||||
|
|
|
||||||
|
|
@ -289,7 +289,7 @@ def build_opt_kwarg_db():
|
||||||
|
|
||||||
has_tensor_lr = False
|
has_tensor_lr = False
|
||||||
for key, val in kwargs.items():
|
for key, val in kwargs.items():
|
||||||
if (not key == "lr" and not key == "betas") and (
|
if (key != "lr" and key != "betas") and (
|
||||||
not isinstance(val, bool) or (isinstance(val, bool) and val)
|
not isinstance(val, bool) or (isinstance(val, bool) and val)
|
||||||
):
|
):
|
||||||
name += "_" + key
|
name += "_" + key
|
||||||
|
|
@ -450,7 +450,7 @@ def make_test(
|
||||||
stack.enter_context(config.patch({"triton.cudagraphs": True}))
|
stack.enter_context(config.patch({"triton.cudagraphs": True}))
|
||||||
|
|
||||||
kwargs_compiled = deepcopy(kwargs)
|
kwargs_compiled = deepcopy(kwargs)
|
||||||
if isinstance(kwargs.get("lr", None), torch.Tensor):
|
if isinstance(kwargs.get("lr"), torch.Tensor):
|
||||||
kwargs["lr"] = kwargs["lr"].to(device)
|
kwargs["lr"] = kwargs["lr"].to(device)
|
||||||
kwargs_compiled["lr"] = kwargs_compiled["lr"].to(device)
|
kwargs_compiled["lr"] = kwargs_compiled["lr"].to(device)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -177,7 +177,7 @@ if HAS_CUDA_AND_TRITON:
|
||||||
|
|
||||||
def get_manager(self, device_index=None):
|
def get_manager(self, device_index=None):
|
||||||
return torch._inductor.cudagraph_trees.get_container(
|
return torch._inductor.cudagraph_trees.get_container(
|
||||||
self.device_idx if not device_index else device_index
|
device_index if device_index else self.device_idx
|
||||||
).tree_manager
|
).tree_manager
|
||||||
|
|
||||||
def get_roots(self):
|
def get_roots(self):
|
||||||
|
|
|
||||||
|
|
@ -585,9 +585,7 @@ class TestFlexAttention(InductorTestCase):
|
||||||
)
|
)
|
||||||
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
|
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
|
||||||
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
|
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
|
||||||
sdpa_partial = create_attention(
|
sdpa_partial = create_attention(score_mod, block_mask, enable_gqa=(Q_H != KV_H))
|
||||||
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
|
|
||||||
)
|
|
||||||
|
|
||||||
compiled_sdpa = torch.compile(sdpa_partial)
|
compiled_sdpa = torch.compile(sdpa_partial)
|
||||||
golden_out = sdpa_partial(q_gold, k_gold, v_gold)
|
golden_out = sdpa_partial(q_gold, k_gold, v_gold)
|
||||||
|
|
@ -761,7 +759,7 @@ class TestFlexAttention(InductorTestCase):
|
||||||
return_lse=return_lse,
|
return_lse=return_lse,
|
||||||
block_mask=converted_block_mask,
|
block_mask=converted_block_mask,
|
||||||
score_mod=converted_score_mod,
|
score_mod=converted_score_mod,
|
||||||
enable_gqa=(not Q_H == KV_H),
|
enable_gqa=(Q_H != KV_H),
|
||||||
kernel_options=kernel_options,
|
kernel_options=kernel_options,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
@ -774,7 +772,7 @@ class TestFlexAttention(InductorTestCase):
|
||||||
return_lse=return_lse,
|
return_lse=return_lse,
|
||||||
block_mask=converted_block_mask,
|
block_mask=converted_block_mask,
|
||||||
score_mod=converted_score_mod,
|
score_mod=converted_score_mod,
|
||||||
enable_gqa=(not Q_H == KV_H),
|
enable_gqa=(Q_H != KV_H),
|
||||||
kernel_options=kernel_options,
|
kernel_options=kernel_options,
|
||||||
)
|
)
|
||||||
return compiled_out, compiled_lse
|
return compiled_out, compiled_lse
|
||||||
|
|
@ -819,9 +817,7 @@ class TestFlexAttention(InductorTestCase):
|
||||||
if block_mask is None:
|
if block_mask is None:
|
||||||
block_mask = create_block_mask(noop_mask, Q_B, 1, Q_S, KV_S, device=device)
|
block_mask = create_block_mask(noop_mask, Q_B, 1, Q_S, KV_S, device=device)
|
||||||
|
|
||||||
sdpa_partial = create_attention(
|
sdpa_partial = create_attention(score_mod, block_mask, enable_gqa=(Q_H != KV_H))
|
||||||
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
|
|
||||||
)
|
|
||||||
golden_out, golden_lse = sdpa_partial(q_gold, k_gold, v_gold, return_lse=True)
|
golden_out, golden_lse = sdpa_partial(q_gold, k_gold, v_gold, return_lse=True)
|
||||||
ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True)
|
ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True)
|
||||||
|
|
||||||
|
|
@ -1466,7 +1462,7 @@ class TestFlexAttention(InductorTestCase):
|
||||||
|
|
||||||
block_mask = create_block_mask(mask_mod, Bq, 1, S, S, device=device)
|
block_mask = create_block_mask(mask_mod, Bq, 1, S, S, device=device)
|
||||||
attention = functools.partial(
|
attention = functools.partial(
|
||||||
flex_attention, block_mask=block_mask, enable_gqa=(not Hq == Hkv)
|
flex_attention, block_mask=block_mask, enable_gqa=(Hq != Hkv)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.run_test_with_call(attention, dtype, device, Bq, Hq, S, D, Bkv, Hkv, S, D)
|
self.run_test_with_call(attention, dtype, device, Bq, Hq, S, D, Bkv, Hkv, S, D)
|
||||||
|
|
|
||||||
|
|
@ -412,7 +412,7 @@ class TestFlexDecoding(InductorTestCase):
|
||||||
sdpa_partial = create_attention(
|
sdpa_partial = create_attention(
|
||||||
score_mod,
|
score_mod,
|
||||||
block_mask,
|
block_mask,
|
||||||
enable_gqa=(not Q_H == KV_H),
|
enable_gqa=(Q_H != KV_H),
|
||||||
kernel_options=kernel_options,
|
kernel_options=kernel_options,
|
||||||
)
|
)
|
||||||
compiled_sdpa = torch.compile(sdpa_partial)
|
compiled_sdpa = torch.compile(sdpa_partial)
|
||||||
|
|
@ -607,7 +607,7 @@ class TestFlexDecoding(InductorTestCase):
|
||||||
return_lse=True,
|
return_lse=True,
|
||||||
block_mask=converted_block_mask,
|
block_mask=converted_block_mask,
|
||||||
score_mod=converted_score_mod,
|
score_mod=converted_score_mod,
|
||||||
enable_gqa=(not Q_H == KV_H),
|
enable_gqa=(Q_H != KV_H),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
compiled_lse = None
|
compiled_lse = None
|
||||||
|
|
@ -618,7 +618,7 @@ class TestFlexDecoding(InductorTestCase):
|
||||||
return_lse=False,
|
return_lse=False,
|
||||||
block_mask=converted_block_mask,
|
block_mask=converted_block_mask,
|
||||||
score_mod=converted_score_mod,
|
score_mod=converted_score_mod,
|
||||||
enable_gqa=(not Q_H == KV_H),
|
enable_gqa=(Q_H != KV_H),
|
||||||
)
|
)
|
||||||
return compiled_out, compiled_lse
|
return compiled_out, compiled_lse
|
||||||
|
|
||||||
|
|
@ -664,9 +664,7 @@ class TestFlexDecoding(InductorTestCase):
|
||||||
if block_mask is None:
|
if block_mask is None:
|
||||||
block_mask = create_block_mask(noop_mask, Q_B, 1, 1, KV_S, device=device)
|
block_mask = create_block_mask(noop_mask, Q_B, 1, 1, KV_S, device=device)
|
||||||
|
|
||||||
sdpa_partial = create_attention(
|
sdpa_partial = create_attention(score_mod, block_mask, enable_gqa=(Q_H != KV_H))
|
||||||
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
|
|
||||||
)
|
|
||||||
golden_out, gold_lse = sdpa_partial(q_gold, k_gold, v_gold, return_lse=True)
|
golden_out, gold_lse = sdpa_partial(q_gold, k_gold, v_gold, return_lse=True)
|
||||||
ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True)
|
ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True)
|
||||||
|
|
||||||
|
|
@ -906,7 +904,7 @@ class TestFlexDecoding(InductorTestCase):
|
||||||
sdpa_partial = create_attention(
|
sdpa_partial = create_attention(
|
||||||
score_mod=score_mod,
|
score_mod=score_mod,
|
||||||
block_mask=None,
|
block_mask=None,
|
||||||
enable_gqa=(not Hq == Hkv),
|
enable_gqa=(Hq != Hkv),
|
||||||
)
|
)
|
||||||
compiled_sdpa = torch.compile(sdpa_partial)
|
compiled_sdpa = torch.compile(sdpa_partial)
|
||||||
ref_out = sdpa_partial(q, k, v)
|
ref_out = sdpa_partial(q, k, v)
|
||||||
|
|
@ -1144,7 +1142,7 @@ class TestFlexDecoding(InductorTestCase):
|
||||||
|
|
||||||
def head_attention_mod(kv_head_num):
|
def head_attention_mod(kv_head_num):
|
||||||
head_type = torch.tensor(
|
head_type = torch.tensor(
|
||||||
[False if i % kv_head_num == 0 else True for i in range(kv_head_num)],
|
[i % kv_head_num != 0 for i in range(kv_head_num)],
|
||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ except ImportError:
|
||||||
HAS_PYDOT = False
|
HAS_PYDOT = False
|
||||||
|
|
||||||
|
|
||||||
HAS_DOT = True if shutil.which("dot") is not None else False
|
HAS_DOT = shutil.which("dot") is not None
|
||||||
|
|
||||||
|
|
||||||
class TestGraphTransformObserver(TestCase):
|
class TestGraphTransformObserver(TestCase):
|
||||||
|
|
|
||||||
|
|
@ -835,9 +835,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||||
|
|
||||||
for dtype in dtypes:
|
for dtype in dtypes:
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
autocast_enabled = (
|
autocast_enabled = dtype in [torch.bfloat16, torch.float16]
|
||||||
True if dtype in [torch.bfloat16, torch.float16] else False
|
|
||||||
)
|
|
||||||
with (
|
with (
|
||||||
torch.no_grad(),
|
torch.no_grad(),
|
||||||
torch.autocast(
|
torch.autocast(
|
||||||
|
|
@ -4421,14 +4419,12 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||||
out_feature = 64
|
out_feature = 64
|
||||||
q_min, q_max = -32, 31
|
q_min, q_max = -32, 31
|
||||||
# we only test for qlinear_binary in this case
|
# we only test for qlinear_binary in this case
|
||||||
test_for_pointwise_binary = (
|
test_for_pointwise_binary = bool(
|
||||||
True
|
M == 1
|
||||||
if M == 1
|
|
||||||
and inplace_add
|
and inplace_add
|
||||||
and not expand_a_scale
|
and not expand_a_scale
|
||||||
and not dynamic
|
and not dynamic
|
||||||
and not has_bias
|
and not has_bias
|
||||||
else False
|
|
||||||
)
|
)
|
||||||
if test_for_pointwise_binary and not IS_X86:
|
if test_for_pointwise_binary and not IS_X86:
|
||||||
self.skipTest("Some UTs are only supported on x86_64 CPUs")
|
self.skipTest("Some UTs are only supported on x86_64 CPUs")
|
||||||
|
|
|
||||||
|
|
@ -706,7 +706,7 @@ def check_model_gpu(
|
||||||
if check_lowp:
|
if check_lowp:
|
||||||
|
|
||||||
def downcast_fn(x):
|
def downcast_fn(x):
|
||||||
if not isinstance(x, torch.Tensor) or not x.dtype == torch.float:
|
if not isinstance(x, torch.Tensor) or x.dtype != torch.float:
|
||||||
return x
|
return x
|
||||||
return torch.empty_strided(
|
return torch.empty_strided(
|
||||||
x.size(), x.stride(), device=GPU_TYPE, dtype=torch.half
|
x.size(), x.stride(), device=GPU_TYPE, dtype=torch.half
|
||||||
|
|
@ -4694,7 +4694,7 @@ class CommonTemplate:
|
||||||
# Make sure we compute also with fp16 in the reference. Otherwise,
|
# Make sure we compute also with fp16 in the reference. Otherwise,
|
||||||
# the reference will compute with fp32 and cast back to fp16, which
|
# the reference will compute with fp32 and cast back to fp16, which
|
||||||
# causes numeric differences beyond tolerance.
|
# causes numeric differences beyond tolerance.
|
||||||
reference_in_float=False if torch.version.hip else True,
|
reference_in_float=not torch.version.hip,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_convolution2(self):
|
def test_convolution2(self):
|
||||||
|
|
@ -4728,7 +4728,7 @@ class CommonTemplate:
|
||||||
# Make sure we compute also with fp16 in the reference. Otherwise,
|
# Make sure we compute also with fp16 in the reference. Otherwise,
|
||||||
# the reference will compute with fp32 and cast back to fp16, which
|
# the reference will compute with fp32 and cast back to fp16, which
|
||||||
# causes numeric differences beyond tolerance.
|
# causes numeric differences beyond tolerance.
|
||||||
reference_in_float=False if torch.version.hip else True,
|
reference_in_float=not torch.version.hip,
|
||||||
)
|
)
|
||||||
|
|
||||||
@skip_if_gpu_halide
|
@skip_if_gpu_halide
|
||||||
|
|
@ -4779,7 +4779,7 @@ class CommonTemplate:
|
||||||
# Make sure we compute also with fp16 in the reference. Otherwise,
|
# Make sure we compute also with fp16 in the reference. Otherwise,
|
||||||
# the reference will compute with fp32 and cast back to fp16, which
|
# the reference will compute with fp32 and cast back to fp16, which
|
||||||
# causes numeric differences beyond tolerance.
|
# causes numeric differences beyond tolerance.
|
||||||
reference_in_float=False if torch.version.hip else True,
|
reference_in_float=not torch.version.hip,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_conv2d_channels_last(self):
|
def test_conv2d_channels_last(self):
|
||||||
|
|
@ -12970,7 +12970,7 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||||
)
|
)
|
||||||
|
|
||||||
res = torch.compile(fn)(20)
|
res = torch.compile(fn)(20)
|
||||||
self.assertTrue(torch.all((0 <= res) & (res < 10)).item())
|
self.assertTrue(torch.all((res >= 0) & (res < 10)).item())
|
||||||
|
|
||||||
@torch._inductor.config.patch(force_shape_pad=True)
|
@torch._inductor.config.patch(force_shape_pad=True)
|
||||||
@skip_if_gpu_halide # correctness issue
|
@skip_if_gpu_halide # correctness issue
|
||||||
|
|
|
||||||
|
|
@ -1220,7 +1220,7 @@ class TestInductorOpInfo(TestCase):
|
||||||
# not exercised in test_ops_gradients atm. The problem is not
|
# not exercised in test_ops_gradients atm. The problem is not
|
||||||
# complex32 per-se (which is supported by data movement only ops)
|
# complex32 per-se (which is supported by data movement only ops)
|
||||||
# but that when we do backwards we expect other ops like add to work
|
# but that when we do backwards we expect other ops like add to work
|
||||||
and not dtype == torch.complex32
|
and dtype != torch.complex32
|
||||||
)
|
)
|
||||||
samples = op.sample_inputs(device, dtype, requires_grad=requires_grad)
|
samples = op.sample_inputs(device, dtype, requires_grad=requires_grad)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,17 +17,13 @@ with open(sys.argv[1]) as input_yaml_file:
|
||||||
for info in model_infos:
|
for info in model_infos:
|
||||||
for op in info["root_operators"]:
|
for op in info["root_operators"]:
|
||||||
# aggregate occurance per op
|
# aggregate occurance per op
|
||||||
root_operators[op] = 1 + (root_operators[op] if op in root_operators else 0)
|
root_operators[op] = 1 + (root_operators.get(op, 0))
|
||||||
for op in info["traced_operators"]:
|
for op in info["traced_operators"]:
|
||||||
# aggregate occurance per op
|
# aggregate occurance per op
|
||||||
traced_operators[op] = 1 + (
|
traced_operators[op] = 1 + (traced_operators.get(op, 0))
|
||||||
traced_operators[op] if op in traced_operators else 0
|
|
||||||
)
|
|
||||||
# merge dtypes for each kernel
|
# merge dtypes for each kernel
|
||||||
for kernal, dtypes in info["kernel_metadata"].items():
|
for kernal, dtypes in info["kernel_metadata"].items():
|
||||||
new_dtypes = dtypes + (
|
new_dtypes = dtypes + (kernel_metadata.get(kernal, []))
|
||||||
kernel_metadata[kernal] if kernal in kernel_metadata else []
|
|
||||||
)
|
|
||||||
kernel_metadata[kernal] = list(set(new_dtypes))
|
kernel_metadata[kernal] = list(set(new_dtypes))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4879,7 +4879,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||||
@skipScriptTest()
|
@skipScriptTest()
|
||||||
def test_rnn_no_bias(self):
|
def test_rnn_no_bias(self):
|
||||||
def make_model(layers, packed_sequence):
|
def make_model(layers, packed_sequence):
|
||||||
batch_first = True if packed_sequence == 2 else False
|
batch_first = packed_sequence == 2
|
||||||
model = torch.nn.RNN(
|
model = torch.nn.RNN(
|
||||||
RNN_INPUT_SIZE,
|
RNN_INPUT_SIZE,
|
||||||
RNN_HIDDEN_SIZE,
|
RNN_HIDDEN_SIZE,
|
||||||
|
|
@ -4900,7 +4900,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def make_input(batch_size, layers, packed_sequence):
|
def make_input(batch_size, layers, packed_sequence):
|
||||||
batch_first = True if packed_sequence == 2 else False
|
batch_first = packed_sequence == 2
|
||||||
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
|
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
|
||||||
seq_lengths = sorted(map(int, seq_lengths), reverse=True)
|
seq_lengths = sorted(map(int, seq_lengths), reverse=True)
|
||||||
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
|
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
|
||||||
|
|
|
||||||
|
|
@ -7045,8 +7045,8 @@ class TestQuantizedConv(TestCase):
|
||||||
# ONEDNN only supports symmetric quantization of weight
|
# ONEDNN only supports symmetric quantization of weight
|
||||||
if W_zero_point is not None:
|
if W_zero_point is not None:
|
||||||
W_zero_point = len(W_zero_point) * [0]
|
W_zero_point = len(W_zero_point) * [0]
|
||||||
fp32_output = True if qconv_output_dtype is torch.float32 else False
|
fp32_output = qconv_output_dtype is torch.float32
|
||||||
bfloat16_output = True if qconv_output_dtype is torch.bfloat16 else False
|
bfloat16_output = qconv_output_dtype is torch.bfloat16
|
||||||
if fp32_output or bfloat16_output:
|
if fp32_output or bfloat16_output:
|
||||||
Y_scale = 1.0
|
Y_scale = 1.0
|
||||||
Y_zero_point = 0
|
Y_zero_point = 0
|
||||||
|
|
@ -7905,8 +7905,8 @@ class TestQuantizedConv(TestCase):
|
||||||
weight_in_channel_last_format=False,
|
weight_in_channel_last_format=False,
|
||||||
):
|
):
|
||||||
# We assume FP8 quantization is always symmetric
|
# We assume FP8 quantization is always symmetric
|
||||||
fp32_output = True if qconv_output_dtype is torch.float32 else False
|
fp32_output = qconv_output_dtype is torch.float32
|
||||||
bfloat16_output = True if qconv_output_dtype is torch.bfloat16 else False
|
bfloat16_output = qconv_output_dtype is torch.bfloat16
|
||||||
if fp32_output or bfloat16_output:
|
if fp32_output or bfloat16_output:
|
||||||
Y_scale = 1.0
|
Y_scale = 1.0
|
||||||
X2_scale = 1.0
|
X2_scale = 1.0
|
||||||
|
|
|
||||||
|
|
@ -11861,7 +11861,7 @@ class TestAutogradDeviceType(TestCase):
|
||||||
def test_nonzero(tensor, value, expected):
|
def test_nonzero(tensor, value, expected):
|
||||||
tensor[0] = value
|
tensor[0] = value
|
||||||
self.assertEqual(expected, bool(tensor))
|
self.assertEqual(expected, bool(tensor))
|
||||||
self.assertEqual(expected, True if tensor else False)
|
self.assertEqual(expected, bool(tensor))
|
||||||
|
|
||||||
test_nonzero(l, 0, False)
|
test_nonzero(l, 0, False)
|
||||||
test_nonzero(l, -2, True)
|
test_nonzero(l, -2, True)
|
||||||
|
|
|
||||||
|
|
@ -577,7 +577,7 @@ print(t.is_pinned())
|
||||||
src = torch.randn(
|
src = torch.randn(
|
||||||
1000000,
|
1000000,
|
||||||
device="cuda" if dst == "cpu" else "cpu",
|
device="cuda" if dst == "cpu" else "cpu",
|
||||||
pin_memory=True if dst == "cuda" else False,
|
pin_memory=dst == "cuda",
|
||||||
)
|
)
|
||||||
_test_to_non_blocking(src, try_non_blocking, dst)
|
_test_to_non_blocking(src, try_non_blocking, dst)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -945,7 +945,7 @@ def forward(self, scores_1, mask_1, value_1):
|
||||||
# not exercised in test_ops_gradients atm. The problem is not
|
# not exercised in test_ops_gradients atm. The problem is not
|
||||||
# complex32 per-se (which is supported by data movement only ops)
|
# complex32 per-se (which is supported by data movement only ops)
|
||||||
# but that when we do backwards we expect other ops like add to work
|
# but that when we do backwards we expect other ops like add to work
|
||||||
and not dtype == torch.complex32
|
and dtype != torch.complex32
|
||||||
)
|
)
|
||||||
samples = op.sample_inputs(device, dtype, requires_grad=requires_grad)
|
samples = op.sample_inputs(device, dtype, requires_grad=requires_grad)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3584,7 +3584,7 @@ class TestFX(JitTestCase):
|
||||||
|
|
||||||
class LeafTracerNotB(Tracer):
|
class LeafTracerNotB(Tracer):
|
||||||
def is_leaf_module(self, module, name):
|
def is_leaf_module(self, module, name):
|
||||||
return False if "b" in name else True
|
return "b" not in name
|
||||||
|
|
||||||
# Recompile calls added "for fun", since they
|
# Recompile calls added "for fun", since they
|
||||||
# chain __call__ wrappers.
|
# chain __call__ wrappers.
|
||||||
|
|
|
||||||
|
|
@ -2036,7 +2036,7 @@ class TestIndexing(TestCase):
|
||||||
index = torch.tensor([0], device=device)
|
index = torch.tensor([0], device=device)
|
||||||
x.index_fill_(1, index, 0)
|
x.index_fill_(1, index, 0)
|
||||||
self.assertEqual(x, torch.tensor([[0, 2], [0, 5]], dtype=dtype, device=device))
|
self.assertEqual(x, torch.tensor([[0, 2], [0, 5]], dtype=dtype, device=device))
|
||||||
if not x.is_complex() and not device == "meta":
|
if not x.is_complex() and device != "meta":
|
||||||
with self.assertRaisesRegex(RuntimeError, r"Scalar"):
|
with self.assertRaisesRegex(RuntimeError, r"Scalar"):
|
||||||
x.index_fill_(1, index, 1 + 1j)
|
x.index_fill_(1, index, 1 + 1j)
|
||||||
# Make sure that the result stays 0-dim while applied to
|
# Make sure that the result stays 0-dim while applied to
|
||||||
|
|
|
||||||
|
|
@ -6723,7 +6723,7 @@ a")
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def testNoThrows(t):
|
def testNoThrows(t):
|
||||||
c1 = 1
|
c1 = 1
|
||||||
if (False and bool(t[1])) or (True or bool(t[1])):
|
if (False and bool(t[1])) or (True or bool(t[1])): # noqa: SIM222,SIM223
|
||||||
c1 = 0
|
c1 = 0
|
||||||
return c1
|
return c1
|
||||||
|
|
||||||
|
|
@ -15758,7 +15758,7 @@ dedent """
|
||||||
def fn(d):
|
def fn(d):
|
||||||
# type: (Dict[str, int]) -> List[int]
|
# type: (Dict[str, int]) -> List[int]
|
||||||
out = [1]
|
out = [1]
|
||||||
for i in range(d["hi"] if "hi" in d else 6):
|
for i in range(d.get("hi", 6)):
|
||||||
out.append(i) # noqa: PERF402
|
out.append(i) # noqa: PERF402
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
@ -16104,7 +16104,7 @@ M = 10
|
||||||
S = 5
|
S = 5
|
||||||
|
|
||||||
def add_nn_module_test(*args, **kwargs):
|
def add_nn_module_test(*args, **kwargs):
|
||||||
no_grad = False if 'no_grad' not in kwargs else kwargs['no_grad']
|
no_grad = kwargs.get('no_grad', False)
|
||||||
|
|
||||||
if 'desc' in kwargs and 'eval' in kwargs['desc']:
|
if 'desc' in kwargs and 'eval' in kwargs['desc']:
|
||||||
# eval() is not supported, so skip these tests
|
# eval() is not supported, so skip these tests
|
||||||
|
|
|
||||||
|
|
@ -111,7 +111,7 @@ class TestAutocast(JitTestCase):
|
||||||
def test_runtime_autocast_state_expr(self):
|
def test_runtime_autocast_state_expr(self):
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def fn(a, b):
|
def fn(a, b):
|
||||||
with autocast(enabled=True if a[0][0] > 0.5 else False):
|
with autocast(enabled=bool((a[0][0] > 0.5).item())):
|
||||||
return torch.mm(a, b)
|
return torch.mm(a, b)
|
||||||
# runtime values for autocast enable argument are not supported
|
# runtime values for autocast enable argument are not supported
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
|
|
|
||||||
|
|
@ -3522,7 +3522,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
||||||
nn.RNN(10, 20, batch_first=True)
|
nn.RNN(10, 20, batch_first=True)
|
||||||
]
|
]
|
||||||
# ROCm RNN does not issue warning about single contig chunk of memory, so don't assert it
|
# ROCm RNN does not issue warning about single contig chunk of memory, so don't assert it
|
||||||
first_warn = False if torch.version.hip else True
|
first_warn = not torch.version.hip
|
||||||
for rnn in rnns:
|
for rnn in rnns:
|
||||||
rnn.cuda()
|
rnn.cuda()
|
||||||
input = torch.randn(5, 4, 10, requires_grad=True, device="cuda")
|
input = torch.randn(5, 4, 10, requires_grad=True, device="cuda")
|
||||||
|
|
|
||||||
|
|
@ -205,7 +205,7 @@ class TestNumPyInterop(TestCase):
|
||||||
x = x.conj()
|
x = x.conj()
|
||||||
y = x.resolve_conj()
|
y = x.resolve_conj()
|
||||||
expect_error = (
|
expect_error = (
|
||||||
requires_grad or sparse or conj or not device == "cpu"
|
requires_grad or sparse or conj or device != "cpu"
|
||||||
)
|
)
|
||||||
error_msg = r"Use (t|T)ensor\..*(\.numpy\(\))?"
|
error_msg = r"Use (t|T)ensor\..*(\.numpy\(\))?"
|
||||||
if not force and expect_error:
|
if not force and expect_error:
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ class PruningOpTest(TestCase):
|
||||||
def _generate_rowwise_mask(self, embedding_rows):
|
def _generate_rowwise_mask(self, embedding_rows):
|
||||||
indicator = torch.from_numpy((np.random.random_sample(embedding_rows)).astype(np.float32))
|
indicator = torch.from_numpy((np.random.random_sample(embedding_rows)).astype(np.float32))
|
||||||
threshold = float(np.random.random_sample())
|
threshold = float(np.random.random_sample())
|
||||||
mask = torch.BoolTensor([True if val >= threshold else False for val in indicator])
|
mask = torch.BoolTensor([val >= threshold for val in indicator])
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
def _test_rowwise_prune_op(self, embedding_rows, embedding_dims, indices_type, weights_dtype):
|
def _test_rowwise_prune_op(self, embedding_rows, embedding_dims, indices_type, weights_dtype):
|
||||||
|
|
|
||||||
|
|
@ -1899,7 +1899,7 @@ class TestReductions(TestCase):
|
||||||
# Note [all, any uint8 compatibility]: However for compatibility reason,
|
# Note [all, any uint8 compatibility]: However for compatibility reason,
|
||||||
# for `uint8`, they return Tensor of same dtype `uint8`.
|
# for `uint8`, they return Tensor of same dtype `uint8`.
|
||||||
# Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561
|
# Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561
|
||||||
exact_dtype = True if dtype != torch.uint8 else False
|
exact_dtype = dtype != torch.uint8
|
||||||
|
|
||||||
def _test_all_any(x):
|
def _test_all_any(x):
|
||||||
self.compare_with_numpy(torch.all, np.all, x)
|
self.compare_with_numpy(torch.all, np.all, x)
|
||||||
|
|
|
||||||
|
|
@ -1204,7 +1204,7 @@ class TestFP8Matmul(TestCase):
|
||||||
events = sorted(events, key=lambda x: x['ts'])
|
events = sorted(events, key=lambda x: x['ts'])
|
||||||
# ROCm carveout is invisible except for kernels running slower on fewer CUs
|
# ROCm carveout is invisible except for kernels running slower on fewer CUs
|
||||||
no_carveout, carveout_0, carveout, no_carveout_again = [float(evt.get("dur", "0.0")) for evt in events]
|
no_carveout, carveout_0, carveout, no_carveout_again = [float(evt.get("dur", "0.0")) for evt in events]
|
||||||
if True or not (no_carveout < carveout and carveout_0 < carveout and no_carveout_again < carveout):
|
if True or not (no_carveout < carveout and carveout_0 < carveout and no_carveout_again < carveout): # noqa: SIM222
|
||||||
# something went wrong, print more info to help debug flaky test
|
# something went wrong, print more info to help debug flaky test
|
||||||
print("ROCm debug info for test_honor_sm_carveout")
|
print("ROCm debug info for test_honor_sm_carveout")
|
||||||
print("cu_count", cu_count)
|
print("cu_count", cu_count)
|
||||||
|
|
|
||||||
|
|
@ -129,7 +129,7 @@ class TestSegmentReductions(TestCase):
|
||||||
|
|
||||||
for reduction in reductions:
|
for reduction in reductions:
|
||||||
for initial in [0, None]:
|
for initial in [0, None]:
|
||||||
check_backward = True if initial is not None else False
|
check_backward = initial is not None
|
||||||
initial_value = initial
|
initial_value = initial
|
||||||
default_value = get_default_value(initial_value, reduction)
|
default_value = get_default_value(initial_value, reduction)
|
||||||
if reduction == "max":
|
if reduction == "max":
|
||||||
|
|
@ -186,7 +186,7 @@ class TestSegmentReductions(TestCase):
|
||||||
|
|
||||||
for reduction in reductions:
|
for reduction in reductions:
|
||||||
for initial in [0, None]:
|
for initial in [0, None]:
|
||||||
check_backward = True if initial is not None else False
|
check_backward = initial is not None
|
||||||
initial_value = initial
|
initial_value = initial
|
||||||
default_value = get_default_value(initial_value, reduction)
|
default_value = get_default_value(initial_value, reduction)
|
||||||
if reduction == "max":
|
if reduction == "max":
|
||||||
|
|
@ -244,7 +244,7 @@ class TestSegmentReductions(TestCase):
|
||||||
|
|
||||||
for reduction in reductions:
|
for reduction in reductions:
|
||||||
for initial in [0, None]:
|
for initial in [0, None]:
|
||||||
check_backward = True if initial is not None else False
|
check_backward = initial is not None
|
||||||
initial_value = initial
|
initial_value = initial
|
||||||
default_value = get_default_value(initial_value, reduction)
|
default_value = get_default_value(initial_value, reduction)
|
||||||
if reduction == "max":
|
if reduction == "max":
|
||||||
|
|
|
||||||
|
|
@ -4553,7 +4553,7 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||||
with TemporaryFileName() as f:
|
with TemporaryFileName() as f:
|
||||||
torch.save(m, f)
|
torch.save(m, f)
|
||||||
try:
|
try:
|
||||||
old_value = os.environ[env_var] if env_var in os.environ else None
|
old_value = os.environ.get(env_var, None)
|
||||||
os.environ[env_var] = "1"
|
os.environ[env_var] = "1"
|
||||||
# if weights_only is explicitly set, TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD cannot override it
|
# if weights_only is explicitly set, TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD cannot override it
|
||||||
with self.assertRaisesRegex(pickle.UnpicklingError, "Weights only load failed"):
|
with self.assertRaisesRegex(pickle.UnpicklingError, "Weights only load failed"):
|
||||||
|
|
|
||||||
|
|
@ -4099,7 +4099,7 @@ class TestSparseCompressedTritonKernels(TestCase):
|
||||||
left_alpha = make_tensor(M, dtype=dtype, device=device, low=0.5, high=high) if has_left_alpha else None
|
left_alpha = make_tensor(M, dtype=dtype, device=device, low=0.5, high=high) if has_left_alpha else None
|
||||||
right_alpha = make_tensor(N, dtype=dtype, device=device, low=0.5, high=high) if has_right_alpha else None
|
right_alpha = make_tensor(N, dtype=dtype, device=device, low=0.5, high=high) if has_right_alpha else None
|
||||||
|
|
||||||
if 0 and op == "bsr_dense_addmm":
|
if 0 and op == "bsr_dense_addmm": # noqa: SIM223
|
||||||
# Find optimal kernel parameters, the speed-up is
|
# Find optimal kernel parameters, the speed-up is
|
||||||
# about 10x for running this test.
|
# about 10x for running this test.
|
||||||
#
|
#
|
||||||
|
|
|
||||||
|
|
@ -3498,7 +3498,7 @@ class TestRandomTensorCreation(TestCase):
|
||||||
else:
|
else:
|
||||||
t.uniform_(from_, to_)
|
t.uniform_(from_, to_)
|
||||||
range_ = to_ - from_
|
range_ = to_ - from_
|
||||||
if not (dtype == torch.bfloat16) and not (
|
if dtype != torch.bfloat16 and not (
|
||||||
dtype == torch.half and device == 'cpu') and not torch.isnan(t).all():
|
dtype == torch.half and device == 'cpu') and not torch.isnan(t).all():
|
||||||
delta = alpha * range_
|
delta = alpha * range_
|
||||||
double_t = t.to(torch.double)
|
double_t = t.to(torch.double)
|
||||||
|
|
|
||||||
|
|
@ -359,7 +359,9 @@ class TestFuzzerCompileIssues(TestCase):
|
||||||
t3 = arg1 # size=(1,), stride=(1,), dtype=int64, device=cuda
|
t3 = arg1 # size=(1,), stride=(1,), dtype=int64, device=cuda
|
||||||
t4 = arg2 # size=(1,), stride=(1,), dtype=int64, device=cuda
|
t4 = arg2 # size=(1,), stride=(1,), dtype=int64, device=cuda
|
||||||
t5 = t3 + t3 + t4 # size=(1,), stride=(1,), dtype=int64, device=cuda
|
t5 = t3 + t3 + t4 # size=(1,), stride=(1,), dtype=int64, device=cuda
|
||||||
t6 = torch.exp(t5) # size=(1,), stride=(1,), dtype=int64, device=cuda
|
t6 = torch.exp( # noqa: F841
|
||||||
|
t5
|
||||||
|
) # size=(1,), stride=(1,), dtype=int64, device=cuda # noqa: F841
|
||||||
t7 = torch.nn.functional.layer_norm(
|
t7 = torch.nn.functional.layer_norm(
|
||||||
t2, (111,)
|
t2, (111,)
|
||||||
) # size=(49, 112, 111), stride=(12432, 111, 1), dtype=float32, device=cuda
|
) # size=(49, 112, 111), stride=(12432, 111, 1), dtype=float32, device=cuda
|
||||||
|
|
@ -436,7 +438,7 @@ class TestFuzzerCompileIssues(TestCase):
|
||||||
torch.manual_seed(9)
|
torch.manual_seed(9)
|
||||||
|
|
||||||
def foo(arg0):
|
def foo(arg0):
|
||||||
var_node_1 = arg0 # size=(1, 2), stride=(2, 1), dtype=int64, device=cuda
|
var_node_1 = arg0 # size=(1, 2), stride=(2, 1), dtype=int64, device=cuda # noqa: F841
|
||||||
var_node_5 = torch.full(
|
var_node_5 = torch.full(
|
||||||
(1, 2), -66, dtype=torch.int32
|
(1, 2), -66, dtype=torch.int32
|
||||||
) # size=(1, 2), stride=(2, 1), dtype=int32, device=cuda
|
) # size=(1, 2), stride=(2, 1), dtype=int32, device=cuda
|
||||||
|
|
|
||||||
|
|
@ -100,7 +100,7 @@ class TestBuiltin(TestCase):
|
||||||
# dtypes results in False/True when compared to valid dtypes.
|
# dtypes results in False/True when compared to valid dtypes.
|
||||||
# Here 7 cannot be converted to dtype. No exceptions should be raised
|
# Here 7 cannot be converted to dtype. No exceptions should be raised
|
||||||
|
|
||||||
assert not np.dtype(np.int32) == 7, "dtype richcompare failed for =="
|
assert np.dtype(np.int32) != 7, "dtype richcompare failed for =="
|
||||||
assert np.dtype(np.int32) != 7, "dtype richcompare failed for !="
|
assert np.dtype(np.int32) != 7, "dtype richcompare failed for !="
|
||||||
|
|
||||||
@parametrize("operation", [operator.le, operator.lt, operator.ge, operator.gt])
|
@parametrize("operation", [operator.le, operator.lt, operator.ge, operator.gt])
|
||||||
|
|
|
||||||
|
|
@ -416,11 +416,8 @@ class JsonProfile:
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore # bad-assignment
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
else:
|
else:
|
||||||
if dtype in _dtype_map:
|
# pyrefly: ignore # bad-assignment
|
||||||
# pyrefly: ignore # bad-assignment
|
self.dtype = _dtype_map.get(dtype)
|
||||||
self.dtype = _dtype_map[dtype]
|
|
||||||
else:
|
|
||||||
self.dtype = None
|
|
||||||
self._create_devices()
|
self._create_devices()
|
||||||
|
|
||||||
def convert_dtype(self, event: dict[str, Any]) -> Optional[torch.dtype]:
|
def convert_dtype(self, event: dict[str, Any]) -> Optional[torch.dtype]:
|
||||||
|
|
|
||||||
|
|
@ -1363,7 +1363,7 @@ class TritonOverrides(OpOverrides):
|
||||||
value = triton_reshape(value, initial_shape, shape_2d)
|
value = triton_reshape(value, initial_shape, shape_2d)
|
||||||
|
|
||||||
# broadcast if needed
|
# broadcast if needed
|
||||||
broadcast_needed = not (shape_2d == [YBLOCK, RBLOCK])
|
broadcast_needed = shape_2d != [YBLOCK, RBLOCK]
|
||||||
if broadcast_needed:
|
if broadcast_needed:
|
||||||
value = f"tl.broadcast_to({value}, ({YBLOCK}, {RBLOCK}))"
|
value = f"tl.broadcast_to({value}, ({YBLOCK}, {RBLOCK}))"
|
||||||
|
|
||||||
|
|
@ -1385,7 +1385,7 @@ class TritonOverrides(OpOverrides):
|
||||||
value = f"tl.trans({value})"
|
value = f"tl.trans({value})"
|
||||||
|
|
||||||
# broadcast if needed
|
# broadcast if needed
|
||||||
broadcast_needed = not (shape_2d == [XBLOCK, RBLOCK])
|
broadcast_needed = shape_2d != [XBLOCK, RBLOCK]
|
||||||
if broadcast_needed:
|
if broadcast_needed:
|
||||||
value = f"tl.broadcast_to({value}, ({RBLOCK}, {XBLOCK}))"
|
value = f"tl.broadcast_to({value}, ({RBLOCK}, {XBLOCK}))"
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -1570,7 +1570,7 @@ class Reduction(Loops):
|
||||||
and V.graph.sizevars.size_hint_or_throw(reduction_numel)
|
and V.graph.sizevars.size_hint_or_throw(reduction_numel)
|
||||||
< config.unroll_reductions_threshold
|
< config.unroll_reductions_threshold
|
||||||
and (sympy_product(ranges) != 1 or is_gpu(device.type))
|
and (sympy_product(ranges) != 1 or is_gpu(device.type))
|
||||||
and not (reduction_type == "dot")
|
and reduction_type != "dot"
|
||||||
):
|
):
|
||||||
# When native matmul, don't unroll the dot reduction.
|
# When native matmul, don't unroll the dot reduction.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -834,7 +834,7 @@ class SizeVarAllocator:
|
||||||
any_unbacked_lhs = has_free_unbacked_symbols(lhs)
|
any_unbacked_lhs = has_free_unbacked_symbols(lhs)
|
||||||
any_unbacked_rhs = has_free_unbacked_symbols(rhs)
|
any_unbacked_rhs = has_free_unbacked_symbols(rhs)
|
||||||
if any_unbacked_lhs != any_unbacked_rhs:
|
if any_unbacked_lhs != any_unbacked_rhs:
|
||||||
return True if any_unbacked_rhs else False
|
return bool(any_unbacked_rhs)
|
||||||
|
|
||||||
# Handles cases where LHS contains the RHS. In other words,
|
# Handles cases where LHS contains the RHS. In other words,
|
||||||
# RHS is a sub-expression of LHS. For example:
|
# RHS is a sub-expression of LHS. For example:
|
||||||
|
|
@ -848,12 +848,12 @@ class SizeVarAllocator:
|
||||||
degrees_lhs = len(self.eq_graph[lhs])
|
degrees_lhs = len(self.eq_graph[lhs])
|
||||||
degrees_rhs = len(self.eq_graph[rhs])
|
degrees_rhs = len(self.eq_graph[rhs])
|
||||||
if degrees_lhs != degrees_rhs:
|
if degrees_lhs != degrees_rhs:
|
||||||
return True if degrees_lhs > degrees_rhs else False
|
return degrees_lhs > degrees_rhs
|
||||||
|
|
||||||
# Try to apply union-by-rank optimization to flatten the
|
# Try to apply union-by-rank optimization to flatten the
|
||||||
# leader trees.
|
# leader trees.
|
||||||
if self.rank[x] != self.rank[y]:
|
if self.rank[x] != self.rank[y]:
|
||||||
return True if self.rank[x] > self.rank[y] else False
|
return self.rank[x] > self.rank[y]
|
||||||
|
|
||||||
# Fallback to sympy.Basic.compare for a deterministic ordering.
|
# Fallback to sympy.Basic.compare for a deterministic ordering.
|
||||||
return lhs.compare(rhs) == -1
|
return lhs.compare(rhs) == -1
|
||||||
|
|
|
||||||
|
|
@ -708,7 +708,7 @@ def _distribute_state_dict(
|
||||||
local_state_dict[key] = value.cpu()
|
local_state_dict[key] = value.cpu()
|
||||||
else:
|
else:
|
||||||
assert isinstance(value, torch.Tensor)
|
assert isinstance(value, torch.Tensor)
|
||||||
local_state = local_state_dict.get(key, None)
|
local_state = local_state_dict.get(key)
|
||||||
if local_state is None:
|
if local_state is None:
|
||||||
continue
|
continue
|
||||||
elif isinstance(local_state, DTensor):
|
elif isinstance(local_state, DTensor):
|
||||||
|
|
|
||||||
|
|
@ -6686,7 +6686,7 @@ def scaled_mm(
|
||||||
# So, we need to convert None arguments for lists in python
|
# So, we need to convert None arguments for lists in python
|
||||||
# explicitly into empty lists.
|
# explicitly into empty lists.
|
||||||
def list_or_empty(l: list[_Any] | None) -> list[_Any]:
|
def list_or_empty(l: list[_Any] | None) -> list[_Any]:
|
||||||
return [] if not l else l
|
return l if l else []
|
||||||
|
|
||||||
def enum_list_as_int_list(l: _Any | list[_Any]) -> list[_Any]:
|
def enum_list_as_int_list(l: _Any | list[_Any]) -> list[_Any]:
|
||||||
if not isinstance(l, list):
|
if not isinstance(l, list):
|
||||||
|
|
@ -6772,7 +6772,7 @@ def scaled_grouped_mm(
|
||||||
# So, we need to convert None arguments for lists in python
|
# So, we need to convert None arguments for lists in python
|
||||||
# explicitly into empty lists.
|
# explicitly into empty lists.
|
||||||
def list_or_empty(l: list[_Any] | None) -> list[_Any]:
|
def list_or_empty(l: list[_Any] | None) -> list[_Any]:
|
||||||
return [] if not l else l
|
return l if l else []
|
||||||
|
|
||||||
def enum_list_as_int_list(l: _Any | list[_Any]) -> list[_Any]:
|
def enum_list_as_int_list(l: _Any | list[_Any]) -> list[_Any]:
|
||||||
if not isinstance(l, list):
|
if not isinstance(l, list):
|
||||||
|
|
|
||||||
|
|
@ -150,7 +150,7 @@ def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> str | None:
|
||||||
assert schema.kind() == SchemaKind.inplace
|
assert schema.kind() == SchemaKind.inplace
|
||||||
if not is_mutated_arg(schema.arguments.flat_all[0]):
|
if not is_mutated_arg(schema.arguments.flat_all[0]):
|
||||||
return None
|
return None
|
||||||
if not len([arg for arg in schema.arguments.flat_all if is_mutated_arg(arg)]) == 1:
|
if len([arg for arg in schema.arguments.flat_all if is_mutated_arg(arg)]) != 1:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Only support cases where all returns are Tensors or vector<Tensor>
|
# Only support cases where all returns are Tensors or vector<Tensor>
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user