mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Introduce a lightweight TorchDispatchMode for understanding the magic behind DTensor.
- Tracks redistribution, see `redistribute_input(input_idx, from_placement, to_placement)`
- Optionally tracks torch-level functions, via `__torch_function__`
- Optionally tracks FakeTensor operations, which was needed for propagating tensor meta as a step of sharding propagation
- Optionally tracks real tensor operations, including functional c10d op, and regular ops
- Calls are shown in the hierarchical structure!
- shorthand representation
- dt: DTesnor, ft: FakeTensor, t: Tensor
- DM(2, 2) == DeviceMesh(shape = [2, 2])
- [R, P, S(0)] == Placement[Replicate, Partial, Shard(0)]
- f32[8,8] == float32 with shape[8, 8]
```
debug_mode = DTensorDebugMode(record_faketensor=False, record_realtensor=True)
with debug_mode:
torch.mm(x_dtensor, y_dtensor)
print(debug_mode.debug_string())
```
produces:
```
torch.mm(dt: f32[8, 8][S(0)], dt: f32[8, 32][S(0)])
aten::mm(dt: f32[8, 8][S(0)], dt: f32[8, 32][S(0)])
redistribute_input(1, [S(0)], [R])
_c10d_functional::all_gather_into_tensor(t: f32[1, 32], 8, 0)
_c10d_functional::wait_tensor(t: f32[8, 32])
aten::mm(t: f32[1, 8], t: f32[8, 32])
```
Another example, for torch.einsum
```
torch.functional.einsum(bld,dnh->blnh, dt: f32[16, 6, 8][P, R], dt: f32[8, 4, 4][R, P])
aten::unsqueeze(dt: f32[16, 6, 8][P, R], 3)
aten::unsqueeze(t: f32[16, 6, 8], 3)
aten::unsqueeze(dt: f32[16, 6, 8, 1][P, R], 4)
aten::unsqueeze(t: f32[16, 6, 8, 1], 4)
aten::permute(dt: f32[16, 6, 8, 1, 1][P, R], [0, 1, 3, 4, 2])
aten::permute(t: f32[16, 6, 8, 1, 1], [0, 1, 3, 4, 2])
aten::unsqueeze(dt: f32[8, 4, 4][R, P], 3)
aten::unsqueeze(t: f32[8, 4, 4], 3)
aten::unsqueeze(dt: f32[8, 4, 4, 1][R, P], 4)
aten::unsqueeze(t: f32[8, 4, 4, 1], 4)
aten::permute(dt: f32[8, 4, 4, 1, 1][R, P], [3, 4, 1, 2, 0])
aten::permute(t: f32[8, 4, 4, 1, 1], [3, 4, 1, 2, 0])
aten::permute(dt: f32[16, 6, 1, 1, 8][P, R], [0, 1, 4, 2, 3])
aten::permute(t: f32[16, 6, 1, 1, 8], [0, 1, 4, 2, 3])
aten::view(dt: f32[16, 6, 8, 1, 1][P, R], [1, 96, 8])
aten::view(t: f32[16, 6, 8, 1, 1], [1, 96, 8])
aten::permute(dt: f32[1, 1, 4, 4, 8][R, P], [4, 2, 3, 0, 1])
aten::permute(t: f32[1, 1, 4, 4, 8], [4, 2, 3, 0, 1])
aten::view(dt: f32[8, 4, 4, 1, 1][R, P], [1, 8, 16])
aten::view(t: f32[8, 4, 4, 1, 1], [1, 8, 16])
aten::bmm(dt: f32[1, 96, 8][P, R], dt: f32[1, 8, 16][R, P])
redistribute_input(0, [P, R], [S(2), S(2)])
aten::chunk(t: f32[1, 96, 8], 4, 2)
aten::cat(['t: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]'])
_c10d_functional::reduce_scatter_tensor(t: f32[4, 96, 2], sum, 4, 2)
aten::clone(t: f32[1, 96, 1])
redistribute_input(1, [R, P], [S(1), S(1)])
aten::chunk(t: f32[1, 8, 16], 4, 1)
aten::clone(t: f32[1, 2, 16])
aten::chunk(t: f32[1, 2, 16], 2, 1)
aten::cat(['t: f32[1, 1, 16]', 't: f32[1, 1, 16]'])
_c10d_functional::reduce_scatter_tensor(t: f32[2, 1, 16], sum, 2, 3)
_c10d_functional::wait_tensor(t: f32[1, 1, 16])
aten::bmm(t: f32[1, 96, 1], t: f32[1, 1, 16])
aten::view(dt: f32[1, 96, 16][P, P], [16, 6, 1, 4, 4])
aten::view(t: f32[1, 96, 16], [16, 6, 1, 4, 4])
aten::permute(dt: f32[16, 6, 1, 4, 4][P, P], [0, 1, 3, 4, 2])
aten::permute(t: f32[16, 6, 1, 4, 4], [0, 1, 3, 4, 2])
aten::view(dt: f32[16, 6, 4, 4, 1][P, P], [16, 6, 4, 4])
aten::view(t: f32[16, 6, 4, 4, 1], [16, 6, 4, 4])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162665
Approved by: https://github.com/ezyang
|
||
|---|---|---|
| .. | ||
| _static | ||
| _templates | ||
| accelerator | ||
| community | ||
| compile | ||
| elastic | ||
| export | ||
| notes | ||
| rpc | ||
| scripts | ||
| user_guide | ||
| accelerator.md | ||
| amp.md | ||
| autograd.md | ||
| backends.md | ||
| benchmark_utils.md | ||
| bottleneck.rst | ||
| checkpoint.md | ||
| complex_numbers.md | ||
| cond.md | ||
| conf.py | ||
| config_mod.md | ||
| cpp_extension.rst | ||
| cpp_index.rst | ||
| cpu.rst | ||
| cuda_environment_variables.rst | ||
| cuda._sanitizer.rst | ||
| cuda.md | ||
| cuda.tunable.md | ||
| cudnn_persistent_rnn.rst | ||
| cudnn_rnn_determinism.rst | ||
| data.md | ||
| ddp_comm_hooks.md | ||
| debugging_environment_variables.md | ||
| deterministic.md | ||
| distributed._dist2.md | ||
| distributed.algorithms.join.md | ||
| distributed.checkpoint.md | ||
| distributed.elastic.md | ||
| distributed.fsdp.fully_shard.md | ||
| distributed.md | ||
| distributed.optim.md | ||
| distributed.pipelining.md | ||
| distributed.tensor.md | ||
| distributed.tensor.parallel.md | ||
| distributions.md | ||
| dlpack.md | ||
| docutils.conf | ||
| export.md | ||
| fft.md | ||
| fsdp.md | ||
| func.api.md | ||
| func.batch_norm.md | ||
| func.md | ||
| func.migrating.md | ||
| func.ux_limitations.md | ||
| func.whirlwind_tour.md | ||
| future_mod.md | ||
| futures.md | ||
| fx.experimental.md | ||
| fx.md | ||
| hub.md | ||
| index.md | ||
| jit_builtin_functions.rst | ||
| jit_language_reference_v2.md | ||
| jit_language_reference.md | ||
| jit_python_reference.md | ||
| jit_unsupported.md | ||
| jit_utils.md | ||
| jit.rst | ||
| library.md | ||
| linalg.md | ||
| logging.md | ||
| masked.md | ||
| math-quantizer-equation.png | ||
| meta.md | ||
| miscellaneous_environment_variables.md | ||
| mobile_optimizer.md | ||
| model_zoo.md | ||
| module_tracker.md | ||
| monitor.md | ||
| mps_environment_variables.md | ||
| mps.md | ||
| mtia.md | ||
| mtia.memory.md | ||
| multiprocessing.md | ||
| name_inference.md | ||
| named_tensor.md | ||
| nativert.rst | ||
| nested.md | ||
| nn.aliases.md | ||
| nn.attention.bias.md | ||
| nn.attention.experimental.md | ||
| nn.attention.flex_attention.md | ||
| nn.attention.rst | ||
| nn.functional.rst | ||
| nn.init.rst | ||
| nn.rst | ||
| notes.md | ||
| onnx_export.md | ||
| onnx_ops.md | ||
| onnx_testing.md | ||
| onnx_verification.md | ||
| onnx.md | ||
| optim.aliases.md | ||
| optim.md | ||
| package.md | ||
| profiler.md | ||
| pytorch-api.md | ||
| quantization-support.md | ||
| quantization.rst | ||
| random.md | ||
| rpc.md | ||
| signal.md | ||
| size.md | ||
| sparse.rst | ||
| special.md | ||
| storage.rst | ||
| tensor_attributes.rst | ||
| tensor_view.rst | ||
| tensorboard.rst | ||
| tensors.rst | ||
| testing.md | ||
| threading_environment_variables.md | ||
| torch_cuda_memory.md | ||
| torch_environment_variables.md | ||
| torch_nccl_environment_variables.md | ||
| torch.aliases.md | ||
| torch.compiler_aot_inductor_debugging_guide.md | ||
| torch.compiler_aot_inductor_minifier.md | ||
| torch.compiler_aot_inductor.md | ||
| torch.compiler_api.md | ||
| torch.compiler_backward.md | ||
| torch.compiler_cudagraph_trees.md | ||
| torch.compiler_custom_backends.md | ||
| torch.compiler_dynamic_shapes.md | ||
| torch.compiler_dynamo_deepdive.md | ||
| torch.compiler_dynamo_overview.md | ||
| torch.compiler_fake_tensor.md | ||
| torch.compiler_faq.md | ||
| torch.compiler_fine_grain_apis.md | ||
| torch.compiler_get_started.md | ||
| torch.compiler_inductor_profiling.md | ||
| torch.compiler_inductor_provenance.rst | ||
| torch.compiler_ir.md | ||
| torch.compiler_nn_module.md | ||
| torch.compiler_performance_dashboard.md | ||
| torch.compiler_profiling_torch_compile.md | ||
| torch.compiler_transformations.md | ||
| torch.compiler_troubleshooting_old.md | ||
| torch.compiler_troubleshooting.md | ||
| torch.compiler.config.md | ||
| torch.compiler.md | ||
| torch.overrides.md | ||
| torch.rst | ||
| type_info.md | ||
| utils.md | ||
| xpu.md | ||