Commit Graph

513 Commits

Author SHA1 Message Date
Syed Tousif Ahmed
1637a40796 Adds snapshot API for MemPools to get pool memory segments (#133601)
Canonically, the snapshot API returns the entire memory state of the CUDACachingAllocator (using `get_all_blocks`). There is no API that can only return the memory state of a given pool.

In this PR, we extend the functionality of snapshot API such that it can only return the memory addresses of an active pool. When snapshot API is called under a MemPoolContext, we only return the blocks that correspond to the pool id of the active pool.

Part of https://github.com/pytorch/pytorch/issues/124807.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133601
Approved by: https://github.com/ezyang
2024-10-29 01:01:47 +00:00
PyTorch MergeBot
3b0f39336c Revert "Adds snapshot API for MemPools to get pool memory segments (#133601)"
This reverts commit 00504aa6b8.

Reverted https://github.com/pytorch/pytorch/pull/133601 on behalf of https://github.com/wdvr due to reverting for now as this breaks lots of internal tests. Details below ([comment](https://github.com/pytorch/pytorch/pull/133601#issuecomment-2441864871))
2024-10-28 15:12:20 +00:00
Syed Tousif Ahmed
00504aa6b8 Adds snapshot API for MemPools to get pool memory segments (#133601)
Canonically, the snapshot API returns the entire memory state of the CUDACachingAllocator (using `get_all_blocks`). There is no API that can only return the memory state of a given pool.

In this PR, we extend the functionality of snapshot API such that it can only return the memory addresses of an active pool. When snapshot API is called under a MemPoolContext, we only return the blocks that correspond to the pool id of the active pool.

Part of https://github.com/pytorch/pytorch/issues/124807.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133601
Approved by: https://github.com/ezyang
2024-10-26 03:34:59 +00:00
Syed Tousif Ahmed
03c72976a5 Properly uses ref-counting for torch.cuda.use_mem_pool (#133600)
This PR refactors some ref-counting functionality out of `beginAllocateToPool` and `releasePool`. The ref-counting logic is then used in construction and destruction of `torch.cuda.MemPool`.

The `use_count` variable in the CUDACachingAllocator is essentially a refcount of how many context managers are using the pool. Since we are now lifting up the MemPool abstraction to the user, the MemPool object itself now needs to hold a an extra reference as well.

Part of https://github.com/pytorch/pytorch/issues/124807.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133600
Approved by: https://github.com/eqy, https://github.com/ezyang
2024-10-22 03:21:53 +00:00
Tom Ritchford
c0582fd0f8 Remove unused Python variables in torch/[b-z]* (#136963)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136963
Approved by: https://github.com/ezyang
2024-10-19 16:45:22 +00:00
albanD
c4ed03cea1 Add proper handling for view and factory function for csan (#138236)
In particular, properly handle that some functions only read/write metadata on the Tensor and thus should not be detected as read/write by csan.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138236
Approved by: https://github.com/ngimel
2024-10-18 14:04:18 +00:00
albanD
69ba89da11 Fix cuda sanitizer and as_subclass calls (#138218)
This fixes 4 main issues:
- The way the cuda sanitizer handle it's state is weird. In particular, because the lifetime of the Mode is linked to the submodule, then this might outlive the python runtime and other modules loaded. On my current version, this even outlives the "sys" module. Given that I'm not sure the impact of changing this lifetime handling, I'm making the exit handler a no-op when python is already dying and thus no point cleaning up.
- Adds a "disable" method to be able to test after the mode is enabled.
- Fix `Tensor.as_sublass()` to properly disable modes when creating the new Tensor object just like we already do in `make_subclass` and `make_wrapper_subclass`. The change here is just to apply the exact same treatment to it.
- ~Fix `Tensor.as_subclass()` not to propagate autograd as there is no valid backward associated here.~ We have test that check that this behavior happen so I guess this is not an obvious bugfix and expected behavior. Reverted that change.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138218
Approved by: https://github.com/ngimel
2024-10-17 21:18:32 +00:00
Jack Taylor
966a1a971e [ROCm] Add AMDSMI support for UUID input (#129741)
Adds support for for using UUIDs for AMDSMI utilities in PyTorch via CUDA_VISIBLE_DEVICES/HIP_VISIBLE_DEVICES.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129741
Approved by: https://github.com/pruthvistony, https://github.com/jeffdaily
2024-10-15 15:56:30 +00:00
Jin Zhou
5516ac5c21 [ROCm] Tunableop record untuned (#128813)
When enable tunableop, It is easy to have OOM since APP usually needs large video memory size, such as running a LLM for inference.  So we need a offline mode to tune the GEMMs. This PR provide an offline mode for tunableOp:

- record untuned GEMMs to file.

- a python API named tune_gemm_in_file is added to read the untuned file and tune the GEMMs in file

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128813
Approved by: https://github.com/jeffdaily, https://github.com/hongxiayang, https://github.com/naromero77amd

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2024-10-09 21:59:03 +00:00
Jeff Daily
c7b0d4b148 raw_alloc ignores PYTORCH_NO_CUDA_MEMORY_CACHING (#131114)
raw_alloc is used by cudnn, miopen, thrust, and tunableop.  Without this PR, the env var for disabling the caching allocator will only partially work.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131114
Approved by: https://github.com/eqy, https://github.com/houseroad, https://github.com/albanD

Co-authored-by: Nichols A. Romero <nick.romero@amd.com>
2024-10-04 15:36:29 +00:00
PyTorch MergeBot
0d1701f310 Revert "raw_alloc ignores PYTORCH_NO_CUDA_MEMORY_CACHING (#131114)"
This reverts commit 7001907480.

Reverted https://github.com/pytorch/pytorch/pull/131114 on behalf of https://github.com/PaliC due to failing internal builds ([comment](https://github.com/pytorch/pytorch/pull/131114#issuecomment-2390615007))
2024-10-03 06:22:55 +00:00
Jeff Daily
7001907480 raw_alloc ignores PYTORCH_NO_CUDA_MEMORY_CACHING (#131114)
raw_alloc is used by cudnn, miopen, thrust, and tunableop.  Without this PR, the env var for disabling the caching allocator will only partially work.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131114
Approved by: https://github.com/eqy, https://github.com/houseroad, https://github.com/albanD

Co-authored-by: Nichols A. Romero <nick.romero@amd.com>
2024-10-02 16:27:15 +00:00
Yu, Guangye
d29094888b Use torch.Stream&torch.Event for Dynamo capature (#134850)
# Motivation
This PR aims to solve the multiple Inheritance problem.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134850
Approved by: https://github.com/yf225, https://github.com/EikanWang
2024-10-02 14:15:33 +00:00
drisspg
d05645841e Update get_device_properties to take in optional device (#136683)
Aligns behavior with the rest of cuda's device info query methods

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136683
Approved by: https://github.com/eqy
2024-09-26 15:07:31 +00:00
Jeff Daily
15dba021bb [ROCm][CI] upgrade CI to ROCm 6.2 (#132555)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132555
Approved by: https://github.com/pruthvistony, https://github.com/malfet
2024-09-20 17:39:31 +00:00
Dan Zimmerman
fc88ba260f [amdsmi][torch] Update amdsmi API usages (#135504)
Summary: In ROCm 6.2.0 there were API name changes-- we check if the new APIs exist and use them in this diff; see 7b2463abe0 for the changes

Test Plan: CI

Differential Revision: D62325661

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135504
Approved by: https://github.com/eqy, https://github.com/houseroad
2024-09-10 19:15:39 +00:00
Syed Tousif Ahmed
4655eb3ee2 Uses MemPoolContext to route allocations from CUDACachingAllocator (#134685)
Re-open of https://github.com/pytorch/pytorch/pull/133599 that was mistakenly closed by issuing `ghstack land`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134685
Approved by: https://github.com/ezyang
2024-08-29 03:56:31 +00:00
Nikita Shulga
f7c1f32803 Fix partially initialized module error (#134019)
https://github.com/pytorch/pytorch/pull/132990 introduced dependency on `torch.version`, which might not be imported yet, and can result in  `AttributeError: partially initialized module 'torch' has no attribute 'version' (most likely due to a circular import)` if user starts its code with `import torch.cuda`

Fix it by importing `torch.version` explicitly

Test Plan: CI

Differential Revision: D61549284

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134019
Approved by: https://github.com/seemethere
2024-08-20 22:20:02 +00:00
Jack Taylor
92151c814b [ROCm] Set _HAS_PYNVML to false if amdsmi not installed (#132990)
This is a bugfix that was recently encountered in ROCm/Deepspeed. Currently if a library installs pynvml and runs on ROCm pytorch will break as _HAS_PYNVML is set to true and it will attempt to use amdsmi library for the device_count call which will not be installed.

This fix will set _HAS_PYNVML to false on ROCm if amdsmi is not installed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132990
Approved by: https://github.com/pruthvistony, https://github.com/eqy, https://github.com/malfet
2024-08-19 09:45:58 +00:00
Mikayla Gawarecki
018e48c337 [Reland] Add wrappers for synchronous GPUDirect Storage APIs (#133489)
Reland #130633

USE_CUFILE turned off by default in this version
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133489
Approved by: https://github.com/albanD
2024-08-15 17:11:52 +00:00
Xuehai Pan
758a0a88a2 [BE][Easy] enable ruff rule PIE790: unnecessary pass statement (#133200)
This PR removes unnecessary `pass` statement. This is semanticly safe because the bytecode for the Python code does not change.

Note that if there is a docstring in the function, a empty function does not need a `pass` statement as placeholder.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133200
Approved by: https://github.com/malfet, https://github.com/eqy, https://github.com/kit1980
2024-08-15 15:50:19 +00:00
Tobias Ringwald
6753ee127c Allow torch.cuda.memory.mem_get_info to take a device str argument with an unspecified device index. (#132616)
`torch.cuda.memory.mem_get_info` allows device strings given the current type hints. However, `device = torch.device('cuda')` leads to `device.index = None`, which results in downstream problems. Setting `optional=True` will insert the default device index in such cases.

Fixes #132583

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132616
Approved by: https://github.com/soulitzer
2024-08-06 13:19:46 +00:00
Xuehai Pan
f3fce597e9 [BE][Easy][17/19] enforce style for empty lines in import segments in torch/[a-c]*/ and torch/[e-n]*/ (#129769)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129769
Approved by: https://github.com/ezyang
2024-08-04 10:24:09 +00:00
Nikita Shulga
cd5452aace [CUDA] is_bf16_supported() should not crash if there are no GPUs (#132313)
`False` is the good answer on a system that does not have any CUDA GPUs.
- Added regression test to TestTorch.

Fixes https://github.com/pytorch/pytorch/issues/132303

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132313
Approved by: https://github.com/eqy, https://github.com/syed-ahmed
2024-08-02 02:50:43 +00:00
Oguz Ulgen
72d2dba992 Add None return type to init (#132335)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132335
Approved by: https://github.com/albanD
2024-08-01 15:26:45 +00:00
Syed Tousif Ahmed
7c89ec0f7c Implements torch.cuda.MemPool() API (#131152)
In this PR:
- Pool id creation logic is refactored and moved to a MemPool class. `graph_pool_handle()` API now uses `torch.cuda.MemPool()` to get a unique id for a pool. Existing tests should cover this change.
- MemPool holds a pointer to a CUDAAllocator as proposed in https://github.com/pytorch/pytorch/issues/124807#issuecomment-2077506997. Tests are added to show usage with CUDAPluggableAllocator.
- MemPoolContext API makes a mempool active. Tests are added to show usage of this API. This API will be used in CUDACachingAllocator to route allocations to a user provided allocator. See draft here: https://github.com/pytorch/pytorch/pull/125722/

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131152
Approved by: https://github.com/eqy, https://github.com/ezyang
2024-08-01 01:29:30 +00:00
PyTorch MergeBot
e191b83462 Revert "Add wrappers for synchronous GPUDirect Storage APIs (#130633)"
This reverts commit 709ddf7a9d.

Reverted https://github.com/pytorch/pytorch/pull/130633 on behalf of https://github.com/clee2000 due to still failing internally D60265673 ([comment](https://github.com/pytorch/pytorch/pull/130633#issuecomment-2253239607))
2024-07-26 18:08:20 +00:00
Mikayla Gawarecki
709ddf7a9d Add wrappers for synchronous GPUDirect Storage APIs (#130633)
Based in part on https://github.com/NVIDIA/apex/pull/1774

Differential Revision: [D60155434](https://our.internmc.facebook.com/intern/diff/D60155434)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130633
Approved by: https://github.com/albanD
2024-07-25 22:23:38 +00:00
PyTorch MergeBot
e4b5645f83 Revert "Add wrappers for synchronous GPUDirect Storage APIs (#130633)"
This reverts commit 5b5e0698a5.

Reverted https://github.com/pytorch/pytorch/pull/130633 on behalf of https://github.com/clee2000 due to breaking a lot of jobs and build rules internally D60085885, possibly needs to update some bazel build? ([comment](https://github.com/pytorch/pytorch/pull/130633#issuecomment-2245806738))
2024-07-23 17:19:34 +00:00
Xiaodong Wang
9e753d1f20 [AMD] catch exception when other processes belong to other users (#131018)
Summary:
It is a long known pain point that if other users are running things, the call of `torch.cuda.memory.list_gpu_processes()` will error out:
```
  torch.cuda.memory.list_gpu_processes()
  File "torch/cuda/memory.py", line 647, in list_gpu_processes
    procs = amdsmi.amdsmi_get_gpu_process_list(handle)  # type: ignore[attr-defined]
  File "amdsmi/py_interface/amdsmi_interface.py", line 1946, in amdsmi_get_gpu_process_list
    _check_res(
  File "amdsmi/py_interface/amdsmi_interface.py", line 510, in _check_res
    raise AmdSmiLibraryException(ret_code)
amdsmi.py_interface.amdsmi_exception.AmdSmiLibraryException: Error code:
	10 | AMDSMI_STATUS_NO_PERM - Permission Denied

```

So just catch this error

Test Plan: torch.cuda.memory.list_gpu_processes() no longer fails

Differential Revision: D59901053

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131018
Approved by: https://github.com/eqy, https://github.com/clee2000
2024-07-22 19:38:51 +00:00
Mikayla Gawarecki
5b5e0698a5 Add wrappers for synchronous GPUDirect Storage APIs (#130633)
Based in part on https://github.com/NVIDIA/apex/pull/1774

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130633
Approved by: https://github.com/albanD
2024-07-22 14:51:24 +00:00
Jack Taylor
e9023d57b0 [ROCm] Return correct AMDSMI socket_power metric (#130331)
Extending on the change in https://github.com/pytorch/pytorch/pull/127729

Depending on gcnArch the API to return socket power will change based on underlying gpu_metrics. This PR will handle both cases.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130331
Approved by: https://github.com/jeffdaily, https://github.com/eqy, https://github.com/malfet
2024-07-17 01:58:58 +00:00
Xuehai Pan
973037be6a [BE][Easy] apply autofix for ruff rules unnecessary-collection-call (C408): list() / tuple() / dict() (#130199)
This PR changes the empty collection factory call to Python literals:

- `list()` -> `[]`
- `tuple()` -> `()`
- `dict()` -> `{}`

The Python literals are more performant and safer. For example, the bytecode for building an empty dictionary:

```bash
$ python3 -m dis - <<EOS
import collections

d1 = {}
d2 = dict()

dict = collections.OrderedDict
d3 = dict()
EOS
```

```text
  0           0 RESUME                   0

  1           2 LOAD_CONST               0 (0)
              4 LOAD_CONST               1 (None)
              6 IMPORT_NAME              0 (collections)
              8 STORE_NAME               0 (collections)

  3          10 BUILD_MAP                0
             12 STORE_NAME               1 (d1)

  4          14 PUSH_NULL
             16 LOAD_NAME                2 (dict)
             18 CALL                     0
             26 STORE_NAME               3 (d2)

  6          28 LOAD_NAME                0 (collections)
             30 LOAD_ATTR                8 (OrderedDict)
             50 STORE_NAME               2 (dict)

  7          52 PUSH_NULL
             54 LOAD_NAME                2 (dict)
             56 CALL                     0
             64 STORE_NAME               5 (d3)
             66 RETURN_CONST             1 (None)
```

The dict literal `{}` only has one bytecode `BUILD_MAP`, while the factory call `dict()` has three `PUSH_NULL + LOAD_NAME + CALL`. Also, the factory call is not safe if users override the `dict` name in `locals` or `globals` (see the example of replacing with `OrderedDict` above).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130199
Approved by: https://github.com/malfet
2024-07-11 17:30:28 +00:00
PyTorch MergeBot
07450e9713 Revert "[MPS] Add support for autocast in MPS (#99272)"
This reverts commit 6240cfd5c7.

Reverted https://github.com/pytorch/pytorch/pull/99272 on behalf of https://github.com/jeanschmidt due to introduced breakages in trunk ([comment](https://github.com/pytorch/pytorch/pull/99272#issuecomment-2203033719))
2024-07-02 12:29:51 +00:00
Jeff Willette
5c9d5272e4 fixes #124582 (#128483)
added check for existence of outputs requiring grad to make_graphed_callables.

added new test case, updated existing test case to include parameterless modules.

Fixes #124582

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128483
Approved by: https://github.com/eqy, https://github.com/ezyang
2024-07-02 08:45:59 +00:00
Kulin Seth
6240cfd5c7 [MPS] Add support for autocast in MPS (#99272)
Fixes https://github.com/pytorch/pytorch/issues/88415

Co-authored-by: Siddharth Kotapati <skotapati@apple.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99272
Approved by: https://github.com/malfet
2024-07-02 01:49:52 +00:00
Jack Taylor
e1b426b345 [ROCm] CUDA_VISIBLE_DEVICES fallback option for device_count (#129650)
Updating `_parse_visible_devices` to allow use of CUDA_VISIBLE_DEVICES if HIP_VISIBLE_DEVICES is unset, to avoid any unnecessary code changes in workloads that already rely on CUDA_VISIBLE_DEVICES.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129650
Approved by: https://github.com/hongxiayang, https://github.com/malfet
2024-07-01 11:40:09 +00:00
Nikita Shulga
14dc08ddc7 Inductor to fail gracefully on Voltas for bf16 tensors (#129288)
Volta(sm_7x) do not have a HW support for bfloat16 datatype, and while it is is emulated to ted in software, so PyTorch eager can use bfloat16 tensors, but not in Triton. So if graph with either CUDA bf16 input or output tensors is used, raise warnings and skip the frame.

Add optional parameter `including_emulation` to `torch.cuda.is_bf16_supported` method and call it from `torch._inductor.compile_fx. _check_triton_bf16_support`.

Test plan: Modify `is_bf16_supported` to return False and see that warning is generated

Fixes https://github.com/pytorch/pytorch/issues/118122 and https://github.com/pytorch/pytorch/issues/118581

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129288
Approved by: https://github.com/eqy, https://github.com/jansel
2024-06-25 00:04:13 +00:00
awayzjj
b70440f0a7 Document the torch.cuda.profiler.profile function (#128216)
Fixes https://github.com/pytorch/pytorch/issues/127901

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128216
Approved by: https://github.com/malfet, https://github.com/eqy
2024-06-17 23:42:40 +00:00
ibartol
c6b180a316 Created docs (and example) for cudart function in torch.cuda (#128741)
Fixes #127908

## Description

Created docs to document the torch.cuda.cudart function to solve the issue #127908.
I tried to stick to the [guidelines to document a function](https://github.com/pytorch/pytorch/wiki/Docstring-Guidelines#documenting-a-function) but I was not sure if there is a consensus on how to handle the docs of a function that calls an internal function. So I went ahead and tried what the function will raise, etc. from the user endpoint and documented it (i.e. I am giving what actually _lazy_init() will raise).

Updated PR from #128298 since I made quite a big mistake in my branch. I apologize for the newbie mistake.

### Summary of Changes

- Added docs for torch.cuda.cudart
- Added the cudart function in the autosummary of docs/source/cuda.rst

## Checklist
- [X] The issue that is being fixed is referred in the description
- [X] Only one issue is addressed in this pull request
- [X] Labels from the issue that this PR is fixing are added to this pull request
- [X] No unnecesary issues are included into this pull request

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128741
Approved by: https://github.com/msaroufim
2024-06-17 16:50:37 +00:00
anandptl84
f48ca2561d Document torch.cuda.profiler.start (#128098)
document https://github.com/pytorch/pytorch/issues/127917 start function of cuda/ profiler.py

Fixes 127917

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128098
Approved by: https://github.com/aaronenyeshi
2024-06-14 01:44:18 +00:00
Xuehai Pan
83bb9b7c53 [BE] explicitly export subpackage torch.utils (#128342)
Resolves #126401

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128342
Approved by: https://github.com/Skylion007
ghstack dependencies: #127707
2024-06-13 04:39:16 +00:00
anandptl84
0f52dc7e51 Document torch.cuda.profiler.stop (#128196)
Fixes #127918

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128196
Approved by: https://github.com/malfet, https://github.com/eqy
2024-06-12 17:39:43 +00:00
Aaron Orenstein
62bcdc0ac9 Flip default value for mypy disallow_untyped_defs [4/11] (#127841)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127841
Approved by: https://github.com/oulgen
2024-06-08 18:36:48 +00:00
albanD
2ffdf556ea Add back API that some people rely on in torch.cuda.amp.grad_scaler namespace (#128056)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128056
Approved by: https://github.com/kit1980, https://github.com/eqy
2024-06-06 17:02:32 +00:00
Kazuaki Ishizaki
6adcf21b2b Documenting the torch.cuda.nccl.version function (#128022)
Fixes #127892

This PR adds docstring to the torch.cuda.nccl.version function

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128022
Approved by: https://github.com/malfet
2024-06-06 01:13:07 +00:00
Jack Taylor
db515b6ac7 [ROCm] Fix error in torch.cuda initialisation if amdsmi is not available (#127528)
Reported in https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/15874

When nvml_count is set via 9f73c65b8f/torch/cuda/__init__.py (L834)

If amdsmi is not available this will throw an error
```
File "python3.10/site-packages/torch/cuda/__init__.py", line 634, in _raw_device_count_amdsmi
    except amdsmi.AmdSmiException as e:
NameError: name 'amdsmi' is not defined
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127528
Approved by: https://github.com/jeffdaily, https://github.com/eqy, https://github.com/pruthvistony, https://github.com/atalman
2024-06-04 11:16:02 +00:00
Jeff Daily
0e7bd7fedd [ROCm] TunableOp improvements (#124362)
- use less memory; smaller default hipblaslt workspace size
- options to avoid cache effects
  - icache flush option
  - rotating buffers during tuning
- python APIs
- unit tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124362
Approved by: https://github.com/xw285cornell
2024-06-03 22:30:11 +00:00
Xiaodong Wang
406532f864 [AMD] Fix power_draw api (#127729)
Summary: average_socket_power only gives me NA. So we need to change it to current_socket_power

Test Plan: Before `torch.cuda.power_draw` gives me NA, after it gives me the right power reading (e.g.441)

Differential Revision: D58047484

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127729
Approved by: https://github.com/nmacchioni, https://github.com/eqy
2024-06-03 21:46:50 +00:00
Xuehai Pan
67ef2683d9 [BE] wrap deprecated function/class with typing_extensions.deprecated (#127689)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

Resolves #126888

- #126888

This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
2024-06-02 12:30:43 +00:00