mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add sparsity (#148513)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148513 Approved by: https://github.com/danielvegamyhre
This commit is contained in:
parent
b4430c3a6d
commit
127bd5a02d
|
|
@ -80,7 +80,10 @@ class Experiment:
|
|||
|
||||
|
||||
def calculate_tflops(
|
||||
config: ExperimentConfig, time_us: float, is_backward: bool = False
|
||||
config: ExperimentConfig,
|
||||
time_us: float,
|
||||
is_backward: bool = False,
|
||||
sparsity: float = 0.0,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate TFLOPS for scaled dot product attention.
|
||||
|
|
@ -89,6 +92,7 @@ def calculate_tflops(
|
|||
- config: The experiment configuration
|
||||
- time_us: The execution time in microseconds
|
||||
- is_backward: Whether to calculate for backward pass (includes gradient computation)
|
||||
- sparsity: Sparsity factor between 0.0 and 1.0, where 0.0 means no sparsity and 1.0 means fully sparse
|
||||
|
||||
Returns:
|
||||
- TFLOPS value
|
||||
|
|
@ -99,6 +103,9 @@ def calculate_tflops(
|
|||
N = config.kv_seq_len
|
||||
D = config.head_dim
|
||||
|
||||
# Calculate density factor (1.0 - sparsity)
|
||||
density = 1.0 - sparsity
|
||||
|
||||
# Forward pass FLOPs
|
||||
qk_flops = (
|
||||
M * N * D * 2
|
||||
|
|
@ -110,6 +117,9 @@ def calculate_tflops(
|
|||
|
||||
total_flops = B * H * (qk_flops + softmax_flops + av_flops)
|
||||
|
||||
# Apply density factor to account for sparsity
|
||||
total_flops *= density
|
||||
|
||||
# For backward pass flash uses 2.5x more flops will use this
|
||||
if is_backward:
|
||||
total_flops *= 2.5
|
||||
|
|
@ -168,8 +178,11 @@ def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
|
|||
)
|
||||
|
||||
# Calculate TFLOPS for forward and backward passes
|
||||
forward_tflops = calculate_tflops(config, forward_time)
|
||||
backward_tflops = calculate_tflops(config, backward_time, is_backward=True)
|
||||
sparsity = 0.5 if is_causal else 0.0
|
||||
forward_tflops = calculate_tflops(config, forward_time, sparsity=sparsity)
|
||||
backward_tflops = calculate_tflops(
|
||||
config, backward_time, is_backward=True, sparsity=sparsity
|
||||
)
|
||||
|
||||
return ExperimentResults(
|
||||
forward_time=forward_time,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user