mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE] Enable flake8-simplify checks (#97984)
Enable some sensible flake8-simplify rules. Mainly wanted to enable the SIM101, and `yield from` SIM103 checks. @kit1980 since you wanted to be tagged on this CI check. Enabling this check also helped flag one logical bug so it's definitely beneficial (also fixed in this PR). Pull Request resolved: https://github.com/pytorch/pytorch/pull/97984 Approved by: https://github.com/ezyang
This commit is contained in:
parent
3dc4405278
commit
9c3fbe7475
8
.flake8
8
.flake8
|
|
@ -1,6 +1,6 @@
|
||||||
[flake8]
|
[flake8]
|
||||||
enable-extensions = G
|
enable-extensions = G
|
||||||
select = B,C,E,F,G,P,T4,W,B9
|
select = B,C,E,F,G,P,SIM1,T4,W,B9
|
||||||
max-line-length = 120
|
max-line-length = 120
|
||||||
# C408 ignored because we like the dict keyword argument syntax
|
# C408 ignored because we like the dict keyword argument syntax
|
||||||
# E501 is not flexible enough, we're using B950 instead
|
# E501 is not flexible enough, we're using B950 instead
|
||||||
|
|
@ -16,7 +16,11 @@ ignore =
|
||||||
# these ignores are from flake8-comprehensions; please fix!
|
# these ignores are from flake8-comprehensions; please fix!
|
||||||
C407
|
C407
|
||||||
# these ignores are from flake8-logging-format; please fix!
|
# these ignores are from flake8-logging-format; please fix!
|
||||||
G001,G002,G003,G004,G100,G101,G200,G201,G202
|
G001,G002,G003,G004,G100,G101,G200,G201,G202,
|
||||||
|
# these ignores are from flake8-simplify. please fix or ignore with commented reason
|
||||||
|
SIM105,SIM108,SIM109,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12,
|
||||||
|
# flake8-simplify code styles
|
||||||
|
SIM102,SIM103,SIM106,SIM112,
|
||||||
per-file-ignores =
|
per-file-ignores =
|
||||||
__init__.py: F401
|
__init__.py: F401
|
||||||
torch/utils/cpp_extension.py: B950
|
torch/utils/cpp_extension.py: B950
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@ init_command = [
|
||||||
'flake8-executable==2.1.3',
|
'flake8-executable==2.1.3',
|
||||||
'flake8-logging-format==0.9.0',
|
'flake8-logging-format==0.9.0',
|
||||||
'flake8-pyi==23.3.1',
|
'flake8-pyi==23.3.1',
|
||||||
|
'flake8-simplify==0.19.3',
|
||||||
'mccabe==0.7.0',
|
'mccabe==0.7.0',
|
||||||
'pycodestyle==2.10.0',
|
'pycodestyle==2.10.0',
|
||||||
'pyflakes==3.0.1',
|
'pyflakes==3.0.1',
|
||||||
|
|
|
||||||
|
|
@ -457,8 +457,8 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
|
||||||
if self.kwdefaults:
|
if self.kwdefaults:
|
||||||
flags |= 0x02
|
flags |= 0x02
|
||||||
codegen(self.kwdefaults)
|
codegen(self.kwdefaults)
|
||||||
if isinstance(self.annotations, variables.ConstDictVariable) or isinstance(
|
if isinstance(
|
||||||
self.annotations, variables.TupleVariable
|
self.annotations, (variables.ConstDictVariable, variables.TupleVariable)
|
||||||
):
|
):
|
||||||
flags |= 0x04
|
flags |= 0x04
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -200,9 +200,7 @@ class ContextWrappingVariable(VariableTracker):
|
||||||
assert len(args) == 1
|
assert len(args) == 1
|
||||||
if isinstance(args[0], NestedUserFunctionVariable):
|
if isinstance(args[0], NestedUserFunctionVariable):
|
||||||
args[0] = UserFunctionVariable(args[0].get_function())
|
args[0] = UserFunctionVariable(args[0].get_function())
|
||||||
assert isinstance(args[0], UserMethodVariable) or isinstance(
|
assert isinstance(args[0], (UserMethodVariable, UserFunctionVariable))
|
||||||
args[0], UserFunctionVariable
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(args[0], UserMethodVariable):
|
if isinstance(args[0], UserMethodVariable):
|
||||||
return WrappedUserMethodVariable(args[0], self)
|
return WrappedUserMethodVariable(args[0], self)
|
||||||
|
|
|
||||||
|
|
@ -4568,7 +4568,7 @@ def meshgrid(
|
||||||
# This ref simultaneously handles two overloads (see stubs above)
|
# This ref simultaneously handles two overloads (see stubs above)
|
||||||
# The `indexing` argument is currently optional for torch.meshgrid, but we
|
# The `indexing` argument is currently optional for torch.meshgrid, but we
|
||||||
# plan to make the argument required: https://github.com/pytorch/pytorch/issues/50276
|
# plan to make the argument required: https://github.com/pytorch/pytorch/issues/50276
|
||||||
if isinstance(tensors[0], list) or isinstance(tensors[0], tuple):
|
if isinstance(tensors[0], (list, tuple)):
|
||||||
assert len(tensors) == 1
|
assert len(tensors) == 1
|
||||||
tensors = tuple(tensors[0])
|
tensors = tuple(tensors[0])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -154,7 +154,7 @@ def _dtensor_expand(
|
||||||
if isinstance(a, torch.Tensor):
|
if isinstance(a, torch.Tensor):
|
||||||
inps.append(a)
|
inps.append(a)
|
||||||
schemas.append(shard_schema)
|
schemas.append(shard_schema)
|
||||||
elif isinstance(a, nn.Module) or isinstance(a, torch.optim.Optimizer):
|
elif isinstance(a, (nn.Module, torch.optim.Optimizer)):
|
||||||
# nn.Module or optimizer placeholder is captured by make_fx but
|
# nn.Module or optimizer placeholder is captured by make_fx but
|
||||||
# never used in the graph
|
# never used in the graph
|
||||||
inps.append(torch.empty(0))
|
inps.append(torch.empty(0))
|
||||||
|
|
|
||||||
|
|
@ -527,8 +527,8 @@ def getitem_inference_rule(n: Node, symbols, constraints, counter):
|
||||||
|
|
||||||
@register_inference_rule(operator.gt)
|
@register_inference_rule(operator.gt)
|
||||||
def gt_inference_rule(n: Node, symbols, constraints, counter):
|
def gt_inference_rule(n: Node, symbols, constraints, counter):
|
||||||
assert isinstance(n.args[0], Node) or isinstance(n.args[0], int)
|
assert isinstance(n.args[0], (Node, int))
|
||||||
assert isinstance(n.args[1], Node) or isinstance(n.args[1], int)
|
assert isinstance(n.args[1], (Node, int))
|
||||||
|
|
||||||
# We make sure this node will not be used again. We do not
|
# We make sure this node will not be used again. We do not
|
||||||
# generate a constraint about that node. Only about the operands.
|
# generate a constraint about that node. Only about the operands.
|
||||||
|
|
@ -586,8 +586,8 @@ def gt_inference_rule(n: Node, symbols, constraints, counter):
|
||||||
|
|
||||||
@register_inference_rule(operator.eq)
|
@register_inference_rule(operator.eq)
|
||||||
def eq_inference_rule(n: Node, symbols, constraints, counter):
|
def eq_inference_rule(n: Node, symbols, constraints, counter):
|
||||||
assert isinstance(n.args[0], Node) or isinstance(n.args[0], int)
|
assert isinstance(n.args[0], (Node, int))
|
||||||
assert isinstance(n.args[1], Node) or isinstance(n.args[1], int)
|
assert isinstance(n.args[1], (Node, int))
|
||||||
|
|
||||||
e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0]
|
e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0]
|
||||||
e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1]
|
e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1]
|
||||||
|
|
@ -640,9 +640,9 @@ def neq_inference_rule(n: Node, symbols, constraints, counter):
|
||||||
# implementing for size 3 and 4
|
# implementing for size 3 and 4
|
||||||
if len(n.args[1]) == 3:
|
if len(n.args[1]) == 3:
|
||||||
|
|
||||||
assert isinstance(n.args[1][0], Node) or isinstance(n.args[1][0], int)
|
assert isinstance(n.args[1][0], (Node, int))
|
||||||
assert isinstance(n.args[1][1], Node) or isinstance(n.args[1][1], int)
|
assert isinstance(n.args[1][1], (Node, int))
|
||||||
assert isinstance(n.args[1][2], Node) or isinstance(n.args[1][2], int)
|
assert isinstance(n.args[1][2], (Node, int))
|
||||||
|
|
||||||
lhs = symbols[n.args[0]]
|
lhs = symbols[n.args[0]]
|
||||||
|
|
||||||
|
|
@ -674,10 +674,10 @@ def neq_inference_rule(n: Node, symbols, constraints, counter):
|
||||||
|
|
||||||
elif len(n.args[1]) == 4:
|
elif len(n.args[1]) == 4:
|
||||||
|
|
||||||
assert isinstance(n.args[1][0], Node) or isinstance(n.args[1][0], int)
|
assert isinstance(n.args[1][0], (Node, int))
|
||||||
assert isinstance(n.args[1][1], Node) or isinstance(n.args[1][1], int)
|
assert isinstance(n.args[1][1], (Node, int))
|
||||||
assert isinstance(n.args[1][2], Node) or isinstance(n.args[1][2], int)
|
assert isinstance(n.args[1][2], (Node, int))
|
||||||
assert isinstance(n.args[1][3], Node) or isinstance(n.args[1][3], int)
|
assert isinstance(n.args[1][3], (Node, int))
|
||||||
|
|
||||||
lhs = symbols[n.args[0]]
|
lhs = symbols[n.args[0]]
|
||||||
|
|
||||||
|
|
@ -722,8 +722,8 @@ def neq_inference_rule(n: Node, symbols, constraints, counter):
|
||||||
|
|
||||||
@register_inference_rule(operator.lt)
|
@register_inference_rule(operator.lt)
|
||||||
def lt_inference_rule(n: Node, symbols, constraints, counter):
|
def lt_inference_rule(n: Node, symbols, constraints, counter):
|
||||||
assert isinstance(n.args[0], Node) or isinstance(n.args[0], int)
|
assert isinstance(n.args[0], (Node, int))
|
||||||
assert isinstance(n.args[1], Node) or isinstance(n.args[1], int)
|
assert isinstance(n.args[1], (Node, int))
|
||||||
|
|
||||||
# We make sure this node will not be used again. We do not
|
# We make sure this node will not be used again. We do not
|
||||||
# generate a constraint about that node. Only about the operands.
|
# generate a constraint about that node. Only about the operands.
|
||||||
|
|
@ -845,7 +845,7 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Method not yet implemented')
|
raise NotImplementedError('Method not yet implemented')
|
||||||
|
|
||||||
elif isinstance(n.args[0], Node) and (isinstance(n.args[1], int) or isinstance(n.args[1], float)):
|
elif isinstance(n.args[0], Node) and isinstance(n.args[1], (int, float)):
|
||||||
if isinstance(symbols[n.args[0]], TVar):
|
if isinstance(symbols[n.args[0]], TVar):
|
||||||
my_output, counter = gen_tvar(counter)
|
my_output, counter = gen_tvar(counter)
|
||||||
symbols[n] = my_output
|
symbols[n] = my_output
|
||||||
|
|
@ -861,7 +861,7 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter):
|
||||||
BinConstraintD(0, my_output, op_leq)])
|
BinConstraintD(0, my_output, op_leq)])
|
||||||
return [c], counter
|
return [c], counter
|
||||||
|
|
||||||
elif isinstance(n.args[1], Node) and (isinstance(n.args[0], int) or isinstance(n.args[1], float)):
|
elif isinstance(n.args[1], Node) and isinstance(n.args[0], (int, float)):
|
||||||
if isinstance(symbols[n.args[1]], TVar):
|
if isinstance(symbols[n.args[1]], TVar):
|
||||||
my_output, counter = gen_tvar(counter)
|
my_output, counter = gen_tvar(counter)
|
||||||
symbols[n] = my_output
|
symbols[n] = my_output
|
||||||
|
|
|
||||||
|
|
@ -582,9 +582,7 @@ class MatMulDimInFP16Pattern(Pattern):
|
||||||
def source_code_location(event: Optional[_ProfilerEvent]):
|
def source_code_location(event: Optional[_ProfilerEvent]):
|
||||||
while event:
|
while event:
|
||||||
if event.tag == _EventType.PyCall or event.tag == _EventType.PyCCall:
|
if event.tag == _EventType.PyCall or event.tag == _EventType.PyCCall:
|
||||||
assert isinstance(event.extra_fields,
|
assert isinstance(event.extra_fields, (_ExtraFields_PyCall, _ExtraFields_PyCCall))
|
||||||
_ExtraFields_PyCall) or isinstance(
|
|
||||||
event.extra_fields, _ExtraFields_PyCCall)
|
|
||||||
if not event.extra_fields.caller.file_name.startswith("torch" +
|
if not event.extra_fields.caller.file_name.startswith("torch" +
|
||||||
os.sep):
|
os.sep):
|
||||||
return f"{event.extra_fields.caller.file_name}:{event.extra_fields.caller.line_number}"
|
return f"{event.extra_fields.caller.file_name}:{event.extra_fields.caller.line_number}"
|
||||||
|
|
|
||||||
|
|
@ -204,9 +204,7 @@ class GenLazyIR(ABC):
|
||||||
# as long as all of its arguments can be generated from information available from the schema
|
# as long as all of its arguments can be generated from information available from the schema
|
||||||
base_ctor_value_args_list = []
|
base_ctor_value_args_list = []
|
||||||
for arg in value_args:
|
for arg in value_args:
|
||||||
if isinstance(arg.lazy_type, BaseCType) or isinstance(
|
if isinstance(arg.lazy_type, (BaseCType, VectorCType)):
|
||||||
arg.lazy_type, VectorCType
|
|
||||||
):
|
|
||||||
base_ctor_value_args_list.append(f"{arg.name}")
|
base_ctor_value_args_list.append(f"{arg.name}")
|
||||||
elif isinstance(arg.lazy_type, OptionalCType):
|
elif isinstance(arg.lazy_type, OptionalCType):
|
||||||
base_ctor_value_args_list.append(f"{arg.name}.value_or(kNullValue)")
|
base_ctor_value_args_list.append(f"{arg.name}.value_or(kNullValue)")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user