Commit Graph

109 Commits

Author SHA1 Message Date
Soumith Chintala
75754beca3 Revert D14577575: [pytorch][PR] Fix lack of state init for adagrad and add share_memory flag
Differential Revision:
D14577575

Original commit changeset: 12440079ac96

fbshipit-source-id: 935106385e608471dc280fc61cfedf19d330812d
2019-04-26 15:43:04 -07:00
kirayue
af06d6342c Add SGDR(Stochastic Gradient Descent with Warm Restarts) scheduler (#17226)
Summary:
Because of merge error with master in #15042, open a new PR for ezyang.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17226

Differential Revision: D14418145

Pulled By: mrshenli

fbshipit-source-id: 099ba225b28e6aba71760b81b2153ad1c40fbaae
2019-04-25 09:26:31 -07:00
Kaiyu Shi
444f792fa6 Fix lack of state init for adagrad and add share_memory flag (#17679)
Summary:
The current code initialize the `state` in `__init__` method, but the initialization process is not invoked in `add_parameter_group`.

I followed the same approach in other Optimizers to init the `state`.

```python
import torch

emb = torch.nn.Embedding(10,10)
emb2 = torch.nn.Embedding(10,10)

optim = torch.optim.Adagrad(emb.parameters())
print(optim.state[emb.weight])  # already initialized

optim.add_param_group({'params': emb2.parameters()})
print(optim.state[emb2.weight])  # empty dict

loss = emb2.weight.sum() + emb.weight.sum()
loss.backward()
optim.step()  # raised KeyError
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17679

Differential Revision: D14577575

Pulled By: ezyang

fbshipit-source-id: 12440079ac964b9eedad48e393d47f558babe300
2019-04-23 12:22:19 -07:00
Chandler Zuo
e3f1504621 Fix the Division by Zero Bug of CosineAnnealingLR (#19180)
Summary:
Added the formula for the corner case. Updated unit tests.

Fixes #17913
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19180

Differential Revision: D14942023

Pulled By: ezyang

fbshipit-source-id: 167c109b97a7830d5b24541dc91e4788d531feec
2019-04-23 09:54:28 -07:00
Bado Lee
36084908e4 Fix lr_scheduler's last_epoch value at the time of initialization (BC BREAKING!) (#7889)
Summary:
Hello everyone :) !!

I've found that lr_scheduler was initialized with last_epoch as -1.
This causes that even after the first step (not the one in init but explicit step of scheduler),
learning rate of scheduler's optimizer remains as the previous.
```python
>>> import torch
>>> cc = torch.nn.Conv2d(10,10,3)
>>> myinitial_lr = 0.1
>>> myoptimizer = torch.optim.Adam(cc.parameters(), lr=myinitial_lr)
>>> mylrdecay = 0.5
>>> myscheduler = torch.optim.lr_scheduler.ExponentialLR(myoptimizer,mylrdecay)

>>> myscheduler.get_lr()
[0.2]    # this is because of  get_lr calculates lr by 0.1 * 0.5^-1
>>> myscheduler.optimizer.param_groups[0]["lr"]
0.1    # this is not consistent with get_lr value
>>> myscheduler.last_epoch
-1

>>> myscheduler.step()
>>> myscheduler.get_lr()
[0.1]    # this should be the value right after the init, not after first step
>>> myscheduler.optimizer.param_groups[0]["lr"]
0.1    # since this is after first step, it should have been decayed as 0.05
>>> myscheduler.last_epoch
0

>>> myscheduler.step()
>>> myscheduler.last_epoch
1
>>> myscheduler.get_lr()
[0.05]
>>> myscheduler.optimizer.param_groups[0]["lr"]
0.05
>>> myscheduler.last_epoch
1
```

First problem is, even after the init of lr_scheduler, you get the inconsistent parameter values.

The second problem is, you are stuck with same learning rate in the first 2 epochs if the step function of lr_scheduler is not called in the beginning of the epoch loop.
Of course, you can avoid this by calling lr_scheduler's step in the beginning,
but I don't think this is proper use since, incase of optimizer, step is called in the end of the iteration loop.

I've simply avoided all above issues by setting last_epoch as 0 after the initialization.

This also makes sense when you init with some value of last_epoch which is not -1.
For example, if you want to init with last epoch 10,
lr should not be set with decayed 1 step further. Which is
last_epoch gets +1 in the previous code.
base_lr * self.gamma ** self.last_epoch

Instead, it should be set with step 10 exact value.

I hope this fix find it's way with all your help :)
I'm really looking forward & excited to become a contributor for pytorch!
Pytorch Rocks!!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/7889

Differential Revision: D15012769

Pulled By: ezyang

fbshipit-source-id: 258fc3009ea7b7390a3cf2e8a3682eafb506b08b
2019-04-23 08:54:09 -07:00
barrh
557b1b362f Fix copied optimizer (#19308)
Summary:
Add the defaults field to the copied object.
Prior to this patch, optimizer.__getattr__ has excluded the defaults
attribute of optimizer source object, required by some LR schedulers. (e.g. CyclicLR with momentum)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19308

Differential Revision: D15012801

Pulled By: soumith

fbshipit-source-id: 95801b269f6f9d78d531d4fed95c973b280cc96f
2019-04-19 10:27:01 -07:00
Jon Malmaud
1b25fdbcd0 More type stubs (#18511)
Summary:
Added stubs for:

* The `device` module
* The `cuda` module
* Parts of the `optim` module
* Began adding stubs for the `autograd` module. I'll annotate more later but `no_grad` and friends are probably the most used exports from it so it seemed like a good place to start.

This would close #16996, although comments on that issue reference other missing stubs so maybe it's worth keeping open as an umbrella issue.

The big remaining missing package is `nn`.

Also added a `py.typed` file so mypy will pick up on the type stubs. That closes #17639.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18511

Differential Revision: D14715053

Pulled By: ezyang

fbshipit-source-id: 9e4882ac997063650e6ce47604b3eaf1232c61c9
2019-04-01 16:03:58 -07:00
Edward Yang
173f224570 Turn on F401: Unused import warning. (#18598)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18598
ghimport-source-id: c74597e5e7437e94a43c163cee0639b20d0d0c6a

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18598 Turn on F401: Unused import warning.**

This was requested by someone at Facebook; this lint is turned
on for Facebook by default.  "Sure, why not."

I had to noqa a number of imports in __init__.  Hypothetically
we're supposed to use __all__ in this case, but I was too lazy
to fix it.  Left for future work.

Be careful!  flake8-2 and flake8-3 behave differently with
respect to import resolution for # type: comments.  flake8-3 will
report an import unused; flake8-2 will not.  For now, I just
noqa'd all these sites.

All the changes were done by hand.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: D14687478

fbshipit-source-id: 30d532381e914091aadfa0d2a5a89404819663e3
2019-03-30 09:01:17 -07:00
Søren Rasmussen
95d3825e48 ReduceLrOnPlateau: best=current -> best=copy(current) (#16364) (#16697)
Summary:
Fixes #16364
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16697

Differential Revision: D14680879

Pulled By: soumith

fbshipit-source-id: c50c22f3eacea4474fb3a04fe85fbf11d5a177c9
2019-03-29 06:56:51 -07:00
Sam Pepose
8635078d9e Adds Cyclical Learning Rate and Momentum (#18001)
Summary:
This implements a cyclical learning rate (CLR) schedule with an optional inverse cyclical momentum. More info about CLR: https://github.com/bckenstler/CLR

This is finishing what #2016 started. Resolves #1909.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18001

Differential Revision: D14451845

Pulled By: sampepose

fbshipit-source-id: 8f682e0c3dee3a73bd2b14cc93fcf5f0e836b8c9
2019-03-27 19:56:04 -07:00
Neta Zmora
1c76746f61 SGD: remove unneeded multiply-add initialization operations (#18114)
Summary:
The momentum buffer is initialized to the value of
d_p, but the current code takes the long way to do this:
1. Create a buffer of zeros
2. Multiply the buffer by the momentum coefficient
3. Add d_p to the buffer

All of these can be collapsed into a single step:
1. Create a clone of d_p
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18114

Differential Revision: D14509122

Pulled By: ezyang

fbshipit-source-id: 4a79b896201d5ff20770b7ae790c244ba744edb8
2019-03-19 10:34:17 -07:00
Chandler Zuo
096ee8467c Redefine scheduler to set learning rate using recursive formula (#14010)
Summary:
Modified step_lr for StepLR, MultiStepLR, ExponentialLR and CosineAnnealingLR. In this way, multiple schedulers can be used simultaneously to modify the learning rates.

Related issue: https://github.com/pytorch/pytorch/issues/13022

Added unit tests combining multiple schedulers.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14010

Reviewed By: ezyang

Differential Revision: D13494941

Pulled By: chandlerzuo

fbshipit-source-id: 7561270245639ba1f2c00748f8e4a5f7dec7160c
2018-12-18 16:44:31 -08:00
Jerry Ma
7956e9718b Add name for required optimizer parameter. (#13202)
Summary:
Small change -- the benefit is that the docs will show
``<required parameter>`` instead of ``<object object>``
for these required parameters.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13202

Reviewed By: SsnL

Differential Revision: D12826252

Pulled By: jma127

fbshipit-source-id: 5f2c8495e5c56920377e4e012b8711e8f2a6e30e
2018-10-29 15:02:21 -07:00
Soumith Chintala
cf235e0894 fix lint after new flake8 release added new style constraints (#13047)
Summary:
fix lint after new flake8 release added new style constraints
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13047

Differential Revision: D10527804

Pulled By: soumith

fbshipit-source-id: 6f4d02662570b6339f69117b61037c8394b0bbd8
2018-10-24 09:03:38 -07:00
Jerry Ma
383d340e88 Small optimization for adam (#12107)
Summary:
Apply weight decay for Adam in-place instead of via copy.

Synced offline with soumith , who mentioned that it should be OK. This is also consistent with other optimizers, e.g. eee01731a5/torch/optim/sgd.py (L93)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12107

Reviewed By: soumith

Differential Revision: D10071787

Pulled By: jma127

fbshipit-source-id: 5fd7939c79039693b225c44c4c80450923b8d673
2018-09-26 21:43:46 -07:00
Jeff Smith
05e06f7de2 migrating deprecated calls without abc module for containers (#11515)
Summary:
Implementing #10540.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11515

Reviewed By: apaszke

Differential Revision: D9771045

Pulled By: jeffreyksmithjr

fbshipit-source-id: 85ea39abaa9b465805a969f122b626b11fc85ef6
2018-09-13 15:09:22 -07:00
Peter Goldsborough
fb4e8088f3 Remove methods that start with an underscore from at::Tensor (#11152)
Summary:
This PR cleans up the `at::Tensor` class by removing all methods that start with an underscore in favor of functions in the `at::` namespace. This greatly cleans up the `Tensor` class and makes it clearer what is the public and non-public API.

For this I changed `native_functions.yaml` and `Declarations.cwrap` to make all underscore methods `variant: function` (or add such a statement to begin with), and then fixed all code locations using the underscore methods.

ezyang colesbury gchanan
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11152

Differential Revision: D9683607

Pulled By: goldsborough

fbshipit-source-id: 97f869f788fa56639c05a439e2a33be49f10f543
2018-09-07 11:55:11 -07:00
0phoff
294c065384 Changed serialization mechanism of LambdaLR scheduler (#9927)
Summary:
I opened an issue explaining some of my frustrations with the current state of schedulers.
While most points that I raised in [that issue](https://github.com/pytorch/pytorch/issues/8741#issuecomment-404449697) need to be discussed more thoroughly before being implemented, there are some that are not so difficult to fix.

This PR changes the way the LambdaLR scheduler gets serialized:
> The lr_lambda functions are only saved if the are callable objects (which can be stateful).
> There is no point in saving functions/lambdas as you need their definition before unpickling and they are stateless.

This has the big advantage that the scheduler is serializable, even if you use lambda functions or locally defined functions (aka a function in a function).

Does this functionality need any unit tests?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9927

Differential Revision: D9055505

Pulled By: soumith

fbshipit-source-id: 6c1cec588beedd098ec7d2bce6a9add27f29e48f
2018-07-31 19:39:06 -07:00
rasbt
eee01731a5 Adds the default value for the amsgrad arg to the Adam docstring (#9971)
Summary:
Minor addition to the docstring of `torch.nn.optim.Adam`, adding the default argument description for the `amsgrad` argument to the docstring for concistency.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9971

Differential Revision: D9040820

Pulled By: soumith

fbshipit-source-id: 168744a6bb0d1422331beffd7e694b9d6f61900c
2018-07-28 09:23:45 -07:00
Tongzhou Wang
27455e9c78 Use _six for inf and nan (#9500)
Summary:
Things like `float('inf')` are actually quite expensive.
```py
In [1]: import math

In [2]: %timeit -n 200 math.inf
49.3 ns ± 1.42 ns per loop (mean ± std. dev. of 7 runs, 200 loops each)

In [3]: %timeit -n 200 float('inf')
194 ns ± 39.1 ns per loop (mean ± std. dev. of 7 runs, 200 loops each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9500

Reviewed By: soumith

Differential Revision: D8876229

Pulled By: SsnL

fbshipit-source-id: 78602b76bb53d5588910b58270930c0bd413d2d7
2018-07-18 10:40:29 -07:00
Ailing
5a3f7810f8 _LRSchedulers getstate include optimizer info (#7757)
* getstate should include optimizer

* remove getstate/setstate functions
2018-05-23 11:43:42 -04:00
Matt Le
c5b9a36f1e Make return uniform in lbfgs step (#7586)
* Make return uniform in lbfgs step

This ensures that we are returning results of the same type
in LBFGS step.

* Adding test case to exercise different exit points

Sets the tolerance_grad to negative infinity and positive
infinity to deterministically excercise the early exit branch

* Fixing lint error
2018-05-16 11:16:46 -04:00
Changhan Wang
a257bd19a2 added state_dict/load_state_dict for ReduceLROnPlateau (#7201) 2018-05-10 12:02:28 +02:00
Domagoj Alagić
f43e067128 Make optimizer not complain about parameters with requires_grad=False (#7419) 2018-05-09 11:34:52 -04:00
Richard Zou
3369828bfa
Clarify patience in ReduceLROnPlateau docs (#7242)
* Clarify patience in ReduceLROnPlateau docs

It's unclear which definition of patience we have. The two ways to
interpret it are:
- How many bad epochs can you see before you start considering changing the learning rate.
- How many bad epochs can you see before you change the learning rate.

This PR clarifies the docs with an example. If `patience = 2`, then
after 2 bad epochs, we begin considering changing the learning rate.
After seeing one more epoch (the 3rd epoch), if that epoch is also bad,
then we change the learning rate after it.

* address comments
2018-05-04 16:39:26 -04:00
Samuel
0c737dff63 fix lbfgs variable names (#7037)
Switches the step/direction variable names (steps and directions are flipped
in the current implementation of the two loop-recursion). This change does
not change the numerical output of the program, but should make it easier
to follow.
2018-04-27 17:47:37 -04:00
Armen
e44f901b55 added functionality for state_dict/load_state_dict for lr_scheduler ( Fixes: #3026 ) (#6342)
* added functionality for state_dict/load_state_dict for lr_scheduler

* fixed linting issues/removed unused import

* refactor lr_scheduler state_dicts/state_dict holds everything __dict__ but optimizer

* changed documentation in lr_scheduler

* Update lr_scheduler.py
2018-04-19 07:09:03 -04:00
Tongzhou Wang
1c01eabd3c
Codemod to update our codebase to 0.4 standard (#6641)
* Codemod to update our codebase to 0.4 standard

* Update some of the test scri[ts

* remove Variable in test_clip_grad_value

* fix _symbolic_override_wrapper_maker
2018-04-17 22:06:54 -04:00
Atul Kumar
3e83e3abfe Adding initial_accumulator_value parameter to Adagrad (#6616) 2018-04-16 22:12:36 +02:00
Kento NOZAWA
3b58b859b2 Fix typos in docs (#6389) 2018-04-07 12:41:15 -04:00
Tongzhou Wang
a2880531ea fix SGD lr check (#6244) 2018-04-03 21:29:18 -04:00
Jiaming Liu
31c0e2321a Block set from param_group['params'] (#6031)
* Block set from param_group['params']

This might cause `list(params)` to output in random order. In this case, in `load_state_dict()`, `id_map` would not be matched correctly.

* Update Error Message

* Add Warning on Optimizer Docs

* Update optimizer.py
2018-03-28 07:45:19 -07:00
lazypanda1
063946d2b3 Added parameter range checks for all optimizers (#6000) 2018-03-28 11:22:23 +02:00
li-roy
df88373f88 set default ams param in adam optimizer (#5501) 2018-03-02 11:43:06 +01:00
Sam Gross
30ec06c140
Merge Variable and Tensor classes (#5225)
This replaces the torch.Tensor constructors with factories that produce
Variables. Similarly, functions on the torch module (e.g. torch.randn)
now return Variables.

To keep the PR to a reasonable size, I've left most of the unused tensor
code. Subsequent PRs will remove the dead code, clean-up calls to
torch.autograd.Variable, and rename Variable to Tensor everywhere.

There are some breaking changes because Variable and Tensors had
slightly different semantics. There's a list of those changes here:

 https://github.com/pytorch/pytorch/wiki/Breaking-Changes-from-Variable-and-Tensor-merge
2018-02-23 18:03:31 -05:00
Marcin Elantkowski
d2ff733cb1 Make ReduceLROnPlateau serializable. (#5300)
* replace lambdas with partial

* flake8
2018-02-20 00:59:14 -05:00
Martin Drawitsch
1fdb3929c9 Fixes for docstrings/sphinx rendering of CosineAnnealingLR and Local Response Normalization (#5254)
* Fix LaTex rendering in CosineAnnealingLR

Backslashes were interpreted by Python as escapes in the string, so \frac
turned into frac, which is not a valid LaTex command.
This could be fixed with double backslashes, but the easiest solution is to
just use a raw (r) docstring.

* Fix sphinx warnings for LRN doc headings

* Move LRN docstring from __init__ to class level

The docstring was not rendered by sphinx at
http://pytorch.org/docs/master/nn.html#torch.nn.LocalResponseNorm
because it was in the constructor.

* Remove superfluous backticks from LRN formula
2018-02-15 10:29:02 -05:00
lazypanda1
a061000250 Added check and test for betas parameter in Adam optimizer (#5147)
* Added check and test for betas parameter in Adam optimizer

* Simplified test
2018-02-11 20:24:43 -05:00
nguyen-binh-minh
188ee3ff0b Fix wrong learning rate evaluation in CosineAnnealingLR in Python 2 (#4656) 2018-01-14 13:10:41 +01:00
Jon Crall
f94f5723e7 fixed spelling (#4598) 2018-01-10 18:48:14 -05:00
Richard Zou
fe70823f8e Fix StepLR docs (#4478) 2018-01-04 12:37:26 -05:00
Dr. Kashif Rasul
859a173502 fix AMSGrad for SparseAdam (#4314) 2017-12-30 13:00:17 +01:00
Vishwak Srinivasan
89acc10f85 Adding description for Optimizers (#4371) 2017-12-28 16:55:52 +01:00
Sam Gross
d605058212
Replace Variable.volatile with torch.no_grad() (#3970)
This removes volatile from Variable. The functionality is mostly
replaced by a global (thread-local) flag, which is controlled by
torch.set_grad_enabled() and the context manager torch.no_grad().

In C++, the flag is exposed through GradMode::is_enabled() and GradMode::set_enabled()

Fixes #3627
2017-12-18 15:46:13 -05:00
Dr. Kashif Rasul
68c0998cbe added AMSgrad optimizer to Adam and SparseAdam (#4034)
* initial AMSGrad

* added test for amsgrad

* added amsgrad to adam

* fixed tests

* added option to sparse adam

* flake8
2017-12-18 13:24:49 -05:00
Kai Arulkumaran
e9ef20eab5 Add Cosine Annealing LR Scheduler (#3311)
* Add Cosine Annealing LR Scheduler

* Update eta_min in tests to prevent numerical mistakes

* Use non-zero min_eta in test_cos_anneal_lr
2017-12-18 02:43:08 -05:00
Adam Paszke
af9fd35d82 Cast tensors when loading optimizer state dicts (#3658) 2017-11-28 09:56:39 -05:00
Ozan Çağlayan
dd6d04ddf2 doc: Normalize all true/false in docstrings to `True|False` (#3593)
* doc: Normalize all true/false in docstrings to ``True|False``

This makes them more apparent in the documentation.

* doc: fix flake8
2017-11-09 08:12:29 -05:00
SsnL
f76d6c029c Sparse Adam optimizer for sparse gradients (#3137)
* sparse adam

* Favor dense addition over sparse_mask
2017-11-06 14:20:51 -05:00
SsnL
ba05dc5549 dense buffer (#3139) 2017-10-17 00:51:37 +02:00