[dynamo] Fix issue with namedtuple slicing (#163351)

Fixes #163253

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163351
Approved by: https://github.com/williamwen42, https://github.com/mlazos
This commit is contained in:
Jason Ansel 2025-09-19 14:44:44 -07:00 committed by PyTorch MergeBot
parent 093f0642aa
commit 8225a26835
2 changed files with 23 additions and 0 deletions

View File

@ -128,6 +128,20 @@ class LstmModule(torch.nn.Module):
class CPUReproTests(TestCase):
common = check_model
def test_torch_linalg_qr_tuple_slice(self):
def fn(x):
return torch.linalg.qr(x)[:1]
x = torch.randn(4, 4)
compiled = torch.compile(fn, backend="inductor")
expected = fn(x)
actual = compiled(x)
self.assertIsInstance(actual, tuple)
self.assertEqual(len(actual), 1)
torch.testing.assert_close(actual[0], expected[0])
@skipIfRocm
def test_conv_stride_constraints(self):
for fmt in [torch.contiguous_format, torch.channels_last]:

View File

@ -1317,6 +1317,15 @@ class NamedTupleVariable(TupleVariable):
return super().call_method(tx, name, args, kwargs)
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
if isinstance(arg, SliceVariable):
# slicing a namedtuple produces a tuple
return TupleVariable(
self.items[arg.as_python_constant()],
source=None,
)
return super().getitem_const(tx, arg)
def var_getattr(self, tx: "InstructionTranslator", name):
def check_and_create_method():
method = inspect.getattr_static(self.tuple_cls, name, None)