Commit Graph

82 Commits

Author SHA1 Message Date
PyTorch MergeBot
66460c4a6a Revert "Fixes maybe_broadcast to actually broadcast only when needed (#79298)"
This reverts commit 1cb1c2c08c.

Reverted https://github.com/pytorch/pytorch/pull/79298 on behalf of https://github.com/suo due to Broke FakeTensor tests on master, see: 1cb1c2c08c
2022-06-11 23:36:18 +00:00
Mike Ruberry
1cb1c2c08c Fixes maybe_broadcast to actually broadcast only when needed (#79298)
Adds a `same_shape` util and updates maybe_broadcast to use it; previously maybe_broadcast was always broadcasting because its equality check was always failing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79298
Approved by: https://github.com/ezyang
2022-06-11 22:04:47 +00:00
kshitij12345
5e656eaae5 [refs] ravel (#78421)
As per title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78421
Approved by: https://github.com/mruberry
2022-06-10 20:20:13 +00:00
kshitij12345
3d77017674 [primTorch] refs: masked_fill (#78132)
TODO

* [x] Add error inputs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78132
Approved by: https://github.com/mruberry
2022-06-10 20:19:48 +00:00
PyTorch MergeBot
fefff54cad Revert "Revert "Revert "Added {logical_not, trace} refs, moved logical ops to use method overloads"""
This reverts commit a2d2981e8e.

Reverted https://github.com/pytorch/pytorch/pull/79224 on behalf of https://github.com/suo due to broke lots of things a2d2981e8e
2022-06-10 04:40:43 +00:00
Horace He
a2d2981e8e Revert "Revert "Added {logical_not, trace} refs, moved logical ops to use method overloads""
This reverts commit d67309aefb.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79224

Approved by: https://github.com/mruberry
2022-06-10 03:07:14 +00:00
Edward Z. Yang
b18ba7e036 Properly setup __name__ on refs functions.
My hands hurt now.  Yes I could have added type annotations if you care do it
yourself.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79178

Approved by: https://github.com/mruberry
2022-06-09 20:17:48 +00:00
PyTorch MergeBot
d67309aefb Revert "Added {logical_not, trace} refs, moved logical ops to use method overloads"
This reverts commit 64b6bd8c1e.

Reverted https://github.com/pytorch/pytorch/pull/79000 on behalf of https://github.com/malfet due to Introduces test failure, see https://hud.pytorch.org/pr/79000
2022-06-09 13:11:23 +00:00
Horace He
64b6bd8c1e Added {logical_not, trace} refs, moved logical ops to use method overloads
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79000

Approved by: https://github.com/ezyang
2022-06-09 07:16:36 +00:00
Horace He
dc11a5642d Improved stack ref and added more decomposition annotations
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78994

Approved by: https://github.com/mruberry
2022-06-09 03:20:28 +00:00
Horace He
c3531c9bce Ported roll to use torch ops and added as a decomposition
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78991

Approved by: https://github.com/mruberry
2022-06-09 03:20:28 +00:00
PyTorch MergeBot
8ce310b943 Revert "Revert "moved logit to use torch ops instead of refs + added …a couple more decompositions"" (#79082)
cc: @osalpekar
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79082
Approved by: https://github.com/eellison
2022-06-08 01:44:53 +00:00
PyTorch MergeBot
7d192d48d2 Revert "moved logit to use torch ops instead of refs + added a couple more decompositions"
This reverts commit 1d9f445b5d.

Reverted https://github.com/pytorch/pytorch/pull/78984 on behalf of https://github.com/osalpekar due to broke some jobs, like meta functorch builds
2022-06-07 21:51:41 +00:00
Horace He
1d9f445b5d moved logit to use torch ops instead of refs + added a couple more decompositions
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78984

Approved by: https://github.com/ezyang
2022-06-07 05:34:05 +00:00
Horace He
69778ee4eb Ported nn.functional functions to use torch calls instead of ref calls
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78978

Approved by: https://github.com/ezyang
2022-06-07 05:09:05 +00:00
Sergii Dymchenko
5a4325955c Correct var_mean in _refs/__init__.py (#78971)
I've found this by looking at https://lgtm.com/projects/g/pytorch/pytorch/

Maybe there is value in lgtm.com integration, as proposed in https://github.com/pytorch/pytorch/issues/78425
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78971
Approved by: https://github.com/ngimel
2022-06-07 00:49:30 +00:00
Horace He
e675dbadc4 Ported gelu decomp to ref (#78697)
Ugh... these are actually so painful to write without operator overloading lol.

Decided to just utilize operator overloading, and xfail the ref tests for now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78697
Approved by: https://github.com/mruberry
2022-06-06 22:30:20 +00:00
Kshiteej K
c461d8a977 [primTorch] refs: hsplit, vsplit (#78418)
As per title

TODO:
* [x] Add error inputs (already exist)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78418
Approved by: https://github.com/mruberry
2022-06-06 19:54:05 +00:00
kshitij12345
17f0c3be2e [primtorch] refs: {h, v}stack (#78614)
As per title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78614
Approved by: https://github.com/mruberry
2022-06-06 14:04:00 +00:00
kshitij12345
030f721b51 [primTorch] refs: isclose - throw error (#78922)
Just checked the condition but didn't throw.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78922
Approved by: https://github.com/mruberry
2022-06-06 13:56:26 +00:00
Horace He
080cf84bed Reland hardtanh ref (again) (#78914)
Fixes land race between 823ddb6e87 and Ed's stack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78914
Approved by: https://github.com/wanchaol
2022-06-06 09:39:01 +00:00
PyTorch MergeBot
ddf1930734 Revert "reland Hardtanh ref (#78894)"
This reverts commit 823ddb6e87.

Reverted https://github.com/pytorch/pytorch/pull/78894 on behalf of https://github.com/suo due to this caused unexpected successes on master (lol), search test_python_ref_meta__refs_nn_functional_hardtanh_cpu_bfloat16: 823ddb6e87"`
2022-06-06 03:59:53 +00:00
Horace He
823ddb6e87 reland Hardtanh ref (#78894)
Reland of https://github.com/pytorch/pytorch/pull/78689

cc: @kit1980
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78894
Approved by: https://github.com/kit1980
2022-06-06 02:09:31 +00:00
Edward Z. Yang
99882fc492 Make check() strongly typed, fix erroneous call sites
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78896

Approved by: https://github.com/Lezcano, https://github.com/anjali411
2022-06-05 23:10:55 +00:00
PyTorch MergeBot
e6cc2e8d38 Revert "Ported hardtanh decomposition to ref (#78689)"
This reverts commit 484282a6fd.

Reverted https://github.com/pytorch/pytorch/pull/78689 on behalf of https://github.com/kit1980 due to test_meta_nn_functional_hardtanh_cuda_float32 failed on both PR and trunk, see 484282a6fd
2022-06-05 17:46:54 +00:00
Edward Z. Yang
587efdb5fa Replace TensorMeta with FakeTensor
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78836

Approved by: https://github.com/albanD, https://github.com/mruberry
2022-06-05 11:51:27 +00:00
Horace He
484282a6fd Ported hardtanh decomposition to ref (#78689)
One note:

The logic for handling scalar boundary conditions seems to be a bit different than other ops - I simply copied the ATen logic (https://github.com/pytorch/pytorch/blob/hardtanh_ref/aten/src/ATen/native/Activation.cpp#L370). Not sure if it's an inconsistency we should fix.

Will add error opinfo after figuring out the scalar boundary condition stuff.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78689
Approved by: https://github.com/mruberry
2022-06-05 11:41:23 +00:00
jjsjann123
afe8a25eb0 [primTorch] Adds dsplit/dstack reference (#78696)
Added references _refs.dsplit & _refs.dstack and PythonRefInfo tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78696
Approved by: https://github.com/mruberry
2022-06-04 03:46:28 +00:00
Edward Z. Yang
54c99a9e1d relu ref
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78601

Approved by: https://github.com/ngimel
2022-06-04 02:18:56 +00:00
Horace He
1ea4075bda Ported t decomp to become a ref (#78686)
Also added an error input for `t`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78686
Approved by: https://github.com/mruberry
2022-06-03 01:16:20 +00:00
Ivan Yashchuk
48256f3cbb Reference implementations for rot90, roll, atleast_1d,2d,3d (#78080)
This PR adds the following references:

- `rot90`
- `roll`
- `atleast_1d`
- `atleast_2d`
- `atleast_3d`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78080
Approved by: https://github.com/mruberry
2022-06-02 09:05:11 +00:00
jjsjann123
fea909b43e [primTorch] Adds broadcast_shapes reference (#78612)
1. Added references `_refs.broadcast_shapes`
2. Added OpInfo test for `torch.broadcast_shapes`

A few minor changes:
- `test_python_ref_meta` and `_ref_test_helper` update to avoid non-tensor outputs
- type annotation update for `_resize_meta`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78612
Approved by: https://github.com/mruberry
2022-06-02 08:56:37 +00:00
Xiang Gao
b651148fc3 remove prims::square (#78627)
because it is just `x * x`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78627
Approved by: https://github.com/mruberry
2022-06-02 02:18:17 +00:00
PyTorch MergeBot
2d5eac48d5 Revert "Reference implementations for rot90, roll, atleast_1d,2d,3d (#78080)"
This reverts commit 96c134854d.

Reverted https://github.com/pytorch/pytorch/pull/78080 on behalf of https://github.com/malfet due to as it broke XLA on trunk (see https://github.com/pytorch/pytorch/runs/6678429656?check_suite_focus=true ) and the same pattern were observable on PR CI https://github.com/pytorch/pytorch/runs/6672733779?check_suite_focus=true
2022-06-01 16:52:25 +00:00
jjsjann123
7ea9c6edc2 [primTorch] Adds broadcast_to, column_stack references (#78416)
1. Added references for the two ops;
2. Inherited original operators' OpInfo tests;

TODO for future PR:
adding primTorch references for `dsplit` and `dstack`. <- Those two should use `atleast_3d` which is in a different packet right now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78416
Approved by: https://github.com/mruberry
2022-06-01 04:16:54 +00:00
Ivan Yashchuk
96c134854d Reference implementations for rot90, roll, atleast_1d,2d,3d (#78080)
This PR adds the following references:

- `rot90`
- `roll`
- `atleast_1d`
- `atleast_2d`
- `atleast_3d`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78080
Approved by: https://github.com/mruberry
2022-05-31 20:36:01 +00:00
Edward Z. Yang
eee2aa14a6 Register std_mean ref as a decomposition
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78468

Approved by: https://github.com/ngimel
2022-05-31 18:59:16 +00:00
Edward Z. Yang
b7215de32f prod ref
It turns out the prim is implemented incorrectly as torch.prod does not accept
a dim list, so I added a little stub for this.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78461

Approved by: https://github.com/ngimel
2022-05-31 14:18:49 +00:00
Edward Z. Yang
e562ed0964 Register PrimTorch sum as a decomposition.
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78460

Approved by: https://github.com/ngimel
2022-05-31 14:18:49 +00:00
Jason Ansel
dabf8f0569 Populate the torch._decomp table on import (#78476)
#78041 broke TorchInductor, because of:
```
>>> from torch import _decomp
>>> import torch
>>> _decomp.get_decompositions([torch.ops.aten.leaky_relu])
{}
>>> import torch._refs.nn.functional
>>> _decomp.get_decompositions([torch.ops.aten.leaky_relu])
{<OpOverload(op='aten.leaky_relu', overload='default')>: <function leaky_relu at 0x7f5a39b56c10>, <OpOverload(op='aten.leaky_relu', overload='out')>: <function leaky_relu at 0x7f5a39b56c10>}
```

cc @Chillee

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78476
Approved by: https://github.com/Chillee
2022-05-31 03:46:38 +00:00
Ryan Spring
2df1da09e1 Add Elementwise unary ops 4 references (#78216)
Add reference implementations for `nan_to_num, positive, sigmoid, signbit, tanhshink`
Add prims for `minimum_value(dtype)` and `maximum_value(dtype)`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78216
Approved by: https://github.com/mruberry
2022-05-27 21:55:34 +00:00
kshitij12345
8bd8f62812 [primTorch] refs: margin_ranking_loss, hinge_embedding_loss (#78057)
Refs for `nn.functional.margin_ranking_loss` and `nn.functional.hinge_embedding_loss`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78057
Approved by: https://github.com/mruberry
2022-05-26 21:01:57 +00:00
Aidyn-A
31016eb81e [primTorch] Elementwise Binary Ops I (#78023)
This PR is a result of collaboration with @rdspring1 and @mruberry on primTorch.

It adds the following prims:
- `fmax`
- `fmin`
- `fmod`

And adds the following refs:
- `fmax`
- `fmin`
- `fmod`
- `logical_xor`

The work is in progress as there are some tests that fail.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78023
Approved by: https://github.com/mruberry
2022-05-26 20:22:27 +00:00
Gao, Xiang
5ecd30e857 [primTorch] Rename is_finite->isfinite (#78211)
`isfinite` sounds like a better name, because PyTorch, C++, numpy all have this name instead of `is_finite`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78211
Approved by: https://github.com/ngimel, https://github.com/mruberry
2022-05-26 16:17:51 +00:00
Edward Z. Yang
a1765f0176 addr ref
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78014

Approved by: https://github.com/ngimel
2022-05-25 01:40:11 +00:00
Ryan Spring
bb4653e736 Add i0, i1, zeta refs (#78111)
Add reference implementations for i0, i1, zeta
Add prim operations for i0, i1, zeta
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78111
Approved by: https://github.com/mruberry
2022-05-23 21:33:56 +00:00
Horace He
ea5d01e629 [Primtorch] Tried porting leaky_relu into a ref (#78041)
Feels good to delete it from `torch._decomps`. This is mainly to clarify the process for me -

Seems like there's still some components missing of the `torch <-> refs` mapping? For example, seems like methods don't work yet for mapping from torch <-> refs, and neither do the meta tests? (cc: @ezyang).

If I replace the `torch` with `refs`, then the tests seem to pass.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78041
Approved by: https://github.com/mruberry
2022-05-23 18:00:21 +00:00
Mike Ruberry
2738405a76 [primTorch] Adds any, all, equal, item references (#78072)
This PR adds the item, equal, any, and all references.

While doing this I found the following issues:
- https://github.com/pytorch/pytorch/issues/78070
- https://github.com/pytorch/pytorch/issues/78071

And I fixed a bug where the `convert_element_type` prim could not convert tensors requiring grad to datatypes that don't require grad.

Creating the item reference required adding item as a prim, but per @ngimel's suggestion I removed the prims for any and all and implemented them as references, so this is net negative one prim.

Reference OpInfos are added for any and all, but item and equal don't even have regular OpInfos.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78072
Approved by: https://github.com/ngimel
2022-05-23 12:49:04 +00:00
Mike Ruberry
d4345ed0a6 [primTorch] Adds random operations (#78026)
This PR...

**Issues Found**
- https://github.com/pytorch/pytorch/issues/78058
- https://github.com/pytorch/pytorch/issues/78054
- https://github.com/pytorch/pytorch/issues/78053
- https://github.com/pytorch/pytorch/issues/78050
- https://github.com/pytorch/pytorch/issues/77932

**Testing**
- disables stride consistency checks in test_ops and test_meta pending resolution of https://github.com/pytorch/pytorch/issues/78050
- skips chalf in reference tests (addressing https://github.com/pytorch/pytorch/issues/78054)
- splits test test_python_reference_consistency in one test for the ctx where torch.foo is torch.foo, and another for when torch.foo is refs.foo
- updates test names to be more natural and consistent:
  - test_python_reference_errors -> test_python_ref_errors
  - test_python_reference_consistency -> test_python_ref and test_python_ref_torch_fallback
  - test_python_reference_meta_functions -> test_python_ref_meta
  - test_reference_testing -> test_numpy_ref
- updates test_python_ref and test_python_ref_torch_fallback to check that the reference is more accurate than the torch op if the reference and torch op results are not close, a warning is raised when this occurs (addressing https://github.com/pytorch/pytorch/issues/77687)
- adds reference inputs for broadcast_tensors
- Updates the "fill_" OpInfo to "fill", adding a NumPy reference and making it an elementwise unary operator
- Adds 1D no element sample inputs to the cat OpInfo and updates the NumPy reference to handle them and type promotion correctly
- Adds reference inputs for elementwise ternary operations, like clamp
- Adds a NumPy reference for clamp
- Adds reference inputs to where's OpInfo
- Makes softplus an elementwise unary OpInfo
- Removes the great majority of Python reference OpInfo skips and xfails due to the above test changes
- Adds Python reference OpInfos for fill, dropout, clamp, broadcast_tensors, and where

**Prims**
- adds the fill, empty_strided, and uniform prims
- removes the empty, empty_like, full, and full_like prims -- these are now references that use empty_strided and fill
- renames the "concatenate" and "select" prims to "cat" and "where", respectively, to be consistent with PyTorch
- extends the `_elementwise_meta` operation to accepts tensors that don't participate in type promotion, like the `cond` tensor in `where`
- fixes a bug in the stride propagation of broadcast_in_dim
- moves some error checks from prims.cat to prims.where to refs.cat and refs.where, respectively, consistent with our new policy of doing as much error checking in the ref as possible

**Utils**
- adds the canoicalize_device, extract_shape, and extract_shape_from_varargs helpers
- adds the elementwise_unary_scalar_wrapper -- this allows elementwise unary operators to take and return scalar values (ex. refs.sin(1) will return .84...)

**Refs**
- adds the fill, broadcast_tensors, clamp, empty_strided, ones, zeros, and uniform references
- adds the nn.functional.dropout reference
- fixes refs.cat to handle 1D tensors with no inputs consistent with eager mode
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78026
Approved by: https://github.com/ngimel
2022-05-23 01:56:28 +00:00
PyTorch MergeBot
acfbc16b1c Revert "[primTorch] Adds random operations (#78026)"
This reverts commit 043cf1f9c7.

Reverted https://github.com/pytorch/pytorch/pull/78026 on behalf of https://github.com/suo due to This broke trunk: 043cf1f9c7
2022-05-22 18:11:14 +00:00