Michael Lazos 2025-04-17 08:43:03 +00:00 committed by PyTorch MergeBot
parent 49c91b4be9
commit f29fe78cf2
2 changed files with 65 additions and 0 deletions

View File

@ -9045,6 +9045,65 @@ def ___make_guard_fn():
self.assertEqual(msg, "shape torch.Size([8, 8]) batch size 1.00")
self.assertEqual(res, img1 + torch.sin(img1))
def test_sourceless_namedtuple(self):
from collections import namedtuple
CustomDtype = namedtuple("CustomDtype", ["dtype", "higher_dtype"])
class CustomTensor(torch.Tensor):
_data: torch.Tensor
custom_dtype: CustomDtype
__torch_function__ = torch._C._disabled_torch_function_impl
__slots__ = [
"_data",
"custom_dtype",
]
def __new__(
cls,
data: torch.Tensor,
custom_dtype: CustomDtype,
):
self = torch.Tensor._make_wrapper_subclass(
cls,
data.size(),
strides=data.stride(),
storage_offset=data.storage_offset(),
dtype=custom_dtype.dtype,
layout=data.layout,
requires_grad=data.requires_grad,
device=data.device,
)
self._data = data
self.custom_dtype = custom_dtype
return self
def __tensor_flatten__(self):
meta = {
"custom_dtype": self.custom_dtype,
}
return ["_data"], meta
@staticmethod
def __tensor_unflatten__(
inner_tensors: dict, metadata, outer_size, outer_stride
):
return CustomTensor(
inner_tensors["_data"],
metadata["custom_dtype"],
)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs={}):
return func(*args, **kwargs)
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
y = CustomTensor(x, CustomDtype(torch.float32, torch.bfloat16))
return y, y.custom_dtype
fn(torch.ones(2, 2, device="cpu"))
# Compiling autograd.Function traces fwd function twice, but the same unbacked symints were not identified
# as the same across the two tracings. This is an unlikely situation in real use cases, so we add another
# `test_validate_outputs_unbacked_by_custom_op` to mitigate it and keep this one as expected failure

View File

@ -3324,6 +3324,12 @@ class SourcelessBuilder:
)
elif isinstance(value, types.GenericAlias):
return TypingVariable(value)
elif is_namedtuple(value):
output = [
SourcelessBuilder.create(tx, getattr(value, name))
for name in namedtuple_fields(type(value))
]
return NamedTupleVariable(output, tuple_cls=type(value))
unimplemented_v2(
gb_type="Unexpected type in sourceless builder",
context=f"{value_type.__module__}.{value_type.__qualname__}",