Fix B903 lint: save memory for data classes with slots/namedtuple (#18184)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18184
ghimport-source-id: 2ce860b07c58d06dc10cd7e5b97d4ef7c709a50d

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18184 Fix B903 lint: save memory for data classes with slots/namedtuple**
* #18181 Fix B902 lint error: invalid first argument.
* #18178 Fix B006 lint errors: using mutable structure in default argument.
* #18177 Fix lstrip bug revealed by B005 lint

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: D14530872

fbshipit-source-id: e26cecab3a8545e7638454c28e654e7b82a3c08a
This commit is contained in:
Edward Yang 2019-03-21 09:06:30 -07:00 committed by Facebook Github Bot
parent ba81074c40
commit d1497debf2
6 changed files with 12 additions and 18 deletions

View File

@ -6,5 +6,5 @@ max-line-length = 120
ignore = ignore =
E203,E305,E402,E501,E721,E741,F401,F403,F405,F821,F841,F999,W503,W504,C408, E203,E305,E402,E501,E721,E741,F401,F403,F405,F821,F841,F999,W503,W504,C408,
# ignores below are temporary, fix them and remove please! # ignores below are temporary, fix them and remove please!
B007,B008,B903 B007,B008
exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install,build,torch/include exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install,build,torch/include

View File

@ -194,6 +194,8 @@ CALL_TEMPLATE = CodeTemplate("${cname}(${actuals})")
class NYIError(Exception): class NYIError(Exception):
"""Indicates we don't support this declaration yet""" """Indicates we don't support this declaration yet"""
__slots__ = ['reason']
def __init__(self, reason): def __init__(self, reason):
self.reason = reason self.reason = reason

View File

@ -2,6 +2,7 @@ import torch
from torch._six import inf, nan, istuple from torch._six import inf, nan, istuple
from functools import reduce, wraps from functools import reduce, wraps
from operator import mul, itemgetter from operator import mul, itemgetter
import collections
from torch.autograd import Variable, Function, detect_anomaly from torch.autograd import Variable, Function, detect_anomaly
from torch.testing import make_non_contiguous from torch.testing import make_non_contiguous
from common_utils import (skipIfNoLapack, from common_utils import (skipIfNoLapack,
@ -72,9 +73,7 @@ def prod_zeros(dim_size, dim_select):
return result return result
class non_differentiable(object): non_differentiable = collections.namedtuple('non_differentiable', ['tensor'])
def __init__(self, tensor):
self.tensor = tensor
class dont_convert(tuple): class dont_convert(tuple):

View File

@ -13942,7 +13942,7 @@ class TestClassType(JitTestCase):
self.assertEqual(fn(input), input) self.assertEqual(fn(input), input)
def test_get_attr(self): def test_get_attr(self):
@torch.jit.script @torch.jit.script # noqa: B903
class FooTest: class FooTest:
def __init__(self, x): def __init__(self, x):
self.foo = x self.foo = x
@ -14005,7 +14005,7 @@ class TestClassType(JitTestCase):
def test_type_annotations(self): def test_type_annotations(self):
with self.assertRaisesRegex(RuntimeError, "expected a value of type bool"): with self.assertRaisesRegex(RuntimeError, "expected a value of type bool"):
@torch.jit.script @torch.jit.script # noqa: B903
class FooTest: class FooTest:
def __init__(self, x): def __init__(self, x):
# type: (bool) -> None # type: (bool) -> None
@ -14026,7 +14026,7 @@ class TestClassType(JitTestCase):
self.attr = x self.attr = x
def test_class_type_as_param(self): def test_class_type_as_param(self):
@torch.jit.script @torch.jit.script # noqa: B903
class FooTest: class FooTest:
def __init__(self, x): def __init__(self, x):
self.attr = x self.attr = x
@ -14094,7 +14094,7 @@ class TestClassType(JitTestCase):
self.assertEqual(input, output) self.assertEqual(input, output)
def test_save_load_with_classes_nested(self): def test_save_load_with_classes_nested(self):
@torch.jit.script @torch.jit.script # noqa: B903
class FooNestedTest: class FooNestedTest:
def __init__(self, y): def __init__(self, y):
self.y = y self.y = y

View File

@ -3,7 +3,7 @@ import re
import os import os
import sys import sys
import itertools import itertools
from collections import defaultdict from collections import defaultdict, namedtuple
import torch import torch
from torch._six import FileNotFoundError from torch._six import FileNotFoundError
@ -366,11 +366,7 @@ class Interval(object):
return self.end - self.start return self.end - self.start
class Kernel(object): Kernel = namedtuple('Kernel', ['name', 'device', 'interval'])
def __init__(self, name, device, interval):
self.name = name
self.device = device
self.interval = interval
# TODO: record TID too # TODO: record TID too

View File

@ -1564,10 +1564,7 @@ def annotate(the_type, the_value):
return the_value return the_value
class Attribute(object): Attribute = collections.namedtuple('Attribute', ['value', 'type'])
def __init__(self, value, the_type):
self.value = value
self.type = the_type
if not torch._C._jit_init(): if not torch._C._jit_init():