Previously, FlopCounterMode would ignore any custom ops registered
through `register_flop_formula`. The problem was:
- register_flop_formula(target) requires target to be an OpOverloadPacket.
- register_flop_formula used register_decomposition to populate its registry
- register_decomposition decomposes the OpOverloadPacket into OpOverload before
putting it into the registry
- FlopCounterMode ignores OpOverloads in its registry (it assumes the
registry is a dictionary mapping OpOverloadPacket to flop formula).
register_decomposition is too heavy of a hammer, plus this isn't a
decomposition, so I changed the registration mechanism.
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131777
Approved by: https://github.com/Chillee
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.
What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...
Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.
What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...
Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.
What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...
Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
This adds implementations for:
* _flash_attention_forward
* _efficient_attention_forward
* _flash_attention_backward
* _efficient_attention_backward
These flop counts are implemented as follows:
* Unbind the batch elements
* Calculate flops individually for each element in the batch
* Sum the final result
This means that we are accessing the concrete sequence lengths (which could be slow, and may trigger a GPU/CPU sync); but, the FLOP numbers will vary with the sparsity of the NestedTensor - more accurate than if we just assumed we padded everything.
Differential Revision: [D57120139](https://our.internmc.facebook.com/intern/diff/D57120139)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125776
Approved by: https://github.com/Chillee
This does a few things that were originally a few PRs but I am on a new machine and don't have ghstack.
If it is too problematic to review, I can re-split, just let me know.
This does:
- Cleanup context manager use in test_flop_counter
- Remove need for mod argument in FlopCounterMode, warning about it
- Re-implement a Module tracker from scratch using global forward Module use and multi_grad_hook (we cannot use global backward Module hook because they don't look for nested Tensor and they're custom Function based instead of multi_grad_hook).
- Update FlopCouterMode to use the new ModuleTracker. All the existing test suite passes as-is (only changes there are new tests and refactoring mentioned above)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125352
Approved by: https://github.com/mikaylagawarecki
Summary:
Torchscript modules do not support forward hooks and thus can't work with flop_counter context manager for hierarchical output by passing a module to FlopCounterMode on construction.
Currently any module that includes a script module causes an exception to be thrown so adding a try/catch to ignore any script modules for forward hooks.
Test Plan: CI Signals
Differential Revision: D56850661
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125346
Approved by: https://github.com/842974287
Updates flake8 to v6.1.0 and fixes a few lints using sed and some ruff tooling.
- Replace `assert(0)` with `raise AssertionError()`
- Remove extraneous parenthesis i.e.
- `assert(a == b)` -> `assert a == b`
- `if(x > y or y < z):`->`if x > y or y < z:`
- And `return('...')` -> `return '...'`
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116591
Approved by: https://github.com/albanD, https://github.com/malfet
* Enable PERF402. Makes code more efficient and succinct by removing useless list copies that could be accomplished either via a list constructor or extend call. All test cases have noqa added since performance is not as sensitive in that folder.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115505
Approved by: https://github.com/malfet
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)
That were reverted due to the conflict with internal source repo.
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
- Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
- Add missing return statement to `torch._export. deserialize_graph`
- Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
- Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
- Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)
That were reverted due to the conflict with internal source repo.
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
- Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
- Add missing return statement to `torch._export. deserialize_graph`
- Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
- Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
- Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
This PR fixes https://github.com/pytorch/pytorch/issues/103684.
- Instead of registering forward hooks in `__init__()`, do it upon `__enter__()`.
- De-register those forward hooks upon `__exit__()`.
- Achieve this by saving an additional mapping `_module_to_forward_hook_handles: Dict[nn.Module, _ForwardHookHandles]`. Only the values in the mapping (i.e. not the keys) are useful for this change. (A `List[_ForwardHookHandles]` would suffice.)
- The unit test accesses private attributes `_forward_hooks` and `_forward_pre_hooks` :/
Note that this PR is technically not backward compatible since it does not register the hooks upon `__init__()`, which means that you will not get the flops counting without the context manager.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103744
Approved by: https://github.com/Chillee
High level approach:
1. I generated a bunch of data comparing FlashAttention and Cutlass implementations (https://pastebin.com/pe0j3YeK)
2. I trained a decision tree using standard train/val split methodology and hyperparameter sweeps (https://pastebin.com/fjYX1HjR).
2a. I did a bunch of feature augmentation to capture interactions between features.
The heuristic I ended up with is:
```
use_flash = seq_len / (num_heads * batch_size) > 6
```
TL;DR: On my dataset, where FlashAttention and Cutlass differ by more than 10%, the existing heuristic achieves 69% accuracy. My new heuristic achieves 94% accuracy.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99644
Approved by: https://github.com/ngimel, https://github.com/drisspg
Overall, an example usage. Note that this *also* captures backwards FLOPs.
```
import torchvision.models as models
import torch
from torch.utils.flop_counter import FlopCounterMode
inp = torch.randn(1, 3, 224, 224, device='cpu')
mod = models.resnet18()
flop_counter = FlopCounterMode(mod, depth=1)
with flop_counter:
mod(inp).sum().backward()
```
<img width="326" alt="image" src="https://user-images.githubusercontent.com/6355099/222023068-3491e405-f195-4e11-b679-36b19a1380c7.png">
You can control the depth of the module hierarchy with the `depth` attribute (which defaults to 2). For example, if I don't limit it, this is what it outputs.
<img width="366" alt="image" src="https://user-images.githubusercontent.com/6355099/222023306-3d880bb6-f534-4f98-bf10-83c4353acefc.png">
## Other APIs
FlopCounterMode(custom_mapping=...): Allows for custom flop counting functions
FlopCounterMode.get_table(depth=...): Explicitly get the table as a string
FlopCounterMode.flop_counts: Contains the flop information as a Dict[hierarchy: str, Dict[Op, int]]
FlopCounterMode.register_hierarchy(f, name): Allows you to register additional "hierarchies" for a function.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95751
Approved by: https://github.com/ngimel, https://github.com/albanD