mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor][Triton] Rework casting logic to avoid illegal bitcast (#147395)
Triton introduced checks for bitcasts where the casted value does not fit into the casted type (e.g. https://github.com/triton-lang/triton/pull/5926, though in this instance I think the issue is related to the type for the broadcast). Some routines in Inductor now perform illegal bitcasts. I reworked the compare and swap w/ index routine used in sort to remove the illegal bitcast (~~I left the bitcast for now, but I think it could probably be removed assuming the reshape does not change the type~~). The explicit cast is correct, and I don't think there are performance issues, but because the cast on the sum is not a bitcast I suppose there could be. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147395 Approved by: https://github.com/eellison
This commit is contained in:
parent
279c7f262e
commit
e758d8b4d1
|
|
@ -479,8 +479,8 @@ def _compare_and_swap_with_index(
|
|||
# slice left/right with 'stride' 2**(n_dims - i - 1)
|
||||
right_mask = tl.arange(0, 2)[None, :, None].to(idtype)
|
||||
left_mask = (1 - right_mask).to(idtype)
|
||||
ileft = tl.broadcast_to(tl.sum(iy * left_mask, 1)[:, None, :], shape)
|
||||
iright = tl.broadcast_to(tl.sum(iy * right_mask, 1)[:, None, :], shape)
|
||||
ileft = tl.broadcast_to(tl.sum(iy * left_mask, 1).to(idtype)[:, None, :], shape)
|
||||
iright = tl.broadcast_to(tl.sum(iy * right_mask, 1).to(idtype)[:, None, :], shape)
|
||||
ileft = tl.reshape(ileft, x.shape)
|
||||
iright = tl.reshape(iright, x.shape)
|
||||
left = ileft.to(x.dtype, bitcast=True)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user