make broadcasting explanation clearer in matmul doc: #22763 (#45699)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45699

Test Plan: Imported from OSS

Reviewed By: mruberry

Differential Revision: D24065584

Pulled By: bdhirsh

fbshipit-source-id: 5e2cdd00ed18ad47d24d11751cfa5bee63853cc9
This commit is contained in:
Brian Hirsh 2020-10-02 06:49:19 -07:00 committed by Facebook GitHub Bot
parent 82cc86b64c
commit c703602e17

View File

@ -4964,8 +4964,14 @@ The behavior depends on the dimensionality of the tensors as follows:
1 is appended to its dimension for the purpose of the batched matrix multiple and removed after.
The non-matrix (i.e. batch) dimensions are :ref:`broadcasted <broadcasting-semantics>` (and thus
must be broadcastable). For example, if :attr:`input` is a
:math:`(j \times 1 \times n \times n)` tensor and :attr:`other` is a :math:`(k \times n \times n)`
tensor, :attr:`out` will be a :math:`(j \times k \times n \times n)` tensor.
Note that the broadcasting logic only looks at the batch dimensions when determining if the inputs
are broadcastable, and not the matrix dimensions. For example, if :attr:`input` is a
:math:`(j \times 1 \times n \times m)` tensor and :attr:`other` is a :math:`(k \times m \times p)`
tensor, :attr:`out` will be an :math:`(j \times k \times n \times p)` tensor.
tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the
matrix dimensions) are different. :attr:`out` will be a :math:`(j \times k \times n \times p)` tensor.
{tf32_note}