Which inherits from `RuntimeError` and contains `error_code`, which in case of CUDA should contain error returned by `cudaGetLastError`
`torch::detail::_new_accelerator_error_object(c10::AcceleratorError&)` follows the pattern of CPython's [`PyErr_SetString`](cb8a72b301/Python/errors.c (L282)), namely
- Convert cstr into Python string with `PyUnicode_FromString`
- Create new exception object using `PyObject_CallOneArg` just like it's done in [`_PyErr_CreateException`](cb8a72b301/Python/errors.c (L32))
- Set `error_code` property using `PyObject_SetAttrString`
- decref all temporary references
Test that it works and captures CPP backtrace (in addition to CI) by running
```python
import os
os.environ['TORCH_SHOW_CPP_STACKTRACES'] = '1'
import torch
x = torch.rand(10, device="cuda")
y = torch.arange(20, device="cuda")
try:
x[y] = 2
print(x)
except torch.AcceleratorError as e:
print("Exception was raised", e.args[0])
print("Captured error code is ", e.error_code)
```
which produces following output
```
Exception was raised CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Exception raised from c10_cuda_check_implementation at /home/ubuntu/pytorch/c10/cuda/CUDAException.cpp:41 (most recent call first):
C++ CapturedTraceback:
#4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> (), c10::SetStackTraceFetcher(std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0
#5 c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ??:0
#6 c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) [clone .cold] from CUDAException.cpp:0
#7 void at::native::gpu_kernel_impl<at::native::AbsFunctor<float> >(at::TensorIteratorBase&, at::native::AbsFunctor<float> const&) [clone .isra.0] from tmpxft_000191fc_00000000-6_AbsKernel.cudafe1.cpp:0
#8 at::native::abs_kernel_cuda(at::TensorIteratorBase&) from ??:0
#9 at::Tensor& at::native::unary_op_impl_with_complex_to_float_out<at::native::abs_stub_DECLARE_DISPATCH_type>(at::Tensor&, at::Tensor const&, at::native::abs_stub_DECLARE_DISPATCH_type&, bool) [clone .constprop.0] from UnaryOps.cpp:0
#10 at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA_out_abs_out(at::Tensor const&, at::Tensor&) from RegisterCUDA_0.cpp:0
#11 at::_ops::abs_out::call(at::Tensor const&, at::Tensor&) from ??:0
#12 at::native::abs(at::Tensor const&) from ??:0
#13 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd__abs>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&> >, at::Tensor (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) from RegisterCompositeExplicitAutograd_0.cpp:0
#14 at::_ops::abs::redispatch(c10::DispatchKeySet, at::Tensor const&) from ??:0
#15 torch::autograd::VariableType::(anonymous namespace)::abs(c10::DispatchKeySet, at::Tensor const&) from VariableType_1.cpp:0
#16 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&), &torch::autograd::VariableType::(anonymous namespace)::abs>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&> >, at::Tensor (c10::DispatchKeySet, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) from VariableType_1.cpp:0
#17 at::_ops::abs::call(at::Tensor const&) from ??:0
#18 at::native::isfinite(at::Tensor const&) from ??:0
#19 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__isfinite>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&> >, at::Tensor (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) from RegisterCompositeImplicitAutograd_0.cpp:0
#20 at::_ops::isfinite::call(at::Tensor const&) from ??:0
#21 torch::autograd::THPVariable_isfinite(_object*, _object*, _object*) from python_torch_functions_2.cpp:0
#22 PyObject_CallFunctionObjArgs from ??:0
#23 _PyObject_MakeTpCall from ??:0
#24 _PyEval_EvalFrameDefault from ??:0
#25 _PyObject_FastCallDictTstate from ??:0
#26 _PyStack_AsDict from ??:0
#27 _PyObject_MakeTpCall from ??:0
#28 _PyEval_EvalFrameDefault from ??:0
#29 _PyFunction_Vectorcall from ??:0
#30 _PyEval_EvalFrameDefault from ??:0
#31 _PyFunction_Vectorcall from ??:0
#32 _PyEval_EvalFrameDefault from ??:0
#33 _PyFunction_Vectorcall from ??:0
#34 _PyEval_EvalFrameDefault from ??:0
#35 PyFrame_GetCode from ??:0
#36 PyNumber_Xor from ??:0
#37 PyObject_Str from ??:0
#38 PyFile_WriteObject from ??:0
#39 _PyWideStringList_AsList from ??:0
#40 _PyDict_NewPresized from ??:0
#41 _PyEval_EvalFrameDefault from ??:0
#42 PyEval_EvalCode from ??:0
#43 PyEval_EvalCode from ??:0
#44 PyUnicode_Tailmatch from ??:0
#45 PyInit__collections from ??:0
#46 PyUnicode_Tailmatch from ??:0
#47 _PyRun_SimpleFileObject from ??:0
#48 _PyRun_AnyFileObject from ??:0
#49 Py_RunMain from ??:0
#50 Py_BytesMain from ??:0
#51 __libc_init_first from ??:0
#52 __libc_start_main from ??:0
#53 _start from ??:0
Captured error code is 710
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152023
Approved by: https://github.com/eqy, https://github.com/mradmila, https://github.com/ngimel
ghstack dependencies: #154436
Removes MemPoolContext from custom user mempools. The ground truth for which pool should be used is in graph_pools active pool, and MemPoolContext just introduced an opportunity for the pool pointed to by MemPoolContext and active pool in graph_pools to go out of sync (see all the asserts in the code to make sure that happens, and yet it still could happen in a multithread scenario, see my recent PRs (#153990).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154042
Approved by: https://github.com/albanD, https://github.com/syed-ahmed
This is an initial attempt to provide some statistics for the pinned host memory allocations flowing through CachingHostAllocator. Many times in the past we have had inexplicable slowdowns that would be much easier to diagnose if we had some host memory characteristics.
This change tries very hard not to disrupt the initial design of the allocator, and it uses existing locking mechanism, whenever possible, to gather statistics "for free". Only deviation from that is on the "slow path" where we incur CUDA calls anyway, so taking a short lock is not going to hurt the performance much, especially in the steady state where most allocations will come from cache.
As mentioned before, this is the first PR, to introduce the concept and to see if it fits the right paradigm. We can always add more later.
Metrics that would require more involved changes to the code base and locks, like requested memory, have been punted for now. I also tried to reuse the Stat structure used in CUDA caching allocator, in order to maintain symmetry.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147660
Approved by: https://github.com/ngimel
This is an initial attempt to provide some statistics for the pinned host memory allocations flowing through CachingHostAllocator. Many times in the past we have had inexplicable slowdowns that would be much easier to diagnose if we had some host memory characteristics.
This change tries very hard not to disrupt the initial design of the allocator, and it uses existing locking mechanism, whenever possible, to gather statistics "for free". Only deviation from that is on the "slow path" where we incur CUDA calls anyway, so taking a short lock is not going to hurt the performance much, especially in the steady state where most allocations will come from cache.
As mentioned before, this is the first PR, to introduce the concept and to see if it fits the right paradigm. We can always add more later.
Metrics that would require more involved changes to the code base and locks, like requested memory, have been punted for now. I also tried to reuse the Stat structure used in CUDA caching allocator, in order to maintain symmetry.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147660
Approved by: https://github.com/ngimel
Triton 2.2 and greater have a bug where allowing TF32 generation for a GPU that does not support TF32 will cause code generation errors. Patch around this problem by:
1. Adding a function to `torch.cuda` that determines whether CUDA hardware is capable of using the TF32 format.
2. Using that function to explicitly disable TF32 generation when calling Triton, where needed.
To demonstrate that this fix works, try running `test/inductor/test_max_autotune.py` on a GPU with CUDA compute capability < 8 (e.g. any NVIDIA consumer GPU) without this fix.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145684
Approved by: https://github.com/eqy
Certain `cpp_wrapper`-enabled tests were OOM-ing in the CI pipeline, with error messages suggesting that sufficient memory was accessible. This ultimately resulted from an internal memory limitation that was not queryable in the API. This PR adds querying for that limit.
Additionally, the failing tests had incorrect memory availability checks, and are updated with measured memory requirements.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140620
Approved by: https://github.com/malfet, https://github.com/eqy
ghstack dependencies: #141367
Fixes#127908
## Description
Created docs to document the torch.cuda.cudart function to solve the issue #127908.
I tried to stick to the [guidelines to document a function](https://github.com/pytorch/pytorch/wiki/Docstring-Guidelines#documenting-a-function) but I was not sure if there is a consensus on how to handle the docs of a function that calls an internal function. So I went ahead and tried what the function will raise, etc. from the user endpoint and documented it (i.e. I am giving what actually _lazy_init() will raise).
Updated PR from #128298 since I made quite a big mistake in my branch. I apologize for the newbie mistake.
### Summary of Changes
- Added docs for torch.cuda.cudart
- Added the cudart function in the autosummary of docs/source/cuda.rst
## Checklist
- [X] The issue that is being fixed is referred in the description
- [X] Only one issue is addressed in this pull request
- [X] Labels from the issue that this PR is fixing are added to this pull request
- [X] No unnecesary issues are included into this pull request
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128741
Approved by: https://github.com/msaroufim
Add non-package python modules to the public API checks.
The original change is to remove the `ispkg` check in this line
https://github.com/pytorch/pytorch/blob/main/docs/source/conf.py#L518
Everything else is to add the appropriate modules to the rst files, make sure every module we provide can be imported (fixed by either making optional dependencies optional or just deleting files that have been un-importable for 3 years), make API that are both modules and functions (like torch.autograd.gradcheck) properly rendered on the docs website without confusion and add every non-documented API to the allow list (~3k of them).
Next steps will be to try and fix these missing docs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110568
Approved by: https://github.com/zou3519
Fixes https://github.com/pytorch/serve/issues/1937
A fairly common query I see folks running while using pytorch is
`nvidia-smi --format=csv,noheader,nounits --query-gpu=utilization.gpu,utilization.memory,memory.total,memory.used,temperature.gpu,power.draw,clocks.current.sm,clocks.current.memory -l 10`
Existing metrics we have
* For kernel utilization`torch.cuda.utilization()`
* For memory utilization we have them under `torch.cuda.memory` the memory allocated with `torch.cuda.memory.memory_allocated()`
* For total available memory we have `torch.cuda.get_device_properties(0).total_memory`
Which means the only metrics we're missing are
* Temperature: now in `torch.cuda.temperature()`
* Power draw: now in `torch.cuda.power()`
* Clock speed: now in `torch.cuda.clock_speed()`
With some important details on each
* Clock speed settings: I picked the SM clock domain which is documented here https://docs.nvidia.com/deploy/nvml-api/group__nvmlDeviceEnumvs.html#group__nvmlDeviceEnumvs_1g805c0647be9996589fc5e3f6ff680c64
* Temperature: I use `pynvml.nvmlDeviceGetTemperature(handle, 0)` where 0 refers to the GPU die temperature
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91717
Approved by: https://github.com/ngimel
Fixes#43144
This uses the Backend system added by [82682](https://github.com/pytorch/pytorch/pull/82682) to change allocators dynamically during the code execution. This will allow us to use RMM, use CUDA managed memory for some portions of the code that do not fit in GPU memory. Write static memory allocators to reduce fragmentation while training models and improve interoperability with external DL compilers/libraries.
For example, we could have the following allocator in c++
```c++
#include <sys/types.h>
#include <cuda_runtime_api.h>
#include <iostream>
extern "C" {
void* my_malloc(ssize_t size, int device, cudaStream_t stream) {
void *ptr;
std::cout<<"alloc "<< size<<std::endl;
cudaMalloc(&ptr, size);
return ptr;
}
void my_free(void* ptr) {
std::cout<<"free "<<std::endl;
cudaFree(ptr);
}
}
```
Compile it as a shared library
```
nvcc allocator.cc -o alloc.so -shared --compiler-options '-fPIC'
```
And use it from PyTorch as follows
```python
import torch
# Init caching
# b = torch.zeros(10, device='cuda')
new_alloc = torch.cuda.memory.CUDAPluggableAllocator('alloc.so', 'my_malloc', 'my_free')
old = torch.cuda.memory.get_current_allocator()
torch.cuda.memory.change_current_allocator(new_alloc)
b = torch.zeros(10, device='cuda')
# This will error since the current allocator was already instantiated
torch.cuda.memory.change_current_allocator(old)
```
Things to discuss
- How to test this, needs compiling external code ...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86786
Approved by: https://github.com/albanD
Resubmit of https://github.com/pytorch/pytorch/pull/77673, which was reverted due to Windows test failures: https://github.com/pytorch/pytorch/pull/77673#issuecomment-1130425845.
I suspect these failures happened because I don't explicitly set a side stream for graph capture in the new test.
Not setting a side stream explicitly is alright on Linux because cuda tests implicitly use a side stream.
I think Windows cuda tests implicitly use the default stream, breaking capture and leaving the backend in a bad state.
Other graphs tests explicitly set side streams and don't error in Windows builds, so i'm 95% sure doing the same for the new test will work.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77789
Approved by: https://github.com/ezyang
This PR allows user to author a CUDA kernel in python.
```
from torch.cuda.jiterator import create_jit_fn
code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x * y + x - y + alpha; }"
jitted_fn = create_jit_fn(code_string, alpha=0)
a = torch.rand(3, device='cuda')
b = torch.rand(3, device='cuda')
result = jitted_fn(a, b, alpha=1.0)
```
Limitations:
- Only supports elementwise kernel
- 1~8 tensor inputs (empty input, e.g. factory methods, is not supported)
- inputs tensors must live in cuda device
- cpu Scalar is not supported
- kwargs must be pre-declared when calling create_jit_fn
- kwargs must be convertible to at::Scalar, one of float64, int64_t, bool. (complex not support for now)
TODOs:
- [x] consolidate union and c10::variant implementation
- [x] plug into existing op testing framework
- [ ] rename files, place files in the right folder
- [ ] place util functions in the right file
- [x] enforce assumptions in python interface e.g <8 inputs, kwargs types
- [x] Add user-facing documentation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76394
Approved by: https://github.com/mruberry
Summary:
Also fixes the documentation failing to appear and adds a test to validate that op works with multiple devices properly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69640
Reviewed By: ngimel
Differential Revision: D32965391
Pulled By: mruberry
fbshipit-source-id: 4fe502809b353464da8edf62d92ca9863804f08e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69104
Add nvidia-smi memory and utilization as native Python API
Test Plan:
testing the function returns the appropriate value.
Unit tests to come.
Reviewed By: malfet
Differential Revision: D32711562
fbshipit-source-id: 01e676203299f8fde4f3ed4065f68b497e62a789
Summary:
Powers have decided this API should be listed as beta.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65247
Reviewed By: malfet
Differential Revision: D31057940
Pulled By: ngimel
fbshipit-source-id: 137b63cbd2c7409fecdc161a22135619bfc96bfa
Summary:
This creates `torch.cuda.set_warn_on_synchronization()` function that would warn or error when synchronizing operation is performed. We could wrap it in a context manager for ease of use, but it would be a lie, because it sets global, and not thread-local state. Since it's intended for debugging, maybe that's ok though.
As all `torch.cuda.*` functions, it's going through CPython, not pybind, so the argument is converted to long before being passed to c10 function. I'll make python argument a python enum class, but without pybind it'll still have to go thourgh long conversion.
For a test script
```
import torch
torch.cuda.set_warn_on_synchronization(1)
x=torch.randn(10, device="cuda")
x.nonzero()
y=torch.randn((), device="cuda")
if y:
print("something")
torch.multinomial(x.abs(), 10, replacement=False)
torch.randperm(20000, device="cuda")
ind = torch.randint(10, (3,), device="cuda")
mask = torch.randint(2, (10,), device="cuda", dtype=torch.bool)
val = torch.randn((), device="cuda")
x[mask]=1.
x[mask] = val
torch.cuda.synchronize()
```
the output is
```
/../playground/sync_warn_test.py:4: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
x.nonzero()
/../playground/sync_warn_test.py:7: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
if y:
something
/../playground/sync_warn_test.py:9: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
torch.multinomial(x.abs(), 10, replacement=False)
/../playground/sync_warn_test.py:15: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
x[mask] = val
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62092
Reviewed By: mruberry
Differential Revision: D29968792
Pulled By: ngimel
fbshipit-source-id: cc6f817212c164727ed99ecf6ab050dc29631b9e
Summary:
Related to https://github.com/pytorch/pytorch/issues/52256
Use autosummary instead of autofunction to create subpages for optim and cuda functions/classes.
Also fix some minor formatting issues in optim.LBFGS and cuda.stream docstings
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55673
Reviewed By: jbschlosser
Differential Revision: D27747741
Pulled By: zou3519
fbshipit-source-id: 070681f840cdf4433a44af75be3483f16e5acf7d