Commit Graph

5 Commits

Author SHA1 Message Date
Oguz Ulgen
79ee6bbde3 Support triton.language.dtype with torch.compile (#121690)
Putting this PR as an RFC since I have resorted to some horrible hacks in order to make this work.
```
(Pdb) p triton.language.float32
triton.language.fp32
(Pdb) p str(triton.language.float32)
'fp32'
(Pdb) p repr(triton.language.float32)
'triton.language.fp32'
```
This means that we need to "rewrite" them for fx graph and inductor execution.

This PR allows Mamba2 to work with `torch.compile`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121690
Approved by: https://github.com/Skylion007
2024-03-12 23:21:46 +00:00
Oguz Ulgen
6566b3db67 Add an autotune cache for inductor generated kernels (#120963)
Summary: Inductor currently has a best config cache for kernels that it generates. This is a local cache done via writing to the file system. This diff takes this local cache to remote by reusing the existing triton caching mechanism built via Memcache internally and Redis externally.

Test Plan:
tested locally using `TORCH_INDUCTOR_AUTOTUNE_REMOTE_CACHE =1`

Look at scuba to verify the local testing: https://fburl.com/scuba/triton_remote_cache/z6pypznk

The plan is to land this diff with this turned off and gradually introduce this.

Differential Revision: D54398076

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120963
Approved by: https://github.com/jansel
2024-03-04 16:58:37 +00:00
xinan.lin
e60bc502b4 [Inductor Intel GPU backend Upstream] Generalize part of Inductor test case (#117513)
Following the RFC https://github.com/pytorch/pytorch/issues/114856, before upstream Intel XPU Inductor Backend, we need to preapre corresponding Inductor test cases. This PR aims to generalize part of Inductor test case so that a new GPU backend can reuse the existing test case with minimal code change.

This Pull Request preferentially generalizes the test cases that cover Inductor's base functionality as follow:
- test/inductor/test_codecache.py
- test/inductor/test_codegen_triton.py
- test/inductor/test_kernel_benchmark.py
- test/inductor/test_torchinductor.py
- test/inductor/test_torchinductor_codegen_dynamic_shapes.py
- test/inductor/test_torchinductor_dynamic_shapes.py
- test/inductor/test_torchinductor_opinfo.py
- test/inductor/test_triton_heuristics.py
- test/inductor/test_triton_wrapper.py

Feature request: https://github.com/pytorch/pytorch/issues/114856

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117513
Approved by: https://github.com/EikanWang, https://github.com/jansel
2024-01-18 08:26:21 +00:00
Yu, Guangye
e9c9b1ed59 [Inductor] Generalize inductor triton backend device agnostic (#109486)
# Motivation
@jansel As discussed before, we expected to generalize some cuda-specific code. This can make inductor more friendly to third-party backend so that we can leverage inductor code as much as possible.

# Solution
To implement this, we give a solution to introduce device runtime abstraction. We wrapper them inside `DeviceInterface` and use `register_interface_for_device` to register each kind of device to inductor. Then use `get_interface_for_device` to fetch the corresponding runtime from device type. Then usage is like this:
```python
device_interface = get_interface_for_device("xpu")
device_interface .is_available() # to check if XPU is available
device_interface .device_count() # to check how much XPU device is available
```
The `DeviceInterface` is a simple abstraction, which enables third-party backends that implement CUDA-like semantics to be integrated with inductor. This can prevent third-party backend from using monkey patch to override some utility functions, like `decode_device` that is hard-coded with CUDA.

# Additional Context
The main code change:
- To leverage AsyncCompile, make it device-agnostic
- Avoid monkey patches, make some utility functions device-agnostic

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109486
Approved by: https://github.com/jansel, https://github.com/jgong5, https://github.com/EikanWang
2023-09-24 07:49:20 +00:00
Oguz Ulgen
1df14f1bf8 Move has_triton to top level triton utils so that dynamo can also access (#109832)
it without creating cyclic dependencies

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109832
Approved by: https://github.com/zou3519
2023-09-22 19:33:41 +00:00