Also fixes xlogy (turns out the only thing it was missing was a type cast annotation! nice!)
I also renamed `canonicalize_idx` => `canonicalize_dim` (to align with `canonicalize_dims`) and fixed a bug in it (cc: @mruberry)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76873
Approved by: https://github.com/mruberry
This PR...
Adds the following prims:
- slice
- slice_in_dim
- transpose
Adds the following refs:
- cat
- permute
- transpose
- swap_axes (alias for transpose)
- tensor_split
Makes the following test improvements:
- adds reference inputs for torch.permute
- adds a NumPy reference for torch.permute
- adds reference inputs for torch.cat
Fixes the following bugs:
- adds support for scalars to the min and max prims
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76727
Approved by: https://github.com/ngimel
This PR makes the following changes:
Prims:
- igamma and igammac are now correctly listed as elementwise binary operations, not elementwise unary operations
- elementwise prims now must specify their type promotion kind (this is currently unused)
Refs:
- complexhalf is now handled by opmath-style type promotion
- adds references for: abs, acos, acosh, asin, atan, ceil, cos, cosh, digamma, erf, erfinv, erfc, exp, expm1, isfinite, isnan, lgamma, log, log1p, neg, reciprocal, sign, sin, sinh, sqrt, square, tan, igamma, igammac
- adds "complex to float" and "bool to long" type promotion kinds
- updates out behavior to warn when resizing a non-empty tensor, consistent with current ops
- updates the elementwise unary reference template with type promotion
Tests:
- fixes torch.pow's OpInfo to correctly specify it only supports one scalar input, not two
- fixes elementwise binary reference inputs to not attempt generating certain tensors in complex half (for now, cc @kshitij12345)
- adds OpInfos for the following Python references: abs, acos, acosh, asin, atan, ceil, cos, cosh, digamma, erf, erfinv, erfc, exp, expm1, isfinite, isnan, lgamma, log, log1p, neg, reciprocal, round, sign, sin, sinh, sqrt, square, tan, atan2, bitwise_and, bitwise_left_shift, bitwise_or, bitwise_xor, eq, float_power, ge, gt, igamma, igammac, le, lt, maximum, minimum, mul, ne, nextafter, pow, sub, true_divide
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76647
Approved by: https://github.com/ngimel
This adds prototype nvFuser integration for the following prims:
- broadcast_in_dim
- convert_element_type
- add
- div
- ge
- gt
- le
- lt
- mul
Adding it for additional prims supported by nvFuser's prototype Python frontend should be easy.
This also adds a new sugar to run operations using the ATen or nvFuser trace executors. For example:
```
def foo(a, b):
return torch.add(a, b)
traced_foo = make_traced(foo)
a = torch.randn((1, 2, 3, 4, 5), device='cuda')
b = torch.randn((1, 2, 3, 4, 5), device='cuda')
result = traced_foo(a, b, executor='nvfuser')
```
Currently only operations with tensor inputs and one tensor output are supported, and the operation must be composed exclusively of reference or prim operations.
Finally, this adds a new test, test_prims.py, that just tests the broadcast_in_dim prim for now. In the future we'll likely have OpInfos for each prim, but we'll need a reference implementation of broadcast_in_dim to make that interesting.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76560
Approved by: https://github.com/ngimel
Adds a prototype tracer with no caching support and the `ElementwiseUnaryPythonRefInfo` class. A reference for `floor` is added to test the latter, and the elementwise binary reference inputs are extended to also return noncontiguous inputs. The SampleInput transform operation has been updated to return an actual SampleInput instead of a tuple to facilitate uniform handling of (transformed) SampleInputs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76388
Approved by: https://github.com/ngimel
Summary:
This PR adds an initial set of experimental primitive operations and Python references that reimplement existing PyTorch operations using them. See https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-0/577 for additional context.
The following experimental primitives are added:
- Elementwise unary prims -- abs, acos, acosh, asin, atan, cos, cosh, bessel_i0e, bessel_i1e, cbrt, ceil, digamma, erf, erf_inv, erfc, exp, expm1, floor, igamma, igammac, is_finite, lgamma, log, log1p, neg, reciprocal, round, sign, sinh, sqrt, square, tan.
- Elementwise binary prims -- add, atan2, bitwise_and, bitwise_not, bitwise_or, bitwise_xor, div, eq, ge, gt, le, lt, max, min, mul, ne, nextafter, pow, rsqrt, shift_left, shift_right_arithmetic
- View prims -- brodcast_in_dim, collapse_view, split_dim, squeeze
- Shape prims -- collapse, concatenate, reshape
- Conditional prims -- select
- Data conversion & movement prims -- convert_element_type, device_put
- Inplace prims -- copy_to, resize
These primitives do not add any new functionality to PyTorch, but are intended to be the semantic building blocks for reference operators. We have tried to make them consistent with the operations in [jax.lax](https://jax.readthedocs.io/en/latest/jax.lax.html) where possible (because PyTorch prefers being consistent with other frameworks), although there are key differences between these prims and operations in jax.lax. Most notably is that these prims model view semantics and inplace operations.
In addition to these primitives the following elementwise binary Python references are added:
- Elementwise binary Python references -- add, atan2, bitwise_and, bitwise_left_shift, bitwise_or, bitwise_right_shift, bitwise_xor, eq, float_power, ge, gt, le, lt, maximum, minimum, mul, ne, nextafter, pow, sub, true_divide
- Conditional Python references - where
- Data conversion & movement references - copy_to
A Python reference implements the same behavior as its corresponding PyTorch operator (excepting slight numerical differences, bug fixes, and in some cases additional features).
The start of an OpInfo-based test architecture for these references is also included in this PR. A new list, `python_ref_db`, is added to `common_methods_invocations.py`. This list introduces the new `ElementwiseBinaryPythonRefInfo`, which inherits input arguments from the original operators' OpInfo, allows them to be overridden, and then constructs the OpInfo for the Python reference using the (potentially modified) arguments. OpInfo-based tests can opt-into testing references by including this new list in the Sequence passed to the `ops` decorator.
cc ngimel csarofeen kevinstephano Lezcano
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75095
Reviewed By: ngimel
Differential Revision: D35888004
Pulled By: mruberry
fbshipit-source-id: 21e77c4456c2a02113367d4bdae168a3a2f33f25
(cherry picked from commit 1d5bcfa99d4e8cf36f60642803a0bfca50e2ea4e)