[hop] support local_map filtered gradients (#164437)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164437
Approved by: https://github.com/ezyang
ghstack dependencies: #164296, #164321, #164419, #164420, #164340, #163602, #164431, #164433
This commit is contained in:
Simon Fan 2025-10-08 13:15:29 -07:00 committed by PyTorch MergeBot
parent 3ad88924ad
commit a61d0de9f9
2 changed files with 38 additions and 3 deletions

View File

@ -682,6 +682,39 @@ class GraphModule(torch.nn.Module):
model = MyModule()
ap_style_initial_capture(model, (x,))
@unittest.skipIf(*get_skip_reasons())
def test_filtered_gradients(self):
@local_map(
out_placements=(
(Replicate(), Replicate(), Replicate()),
(Replicate(), Replicate(), Replicate()),
),
in_placements=(
(Replicate(), Replicate(), Replicate()),
(Replicate(), Replicate(), Replicate()),
),
redistribute_inputs=True,
in_grad_placements=None,
device_mesh=self.mesh,
)
def returns_non_param(w, x):
# x does not requires_grad, and it is an output, so its corresponding tangent is filtered out
return torch.matmul(x, w.t()), x + 20
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = nn.Linear(80, 80)
def forward(self, x):
a, b = returns_non_param(self.w.weight, x)
return a.sum() + b.sum()
model = MyModule()
with FakeTensorMode():
inputs = (torch.randn(80, 80),)
ap_style_initial_capture(model, inputs)
if __name__ == "__main__":
run_tests()

View File

@ -313,9 +313,9 @@ def create_hop_fw_bw(
new_fw_gm.meta["local_map_kwargs"] = local_map_kwargs
new_bw_gm.meta["local_map_kwargs"] = {**local_map_kwargs}
# Okay because Autoparallel assumes same sharding between param and grads
new_bw_gm.meta["local_map_kwargs"]["in_placements"] = local_map_kwargs[
"out_placements"
]
new_bw_gm.meta["local_map_kwargs"]["in_placements"] = tuple(
[local_map_kwargs["out_placements"][i] for i in filtered_grads_idx]
)
new_bw_gm.meta["local_map_kwargs"]["out_placements"] = local_map_kwargs[
"in_placements"
]
@ -344,6 +344,8 @@ def create_hop_fw_bw(
len(new_bw_gm.graph.find_nodes(op="placeholder")) - num_activations
)
assert actual_bw_inputs > 0
assert expected_fw_inputs + expected_bw_inputs == len(primals_and_tangents)
assert actual_fw_inputs + actual_bw_inputs == len(primals_and_tangents)
assert len(new_bw_gm.graph.find_nodes(op="output")) == 1
actual_bw_outputs = len(new_bw_gm.graph.find_nodes(op="output")[0].args[0])
assert expected_bw_inputs == actual_bw_inputs