mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Don't check for linalg errors on meta tensors
Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/78467 Approved by: https://github.com/Chillee
This commit is contained in:
parent
59fdb627a3
commit
789115e05e
|
|
@ -1294,6 +1294,9 @@ void _linalg_check_errors(
|
|||
const Tensor& info,
|
||||
const c10::string_view api_name,
|
||||
bool is_matrix) {
|
||||
if (info.is_meta()) {
|
||||
return;
|
||||
}
|
||||
if (is_matrix) {
|
||||
singleCheckErrors(info.item<int64_t>(), api_name);
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -603,14 +603,10 @@ meta_function_expected_failures = {
|
|||
torch.linalg.eigvals: {f32, f64},
|
||||
torch.linalg.eigvalsh: {f32, f64}, # aten::linalg_eigvalsh.out
|
||||
torch.linalg.householder_product: {f32, f64}, # aten::linalg_householder_product
|
||||
torch.linalg.inv: {f32, f64}, # aten::_local_scalar_dense
|
||||
torch.linalg.ldl_factor: {f32, f64}, # aten::_local_scalar_dense
|
||||
torch.linalg.lstsq: {f32, f64}, # aten::linalg_lstsq.out
|
||||
torch.linalg.lu_factor: {f32, f64}, # aten::_local_scalar_dense
|
||||
torch.linalg.slogdet: {f32, f64}, # aten::linalg_slogdet
|
||||
torch.linalg.solve: {f32, f64}, # aten::linalg_solve, aten::linalg_solve.out
|
||||
torch.linalg.solve_triangular: {f32, f64}, # aten::linalg_solve_triangular
|
||||
torch.linalg.tensorinv: {f32, f64}, # aten::_local_scalar_dense
|
||||
torch.linalg.tensorsolve: {f32, f64}, # aten::linalg_solve
|
||||
torch.logdet: {f32, f64}, # aten::_local_scalar_dense, aten::nonzero
|
||||
}
|
||||
|
|
@ -640,9 +636,6 @@ meta_function_skips = {
|
|||
torch.nn.functional.cross_entropy: {bf16, f32, f64},
|
||||
torch.nn.functional.interpolate: {bf16, f32, f64, u8},
|
||||
torch.nn.functional.nll_loss: {bf16, f32, f64}, # TODO
|
||||
torch.inverse: {f32, f64},
|
||||
torch.linalg.matrix_power: {f32, f64},
|
||||
torch.linalg.matrix_rank: {f32, f64},
|
||||
torch.linalg.pinv: {f32, f64},
|
||||
torch.empty: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8},
|
||||
}
|
||||
|
|
@ -685,11 +678,7 @@ meta_function_device_expected_failures['cuda'] = {
|
|||
torch.linalg.cholesky: {f32, f64}, # aten::linalg_cholesky_ex, aten::linalg_cholesky_ex.L
|
||||
torch.linalg.cholesky_ex: {f32, f64}, # aten::linalg_cholesky_ex
|
||||
torch.linalg.householder_product: {f32, f64}, # aten::linalg_householder_product, aten::linalg_householder_product.out
|
||||
torch.linalg.inv: {f32, f64}, # aten::_local_scalar_dense
|
||||
torch.linalg.ldl_factor: {f32, f64}, # aten::_local_scalar_dense
|
||||
torch.linalg.lu_factor: {f32, f64}, # aten::_local_scalar_dense
|
||||
torch.linalg.solve_triangular: {f32, f64}, # aten::linalg_solve_triangular, aten::linalg_solve_triangular.out
|
||||
torch.linalg.tensorinv: {f32, f64}, # aten::_local_scalar_dense
|
||||
torch.logcumsumexp: {bf16, f16}, # aten::_logcumsumexp, aten::_logcumsumexp.out
|
||||
torch.matrix_exp: {f16}, # aten::linalg_matrix_exp
|
||||
torch.median: {f16}, # aten::median, aten::median.dim_values
|
||||
|
|
@ -726,7 +715,6 @@ meta_function_device_skips['cuda'] = {
|
|||
torch.cummin: {f16},
|
||||
torch.functional.tensordot: {f16},
|
||||
torch.inner: {f16},
|
||||
torch.inverse: {f32, f64},
|
||||
torch.linalg.matrix_power: {f32, f64},
|
||||
torch.linalg.matrix_rank: {f32, f64},
|
||||
torch.linalg.svd: {f32, f64},
|
||||
|
|
@ -873,7 +861,6 @@ meta_dispatch_expected_failures = {
|
|||
aten.vdot.default: {i64, bf16, u8, f32, i8, f64, i16, i32},
|
||||
aten.vdot.out: {i64, bf16, u8, f32, i8, f64, i16, i32},
|
||||
aten._det_lu_based_helper.default: {f32, f64}, # aten::_det_lu_based_helper
|
||||
aten._linalg_check_errors.default: {c128, c64, f32, f64}, # aten::_local_scalar_dense
|
||||
aten.cholesky.default: {f32, f64}, # aten::cholesky
|
||||
aten.cholesky.out: {f32, f64}, # aten::cholesky.out
|
||||
aten.cholesky_inverse.default: {f32, f64}, # aten::cholesky_inverse
|
||||
|
|
@ -882,7 +869,6 @@ meta_dispatch_expected_failures = {
|
|||
aten.cholesky_solve.out: {f32, f64}, # aten::_cholesky_solve_helper
|
||||
aten.eig.default: {f32, f64}, # aten::_local_scalar_dense
|
||||
aten.geqrf.default: {f32, f64}, # aten::geqrf
|
||||
aten.inverse.out: {f32, f64}, # aten::_local_scalar_dense
|
||||
aten.linalg_cholesky_ex.L: {f32, f64}, # aten::linalg_cholesky_ex.L
|
||||
aten.linalg_cholesky_ex.default: {f32, f64}, # aten::linalg_cholesky_ex
|
||||
aten.linalg_eig.default: {f32, f64}, # aten::linalg_eig
|
||||
|
|
@ -929,7 +915,6 @@ meta_dispatch_device_expected_failures['cuda'] = {
|
|||
aten._fft_c2r.out: {c32, f16}, # aten::_fft_c2r.out
|
||||
aten._fft_r2c.default: {f16}, # aten::_fft_r2c
|
||||
aten._fft_r2c.out: {f16}, # aten::_fft_r2c.out
|
||||
aten._linalg_check_errors.default: {c128, c64, f32, f64}, # aten::_local_scalar_dense
|
||||
aten._unique2.default: {f16}, # aten::_unique2
|
||||
aten._use_cudnn_ctc_loss.default: {f32, f64}, # aten::_use_cudnn_ctc_loss
|
||||
aten.addbmm.default: {f16}, # aten::addbmm
|
||||
|
|
@ -944,7 +929,6 @@ meta_dispatch_device_expected_failures['cuda'] = {
|
|||
aten.histc.default: {i16, i32, i64, i8}, # aten::histc
|
||||
aten.histc.out: {i16, i32, i64, i8}, # aten::histc.out
|
||||
aten.index.Tensor: {c32}, # aten::index.Tensor
|
||||
aten.inverse.out: {f32, f64}, # aten::_local_scalar_dense
|
||||
aten.kthvalue.default: {f16}, # aten::kthvalue.values
|
||||
aten.linalg_cholesky_ex.L: {f32, f64}, # aten::linalg_cholesky_ex.L
|
||||
aten.linalg_cholesky_ex.default: {f32, f64}, # aten::linalg_cholesky_ex
|
||||
|
|
@ -990,11 +974,9 @@ meta_dispatch_device_expected_failures['cuda'] = {
|
|||
|
||||
meta_dispatch_device_skips['cuda'] = {
|
||||
aten._conj.default: {c32, f16},
|
||||
aten._linalg_svd.default: {f32, f64},
|
||||
aten.cudnn_batch_norm.default: {f32, f64},
|
||||
aten.cummax.default: {f16},
|
||||
aten.cummin.default: {f16},
|
||||
aten.inverse.default: {f32, f64},
|
||||
# ROCm stuff; technically this should be expected failure but it's
|
||||
# not worth it; these should get unified anyway
|
||||
aten.miopen_batch_norm.default: {f32},
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user