- Set the dtype of "acc" appropriately so that epilogue fusion will have args with dtype
- Update dtype propagation to use `type_to_dtype` instead of instantiating tensor
- Throw if we have a string arg where we should have a proper CSEVariable, unless we're doing the Modification Subgraph thing which is nyi. everything else is appropriately typed (cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @drisspg ).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141991
Approved by: https://github.com/drisspg
ghstack dependencies: #139945, #140057, #141495, #141882
If (ynumel / YBLOCK) > get_max_ygrids(), the z dimension will be used if znumel is None. However, if (ynumel / YBLOCK) % get_max_ygrids() != 0, there will be program launches with inputs that require masking, and so this needs to be considered when determining if the y dimension has a constant mask.
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139751
Approved by: https://github.com/eellison
Co-authored-by: George White <georgew@graphcore.ai>
Preparatory refactor for https://github.com/pytorch/pytorch/pull/137243. Previously, we would typically check for reductions by `tree.prefix == "r"`. This PR moves the check into a helper function. This makes it easier to generalize the code to multi-dimensional reductions, which could have multiple prefixes like `("r0_", "r1_")`.
Tested by the existing CI.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141738
Approved by: https://github.com/jansel
Summary:
Fix logic for inserting broadcast on kernel with load going directly to store. In the case where load is going directly to store, we insert a tl.broadcast on the store, regardless of the block size on the load. In the case where a broadcast is not required, the downstream Triton compiler is expected to remove this no-op broadcast instruction.
Test Plan: Added tests under test_torchinductor_strided_blocks.py:test_expand_broadcast in OSS and internal test cases.
Reviewed By: blaine-rister
Differential Revision: D65518033
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141693
Approved by: https://github.com/blaine-rister
- Add in upcast_compute_type on creation of new tensors (loads, constants)
- Fixes index_expr - right now we are sort of inconsistent in dtype and dont always respect the dtype specified. would be nice to fix but not doing in this pr.
- bug fix in view dtype where we were always upcasting back to fp32 when input was in bf16/fp16. we should only be doing that if the output is also in bf16/fp16.
- for masked, avoid calling dtype propagation and just use output dtype.
Turns on the runtime dtype verification for opinfo tests. The separate test file is still useful because we can use it for testing turning off codegen_upcast_to_fp32.
Follow ups:
- We could consider requiring less explicit upcast_compute_types calls and do it automatically. That would potentially make things easier but be less flexible in the future. Maybe I should have done it this pr.
- Be more consistent on our index expr dtype printing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141495
Approved by: https://github.com/blaine-rister, https://github.com/arui-meta, https://github.com/ezyang
ghstack dependencies: #139945, #140057
A couple changes.
- Tries to reuse dtype propagation rules that were already registered in inductor. These were present both with `pointwise_overrides_data` and the `boolean_ops` list. Additionally, the registration of pointwise ops already specified dtype propagation rules. Saves those registrations and reuses them later.
- Factors out `get_promoted_dtype` which uses functools.lru_cache to take in non - CSEVariable args because those will not work with the functools cache.
Tests get added later in the stack when everything is implemented.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139945
Approved by: https://github.com/blaine-rister, https://github.com/arui-meta, https://github.com/ezyang
Here's a markdown summary for the PR:
# Add workspace buffer support for Triton templates
## Summary
Adds support for templates to allocate and use temporary workspace buffers
## Key Changes
- Add `WorkspaceArg` support in Triton template system
- Automatic workspace allocation/deallocation around kernel execution
- Zero-initialization support for workspace buffers
- Seamless integration with existing tensor management
## Example Usage
```python
def generate(self, ...):
workspace_arg = WorkspaceArg(
count=1024*1024, # 1MB workspace
zero_fill=True # Zero-initialized
)
return TritonTemplateCaller(..., workspace_arg=workspace_arg)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138050
Approved by: https://github.com/Chillee, https://github.com/eellison
Summary:
- This diff introduces `dtype` attribute to `TritonCSEVariable` and a dtype propagation helper function to infer dtype from input to output for each op.
- There will be a follow-up diff that uses this `dtype` information in `TritonCSEVariable` to perform dtype-aware codegen.
Test Plan: CI
Differential Revision: D61815079
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136778
Approved by: https://github.com/eellison, https://github.com/blaine-rister
The Triton `AttrsDescriptor` object was refactored in https://github.com/triton-lang/triton/pull/4734. These changes add support for the new `AttrsDescriptor` while maintaining backwards compatibility with the existing version. The main changes are different names for the initialized of the descriptor parameters, and a creation via a static method instead of the class constructor.
Depends on #137458 which removes some unused logic around the old descriptor. Those changes make this PR cleaner, but if for some reason that old logic is still used I can make adjustments.
Use of the new `AttrsDescriptor` depends on https://github.com/triton-lang/triton/pull/4888
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137757
Approved by: https://github.com/jansel
The `AttrsDescriptor` class has been present in Triton for almost a year now (introduced [here](72c9833927)), so we should be able to rely on it existing. I am in the process of supporting the new `AttrsDescriptor` class and @jansel suggested I split changes to the existing class out separately to make sure nothing breaks removing the legacy attribute descriptor attributes.
Initially I attempted to remove the branching around detecting whether `AttrsDescriptor` exists but that breaks because PyTorch must build without Triton. So, I went back and updated for the naming introduced in the commit linked above, and also removed two unused attributes `divisible_by_8` and `ids_to_fold` which were removed in Feb 2024 (https://github.com/triton-lang/triton/pull/3122 and https://github.com/triton-lang/triton/pull/3080 respectively).
With these changes only the internal workings of the `AttrsDescriptor` class will differ between supported Triton versions, but the data stored will remain consistent.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137458
Approved by: https://github.com/jansel
Previously, instances of `SchedulerNode` and `FusedSchedulerNode` would explicitly check whether the compilation target is Triton when codegen'ing debug strings. Generating debug triton code is instead implemented as a callback set on scheduler nodes by `TritonScheduling`. This makes the codegen more device-agnostic and allows schedulers to customise the codegen output as opposed to it being closely coupled to the debug string codegen
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135015
Approved by: https://github.com/jansel
Summary: Previously triton_reshape will generate code with `Min` keyword in it, which is incorrect. This diff updates the triton_reshape function to properly expand `Min` keyword to `<`.
Test Plan:
```
buck2 run @//mode/{opt,mtia,inplace} //glow/fb/fx/fba/tests:test_fba_inductor -- -r test_Min_keyword_in_block_shape
```
Differential Revision: D63850158
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137357
Approved by: https://github.com/blaine-rister, https://github.com/eellison