mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[functorch] Add batch rule for linalg.lu_unpack (#121811)
Fixes: https://github.com/pytorch/pytorch/issues/119998 Pull Request resolved: https://github.com/pytorch/pytorch/pull/121811 Approved by: https://github.com/peterbell10, https://github.com/zou3519
This commit is contained in:
parent
773ae817f7
commit
e6a461119a
|
|
@ -265,6 +265,28 @@ static void expect_at_least_rank(
|
||||||
rank, " dimensions instead.");
|
rank, " dimensions instead.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
threeOutputs linalg_lu_unpack_batch_rule(
|
||||||
|
const Tensor& LU, optional<int64_t> LU_bdim,
|
||||||
|
const Tensor& pivots, optional<int64_t> pivots_bdim,
|
||||||
|
bool unpack_data, bool unpack_pivots) {
|
||||||
|
auto LU_ = moveBatchDimToFront(LU, LU_bdim);
|
||||||
|
auto pivots_ = moveBatchDimToFront(pivots, pivots_bdim);
|
||||||
|
|
||||||
|
// LU and pivots's first {N-2} (for LU), {N-1} (for pivots) dimensions must
|
||||||
|
// match So if only one of them is being vmapped over, we must expand out that
|
||||||
|
// dimension.
|
||||||
|
if (LU_bdim.has_value() != pivots_bdim.has_value()) {
|
||||||
|
auto bdim_size = get_bdim_size2(LU, LU_bdim, pivots, pivots_bdim);
|
||||||
|
LU_ = ensure_has_bdim(LU_, LU_bdim.has_value(), bdim_size);
|
||||||
|
pivots_ = ensure_has_bdim(pivots_, pivots_bdim.has_value(), bdim_size);
|
||||||
|
pivots_bdim = 0;
|
||||||
|
LU_bdim = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto res = at::lu_unpack(LU_, pivots_, unpack_data, unpack_pivots);
|
||||||
|
return std::make_tuple(std::get<0>(res), 0, std::get<1>(res), 0, std::get<2>(res), 0);
|
||||||
|
}
|
||||||
|
|
||||||
oneOutput linalg_lu_solve_batch_rule(
|
oneOutput linalg_lu_solve_batch_rule(
|
||||||
const Tensor& LU, optional<int64_t> LU_bdim,
|
const Tensor& LU, optional<int64_t> LU_bdim,
|
||||||
const Tensor& pivots, optional<int64_t> pivots_bdim,
|
const Tensor& pivots, optional<int64_t> pivots_bdim,
|
||||||
|
|
@ -578,6 +600,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
|
||||||
VMAP_SUPPORT(dot, dot_batch_rule);
|
VMAP_SUPPORT(dot, dot_batch_rule);
|
||||||
VMAP_SUPPORT(mv, mv_batch_rule);
|
VMAP_SUPPORT(mv, mv_batch_rule);
|
||||||
VMAP_SUPPORT(mm, mm_batch_rule);
|
VMAP_SUPPORT(mm, mm_batch_rule);
|
||||||
|
VMAP_SUPPORT(lu_unpack, linalg_lu_unpack_batch_rule);
|
||||||
VMAP_SUPPORT(linalg_lu_solve, linalg_lu_solve_batch_rule);
|
VMAP_SUPPORT(linalg_lu_solve, linalg_lu_solve_batch_rule);
|
||||||
VMAP_SUPPORT(linalg_householder_product, householder_product_batch_rule);
|
VMAP_SUPPORT(linalg_householder_product, householder_product_batch_rule);
|
||||||
VMAP_SUPPORT(cholesky_solve, cholesky_solve_batch_rule); // custom dim error
|
VMAP_SUPPORT(cholesky_solve, cholesky_solve_batch_rule); // custom dim error
|
||||||
|
|
|
||||||
|
|
@ -1117,7 +1117,6 @@ class TestOperators(TestCase):
|
||||||
@skipOps('TestOperators', 'test_vmapjvpall_has_batch_rule', vmapjvpall_fail.union({
|
@skipOps('TestOperators', 'test_vmapjvpall_has_batch_rule', vmapjvpall_fail.union({
|
||||||
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
|
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
|
||||||
xfail('cdouble'), # RuntimeError: required rank 4 tensor to use channels_last format
|
xfail('cdouble'), # RuntimeError: required rank 4 tensor to use channels_last format
|
||||||
xfail('lu'),
|
|
||||||
xfail('cumprod'),
|
xfail('cumprod'),
|
||||||
xfail('masked_fill'),
|
xfail('masked_fill'),
|
||||||
xfail('fill'),
|
xfail('fill'),
|
||||||
|
|
@ -1126,11 +1125,9 @@ class TestOperators(TestCase):
|
||||||
xfail('put'),
|
xfail('put'),
|
||||||
xfail('take'),
|
xfail('take'),
|
||||||
xfail('nn.functional.feature_alpha_dropout', 'without_train'),
|
xfail('nn.functional.feature_alpha_dropout', 'without_train'),
|
||||||
xfail('linalg.lu_factor', ''),
|
|
||||||
xfail('nn.functional.dropout2d', ''),
|
xfail('nn.functional.dropout2d', ''),
|
||||||
xfail('pca_lowrank', ''),
|
xfail('pca_lowrank', ''),
|
||||||
xfail('svd_lowrank', ''),
|
xfail('svd_lowrank', ''),
|
||||||
xfail('linalg.lu_factor_ex', ''),
|
|
||||||
xfail('nn.functional.feature_alpha_dropout', 'with_train'),
|
xfail('nn.functional.feature_alpha_dropout', 'with_train'),
|
||||||
xfail('special.log_ndtr', ''),
|
xfail('special.log_ndtr', ''),
|
||||||
xfail('fft.ihfft2'), # conj_physical fallback
|
xfail('fft.ihfft2'), # conj_physical fallback
|
||||||
|
|
@ -1144,11 +1141,9 @@ class TestOperators(TestCase):
|
||||||
xfail('scatter_reduce', "mean"), # aten::scatter_reduce.two hit the vmap fallback
|
xfail('scatter_reduce', "mean"), # aten::scatter_reduce.two hit the vmap fallback
|
||||||
xfail('scatter_reduce', "amin"), # aten::scatter_reduce.two hit the vmap fallback
|
xfail('scatter_reduce', "amin"), # aten::scatter_reduce.two hit the vmap fallback
|
||||||
xfail('scatter_reduce', "amax"), # aten::scatter_reduce.two hit the vmap fallback
|
xfail('scatter_reduce', "amax"), # aten::scatter_reduce.two hit the vmap fallback
|
||||||
xfail('lu_unpack'),
|
|
||||||
xfail('nn.functional.glu'),
|
xfail('nn.functional.glu'),
|
||||||
xfail('nn.functional.bilinear'), # trilinear doesn't have batching rule
|
xfail('nn.functional.bilinear'), # trilinear doesn't have batching rule
|
||||||
xfail('linalg.lu', ''),
|
xfail('linalg.lu', ''),
|
||||||
xfail('linalg.lu_solve', ''),
|
|
||||||
xfail('nn.functional.dropout3d', ''),
|
xfail('nn.functional.dropout3d', ''),
|
||||||
xfail('as_strided_scatter', ''),
|
xfail('as_strided_scatter', ''),
|
||||||
xfail('masked.cumprod', ''),
|
xfail('masked.cumprod', ''),
|
||||||
|
|
@ -1190,9 +1185,6 @@ class TestOperators(TestCase):
|
||||||
xfail('narrow'), # Batching rule not implemented for `narrow.Tensor` (and view op)
|
xfail('narrow'), # Batching rule not implemented for `narrow.Tensor` (and view op)
|
||||||
xfail('special.log_ndtr'),
|
xfail('special.log_ndtr'),
|
||||||
xfail('linalg.householder_product'),
|
xfail('linalg.householder_product'),
|
||||||
xfail('lu'),
|
|
||||||
xfail('lu_solve'),
|
|
||||||
xfail('lu_unpack'),
|
|
||||||
xfail('masked_fill'),
|
xfail('masked_fill'),
|
||||||
xfail('masked_scatter'),
|
xfail('masked_scatter'),
|
||||||
xfail('masked_select'),
|
xfail('masked_select'),
|
||||||
|
|
@ -1220,13 +1212,11 @@ class TestOperators(TestCase):
|
||||||
xfail('nn.functional.rrelu'),
|
xfail('nn.functional.rrelu'),
|
||||||
xfail('nn.functional.embedding_bag'),
|
xfail('nn.functional.embedding_bag'),
|
||||||
xfail('nn.functional.fractional_max_pool2d'),
|
xfail('nn.functional.fractional_max_pool2d'),
|
||||||
xfail('linalg.lu_factor', ''),
|
|
||||||
xfail('nn.functional.feature_alpha_dropout', 'with_train'),
|
xfail('nn.functional.feature_alpha_dropout', 'with_train'),
|
||||||
xfail('pca_lowrank', ''),
|
xfail('pca_lowrank', ''),
|
||||||
xfail('nn.functional.dropout2d', ''),
|
xfail('nn.functional.dropout2d', ''),
|
||||||
xfail('nn.functional.feature_alpha_dropout', 'without_train'),
|
xfail('nn.functional.feature_alpha_dropout', 'without_train'),
|
||||||
xfail('svd_lowrank', ''),
|
xfail('svd_lowrank', ''),
|
||||||
xfail('linalg.lu_factor_ex', ''),
|
|
||||||
|
|
||||||
xfail('nn.functional.max_unpool2d', ''),
|
xfail('nn.functional.max_unpool2d', ''),
|
||||||
xfail('nn.functional.multi_margin_loss', ''),
|
xfail('nn.functional.multi_margin_loss', ''),
|
||||||
|
|
@ -1240,7 +1230,6 @@ class TestOperators(TestCase):
|
||||||
xfail('nn.functional.max_unpool1d', 'grad'),
|
xfail('nn.functional.max_unpool1d', 'grad'),
|
||||||
xfail('nn.functional.max_unpool2d', 'grad'),
|
xfail('nn.functional.max_unpool2d', 'grad'),
|
||||||
xfail('linalg.lu', ''),
|
xfail('linalg.lu', ''),
|
||||||
xfail('linalg.lu_solve', ''),
|
|
||||||
xfail('cdouble', ''),
|
xfail('cdouble', ''),
|
||||||
xfail('cfloat', ''),
|
xfail('cfloat', ''),
|
||||||
xfail('chalf', ''),
|
xfail('chalf', ''),
|
||||||
|
|
|
||||||
|
|
@ -3677,7 +3677,6 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||||
# masked index as input which is not supported
|
# masked index as input which is not supported
|
||||||
xfail('index_put', ''),
|
xfail('index_put', ''),
|
||||||
xfail('isin'),
|
xfail('isin'),
|
||||||
xfail('lu_unpack'),
|
|
||||||
xfail('masked_fill'),
|
xfail('masked_fill'),
|
||||||
xfail('masked_scatter'),
|
xfail('masked_scatter'),
|
||||||
xfail('masked_select'),
|
xfail('masked_select'),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user