diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 6ea09369ff2..e989a66f5b9 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -291,7 +291,7 @@ class FakeTensorTest(TestCase): for ten in out: if i == 1: self.assertTrue(isinstance(ten, FakeTensor)) - self.assertTrue(ten.device.type == 'cuda') + self.assertEqual(ten.device.type, 'cuda') @skipIfRocm @unittest.skipIf(not RUN_CUDA, "requires cuda") diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index a4503b1f139..ab33104e830 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -556,8 +556,8 @@ make_fx_failures = { # ??? xfail('nn.functional.ctc_loss'), - # Sparse tensors are not supported with faketensors for now - xfail('to_sparse'), + # proxy tensor doesn't support sparse correctly right now + skip('to_sparse'), # segfaults skip('block_diag'), } diff --git a/test/test_sparse.py b/test/test_sparse.py index e0b50e1b3ed..30bb6f32f10 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -9,7 +9,7 @@ import unittest from torch.testing import make_tensor from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \ do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \ - DeterministicGuard, first_sample + DeterministicGuard, first_sample, TEST_WITH_CROSSREF from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version from numbers import Number from typing import Dict, Any @@ -25,6 +25,7 @@ from torch.testing._internal.common_dtype import ( all_types, all_types_and_complex, all_types_and_complex_and, floating_and_complex_types, floating_and_complex_types_and, integral_types, floating_types_and, ) +from torch.utils._python_dispatch import TorchDispatchMode if TEST_SCIPY: import scipy.sparse @@ -40,7 +41,53 @@ CUSPARSE_SPMM_COMPLEX128_SUPPORTED = ( IS_WINDOWS and torch.version.cuda and LooseVersion(torch.version.cuda) > "11.2" ) or (not IS_WINDOWS and CUDA11OrLater) -class TestSparse(TestCase): +class CrossRefSparseFakeMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + + def on_tensor(f): + def go(t): + if isinstance(t, torch.Tensor): + return f(t) + else: + return t + return go + + # empty_like excluded for now due to sparse complex + # aten._to_dense.default this one is getting called with csc + if ( + func not in [ + torch.ops.aten.lift_fresh.default, + torch.ops.aten.empty_like.default, + torch.ops.aten.set_.source_Storage_storage_offset, + torch.ops.aten.sspaddmm.out, + torch.ops.aten._spdiags.default, + torch.ops.aten._to_dense.default + ] + and torch.Tag.dynamic_output_shape not in func.tags + and torch.Tag.inplace_view not in func.tags + ): + from torch._subclasses.fake_tensor import FakeTensorMode, UnsupportedFakeTensorException + from torch.utils._pytree import tree_map + try: + with FakeTensorMode(allow_meta=True) as fake_mode: + fake_args, fake_kwargs = tree_map(on_tensor(fake_mode.from_tensor), (args, kwargs)) + fake_r = func(*fake_args, **fake_kwargs) + except UnsupportedFakeTensorException: + pass + + r = func(*args, **kwargs) + return r + +class TestSparseBase(TestCase): + def run(self, result=None): + if TEST_WITH_CROSSREF: + with CrossRefSparseFakeMode(): + return super().run(result) + else: + return super().run(result) + +class TestSparse(TestSparseBase): def setUp(self): TestCase.setUp(self) @@ -1641,6 +1688,7 @@ class TestSparse(TestCase): @coalescedonoff @dtypes(torch.double) + @unittest.skipIf(TEST_WITH_CROSSREF, "fallback triggers cuda device error") def test_sparse_sum(self, device, dtype, coalesced): def run_tests(S, td=None): @@ -3413,6 +3461,7 @@ class TestSparse(TestCase): *[torch.bfloat16] if CUDA11OrLater and SM80OrLater else [], *[torch.complex64] if CUDA11OrLater else [], *[torch.complex128] if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else [])) + @unittest.skipIf(TEST_WITH_CROSSREF, "not working with fake tensor") @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2, torch.complex64: 1e-2, torch.float32: 1e-2}) def test_sparse_matmul(self, device, dtype, coalesced): """ diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index e4dbc6c70aa..f126fe182d4 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -129,12 +129,13 @@ class FakeTensorConverter(object): return maybe_memo existing_device = t.device # not yet supported in metatensors - if t.is_sparse: - raise UnsupportedFakeTensorException("sparse nyi in meta tensors") if t.is_quantized: raise UnsupportedFakeTensorException("quantized nyi in meta tensors") with no_dispatch(): - out = FakeTensor(fake_mode, self.meta_converter(t), existing_device) + meta_t = self.meta_converter(t) + if meta_t.device.type != "meta": + raise UnsupportedFakeTensorException("meta converter nyi") + out = FakeTensor(fake_mode, meta_t, existing_device) if type(t) is torch.nn.Parameter: out = torch.nn.Parameter(out, requires_grad=out.requires_grad) # type: ignore[assignment] if t.grad is not None: @@ -150,11 +151,21 @@ class FakeTensorConverter(object): self.set_tensor_memo(t, out) return out + # There are two ways to call this. First, you can have manually constructed + # a meta tensor and you need to turn it into a fake tensor. In that case, + # pass a meta tensor and a device argument. Alternately, you can have a + # real tensor that you need to convert into a fake tensor; in that case, + # omit the device. + # + # The disallowed case: if you specify the device, it MUST be a meta tensor. + # However, you're allowed to pass a meta tensor to be turned into a fake + # tensor; although an odd thing to do, this can occur if you're doing + # cross ref testing and the inner test is already operating on meta tensors def __call__(self, fake_mode, t, device=None): - assert t.device.type != "meta" or device is not None - if t.device.type != "meta": + if device is None: return self.from_real_tensor(fake_mode, t) else: + assert t.device.type == "meta" return self.from_meta_and_device(fake_mode, t, device) @@ -216,6 +227,12 @@ def resize_as_(fake_mode, func, *args, **kwargs): return func(*args, **kwargs) +@register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default) +def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs): + # TODO: remove me + return constructors(fake_mode, func, *args, **kwargs) + + # _to_copy fails when run with FakeTensors to cuda device # TODO: debug @register_op_impl(aten._to_copy.default) @@ -345,14 +362,20 @@ class FakeTensor(torch.Tensor): ) def __init__(self, fake_mode, elem, device: Union[torch.device, str]): - # elem does not need to be recorded, because FakeTensor *is a* elem - assert elem.device.type == "meta", elem + assert elem.device.type == "meta", elem.device.type device = device if isinstance(device, torch.device) else torch.device(device) - # normalize cuda device + # NB: it is fine, if a little confusing, for device to be meta + # (we are faking a meta tensor in that case). However, it often + # indicates some sort of confusion (e.g., you accidentally passed + # in a meta tensor when you should have passed in the real tensor). + # So by default we disallow meta, and if you are working in a situation + # where it is helpful (e.g., crossref testing) you can turn it back + # on + if not fake_mode.allow_meta: + assert device.type != "meta" + # normalize cuda device. if device.type == "cuda" and device.index is None: device = torch.device(f"cuda:{torch.cuda.current_device()}") - assert device.type != "meta" - self.fake_device = device self.fake_mode = fake_mode self.has_sym_ints = symbolic_shapes.has_symbolic_sizes_strides(elem) @@ -360,11 +383,14 @@ class FakeTensor(torch.Tensor): @staticmethod def from_tensor(t, fake_mode): existing_device = t.device + # TODO: this should use meta converter return FakeTensor(fake_mode, t.to(device="meta"), existing_device) # TODO: resolve error in default __repr__ def __repr__(self): - return f"FakeTensor({self.fake_device}, {self.size()}, {self.dtype})" + with in_kernel_invocation_manager(self.fake_mode): + self_repr = super().__repr__() + return f"FakeTensor({self.fake_mode}, {self_repr}, {self.fake_device})" def stride(self): if self.has_sym_ints: @@ -404,6 +430,14 @@ class FakeTensor(torch.Tensor): return torch.device("meta") else: return args[0].fake_device + # Need this to handle infinite recursion with sparse tensors. + # Sparse tensors have custom stride policy which means that + # they will dispatch here on dispatch, and we need to trigger + # the default behavior. + # TODO: when we get other tensor types online they will also + # need to get entries here. + elif func == torch.ops.aten.stride.default: + return None # Because fake mode can return NotImplemented (if it sees a subclass # it doesn't know how to deal with), this test here is important @@ -489,9 +523,10 @@ class FakeTensor(torch.Tensor): class FakeTensorMode(TorchDispatchMode): - def __init__(self, allow_fallback_kernels=True): + def __init__(self, *, allow_fallback_kernels=True, allow_meta=False): self.allow_fallback_kernels = allow_fallback_kernels self.fake_tensor_converter = FakeTensorConverter() + self.allow_meta = allow_meta # [in_kernel_invocation] # when FakeTensor is invoked in user code, .device should return @@ -637,7 +672,9 @@ class FakeTensorMode(TorchDispatchMode): except NotImplementedError as not_implemented_error: if not self.allow_fallback_kernels: raise not_implemented_error - r = run_fallback_kernel(func, args, kwargs, not_implemented_error) + return run_fallback_kernel( + self, func, args, kwargs, not_implemented_error + ) # TODO: handle non-kwarg devices assert func not in _device_not_kwarg_ops, f"NYI: {func}" @@ -666,7 +703,8 @@ class FakeTensorMode(TorchDispatchMode): return self.fake_tensor_converter(self, tensor) -def run_fallback_kernel(func, args, kwargs, orig_not_implemented_exception): +# NB: returns fake tensors +def run_fallback_kernel(fake_mode, func, args, kwargs, orig_not_implemented_exception): # these should all be supported, just to be safe # avoid fallback for operators which inplace modify metadata # because the input fake tensors would be umodified @@ -679,6 +717,8 @@ def run_fallback_kernel(func, args, kwargs, orig_not_implemented_exception): def to_real_tensor(e): if isinstance(e, FakeTensor): out = torch.zeros_like(e, device=e.fake_device) + if e.is_sparse: + out._coalesced_(e.is_coalesced()) inp_impls[id(out)] = e return out return e @@ -693,7 +733,8 @@ def run_fallback_kernel(func, args, kwargs, orig_not_implemented_exception): for e in tree_flatten((args, kwargs))[0]: if isinstance(e, torch.Tensor): - storages.add(e.storage()._cdata) + if not e.is_sparse: + storages.add(e.storage()._cdata) # TODO: also check metadata change on inputs # proper aliasing/metadata relationship between outputs and inputs will @@ -701,16 +742,20 @@ def run_fallback_kernel(func, args, kwargs, orig_not_implemented_exception): # input impl for e in tree_flatten(r)[0]: if id(e) not in inp_impls and ( - isinstance(e, torch.Tensor) and e.storage()._cdata in storages + isinstance(e, torch.Tensor) + and not e.is_sparse + and e.storage()._cdata in storages ): raise orig_not_implemented_exception - # the outputs which are are not reused from impls will be converted - # to fake tensors later - meta_converter = MetaConverter() - def map_out(e): - return inp_impls.get(id(e), meta_converter(e)) + if isinstance(e, torch.Tensor): + if id(e) in inp_impls: + return inp_impls[id(e)] + else: + return fake_mode.fake_tensor_converter(fake_mode, e) + else: + return e return tree_map(map_out, r) diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 5e554fbf5f4..d14685e44bd 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -94,7 +94,10 @@ class MetaConverter: # hold a weak ref to self, otherwise it will be kept alive # by the del_ten closure self_weak_ref = weakref.ref(self) - weak_st = StorageWeakRef(t.storage()) + if t.is_sparse: + weak_st = None + else: + weak_st = StorageWeakRef(t.storage()) tensor_ref_key = WeakTensorRefKey(t) def del_ten(): @@ -106,7 +109,7 @@ class MetaConverter: self_ref.tensor_memo.pop(tensor_ref_key, None) if weak_st and weak_st.expired(): self_ref.storage_memo.pop(weak_st, None) - else: + elif weak_st is not None: # [expired-storages] # NB: even though the tensor has died, # the deallocation of its storage can take longer, @@ -143,7 +146,25 @@ class MetaConverter: if self.get_tensor_memo(t) is None: with torch.inference_mode(t.is_inference()): - if t._is_view(): + if t.is_sparse: + is_leaf = safe_is_leaf(t) + r = torch.ops.aten._sparse_coo_tensor_with_dims( + t.sparse_dim(), + t.dense_dim(), + t.shape, + dtype=t.dtype, + layout=torch.sparse_coo, + device="meta", + ) + r._coalesced_(t.is_coalesced()) + if t.requires_grad: + r.requires_grad = True + if t.requires_grad and not is_leaf: + with torch.enable_grad(): + r = r.clone() + r._coalesced_(t.is_coalesced()) + + elif t._is_view(): # Construct views in two steps: recursively meta-fy their # base, and then create the view off that. NB: doing it # directly from storage is WRONG because this won't cause @@ -211,10 +232,11 @@ class MetaConverter: if any( [ t.is_sparse_csr, - t.is_sparse, + t.layout in [torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc], t.is_mkldnn, t.is_quantized, t.is_nested, + t._is_view() and t._base is not None and t._base.is_sparse, torch._is_functional_tensor(t), # these are supported in meta conversion but the fallbacks # don't work diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py index 0308d028bdd..493f17637a1 100644 --- a/torch/_tensor_str.py +++ b/torch/_tensor_str.py @@ -409,7 +409,10 @@ def _str_intern(inp, *, tensor_contents=None): ) if self.is_sparse: suffixes.append("size=" + str(tuple(self.shape))) - suffixes.append("nnz=" + str(self._nnz())) + from torch._subclasses.fake_tensor import FakeTensor + + if not self.is_meta and not isinstance(self, FakeTensor): + suffixes.append("nnz=" + str(self._nnz())) if not has_default_dtype: suffixes.append("dtype=" + str(self.dtype)) if not custom_contents_provided: