Commit Graph

323 Commits

Author SHA1 Message Date
Zachary DeVito
3b3ed25109 Add a way to visualize memory snapshot traces (#90348)
This adds a d3-based interactive visualization for exploring the memory
allocation traces that the caching allocator can capture. This visualization
code can also be attached to kineto trace information in the future to also
provide visualization for the memory events captured there, which come with
addition information about the graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90348
Approved by: https://github.com/robieta
2022-12-10 02:45:11 +00:00
Ram Rachum
351d73b97f Fix exception causes all over the codebase (#90271)
This is the continuation to #90134 and hopefully the final PR in this series.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90271
Approved by: https://github.com/kit1980
2022-12-07 04:29:00 +00:00
eqy
62e450d55f [CUDA Graphs] Add option to dump a captured graph for debugging (#85519)
CC @xwang233 @ptrblck @ngimel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85519
Approved by: https://github.com/ngimel
2022-12-06 22:03:05 +00:00
albanD
8713119c89 Stream actually overrides __new__ so we need to patch it as well (#89592)
Avoids
```
$ python foo.py
Traceback (most recent call last):
  File "foo.py", line 3, in <module>
    a = torch.cuda.Stream()
  File "/home/albandes/local/pytorch/3.8_debug_source/torch/cuda/streams.py", line 34, in __new__
    return super(Stream, cls).__new__(cls, priority=priority, **kwargs)
TypeError: object.__new__() takes exactly one argument (the type to instantiate)
```
And now gets
```
$ python foo.py
Traceback (most recent call last):
  File "foo.py", line 3, in <module>
    a = torch.cuda.Stream()
  File "/home/albandes/local/pytorch/3.8_debug_source/torch/cuda/streams.py", line 34, in __new__
    return super(Stream, cls).__new__(cls, priority=priority, **kwargs)
  File "/home/albandes/local/pytorch/3.8_debug_source/torch/cuda/_utils.py", line 44, in err_fn
    raise RuntimeError(
RuntimeError: Tried to instantiate dummy base class Stream

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89592
Approved by: https://github.com/soumith
2022-11-29 21:43:23 +00:00
Edward Z. Yang
b589e726d9 Refactor how AOTAutograd backends are defined (#89736)
There was a lot of strangeness in how AOTAutograd backends were previously defined. This refactor replaces the strangeness with something simple and straightforward. The improvements:

- There is no longer a footgun aot_autograd "backend" which doesn't actually work. No more mistyping `torch._dynamo.optimize("aot_autograd")` when you meant "aot_eager"
- Deleted aot_print because it's annoying and anyway there's no uses of it
- Instead of having BOTH the backend Subgraph and AotAutogradStrategy, there is now only an aot_autograd function which takes the kwargs to configure AOTAutograd, and then gives you a compiler function that does AOTAutograd given those kwargs. Easy.
- The primary downside is that we are now eagerly populating all of the kwargs, and that can get us into import cycle shenanigans. Some cycles I resolved directly (e.g., we now no longer manually disable the forward function before passing it to aot_autograd; aot_autograd it does it for us), but for getting inductor decompositions I had to make it take a lambda so I could lazily populate the decomps later.

New code is 130 lines shorter!

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89736
Approved by: https://github.com/anjali411, https://github.com/albanD
2022-11-28 18:39:12 +00:00
Edward Z. Yang
c9a0cc8640 Simplify aot_module_simplified by removing top_args/top_kwargs (#89666)
This makes good on Chillee's CR comment at
af30d351cc (r843315222)
which was never done in the original PR.

There is no logic change, just unpack the args/kwargs at the top
level and remove the inner function indirection.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89666
Approved by: https://github.com/voznesenskym
2022-11-25 20:43:13 +00:00
Emilio Castillo
c9d4390d13 Add Pluggable CUDA allocator backend (#86786)
Fixes #43144

This uses the Backend system added by [82682](https://github.com/pytorch/pytorch/pull/82682) to change allocators dynamically during the code execution. This will allow us to use RMM, use CUDA managed memory for some portions of the code that do not fit in GPU memory. Write static memory allocators to reduce fragmentation while training models and improve interoperability with external DL compilers/libraries.

For example, we could have the following allocator in c++

```c++
#include <sys/types.h>
#include <cuda_runtime_api.h>
#include <iostream>

extern "C" {
void* my_malloc(ssize_t size, int device, cudaStream_t stream) {
   void *ptr;
   std::cout<<"alloc "<< size<<std::endl;
   cudaMalloc(&ptr, size);
   return ptr;
}

void my_free(void* ptr) {
   std::cout<<"free "<<std::endl;
   cudaFree(ptr);
}
}
```

Compile it as a shared library
```
nvcc allocator.cc -o alloc.so -shared --compiler-options '-fPIC'
```

And use it from PyTorch as follows

```python
import torch

# Init caching
# b = torch.zeros(10, device='cuda')
new_alloc = torch.cuda.memory.CUDAPluggableAllocator('alloc.so', 'my_malloc', 'my_free')
old = torch.cuda.memory.get_current_allocator()
torch.cuda.memory.change_current_allocator(new_alloc)
b = torch.zeros(10, device='cuda')
# This will error since the current allocator was already instantiated
torch.cuda.memory.change_current_allocator(old)
```

Things to discuss
- How to test this, needs compiling external code ...

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86786
Approved by: https://github.com/albanD
2022-11-23 17:54:36 +00:00
Kazuaki Ishizaki
1cd6ebe095 Fix typos in messages under torch (#89049)
This PR fixes typos of messages in `.py` files under torch directory.
Only in `torch/onnx/symbolic_opset16.py`, fix a typo in comment to make the operator name correct.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89049
Approved by: https://github.com/lezcano
2022-11-17 04:18:14 +00:00
Kurt Mohler
ee28b865ee Deprecate TypedStorage, its derived classes, and all of their public methods (#85303)
Part of #85302

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85303
Approved by: https://github.com/ezyang
2022-11-08 18:11:01 +00:00
Kazuaki Ishizaki
2ddefbdc3c Fix typos used in documents under torch directory (#88300)
This PR fixes typos, in comments of Python files, that are found from a search box at https://pytorch.org/docs/master/search.html

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88300
Approved by: https://github.com/lezcano
2022-11-02 09:38:13 +00:00
Masaki Kozuki
bc03aa6013 Store autocast_gpu_dtype in custom_fwd and custom_bwd for BFloat16 autocast (#88029)
As per #87979, `custom_bwd` seems to forcefully use `torch.float16` for `torch.autograd.Function.backward` regardless of the `dtype` used in the forward.

Changes:
- store the `dtype` in `args[0]`
- update tests to confirm the dtype of intermediate result tensors that are outputs of autocast compatible `torch` functions

cc @ptrblck @ngimel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88029
Approved by: https://github.com/ngimel
2022-10-31 22:45:26 +00:00
Soumith Chintala
ff43288d31 [AOT][CUDAGraphs] torchdynamo -> torch._dynamo (#87243)
Fixes lingering issues from the torchdynamo -> torch._dynamo migration
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87243
Approved by: https://github.com/suo, https://github.com/voznesenskym, https://github.com/jansel
2022-10-21 03:14:28 +00:00
Syed Tousif Ahmed
77d94ac5ab Sets CUDA_MODULE_LOADING to LAZY when not set by the user (#85692)
This PR sets CUDA_MODULE_LOADING if it's not set by the user. By default, it sets it to "LAZY".

It was tested using the following commands:
```
python -c "import torch; tensor=torch.randn(20, 16, 50, 100).cuda(); free, total = torch.cuda.cudart().cudaMemGetInfo(0); print(total-free)"
```
which shows a memory usage of: 287,047,680 bytes

vs

```
CUDA_MODULE_LOADING="DEFAULT" python -c "import torch; tensor=torch.randn(20, 16, 50, 100).cuda(); free, total = torch.cuda.cudart().cudaMemGetInfo(0); print(total-free)"
```
which shows 666,632,192 bytes.

C++ implementation is needed for the libtorch users (otherwise it could have been a pure python functionality).

cc: @ptrblck @ngimel @malfet
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85692
Approved by: https://github.com/malfet
2022-10-13 14:03:01 +00:00
Daniel Dale
ce56ee11fd Extend torch.cuda.is_available() to attempt an NVML-based CUDA availability assessment when explicitly requested by the user (#85951)
Fixes #83973 (This is a substitute PR for https://github.com/pytorch/pytorch/pull/85024)

First of all, thanks for your invaluable contributions to PyTorch everyone!

Given how extensively `torch.cuda.is_available` is used in the PyTorch ecosystem, IMHO it's worthwhile to provide downstream libraries/frameworks/users the ability to alter the default behavior of `torch.cuda.is_available` in the context of their PyTorch usage.

I'm confident there are many current and future such use cases which could benefit from leveraging a weakened, NVML-based `torch.cuda.is_available` assessment at a downstream framework's explicit direction (thanks @malfet 81da50a972 !). Though one could always patch out the `torch.cuda.is_available` function with another implementation in a downstream library, I think this environmental variable based configuration option is more convenient and the cost to including the option is quite low.

As discussed in https://github.com/pytorch/pytorch/pull/85024#issuecomment-1261542045, this PR gates new non-default NVML-based CUDA behavior with an environmental variable (PYTORCH_NVML_BASED_CUDA_CHK) that allows a user/framework to invoke non-default, NVML-based `is_available()` assessments if desired.

Thanks again for your work everyone!
@ngimel @malfet @awaelchli

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85951
Approved by: https://github.com/ngimel
2022-10-12 18:37:50 +00:00
Eddie Yan
25725fd624 (Re-open) Adds cudaMallocAsync as an alternative backend for the CUDA allocator (#82682)
Rebased version of @mcarilli 's cudaMallocAsync #65365 for continued testing
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82682
Approved by: https://github.com/ngimel
2022-10-12 03:44:21 +00:00
anjali411
e2a4dfa468 Add correct __all__ for torch.distributed and torch.cuda submodules (#85702)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85702
Approved by: https://github.com/ezyang, https://github.com/albanD, https://github.com/rohan-varma
2022-10-10 19:15:24 +00:00
anjali411
a6c0442cce Add __all__ to torch.{autograd, fx, cuda} submodules (#85343)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85343
Approved by: https://github.com/albanD
2022-10-09 14:46:54 +00:00
Zachary DeVito
91b1bae1df Caching allocator tracing (#86241)
We currently can take snapshots of the state of the allocated cuda memory, but we do not have a way to correlate these snapshots with the actions the allocator that were taken between snapshots. This PR adds a simple fixed-sized buffer that records the major actions that the allocator takes (ALLOC, FREE, SEGMENT_ALLOC, SEGMENT_FREE, OOM, SNAPSHOT) and includes these with the snapshot information. Capturing period snapshots with a big enough trace buffer makes it possible to see how the allocator state changes over time.

We plan to use this functionality to guide how settings in the allocator can be adjusted and eventually have a more robust overall algorithm.

As a component of this functionality, we also add the ability to get a callback when the allocator will throw an OOM, primarily so that snapshots can be taken immediately to see why the program ran out of memory (most programs have some C++ state that would free tensors before the OutOfMemory exception can be caught).

This PR also updates the _memory_viz.py script to pretty-print the trace information and provide a better textual summary of snapshots distinguishing between internal and external fragmentation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86241
Approved by: https://github.com/ngimel
2022-10-07 23:19:54 +00:00
Edward Z. Yang
adf5919720 Add option to record C++ backtraces in _record_memory_history (#86145)
I used this to debug https://github.com/pytorch/pytorch/issues/86136 so it is useful. The implementation is not so fast so it is not enabled by default.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86145
Approved by: https://github.com/albanD, https://github.com/zdevito
2022-10-06 04:07:37 +00:00
Horace He
0e256c2550 removed compile cache and static argnums (#85783)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85783
Approved by: https://github.com/wconstab
2022-09-28 08:33:59 +00:00
anjali411
85073b8ddc Add __all__ to fx, fistributed and cuda submodules (#85080)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85080
Approved by: https://github.com/albanD
2022-09-21 18:04:58 +00:00
Mateusz Sypniewski
b70c254ebb Rework printing tensor aliases in CSAN error message (#85008)
Small rework of how the error message is formatted, introduces a distinction between the arguments and the output of kernels. Verified manually on multiple examples that the message is printed as expected.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85008
Approved by: https://github.com/lw
2022-09-21 13:41:52 +00:00
Nikita Shulga
9024015adf [BE] Small improvements to device_count (#85192)
If `_parse_visible_devices` returns an empty set, no need to make nvml calls Also, reduce indent a bit in `_device_count_nvml`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85192
Approved by: https://github.com/kit1980, https://github.com/ngimel
2022-09-18 20:38:43 +00:00
Nikita Shulga
45a9dcd4dd [BE] Add explicit __all__ to torch.cuda (#85193)
This helps one avoid re-exporting torch, warnings and other system modules from `torch.cuda`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85193
Approved by: https://github.com/kit1980
2022-09-17 18:20:17 +00:00
Hector Yuen
d23ce29761 allow changing the cuda allocator settings even after the process started (#84970)
Summary:
- expose a python call to set the allocator settings, it uses the same format as the value for PYTORCH_CUDA_ALLOCATOR
- keep the implementation contained within the cpp file to avoid increasing build times, only expose a function to call the setting
- make some of the Allocator Config methods public, now it looks more like a singleton

Test Plan: added the unit test

Differential Revision: D39487522

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84970
Approved by: https://github.com/zdevito
2022-09-17 09:42:42 +00:00
Nikita Shulga
81da50a972 Return device count using nvml (#84879)
Fixes https://github.com/pytorch/pytorch/issues/83973
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84879
Approved by: https://github.com/ngimel
2022-09-13 20:42:41 +00:00
Nikita Shulga
94f20c3514 Memoize torch.cuda.device_count (#84878)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84878
Approved by: https://github.com/ngimel
2022-09-13 20:42:41 +00:00
Nikita Shulga
cd3731bd17 [BE] Refactor _is_compiled() function (#84877)
Call it from `is_available()` and `device_count()`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84877
Approved by: https://github.com/ngimel
2022-09-12 20:45:13 +00:00
Mateusz Sypniewski
d12f3524b7 Add user facing documentation for CSAN (#84689)
This adds a user facing tutorial for the CSAN tool. The documentation preview should be available [here](https://docs-preview.pytorch.org/84689/index.html) once the GitHub job completes on this PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84689
Approved by: https://github.com/lw
2022-09-09 15:29:34 +00:00
Mateusz Sypniewski
8e57ce63a1 Add CSAN support for CPU synchronizations (#84428)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84428
Approved by: https://github.com/ngimel, https://github.com/lw
2022-09-09 15:19:33 +00:00
Mateusz Sypniewski
2b2e0fddf8 Add CUDA Sanitizer (#83984)
Example of a simple synchronization error:
```
a = torch.rand(4, 2, device="cuda")

with torch.cuda.stream(second_stream):
    torch.mul(a, 5, out=a)
```
Output produced by CSAN:
```
============================
CSAN detected a possible data race on tensor with data pointer 139719969079296
Access by stream 94646435460352 during kernel:
aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
writing to argument: self, out, output
With stack trace:
  File "/private/home/sypniewski/pytorch/torch/cuda/_sanitizer.py", line 364, in _handle_kernel_launch
    stack_trace = traceback.StackSummary.extract(
  File "/private/home/sypniewski/pytorch/torch/cuda/_sanitizer.py", line 544, in __torch_dispatch__
    errors = self.event_handler._handle_kernel_launch(
  File "/private/home/sypniewski/pytorch/torch/utils/_python_dispatch.py", line 76, in wrapped
    return f(self, *args, **kwargs)
  File "/private/home/sypniewski/pytorch/tester.py", line 9, in <module>
    torch.mul(a, 5, out=a)

Previous access by stream 0 during kernel:
aten::rand(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
writing to argument: output
With stack trace:
  File "/private/home/sypniewski/pytorch/torch/cuda/_sanitizer.py", line 364, in _handle_kernel_launch
    stack_trace = traceback.StackSummary.extract(
  File "/private/home/sypniewski/pytorch/torch/cuda/_sanitizer.py", line 544, in __torch_dispatch__
    errors = self.event_handler._handle_kernel_launch(
  File "/private/home/sypniewski/pytorch/torch/utils/_python_dispatch.py", line 76, in wrapped
    return f(self, *args, **kwargs)
  File "/private/home/sypniewski/pytorch/tester.py", line 6, in <module>
    a = torch.rand(10000, device="cuda")

Tensor was allocated with stack trace:
  File "/private/home/sypniewski/pytorch/torch/cuda/_sanitizer.py", line 420, in _handle_memory_allocation
    traceback.StackSummary.extract(
  File "/private/home/sypniewski/pytorch/torch/utils/_cuda_trace.py", line 23, in fire_callbacks
    cb(*args, **kwargs)
  File "/private/home/sypniewski/pytorch/torch/_ops.py", line 60, in __call__
    return self._op(*args, **kwargs or {})
  File "/private/home/sypniewski/pytorch/torch/cuda/_sanitizer.py", line 541, in __torch_dispatch__
    outputs = func(*args, **kwargs)
  File "/private/home/sypniewski/pytorch/torch/utils/_python_dispatch.py", line 76, in wrapped
    return f(self, *args, **kwargs)
  File "/private/home/sypniewski/pytorch/tester.py", line 6, in <module>
    a = torch.rand(10000, device="cuda")
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83984
Approved by: https://github.com/ezyang
2022-09-07 16:55:03 +00:00
Aidyn-A
ce1b727e77 Disable autocast cache in torch.cuda.make_graphed_callables (#84289)
There there are conflicts between `torch.clear_autocast_cache()` and `cudaMallocAsync` from #82682.
Moreover, the use of autocast caching is not reasonable during training which is the main target of `make_graphed_callables`.

cc @eqy @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84289
Approved by: https://github.com/ngimel
2022-09-01 21:34:51 +00:00
Jeff Daily
ff5fe9e622 [ROCm] enable jiterator (#77982)
### Description
Enables jiterator for ROCm builds.  This includes necessary porting when hiprtc and nvrtc behavior differed.  This also ported ROCm versus CUDA differences w.r.t. MAX_DIMS and NUM_THREADS from the non-jiterator code paths into jiterator.

### Testing
CI with ciflow/trunk label to force running ROCm workflows that are currently trunk-only.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77982
Approved by: https://github.com/ngimel
2022-08-15 16:04:09 +00:00
Horace He
c2808571bf Removed trace_factory_functions=False option (#83215)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83215
Approved by: https://github.com/ezyang
2022-08-13 03:06:45 +00:00
Zachary DeVito
4128712397 Propagate CUDAOutOfMemoryError to Python. (#83146)
The intention is to make it easier to catch this situation for debugging,
logging, or application-specific recovery.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83146
Approved by: https://github.com/albanD
2022-08-11 21:32:11 +00:00
Mark Saroufim
0e8beb7d0d Deleted cuda graph files that were moved to torchdynamo (#83128)
These files are now in https://github.com/pytorch/torchdynamo/pull/757
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83128
Approved by: https://github.com/ezyang
2022-08-10 20:23:18 +00:00
Zachary DeVito
726d040692 annotated allocator snapshots (#82146)
Record stack trace information for each allocated segment in the allocator.
It takes around 1.5us to record 50 stack frames of context.
Since invoking a Pytorch operator is around 8us, this adds minimal overhead but we still leave it disabled by default so that we can test it more on real workloads first.

Stack information is kept both for allocated blocks and the last allocation used inactive blocks. We could potential keep around the _first_ allocation that caused the block to get allocated from cuda as well.

Potential Followups:
* stack frame entries are small (16 bytes), but the list of Frames is not compressed eventhough most frames will share some entries. So far this doesn't produce huge dumps (7MB for one real workload that uses all memory on the GPU), but it can be much smaller through compression.
* Code to format the information is slow (a few seconds) because it uses python and FlameGraph.pl
* Things allocated during the backward pass have no stack frames because they are run on another C++ thread.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82146
Approved by: https://github.com/albanD
2022-08-09 17:21:35 +00:00
Pruthvi Madugundu
b57188760b [ROCm] torch.cuda.is_bf16_supported() returns True (#80410)
`torch.cuda.is_bf16_supported()` return False on ROCm which is not correct, since BF16 is supported on all AMD GPU arch - gfx906, gfx908 and gfx90a.

cc @jithunnair-amd
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80410
Approved by: https://github.com/jeffdaily, https://github.com/malfet
2022-08-03 01:18:37 +00:00
Aidyn-A
da0a3fe058 [Re-land] [CUDA graphs] Clear autocast amp cache (#81896)
Re-lands #81558 that got reverted due failing tests.

This failure happened because of the test that I poorly designed. [The loop here](https://github.com/pytorch/pytorch/pull/81558/files#diff-893b1eea27352f336f4cd832919e48d721e4e90186e63400b8596db6b82e7450R3837) is doing `cache_enabled=False` and then `cache_enabled=True`. By doing this loop the graph from previous iteration (case `False`) conflicts with the next one (case `True`). I redesigned the test such that it does not do any loops. The new test does separate function calls with different argument values.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81896
Approved by: https://github.com/ngimel
2022-08-02 23:22:00 +00:00
ProGamerGov
8def154e00 Fix multiple docstring type mistakes (#82474)
### Description

* Docstrings using `(tuple of ints)` shows up as `(tuple of python:ints)`, so I fixed them by making the `int` no longer plural. Example: https://pytorch.org/docs/stable/generated/torch.permute.html#torch.permute
* A docstring type in JIT had one of its types incorrectly highlighted as code. Example: https://pytorch.org/docs/stable/generated/torch.jit.script.html#torch.jit.script
* I found some docstring type usages of `string` that had not yet been converted to `str` after #82410
* Some docstrings incorrectly listed their defaults inside the docstring types.
* I also found a docstring that was missing its type

### Testing
No testing should be required.

---

In the developer guidelines, there should probably be standards listed for the docstring types.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82474
Approved by: https://github.com/albanD
2022-07-29 17:45:37 +00:00
ProGamerGov
357b7d589c Fix docstring inconsistencies: string -> str, boolean -> bool (#82410)
### Description

Throughout the PyTorch docs and codebase, the `string` type in docstrings is referred to by two separate names. This leads to inconsistent docs, like you can see here: https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html#torch.nn.Conv3d

This PR fixes this issue by ensuring that all mentions of the string type in docstrings, are using the same format that Sphinx generates hyperlinks for.

### Testing
No testing should be required for this change

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82410
Approved by: https://github.com/jbschlosser
2022-07-28 21:29:57 +00:00
Edward Z. Yang
3c2c2cc947 cudagraphs dynamo backend (#80566)
This backend handles cases where the preexisting cuda graphs
implementation from dynamo is unsound/has errors.

Requires this functorch bug fix: https://github.com/pytorch/functorch/pull/935

Signed-off-by: Edward Z. Yang <ezyangfb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80566
Approved by: https://github.com/ngimel, https://github.com/wconstab
2022-07-22 14:06:07 +00:00
PyTorch MergeBot
f5b460b200 Revert "[CUDA graphs] Clear autocast amp cache (#81558)"
This reverts commit e9d07bd4f0.

Reverted https://github.com/pytorch/pytorch/pull/81558 on behalf of https://github.com/janeyx99 due to Breaks windows 11.6 tests on trunk e9d07bd4f0
2022-07-21 12:46:36 +00:00
Aidyn-A
e9d07bd4f0 [CUDA graphs] Clear autocast amp cache (#81558)
According to [autocast_mode.cpp](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/autocast_mode.cpp) `cached_casts` is to be cleared at the end of each forward pass. However, this was not the case in current implementation of `make_graphed_callables` so a graph created the following way:

```
    with torch.cuda.amp.autocast(cache_enabled=True):
        graphed_foo = torch.cuda.make_graphed_callables(foo, tensors)
```
Behaves incorrectly.

cc @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81558
Approved by: https://github.com/ngimel
2022-07-21 01:44:14 +00:00
Sergii Dymchenko
e3a569384a Correct super class name (#81507)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81507
Approved by: https://github.com/seemethere
2022-07-18 16:26:50 +00:00
Spencer Kelly
bdf5abd6f0 fixed return type for cuda.memory.mem_get_info() (#81073)
Return type was `int` but function actually returns a tuple of two ints. The first being the free gpu memory in bytes and the second being the total available gpu memory in bytes.

Return type was fixed to correctly read `Tuple[int, int]` and the `Tuple` class was imported from `typing`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81073
Approved by: https://github.com/ngimel
2022-07-14 04:21:59 +00:00
Sergii Dymchenko
99244435f6 Resolve TODO after Python 2 for custom_fwd (#78592)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78592
Approved by: https://github.com/seemethere
2022-06-01 05:17:41 +00:00
Shawn Zhong
a468941355 Fix jiterator doc format (#78471)
Current docs do not show the code example properly:
https://pytorch.org/docs/master/generated/torch.cuda.jiterator._create_jit_fn.html
https://pytorch.org/docs/master/generated/torch.cuda.jiterator._create_multi_output_jit_fn.html

This PR fixes the formatting issue:
https://docs-preview.pytorch.org/78471/generated/torch.cuda.jiterator._create_jit_fn.html
https://docs-preview.pytorch.org/78471/generated/torch.cuda.jiterator._create_multi_output_jit_fn.html
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78471
Approved by: https://github.com/ngimel
2022-05-31 03:44:52 +00:00
Xinfeng Xie
72a4f6773d Add an argument to specify warmup iterations (#78124)
Summary: Add an argument to specify the number of warmup iterations to the API ``torch.cuda.make_graphed_callables``. By default, it needs 3 warm-up iterations. To work with NCCL, it needs 11 warm-up iterations.

Differential Revision: D36606758

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78124
Approved by: https://github.com/jianyuh
2022-05-25 01:21:15 +00:00
Sherlock Huang
6db8440f35 Python Jiterator supports multiple outputs (#78139)
This PR is part3.
Part1: https://github.com/pytorch/pytorch/pull/77902
Part2: https://github.com/pytorch/pytorch/pull/77921

Python Jiterator now supports returning multiple outputs

```
fn = torch.cuda.jiterator._create_multi_output_jit_fn(
"""
template <typename T>
T binary_2outputs(T i0, T i1, T& out0, T& out1) {
    out0 = i0 + i1;
    out1 = i0 - i1;
}
""",
num_outputs=2)

x = torch.rand(3, device='cuda')
y = torch.rand(3, device='cuda')
out0, out1 = fn(x, y)

torch.allclose(out0, x+y)
torch.allclose(out1, x-y)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78139
Approved by: https://github.com/ngimel
2022-05-24 21:52:56 +00:00