From 55a1dc7f88e9371319e06366908a96d9b4fe4d68 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Thu, 13 Apr 2023 23:02:35 +0000 Subject: [PATCH] [dtensor] redistributed by default take self mesh instead (#99060) This PR switches redistribute to default use self mesh instead of the global mesh, which is more user friendly Pull Request resolved: https://github.com/pytorch/pytorch/pull/99060 Approved by: https://github.com/mrshenli --- torch/distributed/_tensor/api.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py index 280d17a8043..228e08a63cd 100644 --- a/torch/distributed/_tensor/api.py +++ b/torch/distributed/_tensor/api.py @@ -332,13 +332,12 @@ class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__ .. note:: `redistribute` is differentiable. """ - # This API perform necessary transformations and get - # a new DTensor with the new spec. i.e. for - # sharding it's a reshard behavior. - # Note that redistribute currently only supports out + # NOTE: This redistribute API currently only supports out # of place redistribution, i.e. it always create a new # DTensor object and leave the original one unchanged. - device_mesh = get_global_device_mesh() if device_mesh is None else device_mesh + + # if device_mesh is not specified, use the current device_mesh + device_mesh = device_mesh or self.device_mesh # raise error if new placements not specified if placements is None: raise RuntimeError("placements is needed for redistribute!")