pytorch/docs/source
Tianyu Liu af5376c444 [dtensor] add support for loss parallel (#119877)
Loss parallel is the last piece of sequence parallelism to enable. It enables efficient distributed cross entropy computation when the input is sharded on the class dimension (in a classification problem with many classes). The implementation is via a context manager `loss_parallel`, after enabling which users can directly use `torch.nn.functional.cross_entropy` or `torch.nn.CrossEntropyLoss` without modifying other parts of their code.

Here are the underlying rationales why we are going through these op replacements:

1. `nn.functional.cross_entropy` is the common method that OSS user is using for things like transformer training, to avoid changing user code, we want user to still use this function for loss calculation if they are already using it.
2. `nn.functional.cross_entropy` boils down into `aten.log_softmax` and `aten.nll_loss_foward/backward`, and DTensor now supports those ops already (#117723 #119255 #118917 #119256). They are doing computation with input *replicated* on the class dimension.
3. However when the input of this loss calculation is **sharded on the class dimension**, to run sharded computation efficiently, we need to run both `aten.log_softmax` and `aten.nll_loss_foward` with multiple all-reduce collectives **in the middle of** those aten ops. This is not possible if we are just overriding these two ops, so we need to have some way to **decompose** these two ops into smaller ops to have collectives run in the middle of these two ops.
4. We explored the existing decompositions (#118950). It seems working, except that `log_softmax_backward` and `nll_loss_backward` combined together in aten are implemented in a inefficient way, which would trigger an additional expensive collective. Recently some user also reported similar issues https://github.com/pytorch/pytorch/issues/119261.
5. Therefore, currently we are doing our own decomposition inside a context manager for sequence parallelism specifically. Once we have a better decomposition in core, we can possibly take that instead of reinventing the wheels here.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119877
Approved by: https://github.com/wanchaol
2024-03-02 05:06:26 +00:00
..
_static Removing HTA documentation (#116513) 2023-12-28 23:04:23 +00:00
_templates Remove sdp_kernel and replace with sdpa_kernel in attention namespace (#114689) 2024-01-24 22:28:04 +00:00
community Fix typo on Contribution Guide (#119428) 2024-02-08 01:07:27 +00:00
elastic [TorchElastic] Refactoring to support non-default logging strategy (#120691) 2024-02-29 20:59:17 +00:00
notes Test seo torch cuda (#119324) 2024-02-07 00:39:51 +00:00
rpc
scripts Fixed typo in build_activation_images.py (#117458) 2024-01-15 03:27:40 +00:00
amp.rst add GradScaler on CPU (#109993) 2024-01-29 23:42:35 +00:00
autograd.rst Autograd doc cleanup (#118500) 2024-01-29 21:51:33 +00:00
backends.rst [CUDNN][SDPA] Experimental cuDNN Flash Attention v2 Inference (#115663) 2024-02-14 22:02:06 +00:00
benchmark_utils.rst
bottleneck.rst
checkpoint.rst Add missing words to torch.utils.checkpoint doc (#120196) 2024-02-20 20:18:42 +00:00
complex_numbers.rst Update mentions of deprecated functions if complex_numbers.rst (#113391) 2023-11-09 22:32:26 +00:00
cond.rst Fix typo under docs directory (#119657) 2024-02-15 21:14:34 +00:00
conf.py [dtensor] add support for loss parallel (#119877) 2024-03-02 05:06:26 +00:00
config_mod.rst
cpp_extension.rst
cpp_index.rst
cpu.rst Add current_device() to torch.cpu (#110987) 2023-10-11 05:13:10 +00:00
cuda_environment_variables.rst Add doc page for environment variables that effect PyTorch Runtime (#119087) 2024-02-15 21:41:38 +00:00
cuda._sanitizer.rst
cuda.rst Doc test non packages (#110568) 2023-10-06 14:16:01 +00:00
cudnn_persistent_rnn.rst
cudnn_rnn_determinism.rst
data.rst Revert "reseed all Generators in Dataloader's _worker_loop() -- via GC (#107131)" 2023-08-23 17:08:07 +00:00
ddp_comm_hooks.rst [DOCS][DDP]Fix the simple of saving and reloading PowerSGD state and hook. (#102721) 2023-06-10 00:15:00 +00:00
debugging_environment_variables.rst Add doc page for environment variables that effect PyTorch Runtime (#119087) 2024-02-15 21:41:38 +00:00
deploy.rst
deterministic.rst Add torch.utils.deterministic.fill_uninitialized_memory flag (#111377) 2023-11-01 16:10:09 +00:00
distributed.algorithms.join.rst
distributed.checkpoint.rst [DCP] Adds storage reader and planner classes for online loading/sharding of models in torch.save format (#119816) 2024-03-01 00:21:05 +00:00
distributed.elastic.rst
distributed.optim.rst
distributed.rst [dtensor] add support for loss parallel (#119877) 2024-03-02 05:06:26 +00:00
distributed.tensor.parallel.rst [dtensor] add support for loss parallel (#119877) 2024-03-02 05:06:26 +00:00
distributions.rst Add inverse gamma distribution and fix sign bug in PowerTransform. (#104501) 2023-11-01 02:26:25 +00:00
dlpack.rst
docutils.conf
export.ir_spec.rst [export] Remove torch._export.export (#119095) 2024-02-08 21:22:04 +00:00
export.rst [Export] Remove ScriptObjectMeta (#118241) 2024-01-26 00:37:19 +00:00
fft.rst
fsdp.rst [FSDP][state_dict] Expose optimizer state_dict config (#105949) 2023-08-21 07:29:49 +00:00
func.api.rst [functorch] linearize (#94173) 2023-02-09 15:45:08 +00:00
func.batch_norm.rst Fix typo under docs directory (#97202) 2023-03-21 01:24:10 +00:00
func.migrating.rst
func.rst
func.ux_limitations.rst
func.whirlwind_tour.rst
future_mod.rst Add swap_tensors path to nn.Module._apply (#117167) 2024-02-07 18:55:44 +00:00
futures.rst
fx.experimental.rst Add basic reference documentation for symbolic_shapes.py (#118997) 2024-02-07 14:33:42 +00:00
fx.rst Introduce size oblivious guards (#118579) 2024-02-06 19:45:32 +00:00
hub.rst
index.rst Add doc page for environment variables that effect PyTorch Runtime (#119087) 2024-02-15 21:41:38 +00:00
jit_builtin_functions.rst
jit_language_reference_v2.rst Fix typo under docs directory (#97202) 2023-03-21 01:24:10 +00:00
jit_language_reference.rst [BE] [1/3] Rewrite super() calls in caffe2 and benchmarks (#94587) 2023-02-11 18:19:48 +00:00
jit_python_reference.rst
jit_unsupported.rst Add support for torch.Generator type in TorchScript (#110413) 2023-11-21 23:07:21 +00:00
jit_utils.rst
jit.rst Doc test non packages (#110568) 2023-10-06 14:16:01 +00:00
library.rst Rewrite torch.library's documentation (#111310) 2023-10-23 23:02:41 +00:00
linalg.rst
logging.rst Change classification to beta for TORCH_LOGS (#118682) 2024-01-31 21:50:55 +00:00
masked.rst Doc test non packages (#110568) 2023-10-06 14:16:01 +00:00
math-quantizer-equation.png
meta.rst Add documentation for meta device (#119119) 2024-02-04 01:05:22 +00:00
miscellaneous_environment_variables.rst Add doc page for environment variables that effect PyTorch Runtime (#119087) 2024-02-15 21:41:38 +00:00
mobile_optimizer.rst
model_zoo.rst
monitor.rst
mps.rst Doc test non packages (#110568) 2023-10-06 14:16:01 +00:00
multiprocessing.rst Doc test non packages (#110568) 2023-10-06 14:16:01 +00:00
name_inference.rst [docs] Properly link register_post_accumulate_grad_hook docs (#108157) 2023-08-29 22:13:33 +00:00
named_tensor.rst fixing named tensor unflatten example (#106921) 2023-08-22 18:00:10 +00:00
nested.rst Replace master with main in links and docs/conf.py (#100176) 2023-05-02 18:20:32 +00:00
nn.attention.bias.rst Remove sdp_kernel and replace with sdpa_kernel in attention namespace (#114689) 2024-01-24 22:28:04 +00:00
nn.attention.rst Remove sdp_kernel and replace with sdpa_kernel in attention namespace (#114689) 2024-01-24 22:28:04 +00:00
nn.functional.rst Add python and C++ support for LPPool3d (#114199) 2023-12-08 18:18:44 +00:00
nn.init.rst
nn.rst Added missing CircularPad*d references so the docs are actually built. (#118465) 2024-01-27 22:39:01 +00:00
onnx_dynamo_onnxruntime_backend.rst Follow-up #108379 (#108905) 2023-09-09 01:38:36 +00:00
onnx_dynamo.rst [ez][doc] Fix sample code in onnx_dynamo.rst (#114770) 2023-11-29 19:27:52 +00:00
onnx_torchscript_supported_aten_ops.rst Refactor torch.onnx documentation (#108379) 2023-09-08 18:23:48 +00:00
onnx_torchscript.rst Follow-up #108379 (#108905) 2023-09-09 01:38:36 +00:00
onnx.rst Doc test non packages (#110568) 2023-10-06 14:16:01 +00:00
optim.rst Added example regarding weight_decay distinction with per-parameter API (#117436) 2024-01-22 21:26:02 +00:00
package.rst Doc test non packages (#110568) 2023-10-06 14:16:01 +00:00
pipeline.rst docs: Linking ResNeXt PyTorch Hub Pipeline (#98689) 2023-04-11 02:20:26 +00:00
profiler.rst Doc test non packages (#110568) 2023-10-06 14:16:01 +00:00
quantization-accuracy-debugging.rst
quantization-backend-configuration.rst
quantization-support.rst [quant][pt2e] Add model_is_exported util function (#119726) 2024-02-16 19:29:36 +00:00
quantization.rst Fix typo under docs directory (#119657) 2024-02-15 21:14:34 +00:00
random.rst
rpc.rst [BE] RPC is missing RRef docs (#106902) 2023-08-10 16:26:27 +00:00
signal.rst
sparse.rst Fix typo in https://pytorch.org/docs/stable/sparse.html (#115282) 2023-12-08 18:31:33 +00:00
special.rst
storage.rst
tensor_attributes.rst Include the scalar tensor auto-transfer in the doc (#119967) 2024-02-15 22:37:39 +00:00
tensor_view.rst
tensorboard.rst
tensors.rst Integrate swap_tensors into nn.Module.load_state_dict (#117913) 2024-02-09 22:32:29 +00:00
testing.rst
threading_environment_variables.rst Add doc page for environment variables that effect PyTorch Runtime (#119087) 2024-02-15 21:41:38 +00:00
torch_cuda_memory.rst Fix typo under docs directory (#110359) 2023-10-03 16:36:05 +00:00
torch_environment_variables.rst Add doc page for environment variables that effect PyTorch Runtime (#119087) 2024-02-15 21:41:38 +00:00
torch.ao.ns._numeric_suite_fx.rst
torch.ao.ns._numeric_suite.rst
torch.compiler_aot_inductor.rst Fix typo under docs directory (#119657) 2024-02-15 21:14:34 +00:00
torch.compiler_api.rst [torch.export] Support is_compiling() flag for non-strict mode (#119602) 2024-02-29 05:52:51 +00:00
torch.compiler_best_practices_for_backends.rst Restructure torch.compile docs (#105376) 2023-07-28 20:58:57 +00:00
torch.compiler_cudagraph_trees.rst [docs] add mode="reduce-overhead" into torch.compile to enable cuda g… (#116529) 2024-01-05 22:54:20 +00:00
torch.compiler_custom_backends.rst [docs, dynamo] fix typos in dynamo custom backend docs (#115444) 2023-12-08 23:58:26 +00:00
torch.compiler_deepdive.rst [Dynamo]Expose bytecode hooks and add example usage for decompilation in docs (#110714) 2023-10-13 12:36:00 +00:00
torch.compiler_dynamic_shapes.rst Update dynamic shapes documentation (#109764) 2023-09-21 13:53:43 +00:00
torch.compiler_fake_tensor.rst Restructure torch.compile docs (#105376) 2023-07-28 20:58:57 +00:00
torch.compiler_faq.rst Add a wrapper to transform a NumPy function into a PyTorch function (#114610) 2024-01-02 18:35:29 +00:00
torch.compiler_fine_grain_apis.rst [torch.export] Support is_compiling() flag for non-strict mode (#119602) 2024-02-29 05:52:51 +00:00
torch.compiler_get_started.rst [Reland2] [inductor][BE] split triton_meta and inductor_meta (#112351) 2023-11-02 00:40:12 +00:00
torch.compiler_guards_overview.rst Do not use a specific LOC in link (#108957) 2023-09-13 19:21:45 +00:00
torch.compiler_inductor_profiling.rst Restructure torch.compile docs (#105376) 2023-07-28 20:58:57 +00:00
torch.compiler_ir.rst [export] torch.export landing page (#108783) 2023-09-10 01:40:42 +00:00
torch.compiler_nn_module.rst Revert "Reland 3rd try [finishing colesbury's PR 100642] Guard on nn.Module dicts and type (#109323)" + Forward fixes + test (#110964) 2023-10-11 05:16:47 +00:00
torch.compiler_performance_dashboard.rst Restructure torch.compile docs (#105376) 2023-07-28 20:58:57 +00:00
torch.compiler_profiling_torch_compile.rst Restructure torch.compile docs (#105376) 2023-07-28 20:58:57 +00:00
torch.compiler_transformations.rst Fix typo under docs directory (#110359) 2023-10-03 16:36:05 +00:00
torch.compiler_troubleshooting.rst Update DDP dynamo debug docs (#118295) 2024-01-29 14:58:26 +00:00
torch.compiler.rst [docs] Fix torch.compile "tensorrt" backend docs (#113711) 2023-11-15 08:42:53 +00:00
torch.overrides.rst Doc test non packages (#110568) 2023-10-06 14:16:01 +00:00
torch.rst Autograd doc cleanup (#118500) 2024-01-29 21:51:33 +00:00
type_info.rst
utils.rst New swap function (#111747) 2023-12-08 18:49:35 +00:00
xpu.rst [2/2] Intel GPU Runtime Upstreaming for Generator (#118613) 2024-02-28 05:28:11 +00:00