mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE] Remove dependency on six and future (#94709)
Remove the Python 2 and 3 compatibility library [six](https://pypi.org/project/six) and [future](https://pypi.org/project/future) and `torch._six`. We only support Python 3.8+ now. It's time to retire them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94709 Approved by: https://github.com/malfet, https://github.com/Skylion007
This commit is contained in:
parent
39511697d4
commit
b005ec62b9
|
|
@ -36,11 +36,6 @@ flatbuffers==2.0
|
|||
#Pinned versions: 2.0
|
||||
#test that import:
|
||||
|
||||
#future #this breaks linux-bionic-rocm4.5-py3.7
|
||||
#Description: compatibility layer between python 2 and python 3
|
||||
#Pinned versions:
|
||||
#test that import:
|
||||
|
||||
hypothesis==5.35.1
|
||||
# Pin hypothesis to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136
|
||||
#Description: advanced library for generating parametrized tests
|
||||
|
|
|
|||
2
.circleci/config.yml
generated
2
.circleci/config.yml
generated
|
|
@ -1101,7 +1101,7 @@ jobs:
|
|||
cd ${PROJ_ROOT}/ios/TestApp/benchmark
|
||||
mkdir -p ../models
|
||||
if [ ${USE_COREML_DELEGATE} == 1 ]; then
|
||||
pip install coremltools==5.0b5 protobuf==3.20.1 six==1.16.0
|
||||
pip install coremltools==5.0b5 protobuf==3.20.1
|
||||
python coreml_backend.py
|
||||
else
|
||||
cd "${PROJ_ROOT}"
|
||||
|
|
|
|||
|
|
@ -82,8 +82,7 @@ if [[ "$PACKAGE_TYPE" == conda ]]; then
|
|||
mkl>=2018 \
|
||||
ninja \
|
||||
typing-extensions \
|
||||
${PROTOBUF_PACKAGE} \
|
||||
six
|
||||
${PROTOBUF_PACKAGE}
|
||||
if [[ "$DESIRED_CUDA" == 'cpu' ]]; then
|
||||
retry conda install -c pytorch -y cpuonly
|
||||
else
|
||||
|
|
@ -100,7 +99,7 @@ if [[ "$PACKAGE_TYPE" == conda ]]; then
|
|||
)
|
||||
elif [[ "$PACKAGE_TYPE" != libtorch ]]; then
|
||||
pip install "\$pkg" --extra-index-url "https://download.pytorch.org/whl/nightly/${DESIRED_CUDA}"
|
||||
retry pip install -q future numpy protobuf typing-extensions six
|
||||
retry pip install -q numpy protobuf typing-extensions
|
||||
fi
|
||||
if [[ "$PACKAGE_TYPE" == libtorch ]]; then
|
||||
pkg="\$(ls /final_pkgs/*-latest.zip)"
|
||||
|
|
|
|||
|
|
@ -626,7 +626,7 @@
|
|||
cd ${PROJ_ROOT}/ios/TestApp/benchmark
|
||||
mkdir -p ../models
|
||||
if [ ${USE_COREML_DELEGATE} == 1 ]; then
|
||||
pip install coremltools==5.0b5 protobuf==3.20.1 six==1.16.0
|
||||
pip install coremltools==5.0b5 protobuf==3.20.1
|
||||
python coreml_backend.py
|
||||
else
|
||||
cd "${PROJ_ROOT}"
|
||||
|
|
|
|||
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
|
|
@ -1 +1 @@
|
|||
9cbcdb4008c14ad8251c5d4d7723aa616f659edb
|
||||
d29eb67c27af0f18d4f487d76b86f43b0a69aade
|
||||
|
|
|
|||
1
.github/requirements/conda-env-macOS-ARM64
vendored
1
.github/requirements/conda-env-macOS-ARM64
vendored
|
|
@ -5,7 +5,6 @@ cmake=3.22.*
|
|||
typing-extensions=4.3.0
|
||||
dataclasses=0.8
|
||||
pip=22.2.2
|
||||
six=1.16.0
|
||||
pillow=9.2.0
|
||||
pkg-config=0.29.2
|
||||
wheel=0.37.1
|
||||
|
|
|
|||
1
.github/requirements/conda-env-macOS-X64
vendored
1
.github/requirements/conda-env-macOS-X64
vendored
|
|
@ -7,7 +7,6 @@ cmake=3.22.*
|
|||
typing-extensions=4.3.0
|
||||
dataclasses=0.8
|
||||
pip=22.2.2
|
||||
six=1.16.0
|
||||
pillow=9.2.0
|
||||
libuv=1.40.0
|
||||
pkg-config=0.29.2
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# iOS simulator requirements
|
||||
coremltools==5.0b5
|
||||
protobuf==3.20.2
|
||||
six==1.16.0
|
||||
|
|
|
|||
2
.github/workflows/run_torchbench.yml
vendored
2
.github/workflows/run_torchbench.yml
vendored
|
|
@ -41,7 +41,7 @@ jobs:
|
|||
conda activate pr-ci
|
||||
conda install -y numpy="${NUMPY_VERSION}" requests ninja pyyaml mkl mkl-include \
|
||||
setuptools cmake=3.22.* typing-extensions boto3 \
|
||||
six pillow pytest tabulate gitpython git-lfs tqdm psutil
|
||||
pillow pytest tabulate gitpython git-lfs tqdm psutil
|
||||
pip install --pre torch torchvision torchtext -f https://download.pytorch.org/whl/nightly/cu116/torch_nightly.html
|
||||
- name: Setup TorchBench branch
|
||||
run: |
|
||||
|
|
|
|||
|
|
@ -145,7 +145,6 @@ init_command = [
|
|||
'expecttest==0.1.3',
|
||||
'mypy==0.960',
|
||||
'types-requests==2.27.25',
|
||||
'types-six==1.16.15',
|
||||
'types-PyYAML==6.0.7',
|
||||
'types-tabulate==0.8.8',
|
||||
'types-protobuf==3.19.18',
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ build-deps: clone-deps
|
|||
# conda create --name torchdynamo -y python=3.8
|
||||
# conda activate torchdynamo
|
||||
conda install -y astunparse numpy scipy ninja pyyaml mkl mkl-include setuptools cmake \
|
||||
typing-extensions six requests protobuf numba cython scikit-learn
|
||||
typing-extensions requests protobuf numba cython scikit-learn
|
||||
conda install -y -c pytorch magma-cuda116
|
||||
conda install -y -c conda-forge librosa
|
||||
(cd ../../../torchvision && python setup.py clean && python setup.py develop)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@ import itertools
|
|||
import logging
|
||||
import os
|
||||
|
||||
from six import add_metaclass
|
||||
import numpy as np
|
||||
|
||||
from caffe2.python import workspace, core
|
||||
|
|
@ -46,8 +45,7 @@ class BenchmarkMeta(type):
|
|||
return cls
|
||||
|
||||
|
||||
@add_metaclass(BenchmarkMeta)
|
||||
class Benchmark:
|
||||
class Benchmark(metaclass=BenchmarkMeta):
|
||||
|
||||
def __init__(self):
|
||||
self.results = []
|
||||
|
|
|
|||
|
|
@ -58,10 +58,6 @@ Note that you might need to uninstall existing Eigen and pybind11 packages due t
|
|||
|
||||
## Python support
|
||||
|
||||
To use Caffe2 in Python, you need two libraries, future and six.
|
||||
|
||||
pip install future six
|
||||
|
||||
To run the tutorials, download additional source from GitHub.
|
||||
|
||||
git clone --recursive https://github.com/caffe2/tutorials.git caffe2_tutorials
|
||||
|
|
|
|||
|
|
@ -6,4 +6,3 @@ docutils==0.16
|
|||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
|
||||
bs4
|
||||
lxml
|
||||
six
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ requires = [
|
|||
"setuptools",
|
||||
"cmake",
|
||||
"typing-extensions",
|
||||
"six",
|
||||
"requests",
|
||||
]
|
||||
# Use legacy backend to import local packages in setup.py
|
||||
|
|
|
|||
|
|
@ -41,10 +41,6 @@ sudo apt-get install \
|
|||
# the one provided by apt-get is quite old so we install it via pip
|
||||
sudo pip install hypothesis
|
||||
|
||||
# Install the six module, which includes Python 2 and 3 compatibility utilities,
|
||||
# and is required for Caffe2
|
||||
sudo pip install six
|
||||
|
||||
# Now, actually build the android target.
|
||||
echo "Building caffe2"
|
||||
cd $BUILD_ROOT
|
||||
|
|
|
|||
|
|
@ -95,10 +95,6 @@ sudo zypper install \
|
|||
# Obtain python hypothesis, which Caffe2 uses for unit testing. Note that
|
||||
# the one provided by zypper is quite old so we install it via pip
|
||||
sudo pip install hypothesis
|
||||
|
||||
# Install the six module, which includes Python 2 and 3 compatibility utilities,
|
||||
# and is required for Caffe2
|
||||
sudo pip install six
|
||||
}
|
||||
|
||||
caffe2_full_build(){
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import sys
|
|||
import tarfile
|
||||
import tempfile
|
||||
|
||||
from six.moves.urllib.request import urlretrieve
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
from caffe2.python.models.download import downloadFromURLToFile, getURLFromName, deleteDirectory
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ import tempfile
|
|||
|
||||
import boto3
|
||||
|
||||
from six.moves.urllib.request import urlretrieve
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
from caffe2.python.models.download import downloadFromURLToFile, getURLFromName, deleteDirectory
|
||||
from caffe2.proto import caffe2_pb2
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ if not dist.is_available():
|
|||
sys.exit(0)
|
||||
|
||||
import torch.testing._internal.common_utils as common
|
||||
from torch._six import string_classes
|
||||
from torch.testing._internal.common_distributed import (
|
||||
skip_if_win32,
|
||||
create_tcp_store
|
||||
|
|
@ -336,7 +335,7 @@ class MyPythonStore(dist.Store):
|
|||
self.store = {}
|
||||
|
||||
def set(self, key, value):
|
||||
if not isinstance(key, string_classes):
|
||||
if not isinstance(key, str):
|
||||
raise AssertionError("Expected set to be called with string key")
|
||||
if type(value) is not bytes:
|
||||
raise AssertionError("Expected set to be called with bytes value")
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ import torch
|
|||
# Distributions tests use double as the default dtype
|
||||
torch.set_default_dtype(torch.double)
|
||||
|
||||
from torch._six import inf, nan
|
||||
from torch import inf, nan
|
||||
from torch.testing._internal.common_utils import \
|
||||
(TestCase, run_tests, set_rng_seed, TEST_WITH_UBSAN, load_tests,
|
||||
gradcheck, skipIfTorchDynamo)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import random
|
|||
import itertools
|
||||
import math
|
||||
|
||||
from torch._six import inf, nan
|
||||
from torch import inf, nan
|
||||
import torch
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_UBSAN, set_default_dtype, \
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from functools import reduce, partial
|
|||
import torch
|
||||
|
||||
from torch import nn
|
||||
from torch._six import inf, nan
|
||||
from torch import inf, nan
|
||||
from torch.autograd.function import once_differentiable
|
||||
from torch.autograd.profiler import (profile, record_function, emit_nvtx, emit_itt)
|
||||
from torch.autograd.profiler_util import (_format_time, EventList, FunctionEvent, FunctionEventAvg)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ import operator
|
|||
from functools import partial
|
||||
|
||||
import torch.autograd.forward_ad as fwAD
|
||||
from torch._six import inf, nan
|
||||
from torch import inf, nan
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
slowTest,
|
||||
|
|
|
|||
|
|
@ -22,9 +22,9 @@ from random import randint
|
|||
import torch
|
||||
import torch.cuda
|
||||
import torch.cuda.comm as comm
|
||||
from torch import inf, nan
|
||||
from torch.nn.parallel import scatter_gather
|
||||
from torch.utils.checkpoint import checkpoint_sequential
|
||||
from torch._six import inf, nan
|
||||
from torch.testing._internal.common_utils import TestCase, freeze_rng_state, run_tests, \
|
||||
NO_MULTIPROCESSING_SPAWN, skipIfRocm, load_tests, IS_REMOTE_GPU, IS_SANDCASTLE, IS_WINDOWS, \
|
||||
slowTest, skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, TEST_WITH_ROCM, TEST_NUMPY, \
|
||||
|
|
@ -1595,7 +1595,7 @@ class TestCuda(TestCase):
|
|||
p = subprocess.Popen([sys.executable, '-c', f"""\
|
||||
import sys
|
||||
import torch
|
||||
from torch._six import inf, nan
|
||||
from torch import inf, nan
|
||||
try:
|
||||
with torch.random.fork_rng(devices=[0]):
|
||||
torch.multinomial(torch.tensor({probs}).to('cuda'), 2, replacement=True)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from torch._six import inf
|
||||
from torch import inf
|
||||
from torch.nn import Parameter
|
||||
from torch.testing._internal import opinfo
|
||||
from torch.testing._internal.common_utils import \
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ import torch
|
|||
# NN tests use double as the default dtype
|
||||
torch.set_default_dtype(torch.double)
|
||||
|
||||
from torch._six import inf, nan
|
||||
from torch import inf, nan
|
||||
import torch.autograd.forward_ad as fwAD
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.nn as nn
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from functools import partial
|
|||
from itertools import product, combinations, permutations
|
||||
import warnings
|
||||
|
||||
from torch._six import inf, nan
|
||||
from torch import inf, nan
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_dtype import (
|
||||
all_types_and_complex_and, get_all_math_dtypes, integral_types, complex_types, floating_types_and,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from functools import partial
|
|||
import random
|
||||
import warnings
|
||||
|
||||
from torch._six import nan
|
||||
from torch import nan
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, skipIfTorchDynamo, torch_to_numpy_dtype_dict)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import torch
|
|||
import numpy as np
|
||||
|
||||
import random
|
||||
from torch._six import nan
|
||||
from torch import nan
|
||||
from itertools import permutations, product
|
||||
|
||||
from torch.testing import make_tensor
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ import textwrap
|
|||
import subprocess
|
||||
import weakref
|
||||
import sys
|
||||
from torch._six import inf, nan, string_classes
|
||||
from torch import inf, nan
|
||||
from itertools import product, combinations, permutations
|
||||
from functools import partial
|
||||
from torch import multiprocessing as mp
|
||||
|
|
@ -8288,7 +8288,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||
ns_name = ns.__name__
|
||||
skip_regexes = []
|
||||
for r in skips:
|
||||
if isinstance(r, string_classes):
|
||||
if isinstance(r, str):
|
||||
skip_regexes.append(re.compile('^{}$'.format(re.escape(r))))
|
||||
else:
|
||||
skip_regexes.append(r)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from numbers import Number
|
|||
import random
|
||||
import unittest
|
||||
|
||||
from torch._six import inf, nan
|
||||
from torch import inf, nan
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
run_tests,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
# ${generated_comment}
|
||||
|
||||
from torch import Tensor, Generator, strided, memory_format, contiguous_format, strided
|
||||
from torch import Tensor, Generator, strided, memory_format, contiguous_format, strided, inf
|
||||
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, Literal, TypeVar
|
||||
from torch._six import inf
|
||||
|
||||
from torch.types import _int, _float, _bool, Number, _dtype, _device, _qscheme, _size, _layout, SymInt, Device
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import torch
|
||||
from torch.package import PackageExporter
|
||||
from torch import Tensor
|
||||
from torch import Tensor, inf
|
||||
from torch.autograd.graph import Node as _Node
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
|
@ -10,7 +10,6 @@ from typing import (
|
|||
Any, BinaryIO, Callable, ContextManager, Dict, Iterable, Iterator, List,
|
||||
NamedTuple, Optional, overload, Sequence, Tuple, TypeVar, Type, Union,
|
||||
Literal, Generic, Set, AnyStr)
|
||||
from torch._six import inf
|
||||
|
||||
from torch.types import (
|
||||
_int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage, SymInt, _dispatchkey
|
||||
|
|
@ -150,11 +149,11 @@ per_channel_symmetric: qscheme = ...
|
|||
per_channel_affine_float_qparams: qscheme = ...
|
||||
|
||||
# Defined in torch/csrc/autograd/python_function.cpp
|
||||
class _FunctionBase(object):
|
||||
class _FunctionBase:
|
||||
...
|
||||
|
||||
# Defined in torch/csrc/autograd/python_legacy_variable.cpp
|
||||
class _LegacyVariableBase(object):
|
||||
class _LegacyVariableBase(Tensor): # inherits from Tensor to appease mypy
|
||||
def __init__(
|
||||
self,
|
||||
data: Optional[Tensor]=...,
|
||||
|
|
@ -168,7 +167,7 @@ class IODescriptor: ...
|
|||
|
||||
class JITException: ...
|
||||
|
||||
class Future(object):
|
||||
class Future:
|
||||
def __init__(self, devices: List[device]) -> None: ...
|
||||
def done(self) -> _bool: ...
|
||||
def value(self) -> Any: ...
|
||||
|
|
@ -178,7 +177,7 @@ class Future(object):
|
|||
def set_result(self, result: Any) -> None: ...
|
||||
def _set_unwrap_func(self, callback: Callable) -> None: ...
|
||||
|
||||
class _Await(object):
|
||||
class _Await:
|
||||
def __init__(self) -> None: ...
|
||||
def fn(self) -> Callable: ...
|
||||
def args(self) -> Tuple[Any, ...]: ...
|
||||
|
|
@ -700,7 +699,7 @@ def _test_only_add_entry_to_op_version(op_name: str, entry: _UpgraderEntry) -> N
|
|||
def _test_only_remove_entry_to_op_version(op_name: str) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/jit/python/script_init.cpp
|
||||
class ScriptModuleSerializer(object):
|
||||
class ScriptModuleSerializer:
|
||||
def __init__(self, export_writer: PyTorchFileWriter) -> None: ...
|
||||
def serialize(self, model: ScriptModule, script_module_id: _int) -> None: ...
|
||||
def write_files(self) -> None: ...
|
||||
|
|
@ -708,14 +707,14 @@ class ScriptModuleSerializer(object):
|
|||
...
|
||||
|
||||
# Defined in torch/csrc/jit/python/script_init.cpp
|
||||
class SerializationStorageContext(object):
|
||||
class SerializationStorageContext:
|
||||
def __init__(self) -> None: ...
|
||||
def has_storage(self, storage: Storage) -> _bool: ...
|
||||
def get_or_add_storage(self, storage: Storage) -> _int: ...
|
||||
...
|
||||
|
||||
# Defined in torch/csrc/jit/python/script_init.cpp
|
||||
class DeserializationStorageContext(object):
|
||||
class DeserializationStorageContext:
|
||||
def __init__(self) -> None: ...
|
||||
def get_storage(self, name: str, dtype: _dtype) -> Tensor: ...
|
||||
def has_storage(self, name: str) -> _bool: ...
|
||||
|
|
@ -971,7 +970,7 @@ def _pop_torch_dispatch_stack() -> Any: ...
|
|||
def _get_dispatch_stack_at(idx: _int) -> Any: ...
|
||||
def _len_torch_dispatch_stack() -> _int: ...
|
||||
|
||||
class _InferenceMode(object):
|
||||
class _InferenceMode:
|
||||
def __init__(self, mode: _bool) -> None: ...
|
||||
|
||||
class _DisableFuncTorch:
|
||||
|
|
@ -987,7 +986,7 @@ class _ViewReplayEnabled:
|
|||
def __init__(self, mode: _bool) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/jit/python/script_init.cpp
|
||||
class LoggerBase(object):
|
||||
class LoggerBase:
|
||||
...
|
||||
|
||||
class NoopLogger(LoggerBase):
|
||||
|
|
@ -1000,7 +999,7 @@ class AggregationType(Enum):
|
|||
SUM = 0
|
||||
AVG = 1
|
||||
|
||||
class FileCheck(object):
|
||||
class FileCheck:
|
||||
def run(self, test_string: str) -> None: ...
|
||||
def check(self, test_string: str) -> 'FileCheck': ...
|
||||
def check_not(self, test_string: str) -> 'FileCheck': ...
|
||||
|
|
@ -1012,7 +1011,7 @@ class FileCheck(object):
|
|||
...
|
||||
|
||||
# Defined in torch/csrc/jit/python/init.cpp
|
||||
class PyTorchFileReader(object):
|
||||
class PyTorchFileReader:
|
||||
@overload
|
||||
def __init__(self, name: str) -> None: ...
|
||||
@overload
|
||||
|
|
@ -1020,7 +1019,7 @@ class PyTorchFileReader(object):
|
|||
def get_record(self, name: str) -> bytes: ...
|
||||
...
|
||||
|
||||
class PyTorchFileWriter(object):
|
||||
class PyTorchFileWriter:
|
||||
@overload
|
||||
def __init__(self, name: str) -> None: ...
|
||||
@overload
|
||||
|
|
@ -1048,7 +1047,7 @@ def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ...
|
|||
def _rename_privateuse1_backend(backend: str) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/Generator.cpp
|
||||
class Generator(object):
|
||||
class Generator:
|
||||
device: _device
|
||||
def __init__(self, device: Union[_device, str, None] = None) -> None: ...
|
||||
def get_state(self) -> Tensor: ...
|
||||
|
|
@ -1127,28 +1126,28 @@ def _dispatch_get_registrations_for_dispatch_key(dispatch_key: str = "") -> List
|
|||
def _are_functorch_transforms_active() -> _bool: ...
|
||||
|
||||
# Define in torch/csrc/autograd/init.cpp
|
||||
class _DisablePythonDispatcher(object):
|
||||
class _DisablePythonDispatcher:
|
||||
pass
|
||||
|
||||
class _EnablePythonDispatcher(object):
|
||||
class _EnablePythonDispatcher:
|
||||
pass
|
||||
|
||||
def _set_python_dispatcher(dispatcher: object) -> None: ...
|
||||
|
||||
|
||||
# Defined in torch/csrc/utils/init.cpp
|
||||
class BenchmarkConfig(object):
|
||||
class BenchmarkConfig:
|
||||
num_calling_threads: _int
|
||||
num_worker_threads: _int
|
||||
num_warmup_iters: _int
|
||||
num_iters: _int
|
||||
profiler_output_path: str
|
||||
|
||||
class BenchmarkExecutionStats(object):
|
||||
class BenchmarkExecutionStats:
|
||||
latency_avg_ms: _float
|
||||
num_iters: _int
|
||||
|
||||
class ThroughputBenchmark(object):
|
||||
class ThroughputBenchmark:
|
||||
def __init__(self, module: Any) -> None: ...
|
||||
def add_input(self, *args: Any, **kwargs: Any) -> None: ...
|
||||
def run_once(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
|
@ -1162,7 +1161,9 @@ ${legacy_class_hints}
|
|||
|
||||
# Defined in torch/csrc/autograd/python_engine.cpp
|
||||
class _ImperativeEngine:
|
||||
...
|
||||
def queue_callback(self, callback: Callable[[], None]) -> None: ...
|
||||
def run_backward(self, *args: Any, **kwargs: Any) -> Tuple[Tensor, ...]: ...
|
||||
def is_checkpoint_valid(self) -> _bool: ...
|
||||
|
||||
# Defined in torch/csrc/autograd/python_variable.cpp
|
||||
class _TensorMeta(type):
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
# ${generated_comment}
|
||||
|
||||
from torch import Tensor, Generator, strided, memory_format, contiguous_format, strided
|
||||
from torch import Tensor, Generator, strided, memory_format, contiguous_format, strided, inf
|
||||
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, Literal, TypeVar
|
||||
from torch._six import inf
|
||||
|
||||
from torch.types import _int, _float, _bool, Number, _dtype, _device, _qscheme, _size, _layout
|
||||
|
||||
|
|
|
|||
|
|
@ -28,8 +28,6 @@ if sys.executable == 'torch_deploy':
|
|||
else:
|
||||
from .torch_version import __version__ as __version__
|
||||
|
||||
from ._six import string_classes as _string_classes
|
||||
|
||||
from typing import Any, Callable, Dict, Optional, Set, Type, TYPE_CHECKING, Union
|
||||
import builtins
|
||||
|
||||
|
|
@ -593,7 +591,7 @@ def set_default_tensor_type(t):
|
|||
torch.float64
|
||||
|
||||
"""
|
||||
if isinstance(t, _string_classes):
|
||||
if isinstance(t, str):
|
||||
t = _import_dotted_name(t)
|
||||
_C._set_default_tensor_type(t)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,41 +0,0 @@
|
|||
# Copyright (c) 2010-2017 Benjamin Peterson
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import math
|
||||
|
||||
inf = math.inf
|
||||
nan = math.nan
|
||||
string_classes = (str, bytes)
|
||||
|
||||
|
||||
def with_metaclass(meta: type, *bases) -> type:
|
||||
"""Create a base class with a metaclass."""
|
||||
# This requires a bit of explanation: the basic idea is to make a dummy
|
||||
# metaclass for one level of class instantiation that replaces itself with
|
||||
# the actual metaclass.
|
||||
class metaclass(meta): # type: ignore[misc, valid-type]
|
||||
def __new__(cls, name, this_bases, d):
|
||||
return meta(name, bases, d)
|
||||
|
||||
@classmethod
|
||||
def __prepare__(cls, name, this_bases):
|
||||
return meta.__prepare__(name, bases)
|
||||
|
||||
return type.__new__(metaclass, "temporary_class", (), {})
|
||||
|
|
@ -3,7 +3,7 @@ import textwrap
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch._six import inf
|
||||
from torch import inf
|
||||
|
||||
|
||||
class __PrinterOptions:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import torch._C as _C
|
|||
from torch._C import _functions
|
||||
import torch._functorch as _functorch
|
||||
import torch.utils.hooks as hooks
|
||||
from torch._six import with_metaclass
|
||||
import functools
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
|
|
@ -294,8 +293,7 @@ class FunctionMeta(type):
|
|||
super(FunctionMeta, cls).__init__(name, bases, attrs)
|
||||
|
||||
|
||||
# mypy doesn't understand `with_metaclass` from torch._six
|
||||
class _SingleLevelFunction(with_metaclass(FunctionMeta, _C._FunctionBase, FunctionCtx, _HookMixin)): # type: ignore[misc]
|
||||
class _SingleLevelFunction(_C._FunctionBase, FunctionCtx, _HookMixin, metaclass=FunctionMeta):
|
||||
@staticmethod
|
||||
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
r"""
|
||||
|
|
@ -505,7 +503,7 @@ class Function(_SingleLevelFunction):
|
|||
if not torch._C._are_functorch_transforms_active():
|
||||
# See NOTE: [functorch vjp and autograd interaction]
|
||||
args = _functorch.utils.unwrap_dead_wrappers(args)
|
||||
return super().apply(*args, **kwargs)
|
||||
return super().apply(*args, **kwargs) # type: ignore[misc]
|
||||
|
||||
if cls.setup_context == _SingleLevelFunction.setup_context:
|
||||
raise RuntimeError(
|
||||
|
|
@ -680,14 +678,14 @@ class NestedIOFunction(Function):
|
|||
def _do_forward(self, *input):
|
||||
self._nested_input = input
|
||||
flat_input = tuple(_iter_tensors(input))
|
||||
flat_output = super()._do_forward(*flat_input)
|
||||
flat_output = super()._do_forward(*flat_input) # type: ignore[misc]
|
||||
nested_output = self._nested_output
|
||||
nested_tensors = _unflatten(flat_output, self._nested_output)
|
||||
return nested_tensors
|
||||
|
||||
def _do_backward(self, gradients, retain_variables):
|
||||
self.retain_variables = retain_variables
|
||||
result = super()._do_backward(gradients, retain_variables)
|
||||
result = super()._do_backward(gradients, retain_variables) # type: ignore[misc]
|
||||
if not retain_variables:
|
||||
del self._nested_output
|
||||
del self._to_save_nested
|
||||
|
|
@ -713,7 +711,7 @@ class NestedIOFunction(Function):
|
|||
|
||||
@property
|
||||
def saved_tensors(self):
|
||||
flat_tensors = super().saved_tensors
|
||||
flat_tensors = super().saved_tensors # type: ignore[misc]
|
||||
return _unflatten(flat_tensors, self._to_save_nested)
|
||||
|
||||
def mark_dirty(self, *args: Any, **kwargs: Any) -> None:
|
||||
|
|
|
|||
|
|
@ -1,15 +1,14 @@
|
|||
import torch
|
||||
from torch._six import with_metaclass
|
||||
from torch._C import _ImperativeEngine as ImperativeEngine
|
||||
|
||||
|
||||
__all__ = ["VariableMeta", "Variable"]
|
||||
|
||||
|
||||
class VariableMeta(type):
|
||||
def __instancecheck__(cls, other):
|
||||
return isinstance(other, torch.Tensor)
|
||||
|
||||
# mypy doesn't understand torch._six.with_metaclass
|
||||
class Variable(with_metaclass(VariableMeta, torch._C._LegacyVariableBase)): # type: ignore[misc]
|
||||
pass
|
||||
|
||||
from torch._C import _ImperativeEngine as ImperativeEngine
|
||||
Variable._execution_engine = ImperativeEngine()
|
||||
class Variable(torch._C._LegacyVariableBase, metaclass=VariableMeta): # type: ignore[misc]
|
||||
_execution_engine = ImperativeEngine()
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ try:
|
|||
HAS_NUMPY = True
|
||||
except ModuleNotFoundError:
|
||||
np = None # type: ignore[assignment]
|
||||
from torch._six import string_classes
|
||||
from typing import Any
|
||||
|
||||
__all__ = ["autocast", "custom_fwd", "custom_bwd"]
|
||||
|
|
@ -48,7 +47,7 @@ def _cast(value, dtype):
|
|||
if isinstance(value, torch.Tensor):
|
||||
is_eligible = (value.is_floating_point() and value.is_cuda and (value.dtype is not torch.float64))
|
||||
return value.to(dtype) if is_eligible else value
|
||||
elif isinstance(value, string_classes):
|
||||
elif isinstance(value, str):
|
||||
return value
|
||||
elif HAS_NUMPY and isinstance(value, np.ndarray):
|
||||
return value
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ class _DDPSink(Function):
|
|||
# Enqueue delay allreduce for static graph training on the first
|
||||
# iteration.
|
||||
if state_dict["static_graph"] and state_dict["num_iterations"] == 1:
|
||||
Variable._execution_engine.queue_callback(ctx.reducer._delay_all_reduce)
|
||||
Variable._execution_engine.queue_callback(ctx.reducer._delay_all_reduce) # type: ignore[call-arg,misc]
|
||||
|
||||
return (None, None, *grad_outputs)
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,6 @@ from torch._C._distributed_c10d import (
|
|||
get_debug_level,
|
||||
Work
|
||||
)
|
||||
from torch._six import string_classes
|
||||
from torch.autograd.profiler import record_function
|
||||
from .constants import default_pg_timeout
|
||||
from .c10d_error_logger import _get_or_create_logger
|
||||
|
|
@ -178,7 +177,7 @@ class Backend:
|
|||
backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI]
|
||||
|
||||
def __new__(cls, name: str):
|
||||
if not isinstance(name, string_classes):
|
||||
if not isinstance(name, str):
|
||||
raise ValueError("Backend name must be a string, but got: {}".format(name))
|
||||
value = getattr(Backend, name.upper(), Backend.UNDEFINED)
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import sys
|
|||
from datetime import timedelta
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch._six as six
|
||||
from torch.distributed import FileStore, PrefixStore, Store, TCPStore
|
||||
|
||||
from .constants import default_pg_timeout
|
||||
|
|
@ -91,7 +90,7 @@ def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwa
|
|||
|
||||
|
||||
def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):
|
||||
if not isinstance(url, six.string_classes):
|
||||
if not isinstance(url, str):
|
||||
raise RuntimeError("`url` must be a string. {}: {}".format(type(url), url))
|
||||
|
||||
if not isinstance(rank, numbers.Integral):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from numbers import Number
|
||||
|
||||
import torch
|
||||
from torch._six import nan
|
||||
from torch import nan
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.exp_family import ExponentialFamily
|
||||
from torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
from torch._six import nan
|
||||
from torch import nan
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.utils import probs_to_logits, logits_to_probs, lazy_property
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import math
|
||||
from torch._six import inf, nan
|
||||
from torch import inf, nan
|
||||
from numbers import Number
|
||||
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from numbers import Number
|
||||
import torch
|
||||
from torch._six import nan
|
||||
from torch import nan
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.gamma import Gamma
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
from torch._six import inf
|
||||
from torch import inf
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.transforms import AbsTransform
|
||||
from torch.distributions.cauchy import Cauchy
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
from torch._six import inf
|
||||
from torch import inf
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.transforms import AbsTransform
|
||||
from torch.distributions.normal import Normal
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from functools import total_ordering
|
|||
from typing import Type, Dict, Callable, Tuple
|
||||
|
||||
import torch
|
||||
from torch._six import inf
|
||||
from torch import inf
|
||||
|
||||
from .bernoulli import Bernoulli
|
||||
from .beta import Beta
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
from torch._six import nan
|
||||
from torch import nan
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.uniform import Uniform
|
||||
from torch.distributions.transformed_distribution import TransformedDistribution
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
from torch._six import inf
|
||||
from torch import inf
|
||||
from torch.distributions.binomial import Binomial
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions import Categorical
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
from torch._six import inf, nan
|
||||
from torch import inf, nan
|
||||
from torch.distributions import Chi2, constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.utils import _standard_normal, broadcast_all
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from numbers import Number
|
||||
|
||||
import torch
|
||||
from torch._six import nan
|
||||
from torch import nan
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.utils import broadcast_all
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from numbers import Number
|
|||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch._six import nan
|
||||
from torch import nan
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.exp_family import ExponentialFamily
|
||||
from torch.distributions.utils import lazy_property
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
import six
|
||||
|
||||
from .utils import typename
|
||||
|
||||
__all__ = ["VariadicSignatureType", "isvariadic", "VariadicSignatureMeta", "Variadic"]
|
||||
|
|
@ -72,7 +70,7 @@ class VariadicSignatureMeta(type):
|
|||
)
|
||||
|
||||
|
||||
class Variadic(six.with_metaclass(VariadicSignatureMeta)):
|
||||
class Variadic(metaclass=VariadicSignatureMeta):
|
||||
"""A class whose getitem method can be used to generate a new type
|
||||
representing a specific variadic signature.
|
||||
Examples
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ from torch.jit._recursive import ScriptMethodStub, wrap_cpp_module, infer_method
|
|||
from torch.nn import Module
|
||||
from torch.jit._state import _enabled
|
||||
from torch.jit._builtins import _register_builtin
|
||||
from torch._six import with_metaclass
|
||||
from torch.jit.frontend import get_jit_def, get_default_args, get_jit_class_def
|
||||
from torch._jit_internal import _qualified_name
|
||||
from torch.jit._fuser import _graph_for, _script_method_graph_for
|
||||
|
|
@ -484,7 +483,7 @@ if _enabled:
|
|||
# did nothing, __getattr__ would not be called. Instead we'd get nn.Module.forward
|
||||
# which always throws an exception.
|
||||
|
||||
class ScriptModule(with_metaclass(ScriptMeta, Module)): # type: ignore[misc]
|
||||
class ScriptModule(Module, metaclass=ScriptMeta):
|
||||
r"""
|
||||
A wrapper around C++ ``torch::jit::Module``. ``ScriptModule``\s
|
||||
contain methods, attributes, parameters, and
|
||||
|
|
@ -495,7 +494,7 @@ if _enabled:
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
forward = _CachedForward()
|
||||
forward: Callable[..., Any] = _CachedForward() # type: ignore[assignment]
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if "_actual_script_module" not in self.__dict__:
|
||||
|
|
@ -650,11 +649,11 @@ if _enabled:
|
|||
modules = {}
|
||||
for name, cpp_module in torch._C.ModuleDict(self._c).items():
|
||||
modules[name] = wrap_cpp_module(cpp_module)
|
||||
self._modules = OrderedModuleDict(self._c, modules)
|
||||
self._modules = OrderedModuleDict(self._c, modules) # type: ignore[assignment]
|
||||
|
||||
# Copy parameters and buffers.
|
||||
self._parameters = OrderedDictWrapper(torch._C.ParameterDict(self._c))
|
||||
self._buffers = OrderedDictWrapper(torch._C.BufferDict(self._c))
|
||||
self._parameters = OrderedDictWrapper(torch._C.ParameterDict(self._c)) # type: ignore[assignment]
|
||||
self._buffers = OrderedDictWrapper(torch._C.BufferDict(self._c)) # type: ignore[assignment]
|
||||
|
||||
# Get rid of the functions from the old C++ module.
|
||||
self.__dict__ = {
|
||||
|
|
@ -679,7 +678,7 @@ if _enabled:
|
|||
``forward`` method. This graph will be preprocessed to inline all function and method calls.
|
||||
See :ref:`interpreting-graphs` for details.
|
||||
"""
|
||||
return self.forward.inlined_graph
|
||||
return self.forward.inlined_graph # type: ignore[attr-defined]
|
||||
|
||||
@property
|
||||
def code(self):
|
||||
|
|
@ -688,7 +687,7 @@ if _enabled:
|
|||
the internal graph for the ``forward`` method. See
|
||||
:ref:`inspecting-code` for details.
|
||||
"""
|
||||
return self.forward.code
|
||||
return self.forward.code # type: ignore[attr-defined]
|
||||
|
||||
@property
|
||||
def code_with_constants(self):
|
||||
|
|
@ -702,7 +701,7 @@ if _enabled:
|
|||
|
||||
See :ref:`inspecting-code` for details.
|
||||
"""
|
||||
r = self.forward.code_with_constants
|
||||
r = self.forward.code_with_constants # type: ignore[attr-defined]
|
||||
return (r[0], ConstMap(r[1]))
|
||||
|
||||
def save(self, f, **kwargs):
|
||||
|
|
@ -740,7 +739,7 @@ if _enabled:
|
|||
return "original_name={}".format(self.original_name)
|
||||
|
||||
def graph_for(self, *args, **kwargs):
|
||||
return self.forward.graph_for(self, *args, **kwargs)
|
||||
return self.forward.graph_for(self, *args, **kwargs) # type: ignore[attr-defined]
|
||||
|
||||
@property
|
||||
def original_name(self):
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import os
|
|||
import pathlib
|
||||
|
||||
import torch
|
||||
from torch._six import string_classes
|
||||
from torch.jit._recursive import wrap_cpp_module
|
||||
from torch.serialization import validate_cuda_device
|
||||
|
||||
|
|
@ -148,7 +147,7 @@ def load(f, map_location=None, _extra_files=None, _restore_shapes=False):
|
|||
os.remove("scriptmodule.pt")
|
||||
"""
|
||||
|
||||
if isinstance(f, string_classes):
|
||||
if isinstance(f, str):
|
||||
if not os.path.exists(f): # type: ignore[type-var]
|
||||
raise ValueError("The provided filename {} does not exist".format(f)) # type: ignore[str-bytes-safe]
|
||||
if os.path.isdir(f):
|
||||
|
|
@ -197,7 +196,7 @@ def get_ff_module():
|
|||
|
||||
def jit_module_from_flatbuffer(f):
|
||||
ff = get_ff_module()
|
||||
if isinstance(f, string_classes):
|
||||
if isinstance(f, str):
|
||||
if not os.path.exists(f): # type: ignore[type-var]
|
||||
raise ValueError("The provided filename {} does not exist".format(f)) # type: ignore[str-bytes-safe]
|
||||
if os.path.isdir(f):
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import functools
|
|||
import warnings
|
||||
import inspect
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
from torch.jit._state import _python_cu, _enabled
|
||||
from torch.jit._script import ScriptModule, _CachedForward, script
|
||||
|
|
@ -1198,7 +1198,7 @@ class TracedModule(ScriptModule):
|
|||
|
||||
|
||||
class TopLevelTracedModule(TracedModule):
|
||||
forward = _CachedForward()
|
||||
forward: Callable[..., Any] = _CachedForward() # type: ignore[assignment]
|
||||
|
||||
def _reconstruct(self, cpp_module):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -512,7 +512,7 @@ class Module:
|
|||
if '_buffers' not in self.__dict__:
|
||||
raise AttributeError(
|
||||
"cannot assign buffer before Module.__init__() call")
|
||||
elif not isinstance(name, torch._six.string_classes):
|
||||
elif not isinstance(name, str):
|
||||
raise TypeError("buffer name should be a string. "
|
||||
"Got {}".format(torch.typename(name)))
|
||||
elif '.' in name:
|
||||
|
|
@ -553,7 +553,7 @@ class Module:
|
|||
raise AttributeError(
|
||||
"cannot assign parameter before Module.__init__() call")
|
||||
|
||||
elif not isinstance(name, torch._six.string_classes):
|
||||
elif not isinstance(name, str):
|
||||
raise TypeError("parameter name should be a string. "
|
||||
"Got {}".format(torch.typename(name)))
|
||||
elif '.' in name:
|
||||
|
|
@ -595,7 +595,7 @@ class Module:
|
|||
if not isinstance(module, Module) and module is not None:
|
||||
raise TypeError("{} is not a Module subclass".format(
|
||||
torch.typename(module)))
|
||||
elif not isinstance(name, torch._six.string_classes):
|
||||
elif not isinstance(name, str):
|
||||
raise TypeError("module name should be a string. Got {}".format(
|
||||
torch.typename(name)))
|
||||
elif hasattr(self, name) and name not in self._modules:
|
||||
|
|
|
|||
|
|
@ -184,7 +184,7 @@ class _DDPSink(Function):
|
|||
ctx.state_dict["static_graph"]
|
||||
and ctx.state_dict["num_iterations"] == 1
|
||||
):
|
||||
Variable._execution_engine.queue_callback(
|
||||
Variable._execution_engine.queue_callback( # type: ignore[call-arg,misc]
|
||||
ctx.reducer._delay_all_reduce
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,7 @@ import warnings
|
|||
from typing import Union, Iterable, List, Dict, Tuple, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._six import inf
|
||||
from torch import Tensor, inf
|
||||
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype, _has_foreach_support
|
||||
|
||||
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
|
||||
|
|
|
|||
|
|
@ -310,7 +310,7 @@ def _create_node(
|
|||
@_beartype.beartype
|
||||
def _is_onnx_list(value):
|
||||
return (
|
||||
not isinstance(value, torch._six.string_classes)
|
||||
not isinstance(value, str)
|
||||
and not isinstance(value, torch.Tensor)
|
||||
and isinstance(value, Iterable)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -959,7 +959,7 @@ def _create_jit_graph(
|
|||
|
||||
if isinstance(model, torch.jit.ScriptModule):
|
||||
try:
|
||||
graph = model.forward.graph
|
||||
graph = model.forward.graph # type: ignore[attr-defined]
|
||||
except AttributeError as e:
|
||||
raise RuntimeError("'forward' method must be a script method") from e
|
||||
_C._jit_pass_onnx_function_substitution(graph)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import types
|
||||
import math
|
||||
from torch._six import inf
|
||||
from torch import inf
|
||||
from functools import wraps
|
||||
import warnings
|
||||
import weakref
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import tempfile
|
|||
import warnings
|
||||
from contextlib import closing, contextmanager
|
||||
from ._utils import _import_dotted_name
|
||||
from ._six import string_classes as _string_classes
|
||||
from torch._sources import get_source_lines_and_file
|
||||
from torch.types import Storage
|
||||
from torch.storage import _get_dtype_from_pickle_storage_type
|
||||
|
|
@ -1079,7 +1078,7 @@ def _get_restore_location(map_location):
|
|||
def restore_location(storage, location):
|
||||
location = map_location.get(location, location)
|
||||
return default_restore_location(storage, location)
|
||||
elif isinstance(map_location, _string_classes):
|
||||
elif isinstance(map_location, str):
|
||||
def restore_location(storage, location):
|
||||
return default_restore_location(storage, map_location)
|
||||
elif isinstance(map_location, torch.device):
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import enum
|
|||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch._six import inf, nan
|
||||
from torch import inf, nan
|
||||
|
||||
from typing import Any, Dict, List, Tuple, Union, Sequence
|
||||
from torch.testing import make_tensor
|
||||
|
|
|
|||
|
|
@ -65,7 +65,6 @@ import torch.backends.xnnpack
|
|||
import torch.cuda
|
||||
from torch import Tensor
|
||||
from torch._C import ScriptDict, ScriptList # type: ignore[attr-defined]
|
||||
from torch._six import string_classes
|
||||
from torch._utils_internal import get_writable_path
|
||||
from torch.nn import (
|
||||
ModuleDict,
|
||||
|
|
@ -589,7 +588,7 @@ def shell(command, cwd=None, env=None, stdout=None, stderr=None):
|
|||
# `p.wait()` in a `final` block for the code to be portable.
|
||||
#
|
||||
# https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323
|
||||
assert not isinstance(command, torch._six.string_classes), "Command to shell should be a list or tuple of tokens"
|
||||
assert not isinstance(command, str), "Command to shell should be a list or tuple of tokens"
|
||||
p = subprocess.Popen(command, universal_newlines=True, cwd=cwd, env=env, stdout=stdout, stderr=stderr)
|
||||
return wait_for_process(p)
|
||||
|
||||
|
|
@ -1924,7 +1923,7 @@ class UnittestPair(Pair):
|
|||
|
||||
|
||||
class StringPair(UnittestPair):
|
||||
CLS = string_classes
|
||||
CLS = str
|
||||
TYPE_NAME = "string"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from typing import Any, Dict, List, Union
|
|||
import math # noqa: F401
|
||||
|
||||
# Testing utils
|
||||
from torch._six import inf
|
||||
from torch import inf
|
||||
|
||||
# TODO: include files like this should not set the default dtype
|
||||
torch.set_default_dtype(torch.double)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import re
|
|||
import torch
|
||||
|
||||
from typing import Callable, Dict, Optional, Tuple, Type, Union
|
||||
from torch._six import string_classes
|
||||
|
||||
np_str_obj_array_pattern = re.compile(r'[SaUO]')
|
||||
|
||||
|
|
@ -70,7 +69,7 @@ def default_convert(data):
|
|||
return elem_type(*(default_convert(d) for d in data))
|
||||
elif isinstance(data, tuple):
|
||||
return [default_convert(d) for d in data] # Backwards compatibility.
|
||||
elif isinstance(data, collections.abc.Sequence) and not isinstance(data, string_classes):
|
||||
elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str):
|
||||
try:
|
||||
return elem_type([default_convert(d) for d in data])
|
||||
except TypeError:
|
||||
|
|
@ -198,7 +197,7 @@ with contextlib.suppress(ImportError):
|
|||
default_collate_fn_map[(np.bool_, np.number, np.object_)] = collate_numpy_scalar_fn
|
||||
default_collate_fn_map[float] = collate_float_fn
|
||||
default_collate_fn_map[int] = collate_int_fn
|
||||
default_collate_fn_map[string_classes] = collate_str_fn
|
||||
default_collate_fn_map[str] = collate_str_fn
|
||||
|
||||
|
||||
def default_collate(batch):
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ import collections
|
|||
import queue
|
||||
|
||||
import torch
|
||||
from torch._six import string_classes
|
||||
from . import MP_STATUS_CHECK_INTERVAL
|
||||
from torch._utils import ExceptionWrapper
|
||||
|
||||
|
|
@ -54,7 +53,7 @@ def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device):
|
|||
def pin_memory(data, device=None):
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.pin_memory(device)
|
||||
elif isinstance(data, string_classes):
|
||||
elif isinstance(data, str):
|
||||
return data
|
||||
elif isinstance(data, collections.abc.Mapping):
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ import torch.multiprocessing as multiprocessing
|
|||
import torch.utils.data.graph_settings
|
||||
|
||||
from torch._utils import ExceptionWrapper
|
||||
from torch._six import string_classes
|
||||
|
||||
from . import (
|
||||
IterDataPipe,
|
||||
|
|
@ -396,7 +395,7 @@ class DataLoader(Generic[T_co]):
|
|||
def multiprocessing_context(self, multiprocessing_context):
|
||||
if multiprocessing_context is not None:
|
||||
if self.num_workers > 0:
|
||||
if isinstance(multiprocessing_context, string_classes):
|
||||
if isinstance(multiprocessing_context, str):
|
||||
valid_start_methods = multiprocessing.get_all_start_methods()
|
||||
if multiprocessing_context not in valid_start_methods:
|
||||
raise ValueError(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user