[pt2] add meta function for solve_triangular (#100829)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100829
Approved by: https://github.com/ezyang
This commit is contained in:
Nikita Karetnikov 2023-05-07 23:17:37 +02:00 committed by PyTorch MergeBot
parent cd8b82e5c6
commit 1e591a8b64
5 changed files with 146 additions and 23 deletions

View File

@ -2490,7 +2490,6 @@ symbolic_aot_autograd_failures = {
xfail('index_fill', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('kron', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('kthvalue', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('linalg.cholesky_ex', ''), # could not find kernel for aten.linalg_solve_triangular.default
xfail('linalg.det', ''), # aten._linalg_det.default - couldn't find symbolic meta function/decomposition
xfail('linalg.det', 'singular'), # aten._linalg_det.default - couldn't find symbolic meta function/deco...
xfail('linalg.eigh', ''), # aten._linalg_eigh.default - couldn't find symbolic meta function/decomposition
@ -2511,7 +2510,6 @@ symbolic_aot_autograd_failures = {
xfail('linalg.slogdet', ''), # aten._linalg_slogdet.default - couldn't find symbolic meta function/decom...
xfail('linalg.solve', ''), # aten._linalg_solve_ex.default - couldn't find symbolic meta function/decomp...
xfail('linalg.solve_ex', ''), # aten._linalg_solve_ex.default - couldn't find symbolic meta function/dec...
xfail('linalg.solve_triangular', ''), # aten.linalg_solve_triangular.default - couldn't find symbolic me...
xfail('linalg.tensorinv', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('linalg.tensorsolve', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('linalg.vander', ''), # Cannot call sizes() on tensor with symbolic sizes/strides

View File

@ -598,7 +598,6 @@ meta_function_expected_failures = {
torch.functional.istft : {f64, c64, c128, f32},
torch.geqrf : {f64, c64, c128, f32},
torch.linalg.householder_product : {f64, c64, c128, f32},
torch.linalg.solve_triangular : {f64, c64, c128, f32},
torch.masked_select : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
torch.matrix_exp : {f64, c128, c64, bf16, f32},
torch.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
@ -718,7 +717,6 @@ meta_function_device_expected_failures['cuda'] = {
torch.histc: {i16, i32, i64, i8}, # aten::histc, aten::histc.out
torch.kthvalue: {f16}, # aten::kthvalue.values
torch.linalg.householder_product: {f32, f64}, # aten::linalg_householder_product, aten::linalg_householder_product.out
torch.linalg.solve_triangular: {f32, f64}, # aten::linalg_solve_triangular, aten::linalg_solve_triangular.out
torch.matrix_exp: {f16}, # aten::linalg_matrix_exp
torch.median: {f16}, # aten::median, aten::median.dim_values
torch.multinomial: {f16}, # aten::multinomial, aten::multinomial.out
@ -830,8 +828,6 @@ meta_dispatch_expected_failures = {
aten.linalg_householder_product.out : {c64, c128, f64, f32},
aten.linalg_lstsq.default : {c64, c128, f64, f32},
aten.linalg_matrix_exp.default : {c64, bf16, f32, f64, c128},
aten.linalg_solve_triangular.default : {c64, c128, f64, f32},
aten.linalg_solve_triangular.out : {c64, c128, f64, f32},
aten.masked_select.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
aten.masked_select.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
aten.nonzero.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, c32, b8, i16, u8},
@ -929,8 +925,6 @@ meta_dispatch_device_expected_failures['cuda'] = {
aten.linalg_householder_product.default: {f32, f64}, # aten::linalg_householder_product
aten.linalg_householder_product.out: {f32, f64}, # aten::linalg_householder_product.out
aten.linalg_matrix_exp.default: {f16}, # aten::linalg_matrix_exp
aten.linalg_solve_triangular.default: {f32, f64}, # aten::linalg_solve_triangular
aten.linalg_solve_triangular.out: {f32, f64}, # aten::linalg_solve_triangular.out
aten.log_sigmoid_forward.default: {bf16, f16, f64, f32},
aten.log_sigmoid_forward.output : {bf16, f16, f64, f32}, # aten::log_sigmoid_forward.output
aten.max_pool3d_with_indices.default: {bf16, f16}, # aten::max_pool3d_with_indices

View File

@ -1458,7 +1458,6 @@ symbolic_tensor_failures = {
xfail('linalg.slogdet', ''), # aten._linalg_slogdet.default - couldn't find symbolic meta function/decomposition
xfail('linalg.solve', ''), # aten._linalg_solve_ex.default - couldn't find symbolic meta function/decomposition
xfail('linalg.solve_ex', ''), # aten._linalg_solve_ex.default - couldn't find symbolic meta function/decomposition
xfail('linalg.solve_triangular', ''), # aten.linalg_solve_triangular.default - couldn't find symbolic meta function/de...
xfail('linalg.tensorinv', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('linalg.tensorsolve', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('linalg.vander', ''), # aten.size.default - couldn't find symbolic meta function/decomposition

View File

@ -18,7 +18,12 @@ from torch._prims_common import (
TensorLike,
)
from torch._prims_common.wrappers import _maybe_resize_out, _safe_copy_out, out_wrapper
from torch._prims_common.wrappers import (
_maybe_resize_out,
_resize_output_check,
_safe_copy_out,
out_wrapper,
)
from torch._refs import _broadcast_shapes
from torch.utils._pytree import tree_map
@ -315,6 +320,48 @@ def squareCheckInputs(self: Tensor, f_name: str):
), f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices"
# Validates input shapes and devices
# for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve)
# From aten/src/ATen/native/LinearAlgebraUtils.h
def linearSolveCheckInputs(
self: Tensor,
A: Tensor,
name: str,
):
check(
self.device == A.device,
lambda: (
f"Expected b and A to be on the same device, but found b on "
f"{self.device} and A on {A.device} instead."
),
)
check(
self.dtype == A.dtype,
lambda: (
f"Expected b and A to have the same dtype, but found b of type "
f"{self.dtype} and A of type {A.dtype} instead."
),
)
check(
A.size(-1) == A.size(-2),
lambda: (
f"A must be batches of square matrices, "
f"but they are {A.size(-2)} by {A.size(-1)} matrices"
),
)
check(
A.size(-1) == self.size(-2),
lambda: (
f"Incompatible matrix sizes for {name}: each A "
f"matrix is {A.size(-1)} by {A.size(-1)}"
f" but each b matrix is {self.size(-2)} by {self.size(-1)}"
),
)
# From aten/src/ATen/native/LinearAlgebraUtils.h
def checkFloatingOrComplex(
t: Tensor, f_name: str, allow_low_precision_dtypes: bool = True
@ -339,6 +386,24 @@ def checkIsMatrix(A: Tensor, f_name: str, arg_name: str = "A"):
)
def checkInputsSolver(
A: Tensor,
B: Tensor,
left: bool,
f_name: str,
):
squareCheckInputs(A, f_name)
checkIsMatrix(B, f_name)
check(
A.size(-2) == B.size(-2) if left else A.size(-1) == B.size(-1),
lambda: (
f"{f_name}: Incompatible shapes of A and B for the equation "
f"{'AX = B' if left else 'XA = B'}"
f" ({A.size(-2)}x{A.size(-1)} and {B.size(-2)}x{B.size(-1)})"
),
)
def checkUplo(uplo: str):
uplo_uppercase = uplo.upper()
assert (
@ -483,6 +548,66 @@ def _linalg_svd_meta(
return U, S, V
def _linalg_broadcast_batch_dims(
arg1: Tensor, arg2: Tensor
) -> Tuple[List[int], List[int]]:
# broadcast the batch dimensions of arg1 and arg2.
arg1_batch_sizes = arg1.shape[:-2]
arg2_batch_sizes = arg2.shape[:-2]
expand_batch_portion = _broadcast_shapes(arg1_batch_sizes, arg2_batch_sizes)
arg1_expand_size = list(expand_batch_portion)
arg1_expand_size += [arg1.size(-2), arg1.size(-1)]
arg2_expand_size = list(expand_batch_portion)
arg2_expand_size += [arg2.size(-2), arg2.size(-1)]
return arg1_expand_size, arg2_expand_size
def _linalg_broadcast_batch_dims_name(
arg1: Tensor, arg2: Tensor, name: Optional[str]
) -> Tuple[Tensor, Tensor]:
# If there's no name we assume we don't want to check the errors
if name:
linearSolveCheckInputs(arg1, arg2, name)
arg1_expand_size, arg2_expand_size = _linalg_broadcast_batch_dims(arg1, arg2)
arg1_broadcasted = (
arg1 if arg1_expand_size == arg1.shape else arg1.expand(arg1_expand_size)
)
arg2_broadcasted = (
arg2 if arg2_expand_size == arg2.shape else arg2.expand(arg2_expand_size)
)
return arg1_broadcasted, arg2_broadcasted
@register_meta([aten.linalg_solve_triangular.default, aten.linalg_solve_triangular.out])
def linalg_solve_triangular_meta(
A: Tensor,
B: Tensor,
*,
upper: bool,
left: bool = True,
unitriangular: bool = False,
out: Tensor = None,
) -> Tensor:
if out is None:
out = A.new_empty([0])
assert isinstance(out, TensorLike)
checkInputsSolver(A, B, left, "linalg.solve_triangular")
B_, A_ = _linalg_broadcast_batch_dims_name(B, A, None)
avoid_copy_A = A_.transpose(-2, -1).is_contiguous() and A_.is_conj()
if avoid_copy_A:
out = _maybe_resize_out(out, B_.shape)
else:
# reimplementation of resize_output with result F-contig
if _resize_output_check(out, B_.shape):
out.resize_(B_.transpose(-2, -1).shape)
out.transpose_(-2, -1)
return out # type: ignore[return-value]
# From aten/src/ATen/native/LinearAlgebra.cpp
@register_meta(aten._linalg_det.default)
def _linalg_det_meta(A):

View File

@ -139,22 +139,29 @@ class elementwise_type_promotion_wrapper:
return _fn
# TODO: handle tuples of tensors
def _maybe_resize_out(out: TensorLikeType, shape: ShapeType):
# Returns True if resize is necessary
def _resize_output_check(out: TensorLikeType, shape: ShapeType):
# If the shapes are correct there's nothing to do
if utils.same_shape(out.shape, shape):
return out
else:
if out.numel() != 0:
msg = (
f"An output with one or more elements was resized since it had shape {str(out.shape)} "
"which does not match the required output shape {str(shape)}. "
"This behavior is deprecated, and in a future PyTorch release outputs will not "
"be resized unless they have zero elements. "
"You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)."
)
warnings.warn(msg)
return False
if out.numel() != 0:
msg = (
f"An output with one or more elements was resized since it had shape {str(out.shape)} "
"which does not match the required output shape {str(shape)}. "
"This behavior is deprecated, and in a future PyTorch release outputs will not "
"be resized unless they have zero elements. "
"You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)."
)
warnings.warn(msg)
return True
# TODO: handle tuples of tensors
def _maybe_resize_out(out: TensorLikeType, shape: ShapeType):
if _resize_output_check(out, shape):
return out.resize_(shape)
else:
return out
def _safe_copy_out(