mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[FX] Fix _replicate_for_data_parallel (#63821)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63821 Test Plan: Imported from OSS Reviewed By: suo Differential Revision: D30502115 Pulled By: jamesr66a fbshipit-source-id: 0f004f95def6e1ba21ccbeab40cb0a739a0ad20c
This commit is contained in:
parent
5be17ec1fc
commit
4e37a015c7
|
|
@ -2296,6 +2296,21 @@ class TestFX(JitTestCase):
|
|||
r"Call using an FX-traced Module, line .* of the "
|
||||
r"traced Module's generated forward function:")
|
||||
|
||||
def test_graph_module_replicate_for_dp(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.relu(x)
|
||||
|
||||
gm = torch.fx.symbolic_trace(Foo())
|
||||
|
||||
x = torch.randn(5, 3)
|
||||
out = gm(x)
|
||||
|
||||
replica = gm._replicate_for_data_parallel()
|
||||
out_replica = replica(x)
|
||||
|
||||
torch.testing.assert_allclose(out_replica, out)
|
||||
|
||||
def test_ast_rewriter_rewrites_assert(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: int, z: int):
|
||||
|
|
|
|||
|
|
@ -656,6 +656,11 @@ class {module_name}(torch.nn.Module):
|
|||
orig_str = super().__str__()
|
||||
return '\n'.join([orig_str, self._code])
|
||||
|
||||
def _replicate_for_data_parallel(self):
|
||||
new_gm = self.__copy__()
|
||||
new_gm._is_replica = True
|
||||
return new_gm
|
||||
|
||||
# workarounds for issues in __torch_function__
|
||||
|
||||
# WAR for __torch_function__ not handling tensor lists,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user