mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[hop] support local_map filtered gradients (#164437)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164437 Approved by: https://github.com/ezyang ghstack dependencies: #164296, #164321, #164419, #164420, #164340, #163602, #164431, #164433
This commit is contained in:
parent
3ad88924ad
commit
a61d0de9f9
|
|
@ -682,6 +682,39 @@ class GraphModule(torch.nn.Module):
|
|||
model = MyModule()
|
||||
ap_style_initial_capture(model, (x,))
|
||||
|
||||
@unittest.skipIf(*get_skip_reasons())
|
||||
def test_filtered_gradients(self):
|
||||
@local_map(
|
||||
out_placements=(
|
||||
(Replicate(), Replicate(), Replicate()),
|
||||
(Replicate(), Replicate(), Replicate()),
|
||||
),
|
||||
in_placements=(
|
||||
(Replicate(), Replicate(), Replicate()),
|
||||
(Replicate(), Replicate(), Replicate()),
|
||||
),
|
||||
redistribute_inputs=True,
|
||||
in_grad_placements=None,
|
||||
device_mesh=self.mesh,
|
||||
)
|
||||
def returns_non_param(w, x):
|
||||
# x does not requires_grad, and it is an output, so its corresponding tangent is filtered out
|
||||
return torch.matmul(x, w.t()), x + 20
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.w = nn.Linear(80, 80)
|
||||
|
||||
def forward(self, x):
|
||||
a, b = returns_non_param(self.w.weight, x)
|
||||
return a.sum() + b.sum()
|
||||
|
||||
model = MyModule()
|
||||
with FakeTensorMode():
|
||||
inputs = (torch.randn(80, 80),)
|
||||
ap_style_initial_capture(model, inputs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -313,9 +313,9 @@ def create_hop_fw_bw(
|
|||
new_fw_gm.meta["local_map_kwargs"] = local_map_kwargs
|
||||
new_bw_gm.meta["local_map_kwargs"] = {**local_map_kwargs}
|
||||
# Okay because Autoparallel assumes same sharding between param and grads
|
||||
new_bw_gm.meta["local_map_kwargs"]["in_placements"] = local_map_kwargs[
|
||||
"out_placements"
|
||||
]
|
||||
new_bw_gm.meta["local_map_kwargs"]["in_placements"] = tuple(
|
||||
[local_map_kwargs["out_placements"][i] for i in filtered_grads_idx]
|
||||
)
|
||||
new_bw_gm.meta["local_map_kwargs"]["out_placements"] = local_map_kwargs[
|
||||
"in_placements"
|
||||
]
|
||||
|
|
@ -344,6 +344,8 @@ def create_hop_fw_bw(
|
|||
len(new_bw_gm.graph.find_nodes(op="placeholder")) - num_activations
|
||||
)
|
||||
assert actual_bw_inputs > 0
|
||||
assert expected_fw_inputs + expected_bw_inputs == len(primals_and_tangents)
|
||||
assert actual_fw_inputs + actual_bw_inputs == len(primals_and_tangents)
|
||||
assert len(new_bw_gm.graph.find_nodes(op="output")) == 1
|
||||
actual_bw_outputs = len(new_bw_gm.graph.find_nodes(op="output")[0].args[0])
|
||||
assert expected_bw_inputs == actual_bw_inputs
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user