mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This PR updates OpInfo-based tests for NJTs:
* Adds extensive coverage across non-contiguous NJTs (both non-contiguous transposed and non-contiguous with holes)
* The `_sample_njts()` helper that `sample_input_func`s utilize now produces non-contig NJTs as well
* Utilizes a `SampleInput`-based xfail system for granular classification of bugs. For example, it's possible to indicate that a class of ops is expected to fail only on non-contig with holes NJT inputs.
* I decided on adding `SampleInput`s and utilizing this system over using test parametrization for two reasons:
* Test perf - adding `SampleInput`s is faster than generating entire new tests
* Avoiding the possibility of `sample_input_func`s not respecting the non-contig test parameter - this would result in silently incorrect passing of these tests. Keeping the responsibility for `SampleInput` generation firmly within each `OpInfo`'s `sample_input_func` means weirdness like this isn't possible
* Improves `SampleInput` naming for a bunch of `sample_input_func`s. This makes it easier to xfail them as needed. For example, binary / unary / other ops now use the new `_describe_njt()` helper to get a string repr that uniquely defines the type of NJT being passed to the op
* Adds appropriate `XFailRule`s to get tests passing for forward / backward / forward compile / backward compile. In general, each xfail corresponds to some bug that needs to be fixed
```python
# Represents a rule indicating how to xfail a particular test. It allows granularity
# at the device, dtype, op, and individual sample levels. This flexibility allows entire
# bugs to be represented by a single rule, even if this corresponds with multiple conceptual
# test cases across multiple ops.
@dataclass
class XFailRule:
# expected error type
error_type: TypeVar = Exception
# expected error message
error_msg: str = ".*"
# function to indicate whether the rule applies; return True if so
match_fn: Callable[[torch.device, torch.dtype, OpInfo, SampleInput], bool] = None
# optional name for identifying the rule
name: str = ""
def match(self, device, dtype, op, sample) -> bool:
return self.match_fn(device, dtype, op, sample)
```
Example:
```python
# Bug when broadcasting a binary op with non-contiguous with holes NJT + dense
# tensor with 1 in ragged dim.
XFailRule(
error_type=RuntimeError,
error_msg="cannot call binary pointwise function .* with inputs of shapes",
match_fn=lambda device, dtype, op, sample: (
isinstance(op, BinaryUfuncInfo)
and "noncontig_holes" in sample.name
and "broadcasting 1 over ragged" in sample.name
),
name="binary_noncontig_holes_broadcasting_1_over_ragged",
),
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138370
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
ghstack dependencies: #140160
|
||
|---|---|---|
| .. | ||
| _internal | ||
| __init__.py | ||
| _comparison.py | ||
| _creation.py | ||
| _utils.py | ||