mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Sparse fake tensor support (#82172)
Add support for sparse fake tensors. - The testing strategy is to run a fake tensor cross ref test on `test_sparse.py`. This is necessary because OpInfo sparse coverage is completely nonexistent. We could have tried to turn on cross ref testing globally for all files, but that would be very time consuming and the tests I'm interested in are mostly in this file. There are some exclusions in testing for things that don't work. - I make fake tensor converter raise a UnsupportedFakeTensorException if the meta converter fails to do a conversion (which can happen in a relatively large number of situations). - I relax fake tensor invariants so that you can make a fake tensor from a meta tensor. This is useful because in the cross ref test sometimes we operate on meta tensors. - Fake tensor wrapping is improved to handle the case when a function doesn't return any tensors - Meta converter is taught how to convert sparse tensors to meta There's still a little more cleanup that needs to be done, but this is good for review. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/82172 Approved by: https://github.com/eellison
This commit is contained in:
parent
8092cf60c6
commit
42fefd4403
|
|
@ -291,7 +291,7 @@ class FakeTensorTest(TestCase):
|
|||
for ten in out:
|
||||
if i == 1:
|
||||
self.assertTrue(isinstance(ten, FakeTensor))
|
||||
self.assertTrue(ten.device.type == 'cuda')
|
||||
self.assertEqual(ten.device.type, 'cuda')
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
|
|
|
|||
|
|
@ -556,8 +556,8 @@ make_fx_failures = {
|
|||
|
||||
# ???
|
||||
xfail('nn.functional.ctc_loss'),
|
||||
# Sparse tensors are not supported with faketensors for now
|
||||
xfail('to_sparse'),
|
||||
# proxy tensor doesn't support sparse correctly right now
|
||||
skip('to_sparse'),
|
||||
# segfaults
|
||||
skip('block_diag'),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import unittest
|
|||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
|
||||
do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \
|
||||
DeterministicGuard, first_sample
|
||||
DeterministicGuard, first_sample, TEST_WITH_CROSSREF
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
|
||||
from numbers import Number
|
||||
from typing import Dict, Any
|
||||
|
|
@ -25,6 +25,7 @@ from torch.testing._internal.common_dtype import (
|
|||
all_types, all_types_and_complex, all_types_and_complex_and, floating_and_complex_types,
|
||||
floating_and_complex_types_and, integral_types, floating_types_and,
|
||||
)
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
if TEST_SCIPY:
|
||||
import scipy.sparse
|
||||
|
|
@ -40,7 +41,53 @@ CUSPARSE_SPMM_COMPLEX128_SUPPORTED = (
|
|||
IS_WINDOWS and torch.version.cuda and LooseVersion(torch.version.cuda) > "11.2"
|
||||
) or (not IS_WINDOWS and CUDA11OrLater)
|
||||
|
||||
class TestSparse(TestCase):
|
||||
class CrossRefSparseFakeMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
def on_tensor(f):
|
||||
def go(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return f(t)
|
||||
else:
|
||||
return t
|
||||
return go
|
||||
|
||||
# empty_like excluded for now due to sparse complex
|
||||
# aten._to_dense.default this one is getting called with csc
|
||||
if (
|
||||
func not in [
|
||||
torch.ops.aten.lift_fresh.default,
|
||||
torch.ops.aten.empty_like.default,
|
||||
torch.ops.aten.set_.source_Storage_storage_offset,
|
||||
torch.ops.aten.sspaddmm.out,
|
||||
torch.ops.aten._spdiags.default,
|
||||
torch.ops.aten._to_dense.default
|
||||
]
|
||||
and torch.Tag.dynamic_output_shape not in func.tags
|
||||
and torch.Tag.inplace_view not in func.tags
|
||||
):
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode, UnsupportedFakeTensorException
|
||||
from torch.utils._pytree import tree_map
|
||||
try:
|
||||
with FakeTensorMode(allow_meta=True) as fake_mode:
|
||||
fake_args, fake_kwargs = tree_map(on_tensor(fake_mode.from_tensor), (args, kwargs))
|
||||
fake_r = func(*fake_args, **fake_kwargs)
|
||||
except UnsupportedFakeTensorException:
|
||||
pass
|
||||
|
||||
r = func(*args, **kwargs)
|
||||
return r
|
||||
|
||||
class TestSparseBase(TestCase):
|
||||
def run(self, result=None):
|
||||
if TEST_WITH_CROSSREF:
|
||||
with CrossRefSparseFakeMode():
|
||||
return super().run(result)
|
||||
else:
|
||||
return super().run(result)
|
||||
|
||||
class TestSparse(TestSparseBase):
|
||||
|
||||
def setUp(self):
|
||||
TestCase.setUp(self)
|
||||
|
|
@ -1641,6 +1688,7 @@ class TestSparse(TestCase):
|
|||
|
||||
@coalescedonoff
|
||||
@dtypes(torch.double)
|
||||
@unittest.skipIf(TEST_WITH_CROSSREF, "fallback triggers cuda device error")
|
||||
def test_sparse_sum(self, device, dtype, coalesced):
|
||||
|
||||
def run_tests(S, td=None):
|
||||
|
|
@ -3413,6 +3461,7 @@ class TestSparse(TestCase):
|
|||
*[torch.bfloat16] if CUDA11OrLater and SM80OrLater else [],
|
||||
*[torch.complex64] if CUDA11OrLater else [],
|
||||
*[torch.complex128] if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else []))
|
||||
@unittest.skipIf(TEST_WITH_CROSSREF, "not working with fake tensor")
|
||||
@precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2, torch.complex64: 1e-2, torch.float32: 1e-2})
|
||||
def test_sparse_matmul(self, device, dtype, coalesced):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -129,12 +129,13 @@ class FakeTensorConverter(object):
|
|||
return maybe_memo
|
||||
existing_device = t.device
|
||||
# not yet supported in metatensors
|
||||
if t.is_sparse:
|
||||
raise UnsupportedFakeTensorException("sparse nyi in meta tensors")
|
||||
if t.is_quantized:
|
||||
raise UnsupportedFakeTensorException("quantized nyi in meta tensors")
|
||||
with no_dispatch():
|
||||
out = FakeTensor(fake_mode, self.meta_converter(t), existing_device)
|
||||
meta_t = self.meta_converter(t)
|
||||
if meta_t.device.type != "meta":
|
||||
raise UnsupportedFakeTensorException("meta converter nyi")
|
||||
out = FakeTensor(fake_mode, meta_t, existing_device)
|
||||
if type(t) is torch.nn.Parameter:
|
||||
out = torch.nn.Parameter(out, requires_grad=out.requires_grad) # type: ignore[assignment]
|
||||
if t.grad is not None:
|
||||
|
|
@ -150,11 +151,21 @@ class FakeTensorConverter(object):
|
|||
self.set_tensor_memo(t, out)
|
||||
return out
|
||||
|
||||
# There are two ways to call this. First, you can have manually constructed
|
||||
# a meta tensor and you need to turn it into a fake tensor. In that case,
|
||||
# pass a meta tensor and a device argument. Alternately, you can have a
|
||||
# real tensor that you need to convert into a fake tensor; in that case,
|
||||
# omit the device.
|
||||
#
|
||||
# The disallowed case: if you specify the device, it MUST be a meta tensor.
|
||||
# However, you're allowed to pass a meta tensor to be turned into a fake
|
||||
# tensor; although an odd thing to do, this can occur if you're doing
|
||||
# cross ref testing and the inner test is already operating on meta tensors
|
||||
def __call__(self, fake_mode, t, device=None):
|
||||
assert t.device.type != "meta" or device is not None
|
||||
if t.device.type != "meta":
|
||||
if device is None:
|
||||
return self.from_real_tensor(fake_mode, t)
|
||||
else:
|
||||
assert t.device.type == "meta"
|
||||
return self.from_meta_and_device(fake_mode, t, device)
|
||||
|
||||
|
||||
|
|
@ -216,6 +227,12 @@ def resize_as_(fake_mode, func, *args, **kwargs):
|
|||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default)
|
||||
def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs):
|
||||
# TODO: remove me
|
||||
return constructors(fake_mode, func, *args, **kwargs)
|
||||
|
||||
|
||||
# _to_copy fails when run with FakeTensors to cuda device
|
||||
# TODO: debug
|
||||
@register_op_impl(aten._to_copy.default)
|
||||
|
|
@ -345,14 +362,20 @@ class FakeTensor(torch.Tensor):
|
|||
)
|
||||
|
||||
def __init__(self, fake_mode, elem, device: Union[torch.device, str]):
|
||||
# elem does not need to be recorded, because FakeTensor *is a* elem
|
||||
assert elem.device.type == "meta", elem
|
||||
assert elem.device.type == "meta", elem.device.type
|
||||
device = device if isinstance(device, torch.device) else torch.device(device)
|
||||
# normalize cuda device
|
||||
# NB: it is fine, if a little confusing, for device to be meta
|
||||
# (we are faking a meta tensor in that case). However, it often
|
||||
# indicates some sort of confusion (e.g., you accidentally passed
|
||||
# in a meta tensor when you should have passed in the real tensor).
|
||||
# So by default we disallow meta, and if you are working in a situation
|
||||
# where it is helpful (e.g., crossref testing) you can turn it back
|
||||
# on
|
||||
if not fake_mode.allow_meta:
|
||||
assert device.type != "meta"
|
||||
# normalize cuda device.
|
||||
if device.type == "cuda" and device.index is None:
|
||||
device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
assert device.type != "meta"
|
||||
|
||||
self.fake_device = device
|
||||
self.fake_mode = fake_mode
|
||||
self.has_sym_ints = symbolic_shapes.has_symbolic_sizes_strides(elem)
|
||||
|
|
@ -360,11 +383,14 @@ class FakeTensor(torch.Tensor):
|
|||
@staticmethod
|
||||
def from_tensor(t, fake_mode):
|
||||
existing_device = t.device
|
||||
# TODO: this should use meta converter
|
||||
return FakeTensor(fake_mode, t.to(device="meta"), existing_device)
|
||||
|
||||
# TODO: resolve error in default __repr__
|
||||
def __repr__(self):
|
||||
return f"FakeTensor({self.fake_device}, {self.size()}, {self.dtype})"
|
||||
with in_kernel_invocation_manager(self.fake_mode):
|
||||
self_repr = super().__repr__()
|
||||
return f"FakeTensor({self.fake_mode}, {self_repr}, {self.fake_device})"
|
||||
|
||||
def stride(self):
|
||||
if self.has_sym_ints:
|
||||
|
|
@ -404,6 +430,14 @@ class FakeTensor(torch.Tensor):
|
|||
return torch.device("meta")
|
||||
else:
|
||||
return args[0].fake_device
|
||||
# Need this to handle infinite recursion with sparse tensors.
|
||||
# Sparse tensors have custom stride policy which means that
|
||||
# they will dispatch here on dispatch, and we need to trigger
|
||||
# the default behavior.
|
||||
# TODO: when we get other tensor types online they will also
|
||||
# need to get entries here.
|
||||
elif func == torch.ops.aten.stride.default:
|
||||
return None
|
||||
|
||||
# Because fake mode can return NotImplemented (if it sees a subclass
|
||||
# it doesn't know how to deal with), this test here is important
|
||||
|
|
@ -489,9 +523,10 @@ class FakeTensor(torch.Tensor):
|
|||
|
||||
|
||||
class FakeTensorMode(TorchDispatchMode):
|
||||
def __init__(self, allow_fallback_kernels=True):
|
||||
def __init__(self, *, allow_fallback_kernels=True, allow_meta=False):
|
||||
self.allow_fallback_kernels = allow_fallback_kernels
|
||||
self.fake_tensor_converter = FakeTensorConverter()
|
||||
self.allow_meta = allow_meta
|
||||
|
||||
# [in_kernel_invocation]
|
||||
# when FakeTensor is invoked in user code, .device should return
|
||||
|
|
@ -637,7 +672,9 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
except NotImplementedError as not_implemented_error:
|
||||
if not self.allow_fallback_kernels:
|
||||
raise not_implemented_error
|
||||
r = run_fallback_kernel(func, args, kwargs, not_implemented_error)
|
||||
return run_fallback_kernel(
|
||||
self, func, args, kwargs, not_implemented_error
|
||||
)
|
||||
|
||||
# TODO: handle non-kwarg devices
|
||||
assert func not in _device_not_kwarg_ops, f"NYI: {func}"
|
||||
|
|
@ -666,7 +703,8 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
return self.fake_tensor_converter(self, tensor)
|
||||
|
||||
|
||||
def run_fallback_kernel(func, args, kwargs, orig_not_implemented_exception):
|
||||
# NB: returns fake tensors
|
||||
def run_fallback_kernel(fake_mode, func, args, kwargs, orig_not_implemented_exception):
|
||||
# these should all be supported, just to be safe
|
||||
# avoid fallback for operators which inplace modify metadata
|
||||
# because the input fake tensors would be umodified
|
||||
|
|
@ -679,6 +717,8 @@ def run_fallback_kernel(func, args, kwargs, orig_not_implemented_exception):
|
|||
def to_real_tensor(e):
|
||||
if isinstance(e, FakeTensor):
|
||||
out = torch.zeros_like(e, device=e.fake_device)
|
||||
if e.is_sparse:
|
||||
out._coalesced_(e.is_coalesced())
|
||||
inp_impls[id(out)] = e
|
||||
return out
|
||||
return e
|
||||
|
|
@ -693,7 +733,8 @@ def run_fallback_kernel(func, args, kwargs, orig_not_implemented_exception):
|
|||
|
||||
for e in tree_flatten((args, kwargs))[0]:
|
||||
if isinstance(e, torch.Tensor):
|
||||
storages.add(e.storage()._cdata)
|
||||
if not e.is_sparse:
|
||||
storages.add(e.storage()._cdata)
|
||||
|
||||
# TODO: also check metadata change on inputs
|
||||
# proper aliasing/metadata relationship between outputs and inputs will
|
||||
|
|
@ -701,16 +742,20 @@ def run_fallback_kernel(func, args, kwargs, orig_not_implemented_exception):
|
|||
# input impl
|
||||
for e in tree_flatten(r)[0]:
|
||||
if id(e) not in inp_impls and (
|
||||
isinstance(e, torch.Tensor) and e.storage()._cdata in storages
|
||||
isinstance(e, torch.Tensor)
|
||||
and not e.is_sparse
|
||||
and e.storage()._cdata in storages
|
||||
):
|
||||
raise orig_not_implemented_exception
|
||||
|
||||
# the outputs which are are not reused from impls will be converted
|
||||
# to fake tensors later
|
||||
meta_converter = MetaConverter()
|
||||
|
||||
def map_out(e):
|
||||
return inp_impls.get(id(e), meta_converter(e))
|
||||
if isinstance(e, torch.Tensor):
|
||||
if id(e) in inp_impls:
|
||||
return inp_impls[id(e)]
|
||||
else:
|
||||
return fake_mode.fake_tensor_converter(fake_mode, e)
|
||||
else:
|
||||
return e
|
||||
|
||||
return tree_map(map_out, r)
|
||||
|
||||
|
|
|
|||
|
|
@ -94,7 +94,10 @@ class MetaConverter:
|
|||
# hold a weak ref to self, otherwise it will be kept alive
|
||||
# by the del_ten closure
|
||||
self_weak_ref = weakref.ref(self)
|
||||
weak_st = StorageWeakRef(t.storage())
|
||||
if t.is_sparse:
|
||||
weak_st = None
|
||||
else:
|
||||
weak_st = StorageWeakRef(t.storage())
|
||||
tensor_ref_key = WeakTensorRefKey(t)
|
||||
|
||||
def del_ten():
|
||||
|
|
@ -106,7 +109,7 @@ class MetaConverter:
|
|||
self_ref.tensor_memo.pop(tensor_ref_key, None)
|
||||
if weak_st and weak_st.expired():
|
||||
self_ref.storage_memo.pop(weak_st, None)
|
||||
else:
|
||||
elif weak_st is not None:
|
||||
# [expired-storages]
|
||||
# NB: even though the tensor has died,
|
||||
# the deallocation of its storage can take longer,
|
||||
|
|
@ -143,7 +146,25 @@ class MetaConverter:
|
|||
|
||||
if self.get_tensor_memo(t) is None:
|
||||
with torch.inference_mode(t.is_inference()):
|
||||
if t._is_view():
|
||||
if t.is_sparse:
|
||||
is_leaf = safe_is_leaf(t)
|
||||
r = torch.ops.aten._sparse_coo_tensor_with_dims(
|
||||
t.sparse_dim(),
|
||||
t.dense_dim(),
|
||||
t.shape,
|
||||
dtype=t.dtype,
|
||||
layout=torch.sparse_coo,
|
||||
device="meta",
|
||||
)
|
||||
r._coalesced_(t.is_coalesced())
|
||||
if t.requires_grad:
|
||||
r.requires_grad = True
|
||||
if t.requires_grad and not is_leaf:
|
||||
with torch.enable_grad():
|
||||
r = r.clone()
|
||||
r._coalesced_(t.is_coalesced())
|
||||
|
||||
elif t._is_view():
|
||||
# Construct views in two steps: recursively meta-fy their
|
||||
# base, and then create the view off that. NB: doing it
|
||||
# directly from storage is WRONG because this won't cause
|
||||
|
|
@ -211,10 +232,11 @@ class MetaConverter:
|
|||
if any(
|
||||
[
|
||||
t.is_sparse_csr,
|
||||
t.is_sparse,
|
||||
t.layout in [torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc],
|
||||
t.is_mkldnn,
|
||||
t.is_quantized,
|
||||
t.is_nested,
|
||||
t._is_view() and t._base is not None and t._base.is_sparse,
|
||||
torch._is_functional_tensor(t),
|
||||
# these are supported in meta conversion but the fallbacks
|
||||
# don't work
|
||||
|
|
|
|||
|
|
@ -409,7 +409,10 @@ def _str_intern(inp, *, tensor_contents=None):
|
|||
)
|
||||
if self.is_sparse:
|
||||
suffixes.append("size=" + str(tuple(self.shape)))
|
||||
suffixes.append("nnz=" + str(self._nnz()))
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
|
||||
if not self.is_meta and not isinstance(self, FakeTensor):
|
||||
suffixes.append("nnz=" + str(self._nnz()))
|
||||
if not has_default_dtype:
|
||||
suffixes.append("dtype=" + str(self.dtype))
|
||||
if not custom_contents_provided:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user