mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Dynamo] Implement sourceless named tuple support (#151266)
Fixes https://github.com/pytorch/pytorch/issues/140903 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151266 Approved by: https://github.com/williamwen42, https://github.com/StrongerXi, https://github.com/anijain2305
This commit is contained in:
parent
49c91b4be9
commit
f29fe78cf2
|
|
@ -9045,6 +9045,65 @@ def ___make_guard_fn():
|
|||
self.assertEqual(msg, "shape torch.Size([8, 8]) batch size 1.00")
|
||||
self.assertEqual(res, img1 + torch.sin(img1))
|
||||
|
||||
def test_sourceless_namedtuple(self):
|
||||
from collections import namedtuple
|
||||
|
||||
CustomDtype = namedtuple("CustomDtype", ["dtype", "higher_dtype"])
|
||||
|
||||
class CustomTensor(torch.Tensor):
|
||||
_data: torch.Tensor
|
||||
custom_dtype: CustomDtype
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
__slots__ = [
|
||||
"_data",
|
||||
"custom_dtype",
|
||||
]
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
data: torch.Tensor,
|
||||
custom_dtype: CustomDtype,
|
||||
):
|
||||
self = torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
data.size(),
|
||||
strides=data.stride(),
|
||||
storage_offset=data.storage_offset(),
|
||||
dtype=custom_dtype.dtype,
|
||||
layout=data.layout,
|
||||
requires_grad=data.requires_grad,
|
||||
device=data.device,
|
||||
)
|
||||
self._data = data
|
||||
self.custom_dtype = custom_dtype
|
||||
return self
|
||||
|
||||
def __tensor_flatten__(self):
|
||||
meta = {
|
||||
"custom_dtype": self.custom_dtype,
|
||||
}
|
||||
return ["_data"], meta
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(
|
||||
inner_tensors: dict, metadata, outer_size, outer_stride
|
||||
):
|
||||
return CustomTensor(
|
||||
inner_tensors["_data"],
|
||||
metadata["custom_dtype"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs={}):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(x):
|
||||
y = CustomTensor(x, CustomDtype(torch.float32, torch.bfloat16))
|
||||
return y, y.custom_dtype
|
||||
|
||||
fn(torch.ones(2, 2, device="cpu"))
|
||||
|
||||
# Compiling autograd.Function traces fwd function twice, but the same unbacked symints were not identified
|
||||
# as the same across the two tracings. This is an unlikely situation in real use cases, so we add another
|
||||
# `test_validate_outputs_unbacked_by_custom_op` to mitigate it and keep this one as expected failure
|
||||
|
|
|
|||
|
|
@ -3324,6 +3324,12 @@ class SourcelessBuilder:
|
|||
)
|
||||
elif isinstance(value, types.GenericAlias):
|
||||
return TypingVariable(value)
|
||||
elif is_namedtuple(value):
|
||||
output = [
|
||||
SourcelessBuilder.create(tx, getattr(value, name))
|
||||
for name in namedtuple_fields(type(value))
|
||||
]
|
||||
return NamedTupleVariable(output, tuple_cls=type(value))
|
||||
unimplemented_v2(
|
||||
gb_type="Unexpected type in sourceless builder",
|
||||
context=f"{value_type.__module__}.{value_type.__qualname__}",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user