[aotd] coerce_same_metadata_as_tangent with expected_type for e.g.AsyncCollectiveTensor (#139095)

Based on discussion here: https://github.com/pytorch/pytorch/pull/138731

Introducing ability for subclass implement type convertion to expected_type.
```
    def __coerce_same_metadata_as_tangent__(
        self, expected_metadata: Any, expected_type: Optional[Type] = None
    ):
```
Here if `expected_type=None` means `SubclassClass` is expected.

E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case
`expected_type=Tensor` will be called during runtime

Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139095
Approved by: https://github.com/bdhirsh
This commit is contained in:
IvanKobzarev 2024-11-06 05:30:43 -08:00 committed by PyTorch MergeBot
parent 8d3d47e439
commit 781c68c865
6 changed files with 119 additions and 4 deletions

View File

@ -652,7 +652,6 @@ def forward(self, primals_1):
return (sin_1, primals_1, wait_tensor)""", return (sin_1, primals_1, wait_tensor)""",
) )
@unittest.expectedFailure
@skipIfTorchDynamo() @skipIfTorchDynamo()
def test_unwrap_async_collective_tensor_tangent(self): def test_unwrap_async_collective_tensor_tangent(self):
from torch.distributed._functional_collectives import AsyncCollectiveTensor from torch.distributed._functional_collectives import AsyncCollectiveTensor

View File

@ -78,6 +78,7 @@ from torch.testing._internal.optests import (
_test_aot_autograd_forwards_backwards_helper, _test_aot_autograd_forwards_backwards_helper,
aot_autograd_check, aot_autograd_check,
) )
from torch.testing._internal.subclasses import WrapperSubclass
from torch.testing._internal.two_tensor import TwoTensor, TwoTensorMode from torch.testing._internal.two_tensor import TwoTensor, TwoTensorMode
@ -6227,6 +6228,32 @@ metadata incorrectly.
): ):
y2.backward(gradient=torch.randn(2, 3)) y2.backward(gradient=torch.randn(2, 3))
def test_tangent_type_coercion(self):
def fn(x):
return x.clone()
ref_y = fn(WrapperSubclass(torch.randn(2, 3, requires_grad=True)))
ref_y.sum().backward()
fn_comp = torch.compile(fn, fullgraph=True)
x = TwoTensor(
torch.randn(2, 3, requires_grad=True), torch.randn(2, 3, requires_grad=True)
)
y = fn_comp(x)
y.backward(gradient=TwoTensor(torch.randn(2, 3), torch.randn(2, 3)))
x2 = TwoTensor(
torch.randn(2, 3, requires_grad=True), torch.randn(2, 3, requires_grad=True)
)
y2 = fn_comp(x2)
# Test coercion WrapperSubclass -> TwoTensor
y2.backward(gradient=WrapperSubclass(torch.randn(2, 3)))
y3 = torch.compile(fn, fullgraph=True)(torch.randn(2, 3, requires_grad=True))
# Test coercion WrapperSubclass -> Tensor
y3.backward(gradient=WrapperSubclass(torch.randn(2, 3)))
@torch._inductor.config.patch({"freezing": True}) @torch._inductor.config.patch({"freezing": True})
def test_inductor_freezing_with_subclasses(self): def test_inductor_freezing_with_subclasses(self):
class M(torch.nn.Module): class M(torch.nn.Module):

View File

@ -1489,7 +1489,7 @@ class AOTDispatchAutograd:
# Backward Compatibility, as some Subclass impls can have original 1-arg function. # Backward Compatibility, as some Subclass impls can have original 1-arg function.
return x.__coerce_same_metadata_as_tangent__(expected_meta) return x.__coerce_same_metadata_as_tangent__(expected_meta)
return None return x.__coerce_same_metadata_as_tangent__(expected_meta, expected_type)
# Coerce to expected type and metadata # Coerce to expected type and metadata
orig_x = x orig_x = x

View File

