pytorch/torch
IvanKobzarev 585b9dbb5e [async_tp] Support ag+mm with gather_dim lastdim of mat_A (#163068)
Adding ag+mm support for the case, when gather_dim is last dim of matmul (reduction dim).

When we decompose matmul by reduction dimension we result in partials that needs additional reduction,
we allocate memory for accumulator.

Decomposition should not produce small (thin) mms that can not efficiently load the GPU. Limiting for minimal size of the shard 1024 (found empirically by testing in torchtitan).

scaled_mm is not supported yet for this case.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163068
Approved by: https://github.com/ngimel
2025-10-16 20:14:39 +00:00
..
_awaits
_C [ROCm][tunableop] Improvements to tunableop Numerical Check (#163079) 2025-10-15 22:26:47 +00:00
_C_flatbuffer
_custom_op [2/N] Fix ruff warnings (#164460) 2025-10-04 03:40:32 +00:00
_decomp Ensure rms_norm decomp generates add.Scalar for pattern match BC (#165437) 2025-10-14 19:56:37 +00:00
_dispatch Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
_dynamo [RFC] Add pyrefly to lintrunner (#165179) 2025-10-16 20:07:09 +00:00
_export [torch.export] Rmoving unused constants - add support for corner case (#165205) 2025-10-14 20:26:28 +00:00
_functorch [RFC] Add pyrefly to lintrunner (#165179) 2025-10-16 20:07:09 +00:00
_higher_order_ops [hop] run local_map with interpreter to preserve fx_traceback annotations (#165336) 2025-10-16 02:53:17 +00:00
_inductor [async_tp] Support ag+mm with gather_dim lastdim of mat_A (#163068) 2025-10-16 20:14:39 +00:00
_lazy Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
_library [RFC] Add pyrefly to lintrunner (#165179) 2025-10-16 20:07:09 +00:00
_logging Enable ruff rule E721 (#165162) 2025-10-13 01:48:55 +00:00
_numpy Enable ruff rule E721 (#165162) 2025-10-13 01:48:55 +00:00
_prims Add pyrefly suppressions 2/n (#164513) 2025-10-03 02:46:13 +00:00
_prims_common [2/N] Use "is" in python type comparison (#165142) 2025-10-10 15:36:44 +00:00
_refs Enable ruff rule E721 (#165162) 2025-10-13 01:48:55 +00:00
_strobelight Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
_subclasses Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939) 2025-10-11 01:03:55 +00:00
_vendor
accelerator
amp Revert "[AMP][Refactor] Simplify dtype support logic in autocast context manager (#163446)" 2025-10-10 15:12:46 +00:00
ao [2/N] More ruff SIM fixes (#165031) 2025-10-14 14:22:54 +00:00
autograd [2/N] More ruff SIM fixes (#165031) 2025-10-14 14:22:54 +00:00
backends Revert "Add SVE128 ISA (#158932)" 2025-10-10 01:17:02 +00:00
compiler Megacache integration (#163533) 2025-10-15 22:49:15 +00:00
contrib
cpu Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
csrc refactor: replace runtime_error with TORCH_CHECK for better error handling (#163628) 2025-10-16 11:09:48 +00:00
cuda [ROCm][tunableop] Improvements to tunableop Numerical Check (#163079) 2025-10-15 22:26:47 +00:00
distributed [async_tp] Support ag+mm with gather_dim lastdim of mat_A (#163068) 2025-10-16 20:14:39 +00:00
distributions [1/N] Use "is" in python type comparison (#165037) 2025-10-10 12:36:50 +00:00
export [RFC] Add pyrefly to lintrunner (#165179) 2025-10-16 20:07:09 +00:00
fft
func
futures [5/N] Apply ruff UP035 rule (#164423) 2025-10-02 07:31:11 +00:00
fx [Bugfix][Precompile][vLLM] Support for pickling einops for aot_autograd serialization in vLLM (#165359) 2025-10-15 20:00:24 +00:00
headeronly Move toString(ScalarType) and ScalarType ostream operator to headeronly (#164405) 2025-10-16 00:55:43 +00:00
jit Fix missing brackets (#165138) 2025-10-10 17:23:31 +00:00
legacy
lib [2/N] Mark unused parameters in C++ code (#165121) 2025-10-15 03:04:39 +00:00
linalg Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
masked [1/N] Use "is" in python type comparison (#165037) 2025-10-10 12:36:50 +00:00
monitor
mps Add type annotations to MPS profiler utilities (#163486) 2025-09-27 23:00:53 +00:00
mtia Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
multiprocessing Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
nativert [2/N] Mark unused parameters in C++ code (#165121) 2025-10-15 03:04:39 +00:00
nested [NJT] Fix schema validation error in jagged functions (#165307) 2025-10-13 17:59:18 +00:00
nn [RFC] Add pyrefly to lintrunner (#165179) 2025-10-16 20:07:09 +00:00
numa Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
onnx [RFC] Add pyrefly to lintrunner (#165179) 2025-10-16 20:07:09 +00:00
optim [RFC] Add pyrefly to lintrunner (#165179) 2025-10-16 20:07:09 +00:00
package [1/N] Use "is" in python type comparison (#165037) 2025-10-10 12:36:50 +00:00
profiler Pyrefly suppressions 6/n (#164877) 2025-10-08 02:30:57 +00:00
quantization [RFC] Add pyrefly to lintrunner (#165179) 2025-10-16 20:07:09 +00:00
signal Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
sparse [RFC] Add pyrefly to lintrunner (#165179) 2025-10-16 20:07:09 +00:00
special
testing add the option to disable functionalization in AOTDispatcher (#164577) 2025-10-16 15:44:11 +00:00
utils [RFC] Add pyrefly to lintrunner (#165179) 2025-10-16 20:07:09 +00:00
xpu Add a new API torch.xpu.is_tf32_supported for Intel GPU (#163141) 2025-10-12 12:11:57 +00:00
__config__.py
__future__.py
__init__.py [2/N] Use "is" in python type comparison (#165142) 2025-10-10 15:36:44 +00:00
_appdirs.py
_classes.py
_compile.py [4/N] Apply ruff UP035 rule to python code (#164206) 2025-10-01 19:05:53 +00:00
_custom_ops.py
_environment.py
_guards.py [4/N] Apply ruff UP035 rule to python code (#164206) 2025-10-01 19:05:53 +00:00
_jit_internal.py [2/N] More ruff SIM fixes (#165031) 2025-10-14 14:22:54 +00:00
_linalg_utils.py
_lobpcg.py Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
_lowrank.py
_meta_registrations.py [2/N] More ruff SIM fixes (#165031) 2025-10-14 14:22:54 +00:00
_namedtensor_internals.py
_ops.py [2/N] More ruff SIM fixes (#165031) 2025-10-14 14:22:54 +00:00
_python_dispatcher.py
_size_docs.py
_sources.py
_storage_docs.py
_streambase.py
_tensor_docs.py [reland] Allow setting grad_dtype on leaf tensors (#164751) 2025-10-08 20:23:13 +00:00
_tensor_str.py Pyrefly suppressions 6/n (#164877) 2025-10-08 02:30:57 +00:00
_tensor.py Pyrefly suppressions 6/n (#164877) 2025-10-08 02:30:57 +00:00
_thread_safe_fork.py
_torch_docs.py
_utils_internal.py Revert "Call internal log_compilation_event if it exists (#164855)" 2025-10-09 22:38:45 +00:00
_utils.py Enable ruff rule E721 (#165162) 2025-10-13 01:48:55 +00:00
_VF.py
_vmap_internals.py [4/N] Apply ruff UP035 rule to python code (#164206) 2025-10-01 19:05:53 +00:00
_weights_only_unpickler.py [4/N] Apply ruff UP035 rule to python code (#164206) 2025-10-01 19:05:53 +00:00
CMakeLists.txt
custom_class_detail.h Mark unused parameters in C++ code (#164912) 2025-10-09 06:23:25 +00:00
custom_class.h Mark unused parameters in C++ code (#164912) 2025-10-09 06:23:25 +00:00
extension.h
functional.py [2/N] Fix ruff warnings (#164460) 2025-10-04 03:40:32 +00:00
header_only_apis.txt Move toString(ScalarType) and ScalarType ostream operator to headeronly (#164405) 2025-10-16 00:55:43 +00:00
hub.py Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
library.h Mark unused parameters in C++ code (#164912) 2025-10-09 06:23:25 +00:00
library.py Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
overrides.py Add scaled_grouped_mm_v2 and python API (#165154) 2025-10-15 17:47:23 +00:00
py.typed
quasirandom.py
random.py Revert "Add device argument to torch.random.get_rng_state (#163034)" 2025-10-04 15:25:45 +00:00
return_types.py
script.h
serialization.py Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
storage.py [2/N] Use "is" in python type comparison (#165142) 2025-10-10 15:36:44 +00:00
torch_version.py
types.py [4/N] Apply ruff UP035 rule to python code (#164206) 2025-10-01 19:05:53 +00:00
version.py.tpl