[BE] Enable flake8-simplify checks (#97984)

Enable some sensible flake8-simplify rules. Mainly wanted to enable the SIM101, and `yield from` SIM103 checks. @kit1980 since you wanted to be tagged on this CI check.

Enabling this check also helped flag one logical bug so it's definitely beneficial (also fixed in this PR).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97984
Approved by: https://github.com/ezyang
This commit is contained in:
Aaron Gokaslan 2023-03-31 03:40:21 +00:00 committed by PyTorch MergeBot
parent 3dc4405278
commit 9c3fbe7475
9 changed files with 29 additions and 30 deletions

View File

@ -1,6 +1,6 @@
[flake8] [flake8]
enable-extensions = G enable-extensions = G
select = B,C,E,F,G,P,T4,W,B9 select = B,C,E,F,G,P,SIM1,T4,W,B9
max-line-length = 120 max-line-length = 120
# C408 ignored because we like the dict keyword argument syntax # C408 ignored because we like the dict keyword argument syntax
# E501 is not flexible enough, we're using B950 instead # E501 is not flexible enough, we're using B950 instead
@ -16,7 +16,11 @@ ignore =
# these ignores are from flake8-comprehensions; please fix! # these ignores are from flake8-comprehensions; please fix!
C407 C407
# these ignores are from flake8-logging-format; please fix! # these ignores are from flake8-logging-format; please fix!
G001,G002,G003,G004,G100,G101,G200,G201,G202 G001,G002,G003,G004,G100,G101,G200,G201,G202,
# these ignores are from flake8-simplify. please fix or ignore with commented reason
SIM105,SIM108,SIM109,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12,
# flake8-simplify code styles
SIM102,SIM103,SIM106,SIM112,
per-file-ignores = per-file-ignores =
__init__.py: F401 __init__.py: F401
torch/utils/cpp_extension.py: B950 torch/utils/cpp_extension.py: B950

View File

@ -39,6 +39,7 @@ init_command = [
'flake8-executable==2.1.3', 'flake8-executable==2.1.3',
'flake8-logging-format==0.9.0', 'flake8-logging-format==0.9.0',
'flake8-pyi==23.3.1', 'flake8-pyi==23.3.1',
'flake8-simplify==0.19.3',
'mccabe==0.7.0', 'mccabe==0.7.0',
'pycodestyle==2.10.0', 'pycodestyle==2.10.0',
'pyflakes==3.0.1', 'pyflakes==3.0.1',

View File

@ -457,8 +457,8 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
if self.kwdefaults: if self.kwdefaults:
flags |= 0x02 flags |= 0x02
codegen(self.kwdefaults) codegen(self.kwdefaults)
if isinstance(self.annotations, variables.ConstDictVariable) or isinstance( if isinstance(
self.annotations, variables.TupleVariable self.annotations, (variables.ConstDictVariable, variables.TupleVariable)
): ):
flags |= 0x04 flags |= 0x04
try: try:

View File

@ -200,9 +200,7 @@ class ContextWrappingVariable(VariableTracker):
assert len(args) == 1 assert len(args) == 1
if isinstance(args[0], NestedUserFunctionVariable): if isinstance(args[0], NestedUserFunctionVariable):
args[0] = UserFunctionVariable(args[0].get_function()) args[0] = UserFunctionVariable(args[0].get_function())
assert isinstance(args[0], UserMethodVariable) or isinstance( assert isinstance(args[0], (UserMethodVariable, UserFunctionVariable))
args[0], UserFunctionVariable
)
if isinstance(args[0], UserMethodVariable): if isinstance(args[0], UserMethodVariable):
return WrappedUserMethodVariable(args[0], self) return WrappedUserMethodVariable(args[0], self)

View File

@ -4568,7 +4568,7 @@ def meshgrid(
# This ref simultaneously handles two overloads (see stubs above) # This ref simultaneously handles two overloads (see stubs above)
# The `indexing` argument is currently optional for torch.meshgrid, but we # The `indexing` argument is currently optional for torch.meshgrid, but we
# plan to make the argument required: https://github.com/pytorch/pytorch/issues/50276 # plan to make the argument required: https://github.com/pytorch/pytorch/issues/50276
if isinstance(tensors[0], list) or isinstance(tensors[0], tuple): if isinstance(tensors[0], (list, tuple)):
assert len(tensors) == 1 assert len(tensors) == 1
tensors = tuple(tensors[0]) tensors = tuple(tensors[0])

View File

@ -154,7 +154,7 @@ def _dtensor_expand(
if isinstance(a, torch.Tensor): if isinstance(a, torch.Tensor):
inps.append(a) inps.append(a)
schemas.append(shard_schema) schemas.append(shard_schema)
elif isinstance(a, nn.Module) or isinstance(a, torch.optim.Optimizer): elif isinstance(a, (nn.Module, torch.optim.Optimizer)):
# nn.Module or optimizer placeholder is captured by make_fx but # nn.Module or optimizer placeholder is captured by make_fx but
# never used in the graph # never used in the graph
inps.append(torch.empty(0)) inps.append(torch.empty(0))

View File

@ -527,8 +527,8 @@ def getitem_inference_rule(n: Node, symbols, constraints, counter):
@register_inference_rule(operator.gt) @register_inference_rule(operator.gt)
def gt_inference_rule(n: Node, symbols, constraints, counter): def gt_inference_rule(n: Node, symbols, constraints, counter):
assert isinstance(n.args[0], Node) or isinstance(n.args[0], int) assert isinstance(n.args[0], (Node, int))
assert isinstance(n.args[1], Node) or isinstance(n.args[1], int) assert isinstance(n.args[1], (Node, int))
# We make sure this node will not be used again. We do not # We make sure this node will not be used again. We do not
# generate a constraint about that node. Only about the operands. # generate a constraint about that node. Only about the operands.
@ -586,8 +586,8 @@ def gt_inference_rule(n: Node, symbols, constraints, counter):
@register_inference_rule(operator.eq) @register_inference_rule(operator.eq)
def eq_inference_rule(n: Node, symbols, constraints, counter): def eq_inference_rule(n: Node, symbols, constraints, counter):
assert isinstance(n.args[0], Node) or isinstance(n.args[0], int) assert isinstance(n.args[0], (Node, int))
assert isinstance(n.args[1], Node) or isinstance(n.args[1], int) assert isinstance(n.args[1], (Node, int))
e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0]
e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1]
@ -640,9 +640,9 @@ def neq_inference_rule(n: Node, symbols, constraints, counter):
# implementing for size 3 and 4 # implementing for size 3 and 4
if len(n.args[1]) == 3: if len(n.args[1]) == 3:
assert isinstance(n.args[1][0], Node) or isinstance(n.args[1][0], int) assert isinstance(n.args[1][0], (Node, int))
assert isinstance(n.args[1][1], Node) or isinstance(n.args[1][1], int) assert isinstance(n.args[1][1], (Node, int))
assert isinstance(n.args[1][2], Node) or isinstance(n.args[1][2], int) assert isinstance(n.args[1][2], (Node, int))
lhs = symbols[n.args[0]] lhs = symbols[n.args[0]]
@ -674,10 +674,10 @@ def neq_inference_rule(n: Node, symbols, constraints, counter):
elif len(n.args[1]) == 4: elif len(n.args[1]) == 4:
assert isinstance(n.args[1][0], Node) or isinstance(n.args[1][0], int) assert isinstance(n.args[1][0], (Node, int))
assert isinstance(n.args[1][1], Node) or isinstance(n.args[1][1], int) assert isinstance(n.args[1][1], (Node, int))
assert isinstance(n.args[1][2], Node) or isinstance(n.args[1][2], int) assert isinstance(n.args[1][2], (Node, int))
assert isinstance(n.args[1][3], Node) or isinstance(n.args[1][3], int) assert isinstance(n.args[1][3], (Node, int))
lhs = symbols[n.args[0]] lhs = symbols[n.args[0]]
@ -722,8 +722,8 @@ def neq_inference_rule(n: Node, symbols, constraints, counter):
@register_inference_rule(operator.lt) @register_inference_rule(operator.lt)
def lt_inference_rule(n: Node, symbols, constraints, counter): def lt_inference_rule(n: Node, symbols, constraints, counter):
assert isinstance(n.args[0], Node) or isinstance(n.args[0], int) assert isinstance(n.args[0], (Node, int))
assert isinstance(n.args[1], Node) or isinstance(n.args[1], int) assert isinstance(n.args[1], (Node, int))
# We make sure this node will not be used again. We do not # We make sure this node will not be used again. We do not
# generate a constraint about that node. Only about the operands. # generate a constraint about that node. Only about the operands.
@ -845,7 +845,7 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter):
else: else:
raise NotImplementedError('Method not yet implemented') raise NotImplementedError('Method not yet implemented')
elif isinstance(n.args[0], Node) and (isinstance(n.args[1], int) or isinstance(n.args[1], float)): elif isinstance(n.args[0], Node) and isinstance(n.args[1], (int, float)):
if isinstance(symbols[n.args[0]], TVar): if isinstance(symbols[n.args[0]], TVar):
my_output, counter = gen_tvar(counter) my_output, counter = gen_tvar(counter)
symbols[n] = my_output symbols[n] = my_output
@ -861,7 +861,7 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter):
BinConstraintD(0, my_output, op_leq)]) BinConstraintD(0, my_output, op_leq)])
return [c], counter return [c], counter
elif isinstance(n.args[1], Node) and (isinstance(n.args[0], int) or isinstance(n.args[1], float)): elif isinstance(n.args[1], Node) and isinstance(n.args[0], (int, float)):
if isinstance(symbols[n.args[1]], TVar): if isinstance(symbols[n.args[1]], TVar):
my_output, counter = gen_tvar(counter) my_output, counter = gen_tvar(counter)
symbols[n] = my_output symbols[n] = my_output

