[FX] Fix _replicate_for_data_parallel (#63821)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63821

Test Plan: Imported from OSS

Reviewed By: suo

Differential Revision: D30502115

Pulled By: jamesr66a

fbshipit-source-id: 0f004f95def6e1ba21ccbeab40cb0a739a0ad20c
This commit is contained in:
James Reed 2021-08-24 13:44:52 -07:00 committed by Facebook GitHub Bot
parent 5be17ec1fc
commit 4e37a015c7
2 changed files with 20 additions and 0 deletions

View File

@ -2296,6 +2296,21 @@ class TestFX(JitTestCase):
r"Call using an FX-traced Module, line .* of the "
r"traced Module's generated forward function:")
def test_graph_module_replicate_for_dp(self):
class Foo(torch.nn.Module):
def forward(self, x):
return torch.relu(x)
gm = torch.fx.symbolic_trace(Foo())
x = torch.randn(5, 3)
out = gm(x)
replica = gm._replicate_for_data_parallel()
out_replica = replica(x)
torch.testing.assert_allclose(out_replica, out)
def test_ast_rewriter_rewrites_assert(self):
class M(torch.nn.Module):
def forward(self, x: torch.Tensor, y: int, z: int):

View File

@ -656,6 +656,11 @@ class {module_name}(torch.nn.Module):
orig_str = super().__str__()
return '\n'.join([orig_str, self._code])
def _replicate_for_data_parallel(self):
new_gm = self.__copy__()
new_gm._is_replica = True
return new_gm
# workarounds for issues in __torch_function__
# WAR for __torch_function__ not handling tensor lists,