mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
# Summary
### Update
API
```Py
class AuxRequest(NamedTuple):
"""Request which auxiliary outputs to compute from flex_attention.
Each field is a boolean indicating whether that auxiliary output should be computed.
"""
lse: bool = False
max_scores: bool = False
class AuxOutput(NamedTuple):
"""Auxiliary outputs from flex_attention operation.
Fields will be None if not requested, or contain the tensor if requested.
"""
lse: Optional[Tensor] = None
max_scores: Optional[Tensor] = None
out_only = flex_attention(query, key, value, score_mod)
out_max, aux_max = flex_attention(
query,
key,
value,
score_mod,
return_aux=FlexAttentionAuxRequest(max_scores=True),
)
out_both, aux_both = flex_attention(
query,
key,
value,
score_mod,
return_aux=FlexAttentionAuxRequest(lse=True, max_scores=True),
)
```
Returns the max post mod scores from flex attention.
Not being able to break BC is kinda of annoying here since we end up with a combinatorial problem where if we need to add any more return vals we need to new kwargs that gate if they get returned by the function and need to support the 2**N additional args possible return groups.
Ideally there isn't much more we need to return, but we might want to think about how best to set this up for expansion in the future. I added kwarg only now
Maybe we make a `ExtraReturns` type kwarg that can grow and we don't need to keep adding new top level args.
We could also return a Struct that holds all the extra tensors and start deprecation cycle for logsumexp eventually returning just 1 `ExtraReturns` like struct with the tensors.
### Req Grad
I currently dont return a max_scores that supports backproping grads. I think this might be feasible but since max is essentially 1 hot on the inputs and a reduction we would either need to save another `max_location` from the forward or find the max_score but also only apply to first occurence if there is multiple equivalent scores (need to check if thats we define for vanilla max op in torch).
For now no grad, we can re-visit if needed.
## Perf
I am going to disable for flex_decode. Since at least initially the motivation is for training. I also more hard than it should be to have ops return nuns or optional tensors, If return max is at the false, we should probably just create a tensor of size zero so that we don't slow down the hot path.
```Shell
🔝 Top 5 TFlops Deltas (by absolute %):
shape: (5, 7)
┌────────────────┬────────────────┬───────────────────────┬───────────────┬──────────────┬───────────┬───────────┐
│ attn_type ┆ dtype ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta ┆ pct_delta │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞════════════════╪════════════════╪═══════════════════════╪═══════════════╪══════════════╪═══════════╪═══════════╡
│ causal ┆ torch.bfloat16 ┆ (4, 16, 2048, 16, ┆ 249.514658 ┆ 243.078974 ┆ 6.435684 ┆ 2.647569 │
│ ┆ ┆ 2048, 64) ┆ ┆ ┆ ┆ │
│ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 57.971274 ┆ 56.633641 ┆ 1.337633 ┆ 2.361905 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
│ noop ┆ torch.bfloat16 ┆ (4, 16, 1024, 16, ┆ 244.052884 ┆ 248.65129 ┆ -4.598406 ┆ -1.849339 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
│ noop ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 280.71254 ┆ 275.686991 ┆ 5.025549 ┆ 1.822918 │
│ ┆ ┆ 1024, 128) ┆ ┆ ┆ ┆ │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 16384, 16, ┆ 152.970031 ┆ 150.489109 ┆ 2.480923 ┆ 1.648573 │
│ ┆ ┆ 16384, 64) ┆ ┆ ┆ ┆ │
└────────────────┴────────────────┴───────────────────────┴───────────────┴──────────────┴───────────┴───────────┘
🔺 Top 5 Positive TFlops Deltas (highest +%):
shape: (5, 7)
┌────────────────┬────────────────┬────────────────────────┬───────────────┬──────────────┬──────────┬───────────┐
│ attn_type ┆ dtype ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta ┆ pct_delta │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞════════════════╪════════════════╪════════════════════════╪═══════════════╪══════════════╪══════════╪═══════════╡
│ causal ┆ torch.bfloat16 ┆ (4, 16, 2048, 16, ┆ 249.514658 ┆ 243.078974 ┆ 6.435684 ┆ 2.647569 │
│ ┆ ┆ 2048, 64) ┆ ┆ ┆ ┆ │
│ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 57.971274 ┆ 56.633641 ┆ 1.337633 ┆ 2.361905 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
│ noop ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 280.71254 ┆ 275.686991 ┆ 5.025549 ┆ 1.822918 │
│ ┆ ┆ 1024, 128) ┆ ┆ ┆ ┆ │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 16384, 16, ┆ 152.970031 ┆ 150.489109 ┆ 2.480923 ┆ 1.648573 │
│ ┆ ┆ 16384, 64) ┆ ┆ ┆ ┆ │
│ causal ┆ torch.bfloat16 ┆ (4, 16, 1024, 16, ┆ 161.031318 ┆ 158.597808 ┆ 2.43351 ┆ 1.534391 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
└────────────────┴────────────────┴────────────────────────┴───────────────┴──────────────┴──────────┴───────────┘
🔻 Top 5 Negative TFlops Deltas (lowest -%):
shape: (5, 7)
┌────────────────┬────────────────┬───────────────────────┬───────────────┬──────────────┬───────────┬───────────┐
│ attn_type ┆ dtype ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta ┆ pct_delta │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞════════════════╪════════════════╪═══════════════════════╪═══════════════╪══════════════╪═══════════╪═══════════╡
│ noop ┆ torch.bfloat16 ┆ (4, 16, 1024, 16, ┆ 244.052884 ┆ 248.65129 ┆ -4.598406 ┆ -1.849339 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
│ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 4, ┆ 175.546923 ┆ 177.81205 ┆ -2.265127 ┆ -1.273888 │
│ ┆ ┆ 1024, 128) ┆ ┆ ┆ ┆ │
│ sliding_window ┆ torch.bfloat16 ┆ (4, 16, 16384, 4, ┆ 156.282597 ┆ 158.209134 ┆ -1.926537 ┆ -1.217715 │
│ ┆ ┆ 16384, 64) ┆ ┆ ┆ ┆ │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 2048, 16, ┆ 232.542929 ┆ 235.140136 ┆ -2.597207 ┆ -1.104536 │
│ ┆ ┆ 2048, 128) ┆ ┆ ┆ ┆ │
│ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 169.652791 ┆ 171.475986 ┆ -1.823195 ┆ -1.063236 │
│ ┆ ┆ 1024, 128) ┆ ┆ ┆ ┆ │
└────────────────┴────────────────┴───────────────────────┴───────────────┴──────────────┴───────────┴───────────┘
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161667
Approved by: https://github.com/Chillee, https://github.com/BoyuanFeng
|
||
|---|---|---|
| .. | ||
| _static | ||
| _templates | ||
| accelerator | ||
| community | ||
| compile | ||
| elastic | ||
| export | ||
| notes | ||
| rpc | ||
| scripts | ||
| user_guide | ||
| accelerator.md | ||
| amp.md | ||
| autograd.md | ||
| backends.md | ||
| benchmark_utils.md | ||
| bottleneck.rst | ||
| checkpoint.md | ||
| complex_numbers.md | ||
| cond.md | ||
| conf.py | ||
| config_mod.md | ||
| cpp_extension.rst | ||
| cpp_index.rst | ||
| cpu.rst | ||
| cuda_environment_variables.rst | ||
| cuda._sanitizer.rst | ||
| cuda.md | ||
| cuda.tunable.md | ||
| cudnn_persistent_rnn.rst | ||
| cudnn_rnn_determinism.rst | ||
| data.md | ||
| ddp_comm_hooks.md | ||
| debugging_environment_variables.md | ||
| deterministic.md | ||
| distributed._dist2.md | ||
| distributed.algorithms.join.md | ||
| distributed.checkpoint.md | ||
| distributed.elastic.md | ||
| distributed.fsdp.fully_shard.md | ||
| distributed.md | ||
| distributed.optim.md | ||
| distributed.pipelining.md | ||
| distributed.tensor.md | ||
| distributed.tensor.parallel.md | ||
| distributions.md | ||
| dlpack.md | ||
| docutils.conf | ||
| export.md | ||
| fft.md | ||
| fsdp.md | ||
| func.api.md | ||
| func.batch_norm.md | ||
| func.md | ||
| func.migrating.md | ||
| func.ux_limitations.md | ||
| func.whirlwind_tour.md | ||
| future_mod.md | ||
| futures.md | ||
| fx.experimental.md | ||
| fx.md | ||
| hub.md | ||
| index.md | ||
| jit_builtin_functions.rst | ||
| jit_language_reference_v2.md | ||
| jit_language_reference.md | ||
| jit_python_reference.md | ||
| jit_unsupported.md | ||
| jit_utils.md | ||
| jit.rst | ||
| library.md | ||
| linalg.md | ||
| logging.md | ||
| masked.md | ||
| math-quantizer-equation.png | ||
| meta.md | ||
| miscellaneous_environment_variables.md | ||
| mobile_optimizer.md | ||
| model_zoo.md | ||
| module_tracker.md | ||
| monitor.md | ||
| mps_environment_variables.md | ||
| mps.md | ||
| mtia.md | ||
| mtia.memory.md | ||
| multiprocessing.md | ||
| name_inference.md | ||
| named_tensor.md | ||
| nested.md | ||
| nn.aliases.md | ||
| nn.attention.bias.md | ||
| nn.attention.experimental.md | ||
| nn.attention.flex_attention.md | ||
| nn.attention.rst | ||
| nn.functional.rst | ||
| nn.init.rst | ||
| nn.rst | ||
| notes.md | ||
| onnx_export.md | ||
| onnx_ops.md | ||
| onnx_verification.md | ||
| onnx.md | ||
| optim.aliases.md | ||
| optim.md | ||
| package.md | ||
| profiler.md | ||
| pytorch-api.md | ||
| quantization-support.md | ||
| quantization.rst | ||
| random.md | ||
| rpc.md | ||
| signal.md | ||
| size.md | ||
| sparse.rst | ||
| special.md | ||
| storage.rst | ||
| tensor_attributes.rst | ||
| tensor_view.rst | ||
| tensorboard.rst | ||
| tensors.rst | ||
| testing.md | ||
| threading_environment_variables.md | ||
| torch_cuda_memory.md | ||
| torch_environment_variables.md | ||
| torch_nccl_environment_variables.md | ||
| torch.aliases.md | ||
| torch.compiler_aot_inductor_debugging_guide.md | ||
| torch.compiler_aot_inductor_minifier.md | ||
| torch.compiler_aot_inductor.md | ||
| torch.compiler_api.md | ||
| torch.compiler_backward.md | ||
| torch.compiler_cudagraph_trees.md | ||
| torch.compiler_custom_backends.md | ||
| torch.compiler_dynamic_shapes.md | ||
| torch.compiler_dynamo_deepdive.md | ||
| torch.compiler_dynamo_overview.md | ||
| torch.compiler_fake_tensor.md | ||
| torch.compiler_faq.md | ||
| torch.compiler_fine_grain_apis.md | ||
| torch.compiler_get_started.md | ||
| torch.compiler_inductor_profiling.md | ||
| torch.compiler_inductor_provenance.rst | ||
| torch.compiler_ir.md | ||
| torch.compiler_nn_module.md | ||
| torch.compiler_performance_dashboard.md | ||
| torch.compiler_profiling_torch_compile.md | ||
| torch.compiler_transformations.md | ||
| torch.compiler_troubleshooting_old.md | ||
| torch.compiler_troubleshooting.md | ||
| torch.compiler.config.md | ||
| torch.compiler.md | ||
| torch.overrides.md | ||
| torch.rst | ||
| type_info.md | ||
| utils.md | ||
| xpu.md | ||