mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[NJT] add aten.to.dtype support (#134164)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134164 Approved by: https://github.com/davidberard98
This commit is contained in:
parent
b6abac68ec
commit
44fa9f991c
|
|
@ -5950,6 +5950,22 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
|
||||||
self.assertFalse(clone.is_contiguous())
|
self.assertFalse(clone.is_contiguous())
|
||||||
check_nt_equality(detached, transposed)
|
check_nt_equality(detached, transposed)
|
||||||
|
|
||||||
|
def test_to_dtype(self, device):
|
||||||
|
nt = random_nt_from_dims(
|
||||||
|
[2, None, 3], device, torch.float32, layout=torch.jagged
|
||||||
|
)
|
||||||
|
nt_after = nt.to(torch.float64)
|
||||||
|
self.assertEqual(torch.float32, nt.dtype)
|
||||||
|
self.assertEqual(torch.float64, nt_after.dtype)
|
||||||
|
self.assertEqual(torch.float64, nt_after.values().dtype)
|
||||||
|
self.assertEqual(torch.int64, nt_after.offsets().dtype)
|
||||||
|
|
||||||
|
noncontiguous_nt = nt.transpose(1, 2)
|
||||||
|
noncontiguous_nt_after = noncontiguous_nt.to(torch.bfloat16)
|
||||||
|
self.assertEqual(torch.bfloat16, noncontiguous_nt_after.dtype)
|
||||||
|
self.assertEqual(torch.bfloat16, noncontiguous_nt_after.values().dtype)
|
||||||
|
self.assertEqual(torch.int64, noncontiguous_nt_after.offsets().dtype)
|
||||||
|
|
||||||
def test_to_copy(self, device):
|
def test_to_copy(self, device):
|
||||||
nt = torch.nested.nested_tensor(
|
nt = torch.nested.nested_tensor(
|
||||||
[
|
[
|
||||||
|
|
|
||||||
|
|
@ -490,6 +490,17 @@ def linear_backward_default(func, *args, **kwargs):
|
||||||
return (ds, dw, db)
|
return (ds, dw, db)
|
||||||
|
|
||||||
|
|
||||||
|
@register_jagged_func(torch.ops.aten.to.dtype, "input: jt_all, dtype: any")
|
||||||
|
def to_dtype(func, *args, **kwargs):
|
||||||
|
_, new_kwargs = normalize_function(
|
||||||
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
|
)
|
||||||
|
|
||||||
|
inp = new_kwargs.pop("input")
|
||||||
|
|
||||||
|
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
|
||||||
|
|
||||||
|
|
||||||
@register_jagged_func(torch.ops.aten._to_copy.default, "self: jt_all")
|
@register_jagged_func(torch.ops.aten._to_copy.default, "self: jt_all")
|
||||||
def to_copy_default(func, *args, **kwargs):
|
def to_copy_default(func, *args, **kwargs):
|
||||||
from .nested_tensor import _tensor_symint_registry
|
from .nested_tensor import _tensor_symint_registry
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user