1. Added references `_refs.broadcast_shapes`
2. Added OpInfo test for `torch.broadcast_shapes`
A few minor changes:
- `test_python_ref_meta` and `_ref_test_helper` update to avoid non-tensor outputs
- type annotation update for `_resize_meta`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78612
Approved by: https://github.com/mruberry
1. Added references for the two ops;
2. Inherited original operators' OpInfo tests;
TODO for future PR:
adding primTorch references for `dsplit` and `dstack`. <- Those two should use `atleast_3d` which is in a different packet right now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78416
Approved by: https://github.com/mruberry
This PR is a result of collaboration with @rdspring1 and @mruberry on primTorch.
It adds the following prims:
- `fmax`
- `fmin`
- `fmod`
And adds the following refs:
- `fmax`
- `fmin`
- `fmod`
- `logical_xor`
The work is in progress as there are some tests that fail.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78023
Approved by: https://github.com/mruberry
Feels good to delete it from `torch._decomps`. This is mainly to clarify the process for me -
Seems like there's still some components missing of the `torch <-> refs` mapping? For example, seems like methods don't work yet for mapping from torch <-> refs, and neither do the meta tests? (cc: @ezyang).
If I replace the `torch` with `refs`, then the tests seem to pass.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78041
Approved by: https://github.com/mruberry
This PR adds the item, equal, any, and all references.
While doing this I found the following issues:
- https://github.com/pytorch/pytorch/issues/78070
- https://github.com/pytorch/pytorch/issues/78071
And I fixed a bug where the `convert_element_type` prim could not convert tensors requiring grad to datatypes that don't require grad.
Creating the item reference required adding item as a prim, but per @ngimel's suggestion I removed the prims for any and all and implemented them as references, so this is net negative one prim.
Reference OpInfos are added for any and all, but item and equal don't even have regular OpInfos.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78072
Approved by: https://github.com/ngimel
This PR...
**Issues Found**
- https://github.com/pytorch/pytorch/issues/78058
- https://github.com/pytorch/pytorch/issues/78054
- https://github.com/pytorch/pytorch/issues/78053
- https://github.com/pytorch/pytorch/issues/78050
- https://github.com/pytorch/pytorch/issues/77932
**Testing**
- disables stride consistency checks in test_ops and test_meta pending resolution of https://github.com/pytorch/pytorch/issues/78050
- skips chalf in reference tests (addressing https://github.com/pytorch/pytorch/issues/78054)
- splits test test_python_reference_consistency in one test for the ctx where torch.foo is torch.foo, and another for when torch.foo is refs.foo
- updates test names to be more natural and consistent:
- test_python_reference_errors -> test_python_ref_errors
- test_python_reference_consistency -> test_python_ref and test_python_ref_torch_fallback
- test_python_reference_meta_functions -> test_python_ref_meta
- test_reference_testing -> test_numpy_ref
- updates test_python_ref and test_python_ref_torch_fallback to check that the reference is more accurate than the torch op if the reference and torch op results are not close, a warning is raised when this occurs (addressing https://github.com/pytorch/pytorch/issues/77687)
- adds reference inputs for broadcast_tensors
- Updates the "fill_" OpInfo to "fill", adding a NumPy reference and making it an elementwise unary operator
- Adds 1D no element sample inputs to the cat OpInfo and updates the NumPy reference to handle them and type promotion correctly
- Adds reference inputs for elementwise ternary operations, like clamp
- Adds a NumPy reference for clamp
- Adds reference inputs to where's OpInfo
- Makes softplus an elementwise unary OpInfo
- Removes the great majority of Python reference OpInfo skips and xfails due to the above test changes
- Adds Python reference OpInfos for fill, dropout, clamp, broadcast_tensors, and where
**Prims**
- adds the fill, empty_strided, and uniform prims
- removes the empty, empty_like, full, and full_like prims -- these are now references that use empty_strided and fill
- renames the "concatenate" and "select" prims to "cat" and "where", respectively, to be consistent with PyTorch
- extends the `_elementwise_meta` operation to accepts tensors that don't participate in type promotion, like the `cond` tensor in `where`
- fixes a bug in the stride propagation of broadcast_in_dim
- moves some error checks from prims.cat to prims.where to refs.cat and refs.where, respectively, consistent with our new policy of doing as much error checking in the ref as possible
**Utils**
- adds the canoicalize_device, extract_shape, and extract_shape_from_varargs helpers
- adds the elementwise_unary_scalar_wrapper -- this allows elementwise unary operators to take and return scalar values (ex. refs.sin(1) will return .84...)
**Refs**
- adds the fill, broadcast_tensors, clamp, empty_strided, ones, zeros, and uniform references
- adds the nn.functional.dropout reference
- fixes refs.cat to handle 1D tensors with no inputs consistent with eager mode
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78026
Approved by: https://github.com/ngimel