pytorch/torch
Nikita Vedeneev 7f256fff77 Improve bsr @ strided performance in baddmm for bfloat16/half with Triton kernels. (#88078)
As per title.

Additionally we also introduce support for:
- Rectangular block sizes which are powers of 2 and at least 16 (triton's `dot` limitation).
- Batch support with broadcasting for either of the arguments.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88078
Approved by: https://github.com/cpuhrsch
2023-01-17 21:43:20 +00:00
..
_C [autograd.Function] Kill the extension feature flag (#92026) 2023-01-17 13:36:42 +00:00
_C_flatbuffer
_decomp Rewrite out-of-place decompositions in terms of out-of-place ops (#92003) 2023-01-17 16:53:27 +00:00
_dispatch
_dynamo Revert "[FSDP] Do not clean FQNs even for use_orig_params=True (#91767)" 2023-01-17 20:04:52 +00:00
_functorch [autograd.Function] update error messages for vmap to point to docs (#92030) 2023-01-17 13:36:42 +00:00
_inductor Fix model accuracy issue caused by vectorized transpose (#92299) 2023-01-17 17:53:45 +00:00
_lazy
_prims squeeze: allow squeezing multiple dimensions at once (#89017) 2023-01-17 14:20:15 +00:00
_prims_common Enable xdoctest runner in CI for real this time (#83816) 2022-12-29 05:32:42 +00:00
_refs Assorted decomposition fixes (#87183) 2023-01-17 16:53:31 +00:00
_subclasses Revert "[Modes] Add assert that the mode isn't already on the stack (#90770)" 2023-01-12 16:44:29 +00:00
amp
ao [quant][fx] Add support for GRU in fx graph mode quantization (#91976) 2023-01-13 07:00:12 +00:00
autograd [autograd.Function] Kill the extension feature flag (#92026) 2023-01-17 13:36:42 +00:00
backends Dynamo benchmark: add CPU specific changes (#88477) 2023-01-07 09:26:06 +00:00
contrib
cpu
csrc Revert "Add sym_size/stride/numel/storage_offset to native_function.yaml (#91919)" 2023-01-17 21:03:18 +00:00
cuda Use binary units for CUDA memory summary (#91854) 2023-01-14 05:10:51 +00:00
distributed Revert "[FSDP] Do not clean FQNs even for use_orig_params=True (#91767)" 2023-01-17 20:04:52 +00:00
distributions Enable xdoctest runner in CI for real this time (#83816) 2022-12-29 05:32:42 +00:00
fft
func [functorch] move batch_norm_replacement to torch.func (#91412) 2023-01-12 19:15:41 +00:00
futures Enable xdoctest runner in CI for real this time (#83816) 2022-12-29 05:32:42 +00:00
fx Revert "[fx] rewrite FloorDiv to match Python better (#90906)" 2023-01-17 19:26:38 +00:00
jit Add shape function for movedim op (#91696) 2023-01-06 18:24:52 +00:00
legacy
lib Some CMake and CUDA cleanup given recent update to C++17 (#90599) 2022-12-30 11:19:26 +00:00
linalg Fix terminology within linalg.slogdet docs (#91129) 2022-12-20 01:55:27 +00:00
masked unify reduction types from different operators: scatter, scatter_reduce, segment_reduce (#91499) 2023-01-13 04:32:34 +00:00
monitor Enable xdoctest runner in CI for real this time (#83816) 2022-12-29 05:32:42 +00:00
multiprocessing
nested
nn Update Module.__setattr__ to respect property setters (#92044) 2023-01-17 20:00:06 +00:00
onnx [ONNX] Raise Unsupported for Grid Sample with volumetric 5D input (#92212) 2023-01-16 03:34:05 +00:00
optim [optim] abstract out _default_to_foreach_util (#92305) 2023-01-17 19:42:20 +00:00
package Minor fix in package exporter (#90306) 2022-12-27 18:01:59 +00:00
profiler Call profiler step via optimizer post hook (#90101) 2023-01-13 18:07:40 +00:00
quantization [ao] making _is_activation_post_process private with BC (#90554) 2022-12-16 08:09:33 +00:00
signal Reland "Add torch.utils.device_mode" (#91796) 2023-01-09 20:57:12 +00:00
sparse Improve bsr @ strided performance in baddmm for bfloat16/half with Triton kernels. (#88078) 2023-01-17 21:43:20 +00:00
special
testing Deprecate .mT,.T,.mH,.H on 0D tensors (#92143) 2023-01-17 16:54:35 +00:00
utils [Reland] Clean Up MobileOptimizerType Rewrite Flags Public API and Documentation (#92081) 2023-01-14 17:06:00 +00:00
__config__.py
__future__.py
__init__.py Improve bsr @ strided performance in baddmm for bfloat16/half with Triton kernels. (#88078) 2023-01-17 21:43:20 +00:00
_appdirs.py
_classes.py
_deploy.py
_guards.py Properly resolve source_ref when constructing shape guards (#91058) 2022-12-30 05:56:56 +00:00
_jit_internal.py [JIT] Skip builtins while enumerating class methods (#91805) 2023-01-06 21:45:09 +00:00
_linalg_utils.py
_lobpcg.py Fix typo in _lobpcg.py (#91641) 2023-01-04 15:19:05 +00:00
_lowrank.py
_meta_registrations.py Return empty attention weights when need_atten_weights = False (#91782) 2023-01-06 19:06:48 +00:00
_namedtensor_internals.py Enable xdoctest runner in CI for real this time (#83816) 2022-12-29 05:32:42 +00:00
_ops.py
_python_dispatcher.py
_six.py
_sources.py
_storage_docs.py
_tensor_docs.py fix in-place geometric pmf (#92049) 2023-01-12 19:56:44 +00:00
_tensor_str.py Enable xdoctest runner in CI for real this time (#83816) 2022-12-29 05:32:42 +00:00
_tensor.py Make torch.split take symint as arg (#91724) 2023-01-07 00:00:03 +00:00
_torch_docs.py squeeze: allow squeezing multiple dimensions at once (#89017) 2023-01-17 14:20:15 +00:00
_utils_internal.py
_utils.py [follow-up] Python Attr Serialization (#88913) 2023-01-13 17:38:51 +00:00
_VF.py
_vmap_internals.py Enable xdoctest runner in CI for real this time (#83816) 2022-12-29 05:32:42 +00:00
_weights_only_unpickler.py
abi-check.cpp
CMakeLists.txt Revert "[cuDNN][cuDNN V8 API] Always build assuming cuDNN >= 8.0 (#91527)" 2023-01-16 13:28:09 +00:00
custom_class_detail.h
custom_class.h
extension.h
functional.py Update version numbers in torch.{stft,istft} deprecations (#91761) 2023-01-05 22:17:37 +00:00
hub.py Preventing crashing incase of no network by loading from cache (#91569) 2023-01-11 11:56:46 +00:00
library.h
library.py Enable xdoctest runner in CI for real this time (#83816) 2022-12-29 05:32:42 +00:00
overrides.py Reland "Add torch.utils.device_mode" (#91796) 2023-01-09 20:57:12 +00:00
py.typed
quasirandom.py
random.py
README.txt
return_types.py
script.h
serialization.py Enable xdoctest runner in CI for real this time (#83816) 2022-12-29 05:32:42 +00:00
storage.py Rename Tensor._storage to Tensor.untyped_storage and update docs (#91414) 2022-12-28 19:21:34 +00:00
torch_version.py
types.py

Note [TH abstraction violation]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

TH/THC provide some hpp headers, which are proper C++ headers rather than
C headers.  These headers serve double duty as *internal implementation
detail* headers, whose contents should largely not be used by external
clients.

Ideally, we would not install these headers at all; instead, you should
use public functions (in headers like `THTensor.h`, NOT `THTensor.hpp`)
to manipulate these structs.  However, there are a few places
in torch/csrc where we violate this abstraction.  They are marked with
a pointer to this note.  Each of those sites will have to be refactored
when we refactor the guts of THTensor and related structures.