[BE]: RUF018 - ban assignment in assert (#125125)

Ban assignment inside of assert. Python code should ideally not break with assertions disabled. Adds a ruff lint rule to enforce this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125125
Approved by: https://github.com/ezyang
This commit is contained in:
Aaron Gokaslan 2024-04-28 21:41:34 +00:00 committed by PyTorch MergeBot
parent a05b2ae302
commit 3e1fb96964
3 changed files with 17 additions and 12 deletions

View File

@ -131,6 +131,7 @@ select = [
"RUF015", # access first ele in constant time
"RUF016", # type error non-integer index
"RUF017",
"RUF018", # no assignment in assert
"TRY002", # ban vanilla raise (todo fix NOQAs)
"TRY200", # TODO: migrate from deprecated alias
"TRY302",

View File

@ -2265,13 +2265,11 @@ class ExportedProgramDeserializer(metaclass=Final):
key for key in model_opset_version if key in self.expected_opset_version
}
for namespace in common_namespaces:
assert isinstance(
model_version := model_opset_version[namespace], int
), f"model_opset_version value should be int, got {model_opset_version[namespace]}"
model_version = model_opset_version[namespace]
assert isinstance(model_version, int), f"model_opset_version value should be int, got {model_version}"
assert isinstance(
compiler_version := self.expected_opset_version[namespace], int
), f"expected_opset_version value should be int, got {self.expected_opset_version[namespace]}"
compiler_version = self.expected_opset_version[namespace]
assert isinstance(compiler_version, int), f"expected_opset_version value should be int, got {compiler_version}"
# TODO(larryliu0820): Add support for upgrader & downgrader
if model_version != compiler_version:

View File

@ -292,9 +292,11 @@ class ReductionTypePromotionRule(TypePromotionRule):
def preview_type_promotion(
self, args: tuple, kwargs: dict
) -> TypePromotionSnapshot:
assert len(args) >= 1 and isinstance(
arg := args[0], torch.Tensor
assert (
len(args) >= 1
), f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument"
arg = args[0]
assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor"
dtype: Optional[torch.dtype] = kwargs.get("dtype", None)
computation_dtype, result_dtype = _prims_common.reduction_dtypes(
@ -329,9 +331,11 @@ class AllOrAnyReductionTypePromotionRule(ReductionTypePromotionRule):
def preview_type_promotion(
self, args: tuple, kwargs: dict
) -> TypePromotionSnapshot:
assert len(args) >= 1 and isinstance(
arg := args[0], torch.Tensor
assert (
len(args) >= 1
), f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument"
arg = args[0]
assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor"
computation_dtype = torch.bool
# Preserves uint8 -- probably a legacy mask thing
result_dtype = torch.uint8 if arg.dtype == torch.uint8 else torch.bool
@ -352,9 +356,11 @@ class SumLikeReductionTypePromotionRule(ReductionTypePromotionRule):
def preview_type_promotion(
self, args: tuple, kwargs: dict
) -> TypePromotionSnapshot:
assert len(args) >= 1 and isinstance(
arg := args[0], torch.Tensor
assert (
len(args) >= 1
), f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument"
arg = args[0]
assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor"
dtype: Optional[torch.dtype] = kwargs.get("dtype", None)
# The below logic is copied from `torch/_refs/__init__.py` reduction ops impl.
if dtype is None: