mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Fixes #112592 1) **File: torch/cuda/random.py** ``` Before: /content/pytorch/torch/cuda/random.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/cuda/random.py:21 in public function `get_rng_state`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') /content/pytorch/torch/cuda/random.py:43 in public function `get_rng_state_all`: D202: No blank lines allowed after function docstring (found 1) /content/pytorch/torch/cuda/random.py:43 in public function `get_rng_state_all`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') /content/pytorch/torch/cuda/random.py:54 in public function `set_rng_state`: D401: First line should be in imperative mood (perhaps 'Set', not 'Sets') /content/pytorch/torch/cuda/random.py:79 in public function `set_rng_state_all`: D208: Docstring is over-indented /content/pytorch/torch/cuda/random.py:79 in public function `set_rng_state_all`: D209: Multi-line docstring closing quotes should be on a separate line /content/pytorch/torch/cuda/random.py:79 in public function `set_rng_state_all`: D401: First line should be in imperative mood (perhaps 'Set', not 'Sets') /content/pytorch/torch/cuda/random.py:79 in public function `set_rng_state_all`: D414: Section has no content ('Args') /content/pytorch/torch/cuda/random.py:88 in public function `manual_seed`: D205: 1 blank line required between summary line and description (found 0) /content/pytorch/torch/cuda/random.py:88 in public function `manual_seed`: D401: First line should be in imperative mood (perhaps 'Set', not 'Sets') /content/pytorch/torch/cuda/random.py:110 in public function `manual_seed_all`: D205: 1 blank line required between summary line and description (found 0) /content/pytorch/torch/cuda/random.py:110 in public function `manual_seed_all`: D401: First line should be in imperative mood (perhaps 'Set', not 'Sets') /content/pytorch/torch/cuda/random.py:128 in public function `seed`: D205: 1 blank line required between summary line and description (found 0) /content/pytorch/torch/cuda/random.py:128 in public function `seed`: D401: First line should be in imperative mood (perhaps 'Set', not 'Sets') /content/pytorch/torch/cuda/random.py:146 in public function `seed_all`: D205: 1 blank line required between summary line and description (found 0) /content/pytorch/torch/cuda/random.py:146 in public function `seed_all`: D401: First line should be in imperative mood (perhaps 'Set', not 'Sets') /content/pytorch/torch/cuda/random.py:167 in public function `initial_seed`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') 18 ``` ``` After: /content/pytorch/torch/cuda/random.py:1 at module level: D100: Missing docstring in public module 1 ``` 2) **File: torch/cuda/amp/autocast_mode.py** ``` Before: /content/pytorch/torch/cuda/amp/autocast_mode.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/cuda/amp/autocast_mode.py:18 in public class `autocast`: D205: 1 blank line required between summary line and description (found 0) /content/pytorch/torch/cuda/amp/autocast_mode.py:23 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/cuda/amp/autocast_mode.py:38 in public method `__enter__`: D105: Missing docstring in magic method /content/pytorch/torch/cuda/amp/autocast_mode.py:44 in public method `__exit__`: D105: Missing docstring in magic method /content/pytorch/torch/cuda/amp/autocast_mode.py:49 in public method `__call__`: D102: Missing docstring in public method /content/pytorch/torch/cuda/amp/autocast_mode.py:90 in public function `custom_fwd`: D205: 1 blank line required between summary line and description (found 0) /content/pytorch/torch/cuda/amp/autocast_mode.py:90 in public function `custom_fwd`: D400: First line should end with a period (not 'f') /content/pytorch/torch/cuda/amp/autocast_mode.py:90 in public function `custom_fwd`: D401: First line should be in imperative mood; try rephrasing (found 'Helper') /content/pytorch/torch/cuda/amp/autocast_mode.py:130 in public function `custom_bwd`: D205: 1 blank line required between summary line and description (found 0) /content/pytorch/torch/cuda/amp/autocast_mode.py:130 in public function `custom_bwd`: D400: First line should end with a period (not 'f') /content/pytorch/torch/cuda/amp/autocast_mode.py:130 in public function `custom_bwd`: D401: First line should be in imperative mood; try rephrasing (found 'Helper') 12 ``` ``` After: /content/pytorch/torch/cuda/amp/autocast_mode.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/cuda/amp/autocast_mode.py:23 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/cuda/amp/autocast_mode.py:38 in public method `__enter__`: D105: Missing docstring in magic method /content/pytorch/torch/cuda/amp/autocast_mode.py:44 in public method `__exit__`: D105: Missing docstring in magic method /content/pytorch/torch/cuda/amp/autocast_mode.py:49 in public method `__call__`: D102: Missing docstring in public method 5 ``` 3) **File: torch/cuda/amp/grad_scaler.py** ``` Before: /content/pytorch/torch/cuda/amp/grad_scaler.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/cuda/amp/grad_scaler.py:17 in private class `_MultiDeviceReplicator`: D200: One-line docstring should fit on one line with quotes (found 3) /content/pytorch/torch/cuda/amp/grad_scaler.py:39 in public class `OptState`: D101: Missing docstring in public class /content/pytorch/torch/cuda/amp/grad_scaler.py:50 in public class `GradScaler`: D205: 1 blank line required between summary line and description (found 0) /content/pytorch/torch/cuda/amp/grad_scaler.py:50 in public class `GradScaler`: D400: First line should end with a period (not 'g') /content/pytorch/torch/cuda/amp/grad_scaler.py:115 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/cuda/amp/grad_scaler.py:354 in public method `step`: D400: First line should end with a period (not ':') /content/pytorch/torch/cuda/amp/grad_scaler.py:456 in public method `update`: D401: First line should be in imperative mood (perhaps 'Update', not 'Updates') /content/pytorch/torch/cuda/amp/grad_scaler.py:529 in public method `get_scale`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') /content/pytorch/torch/cuda/amp/grad_scaler.py:544 in public method `get_growth_factor`: D200: One-line docstring should fit on one line with quotes (found 3) /content/pytorch/torch/cuda/amp/grad_scaler.py:544 in public method `get_growth_factor`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') /content/pytorch/torch/cuda/amp/grad_scaler.py:550 in public method `set_growth_factor`: D205: 1 blank line required between summary line and description (found 0) /content/pytorch/torch/cuda/amp/grad_scaler.py:550 in public method `set_growth_factor`: D400: First line should end with a period (not ':') /content/pytorch/torch/cuda/amp/grad_scaler.py:557 in public method `get_backoff_factor`: D200: One-line docstring should fit on one line with quotes (found 3) /content/pytorch/torch/cuda/amp/grad_scaler.py:557 in public method `get_backoff_factor`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') /content/pytorch/torch/cuda/amp/grad_scaler.py:563 in public method `set_backoff_factor`: D205: 1 blank line required between summary line and description (found 0) /content/pytorch/torch/cuda/amp/grad_scaler.py:563 in public method `set_backoff_factor`: D400: First line should end with a period (not ':') /content/pytorch/torch/cuda/amp/grad_scaler.py:570 in public method `get_growth_interval`: D200: One-line docstring should fit on one line with quotes (found 3) /content/pytorch/torch/cuda/amp/grad_scaler.py:570 in public method `get_growth_interval`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') /content/pytorch/torch/cuda/amp/grad_scaler.py:576 in public method `set_growth_interval`: D205: 1 blank line required between summary line and description (found 0) /content/pytorch/torch/cuda/amp/grad_scaler.py:576 in public method `set_growth_interval`: D400: First line should end with a period (not ':') /content/pytorch/torch/cuda/amp/grad_scaler.py:592 in public method `is_enabled`: D200: One-line docstring should fit on one line with quotes (found 3) /content/pytorch/torch/cuda/amp/grad_scaler.py:592 in public method `is_enabled`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') /content/pytorch/torch/cuda/amp/grad_scaler.py:598 in public method `state_dict`: D400: First line should end with a period (not ':') /content/pytorch/torch/cuda/amp/grad_scaler.py:598 in public method `state_dict`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') /content/pytorch/torch/cuda/amp/grad_scaler.py:624 in public method `load_state_dict`: D401: First line should be in imperative mood (perhaps 'Load', not 'Loads') /content/pytorch/torch/cuda/amp/grad_scaler.py:649 in public method `__getstate__`: D105: Missing docstring in magic method /content/pytorch/torch/cuda/amp/grad_scaler.py:665 in public method `__setstate__`: D105: Missing docstring in magic method 28 ``` ``` After: /content/pytorch/torch/cuda/amp/grad_scaler.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/cuda/amp/grad_scaler.py:40 in public class `OptState`: D101: Missing docstring in public class /content/pytorch/torch/cuda/amp/grad_scaler.py:117 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/cuda/amp/grad_scaler.py:647 in public method `__getstate__`: D105: Missing docstring in magic method /content/pytorch/torch/cuda/amp/grad_scaler.py:663 in public method `__setstate__`: D105: Missing docstring in magic method 5 ``` 4) **File: torch/optim/_functional.py** ``` Before: /content/pytorch/torch/optim/_functional.py:1 at module level: D400: First line should end with a period (not 'e') 1 ``` ``` After: 0 ``` 5) **File: torch/optim/__init__.py** ``` Before: /content/pytorch/torch/optim/__init__.py:1 at module level: D205: 1 blank line required between summary line and description (found 0) 1 ``` ``` After: 0 ``` 6) **File: torch/optim/lbfgs.py** ``` Before: /content/pytorch/torch/optim/lbfgs.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/lbfgs.py:185 in public class `LBFGS`: D205: 1 blank line required between summary line and description (found 0) /content/pytorch/torch/optim/lbfgs.py:185 in public class `LBFGS`: D400: First line should end with a period (not 'c') /content/pytorch/torch/optim/lbfgs.py:215 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/lbfgs.py:285 in public method `step`: D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs') 5 ``` ``` After: /content/pytorch/torch/optim/lbfgs.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/lbfgs.py:217 in public method `__init__`: D107: Missing docstring in __init__ 2 ``` 7)**File: torch/optim/sparse_adam.py** ``` Before: /content/pytorch/torch/optim/sparse_adam.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/sparse_adam.py:7 in public class `SparseAdam`: D101: Missing docstring in public class /content/pytorch/torch/optim/sparse_adam.py:8 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/sparse_adam.py:40 in public method `step`: D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs') 4 ``` ``` After: /content/pytorch/torch/optim/sparse_adam.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/sparse_adam.py:7 in public class `SparseAdam`: D101: Missing docstring in public class /content/pytorch/torch/optim/sparse_adam.py:8 in public method `__init__`: D107: Missing docstring in __init__ 3 ``` 8) **File:torch/optim/adadelta.py** ``` Before: /content/pytorch/torch/optim/adadelta.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/adadelta.py:11 in public class `Adadelta`: D101: Missing docstring in public class /content/pytorch/torch/optim/adadelta.py:12 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/adadelta.py:44 in public method `__setstate__`: D105: Missing docstring in magic method /content/pytorch/torch/optim/adadelta.py:82 in public method `step`: D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs') /content/pytorch/torch/optim/adadelta.py:193 in public function `adadelta`: D202: No blank lines allowed after function docstring (found 1) 6 ``` ``` After: /content/pytorch/torch/optim/adadelta.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/adadelta.py:11 in public class `Adadelta`: D101: Missing docstring in public class /content/pytorch/torch/optim/adadelta.py:12 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/adadelta.py:44 in public method `__setstate__`: D105: Missing docstring in magic method 4 ``` 9) **File: torch/optim/adagrad.py** ``` Before: /content/pytorch/torch/optim/adagrad.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/adagrad.py:11 in public class `Adagrad`: D101: Missing docstring in public class /content/pytorch/torch/optim/adagrad.py:12 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/adagrad.py:63 in public method `__setstate__`: D105: Missing docstring in magic method /content/pytorch/torch/optim/adagrad.py:78 in public method `share_memory`: D102: Missing docstring in public method /content/pytorch/torch/optim/adagrad.py:100 in public method `step`: D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs') /content/pytorch/torch/optim/adagrad.py:201 in public function `adagrad`: D202: No blank lines allowed after function docstring (found 1) 7 ``` ``` After: /content/pytorch/torch/optim/adagrad.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/adagrad.py:11 in public class `Adagrad`: D101: Missing docstring in public class /content/pytorch/torch/optim/adagrad.py:12 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/adagrad.py:63 in public method `__setstate__`: D105: Missing docstring in magic method /content/pytorch/torch/optim/adagrad.py:78 in public method `share_memory`: D102: Missing docstring in public method 5 ``` 10) **File: torch/optim/adam.py** ``` Before: /content/pytorch/torch/optim/adam.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/adam.py:14 in public class `Adam`: D101: Missing docstring in public class /content/pytorch/torch/optim/adam.py:15 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/adam.py:65 in public method `__setstate__`: D105: Missing docstring in magic method /content/pytorch/torch/optim/adam.py:135 in public method `step`: D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs') /content/pytorch/torch/optim/adam.py:281 in public function `adam`: D202: No blank lines allowed after function docstring (found 1) /content/pytorch/torch/optim/adam.py:281 in public function `adam`: D205: 1 blank line required between summary line and description (found 0) 7 ``` ``` After: /content/pytorch/torch/optim/adam.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/adam.py:14 in public class `Adam`: D101: Missing docstring in public class /content/pytorch/torch/optim/adam.py:15 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/adam.py:65 in public method `__setstate__`: D105: Missing docstring in magic method 4 ``` 11) **File: torch/optim/adamax.py** ``` Before: /content/pytorch/torch/optim/adamax.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/adamax.py:12 in public class `Adamax`: D101: Missing docstring in public class /content/pytorch/torch/optim/adamax.py:13 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/adamax.py:47 in public method `__setstate__`: D105: Missing docstring in magic method /content/pytorch/torch/optim/adamax.py:91 in public method `step`: D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs') /content/pytorch/torch/optim/adamax.py:203 in public function `adamax`: D202: No blank lines allowed after function docstring (found 1) 6 ``` ``` After: /content/pytorch/torch/optim/adamax.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/adamax.py:12 in public class `Adamax`: D101: Missing docstring in public class /content/pytorch/torch/optim/adamax.py:13 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/adamax.py:47 in public method `__setstate__`: D105: Missing docstring in magic method 4 ``` 12) **File: torch/optim/adamw.py** ``` Before: /content/pytorch/torch/optim/adamw.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/adamw.py:12 in public class `AdamW`: D101: Missing docstring in public class /content/pytorch/torch/optim/adamw.py:13 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/adamw.py:73 in public method `__setstate__`: D105: Missing docstring in magic method /content/pytorch/torch/optim/adamw.py:153 in public method `step`: D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs') /content/pytorch/torch/optim/adamw.py:304 in public function `adamw`: D202: No blank lines allowed after function docstring (found 1) 6 ``` ``` After: /content/pytorch/torch/optim/adamw.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/adamw.py:12 in public class `AdamW`: D101: Missing docstring in public class /content/pytorch/torch/optim/adamw.py:13 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/adamw.py:73 in public method `__setstate__`: D105: Missing docstring in magic method 4 ``` 13) **File: torch/optim/asgd.py** ``` Before: /content/pytorch/torch/optim/asgd.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/asgd.py:17 in public class `ASGD`: D101: Missing docstring in public class /content/pytorch/torch/optim/asgd.py:18 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/asgd.py:52 in public method `__setstate__`: D105: Missing docstring in magic method /content/pytorch/torch/optim/asgd.py:107 in public method `step`: D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs') /content/pytorch/torch/optim/asgd.py:195 in public function `asgd`: D202: No blank lines allowed after function docstring (found 1) 6 ``` ``` After: /content/pytorch/torch/optim/asgd.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/asgd.py:17 in public class `ASGD`: D101: Missing docstring in public class /content/pytorch/torch/optim/asgd.py:18 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/asgd.py:52 in public method `__setstate__`: D105: Missing docstring in magic method 4 ``` Resolved docstring errors as listed. I initially changed in the main branch of forked repo which caused changes to appear in my PR to other issue. I have fixed that and hope this PR won't have any conflicts. Kindly review @svekars @jbschlosser. In case of any other issues please let me know. Thanks! Pull Request resolved: https://github.com/pytorch/pytorch/pull/112964 Approved by: https://github.com/kit1980
479 lines
17 KiB
Python
479 lines
17 KiB
Python
import torch
|
|
from functools import reduce
|
|
from .optimizer import Optimizer
|
|
|
|
__all__ = ['LBFGS']
|
|
|
|
def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
|
|
# ported from https://github.com/torch/optim/blob/master/polyinterp.lua
|
|
# Compute bounds of interpolation area
|
|
if bounds is not None:
|
|
xmin_bound, xmax_bound = bounds
|
|
else:
|
|
xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1)
|
|
|
|
# Code for most common case: cubic interpolation of 2 points
|
|
# w/ function and derivative values for both
|
|
# Solution in this case (where x2 is the farthest point):
|
|
# d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
|
|
# d2 = sqrt(d1^2 - g1*g2);
|
|
# min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
|
|
# t_new = min(max(min_pos,xmin_bound),xmax_bound);
|
|
d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
|
|
d2_square = d1**2 - g1 * g2
|
|
if d2_square >= 0:
|
|
d2 = d2_square.sqrt()
|
|
if x1 <= x2:
|
|
min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
|
|
else:
|
|
min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
|
|
return min(max(min_pos, xmin_bound), xmax_bound)
|
|
else:
|
|
return (xmin_bound + xmax_bound) / 2.
|
|
|
|
|
|
def _strong_wolfe(obj_func,
|
|
x,
|
|
t,
|
|
d,
|
|
f,
|
|
g,
|
|
gtd,
|
|
c1=1e-4,
|
|
c2=0.9,
|
|
tolerance_change=1e-9,
|
|
max_ls=25):
|
|
# ported from https://github.com/torch/optim/blob/master/lswolfe.lua
|
|
d_norm = d.abs().max()
|
|
g = g.clone(memory_format=torch.contiguous_format)
|
|
# evaluate objective and gradient using initial step
|
|
f_new, g_new = obj_func(x, t, d)
|
|
ls_func_evals = 1
|
|
gtd_new = g_new.dot(d)
|
|
|
|
# bracket an interval containing a point satisfying the Wolfe criteria
|
|
t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd
|
|
done = False
|
|
ls_iter = 0
|
|
while ls_iter < max_ls:
|
|
# check conditions
|
|
if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev):
|
|
bracket = [t_prev, t]
|
|
bracket_f = [f_prev, f_new]
|
|
bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
|
|
bracket_gtd = [gtd_prev, gtd_new]
|
|
break
|
|
|
|
if abs(gtd_new) <= -c2 * gtd:
|
|
bracket = [t]
|
|
bracket_f = [f_new]
|
|
bracket_g = [g_new]
|
|
done = True
|
|
break
|
|
|
|
if gtd_new >= 0:
|
|
bracket = [t_prev, t]
|
|
bracket_f = [f_prev, f_new]
|
|
bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
|
|
bracket_gtd = [gtd_prev, gtd_new]
|
|
break
|
|
|
|
# interpolate
|
|
min_step = t + 0.01 * (t - t_prev)
|
|
max_step = t * 10
|
|
tmp = t
|
|
t = _cubic_interpolate(
|
|
t_prev,
|
|
f_prev,
|
|
gtd_prev,
|
|
t,
|
|
f_new,
|
|
gtd_new,
|
|
bounds=(min_step, max_step))
|
|
|
|
# next step
|
|
t_prev = tmp
|
|
f_prev = f_new
|
|
g_prev = g_new.clone(memory_format=torch.contiguous_format)
|
|
gtd_prev = gtd_new
|
|
f_new, g_new = obj_func(x, t, d)
|
|
ls_func_evals += 1
|
|
gtd_new = g_new.dot(d)
|
|
ls_iter += 1
|
|
|
|
# reached max number of iterations?
|
|
if ls_iter == max_ls:
|
|
bracket = [0, t]
|
|
bracket_f = [f, f_new]
|
|
bracket_g = [g, g_new]
|
|
|
|
# zoom phase: we now have a point satisfying the criteria, or
|
|
# a bracket around it. We refine the bracket until we find the
|
|
# exact point satisfying the criteria
|
|
insuf_progress = False
|
|
# find high and low points in bracket
|
|
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0)
|
|
while not done and ls_iter < max_ls:
|
|
# line-search bracket is so small
|
|
if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change:
|
|
break
|
|
|
|
# compute new trial value
|
|
t = _cubic_interpolate(bracket[0], bracket_f[0], bracket_gtd[0],
|
|
bracket[1], bracket_f[1], bracket_gtd[1])
|
|
|
|
# test that we are making sufficient progress:
|
|
# in case `t` is so close to boundary, we mark that we are making
|
|
# insufficient progress, and if
|
|
# + we have made insufficient progress in the last step, or
|
|
# + `t` is at one of the boundary,
|
|
# we will move `t` to a position which is `0.1 * len(bracket)`
|
|
# away from the nearest boundary point.
|
|
eps = 0.1 * (max(bracket) - min(bracket))
|
|
if min(max(bracket) - t, t - min(bracket)) < eps:
|
|
# interpolation close to boundary
|
|
if insuf_progress or t >= max(bracket) or t <= min(bracket):
|
|
# evaluate at 0.1 away from boundary
|
|
if abs(t - max(bracket)) < abs(t - min(bracket)):
|
|
t = max(bracket) - eps
|
|
else:
|
|
t = min(bracket) + eps
|
|
insuf_progress = False
|
|
else:
|
|
insuf_progress = True
|
|
else:
|
|
insuf_progress = False
|
|
|
|
# Evaluate new point
|
|
f_new, g_new = obj_func(x, t, d)
|
|
ls_func_evals += 1
|
|
gtd_new = g_new.dot(d)
|
|
ls_iter += 1
|
|
|
|
if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
|
|
# Armijo condition not satisfied or not lower than lowest point
|
|
bracket[high_pos] = t
|
|
bracket_f[high_pos] = f_new
|
|
bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format)
|
|
bracket_gtd[high_pos] = gtd_new
|
|
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
|
|
else:
|
|
if abs(gtd_new) <= -c2 * gtd:
|
|
# Wolfe conditions satisfied
|
|
done = True
|
|
elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
|
|
# old high becomes new low
|
|
bracket[high_pos] = bracket[low_pos]
|
|
bracket_f[high_pos] = bracket_f[low_pos]
|
|
bracket_g[high_pos] = bracket_g[low_pos]
|
|
bracket_gtd[high_pos] = bracket_gtd[low_pos]
|
|
|
|
# new point becomes new low
|
|
bracket[low_pos] = t
|
|
bracket_f[low_pos] = f_new
|
|
bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format)
|
|
bracket_gtd[low_pos] = gtd_new
|
|
|
|
# return stuff
|
|
t = bracket[low_pos]
|
|
f_new = bracket_f[low_pos]
|
|
g_new = bracket_g[low_pos]
|
|
return f_new, g_new, t, ls_func_evals
|
|
|
|
|
|
class LBFGS(Optimizer):
|
|
"""Implements L-BFGS algorithm.
|
|
|
|
Heavily inspired by `minFunc
|
|
<https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html>`_.
|
|
|
|
.. warning::
|
|
This optimizer doesn't support per-parameter options and parameter
|
|
groups (there can be only one).
|
|
|
|
.. warning::
|
|
Right now all parameters have to be on a single device. This will be
|
|
improved in the future.
|
|
|
|
.. note::
|
|
This is a very memory intensive optimizer (it requires additional
|
|
``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory
|
|
try reducing the history size, or use a different algorithm.
|
|
|
|
Args:
|
|
lr (float): learning rate (default: 1)
|
|
max_iter (int): maximal number of iterations per optimization step
|
|
(default: 20)
|
|
max_eval (int): maximal number of function evaluations per optimization
|
|
step (default: max_iter * 1.25).
|
|
tolerance_grad (float): termination tolerance on first order optimality
|
|
(default: 1e-7).
|
|
tolerance_change (float): termination tolerance on function
|
|
value/parameter changes (default: 1e-9).
|
|
history_size (int): update history size (default: 100).
|
|
line_search_fn (str): either 'strong_wolfe' or None (default: None).
|
|
"""
|
|
|
|
def __init__(self,
|
|
params,
|
|
lr=1,
|
|
max_iter=20,
|
|
max_eval=None,
|
|
tolerance_grad=1e-7,
|
|
tolerance_change=1e-9,
|
|
history_size=100,
|
|
line_search_fn=None):
|
|
if max_eval is None:
|
|
max_eval = max_iter * 5 // 4
|
|
defaults = dict(
|
|
lr=lr,
|
|
max_iter=max_iter,
|
|
max_eval=max_eval,
|
|
tolerance_grad=tolerance_grad,
|
|
tolerance_change=tolerance_change,
|
|
history_size=history_size,
|
|
line_search_fn=line_search_fn)
|
|
super().__init__(params, defaults)
|
|
|
|
if len(self.param_groups) != 1:
|
|
raise ValueError("LBFGS doesn't support per-parameter options "
|
|
"(parameter groups)")
|
|
|
|
self._params = self.param_groups[0]['params']
|
|
self._numel_cache = None
|
|
|
|
def _numel(self):
|
|
if self._numel_cache is None:
|
|
self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
|
|
return self._numel_cache
|
|
|
|
def _gather_flat_grad(self):
|
|
views = []
|
|
for p in self._params:
|
|
if p.grad is None:
|
|
view = p.new(p.numel()).zero_()
|
|
elif p.grad.is_sparse:
|
|
view = p.grad.to_dense().view(-1)
|
|
else:
|
|
view = p.grad.view(-1)
|
|
views.append(view)
|
|
return torch.cat(views, 0)
|
|
|
|
def _add_grad(self, step_size, update):
|
|
offset = 0
|
|
for p in self._params:
|
|
numel = p.numel()
|
|
# view as to avoid deprecated pointwise semantics
|
|
p.add_(update[offset:offset + numel].view_as(p), alpha=step_size)
|
|
offset += numel
|
|
assert offset == self._numel()
|
|
|
|
def _clone_param(self):
|
|
return [p.clone(memory_format=torch.contiguous_format) for p in self._params]
|
|
|
|
def _set_param(self, params_data):
|
|
for p, pdata in zip(self._params, params_data):
|
|
p.copy_(pdata)
|
|
|
|
def _directional_evaluate(self, closure, x, t, d):
|
|
self._add_grad(t, d)
|
|
loss = float(closure())
|
|
flat_grad = self._gather_flat_grad()
|
|
self._set_param(x)
|
|
return loss, flat_grad
|
|
|
|
@torch.no_grad()
|
|
def step(self, closure):
|
|
"""Perform a single optimization step.
|
|
|
|
Args:
|
|
closure (Callable): A closure that reevaluates the model
|
|
and returns the loss.
|
|
"""
|
|
assert len(self.param_groups) == 1
|
|
|
|
# Make sure the closure is always called with grad enabled
|
|
closure = torch.enable_grad()(closure)
|
|
|
|
group = self.param_groups[0]
|
|
lr = group['lr']
|
|
max_iter = group['max_iter']
|
|
max_eval = group['max_eval']
|
|
tolerance_grad = group['tolerance_grad']
|
|
tolerance_change = group['tolerance_change']
|
|
line_search_fn = group['line_search_fn']
|
|
history_size = group['history_size']
|
|
|
|
# NOTE: LBFGS has only global state, but we register it as state for
|
|
# the first param, because this helps with casting in load_state_dict
|
|
state = self.state[self._params[0]]
|
|
state.setdefault('func_evals', 0)
|
|
state.setdefault('n_iter', 0)
|
|
|
|
# evaluate initial f(x) and df/dx
|
|
orig_loss = closure()
|
|
loss = float(orig_loss)
|
|
current_evals = 1
|
|
state['func_evals'] += 1
|
|
|
|
flat_grad = self._gather_flat_grad()
|
|
opt_cond = flat_grad.abs().max() <= tolerance_grad
|
|
|
|
# optimal condition
|
|
if opt_cond:
|
|
return orig_loss
|
|
|
|
# tensors cached in state (for tracing)
|
|
d = state.get('d')
|
|
t = state.get('t')
|
|
old_dirs = state.get('old_dirs')
|
|
old_stps = state.get('old_stps')
|
|
ro = state.get('ro')
|
|
H_diag = state.get('H_diag')
|
|
prev_flat_grad = state.get('prev_flat_grad')
|
|
prev_loss = state.get('prev_loss')
|
|
|
|
n_iter = 0
|
|
# optimize for a max of max_iter iterations
|
|
while n_iter < max_iter:
|
|
# keep track of nb of iterations
|
|
n_iter += 1
|
|
state['n_iter'] += 1
|
|
|
|
############################################################
|
|
# compute gradient descent direction
|
|
############################################################
|
|
if state['n_iter'] == 1:
|
|
d = flat_grad.neg()
|
|
old_dirs = []
|
|
old_stps = []
|
|
ro = []
|
|
H_diag = 1
|
|
else:
|
|
# do lbfgs update (update memory)
|
|
y = flat_grad.sub(prev_flat_grad)
|
|
s = d.mul(t)
|
|
ys = y.dot(s) # y*s
|
|
if ys > 1e-10:
|
|
# updating memory
|
|
if len(old_dirs) == history_size:
|
|
# shift history by one (limited-memory)
|
|
old_dirs.pop(0)
|
|
old_stps.pop(0)
|
|
ro.pop(0)
|
|
|
|
# store new direction/step
|
|
old_dirs.append(y)
|
|
old_stps.append(s)
|
|
ro.append(1. / ys)
|
|
|
|
# update scale of initial Hessian approximation
|
|
H_diag = ys / y.dot(y) # (y*y)
|
|
|
|
# compute the approximate (L-BFGS) inverse Hessian
|
|
# multiplied by the gradient
|
|
num_old = len(old_dirs)
|
|
|
|
if 'al' not in state:
|
|
state['al'] = [None] * history_size
|
|
al = state['al']
|
|
|
|
# iteration in L-BFGS loop collapsed to use just one buffer
|
|
q = flat_grad.neg()
|
|
for i in range(num_old - 1, -1, -1):
|
|
al[i] = old_stps[i].dot(q) * ro[i]
|
|
q.add_(old_dirs[i], alpha=-al[i])
|
|
|
|
# multiply by initial Hessian
|
|
# r/d is the final direction
|
|
d = r = torch.mul(q, H_diag)
|
|
for i in range(num_old):
|
|
be_i = old_dirs[i].dot(r) * ro[i]
|
|
r.add_(old_stps[i], alpha=al[i] - be_i)
|
|
|
|
if prev_flat_grad is None:
|
|
prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format)
|
|
else:
|
|
prev_flat_grad.copy_(flat_grad)
|
|
prev_loss = loss
|
|
|
|
############################################################
|
|
# compute step length
|
|
############################################################
|
|
# reset initial guess for step size
|
|
if state['n_iter'] == 1:
|
|
t = min(1., 1. / flat_grad.abs().sum()) * lr
|
|
else:
|
|
t = lr
|
|
|
|
# directional derivative
|
|
gtd = flat_grad.dot(d) # g * d
|
|
|
|
# directional derivative is below tolerance
|
|
if gtd > -tolerance_change:
|
|
break
|
|
|
|
# optional line search: user function
|
|
ls_func_evals = 0
|
|
if line_search_fn is not None:
|
|
# perform line search, using user function
|
|
if line_search_fn != "strong_wolfe":
|
|
raise RuntimeError("only 'strong_wolfe' is supported")
|
|
else:
|
|
x_init = self._clone_param()
|
|
|
|
def obj_func(x, t, d):
|
|
return self._directional_evaluate(closure, x, t, d)
|
|
|
|
loss, flat_grad, t, ls_func_evals = _strong_wolfe(
|
|
obj_func, x_init, t, d, loss, flat_grad, gtd)
|
|
self._add_grad(t, d)
|
|
opt_cond = flat_grad.abs().max() <= tolerance_grad
|
|
else:
|
|
# no line search, simply move with fixed-step
|
|
self._add_grad(t, d)
|
|
if n_iter != max_iter:
|
|
# re-evaluate function only if not in last iteration
|
|
# the reason we do this: in a stochastic setting,
|
|
# no use to re-evaluate that function here
|
|
with torch.enable_grad():
|
|
loss = float(closure())
|
|
flat_grad = self._gather_flat_grad()
|
|
opt_cond = flat_grad.abs().max() <= tolerance_grad
|
|
ls_func_evals = 1
|
|
|
|
# update func eval
|
|
current_evals += ls_func_evals
|
|
state['func_evals'] += ls_func_evals
|
|
|
|
############################################################
|
|
# check conditions
|
|
############################################################
|
|
if n_iter == max_iter:
|
|
break
|
|
|
|
if current_evals >= max_eval:
|
|
break
|
|
|
|
# optimal condition
|
|
if opt_cond:
|
|
break
|
|
|
|
# lack of progress
|
|
if d.mul(t).abs().max() <= tolerance_change:
|
|
break
|
|
|
|
if abs(loss - prev_loss) < tolerance_change:
|
|
break
|
|
|
|
state['d'] = d
|
|
state['t'] = t
|
|
state['old_dirs'] = old_dirs
|
|
state['old_stps'] = old_stps
|
|
state['ro'] = ro
|
|
state['H_diag'] = H_diag
|
|
state['prev_flat_grad'] = prev_flat_grad
|
|
state['prev_loss'] = prev_loss
|
|
|
|
return orig_loss
|