@ -2,7 +2,7 @@
import contextlib import contextlib
import sys import sys
import warnings import warnings
from typing import cast, List, Optional, Tuple, TYPE_CHECKING, Union from typing import Any, cast, List, Optional, Tuple, Type, TYPE_CHECKING, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -601,6 +601,14 @@ class AsyncCollectiveTensor(torch.Tensor):
elem = inner_tensors["elem"] elem = inner_tensors["elem"]
return AsyncCollectiveTensor(elem) return AsyncCollectiveTensor(elem)
def __coerce_same_metadata_as_tangent__(
self, expected_metadata: Any, expected_type: Optional[Type] = None
):
if expected_type is not torch.Tensor:
return None
return self.trigger_wait()
def __repr__(self) -> str: # type: ignore[override] def __repr__(self) -> str: # type: ignore[override]
return f"AsyncCollectiveTensor({self.trigger_wait()})" return f"AsyncCollectiveTensor({self.trigger_wait()})"

View File

@ -325,7 +325,10 @@ class DTensor(torch.Tensor):
] ]
return self.redistribute(device_mesh=self.device_mesh, placements=placements) return self.redistribute(device_mesh=self.device_mesh, placements=placements)
def __coerce_same_metadata_as_tangent__(self, flatten_spec): def __coerce_same_metadata_as_tangent__(self, flatten_spec, expected_type=None):
if expected_type is not None:
return None
(spec, _) = flatten_spec # Result of tensor_flatten() (spec, _) = flatten_spec # Result of tensor_flatten()
return self.redistribute( return self.redistribute(
device_mesh=self.device_mesh, device_mesh=self.device_mesh,

View File

@ -0,0 +1,78 @@
# mypy: ignore-errors
from typing import Any, Optional, Type
import torch
import torch.utils._pytree as pytree
from torch._subclasses.fake_tensor import is_fake
from torch.testing._internal.two_tensor import TwoTensor
from torch.utils._python_dispatch import return_and_correct_aliasing
class WrapperSubclass(torch.Tensor):
@staticmethod
def __new__(cls, a, outer_size=None, outer_stride=None):
if outer_size is None:
outer_size = a.size()
if outer_stride is None:
outer_stride = a.stride()
kwargs = {}
kwargs["strides"] = a.stride()
kwargs["storage_offset"] = a.storage_offset()
kwargs["device"] = a.device
kwargs["layout"] = a.layout
kwargs["requires_grad"] = a.requires_grad
kwargs["dtype"] = a.dtype
out = torch.Tensor._make_wrapper_subclass(cls, a.size(), **kwargs)
return out
def __init__(self, a, outer_size=None, outer_stride=None):
self.a = a
def __repr__(self):
return f"WrapperSubclass({repr(self.a)})"
def __tensor_flatten__(self):
return ["a"], None
@staticmethod
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
assert meta is None
a = inner_tensors["a"]
if is_fake(a):
assert outer_size is not None
assert outer_stride is not None
return WrapperSubclass(a, outer_size, outer_stride)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if kwargs is None:
kwargs = {}
args_a = pytree.tree_map_only(WrapperSubclass, lambda x: x.a, args)
kwargs_a = pytree.tree_map_only(WrapperSubclass, lambda x: x.a, kwargs)
out_a = func(*args_a, **kwargs_a)
out_a_flat, spec = pytree.tree_flatten(out_a)
out_flat = [
WrapperSubclass(o_a) if isinstance(o_a, torch.Tensor) else o_a
for o_a in out_a_flat
]
out = pytree.tree_unflatten(out_flat, spec)
from torch._higher_order_ops.cond import cond_op
if func is cond_op:
return out
else:
return return_and_correct_aliasing(func, args, kwargs, out)
def __coerce_same_metadata_as_tangent__(
self, expected_metadata: Any, expected_type: Optional[Type] = None
):
if expected_type == type(self.a):
return self.a
elif expected_type is TwoTensor:
return TwoTensor(self.a, self.a.clone())
return None