View File

@ -582,9 +582,7 @@ class MatMulDimInFP16Pattern(Pattern):
def source_code_location(event: Optional[_ProfilerEvent]): def source_code_location(event: Optional[_ProfilerEvent]):
while event: while event:
if event.tag == _EventType.PyCall or event.tag == _EventType.PyCCall: if event.tag == _EventType.PyCall or event.tag == _EventType.PyCCall:
assert isinstance(event.extra_fields, assert isinstance(event.extra_fields, (_ExtraFields_PyCall, _ExtraFields_PyCCall))
_ExtraFields_PyCall) or isinstance(
event.extra_fields, _ExtraFields_PyCCall)
if not event.extra_fields.caller.file_name.startswith("torch" + if not event.extra_fields.caller.file_name.startswith("torch" +
os.sep): os.sep):
return f"{event.extra_fields.caller.file_name}:{event.extra_fields.caller.line_number}" return f"{event.extra_fields.caller.file_name}:{event.extra_fields.caller.line_number}"

View File

@ -204,9 +204,7 @@ class GenLazyIR(ABC):
# as long as all of its arguments can be generated from information available from the schema # as long as all of its arguments can be generated from information available from the schema
base_ctor_value_args_list = [] base_ctor_value_args_list = []
for arg in value_args: for arg in value_args:
if isinstance(arg.lazy_type, BaseCType) or isinstance( if isinstance(arg.lazy_type, (BaseCType, VectorCType)):
arg.lazy_type, VectorCType
):
base_ctor_value_args_list.append(f"{arg.name}") base_ctor_value_args_list.append(f"{arg.name}")
elif isinstance(arg.lazy_type, OptionalCType): elif isinstance(arg.lazy_type, OptionalCType):
base_ctor_value_args_list.append(f"{arg.name}.value_or(kNullValue)") base_ctor_value_args_list.append(f"{arg.name}.value_or(kNullValue)")