[Inductor][Triton] Rework casting logic to avoid illegal bitcast (#147395)

Triton introduced checks for bitcasts where the casted value does not fit into the casted type (e.g. https://github.com/triton-lang/triton/pull/5926, though in this instance I think the issue is related to the type for the broadcast). Some routines in Inductor now perform illegal bitcasts. I reworked the compare and swap w/ index routine used in sort to remove the illegal bitcast (~~I left the bitcast for now, but I think it could probably be removed assuming the reshape does not change the type~~). The explicit cast is correct, and I don't think there are performance issues, but because the cast on the sum is not a bitcast I suppose there could be.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147395
Approved by: https://github.com/eellison
This commit is contained in:
Alex Baden 2025-02-19 19:44:57 +00:00 committed by PyTorch MergeBot
parent 279c7f262e
commit e758d8b4d1

View File

@ -479,8 +479,8 @@ def _compare_and_swap_with_index(
# slice left/right with 'stride' 2**(n_dims - i - 1)
right_mask = tl.arange(0, 2)[None, :, None].to(idtype)
left_mask = (1 - right_mask).to(idtype)
ileft = tl.broadcast_to(tl.sum(iy * left_mask, 1)[:, None, :], shape)
iright = tl.broadcast_to(tl.sum(iy * right_mask, 1)[:, None, :], shape)
ileft = tl.broadcast_to(tl.sum(iy * left_mask, 1).to(idtype)[:, None, :], shape)
iright = tl.broadcast_to(tl.sum(iy * right_mask, 1).to(idtype)[:, None, :], shape)
ileft = tl.reshape(ileft, x.shape)
iright = tl.reshape(iright, x.shape)
left = ileft.to(x.dtype, bitcast=True)