mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR introduces `ExtraOpData`, a structure that contains op metadata regarding whether the op is a view and the dim-related args it accepts. It also populates a huge database for dim-wise / view ops with this info. Test logic (sample input generation, references) have been updated to utilize this data. It allows for a fairly generic set of sample inputs & a reference for the class of ops that accept a single NJT and operate dim-wise (AKA "unary dimwise ops"). Testing is added over the following ops: * `chunk()` * `narrow()` * `select()` * `split()` * `split_with_sizes()` * `squeeze()` * `unflatten()` * `unsqueeze()` Most of the above do not operate on the ragged / batch dims or on non-contiguous NJTs, so the proper xfails are added as needed. I also slipped in a couple minor fixes (sorry): 1. The `_wrap_jagged_dim()` helper now avoids assuming the `nt._ragged_idx == 1` and allows for a batch dim to be a valid input, disambiguating the converted inner dim as necessary through an additional `operating_on_batch` return value (i.e. both dim=0 and dim=1 map to dim=0 on the inner values tensor, since that dim represents a packed ragged dim for all batch items) 2. Padded dense -> NJT conversion requires shape gymnastics to operate with the restrictive FBGEMM kernel. The gymnastics were slightly wrong for the transposed NJT case, and this PR fixes that Pull Request resolved: https://github.com/pytorch/pytorch/pull/140161 Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch ghstack dependencies: #140736 |
||
|---|---|---|
| .. | ||
| __init__.py | ||
| _masked.py | ||
| fft.py | ||
| linalg.py | ||
| nested.py | ||
| signal.py | ||
| sparse.py | ||
| special.py | ||