mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
This PR is the final step to making `torch::` the only namespace users of the C++ API ever see. Basically, I did:
``` cpp
namespace torch {
using namespace at;
}
```
And then changed `torch::` to `at::` almost everywhere. This worked surprisingly well out of the box. So users can now write `torch::relu` and `torch::log_softmax` and `torch::conv2d` instead of having to know when to use `at::` and when `torch::`. This is happy!
Another thing I did was to have `using Dtype = at::ScalarType`, which will be the eventual name anyway.
ebetica ezyang apaszke zdevito
Closes https://github.com/pytorch/pytorch/pull/8911
Reviewed By: ezyang
Differential Revision: D8668230
Pulled By: goldsborough
fbshipit-source-id: a72ccb70fca763c396c4b0997d3c4767c8cf4fd3
105 lines
2.9 KiB
Python
105 lines
2.9 KiB
Python
"""Script to generate baseline values from PyTorch optimization algorithms"""
|
|
|
|
import argparse
|
|
import math
|
|
|
|
import torch
|
|
import torch.optim
|
|
|
|
|
|
HEADER = """
|
|
#include <torch/tensor.h>
|
|
|
|
#include <vector>
|
|
|
|
namespace expected_parameters {
|
|
"""
|
|
|
|
FOOTER = "} // namespace expected_parameters"
|
|
|
|
PARAMETERS = "static std::vector<std::vector<torch::Tensor>> {} = {{"
|
|
|
|
OPTIMIZERS = {
|
|
"Adam": lambda p: torch.optim.Adam(p, 1.0, weight_decay=1e-6),
|
|
"Adagrad": lambda p: torch.optim.Adagrad(p, 1.0, weight_decay=1e-6, lr_decay=1e-3),
|
|
"RMSprop": lambda p: torch.optim.RMSprop(p, 0.1, momentum=0.9, weight_decay=1e-6),
|
|
"SGD": lambda p: torch.optim.SGD(p, 0.1, momentum=0.9, weight_decay=1e-6),
|
|
}
|
|
|
|
|
|
def weight_init(module):
|
|
if isinstance(module, torch.nn.Linear):
|
|
stdev = 1.0 / math.sqrt(module.weight.size(1))
|
|
for p in module.parameters():
|
|
p.data.uniform_(-stdev, stdev)
|
|
|
|
|
|
def run(optimizer_name, iterations, sample_every):
|
|
torch.manual_seed(0)
|
|
model = torch.nn.Sequential(
|
|
torch.nn.Linear(2, 3),
|
|
torch.nn.Sigmoid(),
|
|
torch.nn.Linear(3, 1),
|
|
torch.nn.Sigmoid(),
|
|
)
|
|
model = model.to(torch.float64).apply(weight_init)
|
|
|
|
optimizer = OPTIMIZERS[optimizer_name](model.parameters())
|
|
|
|
input = torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], dtype=torch.float64)
|
|
|
|
values = []
|
|
for i in range(iterations):
|
|
optimizer.zero_grad()
|
|
|
|
output = model.forward(input)
|
|
loss = output.sum()
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
if i % sample_every == 0:
|
|
|
|
values.append(
|
|
[p.clone().flatten().data.numpy() for p in model.parameters()]
|
|
)
|
|
|
|
return values
|
|
|
|
|
|
def emit(optimizer_parameter_map):
|
|
# Don't write generated with an @ in front, else this file is recognized as generated.
|
|
print("// @{} from {}".format('generated', __file__))
|
|
print(HEADER)
|
|
for optimizer_name, parameters in optimizer_parameter_map.items():
|
|
print(PARAMETERS.format(optimizer_name))
|
|
for sample in parameters:
|
|
print(" {")
|
|
for parameter in sample:
|
|
parameter_values = "{{{}}}".format(", ".join(map(str, parameter)))
|
|
print(" torch::tensor({}),".format(parameter_values))
|
|
print(" },")
|
|
print("};\n")
|
|
print(FOOTER)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
"Produce optimization output baseline from PyTorch"
|
|
)
|
|
parser.add_argument("-i", "--iterations", default=1001, type=int)
|
|
parser.add_argument("-s", "--sample-every", default=100, type=int)
|
|
options = parser.parse_args()
|
|
|
|
optimizer_parameter_map = {}
|
|
for optimizer in ["Adam", "Adagrad", "RMSprop", "SGD"]:
|
|
optimizer_parameter_map[optimizer] = run(
|
|
optimizer, options.iterations, options.sample_every
|
|
)
|
|
|
|
emit(optimizer_parameter_map)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|