pytorch/torch
Yidi Wu af4ba78543 [scan x vmap] support scan in vmap (#165580)
This is required by the chunked_with_scan work where two nested vmap(vmap) with chunk sizes > 1 are invoked, which produces a scan-> vmap -> scan -> vmap chain and we need to handle the case of vmap(scan) and scan(vmap).

The way we handle vmap(scan) is to turn it into scan(vmap(combine_fn)). The idea being that the combine_fn no longer do the combine_fn for a single slice, it vmaps over the combine_fn and do multiple combine_fns in one step. We need to need know how combine_fn propagates the batched tensor and what are the batched dims of the output. For this purpose, we use restore_vmap to give us the out_dims information.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165580
Approved by: https://github.com/zou3519
ghstack dependencies: #165675
2025-10-22 09:46:00 +00:00
..
_awaits
_C [reland][fx] Move Node._prepend/Node._remove_from_list to C++ (#165882) 2025-10-21 19:43:55 +00:00
_C_flatbuffer
_custom_op [2/N] Fix ruff warnings (#164460) 2025-10-04 03:40:32 +00:00
_decomp Back out "Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)" (#165910) 2025-10-21 16:36:38 +00:00
_dispatch Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
_dynamo [vmap][dynamo] use create_proxy instead of create_node in vmap increate nesting ctx manager (#165675) 2025-10-22 09:46:00 +00:00
_export [torch.export] Rmoving unused constants - add support for corner case (#165205) 2025-10-14 20:26:28 +00:00
_functorch [scan x vmap] support scan in vmap (#165580) 2025-10-22 09:46:00 +00:00
_higher_order_ops [scan x vmap] support scan in vmap (#165580) 2025-10-22 09:46:00 +00:00
_inductor Enable PLC0414 on ruff (#165828) 2025-10-22 04:56:52 +00:00
_lazy Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
_library Enable all flake8-logging-format rules (#164655) 2025-10-19 00:59:28 +00:00
_logging [annotation] add logging for debugging annotation (#165797) 2025-10-20 21:27:38 +00:00
_numpy Enable PLW0127 in ruff (#165851) 2025-10-21 03:30:57 +00:00
_prims Fix self assignment (#165816) 2025-10-18 18:51:52 +00:00
_prims_common [2/N] Use "is" in python type comparison (#165142) 2025-10-10 15:36:44 +00:00
_refs Enable all PIE rules on ruff (#165814) 2025-10-18 07:36:18 +00:00
_strobelight Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
_subclasses Enable PLC0414 on ruff (#165828) 2025-10-22 04:56:52 +00:00
_vendor
accelerator
amp Revert "[AMP][Refactor] Autocast dtype handling to simplify device-specific c… (#165221)" 2025-10-22 00:26:57 +00:00
ao Revert "[Code Clean] Clean asserts in torch/ao/quantization (root, quantizer, backend_config) (#165433)" 2025-10-21 22:10:19 +00:00
autograd [Code Clean] Clean asserts in torch/autograd. (#165627) 2025-10-20 23:03:47 +00:00
backends Add type suppressions to _inductor/runtime (#165918) 2025-10-21 02:54:22 +00:00
compiler Megacache integration (#163533) 2025-10-15 22:49:15 +00:00
contrib
cpu Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
csrc [Code Clean] Better error handling in torch/csrc/distributed (#165053) 2025-10-22 01:40:36 +00:00
cuda Add type suppressions to _inductor/runtime (#165918) 2025-10-21 02:54:22 +00:00
distributed Revert "shrink_group implementation to expose ncclCommShrink API (#164518)" 2025-10-21 20:24:14 +00:00
distributions [1/N] Use "is" in python type comparison (#165037) 2025-10-10 12:36:50 +00:00
export Back out "Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)" (#165910) 2025-10-21 16:36:38 +00:00
fft
func
futures
fx [reland][fx] Move Node._prepend/Node._remove_from_list to C++ (#165882) 2025-10-21 19:43:55 +00:00
headeronly Move toUnderlying to headeronly (#165694) 2025-10-22 05:31:16 +00:00
jit Fix missing brackets (#165138) 2025-10-10 17:23:31 +00:00
legacy
lib [2/N] Mark unused parameters in C++ code (#165121) 2025-10-15 03:04:39 +00:00
linalg Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
masked [1/N] Use "is" in python type comparison (#165037) 2025-10-10 12:36:50 +00:00
monitor
mps
mtia Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
multiprocessing Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
nativert [2/N] Mark unused parameters in C++ code (#165121) 2025-10-15 03:04:39 +00:00
nested Enable all PIE rules on ruff (#165814) 2025-10-18 07:36:18 +00:00
nn Enable PLW0127 in ruff (#165851) 2025-10-21 03:30:57 +00:00
numa Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
onnx Enable all flake8-logging-format rules (#164655) 2025-10-19 00:59:28 +00:00
optim [RFC] Add pyrefly to lintrunner (#165179) 2025-10-16 20:07:09 +00:00
package [1/N] Use "is" in python type comparison (#165037) 2025-10-10 12:36:50 +00:00
profiler Enable PLC0414 on ruff (#165828) 2025-10-22 04:56:52 +00:00
quantization [RFC] Add pyrefly to lintrunner (#165179) 2025-10-16 20:07:09 +00:00
signal Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
sparse [RFC] Add pyrefly to lintrunner (#165179) 2025-10-16 20:07:09 +00:00
special
testing [AMD][gfx1100] test_decompose_mem_bound_mm.py tolerance increase (#165625) 2025-10-22 01:38:48 +00:00
utils Enable PLC0414 on ruff (#165828) 2025-10-22 04:56:52 +00:00
xpu Add a new API torch.xpu.is_tf32_supported for Intel GPU (#163141) 2025-10-12 12:11:57 +00:00
__config__.py
__future__.py
__init__.py [BE][Ez]: Update torch.is_tensor documentation (#165841) 2025-10-19 09:24:11 +00:00
_appdirs.py
_classes.py
_compile.py
_custom_ops.py
_environment.py
_guards.py
_jit_internal.py [2/N] More ruff SIM fixes (#165031) 2025-10-14 14:22:54 +00:00
_linalg_utils.py
_lobpcg.py Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
_lowrank.py
_meta_registrations.py Enable all PIE rules on ruff (#165814) 2025-10-18 07:36:18 +00:00
_namedtensor_internals.py
_ops.py [2/N] More ruff SIM fixes (#165031) 2025-10-14 14:22:54 +00:00
_python_dispatcher.py
_size_docs.py
_sources.py
_storage_docs.py
_streambase.py
_tensor_docs.py [reland] Allow setting grad_dtype on leaf tensors (#164751) 2025-10-08 20:23:13 +00:00
_tensor_str.py Enable all PIE rules on ruff (#165814) 2025-10-18 07:36:18 +00:00
_tensor.py Pyrefly suppressions 6/n (#164877) 2025-10-08 02:30:57 +00:00
_thread_safe_fork.py
_torch_docs.py Clarrifying input output angle unit in the docs for trigonometric fun… (#161248) 2025-10-18 11:53:48 +00:00
_utils_internal.py Revert "Call internal log_compilation_event if it exists (#164855)" 2025-10-09 22:38:45 +00:00
_utils.py Enable ruff rule E721 (#165162) 2025-10-13 01:48:55 +00:00
_VF.py
_vmap_internals.py
_weights_only_unpickler.py
CMakeLists.txt
custom_class_detail.h Mark unused parameters in C++ code (#164912) 2025-10-09 06:23:25 +00:00
custom_class.h Mark unused parameters in C++ code (#164912) 2025-10-09 06:23:25 +00:00
extension.h
functional.py [2/N] Fix ruff warnings (#164460) 2025-10-04 03:40:32 +00:00
header_only_apis.txt Move toUnderlying to headeronly (#165694) 2025-10-22 05:31:16 +00:00
hub.py Enable PLC1802 on ruff (#165813) 2025-10-18 05:44:14 +00:00
library.h Mark unused parameters in C++ code (#164912) 2025-10-09 06:23:25 +00:00
library.py Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
overrides.py Add scaled_grouped_mm_v2 and python API (#165154) 2025-10-15 17:47:23 +00:00
py.typed
quasirandom.py
random.py Revert "Add device argument to torch.random.get_rng_state (#163034)" 2025-10-04 15:25:45 +00:00
return_types.py
script.h
serialization.py Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
storage.py [2/N] Use "is" in python type comparison (#165142) 2025-10-10 15:36:44 +00:00
torch_version.py
types.py Enable PLC0414 on ruff (#165828) 2025-10-22 04:56:52 +00:00
version.py.tpl