mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix triangular_solve meta function out parameter names. (#140186)
This PR replaces the parameter names specified in the `triangular_solve_meta`
function (specifically in its `@out_wrapper(...)` decorator) by those written in the
_native_functions.yaml_ file.
This name mismatch caused the operation to fail when using the meta device (see error
below):
```python
Traceback (most recent call last):
File "examples/test.py", line 23, in <module>
torch.triangular_solve(b.to("meta"), A.to("meta"), out=meta_out)
File "torch/_decomp/__init__.py", line 100, in _fn
return f(*args, **kwargs, out=None if is_none else out_kwargs)
File "torch/_prims_common/wrappers.py", line 289, in _fn
result = fn(*args, **kwargs)
TypeError: triangular_solve_meta() got an unexpected keyword argument 'X'
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140186
Approved by: https://github.com/ezyang
This commit is contained in:
parent
6a368b3fc5
commit
c182c7ccfc
|
|
@ -6,7 +6,7 @@ import os
|
|||
import numpy as np
|
||||
from enum import Enum
|
||||
from torch.overrides import resolve_name
|
||||
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
||||
from torch.utils._pytree import tree_map, tree_map_only, tree_flatten, tree_unflatten
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq, is_sparse_any
|
||||
import torch.utils._python_dispatch
|
||||
|
|
@ -1755,6 +1755,24 @@ class TestMeta(TestCase):
|
|||
meta_tensor = torch.randn(1, device='meta')
|
||||
meta_tensor.item()
|
||||
|
||||
def test_triangular_solve_out(self):
|
||||
# Get what's the expected output for the given example.
|
||||
A = torch.randn(2, 2).triu()
|
||||
b = torch.randn(2, 3)
|
||||
out = torch.triangular_solve(b, A)
|
||||
|
||||
# Call the function again, transforming every tensor input (including the out tensor)
|
||||
# into a meta tensor.
|
||||
meta_out = tree_map_only(torch.Tensor, lambda t: t.to("meta"), out)
|
||||
torch.triangular_solve(b.to("meta"), A.to("meta"), out=meta_out)
|
||||
|
||||
self.assertEqual(out[0].shape, meta_out[0].shape)
|
||||
self.assertEqual(out[0].dtype, meta_out[0].dtype)
|
||||
|
||||
self.assertEqual(out[1].shape, meta_out[1].shape)
|
||||
self.assertEqual(out[1].dtype, meta_out[1].dtype)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestMeta, globals())
|
||||
|
||||
def print_op_str_if_not_supported(op_str):
|
||||
|
|
|
|||
|
|
@ -222,6 +222,7 @@ meta_consistency_out_dtype_mismatch_xfails = {
|
|||
xfail("take"),
|
||||
xfail("transpose_copy"),
|
||||
xfail("tril"),
|
||||
xfail("triangular_solve"),
|
||||
xfail("triu"),
|
||||
xfail("trunc"),
|
||||
xfail("unfold_copy"),
|
||||
|
|
|
|||
|
|
@ -2054,7 +2054,6 @@ out_symbolic_tensor_failures = {
|
|||
xfail('scatter_add', ''),
|
||||
xfail('scatter', ''),
|
||||
xfail('take_along_dim', ''),
|
||||
xfail('triangular_solve', ''),
|
||||
|
||||
# SymIntArrayRef expected to contain only concrete
|
||||
xfail('ones', ''),
|
||||
|
|
|
|||
|
|
@ -1495,7 +1495,7 @@ def linalg_solve_triangular_meta(
|
|||
|
||||
|
||||
@register_meta(aten.triangular_solve)
|
||||
@out_wrapper("solution", "cloned_coefficient")
|
||||
@out_wrapper("X", "M")
|
||||
def triangular_solve_meta(
|
||||
self: Tensor,
|
||||
A: Tensor,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user