From fa127d9b20720b70c6481ee9c19693714c428446 Mon Sep 17 00:00:00 2001 From: zeshengzong Date: Tue, 16 Sep 2025 12:07:46 +0000 Subject: [PATCH] Fix `LBFGS` wolfe max iteration (#161488) Fixes #91581 , based on #135026 ## Test Result ```bash pytest test/test_optim.py ......... ========================== 1473 passed, 242 skipped in 2412.49s (0:40:12) =========================== ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/161488 Approved by: https://github.com/albanD --- test/test_optim.py | 28 ++++++++++++++++++++++++++++ torch/optim/lbfgs.py | 9 ++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/test/test_optim.py b/test/test_optim.py index 6dd23d6328c..de185725b5c 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -2305,6 +2305,34 @@ class TestOptimRenewed(TestCase): for state in optim.state.values(): self.assertGreater(len(state), 0) + @parametrize("dtype", [torch.float32]) + def test_step_iteration(self, device, dtype): + def _get_model_and_input_tensor(device, dtype): + model = torch.nn.Sequential( + torch.nn.Conv2d(4, 2, 1, stride=2), + torch.nn.BatchNorm2d(2, eps=1e-05, momentum=0.1), + ) + input = torch.rand(1, 4, 16, 16, device=device, dtype=dtype) + model.to(dtype=dtype, device=device) + return model, input + + counter = 0 + + def fwd_bwd(optim, mod, i): + nonlocal counter + counter += 1 + optim.zero_grad() + loss = mod(i).sum() + loss.backward() + return loss + + model, input = _get_model_and_input_tensor(device, dtype) + optimizer = torch.optim.LBFGS( + model.parameters(), max_iter=1, max_eval=5, line_search_fn="strong_wolfe" + ) + optimizer.step(functools.partial(fwd_bwd, optimizer, model, input)) + self.assertEqual(counter, 6) + instantiate_device_type_tests(TestOptimRenewed, globals(), allow_mps=True) diff --git a/torch/optim/lbfgs.py b/torch/optim/lbfgs.py index 674aaaf2688..09f5f2ca8c8 100644 --- a/torch/optim/lbfgs.py +++ b/torch/optim/lbfgs.py @@ -442,7 +442,14 @@ class LBFGS(Optimizer): return self._directional_evaluate(closure, x, t, d) loss, flat_grad, t, ls_func_evals = _strong_wolfe( - obj_func, x_init, t, d, loss, flat_grad, gtd + obj_func, + x_init, + t, + d, + loss, + flat_grad, + gtd, + max_ls=max_eval - current_evals, ) self._add_grad(t, d) opt_cond = flat_grad.abs().max() <= tolerance_grad