diff --git a/test/distributed/tensor/test_api.py b/test/distributed/tensor/test_api.py index b880beef477..b9280a143e5 100644 --- a/test/distributed/tensor/test_api.py +++ b/test/distributed/tensor/test_api.py @@ -319,6 +319,9 @@ class DTensorAPITest(DTensorTestBase): ) # check autocast + # `distribute_module` is an in-place operation, so we need to create a + # new model + model = MyModel(10, 10, device=self.device_type) dt = distribute_tensor(torch.rand(10), device_mesh, [Replicate()]) replica_model = distribute_module( model, diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index cedadee57b0..a572707be5d 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -874,6 +874,13 @@ def distribute_module( torch._C._log_api_usage_once("torch.dtensor.distribute_module") + already_distributed = getattr(module, "_distribute_module_applied", False) + if already_distributed: + raise RuntimeError( + "distribute_module should only be called once on a module, " + "but it has already been called on this module!" + ) + device_mesh = device_mesh or _mesh_resources.get_current_mesh() device_type = device_mesh.device_type if device_type == "xla": @@ -967,6 +974,7 @@ def distribute_module( f"output_fn should take in 3 arguments, but got {num_args} arguments!" ) + module._distribute_module_applied = True # type: ignore[assignment] return module