Commit Graph

1301 Commits

Author SHA1 Message Date
Masaki Kozuki
ef0baba23f Use int64_t for nll_loss with cuda inputs (#85395)
Related #85005
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85395
Approved by: https://github.com/t-vi, https://github.com/lezcano
2022-09-29 17:02:04 +00:00
Mikayla Gawarecki
afaee00fec Add python nested_tensor and as_nested_tensor constructors in torch.nested (#85593)
Remove `torch.nested_tensor` which has erroneous behavior wrt gradients (could be either leaf or not leaf). Introduce `torch.nested.nested_tensor` and `torch.nested.as_nested_tensor` in the vein of `torch.tensor` and `torch.as_tensor`. Done in nested `__init__.py` for now but can move to pybind in future (when we want to load from numpy/nested lists ).

Discussed offline with @cpuhrsch and pybind constructor (https://github.com/pytorch/pytorch/pull/85536) was more gnarly than expected, so we can move to that when we do need loading from numpy etc.

Differential Revision: [D39806622](https://our.internmc.facebook.com/intern/diff/D39806622)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85593
Approved by: https://github.com/drisspg, https://github.com/cpuhrsch
2022-09-28 20:15:02 +00:00
Weiyi Zheng
b2311192e6 [NN module] speed up _load_from_state_dict (#85743)
Fixes #61398

The original implementation is very slow when the state_dict.keys() is long. This PR only passes relevant keys to the child module.

existing test passes: `pytest test/test_nn.py -k state_dict`
I couldn't figure out a good way to write a new test for this new behavior. Had a new snippet, but it will be flaky if integrated into the main CI because it's a timing based check.
But I can verify that the test took 30s to run, after this PR it only takes 0.5s.

```python
    def test_load_state_dict_large(self):
        # construct a module with 4 levels of module, 10 linear each, leads to 10k items in the dictionary
        import copy
        import time
        base_module = nn.Linear(1,1)
        model = base_module
        for level in range(4):
           model = nn.Sequential(*[copy.deepcopy(model) for _ in range(10)])
        state_dict = model.state_dict()
        self.assertEqual(len(state_dict), 20000)
        st = time.time()
        model.load_state_dict(state_dict, strict=True)
        strict_load_time = time.time() - st
        # it took 0.5 seconds to
        self.assertLess(strict_load_time, 10)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85743
Approved by: https://github.com/albanD
2022-09-28 15:26:03 +00:00
Eddie Yan
2bc82163eb Reduce memory usage requirement of test_warp_softmax_64bit_indexing in test_nn.py (re-open of #85037) (#85373)
CC @ngimel @xwang233 @ptrblck

Adds fix for `get_tolerances`, tested locally on a dgx Volta.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85373
Approved by: https://github.com/ngimel
2022-09-22 07:34:47 +00:00
Mikayla Gawarecki
77f1f98479 Re-introduce torch.Tensor.to_padded_tensor (#85293)
Differential Revision: [D39629004](https://our.internmc.facebook.com/intern/diff/D39629004)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85293
Approved by: https://github.com/cpuhrsch
2022-09-21 18:45:56 +00:00
PyTorch MergeBot
53fdd60635 Revert "Reduce memory usage requirement of test_warp_softmax_64bit_indexing in test_nn.py (#85037)"
This reverts commit 66a9cba221.

Reverted https://github.com/pytorch/pytorch/pull/85037 on behalf of https://github.com/clee2000 due to broke test_warp_softmax_64bit_indexing_cuda_float32 and test_warp_softmax_64bit_indexing_cuda_float16 on rocm https://github.com/pytorch/pytorch/actions/runs/3085764744/jobs/4989643817
2022-09-20 00:13:41 +00:00
eqy
66a9cba221 Reduce memory usage requirement of test_warp_softmax_64bit_indexing in test_nn.py (#85037)
For reference: #84944

CC @xwang233 @ngimel @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85037
Approved by: https://github.com/ngimel, https://github.com/pmeier
2022-09-19 21:31:08 +00:00
Elias Ellison
f37069aac7 Re-enable fixed dynamo tests (#84969)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84969
Approved by: https://github.com/bdhirsh, https://github.com/ezyang
2022-09-16 15:36:52 +00:00
Michael Melesse
b6d6a78c12 [ROCM] test_batchnorm_cudnn_nhwc (#84603)
This pr enables test_batchnorm_cudnn_nhwc.  This is a follow up to https://github.com/pytorch/pytorch/pull/82512
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84603
Approved by: https://github.com/jithunnair-amd, https://github.com/malfet
2022-09-14 15:50:14 +00:00
Mikayla Gawarecki
e217b30b0f Add torch.nested namespace (#84102)
First step towards #83775
- only `to_padded_tensor` is moved to the nested namespace for now
- following the schema used for `special`, `fft`, `linalg` and other namespaces, nested functions are registered in native_functions.yaml as `nested_{function_name}` and are bound to the desired Python name in
`torch/nested/__init__.py`, and the desired C++ name in `torch/csrc/api/include/torch/nested.h`.

~~**Question**: should we keep the documentation for `Tensor.to_padded_tensor` or can this deleted since it is shared by `torch.nested.to_padded_tensor`?~~

[generated nested docs](https://docs-preview.pytorch.org/84102/nested.html?highlight=nested#module-torch.nested)

Differential Revision: [D39361148](https://our.internmc.facebook.com/intern/diff/D39361148)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84102
Approved by: https://github.com/drisspg
2022-09-12 16:31:05 +00:00
Kshiteej K
6d6e04d6cc [test_nn] move dropout tests to test/nn/test_dropout.py (#84165)
Ref https://github.com/pytorch/pytorch/issues/63085
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84165
Approved by: https://github.com/albanD
2022-09-03 07:21:48 +00:00
Elias Ellison
f701cb04fb Test Dynamo CI w Fake Tensors (#84282)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84282
Approved by: https://github.com/anijain2305
2022-09-01 00:15:05 +00:00
lezcano
b106a04d76 Fix the edge case when y = 0 in kl_div (#82714)
Brought up in https://github.com/pytorch/pytorch/pull/80334#issuecomment-1193600883

We also prepare its opinfo to fix https://github.com/pytorch/pytorch/issues/80488
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82714
Approved by: https://github.com/albanD
2022-08-30 18:18:25 +00:00
Edward Z. Yang
ad44670fa1 Back out "Revert D38984222: Don't introduce new overload for SymInt (#83628)" (#84173)
Also Back out "Revert D39075159: [acc_tensor] Use SymIntArrayRef for overloaded empty.memory_format's signature"

Original commit changeset: dab4a9dba4fa
Original commit changeset: dcaf16c037a9

Original Phabricator Diff: D38984222
Original Phabricator Diff: D39075159

Also update Metal registrations for C++ registration changes.

Also update NNPI registration to account for tightened schema checking

Differential Revision: [D39084762](https://our.internmc.facebook.com/intern/diff/D39084762/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D39084762/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84173
Approved by: https://github.com/Krovatkin
2022-08-29 18:01:07 +00:00
soulitzer
7088a98fba conv2d: require bias to have the same dtype as input and weight on cpu (#83686)
Fixes https://github.com/pytorch/pytorch/issues/83505

BC-breaking message:
- Previously we only required input and weight to have the same dtype on cpu (when input is non-complex). After this change, the dtype of bias is now also expected to have the same dtype. This change was necessary to improve the error message for certain combinations of inputs. This behavior now also matches that of convolution on cuda.

<details>
<summary>
Old plan
</summary>
Previously convolution (at least for slow_conv2d) did not perform type promotion, i.e. the output of `conv(int, int, float)` is an int, and that leads to the autograd assert.

This PR adds type promotion handling at the `at::native::conv2d` (this is a composite) level. We also need to correct or remove many tests that assume that conv errors when input types are mixed

Pros:
- Doing type promotion at this level avoids the complex path from having any special handling for mixed dtypes, and can potentially speed up mixed dtype inputs to now dispatch to faster kernels which are only capable of handling floats.

Cons:
- Doing type promotion at this level has the risk of introducing extra overhead when we would've dispatched to a kernel capable of handle mixed type anyway. I don't know if any of these exist at all though - it is possible that inputs with any non-float arguments are dispatched to the slow path.

If this approach is OK, we can proceed with the other convolutions as well:
</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83686
Approved by: https://github.com/ngimel
2022-08-29 16:41:17 +00:00
Natalia Gimelshein
0ac2986d33 Fixes softmax indexing for large tensors (#84182)
Fixes #84144

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84182
Approved by: https://github.com/janeyx99
2022-08-29 04:29:09 +00:00
PyTorch MergeBot
c7edcd6968 Revert "Don't introduce new overload for SymInt (#83628)"
This reverts commit 9790d90e4b.

Reverted https://github.com/pytorch/pytorch/pull/83628 on behalf of https://github.com/malfet due to Breaks internal builds, see D39076487
2022-08-27 01:23:17 +00:00
Animesh Jain
6a58603956 Update Dynamo pin (#83829)
As title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83829
Approved by: https://github.com/ezyang
2022-08-26 20:49:43 +00:00
Edward Z. Yang
9790d90e4b Don't introduce new overload for SymInt (#83628)
Previously, we introduced new SymInt overloads for every function we wanted.  This led to a lot of boilerplate, and also a lot of confusion about how the overloads needed to be implemented.

This PR takes a simpler but more risky approach: just take the original function and changes its ints to SymInts.

This is BC-breaking in the following ways:

* The C++ API for registering implementations for aten operators will change from int64_t to SymInt whenever you make this change. Code generated registrations in PyTorch do not change as codegen handles the translation automatically, but manual registrations will need to follow the change.  Typically, if you now accept a SymInt where you previously only took int64_t, you have to convert it back manually.  This will definitely break XLA, see companion PR https://github.com/pytorch/xla/pull/3914 Note that not all dispatch keys get the automatic translation; all the composite keys and Meta keys are modified to take SymInt directly (because they should handle them directly), and so there are adjustments for this.

This is not BC-breaking in the following ways:

* The user facing C++ API remains compatible.  Even if a function changes from int to SymInt, the default C++ binding still takes only ints.  (e.g., at::empty(IntArrayRef, ...).  To call with SymInts, you must call at::empty_symint instead. This involved adding two more signatures to CppSignatureGroup; in many cases I refactored code to iterate over all signatures in the group instead of hard-coding the two that previously existed.
* This is TorchScript compatible; internally we treat SymInts as ints so there is no change to what happens at runtime in TorchScript. In particular, it's OK to reference an empty schema by its old type (using int types), as long as you're not doing string equality (which you shouldn't be), these parse to the same underyling type.

Structure of the PR:

* The general strategy of this PR is that, even when you write `SymInt` inside `native_functions.yaml`, sometimes, we will treat it *as if* it were an `int`. This idea pervades the codegen changes, where we have a translation from SymInt to c10::SymInt or int64_t, and this is controlled by a symint kwarg which I added and then audited all call sites to decide which I wanted. Here are some of the major places where we pick one or the other:
  * The C++ FunctionSchema representation represents `SymInt` as `int`. There are a few places we do need to know that we actually have a SymInt and we consult `real_type()` to get the real type in this case. In particular:
    * When we do schema validation of C++ operator registration, we must compare against true schema (as the C++ API will provide `c10::SymInt`, and this will only be accepted if the schema is `SymInt`. This is handled with cloneWithRealTypes before we check for schema differences.
    * In `toIValue` argument parsing, we parse against the true schema value. For backwards compatibility reasons, I do still accept ints in many places where Layout/SymInt/etc were expected. (Well, accepting int where SymInt is expected is not BC, it's just the right logic!)
  * In particular, because SymInt never shows up as type() in FunctionSchema, this means that we no longer need a dedicated Tag::SymInt. This is good, because SymInts never show up in mobile anyway.
* Changes to functorch/aten are mostly about tracking changes to the C++ API registration convention. Additionally, since SymInt overloads no longer exist, registrations for SymInt implementations are deleted. In many cases, the old implementations did not properly support SymInts; I did not add any new functionality with this PR, but I did try to annotate with TODOs where this is work to do. Finally, because the signature of `native::` API changed from int to SymInt, I need to find alternative APIs for people who were directly calling these functions to call. Typically, I insert a new dispatch call when perf doesn't matter, or use `at::compositeexplicitautograd` namespace to handle other caes.
* The change to `make_boxed_from_unboxed_functor.h` is so that we accept a plain IntList IValue anywhere a SymIntList is expected; these are read-only arguments so covariant typing is OK.
* I change how unboxing logic works slightly. Previously, we interpret the C++ type for Layout/etc directly as IntType JIT type, which works well because the incoming IValue is tagged as an integer. Now, we interpret the C++ type for Layout as its true type, e.g., LayoutType (change to `jit_type.h`), but then we accept an int IValue for it anyway. This makes it symmetric with SymInt, where we interpret the C++ type as SymIntType, and then accept SymInt and int IValues for it.
* I renamed the `empty.names` overload to `empty_names` to make it less confusing (I kept mixing it up with the real empty overload)
* I deleted the `empty.SymInt` overload, which ended up killing a pile of functions. (This was originally a separate PR but the profiler expect test was giving me grief so I folded it in.)
* I deleted the LazyDynamicOpsTest tests. These were failing after these changes, and I couldn't figure out why they used to be passing: they make use of `narrow_copy` which didn't actually support SymInts; they were immediately converted to ints.
* I bashed LTC into working. The patches made here are not the end of the story. The big problem is that SymInt translates into Value, but what if you have a list of SymInt? This cannot be conveniently represented in the IR today, since variadic Values are not supported. To work around this, I translate SymInt[] into plain int[] (this is fine for tests because LTC dynamic shapes never actually worked); but this will need to be fixed for proper LTC SymInt support. The LTC codegen also looked somewhat questionable; I added comments based on my code reading.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83628
Approved by: https://github.com/albanD, https://github.com/bdhirsh
2022-08-26 01:35:40 +00:00
zaf
2f04ba2c7c [quant][ao_migration] torch.nn.qattorch.ao.nn.qat (#78716)
Context: In order to avoid the cluttering of the `torch.nn` namespace
the quantized modules namespace is moved to `torch.ao.nn`.

The list of the `nn.quantized` files that are being migrated:

- [X] `torch.nn.quantized` → `torch.ao.nn.quantized`
    - [X] `torch.nn.quantized.functional` → `torch.ao.nn.quantized.functional`
    - [X] `torch.nn.quantized.modules` → `torch.ao.nn.quantized.modules`
    - [X] `torch.nn.quantized.dynamic` → `torch.ao.nn.quantized.dynamic`
    - [X] `torch.nn.quantized._reference` → `torch.ao.nn.quantized._reference`
- [X] `torch.nn.quantizable` → `torch.ao.nn.quantizable`
- [X] [Current PR] `torch.nn.qat` → `torch.ao.nn.qat`
    - [X] `torch.nn.qat.modules` → `torch.ao.nn.qat.modules`
    - [X] `torch.nn.qat.dynamic` → `torch.ao.nn.qat.dynamic`
- [ ] `torch.nn.intrinsic` → `torch.ao.nn.intrinsic`
    - [ ] `torch.nn.intrinsic.modules` → `torch.ao.nn.intrinsic.modules`
    - [ ] `torch.nn.intrinsic.qat` → `torch.ao.nn.intrinsic.qat`
    - [ ] `torch.nn.intrinsic.quantized` → `torch.ao.nn.intrinsic.quantized`
        - [ ] `torch.nn.intrinsic.quantized.modules` → `torch.ao.nn.intrinsic.quantized.modules`
        - [ ] `torch.nn.intrinsic.quantized.dynamic` → `torch.ao.nn.intrinsic.quantized.dynamic`

Majority of the files are just moved to the new location.
However, specific files need to be double checked:

- None

Differential Revision: [D36861197](https://our.internmc.facebook.com/intern/diff/D36861197/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D36861197/)!

Differential Revision: [D36861197](https://our.internmc.facebook.com/intern/diff/D36861197)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78716
Approved by: https://github.com/jerryzh168
2022-08-25 16:50:38 +00:00
XiaobingSuper
a013597b32 fix oneDNN channels_last path issue (#83653)
Fix #82060(N>1 will call in OneDNN path) and #80837, those two issues are introduced by the definition of channels last is different between PyTorch FW side with ideep side, this PR will fix this gap which ideep will use the format flag given by FW side.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83653
Approved by: https://github.com/mingfeima, https://github.com/malfet
2022-08-25 03:58:11 +00:00
PyTorch MergeBot
a7edf71360 Revert "Don't introduce new overload for SymInt (#83628)"
This reverts commit 8fae7027b3.

Reverted https://github.com/pytorch/pytorch/pull/83628 on behalf of https://github.com/malfet due to breaking internal builds, see https://www.internalfb.com/diff/D38984222
2022-08-25 00:49:40 +00:00
kshitij12345
7a8152530d move pooling test from test_nn to test/nn/test_pooling (#83915)
Ref #63085

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83915
Approved by: https://github.com/albanD
2022-08-24 16:17:50 +00:00
Ishan-Rajgarhia
7fdc2f70c6 Task: T129772171 remove assertEqualIgnoreTypes from test/test_nn.py (#83870)
See https://github.com/pytorch/pytorch/issues/38095
Replaced assertEqualIgnoreType with assertEqual
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83870
Approved by: https://github.com/kit1980
2022-08-24 02:45:52 +00:00
Edward Z. Yang
8fae7027b3 Don't introduce new overload for SymInt (#83628)
Previously, we introduced new SymInt overloads for every function we wanted.  This led to a lot of boilerplate, and also a lot of confusion about how the overloads needed to be implemented.

This PR takes a simpler but more risky approach: just take the original function and changes its ints to SymInts.

This is BC-breaking in the following ways:

* The C++ API for registering implementations for aten operators will change from int64_t to SymInt whenever you make this change. Code generated registrations in PyTorch do not change as codegen handles the translation automatically, but manual registrations will need to follow the change.  Typically, if you now accept a SymInt where you previously only took int64_t, you have to convert it back manually.  This will definitely break XLA, see companion PR https://github.com/pytorch/xla/pull/3914 Note that not all dispatch keys get the automatic translation; all the composite keys and Meta keys are modified to take SymInt directly (because they should handle them directly), and so there are adjustments for this.

This is not BC-breaking in the following ways:

* The user facing C++ API remains compatible.  Even if a function changes from int to SymInt, the default C++ binding still takes only ints.  (e.g., at::empty(IntArrayRef, ...).  To call with SymInts, you must call at::empty_symint instead. This involved adding two more signatures to CppSignatureGroup; in many cases I refactored code to iterate over all signatures in the group instead of hard-coding the two that previously existed.
* This is TorchScript compatible; internally we treat SymInts as ints so there is no change to what happens at runtime in TorchScript. In particular, it's OK to reference an empty schema by its old type (using int types), as long as you're not doing string equality (which you shouldn't be), these parse to the same underyling type.

Structure of the PR:

* The general strategy of this PR is that, even when you write `SymInt` inside `native_functions.yaml`, sometimes, we will treat it *as if* it were an `int`. This idea pervades the codegen changes, where we have a translation from SymInt to c10::SymInt or int64_t, and this is controlled by a symint kwarg which I added and then audited all call sites to decide which I wanted. Here are some of the major places where we pick one or the other:
  * The C++ FunctionSchema representation represents `SymInt` as `int`. There are a few places we do need to know that we actually have a SymInt and we consult `real_type()` to get the real type in this case. In particular:
    * When we do schema validation of C++ operator registration, we must compare against true schema (as the C++ API will provide `c10::SymInt`, and this will only be accepted if the schema is `SymInt`. This is handled with cloneWithRealTypes before we check for schema differences.
    * In `toIValue` argument parsing, we parse against the true schema value. For backwards compatibility reasons, I do still accept ints in many places where Layout/SymInt/etc were expected. (Well, accepting int where SymInt is expected is not BC, it's just the right logic!)
  * In particular, because SymInt never shows up as type() in FunctionSchema, this means that we no longer need a dedicated Tag::SymInt. This is good, because SymInts never show up in mobile anyway.
* Changes to functorch/aten are mostly about tracking changes to the C++ API registration convention. Additionally, since SymInt overloads no longer exist, registrations for SymInt implementations are deleted. In many cases, the old implementations did not properly support SymInts; I did not add any new functionality with this PR, but I did try to annotate with TODOs where this is work to do. Finally, because the signature of `native::` API changed from int to SymInt, I need to find alternative APIs for people who were directly calling these functions to call. Typically, I insert a new dispatch call when perf doesn't matter, or use `at::compositeexplicitautograd` namespace to handle other caes.
* The change to `make_boxed_from_unboxed_functor.h` is so that we accept a plain IntList IValue anywhere a SymIntList is expected; these are read-only arguments so covariant typing is OK.
* I change how unboxing logic works slightly. Previously, we interpret the C++ type for Layout/etc directly as IntType JIT type, which works well because the incoming IValue is tagged as an integer. Now, we interpret the C++ type for Layout as its true type, e.g., LayoutType (change to `jit_type.h`), but then we accept an int IValue for it anyway. This makes it symmetric with SymInt, where we interpret the C++ type as SymIntType, and then accept SymInt and int IValues for it.
* I renamed the `empty.names` overload to `empty_names` to make it less confusing (I kept mixing it up with the real empty overload)
* I deleted the `empty.SymInt` overload, which ended up killing a pile of functions. (This was originally a separate PR but the profiler expect test was giving me grief so I folded it in.)
* I deleted the LazyDynamicOpsTest tests. These were failing after these changes, and I couldn't figure out why they used to be passing: they make use of `narrow_copy` which didn't actually support SymInts; they were immediately converted to ints.
* I bashed LTC into working. The patches made here are not the end of the story. The big problem is that SymInt translates into Value, but what if you have a list of SymInt? This cannot be conveniently represented in the IR today, since variadic Values are not supported. To work around this, I translate SymInt[] into plain int[] (this is fine for tests because LTC dynamic shapes never actually worked); but this will need to be fixed for proper LTC SymInt support. The LTC codegen also looked somewhat questionable; I added comments based on my code reading.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83628
Approved by: https://github.com/albanD, https://github.com/bdhirsh
2022-08-23 22:04:07 +00:00
Khushi Agrawal
9095030239 [fix] edge case in MaxPool1d and add ErrorInputs (#83553)
Fixes #83224

cc @kshitij12345 @albanD!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83553
Approved by: https://github.com/albanD
2022-08-23 19:23:39 +00:00
Kshiteej K
dd67d52b57 [nn] split rnn_utils test from test_nn.py (#83675)
Ref: https://github.com/pytorch/pytorch/issues/63085
Proposed folder structure
```
-> test
  -> nn
    -> test_conv.py
    -> test_pooling.py
    -> .....
```

This PR: Moves test related RNN utilities to a different file.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83675
Approved by: https://github.com/albanD
2022-08-23 08:34:39 +00:00
XiaobingSuper
658f958bc4 fix upsample bf16 issue for channels last path by using high pricsion to compute index (#83847)
Given the following case:
```
import torch
a = torch.ones(1, 3, 320, 480).bfloat16().to(memory_format=torch.channels_last)
out_bf16 = torch.nn.functional.interpolate(a, size = (640, 960), scale_factor = None, mode = 'bilinear', align_corners = False, recompute_scale_factor= None, antialias = False)
out_fp32= torch.nn.functional.interpolate(a.float(), size = (640, 960), scale_factor = None, mode = 'bilinear', align_corners = False, recompute_scale_factor= None, antialias = False)
print(out_bf16[0, 2, :, :])
print(out_fp32[0, 2, :, :])
```
the boundary of bfloat16 output gets a wrong value:
```
tensor([[1.0000e+00, 1.0000e+00, 1.0000e+00,  ..., 1.0000e+00, 1.0000e+00,
         1.0000e+00],
        [1.0000e+00, 1.0000e+00, 1.0000e+00,  ..., 1.0000e+00, 1.0000e+00,
         1.0000e+00],
        [1.0000e+00, 1.0000e+00, 1.0000e+00,  ..., 1.0000e+00, 1.0000e+00,
         1.0000e+00],
        ...,
        [1.0000e+00, 1.0000e+00, 1.0000e+00,  ..., 1.0000e+00, 1.0000e+00,
         1.0000e+00],
        [1.0000e+00, 1.0000e+00, 1.0000e+00,  ..., 1.0000e+00, 1.0000e+00,
         1.0000e+00],
        [0.0000e+00, 0.0000e+00, 1.8367e-40,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]], dtype=torch.bfloat16)
tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]])

```

the expected behavior is that the bfloat16 output value should also be one. The main reason is that we use low precision to compute the index,  see
fcb124406b/aten/src/ATen/native/UpSample.h (L448), we should use a high precison to do the computation as GPU path:
fcb124406b/aten/src/ATen/native/cuda/UpSample.cuh (L123)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83847
Approved by: https://github.com/frank-wei
2022-08-23 00:53:37 +00:00
PyTorch MergeBot
4cbb1986fe Revert "[quant][ao_migration] torch.nn.qattorch.ao.nn.qat (#78716)"
This reverts commit 7cd2fa1d38.

Reverted https://github.com/pytorch/pytorch/pull/78716 on behalf of https://github.com/janeyx99 due to sorry, reverting so https://github.com/pytorch/pytorch/pull/78713 could be cleanly reverted
2022-08-22 07:23:24 +00:00
zaf
7cd2fa1d38 [quant][ao_migration] torch.nn.qattorch.ao.nn.qat (#78716)
Context: In order to avoid the cluttering of the `torch.nn` namespace
the quantized modules namespace is moved to `torch.ao.nn`.

The list of the `nn.quantized` files that are being migrated:

- [X] `torch.nn.quantized` → `torch.ao.nn.quantized`
    - [X] `torch.nn.quantized.functional` → `torch.ao.nn.quantized.functional`
    - [X] `torch.nn.quantized.modules` → `torch.ao.nn.quantized.modules`
    - [X] `torch.nn.quantized.dynamic` → `torch.ao.nn.quantized.dynamic`
    - [X] `torch.nn.quantized._reference` → `torch.ao.nn.quantized._reference`
- [X] `torch.nn.quantizable` → `torch.ao.nn.quantizable`
- [X] [Current PR] `torch.nn.qat` → `torch.ao.nn.qat`
    - [X] `torch.nn.qat.modules` → `torch.ao.nn.qat.modules`
    - [X] `torch.nn.qat.dynamic` → `torch.ao.nn.qat.dynamic`
- [ ] `torch.nn.intrinsic` → `torch.ao.nn.intrinsic`
    - [ ] `torch.nn.intrinsic.modules` → `torch.ao.nn.intrinsic.modules`
    - [ ] `torch.nn.intrinsic.qat` → `torch.ao.nn.intrinsic.qat`
    - [ ] `torch.nn.intrinsic.quantized` → `torch.ao.nn.intrinsic.quantized`
        - [ ] `torch.nn.intrinsic.quantized.modules` → `torch.ao.nn.intrinsic.quantized.modules`
        - [ ] `torch.nn.intrinsic.quantized.dynamic` → `torch.ao.nn.intrinsic.quantized.dynamic`

Majority of the files are just moved to the new location.
However, specific files need to be double checked:

- None

Differential Revision: [D36861197](https://our.internmc.facebook.com/intern/diff/D36861197/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D36861197/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78716
Approved by: https://github.com/jerryzh168
2022-08-22 05:33:23 +00:00
Rui Zhu
e0f2eba93d Move odd num_head in TransformerEncoder to slow_path (#83483)
Summary: odd nhead is not supported for masked softmax, therefore we just move it to use old slow_path

Test Plan: CI

Differential Revision: D38720086

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83483
Approved by: https://github.com/erichan1
2022-08-20 10:02:08 +00:00
Jeff Daily
d52d2bd5a9 [ROCm] MIOpen fused convolution relu (#82002)
Adds MIOpen fused convolution relu for fp32 and contiguous memory format.  Adds fallbacks for conv + z + bias + relu, fp16, and channels last until MIOpen adds these features.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82002
Approved by: https://github.com/ngimel, https://github.com/malfet
2022-08-16 20:49:33 +00:00
Nicolas Macchioni
b236352036 Add mask identifier for multiplexed src_mask/src_key_padding_mask in BT (#81947)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81947

Transformer fastpath multiplexes two arguments, src_mask [seq_len x seq_len] and src_key_padding_mask [batch_size x seq_len], and later deduces the type based on mask shape.

In the event that batch_size == seq_len, any src_mask is wrongly interpreted as a src_key padding_mask. This is fixed by requiring a mask_type identifier be supplied whenever batch_size == seq_len.

Additionally, added support for src_mask in masked_softmax CPU path.

Test Plan: existing unit tests + new unit tests (batch_size == seq_len)

Differential Revision: D37932240

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81947
Approved by: https://github.com/zrphercule
2022-08-09 23:42:16 +00:00
Sergii Dymchenko
7390ae837c Resolve TODO for GroupNorm numerical issues (#82423)
Looks like the numerical issues are resolved now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82423
Approved by: https://github.com/ngimel
2022-08-03 19:42:26 +00:00
Jiayi Sun
15a284b09e optimize softmax backward and logsoftmax backward (#80114)
Currently, if we run softmax_backward/logsoftmax_backward which are not along the last dim, the calculation will fall to a [scalar version](32593ef2dd/aten/src/ATen/native/SoftMax.cpp (L220-L287)). And we find actually we have the chance to vectorize the calculation along the inner_size dim.

Changes we made:

Use vectorized softmax_backward_kernel/log_softmax_backward_kernel instead of host_softmax_backward when not along the last dim.

We collected the benchmark data of softmax_backward and logsoftmax_backward for BFloat16 and Float32 data type by using the operator_benchmark tool of PyTorch on the platform of Intel(R) Xeon(R) Platinum 8260L CPU @ 2.40GHz.
Number of cores: 24 cores(1 socket)
[softmax_benchmark_32593ef.log](https://github.com/pytorch/pytorch/files/8962956/softmax_benchmark_32593ef.log)
[softmax_benchmark_the_pr.log](https://github.com/pytorch/pytorch/files/8962958/softmax_benchmark_the_pr.log)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80114
Approved by: https://github.com/frank-wei
2022-08-03 00:36:28 +00:00
mingfeima
b019a41674 fix bug for thnn_conv2d when input's C is 1 and weight is channels last (#82392)
To fix https://github.com/pytorch/pytorch/issues/82060

When `input` is not explicitly converted to channels last while `conv` has, the output should also be in channels last. The root cause is that when input has IC of 1, `compute_columns2d` from `\aten\src\ATen\native\ConvolutionMM2d.cpp` would consider it as channels first:

We do have logic to make sure both input and weight have the same memory format even if they are given differently, like:
```
auto input = self.contiguous(memory_format);
auto weight = weight_.contiguous(memory_format);
```

But  for a N1HW input, `.contiguous(MemoryFormat::ChannelsLast)` would not change its stride , and its `suggest_memory_format()` still returns `MemoryFormat::Contiguous`. That's how it went wrong.

Also updated the corresponding test cases, without this patch, the new test case would fail on forward path and runtime error on backward path.

attach old fail log on forward path:
```
FAIL: test_conv_thnn_nhwc_cpu_float32 (__main__.TestNNDeviceTypeCPU)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/mingfeim/anaconda3/envs/pytorch-test-cpu/lib/python3.7/site-packages/torch/testing/_internal/common_device_type.py", line 377, in instantiated_test
    result = test(self, **param_kwargs)
  File "/home/mingfeim/anaconda3/envs/pytorch-test-cpu/lib/python3.7/site-packages/torch/testing/_internal/common_device_type.py", line 974, in only_fn
    return fn(slf, *args, **kwargs)
  File "test/test_nn.py", line 19487, in test_conv_thnn_nhwc
    input_format=torch.contiguous_format, weight_format=torch.channels_last)
  File "test/test_nn.py", line 19469, in helper
    self.assertEqual(out, ref_out, exact_dtype=False)
  File "/home/mingfeim/anaconda3/envs/pytorch-test-cpu/lib/python3.7/site-packages/torch/testing/_internal/common_utils.py", line 2376, in assertEqual
    msg=(lambda generated_msg: f"{generated_msg} : {msg}") if isinstance(msg, str) and self.longMessage else msg,
  File "/home/mingfeim/anaconda3/envs/pytorch-test-cpu/lib/python3.7/site-packages/torch/testing/_comparison.py", line 1093, in assert_equal
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 988 / 1024 (96.5%)
Greatest absolute difference: 42.0 at index (1, 2, 6, 6) (up to 1e-05 allowed)
Greatest relative difference: inf at index (0, 0, 2, 1) (up to 1.3e-06 allowed)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82392
Approved by: https://github.com/jbschlosser
2022-07-28 14:20:52 +00:00
Khushi Agrawal
050aec1805 [nn] add pop to sequential and ModuleList (#81601)
Follows #71329

cc @kshitij12345!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81601
Approved by: https://github.com/albanD
2022-07-25 19:32:32 +00:00
Ansh Radhakrishnan
110cd724fc [nn] Add support for +=, * and *= operations for nn.Sequential objects (#81279)
Fixes 71329

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81279
Approved by: https://github.com/albanD
2022-07-25 15:48:47 +00:00
soulitzer
f595467e5c Reenable slow gradcheck and make it pass (#80514)
Context: For a while slow gradcheck CI was skipping nearly all tests and this hid the fact that it should've been failing and timing out (10+h runtime for TestGradients). The CI configuration has since been fixed to correct this, revealing the test failures. This PR reenables slow gradcheck CI and makes it pass again.

This PR:
- makes slow and failing tests run in fast gradcheck mode only
- reduce the input size for slow gradcheck only for unary/binary ufuncs (alternatively, skip the test entirely)
- skip entire test files on slow gradcheck runner if they don't use gradcheck (test_ops, test_meta, test_decomp, test_ops_jit)
- reduces the input size for some ops

Follow ups:
1. Investigate slow mode failures https://github.com/pytorch/pytorch/issues/80411
2. See if we can re-enable slow gradcheck tests for some of the slow tests by reducing the sizes of their inputs

The following are failing in slow mode, they are now running in fast mode only.
```
test_fn_fwgrad_bwgrad___rmod___cuda_float64
test_fn_fwgrad_bwgrad_linalg_householder_product_cuda_complex128
test_fn_fwgrad_bwgrad__masked_prod_cuda_complex128
test_fn_fwgrad_bwgrad__masked_prod_cuda_float64
test_fn_fwgrad_bwgrad_linalg_matrix_power_cuda_complex128
test_fn_fwgrad_bwgrad_cat_cuda_complex128
test_fn_fwgrad_bwgrad_linalg_lu_factor_ex_cuda_float64
test_fn_fwgrad_bwgrad_copysign_cuda_float64
test_fn_fwgrad_bwgrad_cholesky_inverse_cuda_complex128
test_fn_fwgrad_bwgrad_float_power_cuda_complex128
test_fn_fwgrad_bwgrad_fmod_cuda_float64
test_fn_fwgrad_bwgrad_float_power_cuda_float64
test_fn_fwgrad_bwgrad_linalg_lu_cuda_float64
test_fn_fwgrad_bwgrad_remainder_cuda_float64
test_fn_fwgrad_bwgrad_repeat_cuda_complex128
test_fn_fwgrad_bwgrad_prod_cuda_complex128
test_fn_fwgrad_bwgrad_slice_scatter_cuda_float64
test_fn_fwgrad_bwgrad_tile_cuda_complex128
test_fn_fwgrad_bwgrad_pow_cuda_float64
test_fn_fwgrad_bwgrad_pow_cuda_complex128
test_fn_fwgrad_bwgrad_fft_*
test_fn_fwgrad_bwgrad_zero__cuda_complex128
test_fn_gradgrad_linalg_lu_factor_cuda_float64
test_fn_grad_div_trunc_rounding_cuda_float64
test_fn_grad_div_floor_rounding_cuda_float64
```

Marks the OpInfos for the following ops that run slowly in slow gradcheck as `fast_gradcheck` only (the left column represents runtime in seconds):
```
0  918.722  test_fn_fwgrad_bwgrad_nn_functional_conv_transpose3d_cuda_float64
1  795.042  test_fn_fwgrad_bwgrad_nn_functional_unfold_cuda_complex128
2  583.63  test_fn_fwgrad_bwgrad_nn_functional_max_pool3d_cuda_float64
3  516.946  test_fn_fwgrad_bwgrad_svd_cuda_complex128
4  503.179  test_fn_fwgrad_bwgrad_linalg_svd_cuda_complex128
5  460.985  test_fn_fwgrad_bwgrad_linalg_lu_cuda_complex128
6  401.04  test_fn_fwgrad_bwgrad_linalg_lstsq_grad_oriented_cuda_complex128
7  353.671  test_fn_fwgrad_bwgrad_nn_functional_max_pool2d_cuda_float64
8  321.903  test_fn_fwgrad_bwgrad_nn_functional_gaussian_nll_loss_cuda_float64
9  307.951  test_fn_fwgrad_bwgrad_stft_cuda_complex128
10  266.104  test_fn_fwgrad_bwgrad_svd_lowrank_cuda_float64
11  221.032  test_fn_fwgrad_bwgrad_istft_cuda_complex128
12  183.741  test_fn_fwgrad_bwgrad_lu_unpack_cuda_complex128
13  132.019  test_fn_fwgrad_bwgrad_nn_functional_unfold_cuda_float64
14  125.343  test_fn_fwgrad_bwgrad_nn_functional_pad_constant_cuda_complex128
15  124.2  test_fn_fwgrad_bwgrad_kron_cuda_complex128
16  123.721  test_fn_fwgrad_bwgrad_pca_lowrank_cuda_float64
17  121.074  test_fn_fwgrad_bwgrad_nn_functional_max_unpool3d_cuda_float64
18  119.387  test_fn_fwgrad_bwgrad_rot90_cuda_complex128
19  112.889  test_fn_fwgrad_bwgrad__masked_normalize_cuda_complex128
20  107.541  test_fn_fwgrad_bwgrad_dist_cuda_complex128
21  106.727  test_fn_fwgrad_bwgrad_diff_cuda_complex128
22  104.588  test_fn_fwgrad_bwgrad__masked_cumprod_cuda_complex128
23  100.135  test_fn_fwgrad_bwgrad_nn_functional_feature_alpha_dropout_with_train_cuda_float64
24  88.359  test_fn_fwgrad_bwgrad_mH_cuda_complex128
25  86.214  test_fn_fwgrad_bwgrad_nn_functional_max_unpool2d_cuda_float64
26  83.037  test_fn_fwgrad_bwgrad_nn_functional_bilinear_cuda_float64
27  79.987  test_fn_fwgrad_bwgrad__masked_cumsum_cuda_complex128
28  77.822  test_fn_fwgrad_bwgrad_diag_embed_cuda_complex128
29  76.256  test_fn_fwgrad_bwgrad_mT_cuda_complex128
30  74.039  test_fn_fwgrad_bwgrad_linalg_lu_solve_cuda_complex128
```
```
0  334.142  test_fn_fwgrad_bwgrad_unfold_cuda_complex128
1  312.791  test_fn_fwgrad_bwgrad_linalg_lu_factor_cuda_complex128
2  121.963  test_fn_fwgrad_bwgrad_nn_functional_max_unpool3d_cuda_float64
3  108.085  test_fn_fwgrad_bwgrad_diff_cuda_complex128
4  89.418  test_fn_fwgrad_bwgrad_nn_functional_max_unpool2d_cuda_float64
5  72.231  test_fn_fwgrad_bwgrad___rdiv___cuda_complex128
6  69.433  test_fn_fwgrad_bwgrad___getitem___cuda_complex128
7  68.582  test_fn_fwgrad_bwgrad_ldexp_cuda_complex128
8  68.572  test_fn_fwgrad_bwgrad_linalg_pinv_cuda_complex128
9  67.585  test_fn_fwgrad_bwgrad_nn_functional_glu_cuda_float64
10  66.567  test_fn_fwgrad_bwgrad_lu_cuda_float64
```
```
0  630.13  test_fn_gradgrad_nn_functional_conv2d_cuda_complex128
1  81.086  test_fn_gradgrad_linalg_solve_triangular_cuda_complex128
2  71.332  test_fn_gradgrad_norm_cuda_complex128
3  64.308  test_fn_gradgrad__masked_std_cuda_complex128
4  59.519  test_fn_gradgrad_div_no_rounding_mode_cuda_complex128
5  58.836  test_fn_gradgrad_nn_functional_adaptive_avg_pool3
```

Reduces the sizes of the inputs for:
- diff
- diag_embed

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80514
Approved by: https://github.com/albanD
2022-07-22 02:05:37 +00:00
Saketh Are
445ee5620e Simplify torch.nn.grad by calling into aten::convolution_backward (#81839)
`torch.nn.grad` has its own implementations of gradients for conv1d, conv2d, and conv3d. This PR simplifies them by calling into the unified `aten::convolution_backward` backend instead.

The existing implementation of conv2d_weight is incorrect for some inputs (see issue #51430). This PR fixes the issue.

This PR expands coverage in test_nn to include conv1d_weight, conv2d_weight, and conv3d_weight, which were previously untested. It also expands the cases for conv2d to cover issue #51430.

Fixes #51430

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81839
Approved by: https://github.com/albanD
2022-07-21 19:34:27 +00:00
Khushi Agrawal
dced803339 [nn] add insert method to sequential class (#81402)
Follows #71329

cc @kshitij12345
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81402
Approved by: https://github.com/albanD
2022-07-20 14:45:52 +00:00
Khushi Agrawal
2c0b11b43b [nn] implement extend method to sequential class (#81179)
Follows #71329

cc @kshitij12345 :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81179
Approved by: https://github.com/albanD
2022-07-20 05:33:41 +00:00
PyTorch MergeBot
f82b19f15b Revert "Disable use_mkldnn when input is not contiguous for oneDNN (#80864)"
This reverts commit 4655c3bace.

Reverted https://github.com/pytorch/pytorch/pull/80864 on behalf of https://github.com/janeyx99 due to Reverting due for a perf regression https://github.com/pytorch/benchmark/issues/1040
2022-07-19 18:58:52 +00:00
yanbing-j
4655c3bace Disable use_mkldnn when input is not contiguous for oneDNN (#80864)
Fixes [#80837](https://github.com/pytorch/pytorch/issues/80837).
This PR is to disable use_mkldnn when input is not contiguous for oneDNN requirement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80864
Approved by: https://github.com/malfet
2022-07-17 14:58:26 +00:00
Rui Zhu
b22166fd62 Add a small fastpath test for native mha (#81432)
Summary: We dont have a small fast path passing test for mha before, this diff added one for better testing

Test Plan: buck build mode/dev-nosan -c fbcode.platform=platform009 -c fbcode.enable_gpu_sections=true caffe2/test:nn && buck-out/dev/gen/caffe2/test/nn\#binary.par -r test_multihead_attn_fast_path_small_test

Differential Revision: D37834319

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81432
Approved by: https://github.com/erichan1
2022-07-15 23:54:40 +00:00
Eric Han
23088fcfdf disable src mask for transformer and multiheadattention fastpath (#81277)
Disable fastpath if src_mask passed to TransformerEncoderLayer and MultiheadAttention.
- Refactored test_transformerencoder from test_nn.py to test_transformers.py. Added a src_mask test there.
- Added a specific src_mask test in test_transformers.py

Fixes https://github.com/pytorch/pytorch/issues/81129

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81277
Approved by: https://github.com/zrphercule
2022-07-15 20:55:17 +00:00
n.zhuravlev
7af0200a46 Add deepcopy functionality to parametrized modules (#80811)
Fixes #69413

After applying parametrization to any `nn.Module` we lose the ability to create a deepcopy of it e.g. it makes it impossible to wrap a module by an `AveragedModel`.
Specifically, the problem is that the `deepcopy` tries to invoke `__getstate__` if object hasn't implemented its own `__deepcopy__` magic method. But we don't allow serialization of the parametrized modules: `__getstate__` raises an error.
My solution is just to create a default `__deepcopy__` method when it doesn't exist yet.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80811
Approved by: https://github.com/pearu, https://github.com/albanD
2022-07-15 09:06:45 +00:00
Khushi Agrawal
3da8c909da [nn] add + operator for torch.nn.Sequential to concatenate (#81170)
Fixes #78512

#### TODO
- [x] add tests

cc @kshitij12345!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81170
Approved by: https://github.com/albanD
2022-07-11 17:49:58 +00:00
eqy
3b78c5682b Don't implicitly convert to channels-first in MaxPool3D on CUDA (#80748)
MaxPool3D currently converts inputs implicitly to channels-first (via `.contiguous()`) which may yield unexpected regressions in workloads that expect a full channels-last path. This PR preserves the channels-last format in MaxPool3D while attempting to avoid seriously regressing performance.

Currently, typical case (kernel size == 2 == stride) looks good, but larger kernel sizes (>4) or the unusual case of stride 1 can sometimes be slower than converting to channels-first before doing MaxPool3D.

Additionally, this PR adds a test for 64bit-indexing backwards as testing of these changes uncovered an IMA for large tensors when doing the backwards pass with MaxPool3D.

Performance comparison on A6000:

```
[------------------------------------- max_pool3d ---------------------------------------------------------]
                                          |  channels_last=False  |   curr ch_last=True |  new ch_last=True
1 threads: ---------------------------------------------------------------------------- ---------------------
      [64, 256, 32, 32, 32] 4x4 stride 4  |        20093.5        |       34823.4       |       20640.0
      [64, 256, 32, 32, 32] 4x4 stride 2  |        28623.7        |       42625.6       |       27935.5
      [64, 256, 32, 32, 32] 4x4 stride 1  |        68177.5        |       79147.2       |       85604.8
      [64, 256, 32, 32, 32] 2x2 stride 4  |        17237.7        |       32071.3       |       16641.6
      [64, 256, 32, 32, 32] 2x2 stride 2  |        25252.5        |       39993.2       |       25054.8
      [64, 256, 32, 32, 32] 2x2 stride 1  |        43185.2        |       58164.6       |       48416.9
      [64, 256, 16, 16, 16] 4x4 stride 4  |         3017.7        |        3952.4       |        2593.8
      [64, 256, 16, 16, 16] 4x4 stride 2  |         4581.5        |        5384.3       |        3294.3
      [64, 256, 16, 16, 16] 4x4 stride 1  |        11334.1        |       11534.7       |        8651.1
      [64, 256, 16, 16, 16] 2x2 stride 4  |         2346.9        |        3304.6       |        2098.8
      [64, 256, 16, 16, 16] 2x2 stride 2  |         3550.8        |        4526.5       |        3143.6
      [64, 256, 16, 16, 16] 2x2 stride 1  |         6898.1        |        7816.0       |        5820.8
      [64, 256, 4, 4, 4] 4x4 stride 4     |          191.5        |         176.3       |          77.5
      [64, 256, 4, 4, 4] 4x4 stride 2     |          191.8        |         176.8       |          94.1
      [64, 256, 4, 4, 4] 4x4 stride 1     |          191.3        |         176.4       |          97.3
      [64, 256, 4, 4, 4] 2x2 stride 4     |           96.4        |         114.4       |          93.6
      [64, 256, 4, 4, 4] 2x2 stride 2     |          172.1        |         178.6       |          93.7
      [64, 256, 4, 4, 4] 2x2 stride 1     |          263.0        |         279.4       |          92.4
      [64, 64, 32, 32, 32] 4x4 stride 4   |         5033.2        |        7208.3       |        5167.5
      [64, 64, 32, 32, 32] 4x4 stride 2   |         7216.1        |        9218.7       |        6637.1
      [64, 64, 32, 32, 32] 4x4 stride 1   |        17192.1        |       18392.9       |       20489.0
      [64, 64, 32, 32, 32] 2x2 stride 4   |         4318.0        |        6511.2       |        4193.1
      [64, 64, 32, 32, 32] 2x2 stride 2   |         6324.4        |        8657.7       |        6263.6
      [64, 64, 32, 32, 32] 2x2 stride 1   |        10855.0        |       13040.2       |       12055.9
      [64, 64, 16, 16, 16] 4x4 stride 4   |          764.1        |         975.6       |         671.3
      [64, 64, 16, 16, 16] 4x4 stride 2   |         1163.1        |        1333.4       |         833.6
      [64, 64, 16, 16, 16] 4x4 stride 1   |         2890.0        |        2898.5       |        2209.8
      [64, 64, 16, 16, 16] 2x2 stride 4   |          593.5        |         811.2       |         536.3
      [64, 64, 16, 16, 16] 2x2 stride 2   |          895.9        |        1112.3       |         794.5
      [64, 64, 16, 16, 16] 2x2 stride 1   |         1742.5        |        1968.0       |        1475.2
      [64, 64, 4, 4, 4] 4x4 stride 4      |          101.1        |         112.2       |          93.4
      [64, 64, 4, 4, 4] 4x4 stride 2      |           96.7        |         114.6       |          92.5
      [64, 64, 4, 4, 4] 4x4 stride 1      |           98.9        |         111.9       |          96.5
      [64, 64, 4, 4, 4] 2x2 stride 4      |          100.1        |         107.1       |          94.2
      [64, 64, 4, 4, 4] 2x2 stride 2      |           96.6        |         108.0       |          94.5
      [64, 64, 4, 4, 4] 2x2 stride 1      |           96.7        |         107.9       |          95.2
      [64, 3, 32, 32, 32] 4x4 stride 4    |          250.1        |         326.6       |         278.0
      [64, 3, 32, 32, 32] 4x4 stride 2    |          350.4        |         414.0       |         323.2
      [64, 3, 32, 32, 32] 4x4 stride 1    |          825.6        |         846.9       |         982.5
      [64, 3, 32, 32, 32] 2x2 stride 4    |          213.3        |         289.8       |         219.9
      [64, 3, 32, 32, 32] 2x2 stride 2    |          308.2        |         384.9       |         305.9
      [64, 3, 32, 32, 32] 2x2 stride 1    |          523.5        |         594.7       |         589.9
      [64, 3, 16, 16, 16] 4x4 stride 4    |          103.8        |         116.7       |          93.0
      [64, 3, 16, 16, 16] 4x4 stride 2    |          100.9        |         108.3       |          93.3
      [64, 3, 16, 16, 16] 4x4 stride 1    |          139.4        |         140.7       |         104.8
      [64, 3, 16, 16, 16] 2x2 stride 4    |           97.5        |         114.7       |          92.7
      [64, 3, 16, 16, 16] 2x2 stride 2    |           97.4        |         108.8       |          91.7
      [64, 3, 16, 16, 16] 2x2 stride 1    |           99.9        |         108.0       |          94.1
      [64, 3, 4, 4, 4] 4x4 stride 4       |           97.2        |         110.2       |          94.7
      [64, 3, 4, 4, 4] 4x4 stride 2       |          105.7        |         107.4       |          92.8
      [64, 3, 4, 4, 4] 4x4 stride 1       |           98.0        |         110.0       |          93.7
      [64, 3, 4, 4, 4] 2x2 stride 4       |           98.3        |         116.7       |          93.0
      [64, 3, 4, 4, 4] 2x2 stride 2       |           98.6        |         107.5       |          92.8
      [64, 3, 4, 4, 4] 2x2 stride 1       |          100.6        |         110.3       |          94.0
      [16, 256, 32, 32, 32] 4x4 stride 4  |         5034.2        |        8838.0       |        5165.9
      [16, 256, 32, 32, 32] 4x4 stride 2  |         7236.3        |       10869.9       |        7038.2
      [16, 256, 32, 32, 32] 4x4 stride 1  |        17385.4        |       21401.6       |       21900.7
      [16, 256, 32, 32, 32] 2x2 stride 4  |         4318.7        |        8101.2       |        4172.9
      [16, 256, 32, 32, 32] 2x2 stride 2  |         6324.0        |       10147.5       |        6279.7
      [16, 256, 32, 32, 32] 2x2 stride 1  |        10899.7        |       14826.0       |       12256.3
      [16, 256, 16, 16, 16] 4x4 stride 4  |          765.4        |        1012.7       |         675.6
      [16, 256, 16, 16, 16] 4x4 stride 2  |         1162.8        |        1376.9       |         843.4
      [16, 256, 16, 16, 16] 4x4 stride 1  |         2928.9        |        2969.8       |        2222.5
      [16, 256, 16, 16, 16] 2x2 stride 4  |          593.5        |         845.8       |         534.2
      [16, 256, 16, 16, 16] 2x2 stride 2  |          896.9        |        1152.2       |         796.9
      [16, 256, 16, 16, 16] 2x2 stride 1  |         1750.2        |        2009.4       |        1481.8
      [16, 256, 4, 4, 4] 4x4 stride 4     |           96.6        |         107.1       |          92.7
      [16, 256, 4, 4, 4] 4x4 stride 2     |           97.9        |         114.9       |          93.8
      [16, 256, 4, 4, 4] 4x4 stride 1     |           98.2        |         115.6       |          94.0
      [16, 256, 4, 4, 4] 2x2 stride 4     |           97.0        |         106.7       |          93.8
      [16, 256, 4, 4, 4] 2x2 stride 2     |           96.8        |         108.1       |          93.3
      [16, 256, 4, 4, 4] 2x2 stride 1     |           95.8        |         120.9       |          95.7
      [16, 64, 32, 32, 32] 4x4 stride 4   |         1266.4        |        1815.4       |        1312.3
      [16, 64, 32, 32, 32] 4x4 stride 2   |         1818.5        |        2328.0       |        1678.9
      [16, 64, 32, 32, 32] 4x4 stride 1   |         4352.9        |        4649.3       |        5204.6
      [16, 64, 32, 32, 32] 2x2 stride 4   |         1090.0        |        1631.2       |        1060.8
      [16, 64, 32, 32, 32] 2x2 stride 2   |         1589.4        |        2141.1       |        1576.4
      [16, 64, 32, 32, 32] 2x2 stride 1   |         2733.5        |        3286.0       |        3041.6
      [16, 64, 16, 16, 16] 4x4 stride 4   |          201.7        |         259.6       |         175.0
      [16, 64, 16, 16, 16] 4x4 stride 2   |          301.0        |         350.1       |         226.3
      [16, 64, 16, 16, 16] 4x4 stride 1   |          740.1        |         748.7       |         570.6
      [16, 64, 16, 16, 16] 2x2 stride 4   |          156.0        |         214.8       |         140.8
      [16, 64, 16, 16, 16] 2x2 stride 2   |          232.3        |         292.3       |         208.7
      [16, 64, 16, 16, 16] 2x2 stride 1   |          449.1        |         504.0       |         382.1
      [16, 64, 4, 4, 4] 4x4 stride 4      |           97.5        |         111.4       |          94.5
      [16, 64, 4, 4, 4] 4x4 stride 2      |           98.8        |         111.9       |          94.4
      [16, 64, 4, 4, 4] 4x4 stride 1      |           98.2        |         112.0       |          95.2
      [16, 64, 4, 4, 4] 2x2 stride 4      |           99.7        |         111.0       |          94.0
      [16, 64, 4, 4, 4] 2x2 stride 2      |          100.3        |         110.0       |          93.2
      [16, 64, 4, 4, 4] 2x2 stride 1      |           97.5        |         107.6       |          93.5
      [16, 3, 32, 32, 32] 4x4 stride 4    |          100.5        |         117.1       |          95.7
      [16, 3, 32, 32, 32] 4x4 stride 2    |           97.5        |         121.3       |          92.5
      [16, 3, 32, 32, 32] 4x4 stride 1    |          216.0        |         227.4       |         258.4
      [16, 3, 32, 32, 32] 2x2 stride 4    |           97.1        |         109.0       |          91.9
      [16, 3, 32, 32, 32] 2x2 stride 2    |           95.8        |         108.5       |          92.9
      [16, 3, 32, 32, 32] 2x2 stride 1    |          139.4        |         161.2       |         157.8
      [16, 3, 16, 16, 16] 4x4 stride 4    |           96.4        |         113.6       |          91.9
      [16, 3, 16, 16, 16] 4x4 stride 2    |           97.4        |         108.1       |          93.5
      [16, 3, 16, 16, 16] 4x4 stride 1    |           99.0        |         107.5       |          92.1
      [16, 3, 16, 16, 16] 2x2 stride 4    |           96.9        |         118.1       |          93.4
      [16, 3, 16, 16, 16] 2x2 stride 2    |           97.3        |         106.7       |          95.8
      [16, 3, 16, 16, 16] 2x2 stride 1    |           98.8        |         109.2       |          93.8
      [16, 3, 4, 4, 4] 4x4 stride 4       |           97.8        |         108.0       |          94.2
      [16, 3, 4, 4, 4] 4x4 stride 2       |           92.7        |         108.0       |          93.9
      [16, 3, 4, 4, 4] 4x4 stride 1       |           97.8        |         107.6       |          93.5
      [16, 3, 4, 4, 4] 2x2 stride 4       |          100.3        |         107.7       |          94.3
      [16, 3, 4, 4, 4] 2x2 stride 2       |           97.2        |         107.5       |          96.1
      [16, 3, 4, 4, 4] 2x2 stride 1       |           98.1        |         111.1       |          93.8

Times are in microseconds (us).
```

Performance comparison on V100:
(these times have been updated after working around some noisy measurements in my setup)
```
[------------------------------------- max_pool3d ---------------------------------------------------------]
                                          |  channels_last=False  |  curr ch_last=True |  new ch_last=True
1 threads: -------------------------------------------------------------------------------------------------
      [64, 256, 32, 32, 32] 4x4 stride 4  |        15810.7        |       33807.7      |        16452.9
      [64, 256, 32, 32, 32] 4x4 stride 2  |        24422.7        |       42515.3      |        27700.3
      [64, 256, 32, 32, 32] 4x4 stride 1  |        71756.0        |       89916.5      |       106464.0
      [64, 256, 32, 32, 32] 2x2 stride 4  |        12102.9        |       30210.4      |        11319.8
      [64, 256, 32, 32, 32] 2x2 stride 2  |        19101.7        |       37210.8      |        20373.3
      [64, 256, 32, 32, 32] 2x2 stride 1  |        41418.0        |       59650.5      |        53009.2
      [64, 256, 16, 16, 16] 4x4 stride 4  |         2362.0        |        4210.3      |         2114.0
      [64, 256, 16, 16, 16] 4x4 stride 2  |         4102.4        |        5897.4      |         3179.7
      [64, 256, 16, 16, 16] 4x4 stride 1  |        11339.3        |       13116.6      |        10032.6
      [64, 256, 16, 16, 16] 2x2 stride 4  |         1709.7        |        3506.7      |         1423.6
      [64, 256, 16, 16, 16] 2x2 stride 2  |         2966.6        |        4760.8      |         2499.3
      [64, 256, 16, 16, 16] 2x2 stride 1  |         6998.4        |        8797.3      |         6152.0
      [64, 256, 4, 4, 4] 4x4 stride 4     |          173.0        |         176.3      |          127.9
      [64, 256, 4, 4, 4] 4x4 stride 2     |          149.1        |         176.3      |          125.5
      [64, 256, 4, 4, 4] 4x4 stride 1     |          150.0        |         177.2      |          125.6
      [64, 256, 4, 4, 4] 2x2 stride 4     |          158.0        |         192.7      |          127.9
      [64, 256, 4, 4, 4] 2x2 stride 2     |          169.7        |         199.2      |          125.3
      [64, 256, 4, 4, 4] 2x2 stride 1     |          289.6        |         318.2      |          116.5
      [64, 64, 32, 32, 32] 4x4 stride 4   |         3914.4        |        6993.3      |         4141.4
      [64, 64, 32, 32, 32] 4x4 stride 2   |         6107.4        |        9186.4      |         6378.5
      [64, 64, 32, 32, 32] 4x4 stride 1   |        17920.0        |       20993.5      |        23891.1
      [64, 64, 32, 32, 32] 2x2 stride 4   |         3029.7        |        6112.6      |         2895.6
      [64, 64, 32, 32, 32] 2x2 stride 2   |         4787.8        |        7870.6      |         4724.8
      [64, 64, 32, 32, 32] 2x2 stride 1   |        10366.4        |       13446.4      |        12603.8
      [64, 64, 16, 16, 16] 4x4 stride 4   |          605.8        |         962.9      |          499.7
      [64, 64, 16, 16, 16] 4x4 stride 2   |         1037.0        |        1394.8      |          791.6
      [64, 64, 16, 16, 16] 4x4 stride 1   |         2835.4        |        3191.8      |         2484.3
      [64, 64, 16, 16, 16] 2x2 stride 4   |          438.6        |         795.7      |          368.6
      [64, 64, 16, 16, 16] 2x2 stride 2   |          749.1        |        1108.0      |          612.0
      [64, 64, 16, 16, 16] 2x2 stride 1   |         1756.4        |        2112.2      |         1538.5
      [64, 64, 4, 4, 4] 4x4 stride 4      |          132.6        |         163.9      |          115.4
      [64, 64, 4, 4, 4] 4x4 stride 2      |          129.3        |         153.7      |          117.8
      [64, 64, 4, 4, 4] 4x4 stride 1      |          128.0        |         153.8      |          117.6
      [64, 64, 4, 4, 4] 2x2 stride 4      |          128.2        |         154.1      |          117.5
      [64, 64, 4, 4, 4] 2x2 stride 2      |          130.5        |         157.3      |          117.6
      [64, 64, 4, 4, 4] 2x2 stride 1      |          128.8        |         156.4      |          120.6
      [64, 3, 32, 32, 32] 4x4 stride 4    |          200.4        |         261.0      |          228.8
      [64, 3, 32, 32, 32] 4x4 stride 2    |          305.3        |         366.5      |          344.4
      [64, 3, 32, 32, 32] 4x4 stride 1    |          860.9        |         922.1      |         1136.0
      [64, 3, 32, 32, 32] 2x2 stride 4    |          157.0        |         216.9      |          158.1
      [64, 3, 32, 32, 32] 2x2 stride 2    |          240.5        |         300.9      |          247.7
      [64, 3, 32, 32, 32] 2x2 stride 1    |          503.5        |         565.1      |          609.8
      [64, 3, 16, 16, 16] 4x4 stride 4    |          136.0        |         159.0      |          120.3
      [64, 3, 16, 16, 16] 4x4 stride 2    |          131.2        |         156.9      |          120.0
      [64, 3, 16, 16, 16] 4x4 stride 1    |          146.6        |         158.5      |          123.8
      [64, 3, 16, 16, 16] 2x2 stride 4    |          133.8        |         158.4      |          117.1
      [64, 3, 16, 16, 16] 2x2 stride 2    |          132.1        |         160.8      |          117.9
      [64, 3, 16, 16, 16] 2x2 stride 1    |          133.7        |         174.4      |          118.0
      [64, 3, 4, 4, 4] 4x4 stride 4       |          156.8        |         166.2      |          119.4
      [64, 3, 4, 4, 4] 4x4 stride 2       |          126.8        |         150.4      |          118.2
      [64, 3, 4, 4, 4] 4x4 stride 1       |          125.2        |         151.7      |          117.8
      [64, 3, 4, 4, 4] 2x2 stride 4       |          127.3        |         152.7      |          116.2
      [64, 3, 4, 4, 4] 2x2 stride 2       |          128.6        |         153.3      |          114.6
      [64, 3, 4, 4, 4] 2x2 stride 1       |          128.6        |         153.5      |          114.7
      [16, 256, 32, 32, 32] 4x4 stride 4  |         3921.7        |        8445.7      |         4064.7
      [16, 256, 32, 32, 32] 4x4 stride 2  |         6111.7        |       10630.0      |         6944.4
      [16, 256, 32, 32, 32] 4x4 stride 1  |        17938.9        |       22896.8      |        26648.7
      [16, 256, 32, 32, 32] 2x2 stride 4  |         3029.6        |        7552.7      |         2840.9
      [16, 256, 32, 32, 32] 2x2 stride 2  |         4788.0        |        9322.1      |         5110.5
      [16, 256, 32, 32, 32] 2x2 stride 1  |        10363.7        |       14885.9      |        13213.6
      [16, 256, 16, 16, 16] 4x4 stride 4  |          606.0        |        1059.1      |          535.9
      [16, 256, 16, 16, 16] 4x4 stride 2  |         1037.5        |        1491.5      |          822.3
      [16, 256, 16, 16, 16] 4x4 stride 1  |         2835.4        |        3306.8      |         2522.8
      [16, 256, 16, 16, 16] 2x2 stride 4  |          438.6        |         892.3      |          369.0
      [16, 256, 16, 16, 16] 2x2 stride 2  |          749.2        |        1203.7      |          638.7
      [16, 256, 16, 16, 16] 2x2 stride 1  |         1756.1        |        2212.5      |         1547.0
      [16, 256, 4, 4, 4] 4x4 stride 4     |          159.6        |         187.6      |          117.6
      [16, 256, 4, 4, 4] 4x4 stride 2     |          161.1        |         185.5      |          117.3
      [16, 256, 4, 4, 4] 4x4 stride 1     |          160.0        |         148.1      |          117.8
      [16, 256, 4, 4, 4] 2x2 stride 4     |          123.9        |         148.3      |          117.6
      [16, 256, 4, 4, 4] 2x2 stride 2     |          126.0        |         151.7      |          117.4
      [16, 256, 4, 4, 4] 2x2 stride 1     |          127.1        |         152.3      |          117.9
      [16, 64, 32, 32, 32] 4x4 stride 4   |          983.5        |        1756.7      |         1067.8
      [16, 64, 32, 32, 32] 4x4 stride 2   |         1542.4        |        2315.2      |         1621.5
      [16, 64, 32, 32, 32] 4x4 stride 1   |         4498.7        |        5273.4      |         6006.7
      [16, 64, 32, 32, 32] 2x2 stride 4   |          767.2        |        1543.4      |          736.7
      [16, 64, 32, 32, 32] 2x2 stride 2   |         1207.8        |        1981.5      |         1197.0
      [16, 64, 32, 32, 32] 2x2 stride 1   |         2603.3        |        3367.5      |         3161.9
      [16, 64, 16, 16, 16] 4x4 stride 4   |          169.5        |         264.6      |          142.8
      [16, 64, 16, 16, 16] 4x4 stride 2   |          274.6        |         368.9      |          216.8
      [16, 64, 16, 16, 16] 4x4 stride 1   |          723.3        |         820.4      |          643.2
      [16, 64, 16, 16, 16] 2x2 stride 4   |          131.4        |         216.0      |          116.1
      [16, 64, 16, 16, 16] 2x2 stride 2   |          199.9        |         295.0      |          166.8
```
CC @ptrblck

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80748
Approved by: https://github.com/ngimel
2022-07-08 04:26:01 +00:00
Michael Gschwind
25449292a0 Run mask test with and without nested tensor (#81008)
Summary: Run mask test with and without nested tensor

Test Plan: sandcastle

Differential Revision: D37665532

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81008
Approved by: https://github.com/malfet
2022-07-07 23:54:37 +00:00