pytorch/docs/source/utils.md
Sherlock Huang f8d379d29e [DTensor] Introduce DebugMode (#162665)
Introduce a lightweight TorchDispatchMode for understanding the magic behind DTensor.

- Tracks redistribution, see `redistribute_input(input_idx, from_placement, to_placement)`
- Optionally tracks torch-level functions, via `__torch_function__`
- Optionally tracks FakeTensor operations, which was needed for propagating tensor meta as a step of sharding propagation
- Optionally tracks real tensor operations, including functional c10d op, and regular ops
- Calls are shown in the hierarchical structure!
- shorthand representation
  - dt: DTesnor, ft: FakeTensor, t: Tensor
  - DM(2, 2) == DeviceMesh(shape = [2, 2])
  - [R, P, S(0)] == Placement[Replicate, Partial, Shard(0)]
  - f32[8,8] == float32 with shape[8, 8]

```
  debug_mode = DTensorDebugMode(record_faketensor=False, record_realtensor=True)
  with debug_mode:
      torch.mm(x_dtensor, y_dtensor)
  print(debug_mode.debug_string())
```
produces:
```
  torch.mm(dt: f32[8, 8][S(0)], dt: f32[8, 32][S(0)])
    aten::mm(dt: f32[8, 8][S(0)], dt: f32[8, 32][S(0)])
      redistribute_input(1, [S(0)], [R])
        _c10d_functional::all_gather_into_tensor(t: f32[1, 32], 8, 0)
        _c10d_functional::wait_tensor(t: f32[8, 32])
      aten::mm(t: f32[1, 8], t: f32[8, 32])
```

Another example, for torch.einsum
```
  torch.functional.einsum(bld,dnh->blnh, dt: f32[16, 6, 8][P, R], dt: f32[8, 4, 4][R, P])
    aten::unsqueeze(dt: f32[16, 6, 8][P, R], 3)
      aten::unsqueeze(t: f32[16, 6, 8], 3)
    aten::unsqueeze(dt: f32[16, 6, 8, 1][P, R], 4)
      aten::unsqueeze(t: f32[16, 6, 8, 1], 4)
    aten::permute(dt: f32[16, 6, 8, 1, 1][P, R], [0, 1, 3, 4, 2])
      aten::permute(t: f32[16, 6, 8, 1, 1], [0, 1, 3, 4, 2])
    aten::unsqueeze(dt: f32[8, 4, 4][R, P], 3)
      aten::unsqueeze(t: f32[8, 4, 4], 3)
    aten::unsqueeze(dt: f32[8, 4, 4, 1][R, P], 4)
      aten::unsqueeze(t: f32[8, 4, 4, 1], 4)
    aten::permute(dt: f32[8, 4, 4, 1, 1][R, P], [3, 4, 1, 2, 0])
      aten::permute(t: f32[8, 4, 4, 1, 1], [3, 4, 1, 2, 0])
    aten::permute(dt: f32[16, 6, 1, 1, 8][P, R], [0, 1, 4, 2, 3])
      aten::permute(t: f32[16, 6, 1, 1, 8], [0, 1, 4, 2, 3])
    aten::view(dt: f32[16, 6, 8, 1, 1][P, R], [1, 96, 8])
      aten::view(t: f32[16, 6, 8, 1, 1], [1, 96, 8])
    aten::permute(dt: f32[1, 1, 4, 4, 8][R, P], [4, 2, 3, 0, 1])
      aten::permute(t: f32[1, 1, 4, 4, 8], [4, 2, 3, 0, 1])
    aten::view(dt: f32[8, 4, 4, 1, 1][R, P], [1, 8, 16])
      aten::view(t: f32[8, 4, 4, 1, 1], [1, 8, 16])
    aten::bmm(dt: f32[1, 96, 8][P, R], dt: f32[1, 8, 16][R, P])
      redistribute_input(0, [P, R], [S(2), S(2)])
        aten::chunk(t: f32[1, 96, 8], 4, 2)
        aten::cat(['t: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]'])
        _c10d_functional::reduce_scatter_tensor(t: f32[4, 96, 2], sum, 4, 2)
        aten::clone(t: f32[1, 96, 1])
      redistribute_input(1, [R, P], [S(1), S(1)])
        aten::chunk(t: f32[1, 8, 16], 4, 1)
        aten::clone(t: f32[1, 2, 16])
        aten::chunk(t: f32[1, 2, 16], 2, 1)
        aten::cat(['t: f32[1, 1, 16]', 't: f32[1, 1, 16]'])
        _c10d_functional::reduce_scatter_tensor(t: f32[2, 1, 16], sum, 2, 3)
        _c10d_functional::wait_tensor(t: f32[1, 1, 16])
      aten::bmm(t: f32[1, 96, 1], t: f32[1, 1, 16])
    aten::view(dt: f32[1, 96, 16][P, P], [16, 6, 1, 4, 4])
      aten::view(t: f32[1, 96, 16], [16, 6, 1, 4, 4])
    aten::permute(dt: f32[16, 6, 1, 4, 4][P, P], [0, 1, 3, 4, 2])
      aten::permute(t: f32[16, 6, 1, 4, 4], [0, 1, 3, 4, 2])
    aten::view(dt: f32[16, 6, 4, 4, 1][P, P], [16, 6, 4, 4])
      aten::view(t: f32[16, 6, 4, 4, 1], [16, 6, 4, 4])
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162665
Approved by: https://github.com/ezyang
2025-09-16 07:30:05 +00:00

4.1 KiB

torch.utils

.. automodule:: torch.utils
.. currentmodule:: torch.utils
.. autosummary::
    :toctree: generated
    :nosignatures:

    rename_privateuse1_backend
    generate_methods_for_privateuse1_backend
    get_cpp_backtrace
    set_module
    swap_tensors
.. py:module:: torch.utils.backend_registration
.. py:module:: torch.utils.benchmark.examples.compare
.. py:module:: torch.utils.benchmark.examples.fuzzer
.. py:module:: torch.utils.benchmark.examples.op_benchmark
.. py:module:: torch.utils.benchmark.examples.simple_timeit
.. py:module:: torch.utils.benchmark.examples.spectral_ops_fuzz_test
.. py:module:: torch.utils.benchmark.op_fuzzers.binary
.. py:module:: torch.utils.benchmark.op_fuzzers.sparse_binary
.. py:module:: torch.utils.benchmark.op_fuzzers.sparse_unary
.. py:module:: torch.utils.benchmark.op_fuzzers.spectral
.. py:module:: torch.utils.benchmark.op_fuzzers.unary
.. py:module:: torch.utils.benchmark.utils.common
.. py:module:: torch.utils.benchmark.utils.compare
.. py:module:: torch.utils.benchmark.utils.compile
.. py:module:: torch.utils.benchmark.utils.cpp_jit
.. py:module:: torch.utils.benchmark.utils.fuzzer
.. py:module:: torch.utils.benchmark.utils.sparse_fuzzer
.. py:module:: torch.utils.benchmark.utils.timer
.. py:module:: torch.utils.benchmark.utils.valgrind_wrapper.timer_interface
.. py:module:: torch.utils.bundled_inputs
.. py:module:: torch.utils.checkpoint
.. py:module:: torch.utils.collect_env
.. py:module:: torch.utils.cpp_backtrace
.. py:module:: torch.utils.cpp_extension
.. py:module:: torch.utils.data.backward_compatibility
.. py:module:: torch.utils.data.dataloader
.. py:module:: torch.utils.data.datapipes.dataframe.dataframe_wrapper
.. py:module:: torch.utils.data.datapipes.dataframe.dataframes
.. py:module:: torch.utils.data.datapipes.dataframe.datapipes
.. py:module:: torch.utils.data.datapipes.dataframe.structures
.. py:module:: torch.utils.data.datapipes.datapipe
.. py:module:: torch.utils.data.datapipes.gen_pyi
.. py:module:: torch.utils.data.datapipes.iter.callable
.. py:module:: torch.utils.data.datapipes.iter.combinatorics
.. py:module:: torch.utils.data.datapipes.iter.combining
.. py:module:: torch.utils.data.datapipes.iter.filelister
.. py:module:: torch.utils.data.datapipes.iter.fileopener
.. py:module:: torch.utils.data.datapipes.iter.grouping
.. py:module:: torch.utils.data.datapipes.iter.routeddecoder
.. py:module:: torch.utils.data.datapipes.iter.selecting
.. py:module:: torch.utils.data.datapipes.iter.sharding
.. py:module:: torch.utils.data.datapipes.iter.streamreader
.. py:module:: torch.utils.data.datapipes.iter.utils
.. py:module:: torch.utils.data.datapipes.map.callable
.. py:module:: torch.utils.data.datapipes.map.combinatorics
.. py:module:: torch.utils.data.datapipes.map.combining
.. py:module:: torch.utils.data.datapipes.map.grouping
.. py:module:: torch.utils.data.datapipes.map.utils
.. py:module:: torch.utils.data.datapipes.utils.common
.. py:module:: torch.utils.data.datapipes.utils.decoder
.. py:module:: torch.utils.data.datapipes.utils.snapshot
.. py:module:: torch.utils.data.dataset
.. py:module:: torch.utils.data.distributed
.. py:module:: torch.utils.data.graph
.. py:module:: torch.utils.data.graph_settings
.. py:module:: torch.utils.data.sampler
.. py:module:: torch.utils.debug_mode
.. py:module:: torch.utils.dlpack
.. py:module:: torch.utils.file_baton
.. py:module:: torch.utils.flop_counter
.. py:module:: torch.utils.hipify.constants
.. py:module:: torch.utils.hipify.cuda_to_hip_mappings
.. py:module:: torch.utils.hipify.hipify_python
.. py:module:: torch.utils.hipify.version
.. py:module:: torch.utils.hooks
.. py:module:: torch.utils.jit.log_extract
.. py:module:: torch.utils.mkldnn
.. py:module:: torch.utils.mobile_optimizer
.. py:module:: torch.utils.show_pickle
.. py:module:: torch.utils.tensorboard.summary
.. py:module:: torch.utils.tensorboard.writer
.. py:module:: torch.utils.throughput_benchmark
.. py:module:: torch.utils.weak