[functorch] Add batch rule for linalg.lu_unpack (#121811)

Fixes: https://github.com/pytorch/pytorch/issues/119998

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121811
Approved by: https://github.com/peterbell10, https://github.com/zou3519
This commit is contained in:
Guilherme Leobas 2024-03-18 11:35:35 -03:00 committed by PyTorch MergeBot
parent 773ae817f7
commit e6a461119a
3 changed files with 23 additions and 12 deletions

View File

@ -265,6 +265,28 @@ static void expect_at_least_rank(
rank, " dimensions instead."); rank, " dimensions instead.");
} }
threeOutputs linalg_lu_unpack_batch_rule(
const Tensor& LU, optional<int64_t> LU_bdim,
const Tensor& pivots, optional<int64_t> pivots_bdim,
bool unpack_data, bool unpack_pivots) {
auto LU_ = moveBatchDimToFront(LU, LU_bdim);
auto pivots_ = moveBatchDimToFront(pivots, pivots_bdim);
// LU and pivots's first {N-2} (for LU), {N-1} (for pivots) dimensions must
// match So if only one of them is being vmapped over, we must expand out that
// dimension.
if (LU_bdim.has_value() != pivots_bdim.has_value()) {
auto bdim_size = get_bdim_size2(LU, LU_bdim, pivots, pivots_bdim);
LU_ = ensure_has_bdim(LU_, LU_bdim.has_value(), bdim_size);
pivots_ = ensure_has_bdim(pivots_, pivots_bdim.has_value(), bdim_size);
pivots_bdim = 0;
LU_bdim = 0;
}
const auto res = at::lu_unpack(LU_, pivots_, unpack_data, unpack_pivots);
return std::make_tuple(std::get<0>(res), 0, std::get<1>(res), 0, std::get<2>(res), 0);
}
oneOutput linalg_lu_solve_batch_rule( oneOutput linalg_lu_solve_batch_rule(
const Tensor& LU, optional<int64_t> LU_bdim, const Tensor& LU, optional<int64_t> LU_bdim,
const Tensor& pivots, optional<int64_t> pivots_bdim, const Tensor& pivots, optional<int64_t> pivots_bdim,
@ -578,6 +600,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
VMAP_SUPPORT(dot, dot_batch_rule); VMAP_SUPPORT(dot, dot_batch_rule);
VMAP_SUPPORT(mv, mv_batch_rule); VMAP_SUPPORT(mv, mv_batch_rule);
VMAP_SUPPORT(mm, mm_batch_rule); VMAP_SUPPORT(mm, mm_batch_rule);
VMAP_SUPPORT(lu_unpack, linalg_lu_unpack_batch_rule);
VMAP_SUPPORT(linalg_lu_solve, linalg_lu_solve_batch_rule); VMAP_SUPPORT(linalg_lu_solve, linalg_lu_solve_batch_rule);
VMAP_SUPPORT(linalg_householder_product, householder_product_batch_rule); VMAP_SUPPORT(linalg_householder_product, householder_product_batch_rule);
VMAP_SUPPORT(cholesky_solve, cholesky_solve_batch_rule); // custom dim error VMAP_SUPPORT(cholesky_solve, cholesky_solve_batch_rule); // custom dim error

View File

@ -1117,7 +1117,6 @@ class TestOperators(TestCase):
@skipOps('TestOperators', 'test_vmapjvpall_has_batch_rule', vmapjvpall_fail.union({ @skipOps('TestOperators', 'test_vmapjvpall_has_batch_rule', vmapjvpall_fail.union({
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
xfail('cdouble'), # RuntimeError: required rank 4 tensor to use channels_last format xfail('cdouble'), # RuntimeError: required rank 4 tensor to use channels_last format
xfail('lu'),
xfail('cumprod'), xfail('cumprod'),
xfail('masked_fill'), xfail('masked_fill'),
xfail('fill'), xfail('fill'),
@ -1126,11 +1125,9 @@ class TestOperators(TestCase):
xfail('put'), xfail('put'),
xfail('take'), xfail('take'),
xfail('nn.functional.feature_alpha_dropout', 'without_train'), xfail('nn.functional.feature_alpha_dropout', 'without_train'),
xfail('linalg.lu_factor', ''),
xfail('nn.functional.dropout2d', ''), xfail('nn.functional.dropout2d', ''),
xfail('pca_lowrank', ''), xfail('pca_lowrank', ''),
xfail('svd_lowrank', ''), xfail('svd_lowrank', ''),
xfail('linalg.lu_factor_ex', ''),
xfail('nn.functional.feature_alpha_dropout', 'with_train'), xfail('nn.functional.feature_alpha_dropout', 'with_train'),
xfail('special.log_ndtr', ''), xfail('special.log_ndtr', ''),
xfail('fft.ihfft2'), # conj_physical fallback xfail('fft.ihfft2'), # conj_physical fallback
@ -1144,11 +1141,9 @@ class TestOperators(TestCase):
xfail('scatter_reduce', "mean"), # aten::scatter_reduce.two hit the vmap fallback xfail('scatter_reduce', "mean"), # aten::scatter_reduce.two hit the vmap fallback
xfail('scatter_reduce', "amin"), # aten::scatter_reduce.two hit the vmap fallback xfail('scatter_reduce', "amin"), # aten::scatter_reduce.two hit the vmap fallback
xfail('scatter_reduce', "amax"), # aten::scatter_reduce.two hit the vmap fallback xfail('scatter_reduce', "amax"), # aten::scatter_reduce.two hit the vmap fallback
xfail('lu_unpack'),
xfail('nn.functional.glu'), xfail('nn.functional.glu'),
xfail('nn.functional.bilinear'), # trilinear doesn't have batching rule xfail('nn.functional.bilinear'), # trilinear doesn't have batching rule
xfail('linalg.lu', ''), xfail('linalg.lu', ''),
xfail('linalg.lu_solve', ''),
xfail('nn.functional.dropout3d', ''), xfail('nn.functional.dropout3d', ''),
xfail('as_strided_scatter', ''), xfail('as_strided_scatter', ''),
xfail('masked.cumprod', ''), xfail('masked.cumprod', ''),
@ -1190,9 +1185,6 @@ class TestOperators(TestCase):
xfail('narrow'), # Batching rule not implemented for `narrow.Tensor` (and view op) xfail('narrow'), # Batching rule not implemented for `narrow.Tensor` (and view op)
xfail('special.log_ndtr'), xfail('special.log_ndtr'),
xfail('linalg.householder_product'), xfail('linalg.householder_product'),
xfail('lu'),
xfail('lu_solve'),
xfail('lu_unpack'),
xfail('masked_fill'), xfail('masked_fill'),
xfail('masked_scatter'), xfail('masked_scatter'),
xfail('masked_select'), xfail('masked_select'),
@ -1220,13 +1212,11 @@ class TestOperators(TestCase):
xfail('nn.functional.rrelu'), xfail('nn.functional.rrelu'),
xfail('nn.functional.embedding_bag'), xfail('nn.functional.embedding_bag'),
xfail('nn.functional.fractional_max_pool2d'), xfail('nn.functional.fractional_max_pool2d'),
xfail('linalg.lu_factor', ''),
xfail('nn.functional.feature_alpha_dropout', 'with_train'), xfail('nn.functional.feature_alpha_dropout', 'with_train'),
xfail('pca_lowrank', ''), xfail('pca_lowrank', ''),
xfail('nn.functional.dropout2d', ''), xfail('nn.functional.dropout2d', ''),
xfail('nn.functional.feature_alpha_dropout', 'without_train'), xfail('nn.functional.feature_alpha_dropout', 'without_train'),
xfail('svd_lowrank', ''), xfail('svd_lowrank', ''),
xfail('linalg.lu_factor_ex', ''),
xfail('nn.functional.max_unpool2d', ''), xfail('nn.functional.max_unpool2d', ''),
xfail('nn.functional.multi_margin_loss', ''), xfail('nn.functional.multi_margin_loss', ''),
@ -1240,7 +1230,6 @@ class TestOperators(TestCase):
xfail('nn.functional.max_unpool1d', 'grad'), xfail('nn.functional.max_unpool1d', 'grad'),
xfail('nn.functional.max_unpool2d', 'grad'), xfail('nn.functional.max_unpool2d', 'grad'),
xfail('linalg.lu', ''), xfail('linalg.lu', ''),
xfail('linalg.lu_solve', ''),
xfail('cdouble', ''), xfail('cdouble', ''),
xfail('cfloat', ''), xfail('cfloat', ''),
xfail('chalf', ''), xfail('chalf', ''),

View File

@ -3677,7 +3677,6 @@ class TestVmapOperatorsOpInfo(TestCase):
# masked index as input which is not supported # masked index as input which is not supported
xfail('index_put', ''), xfail('index_put', ''),
xfail('isin'), xfail('isin'),
xfail('lu_unpack'),
xfail('masked_fill'), xfail('masked_fill'),
xfail('masked_scatter'), xfail('masked_scatter'),
xfail('masked_select'), xfail('masked_select'),