mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Enabled some cases to work where num_microbatches % pp_size != 0. Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and it works as long as n_microbatches % num_rounds is 0. As a few examples, support pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0. Moved over from PiPPy (https://github.com/pytorch/PiPPy/pull/1129) Tested using the config in (1), schedule looks like the following graph: ``` =========== ALL_RANK_ACTIONS =========== Rank 0 Rank 1 Rank 2 Rank 3 Step 00: F0_s0 None None None Step 01: F1_s0 F0_s1 None None Step 02: F2_s0 F1_s1 F0_s2 None Step 03: F3_s0 F2_s1 F1_s2 F0_s3 Step 04: F4_s0 F3_s1 F2_s2 F1_s3 Step 05: F0_s4 F4_s1 F3_s2 F2_s3 Step 06: F1_s4 F0_s5 F4_s2 F3_s3 Step 07: F2_s4 F1_s5 F0_s6 F4_s3 Step 08: F3_s4 F2_s5 F1_s6 F0_s7 Step 09: F4_s4 F3_s5 None B0_s7 Step 10: F5_s0 None F2_s6 F1_s7 Step 11: None None B0_s6 B1_s7 Step 12: None F4_s5 F3_s6 F2_s7 Step 13: None B0_s5 B1_s6 B2_s7 Step 14: F6_s0 F5_s1 F4_s6 F3_s7 Step 15: B0_s4 B1_s5 B2_s6 B3_s7 Step 16: F7_s0 F6_s1 F5_s2 F4_s7 Step 17: B1_s4 B2_s5 B3_s6 B4_s7 Step 18: F8_s0 F7_s1 F6_s2 F5_s3 Step 19: B2_s4 B3_s5 B4_s6 B0_s3 Step 20: F9_s0 F8_s1 F7_s2 F6_s3 Step 21: B3_s4 B4_s5 B0_s2 B1_s3 Step 22: F5_s4 F9_s1 F8_s2 F7_s3 Step 23: B4_s4 B0_s1 B1_s2 B2_s3 Step 24: F6_s4 F5_s5 F9_s2 F8_s3 Step 25: B0_s0 B1_s1 B2_s2 B3_s3 Step 26: F7_s4 F6_s5 F5_s6 F9_s3 Step 27: B1_s0 B2_s1 B3_s2 B4_s3 Step 28: F8_s4 F7_s5 F6_s6 F5_s7 Step 29: B2_s0 B3_s1 B4_s2 B5_s7 Step 30: F9_s4 F8_s5 F7_s6 F6_s7 Step 31: B3_s0 B4_s1 B5_s6 B6_s7 Step 32: None F9_s5 F8_s6 F7_s7 Step 33: B4_s0 B5_s5 B6_s6 B7_s7 Step 34: None None F9_s6 F8_s7 Step 35: B5_s4 B6_s5 B7_s6 B8_s7 Step 36: None None None F9_s7 Step 37: B6_s4 B7_s5 B8_s6 B9_s7 Step 38: None None None None Step 39: B7_s4 B8_s5 B9_s6 B5_s3 Step 40: None None None None Step 41: B8_s4 B9_s5 B5_s2 B6_s3 Step 42: None None None None Step 43: B9_s4 B5_s1 B6_s2 B7_s3 Step 44: None None None None Step 45: B5_s0 B6_s1 B7_s2 B8_s3 Step 46: None None None None Step 47: B6_s0 B7_s1 B8_s2 B9_s3 Step 48: None None None Step 49: B7_s0 B8_s1 B9_s2 Step 50: None None Step 51: B8_s0 B9_s1 Step 52: None ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129597 Approved by: https://github.com/H-Huang |
||
|---|---|---|
| .. | ||
| _static | ||
| _templates | ||
| community | ||
| elastic | ||
| notes | ||
| rpc | ||
| scripts | ||
| amp.rst | ||
| autograd.rst | ||
| backends.rst | ||
| benchmark_utils.rst | ||
| bottleneck.rst | ||
| checkpoint.rst | ||
| complex_numbers.rst | ||
| cond.rst | ||
| conf.py | ||
| config_mod.rst | ||
| cpp_extension.rst | ||
| cpp_index.rst | ||
| cpu.rst | ||
| cuda_environment_variables.rst | ||
| cuda._sanitizer.rst | ||
| cuda.rst | ||
| cuda.tunable.rst | ||
| cudnn_persistent_rnn.rst | ||
| cudnn_rnn_determinism.rst | ||
| data.rst | ||
| ddp_comm_hooks.rst | ||
| debugging_environment_variables.rst | ||
| deploy.rst | ||
| deterministic.rst | ||
| distributed.algorithms.join.rst | ||
| distributed.checkpoint.rst | ||
| distributed.elastic.rst | ||
| distributed.optim.rst | ||
| distributed.pipelining.rst | ||
| distributed.rst | ||
| distributed.tensor.parallel.rst | ||
| distributions.rst | ||
| dlpack.rst | ||
| docutils.conf | ||
| export.ir_spec.rst | ||
| export.rst | ||
| fft.rst | ||
| fsdp.rst | ||
| func.api.rst | ||
| func.batch_norm.rst | ||
| func.migrating.rst | ||
| func.rst | ||
| func.ux_limitations.rst | ||
| func.whirlwind_tour.rst | ||
| future_mod.rst | ||
| futures.rst | ||
| fx.experimental.rst | ||
| fx.rst | ||
| hub.rst | ||
| index.rst | ||
| jit_builtin_functions.rst | ||
| jit_language_reference_v2.rst | ||
| jit_language_reference.rst | ||
| jit_python_reference.rst | ||
| jit_unsupported.rst | ||
| jit_utils.rst | ||
| jit.rst | ||
| library.rst | ||
| linalg.rst | ||
| logging.rst | ||
| masked.rst | ||
| math-quantizer-equation.png | ||
| meta.rst | ||
| miscellaneous_environment_variables.rst | ||
| mobile_optimizer.rst | ||
| model_zoo.rst | ||
| module_tracker.rst | ||
| monitor.rst | ||
| mps_environment_variables.rst | ||
| mps.rst | ||
| mtia.rst | ||
| multiprocessing.rst | ||
| name_inference.rst | ||
| named_tensor.rst | ||
| nested.rst | ||
| nn.attention.bias.rst | ||
| nn.attention.rst | ||
| nn.functional.rst | ||
| nn.init.rst | ||
| nn.rst | ||
| onnx_dynamo_onnxruntime_backend.rst | ||
| onnx_dynamo.rst | ||
| onnx_torchscript_supported_aten_ops.rst | ||
| onnx_torchscript.rst | ||
| onnx.rst | ||
| optim.rst | ||
| package.rst | ||
| profiler.rst | ||
| quantization-accuracy-debugging.rst | ||
| quantization-backend-configuration.rst | ||
| quantization-support.rst | ||
| quantization.rst | ||
| random.rst | ||
| rpc.rst | ||
| signal.rst | ||
| size.rst | ||
| sparse.rst | ||
| special.rst | ||
| storage.rst | ||
| tensor_attributes.rst | ||
| tensor_view.rst | ||
| tensorboard.rst | ||
| tensors.rst | ||
| testing.rst | ||
| threading_environment_variables.rst | ||
| torch_cuda_memory.rst | ||
| torch_environment_variables.rst | ||
| torch_nccl_environment_variables.rst | ||
| torch.ao.ns._numeric_suite_fx.rst | ||
| torch.ao.ns._numeric_suite.rst | ||
| torch.compiler_aot_inductor.rst | ||
| torch.compiler_api.rst | ||
| torch.compiler_best_practices_for_backends.rst | ||
| torch.compiler_cudagraph_trees.rst | ||
| torch.compiler_custom_backends.rst | ||
| torch.compiler_dynamic_shapes.rst | ||
| torch.compiler_dynamo_deepdive.rst | ||
| torch.compiler_dynamo_overview.rst | ||
| torch.compiler_fake_tensor.rst | ||
| torch.compiler_faq.rst | ||
| torch.compiler_fine_grain_apis.rst | ||
| torch.compiler_get_started.rst | ||
| torch.compiler_inductor_profiling.rst | ||
| torch.compiler_ir.rst | ||
| torch.compiler_nn_module.rst | ||
| torch.compiler_performance_dashboard.rst | ||
| torch.compiler_profiling_torch_compile.rst | ||
| torch.compiler_transformations.rst | ||
| torch.compiler_troubleshooting.rst | ||
| torch.compiler.rst | ||
| torch.overrides.rst | ||
| torch.rst | ||
| type_info.rst | ||
| utils.rst | ||
| xpu.rst | ||