mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix B903 lint: save memory for data classes with slots/namedtuple (#18184)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18184 ghimport-source-id: 2ce860b07c58d06dc10cd7e5b97d4ef7c709a50d Stack from [ghstack](https://github.com/ezyang/ghstack): * **#18184 Fix B903 lint: save memory for data classes with slots/namedtuple** * #18181 Fix B902 lint error: invalid first argument. * #18178 Fix B006 lint errors: using mutable structure in default argument. * #18177 Fix lstrip bug revealed by B005 lint Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: D14530872 fbshipit-source-id: e26cecab3a8545e7638454c28e654e7b82a3c08a
This commit is contained in:
parent
ba81074c40
commit
d1497debf2
2
.flake8
2
.flake8
|
|
@ -6,5 +6,5 @@ max-line-length = 120
|
||||||
ignore =
|
ignore =
|
||||||
E203,E305,E402,E501,E721,E741,F401,F403,F405,F821,F841,F999,W503,W504,C408,
|
E203,E305,E402,E501,E721,E741,F401,F403,F405,F821,F841,F999,W503,W504,C408,
|
||||||
# ignores below are temporary, fix them and remove please!
|
# ignores below are temporary, fix them and remove please!
|
||||||
B007,B008,B903
|
B007,B008
|
||||||
exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install,build,torch/include
|
exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install,build,torch/include
|
||||||
|
|
|
||||||
|
|
@ -194,6 +194,8 @@ CALL_TEMPLATE = CodeTemplate("${cname}(${actuals})")
|
||||||
class NYIError(Exception):
|
class NYIError(Exception):
|
||||||
"""Indicates we don't support this declaration yet"""
|
"""Indicates we don't support this declaration yet"""
|
||||||
|
|
||||||
|
__slots__ = ['reason']
|
||||||
|
|
||||||
def __init__(self, reason):
|
def __init__(self, reason):
|
||||||
self.reason = reason
|
self.reason = reason
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ import torch
|
||||||
from torch._six import inf, nan, istuple
|
from torch._six import inf, nan, istuple
|
||||||
from functools import reduce, wraps
|
from functools import reduce, wraps
|
||||||
from operator import mul, itemgetter
|
from operator import mul, itemgetter
|
||||||
|
import collections
|
||||||
from torch.autograd import Variable, Function, detect_anomaly
|
from torch.autograd import Variable, Function, detect_anomaly
|
||||||
from torch.testing import make_non_contiguous
|
from torch.testing import make_non_contiguous
|
||||||
from common_utils import (skipIfNoLapack,
|
from common_utils import (skipIfNoLapack,
|
||||||
|
|
@ -72,9 +73,7 @@ def prod_zeros(dim_size, dim_select):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class non_differentiable(object):
|
non_differentiable = collections.namedtuple('non_differentiable', ['tensor'])
|
||||||
def __init__(self, tensor):
|
|
||||||
self.tensor = tensor
|
|
||||||
|
|
||||||
|
|
||||||
class dont_convert(tuple):
|
class dont_convert(tuple):
|
||||||
|
|
|
||||||
|
|
@ -13942,7 +13942,7 @@ class TestClassType(JitTestCase):
|
||||||
self.assertEqual(fn(input), input)
|
self.assertEqual(fn(input), input)
|
||||||
|
|
||||||
def test_get_attr(self):
|
def test_get_attr(self):
|
||||||
@torch.jit.script
|
@torch.jit.script # noqa: B903
|
||||||
class FooTest:
|
class FooTest:
|
||||||
def __init__(self, x):
|
def __init__(self, x):
|
||||||
self.foo = x
|
self.foo = x
|
||||||
|
|
@ -14005,7 +14005,7 @@ class TestClassType(JitTestCase):
|
||||||
|
|
||||||
def test_type_annotations(self):
|
def test_type_annotations(self):
|
||||||
with self.assertRaisesRegex(RuntimeError, "expected a value of type bool"):
|
with self.assertRaisesRegex(RuntimeError, "expected a value of type bool"):
|
||||||
@torch.jit.script
|
@torch.jit.script # noqa: B903
|
||||||
class FooTest:
|
class FooTest:
|
||||||
def __init__(self, x):
|
def __init__(self, x):
|
||||||
# type: (bool) -> None
|
# type: (bool) -> None
|
||||||
|
|
@ -14026,7 +14026,7 @@ class TestClassType(JitTestCase):
|
||||||
self.attr = x
|
self.attr = x
|
||||||
|
|
||||||
def test_class_type_as_param(self):
|
def test_class_type_as_param(self):
|
||||||
@torch.jit.script
|
@torch.jit.script # noqa: B903
|
||||||
class FooTest:
|
class FooTest:
|
||||||
def __init__(self, x):
|
def __init__(self, x):
|
||||||
self.attr = x
|
self.attr = x
|
||||||
|
|
@ -14094,7 +14094,7 @@ class TestClassType(JitTestCase):
|
||||||
self.assertEqual(input, output)
|
self.assertEqual(input, output)
|
||||||
|
|
||||||
def test_save_load_with_classes_nested(self):
|
def test_save_load_with_classes_nested(self):
|
||||||
@torch.jit.script
|
@torch.jit.script # noqa: B903
|
||||||
class FooNestedTest:
|
class FooNestedTest:
|
||||||
def __init__(self, y):
|
def __init__(self, y):
|
||||||
self.y = y
|
self.y = y
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import re
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import itertools
|
import itertools
|
||||||
from collections import defaultdict
|
from collections import defaultdict, namedtuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._six import FileNotFoundError
|
from torch._six import FileNotFoundError
|
||||||
|
|
@ -366,11 +366,7 @@ class Interval(object):
|
||||||
return self.end - self.start
|
return self.end - self.start
|
||||||
|
|
||||||
|
|
||||||
class Kernel(object):
|
Kernel = namedtuple('Kernel', ['name', 'device', 'interval'])
|
||||||
def __init__(self, name, device, interval):
|
|
||||||
self.name = name
|
|
||||||
self.device = device
|
|
||||||
self.interval = interval
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: record TID too
|
# TODO: record TID too
|
||||||
|
|
|
||||||
|
|
@ -1564,10 +1564,7 @@ def annotate(the_type, the_value):
|
||||||
return the_value
|
return the_value
|
||||||
|
|
||||||
|
|
||||||
class Attribute(object):
|
Attribute = collections.namedtuple('Attribute', ['value', 'type'])
|
||||||
def __init__(self, value, the_type):
|
|
||||||
self.value = value
|
|
||||||
self.type = the_type
|
|
||||||
|
|
||||||
|
|
||||||
if not torch._C._jit_init():
|
if not torch._C._jit_init():
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user