mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[TP] Add warning when module is distributed twice (#147006)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147006 Approved by: https://github.com/XilunWu
This commit is contained in:
parent
3e4172d985
commit
4879f8f919
|
|
@ -319,6 +319,9 @@ class DTensorAPITest(DTensorTestBase):
|
|||
)
|
||||
|
||||
# check autocast
|
||||
# `distribute_module` is an in-place operation, so we need to create a
|
||||
# new model
|
||||
model = MyModel(10, 10, device=self.device_type)
|
||||
dt = distribute_tensor(torch.rand(10), device_mesh, [Replicate()])
|
||||
replica_model = distribute_module(
|
||||
model,
|
||||
|
|
|
|||
|
|
@ -874,6 +874,13 @@ def distribute_module(
|
|||
|
||||
torch._C._log_api_usage_once("torch.dtensor.distribute_module")
|
||||
|
||||
already_distributed = getattr(module, "_distribute_module_applied", False)
|
||||
if already_distributed:
|
||||
raise RuntimeError(
|
||||
"distribute_module should only be called once on a module, "
|
||||
"but it has already been called on this module!"
|
||||
)
|
||||
|
||||
device_mesh = device_mesh or _mesh_resources.get_current_mesh()
|
||||
device_type = device_mesh.device_type
|
||||
if device_type == "xla":
|
||||
|
|
@ -967,6 +974,7 @@ def distribute_module(
|
|||
f"output_fn should take in 3 arguments, but got {num_args} arguments!"
|
||||
)
|
||||
|
||||
module._distribute_module_applied = True # type: ignore[assignment]
|
||||
return module
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user