mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[aotd] coerce_same_metadata_as_tangent with expected_type for e.g.AsyncCollectiveTensor (#139095)
Based on discussion here: https://github.com/pytorch/pytorch/pull/138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/139095 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
8d3d47e439
commit
781c68c865
|
|
@ -652,7 +652,6 @@ def forward(self, primals_1):
|
||||||
return (sin_1, primals_1, wait_tensor)""",
|
return (sin_1, primals_1, wait_tensor)""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
@skipIfTorchDynamo()
|
@skipIfTorchDynamo()
|
||||||
def test_unwrap_async_collective_tensor_tangent(self):
|
def test_unwrap_async_collective_tensor_tangent(self):
|
||||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||||
|
|
|
||||||
|
|
@ -78,6 +78,7 @@ from torch.testing._internal.optests import (
|
||||||
_test_aot_autograd_forwards_backwards_helper,
|
_test_aot_autograd_forwards_backwards_helper,
|
||||||
aot_autograd_check,
|
aot_autograd_check,
|
||||||
)
|
)
|
||||||
|
from torch.testing._internal.subclasses import WrapperSubclass
|
||||||
from torch.testing._internal.two_tensor import TwoTensor, TwoTensorMode
|
from torch.testing._internal.two_tensor import TwoTensor, TwoTensorMode
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -6227,6 +6228,32 @@ metadata incorrectly.
|
||||||
):
|
):
|
||||||
y2.backward(gradient=torch.randn(2, 3))
|
y2.backward(gradient=torch.randn(2, 3))
|
||||||
|
|
||||||
|
def test_tangent_type_coercion(self):
|
||||||
|
def fn(x):
|
||||||
|
return x.clone()
|
||||||
|
|
||||||
|
ref_y = fn(WrapperSubclass(torch.randn(2, 3, requires_grad=True)))
|
||||||
|
ref_y.sum().backward()
|
||||||
|
|
||||||
|
fn_comp = torch.compile(fn, fullgraph=True)
|
||||||
|
|
||||||
|
x = TwoTensor(
|
||||||
|
torch.randn(2, 3, requires_grad=True), torch.randn(2, 3, requires_grad=True)
|
||||||
|
)
|
||||||
|
y = fn_comp(x)
|
||||||
|
y.backward(gradient=TwoTensor(torch.randn(2, 3), torch.randn(2, 3)))
|
||||||
|
|
||||||
|
x2 = TwoTensor(
|
||||||
|
torch.randn(2, 3, requires_grad=True), torch.randn(2, 3, requires_grad=True)
|
||||||
|
)
|
||||||
|
y2 = fn_comp(x2)
|
||||||
|
# Test coercion WrapperSubclass -> TwoTensor
|
||||||
|
y2.backward(gradient=WrapperSubclass(torch.randn(2, 3)))
|
||||||
|
|
||||||
|
y3 = torch.compile(fn, fullgraph=True)(torch.randn(2, 3, requires_grad=True))
|
||||||
|
# Test coercion WrapperSubclass -> Tensor
|
||||||
|
y3.backward(gradient=WrapperSubclass(torch.randn(2, 3)))
|
||||||
|
|
||||||
@torch._inductor.config.patch({"freezing": True})
|
@torch._inductor.config.patch({"freezing": True})
|
||||||
def test_inductor_freezing_with_subclasses(self):
|
def test_inductor_freezing_with_subclasses(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -1489,7 +1489,7 @@ class AOTDispatchAutograd:
|
||||||
# Backward Compatibility, as some Subclass impls can have original 1-arg function.
|
# Backward Compatibility, as some Subclass impls can have original 1-arg function.
|
||||||
return x.__coerce_same_metadata_as_tangent__(expected_meta)
|
return x.__coerce_same_metadata_as_tangent__(expected_meta)
|
||||||
|
|
||||||
return None
|
return x.__coerce_same_metadata_as_tangent__(expected_meta, expected_type)
|
||||||
|
|
||||||
# Coerce to expected type and metadata
|
# Coerce to expected type and metadata
|
||||||
orig_x = x
|
orig_x = x
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
import contextlib
|
import contextlib
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from typing import cast, List, Optional, Tuple, TYPE_CHECKING, Union
|
from typing import Any, cast, List, Optional, Tuple, Type, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
@ -601,6 +601,14 @@ class AsyncCollectiveTensor(torch.Tensor):
|
||||||
elem = inner_tensors["elem"]
|
elem = inner_tensors["elem"]
|
||||||
return AsyncCollectiveTensor(elem)
|
return AsyncCollectiveTensor(elem)
|
||||||
|
|
||||||
|
def __coerce_same_metadata_as_tangent__(
|
||||||
|
self, expected_metadata: Any, expected_type: Optional[Type] = None
|
||||||
|
):
|
||||||
|
if expected_type is not torch.Tensor:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self.trigger_wait()
|
||||||
|
|
||||||
def __repr__(self) -> str: # type: ignore[override]
|
def __repr__(self) -> str: # type: ignore[override]
|
||||||
return f"AsyncCollectiveTensor({self.trigger_wait()})"
|
return f"AsyncCollectiveTensor({self.trigger_wait()})"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -325,7 +325,10 @@ class DTensor(torch.Tensor):
|
||||||
]
|
]
|
||||||
return self.redistribute(device_mesh=self.device_mesh, placements=placements)
|
return self.redistribute(device_mesh=self.device_mesh, placements=placements)
|
||||||
|
|
||||||
def __coerce_same_metadata_as_tangent__(self, flatten_spec):
|
def __coerce_same_metadata_as_tangent__(self, flatten_spec, expected_type=None):
|
||||||
|
if expected_type is not None:
|
||||||
|
return None
|
||||||
|
|
||||||
(spec, _) = flatten_spec # Result of tensor_flatten()
|
(spec, _) = flatten_spec # Result of tensor_flatten()
|
||||||
return self.redistribute(
|
return self.redistribute(
|
||||||
device_mesh=self.device_mesh,
|
device_mesh=self.device_mesh,
|
||||||
|
|
|
||||||
78
torch/testing/_internal/subclasses.py
Normal file
78
torch/testing/_internal/subclasses.py
Normal file
|
|
@ -0,0 +1,78 @@
|
||||||
|
# mypy: ignore-errors
|
||||||
|
from typing import Any, Optional, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils._pytree as pytree
|
||||||
|
from torch._subclasses.fake_tensor import is_fake
|
||||||
|
from torch.testing._internal.two_tensor import TwoTensor
|
||||||
|
from torch.utils._python_dispatch import return_and_correct_aliasing
|
||||||
|
|
||||||
|
|
||||||
|
class WrapperSubclass(torch.Tensor):
|
||||||
|
@staticmethod
|
||||||
|
def __new__(cls, a, outer_size=None, outer_stride=None):
|
||||||
|
if outer_size is None:
|
||||||
|
outer_size = a.size()
|
||||||
|
if outer_stride is None:
|
||||||
|
outer_stride = a.stride()
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
kwargs["strides"] = a.stride()
|
||||||
|
kwargs["storage_offset"] = a.storage_offset()
|
||||||
|
kwargs["device"] = a.device
|
||||||
|
kwargs["layout"] = a.layout
|
||||||
|
kwargs["requires_grad"] = a.requires_grad
|
||||||
|
kwargs["dtype"] = a.dtype
|
||||||
|
out = torch.Tensor._make_wrapper_subclass(cls, a.size(), **kwargs)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def __init__(self, a, outer_size=None, outer_stride=None):
|
||||||
|
self.a = a
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"WrapperSubclass({repr(self.a)})"
|
||||||
|
|
||||||
|
def __tensor_flatten__(self):
|
||||||
|
return ["a"], None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
|
||||||
|
assert meta is None
|
||||||
|
a = inner_tensors["a"]
|
||||||
|
if is_fake(a):
|
||||||
|
assert outer_size is not None
|
||||||
|
assert outer_stride is not None
|
||||||
|
return WrapperSubclass(a, outer_size, outer_stride)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||||
|
if kwargs is None:
|
||||||
|
kwargs = {}
|
||||||
|
args_a = pytree.tree_map_only(WrapperSubclass, lambda x: x.a, args)
|
||||||
|
|
||||||
|
kwargs_a = pytree.tree_map_only(WrapperSubclass, lambda x: x.a, kwargs)
|
||||||
|
|
||||||
|
out_a = func(*args_a, **kwargs_a)
|
||||||
|
out_a_flat, spec = pytree.tree_flatten(out_a)
|
||||||
|
out_flat = [
|
||||||
|
WrapperSubclass(o_a) if isinstance(o_a, torch.Tensor) else o_a
|
||||||
|
for o_a in out_a_flat
|
||||||
|
]
|
||||||
|
out = pytree.tree_unflatten(out_flat, spec)
|
||||||
|
from torch._higher_order_ops.cond import cond_op
|
||||||
|
|
||||||
|
if func is cond_op:
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
return return_and_correct_aliasing(func, args, kwargs, out)
|
||||||
|
|
||||||
|
def __coerce_same_metadata_as_tangent__(
|
||||||
|
self, expected_metadata: Any, expected_type: Optional[Type] = None
|
||||||
|
):
|
||||||
|
if expected_type == type(self.a):
|
||||||
|
return self.a
|
||||||
|
elif expected_type is TwoTensor:
|
||||||
|
return TwoTensor(self.a, self.a.clone())
|
||||||
|
|
||||||
|
return None
|
||||||
Loading…
Reference in New Issue
Block a user