Support remaining *_like factory functions for NJT (#144889)

Fixes #144761

This PR adds NJT impls for those *_like functions that were previously missing:
* `full_like()`
* `rand_like()`
* `randint_like()`

It also fixes a bug in existing *_like functions when a new device is specified. Fix is to also transfer `offsets` / `lengths` to the new device.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144889
Approved by: https://github.com/soulitzer
This commit is contained in:
Joel Schlosser 2025-01-27 13:32:24 -05:00 committed by PyTorch MergeBot
parent 3a23d75b37
commit 1ba1b7b597
2 changed files with 112 additions and 11 deletions

View File

@ -71,7 +71,7 @@ from torch.testing._internal.opinfo.core import (
SkipRule,
XFailRule,
)
from torch.testing._internal.opinfo.definitions.nested import njt_op_db
from torch.testing._internal.opinfo.definitions.nested import _sample_njts, njt_op_db
from torch.utils._pytree import tree_flatten, tree_map_only
from torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts
@ -6109,17 +6109,72 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
@parametrize(
"func", [torch.ones_like, torch.zeros_like], name_fn=lambda f: f.__name__
"func",
[
torch.empty_like,
torch.full_like,
torch.ones_like,
torch.rand_like,
torch.randint_like,
torch.randn_like,
torch.zeros_like,
],
name_fn=lambda f: f.__name__,
)
def test_like_value(self, func):
nt = random_nt_from_dims(
[2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged
)
nt_like = func(nt)
def test_like_value(self, func, device):
dtype = torch.float32 if func is not torch.randint_like else torch.int32
for nt in _sample_njts(device=device, dtype=dtype):
extra_kwarg_sets = [{}]
if func is torch.full_like:
extra_kwarg_sets = [{"fill_value": 4.2}]
elif func is torch.randint_like:
extra_kwarg_sets = [{"high": 5}, {"low": 4, "high": 9}]
for nt_ub in nt_like.unbind():
t_like = func(nt_ub)
self.assertEqual(nt_ub, t_like)
# only test changing dtype / device from CUDA -> CPU because CUDA might not be
# available when running this test for CPU
change_dtype_device_settings = (
[False, True] if "cuda" in device else [False]
)
for change_dtype_device in change_dtype_device_settings:
if change_dtype_device:
new_dtype = (
torch.float64 if func is not torch.randint_like else torch.int64
)
new_device = "cpu" if "cuda" in device else device
new_layout = torch.strided
for extra_kwargs in extra_kwarg_sets:
extra_kwargs.update(
{
"dtype": new_dtype,
"device": new_device,
"layout": new_layout,
}
)
for extra_kwargs in extra_kwarg_sets:
nt_like = func(nt, **extra_kwargs)
self.assertEqual(nt.shape, nt_like.shape)
if change_dtype_device:
self.assertNotEqual(nt.device, nt_like.device)
self.assertNotEqual(nt.device, nt_like.dtype)
# layout should be ignored since only torch.jagged is supported
self.assertEqual(torch.jagged, nt_like.layout)
else:
self.assertEqual(nt.device, nt_like.device)
self.assertEqual(nt.dtype, nt_like.dtype)
self.assertEqual(nt.layout, nt_like.layout)
self.assertEqual(nt.layout, torch.jagged)
# don't bother trying to compare random or empty values
if func not in [
torch.empty_like,
torch.rand_like,
torch.randn_like,
torch.randint_like,
]:
for nt_ub in nt_like.unbind():
t_like = func(nt_ub, **extra_kwargs)
self.assertEqual(nt_ub, t_like)
def test_noncontiguous_pointwise(self, device):
a = torch.randn(2, 3, 4, requires_grad=True, dtype=torch.float64, device=device)

View File

@ -690,6 +690,7 @@ register_jagged_func(torch.ops.aten.detach.default, "self: jt_all")(
torch.ops.aten.empty_like.default,
torch.ops.aten.ones_like.default,
torch.ops.aten.zeros_like.default,
torch.ops.aten.rand_like.default,
torch.ops.aten.randn_like.default,
],
"self: jt_all",
@ -706,7 +707,52 @@ def like_factory_default(func, *args, **kwargs):
# This should be set to strided for redispatching on values.
new_kwargs["layout"] = torch.strided
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
new_values = func(inp._values, **new_kwargs)
new_offsets = inp._offsets.to(device=new_values.device)
new_lengths = None
if inp._lengths is not None:
new_lengths = inp._lengths.to(device=new_values.device)
output_kwargs = extract_kwargs(inp)
if "offsets" in output_kwargs:
output_kwargs["offsets"] = new_offsets
if "lengths" in output_kwargs:
output_kwargs["lengths"] = new_lengths
if inp.device != new_values.device:
# Update the nested int registry to indicate that the ragged structure is the same
# between the two offsets / lengths on different devices.
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import (
FunctionalTensor,
mb_unwrap_functional_tensor,
)
from .nested_tensor import _tensor_symint_registry
ragged_source = inp._offsets if inp._lengths is None else inp._lengths
new_thing = new_offsets if new_lengths is None else new_lengths
if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
# Temporary hack until we have the union find
tgt = mb_unwrap_functional_tensor(new_thing)
src = mb_unwrap_functional_tensor(ragged_source)
tgt.nested_int_memo = src.nested_int_memo
else:
_tensor_symint_registry[new_thing] = _tensor_symint_registry[ragged_source]
return NestedTensor(new_values, **output_kwargs)
register_jagged_func(torch.ops.aten.full_like.default, "self: jt_all, fill_value: any")(
like_factory_default
)
register_jagged_func(torch.ops.aten.randint_like.default, "self: jt_all, high: any")(
like_factory_default
)
register_jagged_func(
torch.ops.aten.randint_like.low_dtype, "self: jt_all, low: any, high: any"
)(like_factory_default)
@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all")