mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE]: RUF018 - ban assignment in assert (#125125)
Ban assignment inside of assert. Python code should ideally not break with assertions disabled. Adds a ruff lint rule to enforce this. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125125 Approved by: https://github.com/ezyang
This commit is contained in:
parent
a05b2ae302
commit
3e1fb96964
|
|
@ -131,6 +131,7 @@ select = [
|
|||
"RUF015", # access first ele in constant time
|
||||
"RUF016", # type error non-integer index
|
||||
"RUF017",
|
||||
"RUF018", # no assignment in assert
|
||||
"TRY002", # ban vanilla raise (todo fix NOQAs)
|
||||
"TRY200", # TODO: migrate from deprecated alias
|
||||
"TRY302",
|
||||
|
|
|
|||
|
|
@ -2265,13 +2265,11 @@ class ExportedProgramDeserializer(metaclass=Final):
|
|||
key for key in model_opset_version if key in self.expected_opset_version
|
||||
}
|
||||
for namespace in common_namespaces:
|
||||
assert isinstance(
|
||||
model_version := model_opset_version[namespace], int
|
||||
), f"model_opset_version value should be int, got {model_opset_version[namespace]}"
|
||||
model_version = model_opset_version[namespace]
|
||||
assert isinstance(model_version, int), f"model_opset_version value should be int, got {model_version}"
|
||||
|
||||
assert isinstance(
|
||||
compiler_version := self.expected_opset_version[namespace], int
|
||||
), f"expected_opset_version value should be int, got {self.expected_opset_version[namespace]}"
|
||||
compiler_version = self.expected_opset_version[namespace]
|
||||
assert isinstance(compiler_version, int), f"expected_opset_version value should be int, got {compiler_version}"
|
||||
|
||||
# TODO(larryliu0820): Add support for upgrader & downgrader
|
||||
if model_version != compiler_version:
|
||||
|
|
|
|||
|
|
@ -292,9 +292,11 @@ class ReductionTypePromotionRule(TypePromotionRule):
|
|||
def preview_type_promotion(
|
||||
self, args: tuple, kwargs: dict
|
||||
) -> TypePromotionSnapshot:
|
||||
assert len(args) >= 1 and isinstance(
|
||||
arg := args[0], torch.Tensor
|
||||
assert (
|
||||
len(args) >= 1
|
||||
), f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument"
|
||||
arg = args[0]
|
||||
assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor"
|
||||
dtype: Optional[torch.dtype] = kwargs.get("dtype", None)
|
||||
|
||||
computation_dtype, result_dtype = _prims_common.reduction_dtypes(
|
||||
|
|
@ -329,9 +331,11 @@ class AllOrAnyReductionTypePromotionRule(ReductionTypePromotionRule):
|
|||
def preview_type_promotion(
|
||||
self, args: tuple, kwargs: dict
|
||||
) -> TypePromotionSnapshot:
|
||||
assert len(args) >= 1 and isinstance(
|
||||
arg := args[0], torch.Tensor
|
||||
assert (
|
||||
len(args) >= 1
|
||||
), f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument"
|
||||
arg = args[0]
|
||||
assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor"
|
||||
computation_dtype = torch.bool
|
||||
# Preserves uint8 -- probably a legacy mask thing
|
||||
result_dtype = torch.uint8 if arg.dtype == torch.uint8 else torch.bool
|
||||
|
|
@ -352,9 +356,11 @@ class SumLikeReductionTypePromotionRule(ReductionTypePromotionRule):
|
|||
def preview_type_promotion(
|
||||
self, args: tuple, kwargs: dict
|
||||
) -> TypePromotionSnapshot:
|
||||
assert len(args) >= 1 and isinstance(
|
||||
arg := args[0], torch.Tensor
|
||||
assert (
|
||||
len(args) >= 1
|
||||
), f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument"
|
||||
arg = args[0]
|
||||
assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor"
|
||||
dtype: Optional[torch.dtype] = kwargs.get("dtype", None)
|
||||
# The below logic is copied from `torch/_refs/__init__.py` reduction ops impl.
|
||||
if dtype is None:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user