Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70649
As described in https://fb.quip.com/oxpiA1uDBjgP
This implements the first parts of the RFC, and is a rough draft showing the approach. The idea is that for the first cut we can maintain very close (identical I believe in this diff) numerical equivalence to the existing nn.MHA implementation, which is what this diff attempts to do. In subsequent implementations, once we have a working and adopted native self-attention implementation, we could then explore alternative implementations, etc.
The current implementation is similar to existing dedicated implementations such as LightSeq/FasterTransformer/DeepSpeed, and for MHA on both CPUs and GPUs is between 1.2x and 2x faster depending on the setting. It makes some approximations/restrictions (doesn't handle masking in masked softmax, etc), but these shouldn't materially impact performance.
This does the first few items:
* add native_multi_head_attention(...) , native_multi_head_attention_backward(..) to native_functions.yaml
* Implement native_multi_head_attention(..) on GPU, extracting bits and pieces out of LS/DS/FT as appropriate
* Implement native_multi_head_attention(..) on CPU
The backward implementation is still WIP, but the idea would be to:
* Hook these up in derivatives.yaml
Implement native_multi_head_attention_backward(..) on GPU, extracting out bits and pieces out of LS/DS (not FT since it’s inference only)
* Implement native_multi_head_attention_backward(..) on CPU
* In torch.nn.functional.multi_head_attention_forward 23321ba7a3/torch/nn/functional.py (L4953), add some conditionals to check if we are being called in a BERT/ViT-style encoder fashion, and invoke the native function directly.
Test Plan: TODO
Reviewed By: mikekgfb
Differential Revision: D31829981
fbshipit-source-id: c430344d91ba7a5fbee3138e50b3e62efbb33d96
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66933
This PR exposes `torch.lu` as `torch.linalg.lu_factor` and
`torch.linalg.lu_factor_ex`.
This PR also adds support for matrices with zero elements both in
the size of the matrix and the batch. Note that this function simply
returns empty tensors of the correct size in this case.
We add a test and an OpInfo for the new function.
This PR also adds documentation for this new function in line of
the documentation in the rest of `torch.linalg`.
Fixes https://github.com/pytorch/pytorch/issues/56590
Fixes https://github.com/pytorch/pytorch/issues/64014
cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano
Test Plan: Imported from OSS
Reviewed By: gchanan
Differential Revision: D32834069
Pulled By: mruberry
fbshipit-source-id: 51ef12535fa91d292f419acf83b800b86ee9c7eb
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69327
Original commit changeset: d44096d88265
Original Phabricator Diff: D32144240 (668574af4a)
Test Plan:
CI
original diff failed 175 builds in CI
Reviewed By: airboyang, anjali411
Differential Revision: D32809407
fbshipit-source-id: c7c8e69bcee0274992e2d5da901f035332e60071
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66933
This PR exposes `torch.lu` as `torch.linalg.lu_factor` and
`torch.linalg.lu_factor_ex`.
This PR also adds support for matrices with zero elements both in
the size of the matrix and the batch. Note that this function simply
returns empty tensors of the correct size in this case.
We add a test and an OpInfo for the new function.
This PR also adds documentation for this new function in line of
the documentation in the rest of `torch.linalg`.
Fixes https://github.com/pytorch/pytorch/issues/56590
Fixes https://github.com/pytorch/pytorch/issues/64014
cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D32521980
Pulled By: mruberry
fbshipit-source-id: 26a49ebd87f8a41472f8cd4e9de4ddfb7f5581fb
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63568
This PR adds the first solver with structure to `linalg`. This solver
has an API compatible with that of `linalg.solve` preparing these for a
possible future merge of the APIs. The new API:
- Just returns the solution, rather than the solution and a copy of `A`
- Removes the confusing `transpose` argument and replaces it by a
correct handling of conj and strides within the call
- Adds a `left=True` kwarg. This can be achieved via transposes of the
inputs and the result, but it's exposed for convenience.
This PR also implements a dataflow that minimises the number of copies
needed before calling LAPACK / MAGMA / cuBLAS and takes advantage of the
conjugate and neg bits.
This algorithm is implemented for `solve_triangular` (which, for this, is
the most complex of all the solvers due to the `upper` parameters).
Once more solvers are added, we will factor out this calling algorithm,
so that all of them can take advantage of it.
Given the complexity of this algorithm, we implement some thorough
testing. We also added tests for all the backends, which was not done
before.
We also add forward AD support for `linalg.solve_triangular` and improve the
docs of `linalg.solve_triangular`. We also fix a few issues with those of
`torch.triangular_solve`.
Resolves https://github.com/pytorch/pytorch/issues/54258
Resolves https://github.com/pytorch/pytorch/issues/56327
Resolves https://github.com/pytorch/pytorch/issues/45734
cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano
Test Plan: Imported from OSS
Reviewed By: jbschlosser
Differential Revision: D32588230
Pulled By: mruberry
fbshipit-source-id: 69e484849deb9ad7bb992cc97905df29c8915910
Summary:
Adds native_dropout to have a reasonable target for torchscript in auto diff. native_dropout has scale and train as arguments in its signature, this makes native_dropout more consistent with other operators and removes conditionals in the autodiff definition.
cc gmagogsfm
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63937
Reviewed By: mruberry
Differential Revision: D32477657
Pulled By: ngimel
fbshipit-source-id: d37b137a37acafa50990f60c77f5cea2818454e4
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63568
This PR adds the first solver with structure to `linalg`. This solver
has an API compatible with that of `linalg.solve` preparing these for a
possible future merge of the APIs. The new API:
- Just returns the solution, rather than the solution and a copy of `A`
- Removes the confusing `transpose` argument and replaces it by a
correct handling of conj and strides within the call
- Adds a `left=True` kwarg. This can be achieved via transposes of the
inputs and the result, but it's exposed for convenience.
This PR also implements a dataflow that minimises the number of copies
needed before calling LAPACK / MAGMA / cuBLAS and takes advantage of the
conjugate and neg bits.
This algorithm is implemented for `solve_triangular` (which, for this, is
the most complex of all the solvers due to the `upper` parameters).
Once more solvers are added, we will factor out this calling algorithm,
so that all of them can take advantage of it.
Given the complexity of this algorithm, we implement some thorough
testing. We also added tests for all the backends, which was not done
before.
We also add forward AD support for `linalg.solve_triangular` and improve the
docs of `linalg.solve_triangular`. We also fix a few issues with those of
`torch.triangular_solve`.
Resolves https://github.com/pytorch/pytorch/issues/54258
Resolves https://github.com/pytorch/pytorch/issues/56327
Resolves https://github.com/pytorch/pytorch/issues/45734
cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano
Test Plan: Imported from OSS
Reviewed By: zou3519, JacobSzwejbka
Differential Revision: D32283178
Pulled By: mruberry
fbshipit-source-id: deb672e6e52f58b76536ab4158073927a35e43a8
Summary:
### Create `linalg.cross`
Fixes https://github.com/pytorch/pytorch/issues/62810
As discussed in the corresponding issue, this PR adds `cross` to the `linalg` namespace (**Note**: There is no method variant) which is slightly different in behaviour compared to `torch.cross`.
**Note**: this is NOT an alias as suggested in mruberry's [https://github.com/pytorch/pytorch/issues/62810 comment](https://github.com/pytorch/pytorch/issues/62810#issuecomment-897504372) below
> linalg.cross being consistent with the Python Array API (over NumPy) makes sense because NumPy has no linalg.cross. I also think we can implement linalg.cross without immediately deprecating torch.cross, although we should definitely refer users to linalg.cross. Deprecating torch.cross will require additional review. While it's not used often it is used, and it's unclear if users are relying on its unique behavior or not.
The current default implementation of `torch.cross` is extremely weird and confusing. This has also been reported multiple times previously. (See https://github.com/pytorch/pytorch/issues/17229, https://github.com/pytorch/pytorch/issues/39310, https://github.com/pytorch/pytorch/issues/41850, https://github.com/pytorch/pytorch/issues/50273)
- [x] Add `torch.linalg.cross` with default `dim=-1`
- [x] Add OpInfo and other tests for `torch.linalg.cross`
- [x] Add broadcasting support to `torch.cross` and `torch.linalg.cross`
- [x] Remove out skip from `torch.cross` OpInfo
- [x] Add docs for `torch.linalg.cross`. Improve docs for `torch.cross` mentioning `linalg.cross` and the difference between the two. Also adds a warning to `torch.cross`, that it may change in the future (we might want to deprecate it later)
---
### Additional Fixes to `torch.cross`
- [x] Fix Doc for Tensor.cross
- [x] Fix torch.cross in `torch/overridres.py`
While working on `linalg.cross` I noticed these small issues with `torch.cross` itself.
[Tensor.cross docs](https://pytorch.org/docs/stable/generated/torch.Tensor.cross.html) still mentions `dim=-1` default which is actually wrong. It should be `dim=None` after the behaviour was updated in PR https://github.com/pytorch/pytorch/issues/17582 but the documentation for the `method` or `function` variant wasn’t updated. Later PR https://github.com/pytorch/pytorch/issues/41850 updated the documentation for the `function` variant i.e `torch.cross` and also added the following warning about the weird behaviour.
> If `dim` is not given, it defaults to the first dimension found with the size 3. Note that this might be unexpected.
But still, the `Tensor.cross` docs were missed and remained outdated. I’m finally fixing that here. Also fixing `torch/overrides.py` for `torch.cross` as well now, with `dim=None`.
To verify according to the docs the default behaviour of `dim=-1` should raise, you can try the following.
```python
a = torch.randn(3, 4)
b = torch.randn(3, 4)
b.cross(a) # this works because the implementation finds 3 in the first dimension and the default behaviour as shown in documentation is actually not true.
>>> tensor([[ 0.7171, -1.1059, 0.4162, 1.3026],
[ 0.4320, -2.1591, -1.1423, 1.2314],
[-0.6034, -1.6592, -0.8016, 1.6467]])
b.cross(a, dim=-1) # this raises as expected since the last dimension doesn't have a 3
>>> RuntimeError: dimension -1 does not have size 3
```
Please take a closer look (particularly the autograd part, this is the first time I'm dealing with `derivatives.yaml`). If there is something missing, wrong or needs more explanation, please let me know. Looking forward to the feedback.
cc mruberry Lezcano IvanYashchuk rgommers
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63285
Reviewed By: gchanan
Differential Revision: D32313346
Pulled By: mruberry
fbshipit-source-id: e68c2687c57367274e8ddb7ef28ee92dcd4c9f2c
Summary:
Adds `torch.argwhere` as an alias to `torch.nonzero`
Currently, `torch.nonzero` is actually provides equivalent functionality to `np.argwhere`.
From NumPy docs,
> np.argwhere(a) is almost the same as np.transpose(np.nonzero(a)), but produces a result of the correct shape for a 0D array.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64257
Reviewed By: qihqi
Differential Revision: D32049884
Pulled By: saketh-are
fbshipit-source-id: 016e49884698daa53b83e384435c3f8f6b5bf6bb
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64430
The functionalization pass needs `{view}_scatter` versions of the slice/select/diagonal ops in order to correctly propagate mutations from a view to its base. On top of that, the implementations need to be primitive w.r.t. autograd, because they look something like `...slice().copy_()`, and the functionalization pass can't use views + mutations inside of it's own alias-removal machinery!
I added some basic tests that I tried to base off of existing tests for views (particularly around testing the derivative formulas), but I'm wondering if I should add something more comprehensive.
Also, as_strided fits into this category - the functionalization pass will need an `as_strided_scatter` op that's primitive w.r.t. autograd. I didn't add it for now, because it'll involve duplicating a bunch of logic from the current `as_strided_backward()` function, and also writing a derivative formula that I wasn't sure how to write :)
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D31942092
Pulled By: bdhirsh
fbshipit-source-id: c702a57c2748a7c771c14e4bcc3e996b48fcc4c8
Summary:
Adds mixed precision autocasting support between fp32/fp16 to torchscript/JIT. More in depth descriptoin can be found at [torch/csrc/jit/JIT-AUTOCAST.md](https://github.com/pytorch/pytorch/pull/63939/files#diff-1f1772aaa508841c5bb58b74ab98f49a1e577612cd9ea5c386c8714a75db830b)
This PR implemented an autocast optimization pass that inserts casting ops per AMP rule (torch/csrc/jit/passes/autocast.cpp), that mimics the behavior of eager autocast. The pass also takes into consideration the context of `torch.cuda.amp.autocast` and only inserts casting ops within the enabled context manager, giving feature parity as with eager amp autocast.
We currently provide JIT AMP autocast as a prototyping feature, so it is default off and could be turned on via `torch._C._jit_set_autocast_mode(True)`
The JIT support for autocast is subject to different constraints compared to the eager mode implementation (mostly related to the fact that TorchScript is statically typed), restriction on the user facing python code is described in doc torch/csrc/jit/JIT-AUTOCAST.md
This is a prototype, there are also implementation limitation that's necessary to keep this PR small and get something functioning quickly on upstream, so we can iterate on designs.
Few limitation/challenge that is not properly resolved in this PR:
1. Autocast inserts cast operation, which would have impact on scalar type of output tensor feeding downstream operations. We are not currently propagating the updated scalar types, this would give issues/wrong results on operations in promotion rules.
2. Backward for autodiff in JIT misses the casting of dgrad to input scalar type, as what autograd does in eager. This forces us to explicitly mark the casting operation for certain operations (e.g. binary ops), otherwise, we might be feeding dgrad with mismatch scalar type to input. This could potentially break gradient function consuming dgrad. (e.g. gemm backwards, which assumes grad_output to be of same scalar type as input')
3. `torch.autocast` api has an optional argument `dtype` which is not currently supported in the JIT autocast and we require a static value.
Credit goes mostly to:
tlemo
kevinstephano
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63939
Reviewed By: navahgar
Differential Revision: D31093381
Pulled By: eellison
fbshipit-source-id: da6e26c668c38b01e296f304507048d6c1794314
Summary:
Adds `torch.argwhere` as an alias to `torch.nonzero`
Currently, `torch.nonzero` is actually provides equivalent functionality to `np.argwhere`.
From NumPy docs,
> np.argwhere(a) is almost the same as np.transpose(np.nonzero(a)), but produces a result of the correct shape for a 0D array.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64257
Reviewed By: dagitses
Differential Revision: D31474901
Pulled By: saketh-are
fbshipit-source-id: 335327a4986fa327da74e1fb8624cc1e56959c70
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62030
Remove dtype tracking from Python Storage interface, remove all the different `<type>Storage` classes except for `ByteStorage`, and update serialization accordingly, while maintaining as much FC/BC as possible
Fixes https://github.com/pytorch/pytorch/issues/47442
* **THE SERIALIZATION FORMAT IS FULLY FC/BC.** We worked very hard to make sure this is the case. We will probably want to break FC at some point to make the serialization structure of tensors make more sense, but not today.
* There is now only a single torch.ByteStorage class. Methods like `Tensor.set_` no longer check that the dtype of storage is appropriate.
* As we no longer know what dtype of a storage is, we've **removed** the size method from Storage, replacing it with nbytes. This is to help catch otherwise silent errors where you confuse number of elements with number of bytes.
* `Storage._new_shared` takes a `nbytes` kwarg and will reject previous positional only calls. `Storage._new_with_file` and `_set_from_file` require explicit element size arguments.
* It's no longer possible to convert storages to different types using the float/double/etc methods. Instead, do the conversion using a tensor.
* It's no longer possible to allocate a typed storage directly using FloatStorage/DoubleStorage/etc constructors. Instead, construct a tensor and extract its storage. The classes still exist but they are used purely for unpickling.
* The preexisting serialization format stores dtype with storage, and in fact this dtype is used to determine the dtype of the tensor overall.
To accommodate this case, we introduce a new TypedStorage concept that exists only during unpickling time which is used to temporarily store the dtype so we can construct a tensor. **If you overrode the handling of pickling/unpickling, you MUST add handling for TypedStorage** or your serialization code will degrade to standard file-based serialization.
Original pull request: https://github.com/pytorch/pytorch/pull/59671
Reviewed By: soulitzer, ngimel
Differential Revision: D29466819
Pulled By: ezyang
fbshipit-source-id: 4a14e5d3c2b08e06e558683d97f7378a3180b00e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65621
Add a new attribute to the FusedMovingAvgObsFakeQuantize that controls if the Fake Quant operation should be applied at the output of a particular layer. The motivation is to give the users additional control to control the numerics of the fake_quant operators during training. It defaults to always fake quant the output (True).
Note: We will still observer the tensors as before (only the fake_quant operation is controlled using this flag)
For example
```
input model
x -> fc1 -> fc2 -> non_quantizable_op -> fc3
After fake_quant
x -> fake_quant(x) -> fc1 -> fake_quant(fc1) -> fc2 -> fake_quant(fc2) -> non_quantizable_op -> fake_quant() -> fc3 -> fake_quantize(fc3)
With output_fake_quant disabled at the output of fc2 and fc3 (since their outputs are non-quantizable)
x -> fake_quant(x) -> fc1 -> fake_quant(fc1) -> fc2 -> non_quantizable_op -> fake_quant() -> fc3
```
Test Plan: ./buck-out/gen/caffe2/test/quantization_fx\#binary.par -r test_disable_output_fake_quant
Reviewed By: jerryzh168
Differential Revision: D31174526
fbshipit-source-id: bffe776216d041fb09133a6fb09bfc2c0bb46b89
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65340
I thought about a few possible ways of doing this. The main hazard is
that if I create a CPU tensor that doesn't have any real storage, the
moment I actually try to access the data on the tensor I will segfault.
So I don't want to use _make_subclass on a "cpu meta tensor" because
the CPU meta tensor (with no subclass) is radioactive: printing it
will immediately cause a segfault. So instead, I have to create
the CPU meta tensor AND subclass all in one go, and that means I need
another function for it. One downside to doing it this way is
I need another overload for explicit strides, and in general it is
difficult to get the view relationships to all work out properly;
tracked at https://github.com/pytorch/pytorch/issues/65339
Fixes https://github.com/pytorch/pytorch/issues/62972
Fixes https://github.com/pytorch/pytorch/issues/62730
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D31057231
Pulled By: ezyang
fbshipit-source-id: 73522769e093ae8a1bf0c7f7e594659bfb827b28
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62671
Very crude first implementation of `torch.nanmean`. The current reduction kernels do not have good support for implementing nan* variants. Rather than implementing new kernels for each nan* operator, I will work on new reduction kernels with support for a `nan_policy` flag and then I will port `nanmean` to use that.
**TODO**
- [x] Fix autograd issue
Test Plan: Imported from OSS
Reviewed By: malfet
Differential Revision: D30515181
Pulled By: heitorschueroff
fbshipit-source-id: 303004ebd7ac9cf963dc4f8e2553eaded5f013f0
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64689
This brings it in line with the C++ implementation.
Fixes https://github.com/pytorch/pytorch/issues/64687
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D30816215
Pulled By: ezyang
fbshipit-source-id: ed36af6c35467ae678d9548197efd97c36d38dec
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63552
In this PR, we want to exclude these 2 cases in the `Autocast` weight cache usages:
- Using `torch.jit.trace` under the `Autocast`
As report in https://github.com/pytorch/pytorch/issues/50231 and several other discussions, using `torch.jit.trace` under the `Autocast`, the trace process would hit Autocast's weight cache and fails. So we should disable weight cache under the trace process.
- Using `Autocast` with `Grad mode`
- Usually we are using `Grad mode` for training. Since in the training phase, the weight will change in every step. So we doesn't need to cache the weight.
- For the recommended `Autocast` training case in the [doc](https://pytorch.org/docs/stable/amp.html), `Autocast` will clear the cache every step leaving the context. We should disable it to save the clear operations.
```
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
for input, target in data:
optimizer.zero_grad()
with autocast():
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
```
Test Plan: Imported from OSS
Reviewed By: mrshenli
Differential Revision: D30644913
Pulled By: ezyang
fbshipit-source-id: ad7bc87372e554e7aa1aa0795e9676871b3974e7
Summary:
Fixes https://github.com/pytorch/pytorch/issues/62811
Add `torch.linalg.matmul` alias to `torch.matmul`. Note that the `linalg.matmul` doesn't have a `method` variant.
Also cleaning up `torch/_torch_docs.py` when formatting is not needed.
cc IvanYashchuk Lezcano mruberry rgommers
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63227
Reviewed By: mrshenli
Differential Revision: D30770235
Pulled By: mruberry
fbshipit-source-id: bfba77dfcbb61fcd44f22ba41bd8d84c21132403
Summary:
Fixes https://github.com/pytorch/pytorch/issues/61767
## Changes
- [x] Add `torch.concat` alias to `torch.cat`
- [x] Add OpInfo for `cat`/`concat`
- [x] Fix `test_out` skips (Use `at::native::resize_output` or `at::native::resize_output_check`)
- [x] `cat`/`concat`
- [x] `stack`
- [x] `hstack`
- [x] `dstack`
- [x] `vstack`/`row_stack`
- [x] Remove redundant tests for `cat`/`stack`
~I've not added `cat`/`concat` to OpInfo `op_db` yet, since cat is a little more tricky than other OpInfos (should have a lot of tests) and currently there are no OpInfos for that. I can try to add that in a subsequent PR or maybe here itself, whatever is suggested.~
**Edit**: cat/concat OpInfo has been added.
**Note**: I've added the named tensor support for `concat` alias as well, maybe that's out of spec in `array-api` but it is still useful for consistency in PyTorch.
Thanks to krshrimali for guidance on my first PR :))
cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi heitorschueroff krshrimali
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62560
Reviewed By: saketh-are
Differential Revision: D30762069
Pulled By: mruberry
fbshipit-source-id: 6985159d1d9756238890488a0ab3ae7699d94337
Summary:
This PR implements the necessary hooks/stubs/enums/etc for complete ONNX Runtime (ORT) Eager Mode integration. The actual extension will live out of tree at https://github.com/pytorch/ort.
We have been [working on this at Microsoft](https://github.com/microsoft/onnxruntime-pytorch/tree/eager-ort/torch_onnxruntime) for the last few months, and are finally ready to contribute the PyTorch core changes upstream (nothing major or exciting, just the usual boilerplate for adding new backends).
The ORT backend will allow us to ferry [almost] all torch ops into granular ONNX kernels that ORT will eagerly execute against any devices it supports (therefore, we only need a single ORT backend from a PyTorch perspective).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58248
Reviewed By: astaff
Differential Revision: D30344992
Pulled By: albanD
fbshipit-source-id: 69082b32121246340d686e16653626114b7714b2