mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Introduce a lightweight TorchDispatchMode for understanding the magic behind DTensor.
- Tracks redistribution, see `redistribute_input(input_idx, from_placement, to_placement)`
- Optionally tracks torch-level functions, via `__torch_function__`
- Optionally tracks FakeTensor operations, which was needed for propagating tensor meta as a step of sharding propagation
- Optionally tracks real tensor operations, including functional c10d op, and regular ops
- Calls are shown in the hierarchical structure!
- shorthand representation
- dt: DTesnor, ft: FakeTensor, t: Tensor
- DM(2, 2) == DeviceMesh(shape = [2, 2])
- [R, P, S(0)] == Placement[Replicate, Partial, Shard(0)]
- f32[8,8] == float32 with shape[8, 8]
```
debug_mode = DTensorDebugMode(record_faketensor=False, record_realtensor=True)
with debug_mode:
torch.mm(x_dtensor, y_dtensor)
print(debug_mode.debug_string())
```
produces:
```
torch.mm(dt: f32[8, 8][S(0)], dt: f32[8, 32][S(0)])
aten::mm(dt: f32[8, 8][S(0)], dt: f32[8, 32][S(0)])
redistribute_input(1, [S(0)], [R])
_c10d_functional::all_gather_into_tensor(t: f32[1, 32], 8, 0)
_c10d_functional::wait_tensor(t: f32[8, 32])
aten::mm(t: f32[1, 8], t: f32[8, 32])
```
Another example, for torch.einsum
```
torch.functional.einsum(bld,dnh->blnh, dt: f32[16, 6, 8][P, R], dt: f32[8, 4, 4][R, P])
aten::unsqueeze(dt: f32[16, 6, 8][P, R], 3)
aten::unsqueeze(t: f32[16, 6, 8], 3)
aten::unsqueeze(dt: f32[16, 6, 8, 1][P, R], 4)
aten::unsqueeze(t: f32[16, 6, 8, 1], 4)
aten::permute(dt: f32[16, 6, 8, 1, 1][P, R], [0, 1, 3, 4, 2])
aten::permute(t: f32[16, 6, 8, 1, 1], [0, 1, 3, 4, 2])
aten::unsqueeze(dt: f32[8, 4, 4][R, P], 3)
aten::unsqueeze(t: f32[8, 4, 4], 3)
aten::unsqueeze(dt: f32[8, 4, 4, 1][R, P], 4)
aten::unsqueeze(t: f32[8, 4, 4, 1], 4)
aten::permute(dt: f32[8, 4, 4, 1, 1][R, P], [3, 4, 1, 2, 0])
aten::permute(t: f32[8, 4, 4, 1, 1], [3, 4, 1, 2, 0])
aten::permute(dt: f32[16, 6, 1, 1, 8][P, R], [0, 1, 4, 2, 3])
aten::permute(t: f32[16, 6, 1, 1, 8], [0, 1, 4, 2, 3])
aten::view(dt: f32[16, 6, 8, 1, 1][P, R], [1, 96, 8])
aten::view(t: f32[16, 6, 8, 1, 1], [1, 96, 8])
aten::permute(dt: f32[1, 1, 4, 4, 8][R, P], [4, 2, 3, 0, 1])
aten::permute(t: f32[1, 1, 4, 4, 8], [4, 2, 3, 0, 1])
aten::view(dt: f32[8, 4, 4, 1, 1][R, P], [1, 8, 16])
aten::view(t: f32[8, 4, 4, 1, 1], [1, 8, 16])
aten::bmm(dt: f32[1, 96, 8][P, R], dt: f32[1, 8, 16][R, P])
redistribute_input(0, [P, R], [S(2), S(2)])
aten::chunk(t: f32[1, 96, 8], 4, 2)
aten::cat(['t: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]'])
_c10d_functional::reduce_scatter_tensor(t: f32[4, 96, 2], sum, 4, 2)
aten::clone(t: f32[1, 96, 1])
redistribute_input(1, [R, P], [S(1), S(1)])
aten::chunk(t: f32[1, 8, 16], 4, 1)
aten::clone(t: f32[1, 2, 16])
aten::chunk(t: f32[1, 2, 16], 2, 1)
aten::cat(['t: f32[1, 1, 16]', 't: f32[1, 1, 16]'])
_c10d_functional::reduce_scatter_tensor(t: f32[2, 1, 16], sum, 2, 3)
_c10d_functional::wait_tensor(t: f32[1, 1, 16])
aten::bmm(t: f32[1, 96, 1], t: f32[1, 1, 16])
aten::view(dt: f32[1, 96, 16][P, P], [16, 6, 1, 4, 4])
aten::view(t: f32[1, 96, 16], [16, 6, 1, 4, 4])
aten::permute(dt: f32[16, 6, 1, 4, 4][P, P], [0, 1, 3, 4, 2])
aten::permute(t: f32[16, 6, 1, 4, 4], [0, 1, 3, 4, 2])
aten::view(dt: f32[16, 6, 4, 4, 1][P, P], [16, 6, 4, 4])
aten::view(t: f32[16, 6, 4, 4, 1], [16, 6, 4, 4])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162665
Approved by: https://github.com/ezyang
4.1 KiB
4.1 KiB
torch.utils
.. automodule:: torch.utils
.. currentmodule:: torch.utils
.. autosummary::
:toctree: generated
:nosignatures:
rename_privateuse1_backend
generate_methods_for_privateuse1_backend
get_cpp_backtrace
set_module
swap_tensors
.. py:module:: torch.utils.backend_registration
.. py:module:: torch.utils.benchmark.examples.compare
.. py:module:: torch.utils.benchmark.examples.fuzzer
.. py:module:: torch.utils.benchmark.examples.op_benchmark
.. py:module:: torch.utils.benchmark.examples.simple_timeit
.. py:module:: torch.utils.benchmark.examples.spectral_ops_fuzz_test
.. py:module:: torch.utils.benchmark.op_fuzzers.binary
.. py:module:: torch.utils.benchmark.op_fuzzers.sparse_binary
.. py:module:: torch.utils.benchmark.op_fuzzers.sparse_unary
.. py:module:: torch.utils.benchmark.op_fuzzers.spectral
.. py:module:: torch.utils.benchmark.op_fuzzers.unary
.. py:module:: torch.utils.benchmark.utils.common
.. py:module:: torch.utils.benchmark.utils.compare
.. py:module:: torch.utils.benchmark.utils.compile
.. py:module:: torch.utils.benchmark.utils.cpp_jit
.. py:module:: torch.utils.benchmark.utils.fuzzer
.. py:module:: torch.utils.benchmark.utils.sparse_fuzzer
.. py:module:: torch.utils.benchmark.utils.timer
.. py:module:: torch.utils.benchmark.utils.valgrind_wrapper.timer_interface
.. py:module:: torch.utils.bundled_inputs
.. py:module:: torch.utils.checkpoint
.. py:module:: torch.utils.collect_env
.. py:module:: torch.utils.cpp_backtrace
.. py:module:: torch.utils.cpp_extension
.. py:module:: torch.utils.data.backward_compatibility
.. py:module:: torch.utils.data.dataloader
.. py:module:: torch.utils.data.datapipes.dataframe.dataframe_wrapper
.. py:module:: torch.utils.data.datapipes.dataframe.dataframes
.. py:module:: torch.utils.data.datapipes.dataframe.datapipes
.. py:module:: torch.utils.data.datapipes.dataframe.structures
.. py:module:: torch.utils.data.datapipes.datapipe
.. py:module:: torch.utils.data.datapipes.gen_pyi
.. py:module:: torch.utils.data.datapipes.iter.callable
.. py:module:: torch.utils.data.datapipes.iter.combinatorics
.. py:module:: torch.utils.data.datapipes.iter.combining
.. py:module:: torch.utils.data.datapipes.iter.filelister
.. py:module:: torch.utils.data.datapipes.iter.fileopener
.. py:module:: torch.utils.data.datapipes.iter.grouping
.. py:module:: torch.utils.data.datapipes.iter.routeddecoder
.. py:module:: torch.utils.data.datapipes.iter.selecting
.. py:module:: torch.utils.data.datapipes.iter.sharding
.. py:module:: torch.utils.data.datapipes.iter.streamreader
.. py:module:: torch.utils.data.datapipes.iter.utils
.. py:module:: torch.utils.data.datapipes.map.callable
.. py:module:: torch.utils.data.datapipes.map.combinatorics
.. py:module:: torch.utils.data.datapipes.map.combining
.. py:module:: torch.utils.data.datapipes.map.grouping
.. py:module:: torch.utils.data.datapipes.map.utils
.. py:module:: torch.utils.data.datapipes.utils.common
.. py:module:: torch.utils.data.datapipes.utils.decoder
.. py:module:: torch.utils.data.datapipes.utils.snapshot
.. py:module:: torch.utils.data.dataset
.. py:module:: torch.utils.data.distributed
.. py:module:: torch.utils.data.graph
.. py:module:: torch.utils.data.graph_settings
.. py:module:: torch.utils.data.sampler
.. py:module:: torch.utils.debug_mode
.. py:module:: torch.utils.dlpack
.. py:module:: torch.utils.file_baton
.. py:module:: torch.utils.flop_counter
.. py:module:: torch.utils.hipify.constants
.. py:module:: torch.utils.hipify.cuda_to_hip_mappings
.. py:module:: torch.utils.hipify.hipify_python
.. py:module:: torch.utils.hipify.version
.. py:module:: torch.utils.hooks
.. py:module:: torch.utils.jit.log_extract
.. py:module:: torch.utils.mkldnn
.. py:module:: torch.utils.mobile_optimizer
.. py:module:: torch.utils.show_pickle
.. py:module:: torch.utils.tensorboard.summary
.. py:module:: torch.utils.tensorboard.writer
.. py:module:: torch.utils.throughput_benchmark
.. py:module:: torch.utils.weak