[BE][Ez]: Remove extra copy in dtensor parallel loss (#148096)

Remove an extra copy of the input to `_log_softmax` when there is a dtype and memory format change. Fuse the copies instead.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148096
Approved by: https://github.com/jansel, https://github.com/wconstab
This commit is contained in:
Aaron Gokaslan 2025-02-28 05:42:29 +00:00 committed by PyTorch MergeBot
parent 9b7130b8db
commit 3b4b23ab0b

View File

@ -125,13 +125,12 @@ def _propagate_tensor_meta(
# NOTE: The implementation follows torch._decomp.decomposition._log_softmax,
# with all_reduce manually inserted to perform distributed computation.
def _log_softmax(x, dim, half_to_float, mesh, mesh_dim):
x = x.contiguous()
if half_to_float:
assert x.dtype == torch.half
computation_dtype, result_dtype = utils.elementwise_dtypes(
x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
)
x = x.to(computation_dtype)
x = x.to(dtype=computation_dtype, memory_format=torch.contiguous_format)
if x.numel() == 0:
shifted = x
else: