[NJT] add aten.to.dtype support (#134164)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134164
Approved by: https://github.com/davidberard98
This commit is contained in:
yuqingj 2024-08-21 15:33:57 -07:00 committed by PyTorch MergeBot
parent b6abac68ec
commit 44fa9f991c
2 changed files with 27 additions and 0 deletions

View File

@ -5950,6 +5950,22 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
self.assertFalse(clone.is_contiguous()) self.assertFalse(clone.is_contiguous())
check_nt_equality(detached, transposed) check_nt_equality(detached, transposed)
def test_to_dtype(self, device):
nt = random_nt_from_dims(
[2, None, 3], device, torch.float32, layout=torch.jagged
)
nt_after = nt.to(torch.float64)
self.assertEqual(torch.float32, nt.dtype)
self.assertEqual(torch.float64, nt_after.dtype)
self.assertEqual(torch.float64, nt_after.values().dtype)
self.assertEqual(torch.int64, nt_after.offsets().dtype)
noncontiguous_nt = nt.transpose(1, 2)
noncontiguous_nt_after = noncontiguous_nt.to(torch.bfloat16)
self.assertEqual(torch.bfloat16, noncontiguous_nt_after.dtype)
self.assertEqual(torch.bfloat16, noncontiguous_nt_after.values().dtype)
self.assertEqual(torch.int64, noncontiguous_nt_after.offsets().dtype)
def test_to_copy(self, device): def test_to_copy(self, device):
nt = torch.nested.nested_tensor( nt = torch.nested.nested_tensor(
[ [

View File

@ -490,6 +490,17 @@ def linear_backward_default(func, *args, **kwargs):
return (ds, dw, db) return (ds, dw, db)
@register_jagged_func(torch.ops.aten.to.dtype, "input: jt_all, dtype: any")
def to_dtype(func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
@register_jagged_func(torch.ops.aten._to_copy.default, "self: jt_all") @register_jagged_func(torch.ops.aten._to_copy.default, "self: jt_all")
def to_copy_default(func, *args, **kwargs): def to_copy_default(func, *args, **kwargs):
from .nested_tensor import _tensor_symint_registry from .nested_tensor import _tensor_symint_registry