[BE]: ruff - enable PIE804 (#113951)

Enables ruff PIE804 which kills some more unnecessary temporary dicts.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113951
Approved by: https://github.com/ezyang, https://github.com/malfet
This commit is contained in:
Aaron Gokaslan 2023-11-17 21:22:58 +00:00 committed by PyTorch MergeBot
parent 4b1583fe57
commit 69d9267c4f
7 changed files with 13 additions and 14 deletions

View File

@ -79,6 +79,7 @@ select = [
"PGH004",
"PIE794",
"PIE800",
"PIE804",
"PIE807",
"PIE810",
"PLE",

View File

@ -1308,12 +1308,10 @@ class TestQuantizeEagerONNXExport(common_utils.TestCase):
with torch.no_grad():
_ = model(
**{
"input_ids": ids["input_ids"],
"attention_mask": ids["attention_mask"],
"decoder_input_ids": ids["input_ids"],
"decoder_attention_mask": ids["attention_mask"],
}
input_ids=ids["input_ids"],
attention_mask=ids["attention_mask"],
decoder_input_ids=ids["input_ids"],
decoder_attention_mask=ids["attention_mask"],
)

View File

@ -3062,9 +3062,9 @@ class TestFrontend(JitTestCase):
res_func = traced_func(**example_input_dict_func)
self.assertEqual(res_func, 2 * torch.ones(1))
with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'x'."):
res_2 = traced_model_2(**{'z': torch.rand([2]), 'y': torch.rand([2])})
res_2 = traced_model_2(**{'z': torch.rand([2]), 'y': torch.rand([2])}) # noqa: PIE804
with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'y'."):
res_2 = traced_model_2(**{'x': torch.rand([2]), 'z': torch.rand([2])})
res_2 = traced_model_2(**{'x': torch.rand([2]), 'z': torch.rand([2])}) # noqa: PIE804
@skipIfTorchDynamo()

View File

@ -2991,14 +2991,14 @@ class TestBroadcast(TestCase):
# gh-13455
arrs = [np.empty((5, 6, 7))]
mit = np.broadcast(*arrs)
mit2 = np.broadcast(*arrs, **{})
mit2 = np.broadcast(*arrs, **{}) # noqa: PIE804
assert_equal(mit.shape, mit2.shape)
assert_equal(mit.ndim, mit2.ndim)
assert_equal(mit.nd, mit2.nd)
assert_equal(mit.numiter, mit2.numiter)
assert_(mit.iters[0].base is mit2.iters[0].base)
assert_raises(ValueError, np.broadcast, 1, **{"x": 1})
assert_raises(ValueError, np.broadcast, 1, **{"x": 1}) # noqa: PIE804
@skip(reason="error messages do not match.")
def test_shape_mismatch_error_message(self):

View File

@ -9,7 +9,7 @@ from torch._export.db.case import export_case, ExportArgs, SupportLevel
(torch.randn(4), torch.randn(4)),
*[torch.randn(4), torch.randn(4)],
mykw0=torch.randn(4),
**{"input0": torch.randn(4), "input1": torch.randn(4)}
input0=torch.randn(4), input1=torch.randn(4)
),
tags={"python.data-structure"},
support_level=SupportLevel.SUPPORTED,

View File

@ -298,7 +298,7 @@ def tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype, *, layout=None):
choices,
input_nodes=(mat1, mat2, mat3),
layout=layout,
**dict(mm_options(config, k, layout), **{"ACC_TYPE": "tl.int32"}),
**dict(mm_options(config, k, layout), ACC_TYPE="tl.int32"),
suffix_args=1,
epilogue_fn=V.ops.mul,
)

View File

@ -8878,7 +8878,7 @@ class foreach_norm_sample_func(foreach_inputs_sample_func):
disable_fastpath = True
if ord in (1, 2) and dtype in floating_types_and(torch.half, torch.bfloat16):
disable_fastpath = False
yield ForeachSampleInput(input, **{"ord": ord, "disable_fastpath": disable_fastpath})
yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath)
def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
num_input_tensors = kwargs.pop("num_input_tensors", foreach_num_tensors)
@ -8891,7 +8891,7 @@ class foreach_norm_sample_func(foreach_inputs_sample_func):
disable_fastpath = True
if ord in (1, 2) and dtype in floating_types_and(torch.half, torch.bfloat16):
disable_fastpath = False
yield ForeachSampleInput(input, **{"ord": ord, "disable_fastpath": disable_fastpath})
yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath)
class foreach_lerp_sample_func(foreach_inputs_sample_func):