pytorch/tools
Xiaodong Wang 0a94bb432e [ROCm] CK Flash Attention Backend (#143695)
Replace https://github.com/pytorch/pytorch/pull/138947 for re-import.

Replaces https://github.com/ROCm/pytorch/pull/1592

This PR contains the initial implementation of SDPA with composable_kernel backend. The CK path can be forced by simply calling torch.backends.cuda.preferred_rocm_fa_library("ck"). Similarly, you can force the incumbent aotriton implementation by passing in "aotriton" or "default". As you'd expect, not setting this option will result in aotriton to be used as the backend. In the case of CK, if pytorch deems flash attention usable, then it will use the CK path in all the same places aotriton would have been used. This PR makes no changes to the heuristics which select which attention scheme to use (i.e. flash attention vs memory efficient attention vs math etc etc). It only gets called when flash attention is both enabled (via USE_FLASH_ATTENTION) and is selected at runtime by the existing heuristics.

Files located in pytorch/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha* have been pulled from https://github.com/Dao-AILab/flash-attention courtesy of @tridao's hard work who is the co-author

NOTE: In order to use this backend, the user MUST set USE_CK_FLASH_ATTENTION=1 in their environment when they build PyTorch.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143695
Approved by: https://github.com/malfet

Co-authored-by: Andy Lugo <Andy.LugoReyes@amd.com>
Co-authored-by: Jithun Nair <jithun.nair@amd.com>
2025-01-03 22:01:36 +00:00
..
alerts
amd_build [ROCm] CK Flash Attention Backend (#143695) 2025-01-03 22:01:36 +00:00
autograd [gen_autograd_functions] rename some variables (#143166) 2024-12-16 23:18:55 +00:00
bazel_tools
build/bazel Add networkx as bazel dep to fix CI failure (#143995) 2025-01-02 19:42:18 +00:00
build_defs [lint] Remove unnecessary BUCKRESTRICTEDSYNTAX suppressions 2024-07-19 07:19:11 -07:00
code_analyzer [3/N] Apply py39 ruff fixes (#142115) 2024-12-11 17:50:10 +00:00
code_coverage Use absolute path path.resolve() -> path.absolute() (#129409) 2025-01-03 20:03:40 +00:00
config
coverage_plugins_package
dynamo
flight_recorder [fr] recognize all_reduce_barrier as a valid op (#143354) 2024-12-17 21:09:18 +00:00
gdb Add gdb print methods support same as pytorch-lldb (#140935) 2024-11-19 01:28:30 +00:00
github
iwyu
jit Use absolute path path.resolve() -> path.absolute() (#129409) 2025-01-03 20:03:40 +00:00
linter Use absolute path path.resolve() -> path.absolute() (#129409) 2025-01-03 20:03:40 +00:00
lite_interpreter c10::string_view -> std::string_view in more places (#142517) 2024-12-12 19:45:59 +00:00
lldb
onnx Use absolute path path.resolve() -> path.absolute() (#129409) 2025-01-03 20:03:40 +00:00
packaging tools: Add a tool to build wheels for multiple python versions (#143361) 2024-12-17 21:56:06 +00:00
pyi [3/N] Apply py39 ruff fixes (#142115) 2024-12-11 17:50:10 +00:00
rules
rules_cc [BE] Fix incompatible-std-redefinition warning (#141630) 2024-11-27 05:06:36 +00:00
setup_helpers Use absolute path path.resolve() -> path.absolute() (#129409) 2025-01-03 20:03:40 +00:00
shared
stats Use absolute path path.resolve() -> path.absolute() (#129409) 2025-01-03 20:03:40 +00:00
test Use absolute path path.resolve() -> path.absolute() (#129409) 2025-01-03 20:03:40 +00:00
testing Use absolute path path.resolve() -> path.absolute() (#129409) 2025-01-03 20:03:40 +00:00
__init__.py
bazel.bzl
BUCK.bzl [lint] Remove unnecessary BUCKRESTRICTEDSYNTAX suppressions 2024-07-19 07:19:11 -07:00
BUCK.oss
build_libtorch.py [BE][Easy] use pathlib.Path instead of dirname / ".." / pardir (#129374) 2024-12-29 17:23:13 +00:00
build_pytorch_libs.py Fix access to _msvccompiler from newer distutils (#141363) 2024-11-25 01:50:47 +00:00
build_with_debinfo.py Use absolute path path.resolve() -> path.absolute() (#129409) 2025-01-03 20:03:40 +00:00
download_mnist.py
extract_scripts.py [3/N] Apply py39 ruff fixes (#142115) 2024-12-11 17:50:10 +00:00
gen_flatbuffers.sh
gen_vulkan_spv.py [BE][Easy] use pathlib.Path instead of dirname / ".." / pardir (#129374) 2024-12-29 17:23:13 +00:00
generate_torch_version.py Use absolute path path.resolve() -> path.absolute() (#129409) 2025-01-03 20:03:40 +00:00
generated_dirs.txt
git_add_generated_dirs.sh
git_reset_generated_dirs.sh
nightly_hotpatch.py [3/N] Apply py39 ruff fixes (#142115) 2024-12-11 17:50:10 +00:00
nightly.py [Easy] add quotes to shell activation commands (#143902) 2024-12-27 19:17:46 +00:00
nvcc_fix_deps.py Use absolute path path.resolve() -> path.absolute() (#129409) 2025-01-03 20:03:40 +00:00
README.md
render_junit.py
substitute.py
update_masked_docs.py
vscode_settings.py [BE][Easy][5/19] enforce style for empty lines in import segments in tools/ and torchgen/ (#129756) 2024-07-17 06:44:35 +00:00

This folder contains a number of scripts which are used as part of the PyTorch build process. This directory also doubles as a Python module hierarchy (thus the __init__.py).

Overview

Modern infrastructure:

  • autograd - Code generation for autograd. This includes definitions of all our derivatives.
  • jit - Code generation for JIT
  • shared - Generic infrastructure that scripts in tools may find useful.
    • module_loader.py - Makes it easier to import arbitrary Python files in a script, without having to add them to the PYTHONPATH first.

Build system pieces:

  • setup_helpers - Helper code for searching for third-party dependencies on the user system.
  • build_pytorch_libs.py - cross-platform script that builds all of the constituent libraries of PyTorch, but not the PyTorch Python extension itself.
  • build_libtorch.py - Script for building libtorch, a standalone C++ library without Python support. This build script is tested in CI.

Developer tools which you might find useful:

Important if you want to run on AMD GPU:

  • amd_build - HIPify scripts, for transpiling CUDA into AMD HIP. Right now, PyTorch and Caffe2 share logic for how to do this transpilation, but have separate entry-points for transpiling either PyTorch or Caffe2 code.
    • build_amd.py - Top-level entry point for HIPifying our codebase.

Tools which are only situationally useful: