mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Turn on F401: Unused import warning. (#18598)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18598 ghimport-source-id: c74597e5e7437e94a43c163cee0639b20d0d0c6a Stack from [ghstack](https://github.com/ezyang/ghstack): * **#18598 Turn on F401: Unused import warning.** This was requested by someone at Facebook; this lint is turned on for Facebook by default. "Sure, why not." I had to noqa a number of imports in __init__. Hypothetically we're supposed to use __all__ in this case, but I was too lazy to fix it. Left for future work. Be careful! flake8-2 and flake8-3 behave differently with respect to import resolution for # type: comments. flake8-3 will report an import unused; flake8-2 will not. For now, I just noqa'd all these sites. All the changes were done by hand. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: D14687478 fbshipit-source-id: 30d532381e914091aadfa0d2a5a89404819663e3
This commit is contained in:
parent
96456bfa4c
commit
173f224570
|
|
@ -22,7 +22,7 @@ def handle_missing_graphviz(f):
|
|||
calls to the draw() method of the returned object to do nothing.
|
||||
"""
|
||||
try:
|
||||
import pygraphviz
|
||||
import pygraphviz # noqa: F401
|
||||
return f
|
||||
|
||||
except ModuleNotFoundError:
|
||||
|
|
|
|||
2
.flake8
2
.flake8
|
|
@ -7,7 +7,7 @@ max-line-length = 120
|
|||
# C408 ignored because we like the dict keyword argument syntax
|
||||
# E501 is not flexible enough, we're using B950 instead
|
||||
ignore =
|
||||
E203,E305,E402,E501,E721,E741,F401,F403,F405,F821,F841,F999,W503,W504,C408,
|
||||
E203,E305,E402,E501,E721,E741,F403,F405,F821,F841,F999,W503,W504,C408,
|
||||
# these ignores are from flake8-bugbear; please fix!
|
||||
B007,B008,
|
||||
# these ignores are from flake8-comprehensions; please fix!
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import sys
|
||||
import json
|
||||
import math
|
||||
import numpy
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from .cells import *
|
||||
from .factory import *
|
||||
from .cells import * # noqa: F401
|
||||
from .factory import * # noqa: F401
|
||||
|
||||
# (output, next_state) = cell(input, state)
|
||||
seqLength = 100
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import argparse
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .cells import lstm_cell
|
||||
from .factory import pytorch_lstm_creator, varlen_pytorch_lstm_creator
|
||||
from .runner import get_nn_runners
|
||||
|
||||
|
|
|
|||
|
|
@ -20,11 +20,8 @@
|
|||
import os
|
||||
# sys.path.insert(0, os.path.abspath('.'))
|
||||
|
||||
import sys
|
||||
import textwrap
|
||||
|
||||
import pytorch_sphinx_theme
|
||||
|
||||
# -- General configuration ------------------------------------------------
|
||||
|
||||
# If your documentation needs a minimal Sphinx version, state it here.
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ import os
|
|||
|
||||
import torch
|
||||
try:
|
||||
import torchvision
|
||||
import torchvision # noqa: F401
|
||||
except ImportError:
|
||||
import warnings
|
||||
warnings.warn('unable to load "torchvision" package')
|
||||
|
|
|
|||
3
setup.py
3
setup.py
|
|
@ -142,7 +142,7 @@
|
|||
# we will search for libraries in these paths
|
||||
|
||||
from __future__ import print_function
|
||||
from setuptools import setup, Extension, distutils, Command, find_packages
|
||||
from setuptools import setup, Extension, distutils, find_packages
|
||||
from distutils import core, dir_util
|
||||
from distutils.core import Distribution
|
||||
from distutils.errors import DistutilsArgError
|
||||
|
|
@ -151,7 +151,6 @@ import setuptools.command.install
|
|||
import distutils.command.clean
|
||||
import distutils.sysconfig
|
||||
import filecmp
|
||||
import platform
|
||||
import subprocess
|
||||
import shutil
|
||||
import sys
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import torch
|
||||
from torch._six import inf, nan, istuple
|
||||
from functools import reduce, wraps
|
||||
from torch._six import inf, istuple
|
||||
from functools import reduce
|
||||
from operator import mul, itemgetter
|
||||
import collections
|
||||
from torch.autograd import Variable, Function, detect_anomaly
|
||||
from torch.autograd import Variable
|
||||
from torch.testing import make_non_contiguous
|
||||
from common_utils import (skipIfNoLapack,
|
||||
prod_single_zero, random_square_matrix_of_rank,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from torch.nn.functional import _Reduction
|
||||
from common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
|
||||
TEST_WITH_ROCM, skipIfRocm
|
||||
TEST_WITH_ROCM
|
||||
from common_cuda import TEST_CUDA
|
||||
from torch.autograd.gradcheck import get_numerical_jacobian, iter_tensors
|
||||
from torch.autograd import Variable
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import argparse
|
||||
import os.path
|
||||
import tempfile
|
||||
import unittest
|
||||
|
|
|
|||
|
|
@ -4,12 +4,10 @@ from __future__ import print_function
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import sys
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import torch.jit
|
||||
from torch.autograd import Variable
|
||||
import torch.autograd.function as function
|
||||
|
||||
import onnx
|
||||
import caffe2.python.onnx.backend as c2
|
||||
|
|
|
|||
|
|
@ -5,12 +5,9 @@ from __future__ import unicode_literals
|
|||
|
||||
import argparse
|
||||
import glob
|
||||
import numpy as np
|
||||
import onnx.backend.test
|
||||
import caffe2.python.onnx.backend as c2
|
||||
import os
|
||||
import shutil
|
||||
from onnx import numpy_helper
|
||||
from test_caffe2_common import run_generated_test
|
||||
import google.protobuf.text_format
|
||||
import test_onnx_common
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import shutil
|
|||
import torch
|
||||
import traceback
|
||||
|
||||
import test_pytorch_common
|
||||
import test_onnx_common
|
||||
from common_nn import module_tests
|
||||
from test_nn import new_module_tests
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from .squeezenet import *
|
||||
from .super_resolution import *
|
||||
from .op_test import *
|
||||
from .srresnet import *
|
||||
from .squeezenet import * # noqa: F401
|
||||
from .super_resolution import * # noqa: F401
|
||||
from .op_test import * # noqa: F401
|
||||
from .srresnet import * # noqa: F401
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
class RNNModel(nn.Module):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from torchvision.models.resnet import resnet50
|
|||
from torchvision.models.vgg import vgg16, vgg16_bn, vgg19, vgg19_bn
|
||||
|
||||
from model_defs.mnist import MNIST
|
||||
from model_defs.word_language_model import RNNModel
|
||||
from model_defs.squeezenet import SqueezeNet
|
||||
from model_defs.super_resolution import SuperResolutionNet
|
||||
from model_defs.srresnet import SRResNet
|
||||
|
|
@ -17,17 +16,9 @@ from test_pytorch_common import TestCase, run_tests, skipIfNoLapack
|
|||
import torch
|
||||
import torch.onnx
|
||||
import torch.onnx.utils
|
||||
from torch.autograd import Variable, Function
|
||||
from torch.nn import Module
|
||||
from torch.autograd import Variable
|
||||
from torch.onnx import OperatorExportTypes
|
||||
|
||||
import onnx
|
||||
import onnx.checker
|
||||
import onnx.helper
|
||||
|
||||
import google.protobuf.text_format
|
||||
|
||||
import io
|
||||
import unittest
|
||||
|
||||
import caffe2.python.onnx.backend as backend
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from test_pytorch_common import TestCase, run_tests, skipIfNoLapack, flatten
|
||||
from test_pytorch_common import TestCase, run_tests, flatten
|
||||
|
||||
import torch
|
||||
import torch.onnx
|
||||
|
|
@ -10,11 +10,9 @@ import itertools
|
|||
import io
|
||||
import unittest
|
||||
import inspect
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import common_utils as common
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import torch.autograd.function as function
|
|||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.insert(-1, pytorch_test_dir)
|
||||
|
||||
from common_utils import *
|
||||
from common_utils import * # noqa: F401
|
||||
|
||||
torch.set_default_tensor_type('torch.FloatTensor')
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# Some standard imports
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
import torch.onnx
|
||||
import torch.nn.init as init
|
||||
from caffe2.python.model_helper import ModelHelper
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from functools import wraps
|
||||
import numpy as np
|
||||
import sys
|
||||
import unittest
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import json
|
||||
import torch
|
||||
import torch.legacy.optim as optim
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
def rosenbrock(tensor):
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ import warnings
|
|||
from copy import deepcopy
|
||||
from collections import OrderedDict
|
||||
from itertools import product
|
||||
from operator import mul, itemgetter
|
||||
from functools import reduce, wraps
|
||||
from operator import mul
|
||||
from functools import reduce
|
||||
from torch._six import inf, nan, istuple
|
||||
from torch.autograd.gradcheck import gradgradcheck, gradcheck
|
||||
from torch.autograd.function import once_differentiable
|
||||
|
|
@ -17,14 +17,11 @@ from torch.autograd.profiler import profile
|
|||
from torch.utils.checkpoint import checkpoint
|
||||
from common_utils import (TEST_MKL, TestCase, run_tests, skipIfNoLapack,
|
||||
suppress_warnings, skipIfRocm,
|
||||
prod_single_zero, random_square_matrix_of_rank,
|
||||
random_symmetric_matrix, random_symmetric_psd_matrix,
|
||||
random_symmetric_pd_matrix, make_nonzero_det,
|
||||
random_fullrank_matrix_distinct_singular_value, load_tests)
|
||||
load_tests)
|
||||
from common_cuda import TEST_CUDA
|
||||
from torch.autograd import Variable, Function, detect_anomaly
|
||||
from torch.autograd.function import InplaceFunction
|
||||
from torch.testing import make_non_contiguous, randn_like
|
||||
from torch.testing import randn_like
|
||||
from common_methods_invocations import (method_tests,
|
||||
create_input, unpack_variables,
|
||||
EXCLUDE_FUNCTIONAL, EXCLUDE_GRADCHECK,
|
||||
|
|
@ -32,7 +29,7 @@ from common_methods_invocations import (method_tests,
|
|||
EXCLUDE_GRADGRADCHECK_BY_TEST_NAME,
|
||||
exclude_tensor_method,
|
||||
mask_not_all_zeros,
|
||||
L, S)
|
||||
S)
|
||||
|
||||
# load_tests from common_utils is used to automatically filter tests for
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
import io
|
||||
import math
|
||||
import tempfile
|
||||
import re
|
||||
import unittest
|
||||
import sys
|
||||
from itertools import repeat
|
||||
|
|
@ -19,9 +17,9 @@ from torch._six import inf, nan
|
|||
from test_torch import _TestTorchMixin
|
||||
|
||||
from common_methods_invocations import tri_tests_args, tri_large_tests_args, \
|
||||
run_additional_tri_tests, _compare_trilu_indices, _compare_large_trilu_indices
|
||||
_compare_trilu_indices, _compare_large_trilu_indices
|
||||
from common_utils import TestCase, get_gpu_type, to_gpu, freeze_rng_state, run_tests, \
|
||||
PY3, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, skipIfRocm, TEST_NUMPY, TEST_WITH_ROCM, load_tests, iter_indices
|
||||
PY3, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, skipIfRocm, TEST_NUMPY, TEST_WITH_ROCM, load_tests
|
||||
|
||||
# load_tests from common_utils is used to automatically filter tests for
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
|
|
|
|||
|
|
@ -3,13 +3,10 @@ import sys
|
|||
import errno
|
||||
import os
|
||||
import ctypes
|
||||
import signal
|
||||
import torch
|
||||
import gc
|
||||
import time
|
||||
import traceback
|
||||
import unittest
|
||||
import subprocess
|
||||
import itertools
|
||||
import warnings
|
||||
from torch import multiprocessing as mp
|
||||
|
|
|
|||
|
|
@ -16,11 +16,9 @@ import torch.cuda
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from common_utils import TestCase, run_tests
|
||||
from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
|
||||
from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT
|
||||
import common_utils as common
|
||||
|
||||
BACKEND = os.environ["BACKEND"]
|
||||
TEMP_DIR = os.environ["TEMP_DIR"]
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@ import torch
|
|||
import unittest
|
||||
import os
|
||||
import re
|
||||
import ast
|
||||
import _ast
|
||||
import textwrap
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from common_utils import TestCase, run_tests
|
||||
import torch
|
||||
import warnings
|
||||
from torch import tensor
|
||||
import unittest
|
||||
|
||||
|
|
|
|||
|
|
@ -12,9 +12,6 @@ from contextlib import contextmanager
|
|||
from itertools import product, chain
|
||||
import torch.jit.frontend
|
||||
from torch.autograd import Variable, Function
|
||||
from torch.nn import Module
|
||||
from torch.autograd.function import traceable
|
||||
from torch.testing import assert_allclose
|
||||
from torch.onnx import OperatorExportTypes
|
||||
from torch._six import inf, PY2, builtins, StringIO
|
||||
from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
|
||||
|
|
@ -25,7 +22,6 @@ from textwrap import dedent
|
|||
from functools import wraps
|
||||
import os
|
||||
import io
|
||||
import itertools
|
||||
import sys
|
||||
import unittest
|
||||
import inspect
|
||||
|
|
@ -46,14 +42,13 @@ from common_methods_invocations import method_tests as autograd_method_tests
|
|||
from common_methods_invocations import create_input, unpack_variables, \
|
||||
exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL
|
||||
from torch.testing import FileCheck
|
||||
from torch._C import TensorType, TupleType, FloatType, IntType, \
|
||||
ListType, StringType, DictType
|
||||
from torch._C import TensorType
|
||||
from copy import deepcopy
|
||||
import random
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from torch.jit.frontend import NotSupportedError
|
||||
from torch import Tensor
|
||||
from torch.jit.annotations import BroadcastingList2, BroadcastingList3
|
||||
from torch.jit.annotations import BroadcastingList2, BroadcastingList3 # noqa: F401
|
||||
|
||||
# For testing truediv in python 2
|
||||
from test_module.future_div import div_int_future, div_float_future
|
||||
|
|
@ -6727,8 +6722,6 @@ a")
|
|||
|
||||
@unittest.skipIf(not PY35, "Python 3.5 needed")
|
||||
def test_type_annotation_py3(self):
|
||||
import importlib.util
|
||||
|
||||
code = dedent("""
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
@ -9349,8 +9342,6 @@ a")
|
|||
foo(torch.ones([123])) # wrong size
|
||||
|
||||
def test_builtin_error_messsage(self):
|
||||
from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
|
||||
@torch.jit.script
|
||||
def close_match(x):
|
||||
|
|
@ -11020,8 +11011,6 @@ class TestEndToEndHybridFrontendModels(JitTestCase):
|
|||
|
||||
@staticmethod
|
||||
def _test_super_resolution(self, device, check_export_import=True):
|
||||
import torch.nn.init as init
|
||||
|
||||
class Net(nn.Module):
|
||||
|
||||
def __init__(self, upscale_factor):
|
||||
|
|
|
|||
|
|
@ -3,17 +3,12 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import functools
|
||||
import os
|
||||
import unittest
|
||||
import sys
|
||||
import torch
|
||||
import torch.autograd.function as function
|
||||
from torch import Tensor
|
||||
|
||||
from common_utils import TestCase, run_tests, IS_WINDOWS, \
|
||||
from common_utils import IS_WINDOWS, \
|
||||
skipIfRocm, IS_SANDCASTLE
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
|
||||
from test_jit import JitTestCase, enable_cpu_fuser, RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, \
|
||||
backward_graph
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from __future__ import division
|
||||
import torch
|
||||
import torch # noqa: F401
|
||||
|
||||
|
||||
def div_int_future():
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import torch
|
||||
import torch # noqa: F401
|
||||
|
||||
|
||||
def div_int_nofuture():
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import torch.utils.hooks
|
|||
from torch.nn import Parameter
|
||||
from common_utils import (TestCase, run_tests, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, TEST_WITH_ASAN,
|
||||
load_tests, slowTest)
|
||||
from multiprocessing.reduction import ForkingPickler
|
||||
|
||||
# load_tests from common_utils is used to automatically filter tests for
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
|
|
|
|||
|
|
@ -11,8 +11,6 @@ from itertools import repeat, product
|
|||
from functools import wraps, reduce
|
||||
from operator import mul
|
||||
from collections import OrderedDict
|
||||
import hashlib
|
||||
import os
|
||||
import threading
|
||||
|
||||
import torch
|
||||
|
|
@ -29,9 +27,9 @@ from torch.autograd import Variable, gradcheck
|
|||
from torch.autograd.gradcheck import gradgradcheck
|
||||
from torch.nn import Parameter
|
||||
from torch.nn.parallel._functions import Broadcast
|
||||
from common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, TEST_WITH_ROCM, \
|
||||
TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, download_file, PY3, PY34, to_gpu, \
|
||||
get_function_arglist, skipCUDAMemoryLeakCheckIf, load_tests
|
||||
from common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
|
||||
TEST_NUMPY, TEST_SCIPY, download_file, PY3, PY34, to_gpu, \
|
||||
get_function_arglist, load_tests
|
||||
from common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \
|
||||
TEST_CUDNN_VERSION
|
||||
from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \
|
||||
|
|
|
|||
|
|
@ -1,13 +1,9 @@
|
|||
import torch
|
||||
import torch.jit
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import unittest
|
||||
from caffe2.python import core
|
||||
from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
|
||||
skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
|
||||
freeze_rng_state, set_rng_seed
|
||||
from common_utils import TestCase, run_tests
|
||||
|
||||
|
||||
def canonical(graph):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import torch
|
||||
from torch import sparse
|
||||
|
||||
import itertools
|
||||
import functools
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ import torch.cuda
|
|||
import torch.distributed.deprecated as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from common_utils import TestCase, run_tests
|
||||
from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import io
|
|||
import os
|
||||
import math
|
||||
import random
|
||||
import operator
|
||||
import copy
|
||||
import shutil
|
||||
import torch
|
||||
|
|
@ -17,7 +16,7 @@ import gzip
|
|||
import types
|
||||
import textwrap
|
||||
import re
|
||||
from torch._utils_internal import get_file_path, get_file_path_2
|
||||
from torch._utils_internal import get_file_path_2
|
||||
from torch.utils.dlpack import from_dlpack, to_dlpack
|
||||
from torch._utils import _rebuild_tensor
|
||||
from torch._six import inf, nan, string_classes, istuple
|
||||
|
|
@ -2032,7 +2031,6 @@ class _TestTorchMixin(object):
|
|||
def _test_int_pow(self, cast):
|
||||
if not TEST_NUMPY:
|
||||
return
|
||||
import numpy as np
|
||||
|
||||
def check_against_np(tensor, exp):
|
||||
tensor_np = tensor.cpu().numpy()
|
||||
|
|
@ -4669,7 +4667,6 @@ class _TestTorchMixin(object):
|
|||
# Test non-contiguous inputs.
|
||||
if not TEST_NUMPY:
|
||||
return
|
||||
import numpy
|
||||
from numpy.linalg import solve
|
||||
A = cast(random_fullrank_matrix_distinct_singular_value(2, 2)).permute(1, 0, 2)
|
||||
b = cast(torch.randn(2, 2, 2)).permute(2, 1, 0)
|
||||
|
|
@ -6218,7 +6215,6 @@ class _TestTorchMixin(object):
|
|||
# Test non-contiguous inputs.
|
||||
if not TEST_NUMPY:
|
||||
return
|
||||
import numpy
|
||||
from numpy.linalg import solve
|
||||
A = random_symmetric_pd_matrix(2, 2)
|
||||
b = torch.randn(2, 2, 2)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import print_function
|
||||
import unittest
|
||||
from common_utils import TestCase, run_tests, download_file
|
||||
from common_utils import TestCase, run_tests
|
||||
import tempfile
|
||||
import torch
|
||||
import re
|
||||
|
|
@ -10,7 +10,7 @@ import subprocess
|
|||
import inspect
|
||||
|
||||
try:
|
||||
import mypy
|
||||
import mypy # noqa: F401
|
||||
HAVE_MYPY = True
|
||||
except ImportError:
|
||||
HAVE_MYPY = False
|
||||
|
|
|
|||
|
|
@ -2,22 +2,19 @@ from __future__ import print_function
|
|||
import sys
|
||||
import os
|
||||
import re
|
||||
import math
|
||||
import shutil
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
import traceback
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.data
|
||||
import torch.cuda
|
||||
import warnings
|
||||
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
|
||||
import torch.hub as hub
|
||||
from torch.autograd._functions.utils import prepare_onnx_paddings
|
||||
from torch.autograd._functions.utils import check_onnx_broadcast
|
||||
from common_utils import IS_WINDOWS, IS_PPC, skipIfRocm, load_tests
|
||||
from common_utils import skipIfRocm, load_tests
|
||||
|
||||
# load_tests from common_utils is used to automatically filter tests for
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
|
|
@ -34,7 +31,7 @@ skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
|
|||
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
|
||||
from common_utils import TestCase, run_tests, download_file
|
||||
from common_utils import TestCase, run_tests
|
||||
|
||||
|
||||
class RandomDatasetMock(object):
|
||||
|
|
@ -326,7 +323,7 @@ test_dir = os.path.abspath(os.path.dirname(str(__file__)))
|
|||
class TestFFI(TestCase):
|
||||
def test_deprecated(self):
|
||||
with self.assertRaisesRegex(ImportError, "torch.utils.ffi is deprecated. Please use cpp extensions instead."):
|
||||
from torch.utils.ffi import create_extension
|
||||
from torch.utils.ffi import create_extension # noqa: F401
|
||||
|
||||
|
||||
@unittest.skipIf('SKIP_TEST_BOTTLENECK' in os.environ.keys(), 'SKIP_TEST_BOTTLENECK is set')
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import argparse
|
||||
from functools import reduce
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@
|
|||
# differentiable subcomponents.
|
||||
#
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
from .utils import CodeTemplate, nested_dict, write, uninplace_api_name
|
||||
from .gen_autograd import VIEW_FUNCTIONS
|
||||
|
|
|
|||
|
|
@ -1,8 +1,5 @@
|
|||
import argparse
|
||||
import os
|
||||
from os.path import dirname, abspath
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
# By appending pytorch_root to sys.path, this module can import other torch
|
||||
|
|
|
|||
|
|
@ -1,22 +1,20 @@
|
|||
from .setup_helpers.env import (IS_64BIT, IS_ARM, IS_DARWIN, IS_LINUX, IS_PPC, IS_WINDOWS,
|
||||
from .setup_helpers.env import (IS_64BIT, IS_DARWIN, IS_WINDOWS,
|
||||
DEBUG, REL_WITH_DEB_INFO, USE_MKLDNN,
|
||||
check_env_flag, check_negative_env_flag, hotpatch_build_env_vars)
|
||||
check_env_flag, check_negative_env_flag)
|
||||
|
||||
import os
|
||||
import sys
|
||||
import distutils
|
||||
import distutils.sysconfig
|
||||
from distutils.file_util import copy_file
|
||||
from distutils.dir_util import copy_tree
|
||||
from subprocess import check_call, call, check_output
|
||||
from subprocess import check_call, check_output
|
||||
from distutils.version import LooseVersion
|
||||
from .setup_helpers.cuda import USE_CUDA, CUDA_HOME
|
||||
from .setup_helpers.dist_check import USE_DISTRIBUTED, USE_GLOO_IBVERBS
|
||||
from .setup_helpers.nccl import USE_SYSTEM_NCCL, NCCL_INCLUDE_DIR, NCCL_ROOT_DIR, NCCL_SYSTEM_LIB, USE_NCCL
|
||||
from .setup_helpers.rocm import ROCM_HOME, ROCM_VERSION, USE_ROCM
|
||||
from .setup_helpers.rocm import USE_ROCM
|
||||
from .setup_helpers.nnpack import USE_NNPACK
|
||||
from .setup_helpers.qnnpack import USE_QNNPACK
|
||||
from .setup_helpers.cudnn import CUDNN_INCLUDE_DIR, CUDNN_LIB_DIR, CUDNN_LIBRARY, USE_CUDNN
|
||||
from .setup_helpers.cudnn import CUDNN_INCLUDE_DIR, CUDNN_LIBRARY, USE_CUDNN
|
||||
|
||||
|
||||
from pprint import pprint
|
||||
|
|
|
|||
|
|
@ -12,9 +12,7 @@ Only files that are in CLANG_FORMAT_WHITELIST are checked.
|
|||
import subprocess
|
||||
import os
|
||||
import argparse
|
||||
import fnmatch
|
||||
import difflib
|
||||
import sys
|
||||
import re
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from .cwrap import cwrap
|
||||
from .cwrap import cwrap # noqa: F401
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from . import CWrapPlugin
|
||||
from string import Template
|
||||
|
||||
|
||||
class ArgumentReferences(CWrapPlugin):
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
from string import Template
|
||||
import copy
|
||||
from copy import deepcopy
|
||||
from . import CWrapPlugin
|
||||
from itertools import product
|
||||
|
||||
|
||||
class CuDNNPlugin(CWrapPlugin):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from . import CWrapPlugin
|
||||
from string import Template
|
||||
|
||||
|
||||
class GILRelease(CWrapPlugin):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,4 @@
|
|||
import os
|
||||
from copy import deepcopy
|
||||
from . import CWrapPlugin
|
||||
from itertools import product
|
||||
from ...shared import cwrap_common
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -420,15 +420,15 @@ class CWrapPlugin(object):
|
|||
return template
|
||||
|
||||
|
||||
from .NNExtension import NNExtension
|
||||
from .NullableArguments import NullableArguments
|
||||
from .OptionalArguments import OptionalArguments
|
||||
from .ArgcountChecker import ArgcountChecker
|
||||
from .ArgumentReferences import ArgumentReferences
|
||||
from .BeforeAfterCall import BeforeAfterCall
|
||||
from .ConstantArguments import ConstantArguments
|
||||
from .ReturnArguments import ReturnArguments
|
||||
from .GILRelease import GILRelease
|
||||
from .AutoGPU import AutoGPU
|
||||
from .CuDNNPlugin import CuDNNPlugin
|
||||
from .WrapDim import WrapDim
|
||||
from .NNExtension import NNExtension # noqa: F401
|
||||
from .NullableArguments import NullableArguments # noqa: F401
|
||||
from .OptionalArguments import OptionalArguments # noqa: F401
|
||||
from .ArgcountChecker import ArgcountChecker # noqa: F401
|
||||
from .ArgumentReferences import ArgumentReferences # noqa: F401
|
||||
from .BeforeAfterCall import BeforeAfterCall # noqa: F401
|
||||
from .ConstantArguments import ConstantArguments # noqa: F401
|
||||
from .ReturnArguments import ReturnArguments # noqa: F401
|
||||
from .GILRelease import GILRelease # noqa: F401
|
||||
from .AutoGPU import AutoGPU # noqa: F401
|
||||
from .CuDNNPlugin import CuDNNPlugin # noqa: F401
|
||||
from .WrapDim import WrapDim # noqa: F401
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import argparse
|
|||
import gzip
|
||||
import os
|
||||
import sys
|
||||
import urllib
|
||||
|
||||
try:
|
||||
from urllib.error import URLError
|
||||
|
|
|
|||
|
|
@ -12,14 +12,11 @@ generated. In the full build system, OUTPUT_DIR is
|
|||
torch/csrc/jit/generated/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import re
|
||||
import copy
|
||||
from itertools import count, combinations, groupby
|
||||
from ..autograd.utils import CodeTemplate, write, uninplace_api_name
|
||||
from itertools import groupby
|
||||
from ..autograd.utils import CodeTemplate, write
|
||||
from ..autograd.gen_autograd import load_aten_declarations
|
||||
from collections import OrderedDict
|
||||
from ..autograd.gen_autograd import RETURNS_VIEWS_OF_INPUT
|
||||
|
||||
# JIT has a type system of
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from .generate_wrappers import generate_wrappers, wrap_function, import_module
|
||||
from .generate_wrappers import generate_wrappers, wrap_function, import_module # noqa: F401
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import os
|
||||
import sys
|
||||
from string import Template, ascii_lowercase
|
||||
from string import Template
|
||||
from ..cwrap import cwrap
|
||||
from ..cwrap.plugins import NNExtension, NullableArguments, AutoGPU
|
||||
from ..shared import import_module
|
||||
|
|
|
|||
|
|
@ -1,11 +1,7 @@
|
|||
from __future__ import print_function
|
||||
import multiprocessing
|
||||
import sys
|
||||
import os
|
||||
import inspect
|
||||
import collections
|
||||
import yaml
|
||||
import types
|
||||
import re
|
||||
import argparse
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,7 @@ import os
|
|||
import subprocess
|
||||
import glob
|
||||
|
||||
from .env import IS_CONDA, IS_LINUX, IS_WINDOWS, CONDA_DIR, check_env_flag, check_negative_env_flag, gather_paths
|
||||
from .cuda import USE_CUDA
|
||||
from .env import IS_CONDA, IS_WINDOWS, CONDA_DIR, check_env_flag, check_negative_env_flag, gather_paths
|
||||
|
||||
# On ROCm, RCCL development isn't complete. https://github.com/ROCmSoftwarePlatform/rccl
|
||||
USE_DISTRIBUTED = not check_negative_env_flag("USE_DISTRIBUTED") and not IS_WINDOWS and not check_env_flag("USE_ROCM")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,4 @@
|
|||
import os
|
||||
import glob
|
||||
|
||||
from .env import IS_WINDOWS, IS_CONDA, CONDA_DIR, check_env_flag, gather_paths
|
||||
from .env import check_env_flag
|
||||
from .rocm import USE_ROCM, ROCM_HOME
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
import os
|
||||
import glob
|
||||
import warnings
|
||||
from itertools import chain
|
||||
|
||||
from .env import IS_WINDOWS, IS_DARWIN, IS_CONDA, CONDA_DIR, check_negative_env_flag, \
|
||||
gather_paths
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import os
|
||||
import platform
|
||||
import ctypes.util
|
||||
from subprocess import Popen, PIPE
|
||||
|
||||
from .cuda import USE_CUDA
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,2 @@
|
|||
from .module_loader import import_module
|
||||
from .cwrap_common import set_declaration_defaults, \
|
||||
sort_by_number_of_options, enumerate_options_due_to_default
|
||||
from .module_loader import import_module # noqa: F401
|
||||
from .cwrap_common import set_declaration_defaults, sort_by_number_of_options, enumerate_options_due_to_default # noqa: F401
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import sys
|
|||
import platform
|
||||
from ._utils import _import_dotted_name
|
||||
from ._utils_internal import get_file_path, prepare_multiprocessing_environment
|
||||
from .version import __version__
|
||||
from .version import __version__ # noqa: F401
|
||||
from ._six import string_classes as _string_classes
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -39,7 +39,7 @@ import os as _dl_flags
|
|||
# if we have numpy, it *must* be imported before the call to setdlopenflags()
|
||||
# or there is risk that later c modules will segfault when importing numpy
|
||||
try:
|
||||
import numpy as _np
|
||||
import numpy as _np # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
|
@ -281,7 +281,7 @@ del BoolStorageBase
|
|||
|
||||
import torch.cuda
|
||||
import torch.autograd
|
||||
from torch.autograd import no_grad, enable_grad, set_grad_enabled
|
||||
from torch.autograd import no_grad, enable_grad, set_grad_enabled # noqa: F401
|
||||
import torch.nn
|
||||
import torch.optim
|
||||
import torch.multiprocessing
|
||||
|
|
@ -309,7 +309,7 @@ def compiled_with_cxx11_abi():
|
|||
|
||||
|
||||
# Import the ops "namespace"
|
||||
from torch._ops import ops
|
||||
from torch._ops import ops # noqa: F401
|
||||
|
||||
# Import the quasi random sampler
|
||||
import torch.quasirandom
|
||||
|
|
|
|||
|
|
@ -53,9 +53,9 @@ else:
|
|||
|
||||
|
||||
if PY2:
|
||||
import Queue as queue
|
||||
import Queue as queue # noqa: F401
|
||||
else:
|
||||
import queue
|
||||
import queue # noqa: F401
|
||||
|
||||
|
||||
def with_metaclass(meta, *bases):
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
import math
|
||||
import torch
|
||||
from functools import reduce
|
||||
from sys import float_info
|
||||
from torch._six import inf, nan
|
||||
from torch._six import inf
|
||||
|
||||
|
||||
class __PrinterOptions(object):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,3 @@
|
|||
import os
|
||||
import itertools
|
||||
import importlib
|
||||
|
||||
try:
|
||||
# when compiling a cffi extension, this works. When compiling
|
||||
# torch itself, it doesn't work because the parent module can't
|
||||
|
|
|
|||
|
|
@ -8,11 +8,11 @@ import torch
|
|||
import warnings
|
||||
|
||||
from .variable import Variable
|
||||
from .function import Function, NestedIOFunction
|
||||
from .gradcheck import gradcheck, gradgradcheck
|
||||
from .grad_mode import no_grad, enable_grad, set_grad_enabled
|
||||
from .anomaly_mode import detect_anomaly, set_detect_anomaly
|
||||
from . import profiler
|
||||
from .function import Function, NestedIOFunction # noqa: F401
|
||||
from .gradcheck import gradcheck, gradgradcheck # noqa: F401
|
||||
from .grad_mode import no_grad, enable_grad, set_grad_enabled # noqa: F401
|
||||
from .anomaly_mode import detect_anomaly, set_detect_anomaly # noqa: F401
|
||||
from . import profiler # noqa: F401
|
||||
|
||||
__all__ = ['Variable', 'Function', 'backward', 'grad_mode']
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from .tensor import *
|
||||
from .tensor import * # noqa: F401
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import torch
|
||||
from functools import reduce
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import torch
|
||||
from torch._six import container_abcs, istuple
|
||||
import torch.testing
|
||||
import sys
|
||||
from itertools import product
|
||||
import warnings
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,7 @@
|
|||
import subprocess
|
||||
import re
|
||||
import os
|
||||
import sys
|
||||
import itertools
|
||||
from collections import defaultdict, namedtuple
|
||||
|
||||
import torch
|
||||
from torch._six import FileNotFoundError
|
||||
|
||||
|
||||
class range(object):
|
||||
|
|
|
|||
|
|
@ -648,7 +648,7 @@ torch._storage_classes.add(ByteStorage)
|
|||
torch._storage_classes.add(HalfStorage)
|
||||
torch._storage_classes.add(BoolStorage)
|
||||
|
||||
from . import sparse
|
||||
from . import profiler
|
||||
from . import nvtx
|
||||
from .streams import Stream, Event
|
||||
from . import sparse # noqa: F401
|
||||
from . import profiler # noqa: F401
|
||||
from . import nvtx # noqa: F401
|
||||
from .streams import Stream, Event # noqa: F401
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
import torch
|
||||
from . import nccl
|
||||
from torch._utils import _accumulate, _take_tensors, _flatten_dense_tensors, \
|
||||
_flatten_sparse_tensors, _unflatten_dense_tensors, \
|
||||
_unflatten_sparse_tensors, _reorder_tensors_as
|
||||
from torch._utils import _take_tensors, _flatten_dense_tensors, \
|
||||
_unflatten_dense_tensors, _reorder_tensors_as
|
||||
|
||||
|
||||
def broadcast(tensor, devices):
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ if is_available() and not torch._C._c10d_init():
|
|||
|
||||
|
||||
if is_available():
|
||||
from .distributed_c10d import *
|
||||
from .distributed_c10d import * # noqa: F401
|
||||
# Variables prefixed with underscore are not auto imported
|
||||
# See the comment in `distributed_c10d.py` above `_backend` on why we expose
|
||||
# this.
|
||||
from .distributed_c10d import _backend
|
||||
from .distributed_c10d import _backend # noqa: F401
|
||||
|
|
|
|||
|
|
@ -3,7 +3,10 @@ import warnings
|
|||
from torch._six import string_classes
|
||||
from datetime import timedelta
|
||||
|
||||
from .rendezvous import rendezvous, register_rendezvous_handler
|
||||
# This module is wildcard imported from torch.distributed.
|
||||
# TODO: specify __all__
|
||||
|
||||
from .rendezvous import rendezvous, register_rendezvous_handler # noqa: F401
|
||||
from . import BroadcastOptions, AllreduceOptions, ReduceOptions, \
|
||||
ScatterOptions, GatherOptions
|
||||
from . import ReduceOp
|
||||
|
|
|
|||
|
|
@ -140,11 +140,8 @@ will not pass ``--local_rank`` when you specify this flag.
|
|||
import sys
|
||||
import subprocess
|
||||
import os
|
||||
import socket
|
||||
from argparse import ArgumentParser, REMAINDER
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -124,6 +124,8 @@ __all__ = [
|
|||
'Gamma',
|
||||
'Geometric',
|
||||
'Gumbel',
|
||||
'HalfCauchy',
|
||||
'HalfNormal',
|
||||
'Independent',
|
||||
'Laplace',
|
||||
'LogNormal',
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import torch
|
|||
from torch._six import nan
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.utils import probs_to_logits, logits_to_probs, lazy_property, broadcast_all
|
||||
from torch.distributions.utils import probs_to_logits, logits_to_probs, lazy_property
|
||||
|
||||
|
||||
class Categorical(Distribution):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import torch
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.gamma import Gamma
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from numbers import Number
|
||||
import torch
|
||||
import math
|
||||
from torch._six import nan
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from numbers import Number
|
|||
import torch
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.exp_family import ExponentialFamily
|
||||
from torch.distributions.utils import broadcast_all, lazy_property
|
||||
from torch.distributions.utils import broadcast_all
|
||||
|
||||
|
||||
def _standard_gamma(concentration):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
from torch._six import inf
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.transforms import AbsTransform
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
from torch._six import inf
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.transforms import AbsTransform
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ from .gumbel import Gumbel
|
|||
from .half_normal import HalfNormal
|
||||
from .independent import Independent
|
||||
from .laplace import Laplace
|
||||
from .logistic_normal import LogisticNormal
|
||||
from .lowrank_multivariate_normal import (LowRankMultivariateNormal, _batch_lowrank_logdet,
|
||||
_batch_lowrank_mahalanobis)
|
||||
from .multivariate_normal import (MultivariateNormal, _batch_mahalanobis)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import torch
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.transforms import ExpTransform
|
||||
from torch.distributions.normal import Normal
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import torch
|
|||
from torch.distributions import constraints
|
||||
from torch.distributions.normal import Normal
|
||||
from torch.distributions.transformed_distribution import TransformedDistribution
|
||||
from torch.distributions.transforms import ComposeTransform, ExpTransform, StickBreakingTransform
|
||||
from torch.distributions.transforms import StickBreakingTransform
|
||||
|
||||
|
||||
class LogisticNormal(TransformedDistribution):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import torch
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.exponential import Exponential
|
||||
from torch.distributions.transformed_distribution import TransformedDistribution
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import math
|
||||
from numbers import Number
|
||||
|
||||
import torch
|
||||
from torch._six import inf, nan
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from functools import update_wrapper
|
||||
from numbers import Number
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from .onnx import *
|
||||
from .onnx import * # noqa: F401
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import importlib
|
|||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
|
|
@ -10,10 +9,7 @@ if sys.version_info[0] == 2:
|
|||
from urllib2 import urlopen # noqa f811
|
||||
else:
|
||||
from urllib.request import urlopen
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import torch
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from urllib.parse import urlparse # noqa: F401
|
||||
|
||||
MASTER_BRANCH = 'master'
|
||||
ENV_TORCH_HUB_DIR = 'TORCH_HUB_DIR'
|
||||
|
|
|
|||
|
|
@ -1,35 +1,31 @@
|
|||
import torch._C
|
||||
from torch import Tensor
|
||||
from torch.autograd import Variable, function
|
||||
from torch.serialization import validate_cuda_device
|
||||
from torch.nn import Module, ModuleList, ParameterList, Parameter, Sequential
|
||||
from torch.nn import Module, ModuleList, Parameter, Sequential
|
||||
from torch.jit.frontend import get_jit_class_def, get_jit_def, get_default_args
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.jit.annotations
|
||||
import torch._jit_internal as _jit_internal
|
||||
from torch._six import raise_from, with_metaclass, get_function_from_type, \
|
||||
from torch._six import with_metaclass, get_function_from_type, \
|
||||
string_classes
|
||||
from torch._jit_internal import ignore
|
||||
from torch.jit._pickle import Unpickler
|
||||
from torch._jit_internal import ignore # noqa: F401
|
||||
from torch.jit._pickle import Unpickler # noqa: F401
|
||||
from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \
|
||||
_list_with_default
|
||||
import torch.testing
|
||||
|
||||
import math
|
||||
from collections import defaultdict, OrderedDict, namedtuple
|
||||
from collections import OrderedDict, namedtuple
|
||||
import textwrap
|
||||
import sys
|
||||
import warnings
|
||||
import itertools
|
||||
import weakref
|
||||
import types
|
||||
import contextlib
|
||||
import os
|
||||
import functools
|
||||
import copy
|
||||
import numbers
|
||||
import collections
|
||||
import re
|
||||
import inspect
|
||||
import pickle
|
||||
if sys.version_info[0] > 2:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
import torch
|
||||
import functools
|
||||
import pickle
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import re
|
||||
import sys
|
||||
import ast
|
||||
import inspect
|
||||
|
|
@ -181,3 +180,33 @@ def ann_to_type(ann):
|
|||
elif ann is str:
|
||||
return StringType.get()
|
||||
raise ValueError("Unknown type annotation: '{}'".format(ann.__name__))
|
||||
|
||||
|
||||
__all__ = [
|
||||
'List',
|
||||
'BroadcastingList1',
|
||||
'BroadcastingList2',
|
||||
'BroadcastingList3',
|
||||
'Tuple',
|
||||
'is_tuple',
|
||||
'is_list',
|
||||
'Dict',
|
||||
'is_dict',
|
||||
'TensorType',
|
||||
'TupleType',
|
||||
'FloatType',
|
||||
'IntType',
|
||||
'ListType',
|
||||
'StringType',
|
||||
'DictType',
|
||||
'Module',
|
||||
# TODO: Consider not exporting these during wildcard import (reserve
|
||||
# that for the types; for idiomatic typing code.)
|
||||
'get_signature',
|
||||
'get_num_params',
|
||||
'parse_type_line',
|
||||
'get_type_line',
|
||||
'split_type_line',
|
||||
'try_real_annotations',
|
||||
'ann_to_type',
|
||||
]
|
||||
|
|
|
|||
|
|
@ -5,8 +5,6 @@ import ast
|
|||
import inspect
|
||||
import string
|
||||
from textwrap import dedent
|
||||
from functools import partial
|
||||
from collections import namedtuple
|
||||
from torch._six import PY2
|
||||
from torch._C._jit_tree_views import *
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,7 @@
|
|||
import torch
|
||||
import copy
|
||||
import numbers
|
||||
from typing import Tuple, Optional
|
||||
from typing import Tuple, Optional # noqa: F401
|
||||
from torch import Tensor
|
||||
from torch.jit import ScriptModule
|
||||
|
||||
from torch.nn.utils.rnn import PackedSequence
|
||||
from torch.nn import _VF
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ __all__ = ['set_sharing_strategy', 'get_sharing_strategy',
|
|||
'get_all_sharing_strategies']
|
||||
|
||||
|
||||
from multiprocessing import *
|
||||
from multiprocessing import * # noqa: F401
|
||||
|
||||
|
||||
__all__ += multiprocessing.__all__
|
||||
|
|
@ -36,13 +36,13 @@ torch._C._multiprocessing_init()
|
|||
if sys.version_info < (3, 3):
|
||||
"""Override basic classes in Python 2.7 and Python 3.3 to use ForkingPickler
|
||||
for serialization. Later versions of Python already use ForkingPickler."""
|
||||
from .queue import Queue, SimpleQueue
|
||||
from .pool import Pool
|
||||
from .queue import Queue, SimpleQueue # noqa: F401
|
||||
from .pool import Pool # noqa: F401
|
||||
|
||||
|
||||
"""Add helper function to spawn N processes and wait for completion of any of
|
||||
them. This depends `mp.get_context` which was added in Python 3.4."""
|
||||
from .spawn import spawn, SpawnContext
|
||||
from .spawn import spawn, SpawnContext # noqa: F401
|
||||
|
||||
|
||||
if sys.platform == 'darwin' or sys.platform == 'win32':
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import torch
|
||||
import torch.utils.hooks
|
||||
import os
|
||||
import weakref
|
||||
import threading
|
||||
import multiprocessing
|
||||
from multiprocessing.reduction import ForkingPickler
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user