mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: This PR makes the following improvements: 1. Add `forward_with_indices` method to all C++ MaxPool modules, to return the max indices along with the outputs. (We can't make two `forward` methods that return different types based on input, because that will break the type deduction of `torch::detail::return_type_of_forward_t`) 2. Add `max_poolNd_with_indices` to `torch::nn::functional`, to be used when indices of the max values are needed. (We can't merge this with `torch::nn::functional::max_poolNd` because the return type of `max_poolNd` has to be defined statically). 3. Improve `pretty_print` of C++ MaxPoolNd and AvgPoolNd modules to match the Python `extra_repr`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/26521 Differential Revision: D17507358 Pulled By: yf225 fbshipit-source-id: b6c0e2b27b38378cdc0c75f4bfc797b3c6b17cd9 |
||
|---|---|---|
| .. | ||
| api | ||
| common | ||
| dist_autograd | ||
| jit | ||
| __init__.py | ||