mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[pt2] add meta function for solve_triangular (#100829)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100829 Approved by: https://github.com/ezyang
This commit is contained in:
parent
cd8b82e5c6
commit
1e591a8b64
|
|
@ -2490,7 +2490,6 @@ symbolic_aot_autograd_failures = {
|
|||
xfail('index_fill', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('kron', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('kthvalue', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('linalg.cholesky_ex', ''), # could not find kernel for aten.linalg_solve_triangular.default
|
||||
xfail('linalg.det', ''), # aten._linalg_det.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('linalg.det', 'singular'), # aten._linalg_det.default - couldn't find symbolic meta function/deco...
|
||||
xfail('linalg.eigh', ''), # aten._linalg_eigh.default - couldn't find symbolic meta function/decomposition
|
||||
|
|
@ -2511,7 +2510,6 @@ symbolic_aot_autograd_failures = {
|
|||
xfail('linalg.slogdet', ''), # aten._linalg_slogdet.default - couldn't find symbolic meta function/decom...
|
||||
xfail('linalg.solve', ''), # aten._linalg_solve_ex.default - couldn't find symbolic meta function/decomp...
|
||||
xfail('linalg.solve_ex', ''), # aten._linalg_solve_ex.default - couldn't find symbolic meta function/dec...
|
||||
xfail('linalg.solve_triangular', ''), # aten.linalg_solve_triangular.default - couldn't find symbolic me...
|
||||
xfail('linalg.tensorinv', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('linalg.tensorsolve', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('linalg.vander', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
|
|
|
|||
|
|
@ -598,7 +598,6 @@ meta_function_expected_failures = {
|
|||
torch.functional.istft : {f64, c64, c128, f32},
|
||||
torch.geqrf : {f64, c64, c128, f32},
|
||||
torch.linalg.householder_product : {f64, c64, c128, f32},
|
||||
torch.linalg.solve_triangular : {f64, c64, c128, f32},
|
||||
torch.masked_select : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
|
||||
torch.matrix_exp : {f64, c128, c64, bf16, f32},
|
||||
torch.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
|
||||
|
|
@ -718,7 +717,6 @@ meta_function_device_expected_failures['cuda'] = {
|
|||
torch.histc: {i16, i32, i64, i8}, # aten::histc, aten::histc.out
|
||||
torch.kthvalue: {f16}, # aten::kthvalue.values
|
||||
torch.linalg.householder_product: {f32, f64}, # aten::linalg_householder_product, aten::linalg_householder_product.out
|
||||
torch.linalg.solve_triangular: {f32, f64}, # aten::linalg_solve_triangular, aten::linalg_solve_triangular.out
|
||||
torch.matrix_exp: {f16}, # aten::linalg_matrix_exp
|
||||
torch.median: {f16}, # aten::median, aten::median.dim_values
|
||||
torch.multinomial: {f16}, # aten::multinomial, aten::multinomial.out
|
||||
|
|
@ -830,8 +828,6 @@ meta_dispatch_expected_failures = {
|
|||
aten.linalg_householder_product.out : {c64, c128, f64, f32},
|
||||
aten.linalg_lstsq.default : {c64, c128, f64, f32},
|
||||
aten.linalg_matrix_exp.default : {c64, bf16, f32, f64, c128},
|
||||
aten.linalg_solve_triangular.default : {c64, c128, f64, f32},
|
||||
aten.linalg_solve_triangular.out : {c64, c128, f64, f32},
|
||||
aten.masked_select.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
|
||||
aten.masked_select.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
|
||||
aten.nonzero.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, c32, b8, i16, u8},
|
||||
|
|
@ -929,8 +925,6 @@ meta_dispatch_device_expected_failures['cuda'] = {
|
|||
aten.linalg_householder_product.default: {f32, f64}, # aten::linalg_householder_product
|
||||
aten.linalg_householder_product.out: {f32, f64}, # aten::linalg_householder_product.out
|
||||
aten.linalg_matrix_exp.default: {f16}, # aten::linalg_matrix_exp
|
||||
aten.linalg_solve_triangular.default: {f32, f64}, # aten::linalg_solve_triangular
|
||||
aten.linalg_solve_triangular.out: {f32, f64}, # aten::linalg_solve_triangular.out
|
||||
aten.log_sigmoid_forward.default: {bf16, f16, f64, f32},
|
||||
aten.log_sigmoid_forward.output : {bf16, f16, f64, f32}, # aten::log_sigmoid_forward.output
|
||||
aten.max_pool3d_with_indices.default: {bf16, f16}, # aten::max_pool3d_with_indices
|
||||
|
|
|
|||
|
|
@ -1458,7 +1458,6 @@ symbolic_tensor_failures = {
|
|||
xfail('linalg.slogdet', ''), # aten._linalg_slogdet.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('linalg.solve', ''), # aten._linalg_solve_ex.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('linalg.solve_ex', ''), # aten._linalg_solve_ex.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('linalg.solve_triangular', ''), # aten.linalg_solve_triangular.default - couldn't find symbolic meta function/de...
|
||||
xfail('linalg.tensorinv', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('linalg.tensorsolve', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('linalg.vander', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
|
|
|
|||
|
|
@ -18,7 +18,12 @@ from torch._prims_common import (
|
|||
TensorLike,
|
||||
)
|
||||
|
||||
from torch._prims_common.wrappers import _maybe_resize_out, _safe_copy_out, out_wrapper
|
||||
from torch._prims_common.wrappers import (
|
||||
_maybe_resize_out,
|
||||
_resize_output_check,
|
||||
_safe_copy_out,
|
||||
out_wrapper,
|
||||
)
|
||||
from torch._refs import _broadcast_shapes
|
||||
|
||||
from torch.utils._pytree import tree_map
|
||||
|
|
@ -315,6 +320,48 @@ def squareCheckInputs(self: Tensor, f_name: str):
|
|||
), f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices"
|
||||
|
||||
|
||||
# Validates input shapes and devices
|
||||
# for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve)
|
||||
# From aten/src/ATen/native/LinearAlgebraUtils.h
|
||||
def linearSolveCheckInputs(
|
||||
self: Tensor,
|
||||
A: Tensor,
|
||||
name: str,
|
||||
):
|
||||
check(
|
||||
self.device == A.device,
|
||||
lambda: (
|
||||
f"Expected b and A to be on the same device, but found b on "
|
||||
f"{self.device} and A on {A.device} instead."
|
||||
),
|
||||
)
|
||||
|
||||
check(
|
||||
self.dtype == A.dtype,
|
||||
lambda: (
|
||||
f"Expected b and A to have the same dtype, but found b of type "
|
||||
f"{self.dtype} and A of type {A.dtype} instead."
|
||||
),
|
||||
)
|
||||
|
||||
check(
|
||||
A.size(-1) == A.size(-2),
|
||||
lambda: (
|
||||
f"A must be batches of square matrices, "
|
||||
f"but they are {A.size(-2)} by {A.size(-1)} matrices"
|
||||
),
|
||||
)
|
||||
|
||||
check(
|
||||
A.size(-1) == self.size(-2),
|
||||
lambda: (
|
||||
f"Incompatible matrix sizes for {name}: each A "
|
||||
f"matrix is {A.size(-1)} by {A.size(-1)}"
|
||||
f" but each b matrix is {self.size(-2)} by {self.size(-1)}"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# From aten/src/ATen/native/LinearAlgebraUtils.h
|
||||
def checkFloatingOrComplex(
|
||||
t: Tensor, f_name: str, allow_low_precision_dtypes: bool = True
|
||||
|
|
@ -339,6 +386,24 @@ def checkIsMatrix(A: Tensor, f_name: str, arg_name: str = "A"):
|
|||
)
|
||||
|
||||
|
||||
def checkInputsSolver(
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
left: bool,
|
||||
f_name: str,
|
||||
):
|
||||
squareCheckInputs(A, f_name)
|
||||
checkIsMatrix(B, f_name)
|
||||
check(
|
||||
A.size(-2) == B.size(-2) if left else A.size(-1) == B.size(-1),
|
||||
lambda: (
|
||||
f"{f_name}: Incompatible shapes of A and B for the equation "
|
||||
f"{'AX = B' if left else 'XA = B'}"
|
||||
f" ({A.size(-2)}x{A.size(-1)} and {B.size(-2)}x{B.size(-1)})"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def checkUplo(uplo: str):
|
||||
uplo_uppercase = uplo.upper()
|
||||
assert (
|
||||
|
|
@ -483,6 +548,66 @@ def _linalg_svd_meta(
|
|||
return U, S, V
|
||||
|
||||
|
||||
def _linalg_broadcast_batch_dims(
|
||||
arg1: Tensor, arg2: Tensor
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
# broadcast the batch dimensions of arg1 and arg2.
|
||||
arg1_batch_sizes = arg1.shape[:-2]
|
||||
arg2_batch_sizes = arg2.shape[:-2]
|
||||
expand_batch_portion = _broadcast_shapes(arg1_batch_sizes, arg2_batch_sizes)
|
||||
|
||||
arg1_expand_size = list(expand_batch_portion)
|
||||
arg1_expand_size += [arg1.size(-2), arg1.size(-1)]
|
||||
|
||||
arg2_expand_size = list(expand_batch_portion)
|
||||
arg2_expand_size += [arg2.size(-2), arg2.size(-1)]
|
||||
return arg1_expand_size, arg2_expand_size
|
||||
|
||||
|
||||
def _linalg_broadcast_batch_dims_name(
|
||||
arg1: Tensor, arg2: Tensor, name: Optional[str]
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
# If there's no name we assume we don't want to check the errors
|
||||
if name:
|
||||
linearSolveCheckInputs(arg1, arg2, name)
|
||||
|
||||
arg1_expand_size, arg2_expand_size = _linalg_broadcast_batch_dims(arg1, arg2)
|
||||
|
||||
arg1_broadcasted = (
|
||||
arg1 if arg1_expand_size == arg1.shape else arg1.expand(arg1_expand_size)
|
||||
)
|
||||
arg2_broadcasted = (
|
||||
arg2 if arg2_expand_size == arg2.shape else arg2.expand(arg2_expand_size)
|
||||
)
|
||||
return arg1_broadcasted, arg2_broadcasted
|
||||
|
||||
|
||||
@register_meta([aten.linalg_solve_triangular.default, aten.linalg_solve_triangular.out])
|
||||
def linalg_solve_triangular_meta(
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
*,
|
||||
upper: bool,
|
||||
left: bool = True,
|
||||
unitriangular: bool = False,
|
||||
out: Tensor = None,
|
||||
) -> Tensor:
|
||||
if out is None:
|
||||
out = A.new_empty([0])
|
||||
assert isinstance(out, TensorLike)
|
||||
checkInputsSolver(A, B, left, "linalg.solve_triangular")
|
||||
B_, A_ = _linalg_broadcast_batch_dims_name(B, A, None)
|
||||
avoid_copy_A = A_.transpose(-2, -1).is_contiguous() and A_.is_conj()
|
||||
if avoid_copy_A:
|
||||
out = _maybe_resize_out(out, B_.shape)
|
||||
else:
|
||||
# reimplementation of resize_output with result F-contig
|
||||
if _resize_output_check(out, B_.shape):
|
||||
out.resize_(B_.transpose(-2, -1).shape)
|
||||
out.transpose_(-2, -1)
|
||||
return out # type: ignore[return-value]
|
||||
|
||||
|
||||
# From aten/src/ATen/native/LinearAlgebra.cpp
|
||||
@register_meta(aten._linalg_det.default)
|
||||
def _linalg_det_meta(A):
|
||||
|
|
|
|||
|
|
@ -139,22 +139,29 @@ class elementwise_type_promotion_wrapper:
|
|||
return _fn
|
||||
|
||||
|
||||
# TODO: handle tuples of tensors
|
||||
def _maybe_resize_out(out: TensorLikeType, shape: ShapeType):
|
||||
# Returns True if resize is necessary
|
||||
def _resize_output_check(out: TensorLikeType, shape: ShapeType):
|
||||
# If the shapes are correct there's nothing to do
|
||||
if utils.same_shape(out.shape, shape):
|
||||
return out
|
||||
else:
|
||||
if out.numel() != 0:
|
||||
msg = (
|
||||
f"An output with one or more elements was resized since it had shape {str(out.shape)} "
|
||||
"which does not match the required output shape {str(shape)}. "
|
||||
"This behavior is deprecated, and in a future PyTorch release outputs will not "
|
||||
"be resized unless they have zero elements. "
|
||||
"You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)."
|
||||
)
|
||||
warnings.warn(msg)
|
||||
return False
|
||||
if out.numel() != 0:
|
||||
msg = (
|
||||
f"An output with one or more elements was resized since it had shape {str(out.shape)} "
|
||||
"which does not match the required output shape {str(shape)}. "
|
||||
"This behavior is deprecated, and in a future PyTorch release outputs will not "
|
||||
"be resized unless they have zero elements. "
|
||||
"You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)."
|
||||
)
|
||||
warnings.warn(msg)
|
||||
return True
|
||||
|
||||
|
||||
# TODO: handle tuples of tensors
|
||||
def _maybe_resize_out(out: TensorLikeType, shape: ShapeType):
|
||||
if _resize_output_check(out, shape):
|
||||
return out.resize_(shape)
|
||||
else:
|
||||
return out
|
||||
|
||||
|
||||
def _safe_copy_out(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user