pytorch/test/cpp/api/util.h
Peter Goldsborough 2e0dd86903 Make torch::Tensor -> at::Tensor (#10516)
Summary:
This PR removes the `using Tensor = autograd::Variable;` alias from `torch/tensor.h`, which means `torch::Tensor` is now `at::Tensor`. This PR fixes up some last uses of `.data()` and tidies up the resulting code. For example, I was able to remove `TensorListView` such that code like

```
auto loss = torch::stack(torch::TensorListView(policy_loss)).sum() +
    torch::stack(torch::TensorListView(value_loss)).sum();
```

is now

```
auto loss = torch::stack(policy_loss).sum() + torch::stack(value_loss).sum();
```

CC jgehring

ebetica
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10516

Differential Revision: D9324691

Pulled By: goldsborough

fbshipit-source-id: a7c1cb779c9c829f89cea55f07ac539b00c78449
2018-08-15 21:25:12 -07:00

30 lines
683 B
C++

#pragma once
#include <torch/nn/cloneable.h>
#include <string>
#include <utility>
namespace torch {
namespace test {
// Lets you use a container without making a new class,
// for experimental implementations
class SimpleContainer : public nn::Cloneable<SimpleContainer> {
public:
void reset() override {}
template <typename ModuleHolder>
ModuleHolder add(
ModuleHolder module_holder,
std::string name = std::string()) {
return Module::register_module(std::move(name), module_holder);
}
};
inline bool pointer_equal(at::Tensor first, at::Tensor second) {
return first.data<float>() == second.data<float>();
}
} // namespace test
} // namespace torch