diff --git a/torch/_C/_functions.pyi b/torch/_C/_functions.pyi index 38a6bb8b39f..422e59984d0 100644 --- a/torch/_C/_functions.pyi +++ b/torch/_C/_functions.pyi @@ -1,4 +1,4 @@ -from typing import AnyStr +from typing import AnyStr, overload, Tuple from torch import Tensor @@ -8,4 +8,12 @@ class UndefinedGrad: class DelayedError: def __init__(self, msg: AnyStr, num_inputs: int) -> None: ... - def __call__(self, inputs: list[Tensor]) -> list[Tensor]: ... + + # __call__ should really be a higher-kinded type: + # def __call__(self, arg: Tensor) -> Tensor: ... + # def __call__(self, *args: Tensor * num_inputs) -> Tuple[Tensor * num_inputs]: ... + + @overload + def __call__(self, i0: Tensor) -> Tensor: ... + @overload + def __call__(self, *args: Tensor) -> Tuple[Tensor, ...]: ... diff --git a/torch/_inductor/freezing.py b/torch/_inductor/freezing.py index afc3b0fb5fe..5a33484b5a3 100644 --- a/torch/_inductor/freezing.py +++ b/torch/_inductor/freezing.py @@ -166,7 +166,7 @@ def invalidate_eager_modules(): e_t = ErasedTensor(tensor, attr_name, mod) if isinstance(tensor, torch.nn.Parameter): e_t.requires_grad_(True) - e_t._is_param = True # type: ignore[attr-defined] + e_t._is_param = True setattr(mod, attr_name, e_t) @@ -181,7 +181,7 @@ def discard_traced_gm_params(mod: torch.fx.GraphModule): e_t = ErasedTensor(tensor, attr_name, mod) if isinstance(tensor, torch.nn.Parameter): e_t.requires_grad_(True) - e_t._is_param = True # type: ignore[attr-defined] + e_t._is_param = True setattr(mod, attr_name, e_t) diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index bbea8710a07..6815f3e5ef9 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -1,8 +1,8 @@ -# mypy: allow-untyped-defs from __future__ import annotations import contextlib import dataclasses +import typing import warnings import weakref from dataclasses import dataclass @@ -12,14 +12,18 @@ from typing import ( ClassVar, ContextManager, Dict, + Generic, List, + NewType, Optional, + Set, Tuple, Type, TYPE_CHECKING, + TypeVar, Union, ) -from typing_extensions import TypeAlias, TypeGuard +from typing_extensions import TypeGuard import torch from torch._C._autograd import CreationMeta @@ -47,16 +51,17 @@ if TYPE_CHECKING: from torch._guards import Source # Import here to avoid cycle - from torch._subclasses.fake_tensor import FakeTensorMode - # Import the following modules during type checking to enable code intelligence features, # Do not import unconditionally, as they import sympy and importing sympy is very slow from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext DimList = List +_TensorLikeT = TypeVar("_TensorLikeT", "MetaTensorDesc", torch.Tensor) +_T = TypeVar("_T") +_TensorT = TypeVar("_TensorT", bound=torch.Tensor) -def safe_is_leaf(t): +def safe_is_leaf(t: Union[MetaTensorDesc, torch.Tensor]) -> bool: try: return t.is_leaf except RuntimeError: @@ -64,28 +69,37 @@ def safe_is_leaf(t): return False -def safe_grad(t): +def safe_grad(t: _TensorLikeT) -> Optional[_TensorLikeT]: with warnings.catch_warnings(): warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") return t.grad -def assert_eq(a, b): +def _expect_safe_grad(t: _TensorLikeT) -> _TensorLikeT: + grad = safe_grad(t) + assert grad is not None + return grad + + +def assert_eq(a: _T, b: _T) -> None: assert a == b, f"{a} != {b}" def assert_metadata_eq( - assert_eq, + assert_eq: Callable[[object, object], None], m1: Union[MetaTensorDesc, torch.Tensor], m2: torch.Tensor, *, - skip_symbolic=False, - skip_leaf=False, -): - if isinstance(m1, torch.Tensor): - m1 = MetaTensorDescriber().describe_tensor(m1) + skip_symbolic: bool = False, + skip_leaf: bool = False, +) -> None: + m1 = ( + MetaTensorDescriber().describe_tensor(m1) + if isinstance(m1, torch.Tensor) + else m1 + ) - def go(m1, m2): + def go(m1: MetaTensorDesc, m2: torch.Tensor) -> None: assert_eq(m1.dtype, m2.dtype) if not skip_symbolic: assert_eq(m1.shape, m2.shape) @@ -100,7 +114,7 @@ def assert_metadata_eq( assert_eq(m1.is_neg, m2.is_neg()) assert_eq(m1.grad is not None, safe_grad(m2) is not None) if m1.grad is not None: - go(m1.grad, safe_grad(m2)) + go(m1.grad, _expect_safe_grad(m2)) # TODO: move "assert_eq(m1.layout, m2.layout)" out of sparse # branches (but not ready for prime time yet)... if m1.is_sparse: @@ -118,6 +132,8 @@ def assert_metadata_eq( assert_eq(m1.storage_offset, m2.storage_offset()) assert_eq(m1.is_view, m2._is_view()) if m1.is_view: + assert m1.base is not None + assert m2._base is not None go(m1.base, m2._base) # TODO: test if is resizable (no direct query for this atm) # TODO: audit AutogradMeta to see if it matches @@ -126,11 +142,12 @@ def assert_metadata_eq( return go(m1, m2) -def is_sparse_coo(t): +# TypeGuard (not TypeIs): False does not imply !torch.Tensor +def is_sparse_coo(t: object) -> TypeGuard[torch.Tensor]: return isinstance(t, torch.Tensor) and t.layout is torch.sparse_coo -def is_sparse_compressed_layout(layout): +def is_sparse_compressed_layout(layout: torch.layout) -> bool: return layout in { torch.sparse_csr, torch.sparse_csc, @@ -139,20 +156,38 @@ def is_sparse_compressed_layout(layout): } -def is_sparse_compressed(t): +# TypeGuard (not TypeIs): False does not imply !torch.Tensor +def is_sparse_compressed(t: object) -> TypeGuard[torch.Tensor]: return isinstance(t, torch.Tensor) and is_sparse_compressed_layout(t.layout) +# TypeGuard (not TypeIs): False does not imply !torch.Tensor def is_sparse_any(t: object) -> TypeGuard[torch.Tensor]: return is_sparse_coo(t) or is_sparse_compressed(t) +def _checked_cast(ty: Type[_T], obj: object) -> _T: + assert isinstance(obj, ty), f"expected {ty} but got {type(obj)}" + return obj + + +def _get_real_storage(base: torch.UntypedStorage) -> torch.UntypedStorage: + return base.real_storage # type: ignore[attr-defined] + + +def _set_real_storage( + base: torch.UntypedStorage, real_storage: torch.UntypedStorage +) -> None: + base.real_storage = real_storage # type: ignore[attr-defined] + + # Don't use id() directly, because those can get reallocated over time. -MetaStorageId: TypeAlias = int -MetaTensorId: TypeAlias = int +MetaStorageId = NewType("MetaStorageId", int) +MetaTensorId = NewType("MetaTensorId", int) -DESCRIBER_NEXT_ID = 0 +_DescriberId = NewType("_DescriberId", int) +DESCRIBER_NEXT_ID = _DescriberId(0) class MetaTensorDescriber: @@ -166,33 +201,35 @@ class MetaTensorDescriber: the same ID when we see the same tensor/storage. """ - def __init__(self, *, copy_data=False): + def __init__(self, *, copy_data: bool = False) -> None: global DESCRIBER_NEXT_ID self.id = DESCRIBER_NEXT_ID - DESCRIBER_NEXT_ID += 1 - self.next_tensor_id: MetaTensorId = 0 - self.next_storage_id: MetaStorageId = 0 + DESCRIBER_NEXT_ID = _DescriberId(DESCRIBER_NEXT_ID + 1) + self.next_tensor_id: MetaTensorId = MetaTensorId(0) + self.next_storage_id: MetaStorageId = MetaStorageId(0) # Tensor -> int self.lookup_tensor = WeakIdKeyDictionary() # Storage -> int self.lookup_storage = WeakIdKeyDictionary() self.copy_data = copy_data - self.traced_tensors = set() - self.traced_storages = set() + self.traced_tensors: Set[int] = set() + self.traced_storages: Set[int] = set() - def get_tensor_id(self, t: torch.Tensor): + def get_tensor_id(self, t: torch.Tensor) -> MetaTensorId: if t not in self.lookup_tensor: self.lookup_tensor[t] = self.next_tensor_id - self.next_tensor_id += 1 + self.next_tensor_id = MetaTensorId(self.next_tensor_id + 1) return self.lookup_tensor[t] - def get_storage_id(self, s: torch.UntypedStorage): + def get_storage_id(self, s: torch.UntypedStorage) -> MetaStorageId: if s not in self.lookup_storage: self.lookup_storage[s] = self.next_storage_id - self.next_storage_id += 1 + self.next_storage_id = MetaStorageId(self.next_storage_id + 1) return self.lookup_storage[s] - def describe_storage(self, s: torch.UntypedStorage, *, trace: bool = False): + def describe_storage( + self, s: torch.UntypedStorage, *, trace: bool = False + ) -> MetaStorageDesc: r = MetaStorageDesc( id=self.get_storage_id(s), size=s.size(), @@ -210,7 +247,7 @@ class MetaTensorDescriber: def describe_tensor( self, t: torch.Tensor, *, recurse: bool = True, trace: bool = False - ): + ) -> MetaTensorDesc: is_leaf = safe_is_leaf(t) is_view = t._is_view() is_sparse = t.is_sparse @@ -381,8 +418,8 @@ class MetaTensorDescriber: else None ), grad=( - self.describe_tensor(safe_grad(t), trace=trace) - if safe_grad(t) is not None + self.describe_tensor(grad, trace=trace) + if (grad := safe_grad(t)) is not None else None ), creation_meta=( @@ -430,7 +467,7 @@ class MetaStorageDesc: # serializable in JSON, you want to do something special here anyway data: Optional[torch.UntypedStorage] - def as_json(self, describer_id): + def as_json(self, describer_id: _DescriberId) -> Dict[str, object]: return { "id": self.id, "describer_id": describer_id, @@ -439,7 +476,7 @@ class MetaStorageDesc: @dataclass(frozen=True) -class MetaTensorDesc: +class MetaTensorDesc(Generic[_TensorT]): id: MetaTensorId ndim: int dtype: torch.dtype @@ -520,15 +557,15 @@ class MetaTensorDesc: ctx: Optional[object] = None # is_traceable_wrapper_subclass type: Optional[Type] = None # is_traceable_wrapper_subclass - fake_mode: Optional[FakeTensorMode] = None + fake_mode: Optional[torch._subclasses.fake_tensor.FakeTensorMode] = None view_func: Optional[ Callable[ [ torch.Tensor, Callable[[int], int], - Callable[[torch.Tensor], torch.Tensor], + Callable[[torch.Tensor], _TensorT], ], - torch.Tensor, + _TensorT, ] ] = None # level looks serializable, but actually it is meaningless without @@ -555,8 +592,8 @@ class MetaTensorDesc: # NB: This will reference numeric IDs, and it is assumed that you've # already serialized everything this recursively references - def as_json(self, describer_id): - def json(k, v): + def as_json(self, describer_id: _DescriberId) -> Dict[str, object]: + def json(k: str, v: object) -> object: # Some best-effort debugging serialization for unserializable # fields (feel free to add other special cases as appropriate) if k in ["data", "autograd_meta_from"]: @@ -592,7 +629,7 @@ class MetaTensorDesc: return r @property - def shape(self): + def shape(self) -> Tuple[int, ...]: return self.size @@ -608,13 +645,13 @@ class MetaTensorDesc: # FakeTensor as src, we MUST NOT run the copy/clone operation. A better way # to do this would be to not use no_dispatch and instead just disable fake # tensor mode only (allowing for subclass dispatch to occur) -def _safe_copy(dst, src): +def _safe_copy(dst: torch.Tensor, src: Optional[torch.Tensor]) -> None: if type(src) is not torch.Tensor: return dst.copy_(src) -def _safe_clone(src): +def _safe_clone(src: torch.Tensor) -> Optional[torch.Tensor]: if type(src) is not torch.Tensor: return None return src.clone() @@ -627,13 +664,17 @@ def _safe_clone(src): # share storage because this is how we correlate shared storages to the same # meta storages. This class will hold weak references to cached tenosrs # and tensor storages. -class MetaConverter: - def __init__(self, *, copy_data: bool = False): +class MetaConverter(Generic[_TensorT]): + def __init__(self, *, copy_data: bool = False) -> None: # Maps MetaStorageId to UntypedStorage - self.storage_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + self.storage_memo: weakref.WeakValueDictionary[ + MetaStorageId, torch.UntypedStorage + ] = weakref.WeakValueDictionary() # Maps MetaTensorId to torch.Tensor (typically a meta tensor or # FakeTensor) - self.tensor_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + self.tensor_memo: weakref.WeakValueDictionary[ + MetaTensorId, _TensorT + ] = weakref.WeakValueDictionary() self.hit = 0 self.miss = 0 self.del_hook = None @@ -645,25 +686,34 @@ class MetaConverter: self.copy_data = copy_data self.describer = MetaTensorDescriber(copy_data=copy_data) - def successful(self): + def successful(self) -> bool: return self.hit > 0 and self.miss == 0 - def get_tensor_memo(self, t: MetaTensorDesc): + def get_tensor_memo(self, t: MetaTensorDesc) -> Optional[torch.Tensor]: return self.tensor_memo.get(t.id, None) - def set_tensor_memo(self, t: MetaTensorDesc, v): + def _checked_get_tensor_memo(self, t: MetaTensorDesc) -> _TensorT: + r = self.tensor_memo.get(t.id, None) + assert r is not None + return r + + def set_tensor_memo(self, t: MetaTensorDesc, v: _TensorT) -> None: self.tensor_memo[t.id] = v - def get_storage_memo(self, s: MetaStorageDesc): + def get_storage_memo(self, s: MetaStorageDesc) -> Optional[torch.UntypedStorage]: return self.storage_memo.get(s.id, None) - def set_storage_memo(self, s: MetaStorageDesc, v): + def set_storage_memo(self, s: MetaStorageDesc, v: torch.UntypedStorage) -> None: self.storage_memo[s.id] = v - def meta_storage(self, s: MetaStorageDesc, callback): + def meta_storage( + self, + s: MetaStorageDesc, + callback: Callable[[Callable[[], torch.Tensor]], _TensorT], + ) -> torch.UntypedStorage: # If we are fakeifying a tensor that has a secretly-zero-sized storage, # Need to make sure to resize the meta storage too. - if self.get_storage_memo(s) is None: + if (memo := self.get_storage_memo(s)) is None: r_s = callback( lambda: torch.empty(s.size, dtype=torch.uint8, device="meta"), ).untyped_storage() @@ -672,11 +722,29 @@ class MetaConverter: # implemented as Tensor operations with torch.no_grad(), no_dispatch(): assert s.data is not None - r_s.real_storage = s.data.clone() + _set_real_storage(r_s, s.data.clone()) self.set_storage_memo(s, r_s) return r_s else: - return self.get_storage_memo(s) + return memo + + @classmethod + def _checked_cast_tensor_t(cls, t: torch.Tensor) -> _TensorT: + # TODO: how to check _TensorT? + return typing.cast(_TensorT, t) + + @classmethod + def _identity_callable(cls, t: Callable[[], torch.Tensor]) -> _TensorT: + return cls._checked_cast_tensor_t(t()) + + @classmethod + def _backward_error(cls, t: _TensorT) -> _TensorT: + errfn = torch._C._functions.DelayedError( + "Internal error: Tried to backward() through example input", + 1, + ) + err = errfn(t) + return typing.cast(_TensorT, err) # This function assumes that it's possible to do the conversion # NB: name here is used in a conventional way by Dynamo; it corresponds @@ -687,11 +755,11 @@ class MetaConverter: def meta_tensor( self, t: MetaTensorDesc, - shape_env: Optional[ShapeEnv] = None, - callback=lambda t: t(), - source: Optional[Source] = None, - symbolic_context: Optional[SymbolicContext] = None, - ): + shape_env: Optional[ShapeEnv], + callback: Callable[[Callable[[], torch.Tensor]], _TensorT], + source: Optional[Source], + symbolic_context: Optional[SymbolicContext], + ) -> _TensorT: if source is None: from torch._dynamo.source import ConstantSource @@ -739,7 +807,11 @@ class MetaConverter: maybe_suppress = shape_env.suppress_guards def sym_sizes_strides_storage_offset( - t: MetaTensorDesc, src, symbolic_context=symbolic_context + t: MetaTensorDesc, + src: torch._guards.Source, + symbolic_context: Optional[ + torch.fx.experimental.symbolic_shapes.SymbolicContext + ] = symbolic_context, ) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]: assert t.stride is not None if shape_env is not None: @@ -773,8 +845,12 @@ class MetaConverter: return (t.size, t.stride, t.storage_offset) def empty_create( - inner_t: MetaTensorDesc, inner_src, symbolic_context=symbolic_context - ): + inner_t: MetaTensorDesc, + inner_src: torch._guards.Source, + symbolic_context: Optional[ + torch.fx.experimental.symbolic_shapes.SymbolicContext + ] = symbolic_context, + ) -> torch.Tensor: ( inner_sizes, inner_strides, @@ -791,12 +867,13 @@ class MetaConverter: # symbolic context. def empty_create_subclass( t: MetaTensorDesc, - outer_size, - outer_stride, - symbolic_context=symbolic_context, - callback=callback, - source=source, - ): + outer_size: Tuple[int, ...], + outer_stride: Tuple[int, ...], + symbolic_context: Optional[ + torch.fx.experimental.symbolic_shapes.SymbolicContext + ] = symbolic_context, + source: Optional[torch._guards.Source] = source, + ) -> _TensorT: from torch._dynamo.source import AttrSource from torch.fx.experimental.symbolic_shapes import SubclassSymbolicContext @@ -822,24 +899,38 @@ class MetaConverter: ) def _empty_create_subclass( - t, outer_size, outer_stride, symbolic_context, callback, source - ): + t: MetaTensorDesc, + outer_size: Optional[Tuple[int, ...]], + outer_stride: Optional[Tuple[int, ...]], + symbolic_context: Optional[ + torch.fx.experimental.symbolic_shapes.SymbolicContext + ], + callback: Callable[[Callable[[], torch.Tensor]], _TensorT], + source: torch._guards.Source, + ) -> _TensorT: # We are hitting plain meta_desc tensor so actually # create a tensor here. if t.attrs is None: return self.meta_tensor( t, - shape_env=shape_env, - callback=callback, - source=source, - symbolic_context=symbolic_context, + shape_env, + callback, + source, + symbolic_context, ) inner_tensors = {} for attr, meta_tensor_desc in t.attrs.items(): current_context = None if symbolic_context is not None: - current_context = symbolic_context.inner_contexts[attr] + assert isinstance(symbolic_context, SubclassSymbolicContext) + if ( + current_context_ := symbolic_context.inner_contexts[attr] + ) is not None: + current_context = _checked_cast( + torch.fx.experimental.symbolic_shapes.SymbolicContext, + current_context_, + ) current_source = AttrSource(source, attr) new_empty_tensor = _empty_create_subclass( @@ -852,10 +943,12 @@ class MetaConverter: ) inner_tensors[attr] = new_empty_tensor + assert t.type is not None return t.type.__tensor_unflatten__( inner_tensors, t.ctx, outer_size, outer_stride ) + assert source is not None sub = _empty_create_subclass( t, outer_size, outer_stride, symbolic_context, callback, source ) @@ -879,8 +972,11 @@ class MetaConverter: # closed-over ViewFunc state, as we don't have symbolic contexts for them, but we # don't want to over-specialize during view replay. def all_dynamic_symbolic_context( - t: MetaTensorDesc, source, shape_env, callback - ): + t: MetaTensorDesc, + source: torch._guards.Source, + shape_env: Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv], + callback: Callable[[Callable[[], torch.Tensor]], _TensorT], + ) -> torch.fx.experimental.symbolic_shapes.SymbolicContext: from torch._dynamo.source import AttrSource from torch.fx.experimental.symbolic_shapes import ( DimDynamic, @@ -888,18 +984,22 @@ class MetaConverter: SubclassSymbolicContext, ) - view_base_context: Optional[SymbolicContext] = None + view_base_context: Optional[ + torch.fx.experimental.symbolic_shapes.SymbolicContext + ] = None if t.is_view: assert t.base is not None view_base_context = all_dynamic_symbolic_context( t.base, AttrSource(source, "_base"), shape_env, callback ) - t_symbolic_context: SymbolicContext + t_symbolic_context: torch.fx.experimental.symbolic_shapes.SymbolicContext t_dynamic_sizes = [DimDynamic.DYNAMIC] * t.ndim if t.is_traceable_wrapper_subclass: assert t.attrs is not None - inner_contexts: Dict[str, SymbolicContext] = {} + inner_contexts: Dict[ + str, torch.fx.experimental.symbolic_shapes.SymbolicContext + ] = {} for attr, inner in t.attrs.items(): assert isinstance(attr, str) inner_contexts[attr] = all_dynamic_symbolic_context( @@ -951,8 +1051,12 @@ class MetaConverter: # Then view replay is done, swapping in the fake offsets so the view replay output # is fully fake with no invalid specialization. def view_from_base( - base: torch.Tensor, t: MetaTensorDesc, source=source, shape_env=shape_env - ): + base: _TensorT, + t: MetaTensorDesc, + shape_env: Optional[ + torch.fx.experimental.symbolic_shapes.ShapeEnv + ] = shape_env, + ) -> _TensorT: # fake-ify t's metadata according to the outer symbolic context (sizes, strides, storage_offset) = sym_sizes_strides_storage_offset( t, source @@ -965,7 +1069,9 @@ class MetaConverter: # TODO: Change this logic to use view replay for consistency? # It's likely there is no view func available. with maybe_suppress(): - return base.as_strided(sizes, strides, storage_offset) + return self._checked_cast_tensor_t( + base.as_strided(sizes, strides, storage_offset) + ) from torch._dynamo.source import EphemeralSource from torch.fx.experimental.symbolic_shapes import ( @@ -973,7 +1079,7 @@ class MetaConverter: sym_eq, ) - def symint_visitor_fn(s): + def symint_visitor_fn(s: int) -> int: nonlocal symbolic_context from torch.fx.experimental.symbolic_shapes import DimDynamic @@ -1017,10 +1123,10 @@ class MetaConverter: # want a view of values with the offsets closed over. As the offsets component # is needed to describe the output view, it's important that it's fakeified # correctly. - fake_t = empty_create_subclass( + fake_t: _TensorT = empty_create_subclass( t, outer_size=sizes, outer_stride=strides ) - attrs, _ = fake_t.__tensor_flatten__() + attrs, _ = fake_t.__tensor_flatten__() # type: ignore[attr-defined] for attr in attrs: real_to_fake_mapping[t.attrs[attr].id] = getattr(fake_t, attr) @@ -1028,9 +1134,11 @@ class MetaConverter: visited_t: torch.Tensor, # These arguments are never passed, we just use them to close # over these relevant values - shape_env=shape_env, - callback=callback, - ): + shape_env: Optional[ + torch.fx.experimental.symbolic_shapes.ShapeEnv + ] = shape_env, + callback: Callable[[Callable[[], torch.Tensor]], _TensorT] = callback, # type: ignore[assignment] + ) -> torch.Tensor: # It's possible to close over an undefined tensor (e.g. NJT's lengths). if visited_t is None: return None @@ -1057,8 +1165,8 @@ class MetaConverter: visited_desc, shape_env, callback, - source=temp_source, - symbolic_context=all_dynamic_symbolic_context( + temp_source, + all_dynamic_symbolic_context( visited_desc, temp_source, shape_env, callback ), ) @@ -1102,6 +1210,9 @@ class MetaConverter: # Pray that sparse clone doesn't lose information assert t.data is not None with torch.no_grad(), no_dispatch(): + assert isinstance( + r, torch._subclasses.fake_tensor.FakeTensor + ) r.real_tensor = _safe_clone(t.data) assert safe_is_leaf(r), "the callback you passed in doesn't detach" # Note [is_coalesced is dispatched] @@ -1109,7 +1220,7 @@ class MetaConverter: # which means that it will get caught by fake tensor mode. # Ordinarily this would error, but there's some logic in # fake tensor ensure this doesn't happen. - r._coalesced_(t.is_coalesced) + r._coalesced_(bool(t.is_coalesced)) if t.requires_grad: r.requires_grad = True if t.requires_grad and not is_leaf: @@ -1117,9 +1228,9 @@ class MetaConverter: # but clone is fine for now for sparse tensors. # (DelayedError does not work for sparse because it causes # the Fake sparse tensor to "lose" its fakeness) - r = r.clone() + r = self._checked_cast_tensor_t(r.clone()) with torch.enable_grad(): - r._coalesced_(t.is_coalesced) + r._coalesced_(bool(t.is_coalesced)) elif is_sparse_compressed_layout(t.layout): is_leaf = t.is_leaf @@ -1154,15 +1265,15 @@ class MetaConverter: # Pray sparse clone doesn't lose information assert t.data is not None with torch.no_grad(), no_dispatch(): + assert isinstance( + r, torch._subclasses.fake_tensor.FakeTensor + ) r.real_tensor = _safe_clone(t.data) assert safe_is_leaf(r), "the callback you passed in doesn't detach" if t.requires_grad: r.requires_grad = True if t.requires_grad and not is_leaf: - r = torch._C._functions.DelayedError( - "Internal error: Tried to backward() through example input", - 1, - )(r) + r = self._backward_error(r) elif t.is_nested and not t.is_traceable_wrapper_subclass: # TODO: Handle this better in Dynamo? # There are checks there now, but this can still be triggered by a dense @@ -1174,9 +1285,11 @@ class MetaConverter: ) elif t.is_mkldnn: is_leaf = t.is_leaf - sizes, strides, _storage_offset = sym_sizes_strides_storage_offset( - t, source - ) + ( + sizes, + strides, + _storage_offset, + ) = sym_sizes_strides_storage_offset(t, source) # TODO: This doesn't seem right, where's the MKLDNN'ness # lol r = callback( @@ -1188,6 +1301,9 @@ class MetaConverter: with torch.no_grad(), no_dispatch(): assert t.size is not None assert t.stride is not None + assert isinstance( + r, torch._subclasses.fake_tensor.FakeTensor + ) r.real_tensor = torch.empty_strided( t.size, t.stride, dtype=t.dtype, device=t.device ) @@ -1197,10 +1313,7 @@ class MetaConverter: if t.requires_grad: r.requires_grad = True if t.requires_grad and not is_leaf: - r = torch._C._functions.DelayedError( - "Internal error: Tried to backward() through example input", - 1, - )(r) + r = self._backward_error(r) elif t.is_functorch_wrapped: if t.is_view: from torch._dynamo.exc import unimplemented @@ -1211,9 +1324,10 @@ class MetaConverter: # Wraps a functorch tensor class (BatchedTensor, GradTrackingTensor) # in a FakeTensor - def _to_fake_tensor(t: MetaTensorDesc): + def _to_fake_tensor(t: MetaTensorDesc) -> _TensorT: # TODO: why aren't the recursive calls going to # meta_tensor + r: _TensorT if t.is_batchedtensor: assert t.unwrapped is not None assert t.level is not None @@ -1228,7 +1342,9 @@ class MetaConverter: with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack( t.functorch_stack ): - r = _add_batch_dim(ft, bdim, lvl) + r = self._checked_cast_tensor_t( + _add_batch_dim(ft, bdim, lvl) + ) elif t.is_gradtrackingtensor: assert t.unwrapped is not None assert t.level is not None @@ -1242,33 +1358,32 @@ class MetaConverter: with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack( t.functorch_stack ): - r = torch._C._functorch._wrap_for_grad(ft, lvl) + r = self._checked_cast_tensor_t( + torch._C._functorch._wrap_for_grad(ft, lvl), + ) is_leaf = t.is_leaf if t.requires_grad and safe_is_leaf(r): r.requires_grad = True elif t.requires_grad and not is_leaf: - r = torch._C._functions.DelayedError( # type: ignore[assignment] - "Internal error: Tried to backward() through example input", - 1, - )( - r # type: ignore[arg-type] - ) + r = self._backward_error(r) elif t.is_functional: assert t.unwrapped is not None assert t.current_level is not None ft = self.meta_tensor( t.unwrapped, - shape_env=shape_env, - callback=callback, + shape_env, + callback, # NB: reuse these exactly, we treat the # functional tensor as "invisible". # TODO: Actually this all probably doesn't # work, take a closer look. - source=source, - symbolic_context=symbolic_context, + source, + symbolic_context, + ) + r = self._checked_cast_tensor_t( + _wrap_functional_tensor(ft, t.current_level), ) - r = _wrap_functional_tensor(ft, t.current_level) # TODO: is_leaf/requires_grad? else: assert t.stride is not None @@ -1302,12 +1417,14 @@ class MetaConverter: assert not t.is_functorch_wrapped # handled above unwrapped = self.meta_tensor( t.unwrapped, - shape_env=shape_env, - callback=callback, - source=source, - symbolic_context=symbolic_context, + shape_env, + callback, + source, + symbolic_context, + ) + r = self._checked_cast_tensor_t( + torch._to_functional_tensor(unwrapped) ) - r = torch._to_functional_tensor(unwrapped) torch._mirror_autograd_meta_to(t.autograd_meta_from, r) # type: ignore[attr-defined] elif t.is_view: @@ -1335,11 +1452,13 @@ class MetaConverter: t.base, shape_env, callback, - source=torch._dynamo.source.AttrSource(source, "_base"), - symbolic_context=base_symbolic_context, + torch._dynamo.source.AttrSource(source, "_base"), + base_symbolic_context, ) - def is_c_of_r(complex_dtype, real_dtype): + def is_c_of_r( + complex_dtype: torch.dtype, real_dtype: torch.dtype + ) -> bool: return ( utils.is_complex_dtype(complex_dtype) and utils.corresponding_real_dtype(complex_dtype) @@ -1361,14 +1480,16 @@ class MetaConverter: if base.dtype == t.dtype: pass elif is_c_of_r(base.dtype, t.dtype): - base = torch.view_as_real(base) + base = self._checked_cast_tensor_t(torch.view_as_real(base)) elif is_c_of_r(t.dtype, base.dtype): - base = torch.view_as_complex(base) + base = self._checked_cast_tensor_t( + torch.view_as_complex(base) + ) else: # This is not guaranteed to succeed. If it fails, it # means there is another dtype-converting view function # that hasn't been handled here - base = base.view(t.dtype) + base = self._checked_cast_tensor_t(base.view(t.dtype)) # This is very tricky. Naively, you might expect this # to hold: @@ -1410,7 +1531,9 @@ class MetaConverter: # NB: Can't have a non-leaf without requiring grad! assert t.requires_grad with torch.no_grad(): - mid = base.view(base.shape) + mid = self._checked_cast_tensor_t( + base.view(base.shape) + ) mid.requires_grad = t.requires_grad with torch.enable_grad(): r = view_from_base(mid, t) @@ -1459,6 +1582,9 @@ class MetaConverter: with torch.no_grad(), no_dispatch(): assert t.size is not None assert t.stride is not None + assert isinstance( + r, torch._subclasses.fake_tensor.FakeTensor + ) r.real_tensor = torch.empty_strided( t.size, t.stride, dtype=t.dtype, device=t.device ) @@ -1477,10 +1603,7 @@ class MetaConverter: # the metadata of the inner tensor. # So instead, we now have a dedicated fn to set autograd history, # without inadvertently changing other metadata. - r = torch._C._functions.DelayedError( - "Internal error: Tried to backward() through example input", - 1, - )(r) + r = self._backward_error(r) s = t.storage assert s is not None @@ -1494,8 +1617,12 @@ class MetaConverter: # You're normal and happy, install the fresh storage into the memo self.set_storage_memo(s, r.untyped_storage()) if self.copy_data: - r.untyped_storage().real_storage = ( - r.real_tensor.untyped_storage() + assert isinstance( + r, torch._subclasses.fake_tensor.FakeTensor + ) + assert r.real_tensor is not None + _set_real_storage( + r.untyped_storage(), r.real_tensor.untyped_storage() ) else: # You're in crazy town; somehow you gave us a tensor @@ -1540,8 +1667,13 @@ class MetaConverter: r.set_(r_s, storage_offset, sizes, strides) if self.copy_data: with torch.no_grad(), no_dispatch(): + assert isinstance( + r, torch._subclasses.fake_tensor.FakeTensor + ) + assert r.real_tensor is not None + assert t.stride is not None r.real_tensor.set_( - r_s.real_storage, + _get_real_storage(r_s), t.storage_offset, t.size, t.stride, @@ -1556,8 +1688,8 @@ class MetaConverter: t.grad, shape_env, callback, - source=AttrSource(source, "grad"), - symbolic_context=symbolic_context, + AttrSource(source, "grad"), + symbolic_context, ) torch._C._set_conj(r, t.is_conj) torch._C._set_neg(r, t.is_neg) @@ -1577,27 +1709,33 @@ class MetaConverter: # See Note: [Creating symbolic nested int] if t.nested_int is not None: + assert isinstance(r, torch._subclasses.fake_tensor.FakeTensor) r.nested_int_memo = r.fake_mode.create_symbolic_nested_int( nt_tensor_id=t.nested_int ) self.set_tensor_memo(t, r) - return self.get_tensor_memo(t) + return self._checked_get_tensor_memo(t) def __call__( self, - t, - shape_env=None, + t: torch.Tensor, + shape_env: Optional[ShapeEnv] = None, *, - callback=lambda t: t(), - source=None, - symbolic_context=None, + callback: Optional[Callable[[Callable[[], torch.Tensor]], _TensorT]] = None, + source: Optional[Source] = None, + symbolic_context: Optional[SymbolicContext] = None, # Controls whether or not we should dump the tensor metadata to structured logs # when source is not None. Because we refakify after Dynamo is done, # we don't want to dump info again from AOTAutograd, it is redundant. - trace=True, - ): + trace: bool = True, + ) -> _TensorT: + callback_: Callable[[Callable[[], torch.Tensor]], _TensorT] + if callback is None: + callback_ = self._identity_callable + else: + callback_ = callback # TODO: zero tensors? We appear to have eliminated them by # excluding complex for now @@ -1637,6 +1775,7 @@ class MetaConverter: t_desc = self.describer.describe_tensor(t, trace=trace) if trace: + assert source is not None trace_structured( "describe_source", metadata_fn=lambda: { @@ -1659,10 +1798,10 @@ class MetaConverter: r = self.meta_tensor( t_desc, - shape_env=shape_env, - callback=callback, - source=source, - symbolic_context=symbolic_context, + shape_env, + callback_, + source, + symbolic_context, ) if type(t) is torch.nn.Parameter: diff --git a/torch/_tensor.py b/torch/_tensor.py index 18c309ec876..3ca7cfb435a 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -78,6 +78,8 @@ def _rebuild_from_type_v2(func, new_type, args, state): # torch/_C/__init__.pyi.in to add a type annotation for your method; # otherwise, it will not show up in autocomplete. class Tensor(torch._C.TensorBase): + _is_param: bool + def _clear_non_serializable_cached_data(self): r"""Clears any data cached in the tensor's ``__dict__`` that would prevent the tensor from being serialized.