mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
```
[-------------------------------------- Fused SGD --------------------------------------]
| Fused: True | Fused: False
1 threads: ------------------------------------------------------------------------------
numel: 1024, num_tensors: 100, momentum: True | 2 | 15
numel: 1024, num_tensors: 100, momentum: False | 2 | 5
numel: 65536, num_tensors: 100, momentum: True | 3 | 16
numel: 65536, num_tensors: 100, momentum: False | 2 | 5
numel: 1048576, num_tensors: 100, momentum: True | 11 | 16
numel: 1048576, num_tensors: 100, momentum: False | 8 | 6
numel: 1024, num_tensors: 500, momentum: True | 29 | 70
numel: 1024, num_tensors: 500, momentum: False | 20 | 24
numel: 65536, num_tensors: 500, momentum: True | 33 | 76
numel: 65536, num_tensors: 500, momentum: False | 22 | 26
numel: 1048576, num_tensors: 500, momentum: True | 70 | 80
numel: 1048576, num_tensors: 500, momentum: False | 43 | 40
numel: 1024, num_tensors: 1000, momentum: True | 108 | 139
numel: 1024, num_tensors: 1000, momentum: False | 72 | 48
numel: 65536, num_tensors: 1000, momentum: True | 116 | 150
numel: 65536, num_tensors: 1000, momentum: False | 77 | 52
numel: 1048576, num_tensors: 1000, momentum: True | 190 | 170
numel: 1048576, num_tensors: 1000, momentum: False | 120 | 50
```
```python
def profile_fused_sgd():
from torch.optim.sgd import sgd
import torch.utils.benchmark as benchmark
import itertools
def profile(fn, params, grads, momentum_buffer_list, fused):
fn(
params,
grads,
momentum_buffer_list,
momentum=True if len(momentum_buffer_list) > 0 else False,
dampening=0.0,
nesterov=False,
foreach=False,
fused=fused,
lr=1e-3,
weight_decay=.0,
maximize=False,
grad_scale=None,
found_inf=None,
)
torch.mps.synchronize()
device = "mps"
results = []
for num_tensors, numel, momentum in itertools.product([100, 500, 1000], [1024, 65536, 1048576], [True, False]):
sublabel = f"numel: {numel}, num_tensors: {num_tensors}, momentum: {momentum}"
print(sublabel)
params, grads = [[torch.arange(numel, dtype=torch.float32, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(2)]
momentum_buffer_list = [torch.arange(numel, dtype=torch.float32, device=device) + (numel * i) for i in range(num_tensors)] if momentum else []
fn = sgd
for fused in [True, False]:
t = benchmark.Timer(
stmt='profile(fn, params, grads, momentum_buffer_list, fused)',
label='Fused SGD',
sub_label=sublabel,
globals=locals(),
description= f"Fused: {fused}",
).blocked_autorange(min_run_time=5)
results.append(t)
compare = benchmark.Compare(results)
compare.trim_significant_figures()
compare.colorize(rowwise=True)
compare.print()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129350
Approved by: https://github.com/janeyx99
ghstack dependencies: #129006, #129008, #129007, #129105
|
||
|---|---|---|
| .. | ||
| _static | ||
| _templates | ||
| community | ||
| elastic | ||
| notes | ||
| rpc | ||
| scripts | ||
| amp.rst | ||
| autograd.rst | ||
| backends.rst | ||
| benchmark_utils.rst | ||
| bottleneck.rst | ||
| checkpoint.rst | ||
| complex_numbers.rst | ||
| cond.rst | ||
| conf.py | ||
| config_mod.rst | ||
| cpp_extension.rst | ||
| cpp_index.rst | ||
| cpu.rst | ||
| cuda_environment_variables.rst | ||
| cuda._sanitizer.rst | ||
| cuda.rst | ||
| cuda.tunable.rst | ||
| cudnn_persistent_rnn.rst | ||
| cudnn_rnn_determinism.rst | ||
| data.rst | ||
| ddp_comm_hooks.rst | ||
| debugging_environment_variables.rst | ||
| deploy.rst | ||
| deterministic.rst | ||
| distributed.algorithms.join.rst | ||
| distributed.checkpoint.rst | ||
| distributed.elastic.rst | ||
| distributed.optim.rst | ||
| distributed.pipelining.rst | ||
| distributed.rst | ||
| distributed.tensor.parallel.rst | ||
| distributions.rst | ||
| dlpack.rst | ||
| docutils.conf | ||
| export.ir_spec.rst | ||
| export.rst | ||
| fft.rst | ||
| fsdp.rst | ||
| func.api.rst | ||
| func.batch_norm.rst | ||
| func.migrating.rst | ||
| func.rst | ||
| func.ux_limitations.rst | ||
| func.whirlwind_tour.rst | ||
| future_mod.rst | ||
| futures.rst | ||
| fx.experimental.rst | ||
| fx.rst | ||
| hub.rst | ||
| index.rst | ||
| jit_builtin_functions.rst | ||
| jit_language_reference_v2.rst | ||
| jit_language_reference.rst | ||
| jit_python_reference.rst | ||
| jit_unsupported.rst | ||
| jit_utils.rst | ||
| jit.rst | ||
| library.rst | ||
| linalg.rst | ||
| logging.rst | ||
| masked.rst | ||
| math-quantizer-equation.png | ||
| meta.rst | ||
| miscellaneous_environment_variables.rst | ||
| mobile_optimizer.rst | ||
| model_zoo.rst | ||
| module_tracker.rst | ||
| monitor.rst | ||
| mps_environment_variables.rst | ||
| mps.rst | ||
| mtia.rst | ||
| multiprocessing.rst | ||
| name_inference.rst | ||
| named_tensor.rst | ||
| nested.rst | ||
| nn.attention.bias.rst | ||
| nn.attention.rst | ||
| nn.functional.rst | ||
| nn.init.rst | ||
| nn.rst | ||
| onnx_dynamo_onnxruntime_backend.rst | ||
| onnx_dynamo.rst | ||
| onnx_torchscript_supported_aten_ops.rst | ||
| onnx_torchscript.rst | ||
| onnx.rst | ||
| optim.rst | ||
| package.rst | ||
| profiler.rst | ||
| quantization-accuracy-debugging.rst | ||
| quantization-backend-configuration.rst | ||
| quantization-support.rst | ||
| quantization.rst | ||
| random.rst | ||
| rpc.rst | ||
| signal.rst | ||
| size.rst | ||
| sparse.rst | ||
| special.rst | ||
| storage.rst | ||
| tensor_attributes.rst | ||
| tensor_view.rst | ||
| tensorboard.rst | ||
| tensors.rst | ||
| testing.rst | ||
| threading_environment_variables.rst | ||
| torch_cuda_memory.rst | ||
| torch_environment_variables.rst | ||
| torch_nccl_environment_variables.rst | ||
| torch.ao.ns._numeric_suite_fx.rst | ||
| torch.ao.ns._numeric_suite.rst | ||
| torch.compiler_aot_inductor.rst | ||
| torch.compiler_api.rst | ||
| torch.compiler_best_practices_for_backends.rst | ||
| torch.compiler_cudagraph_trees.rst | ||
| torch.compiler_custom_backends.rst | ||
| torch.compiler_dynamic_shapes.rst | ||
| torch.compiler_dynamo_deepdive.rst | ||
| torch.compiler_dynamo_overview.rst | ||
| torch.compiler_fake_tensor.rst | ||
| torch.compiler_faq.rst | ||
| torch.compiler_fine_grain_apis.rst | ||
| torch.compiler_get_started.rst | ||
| torch.compiler_inductor_profiling.rst | ||
| torch.compiler_ir.rst | ||
| torch.compiler_nn_module.rst | ||
| torch.compiler_performance_dashboard.rst | ||
| torch.compiler_profiling_torch_compile.rst | ||
| torch.compiler_transformations.rst | ||
| torch.compiler_troubleshooting.rst | ||
| torch.compiler.rst | ||
| torch.overrides.rst | ||
| torch.rst | ||
| type_info.rst | ||
| utils.rst | ||
| xpu.rst | ||