mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Merge autogradpp into PyTorch (#7074)
* Dump autogradpp into PyTorch * Fixed up CMake for autogradpp/C++ API * Made cereal a submodule * Change search location of autogradpps mnist directory * Add test_api to CI * Download MNIST from the internet instead of storing in repo * Fix warnings
This commit is contained in:
parent
3407708b81
commit
af71fb882f
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -55,6 +55,7 @@ test/.coverage
|
|||
test/data/legacy_serialized.pt
|
||||
test/data/linear.pt
|
||||
.mypy_cache
|
||||
test/cpp/api/mnist
|
||||
|
||||
# IPython notebook checkpoints
|
||||
.ipynb_checkpoints
|
||||
|
|
|
|||
3
.gitmodules
vendored
3
.gitmodules
vendored
|
|
@ -77,3 +77,6 @@
|
|||
[submodule "third_party/ideep"]
|
||||
path = third_party/ideep
|
||||
url = https://github.com/intel/ideep.git
|
||||
[submodule "third_party/cereal"]
|
||||
path = third_party/cereal
|
||||
url = https://github.com/USCiLab/cereal
|
||||
|
|
|
|||
|
|
@ -55,4 +55,6 @@ if [[ "$BUILD_TEST_LIBTORCH" == "1" ]]; then
|
|||
else
|
||||
"$CPP_BUILD"/libtorch/bin/test_jit "[cpu]"
|
||||
fi
|
||||
python tools/download_mnist.py -d test/cpp/api/mnist
|
||||
"$CPP_BUILD"/libtorch/bin/test_api
|
||||
fi
|
||||
|
|
|
|||
258
test/cpp/api/container_t.cpp
Normal file
258
test/cpp/api/container_t.cpp
Normal file
|
|
@ -0,0 +1,258 @@
|
|||
#include "test.h"
|
||||
|
||||
AUTOGRAD_CONTAINER_CLASS(TestModel) {
|
||||
public:
|
||||
void initialize_containers() override {
|
||||
add(Linear(10, 3).make(), "l1");
|
||||
add(Linear(3, 5).make(), "l2");
|
||||
add(Linear(5, 100).make(), "l3");
|
||||
}
|
||||
|
||||
variable_list forward(variable_list input) override { return input; };
|
||||
};
|
||||
|
||||
AUTOGRAD_CONTAINER_CLASS(NestedModel) {
|
||||
public:
|
||||
void initialize_containers() override {
|
||||
add(Linear(5, 20).make(), "l1");
|
||||
add(TestModel().make(), "test");
|
||||
}
|
||||
|
||||
void initialize_parameters() override {
|
||||
add(Var(DefaultTensor(at::kFloat).tensor({3, 2, 21}), false), "param");
|
||||
}
|
||||
|
||||
variable_list forward(variable_list input) override { return input; };
|
||||
};
|
||||
|
||||
CASE("containers/conv2d/even") {
|
||||
auto model = Conv2d(3, 2, 3).stride(2).make();
|
||||
auto x = Var(at::CPU(at::kFloat).randn({2, 3, 5, 5}), true);
|
||||
auto y = model->forward({x})[0];
|
||||
Variable s = y.sum();
|
||||
|
||||
backward(s);
|
||||
EXPECT(y.ndimension() == 4);
|
||||
EXPECT(s.ndimension() == 0);
|
||||
for (auto i = 0; i < 4; i++) {
|
||||
EXPECT(y.size(i) == 2);
|
||||
}
|
||||
|
||||
EXPECT(model->parameters()["weight"].grad().numel() == 3 * 2 * 3 * 3);
|
||||
};
|
||||
|
||||
CASE("containers/conv2d/uneven") {
|
||||
auto model = Conv2d(3, 2, IntVec({3, 2})).stride(2).make();
|
||||
auto x = Var(at::CPU(at::kFloat).randn({2, 3, 5, 4}), true);
|
||||
auto y = model->forward({x})[0];
|
||||
Variable s = y.sum();
|
||||
|
||||
backward(s);
|
||||
EXPECT(y.ndimension() == 4);
|
||||
EXPECT(s.ndimension() == 0);
|
||||
for (auto i = 0; i < 4; i++) {
|
||||
EXPECT(y.size(i) == 2);
|
||||
}
|
||||
|
||||
EXPECT(model->parameters()["weight"].grad().numel() == 3 * 2 * 3 * 2);
|
||||
};
|
||||
|
||||
CASE("containers/conv1d/even") {
|
||||
auto model = Conv1d(3, 2, 3).stride(2).make();
|
||||
auto x = Var(at::CPU(at::kFloat).randn({2, 3, 5}), true);
|
||||
auto y = model->forward({x})[0];
|
||||
Variable s = y.sum();
|
||||
|
||||
backward(s);
|
||||
EXPECT(y.ndimension() == 4);
|
||||
EXPECT(s.ndimension() == 0);
|
||||
for (auto i = 0; i < 3; i++) {
|
||||
EXPECT(y.size(i) == 2);
|
||||
}
|
||||
|
||||
EXPECT(model->parameters()["weight"].grad().numel() == 3 * 2 * 3);
|
||||
};
|
||||
|
||||
CASE("containers/conv3d/even") {
|
||||
auto model = Conv3d(3, 2, 3).stride(2).make();
|
||||
auto x = Var(at::CPU(at::kFloat).randn({2, 3, 5, 5, 5}), true);
|
||||
auto y = model->forward({x})[0];
|
||||
Variable s = y.sum();
|
||||
|
||||
backward(s);
|
||||
EXPECT(y.ndimension() == 5);
|
||||
EXPECT(s.ndimension() == 0);
|
||||
for (auto i = 0; i < 5; i++) {
|
||||
EXPECT(y.size(i) == 2);
|
||||
}
|
||||
|
||||
EXPECT(model->parameters()["weight"].grad().numel() == 3 * 2 * 3 * 3 * 3);
|
||||
};
|
||||
|
||||
CASE("containers/linear/basic1") {
|
||||
auto model = Linear(5, 2).make();
|
||||
auto x = Var(at::CPU(at::kFloat).randn({10, 5}), true);
|
||||
auto y = model->forward({x})[0];
|
||||
Variable s = y.sum();
|
||||
|
||||
backward(s);
|
||||
EXPECT(y.ndimension() == 2);
|
||||
EXPECT(s.ndimension() == 0);
|
||||
EXPECT(y.size(0) == 10);
|
||||
EXPECT(y.size(1) == 2);
|
||||
|
||||
EXPECT(model->parameters()["weight"].grad().numel() == 2 * 5);
|
||||
};
|
||||
|
||||
CASE("containers/linear/sequential") {
|
||||
auto model = ContainerList()
|
||||
.append(Linear(10, 3).make())
|
||||
.append(Linear(3, 5).make())
|
||||
.append(Linear(5, 100).make())
|
||||
.make();
|
||||
|
||||
auto x = Var(at::CPU(at::kFloat).randn({1000, 10}));
|
||||
for (auto layer : *model) {
|
||||
x = layer->forward({x})[0];
|
||||
x = x.clamp_min(0); // relu
|
||||
}
|
||||
|
||||
backward(x);
|
||||
EXPECT(x.ndimension() == 2);
|
||||
EXPECT(x.size(0) == 1000);
|
||||
EXPECT(x.size(1) == 100);
|
||||
EXPECT(x.data().min().toCFloat() == 0);
|
||||
};
|
||||
|
||||
CASE("containers/linear/simple") {
|
||||
auto model = SimpleContainer().make();
|
||||
auto l1 = model->add(Linear(10, 3).make(), "l1");
|
||||
auto l2 = model->add(Linear(3, 5).make(), "l2");
|
||||
auto l3 = model->add(Linear(5, 100).make(), "l3");
|
||||
|
||||
auto x = Var(at::CPU(at::kFloat).randn({1000, 10}));
|
||||
x = l1->forward({x})[0].clamp_min(0);
|
||||
x = l2->forward({x})[0].clamp_min(0);
|
||||
x = l3->forward({x})[0].clamp_min(0);
|
||||
|
||||
backward(x);
|
||||
EXPECT(x.ndimension() == 2);
|
||||
EXPECT(x.size(0) == 1000);
|
||||
EXPECT(x.size(1) == 100);
|
||||
EXPECT(x.data().min().toCFloat() == 0);
|
||||
};
|
||||
|
||||
CASE("containers/clone") {
|
||||
auto model = TestModel().make();
|
||||
|
||||
auto model2 = model->clone();
|
||||
auto m1param = model->parameters();
|
||||
auto m2param = model2->parameters();
|
||||
for (auto& param : m1param) {
|
||||
EXPECT(param.second.allclose(m2param[param.first]));
|
||||
param.second.data().mul_(2);
|
||||
}
|
||||
for (auto& param : m1param) {
|
||||
EXPECT(!param.second.allclose(m2param[param.first]));
|
||||
}
|
||||
};
|
||||
|
||||
CASE("containers/embedding/basic") {
|
||||
int dict_size = 10;
|
||||
auto model = Embedding(dict_size, 2).make();
|
||||
// Cannot get gradients to change indices (input) - only for embedding params
|
||||
auto x = Var(at::CPU(at::kLong).tensor({10}).fill_(dict_size - 1), false);
|
||||
auto y = model->forward({x})[0];
|
||||
Variable s = y.sum();
|
||||
|
||||
backward(s);
|
||||
EXPECT(y.ndimension() == 2);
|
||||
EXPECT(s.ndimension() == 0);
|
||||
EXPECT(y.size(0) == 10);
|
||||
EXPECT(y.size(1) == 2);
|
||||
|
||||
EXPECT(model->parameters()["weight"].grad().numel() == 2 * dict_size);
|
||||
};
|
||||
|
||||
CASE("containers/embedding/list") {
|
||||
auto model = Embedding(6, 4).make();
|
||||
auto x = Var(at::CPU(at::kLong).tensor({2, 3}).fill_(5), false);
|
||||
auto y = model->forward({x})[0];
|
||||
Variable s = y.sum();
|
||||
|
||||
backward(s);
|
||||
EXPECT(y.ndimension() == 3);
|
||||
EXPECT(y.size(0) == 2);
|
||||
EXPECT(y.size(1) == 3);
|
||||
EXPECT(y.size(2) == 4);
|
||||
};
|
||||
|
||||
CASE("containers/cuda/1") {
|
||||
CUDA_GUARD;
|
||||
auto model = Linear(5, 2).make();
|
||||
model->cuda();
|
||||
auto x = Var(at::CUDA(at::kFloat).randn({10, 5}), true);
|
||||
auto y = model->forward({x})[0];
|
||||
Variable s = y.sum();
|
||||
|
||||
backward(s);
|
||||
EXPECT(y.ndimension() == 2);
|
||||
EXPECT(s.ndimension() == 0);
|
||||
EXPECT(y.size(0) == 10);
|
||||
EXPECT(y.size(1) == 2);
|
||||
|
||||
EXPECT(model->parameters()["weight"].grad().numel() == 2 * 5);
|
||||
};
|
||||
|
||||
CASE("containers/cuda/2") {
|
||||
CUDA_GUARD;
|
||||
auto model = Linear(5, 2).make();
|
||||
model->cuda();
|
||||
model->cpu();
|
||||
auto x = Var(at::CPU(at::kFloat).randn({10, 5}), true);
|
||||
auto y = model->forward({x})[0];
|
||||
Variable s = y.sum();
|
||||
|
||||
backward(s);
|
||||
EXPECT(y.ndimension() == 2);
|
||||
EXPECT(s.ndimension() == 0);
|
||||
EXPECT(y.size(0) == 10);
|
||||
EXPECT(y.size(1) == 2);
|
||||
|
||||
EXPECT(model->parameters()["weight"].grad().numel() == 2 * 5);
|
||||
};
|
||||
|
||||
CASE("containers/dropout/1") {
|
||||
auto dropout = Dropout(0.5).make();
|
||||
Variable x = Var(at::CPU(at::kFloat).ones(100));
|
||||
Variable y = dropout->forward({x})[0];
|
||||
|
||||
backward(y);
|
||||
EXPECT(y.ndimension() == 1);
|
||||
EXPECT(y.size(0) == 100);
|
||||
EXPECT(y.sum().toCFloat() < 130); // Probably
|
||||
EXPECT(y.sum().toCFloat() > 70); // Probably
|
||||
|
||||
dropout->eval();
|
||||
y = dropout->forward({x})[0];
|
||||
EXPECT(y.data().sum().toCFloat() == 100);
|
||||
};
|
||||
|
||||
CASE("containers/param") {
|
||||
auto model = NestedModel().make();
|
||||
EXPECT(model->param("param").size(0) == 3);
|
||||
EXPECT(model->param("param").size(1) == 2);
|
||||
EXPECT(model->param("param").size(2) == 21);
|
||||
EXPECT(model->param("l1.bias").size(0) == 20);
|
||||
EXPECT(model->param("l1.weight").size(0) == 20);
|
||||
EXPECT(model->param("l1.weight").size(1) == 5);
|
||||
EXPECT(model->param("test.l1.bias").size(0) == 3);
|
||||
EXPECT(model->param("test.l1.weight").size(0) == 3);
|
||||
EXPECT(model->param("test.l1.weight").size(1) == 10);
|
||||
EXPECT(model->param("test.l2.bias").size(0) == 5);
|
||||
EXPECT(model->param("test.l2.weight").size(0) == 5);
|
||||
EXPECT(model->param("test.l2.weight").size(1) == 3);
|
||||
EXPECT(model->param("test.l3.bias").size(0) == 100);
|
||||
EXPECT(model->param("test.l3.weight").size(0) == 100);
|
||||
EXPECT(model->param("test.l3.weight").size(1) == 5);
|
||||
}
|
||||
355
test/cpp/api/integration_t.cpp
Normal file
355
test/cpp/api/integration_t.cpp
Normal file
|
|
@ -0,0 +1,355 @@
|
|||
#include "test.h"
|
||||
|
||||
class CartPole {
|
||||
// Translated from openai/gym's cartpole.py
|
||||
public:
|
||||
double gravity = 9.8;
|
||||
double masscart = 1.0;
|
||||
double masspole = 0.1;
|
||||
double total_mass = (masspole + masscart);
|
||||
double length = 0.5; // actually half the pole's length;
|
||||
double polemass_length = (masspole * length);
|
||||
double force_mag = 10.0;
|
||||
double tau = 0.02; // seconds between state updates;
|
||||
|
||||
// Angle at which to fail the episode
|
||||
double theta_threshold_radians = 12 * 2 * M_PI / 360;
|
||||
double x_threshold = 2.4;
|
||||
int steps_beyond_done = -1;
|
||||
|
||||
at::Tensor state;
|
||||
double reward;
|
||||
bool done;
|
||||
int step_ = 0;
|
||||
|
||||
at::Tensor getState() {
|
||||
return state;
|
||||
}
|
||||
|
||||
double getReward() {
|
||||
return reward;
|
||||
}
|
||||
|
||||
double isDone() {
|
||||
return done;
|
||||
}
|
||||
|
||||
void reset() {
|
||||
state = at::CPU(at::kFloat).tensor({4}).uniform_(-0.05, 0.05);
|
||||
steps_beyond_done = -1;
|
||||
step_ = 0;
|
||||
}
|
||||
|
||||
CartPole() {
|
||||
reset();
|
||||
}
|
||||
|
||||
void step(int action) {
|
||||
auto x = state[0].toCFloat();
|
||||
auto x_dot = state[1].toCFloat();
|
||||
auto theta = state[2].toCFloat();
|
||||
auto theta_dot = state[3].toCFloat();
|
||||
|
||||
auto force = (action == 1) ? force_mag : -force_mag;
|
||||
auto costheta = std::cos(theta);
|
||||
auto sintheta = std::sin(theta);
|
||||
auto temp = (force + polemass_length * theta_dot * theta_dot * sintheta) / total_mass;
|
||||
auto thetaacc = (gravity * sintheta - costheta* temp) / (length * (4.0/3.0 - masspole * costheta * costheta / total_mass));
|
||||
auto xacc = temp - polemass_length * thetaacc * costheta / total_mass;
|
||||
|
||||
x = x + tau * x_dot;
|
||||
x_dot = x_dot + tau * xacc;
|
||||
theta = theta + tau * theta_dot;
|
||||
theta_dot = theta_dot + tau * thetaacc;
|
||||
state[0] = x;
|
||||
state[1] = x_dot;
|
||||
state[2] = theta;
|
||||
state[3] = theta_dot;
|
||||
done = x < - x_threshold
|
||||
|| x > x_threshold
|
||||
|| theta < -theta_threshold_radians
|
||||
|| theta > theta_threshold_radians
|
||||
|| step_ > 200;
|
||||
|
||||
if (!done) {
|
||||
reward = 1.0;
|
||||
} else if (steps_beyond_done == -1) {
|
||||
// Pole just fell!
|
||||
steps_beyond_done = 0;
|
||||
reward = 0;
|
||||
} else {
|
||||
if (steps_beyond_done == 0) {
|
||||
assert(false); // Can't do this
|
||||
}
|
||||
}
|
||||
step_++;
|
||||
|
||||
};
|
||||
};
|
||||
|
||||
template <typename M, typename F, typename O>
|
||||
bool test_mnist(uint32_t batch_size, uint32_t num_epochs, bool useGPU,
|
||||
M&& model, F&& forward_op, O&& optim) {
|
||||
std::cout << "Training MNIST for " << num_epochs << " epochs, rest your eyes for a bit!\n";
|
||||
struct MNIST_Reader
|
||||
{
|
||||
FILE *fp_;
|
||||
|
||||
MNIST_Reader(const char *path) {
|
||||
fp_ = fopen(path, "rb");
|
||||
if (!fp_) throw std::runtime_error("failed to open file");
|
||||
}
|
||||
|
||||
~MNIST_Reader() { if (fp_) fclose(fp_); }
|
||||
|
||||
int32_t read_int() {
|
||||
uint8_t buf[4];
|
||||
if (fread(buf, sizeof(buf), 1, fp_) != 1) throw std::runtime_error("failed to read an integer");
|
||||
return int32_t(buf[0] << 24 | buf[1] << 16 | buf[2] << 8 | buf[3]);
|
||||
}
|
||||
|
||||
uint8_t read_byte() {
|
||||
uint8_t i;
|
||||
if (fread(&i, sizeof(i), 1, fp_) != 1) throw std::runtime_error("failed to read an byte");
|
||||
return i;
|
||||
}
|
||||
};
|
||||
|
||||
auto readData = [&](std::string fn) {
|
||||
MNIST_Reader rd(fn.c_str());
|
||||
|
||||
/* int image_magic = */ rd.read_int();
|
||||
int image_count = rd.read_int();
|
||||
int image_rows = rd.read_int();
|
||||
int image_cols = rd.read_int();
|
||||
|
||||
auto data = at::CPU(at::kFloat).tensor({image_count, 1, image_rows, image_cols});
|
||||
auto a_data = data.accessor<float, 4>();
|
||||
|
||||
for (int c = 0; c < image_count; c++) {
|
||||
for (int i = 0; i < image_rows; i++) {
|
||||
for (int j = 0; j < image_cols; j++) {
|
||||
a_data[c][0][i][j] = float(rd.read_byte()) / 255;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return data.toBackend(useGPU ? at::kCUDA : at::kCPU);
|
||||
};
|
||||
|
||||
auto readLabels = [&](std::string fn) {
|
||||
MNIST_Reader rd(fn.c_str());
|
||||
/* int label_magic = */ rd.read_int();
|
||||
int label_count = rd.read_int();
|
||||
|
||||
auto data = at::CPU(at::kLong).tensor({label_count});
|
||||
auto a_data = data.accessor<int64_t, 1>();
|
||||
|
||||
for (int i = 0; i < label_count; ++i) {
|
||||
a_data[i] = long(rd.read_byte());
|
||||
}
|
||||
return data.toBackend(useGPU ? at::kCUDA : at::kCPU);
|
||||
};
|
||||
|
||||
auto trdata = readData("test/cpp/api/mnist/train-images-idx3-ubyte");
|
||||
auto trlabel = readLabels("test/cpp/api/mnist/train-labels-idx1-ubyte");
|
||||
auto tedata = readData("test/cpp/api/mnist/t10k-images-idx3-ubyte");
|
||||
auto telabel = readLabels("test/cpp/api/mnist/t10k-labels-idx1-ubyte");
|
||||
|
||||
if (useGPU) {
|
||||
model->cuda();
|
||||
}
|
||||
|
||||
for (auto epoch = 0U; epoch < num_epochs; epoch++) {
|
||||
auto shuffled_inds = std::vector<int>(trdata.size(0));
|
||||
for (int i=0; i < trdata.size(0); i++) {
|
||||
shuffled_inds[i] = i;
|
||||
}
|
||||
std::random_shuffle(shuffled_inds.begin(), shuffled_inds.end());
|
||||
|
||||
auto inp = (useGPU ? at::CUDA : at::CPU)(at::kFloat).tensor({batch_size, 1, trdata.size(2), trdata.size(3)});
|
||||
auto lab = (useGPU ? at::CUDA : at::CPU)(at::kLong).tensor({batch_size});
|
||||
for (auto p = 0U; p < shuffled_inds.size() - batch_size; p++) {
|
||||
inp[p % batch_size] = trdata[shuffled_inds[p]];
|
||||
lab[p % batch_size] = trlabel[shuffled_inds[p]];
|
||||
|
||||
if (p % batch_size != batch_size - 1) continue;
|
||||
Variable x = forward_op(Var(inp));
|
||||
Variable y = Var(lab, false);
|
||||
Variable loss = at::nll_loss(x, y);
|
||||
|
||||
optim->zero_grad();
|
||||
backward(loss);
|
||||
optim->step();
|
||||
}
|
||||
}
|
||||
|
||||
no_grad_guard guard;
|
||||
auto result = std::get<1>(forward_op(Var(tedata, false)).max(1));
|
||||
Variable correct = (result == Var(telabel)).toType(at::kFloat);
|
||||
std::cout << "Num correct: " << correct.data().sum().toCFloat()
|
||||
<< " out of " << telabel.size(0) << std::endl;
|
||||
return correct.data().sum().toCFloat() > telabel.size(0) * 0.8;
|
||||
};
|
||||
|
||||
CASE("integration/RL/cartpole") {
|
||||
std::cout << "Training episodic policy gradient with a critic for up to 3000"
|
||||
" episodes, rest your eyes for a bit!\n";
|
||||
auto model = SimpleContainer().make();
|
||||
auto linear = model->add(Linear(4, 128).make(), "linear");
|
||||
auto policyHead = model->add(Linear(128, 2).make(), "policy");
|
||||
auto valueHead = model->add(Linear(128, 1).make(), "action");
|
||||
auto optim = Adam(model, 1e-3).make();
|
||||
|
||||
std::vector<Variable> saved_log_probs;
|
||||
std::vector<Variable> saved_values;
|
||||
std::vector<float> rewards;
|
||||
|
||||
auto forward = [&](variable_list inp) {
|
||||
auto x = linear->forward(inp)[0].clamp_min(0);
|
||||
Variable actions = policyHead->forward({x})[0];
|
||||
Variable value = valueHead->forward({x})[0];
|
||||
return std::make_tuple(at::softmax(actions, -1), value);
|
||||
};
|
||||
|
||||
auto selectAction = [&](at::Tensor state) {
|
||||
// Only work on single state right now, change index to gather for batch
|
||||
auto out = forward({Var(state, false)});
|
||||
auto probs = Variable(std::get<0>(out));
|
||||
auto value = Variable(std::get<1>(out));
|
||||
auto action = probs.data().multinomial(1)[0].toCInt();
|
||||
// Compute the log prob of a multinomial distribution.
|
||||
// This should probably be actually implemented in autogradpp...
|
||||
auto p = probs / probs.sum(-1, true);
|
||||
auto log_prob = p[action].log();
|
||||
saved_log_probs.push_back(log_prob);
|
||||
saved_values.push_back(value);
|
||||
return action;
|
||||
};
|
||||
|
||||
auto finishEpisode = [&]() {
|
||||
auto R = 0.;
|
||||
for (int i = rewards.size() - 1; i >= 0; i--) {
|
||||
R = rewards[i] + 0.99 * R;
|
||||
rewards[i] = R;
|
||||
}
|
||||
auto r_t = at::CPU(at::kFloat).tensorFromBlob(rewards.data(), {static_cast<int64_t>(rewards.size())});
|
||||
r_t = (r_t - r_t.mean()) / (r_t.std() + 1e-5);
|
||||
|
||||
std::vector<at::Tensor> policy_loss;
|
||||
std::vector<at::Tensor> value_loss;
|
||||
for (auto i = 0U; i < saved_log_probs.size(); i++) {
|
||||
auto r = rewards[i] - saved_values[i].toCFloat();
|
||||
policy_loss.push_back(- r * saved_log_probs[i]);
|
||||
value_loss.push_back(at::smooth_l1_loss(saved_values[i], Var(at::CPU(at::kFloat).scalarTensor(at::Scalar(rewards[i])), false)));
|
||||
}
|
||||
auto loss = at::stack(policy_loss).sum() + at::stack(value_loss).sum();
|
||||
|
||||
optim->zero_grad();
|
||||
backward(loss);
|
||||
optim->step();
|
||||
|
||||
rewards.clear();
|
||||
saved_log_probs.clear();
|
||||
saved_values.clear();
|
||||
};
|
||||
|
||||
auto env = CartPole();
|
||||
double running_reward = 10.0;
|
||||
for (auto episode = 0; ; episode++) {
|
||||
env.reset();
|
||||
auto state = env.getState();
|
||||
int t = 0;
|
||||
for ( ; t < 10000; t++) {
|
||||
auto action = selectAction(state);
|
||||
env.step(action);
|
||||
state = env.getState();
|
||||
auto reward = env.getReward();
|
||||
auto done = env.isDone();
|
||||
|
||||
rewards.push_back(reward);
|
||||
if (done) break;
|
||||
}
|
||||
|
||||
running_reward = running_reward * 0.99 + t * 0.01;
|
||||
finishEpisode();
|
||||
/*
|
||||
if (episode % 10 == 0) {
|
||||
printf("Episode %i\tLast length: %5d\tAverage length: %.2f\n",
|
||||
episode, t, running_reward);
|
||||
}
|
||||
*/
|
||||
if (running_reward > 150) break;
|
||||
EXPECT(episode < 3000);
|
||||
}
|
||||
}
|
||||
|
||||
CASE("integration/mnist") { // ~ will make it run last :D
|
||||
CUDA_GUARD;
|
||||
auto model = SimpleContainer().make();
|
||||
auto conv1 = model->add(Conv2d(1, 10, 5).make(), "conv1");
|
||||
auto conv2 = model->add(Conv2d(10, 20, 5).make(), "conv2");
|
||||
auto drop = Dropout(0.3).make();
|
||||
auto drop2d = Dropout2d(0.3).make();
|
||||
auto linear1 = model->add(Linear(320, 50).make(), "linear1");
|
||||
auto linear2 = model->add(Linear(50, 10).make(), "linear2");
|
||||
|
||||
auto forward = [&](Variable x) {
|
||||
x = std::get<0>(at::max_pool2d(conv1->forward({x})[0], {2, 2})).clamp_min(0);
|
||||
x = conv2->forward({x})[0];
|
||||
x = drop2d->forward({x})[0];
|
||||
x = std::get<0>(at::max_pool2d(x, {2, 2})).clamp_min(0);
|
||||
|
||||
x = x.view({-1, 320});
|
||||
x = linear1->forward({x})[0].clamp_min(0);
|
||||
x = drop->forward({x})[0];
|
||||
x = linear2->forward({x})[0];
|
||||
x = at::log_softmax(x, 1);
|
||||
return x;
|
||||
};
|
||||
|
||||
auto optim = SGD(model, 1e-2).momentum(0.5).make();
|
||||
|
||||
EXPECT(test_mnist(
|
||||
32, // batch_size
|
||||
3, // num_epochs
|
||||
true, // useGPU
|
||||
model, forward, optim));
|
||||
};
|
||||
|
||||
CASE("integration/mnist_batchnorm") { // ~ will make it run last :D
|
||||
CUDA_GUARD;
|
||||
auto model = SimpleContainer().make();
|
||||
auto conv1 = model->add(Conv2d(1, 10, 5).make(), "conv1");
|
||||
auto batchnorm2d = model->add(
|
||||
BatchNorm(10).stateful().make(),
|
||||
"batchnorm2d");
|
||||
auto conv2 = model->add(Conv2d(10, 20, 5).make(), "conv2");
|
||||
auto linear1 = model->add(Linear(320, 50).make(), "linear1");
|
||||
auto batchnorm1 = model->add(
|
||||
BatchNorm(50).stateful().make(),
|
||||
"batchnorm1");
|
||||
auto linear2 = model->add(Linear(50, 10).make(), "linear2");
|
||||
|
||||
auto forward = [&](Variable x) {
|
||||
x = std::get<0>(at::max_pool2d(conv1->forward({x})[0], {2, 2})).clamp_min(0);
|
||||
x = batchnorm2d->forward({x})[0];
|
||||
x = conv2->forward({x})[0];
|
||||
x = std::get<0>(at::max_pool2d(x, {2, 2})).clamp_min(0);
|
||||
|
||||
x = x.view({-1, 320});
|
||||
x = linear1->forward({x})[0].clamp_min(0);
|
||||
x = batchnorm1->forward({x})[0];
|
||||
x = linear2->forward({x})[0];
|
||||
x = at::log_softmax(x, 1);
|
||||
return x;
|
||||
};
|
||||
|
||||
auto optim = SGD(model, 1e-2).momentum(0.5).make();
|
||||
|
||||
EXPECT(test_mnist(
|
||||
32, // batch_size
|
||||
3, // num_epochs
|
||||
true, // useGPU
|
||||
model, forward, optim));
|
||||
};
|
||||
1308
test/cpp/api/lest.hpp
Normal file
1308
test/cpp/api/lest.hpp
Normal file
File diff suppressed because it is too large
Load Diff
35
test/cpp/api/misc_t.cpp
Normal file
35
test/cpp/api/misc_t.cpp
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
#include "test.h"
|
||||
|
||||
CASE("misc/no_grad/1") {
|
||||
no_grad_guard guard;
|
||||
auto model = Linear(5, 2).make();
|
||||
auto x = Var(at::CPU(at::kFloat).randn({10, 5}), true);
|
||||
auto y = model->forward({x})[0];
|
||||
Variable s = y.sum();
|
||||
|
||||
backward(s);
|
||||
EXPECT(!model->parameters()["weight"].grad().defined());
|
||||
};
|
||||
|
||||
CASE("misc/random/seed_cpu") {
|
||||
int size = 100;
|
||||
setSeed(7);
|
||||
auto x1 = Var(at::CPU(at::kFloat).randn({size}));
|
||||
setSeed(7);
|
||||
auto x2 = Var(at::CPU(at::kFloat).randn({size}));
|
||||
|
||||
auto l_inf = (x1.data() - x2.data()).abs().max().toCFloat();
|
||||
EXPECT(l_inf < 1e-10);
|
||||
};
|
||||
|
||||
CASE("misc/random/seed_cuda") {
|
||||
CUDA_GUARD;
|
||||
int size = 100;
|
||||
setSeed(7);
|
||||
auto x1 = Var(at::CUDA(at::kFloat).randn({size}));
|
||||
setSeed(7);
|
||||
auto x2 = Var(at::CUDA(at::kFloat).randn({size}));
|
||||
|
||||
auto l_inf = (x1.data() - x2.data()).abs().max().toCFloat();
|
||||
EXPECT(l_inf < 1e-10);
|
||||
};
|
||||
99
test/cpp/api/optim_t.cpp
Normal file
99
test/cpp/api/optim_t.cpp
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
#include "test.h"
|
||||
|
||||
bool test_optimizer_xor(Optimizer optim, std::shared_ptr<ContainerList> model) {
|
||||
float running_loss = 1;
|
||||
int epoch = 0;
|
||||
while (running_loss > 0.1) {
|
||||
auto bs = 4U;
|
||||
auto inp = at::CPU(at::kFloat).tensor({bs, 2});
|
||||
auto lab = at::CPU(at::kFloat).tensor({bs});
|
||||
for (auto i = 0U; i < bs; i++) {
|
||||
auto a = std::rand() % 2;
|
||||
auto b = std::rand() % 2;
|
||||
auto c = a ^ b;
|
||||
inp[i][0] = a;
|
||||
inp[i][1] = b;
|
||||
lab[i] = c;
|
||||
}
|
||||
// forward
|
||||
auto x = Var(inp);
|
||||
auto y = Var(lab, false);
|
||||
for (auto layer : *model) x = layer->forward({x})[0].sigmoid_();
|
||||
Variable loss = at::binary_cross_entropy(x, y);
|
||||
|
||||
optim->zero_grad();
|
||||
backward(loss);
|
||||
optim->step();
|
||||
|
||||
running_loss = running_loss * 0.99 + loss.data().sum().toCFloat() * 0.01;
|
||||
if (epoch > 3000) {
|
||||
return false;
|
||||
}
|
||||
epoch++;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
CASE("optim/sgd") {
|
||||
auto model = ContainerList()
|
||||
.append(Linear(2, 8).make())
|
||||
.append(Linear(8, 1).make())
|
||||
.make();
|
||||
|
||||
auto optim = SGD(model, 1e-1).momentum(0.9).nesterov().weight_decay(1e-6).make();
|
||||
EXPECT(test_optimizer_xor(optim, model));
|
||||
}
|
||||
|
||||
CASE("optim/adagrad") {
|
||||
auto model = ContainerList()
|
||||
.append(Linear(2, 8).make())
|
||||
.append(Linear(8, 1).make())
|
||||
.make();
|
||||
|
||||
auto optim = Adagrad(model, 1.0).weight_decay(1e-6).lr_decay(1e-3).make();
|
||||
EXPECT(test_optimizer_xor(optim, model));
|
||||
}
|
||||
|
||||
CASE("optim/rmsprop") {
|
||||
{
|
||||
auto model = ContainerList()
|
||||
.append(Linear(2, 8).make())
|
||||
.append(Linear(8, 1).make())
|
||||
.make();
|
||||
|
||||
auto optim = RMSprop(model, 1e-1).momentum(0.9).weight_decay(1e-6).make();
|
||||
EXPECT(test_optimizer_xor(optim, model));
|
||||
}
|
||||
|
||||
{
|
||||
auto model = ContainerList()
|
||||
.append(Linear(2, 8).make())
|
||||
.append(Linear(8, 1).make())
|
||||
.make();
|
||||
|
||||
auto optim = RMSprop(model, 1e-1).centered().make();
|
||||
EXPECT(test_optimizer_xor(optim, model));
|
||||
}
|
||||
}
|
||||
|
||||
CASE("optim/adam") {
|
||||
auto model = ContainerList()
|
||||
.append(Linear(2, 8).make())
|
||||
.append(Linear(8, 1).make())
|
||||
.make();
|
||||
|
||||
auto optim = Adam(model, 1.0).weight_decay(1e-6).make();
|
||||
EXPECT(test_optimizer_xor(optim, model));
|
||||
}
|
||||
|
||||
CASE("optim/amsgrad") {
|
||||
auto model = ContainerList()
|
||||
.append(Linear(2, 8).make())
|
||||
.append(Linear(8, 1).make())
|
||||
.make();
|
||||
|
||||
auto optim = Adam(model, 0.1).weight_decay(1e-6).amsgrad().make();
|
||||
EXPECT(test_optimizer_xor(optim, model));
|
||||
}
|
||||
|
||||
|
||||
168
test/cpp/api/rnn_t.cpp
Normal file
168
test/cpp/api/rnn_t.cpp
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
#include "test.h"
|
||||
|
||||
template <typename R, typename Func>
|
||||
bool test_RNN_xor(Func&& model_maker, bool cuda = false) {
|
||||
auto nhid = 32;
|
||||
auto model = SimpleContainer().make();
|
||||
auto l1 = model->add(Linear(1, nhid).make(), "l1");
|
||||
auto rnn = model->add(model_maker(nhid), "rnn");
|
||||
auto lo = model->add(Linear(nhid, 1).make(), "lo");
|
||||
|
||||
auto optim = Adam(model, 1e-2).make();
|
||||
|
||||
auto forward_op = [&](Variable x) {
|
||||
auto T = x.size(0);
|
||||
auto B = x.size(1);
|
||||
x = x.view({T * B, 1});
|
||||
x = l1->forward({x})[0].view({T, B, nhid}).tanh_();
|
||||
x = rnn->forward({x})[0][T-1];
|
||||
x = lo->forward({x})[0];
|
||||
return x;
|
||||
};
|
||||
|
||||
if (cuda) {
|
||||
model->cuda();
|
||||
}
|
||||
|
||||
float running_loss = 1;
|
||||
int epoch = 0;
|
||||
auto max_epoch = 1500;
|
||||
while (running_loss > 1e-2) {
|
||||
auto bs = 16U;
|
||||
auto nlen = 5U;
|
||||
auto inp = at::CPU(at::kFloat).rand({nlen, bs, 1}).round().toType(at::kFloat);
|
||||
auto lab = inp.sum(0);
|
||||
|
||||
if (cuda) {
|
||||
inp = inp.toBackend(at::kCUDA);
|
||||
lab = lab.toBackend(at::kCUDA);
|
||||
}
|
||||
|
||||
auto x = Var(inp);
|
||||
auto y = Var(lab, false);
|
||||
x = forward_op(x);
|
||||
Variable loss = at::mse_loss(x, y);
|
||||
|
||||
optim->zero_grad();
|
||||
backward(loss);
|
||||
optim->step();
|
||||
|
||||
running_loss = running_loss * 0.99 + loss.toCFloat() * 0.01;
|
||||
if (epoch > max_epoch) {
|
||||
return false;
|
||||
}
|
||||
epoch++;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
CASE("RNN/LSTM/sizes") {
|
||||
auto model = LSTM(128, 64).nlayers(2).dropout(0.2).make();
|
||||
Variable x = Var(at::CPU(at::kFloat).randn({10, 16, 128}));
|
||||
auto tup = model->forward({x});
|
||||
auto y = x.mean();
|
||||
|
||||
auto out = tup[0];
|
||||
auto hids = tup[1];
|
||||
|
||||
backward(y);
|
||||
EXPECT(out.ndimension() == 3);
|
||||
EXPECT(out.size(0) == 10);
|
||||
EXPECT(out.size(1) == 16);
|
||||
EXPECT(out.size(2) == 64);
|
||||
|
||||
EXPECT(hids.ndimension() == 4);
|
||||
EXPECT(hids.size(0) == 2); // 2 layers
|
||||
EXPECT(hids.size(1) == 2); // c and h
|
||||
EXPECT(hids.size(2) == 16); // Batch size of 16
|
||||
EXPECT(hids.size(3) == 64); // 64 hidden dims
|
||||
|
||||
// Something is in the hiddens
|
||||
EXPECT(hids.norm().toCFloat() > 0);
|
||||
|
||||
Variable diff = model->forward({x, hids})[1] - hids;
|
||||
|
||||
// Hiddens changed
|
||||
EXPECT(diff.data().abs().sum().toCFloat() > 1e-3);
|
||||
};
|
||||
|
||||
CASE("RNN/LSTM/outputs") {
|
||||
// Make sure the outputs match pytorch outputs
|
||||
auto model = LSTM(2, 2).make();
|
||||
for (auto& v : model->parameters()) {
|
||||
float size = v.second.numel();
|
||||
auto p = static_cast<float*>(v.second.data().storage()->data());
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
p[i] = i/size;
|
||||
}
|
||||
}
|
||||
|
||||
Variable x = Var(at::CPU(at::kFloat).tensor({3, 4, 2}));
|
||||
float size = x.data().numel();
|
||||
auto p = static_cast<float*>(x.data().storage()->data());
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
p[i] = (size - i) / size;
|
||||
}
|
||||
|
||||
auto out = model->forward({x});
|
||||
EXPECT(out[0].ndimension() == 3);
|
||||
EXPECT(out[0].size(0) == 3);
|
||||
EXPECT(out[0].size(1) == 4);
|
||||
EXPECT(out[0].size(2) == 2);
|
||||
|
||||
auto flat = out[0].data().view(3*4*2);
|
||||
float c_out[] = {0.4391, 0.5402, 0.4330, 0.5324, 0.4261, 0.5239, 0.4183,
|
||||
0.5147, 0.6822, 0.8064, 0.6726, 0.7968, 0.6620, 0.7860, 0.6501, 0.7741,
|
||||
0.7889, 0.9003, 0.7769, 0.8905, 0.7635, 0.8794, 0.7484, 0.8666};
|
||||
for (size_t i = 0; i < 3*4*2; i++) {
|
||||
EXPECT(std::abs(flat[i].toCFloat() - c_out[i]) < 1e-3);
|
||||
}
|
||||
|
||||
EXPECT(out[1].ndimension() == 4); // T x (hx, cx) x B x 2
|
||||
EXPECT(out[1].size(0) == 1);
|
||||
EXPECT(out[1].size(1) == 2);
|
||||
EXPECT(out[1].size(2) == 4);
|
||||
EXPECT(out[1].size(3) == 2);
|
||||
flat = out[1].data().view(16);
|
||||
float h_out[] = {0.7889, 0.9003, 0.7769, 0.8905, 0.7635, 0.8794, 0.7484,
|
||||
0.8666, 1.1647, 1.6106, 1.1425, 1.5726, 1.1187, 1.5329, 1.0931, 1.4911};
|
||||
for (size_t i = 0; i < 16; i++) {
|
||||
EXPECT(std::abs(flat[i].toCFloat() - h_out[i]) < 1e-3);
|
||||
}
|
||||
};
|
||||
|
||||
CASE("integration/RNN/LSTM") {
|
||||
EXPECT(test_RNN_xor<LSTM>([](int s) { return LSTM(s, s).nlayers(2).make(); }));
|
||||
};
|
||||
|
||||
CASE("integration/RNN/GRU") {
|
||||
EXPECT(test_RNN_xor<GRU>([](int s) { return GRU(s, s).nlayers(2).make(); }));
|
||||
};
|
||||
|
||||
CASE("integration/RNN/RNN/Relu") {
|
||||
EXPECT(test_RNN_xor<RNN>([](int s) { return RNN(s, s, RNN::Mode::Relu).nlayers(2).make(); }));
|
||||
};
|
||||
|
||||
CASE("integration/RNN/RNN/Tanh") {
|
||||
EXPECT(test_RNN_xor<RNN>([](int s) { return RNN(s, s, RNN::Mode::Tanh).nlayers(2).make(); }));
|
||||
};
|
||||
|
||||
CASE("integration/RNN/cuda/LSTM") {
|
||||
CUDA_GUARD;
|
||||
EXPECT(test_RNN_xor<LSTM>([](int s) { return LSTM(s, s).nlayers(2).make(); }, true));
|
||||
};
|
||||
|
||||
CASE("integration/RNN/cuda/GRU") {
|
||||
CUDA_GUARD;
|
||||
EXPECT(test_RNN_xor<GRU>([](int s) { return GRU(s, s).nlayers(2).make(); }, true));
|
||||
};
|
||||
|
||||
CASE("integration/RNN/cuda/RNN/Relu") {
|
||||
CUDA_GUARD;
|
||||
EXPECT(test_RNN_xor<RNN>([](int s) { return RNN(s, s, RNN::Mode::Relu).nlayers(2).make(); }, true));
|
||||
};
|
||||
|
||||
CASE("integration/RNN/cuda/RNN/Tanh") {
|
||||
CUDA_GUARD;
|
||||
EXPECT(test_RNN_xor<RNN>([](int s) { return RNN(s, s, RNN::Mode::Tanh).nlayers(2).make(); }, true));
|
||||
};
|
||||
261
test/cpp/api/serialization_t.cpp
Normal file
261
test/cpp/api/serialization_t.cpp
Normal file
|
|
@ -0,0 +1,261 @@
|
|||
#include "test.h"
|
||||
|
||||
#include "cereal/archives/portable_binary.hpp"
|
||||
|
||||
CASE("serialization/undefined") {
|
||||
auto x = at::Tensor();
|
||||
|
||||
EXPECT(!x.defined());
|
||||
|
||||
auto y = at::CPU(at::kFloat).randn({5});
|
||||
|
||||
std::stringstream ss;
|
||||
save(ss, &x);
|
||||
load(ss, &y);
|
||||
|
||||
EXPECT(!y.defined());
|
||||
}
|
||||
|
||||
CASE("serialization/cputypes") {
|
||||
for (int i = 0; i < static_cast<int>(at::ScalarType::NumOptions); i++) {
|
||||
if (i == static_cast<int>(at::ScalarType::Half)) {
|
||||
// XXX can't serialize half tensors at the moment since contiguous() is
|
||||
// not implemented for this type;
|
||||
continue;
|
||||
} else if (i == static_cast<int>(at::ScalarType::Undefined)) {
|
||||
// We can't construct a tensor for this type. This is tested in
|
||||
// serialization/undefined anyway.
|
||||
continue;
|
||||
}
|
||||
|
||||
auto x =
|
||||
at::getType(at::kCPU, static_cast<at::ScalarType>(i)).ones({5, 5});
|
||||
auto y = at::Tensor();
|
||||
|
||||
std::stringstream ss;
|
||||
save(ss, &x);
|
||||
load(ss, &y);
|
||||
|
||||
EXPECT(y.defined());
|
||||
EXPECT(x.sizes().vec() == y.sizes().vec());
|
||||
if (at::isIntegralType(static_cast<at::ScalarType>(i))) {
|
||||
EXPECT(x.equal(y));
|
||||
} else {
|
||||
EXPECT(x.allclose(y));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CASE("serialization/binary") {
|
||||
auto x = at::CPU(at::kFloat).randn({5, 5});
|
||||
auto y = at::Tensor();
|
||||
|
||||
std::stringstream ss;
|
||||
{
|
||||
cereal::BinaryOutputArchive archive(ss);
|
||||
archive(x);
|
||||
}
|
||||
{
|
||||
cereal::BinaryInputArchive archive(ss);
|
||||
archive(y);
|
||||
}
|
||||
|
||||
EXPECT(y.defined());
|
||||
EXPECT(x.sizes().vec() == y.sizes().vec());
|
||||
EXPECT(x.allclose(y));
|
||||
}
|
||||
|
||||
CASE("serialization/portable_binary") {
|
||||
auto x = at::CPU(at::kFloat).randn({5, 5});
|
||||
auto y = at::Tensor();
|
||||
|
||||
std::stringstream ss;
|
||||
{
|
||||
cereal::PortableBinaryOutputArchive archive(ss);
|
||||
archive(x);
|
||||
}
|
||||
{
|
||||
cereal::PortableBinaryInputArchive archive(ss);
|
||||
archive(y);
|
||||
}
|
||||
|
||||
EXPECT(y.defined());
|
||||
EXPECT(x.sizes().vec() == y.sizes().vec());
|
||||
EXPECT(x.allclose(y));
|
||||
}
|
||||
|
||||
CASE("serialization/resized") {
|
||||
auto x = at::CPU(at::kFloat).randn({11, 5});
|
||||
x.resize_({5, 5});
|
||||
auto y = at::Tensor();
|
||||
|
||||
std::stringstream ss;
|
||||
{
|
||||
cereal::BinaryOutputArchive archive(ss);
|
||||
archive(x);
|
||||
}
|
||||
{
|
||||
cereal::BinaryInputArchive archive(ss);
|
||||
archive(y);
|
||||
}
|
||||
|
||||
EXPECT(y.defined());
|
||||
EXPECT(x.sizes().vec() == y.sizes().vec());
|
||||
EXPECT(x.allclose(y));
|
||||
}
|
||||
|
||||
CASE("serialization/sliced") {
|
||||
auto x = at::CPU(at::kFloat).randn({11, 5});
|
||||
x = x.slice(0, 1, 3);
|
||||
auto y = at::Tensor();
|
||||
|
||||
std::stringstream ss;
|
||||
{
|
||||
cereal::BinaryOutputArchive archive(ss);
|
||||
archive(x);
|
||||
}
|
||||
{
|
||||
cereal::BinaryInputArchive archive(ss);
|
||||
archive(y);
|
||||
}
|
||||
|
||||
EXPECT(y.defined());
|
||||
EXPECT(x.sizes().vec() == y.sizes().vec());
|
||||
EXPECT(x.allclose(y));
|
||||
}
|
||||
|
||||
CASE("serialization/noncontig") {
|
||||
auto x = at::CPU(at::kFloat).randn({11, 5});
|
||||
x = x.slice(1, 1, 4);
|
||||
auto y = at::Tensor();
|
||||
|
||||
std::stringstream ss;
|
||||
{
|
||||
cereal::BinaryOutputArchive archive(ss);
|
||||
archive(x);
|
||||
}
|
||||
{
|
||||
cereal::BinaryInputArchive archive(ss);
|
||||
archive(y);
|
||||
}
|
||||
|
||||
EXPECT(y.defined());
|
||||
EXPECT(x.sizes().vec() == y.sizes().vec());
|
||||
EXPECT(x.allclose(y));
|
||||
}
|
||||
|
||||
CASE("serialization/xor") {
|
||||
// We better be able to save and load a XOR model!
|
||||
auto makeModel = []() {
|
||||
return ContainerList()
|
||||
.append(Linear(2, 8).make())
|
||||
.append(Linear(8, 1).make())
|
||||
.make();
|
||||
};
|
||||
auto getLoss = [](std::shared_ptr<ContainerList> model, uint32_t bs) {
|
||||
auto inp = at::CPU(at::kFloat).tensor({bs, 2});
|
||||
auto lab = at::CPU(at::kFloat).tensor({bs});
|
||||
for (auto i = 0U; i < bs; i++) {
|
||||
auto a = std::rand() % 2;
|
||||
auto b = std::rand() % 2;
|
||||
auto c = a ^ b;
|
||||
inp[i][0] = a;
|
||||
inp[i][1] = b;
|
||||
lab[i] = c;
|
||||
}
|
||||
|
||||
// forward
|
||||
auto x = Var(inp);
|
||||
auto y = Var(lab, false);
|
||||
for (auto layer : *model) x = layer->forward({x})[0].sigmoid_();
|
||||
return at::binary_cross_entropy(x, y);
|
||||
};
|
||||
|
||||
auto model = makeModel();
|
||||
auto model2 = makeModel();
|
||||
auto model3 = makeModel();
|
||||
auto optim = SGD(model, 1e-1).momentum(0.9).nesterov().weight_decay(1e-6).make();
|
||||
|
||||
float running_loss = 1;
|
||||
int epoch = 0;
|
||||
while (running_loss > 0.1) {
|
||||
Variable loss = getLoss(model, 4);
|
||||
optim->zero_grad();
|
||||
backward(loss);
|
||||
optim->step();
|
||||
|
||||
running_loss = running_loss * 0.99 + loss.data().sum().toCFloat() * 0.01;
|
||||
EXPECT(epoch < 3000);
|
||||
epoch++;
|
||||
}
|
||||
|
||||
std::stringstream ss;
|
||||
save(ss, model);
|
||||
load(ss, model2);
|
||||
|
||||
auto loss = getLoss(model2, 100);
|
||||
EXPECT(loss.toCFloat() < 0.1);
|
||||
|
||||
CUDA_GUARD;
|
||||
model2->cuda();
|
||||
ss.clear();
|
||||
save(ss, model2);
|
||||
load(ss, model3);
|
||||
|
||||
loss = getLoss(model3, 100);
|
||||
EXPECT(loss.toCFloat() < 0.1);
|
||||
}
|
||||
|
||||
CASE("serialization/optim") {
|
||||
auto model1 = Linear(5, 2).make();
|
||||
auto model2 = Linear(5, 2).make();
|
||||
auto model3 = Linear(5, 2).make();
|
||||
|
||||
// Models 1, 2, 3 will have the same params
|
||||
std::stringstream ss;
|
||||
save(ss, model1);
|
||||
load(ss, model2);
|
||||
ss.seekg(0, std::ios::beg);
|
||||
load(ss, model3);
|
||||
|
||||
// Make some optimizers with momentum (and thus state)
|
||||
auto optim1 = SGD(model1, 1e-1).momentum(0.9).make();
|
||||
auto optim2 = SGD(model2, 1e-1).momentum(0.9).make();
|
||||
auto optim2_2 = SGD(model2, 1e-1).momentum(0.9).make();
|
||||
auto optim3 = SGD(model3, 1e-1).momentum(0.9).make();
|
||||
auto optim3_2 = SGD(model3, 1e-1).momentum(0.9).make();
|
||||
|
||||
auto x = Var(at::CPU(at::kFloat).ones({10, 5}), true);
|
||||
|
||||
auto step = [&](Optimizer optim, Container model) {
|
||||
optim->zero_grad();
|
||||
auto y = model->forward({x})[0].sum();
|
||||
backward(y);
|
||||
optim->step();
|
||||
};
|
||||
|
||||
// Do 2 steps of model1
|
||||
step(optim1, model1);
|
||||
step(optim1, model1);
|
||||
|
||||
// Do 2 steps of model 2 without saving the optimizer
|
||||
step(optim2, model2);
|
||||
step(optim2_2, model2);
|
||||
|
||||
// Do 2 steps of model 3 while saving the optimizer
|
||||
step(optim3, model3);
|
||||
ss.clear();
|
||||
save(ss, optim3);
|
||||
load(ss, optim3_2);
|
||||
step(optim3_2, model3);
|
||||
|
||||
auto param1 = model1->parameters();
|
||||
auto param2 = model2->parameters();
|
||||
auto param3 = model3->parameters();
|
||||
for (auto& p : param1) {
|
||||
auto name = p.first;
|
||||
// Model 1 and 3 should be the same
|
||||
EXPECT(param1[name].norm().toCFloat() == param3[name].norm().toCFloat());
|
||||
EXPECT(param1[name].norm().toCFloat() != param2[name].norm().toCFloat());
|
||||
}
|
||||
}
|
||||
10
test/cpp/api/test.cpp
Normal file
10
test/cpp/api/test.cpp
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
#include "test.h"
|
||||
|
||||
lest::tests & specification() {
|
||||
static lest::tests tests;
|
||||
return tests;
|
||||
}
|
||||
|
||||
int main( int argc, char * argv[] ) {
|
||||
return lest::run( specification(), argc, argv);
|
||||
}
|
||||
14
test/cpp/api/test.h
Normal file
14
test/cpp/api/test.h
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
#pragma once
|
||||
|
||||
#include "lest.hpp"
|
||||
#include <torch/autograd.h>
|
||||
|
||||
using namespace autograd;
|
||||
|
||||
#define CASE( name ) lest_CASE( specification(), name )
|
||||
|
||||
#define CUDA_GUARD if (!hasCuda()) {\
|
||||
std::cerr << "No cuda, skipping test" << std::endl; return;\
|
||||
}
|
||||
|
||||
extern lest::tests & specification();
|
||||
1
third_party/cereal
vendored
Submodule
1
third_party/cereal
vendored
Submodule
|
|
@ -0,0 +1 @@
|
|||
Subproject commit 51cbda5f30e56c801c07fe3d3aba5d7fb9e6cca4
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
|
||||
cmake_policy(VERSION 3.0)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 14)
|
||||
set(CMAKE_CXX_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
|
|
@ -155,10 +155,19 @@ set(TORCH_SRCS
|
|||
${TORCH_SRC_DIR}/csrc/jit/type.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/interpreter_autograd_function.cpp
|
||||
${TORCH_SRC_DIR}/csrc/Exceptions.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/detail.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/containers.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/optimizers.cpp
|
||||
)
|
||||
|
||||
add_library(torch SHARED ${TORCH_SRCS})
|
||||
|
||||
target_compile_options(torch PRIVATE -Wall -Wextra)
|
||||
|
||||
if ($ENV{WERROR})
|
||||
target_compile_options(torch PRIVATE -Werror)
|
||||
endif()
|
||||
|
||||
target_link_libraries(torch
|
||||
${TORCH_CUDA_LIBRARIES}
|
||||
${ATEN_LIBRARY}
|
||||
|
|
@ -169,6 +178,10 @@ set(COMMON_INCLUDES
|
|||
"${ATEN_INCLUDE_DIR}/TH"
|
||||
"${ATEN_BUILD_INCLUDE_DIR}"
|
||||
"${ATEN_BUILD_PATH}/src/TH"
|
||||
"${TORCH_SRC_DIR}/csrc/api/"
|
||||
"${TORCH_SRC_DIR}/csrc/api/include"
|
||||
"${TORCH_SRC_DIR}/../third_party/cereal/include" # For cereal/
|
||||
"${TORCH_SRC_DIR}/../"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}"
|
||||
"${CUDA_INCLUDE_DIRS}")
|
||||
|
||||
|
|
@ -193,9 +206,9 @@ install(TARGETS torch
|
|||
LIBRARY DESTINATION "${TORCH_INSTALL_LIB_DIR}"
|
||||
ARCHIVE DESTINATION "${TORCH_INSTALL_LIB_DIR}")
|
||||
|
||||
set(TORCH_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/test_jit.cpp)
|
||||
# JIT Tests. TODO: Put into test/cpp/jit folder
|
||||
|
||||
add_executable(test_jit ${TORCH_TEST_SRCS})
|
||||
add_executable(test_jit ${TORCH_SRC_DIR}/csrc/jit/test_jit.cpp)
|
||||
|
||||
target_link_libraries(test_jit torch)
|
||||
|
||||
|
|
@ -204,7 +217,19 @@ target_include_directories(test_jit PUBLIC
|
|||
"${TORCH_SRC_DIR}/../third_party/catch/single_include"
|
||||
"${COMMON_INCLUDES}")
|
||||
|
||||
install(TARGETS test_jit
|
||||
RUNTIME DESTINATION "${TORCH_INSTALL_BIN_DIR}"
|
||||
LIBRARY DESTINATION "${TORCH_INSTALL_LIB_DIR}"
|
||||
ARCHIVE DESTINATION "${TORCH_INSTALL_LIB_DIR}")
|
||||
# API Tests
|
||||
|
||||
set(TORCH_API_TEST_DIR "${TORCH_SRC_DIR}/../test/cpp/api")
|
||||
|
||||
add_executable(test_api
|
||||
${TORCH_API_TEST_DIR}/test.cpp
|
||||
${TORCH_API_TEST_DIR}/container_t.cpp
|
||||
${TORCH_API_TEST_DIR}/misc_t.cpp
|
||||
${TORCH_API_TEST_DIR}/rnn_t.cpp
|
||||
${TORCH_API_TEST_DIR}/integration_t.cpp
|
||||
${TORCH_API_TEST_DIR}/optim_t.cpp
|
||||
${TORCH_API_TEST_DIR}/serialization_t.cpp
|
||||
)
|
||||
|
||||
target_compile_options(test_api PRIVATE -Dlest_FEATURE_AUTO_REGISTER=1)
|
||||
target_link_libraries(test_api torch)
|
||||
|
|
|
|||
81
tools/download_mnist.py
Normal file
81
tools/download_mnist.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import gzip
|
||||
import os
|
||||
import sys
|
||||
import urllib
|
||||
|
||||
try:
|
||||
from urllib.error import URLError
|
||||
from urllib.request import urlretrieve
|
||||
except ImportError:
|
||||
from urllib2 import URLError
|
||||
from urllib import urlretrieve
|
||||
|
||||
RESOURCES = [
|
||||
'train-images-idx3-ubyte.gz',
|
||||
'train-labels-idx1-ubyte.gz',
|
||||
't10k-images-idx3-ubyte.gz',
|
||||
't10k-labels-idx1-ubyte.gz',
|
||||
]
|
||||
|
||||
|
||||
def report_download_progress(chunk_number, chunk_size, file_size):
|
||||
if file_size != -1:
|
||||
percent = min(1, (chunk_number * chunk_size) / file_size)
|
||||
bar = '#' * int(64 * percent)
|
||||
sys.stdout.write('\r0% |{:<64}| {}%'.format(bar, int(percent * 100)))
|
||||
|
||||
|
||||
def download(destination_path, url):
|
||||
if os.path.exists(destination_path):
|
||||
print('{} already exists, skipping ...'.format(destination_path))
|
||||
else:
|
||||
print('Downloading {} ...'.format(url))
|
||||
try:
|
||||
urlretrieve(
|
||||
url, destination_path, reporthook=report_download_progress)
|
||||
except URLError:
|
||||
raise RuntimeError('Error downloading resource!')
|
||||
finally:
|
||||
# Just a newline.
|
||||
print()
|
||||
|
||||
|
||||
def unzip(zipped_path):
|
||||
unzipped_path = os.path.splitext(zipped_path)[0]
|
||||
if os.path.exists(unzipped_path):
|
||||
print('{} already exists, skipping ... '.format(unzipped_path))
|
||||
return
|
||||
with gzip.open(zipped_path, 'rb') as zipped_file:
|
||||
with open(unzipped_path, 'wb') as unzipped_file:
|
||||
unzipped_file.write(zipped_file.read())
|
||||
print('Unzipped {} ...'.format(zipped_path))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Download the MNIST dataset from the internet')
|
||||
parser.add_argument(
|
||||
'-d', '--destination', default='.', help='Destination directory')
|
||||
options = parser.parse_args()
|
||||
|
||||
if not os.path.exists(options.destination):
|
||||
os.makedirs(options.destination)
|
||||
|
||||
try:
|
||||
for resource in RESOURCES:
|
||||
path = os.path.join(options.destination, resource)
|
||||
url = 'http://yann.lecun.com/exdb/mnist/{}'.format(resource)
|
||||
download(path, url)
|
||||
unzip(path)
|
||||
except KeyboardInterrupt:
|
||||
print('Interrupted')
|
||||
|
||||
print('Done')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
50
torch/csrc/api/README.md
Normal file
50
torch/csrc/api/README.md
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
# AUTOGRADPP
|
||||
|
||||
This is an experimental C++ frontend to pytorch's C++ backend. Use at your own
|
||||
risk.
|
||||
|
||||
How to build:
|
||||
```
|
||||
git submodule update --init --recursive
|
||||
|
||||
cd pytorch
|
||||
# On Linux:
|
||||
python setup.py build
|
||||
# On macOS (may need to prefix with `MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++` when using anaconda)
|
||||
LDSHARED="cc -dynamiclib -undefined dynamic_lookup" python setup.py build
|
||||
|
||||
cd ..; mkdir -p build; cd build
|
||||
cmake .. -DPYTHON_EXECUTABLE:FILEPATH=$(which python) # helpful if you use anaconda
|
||||
make -j
|
||||
```
|
||||
|
||||
# Stuff
|
||||
|
||||
- Check out the [MNIST example](https://github.com/ebetica/autogradpp/blob/eee977ddd377c484af5fce09ae8676410bb6fcce/tests/integration_t.cpp#L320-L355),
|
||||
which tries to replicate PyTorch's MNIST model + training loop
|
||||
- The principled way to write a model is probably something like
|
||||
```
|
||||
AUTOGRAD_CONTAINER_CLASS(MyModel) {
|
||||
// This does a 2D convolution, followed by global sum pooling, followed by a linear.
|
||||
public:
|
||||
void initialize_containers() override {
|
||||
myConv_ = add(Conv2d(1, 50, 3, 3).stride(2).make(), "conv");
|
||||
myLinear_ = add(Linear(50, 1).make(), "linear");
|
||||
}
|
||||
variable_list forward(variable_list x) override {
|
||||
auto v = myConv_->forward(x);
|
||||
v = v.mean(-1).mean(-1);
|
||||
return myLinear_.forward({v});
|
||||
}
|
||||
private:
|
||||
Container myLinear_;
|
||||
Container myConv_;
|
||||
}
|
||||
```
|
||||
|
||||
Some things are not implemented:
|
||||
- SGD, Adagrad, RMSprop, and Adam are the only optimizers implemented
|
||||
- Bidirectional, batch first, and PackedSequence are not implemented for LSTMs
|
||||
- Sparse Tensors might work but are very untested
|
||||
|
||||
Otherwise, lots of other things work. There may be breaking API changes.
|
||||
4
torch/csrc/api/include/torch/autograd.h
Normal file
4
torch/csrc/api/include/torch/autograd.h
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
#pragma once
|
||||
#include "torch/containers.h"
|
||||
#include "torch/optimizers.h"
|
||||
#include "torch/serialization.h"
|
||||
446
torch/csrc/api/include/torch/containers.h
Normal file
446
torch/csrc/api/include/torch/containers.h
Normal file
|
|
@ -0,0 +1,446 @@
|
|||
#pragma once
|
||||
|
||||
#include "detail.h"
|
||||
|
||||
#include "torch/csrc/autograd/variable.h"
|
||||
|
||||
#define AUTOGRAD_CONTAINER_CLASS(Type) \
|
||||
class Type : public autograd::Container_CRTP<Type>
|
||||
|
||||
namespace autograd {
|
||||
class ContainerImpl {
|
||||
public:
|
||||
// Only construct parameters in initialize_parameters, and
|
||||
// containers in initialize_containers. Most of the time, the containers are
|
||||
// the only thing you need to add.
|
||||
// You are guaranteed that containers are added before parameters.
|
||||
virtual void initialize_containers(){};
|
||||
virtual void initialize_parameters(){};
|
||||
virtual void reset_parameters(){};
|
||||
|
||||
virtual variable_list forward(variable_list) = 0;
|
||||
virtual Container clone() const = 0;
|
||||
|
||||
std::map<std::string, Variable> parameters() const;
|
||||
Variable& param(std::string const&);
|
||||
|
||||
virtual void cuda();
|
||||
virtual void cpu();
|
||||
void train();
|
||||
void eval();
|
||||
|
||||
at::Type& DefaultTensor(at::ScalarType s);
|
||||
|
||||
std::unordered_map<std::string, Container> children_;
|
||||
std::unordered_map<std::string, Variable> params_;
|
||||
bool cuda_ = false;
|
||||
bool train_ = true;
|
||||
|
||||
template <class Archive>
|
||||
void save(Archive& ar) const {
|
||||
auto params = parameters();
|
||||
std::size_t size = params.size();
|
||||
ar(size);
|
||||
for (auto& p : params) {
|
||||
ar(p.first, p.second);
|
||||
}
|
||||
}
|
||||
|
||||
template <class Archive>
|
||||
void load(Archive& ar) {
|
||||
auto params = parameters();
|
||||
std::size_t size;
|
||||
ar(size);
|
||||
std::string name;
|
||||
for (std::size_t i = 0; i < size; i++) {
|
||||
ar(name);
|
||||
ar(params[name]);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
Container add(Container, std::string const&);
|
||||
// Be careful when registering Tensors that are not variables
|
||||
Variable& add(Variable, std::string const&);
|
||||
};
|
||||
|
||||
template <class Derived>
|
||||
class Container_CRTP : public ContainerImpl {
|
||||
public:
|
||||
std::shared_ptr<Derived> make() const {
|
||||
auto ptr = std::make_shared<Derived>(*static_cast<const Derived*>(this));
|
||||
ptr->initialize_containers();
|
||||
ptr->initialize_parameters();
|
||||
ptr->reset_parameters();
|
||||
return ptr;
|
||||
}
|
||||
|
||||
Container clone() const override {
|
||||
auto ptr = std::make_shared<Derived>(*static_cast<const Derived*>(this));
|
||||
ptr->children_.clear();
|
||||
ptr->params_.clear();
|
||||
ptr->initialize_containers();
|
||||
ptr->initialize_parameters();
|
||||
auto newParams = ptr->parameters();
|
||||
for (auto& param : parameters()) {
|
||||
newParams[param.first].data().copy_(param.second.data());
|
||||
}
|
||||
if (cuda_) {
|
||||
ptr->cuda();
|
||||
} else {
|
||||
ptr->cpu();
|
||||
}
|
||||
return ptr;
|
||||
}
|
||||
};
|
||||
|
||||
template <class Derived>
|
||||
class ContainerListImpl : public Container_CRTP<Derived> {
|
||||
// Lets you use a container like a vector without making a new class,
|
||||
// just for simple implementations
|
||||
public:
|
||||
virtual variable_list forward(variable_list) override {
|
||||
throw std::runtime_error(
|
||||
"ContainerList has no forward, maybe you"
|
||||
" wanted to subclass and override this function?");
|
||||
}
|
||||
|
||||
Container add(Container m) {
|
||||
return append(m).children_.back();
|
||||
}
|
||||
|
||||
ContainerListImpl<Derived>& append(Container m) {
|
||||
children_.push_back(m);
|
||||
ContainerImpl::add(children_.back(), std::to_string(size() - 1));
|
||||
return *this;
|
||||
}
|
||||
|
||||
Container& operator[](int index) {
|
||||
return children_[index];
|
||||
}
|
||||
|
||||
int size() {
|
||||
return children_.size();
|
||||
}
|
||||
|
||||
std::vector<Container>::iterator begin() {
|
||||
return children_.begin();
|
||||
}
|
||||
|
||||
std::vector<Container>::iterator end() {
|
||||
return children_.end();
|
||||
}
|
||||
|
||||
std::vector<Container> children_;
|
||||
};
|
||||
|
||||
class ContainerList : public ContainerListImpl<ContainerList> {};
|
||||
|
||||
class Sequential : public ContainerListImpl<Sequential> {
|
||||
// Mimics nn.Sequential from pytorch.
|
||||
public:
|
||||
variable_list forward(variable_list input) override {
|
||||
for (auto& container : children_) {
|
||||
input = container->forward(input);
|
||||
}
|
||||
return input;
|
||||
}
|
||||
|
||||
Container add(Container m, std::string name = "") {
|
||||
return append(m, name).children_.back();
|
||||
}
|
||||
|
||||
Sequential& append(Container m, std::string name = "") {
|
||||
if (name == "") {
|
||||
name = std::to_string(size());
|
||||
}
|
||||
children_.push_back(m);
|
||||
ContainerImpl::add(children_.back(), name);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
AUTOGRAD_CONTAINER_CLASS(SimpleContainer) {
|
||||
// Lets you use a container without making a new class,
|
||||
// for experimental implementations
|
||||
public:
|
||||
virtual variable_list forward(variable_list) override {
|
||||
throw std::runtime_error(
|
||||
"SimpleContainer has no forward, maybe you"
|
||||
" wanted to subclass and override this function?");
|
||||
}
|
||||
using ContainerImpl::add;
|
||||
};
|
||||
|
||||
AUTOGRAD_CONTAINER_CLASS(Functional) {
|
||||
// Lets you create a container from a function, designed for use in
|
||||
// Sequential.
|
||||
public:
|
||||
Functional(std::function<variable_list(variable_list)> fun) : fun_(fun){};
|
||||
Functional(std::function<Variable(Variable)> fun)
|
||||
: fun_([fun](variable_list input) {
|
||||
return variable_list({fun(input[0])});
|
||||
}){};
|
||||
|
||||
variable_list forward(variable_list input) override {
|
||||
return fun_(input);
|
||||
};
|
||||
|
||||
std::function<variable_list(variable_list)> fun_;
|
||||
};
|
||||
|
||||
AUTOGRAD_CONTAINER_CLASS(Linear) {
|
||||
public:
|
||||
Linear(uint32_t nin, uint32_t nout) : nin(nin), nout(nout) {}
|
||||
|
||||
variable_list forward(variable_list) override;
|
||||
void reset_parameters() override;
|
||||
void initialize_parameters() override;
|
||||
AUTOGRAD_KWARG(Linear, bool, no_bias, false, true);
|
||||
|
||||
Variable weight, bias;
|
||||
uint32_t nin, nout;
|
||||
};
|
||||
|
||||
AUTOGRAD_CONTAINER_CLASS(Embedding) {
|
||||
public:
|
||||
Embedding(uint32_t num_embeddings, uint32_t embedding_dim)
|
||||
: num_embeddings(num_embeddings), embedding_dim(embedding_dim) {}
|
||||
|
||||
variable_list forward(variable_list) override;
|
||||
void reset_parameters() override;
|
||||
void initialize_parameters() override;
|
||||
|
||||
Variable weight;
|
||||
uint32_t num_embeddings, embedding_dim;
|
||||
};
|
||||
|
||||
AUTOGRAD_CONTAINER_CLASS(Conv) {
|
||||
private:
|
||||
Conv(uint32_t Nd, uint32_t in_chan, uint32_t out_chan)
|
||||
: Nd_(Nd),
|
||||
in_channels_(in_chan),
|
||||
out_channels_(out_chan),
|
||||
stride_(makeTup(1, 1)),
|
||||
padding_(makeTup(0)),
|
||||
dilation_(makeTup(1, 1)),
|
||||
dilated_(false),
|
||||
output_padding_(makeTup(0)) {}
|
||||
|
||||
public:
|
||||
Conv(uint32_t Nd, uint32_t in_chan, uint32_t out_chan, int ks)
|
||||
: Conv(Nd, in_chan, out_chan) {
|
||||
ks_ = makeTup(ks, 1);
|
||||
}
|
||||
|
||||
Conv(uint32_t Nd, uint32_t in_chan, uint32_t out_chan, IntVec ks)
|
||||
: Conv(Nd, in_chan, out_chan) {
|
||||
ks_ = makeTup(ks);
|
||||
}
|
||||
|
||||
void reset_parameters() override;
|
||||
variable_list forward(variable_list) override;
|
||||
void initialize_parameters() override;
|
||||
|
||||
template <typename T>
|
||||
Conv& stride(T s) {
|
||||
stride_ = makeTup(s, 1);
|
||||
return *this;
|
||||
}
|
||||
template <typename T>
|
||||
Conv& padding(T s) {
|
||||
padding_ = makeTup(s);
|
||||
return *this;
|
||||
}
|
||||
template <typename T>
|
||||
Conv& dilation(T s) {
|
||||
dilation_ = makeTup(s, 1);
|
||||
return *this;
|
||||
}
|
||||
template <typename T>
|
||||
Conv& output_padding(T s) {
|
||||
output_padding_ = makeTup(s);
|
||||
return *this;
|
||||
}
|
||||
|
||||
AUTOGRAD_KWARG(Conv, bool, transposed, false, true)
|
||||
AUTOGRAD_KWARG(Conv, bool, no_bias, false, true)
|
||||
AUTOGRAD_KWARG(Conv, int, groups, 1, 1)
|
||||
|
||||
Variable weight, bias;
|
||||
uint32_t Nd_;
|
||||
uint32_t in_channels_;
|
||||
uint32_t out_channels_;
|
||||
IntVec ks_;
|
||||
IntVec stride_;
|
||||
IntVec padding_;
|
||||
IntVec dilation_;
|
||||
bool dilated_;
|
||||
IntVec output_padding_;
|
||||
|
||||
protected:
|
||||
IntVec makeTup(int x, int def = 0) {
|
||||
IntVec ret;
|
||||
if (Nd_ == 1) {
|
||||
ret.push_back(x);
|
||||
ret.push_back(def);
|
||||
} else {
|
||||
for (auto i = 0U; i < Nd_; i++)
|
||||
ret.push_back(x);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
IntVec makeTup(IntVec x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
class Conv1d : public Conv {
|
||||
public:
|
||||
Conv1d(uint32_t i, uint32_t o, int ks) : Conv(1, i, o, ks) {}
|
||||
Conv1d(uint32_t i, uint32_t o, IntVec ks) : Conv(1, i, o, ks) {}
|
||||
};
|
||||
|
||||
class Conv2d : public Conv {
|
||||
public:
|
||||
Conv2d(uint32_t i, uint32_t o, int ks) : Conv(2, i, o, ks) {}
|
||||
Conv2d(uint32_t i, uint32_t o, IntVec ks) : Conv(2, i, o, ks) {}
|
||||
};
|
||||
|
||||
class Conv3d : public Conv {
|
||||
public:
|
||||
Conv3d(uint32_t i, uint32_t o, int ks) : Conv(3, i, o, ks) {}
|
||||
Conv3d(uint32_t i, uint32_t o, IntVec ks) : Conv(3, i, o, ks) {}
|
||||
};
|
||||
|
||||
AUTOGRAD_CONTAINER_CLASS(BatchNorm) {
|
||||
public:
|
||||
BatchNorm(uint32_t num_features) : num_features_(num_features) {}
|
||||
|
||||
AUTOGRAD_KWARG(BatchNorm, double, eps, 1e-5, 1e-5)
|
||||
AUTOGRAD_KWARG(BatchNorm, double, momentum, 0.1, 0.1)
|
||||
AUTOGRAD_KWARG(BatchNorm, bool, affine, true, true)
|
||||
AUTOGRAD_KWARG(BatchNorm, bool, stateful, false, true)
|
||||
|
||||
void reset_parameters() override;
|
||||
variable_list forward(variable_list) override;
|
||||
void initialize_parameters() override;
|
||||
|
||||
Variable weight;
|
||||
Variable bias;
|
||||
Variable running_mean;
|
||||
Variable running_var;
|
||||
|
||||
protected:
|
||||
uint32_t num_features_;
|
||||
};
|
||||
|
||||
AUTOGRAD_CONTAINER_CLASS(Dropout) {
|
||||
public:
|
||||
Dropout(double p = 0.5) : p_(p) {
|
||||
assert(p < 1 && p >= 0);
|
||||
}
|
||||
variable_list forward(variable_list) override;
|
||||
|
||||
protected:
|
||||
double p_;
|
||||
};
|
||||
|
||||
AUTOGRAD_CONTAINER_CLASS(Dropout2d) {
|
||||
public:
|
||||
Dropout2d(double p = 0.5) : p_(p) {
|
||||
assert(p < 1 && p >= 0);
|
||||
}
|
||||
variable_list forward(variable_list) override;
|
||||
|
||||
protected:
|
||||
double p_;
|
||||
};
|
||||
|
||||
template <typename Derived>
|
||||
class RNNBase : public Container_CRTP<Derived> {
|
||||
public:
|
||||
// These must line up with the CUDNN mode codes
|
||||
enum RNNMode : int64_t { RNN_RELU = 0, RNN_TANH = 1, LSTM = 2, GRU = 3 };
|
||||
RNNBase(uint32_t input_size, uint32_t hidden_size)
|
||||
: input_size_(input_size), hidden_size_(hidden_size) {}
|
||||
|
||||
AUTOGRAD_KWARG(RNNBase, RNNMode, mode, RNNMode::LSTM, RNNMode::LSTM)
|
||||
AUTOGRAD_KWARG(RNNBase, uint32_t, nlayers, 1, 1);
|
||||
AUTOGRAD_KWARG(RNNBase, bool, no_bias, false, true)
|
||||
AUTOGRAD_KWARG(RNNBase, float, dropout, 0, 0)
|
||||
|
||||
bool flatten_parameters(); // Flatten for cudnn
|
||||
|
||||
variable_list forward(variable_list) override;
|
||||
void initialize_containers() override;
|
||||
void reset_parameters() override;
|
||||
|
||||
void cpu() override;
|
||||
void cuda() override;
|
||||
|
||||
std::vector<Container> i2h;
|
||||
std::vector<Container> h2h;
|
||||
|
||||
protected:
|
||||
uint32_t input_size_;
|
||||
uint32_t hidden_size_;
|
||||
uint32_t gate_size_;
|
||||
// This is copied from pytorch, to determine whether weights are flat for
|
||||
// the fast CUDNN route. Otherwise, we have to use non flattened weights,
|
||||
// which
|
||||
// are much slower.
|
||||
// https://github.com/pytorch/pytorch/blob/1848cad10802db9fa0aa066d9de195958120d863/torch/nn/modules/rnn.py#L159-L165
|
||||
// TODO Actually since we are in C++ we can probably just actually check if
|
||||
// the parameters are flat, instead of relying on data pointers and stuff.
|
||||
std::vector<void*> data_ptrs_;
|
||||
Variable flat_weight_;
|
||||
Container dropout_module;
|
||||
|
||||
variable_list CUDNN_forward(variable_list);
|
||||
variable_list autograd_forward(variable_list);
|
||||
|
||||
variable_list cell_forward(variable_list, int);
|
||||
variable_list LSTM_cell_forward(variable_list, int);
|
||||
variable_list GRU_cell_forward(variable_list, int);
|
||||
variable_list RNN_RELU_cell_forward(variable_list, int);
|
||||
variable_list RNN_TANH_cell_forward(variable_list, int);
|
||||
};
|
||||
|
||||
// We must instantiate these templates so we can put implementations in the .cpp
|
||||
class LSTM;
|
||||
template class RNNBase<LSTM>;
|
||||
class LSTM : public RNNBase<LSTM> {
|
||||
public:
|
||||
LSTM(uint32_t inp_size, uint32_t hid_size) : RNNBase(inp_size, hid_size) {
|
||||
mode_ = RNNBase::RNNMode::LSTM;
|
||||
}
|
||||
};
|
||||
|
||||
class GRU;
|
||||
template class RNNBase<GRU>;
|
||||
class GRU : public RNNBase<GRU> {
|
||||
public:
|
||||
GRU(uint32_t inp_size, uint32_t hid_size) : RNNBase(inp_size, hid_size) {
|
||||
mode_ = RNNBase::RNNMode::GRU;
|
||||
}
|
||||
};
|
||||
|
||||
class RNN;
|
||||
template class RNNBase<RNN>;
|
||||
class RNN : public RNNBase<RNN> {
|
||||
public:
|
||||
enum Mode { Tanh, Relu };
|
||||
RNN(uint32_t inp_size, uint32_t hid_size, Mode mode = Mode::Tanh)
|
||||
: RNNBase(inp_size, hid_size) {
|
||||
if (mode == Mode::Tanh) {
|
||||
mode_ = RNNBase::RNNMode::RNN_TANH;
|
||||
} else if (mode == Mode::Relu) {
|
||||
mode_ = RNNBase::RNNMode::RNN_RELU;
|
||||
} else {
|
||||
throw std::runtime_error("RNN Mode not supported");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace autograd
|
||||
70
torch/csrc/api/include/torch/detail.h
Normal file
70
torch/csrc/api/include/torch/detail.h
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
||||
#include "torch/csrc/autograd/engine.h"
|
||||
#include "torch/csrc/autograd/grad_mode.h"
|
||||
|
||||
// for AutoGPU. Usage:
|
||||
// AutoGPU gpu_raii(1);
|
||||
// While this object is in scope, all of your GPU tensors will go to GPU 1
|
||||
#include "torch/csrc/utils/auto_gpu.h"
|
||||
|
||||
#define AUTOGRAD_OPTIMIZER_CLASS(Type) \
|
||||
class Type : public autograd::Optimizer_CRTP<Type>
|
||||
#define AUTOGRAD_KWARG(CLS, TYP, NAME, DEFAULT, OPTION) \
|
||||
TYP NAME##_ = DEFAULT; \
|
||||
CLS& NAME(TYP x = OPTION) { \
|
||||
NAME##_ = x; \
|
||||
return *this; \
|
||||
}
|
||||
|
||||
namespace {
|
||||
namespace tag = torch::autograd;
|
||||
using IntVec = decltype(std::declval<at::IntList>().vec());
|
||||
} // namespace
|
||||
|
||||
namespace autograd {
|
||||
namespace detail {
|
||||
extern tag::Engine engine;
|
||||
}
|
||||
|
||||
class ContainerImpl;
|
||||
class OptimizerImpl;
|
||||
using Variable = tag::Variable;
|
||||
using variable_list = tag::variable_list;
|
||||
using Tensor = at::Tensor;
|
||||
using Container = std::shared_ptr<ContainerImpl>;
|
||||
using Optimizer = std::shared_ptr<OptimizerImpl>;
|
||||
|
||||
void backward(Tensor loss, bool keep_graph = false);
|
||||
|
||||
inline Variable Var(at::Tensor data, bool requires_grad = true) {
|
||||
return tag::make_variable(data, requires_grad);
|
||||
}
|
||||
|
||||
// This is thread local!!!
|
||||
inline void set_grad_enabled(bool val = true) {
|
||||
tag::GradMode::set_enabled(val);
|
||||
}
|
||||
|
||||
// RAII thread local lock that stops future execution from building gradients
|
||||
class no_grad_guard {
|
||||
public:
|
||||
no_grad_guard() {
|
||||
tag::GradMode::set_enabled(false);
|
||||
}
|
||||
|
||||
~no_grad_guard() {
|
||||
tag::GradMode::set_enabled(true);
|
||||
}
|
||||
};
|
||||
|
||||
void setSeed(uint64_t seed);
|
||||
|
||||
int getNumGPUs();
|
||||
bool hasCuda();
|
||||
bool hasCudnn();
|
||||
|
||||
} // namespace autograd
|
||||
139
torch/csrc/api/include/torch/optimizers.h
Normal file
139
torch/csrc/api/include/torch/optimizers.h
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
#pragma once
|
||||
|
||||
#include "torch/containers.h"
|
||||
#include "torch/detail.h"
|
||||
|
||||
#include "cereal/access.hpp"
|
||||
#include "cereal/cereal.hpp"
|
||||
|
||||
namespace autograd {
|
||||
class OptimizerImpl {
|
||||
public:
|
||||
OptimizerImpl(Container model) : model_(model) {}
|
||||
virtual void init_state() {}
|
||||
virtual void step() = 0;
|
||||
void zero_grad();
|
||||
|
||||
void set_model(Container model);
|
||||
|
||||
protected:
|
||||
OptimizerImpl() {}
|
||||
Container model_;
|
||||
};
|
||||
|
||||
template <class Derived>
|
||||
class Optimizer_CRTP : public OptimizerImpl {
|
||||
public:
|
||||
Optimizer_CRTP(Container model) : OptimizerImpl(model) {}
|
||||
std::shared_ptr<Derived> make() const {
|
||||
auto ptr = std::make_shared<Derived>(*static_cast<const Derived*>(this));
|
||||
ptr->init_state();
|
||||
return ptr;
|
||||
}
|
||||
|
||||
protected:
|
||||
Optimizer_CRTP() {}
|
||||
};
|
||||
|
||||
AUTOGRAD_OPTIMIZER_CLASS(SGD) {
|
||||
public:
|
||||
SGD(Container model, double lr) : Optimizer_CRTP(model), lr_(lr) {}
|
||||
AUTOGRAD_KWARG(SGD, double, momentum, 0, 0);
|
||||
AUTOGRAD_KWARG(SGD, double, dampening, 0, 0);
|
||||
AUTOGRAD_KWARG(SGD, double, weight_decay, 0, 0);
|
||||
AUTOGRAD_KWARG(SGD, bool, nesterov, false, true);
|
||||
double lr_;
|
||||
void step() override;
|
||||
void init_state() override;
|
||||
|
||||
template <class Archive>
|
||||
void serialize(Archive & ar) {
|
||||
ar(CEREAL_NVP(momentum_buffers_));
|
||||
}
|
||||
|
||||
private:
|
||||
friend class cereal::access;
|
||||
SGD() {}
|
||||
std::unordered_map<std::string, at::Tensor> momentum_buffers_;
|
||||
};
|
||||
|
||||
AUTOGRAD_OPTIMIZER_CLASS(Adagrad) {
|
||||
public:
|
||||
Adagrad(Container model, double lr) : Optimizer_CRTP(model), lr_(lr) {}
|
||||
AUTOGRAD_KWARG(Adagrad, double, lr_decay, 0, 0);
|
||||
AUTOGRAD_KWARG(Adagrad, double, weight_decay, 0, 0);
|
||||
double lr_;
|
||||
void step() override;
|
||||
void init_state() override;
|
||||
|
||||
template <class Archive>
|
||||
void serialize(Archive & ar) {
|
||||
ar(CEREAL_NVP(sum_));
|
||||
ar(CEREAL_NVP(step_));
|
||||
}
|
||||
|
||||
private:
|
||||
friend class cereal::access;
|
||||
Adagrad() {}
|
||||
std::unordered_map<std::string, at::Tensor> sum_;
|
||||
std::unordered_map<std::string, double> step_;
|
||||
};
|
||||
|
||||
AUTOGRAD_OPTIMIZER_CLASS(RMSprop) {
|
||||
public:
|
||||
RMSprop(Container model, double lr) : Optimizer_CRTP(model), lr_(lr) {}
|
||||
AUTOGRAD_KWARG(RMSprop, double, alpha, 0.99, 0.99);
|
||||
AUTOGRAD_KWARG(RMSprop, double, eps, 1e-8, 1e-8);
|
||||
AUTOGRAD_KWARG(RMSprop, double, weight_decay, 0, 0);
|
||||
AUTOGRAD_KWARG(RMSprop, double, momentum, 0, 0);
|
||||
AUTOGRAD_KWARG(RMSprop, bool, centered, false, true);
|
||||
|
||||
double lr_;
|
||||
void step() override;
|
||||
void init_state() override;
|
||||
|
||||
template <class Archive>
|
||||
void serialize(Archive & ar) {
|
||||
ar(CEREAL_NVP(square_avg_buffer_));
|
||||
ar(CEREAL_NVP(momentum_buffer_));
|
||||
ar(CEREAL_NVP(grad_avg_buffer_));
|
||||
}
|
||||
|
||||
private:
|
||||
friend class cereal::access;
|
||||
RMSprop() {}
|
||||
std::unordered_map<std::string, at::Tensor> square_avg_buffer_;
|
||||
std::unordered_map<std::string, at::Tensor> momentum_buffer_;
|
||||
std::unordered_map<std::string, at::Tensor> grad_avg_buffer_;
|
||||
};
|
||||
|
||||
AUTOGRAD_OPTIMIZER_CLASS(Adam) {
|
||||
public:
|
||||
Adam(Container model, double lr) : Optimizer_CRTP(model), lr_(lr) {}
|
||||
AUTOGRAD_KWARG(Adam, double, beta1, 0.9, 0.9);
|
||||
AUTOGRAD_KWARG(Adam, double, beta2, 0.999, 0.999);
|
||||
AUTOGRAD_KWARG(Adam, double, weight_decay, 0, 0);
|
||||
AUTOGRAD_KWARG(Adam, double, eps, 1e-8, 1e-8);
|
||||
AUTOGRAD_KWARG(Adam, bool, amsgrad, false, true);
|
||||
double lr_;
|
||||
void step() override;
|
||||
void init_state() override;
|
||||
|
||||
template <class Archive>
|
||||
void serialize(Archive & ar) {
|
||||
ar(CEREAL_NVP(step_buffer_),
|
||||
CEREAL_NVP(exp_avg_buffer_),
|
||||
CEREAL_NVP(exp_avg_sq_buffer_),
|
||||
CEREAL_NVP(max_exp_avg_sq_buffer_));
|
||||
}
|
||||
|
||||
private:
|
||||
friend class cereal::access;
|
||||
Adam() {}
|
||||
std::unordered_map<std::string, int> step_buffer_;
|
||||
std::unordered_map<std::string, at::Tensor> exp_avg_buffer_;
|
||||
std::unordered_map<std::string, at::Tensor> exp_avg_sq_buffer_;
|
||||
std::unordered_map<std::string, at::Tensor> max_exp_avg_sq_buffer_;
|
||||
};
|
||||
|
||||
} // namespace autograd
|
||||
236
torch/csrc/api/include/torch/serialization.h
Normal file
236
torch/csrc/api/include/torch/serialization.h
Normal file
|
|
@ -0,0 +1,236 @@
|
|||
#pragma once
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "cereal/archives/binary.hpp"
|
||||
#include "cereal/types/polymorphic.hpp"
|
||||
|
||||
#include "cereal/types/string.hpp"
|
||||
#include "cereal/types/unordered_map.hpp"
|
||||
#include "cereal/types/vector.hpp"
|
||||
|
||||
namespace autograd {
|
||||
|
||||
// Some convenience functions for saving and loading
|
||||
template <typename T>
|
||||
void save(std::ostream& stream, T const& obj) {
|
||||
cereal::BinaryOutputArchive archive(stream);
|
||||
archive(*obj);
|
||||
}
|
||||
template <typename T>
|
||||
void load(std::istream& stream, T& obj) {
|
||||
cereal::BinaryInputArchive archive(stream);
|
||||
archive(*obj);
|
||||
}
|
||||
template <typename T>
|
||||
void save(std::ostream& stream, T const* obj) {
|
||||
cereal::BinaryOutputArchive archive(stream);
|
||||
archive(*obj);
|
||||
}
|
||||
template <typename T>
|
||||
void load(std::istream& stream, T* obj) {
|
||||
cereal::BinaryInputArchive archive(stream);
|
||||
archive(*obj);
|
||||
}
|
||||
template <typename T>
|
||||
void save(std::string const& path, T const& obj) {
|
||||
std::ofstream os(path, std::ios::binary);
|
||||
autograd::save(os, obj);
|
||||
}
|
||||
template <typename T>
|
||||
void load(std::string const& path, T& obj) {
|
||||
std::ifstream is(path, std::ios::binary);
|
||||
autograd::load(is, obj);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
// We use our own hard-coded type<->id mapping so that serialization is robust
|
||||
// wrt changes in ATen; see e.g. https://git.io/vxd6R
|
||||
// The mapping is consistent with the ScalarType enum as of pytorch version
|
||||
// v0.1.11-7675-ge94c67e.
|
||||
inline int32_t scalarTypeId(at::ScalarType type) {
|
||||
switch (type) {
|
||||
case at::ScalarType::Byte: return 0;
|
||||
case at::ScalarType::Char: return 1;
|
||||
case at::ScalarType::Short: return 2;
|
||||
case at::ScalarType::Int: return 3;
|
||||
case at::ScalarType::Long: return 4;
|
||||
case at::ScalarType::Half: return 5;
|
||||
case at::ScalarType::Float: return 6;
|
||||
case at::ScalarType::Double: return 7;
|
||||
case at::ScalarType::Undefined: return 8;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"Unknown scalar type: " + std::to_string(static_cast<int>(type)));
|
||||
}
|
||||
}
|
||||
|
||||
inline at::ScalarType scalarTypeFromId(int32_t id) {
|
||||
switch (id) {
|
||||
case 0: return at::ScalarType::Byte;
|
||||
case 1: return at::ScalarType::Char;
|
||||
case 2: return at::ScalarType::Short;
|
||||
case 3: return at::ScalarType::Int;
|
||||
case 4: return at::ScalarType::Long;
|
||||
case 5: return at::ScalarType::Half;
|
||||
case 6: return at::ScalarType::Float;
|
||||
case 7: return at::ScalarType::Double;
|
||||
case 8: return at::ScalarType::Undefined;
|
||||
default:
|
||||
throw std::runtime_error("Unknown scalar type id: " + std::to_string(id));
|
||||
}
|
||||
}
|
||||
|
||||
inline int32_t backendId(at::Backend backend) {
|
||||
switch (backend) {
|
||||
case at::Backend::CPU: return 0;
|
||||
case at::Backend::CUDA: return 1;
|
||||
case at::Backend::SparseCPU: return 2;
|
||||
case at::Backend::SparseCUDA: return 3;
|
||||
case at::Backend::Undefined: return 4;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"Unknown backend: " + std::to_string(static_cast<int>(backend)));
|
||||
}
|
||||
}
|
||||
|
||||
inline at::Backend backendFromId(int32_t id) {
|
||||
switch (id) {
|
||||
case 0: return at::Backend::CPU;
|
||||
case 1: return at::Backend::CUDA;
|
||||
case 2: return at::Backend::SparseCPU;
|
||||
case 3: return at::Backend::SparseCUDA;
|
||||
case 4: return at::Backend::Undefined;
|
||||
default:
|
||||
throw std::runtime_error("Unknown backend id: " + std::to_string(id));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace autograd
|
||||
|
||||
// This is super ugly and I don't know how to simplify it
|
||||
CEREAL_REGISTER_TYPE(autograd::SGD);
|
||||
CEREAL_REGISTER_POLYMORPHIC_RELATION(autograd::OptimizerImpl, autograd::SGD);
|
||||
CEREAL_REGISTER_TYPE(autograd::Adagrad);
|
||||
CEREAL_REGISTER_POLYMORPHIC_RELATION(
|
||||
autograd::OptimizerImpl,
|
||||
autograd::Adagrad);
|
||||
CEREAL_REGISTER_TYPE(autograd::RMSprop);
|
||||
CEREAL_REGISTER_POLYMORPHIC_RELATION(
|
||||
autograd::OptimizerImpl,
|
||||
autograd::RMSprop);
|
||||
CEREAL_REGISTER_TYPE(autograd::Adam);
|
||||
CEREAL_REGISTER_POLYMORPHIC_RELATION(autograd::OptimizerImpl, autograd::Adam);
|
||||
|
||||
namespace cereal {
|
||||
|
||||
namespace agimpl {
|
||||
|
||||
template <class Archive>
|
||||
void saveBinary(Archive& archive, void const* data, std::size_t size) {
|
||||
// In general, there's no direct `saveBinary`-like method on archives
|
||||
std::vector<char> v(
|
||||
static_cast<char const*>(data), static_cast<char const*>(data) + size);
|
||||
archive(v);
|
||||
}
|
||||
template <>
|
||||
inline void
|
||||
saveBinary(BinaryOutputArchive& archive, void const* data, std::size_t size) {
|
||||
// Writes to output stream without extra copy
|
||||
archive.saveBinary(data, size);
|
||||
}
|
||||
|
||||
template <class Archive>
|
||||
void loadBinary(Archive& archive, void* data, std::size_t size) {
|
||||
// In general, there's no direct `loadBinary`-like method on archives
|
||||
std::vector<char> v(size);
|
||||
archive(v);
|
||||
std::memcpy(data, v.data(), size);
|
||||
}
|
||||
template <>
|
||||
inline void
|
||||
loadBinary(BinaryInputArchive& archive, void* data, std::size_t size) {
|
||||
// Read from input stream without extra copy
|
||||
archive.loadBinary(data, size);
|
||||
}
|
||||
|
||||
} // namespace agimpl
|
||||
|
||||
// Gradients will not be saved for variables
|
||||
template <class Archive>
|
||||
void save(Archive& archive, at::Tensor const& tensor) {
|
||||
if (!tensor.defined()) {
|
||||
int32_t typeId = ::autograd::detail::scalarTypeId(at::ScalarType::Undefined);
|
||||
archive(CEREAL_NVP(typeId));
|
||||
return;
|
||||
} else {
|
||||
int32_t typeId = ::autograd::detail::scalarTypeId(tensor.type().scalarType());
|
||||
archive(CEREAL_NVP(typeId));
|
||||
}
|
||||
auto sizes = std::vector<int64_t>();
|
||||
auto buf = std::vector<uint8_t>();
|
||||
for (auto s : tensor.sizes()) {
|
||||
sizes.push_back(s);
|
||||
}
|
||||
auto contig = tensor.toBackend(at::kCPU).contiguous();
|
||||
int32_t backend = ::autograd::detail::backendId(tensor.type().backend());
|
||||
|
||||
archive(CEREAL_NVP(backend), CEREAL_NVP(sizes));
|
||||
agimpl::saveBinary(
|
||||
archive,
|
||||
contig.data_ptr(),
|
||||
tensor.numel() * tensor.type().elementSizeInBytes());
|
||||
}
|
||||
|
||||
/**
|
||||
* We follow these rules for loading:
|
||||
* 1. If tensor is defined, and the same ScalarType as the saved tensor,
|
||||
* then we simply copy the data into the tensor, with resizing.
|
||||
* 2. Otherwise, overwrite the provided tensor with the right type and backend
|
||||
**/
|
||||
template <class Archive>
|
||||
void load(Archive& archive, at::Tensor& tensor) {
|
||||
at::ScalarType type;
|
||||
int32_t typeId;
|
||||
archive(CEREAL_NVP(typeId));
|
||||
type = ::autograd::detail::scalarTypeFromId(typeId);
|
||||
if (type == at::ScalarType::Undefined) {
|
||||
tensor = at::Tensor();
|
||||
return;
|
||||
}
|
||||
|
||||
int32_t backendId;
|
||||
auto sizes = std::vector<int64_t>();
|
||||
auto buf = std::vector<uint8_t>();
|
||||
archive(CEREAL_NVP(backendId), CEREAL_NVP(sizes));
|
||||
|
||||
at::Backend backend = ::autograd::detail::backendFromId(backendId);
|
||||
if (!tensor.defined() || tensor.type().scalarType() != type) {
|
||||
tensor = at::getType(backend, type).tensor();
|
||||
}
|
||||
tensor.resize_(sizes);
|
||||
|
||||
if (tensor.type().is_cuda()) {
|
||||
// should actually use cudamemcpy probably
|
||||
auto cputensor = at::CPU(tensor.type().scalarType()).tensor(sizes);
|
||||
agimpl::loadBinary(
|
||||
archive,
|
||||
cputensor.data_ptr(),
|
||||
cputensor.numel() * cputensor.type().elementSizeInBytes());
|
||||
tensor.copy_(cputensor);
|
||||
} else {
|
||||
agimpl::loadBinary(
|
||||
archive,
|
||||
tensor.data_ptr(),
|
||||
tensor.numel() * tensor.type().elementSizeInBytes());
|
||||
}
|
||||
}
|
||||
|
||||
template <class Archive>
|
||||
void load(Archive& archive, tag::Variable& var) {
|
||||
load(archive, var.data());
|
||||
}
|
||||
|
||||
} // namespace cereal
|
||||
633
torch/csrc/api/src/containers.cpp
Normal file
633
torch/csrc/api/src/containers.cpp
Normal file
|
|
@ -0,0 +1,633 @@
|
|||
#include "torch/containers.h"
|
||||
|
||||
namespace autograd {
|
||||
std::map<std::string, Variable> ContainerImpl::parameters() const {
|
||||
std::map<std::string, Variable> ret;
|
||||
for (auto pair : children_) {
|
||||
auto& name = pair.first;
|
||||
auto& child = pair.second;
|
||||
for (auto& p : child->parameters()) {
|
||||
ret[name + "." + p.first] = p.second;
|
||||
}
|
||||
}
|
||||
for (auto pair : params_) {
|
||||
ret[pair.first] = pair.second;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
Variable& ContainerImpl::param(std::string const& name) {
|
||||
ContainerImpl* container = this;
|
||||
auto begin = 0;
|
||||
while (true) {
|
||||
auto dot_pos = name.find('.', begin);
|
||||
if (dot_pos == std::string::npos) {
|
||||
break;
|
||||
}
|
||||
|
||||
auto child_name = name.substr(begin, dot_pos - begin);
|
||||
auto it = container->children_.find(child_name);
|
||||
if (it == container->children_.end()) {
|
||||
throw std::runtime_error("No such child: " + child_name);
|
||||
}
|
||||
|
||||
container = it->second.get();
|
||||
begin = dot_pos + 1; // Skip the dot
|
||||
}
|
||||
|
||||
auto param_name = name.substr(begin);
|
||||
auto it = container->params_.find(param_name);
|
||||
if (it == params_.end()) {
|
||||
throw std::runtime_error("No such param: " + param_name);
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void ContainerImpl::cuda() {
|
||||
for (auto& pair : children_) {
|
||||
pair.second->cuda();
|
||||
}
|
||||
cuda_ = true;
|
||||
auto copied = params_;
|
||||
params_.clear();
|
||||
initialize_parameters();
|
||||
for (auto pair : params_) {
|
||||
pair.second.data().copy_(copied[pair.first].data());
|
||||
}
|
||||
};
|
||||
|
||||
void ContainerImpl::cpu() {
|
||||
for (auto& pair : children_) {
|
||||
pair.second->cpu();
|
||||
}
|
||||
cuda_ = false;
|
||||
auto copied = params_;
|
||||
params_.clear();
|
||||
initialize_parameters();
|
||||
for (auto pair : params_) {
|
||||
pair.second.data().copy_(copied[pair.first].data());
|
||||
}
|
||||
};
|
||||
|
||||
void ContainerImpl::train() {
|
||||
for (auto& pair : children_) {
|
||||
pair.second->train();
|
||||
}
|
||||
train_ = true;
|
||||
}
|
||||
|
||||
void ContainerImpl::eval() {
|
||||
for (auto& pair : children_) {
|
||||
pair.second->eval();
|
||||
}
|
||||
train_ = false;
|
||||
}
|
||||
|
||||
Container ContainerImpl::add(Container m, std::string const& name) {
|
||||
if (this->children_.find(name) != this->children_.end()) {
|
||||
throw std::runtime_error("Trying to add container that already exists");
|
||||
}
|
||||
if (std::find(name.begin(), name.end(), '.') != name.end()) {
|
||||
// We can't allow containers with dots in their names, as that would make
|
||||
// their parameters not findable with parameters().
|
||||
throw std::runtime_error("Trying to add parameter with a '.' in its name");
|
||||
}
|
||||
this->children_[name] = std::move(m);
|
||||
return this->children_[name];
|
||||
}
|
||||
|
||||
Variable& ContainerImpl::add(Variable v, std::string const& name) {
|
||||
if (this->params_.find(name) != this->params_.end()) {
|
||||
throw std::runtime_error("Trying to add parameter that already exists");
|
||||
}
|
||||
if (std::find(name.begin(), name.end(), '.') != name.end()) {
|
||||
// We can't allow parameters with dots in their names, as that would make
|
||||
// them not findable with parameters().
|
||||
throw std::runtime_error("Trying to add parameter with a '.' in its name");
|
||||
}
|
||||
this->params_[name] = v;
|
||||
return this->params_[name];
|
||||
}
|
||||
|
||||
at::Type& ContainerImpl::DefaultTensor(at::ScalarType s) {
|
||||
if (cuda_)
|
||||
return at::CUDA(s);
|
||||
else
|
||||
return at::CPU(s);
|
||||
}
|
||||
|
||||
variable_list Linear::forward(variable_list input) {
|
||||
auto x = input[0];
|
||||
if (x.ndimension() == 2 && !no_bias_) {
|
||||
// Fused op is marginally faster
|
||||
assert(x.size(1) == weight.size(1));
|
||||
return variable_list({at::addmm(bias, x, weight.t())});
|
||||
}
|
||||
|
||||
auto output = x.matmul(weight.t());
|
||||
if (!no_bias_) {
|
||||
output += bias;
|
||||
}
|
||||
return variable_list({output});
|
||||
}
|
||||
|
||||
void Linear::reset_parameters() {
|
||||
auto stdv = 1.0 / std::sqrt(weight.size(1));
|
||||
for (auto& p : parameters()) {
|
||||
p.second.data().uniform_(-stdv, stdv);
|
||||
}
|
||||
}
|
||||
|
||||
void Linear::initialize_parameters() {
|
||||
weight = this->add(
|
||||
Var(DefaultTensor(at::kFloat).tensor({nout, nin}), true), "weight");
|
||||
if (!no_bias_) {
|
||||
bias =
|
||||
this->add(Var(DefaultTensor(at::kFloat).tensor({nout}), true), "bias");
|
||||
}
|
||||
}
|
||||
|
||||
variable_list Embedding::forward(variable_list input) {
|
||||
auto x = input[0];
|
||||
return variable_list({at::embedding(weight, x, -1, false, false)});
|
||||
}
|
||||
|
||||
void Embedding::reset_parameters() {
|
||||
for (auto& p : parameters()) {
|
||||
p.second.data().normal_(0, 1);
|
||||
}
|
||||
}
|
||||
|
||||
void Embedding::initialize_parameters() {
|
||||
weight = this->add(
|
||||
Var(DefaultTensor(at::kFloat).tensor({num_embeddings, embedding_dim}),
|
||||
true),
|
||||
"weight");
|
||||
}
|
||||
|
||||
void Conv::initialize_parameters() {
|
||||
if (!transposed_) {
|
||||
for (auto pad : output_padding_) {
|
||||
if (pad != 0) {
|
||||
throw std::runtime_error(
|
||||
"Only transposed convolutions support output padding!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
IntVec wsize;
|
||||
if (transposed_) {
|
||||
wsize.push_back(in_channels_);
|
||||
wsize.push_back(out_channels_ / groups_);
|
||||
} else {
|
||||
wsize.push_back(out_channels_);
|
||||
wsize.push_back(in_channels_ / groups_);
|
||||
}
|
||||
wsize.insert(wsize.end(), ks_.begin(), ks_.end());
|
||||
weight =
|
||||
this->add(Var(DefaultTensor(at::kFloat).tensor(wsize), true), "weight");
|
||||
if (!no_bias_) {
|
||||
bias = this->add(
|
||||
Var(DefaultTensor(at::kFloat).tensor({out_channels_}), true), "bias");
|
||||
} else {
|
||||
assert(!bias.defined());
|
||||
}
|
||||
}
|
||||
|
||||
void Conv::reset_parameters() {
|
||||
auto n = in_channels_;
|
||||
for (auto k : ks_)
|
||||
n *= k;
|
||||
auto stdv = 1.0 / std::sqrt(n);
|
||||
for (auto& p : parameters()) {
|
||||
p.second.data().uniform_(-stdv, stdv);
|
||||
}
|
||||
}
|
||||
|
||||
variable_list Conv::forward(variable_list input) {
|
||||
auto x = input[0];
|
||||
if (Nd_ == 1) {
|
||||
assert(x.ndimension() == 3);
|
||||
x = x.unsqueeze(-1); // TODO: Use conv1d once available
|
||||
} else if (Nd_ == 2) {
|
||||
assert(x.ndimension() == 4);
|
||||
} else if (Nd_ == 3) {
|
||||
assert(x.ndimension() == 5);
|
||||
} else {
|
||||
throw std::runtime_error("Only Conv{1,2,3}d are supported");
|
||||
}
|
||||
|
||||
Variable out;
|
||||
if (Nd_ == 1 || Nd_ == 2) {
|
||||
if (transposed_) {
|
||||
out = at::conv_transpose2d(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
stride_,
|
||||
padding_,
|
||||
output_padding_,
|
||||
groups_,
|
||||
dilation_);
|
||||
} else {
|
||||
out = at::conv2d(x, weight, bias, stride_, padding_, dilation_, groups_);
|
||||
}
|
||||
} else if (Nd_ == 3) {
|
||||
if (transposed_) {
|
||||
out = at::conv_transpose3d(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
stride_,
|
||||
padding_,
|
||||
output_padding_,
|
||||
groups_,
|
||||
dilation_);
|
||||
} else {
|
||||
out = at::conv3d(x, weight, bias, stride_, padding_, dilation_, groups_);
|
||||
}
|
||||
}
|
||||
|
||||
return variable_list({out});
|
||||
}
|
||||
|
||||
void BatchNorm::initialize_parameters() {
|
||||
if (affine_) {
|
||||
weight = this->add(
|
||||
Var(DefaultTensor(at::kFloat).tensor(num_features_), true), "weight");
|
||||
bias = this->add(
|
||||
Var(DefaultTensor(at::kFloat).tensor(num_features_), true), "bias");
|
||||
}
|
||||
|
||||
if (stateful_) {
|
||||
running_mean = Var(DefaultTensor(at::kFloat).zeros({num_features_}), false);
|
||||
running_var = Var(DefaultTensor(at::kFloat).ones({num_features_}), false);
|
||||
}
|
||||
}
|
||||
|
||||
void BatchNorm::reset_parameters() {
|
||||
if (affine_) {
|
||||
weight.data().uniform_();
|
||||
bias.data().zero_();
|
||||
}
|
||||
|
||||
if (stateful_) {
|
||||
running_mean.data().zero_();
|
||||
running_var.data().fill_(1);
|
||||
}
|
||||
}
|
||||
|
||||
variable_list BatchNorm::forward(variable_list inputs) {
|
||||
auto& input = inputs[0];
|
||||
auto& running_mean = (stateful_ ? this->running_mean : inputs[1]);
|
||||
auto& running_var = (stateful_ ? this->running_var : inputs[2]);
|
||||
|
||||
if (train_) {
|
||||
const auto num_channels = input.dim() > 1 ? input.size(1) : 1;
|
||||
if (input.numel() / num_channels <= 1) {
|
||||
throw std::runtime_error(
|
||||
"BatchNorm expected more than 1 value per channel when training!");
|
||||
}
|
||||
}
|
||||
|
||||
auto output = at::batch_norm(
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
train_,
|
||||
momentum_,
|
||||
eps_,
|
||||
hasCudnn());
|
||||
|
||||
return variable_list({output});
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
void RNNBase<Derived>::initialize_containers() {
|
||||
auto gate_size = hidden_size_;
|
||||
if (mode_ == RNNMode::LSTM) {
|
||||
gate_size *= 4;
|
||||
} else if (mode_ == RNNMode::GRU) {
|
||||
gate_size *= 3;
|
||||
}
|
||||
|
||||
for (auto i = 0U; i < nlayers_; i++) {
|
||||
auto input_size = (i == 0) ? input_size_ : hidden_size_;
|
||||
i2h.push_back(this->add(
|
||||
Linear(input_size, gate_size).no_bias(no_bias_).make(),
|
||||
"i2h_" + std::to_string(i)));
|
||||
h2h.push_back(this->add(
|
||||
Linear(hidden_size_, gate_size).no_bias(no_bias_).make(),
|
||||
"h2h_" + std::to_string(i)));
|
||||
}
|
||||
if (dropout_ > 0)
|
||||
dropout_module = Dropout(dropout_).make();
|
||||
this->flatten_parameters();
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
void RNNBase<Derived>::reset_parameters() {
|
||||
auto stdv = 1.0 / std::sqrt(hidden_size_);
|
||||
for (auto& p : this->parameters()) {
|
||||
p.second.data().uniform_(-stdv, stdv);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
variable_list RNNBase<Derived>::GRU_cell_forward(variable_list inputs, int i) {
|
||||
auto x = inputs[0];
|
||||
auto hx = inputs[1].defined()
|
||||
? inputs[1]
|
||||
: Var(this->DefaultTensor(at::kFloat).zeros({x.size(0), hidden_size_}));
|
||||
|
||||
auto gi = i2h[i]->forward({x})[0];
|
||||
auto gh = h2h[i]->forward({hx})[0];
|
||||
auto gic = gi.chunk(3, 1);
|
||||
auto ghc = gh.chunk(3, 1);
|
||||
|
||||
auto reset_gate = (gic[0] + ghc[0]).sigmoid_();
|
||||
auto input_gate = (gic[1] + ghc[1]).sigmoid_();
|
||||
auto new_gate = (gic[2] + reset_gate * ghc[2]).tanh_();
|
||||
auto hy = new_gate + input_gate * (hx - new_gate);
|
||||
|
||||
return variable_list({hy});
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
variable_list RNNBase<Derived>::RNN_TANH_cell_forward(
|
||||
variable_list inputs,
|
||||
int i) {
|
||||
auto x = inputs[0];
|
||||
auto hx = inputs[1].defined()
|
||||
? inputs[1]
|
||||
: Var(this->DefaultTensor(at::kFloat).zeros({x.size(0), hidden_size_}));
|
||||
|
||||
auto h = (i2h[i]->forward({x})[0] + h2h[i]->forward({hx})[0]).tanh();
|
||||
return variable_list({h});
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
variable_list RNNBase<Derived>::RNN_RELU_cell_forward(
|
||||
variable_list inputs,
|
||||
int i) {
|
||||
auto x = inputs[0];
|
||||
auto hx = inputs[1].defined()
|
||||
? inputs[1]
|
||||
: Var(this->DefaultTensor(at::kFloat).zeros({x.size(0), hidden_size_}));
|
||||
|
||||
auto h = (i2h[i]->forward({x})[0] + h2h[i]->forward({hx})[0]).clamp_min(0);
|
||||
return variable_list({h});
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
variable_list RNNBase<Derived>::LSTM_cell_forward(variable_list inputs, int i) {
|
||||
auto x = inputs[0];
|
||||
auto hid = inputs[1].defined()
|
||||
? inputs[1]
|
||||
: Var(this->DefaultTensor(at::kFloat)
|
||||
.zeros({2, x.size(0), hidden_size_}));
|
||||
auto hx = hid[0];
|
||||
auto cx = hid[1];
|
||||
|
||||
auto gates = i2h[i]->forward({x})[0] + h2h[i]->forward({hx})[0];
|
||||
|
||||
auto chunked = gates.chunk(4, 1);
|
||||
auto in_gate = chunked[0].sigmoid();
|
||||
auto forget_gate = chunked[1].sigmoid();
|
||||
auto cell_gate = chunked[2].tanh();
|
||||
auto out_gate = chunked[3].sigmoid();
|
||||
|
||||
auto cy = (forget_gate * cx) + (in_gate * cell_gate);
|
||||
auto hy = out_gate * cy.tanh();
|
||||
|
||||
return variable_list({at::stack({hy, cy}, 0)});
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
variable_list RNNBase<Derived>::cell_forward(variable_list inputs, int i) {
|
||||
if (mode_ == RNNMode::LSTM)
|
||||
return LSTM_cell_forward(inputs, i);
|
||||
else if (mode_ == RNNMode::GRU)
|
||||
return GRU_cell_forward(inputs, i);
|
||||
else if (mode_ == RNNMode::RNN_TANH)
|
||||
return RNN_TANH_cell_forward(inputs, i);
|
||||
else if (mode_ == RNNMode::RNN_RELU)
|
||||
return RNN_RELU_cell_forward(inputs, i);
|
||||
else
|
||||
throw std::runtime_error("No such RNN mode");
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
variable_list RNNBase<Derived>::autograd_forward(variable_list inputs) {
|
||||
auto inp = inputs[0];
|
||||
|
||||
std::vector<Tensor> hidden;
|
||||
for (size_t i = 0; i < nlayers_; i++) {
|
||||
hidden.push_back(inputs[1].defined() ? inputs[1][i] : tag::Variable());
|
||||
}
|
||||
|
||||
auto output =
|
||||
Var(this->DefaultTensor(at::kFloat)
|
||||
.zeros({inp.size(0), inp.size(1), hidden_size_}),
|
||||
false);
|
||||
for (auto t = 0U; t < inp.size(0); t++) {
|
||||
auto x = inp.select(0, t);
|
||||
for (size_t i = 0; i < nlayers_; i++) {
|
||||
auto layer_output = cell_forward({x, hidden[i]}, i);
|
||||
hidden[i] = layer_output[0];
|
||||
if (mode_ == RNNMode::LSTM) {
|
||||
x = hidden[i][0];
|
||||
} else {
|
||||
x = hidden[i];
|
||||
}
|
||||
auto output_slice = output.select(0, t);
|
||||
output_slice.copy_(x);
|
||||
if (dropout_ > 0 && i != nlayers_ - 1) {
|
||||
x = dropout_module->forward({x})[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto hidout = at::stack(hidden, 0);
|
||||
return variable_list({output, hidout});
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
bool RNNBase<Derived>::flatten_parameters() {
|
||||
data_ptrs_.clear();
|
||||
auto anyParam = i2h[0]->params_.begin()->second;
|
||||
if (!anyParam.is_cuda() || !at::cudnn_is_acceptable(anyParam)) {
|
||||
return false;
|
||||
}
|
||||
std::unordered_set<void*> unique_data_ptrs;
|
||||
auto params = this->parameters();
|
||||
for (auto& p : params) {
|
||||
unique_data_ptrs.insert(p.second.data().data_ptr());
|
||||
}
|
||||
// TODO PyTorch says:
|
||||
// If any parameters alias, we fall back to the slower, copying code path.
|
||||
// This is
|
||||
// a sufficient check, because overlapping parameter buffers that don't
|
||||
// completely
|
||||
// alias would break the assumptions of the uniqueness check in
|
||||
// Module.named_parameters().
|
||||
// But I'm not sure if this is the case for us
|
||||
if (unique_data_ptrs.size() != params.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<Tensor> weight_list;
|
||||
for (size_t i = 0; i < nlayers_; i++) {
|
||||
weight_list.push_back(i2h[i]->param("weight"));
|
||||
weight_list.push_back(h2h[i]->param("weight"));
|
||||
if (!no_bias_) {
|
||||
weight_list.push_back(i2h[i]->param("bias"));
|
||||
weight_list.push_back(h2h[i]->param("bias"));
|
||||
}
|
||||
}
|
||||
auto weight_stride0 = no_bias_ ? 2 : 4;
|
||||
|
||||
{
|
||||
no_grad_guard guard;
|
||||
flat_weight_ = at::_cudnn_rnn_flatten_weight(
|
||||
weight_list,
|
||||
weight_stride0,
|
||||
input_size_,
|
||||
mode_,
|
||||
hidden_size_,
|
||||
nlayers_,
|
||||
false,
|
||||
false); // batch_first and bidirectional, unsupported
|
||||
}
|
||||
for (auto& p : params) {
|
||||
data_ptrs_.emplace_back(p.second.data().data_ptr());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
variable_list RNNBase<Derived>::CUDNN_forward(variable_list inputs) {
|
||||
std::vector<Tensor> weight_list;
|
||||
for (size_t i = 0; i < nlayers_; i++) {
|
||||
weight_list.push_back(i2h[i]->param("weight"));
|
||||
weight_list.push_back(h2h[i]->param("weight"));
|
||||
if (!no_bias_) {
|
||||
weight_list.push_back(i2h[i]->param("bias"));
|
||||
weight_list.push_back(h2h[i]->param("bias"));
|
||||
}
|
||||
}
|
||||
auto weight_stride0 = no_bias_ ? 2 : 4;
|
||||
|
||||
auto x = inputs[0];
|
||||
Variable hx, cx;
|
||||
if (!inputs[1].defined()) {
|
||||
hx = x.type().zeros({nlayers_, x.size(1), hidden_size_});
|
||||
if (mode_ == RNNMode::LSTM) {
|
||||
cx = x.type().zeros({nlayers_, x.size(1), hidden_size_});
|
||||
}
|
||||
} else {
|
||||
hx = mode_ == RNNMode::LSTM ? inputs[1][0] : inputs[1];
|
||||
cx = mode_ == RNNMode::LSTM ? inputs[1][1] : Variable();
|
||||
}
|
||||
auto dropout_state = x.type().tensor();
|
||||
|
||||
std::vector<void*> weight_data_ptrs;
|
||||
auto params = this->parameters();
|
||||
for (auto& p : params) {
|
||||
weight_data_ptrs.emplace_back(p.second.data().data_ptr());
|
||||
}
|
||||
if (weight_data_ptrs != data_ptrs_) {
|
||||
std::cerr << "Parameters are unflattened! Code path might be super slow. "
|
||||
"Please call flatten_parameters() when you muck around with "
|
||||
"storages!"
|
||||
<< std::endl;
|
||||
flat_weight_ = Variable();
|
||||
}
|
||||
|
||||
// tup = std::tuple of output, hy, cy, reserve, new_weight_buf
|
||||
auto tup = _cudnn_rnn(
|
||||
x,
|
||||
weight_list,
|
||||
weight_stride0,
|
||||
flat_weight_,
|
||||
hx,
|
||||
cx,
|
||||
mode_,
|
||||
hidden_size_,
|
||||
nlayers_,
|
||||
false, // batch first
|
||||
0, // TODO waiting on dropout state descriptor in C++ pytorch
|
||||
this->train_,
|
||||
false, // bidirectional
|
||||
{}, // packing not supported
|
||||
dropout_state // TODO waiting on dropout state descriptor in C++ pytorch
|
||||
);
|
||||
|
||||
Variable hidout = mode_ == RNNMode::LSTM
|
||||
? at::stack({std::get<1>(tup), std::get<2>(tup)}, 0)
|
||||
: std::get<1>(tup);
|
||||
Variable output = std::get<0>(tup);
|
||||
return variable_list({output, hidout});
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
variable_list RNNBase<Derived>::forward(variable_list inputs) {
|
||||
variable_list inp;
|
||||
inp.push_back(inputs[0]);
|
||||
if (inputs.size() > 1) {
|
||||
inp.push_back(inputs[1]);
|
||||
} else {
|
||||
inp.push_back(Variable());
|
||||
}
|
||||
|
||||
// Dropout descriptors aren't in C++ in PyTorch yet...
|
||||
auto output = at::cudnn_is_acceptable(inp[0]) && dropout_ == 0
|
||||
? CUDNN_forward(inp)
|
||||
: autograd_forward(inp);
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
void RNNBase<Derived>::cuda() {
|
||||
Container_CRTP<Derived>::cuda();
|
||||
flatten_parameters();
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
void RNNBase<Derived>::cpu() {
|
||||
Container_CRTP<Derived>::cpu();
|
||||
flatten_parameters();
|
||||
}
|
||||
|
||||
variable_list Dropout::forward(variable_list inputs) {
|
||||
if (p_ == 0 || !this->train_)
|
||||
return inputs;
|
||||
variable_list lst;
|
||||
for (auto x : inputs) {
|
||||
auto noise = x.data().type().tensor(x.sizes());
|
||||
noise = (noise.uniform_(0, 1) > p_)
|
||||
.toType(x.type().scalarType())
|
||||
.mul_(1. / (1 - p_));
|
||||
lst.push_back(x * Var(noise));
|
||||
}
|
||||
return lst;
|
||||
}
|
||||
|
||||
variable_list Dropout2d::forward(variable_list inputs) {
|
||||
if (p_ == 0 || !this->train_)
|
||||
return inputs;
|
||||
variable_list lst;
|
||||
for (auto x : inputs) {
|
||||
auto noise = x.data().type().tensor({x.size(0), x.size(1), 1, 1});
|
||||
noise = (noise.uniform_(0, 1) > p_)
|
||||
.toType(x.type().scalarType())
|
||||
.mul_(1. / (1 - p_));
|
||||
lst.push_back(x * Var(noise));
|
||||
}
|
||||
return lst;
|
||||
}
|
||||
|
||||
} // namespace autograd
|
||||
71
torch/csrc/api/src/detail.cpp
Normal file
71
torch/csrc/api/src/detail.cpp
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
#include <ATen/Config.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
#include <stdexcept>
|
||||
|
||||
#if AT_CUDA_ENABLED()
|
||||
#include <THC/THCTensorRandom.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#endif
|
||||
|
||||
#include "torch/detail.h"
|
||||
|
||||
namespace autograd {
|
||||
namespace detail {
|
||||
tag::Engine engine;
|
||||
}
|
||||
|
||||
void backward(Variable loss, bool keep_graph) {
|
||||
tag::edge_list edgelst;
|
||||
tag::variable_list varlst;
|
||||
edgelst.emplace_back(loss.grad_fn(), loss.output_nr());
|
||||
varlst.emplace_back(Var(at::ones_like(loss.data()), false));
|
||||
// create_graph should be set to true when we want to support double bwd
|
||||
detail::engine.execute(edgelst, varlst, keep_graph, false);
|
||||
}
|
||||
|
||||
void backward(Tensor loss, bool keep_graph) {
|
||||
Variable tmp(loss);
|
||||
backward(tmp, keep_graph);
|
||||
}
|
||||
|
||||
void setSeed(uint64_t seed) {
|
||||
at::globalContext().defaultGenerator(at::Backend::CPU).manualSeed(seed);
|
||||
#if AT_CUDA_ENABLED()
|
||||
if (getNumGPUs() > 0) {
|
||||
THCRandom_manualSeedAll(at::globalContext().lazyInitCUDA(), seed);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
int getNumGPUs() {
|
||||
#if AT_CUDA_ENABLED()
|
||||
int count;
|
||||
auto err = cudaGetDeviceCount(&count);
|
||||
if (err == cudaErrorNoDevice) {
|
||||
return 0;
|
||||
} else if (err != cudaSuccess) {
|
||||
std::string msg = "CUDA error (";
|
||||
msg += std::to_string(err);
|
||||
msg += "): ";
|
||||
msg += cudaGetErrorString(err);
|
||||
throw std::runtime_error(msg);
|
||||
}
|
||||
return count;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool hasCuda() {
|
||||
return getNumGPUs() > 0;
|
||||
}
|
||||
|
||||
bool hasCudnn() {
|
||||
return hasCuda() && AT_CUDNN_ENABLED();
|
||||
}
|
||||
|
||||
} // namespace autograd
|
||||
199
torch/csrc/api/src/optimizers.cpp
Normal file
199
torch/csrc/api/src/optimizers.cpp
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
#include "torch/optimizers.h"
|
||||
|
||||
namespace autograd {
|
||||
|
||||
void OptimizerImpl::zero_grad() {
|
||||
for (auto p : model_->parameters()) {
|
||||
auto& grad = p.second.grad();
|
||||
if (grad.defined()) {
|
||||
grad = grad.detach();
|
||||
torch::autograd::as_variable_ref(grad).data().zero_();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OptimizerImpl::set_model(Container model) {
|
||||
model_ = model;
|
||||
}
|
||||
|
||||
void SGD::step() {
|
||||
for (auto& pair : model_->parameters()) {
|
||||
auto& name = pair.first;
|
||||
auto& grad = pair.second.grad();
|
||||
auto& p = pair.second.data();
|
||||
if (!grad.defined())
|
||||
continue;
|
||||
|
||||
auto d_p = torch::autograd::as_variable_ref(grad).data();
|
||||
if (weight_decay_ > 0) {
|
||||
d_p.add_(p, weight_decay_);
|
||||
};
|
||||
|
||||
if (momentum_ != 0) {
|
||||
at::Tensor buf;
|
||||
if (momentum_buffers_.find(name) == momentum_buffers_.end()) {
|
||||
buf = momentum_buffers_[name] = at::zeros_like(p);
|
||||
buf.mul_(momentum_).add_(d_p);
|
||||
} else {
|
||||
buf = momentum_buffers_[name];
|
||||
buf.mul_(momentum_).add_(d_p, 1 - dampening_);
|
||||
}
|
||||
|
||||
if (nesterov_) {
|
||||
d_p = d_p.add(buf, momentum_);
|
||||
} else {
|
||||
d_p = buf;
|
||||
}
|
||||
}
|
||||
|
||||
p.add_(d_p, -lr_);
|
||||
}
|
||||
}
|
||||
|
||||
void SGD::init_state() {
|
||||
momentum_buffers_.clear();
|
||||
}
|
||||
|
||||
/// Adapted from
|
||||
/// https://github.com/pytorch/pytorch/blob/master/torch/optim/adagrad.py
|
||||
void Adagrad::step() {
|
||||
for (auto& pair : model_->parameters()) {
|
||||
auto& name = pair.first;
|
||||
auto& grad = pair.second.grad();
|
||||
auto& p = pair.second.data();
|
||||
if (!grad.defined())
|
||||
continue;
|
||||
|
||||
auto d_p = torch::autograd::as_variable_ref(grad).data();
|
||||
if (weight_decay_ > 0) {
|
||||
d_p.add_(p, weight_decay_);
|
||||
};
|
||||
auto& step = step_[name];
|
||||
step += 1.0;
|
||||
auto clr = lr_ / (1.0 + (step - 1.0) * lr_decay_);
|
||||
at::Tensor buf;
|
||||
if (sum_.find(name) == sum_.end()) {
|
||||
buf = sum_[name] = at::zeros_like(p);
|
||||
} else {
|
||||
buf = sum_[name];
|
||||
}
|
||||
|
||||
buf.addcmul_(d_p, d_p, 1.0);
|
||||
at::Tensor std = buf.sqrt().add_(1e-10);
|
||||
p.addcdiv_(d_p, std, -clr);
|
||||
}
|
||||
}
|
||||
|
||||
void Adagrad::init_state() {
|
||||
sum_.clear();
|
||||
step_.clear();
|
||||
}
|
||||
|
||||
/// Adapted from
|
||||
/// https://github.com/pytorch/pytorch/blob/master/torch/optim/rmsprop.py
|
||||
void RMSprop::step() {
|
||||
for (auto& pair : model_->parameters()) {
|
||||
auto& name = pair.first;
|
||||
auto& grad = pair.second.grad();
|
||||
auto& p = pair.second.data();
|
||||
if (!grad.defined())
|
||||
continue;
|
||||
|
||||
if (square_avg_buffer_.find(name) == square_avg_buffer_.end()) {
|
||||
square_avg_buffer_[name] = at::zeros_like(p);
|
||||
if (momentum_) {
|
||||
momentum_buffer_[name] = at::zeros_like(p);
|
||||
};
|
||||
if (centered_) {
|
||||
grad_avg_buffer_[name] = at::zeros_like(p);
|
||||
};
|
||||
};
|
||||
|
||||
auto d_p = torch::autograd::as_variable_ref(grad).data();
|
||||
if (weight_decay_ > 0) {
|
||||
d_p.add_(p, weight_decay_);
|
||||
};
|
||||
|
||||
auto& square_avg = square_avg_buffer_[name];
|
||||
square_avg.mul_(alpha_).addcmul_(d_p, d_p, 1.0 - alpha_);
|
||||
|
||||
at::Tensor avg;
|
||||
if (centered_) {
|
||||
auto& grad_avg = grad_avg_buffer_[name];
|
||||
grad_avg.mul_(alpha_).add_(d_p, 1.0 - alpha_);
|
||||
avg = square_avg.addcmul(grad_avg, grad_avg, -1.0).sqrt().add_(eps_);
|
||||
} else {
|
||||
avg = square_avg.sqrt().add_(eps_);
|
||||
};
|
||||
|
||||
if (momentum_ > 0) {
|
||||
auto& buf = momentum_buffer_[name];
|
||||
buf.mul_(momentum_).addcdiv_(d_p, avg);
|
||||
p.add_(buf, -lr_);
|
||||
} else {
|
||||
p.addcdiv_(d_p, avg, -lr_);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
void RMSprop::init_state() {
|
||||
square_avg_buffer_.clear();
|
||||
momentum_buffer_.clear();
|
||||
grad_avg_buffer_.clear();
|
||||
}
|
||||
|
||||
void Adam::step() {
|
||||
for (auto& pair : model_->parameters()) {
|
||||
auto& name = pair.first;
|
||||
auto& grad = pair.second.grad();
|
||||
auto& p = pair.second.data();
|
||||
if (!grad.defined())
|
||||
continue;
|
||||
|
||||
if (step_buffer_.find(name) == step_buffer_.end()) {
|
||||
step_buffer_[name] = 0;
|
||||
exp_avg_buffer_[name] = at::zeros_like(p);
|
||||
exp_avg_sq_buffer_[name] = at::zeros_like(p);
|
||||
if (amsgrad_) {
|
||||
max_exp_avg_sq_buffer_[name] = at::zeros_like(p);
|
||||
};
|
||||
}
|
||||
|
||||
auto& step = step_buffer_[name];
|
||||
auto& exp_avg = exp_avg_buffer_[name];
|
||||
auto& exp_avg_sq = exp_avg_sq_buffer_[name];
|
||||
|
||||
step += 1;
|
||||
|
||||
auto d_p = torch::autograd::as_variable_ref(grad).data();
|
||||
if (weight_decay_ > 0) {
|
||||
d_p.add_(p, weight_decay_);
|
||||
}
|
||||
|
||||
exp_avg.mul_(beta1_).add_(d_p, 1 - beta1_);
|
||||
exp_avg_sq.mul_(beta2_).addcmul_(d_p, d_p, 1 - beta2_);
|
||||
|
||||
at::Tensor denom;
|
||||
if (amsgrad_) {
|
||||
auto& max_exp_avg_sq = max_exp_avg_sq_buffer_[name];
|
||||
at::max_out(max_exp_avg_sq, max_exp_avg_sq, exp_avg_sq);
|
||||
denom = max_exp_avg_sq.sqrt().add_(eps_);
|
||||
} else {
|
||||
denom = exp_avg_sq.sqrt().add_(eps_);
|
||||
};
|
||||
|
||||
auto bias_correction1 = 1 - std::pow(beta1_, step);
|
||||
auto bias_correction2 = 1 - std::pow(beta2_, step);
|
||||
auto step_size = lr_ * std::sqrt(bias_correction2) / bias_correction1;
|
||||
|
||||
p.addcdiv_(exp_avg, denom, -step_size);
|
||||
}
|
||||
}
|
||||
|
||||
void Adam::init_state() {
|
||||
step_buffer_.clear();
|
||||
exp_avg_buffer_.clear();
|
||||
exp_avg_sq_buffer_.clear();
|
||||
}
|
||||
|
||||
} // namespace autograd
|
||||
Loading…
Reference in New Issue
Block a user