mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE][Easy] export explicitly imported public submodules (#127703)
Add top-level submodules `torch.{storage,serialization,functional,amp,overrides,types}`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127703
Approved by: https://github.com/ezyang
This commit is contained in:
parent
62311257ad
commit
dcc0093dba
|
|
@ -1,6 +1,4 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
r"""
|
||||
"""
|
||||
The torch package contains data structures for multi-dimensional
|
||||
tensors and defines mathematical operations over these tensors.
|
||||
Additionally, it provides many utilities for efficient serialization of
|
||||
|
|
@ -10,6 +8,8 @@ It has a CUDA counterpart, that enables you to run your tensor computations
|
|||
on an NVIDIA GPU with compute capability >= 3.0.
|
||||
"""
|
||||
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
|
|
@ -289,10 +289,6 @@ else:
|
|||
_load_global_deps()
|
||||
from torch._C import * # noqa: F403
|
||||
|
||||
# Appease the type checker; ordinarily this binding is inserted by the
|
||||
# torch._C module initialization code in C
|
||||
if TYPE_CHECKING:
|
||||
from . import _C as _C # noqa: TCH004
|
||||
|
||||
class SymInt:
|
||||
"""
|
||||
|
|
@ -614,10 +610,9 @@ def sym_not(a):
|
|||
a (SymBool or bool): Object to negate
|
||||
"""
|
||||
import sympy
|
||||
from .overrides import has_torch_function_unary, handle_torch_function
|
||||
|
||||
if has_torch_function_unary(a):
|
||||
return handle_torch_function(sym_not, (a,), a)
|
||||
if overrides.has_torch_function_unary(a):
|
||||
return overrides.handle_torch_function(sym_not, (a,), a)
|
||||
if hasattr(a, '__sym_not__'):
|
||||
return a.__sym_not__()
|
||||
if isinstance(a, sympy.Basic):
|
||||
|
|
@ -630,10 +625,8 @@ def sym_float(a):
|
|||
Args:
|
||||
a (SymInt, SymFloat, or object): Object to cast
|
||||
"""
|
||||
from .overrides import has_torch_function_unary, handle_torch_function
|
||||
|
||||
if has_torch_function_unary(a):
|
||||
return handle_torch_function(sym_float, (a,), a)
|
||||
if overrides.has_torch_function_unary(a):
|
||||
return overrides.handle_torch_function(sym_float, (a,), a)
|
||||
if isinstance(a, SymFloat):
|
||||
return a
|
||||
elif hasattr(a, '__sym_float__'):
|
||||
|
|
@ -647,10 +640,8 @@ def sym_int(a):
|
|||
Args:
|
||||
a (SymInt, SymFloat, or object): Object to cast
|
||||
"""
|
||||
from .overrides import has_torch_function_unary, handle_torch_function
|
||||
|
||||
if has_torch_function_unary(a):
|
||||
return handle_torch_function(sym_int, (a,), a)
|
||||
if overrides.has_torch_function_unary(a):
|
||||
return overrides.handle_torch_function(sym_int, (a,), a)
|
||||
if isinstance(a, SymInt):
|
||||
return a
|
||||
elif isinstance(a, SymFloat):
|
||||
|
|
@ -664,10 +655,8 @@ def sym_max(a, b):
|
|||
promotes to float if any argument is float (unlike builtins.max, which
|
||||
will faithfully preserve the type of the input argument).
|
||||
"""
|
||||
from .overrides import has_torch_function, handle_torch_function
|
||||
|
||||
if has_torch_function((a, b)):
|
||||
return handle_torch_function(sym_max, (a, b), a, b)
|
||||
if overrides.has_torch_function((a, b)):
|
||||
return overrides.handle_torch_function(sym_max, (a, b), a, b)
|
||||
if isinstance(a, (SymInt, SymFloat)):
|
||||
return a.__sym_max__(b)
|
||||
elif isinstance(b, (SymInt, SymFloat)):
|
||||
|
|
@ -683,11 +672,9 @@ def sym_max(a, b):
|
|||
return builtins.max(a, b)
|
||||
|
||||
def sym_min(a, b):
|
||||
""" SymInt-aware utility for min()."""
|
||||
from .overrides import has_torch_function, handle_torch_function
|
||||
|
||||
if has_torch_function((a, b)):
|
||||
return handle_torch_function(sym_min, (a, b), a, b)
|
||||
"""SymInt-aware utility for min()."""
|
||||
if overrides.has_torch_function((a, b)):
|
||||
return overrides.handle_torch_function(sym_min, (a, b), a, b)
|
||||
if isinstance(a, (SymInt, SymFloat)):
|
||||
return a.__sym_min__(b)
|
||||
elif isinstance(b, (SymInt, SymFloat)):
|
||||
|
|
@ -702,10 +689,8 @@ def sym_min(a, b):
|
|||
# Drop in replacement for math.sqrt, math.sin, math.cos etc
|
||||
def _get_sym_math_fn(name):
|
||||
def fn(a):
|
||||
from .overrides import has_torch_function_unary, handle_torch_function
|
||||
|
||||
if has_torch_function_unary(a):
|
||||
return handle_torch_function(fn, (a,), a)
|
||||
if overrides.has_torch_function_unary(a):
|
||||
return overrides.handle_torch_function(fn, (a,), a)
|
||||
if hasattr(a, f"__sym_{name}__"):
|
||||
return getattr(a, f"__sym_{name}__")()
|
||||
return getattr(math, name)(a)
|
||||
|
|
@ -727,10 +712,8 @@ __all__.append("sym_sqrt")
|
|||
|
||||
|
||||
def sym_ite(b, t, f):
|
||||
from .overrides import has_torch_function, handle_torch_function
|
||||
|
||||
if has_torch_function((b, t, f)):
|
||||
return handle_torch_function(sym_ite, (b, t, f), b, t, f)
|
||||
if overrides.has_torch_function((b, t, f)):
|
||||
return overrides.handle_torch_function(sym_ite, (b, t, f), b, t, f)
|
||||
assert isinstance(b, (SymBool, builtins.bool)) and type(t) == type(f)
|
||||
if isinstance(b, SymBool):
|
||||
return b.__sym_ite__(t, f)
|
||||
|
|
@ -760,16 +743,20 @@ except ImportError:
|
|||
''').strip()) from None
|
||||
raise # If __file__ is not None the cause is unknown, so just re-raise.
|
||||
|
||||
# The torch._C submodule is already loaded via `from torch._C import *` above
|
||||
# Make an explicit reference to the _C submodule to appease linters
|
||||
from torch import _C as _C
|
||||
|
||||
__name, __obj = '', None
|
||||
for __name in dir(_C):
|
||||
if __name[0] != '_' and not __name.endswith('Base'):
|
||||
__all__.append(__name)
|
||||
__obj = getattr(_C, __name)
|
||||
if callable(__obj) or inspect.isclass(__obj):
|
||||
if __obj.__module__ != __name__:
|
||||
if __obj.__module__ != __name__: # "torch"
|
||||
# TODO: fix their module from C++ side
|
||||
if __name not in ['DisableTorchFunctionSubclass', 'DisableTorchFunction', 'Generator']:
|
||||
__obj.__module__ = __name__
|
||||
__obj.__module__ = __name__ # "torch"
|
||||
elif __name == 'TensorBase':
|
||||
# issue 109438 / pr 109940. Prevent TensorBase from being copied into torch.
|
||||
delattr(sys.modules[__name__], __name)
|
||||
|
|
@ -1478,6 +1465,7 @@ __all__.extend(['e', 'pi', 'nan', 'inf', 'newaxis'])
|
|||
################################################################################
|
||||
|
||||
from ._tensor import Tensor
|
||||
from torch import storage as storage
|
||||
from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage, _warn_typed_storage_removal
|
||||
|
||||
# NOTE: New <type>Storage classes should never be added. When adding a new
|
||||
|
|
@ -1665,7 +1653,9 @@ _storage_classes = {
|
|||
_tensor_classes: Set[Type] = set()
|
||||
|
||||
# If you edit these imports, please update torch/__init__.py.in as well
|
||||
from torch import random as random
|
||||
from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed
|
||||
from torch import serialization as serialization
|
||||
from .serialization import save, load
|
||||
from ._tensor_str import set_printoptions
|
||||
|
||||
|
|
@ -1682,6 +1672,7 @@ def _manager_path():
|
|||
raise RuntimeError("Unable to find torch_shm_manager at " + path)
|
||||
return path.encode('utf-8')
|
||||
|
||||
from torch import amp as amp
|
||||
from torch.amp import autocast, GradScaler
|
||||
|
||||
# Initializing the extension shadows the built-in python float / int classes;
|
||||
|
|
@ -1717,7 +1708,7 @@ for __name in dir(_C._VariableFunctions):
|
|||
if __name.startswith('__') or __name in PRIVATE_OPS:
|
||||
continue
|
||||
__obj = getattr(_C._VariableFunctions, __name)
|
||||
__obj.__module__ = __name__
|
||||
__obj.__module__ = __name__ # "torch"
|
||||
# Hide some APIs that should not be public
|
||||
if __name == "segment_reduce":
|
||||
# TODO: Once the undocumented FC window is passed, remove the line bellow
|
||||
|
|
@ -1751,6 +1742,7 @@ from ._compile import _disable_dynamo
|
|||
################################################################################
|
||||
|
||||
# needs to be after the above ATen bindings so we can overwrite from Python side
|
||||
from torch import functional as functional
|
||||
from .functional import * # noqa: F403
|
||||
|
||||
|
||||
|
|
@ -1769,10 +1761,8 @@ del _LegacyStorage
|
|||
def _assert(condition, message):
|
||||
r"""A wrapper around Python's assert which is symbolically traceable.
|
||||
"""
|
||||
from .overrides import has_torch_function, handle_torch_function
|
||||
|
||||
if type(condition) is not torch.Tensor and has_torch_function((condition,)):
|
||||
return handle_torch_function(_assert, (condition,), condition, message)
|
||||
if type(condition) is not torch.Tensor and overrides.has_torch_function((condition,)):
|
||||
return overrides.handle_torch_function(_assert, (condition,), condition, message)
|
||||
assert condition, message
|
||||
|
||||
################################################################################
|
||||
|
|
@ -1801,7 +1791,6 @@ from torch import nested as nested
|
|||
from torch import nn as nn
|
||||
from torch.signal import windows as windows
|
||||
from torch import optim as optim
|
||||
import torch.optim._multi_tensor
|
||||
from torch import multiprocessing as multiprocessing
|
||||
from torch import sparse as sparse
|
||||
from torch import special as special
|
||||
|
|
@ -1809,7 +1798,6 @@ import torch.utils.backcompat
|
|||
from torch import jit as jit
|
||||
from torch import linalg as linalg
|
||||
from torch import hub as hub
|
||||
from torch import random as random
|
||||
from torch import distributions as distributions
|
||||
from torch import testing as testing
|
||||
from torch import backends as backends
|
||||
|
|
@ -1817,6 +1805,8 @@ import torch.utils.data
|
|||
from torch import __config__ as __config__
|
||||
from torch import __future__ as __future__
|
||||
from torch import profiler as profiler
|
||||
from torch import overrides as overrides
|
||||
from torch import types as types
|
||||
|
||||
# Quantized, sparse, AO, etc. should be last to get imported, as nothing
|
||||
# is expected to depend on them.
|
||||
|
|
@ -1827,7 +1817,7 @@ import torch.nn.quantized
|
|||
import torch.nn.qat
|
||||
import torch.nn.intrinsic
|
||||
|
||||
_C._init_names(list(torch._storage_classes))
|
||||
_C._init_names(list(_storage_classes))
|
||||
|
||||
# attach docstrings to torch and tensor functions
|
||||
from . import _torch_docs, _tensor_docs, _storage_docs, _size_docs
|
||||
|
|
@ -1854,7 +1844,7 @@ from torch import quasirandom as quasirandom
|
|||
# If you are seeing this, it means that this call site was not checked if
|
||||
# the memory format could be preserved, and it was switched to old default
|
||||
# behaviour of contiguous
|
||||
legacy_contiguous_format = contiguous_format
|
||||
legacy_contiguous_format = contiguous_format # defined by _C._initExtension()
|
||||
|
||||
# Register fork handler to initialize OpenMP in child processes (see gh-28389)
|
||||
from torch.multiprocessing._atfork import register_after_fork
|
||||
|
|
@ -1876,7 +1866,7 @@ from torch.utils.dlpack import from_dlpack, to_dlpack
|
|||
# Import experimental masked operations support. See
|
||||
# [RFC-0016](https://github.com/pytorch/rfcs/pull/27) for more
|
||||
# information.
|
||||
from . import masked
|
||||
from torch import masked as masked
|
||||
|
||||
# Import removed ops with error message about removal
|
||||
from ._linalg_utils import ( # type: ignore[misc]
|
||||
|
|
|
|||
|
|
@ -6,21 +6,22 @@ enough, so that more sophisticated ones can also be easily integrated in the
|
|||
future.
|
||||
"""
|
||||
|
||||
from . import lr_scheduler, swa_utils
|
||||
from .adadelta import Adadelta
|
||||
from .adagrad import Adagrad
|
||||
from .adam import Adam
|
||||
from .adamax import Adamax
|
||||
from .adamw import AdamW
|
||||
from .asgd import ASGD
|
||||
from .lbfgs import LBFGS
|
||||
from .nadam import NAdam
|
||||
from .optimizer import Optimizer
|
||||
from .radam import RAdam
|
||||
from .rmsprop import RMSprop
|
||||
from .rprop import Rprop
|
||||
from .sgd import SGD
|
||||
from .sparse_adam import SparseAdam
|
||||
from torch.optim import lr_scheduler, swa_utils
|
||||
from torch.optim.adadelta import Adadelta
|
||||
from torch.optim.adagrad import Adagrad
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.adamax import Adamax
|
||||
from torch.optim.adamw import AdamW
|
||||
from torch.optim.asgd import ASGD
|
||||
from torch.optim.lbfgs import LBFGS
|
||||
from torch.optim.nadam import NAdam
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.optim.radam import RAdam
|
||||
from torch.optim.rmsprop import RMSprop
|
||||
from torch.optim.rprop import Rprop
|
||||
from torch.optim.sgd import SGD
|
||||
from torch.optim.sparse_adam import SparseAdam
|
||||
|
||||
|
||||
del adadelta # type: ignore[name-defined] # noqa: F821
|
||||
del adagrad # type: ignore[name-defined] # noqa: F821
|
||||
|
|
@ -36,3 +37,6 @@ del rmsprop # type: ignore[name-defined] # noqa: F821
|
|||
del optimizer # type: ignore[name-defined] # noqa: F821
|
||||
del nadam # type: ignore[name-defined] # noqa: F821
|
||||
del lbfgs # type: ignore[name-defined] # noqa: F821
|
||||
|
||||
|
||||
import torch.optim._multi_tensor
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from typing import Iterator, Iterable, Optional, Sequence, List, TypeVar, Generic, Sized, Union
|
||||
|
||||
|
|
@ -213,7 +212,7 @@ class WeightedRandomSampler(Sampler[int]):
|
|||
[0, 1, 4, 3, 2]
|
||||
"""
|
||||
|
||||
weights: Tensor
|
||||
weights: torch.Tensor
|
||||
num_samples: int
|
||||
replacement: bool
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user