mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Fix issue with namedtuple slicing (#163351)
Fixes #163253 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163351 Approved by: https://github.com/williamwen42, https://github.com/mlazos
This commit is contained in:
parent
093f0642aa
commit
8225a26835
|
|
@ -128,6 +128,20 @@ class LstmModule(torch.nn.Module):
|
|||
class CPUReproTests(TestCase):
|
||||
common = check_model
|
||||
|
||||
def test_torch_linalg_qr_tuple_slice(self):
|
||||
def fn(x):
|
||||
return torch.linalg.qr(x)[:1]
|
||||
|
||||
x = torch.randn(4, 4)
|
||||
compiled = torch.compile(fn, backend="inductor")
|
||||
|
||||
expected = fn(x)
|
||||
actual = compiled(x)
|
||||
|
||||
self.assertIsInstance(actual, tuple)
|
||||
self.assertEqual(len(actual), 1)
|
||||
torch.testing.assert_close(actual[0], expected[0])
|
||||
|
||||
@skipIfRocm
|
||||
def test_conv_stride_constraints(self):
|
||||
for fmt in [torch.contiguous_format, torch.channels_last]:
|
||||
|
|
|
|||
|
|
@ -1317,6 +1317,15 @@ class NamedTupleVariable(TupleVariable):
|
|||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
|
||||
if isinstance(arg, SliceVariable):
|
||||
# slicing a namedtuple produces a tuple
|
||||
return TupleVariable(
|
||||
self.items[arg.as_python_constant()],
|
||||
source=None,
|
||||
)
|
||||
return super().getitem_const(tx, arg)
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||
def check_and_create_method():
|
||||
method = inspect.getattr_static(self.tuple_cls, name, None)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user