Sparse fake tensor support (#82172)

Add support for sparse fake tensors.

- The testing strategy is to run a fake tensor cross ref test on `test_sparse.py`. This is necessary because OpInfo sparse coverage is completely nonexistent. We could have tried to turn on cross ref testing globally for all files, but that would be very time consuming and the tests I'm interested in are mostly in this file. There are some exclusions in testing for things that don't work.
- I make fake tensor converter raise a UnsupportedFakeTensorException if the meta converter fails to do a conversion (which can happen in a relatively large number of situations).
- I relax fake tensor invariants so that you can make a fake tensor from a meta tensor. This is useful because in the cross ref test sometimes we operate on meta tensors.
- Fake tensor wrapping is improved to handle the case when a function doesn't return any tensors
- Meta converter is taught how to convert sparse tensors to meta

There's still a little more cleanup that needs to be done, but this is good for review.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82172
Approved by: https://github.com/eellison
This commit is contained in:
Edward Z. Yang 2022-08-02 13:22:19 -07:00 committed by PyTorch MergeBot
parent 8092cf60c6
commit 42fefd4403
6 changed files with 150 additions and 31 deletions

View File

@ -291,7 +291,7 @@ class FakeTensorTest(TestCase):
for ten in out:
if i == 1:
self.assertTrue(isinstance(ten, FakeTensor))
self.assertTrue(ten.device.type == 'cuda')
self.assertEqual(ten.device.type, 'cuda')
@skipIfRocm
@unittest.skipIf(not RUN_CUDA, "requires cuda")

View File

@ -556,8 +556,8 @@ make_fx_failures = {
# ???
xfail('nn.functional.ctc_loss'),
# Sparse tensors are not supported with faketensors for now
xfail('to_sparse'),
# proxy tensor doesn't support sparse correctly right now
skip('to_sparse'),
# segfaults
skip('block_diag'),
}

View File

@ -9,7 +9,7 @@ import unittest
from torch.testing import make_tensor
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \
DeterministicGuard, first_sample
DeterministicGuard, first_sample, TEST_WITH_CROSSREF
from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
from numbers import Number
from typing import Dict, Any
@ -25,6 +25,7 @@ from torch.testing._internal.common_dtype import (
all_types, all_types_and_complex, all_types_and_complex_and, floating_and_complex_types,
floating_and_complex_types_and, integral_types, floating_types_and,
)
from torch.utils._python_dispatch import TorchDispatchMode
if TEST_SCIPY:
import scipy.sparse
@ -40,7 +41,53 @@ CUSPARSE_SPMM_COMPLEX128_SUPPORTED = (
IS_WINDOWS and torch.version.cuda and LooseVersion(torch.version.cuda) > "11.2"
) or (not IS_WINDOWS and CUDA11OrLater)
class TestSparse(TestCase):
class CrossRefSparseFakeMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
def on_tensor(f):
def go(t):
if isinstance(t, torch.Tensor):
return f(t)
else:
return t
return go
# empty_like excluded for now due to sparse complex
# aten._to_dense.default this one is getting called with csc
if (
func not in [
torch.ops.aten.lift_fresh.default,
torch.ops.aten.empty_like.default,
torch.ops.aten.set_.source_Storage_storage_offset,
torch.ops.aten.sspaddmm.out,
torch.ops.aten._spdiags.default,
torch.ops.aten._to_dense.default
]
and torch.Tag.dynamic_output_shape not in func.tags
and torch.Tag.inplace_view not in func.tags
):
from torch._subclasses.fake_tensor import FakeTensorMode, UnsupportedFakeTensorException
from torch.utils._pytree import tree_map
try:
with FakeTensorMode(allow_meta=True) as fake_mode:
fake_args, fake_kwargs = tree_map(on_tensor(fake_mode.from_tensor), (args, kwargs))
fake_r = func(*fake_args, **fake_kwargs)
except UnsupportedFakeTensorException:
pass
r = func(*args, **kwargs)
return r
class TestSparseBase(TestCase):
def run(self, result=None):
if TEST_WITH_CROSSREF:
with CrossRefSparseFakeMode():
return super().run(result)
else:
return super().run(result)
class TestSparse(TestSparseBase):
def setUp(self):
TestCase.setUp(self)
@ -1641,6 +1688,7 @@ class TestSparse(TestCase):
@coalescedonoff
@dtypes(torch.double)
@unittest.skipIf(TEST_WITH_CROSSREF, "fallback triggers cuda device error")
def test_sparse_sum(self, device, dtype, coalesced):
def run_tests(S, td=None):
@ -3413,6 +3461,7 @@ class TestSparse(TestCase):
*[torch.bfloat16] if CUDA11OrLater and SM80OrLater else [],
*[torch.complex64] if CUDA11OrLater else [],
*[torch.complex128] if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else []))
@unittest.skipIf(TEST_WITH_CROSSREF, "not working with fake tensor")
@precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2, torch.complex64: 1e-2, torch.float32: 1e-2})
def test_sparse_matmul(self, device, dtype, coalesced):
"""

View File

@ -129,12 +129,13 @@ class FakeTensorConverter(object):
return maybe_memo
existing_device = t.device
# not yet supported in metatensors
if t.is_sparse:
raise UnsupportedFakeTensorException("sparse nyi in meta tensors")
if t.is_quantized:
raise UnsupportedFakeTensorException("quantized nyi in meta tensors")
with no_dispatch():
out = FakeTensor(fake_mode, self.meta_converter(t), existing_device)
meta_t = self.meta_converter(t)
if meta_t.device.type != "meta":
raise UnsupportedFakeTensorException("meta converter nyi")
out = FakeTensor(fake_mode, meta_t, existing_device)
if type(t) is torch.nn.Parameter:
out = torch.nn.Parameter(out, requires_grad=out.requires_grad) # type: ignore[assignment]
if t.grad is not None:
@ -150,11 +151,21 @@ class FakeTensorConverter(object):
self.set_tensor_memo(t, out)
return out
# There are two ways to call this. First, you can have manually constructed
# a meta tensor and you need to turn it into a fake tensor. In that case,
# pass a meta tensor and a device argument. Alternately, you can have a
# real tensor that you need to convert into a fake tensor; in that case,
# omit the device.
#
# The disallowed case: if you specify the device, it MUST be a meta tensor.
# However, you're allowed to pass a meta tensor to be turned into a fake
# tensor; although an odd thing to do, this can occur if you're doing
# cross ref testing and the inner test is already operating on meta tensors
def __call__(self, fake_mode, t, device=None):
assert t.device.type != "meta" or device is not None
if t.device.type != "meta":
if device is None:
return self.from_real_tensor(fake_mode, t)
else:
assert t.device.type == "meta"
return self.from_meta_and_device(fake_mode, t, device)
@ -216,6 +227,12 @@ def resize_as_(fake_mode, func, *args, **kwargs):
return func(*args, **kwargs)
@register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default)
def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs):
# TODO: remove me
return constructors(fake_mode, func, *args, **kwargs)
# _to_copy fails when run with FakeTensors to cuda device
# TODO: debug
@register_op_impl(aten._to_copy.default)
@ -345,14 +362,20 @@ class FakeTensor(torch.Tensor):
)
def __init__(self, fake_mode, elem, device: Union[torch.device, str]):
# elem does not need to be recorded, because FakeTensor *is a* elem
assert elem.device.type == "meta", elem
assert elem.device.type == "meta", elem.device.type
device = device if isinstance(device, torch.device) else torch.device(device)
# normalize cuda device
# NB: it is fine, if a little confusing, for device to be meta
# (we are faking a meta tensor in that case). However, it often
# indicates some sort of confusion (e.g., you accidentally passed
# in a meta tensor when you should have passed in the real tensor).
# So by default we disallow meta, and if you are working in a situation
# where it is helpful (e.g., crossref testing) you can turn it back
# on
if not fake_mode.allow_meta:
assert device.type != "meta"
# normalize cuda device.
if device.type == "cuda" and device.index is None:
device = torch.device(f"cuda:{torch.cuda.current_device()}")
assert device.type != "meta"
self.fake_device = device
self.fake_mode = fake_mode
self.has_sym_ints = symbolic_shapes.has_symbolic_sizes_strides(elem)
@ -360,11 +383,14 @@ class FakeTensor(torch.Tensor):
@staticmethod
def from_tensor(t, fake_mode):
existing_device = t.device
# TODO: this should use meta converter
return FakeTensor(fake_mode, t.to(device="meta"), existing_device)
# TODO: resolve error in default __repr__
def __repr__(self):
return f"FakeTensor({self.fake_device}, {self.size()}, {self.dtype})"
with in_kernel_invocation_manager(self.fake_mode):
self_repr = super().__repr__()
return f"FakeTensor({self.fake_mode}, {self_repr}, {self.fake_device})"
def stride(self):
if self.has_sym_ints:
@ -404,6 +430,14 @@ class FakeTensor(torch.Tensor):
return torch.device("meta")
else:
return args[0].fake_device
# Need this to handle infinite recursion with sparse tensors.
# Sparse tensors have custom stride policy which means that
# they will dispatch here on dispatch, and we need to trigger
# the default behavior.
# TODO: when we get other tensor types online they will also
# need to get entries here.
elif func == torch.ops.aten.stride.default:
return None
# Because fake mode can return NotImplemented (if it sees a subclass
# it doesn't know how to deal with), this test here is important
@ -489,9 +523,10 @@ class FakeTensor(torch.Tensor):
class FakeTensorMode(TorchDispatchMode):
def __init__(self, allow_fallback_kernels=True):
def __init__(self, *, allow_fallback_kernels=True, allow_meta=False):
self.allow_fallback_kernels = allow_fallback_kernels
self.fake_tensor_converter = FakeTensorConverter()
self.allow_meta = allow_meta
# [in_kernel_invocation]
# when FakeTensor is invoked in user code, .device should return
@ -637,7 +672,9 @@ class FakeTensorMode(TorchDispatchMode):
except NotImplementedError as not_implemented_error:
if not self.allow_fallback_kernels:
raise not_implemented_error
r = run_fallback_kernel(func, args, kwargs, not_implemented_error)
return run_fallback_kernel(
self, func, args, kwargs, not_implemented_error
)
# TODO: handle non-kwarg devices
assert func not in _device_not_kwarg_ops, f"NYI: {func}"
@ -666,7 +703,8 @@ class FakeTensorMode(TorchDispatchMode):
return self.fake_tensor_converter(self, tensor)
def run_fallback_kernel(func, args, kwargs, orig_not_implemented_exception):
# NB: returns fake tensors
def run_fallback_kernel(fake_mode, func, args, kwargs, orig_not_implemented_exception):
# these should all be supported, just to be safe
# avoid fallback for operators which inplace modify metadata
# because the input fake tensors would be umodified
@ -679,6 +717,8 @@ def run_fallback_kernel(func, args, kwargs, orig_not_implemented_exception):
def to_real_tensor(e):
if isinstance(e, FakeTensor):
out = torch.zeros_like(e, device=e.fake_device)
if e.is_sparse:
out._coalesced_(e.is_coalesced())
inp_impls[id(out)] = e
return out
return e
@ -693,7 +733,8 @@ def run_fallback_kernel(func, args, kwargs, orig_not_implemented_exception):
for e in tree_flatten((args, kwargs))[0]:
if isinstance(e, torch.Tensor):
storages.add(e.storage()._cdata)
if not e.is_sparse:
storages.add(e.storage()._cdata)
# TODO: also check metadata change on inputs
# proper aliasing/metadata relationship between outputs and inputs will
@ -701,16 +742,20 @@ def run_fallback_kernel(func, args, kwargs, orig_not_implemented_exception):
# input impl
for e in tree_flatten(r)[0]:
if id(e) not in inp_impls and (
isinstance(e, torch.Tensor) and e.storage()._cdata in storages
isinstance(e, torch.Tensor)
and not e.is_sparse
and e.storage()._cdata in storages
):
raise orig_not_implemented_exception
# the outputs which are are not reused from impls will be converted
# to fake tensors later
meta_converter = MetaConverter()
def map_out(e):
return inp_impls.get(id(e), meta_converter(e))
if isinstance(e, torch.Tensor):
if id(e) in inp_impls:
return inp_impls[id(e)]
else:
return fake_mode.fake_tensor_converter(fake_mode, e)
else:
return e
return tree_map(map_out, r)

View File

@ -94,7 +94,10 @@ class MetaConverter:
# hold a weak ref to self, otherwise it will be kept alive
# by the del_ten closure
self_weak_ref = weakref.ref(self)
weak_st = StorageWeakRef(t.storage())
if t.is_sparse:
weak_st = None
else:
weak_st = StorageWeakRef(t.storage())
tensor_ref_key = WeakTensorRefKey(t)
def del_ten():
@ -106,7 +109,7 @@ class MetaConverter:
self_ref.tensor_memo.pop(tensor_ref_key, None)
if weak_st and weak_st.expired():
self_ref.storage_memo.pop(weak_st, None)
else:
elif weak_st is not None:
# [expired-storages]
# NB: even though the tensor has died,
# the deallocation of its storage can take longer,
@ -143,7 +146,25 @@ class MetaConverter:
if self.get_tensor_memo(t) is None:
with torch.inference_mode(t.is_inference()):
if t._is_view():
if t.is_sparse:
is_leaf = safe_is_leaf(t)
r = torch.ops.aten._sparse_coo_tensor_with_dims(
t.sparse_dim(),
t.dense_dim(),
t.shape,
dtype=t.dtype,
layout=torch.sparse_coo,
device="meta",
)
r._coalesced_(t.is_coalesced())
if t.requires_grad:
r.requires_grad = True
if t.requires_grad and not is_leaf:
with torch.enable_grad():
r = r.clone()
r._coalesced_(t.is_coalesced())
elif t._is_view():
# Construct views in two steps: recursively meta-fy their
# base, and then create the view off that. NB: doing it
# directly from storage is WRONG because this won't cause
@ -211,10 +232,11 @@ class MetaConverter:
if any(
[
t.is_sparse_csr,
t.is_sparse,
t.layout in [torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc],
t.is_mkldnn,
t.is_quantized,
t.is_nested,
t._is_view() and t._base is not None and t._base.is_sparse,
torch._is_functional_tensor(t),
# these are supported in meta conversion but the fallbacks
# don't work

View File

@ -409,7 +409,10 @@ def _str_intern(inp, *, tensor_contents=None):
)
if self.is_sparse:
suffixes.append("size=" + str(tuple(self.shape)))
suffixes.append("nnz=" + str(self._nnz()))
from torch._subclasses.fake_tensor import FakeTensor
if not self.is_meta and not isinstance(self, FakeTensor):
suffixes.append("nnz=" + str(self._nnz()))
if not has_default_dtype:
suffixes.append("dtype=" + str(self.dtype))
if not custom_contents_provided: