pytorch/test/cpp/api/optim_baseline.py
Peter Goldsborough 13814d6744 Remove use of data() in optimizers (#10490)
Summary:
After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors.

This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations.

The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`.

For this PR I:

1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`.
2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy!
3. Minor cleanup of the optimizer codebase

ebetica apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490

Differential Revision: D9318229

Pulled By: goldsborough

fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 13:10:19 -07:00

119 lines
3.9 KiB
Python

"""Script to generate baseline values from PyTorch optimization algorithms"""
import argparse
import math
import sys
import torch
import torch.optim
HEADER = """
#include <torch/tensor.h>
#include <vector>
namespace expected_parameters {
"""
FOOTER = "} // namespace expected_parameters"
PARAMETERS = "static std::vector<std::vector<torch::Tensor>> {} = {{"
OPTIMIZERS = {
"Adam": lambda p: torch.optim.Adam(p, 1.0),
"Adam_with_weight_decay": lambda p: torch.optim.Adam(p, 1.0, weight_decay=1e-2),
"Adam_with_weight_decay_and_amsgrad": lambda p: torch.optim.Adam(p, 1.0, weight_decay=1e-6, amsgrad=True),
"Adagrad": lambda p: torch.optim.Adagrad(p, 1.0),
"Adagrad_with_weight_decay": lambda p: torch.optim.Adagrad(p, 1.0, weight_decay=1e-2),
"Adagrad_with_weight_decay_and_lr_decay": lambda p: torch.optim.Adagrad(p, 1.0, weight_decay=1e-6, lr_decay=1e-3),
"RMSprop": lambda p: torch.optim.RMSprop(p, 0.1),
"RMSprop_with_weight_decay": lambda p: torch.optim.RMSprop(p, 0.1, weight_decay=1e-2),
"RMSprop_with_weight_decay_and_centered": lambda p: torch.optim.RMSprop(p, 0.1, weight_decay=1e-6, centered=True),
"RMSprop_with_weight_decay_and_centered_and_momentum":
lambda p: torch.optim.RMSprop(p, 0.1, weight_decay=1e-6, centered=True, momentum=0.9),
"SGD": lambda p: torch.optim.SGD(p, 0.1),
"SGD_with_weight_decay": lambda p: torch.optim.SGD(p, 0.1, weight_decay=1e-2),
"SGD_with_weight_decay_and_momentum": lambda p: torch.optim.SGD(p, 0.1, momentum=0.9, weight_decay=1e-2),
"SGD_with_weight_decay_and_nesterov_momentum":
lambda p: torch.optim.SGD(p, 0.1, momentum=0.9, weight_decay=1e-6, nesterov=True),
}
def weight_init(module):
if isinstance(module, torch.nn.Linear):
stdev = 1.0 / math.sqrt(module.weight.size(1))
for p in module.parameters():
p.data.uniform_(-stdev, stdev)
def run(optimizer_name, iterations, sample_every):
torch.manual_seed(0)
model = torch.nn.Sequential(
torch.nn.Linear(2, 3),
torch.nn.Sigmoid(),
torch.nn.Linear(3, 1),
torch.nn.Sigmoid(),
)
model = model.to(torch.float64).apply(weight_init)
optimizer = OPTIMIZERS[optimizer_name](model.parameters())
input = torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], dtype=torch.float64)
values = []
for i in range(iterations):
optimizer.zero_grad()
output = model.forward(input)
loss = output.sum()
loss.backward()
optimizer.step()
if i % sample_every == 0:
values.append(
[p.clone().flatten().data.numpy() for p in model.parameters()]
)
return values
def emit(optimizer_parameter_map):
# Don't write generated with an @ in front, else this file is recognized as generated.
print("// @{} from {}".format('generated', __file__))
print(HEADER)
for optimizer_name, parameters in optimizer_parameter_map.items():
print(PARAMETERS.format(optimizer_name))
for sample in parameters:
print(" {")
for parameter in sample:
parameter_values = "{{{}}}".format(", ".join(map(str, parameter)))
print(" torch::tensor({}),".format(parameter_values))
print(" },")
print("};\n")
print(FOOTER)
def main():
parser = argparse.ArgumentParser(
"Produce optimization output baseline from PyTorch"
)
parser.add_argument("-i", "--iterations", default=1001, type=int)
parser.add_argument("-s", "--sample-every", default=100, type=int)
options = parser.parse_args()
optimizer_parameter_map = {}
for optimizer in OPTIMIZERS.keys():
sys.stderr.write('Evaluating {} ...\n'.format(optimizer))
optimizer_parameter_map[optimizer] = run(
optimizer, options.iterations, options.sample_every
)
emit(optimizer_parameter_map)
if __name__ == "__main__":
main()