Summary:
cudaEventCreate/Destroy can be expensive especially when the process is calling lots of other CUDA APIs.
Pool the `cudaEvent_t` objects so that we create them once and reuse as much as possible.
Test Plan:
Unit tests to check the functionality.
Manual performance testing shows that this diff is perf positive.
| | create_event_internal (us) | free_event_internal/destructor (us) | insert_events (us) | process_events (us) |
| baseline | 2.411 | 2.647 | 3.968 | 0.321 |
| this diff | 0.115 | 0.147 | 2.846 | 0.262 |
| speed up | 20.9x | 18.0x | 1.4x | 1.2x |
Differential Revision: D35729059
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78279
Approved by: https://github.com/jianyuh
Summary:
If fraction is not set, don't trigger GC!
In the current codebase, if you turn on the GC and *do not set the fraction* in the application, the GC will be triggered every time which does not make much sense -- perf will be as bad as turning off the caching allocator.
With this fix, GC is invoked only when the fraction is set.
Test Plan: Unit tests
Differential Revision: D36026128
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76648
Approved by: https://github.com/yinghai
release_cached_blocks calls this:
```
void synchronize_and_free_events() {
TORCH_INTERNAL_ASSERT(captures_underway == 0);
```
Which means we can't call that function when we are capturing a cuda graph:
```
import torch
with torch.cuda.graph(torch.cuda.CUDAGraph()):
torch.zeros(2 ** 40, device="cuda")
```
results in:
```
RuntimeError: captures_underway == 0INTERNAL ASSERT FAILED at "/tmp/torch/c10/cuda/CUDACachingAllocator.cpp":1224, please report a bug to PyTorch.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76247
Approved by: https://github.com/ngimel
Summary:
Introduces additional ways of handling CUDA errors that allow automated linters to detect if errors are being handled.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74865
Test Plan: Sandcastle
Reviewed By: ngimel
Differential Revision: D35194530
fbshipit-source-id: f4fe61594edbfd81e97a4b605935961b893df167
(cherry picked from commit 919ddf677c5b9b46c5e493ed64346a5f2527bf08)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74261
### Goal
Implement a cheap way to reclaim GPU memory (garbage collection) without incurring GPU sync.
### Why do we need this?
Currently, there are only two ways to reclaim GPU memory block already assigned to a particular stream.
- `release_available_cached_blocks(params)`: Free blocks exceeding the `CachingAllocatorConfig::max_split_size()` until we can satisfy the request.
Issue: If the `max_split_size` is unset (default), this function is a no-op. Even if this is set, the reclamation is quite conservative (e.g., never frees blocks under max_split_size).
- `release_cached_blocks()`: Waits for all the in-flight events and then reclaim blocks.
Issue: 'waiting for all event' is very expensive as it will likely stall all the GPU operations. Many GPU applications without a proper handling of potential GPU throttling would suffer/crash.
### Proposed idea
- If the garbage collection threshold is set, try to reclaim some memory blocks *without* synchronization. It should be safe to do so, as `release_available_cached_blocks` essentially does the same thing (but less aggressively).
- GC is triggered only when we fail to serve a `malloc` request from the block pool. No need to free blocks when the block pool is functioning just fine.
- Prioritize reclaiming blocks that weren't reused for long time. Reclamation stops once the used memory capacity < threshold.
- This code path is totally optional; by default it won't be invoked.
Test Plan:
- Unit tests
- Manually checked that the GPU memory usage stays as indicated by the garbage collector. If not the caching allocator at least tries to keep freeing the blocks.
Reviewed By: jianyuh
Differential Revision: D34482514
fbshipit-source-id: d5eae62ac60b94b0bca851f9d233a092d086e3c2
(cherry picked from commit 05780f1ed4b176f05e765b2411c9eaa2eaeb48b0)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74213
In the current CUDACachingAllocator, the sizes are rounded up in multiple of blocks size of 512, so this works for smaller sizes. However for large sizes, we can have lots of different size blocks in the larger pool. This is problematic when we have variable batch sizes 1001, 1021, 1023 -> all will go to different block size and will create different size of blocks. This will create lots of unused blocks and will waste GPU memory capacity.
This diff adds a rounding approach to allocation size. It rounds up the size to nearest power-of-2 divisions and the power2-division can be changed with env variable setting.
For example, if we need to round-up size of1200 and if number of divisions is 4,
the size 1200 lies between 1024 and 2048 and if we do 4 divisions between
them, the values are 1024, 1280, 1536, and 1792. So the function will
return 1280 as the nearest ceiling of power-2 division.
env setting:
export PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4
ghstack-source-id: 151446017
Reviewed By: ezyang
Differential Revision: D34868036
fbshipit-source-id: 494785add16e6b37c920dcb5a2b81d4c637b554a
(cherry picked from commit 548454ccacbd8700e7ffd2d762e40b4ba37abbae)
Summary:
# Problem
The error message `RuntimeError: Invalid device argument` is not friendly when users just forget calling `torch.cuda.init()`.
This error message is shown for example by calling `torch.cuda.reset_accumulated_memory_stats`, or other methods which internally calls [assertValidDevice](6297aa114f/c10/cuda/CUDACachingAllocator.cpp (L1561-L1566)).
# Reproduce
```python
$ python
Python 3.8.6 (default, Apr 1 2021, 08:23:31)
[GCC 7.5.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch.cuda
>>> torch.cuda.reset_accumulated_memory_stats(0)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.8/site-packages/torch/cuda/memory.py", line 219, in reset_accumulated_memory_stats
return torch._C._cuda_resetAccumulatedMemoryStats(device)
RuntimeError: Invalid device argument.
>>> torch.cuda.current_device()
0
```
# This PR
Shows better error message like `RuntimeError: Invalid device argument 0: did you call init?`. I cited the error message from 6297aa114f/c10/cuda/CUDACachingAllocator.cpp (L1392-L1396).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72404
Reviewed By: mruberry
Differential Revision: D34063268
Pulled By: ngimel
fbshipit-source-id: 0775d9c83a4a0eb0eb41bf6efecca94a00692141
(cherry picked from commit 07a1a3d0b4)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71429
Note that this was untested in OSS Bazel.
ghstack-source-id: 148159363
Test Plan: Tested locally. Rely on CI to validate.
Reviewed By: malfet
Differential Revision: D33638407
fbshipit-source-id: 12ae383ccadc1375b92d9c6a12d43821e48f9dcb
(cherry picked from commit 12be8c195c)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70863
ghstack-source-id: 148159368
Test Plan: Ought to be a no-op: rely on CI to validate.
Reviewed By: malfet
Differential Revision: D33367290
fbshipit-source-id: cb550538b9eafaa0117f94077ebd4cb920688881
(cherry picked from commit 077d9578bc)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/71616
This fixes the leaks in my test case. I have not tested it on our big models yet, but will report back if we can.
This potentially impacts allocator performance in that it slightly increases the amount of CPU memory we allocate for data structures, and it means that `process_events` may look at a larger number of events in the case where there are multiple streams with long-running ops on them.
However, I suspect that in general, either:
- An application isn't using very many streams or very many long-running ops, in which case the performance is essentially the same
- Or, they are, which is precisely the case where https://github.com/pytorch/pytorch/issues/71616 bites you, and so freeing memory faster is probably more valuable than the slight CPU overhead here.
I'm not attached to this approach or any of its details, but figured it was worth throwing up for discussion.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71745
Reviewed By: soulitzer
Differential Revision: D33948288
Pulled By: ngimel
fbshipit-source-id: 73e95f8a9bbe385a77de483d1c58b857b5d84e81
(cherry picked from commit d233719c07)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71669
This was relatively inefficient. Rather than looping for each type of stat we want to update, we now do one loop covering all the stats.
ghstack-source-id: 148013645
Reviewed By: ngimel
Differential Revision: D33725458
fbshipit-source-id: 39ef5d65a73d4ef67f259de8c02c7df29487d990
(cherry picked from commit 7ca46689b7)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71667
We have flat_hash_set because it performs better than std::unordered_set.
ghstack-source-id: 148013648
Reviewed By: ngimel
Differential Revision: D33720595
fbshipit-source-id: aa6077c474dd6fc61ce17e24ebde4056c8bae361
(cherry picked from commit 386082eaf1)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70851
This is a step towards OSS/fbcode convergence since OSS uses this file
in both CMake and Bazel.
ghstack-source-id: 147170896
Test Plan: Relying on the extensive CI internal tests for this.
Reviewed By: malfet
Differential Revision: D33299102
fbshipit-source-id: c650dd4755f8d696d5fce81c583d5c73782e3990
(cherry picked from commit 741ca140c8)
Summary:
The `TORCH_CHECK` asserts for strictly-greater-than `kLargeBuffer`,
but the exception claims `>=`. Fix the error message to match the
code.
Happy to open an issue if it's helpful; I was hopeful the trivial fix doesn't need a separate issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69174
Reviewed By: zou3519
Differential Revision: D32760055
Pulled By: H-Huang
fbshipit-source-id: 1a8ab68f36b326ed62d78afdcb198f4d6572d017
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62445
PyTorch currently uses the old style of compiling CUDA in CMake which is just a
bunch of scripts in `FindCUDA.cmake`. Newer versions support CUDA natively as
a language just like C++ or C.
Test Plan: Imported from OSS
Reviewed By: ejguan
Differential Revision: D31503350
fbshipit-source-id: 2ee817edc9698531ae1b87eda3ad271ee459fd55
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65610
- Replace HIP_PLATFORM_HCC with USE_ROCM
- Dont rely on CUDA_VERSION or HIP_VERSION and use USE_ROCM and ROCM_VERSION.
- In the next PR
- Will be removing the mapping from CUDA_VERSION to HIP_VERSION and CUDA to HIP in hipify.
- HIP_PLATFORM_HCC is deprecated, so will add HIP_PLATFORM_AMD to support HIP host code compilation on gcc.
cc jeffdaily sunway513 jithunnair-amd ROCmSupport amathews-amd
Reviewed By: jbschlosser
Differential Revision: D30909053
Pulled By: ezyang
fbshipit-source-id: 224a966ebf1aaec79beccbbd686fdf3d49267e06
Summary:
- HIP_VERSION semantic versioning will change in ROCm4.3. The changes essentially remove the dependency on HIP_VERSION provided in the hip header to keep code compatible with older and newer versions of ROCm.
- TORCH_HIP_VERSION is derived from HIP_VERSION_MAJOR and HIP_VERSION_MINOR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62786
Reviewed By: bdhirsh
Differential Revision: D30281682
Pulled By: seemethere
fbshipit-source-id: e41e69fb9e13de5ddd1af99ba5bbdcbb7b64b673
Summary:
Report pointed memory size, total allocated memory, total reserved size all in one report.
`ptr` and `alloc_size` will be used for associating with op trace.
`allocated_size`, `reserved_size` will be used for memory trace.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61282
Reviewed By: ejguan
Differential Revision: D29796282
Pulled By: chaekit
fbshipit-source-id: 5314c867632d3af1fa9a3811b35eaa5e931a5d87
Summary:
This creates `torch.cuda.set_warn_on_synchronization()` function that would warn or error when synchronizing operation is performed. We could wrap it in a context manager for ease of use, but it would be a lie, because it sets global, and not thread-local state. Since it's intended for debugging, maybe that's ok though.
As all `torch.cuda.*` functions, it's going through CPython, not pybind, so the argument is converted to long before being passed to c10 function. I'll make python argument a python enum class, but without pybind it'll still have to go thourgh long conversion.
For a test script
```
import torch
torch.cuda.set_warn_on_synchronization(1)
x=torch.randn(10, device="cuda")
x.nonzero()
y=torch.randn((), device="cuda")
if y:
print("something")
torch.multinomial(x.abs(), 10, replacement=False)
torch.randperm(20000, device="cuda")
ind = torch.randint(10, (3,), device="cuda")
mask = torch.randint(2, (10,), device="cuda", dtype=torch.bool)
val = torch.randn((), device="cuda")
x[mask]=1.
x[mask] = val
torch.cuda.synchronize()
```
the output is
```
/../playground/sync_warn_test.py:4: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
x.nonzero()
/../playground/sync_warn_test.py:7: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
if y:
something
/../playground/sync_warn_test.py:9: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
torch.multinomial(x.abs(), 10, replacement=False)
/../playground/sync_warn_test.py:15: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
x[mask] = val
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62092
Reviewed By: mruberry
Differential Revision: D29968792
Pulled By: ngimel
fbshipit-source-id: cc6f817212c164727ed99ecf6ab050dc29631b9e
Summary:
This is a first step towards creating context manager that errors out on synchronizing calls.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61889
Reviewed By: albanD
Differential Revision: D29805280
Pulled By: ngimel
fbshipit-source-id: b66400fbe0941b7daa51e6b30abe27b9cccd4e8a
Summary:
Follow-up to https://github.com/pytorch/pytorch/issues/18584. This PR covers the remaining places where event or stream query might result in not ready errors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61554
Reviewed By: mrshenli
Differential Revision: D29763973
Pulled By: ezyang
fbshipit-source-id: 41d988d1826b2309cc6b01a81144094b353abdf9
Summary:
Fixes https://github.com/pytorch/pytorch/issues/35901
This change is designed to prevent fragmentation in the Caching Allocator. Permissive block splitting in the allocator allows very large blocks to be split into many pieces. Once split too finely it is unlikely all pieces will be 'free' at that same time so the original allocation can never be returned. Anecdotally, we've seen a model run out of memory failing to alloc a 50 MB block on a 32 GB card while the caching allocator is holding 13 GB of 'split free blocks'
Approach:
- Large blocks above a certain size are designated "oversize". This limit is currently set 1 decade above large, 200 MB
- Oversize blocks can not be split
- Oversize blocks must closely match the requested size (e.g. a 200 MB request will match an existing 205 MB block, but not a 300 MB block)
- In lieu of splitting oversize blocks there is a mechanism to quickly free a single oversize block (to the system allocator) to allow an appropriate size block to be allocated. This will be activated under memory pressure and will prevent _release_cached_blocks()_ from triggering
Initial performance tests show this is similar or quicker than the original strategy. Additional tests are ongoing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44742
Reviewed By: zou3519
Differential Revision: D29186394
Pulled By: ezyang
fbshipit-source-id: c88918836db3f51df59de6d1b3e03602ebe306a9
Summary:
Previous is https://github.com/pytorch/pytorch/issues/57781
We add now two CUDA bindings to avoid using ctypes to fix a windows issue.
However, we use ctypes to allocate the stream and create its pointer
(we can do this with a 0-dim tensor too if it feels better).
CC. ezyang rgommers ngimel mruberry
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59527
Reviewed By: albanD
Differential Revision: D29053062
Pulled By: ezyang
fbshipit-source-id: 661e7e58de98b1bdb7a0871808cd41d91fe8f13f
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59903
D29034650 (cf0c4ac258) probably breaks something because it changes a `for` loop on ~Line 1200 from `[size,max)` to `[0,max)`. This fixes that
Test Plan: Sandcastle
Reviewed By: ngimel
Differential Revision: D29081688
fbshipit-source-id: 21f08e3f244fc02cf97d137b3cc80d4378d17185
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59560
`at::cuda::CUDAStream` has the `query` and `synchronize` methods, but `c10::Stream` does not, and I couldn't find any generic way to accomplish this. Hence I added helpers to do this to the DeviceGuardImpl interface, and then defined these methods on `c10::Stream`. (I had to do it out-of-line to circumvent a circular dependency).
ghstack-source-id: 130932249
Test Plan: CI
Reviewed By: ezyang
Differential Revision: D28931377
fbshipit-source-id: cd0c19cf021e305d0c0cf9af364afb445d010248
Summary:
After the change async error warnings look as follows:
```
$ python -c "import torch;torch.eye(3,3,device='cuda:777')"
Traceback (most recent call last):
File "<string>", line 1, in <module>
RuntimeError: CUDA error: invalid device ordinal
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59467
Reviewed By: ngimel
Differential Revision: D28904360
Pulled By: malfet
fbshipit-source-id: 2a8fa5affed5b4ffcaa602c8ab2669061cde7db0
Summary:
This is required in https://github.com/pytorch/pytorch/pull/57110#issuecomment-828357947
We need to provide means to synchronize on externally allocated streams for dlpack support in python array data api.
cc mruberry rgommers leofang asi1024 kmaehashi
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57781
Reviewed By: mrshenli
Differential Revision: D28326365
Pulled By: ezyang
fbshipit-source-id: b67858c8033949951b49a3d319f649884dfd0a91
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57609
Throw c10::CudaError for CUDA Exceptions for better classification of errors
Test Plan: Test locally by running some workflows
Reviewed By: dzhulgakov
Differential Revision: D28209356
fbshipit-source-id: 19a5fc8548433238dc224ea81a5f63a945fc5cc3
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57345
Already back in https://github.com/pytorch/pytorch/pull/57046 we realized that calling this method `getStreamFromPool` could cause issues because that name gets HIPified and thus in some callsites we'd end up calling a method that doesn't exist. In the end we got away with it because the places where we were calling that method weren't HIPified. However in the next PR we'll use this method inside RPC, and that will start causing problems, hence here I rename it to something that should not cause conflicts. This is a private API (since it's inside `impl`) thus there's no backwards compatibility concerns.
ghstack-source-id: 127916484
Test Plan: CI
Reviewed By: mrshenli
Differential Revision: D28114923
fbshipit-source-id: e027ad08a8e02090c08c6407c2db5a7fde104812
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56830
Opt into formatting on GitHub and format everything. This is a trial run before turning on formatting for more and eventually all of the codebase.
Test Plan: CI
Reviewed By: zertosh
Differential Revision: D27979080
fbshipit-source-id: a80f0c48691c08ae8ca0af06377b87e6a2351151
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57049
There was a comment above CUDAMultiStreamGuard which said "TODO: Implement this generically in c10". This is what I'm doing here.
The new generic MultiStreamGuard class is able to take a vector of device-agnostic c10::Streams and is able to support any device type (CUDA, but also ROCm and others) by using a VirtualGuardImpl. A class called CUDAMultiStreamGuard is still kept around, for convenience, and slightly for performance as it avoids a vtable lookup.
ghstack-source-id: 127713139
(Note: this ignores all push blocking failures!)
Test Plan: CI
Reviewed By: mrshenli
Differential Revision: D28029158
fbshipit-source-id: 2f3181371f8cb0d77a3b2e6aa510f1dd74e8f69b
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57047
We intend to merge CUDAFuture into ivalue::Future by using DeviceGuardImplInterface to avoid explicitly referring to CUDA. For that we need to add two methods to DeviceGuardImplInterface. In this PR, we add a method to record a DataPtr onto a stream with the caching allocator.
ghstack-source-id: 127713135
(Note: this ignores all push blocking failures!)
Test Plan: Used later in this stack
Reviewed By: ezyang
Differential Revision: D28029161
fbshipit-source-id: ff337ab8ccc98437b5594b2f263476baa1ae93e7
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57046
We intend to merge CUDAFuture into ivalue::Future by using DeviceGuardImplInterface to avoid explicitly referring to CUDA. For that we need to add two methods to DeviceGuardImplInterface. In this PR, we add a method to get a stream from the global ATen pool.
ghstack-source-id: 127713137
(Note: this ignores all push blocking failures!)
Test Plan: Used later in this stack
Reviewed By: ezyang
Differential Revision: D28029159
fbshipit-source-id: 5055d84c1f3c2a4d86442f3149455c5ebd976dea
Summary:
Safely deallocating and repurposing memory used across streams relies on recording end-of-life events in all an allocation's usage streams beyond its original allocation stream. The events are later queried to see if all GPU work in those extra streams that could have used the allocation is done (from the CPU's perspective) before repurposing the allocation for use in its original stream.
The trouble is, calling EventQuery on an ordinary event recorded in a capturing stream is illegal. Calling EventQuery while capture is underway is also illegal. So when we call `tensor.record_stream` (or `c10::cuda::cudaCachingAllocator::recordStream`) on any tensor that's used or deleted in or around a capture, we often end up with a confusing error thrown from the cudaEventQuery in DeviceCachingAllocator::process_events().
This PR enables hopefully-safe deletion of tensors used across streams in or around capture with a conservative but simple approach: don't record or process end of life events for such tensors until the allocator's sure no captures are underway. You could whiteboard cases where this causes cross-stream-used allocations to be unavailable for reuse longer than absolutely necessary, but cross-stream-used allocations are uncommon, so for practical purposes this approach's impact on the memory footprint of captured sequences should be small.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55860
Reviewed By: ejguan
Differential Revision: D27822557
Pulled By: ezyang
fbshipit-source-id: b2e18a19d83ed05bad67a8157a14a606ed14d04e