Commit Graph

45 Commits

Author SHA1 Message Date
Edward Z. Yang
b7215de32f prod ref
It turns out the prim is implemented incorrectly as torch.prod does not accept
a dim list, so I added a little stub for this.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78461

Approved by: https://github.com/ngimel
2022-05-31 14:18:49 +00:00
Edward Z. Yang
e562ed0964 Register PrimTorch sum as a decomposition.
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78460

Approved by: https://github.com/ngimel
2022-05-31 14:18:49 +00:00
Jason Ansel
dabf8f0569 Populate the torch._decomp table on import (#78476)
#78041 broke TorchInductor, because of:
```
>>> from torch import _decomp
>>> import torch
>>> _decomp.get_decompositions([torch.ops.aten.leaky_relu])
{}
>>> import torch._refs.nn.functional
>>> _decomp.get_decompositions([torch.ops.aten.leaky_relu])
{<OpOverload(op='aten.leaky_relu', overload='default')>: <function leaky_relu at 0x7f5a39b56c10>, <OpOverload(op='aten.leaky_relu', overload='out')>: <function leaky_relu at 0x7f5a39b56c10>}
```

cc @Chillee

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78476
Approved by: https://github.com/Chillee
2022-05-31 03:46:38 +00:00
Ryan Spring
2df1da09e1 Add Elementwise unary ops 4 references (#78216)
Add reference implementations for `nan_to_num, positive, sigmoid, signbit, tanhshink`
Add prims for `minimum_value(dtype)` and `maximum_value(dtype)`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78216
Approved by: https://github.com/mruberry
2022-05-27 21:55:34 +00:00
kshitij12345
8bd8f62812 [primTorch] refs: margin_ranking_loss, hinge_embedding_loss (#78057)
Refs for `nn.functional.margin_ranking_loss` and `nn.functional.hinge_embedding_loss`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78057
Approved by: https://github.com/mruberry
2022-05-26 21:01:57 +00:00
Aidyn-A
31016eb81e [primTorch] Elementwise Binary Ops I (#78023)
This PR is a result of collaboration with @rdspring1 and @mruberry on primTorch.

It adds the following prims:
- `fmax`
- `fmin`
- `fmod`

And adds the following refs:
- `fmax`
- `fmin`
- `fmod`
- `logical_xor`

The work is in progress as there are some tests that fail.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78023
Approved by: https://github.com/mruberry
2022-05-26 20:22:27 +00:00
Gao, Xiang
5ecd30e857 [primTorch] Rename is_finite->isfinite (#78211)
`isfinite` sounds like a better name, because PyTorch, C++, numpy all have this name instead of `is_finite`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78211
Approved by: https://github.com/ngimel, https://github.com/mruberry
2022-05-26 16:17:51 +00:00
Edward Z. Yang
a1765f0176 addr ref
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78014

Approved by: https://github.com/ngimel
2022-05-25 01:40:11 +00:00
Ryan Spring
bb4653e736 Add i0, i1, zeta refs (#78111)
Add reference implementations for i0, i1, zeta
Add prim operations for i0, i1, zeta
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78111
Approved by: https://github.com/mruberry
2022-05-23 21:33:56 +00:00
Horace He
ea5d01e629 [Primtorch] Tried porting leaky_relu into a ref (#78041)
Feels good to delete it from `torch._decomps`. This is mainly to clarify the process for me -

Seems like there's still some components missing of the `torch <-> refs` mapping? For example, seems like methods don't work yet for mapping from torch <-> refs, and neither do the meta tests? (cc: @ezyang).

If I replace the `torch` with `refs`, then the tests seem to pass.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78041
Approved by: https://github.com/mruberry
2022-05-23 18:00:21 +00:00
Mike Ruberry
2738405a76 [primTorch] Adds any, all, equal, item references (#78072)
This PR adds the item, equal, any, and all references.

While doing this I found the following issues:
- https://github.com/pytorch/pytorch/issues/78070
- https://github.com/pytorch/pytorch/issues/78071

And I fixed a bug where the `convert_element_type` prim could not convert tensors requiring grad to datatypes that don't require grad.

Creating the item reference required adding item as a prim, but per @ngimel's suggestion I removed the prims for any and all and implemented them as references, so this is net negative one prim.

Reference OpInfos are added for any and all, but item and equal don't even have regular OpInfos.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78072
Approved by: https://github.com/ngimel
2022-05-23 12:49:04 +00:00
Mike Ruberry
d4345ed0a6 [primTorch] Adds random operations (#78026)
This PR...

**Issues Found**
- https://github.com/pytorch/pytorch/issues/78058
- https://github.com/pytorch/pytorch/issues/78054
- https://github.com/pytorch/pytorch/issues/78053
- https://github.com/pytorch/pytorch/issues/78050
- https://github.com/pytorch/pytorch/issues/77932

**Testing**
- disables stride consistency checks in test_ops and test_meta pending resolution of https://github.com/pytorch/pytorch/issues/78050
- skips chalf in reference tests (addressing https://github.com/pytorch/pytorch/issues/78054)
- splits test test_python_reference_consistency in one test for the ctx where torch.foo is torch.foo, and another for when torch.foo is refs.foo
- updates test names to be more natural and consistent:
  - test_python_reference_errors -> test_python_ref_errors
  - test_python_reference_consistency -> test_python_ref and test_python_ref_torch_fallback
  - test_python_reference_meta_functions -> test_python_ref_meta
  - test_reference_testing -> test_numpy_ref
- updates test_python_ref and test_python_ref_torch_fallback to check that the reference is more accurate than the torch op if the reference and torch op results are not close, a warning is raised when this occurs (addressing https://github.com/pytorch/pytorch/issues/77687)
- adds reference inputs for broadcast_tensors
- Updates the "fill_" OpInfo to "fill", adding a NumPy reference and making it an elementwise unary operator
- Adds 1D no element sample inputs to the cat OpInfo and updates the NumPy reference to handle them and type promotion correctly
- Adds reference inputs for elementwise ternary operations, like clamp
- Adds a NumPy reference for clamp
- Adds reference inputs to where's OpInfo
- Makes softplus an elementwise unary OpInfo
- Removes the great majority of Python reference OpInfo skips and xfails due to the above test changes
- Adds Python reference OpInfos for fill, dropout, clamp, broadcast_tensors, and where

**Prims**
- adds the fill, empty_strided, and uniform prims
- removes the empty, empty_like, full, and full_like prims -- these are now references that use empty_strided and fill
- renames the "concatenate" and "select" prims to "cat" and "where", respectively, to be consistent with PyTorch
- extends the `_elementwise_meta` operation to accepts tensors that don't participate in type promotion, like the `cond` tensor in `where`
- fixes a bug in the stride propagation of broadcast_in_dim
- moves some error checks from prims.cat to prims.where to refs.cat and refs.where, respectively, consistent with our new policy of doing as much error checking in the ref as possible

**Utils**
- adds the canoicalize_device, extract_shape, and extract_shape_from_varargs helpers
- adds the elementwise_unary_scalar_wrapper -- this allows elementwise unary operators to take and return scalar values (ex. refs.sin(1) will return .84...)

**Refs**
- adds the fill, broadcast_tensors, clamp, empty_strided, ones, zeros, and uniform references
- adds the nn.functional.dropout reference
- fixes refs.cat to handle 1D tensors with no inputs consistent with eager mode
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78026
Approved by: https://github.com/ngimel
2022-05-23 01:56:28 +00:00
PyTorch MergeBot
acfbc16b1c Revert "[primTorch] Adds random operations (#78026)"
This reverts commit 043cf1f9c7.

Reverted https://github.com/pytorch/pytorch/pull/78026 on behalf of https://github.com/suo due to This broke trunk: 043cf1f9c7
2022-05-22 18:11:14 +00:00
Mike Ruberry
043cf1f9c7 [primTorch] Adds random operations (#78026)
This PR...

**Issues Found**
- https://github.com/pytorch/pytorch/issues/78058
- https://github.com/pytorch/pytorch/issues/78054
- https://github.com/pytorch/pytorch/issues/78053
- https://github.com/pytorch/pytorch/issues/78050
- https://github.com/pytorch/pytorch/issues/77932

**Testing**
- disables stride consistency checks in test_ops and test_meta pending resolution of https://github.com/pytorch/pytorch/issues/78050
- skips chalf in reference tests (addressing https://github.com/pytorch/pytorch/issues/78054)
- splits test test_python_reference_consistency in one test for the ctx where torch.foo is torch.foo, and another for when torch.foo is refs.foo
- updates test names to be more natural and consistent:
  - test_python_reference_errors -> test_python_ref_errors
  - test_python_reference_consistency -> test_python_ref and test_python_ref_torch_fallback
  - test_python_reference_meta_functions -> test_python_ref_meta
  - test_reference_testing -> test_numpy_ref
- updates test_python_ref and test_python_ref_torch_fallback to check that the reference is more accurate than the torch op if the reference and torch op results are not close, a warning is raised when this occurs (addressing https://github.com/pytorch/pytorch/issues/77687)
- adds reference inputs for broadcast_tensors
- Updates the "fill_" OpInfo to "fill", adding a NumPy reference and making it an elementwise unary operator
- Adds 1D no element sample inputs to the cat OpInfo and updates the NumPy reference to handle them and type promotion correctly
- Adds reference inputs for elementwise ternary operations, like clamp
- Adds a NumPy reference for clamp
- Adds reference inputs to where's OpInfo
- Makes softplus an elementwise unary OpInfo
- Removes the great majority of Python reference OpInfo skips and xfails due to the above test changes
- Adds Python reference OpInfos for fill, dropout, clamp, broadcast_tensors, and where

**Prims**
- adds the fill, empty_strided, and uniform prims
- removes the empty, empty_like, full, and full_like prims -- these are now references that use empty_strided and fill
- renames the "concatenate" and "select" prims to "cat" and "where", respectively, to be consistent with PyTorch
- extends the `_elementwise_meta` operation to accepts tensors that don't participate in type promotion, like the `cond` tensor in `where`
- fixes a bug in the stride propagation of broadcast_in_dim
- moves some error checks from prims.cat to prims.where to refs.cat and refs.where, respectively, consistent with our new policy of doing as much error checking in the ref as possible

**Utils**
- adds the canoicalize_device, extract_shape, and extract_shape_from_varargs helpers
- adds the elementwise_unary_scalar_wrapper -- this allows elementwise unary operators to take and return scalar values (ex. refs.sin(1) will return .84...)

**Refs**
- adds the fill, broadcast_tensors, clamp, empty_strided, ones, zeros, and uniform references
- adds the nn.functional.dropout reference
- fixes refs.cat to handle 1D tensors with no inputs consistent with eager mode
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78026
Approved by: https://github.com/ngimel
2022-05-22 10:06:24 +00:00
Natalia Gimelshein
192aa3ad5f adds std and var refs and var prim (#77948)
Per title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77948
Approved by: https://github.com/mruberry
2022-05-22 04:01:21 +00:00
kshitij12345
5f1b0a4f48 [primTorch] add exp2 (prim and ref), log10 (prim and ref), frac (ref) (#78046)
Adds `exp2`, `log10` to the prims (both also exist in C++ lib and Intel SIMD intrinsic has `exp2`)

Adds `exp2`, `log10`, `frac` to refs with corresponding entries to OpInfo.

Tried to decompose `exp2` (before adding it as prim) as
* `exp(log(2) * x)` but it wasn't stable at large numbers.
* `pow(2, x)` in which case there was stride mismatch. At cursory look, `pow` tries to preserve stride of first arg if possible.

Tried to decompose `log10` (before adding it as prim) as
* `log(x) / log(10)` passed for real dtypes. Failed for complex at extremals. Probably related to https://github.com/pytorch/pytorch/issues/52332 (not a 100% sure)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78046
Approved by: https://github.com/mruberry
2022-05-22 03:43:54 +00:00
Kshiteej K
57fab66fdc [primTorch] add refs fliplr, flipud (#78049)
Add refs for `fliplr, flipud` with corresponding OpInfo entries.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78049
Approved by: https://github.com/mruberry
2022-05-22 01:04:01 +00:00
Edward Z. Yang
6b273444c4 Add logit ref; allow non-refs to be called in refs.
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77816

Approved by: https://github.com/mruberry
2022-05-21 02:35:14 +00:00
Horace He
64b4bb4b01 Fix meta tests on norm (and relanding norm fixes) (#77930)
Had a land race with meta tests.

Will also be relanding https://github.com/pytorch/pytorch/pull/77407
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77930
Approved by: https://github.com/malfet, https://github.com/ezyang
2022-05-20 23:15:53 +00:00
PyTorch MergeBot
03546e9c07 Revert "Fixed type promotion semantics for native_batch_norm and native_layer_norm (#77407)"
This reverts commit 70d80fb424.

Reverted https://github.com/pytorch/pytorch/pull/77407 on behalf of https://github.com/malfet due to as it broke meta tests ( I guess due to landrace), see 70d80fb424
2022-05-20 02:31:57 +00:00
Kevin Stephano
11daf200e8 Adding activation references for celu, mish, selu, softplus, and tanh (#77473)
Adding activation references for celu, softplus, mish, selu.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77473
Approved by: https://github.com/mruberry
2022-05-20 00:47:31 +00:00
Horace He
70d80fb424 Fixed type promotion semantics for native_batch_norm and native_layer_norm (#77407)
Originally, when these were written, they simply used the naive strategy of "upcast all inputs to floats, and downcast all inputs back". In addition to being... not quite what the kernels did, they also didn't capture some additional semantics. Namely, that the norms (except for layer norm on CPU! cc: @ngimel) return fp32 for the mean and rstd values.

Also, folks didn't like that I wrote `native_layer_norm` in terms of `native_batch_norm`. Which is fair - so I refactored the common logic into a `normalize` function.

cc: @jansel / @bertmaher , who've been looking at lowering layer norm/batch norm.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77407
Approved by: https://github.com/bertmaher
2022-05-19 17:11:47 +00:00
Ryan Spring
3c4af1c496 [WIP] Add support for elementwise unary ops (#77807)
* Add support for `log2, isinf, zeros_like`
 * Add primitives for `log2` and `is_infinite`

I left a TODO to remove the `is_infinite` prim and to implement `isinf` reference using `isfinite` and `isnan`.
We're missing `real` and `imag` ops to handle complex tensors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77807
Approved by: https://github.com/mruberry
2022-05-19 16:26:01 +00:00
Edward Z. Yang
4941e72e40 Revert "Revert "Implement sym_sizes to create proper IR for sym ints representing tensor sizes (#76836)""
This reverts commit c35bd8d423.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77719

Approved by: https://github.com/Chillee, https://github.com/malfet
2022-05-18 18:40:57 +00:00
Mike Ruberry
580a053832 [primTorch] Enforces stride metadata (#77542)
This PR...

**Filed the Following Issues**
- https://github.com/pytorch/pytorch/issues/77553
- https://github.com/pytorch/pytorch/issues/77526
- https://github.com/pytorch/pytorch/issues/77600

**Testing**
- Updates test_dtypes to longer attempt to test the backward of sample inputs where no inputs require grad
- Adds a new test_python_reference_errors; it ensures the meta operations for references throw errors as expected
- Updates compare_tensor_meta to better handle CUDA devices, and (temporarily) restricts stride checking to the CUDA device type
- Elementwise unary and elementwise binary operators now have arbitrarily strided reference inputs
- Reference inputs for _like functions are added
- An OpInfo for torch.empty is added
- Reference inputs for torch.clone are added
- A NumPy reference for clone is added
- Adds OpInfos for refs.empty and refs.empty_like

**Prims**
- Renames the "max" and "min" prims have been renamed to "maximum" and "minimum," respectively, to better conform to their ATen names
- Adds the empty, empty_like, full, and full_like prims
- Fixes the elementwise meta function's stride propagation
- Fixes clone's meta function's stride propagation
- Fixes convert_element_type's meta's stride propagation
- Adds a (temporary) _to_dtype pprivate prim that casts a tensor while preserving its stride permutation
- Removes the _set prim comment
- Adds utils.compute_elementwise_output_strides, which computes the correct output strides for elementwise operations
- Corrects an issue where utils.make_contiguous_strides_for was creating the incorrect strides for tensors with no elements

**References**
- Adds the empty, empty_like, full, full_like, and ones_like refs
- Extends make_elementwise_unary_reference to accept an additional callable to perform extra input validation
- Adds an extra validation function to handle refs.neg(BoolTensor)
- Updates the isfinite ref to call ones_like when appropriate
- Models Python scalar handling for elementwise binary operations
- Added a 64 dim check for the amin and amax references
- opmath is now a flag that can be set separately for cpu and CUDA
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77542
Approved by: https://github.com/ezyang
2022-05-18 13:57:26 +00:00
PyTorch MergeBot
48581d74ad Revert "Add dispatch mode testing for meta tensors and other stuff"
This reverts commit c1cdb1216b.

Reverted https://github.com/pytorch/pytorch/pull/77477 on behalf of https://github.com/malfet
2022-05-18 02:56:48 +00:00
Edward Z. Yang
c1cdb1216b Add dispatch mode testing for meta tensors and other stuff
We don't have any coverage for meta tensor correctness for backwards
because torch function mode can only allow us to interpose on
Python torch API calls, but backwards invocations happen from C++.
To make this possible, I add torch_dispatch_meta test which runs the
tests with __torch_dispatch__

While doing this, I needed to generate fresh expected failure / skip
lists for the new test suite, and I discovered that my original
scaffolding for this purpose was woefully insufficient.  So I rewrote
how the test framework worked, and at the same time rewrote the
__torch_function__ code to also use the new logic.  Here's whats
new:

- Expected failure / skip is now done on a per function call basis,
  rather than the entire test.  This means that separate OpInfo
  samples for a function don't affect each other.

- There are now only two lists: expect failure list (where the test
  consistently fails on all runs) and skip list (where the test
  sometimes passes and fails.

- We explicitly notate the dtype that failed.  I considered detecting
  when something failed on all dtypes, but this was complicated and
  listing everything out seemed to be nice and simple.  To keep the
  dtypes short, I introduce a shorthand notation for dtypes.

- Conversion to meta tensors is factored into its own class
  MetaConverter

- To regenerate the expected failure / skip lists, just run with
  PYTORCH_COLLECT_EXPECT and filter on a specific test type
  (test_meta or test_dispatch_meta) for whichever you want to update.

Other misc fixes:

- Fix max_pool1d to work with BFloat16 in all circumstances, by making
  it dispatch and then fixing a minor compile error (constexpr doesn't
  work with BFloat16)

- Add resolve_name for turning random torch API functions into string
  names

- Add push classmethod to the Mode classes, so that you can more easily
  push a mode onto the mode stack

- Add some more skips for missing LAPACK

- Added an API to let you query if there's already a registration for
  a function, added a test to check that we register_meta for all
  decompositions (except detach, that decomp is wrong lol), and then
  update all the necessary sites to make the test pass.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77477

Approved by: https://github.com/zou3519
2022-05-18 00:18:34 +00:00
Natalia Gimelshein
375e21b2c6 check that flip doesn't accept repeating dimensions (#77500)
Per title.
Before this PR `flip` throws errors on invalid inputs from ATen implementation itself, and not from error checks happening in prims/refs.
We should make sure that prims/refs do all the necessary error checking (@mruberry is going to test that by moving reference error inputs testing to call meta implementations instead of real ones).
In general, most error checking should live in refs, prims meta functions should propagate the necessary properties, but they should assume that they are getting valid inputs. The checks on the inputs should happen in refs, where they can be traced to the necessary guards, or lead to RuntimeErrors during tracing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77500
Approved by: https://github.com/mruberry
2022-05-16 19:33:14 +00:00
Edward Z. Yang
2f602abf14 Register more decomps for meta.
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77362

Approved by: https://github.com/mruberry
2022-05-14 02:24:23 +00:00
Mike Ruberry
64c6a89bd6 [primTorch] reshape and view (#77220)
This PR makes the following changes...

Prims
- adds as_strided
- fixes errors in flatten meta

Testing
- enables view consistency checking (which can be opted out of, see issues below)
- adds reference inputs for view, reshape, and flatten
- adds error inputs for reshape

Refs
- adds as_strided, reshape, and view
- fixes an error in the flatten ref where it was not returning self on no-op
- fixes a bug in transpose where it was not retuning a view when the transposed tensor has 1 or fewer dims

Issues
- https://github.com/pytorch/pytorch/issues/77218
- https://github.com/pytorch/pytorch/issues/77216
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77220
Approved by: https://github.com/ngimel
2022-05-13 13:12:04 +00:00
Edward Z. Yang
d5ed73badd Make it possible to register decompositions to Meta key
Decompositions can be used to fill in meta support where necessary,
assuming the operations they decompose to support meta key.
This PR adds register_meta kwarg to register_decomposition that
optionally lets you register the meta to the C++ dispatch table
for meta tensors.  I use this to then get the meta function for
where and huber_loss for free.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77353

Approved by: https://github.com/mruberry
2022-05-12 23:20:16 +00:00
Jane Xu
b0bd5926c9 Fix prims lint broken on trunk due to land race (#77271)
Fixes upstream lint errors https://hud.pytorch.org/minihud?name_filter=Lint%20/%20lintrunner
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77271
Approved by: https://github.com/seemethere, https://github.com/malfet
2022-05-11 17:45:56 +00:00
Edward Z. Yang
0a14a4c280 Register prims as operators.
This makes prims look as if they were defined in native_functions.yaml
but they're still all written in Python.  You now need to give a full
schema string for your prims.  The returned prim object is now
torch.ops.prim overload (prims are not allowed to be overloaded,
so we return the overload, not the overload packet, for speed.)

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77117

Approved by: https://github.com/mruberry, https://github.com/albanD
2022-05-11 16:38:14 +00:00
Mike Ruberry
bb8baea932 [primTorch] flatten, squeeze, unsqueeze... (#77043)
This PR ...

Makes the following testing changes:

- Updates stride testing in test_python_reference_consistency to only check strides of dimensions with length > 1
- Creates reference inputs for reshape
- Creates reference inputs for chunk
- Extends the sample inputs for unsqueeze
- Extends the sample inputs for stack -- test_conj_view and test_neg_view are now xfailed
  - https://github.com/pytorch/pytorch/issues/77046

Makes the following architecture changes:
- Adds the refs.special (sub)module
- Adds the refs.nn.functional (sub)module

Adds the following prims:
- expand_dims
- view_of
- rev
- clone

Adds the following references:
  -  flatten
  - squeeze
  - unsqueeze
  - special.i0e
  - special.i1e
  - logical_or
  - logical_and
  - isclose
  - flip
  - stack
  - nn.functional.elu
  - chunk
  - clone
  - narrow

Identifies the following bugs in PyTorch today:
- https://github.com/pytorch/pytorch/issues/77054
- https://github.com/pytorch/pytorch/issues/77055

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77043
Approved by: https://github.com/ngimel
2022-05-09 11:24:55 +00:00
Mike Ruberry
c031643e39 Adds decorators for Python References and extends Python Reference testing (#76945)
This PR does the following...

Tests:
- fixes test_type_promotion in test_binary_ufuncs to correctly generate scalar cpu tensors
- fixes test_python_reference_consistency to use the Python Reference's reference inputs
- extends Python reference testing to test_conj_view, test_neg_view, and test_neg_conj_view
- adds a NaN propagation sample input for elementwise unary and binary operations
- fixes the UnaryUfuncInfo class to properly register its reference inputs
- Updates the Python Reference OpInfos to skip error inputs when their behavior on scalar inputs is inconsistent with their reference operators

Code organization:
- moves elementwise type promotion functionality to prims.utils

Prims & Refs:
- fixes scalar cpu tensor handling by having them pass through broadcasting and device and shape checks
- adds two decorators, `elementwise_type_promotion_wrapper` and `out_wrapper`, the former allows for elementwise type promotion to be automated and the latter automatically adds the out kwarg and handles it properly

cc @ezyang who also had some thoughts on cpu scalar tensor handling
cc @chillee -- might want to use this new decorator as we converge decompositions and references
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76945
Approved by: https://github.com/ngimel
2022-05-07 03:42:24 +00:00
Edward Z. Yang
f2eed9400d Register PrimTorch refs as decompositions.
For the most part, PrimTorch refs have the same signature as their
ATen equivalents.  I modify most PrimTorch refs to register themselves
as decompositions, using the prim name they wrap to find the aten name
(except for a few cases where the prim/aten names mismatch).  There are
some exclusions, falling into one of two categories:

- The torch equivalent was already implemented as a CompositeImplicitAutograd
  decomposition in C++

- The ref doesn't support enough features (e.g., the real deal has more
  kwargs / overloads than are currently implemented)

PrimTorch refs are written as a single function that supports all
overloads, and this style is convenient for cases where we have a bundle
of overloads for what morally is a single overload with a Union type
on an argument (which we ought to have supported in
native_functions.yaml but blah); to support registering a single decomp
for all the overloads, we modify register_decomposition to register
to ALL overloads if you pass it an overload packet.  This is technically
BC breaking but no tests started failing because of it.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76835

Approved by: https://github.com/Chillee, https://github.com/mruberry
2022-05-06 20:11:45 +00:00
Horace He
e9f34931ef Add some shape decomps (t, transpose, rot90, stack)
Also fixes xlogy (turns out the only thing it was missing was a type cast annotation! nice!)

I also renamed `canonicalize_idx` => `canonicalize_dim` (to align with `canonicalize_dims`) and fixed a bug in it (cc: @mruberry)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76873
Approved by: https://github.com/mruberry
2022-05-06 02:40:57 +00:00
Natalia Gimelshein
1c776d209c Adds amax and amin references
Also extends reference testing to error inputs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76855
Approved by: https://github.com/mruberry
2022-05-05 15:53:09 +00:00
Mike Ruberry
b557e102d8 Fixes prim type promotion and updates type promotion testing
This PR fixes prim elementwise type promotion, tests elementwise binary references using `test_type_promotion` in the elementwise binary test suite, and updates that test with additional cases for float x complex and scalar type promotion.

The following issues were discovered while working on this PR:

- https://github.com/pytorch/pytorch/issues/76806
- https://github.com/pytorch/pytorch/issues/76805
- https://github.com/pytorch/pytorch/issues/76804
- https://github.com/pytorch/pytorch/issues/76803
- https://github.com/pytorch/pytorch/issues/76801

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76809
Approved by: https://github.com/ngimel
2022-05-04 17:58:10 +00:00
Mike Ruberry
ef9f56eb0b [primTorch] slice and transpose & etc.
This PR...

Adds the following prims:
- slice
- slice_in_dim
- transpose

Adds the following refs:
- cat
- permute
- transpose
- swap_axes (alias for transpose)
- tensor_split

Makes the following test improvements:
- adds reference inputs for torch.permute
- adds a NumPy reference for torch.permute
- adds reference inputs for torch.cat

Fixes the following bugs:
- adds support for scalars to the min and max prims

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76727
Approved by: https://github.com/ngimel
2022-05-04 05:38:33 +00:00
Natalia Gimelshein
c51b53d4ef [WIP] sum reference
Per title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76714
Approved by: https://github.com/mruberry
2022-05-04 02:50:00 +00:00
Mike Ruberry
c9bd73878a adds elementwise opinfos and unary references, extends to out testing
This PR makes the following changes:

Prims:
- igamma and igammac are now correctly listed as elementwise binary operations, not elementwise unary operations
- elementwise prims now must specify their type promotion kind (this is currently unused)

Refs:
- complexhalf is now handled by opmath-style type promotion
- adds references for: abs, acos, acosh, asin, atan, ceil, cos, cosh, digamma, erf, erfinv, erfc, exp, expm1, isfinite, isnan, lgamma, log, log1p, neg, reciprocal, sign, sin, sinh, sqrt, square, tan, igamma, igammac
- adds "complex to float" and "bool to long" type promotion kinds
- updates out behavior to warn when resizing a non-empty tensor, consistent with current ops
- updates the elementwise unary reference template with type promotion

Tests:
- fixes torch.pow's OpInfo to correctly specify it only supports one scalar input, not two
- fixes elementwise binary reference inputs to not attempt generating certain tensors in complex half (for now, cc @kshitij12345)
- adds OpInfos for the following Python references: abs, acos, acosh, asin, atan, ceil, cos, cosh, digamma, erf, erfinv, erfc, exp, expm1, isfinite, isnan, lgamma, log, log1p, neg, reciprocal, round, sign, sin, sinh, sqrt, square, tan, atan2, bitwise_and, bitwise_left_shift, bitwise_or, bitwise_xor, eq, float_power, ge, gt, igamma, igammac, le, lt, maximum, minimum, mul, ne, nextafter, pow, sub, true_divide

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76647
Approved by: https://github.com/ngimel
2022-05-02 14:23:05 +00:00
Mike Ruberry
fe1968dea0 [primTorch] Prototype nvFuser integration and test_prims.py
This adds prototype nvFuser integration for the following prims:

- broadcast_in_dim
- convert_element_type
- add
- div
- ge
- gt
- le
- lt
- mul

Adding it for additional prims supported by nvFuser's prototype Python frontend should be easy.

This also adds a new sugar to run operations using the ATen or nvFuser trace executors. For example:

```
def foo(a, b):
  return torch.add(a, b)

traced_foo = make_traced(foo)

a = torch.randn((1, 2, 3, 4, 5), device='cuda')
b = torch.randn((1, 2, 3, 4, 5), device='cuda')
result = traced_foo(a, b, executor='nvfuser')
```

Currently only operations with tensor inputs and one tensor output are supported, and the operation must be composed exclusively of reference or prim operations.

Finally, this adds a new test, test_prims.py, that just tests the broadcast_in_dim prim for now. In the future we'll likely have OpInfos for each prim, but we'll need a reference implementation of broadcast_in_dim to make that interesting.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76560
Approved by: https://github.com/ngimel
2022-04-29 02:02:25 +00:00
Mike Ruberry
4048d4cdd2 [primTorch] Prototype tracer and elementwise unary reference opinfo class
Adds a prototype tracer with no caching support and the `ElementwiseUnaryPythonRefInfo` class. A reference for `floor` is added to test the latter, and the elementwise binary reference inputs are extended to also return noncontiguous inputs. The SampleInput transform operation has been updated to return an actual SampleInput instead of a tuple to facilitate uniform handling of (transformed) SampleInputs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76388
Approved by: https://github.com/ngimel
2022-04-27 14:40:21 +00:00
Mike Ruberry
28c3e0f77c Initial prims, references, and test architecture for them (#75095)
Summary:
This PR adds an initial set of experimental primitive operations and Python references that reimplement existing PyTorch operations using them. See https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-0/577 for additional context.

The following experimental primitives are added:

- Elementwise unary prims -- abs, acos, acosh, asin, atan, cos, cosh, bessel_i0e, bessel_i1e, cbrt, ceil, digamma, erf, erf_inv, erfc, exp, expm1, floor, igamma, igammac, is_finite, lgamma, log, log1p, neg, reciprocal, round, sign, sinh, sqrt, square, tan.
- Elementwise binary prims -- add, atan2, bitwise_and, bitwise_not, bitwise_or, bitwise_xor, div, eq, ge, gt, le, lt, max, min, mul, ne, nextafter, pow, rsqrt, shift_left, shift_right_arithmetic
- View prims -- brodcast_in_dim, collapse_view, split_dim, squeeze
- Shape prims -- collapse, concatenate, reshape
- Conditional prims -- select
- Data conversion & movement prims -- convert_element_type, device_put
- Inplace prims -- copy_to, resize

These primitives do not add any new functionality to PyTorch, but are intended to be the semantic building blocks for reference operators. We have tried to make them consistent with the operations in [jax.lax](https://jax.readthedocs.io/en/latest/jax.lax.html) where possible (because PyTorch prefers being consistent with other frameworks), although there are key differences between these prims and operations in jax.lax. Most notably is that these prims model view semantics and inplace operations.

In addition to these primitives the following elementwise binary Python references are added:

- Elementwise binary Python references -- add, atan2, bitwise_and, bitwise_left_shift, bitwise_or, bitwise_right_shift, bitwise_xor, eq, float_power, ge, gt, le, lt, maximum, minimum, mul, ne, nextafter, pow, sub, true_divide
- Conditional Python references - where
- Data conversion & movement references - copy_to

A Python reference implements the same behavior as its corresponding PyTorch operator (excepting slight numerical differences, bug fixes, and in some cases additional features).

The start of an OpInfo-based test architecture for these references is also included in this PR. A new list, `python_ref_db`, is added to `common_methods_invocations.py`. This list introduces the new `ElementwiseBinaryPythonRefInfo`, which inherits input arguments from the original operators' OpInfo, allows them to be overridden, and then constructs the OpInfo for the Python reference using the (potentially modified) arguments. OpInfo-based tests can opt-into testing references by including this new list in the Sequence passed to the `ops` decorator.

cc ngimel csarofeen kevinstephano Lezcano

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75095

Reviewed By: ngimel

Differential Revision: D35888004

Pulled By: mruberry

fbshipit-source-id: 21e77c4456c2a02113367d4bdae168a3a2f33f25
(cherry picked from commit 1d5bcfa99d4e8cf36f60642803a0bfca50e2ea4e)
2022-04-25 09:57:20 +00:00