mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add batching rule for torch.matrix_exp (#155202)
## Summary Adds the missing batching rule for `torch.matrix_exp` to enable efficient `vmap` support. Previously, using `vmap` with `matrix_exp` would trigger a performance warning and fall back to a slow loop-based implementation, even though `matrix_exp` natively supports batched inputs. Fixes #115992 ## Details `torch.matrix_exp` is an alias for `torch.linalg.matrix_exp`. This PR adds vmap support by registering `matrix_exp` with `OP_DECOMPOSE`, which reuses the existing CompositeImplicitAutograd decomposition to automatically generate batching behavior from the operation's simpler component operations. ## Testing The existing test suite for vmap and matrix_exp should cover this change. The fix enables: - No performance warning when using `vmap(torch.matrix_exp)` - Efficient native batched execution instead of loop-based fallback **Edit:** Updated Details section to accurately reflect the implementation approach (decomposition rather than batch rule registration) Pull Request resolved: https://github.com/pytorch/pytorch/pull/155202 Approved by: https://github.com/zou3519
This commit is contained in:
parent
a5f59cc2ea
commit
2620361d19
|
|
@ -193,6 +193,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
|
|||
OP_DECOMPOSE(_lu_with_info);
|
||||
OP_DECOMPOSE(matmul);
|
||||
OP_DECOMPOSE(matrix_H);
|
||||
OP_DECOMPOSE(matrix_exp);
|
||||
OP_DECOMPOSE(matrix_power);
|
||||
OP_DECOMPOSE2(max, other );
|
||||
OP_DECOMPOSE(max_pool1d);
|
||||
|
|
|
|||
|
|
@ -4474,7 +4474,6 @@ class TestVmapOperatorsOpInfo(TestCase):
|
|||
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
|
||||
xfail("resize_"),
|
||||
xfail("view_as_complex"),
|
||||
xfail("matrix_exp"),
|
||||
xfail("fft.ihfft2"),
|
||||
xfail("fft.ihfftn"),
|
||||
xfail("allclose"),
|
||||
|
|
|
|||
|
|
@ -120,7 +120,6 @@ xfail_not_implemented = {
|
|||
"aten::lu_solve",
|
||||
"aten::margin_ranking_loss",
|
||||
"aten::masked_select_backward",
|
||||
"aten::matrix_exp",
|
||||
"aten::matrix_exp_backward",
|
||||
"aten::max.names_dim",
|
||||
"aten::max.names_dim_max",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user