mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Turn on F401: Unused import warning. (#18598)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18598 ghimport-source-id: c74597e5e7437e94a43c163cee0639b20d0d0c6a Stack from [ghstack](https://github.com/ezyang/ghstack): * **#18598 Turn on F401: Unused import warning.** This was requested by someone at Facebook; this lint is turned on for Facebook by default. "Sure, why not." I had to noqa a number of imports in __init__. Hypothetically we're supposed to use __all__ in this case, but I was too lazy to fix it. Left for future work. Be careful! flake8-2 and flake8-3 behave differently with respect to import resolution for # type: comments. flake8-3 will report an import unused; flake8-2 will not. For now, I just noqa'd all these sites. All the changes were done by hand. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: D14687478 fbshipit-source-id: 30d532381e914091aadfa0d2a5a89404819663e3
This commit is contained in:
parent
96456bfa4c
commit
173f224570
|
|
@ -22,7 +22,7 @@ def handle_missing_graphviz(f):
|
||||||
calls to the draw() method of the returned object to do nothing.
|
calls to the draw() method of the returned object to do nothing.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import pygraphviz
|
import pygraphviz # noqa: F401
|
||||||
return f
|
return f
|
||||||
|
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
|
|
|
||||||
2
.flake8
2
.flake8
|
|
@ -7,7 +7,7 @@ max-line-length = 120
|
||||||
# C408 ignored because we like the dict keyword argument syntax
|
# C408 ignored because we like the dict keyword argument syntax
|
||||||
# E501 is not flexible enough, we're using B950 instead
|
# E501 is not flexible enough, we're using B950 instead
|
||||||
ignore =
|
ignore =
|
||||||
E203,E305,E402,E501,E721,E741,F401,F403,F405,F821,F841,F999,W503,W504,C408,
|
E203,E305,E402,E501,E721,E741,F403,F405,F821,F841,F999,W503,W504,C408,
|
||||||
# these ignores are from flake8-bugbear; please fix!
|
# these ignores are from flake8-bugbear; please fix!
|
||||||
B007,B008,
|
B007,B008,
|
||||||
# these ignores are from flake8-comprehensions; please fix!
|
# these ignores are from flake8-comprehensions; please fix!
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import numpy
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from .cells import *
|
from .cells import * # noqa: F401
|
||||||
from .factory import *
|
from .factory import * # noqa: F401
|
||||||
|
|
||||||
# (output, next_state) = cell(input, state)
|
# (output, next_state) = cell(input, state)
|
||||||
seqLength = 100
|
seqLength = 100
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ import argparse
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .cells import lstm_cell
|
|
||||||
from .factory import pytorch_lstm_creator, varlen_pytorch_lstm_creator
|
from .factory import pytorch_lstm_creator, varlen_pytorch_lstm_creator
|
||||||
from .runner import get_nn_runners
|
from .runner import get_nn_runners
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,11 +20,8 @@
|
||||||
import os
|
import os
|
||||||
# sys.path.insert(0, os.path.abspath('.'))
|
# sys.path.insert(0, os.path.abspath('.'))
|
||||||
|
|
||||||
import sys
|
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
import pytorch_sphinx_theme
|
|
||||||
|
|
||||||
# -- General configuration ------------------------------------------------
|
# -- General configuration ------------------------------------------------
|
||||||
|
|
||||||
# If your documentation needs a minimal Sphinx version, state it here.
|
# If your documentation needs a minimal Sphinx version, state it here.
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
try:
|
try:
|
||||||
import torchvision
|
import torchvision # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
import warnings
|
import warnings
|
||||||
warnings.warn('unable to load "torchvision" package')
|
warnings.warn('unable to load "torchvision" package')
|
||||||
|
|
|
||||||
3
setup.py
3
setup.py
|
|
@ -142,7 +142,7 @@
|
||||||
# we will search for libraries in these paths
|
# we will search for libraries in these paths
|
||||||
|
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
from setuptools import setup, Extension, distutils, Command, find_packages
|
from setuptools import setup, Extension, distutils, find_packages
|
||||||
from distutils import core, dir_util
|
from distutils import core, dir_util
|
||||||
from distutils.core import Distribution
|
from distutils.core import Distribution
|
||||||
from distutils.errors import DistutilsArgError
|
from distutils.errors import DistutilsArgError
|
||||||
|
|
@ -151,7 +151,6 @@ import setuptools.command.install
|
||||||
import distutils.command.clean
|
import distutils.command.clean
|
||||||
import distutils.sysconfig
|
import distutils.sysconfig
|
||||||
import filecmp
|
import filecmp
|
||||||
import platform
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
from torch._six import inf, nan, istuple
|
from torch._six import inf, istuple
|
||||||
from functools import reduce, wraps
|
from functools import reduce
|
||||||
from operator import mul, itemgetter
|
from operator import mul, itemgetter
|
||||||
import collections
|
import collections
|
||||||
from torch.autograd import Variable, Function, detect_anomaly
|
from torch.autograd import Variable
|
||||||
from torch.testing import make_non_contiguous
|
from torch.testing import make_non_contiguous
|
||||||
from common_utils import (skipIfNoLapack,
|
from common_utils import (skipIfNoLapack,
|
||||||
prod_single_zero, random_square_matrix_of_rank,
|
prod_single_zero, random_square_matrix_of_rank,
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn.functional import _Reduction
|
from torch.nn.functional import _Reduction
|
||||||
from common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
|
from common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
|
||||||
TEST_WITH_ROCM, skipIfRocm
|
TEST_WITH_ROCM
|
||||||
from common_cuda import TEST_CUDA
|
from common_cuda import TEST_CUDA
|
||||||
from torch.autograd.gradcheck import get_numerical_jacobian, iter_tensors
|
from torch.autograd.gradcheck import get_numerical_jacobian, iter_tensors
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import argparse
|
|
||||||
import os.path
|
import os.path
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,10 @@ from __future__ import print_function
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import itertools
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.jit
|
import torch.jit
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
import torch.autograd.function as function
|
|
||||||
|
|
||||||
import onnx
|
import onnx
|
||||||
import caffe2.python.onnx.backend as c2
|
import caffe2.python.onnx.backend as c2
|
||||||
|
|
|
||||||
|
|
@ -5,12 +5,9 @@ from __future__ import unicode_literals
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import glob
|
import glob
|
||||||
import numpy as np
|
|
||||||
import onnx.backend.test
|
import onnx.backend.test
|
||||||
import caffe2.python.onnx.backend as c2
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from onnx import numpy_helper
|
|
||||||
from test_caffe2_common import run_generated_test
|
from test_caffe2_common import run_generated_test
|
||||||
import google.protobuf.text_format
|
import google.protobuf.text_format
|
||||||
import test_onnx_common
|
import test_onnx_common
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,6 @@ import shutil
|
||||||
import torch
|
import torch
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import test_pytorch_common
|
|
||||||
import test_onnx_common
|
import test_onnx_common
|
||||||
from common_nn import module_tests
|
from common_nn import module_tests
|
||||||
from test_nn import new_module_tests
|
from test_nn import new_module_tests
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from .squeezenet import *
|
from .squeezenet import * # noqa: F401
|
||||||
from .super_resolution import *
|
from .super_resolution import * # noqa: F401
|
||||||
from .op_test import *
|
from .op_test import * # noqa: F401
|
||||||
from .srresnet import *
|
from .srresnet import * # noqa: F401
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import math
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.init as init
|
import torch.nn.init as init
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.init as init
|
import torch.nn.init as init
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.autograd import Variable
|
|
||||||
|
|
||||||
|
|
||||||
class RNNModel(nn.Module):
|
class RNNModel(nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ from torchvision.models.resnet import resnet50
|
||||||
from torchvision.models.vgg import vgg16, vgg16_bn, vgg19, vgg19_bn
|
from torchvision.models.vgg import vgg16, vgg16_bn, vgg19, vgg19_bn
|
||||||
|
|
||||||
from model_defs.mnist import MNIST
|
from model_defs.mnist import MNIST
|
||||||
from model_defs.word_language_model import RNNModel
|
|
||||||
from model_defs.squeezenet import SqueezeNet
|
from model_defs.squeezenet import SqueezeNet
|
||||||
from model_defs.super_resolution import SuperResolutionNet
|
from model_defs.super_resolution import SuperResolutionNet
|
||||||
from model_defs.srresnet import SRResNet
|
from model_defs.srresnet import SRResNet
|
||||||
|
|
@ -17,17 +16,9 @@ from test_pytorch_common import TestCase, run_tests, skipIfNoLapack
|
||||||
import torch
|
import torch
|
||||||
import torch.onnx
|
import torch.onnx
|
||||||
import torch.onnx.utils
|
import torch.onnx.utils
|
||||||
from torch.autograd import Variable, Function
|
from torch.autograd import Variable
|
||||||
from torch.nn import Module
|
|
||||||
from torch.onnx import OperatorExportTypes
|
from torch.onnx import OperatorExportTypes
|
||||||
|
|
||||||
import onnx
|
|
||||||
import onnx.checker
|
|
||||||
import onnx.helper
|
|
||||||
|
|
||||||
import google.protobuf.text_format
|
|
||||||
|
|
||||||
import io
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import caffe2.python.onnx.backend as backend
|
import caffe2.python.onnx.backend as backend
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from test_pytorch_common import TestCase, run_tests, skipIfNoLapack, flatten
|
from test_pytorch_common import TestCase, run_tests, flatten
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.onnx
|
import torch.onnx
|
||||||
|
|
@ -10,11 +10,9 @@ import itertools
|
||||||
import io
|
import io
|
||||||
import unittest
|
import unittest
|
||||||
import inspect
|
import inspect
|
||||||
import argparse
|
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
|
||||||
import common_utils as common
|
import common_utils as common
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ import torch.autograd.function as function
|
||||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||||
sys.path.insert(-1, pytorch_test_dir)
|
sys.path.insert(-1, pytorch_test_dir)
|
||||||
|
|
||||||
from common_utils import *
|
from common_utils import * # noqa: F401
|
||||||
|
|
||||||
torch.set_default_tensor_type('torch.FloatTensor')
|
torch.set_default_tensor_type('torch.FloatTensor')
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
# Some standard imports
|
# Some standard imports
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.autograd import Variable
|
|
||||||
import torch.onnx
|
import torch.onnx
|
||||||
import torch.nn.init as init
|
import torch.nn.init as init
|
||||||
from caffe2.python.model_helper import ModelHelper
|
from caffe2.python.model_helper import ModelHelper
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from functools import wraps
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
import torch.legacy.optim as optim
|
import torch.legacy.optim as optim
|
||||||
from pprint import pprint
|
|
||||||
|
|
||||||
|
|
||||||
def rosenbrock(tensor):
|
def rosenbrock(tensor):
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@ import warnings
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from operator import mul, itemgetter
|
from operator import mul
|
||||||
from functools import reduce, wraps
|
from functools import reduce
|
||||||
from torch._six import inf, nan, istuple
|
from torch._six import inf, nan, istuple
|
||||||
from torch.autograd.gradcheck import gradgradcheck, gradcheck
|
from torch.autograd.gradcheck import gradgradcheck, gradcheck
|
||||||
from torch.autograd.function import once_differentiable
|
from torch.autograd.function import once_differentiable
|
||||||
|
|
@ -17,14 +17,11 @@ from torch.autograd.profiler import profile
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
from common_utils import (TEST_MKL, TestCase, run_tests, skipIfNoLapack,
|
from common_utils import (TEST_MKL, TestCase, run_tests, skipIfNoLapack,
|
||||||
suppress_warnings, skipIfRocm,
|
suppress_warnings, skipIfRocm,
|
||||||
prod_single_zero, random_square_matrix_of_rank,
|
load_tests)
|
||||||
random_symmetric_matrix, random_symmetric_psd_matrix,
|
|
||||||
random_symmetric_pd_matrix, make_nonzero_det,
|
|
||||||
random_fullrank_matrix_distinct_singular_value, load_tests)
|
|
||||||
from common_cuda import TEST_CUDA
|
from common_cuda import TEST_CUDA
|
||||||
from torch.autograd import Variable, Function, detect_anomaly
|
from torch.autograd import Variable, Function, detect_anomaly
|
||||||
from torch.autograd.function import InplaceFunction
|
from torch.autograd.function import InplaceFunction
|
||||||
from torch.testing import make_non_contiguous, randn_like
|
from torch.testing import randn_like
|
||||||
from common_methods_invocations import (method_tests,
|
from common_methods_invocations import (method_tests,
|
||||||
create_input, unpack_variables,
|
create_input, unpack_variables,
|
||||||
EXCLUDE_FUNCTIONAL, EXCLUDE_GRADCHECK,
|
EXCLUDE_FUNCTIONAL, EXCLUDE_GRADCHECK,
|
||||||
|
|
@ -32,7 +29,7 @@ from common_methods_invocations import (method_tests,
|
||||||
EXCLUDE_GRADGRADCHECK_BY_TEST_NAME,
|
EXCLUDE_GRADGRADCHECK_BY_TEST_NAME,
|
||||||
exclude_tensor_method,
|
exclude_tensor_method,
|
||||||
mask_not_all_zeros,
|
mask_not_all_zeros,
|
||||||
L, S)
|
S)
|
||||||
|
|
||||||
# load_tests from common_utils is used to automatically filter tests for
|
# load_tests from common_utils is used to automatically filter tests for
|
||||||
# sharding on sandcastle. This line silences flake warnings
|
# sharding on sandcastle. This line silences flake warnings
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,5 @@
|
||||||
import io
|
import io
|
||||||
import math
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import re
|
|
||||||
import unittest
|
import unittest
|
||||||
import sys
|
import sys
|
||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
|
|
@ -19,9 +17,9 @@ from torch._six import inf, nan
|
||||||
from test_torch import _TestTorchMixin
|
from test_torch import _TestTorchMixin
|
||||||
|
|
||||||
from common_methods_invocations import tri_tests_args, tri_large_tests_args, \
|
from common_methods_invocations import tri_tests_args, tri_large_tests_args, \
|
||||||
run_additional_tri_tests, _compare_trilu_indices, _compare_large_trilu_indices
|
_compare_trilu_indices, _compare_large_trilu_indices
|
||||||
from common_utils import TestCase, get_gpu_type, to_gpu, freeze_rng_state, run_tests, \
|
from common_utils import TestCase, get_gpu_type, to_gpu, freeze_rng_state, run_tests, \
|
||||||
PY3, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, skipIfRocm, TEST_NUMPY, TEST_WITH_ROCM, load_tests, iter_indices
|
PY3, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, skipIfRocm, TEST_NUMPY, TEST_WITH_ROCM, load_tests
|
||||||
|
|
||||||
# load_tests from common_utils is used to automatically filter tests for
|
# load_tests from common_utils is used to automatically filter tests for
|
||||||
# sharding on sandcastle. This line silences flake warnings
|
# sharding on sandcastle. This line silences flake warnings
|
||||||
|
|
|
||||||
|
|
@ -3,13 +3,10 @@ import sys
|
||||||
import errno
|
import errno
|
||||||
import os
|
import os
|
||||||
import ctypes
|
import ctypes
|
||||||
import signal
|
|
||||||
import torch
|
import torch
|
||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
import traceback
|
|
||||||
import unittest
|
import unittest
|
||||||
import subprocess
|
|
||||||
import itertools
|
import itertools
|
||||||
import warnings
|
import warnings
|
||||||
from torch import multiprocessing as mp
|
from torch import multiprocessing as mp
|
||||||
|
|
|
||||||
|
|
@ -16,11 +16,9 @@ import torch.cuda
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
|
||||||
from common_utils import TestCase, run_tests
|
from common_utils import TestCase, run_tests
|
||||||
from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
|
from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
|
||||||
from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT
|
from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT
|
||||||
import common_utils as common
|
|
||||||
|
|
||||||
BACKEND = os.environ["BACKEND"]
|
BACKEND = os.environ["BACKEND"]
|
||||||
TEMP_DIR = os.environ["TEMP_DIR"]
|
TEMP_DIR = os.environ["TEMP_DIR"]
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,6 @@ import torch
|
||||||
import unittest
|
import unittest
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import ast
|
|
||||||
import _ast
|
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
from common_utils import TestCase, run_tests
|
from common_utils import TestCase, run_tests
|
||||||
import torch
|
import torch
|
||||||
import warnings
|
|
||||||
from torch import tensor
|
from torch import tensor
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,9 +12,6 @@ from contextlib import contextmanager
|
||||||
from itertools import product, chain
|
from itertools import product, chain
|
||||||
import torch.jit.frontend
|
import torch.jit.frontend
|
||||||
from torch.autograd import Variable, Function
|
from torch.autograd import Variable, Function
|
||||||
from torch.nn import Module
|
|
||||||
from torch.autograd.function import traceable
|
|
||||||
from torch.testing import assert_allclose
|
|
||||||
from torch.onnx import OperatorExportTypes
|
from torch.onnx import OperatorExportTypes
|
||||||
from torch._six import inf, PY2, builtins, StringIO
|
from torch._six import inf, PY2, builtins, StringIO
|
||||||
from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
|
from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
|
||||||
|
|
@ -25,7 +22,6 @@ from textwrap import dedent
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
import os
|
import os
|
||||||
import io
|
import io
|
||||||
import itertools
|
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
import inspect
|
import inspect
|
||||||
|
|
@ -46,14 +42,13 @@ from common_methods_invocations import method_tests as autograd_method_tests
|
||||||
from common_methods_invocations import create_input, unpack_variables, \
|
from common_methods_invocations import create_input, unpack_variables, \
|
||||||
exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL
|
exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
from torch._C import TensorType, TupleType, FloatType, IntType, \
|
from torch._C import TensorType
|
||||||
ListType, StringType, DictType
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import random
|
import random
|
||||||
from typing import List, Dict, Optional, Tuple
|
from typing import List, Dict, Optional, Tuple
|
||||||
from torch.jit.frontend import NotSupportedError
|
from torch.jit.frontend import NotSupportedError
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.jit.annotations import BroadcastingList2, BroadcastingList3
|
from torch.jit.annotations import BroadcastingList2, BroadcastingList3 # noqa: F401
|
||||||
|
|
||||||
# For testing truediv in python 2
|
# For testing truediv in python 2
|
||||||
from test_module.future_div import div_int_future, div_float_future
|
from test_module.future_div import div_int_future, div_float_future
|
||||||
|
|
@ -6727,8 +6722,6 @@ a")
|
||||||
|
|
||||||
@unittest.skipIf(not PY35, "Python 3.5 needed")
|
@unittest.skipIf(not PY35, "Python 3.5 needed")
|
||||||
def test_type_annotation_py3(self):
|
def test_type_annotation_py3(self):
|
||||||
import importlib.util
|
|
||||||
|
|
||||||
code = dedent("""
|
code = dedent("""
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
@ -9349,8 +9342,6 @@ a")
|
||||||
foo(torch.ones([123])) # wrong size
|
foo(torch.ones([123])) # wrong size
|
||||||
|
|
||||||
def test_builtin_error_messsage(self):
|
def test_builtin_error_messsage(self):
|
||||||
from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
|
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def close_match(x):
|
def close_match(x):
|
||||||
|
|
@ -11020,8 +11011,6 @@ class TestEndToEndHybridFrontendModels(JitTestCase):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _test_super_resolution(self, device, check_export_import=True):
|
def _test_super_resolution(self, device, check_export_import=True):
|
||||||
import torch.nn.init as init
|
|
||||||
|
|
||||||
class Net(nn.Module):
|
class Net(nn.Module):
|
||||||
|
|
||||||
def __init__(self, upscale_factor):
|
def __init__(self, upscale_factor):
|
||||||
|
|
|
||||||
|
|
@ -3,17 +3,12 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import functools
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
import sys
|
|
||||||
import torch
|
import torch
|
||||||
import torch.autograd.function as function
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from common_utils import TestCase, run_tests, IS_WINDOWS, \
|
from common_utils import IS_WINDOWS, \
|
||||||
skipIfRocm, IS_SANDCASTLE
|
skipIfRocm, IS_SANDCASTLE
|
||||||
from typing import List, Dict, Optional, Tuple
|
|
||||||
|
|
||||||
from test_jit import JitTestCase, enable_cpu_fuser, RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, \
|
from test_jit import JitTestCase, enable_cpu_fuser, RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, \
|
||||||
backward_graph
|
backward_graph
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
import torch
|
import torch # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
def div_int_future():
|
def div_int_future():
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
import torch
|
import torch # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
def div_int_nofuture():
|
def div_int_nofuture():
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,6 @@ import torch.utils.hooks
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
from common_utils import (TestCase, run_tests, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, TEST_WITH_ASAN,
|
from common_utils import (TestCase, run_tests, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, TEST_WITH_ASAN,
|
||||||
load_tests, slowTest)
|
load_tests, slowTest)
|
||||||
from multiprocessing.reduction import ForkingPickler
|
|
||||||
|
|
||||||
# load_tests from common_utils is used to automatically filter tests for
|
# load_tests from common_utils is used to automatically filter tests for
|
||||||
# sharding on sandcastle. This line silences flake warnings
|
# sharding on sandcastle. This line silences flake warnings
|
||||||
|
|
|
||||||
|
|
@ -11,8 +11,6 @@ from itertools import repeat, product
|
||||||
from functools import wraps, reduce
|
from functools import wraps, reduce
|
||||||
from operator import mul
|
from operator import mul
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import hashlib
|
|
||||||
import os
|
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -29,9 +27,9 @@ from torch.autograd import Variable, gradcheck
|
||||||
from torch.autograd.gradcheck import gradgradcheck
|
from torch.autograd.gradcheck import gradgradcheck
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
from torch.nn.parallel._functions import Broadcast
|
from torch.nn.parallel._functions import Broadcast
|
||||||
from common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, TEST_WITH_ROCM, \
|
from common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
|
||||||
TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, download_file, PY3, PY34, to_gpu, \
|
TEST_NUMPY, TEST_SCIPY, download_file, PY3, PY34, to_gpu, \
|
||||||
get_function_arglist, skipCUDAMemoryLeakCheckIf, load_tests
|
get_function_arglist, load_tests
|
||||||
from common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \
|
from common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \
|
||||||
TEST_CUDNN_VERSION
|
TEST_CUDNN_VERSION
|
||||||
from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \
|
from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.jit
|
import torch.jit
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import unittest
|
import unittest
|
||||||
from caffe2.python import core
|
from caffe2.python import core
|
||||||
from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
|
from common_utils import TestCase, run_tests
|
||||||
skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
|
|
||||||
freeze_rng_state, set_rng_seed
|
|
||||||
|
|
||||||
|
|
||||||
def canonical(graph):
|
def canonical(graph):
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import sparse
|
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import functools
|
import functools
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ import torch.cuda
|
||||||
import torch.distributed.deprecated as dist
|
import torch.distributed.deprecated as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
|
||||||
from common_utils import TestCase, run_tests
|
from common_utils import TestCase, run_tests
|
||||||
from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
|
from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ import io
|
||||||
import os
|
import os
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
import operator
|
|
||||||
import copy
|
import copy
|
||||||
import shutil
|
import shutil
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -17,7 +16,7 @@ import gzip
|
||||||
import types
|
import types
|
||||||
import textwrap
|
import textwrap
|
||||||
import re
|
import re
|
||||||
from torch._utils_internal import get_file_path, get_file_path_2
|
from torch._utils_internal import get_file_path_2
|
||||||
from torch.utils.dlpack import from_dlpack, to_dlpack
|
from torch.utils.dlpack import from_dlpack, to_dlpack
|
||||||
from torch._utils import _rebuild_tensor
|
from torch._utils import _rebuild_tensor
|
||||||
from torch._six import inf, nan, string_classes, istuple
|
from torch._six import inf, nan, string_classes, istuple
|
||||||
|
|
@ -2032,7 +2031,6 @@ class _TestTorchMixin(object):
|
||||||
def _test_int_pow(self, cast):
|
def _test_int_pow(self, cast):
|
||||||
if not TEST_NUMPY:
|
if not TEST_NUMPY:
|
||||||
return
|
return
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
def check_against_np(tensor, exp):
|
def check_against_np(tensor, exp):
|
||||||
tensor_np = tensor.cpu().numpy()
|
tensor_np = tensor.cpu().numpy()
|
||||||
|
|
@ -4669,7 +4667,6 @@ class _TestTorchMixin(object):
|
||||||
# Test non-contiguous inputs.
|
# Test non-contiguous inputs.
|
||||||
if not TEST_NUMPY:
|
if not TEST_NUMPY:
|
||||||
return
|
return
|
||||||
import numpy
|
|
||||||
from numpy.linalg import solve
|
from numpy.linalg import solve
|
||||||
A = cast(random_fullrank_matrix_distinct_singular_value(2, 2)).permute(1, 0, 2)
|
A = cast(random_fullrank_matrix_distinct_singular_value(2, 2)).permute(1, 0, 2)
|
||||||
b = cast(torch.randn(2, 2, 2)).permute(2, 1, 0)
|
b = cast(torch.randn(2, 2, 2)).permute(2, 1, 0)
|
||||||
|
|
@ -6218,7 +6215,6 @@ class _TestTorchMixin(object):
|
||||||
# Test non-contiguous inputs.
|
# Test non-contiguous inputs.
|
||||||
if not TEST_NUMPY:
|
if not TEST_NUMPY:
|
||||||
return
|
return
|
||||||
import numpy
|
|
||||||
from numpy.linalg import solve
|
from numpy.linalg import solve
|
||||||
A = random_symmetric_pd_matrix(2, 2)
|
A = random_symmetric_pd_matrix(2, 2)
|
||||||
b = torch.randn(2, 2, 2)
|
b = torch.randn(2, 2, 2)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import unittest
|
import unittest
|
||||||
from common_utils import TestCase, run_tests, download_file
|
from common_utils import TestCase, run_tests
|
||||||
import tempfile
|
import tempfile
|
||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
|
|
@ -10,7 +10,7 @@ import subprocess
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import mypy
|
import mypy # noqa: F401
|
||||||
HAVE_MYPY = True
|
HAVE_MYPY = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
HAVE_MYPY = False
|
HAVE_MYPY = False
|
||||||
|
|
|
||||||
|
|
@ -2,22 +2,19 @@ from __future__ import print_function
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import math
|
|
||||||
import shutil
|
import shutil
|
||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
import traceback
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
import warnings
|
|
||||||
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
|
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
|
||||||
import torch.hub as hub
|
import torch.hub as hub
|
||||||
from torch.autograd._functions.utils import prepare_onnx_paddings
|
from torch.autograd._functions.utils import prepare_onnx_paddings
|
||||||
from torch.autograd._functions.utils import check_onnx_broadcast
|
from torch.autograd._functions.utils import check_onnx_broadcast
|
||||||
from common_utils import IS_WINDOWS, IS_PPC, skipIfRocm, load_tests
|
from common_utils import skipIfRocm, load_tests
|
||||||
|
|
||||||
# load_tests from common_utils is used to automatically filter tests for
|
# load_tests from common_utils is used to automatically filter tests for
|
||||||
# sharding on sandcastle. This line silences flake warnings
|
# sharding on sandcastle. This line silences flake warnings
|
||||||
|
|
@ -34,7 +31,7 @@ skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
|
||||||
|
|
||||||
HAS_CUDA = torch.cuda.is_available()
|
HAS_CUDA = torch.cuda.is_available()
|
||||||
|
|
||||||
from common_utils import TestCase, run_tests, download_file
|
from common_utils import TestCase, run_tests
|
||||||
|
|
||||||
|
|
||||||
class RandomDatasetMock(object):
|
class RandomDatasetMock(object):
|
||||||
|
|
@ -326,7 +323,7 @@ test_dir = os.path.abspath(os.path.dirname(str(__file__)))
|
||||||
class TestFFI(TestCase):
|
class TestFFI(TestCase):
|
||||||
def test_deprecated(self):
|
def test_deprecated(self):
|
||||||
with self.assertRaisesRegex(ImportError, "torch.utils.ffi is deprecated. Please use cpp extensions instead."):
|
with self.assertRaisesRegex(ImportError, "torch.utils.ffi is deprecated. Please use cpp extensions instead."):
|
||||||
from torch.utils.ffi import create_extension
|
from torch.utils.ffi import create_extension # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf('SKIP_TEST_BOTTLENECK' in os.environ.keys(), 'SKIP_TEST_BOTTLENECK is set')
|
@unittest.skipIf('SKIP_TEST_BOTTLENECK' in os.environ.keys(), 'SKIP_TEST_BOTTLENECK is set')
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import argparse
|
import argparse
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,6 @@
|
||||||
# differentiable subcomponents.
|
# differentiable subcomponents.
|
||||||
#
|
#
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
from .utils import CodeTemplate, nested_dict, write, uninplace_api_name
|
from .utils import CodeTemplate, nested_dict, write, uninplace_api_name
|
||||||
from .gen_autograd import VIEW_FUNCTIONS
|
from .gen_autograd import VIEW_FUNCTIONS
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,5 @@
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
from os.path import dirname, abspath
|
from os.path import dirname, abspath
|
||||||
import shlex
|
|
||||||
import subprocess
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
# By appending pytorch_root to sys.path, this module can import other torch
|
# By appending pytorch_root to sys.path, this module can import other torch
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,20 @@
|
||||||
from .setup_helpers.env import (IS_64BIT, IS_ARM, IS_DARWIN, IS_LINUX, IS_PPC, IS_WINDOWS,
|
from .setup_helpers.env import (IS_64BIT, IS_DARWIN, IS_WINDOWS,
|
||||||
DEBUG, REL_WITH_DEB_INFO, USE_MKLDNN,
|
DEBUG, REL_WITH_DEB_INFO, USE_MKLDNN,
|
||||||
check_env_flag, check_negative_env_flag, hotpatch_build_env_vars)
|
check_env_flag, check_negative_env_flag)
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import distutils
|
import distutils
|
||||||
import distutils.sysconfig
|
import distutils.sysconfig
|
||||||
from distutils.file_util import copy_file
|
from subprocess import check_call, check_output
|
||||||
from distutils.dir_util import copy_tree
|
|
||||||
from subprocess import check_call, call, check_output
|
|
||||||
from distutils.version import LooseVersion
|
from distutils.version import LooseVersion
|
||||||
from .setup_helpers.cuda import USE_CUDA, CUDA_HOME
|
from .setup_helpers.cuda import USE_CUDA, CUDA_HOME
|
||||||
from .setup_helpers.dist_check import USE_DISTRIBUTED, USE_GLOO_IBVERBS
|
from .setup_helpers.dist_check import USE_DISTRIBUTED, USE_GLOO_IBVERBS
|
||||||
from .setup_helpers.nccl import USE_SYSTEM_NCCL, NCCL_INCLUDE_DIR, NCCL_ROOT_DIR, NCCL_SYSTEM_LIB, USE_NCCL
|
from .setup_helpers.nccl import USE_SYSTEM_NCCL, NCCL_INCLUDE_DIR, NCCL_ROOT_DIR, NCCL_SYSTEM_LIB, USE_NCCL
|
||||||
from .setup_helpers.rocm import ROCM_HOME, ROCM_VERSION, USE_ROCM
|
from .setup_helpers.rocm import USE_ROCM
|
||||||
from .setup_helpers.nnpack import USE_NNPACK
|
from .setup_helpers.nnpack import USE_NNPACK
|
||||||
from .setup_helpers.qnnpack import USE_QNNPACK
|
from .setup_helpers.qnnpack import USE_QNNPACK
|
||||||
from .setup_helpers.cudnn import CUDNN_INCLUDE_DIR, CUDNN_LIB_DIR, CUDNN_LIBRARY, USE_CUDNN
|
from .setup_helpers.cudnn import CUDNN_INCLUDE_DIR, CUDNN_LIBRARY, USE_CUDNN
|
||||||
|
|
||||||
|
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
|
||||||
|
|
@ -12,9 +12,7 @@ Only files that are in CLANG_FORMAT_WHITELIST are checked.
|
||||||
import subprocess
|
import subprocess
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
import fnmatch
|
|
||||||
import difflib
|
import difflib
|
||||||
import sys
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
from .cwrap import cwrap
|
from .cwrap import cwrap # noqa: F401
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
from . import CWrapPlugin
|
from . import CWrapPlugin
|
||||||
from string import Template
|
|
||||||
|
|
||||||
|
|
||||||
class ArgumentReferences(CWrapPlugin):
|
class ArgumentReferences(CWrapPlugin):
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,6 @@
|
||||||
from string import Template
|
from string import Template
|
||||||
import copy
|
import copy
|
||||||
from copy import deepcopy
|
|
||||||
from . import CWrapPlugin
|
from . import CWrapPlugin
|
||||||
from itertools import product
|
|
||||||
|
|
||||||
|
|
||||||
class CuDNNPlugin(CWrapPlugin):
|
class CuDNNPlugin(CWrapPlugin):
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
from . import CWrapPlugin
|
from . import CWrapPlugin
|
||||||
from string import Template
|
|
||||||
|
|
||||||
|
|
||||||
class GILRelease(CWrapPlugin):
|
class GILRelease(CWrapPlugin):
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,4 @@
|
||||||
import os
|
|
||||||
from copy import deepcopy
|
|
||||||
from . import CWrapPlugin
|
from . import CWrapPlugin
|
||||||
from itertools import product
|
|
||||||
from ...shared import cwrap_common
|
from ...shared import cwrap_common
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -420,15 +420,15 @@ class CWrapPlugin(object):
|
||||||
return template
|
return template
|
||||||
|
|
||||||
|
|
||||||
from .NNExtension import NNExtension
|
from .NNExtension import NNExtension # noqa: F401
|
||||||
from .NullableArguments import NullableArguments
|
from .NullableArguments import NullableArguments # noqa: F401
|
||||||
from .OptionalArguments import OptionalArguments
|
from .OptionalArguments import OptionalArguments # noqa: F401
|
||||||
from .ArgcountChecker import ArgcountChecker
|
from .ArgcountChecker import ArgcountChecker # noqa: F401
|
||||||
from .ArgumentReferences import ArgumentReferences
|
from .ArgumentReferences import ArgumentReferences # noqa: F401
|
||||||
from .BeforeAfterCall import BeforeAfterCall
|
from .BeforeAfterCall import BeforeAfterCall # noqa: F401
|
||||||
from .ConstantArguments import ConstantArguments
|
from .ConstantArguments import ConstantArguments # noqa: F401
|
||||||
from .ReturnArguments import ReturnArguments
|
from .ReturnArguments import ReturnArguments # noqa: F401
|
||||||
from .GILRelease import GILRelease
|
from .GILRelease import GILRelease # noqa: F401
|
||||||
from .AutoGPU import AutoGPU
|
from .AutoGPU import AutoGPU # noqa: F401
|
||||||
from .CuDNNPlugin import CuDNNPlugin
|
from .CuDNNPlugin import CuDNNPlugin # noqa: F401
|
||||||
from .WrapDim import WrapDim
|
from .WrapDim import WrapDim # noqa: F401
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ import argparse
|
||||||
import gzip
|
import gzip
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import urllib
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from urllib.error import URLError
|
from urllib.error import URLError
|
||||||
|
|
|
||||||
|
|
@ -12,14 +12,11 @@ generated. In the full build system, OUTPUT_DIR is
|
||||||
torch/csrc/jit/generated/
|
torch/csrc/jit/generated/
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import argparse
|
import argparse
|
||||||
import re
|
|
||||||
import copy
|
import copy
|
||||||
from itertools import count, combinations, groupby
|
from itertools import groupby
|
||||||
from ..autograd.utils import CodeTemplate, write, uninplace_api_name
|
from ..autograd.utils import CodeTemplate, write
|
||||||
from ..autograd.gen_autograd import load_aten_declarations
|
from ..autograd.gen_autograd import load_aten_declarations
|
||||||
from collections import OrderedDict
|
|
||||||
from ..autograd.gen_autograd import RETURNS_VIEWS_OF_INPUT
|
from ..autograd.gen_autograd import RETURNS_VIEWS_OF_INPUT
|
||||||
|
|
||||||
# JIT has a type system of
|
# JIT has a type system of
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
from .generate_wrappers import generate_wrappers, wrap_function, import_module
|
from .generate_wrappers import generate_wrappers, wrap_function, import_module # noqa: F401
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
from string import Template
|
||||||
from string import Template, ascii_lowercase
|
|
||||||
from ..cwrap import cwrap
|
from ..cwrap import cwrap
|
||||||
from ..cwrap.plugins import NNExtension, NullableArguments, AutoGPU
|
from ..cwrap.plugins import NNExtension, NullableArguments, AutoGPU
|
||||||
from ..shared import import_module
|
from ..shared import import_module
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,7 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import multiprocessing
|
|
||||||
import sys
|
|
||||||
import os
|
import os
|
||||||
import inspect
|
|
||||||
import collections
|
import collections
|
||||||
import yaml
|
import yaml
|
||||||
import types
|
|
||||||
import re
|
import re
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,7 @@ import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import glob
|
import glob
|
||||||
|
|
||||||
from .env import IS_CONDA, IS_LINUX, IS_WINDOWS, CONDA_DIR, check_env_flag, check_negative_env_flag, gather_paths
|
from .env import IS_CONDA, IS_WINDOWS, CONDA_DIR, check_env_flag, check_negative_env_flag, gather_paths
|
||||||
from .cuda import USE_CUDA
|
|
||||||
|
|
||||||
# On ROCm, RCCL development isn't complete. https://github.com/ROCmSoftwarePlatform/rccl
|
# On ROCm, RCCL development isn't complete. https://github.com/ROCmSoftwarePlatform/rccl
|
||||||
USE_DISTRIBUTED = not check_negative_env_flag("USE_DISTRIBUTED") and not IS_WINDOWS and not check_env_flag("USE_ROCM")
|
USE_DISTRIBUTED = not check_negative_env_flag("USE_DISTRIBUTED") and not IS_WINDOWS and not check_env_flag("USE_ROCM")
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,4 @@
|
||||||
import os
|
from .env import check_env_flag
|
||||||
import glob
|
|
||||||
|
|
||||||
from .env import IS_WINDOWS, IS_CONDA, CONDA_DIR, check_env_flag, gather_paths
|
|
||||||
from .rocm import USE_ROCM, ROCM_HOME
|
from .rocm import USE_ROCM, ROCM_HOME
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,5 @@
|
||||||
import os
|
import os
|
||||||
import glob
|
import glob
|
||||||
import warnings
|
|
||||||
from itertools import chain
|
|
||||||
|
|
||||||
from .env import IS_WINDOWS, IS_DARWIN, IS_CONDA, CONDA_DIR, check_negative_env_flag, \
|
from .env import IS_WINDOWS, IS_DARWIN, IS_CONDA, CONDA_DIR, check_negative_env_flag, \
|
||||||
gather_paths
|
gather_paths
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import ctypes.util
|
import ctypes.util
|
||||||
from subprocess import Popen, PIPE
|
|
||||||
|
|
||||||
from .cuda import USE_CUDA
|
from .cuda import USE_CUDA
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,2 @@
|
||||||
from .module_loader import import_module
|
from .module_loader import import_module # noqa: F401
|
||||||
from .cwrap_common import set_declaration_defaults, \
|
from .cwrap_common import set_declaration_defaults, sort_by_number_of_options, enumerate_options_due_to_default # noqa: F401
|
||||||
sort_by_number_of_options, enumerate_options_due_to_default
|
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ import sys
|
||||||
import platform
|
import platform
|
||||||
from ._utils import _import_dotted_name
|
from ._utils import _import_dotted_name
|
||||||
from ._utils_internal import get_file_path, prepare_multiprocessing_environment
|
from ._utils_internal import get_file_path, prepare_multiprocessing_environment
|
||||||
from .version import __version__
|
from .version import __version__ # noqa: F401
|
||||||
from ._six import string_classes as _string_classes
|
from ._six import string_classes as _string_classes
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
@ -39,7 +39,7 @@ import os as _dl_flags
|
||||||
# if we have numpy, it *must* be imported before the call to setdlopenflags()
|
# if we have numpy, it *must* be imported before the call to setdlopenflags()
|
||||||
# or there is risk that later c modules will segfault when importing numpy
|
# or there is risk that later c modules will segfault when importing numpy
|
||||||
try:
|
try:
|
||||||
import numpy as _np
|
import numpy as _np # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -281,7 +281,7 @@ del BoolStorageBase
|
||||||
|
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
import torch.autograd
|
import torch.autograd
|
||||||
from torch.autograd import no_grad, enable_grad, set_grad_enabled
|
from torch.autograd import no_grad, enable_grad, set_grad_enabled # noqa: F401
|
||||||
import torch.nn
|
import torch.nn
|
||||||
import torch.optim
|
import torch.optim
|
||||||
import torch.multiprocessing
|
import torch.multiprocessing
|
||||||
|
|
@ -309,7 +309,7 @@ def compiled_with_cxx11_abi():
|
||||||
|
|
||||||
|
|
||||||
# Import the ops "namespace"
|
# Import the ops "namespace"
|
||||||
from torch._ops import ops
|
from torch._ops import ops # noqa: F401
|
||||||
|
|
||||||
# Import the quasi random sampler
|
# Import the quasi random sampler
|
||||||
import torch.quasirandom
|
import torch.quasirandom
|
||||||
|
|
|
||||||
|
|
@ -53,9 +53,9 @@ else:
|
||||||
|
|
||||||
|
|
||||||
if PY2:
|
if PY2:
|
||||||
import Queue as queue
|
import Queue as queue # noqa: F401
|
||||||
else:
|
else:
|
||||||
import queue
|
import queue # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
def with_metaclass(meta, *bases):
|
def with_metaclass(meta, *bases):
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,6 @@
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from functools import reduce
|
from torch._six import inf
|
||||||
from sys import float_info
|
|
||||||
from torch._six import inf, nan
|
|
||||||
|
|
||||||
|
|
||||||
class __PrinterOptions(object):
|
class __PrinterOptions(object):
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,3 @@
|
||||||
import os
|
|
||||||
import itertools
|
|
||||||
import importlib
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# when compiling a cffi extension, this works. When compiling
|
# when compiling a cffi extension, this works. When compiling
|
||||||
# torch itself, it doesn't work because the parent module can't
|
# torch itself, it doesn't work because the parent module can't
|
||||||
|
|
|
||||||
|
|
@ -8,11 +8,11 @@ import torch
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from .variable import Variable
|
from .variable import Variable
|
||||||
from .function import Function, NestedIOFunction
|
from .function import Function, NestedIOFunction # noqa: F401
|
||||||
from .gradcheck import gradcheck, gradgradcheck
|
from .gradcheck import gradcheck, gradgradcheck # noqa: F401
|
||||||
from .grad_mode import no_grad, enable_grad, set_grad_enabled
|
from .grad_mode import no_grad, enable_grad, set_grad_enabled # noqa: F401
|
||||||
from .anomaly_mode import detect_anomaly, set_detect_anomaly
|
from .anomaly_mode import detect_anomaly, set_detect_anomaly # noqa: F401
|
||||||
from . import profiler
|
from . import profiler # noqa: F401
|
||||||
|
|
||||||
__all__ = ['Variable', 'Function', 'backward', 'grad_mode']
|
__all__ = ['Variable', 'Function', 'backward', 'grad_mode']
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
from .tensor import *
|
from .tensor import * # noqa: F401
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import torch
|
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
from torch._six import container_abcs, istuple
|
from torch._six import container_abcs, istuple
|
||||||
import torch.testing
|
import torch.testing
|
||||||
import sys
|
|
||||||
from itertools import product
|
from itertools import product
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,7 @@
|
||||||
import subprocess
|
|
||||||
import re
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import itertools
|
import itertools
|
||||||
from collections import defaultdict, namedtuple
|
from collections import defaultdict, namedtuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._six import FileNotFoundError
|
|
||||||
|
|
||||||
|
|
||||||
class range(object):
|
class range(object):
|
||||||
|
|
|
||||||
|
|
@ -648,7 +648,7 @@ torch._storage_classes.add(ByteStorage)
|
||||||
torch._storage_classes.add(HalfStorage)
|
torch._storage_classes.add(HalfStorage)
|
||||||
torch._storage_classes.add(BoolStorage)
|
torch._storage_classes.add(BoolStorage)
|
||||||
|
|
||||||
from . import sparse
|
from . import sparse # noqa: F401
|
||||||
from . import profiler
|
from . import profiler # noqa: F401
|
||||||
from . import nvtx
|
from . import nvtx # noqa: F401
|
||||||
from .streams import Stream, Event
|
from .streams import Stream, Event # noqa: F401
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from . import nccl
|
from . import nccl
|
||||||
from torch._utils import _accumulate, _take_tensors, _flatten_dense_tensors, \
|
from torch._utils import _take_tensors, _flatten_dense_tensors, \
|
||||||
_flatten_sparse_tensors, _unflatten_dense_tensors, \
|
_unflatten_dense_tensors, _reorder_tensors_as
|
||||||
_unflatten_sparse_tensors, _reorder_tensors_as
|
|
||||||
|
|
||||||
|
|
||||||
def broadcast(tensor, devices):
|
def broadcast(tensor, devices):
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,8 @@ if is_available() and not torch._C._c10d_init():
|
||||||
|
|
||||||
|
|
||||||
if is_available():
|
if is_available():
|
||||||
from .distributed_c10d import *
|
from .distributed_c10d import * # noqa: F401
|
||||||
# Variables prefixed with underscore are not auto imported
|
# Variables prefixed with underscore are not auto imported
|
||||||
# See the comment in `distributed_c10d.py` above `_backend` on why we expose
|
# See the comment in `distributed_c10d.py` above `_backend` on why we expose
|
||||||
# this.
|
# this.
|
||||||
from .distributed_c10d import _backend
|
from .distributed_c10d import _backend # noqa: F401
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,10 @@ import warnings
|
||||||
from torch._six import string_classes
|
from torch._six import string_classes
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
from .rendezvous import rendezvous, register_rendezvous_handler
|
# This module is wildcard imported from torch.distributed.
|
||||||
|
# TODO: specify __all__
|
||||||
|
|
||||||
|
from .rendezvous import rendezvous, register_rendezvous_handler # noqa: F401
|
||||||
from . import BroadcastOptions, AllreduceOptions, ReduceOptions, \
|
from . import BroadcastOptions, AllreduceOptions, ReduceOptions, \
|
||||||
ScatterOptions, GatherOptions
|
ScatterOptions, GatherOptions
|
||||||
from . import ReduceOp
|
from . import ReduceOp
|
||||||
|
|
|
||||||
|
|
@ -140,11 +140,8 @@ will not pass ``--local_rank`` when you specify this flag.
|
||||||
import sys
|
import sys
|
||||||
import subprocess
|
import subprocess
|
||||||
import os
|
import os
|
||||||
import socket
|
|
||||||
from argparse import ArgumentParser, REMAINDER
|
from argparse import ArgumentParser, REMAINDER
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -124,6 +124,8 @@ __all__ = [
|
||||||
'Gamma',
|
'Gamma',
|
||||||
'Geometric',
|
'Geometric',
|
||||||
'Gumbel',
|
'Gumbel',
|
||||||
|
'HalfCauchy',
|
||||||
|
'HalfNormal',
|
||||||
'Independent',
|
'Independent',
|
||||||
'Laplace',
|
'Laplace',
|
||||||
'LogNormal',
|
'LogNormal',
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import torch
|
||||||
from torch._six import nan
|
from torch._six import nan
|
||||||
from torch.distributions import constraints
|
from torch.distributions import constraints
|
||||||
from torch.distributions.distribution import Distribution
|
from torch.distributions.distribution import Distribution
|
||||||
from torch.distributions.utils import probs_to_logits, logits_to_probs, lazy_property, broadcast_all
|
from torch.distributions.utils import probs_to_logits, logits_to_probs, lazy_property
|
||||||
|
|
||||||
|
|
||||||
class Categorical(Distribution):
|
class Categorical(Distribution):
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import torch
|
|
||||||
from torch.distributions import constraints
|
from torch.distributions import constraints
|
||||||
from torch.distributions.gamma import Gamma
|
from torch.distributions.gamma import Gamma
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
import torch
|
import torch
|
||||||
import math
|
|
||||||
from torch._six import nan
|
from torch._six import nan
|
||||||
from torch.distributions import constraints
|
from torch.distributions import constraints
|
||||||
from torch.distributions.distribution import Distribution
|
from torch.distributions.distribution import Distribution
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from numbers import Number
|
||||||
import torch
|
import torch
|
||||||
from torch.distributions import constraints
|
from torch.distributions import constraints
|
||||||
from torch.distributions.exp_family import ExponentialFamily
|
from torch.distributions.exp_family import ExponentialFamily
|
||||||
from torch.distributions.utils import broadcast_all, lazy_property
|
from torch.distributions.utils import broadcast_all
|
||||||
|
|
||||||
|
|
||||||
def _standard_gamma(concentration):
|
def _standard_gamma(concentration):
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch._six import inf
|
from torch._six import inf
|
||||||
from torch.distributions import constraints
|
from torch.distributions import constraints
|
||||||
from torch.distributions.transforms import AbsTransform
|
from torch.distributions.transforms import AbsTransform
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch._six import inf
|
from torch._six import inf
|
||||||
from torch.distributions import constraints
|
from torch.distributions import constraints
|
||||||
from torch.distributions.transforms import AbsTransform
|
from torch.distributions.transforms import AbsTransform
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,6 @@ from .gumbel import Gumbel
|
||||||
from .half_normal import HalfNormal
|
from .half_normal import HalfNormal
|
||||||
from .independent import Independent
|
from .independent import Independent
|
||||||
from .laplace import Laplace
|
from .laplace import Laplace
|
||||||
from .logistic_normal import LogisticNormal
|
|
||||||
from .lowrank_multivariate_normal import (LowRankMultivariateNormal, _batch_lowrank_logdet,
|
from .lowrank_multivariate_normal import (LowRankMultivariateNormal, _batch_lowrank_logdet,
|
||||||
_batch_lowrank_mahalanobis)
|
_batch_lowrank_mahalanobis)
|
||||||
from .multivariate_normal import (MultivariateNormal, _batch_mahalanobis)
|
from .multivariate_normal import (MultivariateNormal, _batch_mahalanobis)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import torch
|
|
||||||
from torch.distributions import constraints
|
from torch.distributions import constraints
|
||||||
from torch.distributions.transforms import ExpTransform
|
from torch.distributions.transforms import ExpTransform
|
||||||
from torch.distributions.normal import Normal
|
from torch.distributions.normal import Normal
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import torch
|
||||||
from torch.distributions import constraints
|
from torch.distributions import constraints
|
||||||
from torch.distributions.normal import Normal
|
from torch.distributions.normal import Normal
|
||||||
from torch.distributions.transformed_distribution import TransformedDistribution
|
from torch.distributions.transformed_distribution import TransformedDistribution
|
||||||
from torch.distributions.transforms import ComposeTransform, ExpTransform, StickBreakingTransform
|
from torch.distributions.transforms import StickBreakingTransform
|
||||||
|
|
||||||
|
|
||||||
class LogisticNormal(TransformedDistribution):
|
class LogisticNormal(TransformedDistribution):
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import torch
|
|
||||||
from torch.distributions import constraints
|
from torch.distributions import constraints
|
||||||
from torch.distributions.exponential import Exponential
|
from torch.distributions.exponential import Exponential
|
||||||
from torch.distributions.transformed_distribution import TransformedDistribution
|
from torch.distributions.transformed_distribution import TransformedDistribution
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
import math
|
import math
|
||||||
from numbers import Number
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._six import inf, nan
|
from torch._six import inf, nan
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
from functools import update_wrapper
|
from functools import update_wrapper
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
import math
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
from .onnx import *
|
from .onnx import * # noqa: F401
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ import importlib
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
||||||
if sys.version_info[0] == 2:
|
if sys.version_info[0] == 2:
|
||||||
|
|
@ -10,10 +9,7 @@ if sys.version_info[0] == 2:
|
||||||
from urllib2 import urlopen # noqa f811
|
from urllib2 import urlopen # noqa f811
|
||||||
else:
|
else:
|
||||||
from urllib.request import urlopen
|
from urllib.request import urlopen
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse # noqa: F401
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.utils.model_zoo as model_zoo
|
|
||||||
|
|
||||||
MASTER_BRANCH = 'master'
|
MASTER_BRANCH = 'master'
|
||||||
ENV_TORCH_HUB_DIR = 'TORCH_HUB_DIR'
|
ENV_TORCH_HUB_DIR = 'TORCH_HUB_DIR'
|
||||||
|
|
|
||||||
|
|
@ -1,35 +1,31 @@
|
||||||
import torch._C
|
import torch._C
|
||||||
from torch import Tensor
|
|
||||||
from torch.autograd import Variable, function
|
from torch.autograd import Variable, function
|
||||||
from torch.serialization import validate_cuda_device
|
from torch.serialization import validate_cuda_device
|
||||||
from torch.nn import Module, ModuleList, ParameterList, Parameter, Sequential
|
from torch.nn import Module, ModuleList, Parameter, Sequential
|
||||||
from torch.jit.frontend import get_jit_class_def, get_jit_def, get_default_args
|
from torch.jit.frontend import get_jit_class_def, get_jit_def, get_default_args
|
||||||
import torch.backends.cudnn as cudnn
|
import torch.backends.cudnn as cudnn
|
||||||
import torch.jit.annotations
|
import torch.jit.annotations
|
||||||
import torch._jit_internal as _jit_internal
|
import torch._jit_internal as _jit_internal
|
||||||
from torch._six import raise_from, with_metaclass, get_function_from_type, \
|
from torch._six import with_metaclass, get_function_from_type, \
|
||||||
string_classes
|
string_classes
|
||||||
from torch._jit_internal import ignore
|
from torch._jit_internal import ignore # noqa: F401
|
||||||
from torch.jit._pickle import Unpickler
|
from torch.jit._pickle import Unpickler # noqa: F401
|
||||||
from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \
|
from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \
|
||||||
_list_with_default
|
_list_with_default
|
||||||
import torch.testing
|
import torch.testing
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from collections import defaultdict, OrderedDict, namedtuple
|
from collections import OrderedDict, namedtuple
|
||||||
import textwrap
|
import textwrap
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
import itertools
|
|
||||||
import weakref
|
import weakref
|
||||||
import types
|
import types
|
||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
import functools
|
import functools
|
||||||
import copy
|
import copy
|
||||||
import numbers
|
|
||||||
import collections
|
import collections
|
||||||
import re
|
|
||||||
import inspect
|
import inspect
|
||||||
import pickle
|
import pickle
|
||||||
if sys.version_info[0] > 2:
|
if sys.version_info[0] > 2:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
import torch
|
|
||||||
import functools
|
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import re
|
|
||||||
import sys
|
import sys
|
||||||
import ast
|
import ast
|
||||||
import inspect
|
import inspect
|
||||||
|
|
@ -181,3 +180,33 @@ def ann_to_type(ann):
|
||||||
elif ann is str:
|
elif ann is str:
|
||||||
return StringType.get()
|
return StringType.get()
|
||||||
raise ValueError("Unknown type annotation: '{}'".format(ann.__name__))
|
raise ValueError("Unknown type annotation: '{}'".format(ann.__name__))
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'List',
|
||||||
|
'BroadcastingList1',
|
||||||
|
'BroadcastingList2',
|
||||||
|
'BroadcastingList3',
|
||||||
|
'Tuple',
|
||||||
|
'is_tuple',
|
||||||
|
'is_list',
|
||||||
|
'Dict',
|
||||||
|
'is_dict',
|
||||||
|
'TensorType',
|
||||||
|
'TupleType',
|
||||||
|
'FloatType',
|
||||||
|
'IntType',
|
||||||
|
'ListType',
|
||||||
|
'StringType',
|
||||||
|
'DictType',
|
||||||
|
'Module',
|
||||||
|
# TODO: Consider not exporting these during wildcard import (reserve
|
||||||
|
# that for the types; for idiomatic typing code.)
|
||||||
|
'get_signature',
|
||||||
|
'get_num_params',
|
||||||
|
'parse_type_line',
|
||||||
|
'get_type_line',
|
||||||
|
'split_type_line',
|
||||||
|
'try_real_annotations',
|
||||||
|
'ann_to_type',
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,6 @@ import ast
|
||||||
import inspect
|
import inspect
|
||||||
import string
|
import string
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from functools import partial
|
|
||||||
from collections import namedtuple
|
|
||||||
from torch._six import PY2
|
from torch._six import PY2
|
||||||
from torch._C._jit_tree_views import *
|
from torch._C._jit_tree_views import *
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import copy
|
from typing import Tuple, Optional # noqa: F401
|
||||||
import numbers
|
|
||||||
from typing import Tuple, Optional
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.jit import ScriptModule
|
|
||||||
|
|
||||||
from torch.nn.utils.rnn import PackedSequence
|
|
||||||
from torch.nn import _VF
|
from torch.nn import _VF
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ __all__ = ['set_sharing_strategy', 'get_sharing_strategy',
|
||||||
'get_all_sharing_strategies']
|
'get_all_sharing_strategies']
|
||||||
|
|
||||||
|
|
||||||
from multiprocessing import *
|
from multiprocessing import * # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
__all__ += multiprocessing.__all__
|
__all__ += multiprocessing.__all__
|
||||||
|
|
@ -36,13 +36,13 @@ torch._C._multiprocessing_init()
|
||||||
if sys.version_info < (3, 3):
|
if sys.version_info < (3, 3):
|
||||||
"""Override basic classes in Python 2.7 and Python 3.3 to use ForkingPickler
|
"""Override basic classes in Python 2.7 and Python 3.3 to use ForkingPickler
|
||||||
for serialization. Later versions of Python already use ForkingPickler."""
|
for serialization. Later versions of Python already use ForkingPickler."""
|
||||||
from .queue import Queue, SimpleQueue
|
from .queue import Queue, SimpleQueue # noqa: F401
|
||||||
from .pool import Pool
|
from .pool import Pool # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
"""Add helper function to spawn N processes and wait for completion of any of
|
"""Add helper function to spawn N processes and wait for completion of any of
|
||||||
them. This depends `mp.get_context` which was added in Python 3.4."""
|
them. This depends `mp.get_context` which was added in Python 3.4."""
|
||||||
from .spawn import spawn, SpawnContext
|
from .spawn import spawn, SpawnContext # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
if sys.platform == 'darwin' or sys.platform == 'win32':
|
if sys.platform == 'darwin' or sys.platform == 'win32':
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.hooks
|
import torch.utils.hooks
|
||||||
import os
|
import os
|
||||||
import weakref
|
|
||||||
import threading
|
import threading
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from multiprocessing.reduction import ForkingPickler
|
from multiprocessing.reduction import ForkingPickler
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user