This PR fixes some issues with NJT backward / compile backward tests:
1. `requires_grad` was not being propagated appropriately during `SampleInput` generation, so a LOT of backward cases were untested before (sad times). This PR utilizes a helper function `_clone()` to clone() / detach() NJTs for SampleInputs while preserving `requires_grad` status. Note: the clone() / detach() stuff is for autograd; can't have two SampleInputs as part of the same autograd graph.
2. Per-sample skips weren't -fully- working; the op logic would still be invoked even with a skip. I found this out thanks to `split_with_sizes`, which segfaults during backwards because it tries to use an NST-specific formula. As annoying as it is, I tried a ton of things but ultimately had to split the `subtest_ctx` into that + a `skip_xfail_ctx` to run the subtests within.
* Updated all uses of per-sample skips / xfails: 4 in `test_nestedtensor.py` and 1 in `test_vmap.py`
3. Added the appropriate skips / xfails to get everything passing. There are a shitton of bugs to fix!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143072
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
Allow mutations mutations for subclasses that are non-contiguous.
Changes:
Removing assert in collect_metadata_analysis
Main requested testcase:
Compilation of NJT.index_put()
Adding test in test_nestedtensor.py, that compiles NJT.index_put()
It is decomposed to NJT split,unbind, which needed additional `torch._check`, `torch._check_is_size` for NJT.unbind() and guard_size_oblivious() usage in _meta_registrations and _inductor/lowering.py.
Special case:
If tangent is mutated outside of the graph, it does not participate in backward graph. Autograd in this case will set this tangent to zeros tensor.
We handle it separately in CompiledFunction.backward: not doing any processing for this tangent and broadcast to number of expected subclass unwrapped arguments.
disabling for dynamo 2 tests:
1/ For nested tensor - symbolic shapes issue on nested_tensor index operation that does splits [0, 0, 0] - there is a failure with "pending unbacked symints". This PR does not add more .tolist()/item() ops than it was before.
2/ As we do not fail with exception in collect_metadata_analysis new paths for dynamo started working and it started failing with smth strange that set_ in storage_offset (because of test for views) handling updates storage "cpu" -> "meta"
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139630
Approved by: https://github.com/bdhirsh
This PR contains three `unsqueeze()`-related fixes for NJT:
1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim
2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly
3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125
Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141392
Approved by: https://github.com/cpuhrsch
ghstack dependencies: #141500, #140736, #140161
This PR introduces `ExtraOpData`, a structure that contains op metadata regarding whether the op is a view and the dim-related args it accepts. It also populates a huge database for dim-wise / view ops with this info.
Test logic (sample input generation, references) have been updated to utilize this data. It allows for a fairly generic set of sample inputs & a reference for the class of ops that accept a single NJT and operate dim-wise (AKA "unary dimwise ops").
Testing is added over the following ops:
* `chunk()`
* `narrow()`
* `select()`
* `split()`
* `split_with_sizes()`
* `squeeze()`
* `unflatten()`
* `unsqueeze()`
Most of the above do not operate on the ragged / batch dims or on non-contiguous NJTs, so the proper xfails are added as needed.
I also slipped in a couple minor fixes (sorry):
1. The `_wrap_jagged_dim()` helper now avoids assuming the `nt._ragged_idx == 1` and allows for a batch dim to be a valid input, disambiguating the converted inner dim as necessary through an additional `operating_on_batch` return value (i.e. both dim=0 and dim=1 map to dim=0 on the inner values tensor, since that dim represents a packed ragged dim for all batch items)
2. Padded dense -> NJT conversion requires shape gymnastics to operate with the restrictive FBGEMM kernel. The gymnastics were slightly wrong for the transposed NJT case, and this PR fixes that
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140161
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
ghstack dependencies: #141500, #140736
**Background:** It's common to use `scalar_tensor()` in the input to `where()` to convert any scalars present to compatible tensors with matching options, *including layout*. This shows up in various places, notably including derivative formulas ([example](78491d6afc/tools/autograd/derivatives.yaml (L432-L434))). It causes problems for NJTs because they have `layout=torch.jagged` and it never makes sense to create a scalar tensor with this layout. Some of the breakage only seems to happen in CI for reasons I don't fully understand (see the revert of #140736 due to softshrink's derivative formula).
**This PR:**
* Allows non-contiguous NJT inputs to `where()` + adds tests for this
* Handles scalar tensor / dense tensor inputs for `condition` / `other` + adds tests for this
* Uses limited `broadcast_tensors()` / `broadcast_to()` support
* Improves `expand()` to work on non-contig NJTs
* Changes `scalar_tensor()` to use `torch.strided` instead of `torch.jagged` in both eager and torch.compile (i.e. meta registration)
* Changes backward formulas for `sinc`, `pow`, `special.i1`, and `special.i1e` to uses `scalar_tensor()` instead of e.g. `zeros({})`
**Alternative approach:** Update all problematic usages of `scalar_tensor()` to avoid ever passing `layout=torch.jagged`. This is an extensive change and includes `torch.where()` logic, a bunch of derivative formulas, and likely other places not yet discovered.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141500
Approved by: https://github.com/malfet, https://github.com/cpuhrsch, https://github.com/soulitzer
This PR contains three `unsqueeze()`-related fixes for NJT:
1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim
2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly
3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125
Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141392
Approved by: https://github.com/cpuhrsch
ghstack dependencies: #140736, #140161
This PR introduces `ExtraOpData`, a structure that contains op metadata regarding whether the op is a view and the dim-related args it accepts. It also populates a huge database for dim-wise / view ops with this info.
Test logic (sample input generation, references) have been updated to utilize this data. It allows for a fairly generic set of sample inputs & a reference for the class of ops that accept a single NJT and operate dim-wise (AKA "unary dimwise ops").
Testing is added over the following ops:
* `chunk()`
* `narrow()`
* `select()`
* `split()`
* `split_with_sizes()`
* `squeeze()`
* `unflatten()`
* `unsqueeze()`
Most of the above do not operate on the ragged / batch dims or on non-contiguous NJTs, so the proper xfails are added as needed.
I also slipped in a couple minor fixes (sorry):
1. The `_wrap_jagged_dim()` helper now avoids assuming the `nt._ragged_idx == 1` and allows for a batch dim to be a valid input, disambiguating the converted inner dim as necessary through an additional `operating_on_batch` return value (i.e. both dim=0 and dim=1 map to dim=0 on the inner values tensor, since that dim represents a packed ragged dim for all batch items)
2. Padded dense -> NJT conversion requires shape gymnastics to operate with the restrictive FBGEMM kernel. The gymnastics were slightly wrong for the transposed NJT case, and this PR fixes that
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140161
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
ghstack dependencies: #140736
### Background
This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality.
This was originally landed for NJT in #138370 and is generalized and slightly tweaked here.
### Design
#### Principles
* Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed).
* Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops.
* This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example.
* Opt-in with minimal test logic changes + no substantial impact on other tests.
#### Details
The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`.
```python
@dataclass
class SampleRule(ABC):
# function to indicate whether the rule applies to this op; return True if so
# NB: str arg of callable is device_type
op_match_fn: Callable[[str, OpInfo], bool] = None
# function to indicate whether the rule applies to this sample; return True if so
sample_match_fn: Callable[[torch.device, SampleInput], bool] = None
# optional name for identifying the rule
name: str = ""
@dataclass
class XFailRule(SampleRule):
# expected error type
error_type: TypeVar = Exception
# expected error message
error_msg: str = ".*"
@dataclass
class SkipRule(SampleRule):
...
```
* See below for example usage details, but at a high level: each test should have a corresponding list of `sample_skips_and_xfails`.
* The list of `sample_skips_and_xfails` is traversed in order, and the first rule that matches (if any) is applied, so order can matter.
* The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`.
* The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test.
* Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic.
### Example Usage
Consider the following OpInfo test:
```python
class MyTestCase(TestCase):
@ops(op_db)
def test_foo(self, device, dtype, op):
for sample in op.sample_inputs(device, dtype, requires_grad=False):
# do some SampleInput-based test logic
output = op.op(sample.input, *sample.args, **sample.kwargs)
...
```
This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic.
This PR lets you do this to get very flexible xfail / skips based on op / sample input properties:
```python
# NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the @ops decorator, but
# it can be more readable to maintain these somewhere else. These are attempted to be matched in order and
# the first one that matches applies, so order can matter.
FOO_SKIPS_AND_XFAILS = [
XFailRule(
error_type=ValueError,
error_mg="2D inputs not supported",
op_match_fn=lambda device, op: (
# NB: logic for which ops this rule applies to goes here
op.full_name == "add"
),
sample_match_fn=lambda device, sample: (
# NB: logic which samples this rule applies to goes here
sample.input.dim() == 2
),
# NB: optional rule identifier can help with debugging matched rules
name="add_with_2D_inputs_not_supported",
),
# NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously
# this skips a particular SampleInput instead of xfailing :)
SkipRule(...),
...
]
class MyTestCase(TestCase):
@ops(op_db)
@sample_skips_and_xfails(FOO_SKIPS_AND_XFAILS)
# NB: the @ops decorator automatically filters out any rules that don't apply to this op
def test_foo(self, device, dtype, op):
for sample, subtest_ctx in op.sample_inputs(
# NB: use_subtests=True is required for skips / xfails to work. If skips / xfails are defined and use_subtests != True,
# an informative error will be thrown.
device, dtype, requires_grad=False, use_subtests=True
):
# NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately
with subtest_ctx(self):
# do some SampleInput-based test logic
output = op.op(sample.input, *sample.args, **sample.kwargs)
...
```
More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice.
I also demonstrate usage of syntactic sugar over this system in `test/functorch/test_vmap.py`. Here, a skip for the `to()` operator is replaced with a granular xfail for `test_vmap_exhaustive()`:
```python
...
# pre-existing xfail
xfail("item"),
# new granular xfail using syntactic sugar over the general system
xfailIf(
"to",
lambda sample: (
sample.kwargs["memory_format"] == torch.channels_last
),
),
...
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140443
Approved by: https://github.com/janeyx99, https://github.com/zou3519
ghstack dependencies: #140160, #138370
Per discussion with @malfet, only allow weights_only unpickler to load NJT if `torch.nested` and `torch._dynamo` are imported
(this is slightly weird as technically `torch.nested` is actually imported by default and `torch._dynamo.decorators._DimRange` is actually what needs to be imported)
we can't import this from `torch.nested` as this would
- undo dynamo lazy import
- cause circular import
===========================
Redo of https://github.com/pytorch/pytorch/pull/140304 caused issues as `torch.nested._internal.foo` needs to be imported, which causes issues like
```python
torch/_weights_only_unpickler.py", line 339, in load
if full_path in _get_allowed_globals():
torch/_weights_only_unpickler.py", line 188, in _get_allowed_globals
torch.nested._internal.nested_tensor.NestedTensor
AttributeError: module 'torch.nested' has no attribute '_internal'
```
**This likely wasn't caught in our CI because imports are global during unit tests(?), so we use subprocess to properly test this time**
Differential Revision: [D65961691](https://our.internmc.facebook.com/intern/diff/D65961691)
@jbschlosser
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140739
Approved by: https://github.com/malfet
Fixes#140598
Allows ragged structures for query and key+value sequence lengths to differ (i.e. supports cross attention for Flex + NJT).
Technically, this is BC-breaking thanks to arg renaming and positional arg reordering in `create_nested_block_mask()`, but Flex + NJT support isn't in a major release yet so I'm hoping we can just do it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140723
Approved by: https://github.com/drisspg
This PR updates OpInfo-based tests for NJTs:
* Adds extensive coverage across non-contiguous NJTs (both non-contiguous transposed and non-contiguous with holes)
* The `_sample_njts()` helper that `sample_input_func`s utilize now produces non-contig NJTs as well
* Utilizes a `SampleInput`-based xfail system for granular classification of bugs. For example, it's possible to indicate that a class of ops is expected to fail only on non-contig with holes NJT inputs.
* I decided on adding `SampleInput`s and utilizing this system over using test parametrization for two reasons:
* Test perf - adding `SampleInput`s is faster than generating entire new tests
* Avoiding the possibility of `sample_input_func`s not respecting the non-contig test parameter - this would result in silently incorrect passing of these tests. Keeping the responsibility for `SampleInput` generation firmly within each `OpInfo`'s `sample_input_func` means weirdness like this isn't possible
* Improves `SampleInput` naming for a bunch of `sample_input_func`s. This makes it easier to xfail them as needed. For example, binary / unary / other ops now use the new `_describe_njt()` helper to get a string repr that uniquely defines the type of NJT being passed to the op
* Adds appropriate `XFailRule`s to get tests passing for forward / backward / forward compile / backward compile. In general, each xfail corresponds to some bug that needs to be fixed
```python
# Represents a rule indicating how to xfail a particular test. It allows granularity
# at the device, dtype, op, and individual sample levels. This flexibility allows entire
# bugs to be represented by a single rule, even if this corresponds with multiple conceptual
# test cases across multiple ops.
@dataclass
class XFailRule:
# expected error type
error_type: TypeVar = Exception
# expected error message
error_msg: str = ".*"
# function to indicate whether the rule applies; return True if so
match_fn: Callable[[torch.device, torch.dtype, OpInfo, SampleInput], bool] = None
# optional name for identifying the rule
name: str = ""
def match(self, device, dtype, op, sample) -> bool:
return self.match_fn(device, dtype, op, sample)
```
Example:
```python
# Bug when broadcasting a binary op with non-contiguous with holes NJT + dense
# tensor with 1 in ragged dim.
XFailRule(
error_type=RuntimeError,
error_msg="cannot call binary pointwise function .* with inputs of shapes",
match_fn=lambda device, dtype, op, sample: (
isinstance(op, BinaryUfuncInfo)
and "noncontig_holes" in sample.name
and "broadcasting 1 over ragged" in sample.name
),
name="binary_noncontig_holes_broadcasting_1_over_ragged",
),
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138370
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
ghstack dependencies: #140160
Follow up to some issues @malfet's recent PR pointed out about missing ops #139763. Tried to mirror it to other important nearby ops. Seems like we could automate / autogen this more for generic pointwise ops like this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139890
Approved by: https://github.com/malfet
Fixes#137512
Relaxes the restriction that the ragged dim is immediately next to the batch dim e.g. `(B, *, D_0, ..., D_N)`. This allows for constructing NJTs of shape e.g. `(B, D, j0)` directly. It's possible before this PR to get an NJT of e.g. shape `(B, D, j0)` by constructing an NJT of shape `(B, j0, D)` and transposing it. This PR allows a user to go straight there without the transpose. The standard `torch.nested.nested_tensor(list)` constructor has been updated to support this.
At the very least, this is useful for testing on transposed NJTs. I'm willing to make this functionality private if needed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137125
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
I'm sick of reductions not working properly - spotty dim coverage, missing backwards, etc. This PR fixes quite a bit.
It applies to the following ops:
* `sum` / `mean` / `prod`
* `all` / `any`
* `amin` / `amax`
* `min` / `max`
* `argmin` / `argmax`
The general reduction logic has been factored out into a helper `_apply_reduction(func, func_name, identity_element, *args, **kwargs)`. The idea is that by providing a valid identity element, we can utilize conversions to padded dense when needed for reducing over the ragged dim.
Extensive test coverage includes:
* reductions across ragged dim
* reductions across non-batch, non-ragged dims
* reductions across both batch and ragged dims
* multiple dim reductions (for ops that support this)
* full reduction -> scalar
Bonus: the PR includes backwards fixes for `sum` and `mean`, which have never worked.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139317
Approved by: https://github.com/cpuhrsch
This PR adds FlexAttention + NJT support. In particular:
* To handle raggedness, treats the packed sequence dim of input NJTs as a giant "stacked sequence". To ensure user `score_mod` / `mask_mod` functions can still be written in the original NJT sequence space, this PR handles conversions for indices within the giant "stacked sequence" -> sequence relative indices automatically.
* Provides `py_impls` for `NestedTensor` to the HOPs for flex attention forward / backward that simply wrap / unwrap NJTs appropriately
* Adds barebones `new_empty()` support to NJT since FlexAttention utilizes this repeatedly; right now, only `new_empty()` with a shape of `()` is supported
* Tests that FlexAttention with a causal mask matches causal SDPA
* Adds a new public API for FlexAttention usage:
* `create_nested_block_mask(mask_mod, B, H, njt, BLOCK_SIZE, _compile)` - NJT analogue for `create_block_mask()` that utilizes the `njt`'s ragged structure to create an appropriately-sized block mask (e.g. `(1, 1, total_seqlen, total_seqlen)`). This function handles the index conversion from "stacked sequence" space -> relative sequence space.
* Minor note: as this is a public API, this function is purposefully named with "nested" instead of "njt" to keep the latter as an informal, mostly internal-only term.
Example usage:
```python
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
query = ... # NJT of shape (B, H, S*, D)
key = ... # NJT of shape (B, H, S*, D)
value = ... # NJT of shape (B, H, S*, D)
# create_nested_block_mask() automatically converts indices from "stacked sequence" space -> relative sequence space
block_mask = create_nested_block_mask(causal_mask, 1, 1, query) # block mask conceptual shape is (B, H, sum(S*), sum(S*))
output = flex_attention(query, key, value, block_mask=block_mask)
def causal_score_mod(score, b, h, q_idx, kv_idx):
return torch.where(q_idx >= kv_idx, score, float("-inf"))
# flex_attention() automatically converts indices from "stacked sequence" space -> relative sequence space for NJT inputs
output2 = flex_attention(query, key, value, score_mod=causal_score_mod)
```
TODO:
* ~~Determine the right level of abstraction for public API helpers + move them alongside other helpers~~ Verify this with others though
* ~~Some cleanup~~
* ~~`njt_score_mod_adapter`~~
* ~~Q: should `create_njt_block_mask()` call `njt_mask_mod_adapter()` so we don't need two calls?~~
* Can we avoid materializing the `sum(s)` length `seq_idx` used for conversion between stacked sequence -> sequence relative indices?
* Not for now, although future work may deepen the integration between Flex + NJT (possibly requiring custom templates). We should try to cache this though.
* ~~Demonstrate non-causal mask~~
* Support non-contiguous NJTs with holes (**booted to future PR**)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136792
Approved by: https://github.com/drisspg
ghstack dependencies: #138841