mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Triton] [Inductor[ Add tt.descriptor_store to get_tma_stores (#157212)
Summary: Fixes a gap in the Triton update where the traverse would break because `get_tma_stores` didn't handle both TMA APIs. Test Plan: `buck test -m ovr_config//triton:beta 'fbcode//mode/dev-nosan' fbcode//ads_mkl/ops/tests:gdpa_dcpp_test -- --exact 'ads_mkl/ops/tests:gdpa_dcpp_test - test_gdpa_dcpp (ads_mkl.ops.tests.gdpa_dcpp_test.GdpaDCPPTest)'` Rollback Plan: Differential Revision: D77501582 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157212 Approved by: https://github.com/davidberard98
This commit is contained in:
parent
b147b6c0e3
commit
aec569da23
|
|
@ -798,6 +798,9 @@ def get_tma_stores(
|
|||
elif op.name == "tt.experimental_descriptor_store":
|
||||
assert len(op.args) >= 1
|
||||
result.add(op.args[0])
|
||||
elif op.name == "tt.descriptor_store":
|
||||
assert len(op.args) >= 1
|
||||
result.add(op.args[0])
|
||||
|
||||
for val in list(result):
|
||||
if val in ops:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user