import weakref import torch from torch.multiprocessing.reductions import StorageWeakRef from torch.utils._mode_utils import no_dispatch def safe_is_leaf(t): try: return t.is_leaf except RuntimeError: # inference mode can trigger this return False # torch.Tensors cannot be used as a key in a dictionary # because they define a custom __eq__ function which when used # to resolve hash collisions will throw when comparing tensors: # "RuntimeError: bool value of Tensor with more than one value is ambiguous." # To avoid that, we use an object which will hold a Tensor and use # its id for both hashing and equality. # In order to use this as a weak key reference, we cannot # simply use weakref.WeakKeyDictionary because the newly constructed # WeakTensorRefKey only use would be a dictionary so it would have no strong # references. # To get around this issue, we can use it as a normal key, and then set # `weakref.finalize` to delete the key when its contained tensor dies. class WeakTensorRefKey(object): def __init__(self, ten): self.ten = weakref.ref(ten) # store id since as soon as ten is deallocated # the old id will no longer be recoverable, and # we need to be able to remove the WeakTensorRefKey # from the dictionary by hashing it to the same # value it had when ten was alive self.id = id(self.ten()) def __hash__(self): return self.id def __eq__(self, other): if id(self) == id(other): return True return self.id == other.id # This is a class for converting multiple tensors into meta tensors which # share the same view/storage structure. The operation model is you allocate # one of these, and then call it repeatedly on all the tensors you want to # convert. It's important to use the same object for tensors you want to # share storage because this is how we correlate shared storages to the same # meta storages. This class will hold weak references to cached tenosrs # and tensor storages. class MetaConverter: def __init__(self): self.storage_memo = {} self.tensor_memo = {} self.maybe_storages_to_delete = [] self.check_expired_frequency = 128 self.check_expired_count = 0 self.hit = 0 self.miss = 0 self.del_hook = None self.arg_cnt = 0 def successful(self): return self.hit > 0 and self.miss == 0 def check_for_expired_weak_storages(self): new_li = [] stor_to_delete = [] for obj in self.maybe_storages_to_delete: if not obj.expired(): new_li.append(obj) else: stor_to_delete.append(obj) for obj in stor_to_delete: self.storage_memo.pop(obj, None) self.maybe_storages_to_delete = new_li # if for some reason we have aquired many storages which have not expired # even though a tensor with their storage has expired (aliasing or otherwise) # check for expired storages less often so as to bound the amount of work we # do checking for expired storages self.check_expired_frequency = max( self.check_expired_frequency, len(self.maybe_storages_to_delete) ) def get_tensor_memo(self, t): return self.tensor_memo.get(WeakTensorRefKey(t), None) def set_tensor_memo(self, t, v): # hold a weak ref to self, otherwise it will be kept alive # by the del_ten closure self_weak_ref = weakref.ref(self) if t.is_sparse: weak_st = None else: weak_st = StorageWeakRef(t.storage()) tensor_ref_key = WeakTensorRefKey(t) def del_ten(): # tensor outlives the converter self_ref = self_weak_ref() if self_ref is None: return # on shutdown, tensor_ref_key may not be in memo self_ref.tensor_memo.pop(tensor_ref_key, None) if weak_st and weak_st.expired(): self_ref.storage_memo.pop(weak_st, None) elif weak_st is not None: # [expired-storages] # NB: even though the tensor has died, # the deallocation of its storage can take longer, # even when the storage has no other uses/views. # In this case, the StorageWeakRef object will be kept alive # longer than it needs to be, however the storage itself # will be deallocated. We retain the possibly dead storages # and periodically check if any of them are expired and # can be freed. self_ref.maybe_storages_to_delete.append(weak_st) weakref.finalize(t, del_ten) self.tensor_memo[tensor_ref_key] = v # NB: doesn't actually return a storage, because meta storage is # not supported def meta_storage(self, s): # NB: TypedStorage is freshly allocated and cannot be used as hash # key index. # Use a Weak Ref to s in order to not leak memory swr = StorageWeakRef(s) if swr not in self.storage_memo: self.storage_memo[swr] = torch.empty(s.size(), dtype=s.dtype, device="meta") return self.storage_memo[swr] # This function assumes that it's possible to do the conversion def meta_tensor(self, t, shape_env=None): arg_cnt = self.arg_cnt self.arg_cnt += 1 make_symbolic = shape_env is not None def sym(x): if make_symbolic: return shape_env.create_symbol(x) else: return x def sym_sizes_strides(t): if make_symbolic: return shape_env.create_symbolic_sizes_strides(t) return (t.size(), t.stride()) # see expired-storages self.check_expired_count += 1 if self.check_expired_count >= self.check_expired_frequency: self.check_for_expired_weak_storages() self.check_expired_count = 0 if self.get_tensor_memo(t) is None: with torch.inference_mode(t.is_inference()): if t.is_sparse: assert shape_env is None, "symbolic on sparse NYI" is_leaf = safe_is_leaf(t) r = torch.ops.aten._sparse_coo_tensor_with_dims( t.sparse_dim(), t.dense_dim(), t.shape, dtype=t.dtype, layout=torch.sparse_coo, device="meta", ) r._coalesced_(t.is_coalesced()) if t.requires_grad: r.requires_grad = True if t.requires_grad and not is_leaf: with torch.enable_grad(): r = r.clone() r._coalesced_(t.is_coalesced()) elif t._is_view(): # Construct views in two steps: recursively meta-fy their # base, and then create the view off that. NB: doing it # directly from storage is WRONG because this won't cause # version counters to get shared. assert t._is_view() base = self.meta_tensor(t._base) def is_c_of_r(complex_dtype, real_dtype): return ( utils.is_complex_dtype(complex_dtype) and utils.corresponding_real_dtype(complex_dtype) == real_dtype ) if base.dtype == t.dtype: pass elif is_c_of_r(base.dtype, t.dtype): base = torch.view_as_real(base) elif is_c_of_r(t.dtype, base.dtype): base = torch.view_as_complex(base) else: # This is not guaranteed to succeed. If it fails, it # means there is another dtype-converting view function # that hasn't been handled here base = base.view(t.dtype) with torch.enable_grad(): sizes, strides = sym_sizes_strides(t) r = base.as_strided(sizes, strides, sym(t.storage_offset())) else: is_leaf = safe_is_leaf(t) # Fake up some autograd history. if t.requires_grad: r = torch.empty( (0,), dtype=t.dtype, device="meta", requires_grad=True ) if not is_leaf: with torch.enable_grad(): # The backward function here will be wrong, but # that's OK; our goal is just to get the metadata # looking as close as possible; we're not going to # actually try to backward() on these produced # metas. TODO: would be safer to install some # sort of unsupported grad_fn here r = r.clone() else: r = torch.empty((0,), dtype=t.dtype, device="meta") # As long as meta storage is not supported, need to prevent # redispatching on set_(Storage, ...) which will choke with # meta storage s = self.meta_storage(t.storage()) with no_dispatch(): sizes, strides = sym_sizes_strides(t) with torch.no_grad(): r.set_(s, sym(t.storage_offset()), sizes, strides) torch._C._set_conj(r, t.is_conj()) torch._C._set_neg(r, t.is_neg()) self.set_tensor_memo(t, r) return self.get_tensor_memo(t) def __call__(self, t, shape_env=None): # TODO: zero tensors? We appear to have eliminated them by # excluding complex for now from torch._subclasses.fake_tensor import FakeTensor if ( type(t) is torch.Tensor or type(t) is torch.nn.Parameter or isinstance(t, FakeTensor) ): if any( [ t.is_sparse_csr, t.layout in [torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc], t.is_mkldnn, t.is_quantized, t.is_nested, t._is_view() and t._base is not None and t._base.is_sparse, torch._is_functional_tensor(t), # these are supported in meta conversion but the fallbacks # don't work t.is_neg(), t.is_conj(), t.device.type in ("lazy", "meta"), # We need a way to test if a tensor is batched but there # is no official APi to do it # torch._C._is_batched(t), ] ): # TODO: sparse should support meta # NB technically to('meta') does work but our logging # instrumentation will see the meta conversions and the # tests all break so we just exclude this. In any case # the to conversion isn't really right anyhow. self.miss += 1 return t else: self.hit += 1 r = self.meta_tensor(t, shape_env=shape_env) if type(t) is torch.nn.Parameter: r = torch.nn.Parameter(r, requires_grad=r.requires_grad) return r elif torch.overrides.is_tensor_like(t): # Blindly converting tensor subclasses to meta can cause # unpredictable problems; e.g., FX tests will trace meta # tensors into their trace / some subclasses don't correctly # support meta. Trying to YOLO this is more trouble than it's # worth. self.miss += 1 return t else: # non-Tensor types don't count as hit or miss return t import torch._prims_common as utils