Commit Graph

1178 Commits

Author SHA1 Message Date
PyTorch MergeBot
afa1eda901 Revert "[PGNCCL] Launch kernel on current stream & remove record_stream entirely (#148590)"
This reverts commit ef6296e7f2.

Reverted https://github.com/pytorch/pytorch/pull/148590 on behalf of https://github.com/izaitsevfb due to reverted internally, see D71292427 ([comment](https://github.com/pytorch/pytorch/pull/148590#issuecomment-2731114626))
2025-03-17 22:43:15 +00:00
James Wu
a9c55277d7 [Reland] First version of statically compiled launcher for triton compiled CUDA kernels (#149238)
This is a new version of https://github.com/pytorch/pytorch/pull/148561 fixing the ROCM test failure

Putting this up for a first pass review, though I will likely make a bunch of changes before landing to add more features, etc.

This diff implements a first version of a static CUDA kernel launcher in `torch._C`. The goal here is to take a cubin file and some metadata from a CompiledKernel from `triton`, and launch the cubin file directly.

Background doc: https://docs.google.com/document/d/1rjRcHl6MfauHG30nCoQX-9UKvKyIs4WWMy_GsGyqb9g/edit?tab=t.0#heading=h.ut5lf39lzq66

Normally, using triton's CompiledKernel.make_launcher(), we would pay the cost of codegenning C++ and running it at compile time. With this new approach, we can use one statically compiled library to launch the kernel.

The tradeoff here is that this new kernel launcher will not be able to use codegen to deal with different lengths/types of arguments. So we use templating to handle up to 10 arguments for now. We also allocate 8 bytes on the stack per argument no matter the argument type, which can take more memory than codegenning. On the other hand, we improve compile time on cold and warm start by not having to call the C++ compiler at all.

This diff does not add the launcher to torch, but introduces a basic test suite.

A list of TODOs that are not yet complete:
- Handle `nvTmaDesc` and `cuTensorMap`, which triton handles
- Embed the grid logic instead of passing in gridX,Y,Z
- Handle launch_enter and exit hooks? (Not sure if inductor has these)
- Benchmarking to see if there's runtime performance loss
- Probably lots of features of the triton C++ generated code that I haven't handled yet.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149238
Approved by: https://github.com/oulgen
2025-03-15 15:06:46 +00:00
PyTorch MergeBot
643aaea133 Revert "[RFC] First version of statically compiled launcher for triton compiled CUDA kernels (#148561)"
This reverts commit 5a843f8973.

Reverted https://github.com/pytorch/pytorch/pull/148561 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/148561#issuecomment-2725969268))
2025-03-14 23:01:26 +00:00
James Wu
5a843f8973 [RFC] First version of statically compiled launcher for triton compiled CUDA kernels (#148561)
Putting this up for a first pass review, though I will likely make a bunch of changes before landing to add more features, etc.

This diff implements a first version of a static CUDA kernel launcher in `torch._C`. The goal here is to take a cubin file and some metadata from a CompiledKernel from `triton`, and launch the cubin file directly.

Background doc: https://docs.google.com/document/d/1rjRcHl6MfauHG30nCoQX-9UKvKyIs4WWMy_GsGyqb9g/edit?tab=t.0#heading=h.ut5lf39lzq66

Normally, using triton's CompiledKernel.make_launcher(), we would pay the cost of codegenning C++ and running it at compile time. With this new approach, we can use one statically compiled library to launch the kernel.

The tradeoff here is that this new kernel launcher will not be able to use codegen to deal with different lengths/types of arguments. So we use templating to handle up to 10 arguments for now. We also allocate 8 bytes on the stack per argument no matter the argument type, which can take more memory than codegenning. On the other hand, we improve compile time on cold and warm start by not having to call the C++ compiler at all.

This diff does not add the launcher to torch, but introduces a basic test suite.

A list of TODOs that are not yet complete, will do in separate diff:
- Handle `nvTmaDesc` and `cuTensorMap`, which triton handles
- Embed the grid logic instead of passing in gridX,Y,Z. With https://github.com/pytorch/pytorch/pull/147583, we should be able to handle all of the grid logic directly in _StaticCudaLauncher.launch_kernel, and get rid of the python evaluation.
- Handle launch_enter and exit hooks? (Not sure if inductor has these)
- Benchmarking to see if there's runtime performance loss
- Hooking it up with a config to inductor
- Testing harness to test against torch generated triton kernels

Differential Revision: [D69926783](https://our.internmc.facebook.com/intern/diff/D69926783/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148561
Approved by: https://github.com/aorenste, https://github.com/syed-ahmed
2025-03-14 19:12:13 +00:00
Shangdi Yu
cf19efd3d9 Support basic TorchBind in aot_compile and aoti_compile_and_package (#148506)
Summary:
**Codegen**

- Skip some codegen parts for torchbind (such as arg decleration) because they are loaded in proxy executor, so we do not need to declare torchbind args in cpp code
- Added a helper method to get the schema of CallTorchBind HOP. The returned schema is only the schema of `obj.method()`.

**Serialization**
Add support for torchbind object in serialization

- For CallTorchBind HOP, we need to handle it specially because of it's schema. The output serialized args is in the format of `(obj, method, *args, **kwargs)`.
- it.TorchBindObject inputs are serialized to `as_custom_obj` Argument.

**Packaging**

Add torchbind objects file and `custom_objs_config.json` file to generated files output of `aot_compile`.

The json file is stored in the `data/aotinductor/<model_name>` folder in pt2 archive.

The torchbind objects are stored in data/constants/ folder in pt2 archive.
The format of torchbind objects are `f"{CUSTOM_OBJ_FILENAME_PREFIX}{custom_obj_idx}"`. e.g. `custom_obj_0`.
CustomClassHolder objects implement their own pickle methods.

Note that this `custom_objs_config.json` file is different from the `model_constants_config.json` file produced in package_sigmoid(). The keys in `custom_objs_config` directly correspond to the arg name in extern nodes json.
The key in `model_constants_config.json` produced by `package_sigmoid` is the attribute name in the user mode code.

This is required for both internal and OSS torchbind support.
For OSS torchbind support, we also need to package torchbind_constants into the .pt2 output.

**Work Left**
We still need to add torchbind support in ProxyExecutor for inductor.aoti_load_package to work. See other diffs in the stack.

Test Plan:
```
buck run fbcode//mode/dev-nosan //caffe2/test/inductor:torchbind -- -r schema
buck run fbcode//mode/dev-nosan //caffe2/test/inductor:torchbind -- -r aot_compile
```

Differential Revision: D69490718

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148506
Approved by: https://github.com/angelayi
2025-03-11 20:55:18 +00:00
Ke Wen
ef6296e7f2 [PGNCCL] Launch kernel on current stream & remove record_stream entirely (#148590)
This PR has multiple changes to `ProcessGroupNCCL` (which unfortunately are related):
1. When async_op=False, we directly launch the collective on "current" stream, instead of a trampoline stream and join back.
- Resolves #147729
- Resolves #146881
- Also saves two event syncs (which have overhead in case of HIP) and one pybind when we call `work.wait()` in distributed_c10d.py on behalf of user.
2. Entirely remove `record_stream` and use CPU-side stashing for managing tensor lifetime against recycling.
- Resolves #147168
3. Remove tensor life management when async_op=False; only use it when async_op=True.
4. To guard against user not calling `work.wait()`, we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion [here](https://github.com/pytorch/pytorch/issues/147168#issuecomment-2660142460).
5. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels.

Joint work with @cenzhaometa who wants to remove the event sync overhead.

Cc: @ngimel @awgu @Aidyn-A @skyw @wconstab @leonardo0lyj

Differential Revision: [D70937982](https://our.internmc.facebook.com/intern/diff/D70937982)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148590
Approved by: https://github.com/eqy, https://github.com/Aidyn-A, https://github.com/fduwjj
2025-03-11 18:36:12 +00:00
PyTorch MergeBot
a95eb0c0a7 Revert "[PGNCCL] Launch kernel on current stream & remove record_stream entirely (#148590)"
This reverts commit 2149f6c684.

Reverted https://github.com/pytorch/pytorch/pull/148590 on behalf of https://github.com/ZainRizvi due to Breaking internally, see D70873275. Discussed reverting this with Ke. To validate your fixes internally, you can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/148590#issuecomment-2712001270))
2025-03-10 22:38:40 +00:00
Jason Ansel
5d4e7d58b4 [fx] Move Node._prepend/Node._remove_from_list to C++ (#148261)
Microbenchmarking `fx.symbolic_trace(lambda x: functools.reduce(operator.add, [x, *range(100000)]))`, before:
```
24303536 function calls (23503339 primitive calls) in 10.726 seconds
```
after:
```
20003454 function calls (19203257 primitive calls) in 8.936 seconds
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148261
Approved by: https://github.com/oulgen
ghstack dependencies: #148243, #148260
2025-03-10 16:06:11 +00:00
Jason Ansel
bf752c36da [fx] Move Node._update_args_kwargs to C++ (#148260)
Microbenchmarking `fx.symbolic_trace(lambda x: functools.reduce(operator.add, [x, *range(100000)]))`, before:
```
25203549 function calls (24403352 primitive calls) in 12.090 seconds
```
after:
```
24303536 function calls (23503339 primitive calls) in 10.726 seconds
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148260
Approved by: https://github.com/oulgen
ghstack dependencies: #148243
2025-03-10 16:06:02 +00:00
Jason Ansel
bec7bdad47 [fx] Move map_aggregate to C++ (#148243)
Microbenchmarking `fx.symbolic_trace(lambda x: functools.reduce(operator.add, [x, *range(100000)]))`, before:
```
30603618 function calls (29403419 primitive calls) in 13.744 seconds
```
after:
```
25203549 function calls (24403352 primitive calls) in 12.090 seconds
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148243
Approved by: https://github.com/oulgen
2025-03-10 16:05:53 +00:00
Ke Wen
2149f6c684 [PGNCCL] Launch kernel on current stream & remove record_stream entirely (#148590)
This PR has multiple changes to `ProcessGroupNCCL` (which unfortunately are related):
1. When async_op=False, we directly launch the collective on "current" stream, instead of a trampoline stream and join back.
- Resolves #147729
- Resolves #146881
- Also saves two event syncs (which have overhead in case of HIP) and one pybind when we call `work.wait()` in distributed_c10d.py on behalf of user.
2. Entirely remove `record_stream` and use CPU-side stashing for managing tensor lifetime against recycling.
- Resolves #147168
3. Remove tensor life management when async_op=False; only use it when async_op=True.
4. To guard against user not calling `work.wait()`, we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion [here](https://github.com/pytorch/pytorch/issues/147168#issuecomment-2660142460).
5. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels.

Joint work with @cenzhaometa who wants to remove the event sync overhead.

Cc: @ngimel @awgu @Aidyn-A @skyw @wconstab @leonardo0lyj

Differential Revision: [D70835197](https://our.internmc.facebook.com/intern/diff/D70835197)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148590
Approved by: https://github.com/eqy, https://github.com/Aidyn-A, https://github.com/fduwjj
2025-03-09 07:32:23 +00:00
PyTorch MergeBot
9cb25f0ea2 Revert "[PGNCCL] Launch kernel on current stream & remove record_stream entirely (#148590)"
This reverts commit 17dbeb11db.

Reverted https://github.com/pytorch/pytorch/pull/148590 on behalf of https://github.com/janeyx99 due to PR break backward compat test ([comment](https://github.com/pytorch/pytorch/pull/148590#issuecomment-2708641172))
2025-03-09 03:01:55 +00:00
Ke Wen
17dbeb11db [PGNCCL] Launch kernel on current stream & remove record_stream entirely (#148590)
This PR has multiple changes to `ProcessGroupNCCL` (which unfortunately are related):
1. When async_op=False, we directly launch the collective on "current" stream, instead of a trampoline stream and join back.
- Resolves #147729
- Resolves #146881
- Also saves two event syncs (which have overhead in case of HIP) and one pybind when we call `work.wait()` in distributed_c10d.py on behalf of user.
2. Entirely remove `record_stream` and use CPU-side stashing for managing tensor lifetime against recycling.
- Resolves #147168
3. Remove tensor life management when async_op=False; only use it when async_op=True.
4. To guard against user not calling `work.wait()`, we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion [here](https://github.com/pytorch/pytorch/issues/147168#issuecomment-2660142460).
5. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels.

Joint work with @cenzhaometa who wants to remove the event sync overhead.

Cc: @ngimel @awgu @Aidyn-A @skyw @wconstab @leonardo0lyj

Differential Revision: [D70835197](https://our.internmc.facebook.com/intern/diff/D70835197)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148590
Approved by: https://github.com/eqy, https://github.com/Aidyn-A, https://github.com/fduwjj
2025-03-08 20:00:12 +00:00
Tristan Rice
7ffadff286 c10d/ProcessGroup: cleanup abort and shutdown (#148798)
This adds `abort` and `shutdown` to `Backend` and `ProcessGroup` objects. This simplifies the logic in `distributed_c10d.py` by having a default noop implementation for all PGs.

This will be useful for torchft and upcoming versions of NCCL which will handle abort correctly. Currently `torchft` would have to call internal methods `_abort` on the PGNCCL object directly but with this change we can now just call `.abort()` and have it work for any PG implementation.

Test plan:

```
pytest distributed/test_backends.py distributed/test_c10d_common.py distributed/test_c10d_pypg.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148798
Approved by: https://github.com/kwen2501
2025-03-08 18:33:18 +00:00
Sanket Purandare
9841f0ddcf Add support for non functional collectives under FakeTensorMode and fake_pg for memory tracking (#147566)
This PR adds support for non-functional collectives under `FakeTensorMode` and `fake_pg`. It helps eliminate the patching of collectives for memory and runtime estimation.

It also modifies the `ModTracker` to enable the post-backward hook call for modules whose inputs don't require gradients but parameters do.

For the memory tracking, we now enable tracking DTensor dispatcher for custom dispatch functions like `entropy_loss`.
Dispatcher is only enabled for the memory tracking part and disabled as soon as it is done.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147566
Approved by: https://github.com/weifengpy
2025-03-08 18:00:49 +00:00
Mikayla Gawarecki
be0ceee1c3 Make record/storage alignment in torch.save configurable (#147788)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147788
Approved by: https://github.com/albanD
ghstack dependencies: #147786, #147787
2025-03-06 12:04:46 +00:00
Marko Radmilac
c65ee728f0 Initial implementation of host memory stats (#147660)
This is an initial attempt to provide some statistics for the pinned host memory allocations flowing through CachingHostAllocator. Many times in the past we have had inexplicable slowdowns that would be much easier to diagnose if we had some host memory characteristics.

This change tries very hard not to disrupt the initial design of the allocator, and it uses existing locking mechanism, whenever possible, to gather statistics "for free". Only deviation from that is on the "slow path" where we incur CUDA calls anyway, so taking a short lock is not going to hurt the performance much, especially in the steady state where most allocations will come from cache.

As mentioned before, this is the first PR, to introduce the concept and to see if it fits the right paradigm. We can always add more later.

Metrics that would require more involved changes to the code base and locks, like requested memory, have been punted for now. I also tried to reuse the Stat structure used in CUDA caching allocator, in order to maintain symmetry.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147660
Approved by: https://github.com/ngimel
2025-03-05 16:13:19 +00:00
wdziurdz
edc3ca577e [Profiler] Add profiler activity for HPU devices (#148182)
Fixes #148181

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148182
Approved by: https://github.com/sraikund16
2025-03-05 01:37:48 +00:00
Mwiza Kunda
b5873292c6 Add overload names to profiler trace (#143114)
Currently, recorded profiler events for aten ops do not store overload names. It would be useful to know which overloads are actually called to analyse performance.
For example, consider the following dispatch trace which occurs if there is a fallthrough kernel registered for aten::add:
```
             [call] op=[aten::add.Tensor], key=[AutogradCPU]
               [redispatch] op=[aten::add.Tensor], key=[Undefined]
                 [call] op=[aten::empty.memory_format], key=[BackendSelect]
                   [redispatch] op=[aten::empty.memory_format], key=[CPU]
                 [call] op=[aten::add.out], key=[CPU]
```

In this case, aten::add.out is a child of aten::add.Tensor, however the current profiler trace provides no way to differentiate aten op calls.

See the added unit test for a more detailed example.

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143114
Approved by: https://github.com/sraikund16
2025-03-05 01:00:29 +00:00
Zain Rizvi
f30776c37a [BE] Upgrade to mypy 1.14 (#145966)
Upgrade mypy version

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145966
Approved by: https://github.com/Skylion007
2025-03-04 20:58:26 +00:00
PyTorch MergeBot
92beda54c8 Revert "[fx] Move map_aggregate to C++ (#148243)"
This reverts commit edaff88f69.

Reverted https://github.com/pytorch/pytorch/pull/148243 on behalf of https://github.com/jovianjaison due to breaking internal builds [T216910920] ([comment](https://github.com/pytorch/pytorch/pull/148243#issuecomment-2698724058))
2025-03-04 19:40:21 +00:00
PyTorch MergeBot
17d003fe75 Revert "[fx] Move Node._update_args_kwargs to C++ (#148260)"
This reverts commit 0135f57f4a.

Reverted https://github.com/pytorch/pytorch/pull/148260 on behalf of https://github.com/jovianjaison due to breaking internal builds [T216910920] ([comment](https://github.com/pytorch/pytorch/pull/148243#issuecomment-2698724058))
2025-03-04 19:40:21 +00:00
PyTorch MergeBot
97b9e68bc6 Revert "[fx] Move Node._prepend/Node._remove_from_list to C++ (#148261)"
This reverts commit 29c2de9ae1.

Reverted https://github.com/pytorch/pytorch/pull/148261 on behalf of https://github.com/jovianjaison due to breaking internal builds [T216910920] ([comment](https://github.com/pytorch/pytorch/pull/148243#issuecomment-2698724058))
2025-03-04 19:40:21 +00:00
taozhiwei
16d07988fc add supports_coalescing property in c10d::Backend to determine whether backend supports coalescing (#135338)
1. My company is using privateuseone to connect new hardware device and requires the use of `batch_isend_irecv` function. However, `batch_isend_irecv` is currently only open to CUDA, so I add `supports_coalescing` property in `c10d::Backend` to determine whether backend supports coalescing.
2. If `pg._has_hooks` return True, We don't need to determine if the current device is CUDA. So privateuseone can also support `pg._wait_for_pending_works`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135338
Approved by: https://github.com/kwen2501, https://github.com/albanD
2025-03-04 12:37:06 +00:00
cyy
98bf2f1170 Use Python 3.9 typing (#148157)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148157
Approved by: https://github.com/janeyx99
2025-03-04 03:09:55 +00:00
Jason Ansel
29c2de9ae1 [fx] Move Node._prepend/Node._remove_from_list to C++ (#148261)
Microbenchmarking `fx.symbolic_trace(lambda x: functools.reduce(operator.add, [x, *range(100000)]))`, before:
```
24303536 function calls (23503339 primitive calls) in 10.726 seconds
```
after:
```
20003454 function calls (19203257 primitive calls) in 8.936 seconds
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148261
Approved by: https://github.com/oulgen
ghstack dependencies: #148243, #148260
2025-03-02 22:42:31 +00:00
Jason Ansel
0135f57f4a [fx] Move Node._update_args_kwargs to C++ (#148260)
Microbenchmarking `fx.symbolic_trace(lambda x: functools.reduce(operator.add, [x, *range(100000)]))`, before:
```
25203549 function calls (24403352 primitive calls) in 12.090 seconds
```
after:
```
24303536 function calls (23503339 primitive calls) in 10.726 seconds
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148260
Approved by: https://github.com/oulgen
ghstack dependencies: #148243
2025-03-02 22:42:31 +00:00
Jason Ansel
edaff88f69 [fx] Move map_aggregate to C++ (#148243)
Microbenchmarking `fx.symbolic_trace(lambda x: functools.reduce(operator.add, [x, *range(100000)]))`, before:
```
30603618 function calls (29403419 primitive calls) in 13.744 seconds
```
after:
```
25203549 function calls (24403352 primitive calls) in 12.090 seconds
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148243
Approved by: https://github.com/oulgen
2025-03-02 22:42:31 +00:00
PyTorch MergeBot
a983b2b11a Revert "Initial implementation of host memory stats (#147660)"
This reverts commit 945e359fc1.

Reverted https://github.com/pytorch/pytorch/pull/147660 on behalf of https://github.com/mradmila due to There is an issue with ambiguous definition of Stat structure when different C++ tools are used. Backing out for now. ([comment](https://github.com/pytorch/pytorch/pull/147660#issuecomment-2692346379))
2025-03-01 18:05:45 +00:00
William Wen
40b3e4a358 [dynamo] expose code execution strategy to python (#148020)
@anijain2305 this can be used to mark a code object to be skipped/run-only (recursively) while tracing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148020
Approved by: https://github.com/jansel
2025-02-28 21:59:12 +00:00
Marko Radmilac
945e359fc1 Initial implementation of host memory stats (#147660)
This is an initial attempt to provide some statistics for the pinned host memory allocations flowing through CachingHostAllocator. Many times in the past we have had inexplicable slowdowns that would be much easier to diagnose if we had some host memory characteristics.

This change tries very hard not to disrupt the initial design of the allocator, and it uses existing locking mechanism, whenever possible, to gather statistics "for free". Only deviation from that is on the "slow path" where we incur CUDA calls anyway, so taking a short lock is not going to hurt the performance much, especially in the steady state where most allocations will come from cache.

As mentioned before, this is the first PR, to introduce the concept and to see if it fits the right paradigm. We can always add more later.

Metrics that would require more involved changes to the code base and locks, like requested memory, have been punted for now. I also tried to reuse the Stat structure used in CUDA caching allocator, in order to maintain symmetry.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147660
Approved by: https://github.com/ngimel
2025-02-28 18:36:44 +00:00
Luca Wehrstedt
60d94ea22b Add option to limit number of SMs used by matmul kernels (#147966)
Resubmission of #144974 which was reverted for unrelated reasons.

Newer matmul kernels, e.g. those targeting Hopper GPUs, sometime use a "persistent" schedule which consists in launching as many CUDA blocks as there are SMs on the GPU, with each such block then working on multiple output tiles in a row. This allows to eliminate the overhead of starting and finishing each tile, effectively doing cross-tile pipelining. In previous generations these latencies could be hidden by having multiple CUDA blocks per SM but, with blocks becoming larger, only one can run at a time per SM and thus this needs to be taken care of in software.

Persistent kernels become an issue when other kernels are running concurrently. The classical example is a NCCL communication kernel running in the background. In such cases the matmul expects to be able to use all the SMs but is prevented from doing so because some of the are busy. This can lead to its blocks being scheduled as two separate waves on the available SMs. This "wave quantization" can double the latency of the matmul kernels.

While we wait for smarter solutions, such as automatic load balancing among the blocks, an easy way to unblock ourselves is to tell the matmuls to only use a subset of the GPU's SMs. For this, I am introducing a global `sm_carveout` flag which can be used to specify how many SMs should be left available for other kernels.

For now I only change the cuBLAS kernels and the scaled-mm CUTLASS kernel. More kernels can be opted-in later.

I tested this change manually, by using the Kineto profiler to look up the grid size of a scaled-mm kernel with different values of `sm_carveout`, and making sure it changed. Suggestions are welcome for a more automated test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147966
Approved by: https://github.com/danthe3rd
2025-02-26 12:01:12 +00:00
PyTorch MergeBot
1e894d2635 Revert "Add option to limit number of SMs used by matmul kernels (#144974)"
This reverts commit af2d63637e.

Reverted https://github.com/pytorch/pytorch/pull/144974 on behalf of https://github.com/wdvr due to reverting in order to revert #147548 that causes a merge conflict ([comment](https://github.com/pytorch/pytorch/pull/144974#issuecomment-2683461733))
2025-02-25 22:46:38 +00:00
Luca Wehrstedt
af2d63637e Add option to limit number of SMs used by matmul kernels (#144974)
Newer matmul kernels, e.g. those targeting Hopper GPUs, sometime use a "persistent" schedule which consists in launching as many CUDA blocks as there are SMs on the GPU, with each such block then working on multiple output tiles in a row. This allows to eliminate the overhead of starting and finishing each tile, effectively doing cross-tile pipelining. In previous generations these latencies could be hidden by having multiple CUDA blocks per SM but, with blocks becoming larger, only one can run at a time per SM and thus this needs to be taken care of in software.

Persistent kernels become an issue when other kernels are running concurrently. The classical example is a NCCL communication kernel running in the background. In such cases the matmul expects to be able to use all the SMs but is prevented from doing so because some of the are busy. This can lead to its blocks being scheduled as two separate waves on the available SMs. This "wave quantization" can double the latency of the matmul kernels.

While we wait for smarter solutions, such as automatic load balancing among the blocks, an easy way to unblock ourselves is to tell the matmuls to only use a subset of the GPU's SMs. For this, I am introducing a global `sm_carveout` flag which can be used to specify how many SMs should be left available for other kernels.

For now I only change the cuBLAS kernels and the scaled-mm CUTLASS kernel. More kernels can be opted-in later.

I tested this change manually, by using the Kineto profiler to look up the grid size of a scaled-mm kernel with different values of `sm_carveout`, and making sure it changed. Suggestions are welcome for a more automated test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144974
Approved by: https://github.com/eqy, https://github.com/albanD
2025-02-25 10:19:19 +00:00
Luca Wehrstedt
5ed1e23e3a Fix type stubs for SymmetricMemory (#146310)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146310
Approved by: https://github.com/yifuwang
2025-02-21 19:59:43 +00:00
Aaron Orenstein
db4ce78d46 PEP585: More UP006 fixes (#146392)
This should be the final PR before we can enable RUFF UP006.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146392
Approved by: https://github.com/justinchuby, https://github.com/albanD, https://github.com/Skylion007
2025-02-20 06:18:13 +00:00
William Wen
63e8ad49b8 [dynamo] replace hardcoded eval frame control flags skip_code_recursive_flag/cache_limit_hit_flag (#146355)
This PR and the previous:
- Moves parts of `eval_frame.c` to C++.
- Reduces code duplication in `dynamo__custom_eval_frame` and makes the control flow more clear.
- Enables `convert_frame` to signal to `eval_frame.cpp` in a general manner how to evaluate this frame, recursive frames, and future frames with the same code object (default/compile, skip, run-only). e.g. this will allow us to change skipping/cache limit hit eval_frame behavior directly from convert_frame without requiring changes to C/C++.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146355
Approved by: https://github.com/jansel
ghstack dependencies: #145603
2025-02-18 21:37:12 +00:00
Yan Zhiwei
ae351d4d0e [Intel GPU] allow_tf32 for oneDNN backend - XPU part (#137570)
# Motivation
Add context variable `torch.bachend.mkldnn.allow_tf32` to control tf32 computation in convolution kernels at XPU side.  The tf32 data type is beneficial to improve the performance of deep learning workloads during training/inference. Current PR uses the [oneDNN API fpmath_mode](https://oneapi-src.github.io/oneDNN/dev_guide_attributes_fpmath_mode.html#the-floating-point-math-mode-attribute) to trigger the tf32 acceleration in convolution kernels.

# Valiadation
* ut to test context variable
`python test/xpu/test_conv.py -k test_mkldnn_allow_tf32_get_set`

* Runtime exemplification
```
onednn_verbose,primitive,exec,gpu:0,convolution,jit:ir,forward_training,src_f32::blocked:abcd::f0 wei_f32::blocked:abcd::f0 bia_f32::blocked:a::f0 dst_f32::blocked:abcd::f0,attr-scratchpad:user attr-fpmath:tf32,alg:convolution_direct,mb20_ic16oc33_ih50oh24kh3sh2dh0ph0_iw100ow49kw3sw2dw0pw0,0.649902
onednn_verbose,primitive,exec,gpu:0,convolution,jit:ir,forward_training,src_f32::blocked:abcd::f0 wei_f32::blocked:abcd::f0 bia_f32::blocked:a::f0 dst_f32::blocked:abcd::f0,attr-scratchpad:user attr-fpmath:tf32,alg:convolution_direct,mb20_ic33oc33_ih24oh24kh3sh1dh0ph1_iw49ow49kw3sw1dw0pw1,0.151855
onednn_verbose,primitive,exec,gpu:0,convolution,jit:ir,backward_data,src_f32::blocked:abcd::f0 wei_f32::blocked:abcd::f0 bia_undef::undef::: dst_f32::blocked:abcd::f0,attr-scratchpad:user attr-fpmath:tf32,alg:convolution_direct,mb20_ic33oc33_ih24oh24kh3sh1dh0ph1_iw49ow49kw3sw1dw0pw1,0.167969
onednn_verbose,primitive,exec,gpu:0,convolution,jit:ir,backward_weights,src_f32::blocked:abcd::f0 wei_f32::blocked:abcd::f0 bia_f32::blocked:a::f0 dst_f32::blocked:abcd::f0,attr-scratchpad:user attr-fpmath:tf32,alg:convolution_direct,mb20_ic33oc33_ih24oh24kh3sh1dh0ph1_iw49ow49kw3sw1dw0pw1,0.26709
onednn_verbose,primitive,exec,gpu:0,convolution,jit:ir,backward_weights,src_f32::blocked:abcd::f0 wei_f32::blocked:abcd::f0 bia_f32::blocked:a::f0 dst_f32::blocked:abcd::f0,attr-scratchpad:user attr-fpmath:tf32,alg:convolution_direct,mb20_ic16oc33_ih50oh24kh3sh2dh0ph0_iw100ow49kw3sw2dw0pw0,0.219971

```
According to the field `fpmath:tf32` in verbose, we could see that, current context setting utils could successfully trigger tf32 computation in conv forward/backward_data/backward_weights kernels.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137570
Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/atalman, https://github.com/malfet

Co-authored-by: Yu, Guangye <guangye.yu@intel.com>
2025-02-17 01:46:43 +00:00
Animesh Jain
9dc702875d [dynamo][mappingproxy][inspect] Support existing types.MappingProxyType (#147217)
Fixes https://github.com/pytorch/pytorch/issues/147162

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147217
Approved by: https://github.com/williamwen42
2025-02-15 07:59:33 +00:00
PyTorch MergeBot
9a883007a2 Revert "Implement cuda graphs implementation of torch.cond and torch.while_loop (#140979)"
This reverts commit c7515da7b0.

Reverted https://github.com/pytorch/pytorch/pull/140979 on behalf of https://github.com/huydhn due to This change has been reported to break internal code ([comment](https://github.com/pytorch/pytorch/pull/140979#issuecomment-2657361940))
2025-02-13 18:04:26 +00:00
Daniel Galvez
c7515da7b0 Implement cuda graphs implementation of torch.cond and torch.while_loop (#140979)
This is a new PR for #130386 , which got stale and was closed. Since I force-pushed to that branch in order to rebase it on top of main, the PR can no longer be reopened, according to https://github.com/isaacs/github/issues/361

I fixed the possibly-not-warmed-up problem described here: https://github.com/pytorch/pytorch/pull/130386/files#r1690856534

Since starting this, torch.cond and torch.while_loop now apparently have support for backward passes. I will look into what it might take to support that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140979
Approved by: https://github.com/eqy, https://github.com/eellison
2025-02-11 18:16:15 +00:00
Hyunho Yeo
5f621c5879 [MTIA] (3/n) Implement PyTorch APIs to query/reset device peak memory usage (#146710)
Summary: Public summary (shared with Github): This diff implements a C++-Python binding to enable `reset_peak_memory_stats`.

Test Plan: The test is implemented in the following diff.

Reviewed By: yuhc

Differential Revision: D68988673

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146710
Approved by: https://github.com/nautsimon
2025-02-10 16:57:09 +00:00
Eddie Yan
9ee506bd93 [CUDA][cuBLAS] Add fp16 accumulate option to cuBLAS/cuBLASLt (#144441)
Test for `cublasGemmEx` added, still need to figure out the best way to exercise the other APIs...

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144441
Approved by: https://github.com/Chillee, https://github.com/malfet
2025-02-06 19:04:50 +00:00
Brian Hirsh
e68f5087d8 update _unsafe_set_version_counter to accept lists of tensors (#137921)
See the comment [here](https://github.com/pytorch/pytorch/issues/132014#issuecomment-2379547400) (cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @XilunWu @rec) - this PR updates `_unsafe_set_version_counter` to accept a list of tensors, for overhead-sensitive users (e.g. distributed) who need to hide VC bumps from autograd on a large list of tensors without wanting to suffer the overhead of going from python->C++ separately for every tensor in the list.

I left the binding in pybind, and used a `std::vector`. if we **really** need to optimize overhead even further, we could write a manual cpython binding.

I use this updated API in the next PR to fix FSDP2, so that it properly hides the VC of all `all_gather_buffer` tensors in its call to `split_with_sizes_copy.out(all_gather_buffers)`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137921
Approved by: https://github.com/awgu, https://github.com/albanD
2025-02-04 04:51:11 +00:00
Burak Turk
d89c7ea401 add WaitCounter type interface and get rid of type errors (#146175)
Summary: as titled

Differential Revision: D68960123

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146175
Approved by: https://github.com/andriigrynenko, https://github.com/Skylion007
2025-02-01 23:24:52 +00:00
Nikita Shulga
e56dcf2772 [CPUInductor] Fix SVE256 detection (#146207)
This PR removes `torch.cpu._is_arm_sve_supported()` and replaces is with stable `torch.backends.cpu.get_cpu_capability()`

I should have reviewed https://github.com/pytorch/pytorch/pull/134672 more thoroughly, because it introduced duplicate, but slightly different API for detecting CPU architectures, which resulted in runtime crashes on system that do support SVE128, rather than SVE256

Fixes https://github.com/pytorch/pytorch/issues/145441

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146207
Approved by: https://github.com/angelayi
2025-02-01 18:51:34 +00:00
PyTorch MergeBot
c3f71eb61b Revert "[CUDA][cuBLAS] Add fp16 accumulate option to cuBLAS/cuBLASLt (#144441)"
This reverts commit e2917245fb.

Reverted https://github.com/pytorch/pytorch/pull/144441 on behalf of https://github.com/ZainRizvi due to Sorry but this still fails internally with the same error.  @Chillee or @malfet, can you please help the change get tested? (See D68783351) ([comment](https://github.com/pytorch/pytorch/pull/144441#issuecomment-2627886999))
2025-01-31 17:43:09 +00:00
Eddie Yan
e2917245fb [CUDA][cuBLAS] Add fp16 accumulate option to cuBLAS/cuBLASLt (#144441)
Test for `cublasGemmEx` added, still need to figure out the best way to exercise the other APIs...

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144441
Approved by: https://github.com/Chillee, https://github.com/malfet
2025-01-30 22:33:50 +00:00
Ke Wen
51ee9b154e [c10d] Add NCCL memory allocator (#145675)
This PR implements a small UI improvement over #133603.

It prepares a NCCL memory allocator in torch cpp and then pybind's it out, so that user can directly use it.

UI:
```
pool = torch.cuda.MemPool(backend.mem_allocator)
with torch.cuda.use_mem_pool(pool):
    tensor = torch.arange(1024 * 1024 * 2, device=device)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145675
Approved by: https://github.com/syed-ahmed, https://github.com/wconstab
2025-01-30 18:19:00 +00:00
PyTorch MergeBot
5fa28bbe40 Revert "[c10d] Add NCCL memory allocator (#145675)"
This reverts commit 18a7a04c4a.

Reverted https://github.com/pytorch/pytorch/pull/145675 on behalf of https://github.com/ZainRizvi due to Sorry but this still fails internally. See D68866823 for details ([comment](https://github.com/pytorch/pytorch/pull/145675#issuecomment-2624900562))
2025-01-30 16:01:52 +00:00