Add batching rule for torch.matrix_exp (#155202)

## Summary

Adds the missing batching rule for `torch.matrix_exp` to enable efficient `vmap` support.
Previously, using `vmap` with `matrix_exp` would trigger a performance warning and fall back to a slow loop-based implementation, even though `matrix_exp` natively supports batched inputs.

Fixes #115992

## Details

`torch.matrix_exp` is an alias for `torch.linalg.matrix_exp`. This PR adds vmap support by registering `matrix_exp` with `OP_DECOMPOSE`, which reuses the existing CompositeImplicitAutograd decomposition to automatically generate batching behavior from the operation's simpler component operations.

## Testing

The existing test suite for vmap and matrix_exp should cover this change. The fix enables:
- No performance warning when using `vmap(torch.matrix_exp)`
- Efficient native batched execution instead of loop-based fallback

**Edit:** Updated Details section to accurately reflect the implementation approach (decomposition rather than batch rule registration)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155202
Approved by: https://github.com/zou3519
This commit is contained in:
Ayan Das 2025-06-18 17:35:35 +00:00 committed by PyTorch MergeBot
parent a5f59cc2ea
commit 2620361d19
3 changed files with 1 additions and 2 deletions

View File

@ -193,6 +193,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
OP_DECOMPOSE(_lu_with_info);
OP_DECOMPOSE(matmul);
OP_DECOMPOSE(matrix_H);
OP_DECOMPOSE(matrix_exp);
OP_DECOMPOSE(matrix_power);
OP_DECOMPOSE2(max, other );
OP_DECOMPOSE(max_pool1d);

View File

@ -4474,7 +4474,6 @@ class TestVmapOperatorsOpInfo(TestCase):
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
xfail("resize_"),
xfail("view_as_complex"),
xfail("matrix_exp"),
xfail("fft.ihfft2"),
xfail("fft.ihfftn"),
xfail("allclose"),

View File

@ -120,7 +120,6 @@ xfail_not_implemented = {
"aten::lu_solve",
"aten::margin_ranking_loss",
"aten::masked_select_backward",
"aten::matrix_exp",
"aten::matrix_exp_backward",
"aten::max.names_dim",
"aten::max.names_dim_max",