Commit Graph

27 Commits

Author SHA1 Message Date
Bairen Yi
b6672b10e1 Fix incorrect decomposition for native_dropout (#77933)
Quick sanity check: it should be identity function if p=0.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77933
Approved by: https://github.com/Chillee
2022-05-30 20:08:48 +00:00
Aidyn-A
31016eb81e [primTorch] Elementwise Binary Ops I (#78023)
This PR is a result of collaboration with @rdspring1 and @mruberry on primTorch.

It adds the following prims:
- `fmax`
- `fmin`
- `fmod`

And adds the following refs:
- `fmax`
- `fmin`
- `fmod`
- `logical_xor`

The work is in progress as there are some tests that fail.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78023
Approved by: https://github.com/mruberry
2022-05-26 20:22:27 +00:00
Horace He
ea5d01e629 [Primtorch] Tried porting leaky_relu into a ref (#78041)
Feels good to delete it from `torch._decomps`. This is mainly to clarify the process for me -

Seems like there's still some components missing of the `torch <-> refs` mapping? For example, seems like methods don't work yet for mapping from torch <-> refs, and neither do the meta tests? (cc: @ezyang).

If I replace the `torch` with `refs`, then the tests seem to pass.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78041
Approved by: https://github.com/mruberry
2022-05-23 18:00:21 +00:00
Horace He
4428218945 [primtorch] Added native_group_norm decomp (#78029)
cc: @jansel @bertmaher

More or less identical in spirit to the layer norm and batch norm ones.

One annoying thing about all 3 of these is that layer_norm has slightly different `mean/var` semantics than batch norm and group norm. After normalization, `layer_norm` keeps them unsqueezed (so they're something like [1, 5, 1, 1]) while batch norm and group norm squeeze out the 1-dims.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78029
Approved by: https://github.com/bertmaher
2022-05-21 08:07:02 +00:00
Edward Z. Yang
6b273444c4 Add logit ref; allow non-refs to be called in refs.
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77816

Approved by: https://github.com/mruberry
2022-05-21 02:35:14 +00:00
Horace He
64b4bb4b01 Fix meta tests on norm (and relanding norm fixes) (#77930)
Had a land race with meta tests.

Will also be relanding https://github.com/pytorch/pytorch/pull/77407
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77930
Approved by: https://github.com/malfet, https://github.com/ezyang
2022-05-20 23:15:53 +00:00
PyTorch MergeBot
03546e9c07 Revert "Fixed type promotion semantics for native_batch_norm and native_layer_norm (#77407)"
This reverts commit 70d80fb424.

Reverted https://github.com/pytorch/pytorch/pull/77407 on behalf of https://github.com/malfet due to as it broke meta tests ( I guess due to landrace), see 70d80fb424
2022-05-20 02:31:57 +00:00
Horace He
70d80fb424 Fixed type promotion semantics for native_batch_norm and native_layer_norm (#77407)
Originally, when these were written, they simply used the naive strategy of "upcast all inputs to floats, and downcast all inputs back". In addition to being... not quite what the kernels did, they also didn't capture some additional semantics. Namely, that the norms (except for layer norm on CPU! cc: @ngimel) return fp32 for the mean and rstd values.

Also, folks didn't like that I wrote `native_layer_norm` in terms of `native_batch_norm`. Which is fair - so I refactored the common logic into a `normalize` function.

cc: @jansel / @bertmaher , who've been looking at lowering layer norm/batch norm.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77407
Approved by: https://github.com/bertmaher
2022-05-19 17:11:47 +00:00
Edward Z. Yang
88c89c9eb9 log_sigmoid_forward out support; out_wrapper_multi
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77739

Approved by: https://github.com/mruberry
2022-05-19 14:43:35 +00:00
Edward Z. Yang
4941e72e40 Revert "Revert "Implement sym_sizes to create proper IR for sym ints representing tensor sizes (#76836)""
This reverts commit c35bd8d423.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77719

Approved by: https://github.com/Chillee, https://github.com/malfet
2022-05-18 18:40:57 +00:00
Mike Ruberry
580a053832 [primTorch] Enforces stride metadata (#77542)
This PR...

**Filed the Following Issues**
- https://github.com/pytorch/pytorch/issues/77553
- https://github.com/pytorch/pytorch/issues/77526
- https://github.com/pytorch/pytorch/issues/77600

**Testing**
- Updates test_dtypes to longer attempt to test the backward of sample inputs where no inputs require grad
- Adds a new test_python_reference_errors; it ensures the meta operations for references throw errors as expected
- Updates compare_tensor_meta to better handle CUDA devices, and (temporarily) restricts stride checking to the CUDA device type
- Elementwise unary and elementwise binary operators now have arbitrarily strided reference inputs
- Reference inputs for _like functions are added
- An OpInfo for torch.empty is added
- Reference inputs for torch.clone are added
- A NumPy reference for clone is added
- Adds OpInfos for refs.empty and refs.empty_like

**Prims**
- Renames the "max" and "min" prims have been renamed to "maximum" and "minimum," respectively, to better conform to their ATen names
- Adds the empty, empty_like, full, and full_like prims
- Fixes the elementwise meta function's stride propagation
- Fixes clone's meta function's stride propagation
- Fixes convert_element_type's meta's stride propagation
- Adds a (temporary) _to_dtype pprivate prim that casts a tensor while preserving its stride permutation
- Removes the _set prim comment
- Adds utils.compute_elementwise_output_strides, which computes the correct output strides for elementwise operations
- Corrects an issue where utils.make_contiguous_strides_for was creating the incorrect strides for tensors with no elements

**References**
- Adds the empty, empty_like, full, full_like, and ones_like refs
- Extends make_elementwise_unary_reference to accept an additional callable to perform extra input validation
- Adds an extra validation function to handle refs.neg(BoolTensor)
- Updates the isfinite ref to call ones_like when appropriate
- Models Python scalar handling for elementwise binary operations
- Added a 64 dim check for the amin and amax references
- opmath is now a flag that can be set separately for cpu and CUDA
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77542
Approved by: https://github.com/ezyang
2022-05-18 13:57:26 +00:00
PyTorch MergeBot
48581d74ad Revert "Add dispatch mode testing for meta tensors and other stuff"
This reverts commit c1cdb1216b.

Reverted https://github.com/pytorch/pytorch/pull/77477 on behalf of https://github.com/malfet
2022-05-18 02:56:48 +00:00
Edward Z. Yang
c1cdb1216b Add dispatch mode testing for meta tensors and other stuff
We don't have any coverage for meta tensor correctness for backwards
because torch function mode can only allow us to interpose on
Python torch API calls, but backwards invocations happen from C++.
To make this possible, I add torch_dispatch_meta test which runs the
tests with __torch_dispatch__

While doing this, I needed to generate fresh expected failure / skip
lists for the new test suite, and I discovered that my original
scaffolding for this purpose was woefully insufficient.  So I rewrote
how the test framework worked, and at the same time rewrote the
__torch_function__ code to also use the new logic.  Here's whats
new:

- Expected failure / skip is now done on a per function call basis,
  rather than the entire test.  This means that separate OpInfo
  samples for a function don't affect each other.

- There are now only two lists: expect failure list (where the test
  consistently fails on all runs) and skip list (where the test
  sometimes passes and fails.

- We explicitly notate the dtype that failed.  I considered detecting
  when something failed on all dtypes, but this was complicated and
  listing everything out seemed to be nice and simple.  To keep the
  dtypes short, I introduce a shorthand notation for dtypes.

- Conversion to meta tensors is factored into its own class
  MetaConverter

- To regenerate the expected failure / skip lists, just run with
  PYTORCH_COLLECT_EXPECT and filter on a specific test type
  (test_meta or test_dispatch_meta) for whichever you want to update.

Other misc fixes:

- Fix max_pool1d to work with BFloat16 in all circumstances, by making
  it dispatch and then fixing a minor compile error (constexpr doesn't
  work with BFloat16)

- Add resolve_name for turning random torch API functions into string
  names

- Add push classmethod to the Mode classes, so that you can more easily
  push a mode onto the mode stack

- Add some more skips for missing LAPACK

- Added an API to let you query if there's already a registration for
  a function, added a test to check that we register_meta for all
  decompositions (except detach, that decomp is wrong lol), and then
  update all the necessary sites to make the test pass.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77477

Approved by: https://github.com/zou3519
2022-05-18 00:18:34 +00:00
Horace He
8626f76555 Add trace and log_sigmoid_forward decomps (#77329)
Main question mark is that `log_sigmoid_forward` uses `acc_t` instead of `opmath_t` - not sure if we have a decorator today for that?

Glad to add one if we don't.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77329
Approved by: https://github.com/ezyang
2022-05-13 04:55:52 +00:00
Edward Z. Yang
d5ed73badd Make it possible to register decompositions to Meta key
Decompositions can be used to fill in meta support where necessary,
assuming the operations they decompose to support meta key.
This PR adds register_meta kwarg to register_decomposition that
optionally lets you register the meta to the C++ dispatch table
for meta tensors.  I use this to then get the meta function for
where and huber_loss for free.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77353

Approved by: https://github.com/mruberry
2022-05-12 23:20:16 +00:00
Horace He
c25bdeea26 Added logsumexp decomposition (#77219)
Pretty simple.

cc: @jansel who mentioned this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77219
Approved by: https://github.com/jansel
2022-05-12 02:01:31 +00:00
samdow
d694cf60fe add decomposition for nll_loss2d_backward (#77198)
Adds a decomposition for `nll_loss2d_backward`

This will let us actually run all the tests for jvpvjp ([see this functorch PR](https://github.com/pytorch/functorch/pull/792)). I confirmed locally that this made those tests pass too
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77198
Approved by: https://github.com/Chillee
2022-05-11 20:41:20 +00:00
Mike Ruberry
bb8baea932 [primTorch] flatten, squeeze, unsqueeze... (#77043)
This PR ...

Makes the following testing changes:

- Updates stride testing in test_python_reference_consistency to only check strides of dimensions with length > 1
- Creates reference inputs for reshape
- Creates reference inputs for chunk
- Extends the sample inputs for unsqueeze
- Extends the sample inputs for stack -- test_conj_view and test_neg_view are now xfailed
  - https://github.com/pytorch/pytorch/issues/77046

Makes the following architecture changes:
- Adds the refs.special (sub)module
- Adds the refs.nn.functional (sub)module

Adds the following prims:
- expand_dims
- view_of
- rev
- clone

Adds the following references:
  -  flatten
  - squeeze
  - unsqueeze
  - special.i0e
  - special.i1e
  - logical_or
  - logical_and
  - isclose
  - flip
  - stack
  - nn.functional.elu
  - chunk
  - clone
  - narrow

Identifies the following bugs in PyTorch today:
- https://github.com/pytorch/pytorch/issues/77054
- https://github.com/pytorch/pytorch/issues/77055

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77043
Approved by: https://github.com/ngimel
2022-05-09 11:24:55 +00:00
Mike Ruberry
c031643e39 Adds decorators for Python References and extends Python Reference testing (#76945)
This PR does the following...

Tests:
- fixes test_type_promotion in test_binary_ufuncs to correctly generate scalar cpu tensors
- fixes test_python_reference_consistency to use the Python Reference's reference inputs
- extends Python reference testing to test_conj_view, test_neg_view, and test_neg_conj_view
- adds a NaN propagation sample input for elementwise unary and binary operations
- fixes the UnaryUfuncInfo class to properly register its reference inputs
- Updates the Python Reference OpInfos to skip error inputs when their behavior on scalar inputs is inconsistent with their reference operators

Code organization:
- moves elementwise type promotion functionality to prims.utils

Prims & Refs:
- fixes scalar cpu tensor handling by having them pass through broadcasting and device and shape checks
- adds two decorators, `elementwise_type_promotion_wrapper` and `out_wrapper`, the former allows for elementwise type promotion to be automated and the latter automatically adds the out kwarg and handles it properly

cc @ezyang who also had some thoughts on cpu scalar tensor handling
cc @chillee -- might want to use this new decorator as we converge decompositions and references
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76945
Approved by: https://github.com/ngimel
2022-05-07 03:42:24 +00:00
Edward Z. Yang
f2eed9400d Register PrimTorch refs as decompositions.
For the most part, PrimTorch refs have the same signature as their
ATen equivalents.  I modify most PrimTorch refs to register themselves
as decompositions, using the prim name they wrap to find the aten name
(except for a few cases where the prim/aten names mismatch).  There are
some exclusions, falling into one of two categories:

- The torch equivalent was already implemented as a CompositeImplicitAutograd
  decomposition in C++

- The ref doesn't support enough features (e.g., the real deal has more
  kwargs / overloads than are currently implemented)

PrimTorch refs are written as a single function that supports all
overloads, and this style is convenient for cases where we have a bundle
of overloads for what morally is a single overload with a Union type
on an argument (which we ought to have supported in
native_functions.yaml but blah); to support registering a single decomp
for all the overloads, we modify register_decomposition to register
to ALL overloads if you pass it an overload packet.  This is technically
BC breaking but no tests started failing because of it.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76835

Approved by: https://github.com/Chillee, https://github.com/mruberry
2022-05-06 20:11:45 +00:00
Horace He
e9f34931ef Add some shape decomps (t, transpose, rot90, stack)
Also fixes xlogy (turns out the only thing it was missing was a type cast annotation! nice!)

I also renamed `canonicalize_idx` => `canonicalize_dim` (to align with `canonicalize_dims`) and fixed a bug in it (cc: @mruberry)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76873
Approved by: https://github.com/mruberry
2022-05-06 02:40:57 +00:00
Horace He
6917034afb Added logit/reciprocal decomps, fixed var for complex, moved type promotion logic to standardize on primtorch's
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76633
Approved by: https://github.com/ezyang
2022-05-04 21:29:52 +00:00
PyTorch MergeBot
ce63c53c9b Revert "Add binary_cross_entropy and trace decomp - fixed _log_softmax/_softmax dtype promotion semantics"
This reverts commit 8a3e9255ea.

Reverted https://github.com/pytorch/pytorch/pull/76670 on behalf of https://github.com/mruberry
2022-05-04 10:42:39 +00:00
Horace He
ed18181d83 Added gelu decomposition
^
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76763
Approved by: https://github.com/ezyang
2022-05-03 23:23:18 +00:00
Horace He
8a3e9255ea Add binary_cross_entropy and trace decomp - fixed _log_softmax/_softmax dtype promotion semantics
cc: @zou3519
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76670
Approved by: https://github.com/ezyang
2022-05-03 18:20:17 +00:00
Horace He
fb24614011 Port functorch decomps over and fix some tests
Still some stuff to fix up, will finish later.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76621
Approved by: https://github.com/ezyang
2022-05-01 08:48:48 +00:00
Edward Z. Yang
a3f10ec281 Move functorch decompositions to PyTorch
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76311

Approved by: https://github.com/Chillee
2022-04-30 16:47:53 +00:00