pytorch/torch
Henry Tsang ee2d104c05 [cutlass backend] Add (limited) bmm dynamic shape support (#152393)
Differential Revision: D73626732

In this PR, we add support for bmm dynamic shape, provided that the batch stride is the biggest in the stride for A, B, and D. For example, for A of size `(B, M, K)`, we support stride `(M*K, K, 1)` and `(M*K, 1, M)`. With this assumption, we can infer the batch stride from existing arguments.

The reason is we don't want to add 2-3 more runtime params. The concerns are complexity and possible perf regression, though we didn't verify the latter.

We can revisit this if there is a need for that.

We also remove `B = 1` for normal mm and addmm. We tested it and didn't see perf regression. But open to revisiting this as well.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152393
Approved by: https://github.com/ColinPeppler
2025-04-30 04:36:24 +00:00
..
_awaits
_C [Kineto] Enable OOM observer (#152160) 2025-04-27 15:56:44 +00:00
_C_flatbuffer
_custom_op
_decomp Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
_dispatch [BE][PYFMT] migrate PYFMT for torch._dynamo to ruff format (#144549) 2025-02-28 03:03:53 +00:00
_dynamo [dynamo] Guard serialization for NAME_MATCH (#152332) 2025-04-29 20:16:00 +00:00
_export [export] Preserve custom metadata for tensor constants (#152241) 2025-04-30 00:30:35 +00:00
_functorch [cudagraphs] Fix issue in collecting static_input_idxs (#152287) 2025-04-30 03:24:05 +00:00
_higher_order_ops [Typing] Enable torch.types.IntLikeType / FloatLikeType / BoolLikeType (#152157) 2025-04-25 19:00:10 +00:00
_inductor [cutlass backend] Add (limited) bmm dynamic shape support (#152393) 2025-04-30 04:36:24 +00:00
_lazy
_library Save/load op profiles (#151817) 2025-04-29 23:11:32 +00:00
_logging [export] Beef up guard_added logs (#149465) 2025-03-20 23:02:07 +00:00
_numpy Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
_prims Support torch.compile rng selective activation checkpointing with cudagraph (#146878) 2025-02-28 00:47:03 +00:00
_prims_common [dynamic shapes] guard_or_false for _reshape_view_helper, utils._infer_size for wildcard dims (#150127) 2025-04-23 05:42:30 +00:00
_refs [MPSInductor] Fix masked_fill decomp (#152268) 2025-04-27 15:50:46 +00:00
_strobelight Enable strobelight profiling specific compile frame ids using COMPILE_STROBELIGHT_FRAME_FILTER (#147549) 2025-02-22 03:44:53 +00:00
_subclasses [dynamo] Add guard serialization for tensor matches. (#151318) 2025-04-25 14:16:23 +00:00
_vendor
accelerator Add torch.accelerator.device_index as accelerator's device switch context (#148864) 2025-04-25 09:45:25 +00:00
amp [Intel GPU] skip a cuda api call in amp to save some host overhead on xpu (#151111) 2025-04-13 06:37:07 +00:00
ao Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
autograd [BE] Migrate dtype_abbrs into one location (#152229) 2025-04-28 03:52:47 +00:00
backends Expose is_available API for torch.backends.mkldnn (#147432) 2025-04-10 05:05:37 +00:00
compiler [MegaCache] Return None on no compilation (#151921) 2025-04-23 04:32:06 +00:00
contrib
cpu [CPU Stream] Add noop for CPU stream record_event() and wait_event() (#145935) 2025-02-20 18:50:55 +00:00
csrc Revert "[AOTI][reland] Remove typedef for half and bfloat16 (#151109)" 2025-04-29 22:37:16 +00:00
cuda ROCm: Enable tf32 testing on test_nn (#148945) 2025-04-28 23:01:04 +00:00
distributed Add rich support to torch.distributed.tensor.debug.visualize_sharding (#152027) 2025-04-29 03:51:32 +00:00
distributions add generalized pareto distribution (GPD) (#135968) 2025-04-17 18:51:02 +00:00
export Fix additional inputs to error on inconsistent constants (#151970) 2025-04-30 01:38:17 +00:00
fft
func
futures PEP585: More UP006 fixes (#146392) 2025-02-20 06:18:13 +00:00
fx [ez] Make relaxed constraint error message more user friendly (#151407) 2025-04-30 03:51:50 +00:00
jit Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
legacy
lib [1/N] Use internal linkage in torch/csrc C++ files. (#150930) 2025-04-11 02:19:31 +00:00
linalg Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
masked [BE][Easy]: Dedupe a TypeAlias in PrimsCommon (#151565) 2025-04-17 19:59:41 +00:00
monitor
mps [MPS] Make torch.mps.compile_shader public (#148972) 2025-03-11 20:20:58 +00:00
mtia [Kineto] Enable OOM observer (#152160) 2025-04-27 15:56:44 +00:00
multiprocessing
nested [aotd] Guess tangents stride as output strides (#144579) 2025-03-20 15:41:36 +00:00
nn Refactor to use torch.accelerator.device_index instead of torch.cuda.device for generic device context manager (#148880) 2025-04-25 09:45:25 +00:00
onnx Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
optim Add lr_lambda type check in MultiplicativeLR (#151973) 2025-04-29 08:21:41 +00:00
package
profiler [profiler][retry] don't disable CUPTI_LAZY_REINIT for cuda >= 12.6 (#151124) 2025-04-15 16:11:49 +00:00
quantization
signal
sparse Fix spelling (#149277) 2025-03-20 01:02:32 +00:00
special
testing add xfail for distributed tests on Jetson (#152224) 2025-04-29 23:48:40 +00:00
utils Remove conda refs in tools (#152368) 2025-04-29 02:45:47 +00:00
xpu xpu: torch.xpu.get_arch_list() to return [] if xpu not compiled (#147431) 2025-02-24 01:35:54 +00:00
__config__.py
__future__.py
__init__.py [profiler][retry] don't disable CUPTI_LAZY_REINIT for cuda >= 12.6 (#151124) 2025-04-15 16:11:49 +00:00
_appdirs.py Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
_classes.py
_compile.py
_custom_ops.py
_deploy.py
_environment.py
_guards.py [dynamo] Add guard serialization for tensor matches. (#151318) 2025-04-25 14:16:23 +00:00
_jit_internal.py [BE][CI] bump ruff to 0.9.2: multiline assert statements (#144546) 2025-02-27 20:46:16 +00:00
_linalg_utils.py
_lobpcg.py Add scripts to check xrefs and urls (#151844) 2025-04-28 09:30:07 +00:00
_lowrank.py
_meta_registrations.py [inductor] align replicationpad on processing bool dtype with eager (#147666) 2025-04-28 21:54:31 +00:00
_namedtensor_internals.py
_ops.py Introduce unsafe way to mark functions as cacheable (#151603) 2025-04-21 17:37:38 +00:00
_python_dispatcher.py
_size_docs.py
_sources.py
_storage_docs.py
_streambase.py
_tensor_docs.py Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
_tensor_str.py add torch.float4_e2m1fn_x2 to PyTorch (#148791) 2025-03-27 17:32:20 +00:00
_tensor.py Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
_thread_safe_fork.py
_torch_docs.py Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
_utils_internal.py [profiler][retry] don't disable CUPTI_LAZY_REINIT for cuda >= 12.6 (#151124) 2025-04-15 16:11:49 +00:00
_utils.py Allow torch.load under FakeTensorMode to load FakeTensors with correct devices (for plain Tensors) (#147786) 2025-03-06 12:04:32 +00:00
_VF.py
_vmap_internals.py Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
_weights_only_unpickler.py Add sparse tensors constructed via legacy constructor to _sparse_tensors_to_validate (#147759) 2025-02-25 23:51:12 +00:00
CMakeLists.txt Add new dependences for gen_pyi.py (#150391) 2025-04-03 14:18:18 +00:00
custom_class_detail.h
custom_class.h Remove unneeded Clang-tidy suppression (#148246) 2025-03-01 16:51:54 +00:00
extension.h
functional.py Optimize cdist param description (#151178) 2025-04-14 13:53:10 +00:00
hub.py [BE][CI][Easy] bump ruff to 0.9.0: long statements in docstrings (#146509) 2025-02-24 19:56:08 +00:00
library.h Overload Library::def rather than templating it (#151626) 2025-04-18 22:51:16 +00:00
library.py fix spammy library deinit errors when user passes an invalid TORCH_LOGS argument (#151678) 2025-04-22 20:13:52 +00:00
overrides.py [CUDA][cuBLAS] Aten GEMM overload for FP32 output from FP16/BF16 inputs (#150812) 2025-04-18 01:53:26 +00:00
py.typed
quasirandom.py
random.py Update description for torch.random.fork_rng (#151881) 2025-04-23 16:59:29 +00:00
README.md Rename README.txt to README.md (#149811) 2025-03-24 22:33:33 +00:00
return_types.py
script.h
serialization.py Move get accelerator to use build time flags when possible (#146098) 2025-03-10 13:17:58 +00:00
storage.py add torch.float4_e2m1fn_x2 to PyTorch (#148791) 2025-03-27 17:32:20 +00:00
torch_version.py
types.py
version.py.tpl

Note [TH abstraction violation]


TH/THC provide some hpp headers, which are proper C++ headers rather than
C headers.  These headers serve double duty as *internal implementation
detail* headers, whose contents should largely not be used by external
clients.

Ideally, we would not install these headers at all; instead, you should
use public functions (in headers like `THTensor.h`, NOT `THTensor.hpp`)
to manipulate these structs.  However, there are a few places
in torch/csrc where we violate this abstraction.  They are marked with
a pointer to this note.  Each of those sites will have to be refactored
when we refactor the guts of THTensor and related structures.