Commit Graph

73 Commits

Author SHA1 Message Date
Sherlock Huang
ee4b99cc3a Decomp for aten.dropout (#106274)
When exporting dropout with cpu tensor, we get following graph module
```
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[512, 10]):
            empty_memory_format: f32[512, 10] = torch.ops.aten.empty.memory_format([512, 10], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False, memory_format = torch.contiguous_format)
            bernoulli_p: f32[512, 10] = torch.ops.aten.bernoulli.p(empty_memory_format, 0.9);  empty_memory_format = None
            div_scalar: f32[512, 10] = torch.ops.aten.div.Scalar(bernoulli_p, 0.9);  bernoulli_p = None
            mul_tensor: f32[512, 10] = torch.ops.aten.mul.Tensor(arg0_1, div_scalar);  arg0_1 = div_scalar = None
            return (mul_tensor,)
```

In addition, if we export with eval() mode, we will have an empty graph.

However, when exporting with cuda tensor, we got
```
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[512, 10]):
            native_dropout_default = torch.ops.aten.native_dropout.default(arg0_1, 0.1, True);  arg0_1 = None
            getitem: f32[512, 10] = native_dropout_default[0];  native_dropout_default = None
            return (getitem,)
```
and exporting under eval() mode will still have a dropout node in graph.

This PR make exporting with CPU tensor also produce aten.native_dropout.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106274
Approved by: https://github.com/ezyang
2023-08-23 21:12:37 +00:00
Aaron Gokaslan
660e8060ad [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-22 23:16:38 +00:00
PyTorch MergeBot
d59a6864fb Revert "[BE]: Update ruff to 0.285 (#107519)"
This reverts commit 88ab3e4322.

Reverted https://github.com/pytorch/pytorch/pull/107519 on behalf of https://github.com/ZainRizvi due to Sorry, but this PR breaks internal tests. @ezyang, can you please hep them get unblocked? It seems like one of the strings was prob accidentally modified ([comment](https://github.com/pytorch/pytorch/pull/107519#issuecomment-1688833480))
2023-08-22 19:53:32 +00:00
Aaron Gokaslan
88ab3e4322 [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-20 01:36:18 +00:00
lezcano
2c5f96deac [Inductor] Make softshrink composite implicit (#107052)
The backward is pretty much equivalent to the one we had written

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107052
Approved by: https://github.com/peterbell10
ghstack dependencies: #107038, #107039, #107051
2023-08-14 21:01:50 +00:00
lezcano
3b1254e800 Make hardshrink's decomp composite implicit (#107039)
The generated code is the same
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107039
Approved by: https://github.com/peterbell10
ghstack dependencies: #107038
2023-08-14 21:01:50 +00:00
lezcano
45c7880486 Simplify some decompositions. (#107038)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107038
Approved by: https://github.com/peterbell10
2023-08-14 21:01:50 +00:00
Justin Chu
8a688277a2 [BE] Enable ruff's UP rules and autoformat dynamo / functorch and refs (#105432)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105432
Approved by: https://github.com/ezyang
2023-07-19 13:48:44 +00:00
Kurt Mohler
ee83c646bb Replace _prims_common.check with torch._check* (#103240)
This relands most of the changes from #102219 which were backed out by #103128. However, instead of removing `_prims_common.check`, it adds a warning and a comment mentioning that it will be removed in the future and `torch._check*` should be used instead. As mentioned in https://github.com/pytorch/pytorch/pull/103128#pullrequestreview-1466414415, `_prims_common.check` cannot yet be removed because of some internal usage

Part of #72948

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103240
Approved by: https://github.com/albanD
2023-06-21 00:46:17 +00:00
BowenBao
724a1ba2de Tidy __all__ under torch._refs (#103712)
- Added ops that were missing under `__all__`.
- Some misc changes to helper functions to make them private.
- Set correct `fn.__module__` for `fn` created by `_make_alias`, when called in another module.

All modification largely references results from a hacked version of `test_public_bindings::test_correct_module_names`.
By default `torch._refs` is not included in the test because it is technically a private package.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103712
Approved by: https://github.com/lezcano
2023-06-20 00:04:58 +00:00
Ivan Zaitsev
821493715c Back out "Remove check from _prims_common, replace with torch._check* (#102219)", Back out "Forwatd fix for D46427687" (#103128)
Test Plan: revertitparrot

Reviewed By: malfet

Differential Revision: D46506433

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103128
Approved by: https://github.com/malfet
2023-06-07 01:41:41 +00:00
Kurt Mohler
a84bb2709a Remove check from _prims_common, replace with torch._check* (#102219)
Part of #72948

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102219
Approved by: https://github.com/lezcano, https://github.com/albanD
2023-06-03 02:23:21 +00:00
PyTorch MergeBot
a7efa0ce35 Revert "Remove check from _prims_common, replace with torch._check* (#102219)"
This reverts commit fb79d43649.

Reverted https://github.com/pytorch/pytorch/pull/102219 on behalf of https://github.com/malfet due to Broke lint, see https://github.com/pytorch/pytorch/actions/runs/5158949959/jobs/9293466925 ([comment](https://github.com/pytorch/pytorch/pull/102219#issuecomment-1574245414))
2023-06-02 20:00:48 +00:00
Kurt Mohler
fb79d43649 Remove check from _prims_common, replace with torch._check* (#102219)
Part of #72948

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102219
Approved by: https://github.com/lezcano, https://github.com/albanD
2023-06-02 19:13:45 +00:00
vfdev-5
319a1cb4e5 [inductor] Replaced refs.op by torch.op in _refs/* (#102176)
Description:
- Replaced `refs.op` by `torch.op` in `_refs/*`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102176
Approved by: https://github.com/lezcano
2023-05-29 22:36:14 +00:00
vfdev-5
e3d97b6213 [inductor] Added smooth_l1_loss refs (#102077)
Added `smooth_l1_loss` to refs + tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102077
Approved by: https://github.com/lezcano, https://github.com/ngimel
2023-05-24 15:07:08 +00:00
BowenBao
60a68477a6 Bump black version to 23.1.0 (#96578)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96578
Approved by: https://github.com/ezyang
2023-03-15 06:27:59 +00:00
Fabio Rocha
1dbaa5c290 Use decompositions for some fallbacks introduced in #94039 (#94206)
In some cases, implements required inductor primitives.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94206
Approved by: https://github.com/jansel, https://github.com/ngimel
2023-02-14 09:31:30 +00:00
Edward Z. Yang
37fcc53096 Remove import cycle from torch._refs.nn.functional (#93948)
This makes it possible to import torch._refs from
torch._subclasses.fake_tensor

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93948
Approved by: https://github.com/albanD
2023-02-02 21:06:37 +00:00
Nikita Shulga
fd3a7264ae [MPS] Add group_norm[fwd+backward] and mean_var (take 2) (#91190)
Use Prims to implement group_norm, group_norm_backward and mean_var

Use `torch._ops.ops` instead of `torch.ops` in numerous subpackages in
order to be able to make them importable from `torch/backend/mps/__init__.py` as this alias is defined in
15af4b1cee/torch/__init__.py (L1095)
is executed last during init process.

Add `__all__` to `torch/backends/mps/__init__.py` as well as alias all imports as private

Add `TestNNMPS.test_group_norm_backward` that validates no NaNs are generated during the backward pass

Fixes https://github.com/pytorch/pytorch/issues/88331
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91190
Approved by: https://github.com/albanD
2022-12-22 08:54:37 +00:00
PyTorch MergeBot
645eda0a00 Revert "[MPS] Add group_norm[fwd+backward] and mean_var (#91190)"
This reverts commit 371716eb36.

Reverted https://github.com/pytorch/pytorch/pull/91190 on behalf of https://github.com/kit1980 due to Broke test_correct_module_names because of underscore _ops
2022-12-21 19:37:43 +00:00
Nikita Shulga
371716eb36 [MPS] Add group_norm[fwd+backward] and mean_var (#91190)
Use Prims to implement group_norm, group_norm_backward and mean_var

Use `torch._ops.ops` instead of `torch.ops` in numerous subpackages in
order to be able to make them importable from `torch/backend/mps/__init__.py` as this alias is defined in
15af4b1cee/torch/__init__.py (L1095)
is executed last during init process.

Depends on https://github.com/pytorch/pytorch/pull/91203

Fixes https://github.com/pytorch/pytorch/issues/88331
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91190
Approved by: https://github.com/albanD
2022-12-21 17:33:27 +00:00
Nikita Shulga
c8546c930f [BE] Use aten global in torch._refs (#91189)
Similar to pattern used in `torch._decomp`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91189
Approved by: https://github.com/ngimel
2022-12-21 02:28:51 +00:00
soulitzer
98a9235dce Fix prelu ref when a.ndim < 2 (#89809)
Fixes https://github.com/pytorch/pytorch/issues/89560

Previously the test case for "input is 1-D or scalar + weight is not scalar" did not exist; adding it introduced some failures:
- forward AD (fixed in this PR)
- vmap (filed https://github.com/pytorch/pytorch/issues/89895)
- ref/meta (fixed this PR, though this also regresses nvFuser support)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89809
Approved by: https://github.com/ngimel
2022-12-12 23:55:31 +00:00
Yanbo Liang
25f39c1bce Fix uniform ref implementation (#90094)
Fixes https://github.com/pytorch/torchdynamo/issues/1954

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90094
Approved by: https://github.com/ngimel
2022-12-06 21:28:17 +00:00
lezcano
154e58c032 Add most in-place references/decompositions (#88117)
We add most in-place references in a generic way. We also implement a
wrapper to implement the annoying interface that `nn.functional`
nonlinearities have.

We fix along the way a couple decompositions for some non-linearities by
extending the arguments that the references have.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88117
Approved by: https://github.com/mruberry
2022-11-18 14:59:46 +00:00
lezcano
ce0e22a81a Fix names of some reference functions (#88115)
The `__name__` field of some binary reference functions was wrong. We
fix this to be consistent with unary reference functions. In the future,
we should probably make the binary reference wrapper return a wrapper
itself to avoid all those calls to `partial`.

This change helps performing some homogeneous treatment of functions by
their name.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88115
Approved by: https://github.com/mruberry
2022-11-18 14:59:43 +00:00
Kazuaki Ishizaki
1cd6ebe095 Fix typos in messages under torch (#89049)
This PR fixes typos of messages in `.py` files under torch directory.
Only in `torch/onnx/symbolic_opset16.py`, fix a typo in comment to make the operator name correct.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89049
Approved by: https://github.com/lezcano
2022-11-17 04:18:14 +00:00
Khushi Agrawal
f1a5044de0 [primTorch] _refs & opinfo alpha_dropout (#87989)
Add _refs and OpInfo for `nn.functional.alpha_dropout`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87989
Approved by: https://github.com/mruberry
2022-11-14 18:18:45 +00:00
Ryan Spring
534ae6ae47 [primTorch] Implement group norm reference (#87054)
Add group norm reference
Split from #81191
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87054
Approved by: https://github.com/mruberry
2022-11-11 01:08:20 +00:00
Nikita Karetnikov
1b8af28fe8 [primTorch] Add refs for softmax, softmin, log_softmax (#84956)
cc @ezyang @mruberry @ngimel @Lezcano @fdrocha
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84956
Approved by: https://github.com/lezcano, https://github.com/mruberry
2022-10-20 12:29:04 +00:00
PyTorch MergeBot
cd21613526 Revert "[primTorch] Add refs for softmax, softmin, log_softmax (#84956)"
This reverts commit c09ca93e47.

Reverted https://github.com/pytorch/pytorch/pull/84956 on behalf of https://github.com/ZainRizvi due to This is causing the MPS test test_output_match_log_softmax_with_dtype_cpu_float32 (__main__.TestConsistencyCPU) to fail
2022-10-19 20:36:55 +00:00
Nikita Karetnikov
c09ca93e47 [primTorch] Add refs for softmax, softmin, log_softmax (#84956)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84956
Approved by: https://github.com/lezcano, https://github.com/mruberry
2022-10-19 18:45:40 +00:00
Ryan Spring
847ded6db3 [primTorch] Implement NLL loss reference (#81128)
Add Reference:
- nll_loss

Depends on:
- expand https://github.com/pytorch/pytorch/pull/79820
- advance indexing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81128
Approved by: https://github.com/mruberry
2022-10-17 06:20:31 +00:00
Nikita Karetnikov
d56017a14f [primTorch] Add ref for triplet_margin_loss, improve triplet_margin_with_distance_loss (#85614)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85614
Approved by: https://github.com/lezcano, https://github.com/mruberry
2022-10-12 18:37:58 +00:00
Nikita Karetnikov
8dd45424ea [primTorch] Add ref for huber_loss and error inputs (#85041)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85041
Approved by: https://github.com/lezcano, https://github.com/mruberry
2022-09-28 19:56:17 +00:00
Nikita Karetnikov
85b889fa5f [primTorch] Add ref for poisson_nll_loss (#83805)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83805
Approved by: https://github.com/Lezcano, https://github.com/ngimel
2022-08-31 17:39:34 +00:00
Nikita Karetnikov
305af90d0f [primTorch] Add docstring and promotion for l1_loss ref (#83803)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83803
Approved by: https://github.com/Lezcano, https://github.com/ngimel
2022-08-31 17:39:31 +00:00
Fabio Rocha
9e1560821e [primTorch] Refs for pdist, triu and related ops (#82819)
This PR adds refs for the following ops:
   - `torch.triu`
   - `torch.tril`
   - `torch.triu_indices`
   - `torch.tril_indices`
   - `torch.nn.functional.pairwise_distance`
   - `torch.nn.functional.pdist`

It adds OpInfos for
   - `torch.triu_indices`
   - `torch.tril_indices`

Note that these were already tested in `test/test_tensor_creation_ops.py`
but for the ref tests we need the OpInfos.

Finally, it improves documentation for PairwiseDistance and adds
a missing import to `torch/testing/_internal/opinfo/core.py`.

This started with an attempt to just add the `nn.functional` refs above,
but it turned out that `pdist` was easiest to implement using `triu_indices`
so I added that one and the related functions as well.
~~In the end, I changed the `pdist` implementation to not use `triu_indices`
but kept the other refs.~~
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82819
Approved by: https://github.com/ngimel
2022-08-18 21:52:03 +00:00
Fabio Rocha
85ef1a1cd1 [primTorch] added ref for nn.functional.glu (#82214)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82214
Approved by: https://github.com/ngimel
2022-08-17 18:38:43 +00:00
Huy Do
12cb26509a Apply ufmt to torch internal (#81643)
This is a big bang PR, merge conflicts are probably expected and will be addressed at merge.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81643
Approved by: https://github.com/ezyang
2022-07-22 02:19:50 +00:00
Horace He
a5fb41e3d3 Revert "Revert "Refactored prim utils into _prims_utils folder (#81746)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81746
Approved by: https://github.com/anijain2305, https://github.com/Krovatkin
2022-07-20 23:43:57 +00:00
soulitzer
6cf0d9249f Add prelu and relu6 refs missing from __all__ and decomp db (#81420)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81420
Approved by: https://github.com/ngimel
2022-07-19 20:50:04 +00:00
PyTorch MergeBot
e43a02c314 Revert "Refactored prim utils into _prims_utils folder (#81088)"
This reverts commit 80231d0a72.

Reverted https://github.com/pytorch/pytorch/pull/81088 on behalf of https://github.com/jeanschmidt due to breaking internal tests
2022-07-19 19:56:41 +00:00
Horace He
80231d0a72 Refactored prim utils into _prims_utils folder (#81088)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81088
Approved by: https://github.com/ngimel
2022-07-19 03:55:51 +00:00
soulitzer
28776c45e3 Add ref for relu6, fixes hardshrink and improves testing of related ops (#81142)
This PR:
- Adds ref for relu6 and makes its OpInfo a UnaryUfuncInfo
- Correct hardshrink ref when lambd < 0 and when inputs are nan
- Corrected nan behavior vectorized implementation of hardshrink (fixes https://github.com/pytorch/pytorch/issues/81138)
- Make OpInfos for {hard,soft}shrink, hardtanh UnaryUfuncInfos and add error_inputs for softshrink

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81142
Approved by: https://github.com/Lezcano, https://github.com/ngimel, https://github.com/mruberry
2022-07-11 21:40:57 +00:00
Ryan Spring
d26516fd1b [primTorch] Implement loss function references (#80573)
Add Reference:

- mse_loss
- l1_loss
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80573
Approved by: https://github.com/mruberry
2022-07-09 03:31:20 +00:00
David Berard
4c57cf9a8b Register unregistered refs and add a test to check registration (#80497)
Added missing `register_decomposition`s which will register the refs so
they can be used for decompositions.

Also added a test for verifying that new refs are registered.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80497
Approved by: https://github.com/ezyang
2022-07-08 16:29:52 +00:00
soulitzer
bd75b2fea1 Add ref for nn.functional.prelu (#79768)
TODO:
- not sure if these error-inputs work for all devices (awaiting CI)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79768
Approved by: https://github.com/mruberry
2022-07-07 17:04:47 +00:00
lezcano
eb0889cf7d Add support for multiple inputs to out_wrapper and strict dtype checking (#80601)
Reland of https://github.com/pytorch/pytorch/pull/79941
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80601
Approved by: https://github.com/albanD
2022-07-05 12:31:21 +00:00