pytorch/torch/testing
Joel Schlosser e7ec294c10 NJT OpInfo tests v2 (#138370)
This PR updates OpInfo-based tests for NJTs:
* Adds extensive coverage across non-contiguous NJTs (both non-contiguous transposed and non-contiguous with holes)
    * The `_sample_njts()` helper that `sample_input_func`s utilize now produces non-contig NJTs as well
* Utilizes a `SampleInput`-based xfail system for granular classification of bugs. For example, it's possible to indicate that a class of ops is expected to fail only on non-contig with holes NJT inputs.
    * I decided on adding `SampleInput`s and utilizing this system over using test parametrization for two reasons:
        * Test perf - adding `SampleInput`s is faster than generating entire new tests
        * Avoiding the possibility of `sample_input_func`s not respecting the non-contig test parameter - this would result in silently incorrect passing of these tests. Keeping the responsibility for `SampleInput` generation firmly within each `OpInfo`'s `sample_input_func` means weirdness like this isn't possible
* Improves `SampleInput` naming for a bunch of `sample_input_func`s. This makes it easier to xfail them as needed. For example, binary / unary / other ops now use the new `_describe_njt()` helper to get a string repr that uniquely defines the type of NJT being passed to the op
* Adds appropriate `XFailRule`s to get tests passing for forward / backward / forward compile / backward compile. In general, each xfail corresponds to some bug that needs to be fixed

```python
# Represents a rule indicating how to xfail a particular test. It allows granularity
# at the device, dtype, op, and individual sample levels. This flexibility allows entire
# bugs to be represented by a single rule, even if this corresponds with multiple conceptual
# test cases across multiple ops.
@dataclass
class XFailRule:
    # expected error type
    error_type: TypeVar = Exception
    # expected error message
    error_msg: str = ".*"
    # function to indicate whether the rule applies; return True if so
    match_fn: Callable[[torch.device, torch.dtype, OpInfo, SampleInput], bool] = None
    # optional name for identifying the rule
    name: str = ""

    def match(self, device, dtype, op, sample) -> bool:
        return self.match_fn(device, dtype, op, sample)
```

Example:
```python
    # Bug when broadcasting a binary op with non-contiguous with holes NJT + dense
    # tensor with 1 in ragged dim.
    XFailRule(
        error_type=RuntimeError,
        error_msg="cannot call binary pointwise function .* with inputs of shapes",
        match_fn=lambda device, dtype, op, sample: (
            isinstance(op, BinaryUfuncInfo)
            and "noncontig_holes" in sample.name
            and "broadcasting 1 over ragged" in sample.name
        ),
        name="binary_noncontig_holes_broadcasting_1_over_ragged",
    ),
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138370
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
ghstack dependencies: #140160
2024-11-11 19:35:24 +00:00
..
_internal NJT OpInfo tests v2 (#138370) 2024-11-11 19:35:24 +00:00
__init__.py [BE][Easy][19/19] enforce style for empty lines in import segments in torch/[o-z]*/ (#129771) 2024-08-01 17:07:14 +00:00
_comparison.py Strict shape checking for NJTs with TestCase.assertEqual() (#131898) 2024-07-30 20:05:48 +00:00
_creation.py Ensure noncontiguous tensor creation tests offsetting (#136396) 2024-10-02 00:40:43 +00:00
_utils.py [BE][Easy][19/19] enforce style for empty lines in import segments in torch/[o-z]*/ (#129771) 2024-08-01 17:07:14 +00:00