mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE][Easy] export explicitly imported public submodules (#127703)
Add top-level submodules `torch.{storage,serialization,functional,amp,overrides,types}`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127703
Approved by: https://github.com/ezyang
This commit is contained in:
parent
62311257ad
commit
dcc0093dba
|
|
@ -1,6 +1,4 @@
|
||||||
# mypy: allow-untyped-defs
|
"""
|
||||||
|
|
||||||
r"""
|
|
||||||
The torch package contains data structures for multi-dimensional
|
The torch package contains data structures for multi-dimensional
|
||||||
tensors and defines mathematical operations over these tensors.
|
tensors and defines mathematical operations over these tensors.
|
||||||
Additionally, it provides many utilities for efficient serialization of
|
Additionally, it provides many utilities for efficient serialization of
|
||||||
|
|
@ -10,6 +8,8 @@ It has a CUDA counterpart, that enables you to run your tensor computations
|
||||||
on an NVIDIA GPU with compute capability >= 3.0.
|
on an NVIDIA GPU with compute capability >= 3.0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# mypy: allow-untyped-defs
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
@ -289,10 +289,6 @@ else:
|
||||||
_load_global_deps()
|
_load_global_deps()
|
||||||
from torch._C import * # noqa: F403
|
from torch._C import * # noqa: F403
|
||||||
|
|
||||||
# Appease the type checker; ordinarily this binding is inserted by the
|
|
||||||
# torch._C module initialization code in C
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from . import _C as _C # noqa: TCH004
|
|
||||||
|
|
||||||
class SymInt:
|
class SymInt:
|
||||||
"""
|
"""
|
||||||
|
|
@ -614,10 +610,9 @@ def sym_not(a):
|
||||||
a (SymBool or bool): Object to negate
|
a (SymBool or bool): Object to negate
|
||||||
"""
|
"""
|
||||||
import sympy
|
import sympy
|
||||||
from .overrides import has_torch_function_unary, handle_torch_function
|
|
||||||
|
|
||||||
if has_torch_function_unary(a):
|
if overrides.has_torch_function_unary(a):
|
||||||
return handle_torch_function(sym_not, (a,), a)
|
return overrides.handle_torch_function(sym_not, (a,), a)
|
||||||
if hasattr(a, '__sym_not__'):
|
if hasattr(a, '__sym_not__'):
|
||||||
return a.__sym_not__()
|
return a.__sym_not__()
|
||||||
if isinstance(a, sympy.Basic):
|
if isinstance(a, sympy.Basic):
|
||||||
|
|
@ -630,10 +625,8 @@ def sym_float(a):
|
||||||
Args:
|
Args:
|
||||||
a (SymInt, SymFloat, or object): Object to cast
|
a (SymInt, SymFloat, or object): Object to cast
|
||||||
"""
|
"""
|
||||||
from .overrides import has_torch_function_unary, handle_torch_function
|
if overrides.has_torch_function_unary(a):
|
||||||
|
return overrides.handle_torch_function(sym_float, (a,), a)
|
||||||
if has_torch_function_unary(a):
|
|
||||||
return handle_torch_function(sym_float, (a,), a)
|
|
||||||
if isinstance(a, SymFloat):
|
if isinstance(a, SymFloat):
|
||||||
return a
|
return a
|
||||||
elif hasattr(a, '__sym_float__'):
|
elif hasattr(a, '__sym_float__'):
|
||||||
|
|
@ -647,10 +640,8 @@ def sym_int(a):
|
||||||
Args:
|
Args:
|
||||||
a (SymInt, SymFloat, or object): Object to cast
|
a (SymInt, SymFloat, or object): Object to cast
|
||||||
"""
|
"""
|
||||||
from .overrides import has_torch_function_unary, handle_torch_function
|
if overrides.has_torch_function_unary(a):
|
||||||
|
return overrides.handle_torch_function(sym_int, (a,), a)
|
||||||
if has_torch_function_unary(a):
|
|
||||||
return handle_torch_function(sym_int, (a,), a)
|
|
||||||
if isinstance(a, SymInt):
|
if isinstance(a, SymInt):
|
||||||
return a
|
return a
|
||||||
elif isinstance(a, SymFloat):
|
elif isinstance(a, SymFloat):
|
||||||
|
|
@ -664,10 +655,8 @@ def sym_max(a, b):
|
||||||
promotes to float if any argument is float (unlike builtins.max, which
|
promotes to float if any argument is float (unlike builtins.max, which
|
||||||
will faithfully preserve the type of the input argument).
|
will faithfully preserve the type of the input argument).
|
||||||
"""
|
"""
|
||||||
from .overrides import has_torch_function, handle_torch_function
|
if overrides.has_torch_function((a, b)):
|
||||||
|
return overrides.handle_torch_function(sym_max, (a, b), a, b)
|
||||||
if has_torch_function((a, b)):
|
|
||||||
return handle_torch_function(sym_max, (a, b), a, b)
|
|
||||||
if isinstance(a, (SymInt, SymFloat)):
|
if isinstance(a, (SymInt, SymFloat)):
|
||||||
return a.__sym_max__(b)
|
return a.__sym_max__(b)
|
||||||
elif isinstance(b, (SymInt, SymFloat)):
|
elif isinstance(b, (SymInt, SymFloat)):
|
||||||
|
|
@ -683,11 +672,9 @@ def sym_max(a, b):
|
||||||
return builtins.max(a, b)
|
return builtins.max(a, b)
|
||||||
|
|
||||||
def sym_min(a, b):
|
def sym_min(a, b):
|
||||||
""" SymInt-aware utility for min()."""
|
"""SymInt-aware utility for min()."""
|
||||||
from .overrides import has_torch_function, handle_torch_function
|
if overrides.has_torch_function((a, b)):
|
||||||
|
return overrides.handle_torch_function(sym_min, (a, b), a, b)
|
||||||
if has_torch_function((a, b)):
|
|
||||||
return handle_torch_function(sym_min, (a, b), a, b)
|
|
||||||
if isinstance(a, (SymInt, SymFloat)):
|
if isinstance(a, (SymInt, SymFloat)):
|
||||||
return a.__sym_min__(b)
|
return a.__sym_min__(b)
|
||||||
elif isinstance(b, (SymInt, SymFloat)):
|
elif isinstance(b, (SymInt, SymFloat)):
|
||||||
|
|
@ -702,10 +689,8 @@ def sym_min(a, b):
|
||||||
# Drop in replacement for math.sqrt, math.sin, math.cos etc
|
# Drop in replacement for math.sqrt, math.sin, math.cos etc
|
||||||
def _get_sym_math_fn(name):
|
def _get_sym_math_fn(name):
|
||||||
def fn(a):
|
def fn(a):
|
||||||
from .overrides import has_torch_function_unary, handle_torch_function
|
if overrides.has_torch_function_unary(a):
|
||||||
|
return overrides.handle_torch_function(fn, (a,), a)
|
||||||
if has_torch_function_unary(a):
|
|
||||||
return handle_torch_function(fn, (a,), a)
|
|
||||||
if hasattr(a, f"__sym_{name}__"):
|
if hasattr(a, f"__sym_{name}__"):
|
||||||
return getattr(a, f"__sym_{name}__")()
|
return getattr(a, f"__sym_{name}__")()
|
||||||
return getattr(math, name)(a)
|
return getattr(math, name)(a)
|
||||||
|
|
@ -727,10 +712,8 @@ __all__.append("sym_sqrt")
|
||||||
|
|
||||||
|
|
||||||
def sym_ite(b, t, f):
|
def sym_ite(b, t, f):
|
||||||
from .overrides import has_torch_function, handle_torch_function
|
if overrides.has_torch_function((b, t, f)):
|
||||||
|
return overrides.handle_torch_function(sym_ite, (b, t, f), b, t, f)
|
||||||
if has_torch_function((b, t, f)):
|
|
||||||
return handle_torch_function(sym_ite, (b, t, f), b, t, f)
|
|
||||||
assert isinstance(b, (SymBool, builtins.bool)) and type(t) == type(f)
|
assert isinstance(b, (SymBool, builtins.bool)) and type(t) == type(f)
|
||||||
if isinstance(b, SymBool):
|
if isinstance(b, SymBool):
|
||||||
return b.__sym_ite__(t, f)
|
return b.__sym_ite__(t, f)
|
||||||
|
|
@ -760,16 +743,20 @@ except ImportError:
|
||||||
''').strip()) from None
|
''').strip()) from None
|
||||||
raise # If __file__ is not None the cause is unknown, so just re-raise.
|
raise # If __file__ is not None the cause is unknown, so just re-raise.
|
||||||
|
|
||||||
|
# The torch._C submodule is already loaded via `from torch._C import *` above
|
||||||
|
# Make an explicit reference to the _C submodule to appease linters
|
||||||
|
from torch import _C as _C
|
||||||
|
|
||||||
__name, __obj = '', None
|
__name, __obj = '', None
|
||||||
for __name in dir(_C):
|
for __name in dir(_C):
|
||||||
if __name[0] != '_' and not __name.endswith('Base'):
|
if __name[0] != '_' and not __name.endswith('Base'):
|
||||||
__all__.append(__name)
|
__all__.append(__name)
|
||||||
__obj = getattr(_C, __name)
|
__obj = getattr(_C, __name)
|
||||||
if callable(__obj) or inspect.isclass(__obj):
|
if callable(__obj) or inspect.isclass(__obj):
|
||||||
if __obj.__module__ != __name__:
|
if __obj.__module__ != __name__: # "torch"
|
||||||
# TODO: fix their module from C++ side
|
# TODO: fix their module from C++ side
|
||||||
if __name not in ['DisableTorchFunctionSubclass', 'DisableTorchFunction', 'Generator']:
|
if __name not in ['DisableTorchFunctionSubclass', 'DisableTorchFunction', 'Generator']:
|
||||||
__obj.__module__ = __name__
|
__obj.__module__ = __name__ # "torch"
|
||||||
elif __name == 'TensorBase':
|
elif __name == 'TensorBase':
|
||||||
# issue 109438 / pr 109940. Prevent TensorBase from being copied into torch.
|
# issue 109438 / pr 109940. Prevent TensorBase from being copied into torch.
|
||||||
delattr(sys.modules[__name__], __name)
|
delattr(sys.modules[__name__], __name)
|
||||||
|
|
@ -1478,6 +1465,7 @@ __all__.extend(['e', 'pi', 'nan', 'inf', 'newaxis'])
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
||||||
from ._tensor import Tensor
|
from ._tensor import Tensor
|
||||||
|
from torch import storage as storage
|
||||||
from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage, _warn_typed_storage_removal
|
from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage, _warn_typed_storage_removal
|
||||||
|
|
||||||
# NOTE: New <type>Storage classes should never be added. When adding a new
|
# NOTE: New <type>Storage classes should never be added. When adding a new
|
||||||
|
|
@ -1665,7 +1653,9 @@ _storage_classes = {
|
||||||
_tensor_classes: Set[Type] = set()
|
_tensor_classes: Set[Type] = set()
|
||||||
|
|
||||||
# If you edit these imports, please update torch/__init__.py.in as well
|
# If you edit these imports, please update torch/__init__.py.in as well
|
||||||
|
from torch import random as random
|
||||||
from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed
|
from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed
|
||||||
|
from torch import serialization as serialization
|
||||||
from .serialization import save, load
|
from .serialization import save, load
|
||||||
from ._tensor_str import set_printoptions
|
from ._tensor_str import set_printoptions
|
||||||
|
|
||||||
|
|
@ -1682,6 +1672,7 @@ def _manager_path():
|
||||||
raise RuntimeError("Unable to find torch_shm_manager at " + path)
|
raise RuntimeError("Unable to find torch_shm_manager at " + path)
|
||||||
return path.encode('utf-8')
|
return path.encode('utf-8')
|
||||||
|
|
||||||
|
from torch import amp as amp
|
||||||
from torch.amp import autocast, GradScaler
|
from torch.amp import autocast, GradScaler
|
||||||
|
|
||||||
# Initializing the extension shadows the built-in python float / int classes;
|
# Initializing the extension shadows the built-in python float / int classes;
|
||||||
|
|
@ -1717,7 +1708,7 @@ for __name in dir(_C._VariableFunctions):
|
||||||
if __name.startswith('__') or __name in PRIVATE_OPS:
|
if __name.startswith('__') or __name in PRIVATE_OPS:
|
||||||
continue
|
continue
|
||||||
__obj = getattr(_C._VariableFunctions, __name)
|
__obj = getattr(_C._VariableFunctions, __name)
|
||||||
__obj.__module__ = __name__
|
__obj.__module__ = __name__ # "torch"
|
||||||
# Hide some APIs that should not be public
|
# Hide some APIs that should not be public
|
||||||
if __name == "segment_reduce":
|
if __name == "segment_reduce":
|
||||||
# TODO: Once the undocumented FC window is passed, remove the line bellow
|
# TODO: Once the undocumented FC window is passed, remove the line bellow
|
||||||
|
|
@ -1751,6 +1742,7 @@ from ._compile import _disable_dynamo
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
||||||
# needs to be after the above ATen bindings so we can overwrite from Python side
|
# needs to be after the above ATen bindings so we can overwrite from Python side
|
||||||
|
from torch import functional as functional
|
||||||
from .functional import * # noqa: F403
|
from .functional import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1769,10 +1761,8 @@ del _LegacyStorage
|
||||||
def _assert(condition, message):
|
def _assert(condition, message):
|
||||||
r"""A wrapper around Python's assert which is symbolically traceable.
|
r"""A wrapper around Python's assert which is symbolically traceable.
|
||||||
"""
|
"""
|
||||||
from .overrides import has_torch_function, handle_torch_function
|
if type(condition) is not torch.Tensor and overrides.has_torch_function((condition,)):
|
||||||
|
return overrides.handle_torch_function(_assert, (condition,), condition, message)
|
||||||
if type(condition) is not torch.Tensor and has_torch_function((condition,)):
|
|
||||||
return handle_torch_function(_assert, (condition,), condition, message)
|
|
||||||
assert condition, message
|
assert condition, message
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
@ -1801,7 +1791,6 @@ from torch import nested as nested
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
from torch.signal import windows as windows
|
from torch.signal import windows as windows
|
||||||
from torch import optim as optim
|
from torch import optim as optim
|
||||||
import torch.optim._multi_tensor
|
|
||||||
from torch import multiprocessing as multiprocessing
|
from torch import multiprocessing as multiprocessing
|
||||||
from torch import sparse as sparse
|
from torch import sparse as sparse
|
||||||
from torch import special as special
|
from torch import special as special
|
||||||
|
|
@ -1809,7 +1798,6 @@ import torch.utils.backcompat
|
||||||
from torch import jit as jit
|
from torch import jit as jit
|
||||||
from torch import linalg as linalg
|
from torch import linalg as linalg
|
||||||
from torch import hub as hub
|
from torch import hub as hub
|
||||||
from torch import random as random
|
|
||||||
from torch import distributions as distributions
|
from torch import distributions as distributions
|
||||||
from torch import testing as testing
|
from torch import testing as testing
|
||||||
from torch import backends as backends
|
from torch import backends as backends
|
||||||
|
|
@ -1817,6 +1805,8 @@ import torch.utils.data
|
||||||
from torch import __config__ as __config__
|
from torch import __config__ as __config__
|
||||||
from torch import __future__ as __future__
|
from torch import __future__ as __future__
|
||||||
from torch import profiler as profiler
|
from torch import profiler as profiler
|
||||||
|
from torch import overrides as overrides
|
||||||
|
from torch import types as types
|
||||||
|
|
||||||
# Quantized, sparse, AO, etc. should be last to get imported, as nothing
|
# Quantized, sparse, AO, etc. should be last to get imported, as nothing
|
||||||
# is expected to depend on them.
|
# is expected to depend on them.
|
||||||
|
|
@ -1827,7 +1817,7 @@ import torch.nn.quantized
|
||||||
import torch.nn.qat
|
import torch.nn.qat
|
||||||
import torch.nn.intrinsic
|
import torch.nn.intrinsic
|
||||||
|
|
||||||
_C._init_names(list(torch._storage_classes))
|
_C._init_names(list(_storage_classes))
|
||||||
|
|
||||||
# attach docstrings to torch and tensor functions
|
# attach docstrings to torch and tensor functions
|
||||||
from . import _torch_docs, _tensor_docs, _storage_docs, _size_docs
|
from . import _torch_docs, _tensor_docs, _storage_docs, _size_docs
|
||||||
|
|
@ -1854,7 +1844,7 @@ from torch import quasirandom as quasirandom
|
||||||
# If you are seeing this, it means that this call site was not checked if
|
# If you are seeing this, it means that this call site was not checked if
|
||||||
# the memory format could be preserved, and it was switched to old default
|
# the memory format could be preserved, and it was switched to old default
|
||||||
# behaviour of contiguous
|
# behaviour of contiguous
|
||||||
legacy_contiguous_format = contiguous_format
|
legacy_contiguous_format = contiguous_format # defined by _C._initExtension()
|
||||||
|
|
||||||
# Register fork handler to initialize OpenMP in child processes (see gh-28389)
|
# Register fork handler to initialize OpenMP in child processes (see gh-28389)
|
||||||
from torch.multiprocessing._atfork import register_after_fork
|
from torch.multiprocessing._atfork import register_after_fork
|
||||||
|
|
@ -1876,7 +1866,7 @@ from torch.utils.dlpack import from_dlpack, to_dlpack
|
||||||
# Import experimental masked operations support. See
|
# Import experimental masked operations support. See
|
||||||
# [RFC-0016](https://github.com/pytorch/rfcs/pull/27) for more
|
# [RFC-0016](https://github.com/pytorch/rfcs/pull/27) for more
|
||||||
# information.
|
# information.
|
||||||
from . import masked
|
from torch import masked as masked
|
||||||
|
|
||||||
# Import removed ops with error message about removal
|
# Import removed ops with error message about removal
|
||||||
from ._linalg_utils import ( # type: ignore[misc]
|
from ._linalg_utils import ( # type: ignore[misc]
|
||||||
|
|
|
||||||
|
|
@ -6,21 +6,22 @@ enough, so that more sophisticated ones can also be easily integrated in the
|
||||||
future.
|
future.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from . import lr_scheduler, swa_utils
|
from torch.optim import lr_scheduler, swa_utils
|
||||||
from .adadelta import Adadelta
|
from torch.optim.adadelta import Adadelta
|
||||||
from .adagrad import Adagrad
|
from torch.optim.adagrad import Adagrad
|
||||||
from .adam import Adam
|
from torch.optim.adam import Adam
|
||||||
from .adamax import Adamax
|
from torch.optim.adamax import Adamax
|
||||||
from .adamw import AdamW
|
from torch.optim.adamw import AdamW
|
||||||
from .asgd import ASGD
|
from torch.optim.asgd import ASGD
|
||||||
from .lbfgs import LBFGS
|
from torch.optim.lbfgs import LBFGS
|
||||||
from .nadam import NAdam
|
from torch.optim.nadam import NAdam
|
||||||
from .optimizer import Optimizer
|
from torch.optim.optimizer import Optimizer
|
||||||
from .radam import RAdam
|
from torch.optim.radam import RAdam
|
||||||
from .rmsprop import RMSprop
|
from torch.optim.rmsprop import RMSprop
|
||||||
from .rprop import Rprop
|
from torch.optim.rprop import Rprop
|
||||||
from .sgd import SGD
|
from torch.optim.sgd import SGD
|
||||||
from .sparse_adam import SparseAdam
|
from torch.optim.sparse_adam import SparseAdam
|
||||||
|
|
||||||
|
|
||||||
del adadelta # type: ignore[name-defined] # noqa: F821
|
del adadelta # type: ignore[name-defined] # noqa: F821
|
||||||
del adagrad # type: ignore[name-defined] # noqa: F821
|
del adagrad # type: ignore[name-defined] # noqa: F821
|
||||||
|
|
@ -36,3 +37,6 @@ del rmsprop # type: ignore[name-defined] # noqa: F821
|
||||||
del optimizer # type: ignore[name-defined] # noqa: F821
|
del optimizer # type: ignore[name-defined] # noqa: F821
|
||||||
del nadam # type: ignore[name-defined] # noqa: F821
|
del nadam # type: ignore[name-defined] # noqa: F821
|
||||||
del lbfgs # type: ignore[name-defined] # noqa: F821
|
del lbfgs # type: ignore[name-defined] # noqa: F821
|
||||||
|
|
||||||
|
|
||||||
|
import torch.optim._multi_tensor
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from typing import Iterator, Iterable, Optional, Sequence, List, TypeVar, Generic, Sized, Union
|
from typing import Iterator, Iterable, Optional, Sequence, List, TypeVar, Generic, Sized, Union
|
||||||
|
|
||||||
|
|
@ -213,7 +212,7 @@ class WeightedRandomSampler(Sampler[int]):
|
||||||
[0, 1, 4, 3, 2]
|
[0, 1, 4, 3, 2]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
weights: Tensor
|
weights: torch.Tensor
|
||||||
num_samples: int
|
num_samples: int
|
||||||
replacement: bool
|
replacement: bool
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user