[TP] Add warning when module is distributed twice (#147006)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147006
Approved by: https://github.com/XilunWu
This commit is contained in:
Ke Wen 2025-02-12 15:50:24 -08:00 committed by PyTorch MergeBot
parent 3e4172d985
commit 4879f8f919
2 changed files with 11 additions and 0 deletions

View File

@ -319,6 +319,9 @@ class DTensorAPITest(DTensorTestBase):
)
# check autocast
# `distribute_module` is an in-place operation, so we need to create a
# new model
model = MyModel(10, 10, device=self.device_type)
dt = distribute_tensor(torch.rand(10), device_mesh, [Replicate()])
replica_model = distribute_module(
model,

View File

@ -874,6 +874,13 @@ def distribute_module(
torch._C._log_api_usage_once("torch.dtensor.distribute_module")
already_distributed = getattr(module, "_distribute_module_applied", False)
if already_distributed:
raise RuntimeError(
"distribute_module should only be called once on a module, "
"but it has already been called on this module!"
)
device_mesh = device_mesh or _mesh_resources.get_current_mesh()
device_type = device_mesh.device_type
if device_type == "xla":
@ -967,6 +974,7 @@ def distribute_module(
f"output_fn should take in 3 arguments, but got {num_args} arguments!"
)
module._distribute_module_applied = True # type: ignore[assignment]
return module