pytorch/torch/distributed/tensor/_ops
nikitaved edc2d539d1 torch.tensordot: performance improvements when contracting to a scalar. (#145936)
As per title.
Fixes https://github.com/pytorch/pytorch/issues/145731

Touches only compute. The CPU overhead can potentially be further reduced.

Before:
```python
In [3]: n = 512

In [4]: A = torch.rand(n, n)

In [5]: B = torch.rand(n, n)

In [6]: %timeit torch.tensordot(A, B, [[0, 1], [0, 1]])
2.04 ms ± 70 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [7]: %timeit torch.tensordot(A, B, [[0, 1], [1, 0]])
2.85 ms ± 191 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [8]: %timeit torch.tensordot(A, B, [[1, 0], [0, 1]])
2.9 ms ± 133 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [9]: %timeit torch.tensordot(A, B, [[1, 0], [1, 0]])
4.07 ms ± 262 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```

After
```python
In [2]: n = 512

In [3]: A = torch.rand(n, n)

In [4]: B = torch.rand(n, n)

In [5]: %timeit torch.tensordot(A, B, [[0, 1], [0, 1]])
30.7 µs ± 2.51 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [6]: %timeit torch.tensordot(A, B, [[0, 1], [1, 0]])
141 µs ± 6.52 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [7]: %timeit torch.tensordot(A, B, [[1, 0], [0, 1]])
142 µs ± 4.03 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [8]: %timeit torch.tensordot(A, B, [[1, 0], [1, 0]])
62.8 µs ± 4.31 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145936
Approved by: https://github.com/albanD, https://github.com/ngimel
2025-05-13 10:57:30 +00:00
..
__init__.py [dtensor] add op support for select_backward and slice_backward (#150357) 2025-04-01 19:15:25 +00:00
_common_rules.py [BE]: Apply ruff PERF403 to use dict comprehensions more often (#149257) 2025-03-18 00:46:07 +00:00
_conv_ops.py [DTensor][conv] add DTensor convolution_backward op support for case where the input Tensor has requires_grad=False (#142278) 2025-02-10 07:06:40 +00:00
_einsum_strategy.py [BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547) 2025-02-28 07:35:56 +00:00
_embedding_ops.py [DTensor] clean up _local_shard_size_and_offset (#150650) 2025-04-09 22:07:48 +00:00
_math_ops.py [dtensor] add op support for torch.cumsum (#151071) 2025-04-11 16:42:19 +00:00
_matrix_ops.py torch.tensordot: performance improvements when contracting to a scalar. (#145936) 2025-05-13 10:57:30 +00:00
_pointwise_ops.py Let pointwise sharding take arg with largest number of dims in case of ties (#149721) 2025-03-24 15:39:39 +00:00
_random_ops.py [dtensor] refactor sharding prop to handle cross mesh computation (#147869) 2025-03-04 18:30:44 +00:00
_tensor_ops.py [dtensor] add op support for select_backward and slice_backward (#150357) 2025-04-01 19:15:25 +00:00
_view_ops.py [DTensor] Error on illegal view op during sharding prop (#149764) 2025-04-28 18:21:49 +00:00
utils.py [BE][Easy]: Normalize Dim typing in torch distributed (#151566) 2025-04-17 19:30:09 +00:00