Commit Graph

616 Commits

Author SHA1 Message Date
Michael Gschwind
03b6e6979c Transformers: fix src and key padding mask bool regression (#96009)
Summary: fix src and pad mask bool regression

This fixes a regression introduced previously with #92733. That PR unified testing of masks to remove Byte Tensors as permissible mask, introduced mask compatibility check, and mask conversion to FP mask.  The problem addressed in this PR was that after the first mask had been converted, a check for mask compatibility would fail.

Test Plan: sandcastle & github

Differential Revision: D43782858

Fixes  https://github.com/pytorch/pytorch/issues/95702

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96009
Approved by: https://github.com/malfet
2023-03-05 01:50:46 +00:00
soulitzer
e5c2a35d83 Add check that embedding_bag's weight is 2D (#94931)
Fixes https://github.com/pytorch/pytorch/issues/94445

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94931
Approved by: https://github.com/albanD
2023-02-16 02:37:47 +00:00
Driss Guessous
70026aaad6 [SDPA] update type hint for scaled_dot_product_attention and documentation (#94008)
# Summary
- Adds type hinting support for SDPA
- Updates the documentation adding warnings and notes on the context manager
- Adds scaled_dot_product_attention to the non-linear activation function section of nn.functional docs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94008
Approved by: https://github.com/cpuhrsch
2023-02-10 18:02:43 +00:00
Natalia Gimelshein
a5daea69fb teach inductor to handle floor (#94341)
Per title, happen when there's upsampling with non-integer scale.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94341
Approved by: https://github.com/ezyang
2023-02-10 11:21:57 +00:00
PyTorch MergeBot
6007874bbb Revert "teach inductor to handle floor (#94341)"
This reverts commit e7df9aaec8.

Reverted https://github.com/pytorch/pytorch/pull/94341 on behalf of https://github.com/huydhn due to Sorry for reverting your PR, but the CudaTest failure looks related.  It fails on both PR and trunk e7df9aaec8
2023-02-09 19:31:08 +00:00
Natalia Gimelshein
e7df9aaec8 teach inductor to handle floor (#94341)
Per title, happen when there's upsampling with non-integer scale.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94341
Approved by: https://github.com/ezyang
2023-02-09 17:09:35 +00:00
milesial
6c555b29a8 MHA optimizations (#93234)
Slight perf optimizations for regular MHA by reducing the number of kernels called

Before:
![image](https://user-images.githubusercontent.com/30204471/215349212-172c6364-9e3c-4fd1-92b6-8ddd9931613e.png)

After:
![image](https://user-images.githubusercontent.com/30204471/215349247-021dd9e6-f6ca-40a2-8de8-0805af001f69.png)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93234
Approved by: https://github.com/drisspg
2023-02-03 15:18:35 +00:00
Driss Guessous
3df0e26e20 [SDPA] Remove private version and only utilize public version (#94004)
# Summary
Due to internal failures we needed to keep the private call in torch.nn.mha. This PR undoes this change, so that we call the public function and remove the private function.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94004
Approved by: https://github.com/cpuhrsch, https://github.com/albanD
2023-02-03 08:12:09 +00:00
103yiran
d9117b93fb unsqueeze only when dim = 3 (#91052)
unsqueeze is not necessary if use view

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91052
Approved by: https://github.com/albanD
2023-01-31 16:28:23 +00:00
Driss Guessous
ca8f5e177a Use the old aten underscored function for Predictor (#93096)
Summary:
Errors reported via https://fb.prod.workplace.com/groups/1405155842844877/permalink/6644919482201794/

The problem is that the scriptable op set between predictor and the latest build of master is different.

Test Plan: Sandcastle testing

Differential Revision: D42786069

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93096
Approved by: https://github.com/mikekgfb
2023-01-28 03:14:18 +00:00
Michael Gschwind
7265f60ad0 Regularize mask handling for attn_mask and key_padding_mask (#92733)
Summary:
Regularize mask handling for attn_mask and key_padding_mask
* Update documentation to remove reference to byte masks (which were deprecated long ago)
* Introduce check and warn about deprecation if attn_mask and key_padding_mask types mismatch
* Convert all masks to float before combining
* Combine by adding

Test Plan: sandcastle & github CI

Differential Revision: D42653215

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92733
Approved by: https://github.com/ngimel, https://github.com/drisspg
2023-01-24 14:12:05 +00:00
Driss Guessous
df14650f0b [SDPA] Update SDPA API and make function Public (#92189)
# Summary
In preparation for pt 2.0 launch this PR updates SDPA's API and makes the function a nn.funcitonal public function.

## Changes
### API
Previously the the function signature was:
`scaled_dot_product_attention(query, key, value, attn_mask=None, need_attn_weights=False, dropout_p=0.0, is_causal=False) -> (Tensor, Tensor)`
Updated signature:
`scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) -> Tensor`

This PR removes the need_attn_weights optional boolean variable and updates the return type to a singular tensor.

#### Reasoning:
The main goal of this function is to provide an easy interface for users to call into fused attention kernels e.g.  (FlashAttention). The fused kernels do not currently support arbitrary attn_mask or dropout but there is a PR to mem-efficient attention to enable these. We want to have the API surface ready for when the backing kernels get updated.

The fused kernels save on memory usage by not materializing the weights and it is unlikely that a fast fused implementation will enable this feature so we are removing.

Discussed with folks at FAIR/Xformers and +1 this API change.

#### Make function Public
In preparation for the pt 2.0 launch we make the function public to start to generate user feedback

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92189
Approved by: https://github.com/cpuhrsch
2023-01-23 20:50:46 +00:00
Michael Gschwind
af589b3d1f switch causal mask for is_causal flag (#91171)
Summary: switch causal mask for is_causal flag

Test Plan: sandcastle & github

Differential Revision: D42089340

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91171
Approved by: https://github.com/wushirong, https://github.com/drisspg
2022-12-30 17:24:58 +00:00
joncrall
ad782ff7df Enable xdoctest runner in CI for real this time (#83816)
Builds on #83317 and enables running the doctests. Just need to figure out what is causing the failures.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83816
Approved by: https://github.com/ezyang, https://github.com/malfet
2022-12-29 05:32:42 +00:00
Joel Schlosser
3d8834bdbf SymIntify F.interpolate() with recompute_scale_factor=True (#91318)
This PR makes the minor changes necessary to get `F.interpolate()` working with symbolic shapes when `recompute_scale_factor=True` + adds `OpInfo` samples to test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91318
Approved by: https://github.com/ezyang
2022-12-29 01:42:56 +00:00
Michael Gschwind
512ec181ec Introduce causal mask (#90508)
Summary: Introduce causal mask

This PR introduces a causal mask option _causal_mask (as well as causal mask detection if attn_mask is provided), since current custom kernels do not support arbitrary masks.

Test Plan: sandcastle & github ci/cd

Differential Revision: D41723137

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90508
Approved by: https://github.com/albanD
2022-12-16 21:39:42 +00:00
Driss Guessous
78bdb858f9 Call _sdp_attention in nn.functional.mha (#89470)
# Summary
Replaces the the inline block of code in nn.funcitonal.mha with `_scaled_dot_product_attention`. This function allows the fused kernels to be called if all the required input conditions are met.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89470
Approved by: https://github.com/cpuhrsch, https://github.com/mikekgfb
2022-12-02 19:46:22 +00:00
PyTorch MergeBot
f1415b8cb6 Revert "Call _sdp_attention in nn.functional.mha (#89470)"
This reverts commit 4d7ec30220.

Reverted https://github.com/pytorch/pytorch/pull/89470 on behalf of https://github.com/jeanschmidt due to breaking internal builds
2022-11-30 16:16:24 +00:00
PyTorch MergeBot
618a585f6c Revert "replace double transpose with single permute in nn.f.mha (#89847)"
This reverts commit b9afa92827.

Reverted https://github.com/pytorch/pytorch/pull/89847 on behalf of https://github.com/jeanschmidt due to Need to revert this commit as it is causing conflict when reverting #89470
2022-11-30 16:03:48 +00:00
Driss Guessous
b9afa92827 replace double transpose with single permute in nn.f.mha (#89847)
# Summary

I forgot about permute which was exactly what I wanted. Quick perf bump
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89847
Approved by: https://github.com/cpuhrsch, https://github.com/albanD
2022-11-29 22:18:42 +00:00
Driss Guessous
4d7ec30220 Call _sdp_attention in nn.functional.mha (#89470)
# Summary
Replaces the the inline block of code in nn.funcitonal.mha with `_scaled_dot_product_attention`. This function allows the fused kernels to be called if all the required input conditions are met.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89470
Approved by: https://github.com/cpuhrsch, https://github.com/mikekgfb
2022-11-29 03:02:10 +00:00
foram-chandra
e19a7165fd [nn] Remove deprecation warning from nn.functional.{tanh, sigmoid} (#86905)
Fixes #65909

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86905
Approved by: https://github.com/albanD, https://github.com/kit1980
2022-11-24 00:34:26 +00:00
Nikita Karetnikov
0a1a53083e [primTorch] Enable regex error testing for some refs (#87765)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87765
Approved by: https://github.com/mruberry
2022-11-23 23:36:27 +00:00
David Boetius
b652fbc57a Fix torch.nn.functional.gelu docstring formatting (#89061)
The docstring of `torch.nn.functional.gelu` is formatted incorrectly, so that part of the math isn't rendered and there are extra blocks when there shouldn't: https://pytorch.org/docs/stable/generated/torch.nn.functional.gelu.html

I didn't build the docs, so I am not 100% sure that I got the formatting right, but I am confident.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89061
Approved by: https://github.com/bdhirsh, https://github.com/kit1980
2022-11-18 01:57:41 +00:00
Ryan Spring
534ae6ae47 [primTorch] Implement group norm reference (#87054)
Add group norm reference
Split from #81191
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87054
Approved by: https://github.com/mruberry
2022-11-11 01:08:20 +00:00
Kazuaki Ishizaki
2ddefbdc3c Fix typos used in documents under torch directory (#88300)
This PR fixes typos, in comments of Python files, that are found from a search box at https://pytorch.org/docs/master/search.html

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88300
Approved by: https://github.com/lezcano
2022-11-02 09:38:13 +00:00
Rui Zhu
4b757f4633 Assert if padding mask type is unexpected (#86353) (#87106)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86353

Fix the issue described in
https://github.com/pytorch/pytorch/issues/86120

Test Plan: buck test mode/opt caffe2/test:test_transformers -- test_train_with_long_type_pad

Differential Revision: D40129968

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87106
Approved by: https://github.com/malfet
2022-10-20 16:01:54 +00:00
Andrew M. James
db65909255 [Docs] Update mm family ops and F.linear to note limited sparse support. (#86220)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86220
Approved by: https://github.com/cpuhrsch
2022-10-18 19:55:18 +00:00
Nikita Karetnikov
d56017a14f [primTorch] Add ref for triplet_margin_loss, improve triplet_margin_with_distance_loss (#85614)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85614
Approved by: https://github.com/lezcano, https://github.com/mruberry
2022-10-12 18:37:58 +00:00
lezcano
787028cadb Implement col2im decomposition and fix im2col and add a few preconditions (#85541)
As per title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85541
Approved by: https://github.com/jansel
2022-09-30 09:31:53 +00:00
Srikumar Sastry
c8776dca6a Remove extra with in value error exception statement (#84713)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84713
Approved by: https://github.com/ngimel
2022-09-27 18:43:39 +00:00
Driss Guessous
253ffbf28b Exposing native _scaled_dot_product_attention to torch.nn (#85044)
# Summary
This exposes the _scaled_dot_product_attention function to python in the nn namespace. It is still underscored because the api for args, and kwargs is still in flux for the next few weeks and will eventually land as a prototype feature.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85044
Approved by: https://github.com/cpuhrsch
2022-09-22 16:30:16 +00:00
PyTorch MergeBot
a3dc338ee1 Revert "Exposing native _scaled_dot_product_attention to torch.nn (#85044)"
This reverts commit 9fdd8a8b7f.

Reverted https://github.com/pytorch/pytorch/pull/85044 on behalf of https://github.com/huydhn due to This breaks CUDA 10.2 in trunk. We are deprecating CUDA 10.2, but it is still here in the mean time
2022-09-21 08:34:51 +00:00
Driss Guessous
9fdd8a8b7f Exposing native _scaled_dot_product_attention to torch.nn (#85044)
# Summary
This exposes the _scaled_dot_product_attention function to python in the nn namespace. It is still underscored because the api for args, and kwargs is still in flux for the next few weeks and will eventually land as a prototype feature.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85044
Approved by: https://github.com/cpuhrsch
2022-09-21 03:09:08 +00:00
joncrall
b136f3f310 More doctest refinements. (#83317)
Follow up to #82797

Now that the doctests themselves are in a better state, we should be able to enable xdoctest on the CI so they stay that way.

@ezyang @vadimkantorov
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83317
Approved by: https://github.com/ezyang
2022-08-22 20:07:26 +00:00
Edward Z. Yang
cb64b558ee Add spaces so example is flake8 compatible (#83420)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83420
Approved by: https://github.com/jbschlosser
2022-08-15 21:39:57 +00:00
joncrall
4618371da5 Integrate xdoctest - Rebased (#82797)
This is a new version of #15648 based on the latest master branch.

Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.

In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)

Fixes https://github.com/pytorch/pytorch/issues/71105

@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
2022-08-12 02:08:01 +00:00
Alex Li
1fedd40424 Update cross entropy documentation to metion logits clearly (#82538)
### Description
Improved the documentation for cross entropy as it is a common point of confusion.

### Issue
#82081

### Testing
I did not test this change as it is tiny and documentation-only
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82538
Approved by: https://github.com/jbschlosser
2022-08-08 22:24:28 +00:00
ProGamerGov
357b7d589c Fix docstring inconsistencies: string -> str, boolean -> bool (#82410)
### Description

Throughout the PyTorch docs and codebase, the `string` type in docstrings is referred to by two separate names. This leads to inconsistent docs, like you can see here: https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html#torch.nn.Conv3d

This PR fixes this issue by ensuring that all mentions of the string type in docstrings, are using the same format that Sphinx generates hyperlinks for.

### Testing
No testing should be required for this change

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82410
Approved by: https://github.com/jbschlosser
2022-07-28 21:29:57 +00:00
kylematoba
66cf1b6459 correct argument name in docs (#81485)
Recently introduced `average_attn_weights` argument is documented incorrectly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81485
Approved by: https://github.com/albanD
2022-07-20 20:07:16 +00:00
soulitzer
bd75b2fea1 Add ref for nn.functional.prelu (#79768)
TODO:
- not sure if these error-inputs work for all devices (awaiting CI)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79768
Approved by: https://github.com/mruberry
2022-07-07 17:04:47 +00:00
Albert Chung
b4ed13ea0f Update docstring for scale_factor in torch.nn.functional.interpolate. (#80807)
Fixes #80786

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80807
Approved by: https://github.com/ezyang
2022-07-04 04:36:16 +00:00
Joel Benjamin Schlosser
5953fd9133 Revert behavior of Dropout2d on 3D inputs to 1D channel-wise dropout behavior & warn
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79549

Approved by: https://github.com/ngimel, https://github.com/albanD
2022-06-15 14:56:43 +00:00
Joel Benjamin Schlosser
2d73c8e6e0 Add Dropout1d module
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79545

Approved by: https://github.com/ngimel, https://github.com/albanD
2022-06-15 14:39:07 +00:00
PyTorch MergeBot
3556457dd2 Revert "kl_div: fix for grads wrt target, double backward, forward-over-reverse AD support. (#79007)"
This reverts commit 72ad222cff.

Reverted https://github.com/pytorch/pytorch/pull/79007 on behalf of https://github.com/janeyx99 due to Broke test_fn_fwgrad_bwgrad_nn_functional_kl_div_cpu_float64 on trunk https://hud.pytorch.org/minihud?name_filter=pull%20/%20linux-xenial-py3.7-clang7-asan%20/%20test%20(default,%202,%205,%20linux.2xlarge)
2022-06-09 13:07:03 +00:00
Nikita Vedeneev
72ad222cff kl_div: fix for grads wrt target, double backward, forward-over-reverse AD support. (#79007)
Fixes https://github.com/pytorch/pytorch/issues/78867,
fixes https://github.com/pytorch/pytorch/issues/65466.
Adds forward-over-reverse AD support.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79007
Approved by: https://github.com/soulitzer, https://github.com/jbschlosser
2022-06-09 09:06:52 +00:00
Rohit Goswami
5a95b20d0f DOC: Harmonize ELU documentation with the module doc (#78909)
Fixes #77055 by simply referring to the module docs as noted in the issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78909
Approved by: https://github.com/albanD
2022-06-06 14:14:11 +00:00
samdow
b7cb4eae6b Fix embedding jvp support by making embedding_renorm ignore forward mode AD (#78560)
On functorch, we started seeing [embedding forward mode fail](https://github.com/pytorch/functorch/pull/816). From looking at it, we figured out that recently [embedding got forward mode support enabled](369d9f4137) and then doing forward mode with embedding and [max_norm doesn't work with gradcheck](https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_methods_invocations.py#L8877-L8881), so it's not checked.

What was happening is that `embedding_renorm` was setting `torch.no_grad()` which only turns off the backwards mode AD so functorch's jvp tests were still using forward mode AD during the `embedding_renorm` call. This makes it so that we don't use forward mode during the embedding_renorm call
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78560
Approved by: https://github.com/soulitzer, https://github.com/albanD
2022-06-03 19:14:51 +00:00
PyTorch MergeBot
d578197747 Revert "Fix embedding jvp support by making embedding_renorm ignore forward mode AD (#78560)"
This reverts commit ce7c7bb2a9.

Reverted https://github.com/pytorch/pytorch/pull/78560 on behalf of https://github.com/malfet due to broke XLA (on CI and trunk), see ce7c7bb2a9
2022-06-02 17:40:34 +00:00
samdow
ce7c7bb2a9 Fix embedding jvp support by making embedding_renorm ignore forward mode AD (#78560)
On functorch, we started seeing [embedding forward mode fail](https://github.com/pytorch/functorch/pull/816). From looking at it, we figured out that recently [embedding got forward mode support enabled](369d9f4137) and then doing forward mode with embedding and [max_norm doesn't work with gradcheck](https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_methods_invocations.py#L8877-L8881), so it's not checked.

What was happening is that `embedding_renorm` was setting `torch.no_grad()` which only turns off the backwards mode AD so functorch's jvp tests were still using forward mode AD during the `embedding_renorm` call. This makes it so that we don't use forward mode during the embedding_renorm call
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78560
Approved by: https://github.com/soulitzer, https://github.com/albanD
2022-06-02 13:40:21 +00:00