Fix triangular_solve meta function out parameter names. (#140186)

This PR replaces the parameter names specified in the `triangular_solve_meta`
function (specifically in its `@out_wrapper(...)` decorator) by those written in the
_native_functions.yaml_ file.

This name mismatch caused the operation to fail when using the meta device (see error
below):

```python
Traceback (most recent call last):
  File "examples/test.py", line 23, in <module>
    torch.triangular_solve(b.to("meta"), A.to("meta"), out=meta_out)
  File "torch/_decomp/__init__.py", line 100, in _fn
    return f(*args, **kwargs, out=None if is_none else out_kwargs)
  File "torch/_prims_common/wrappers.py", line 289, in _fn
    result = fn(*args, **kwargs)
TypeError: triangular_solve_meta() got an unexpected keyword argument 'X'
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140186
Approved by: https://github.com/ezyang
This commit is contained in:
Yukio Siraichi 2024-11-11 21:01:39 -03:00 committed by PyTorch MergeBot
parent 6a368b3fc5
commit c182c7ccfc
4 changed files with 21 additions and 3 deletions

View File

@ -6,7 +6,7 @@ import os
import numpy as np
from enum import Enum
from torch.overrides import resolve_name
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
from torch.utils._pytree import tree_map, tree_map_only, tree_flatten, tree_unflatten
from torch.utils import _pytree as pytree
from torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq, is_sparse_any
import torch.utils._python_dispatch
@ -1755,6 +1755,24 @@ class TestMeta(TestCase):
meta_tensor = torch.randn(1, device='meta')
meta_tensor.item()
def test_triangular_solve_out(self):
# Get what's the expected output for the given example.
A = torch.randn(2, 2).triu()
b = torch.randn(2, 3)
out = torch.triangular_solve(b, A)
# Call the function again, transforming every tensor input (including the out tensor)
# into a meta tensor.
meta_out = tree_map_only(torch.Tensor, lambda t: t.to("meta"), out)
torch.triangular_solve(b.to("meta"), A.to("meta"), out=meta_out)
self.assertEqual(out[0].shape, meta_out[0].shape)
self.assertEqual(out[0].dtype, meta_out[0].dtype)
self.assertEqual(out[1].shape, meta_out[1].shape)
self.assertEqual(out[1].dtype, meta_out[1].dtype)
instantiate_device_type_tests(TestMeta, globals())
def print_op_str_if_not_supported(op_str):

View File

@ -222,6 +222,7 @@ meta_consistency_out_dtype_mismatch_xfails = {
xfail("take"),
xfail("transpose_copy"),
xfail("tril"),
xfail("triangular_solve"),
xfail("triu"),
xfail("trunc"),
xfail("unfold_copy"),

View File

@ -2054,7 +2054,6 @@ out_symbolic_tensor_failures = {
xfail('scatter_add', ''),
xfail('scatter', ''),
xfail('take_along_dim', ''),
xfail('triangular_solve', ''),
# SymIntArrayRef expected to contain only concrete
xfail('ones', ''),

View File

@ -1495,7 +1495,7 @@ def linalg_solve_triangular_meta(
@register_meta(aten.triangular_solve)
@out_wrapper("solution", "cloned_coefficient")
@out_wrapper("X", "M")
def triangular_solve_meta(
self: Tensor,
A: Tensor,