mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE][Ez]: Remove extra copy in dtensor parallel loss (#148096)
Remove an extra copy of the input to `_log_softmax` when there is a dtype and memory format change. Fuse the copies instead. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148096 Approved by: https://github.com/jansel, https://github.com/wconstab
This commit is contained in:
parent
9b7130b8db
commit
3b4b23ab0b
|
|
@ -125,13 +125,12 @@ def _propagate_tensor_meta(
|
|||
# NOTE: The implementation follows torch._decomp.decomposition._log_softmax,
|
||||
# with all_reduce manually inserted to perform distributed computation.
|
||||
def _log_softmax(x, dim, half_to_float, mesh, mesh_dim):
|
||||
x = x.contiguous()
|
||||
if half_to_float:
|
||||
assert x.dtype == torch.half
|
||||
computation_dtype, result_dtype = utils.elementwise_dtypes(
|
||||
x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
||||
)
|
||||
x = x.to(computation_dtype)
|
||||
x = x.to(dtype=computation_dtype, memory_format=torch.contiguous_format)
|
||||
if x.numel() == 0:
|
||||
shifted = x
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user