pytorch/torch
David Berard df1e855313 [fake_impls] fix max_seqlen return values in efficient_attention_forward (#120842)
To match the actual implementation, we should return the max_seqlen_q/k, not M, N, when in the sparse case

7e185277cd/aten/src/ATen/native/transformers/cuda/attention.cu (L981-L996)

Note that although the .cu file sets max_seqlen_k = 0 in the sparse case, it actually returns max_seqlen_k or N:

7e185277cd/aten/src/ATen/native/transformers/cuda/attention.cu (L1224-L1231)

Tests - added in the next PR (#102839, which also fixes other parts of the test_fake tests so that we can un-xfail them and actually run the tests)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120842
Approved by: https://github.com/YuqingJ
ghstack dependencies: #120682
2024-02-29 07:12:27 +00:00
..
_awaits
_C Let torch dynamo inline torch.func.grad (#118407) 2024-02-28 20:05:00 +00:00
_C_flatbuffer
_custom_op [inductor][custom ops] Add tag to custom ops to preserve stride orders in inductor (#117298) 2024-01-21 18:47:01 +00:00
_decomp Revert "Avoid COW materialization in at::parallel_for/parallel_reduce (#120455)" 2024-02-28 22:30:18 +00:00
_dispatch
_dynamo [torch.export] Support is_compiling() flag for non-strict mode (#119602) 2024-02-29 05:52:51 +00:00
_export [pytorch] Support output types that are non tensors (#120804) 2024-02-29 02:49:10 +00:00
_functorch [Compiled Autograd] Introduce BackwardState capture (#120382) 2024-02-28 20:36:47 +00:00
_higher_order_ops [torch.export] Support is_compiling() flag for non-strict mode (#119602) 2024-02-29 05:52:51 +00:00
_inductor Add equal_to_1 to triton_meta for user-written Triton kernels (#120579) 2024-02-29 05:19:39 +00:00
_lazy
_library Fix FallbackKernel behavior on mutable ops (#118649) 2024-02-09 19:01:54 +00:00
_logging Add TORCH_LOGS_FORMAT=short alias (#120757) 2024-02-28 04:40:48 +00:00
_numpy Fix dynamo failure w/ astype (#117952) 2024-02-03 08:10:15 +00:00
_prims add decomposition for frexp (#119217) 2024-02-23 21:52:42 +00:00
_prims_common [Dynamic] Fix dynamic shape size inspection bug (#120341) 2024-02-22 21:08:28 +00:00
_refs fix decomposition of aten.diag_embed (#120549) 2024-02-28 18:48:01 +00:00
_subclasses [fake_impls] fix max_seqlen return values in efficient_attention_forward (#120842) 2024-02-29 07:12:27 +00:00
_vendor
amp Remove device assert in Gradscaler (#119362) 2024-02-22 08:02:18 +00:00
ao [quant][pt2e] Relax model_is_exported input (#120720) 2024-02-28 18:32:03 +00:00
autograd [XPU][Profiler] Add Logic To The Profiler For Processing XPU-backend Data (#120185) 2024-02-28 17:50:32 +00:00
backends [CUDNN][SDPA] Experimental cuDNN Flash Attention v2 Inference (#115663) 2024-02-14 22:02:06 +00:00
compiler [torch.export] Support is_compiling() flag for non-strict mode (#119602) 2024-02-29 05:52:51 +00:00
contrib
cpu add GradScaler on CPU (#109993) 2024-01-29 23:42:35 +00:00
csrc [C10D] Add ProcessGroup op_id to track ops inside coalescing region (#120745) 2024-02-29 01:03:31 +00:00
cuda refactor code to share across different devices (#120602) 2024-02-28 09:42:58 +00:00
distributed [torch.export] Support is_compiling() flag for non-strict mode (#119602) 2024-02-29 05:52:51 +00:00
distributions Bugfix to MixtureSameFamily's _pad_mixture_dimension (#118947) 2024-02-06 16:24:22 +00:00
export [torch.export] Support is_compiling() flag for non-strict mode (#119602) 2024-02-29 05:52:51 +00:00
fft
func Let torch dynamo inline torch.func.grad (#118407) 2024-02-28 20:05:00 +00:00
futures
fx Reduce create_env log level to DEBUG (#120772) 2024-02-29 01:33:16 +00:00
jit add export to torch.jit.__all__ (#120432) 2024-02-23 20:37:09 +00:00
legacy
lib Remove unneeded linking of torch_shm_manager in CMake (#119540) 2024-02-11 06:33:35 +00:00
linalg Fix error in examples of torch.linalg.lu_factor (#120484) 2024-02-23 13:19:04 +00:00
masked Enable possibly-undefined error code (#118533) 2024-01-30 21:07:01 +00:00
monitor
mps
multiprocessing
nested Rename singleton int to nested int (#119661) 2024-02-16 19:21:17 +00:00
nn [DDP][PT2D] Ignore gradient sync if the gradient is not defined (#120419) 2024-02-29 00:27:54 +00:00
onnx Remove monkey-patch for torch.utils._rebuild_tensor (#120446) 2024-02-23 20:42:50 +00:00
optim [BE][optim] Simplify _init_group. (#120055) 2024-02-22 22:15:01 +00:00
package
profiler [XPU][Profiler] Add Logic To The Profiler For Processing XPU-backend Data (#120185) 2024-02-28 17:50:32 +00:00
quantization Enable possibly-undefined error code (#118533) 2024-01-30 21:07:01 +00:00
signal Clarifying windows cosine behaviour in the documentation (#119444) 2024-02-09 05:57:44 +00:00
sparse [sparse] semi-structured sparse refactor (#117302) 2024-02-14 01:10:40 +00:00
special
testing Fix a potential race condition in the test decorators for enabling/disabling native funcol (#120833) 2024-02-29 03:19:44 +00:00
utils [pytree][reland] Require pytree serialized_type_name (#120636) 2024-02-27 06:53:33 +00:00
xpu [DeviceIndex][7/N] Use DeviceIndex in XPU (#120576) 2024-02-29 05:54:23 +00:00
__config__.py
__future__.py Update nn.Module._apply to not gate on should_use_set_data when swap_tensors is set (#120659) 2024-02-28 00:59:34 +00:00
__init__.py Update _constrain_as_size docs (#120728) 2024-02-28 15:03:10 +00:00
_appdirs.py
_classes.py
_compile.py
_custom_ops.py
_deploy.py [Lint] replace [assigment] with [method-assign] for methods (#119706) 2024-02-13 02:06:04 +00:00
_guards.py [Compiled Autograd] Introduce BackwardState capture (#120382) 2024-02-28 20:36:47 +00:00
_jit_internal.py [jit][perf] Reduce lookupInModule overhead. (#119145) 2024-02-05 18:01:00 +00:00
_linalg_utils.py
_lobpcg.py [Lint] replace [assigment] with [method-assign] for methods (#119706) 2024-02-13 02:06:04 +00:00
_lowrank.py
_meta_registrations.py Move attention kernels from meta_registrations to fake_impls (#120682) 2024-02-28 21:49:13 +00:00
_namedtensor_internals.py
_ops.py Enable local_partial_types (#118467) 2024-01-28 13:38:22 +00:00
_python_dispatcher.py
_sources.py
_storage_docs.py
_streambase.py
_tensor_docs.py update the tensor.scatter_ doc (#120169) 2024-02-23 02:51:55 +00:00
_tensor_str.py Revert "Add meta device support to sparse compressed tensors (#120498)" 2024-02-26 15:59:36 +00:00
_tensor.py add complex32 to v3_dtypes (#120388) 2024-02-28 02:32:29 +00:00
_torch_docs.py Fix the default value of side in torch.searchsorted (#120066) 2024-02-22 19:35:17 +00:00
_utils_internal.py Revert "Add structured trace logs (#120289)" 2024-02-27 19:49:05 +00:00
_utils.py [torch.export] Support is_compiling() flag for non-strict mode (#119602) 2024-02-29 05:52:51 +00:00
_VF.py
_vmap_internals.py
_weights_only_unpickler.py additional support for float8_e4m3fnuz and _e5m2fnuz (#115214) 2024-01-22 18:33:41 +00:00
abi-check.cpp
CMakeLists.txt [3/4] Intel GPU Runtime Upstreaming for Device (#116850) 2024-02-01 12:31:26 +00:00
custom_class_detail.h
custom_class.h
extension.h
functional.py Fix typo in istft docstring (#119776) 2024-02-15 21:20:00 +00:00
hub.py Enable possibly-undefined error code (#118533) 2024-01-30 21:07:01 +00:00
library.h Add way to actually delete a torch.library.Library object (#118318) 2024-01-26 22:30:51 +00:00
library.py Enable local_partial_types (#118467) 2024-01-28 13:38:22 +00:00
overrides.py Integrate swap_tensors into nn.Module.load_state_dict (#117913) 2024-02-09 22:32:29 +00:00
py.typed
quasirandom.py
random.py [2/2] Intel GPU Runtime Upstreaming for Generator (#118613) 2024-02-28 05:28:11 +00:00
README.txt
return_types.py register torch.return_types in torch.fx._pytree (#120027) 2024-02-23 21:52:42 +00:00
script.h
serialization.py Add FakeTensor support to torch._utils._rebuild_tensor (#108186) 2024-02-16 23:42:50 +00:00
storage.py Add hpu device support in storage/resize (#119761) 2024-02-17 01:04:27 +00:00
torch_version.py Replace follow_imports = silent with normal (#118414) 2024-01-27 02:44:11 +00:00
types.py
version.py.tpl

Note [TH abstraction violation]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

TH/THC provide some hpp headers, which are proper C++ headers rather than
C headers.  These headers serve double duty as *internal implementation
detail* headers, whose contents should largely not be used by external
clients.

Ideally, we would not install these headers at all; instead, you should
use public functions (in headers like `THTensor.h`, NOT `THTensor.hpp`)
to manipulate these structs.  However, there are a few places
in torch/csrc where we violate this abstraction.  They are marked with
a pointer to this note.  Each of those sites will have to be refactored
when we refactor the guts of THTensor and related structures.