mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support remaining *_like factory functions for NJT (#144889)
Fixes #144761 This PR adds NJT impls for those *_like functions that were previously missing: * `full_like()` * `rand_like()` * `randint_like()` It also fixes a bug in existing *_like functions when a new device is specified. Fix is to also transfer `offsets` / `lengths` to the new device. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144889 Approved by: https://github.com/soulitzer
This commit is contained in:
parent
3a23d75b37
commit
1ba1b7b597
|
|
@ -71,7 +71,7 @@ from torch.testing._internal.opinfo.core import (
|
|||
SkipRule,
|
||||
XFailRule,
|
||||
)
|
||||
from torch.testing._internal.opinfo.definitions.nested import njt_op_db
|
||||
from torch.testing._internal.opinfo.definitions.nested import _sample_njts, njt_op_db
|
||||
from torch.utils._pytree import tree_flatten, tree_map_only
|
||||
from torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts
|
||||
|
||||
|
|
@ -6109,17 +6109,72 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
|
|||
|
||||
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
||||
@parametrize(
|
||||
"func", [torch.ones_like, torch.zeros_like], name_fn=lambda f: f.__name__
|
||||
"func",
|
||||
[
|
||||
torch.empty_like,
|
||||
torch.full_like,
|
||||
torch.ones_like,
|
||||
torch.rand_like,
|
||||
torch.randint_like,
|
||||
torch.randn_like,
|
||||
torch.zeros_like,
|
||||
],
|
||||
name_fn=lambda f: f.__name__,
|
||||
)
|
||||
def test_like_value(self, func):
|
||||
nt = random_nt_from_dims(
|
||||
[2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged
|
||||
)
|
||||
nt_like = func(nt)
|
||||
def test_like_value(self, func, device):
|
||||
dtype = torch.float32 if func is not torch.randint_like else torch.int32
|
||||
for nt in _sample_njts(device=device, dtype=dtype):
|
||||
extra_kwarg_sets = [{}]
|
||||
if func is torch.full_like:
|
||||
extra_kwarg_sets = [{"fill_value": 4.2}]
|
||||
elif func is torch.randint_like:
|
||||
extra_kwarg_sets = [{"high": 5}, {"low": 4, "high": 9}]
|
||||
|
||||
for nt_ub in nt_like.unbind():
|
||||
t_like = func(nt_ub)
|
||||
self.assertEqual(nt_ub, t_like)
|
||||
# only test changing dtype / device from CUDA -> CPU because CUDA might not be
|
||||
# available when running this test for CPU
|
||||
change_dtype_device_settings = (
|
||||
[False, True] if "cuda" in device else [False]
|
||||
)
|
||||
for change_dtype_device in change_dtype_device_settings:
|
||||
if change_dtype_device:
|
||||
new_dtype = (
|
||||
torch.float64 if func is not torch.randint_like else torch.int64
|
||||
)
|
||||
new_device = "cpu" if "cuda" in device else device
|
||||
new_layout = torch.strided
|
||||
for extra_kwargs in extra_kwarg_sets:
|
||||
extra_kwargs.update(
|
||||
{
|
||||
"dtype": new_dtype,
|
||||
"device": new_device,
|
||||
"layout": new_layout,
|
||||
}
|
||||
)
|
||||
|
||||
for extra_kwargs in extra_kwarg_sets:
|
||||
nt_like = func(nt, **extra_kwargs)
|
||||
self.assertEqual(nt.shape, nt_like.shape)
|
||||
if change_dtype_device:
|
||||
self.assertNotEqual(nt.device, nt_like.device)
|
||||
self.assertNotEqual(nt.device, nt_like.dtype)
|
||||
# layout should be ignored since only torch.jagged is supported
|
||||
self.assertEqual(torch.jagged, nt_like.layout)
|
||||
else:
|
||||
self.assertEqual(nt.device, nt_like.device)
|
||||
self.assertEqual(nt.dtype, nt_like.dtype)
|
||||
self.assertEqual(nt.layout, nt_like.layout)
|
||||
self.assertEqual(nt.layout, torch.jagged)
|
||||
|
||||
# don't bother trying to compare random or empty values
|
||||
if func not in [
|
||||
torch.empty_like,
|
||||
torch.rand_like,
|
||||
torch.randn_like,
|
||||
torch.randint_like,
|
||||
]:
|
||||
for nt_ub in nt_like.unbind():
|
||||
t_like = func(nt_ub, **extra_kwargs)
|
||||
self.assertEqual(nt_ub, t_like)
|
||||
|
||||
def test_noncontiguous_pointwise(self, device):
|
||||
a = torch.randn(2, 3, 4, requires_grad=True, dtype=torch.float64, device=device)
|
||||
|
|
|
|||
|
|
@ -690,6 +690,7 @@ register_jagged_func(torch.ops.aten.detach.default, "self: jt_all")(
|
|||
torch.ops.aten.empty_like.default,
|
||||
torch.ops.aten.ones_like.default,
|
||||
torch.ops.aten.zeros_like.default,
|
||||
torch.ops.aten.rand_like.default,
|
||||
torch.ops.aten.randn_like.default,
|
||||
],
|
||||
"self: jt_all",
|
||||
|
|
@ -706,7 +707,52 @@ def like_factory_default(func, *args, **kwargs):
|
|||
# This should be set to strided for redispatching on values.
|
||||
new_kwargs["layout"] = torch.strided
|
||||
|
||||
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
|
||||
new_values = func(inp._values, **new_kwargs)
|
||||
new_offsets = inp._offsets.to(device=new_values.device)
|
||||
new_lengths = None
|
||||
if inp._lengths is not None:
|
||||
new_lengths = inp._lengths.to(device=new_values.device)
|
||||
output_kwargs = extract_kwargs(inp)
|
||||
if "offsets" in output_kwargs:
|
||||
output_kwargs["offsets"] = new_offsets
|
||||
if "lengths" in output_kwargs:
|
||||
output_kwargs["lengths"] = new_lengths
|
||||
|
||||
if inp.device != new_values.device:
|
||||
# Update the nested int registry to indicate that the ragged structure is the same
|
||||
# between the two offsets / lengths on different devices.
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch._subclasses.functional_tensor import (
|
||||
FunctionalTensor,
|
||||
mb_unwrap_functional_tensor,
|
||||
)
|
||||
|
||||
from .nested_tensor import _tensor_symint_registry
|
||||
|
||||
ragged_source = inp._offsets if inp._lengths is None else inp._lengths
|
||||
new_thing = new_offsets if new_lengths is None else new_lengths
|
||||
if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
|
||||
# Temporary hack until we have the union find
|
||||
tgt = mb_unwrap_functional_tensor(new_thing)
|
||||
src = mb_unwrap_functional_tensor(ragged_source)
|
||||
tgt.nested_int_memo = src.nested_int_memo
|
||||
else:
|
||||
_tensor_symint_registry[new_thing] = _tensor_symint_registry[ragged_source]
|
||||
|
||||
return NestedTensor(new_values, **output_kwargs)
|
||||
|
||||
|
||||
register_jagged_func(torch.ops.aten.full_like.default, "self: jt_all, fill_value: any")(
|
||||
like_factory_default
|
||||
)
|
||||
|
||||
register_jagged_func(torch.ops.aten.randint_like.default, "self: jt_all, high: any")(
|
||||
like_factory_default
|
||||
)
|
||||
|
||||
register_jagged_func(
|
||||
torch.ops.aten.randint_like.low_dtype, "self: jt_all, low: any, high: any"
|
||||
)(like_factory_default)
|
||||
|
||||
|
||||
@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user