mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
torch::optional -> std::optional (#138987)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138987 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
228963ad60
commit
068f7e7a78
|
|
@ -33,7 +33,7 @@ struct DummyDataset : datasets::Dataset<DummyDataset, int> {
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
|
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
|
||||||
return 1 + index;
|
return 1 + index;
|
||||||
}
|
}
|
||||||
torch::optional<size_t> size() const override {
|
std::optional<size_t> size() const override {
|
||||||
return size_;
|
return size_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -151,8 +151,8 @@ struct InfiniteStreamDataset
|
||||||
return batch;
|
return batch;
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::optional<size_t> size() const override {
|
std::optional<size_t> size() const override {
|
||||||
return torch::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t counter = 0;
|
size_t counter = 0;
|
||||||
|
|
@ -459,7 +459,7 @@ TEST(DataTest, StackTransformWorksForExample) {
|
||||||
return {tensor[index], 1 + tensor[index]};
|
return {tensor[index], 1 + tensor[index]};
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::optional<size_t> size() const override {
|
std::optional<size_t> size() const override {
|
||||||
return tensor.size(0);
|
return tensor.size(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -503,7 +503,7 @@ struct TensorStringDataset
|
||||||
return {torch::tensor(static_cast<double>(index)), std::to_string(index)};
|
return {torch::tensor(static_cast<double>(index)), std::to_string(index)};
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::optional<size_t> size() const override {
|
std::optional<size_t> size() const override {
|
||||||
return 100;
|
return 100;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -542,7 +542,7 @@ struct DummyTensorDataset
|
||||||
return {tensor, static_cast<int>(channels)};
|
return {tensor, static_cast<int>(channels)};
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::optional<size_t> size() const override {
|
std::optional<size_t> size() const override {
|
||||||
return 100;
|
return 100;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -624,7 +624,7 @@ struct UnCopyableDataset : public datasets::Dataset<UnCopyableDataset> {
|
||||||
torch::tensor({static_cast<int64_t>(index)})};
|
torch::tensor({static_cast<int64_t>(index)})};
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::optional<size_t> size() const override {
|
std::optional<size_t> size() const override {
|
||||||
return 100;
|
return 100;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -753,7 +753,7 @@ struct UncopyableDataset : datasets::Dataset<UncopyableDataset, int> {
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
|
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
|
||||||
return 1 + index;
|
return 1 + index;
|
||||||
}
|
}
|
||||||
torch::optional<size_t> size() const override {
|
std::optional<size_t> size() const override {
|
||||||
return 100;
|
return 100;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -806,7 +806,7 @@ struct TestIndexDataset
|
||||||
}
|
}
|
||||||
return batch;
|
return batch;
|
||||||
}
|
}
|
||||||
torch::optional<size_t> size() const override {
|
std::optional<size_t> size() const override {
|
||||||
return data.size();
|
return data.size();
|
||||||
}
|
}
|
||||||
std::vector<int> data;
|
std::vector<int> data;
|
||||||
|
|
@ -814,10 +814,10 @@ struct TestIndexDataset
|
||||||
|
|
||||||
struct TestIndexSampler : public samplers::Sampler<TestIndex> {
|
struct TestIndexSampler : public samplers::Sampler<TestIndex> {
|
||||||
explicit TestIndexSampler(size_t size) : size_(size) {}
|
explicit TestIndexSampler(size_t size) : size_(size) {}
|
||||||
void reset(torch::optional<size_t> new_size = torch::nullopt) override {}
|
void reset(std::optional<size_t> new_size = std::nullopt) override {}
|
||||||
torch::optional<TestIndex> next(size_t batch_size) override {
|
std::optional<TestIndex> next(size_t batch_size) override {
|
||||||
if (index_ >= size_) {
|
if (index_ >= size_) {
|
||||||
return torch::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
std::vector<size_t> indices(batch_size);
|
std::vector<size_t> indices(batch_size);
|
||||||
std::iota(indices.begin(), indices.end(), size_t(0));
|
std::iota(indices.begin(), indices.end(), size_t(0));
|
||||||
|
|
@ -847,7 +847,7 @@ TEST(DataTest, DistributedRandomSamplerSingleReplicaProduceCorrectSamples) {
|
||||||
samplers::DistributedRandomSampler drs(sample_count);
|
samplers::DistributedRandomSampler drs(sample_count);
|
||||||
|
|
||||||
std::vector<size_t> res;
|
std::vector<size_t> res;
|
||||||
torch::optional<std::vector<size_t>> idx;
|
std::optional<std::vector<size_t>> idx;
|
||||||
while ((idx = drs.next(3)).has_value()) {
|
while ((idx = drs.next(3)).has_value()) {
|
||||||
res.insert(std::end(res), std::begin(*idx), std::end(*idx));
|
res.insert(std::end(res), std::begin(*idx), std::end(*idx));
|
||||||
}
|
}
|
||||||
|
|
@ -879,7 +879,7 @@ TEST(DataTest, DistributedRandomSamplerMultiReplicaProduceCorrectSamples) {
|
||||||
std::vector<size_t> res;
|
std::vector<size_t> res;
|
||||||
for (const auto i : c10::irange(num_replicas)) {
|
for (const auto i : c10::irange(num_replicas)) {
|
||||||
(*samplers[i]).reset();
|
(*samplers[i]).reset();
|
||||||
torch::optional<std::vector<size_t>> idx;
|
std::optional<std::vector<size_t>> idx;
|
||||||
while ((idx = (*samplers[i]).next(batch_size)).has_value()) {
|
while ((idx = (*samplers[i]).next(batch_size)).has_value()) {
|
||||||
res.insert(std::end(res), std::begin(*idx), std::end(*idx));
|
res.insert(std::end(res), std::begin(*idx), std::end(*idx));
|
||||||
}
|
}
|
||||||
|
|
@ -943,7 +943,7 @@ TEST(DataTest, DistributedSequentialSamplerSingleReplicaProduceCorrectSamples) {
|
||||||
samplers::DistributedSequentialSampler dss(sample_count);
|
samplers::DistributedSequentialSampler dss(sample_count);
|
||||||
|
|
||||||
std::vector<size_t> res;
|
std::vector<size_t> res;
|
||||||
torch::optional<std::vector<size_t>> idx;
|
std::optional<std::vector<size_t>> idx;
|
||||||
while ((idx = dss.next(batch_size)).has_value()) {
|
while ((idx = dss.next(batch_size)).has_value()) {
|
||||||
res.insert(std::end(res), std::begin(*idx), std::end(*idx));
|
res.insert(std::end(res), std::begin(*idx), std::end(*idx));
|
||||||
}
|
}
|
||||||
|
|
@ -976,7 +976,7 @@ TEST(DataTest, DistributedSequentialSamplerMultiReplicaProduceCorrectSamples) {
|
||||||
std::vector<size_t> res;
|
std::vector<size_t> res;
|
||||||
for (const auto i : c10::irange(num_replicas)) {
|
for (const auto i : c10::irange(num_replicas)) {
|
||||||
(*samplers[i]).reset();
|
(*samplers[i]).reset();
|
||||||
torch::optional<std::vector<size_t>> idx;
|
std::optional<std::vector<size_t>> idx;
|
||||||
while ((idx = (*samplers[i]).next(batch_size)).has_value()) {
|
while ((idx = (*samplers[i]).next(batch_size)).has_value()) {
|
||||||
res.insert(std::end(res), std::begin(*idx), std::end(*idx));
|
res.insert(std::end(res), std::begin(*idx), std::end(*idx));
|
||||||
}
|
}
|
||||||
|
|
@ -1052,8 +1052,8 @@ struct UnsizedDataset : public datasets::Dataset<UnsizedDataset> {
|
||||||
torch::data::Example<> get(size_t i) override {
|
torch::data::Example<> get(size_t i) override {
|
||||||
return {torch::ones(i), torch::ones(i)};
|
return {torch::ones(i), torch::ones(i)};
|
||||||
}
|
}
|
||||||
torch::optional<size_t> size() const noexcept override {
|
std::optional<size_t> size() const noexcept override {
|
||||||
return torch::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -1150,7 +1150,7 @@ TEST(DataLoaderTest, CanUseIteratorAlgorithms) {
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
|
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
|
||||||
return 1 + indices.front();
|
return 1 + indices.front();
|
||||||
}
|
}
|
||||||
torch::optional<size_t> size() const override {
|
std::optional<size_t> size() const override {
|
||||||
return 10;
|
return 10;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -1270,7 +1270,7 @@ TEST(DataLoaderTest, RespectsTimeout) {
|
||||||
baton->cv.wait_for(lock, 1000 * kMillisecond);
|
baton->cv.wait_for(lock, 1000 * kMillisecond);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
torch::optional<size_t> size() const override {
|
std::optional<size_t> size() const override {
|
||||||
return 100;
|
return 100;
|
||||||
}
|
}
|
||||||
std::shared_ptr<Baton> baton;
|
std::shared_ptr<Baton> baton;
|
||||||
|
|
@ -1388,7 +1388,7 @@ struct Dataset : datasets::BatchDataset<Dataset, size_t> {
|
||||||
return indices.front();
|
return indices.front();
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::optional<size_t> size() const override {
|
std::optional<size_t> size() const override {
|
||||||
return kNumberOfWorkers;
|
return kNumberOfWorkers;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1441,7 +1441,7 @@ TEST(DataLoaderTest, TestExceptionsArePropagatedFromWorkers) {
|
||||||
int get(size_t index) override {
|
int get(size_t index) override {
|
||||||
throw std::invalid_argument("badness");
|
throw std::invalid_argument("badness");
|
||||||
}
|
}
|
||||||
torch::optional<size_t> size() const override {
|
std::optional<size_t> size() const override {
|
||||||
return 100;
|
return 100;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -1467,13 +1467,13 @@ TEST(DataLoaderTest, StatefulDatasetWithNoWorkers) {
|
||||||
const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
|
const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
|
||||||
|
|
||||||
struct D : datasets::StatefulDataset<D, int, size_t> {
|
struct D : datasets::StatefulDataset<D, int, size_t> {
|
||||||
torch::optional<int> get_batch(size_t) override {
|
std::optional<int> get_batch(size_t) override {
|
||||||
if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
|
if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
|
||||||
return counter++;
|
return counter++;
|
||||||
}
|
}
|
||||||
return torch::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
torch::optional<size_t> size() const override {
|
std::optional<size_t> size() const override {
|
||||||
return 100;
|
return 100;
|
||||||
}
|
}
|
||||||
void reset() override {
|
void reset() override {
|
||||||
|
|
@ -1504,14 +1504,14 @@ TEST(DataLoaderTest, StatefulDatasetWithManyWorkers) {
|
||||||
const int kNumberOfWorkers = 4;
|
const int kNumberOfWorkers = 4;
|
||||||
|
|
||||||
struct D : datasets::StatefulDataset<D, int, size_t> {
|
struct D : datasets::StatefulDataset<D, int, size_t> {
|
||||||
torch::optional<int> get_batch(size_t) override {
|
std::optional<int> get_batch(size_t) override {
|
||||||
std::lock_guard<std::mutex> lock(mutex);
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
|
if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
|
||||||
return counter++;
|
return counter++;
|
||||||
}
|
}
|
||||||
return torch::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
torch::optional<size_t> size() const override {
|
std::optional<size_t> size() const override {
|
||||||
return 100;
|
return 100;
|
||||||
}
|
}
|
||||||
void reset() override {
|
void reset() override {
|
||||||
|
|
@ -1544,13 +1544,13 @@ TEST(DataLoaderTest, StatefulDatasetWithMap) {
|
||||||
const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
|
const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
|
||||||
|
|
||||||
struct D : datasets::StatefulDataset<D, int, size_t> {
|
struct D : datasets::StatefulDataset<D, int, size_t> {
|
||||||
torch::optional<int> get_batch(size_t) override {
|
std::optional<int> get_batch(size_t) override {
|
||||||
if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
|
if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
|
||||||
return counter++;
|
return counter++;
|
||||||
}
|
}
|
||||||
return torch::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
torch::optional<size_t> size() const override {
|
std::optional<size_t> size() const override {
|
||||||
return 100;
|
return 100;
|
||||||
}
|
}
|
||||||
void reset() override {
|
void reset() override {
|
||||||
|
|
@ -1587,7 +1587,7 @@ TEST(DataLoaderTest, StatefulDatasetWithCollate) {
|
||||||
const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
|
const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
|
||||||
|
|
||||||
struct D : datasets::StatefulDataset<D> {
|
struct D : datasets::StatefulDataset<D> {
|
||||||
torch::optional<std::vector<Example<>>> get_batch(
|
std::optional<std::vector<Example<>>> get_batch(
|
||||||
size_t batch_size) override {
|
size_t batch_size) override {
|
||||||
if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
|
if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
|
||||||
counter += batch_size;
|
counter += batch_size;
|
||||||
|
|
@ -1597,9 +1597,9 @@ TEST(DataLoaderTest, StatefulDatasetWithCollate) {
|
||||||
torch::ones(batch_size + 1), torch::zeros(batch_size - 1)});
|
torch::ones(batch_size + 1), torch::zeros(batch_size - 1)});
|
||||||
return batch;
|
return batch;
|
||||||
}
|
}
|
||||||
return torch::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
torch::optional<size_t> size() const override {
|
std::optional<size_t> size() const override {
|
||||||
return 100;
|
return 100;
|
||||||
}
|
}
|
||||||
void reset() override {
|
void reset() override {
|
||||||
|
|
@ -1616,7 +1616,7 @@ TEST(DataLoaderTest, StatefulDatasetWithCollate) {
|
||||||
|
|
||||||
// Notice that the `get_batch()` of the dataset returns a vector<Example>, but
|
// Notice that the `get_batch()` of the dataset returns a vector<Example>, but
|
||||||
// the `Stack` collation stacks the tensors into one.
|
// the `Stack` collation stacks the tensors into one.
|
||||||
torch::optional<Example<>> batch = d.get_batch(kBatchSize);
|
std::optional<Example<>> batch = d.get_batch(kBatchSize);
|
||||||
ASSERT_TRUE(batch.has_value());
|
ASSERT_TRUE(batch.has_value());
|
||||||
ASSERT_EQ(batch->data.size(0), kBatchSize);
|
ASSERT_EQ(batch->data.size(0), kBatchSize);
|
||||||
ASSERT_EQ(batch->data.size(1), kBatchSize + 1);
|
ASSERT_EQ(batch->data.size(1), kBatchSize + 1);
|
||||||
|
|
@ -2117,7 +2117,7 @@ TEST(DataLoaderTest, ChunkDatasetCrossChunkShuffle) {
|
||||||
public:
|
public:
|
||||||
explicit S(size_t size) : size_(size), index_(0){};
|
explicit S(size_t size) : size_(size), index_(0){};
|
||||||
|
|
||||||
void reset(torch::optional<size_t> new_size = torch::nullopt) override {
|
void reset(std::optional<size_t> new_size = std::nullopt) override {
|
||||||
if (new_size.has_value()) {
|
if (new_size.has_value()) {
|
||||||
size_ = *new_size;
|
size_ = *new_size;
|
||||||
}
|
}
|
||||||
|
|
@ -2134,10 +2134,10 @@ TEST(DataLoaderTest, ChunkDatasetCrossChunkShuffle) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the next batch of indices.
|
// Returns the next batch of indices.
|
||||||
torch::optional<std::vector<size_t>> next(size_t batch_size) override {
|
std::optional<std::vector<size_t>> next(size_t batch_size) override {
|
||||||
const auto remaining_indices = size_ - index_;
|
const auto remaining_indices = size_ - index_;
|
||||||
if (remaining_indices == 0) {
|
if (remaining_indices == 0) {
|
||||||
return torch::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
auto return_size = std::min(batch_size, remaining_indices);
|
auto return_size = std::min(batch_size, remaining_indices);
|
||||||
std::vector<size_t> index_batch(
|
std::vector<size_t> index_batch(
|
||||||
|
|
|
||||||
|
|
@ -2329,7 +2329,7 @@ TEST_F(FunctionalTest, Interpolate) {
|
||||||
auto tensor = torch::rand({2, 3, 32, 32});
|
auto tensor = torch::rand({2, 3, 32, 32});
|
||||||
std::vector<int64_t> osize = {8, 10};
|
std::vector<int64_t> osize = {8, 10};
|
||||||
auto expected =
|
auto expected =
|
||||||
at::native::_upsample_nearest_exact2d(tensor, osize, torch::nullopt);
|
at::native::_upsample_nearest_exact2d(tensor, osize, std::nullopt);
|
||||||
|
|
||||||
auto options = F::InterpolateFuncOptions()
|
auto options = F::InterpolateFuncOptions()
|
||||||
.size(osize)
|
.size(osize)
|
||||||
|
|
@ -2342,8 +2342,8 @@ TEST_F(FunctionalTest, Interpolate) {
|
||||||
{
|
{
|
||||||
auto tensor = torch::rand({2, 3, 32, 32});
|
auto tensor = torch::rand({2, 3, 32, 32});
|
||||||
std::vector<int64_t> osize = {8, 10};
|
std::vector<int64_t> osize = {8, 10};
|
||||||
auto expected = at::native::_upsample_bilinear2d_aa(
|
auto expected =
|
||||||
tensor, osize, false, torch::nullopt);
|
at::native::_upsample_bilinear2d_aa(tensor, osize, false, std::nullopt);
|
||||||
|
|
||||||
auto options = F::InterpolateFuncOptions()
|
auto options = F::InterpolateFuncOptions()
|
||||||
.size(osize)
|
.size(osize)
|
||||||
|
|
@ -2356,8 +2356,8 @@ TEST_F(FunctionalTest, Interpolate) {
|
||||||
{
|
{
|
||||||
auto tensor = torch::rand({2, 3, 32, 32});
|
auto tensor = torch::rand({2, 3, 32, 32});
|
||||||
std::vector<int64_t> osize = {8, 10};
|
std::vector<int64_t> osize = {8, 10};
|
||||||
auto expected = at::native::_upsample_bicubic2d_aa(
|
auto expected =
|
||||||
tensor, osize, false, torch::nullopt);
|
at::native::_upsample_bicubic2d_aa(tensor, osize, false, std::nullopt);
|
||||||
|
|
||||||
auto options = F::InterpolateFuncOptions()
|
auto options = F::InterpolateFuncOptions()
|
||||||
.size(osize)
|
.size(osize)
|
||||||
|
|
|
||||||
|
|
@ -381,8 +381,8 @@ TEST_F(ModuleTest, CallingCloneOnModuleThatDoesNotOverrideCloneThrows) {
|
||||||
TEST_F(ModuleTest, CallingCloneOnModuleThatDoesOverrideCloneDoesNotThrow) {
|
TEST_F(ModuleTest, CallingCloneOnModuleThatDoesOverrideCloneDoesNotThrow) {
|
||||||
struct Cloneable : Module {
|
struct Cloneable : Module {
|
||||||
std::shared_ptr<Module> clone(
|
std::shared_ptr<Module> clone(
|
||||||
const torch::optional<torch::Device>& device =
|
const std::optional<torch::Device>& device =
|
||||||
torch::nullopt) const override {
|
std::nullopt) const override {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -190,7 +190,7 @@ TEST_F(
|
||||||
auto output = parallel::data_parallel(
|
auto output = parallel::data_parallel(
|
||||||
m,
|
m,
|
||||||
input,
|
input,
|
||||||
/*devices=*/torch::nullopt,
|
/*devices=*/std::nullopt,
|
||||||
/*output_device=*/torch::Device(torch::kCUDA, 1));
|
/*output_device=*/torch::Device(torch::kCUDA, 1));
|
||||||
ASSERT_TRUE(output.defined());
|
ASSERT_TRUE(output.defined());
|
||||||
ASSERT_TRUE(output.device().is_cuda());
|
ASSERT_TRUE(output.device().is_cuda());
|
||||||
|
|
|
||||||
|
|
@ -750,7 +750,7 @@ TEST_F(RNNTest, UsePackedSequenceAsInput) {
|
||||||
std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
|
std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
|
||||||
|
|
||||||
// Test passing optional argument to `LSTM::forward_with_packed_input`
|
// Test passing optional argument to `LSTM::forward_with_packed_input`
|
||||||
rnn_output = m->forward_with_packed_input(packed_input, torch::nullopt);
|
rnn_output = m->forward_with_packed_input(packed_input, std::nullopt);
|
||||||
ASSERT_TRUE(torch::allclose(
|
ASSERT_TRUE(torch::allclose(
|
||||||
std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
|
std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -317,7 +317,7 @@ struct DummyDataset : torch::data::datasets::Dataset<DummyDataset, int> {
|
||||||
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
||||||
return 1 + index;
|
return 1 + index;
|
||||||
}
|
}
|
||||||
torch::optional<size_t> size() const override {
|
std::optional<size_t> size() const override {
|
||||||
return size_;
|
return size_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ template <
|
||||||
class BatchDataBuffer {
|
class BatchDataBuffer {
|
||||||
public:
|
public:
|
||||||
using UnwrappedBatchType = UnwrappedBatch;
|
using UnwrappedBatchType = UnwrappedBatch;
|
||||||
using BatchType = torch::optional<UnwrappedBatchType>;
|
using BatchType = std::optional<UnwrappedBatchType>;
|
||||||
using BatchRequestType = typename ExampleSampler::BatchRequestType;
|
using BatchRequestType = typename ExampleSampler::BatchRequestType;
|
||||||
|
|
||||||
BatchDataBuffer(
|
BatchDataBuffer(
|
||||||
|
|
@ -316,7 +316,7 @@ class ChunkDataset final
|
||||||
typename ChunkReader::BatchType,
|
typename ChunkReader::BatchType,
|
||||||
size_t> {
|
size_t> {
|
||||||
public:
|
public:
|
||||||
using BatchType = torch::optional<typename ChunkReader::BatchType>;
|
using BatchType = std::optional<typename ChunkReader::BatchType>;
|
||||||
using UnwrappedBatchType = typename ChunkReader::BatchType;
|
using UnwrappedBatchType = typename ChunkReader::BatchType;
|
||||||
using BatchRequestType = size_t;
|
using BatchRequestType = size_t;
|
||||||
using ChunkSamplerType = ChunkSampler;
|
using ChunkSamplerType = ChunkSampler;
|
||||||
|
|
@ -404,7 +404,7 @@ class ChunkDataset final
|
||||||
|
|
||||||
/// size is not used for chunk dataset.
|
/// size is not used for chunk dataset.
|
||||||
std::optional<size_t> size() const override {
|
std::optional<size_t> size() const override {
|
||||||
return torch::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
// provide a references to chunk sampler. Used mainly in distributed data
|
// provide a references to chunk sampler. Used mainly in distributed data
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@
|
||||||
namespace torch::data::datasets {
|
namespace torch::data::datasets {
|
||||||
namespace detail {
|
namespace detail {
|
||||||
template <bool C, typename T>
|
template <bool C, typename T>
|
||||||
using optional_if_t = std::conditional_t<C, torch::optional<T>, T>;
|
using optional_if_t = std::conditional_t<C, std::optional<T>, T>;
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
/// A `MapDataset` is a dataset that applies a transform to a source dataset.
|
/// A `MapDataset` is a dataset that applies a transform to a source dataset.
|
||||||
|
|
|
||||||
|
|
@ -158,17 +158,17 @@ class TORCH_API LSTMImpl : public detail::RNNImplBase<LSTMImpl> {
|
||||||
|
|
||||||
std::tuple<Tensor, std::tuple<Tensor, Tensor>> forward(
|
std::tuple<Tensor, std::tuple<Tensor, Tensor>> forward(
|
||||||
const Tensor& input,
|
const Tensor& input,
|
||||||
torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {});
|
std::optional<std::tuple<Tensor, Tensor>> hx_opt = {});
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
FORWARD_HAS_DEFAULT_ARGS(
|
FORWARD_HAS_DEFAULT_ARGS(
|
||||||
{1, AnyValue(torch::optional<std::tuple<Tensor, Tensor>>())})
|
{1, AnyValue(std::optional<std::tuple<Tensor, Tensor>>())})
|
||||||
|
|
||||||
public:
|
public:
|
||||||
std::tuple<torch::nn::utils::rnn::PackedSequence, std::tuple<Tensor, Tensor>>
|
std::tuple<torch::nn::utils::rnn::PackedSequence, std::tuple<Tensor, Tensor>>
|
||||||
forward_with_packed_input(
|
forward_with_packed_input(
|
||||||
const torch::nn::utils::rnn::PackedSequence& packed_input,
|
const torch::nn::utils::rnn::PackedSequence& packed_input,
|
||||||
torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {});
|
std::optional<std::tuple<Tensor, Tensor>> hx_opt = {});
|
||||||
|
|
||||||
LSTMOptions options;
|
LSTMOptions options;
|
||||||
|
|
||||||
|
|
@ -191,7 +191,7 @@ class TORCH_API LSTMImpl : public detail::RNNImplBase<LSTMImpl> {
|
||||||
const Tensor& batch_sizes,
|
const Tensor& batch_sizes,
|
||||||
const Tensor& sorted_indices,
|
const Tensor& sorted_indices,
|
||||||
int64_t max_batch_size,
|
int64_t max_batch_size,
|
||||||
torch::optional<std::tuple<Tensor, Tensor>> hx_opt);
|
std::optional<std::tuple<Tensor, Tensor>> hx_opt);
|
||||||
};
|
};
|
||||||
|
|
||||||
/// A `ModuleHolder` subclass for `LSTMImpl`.
|
/// A `ModuleHolder` subclass for `LSTMImpl`.
|
||||||
|
|
@ -343,11 +343,11 @@ class TORCH_API LSTMCellImpl : public detail::RNNCellImplBase<LSTMCellImpl> {
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor> forward(
|
std::tuple<Tensor, Tensor> forward(
|
||||||
const Tensor& input,
|
const Tensor& input,
|
||||||
torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {});
|
std::optional<std::tuple<Tensor, Tensor>> hx_opt = {});
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
FORWARD_HAS_DEFAULT_ARGS(
|
FORWARD_HAS_DEFAULT_ARGS(
|
||||||
{1, AnyValue(torch::optional<std::tuple<Tensor, Tensor>>())})
|
{1, AnyValue(std::optional<std::tuple<Tensor, Tensor>>())})
|
||||||
|
|
||||||
public:
|
public:
|
||||||
LSTMCellOptions options;
|
LSTMCellOptions options;
|
||||||
|
|
|
||||||
|
|
@ -244,7 +244,7 @@ inline std::tuple<Tensor, Tensor> pad_packed_sequence(
|
||||||
const PackedSequence& sequence,
|
const PackedSequence& sequence,
|
||||||
bool batch_first = false,
|
bool batch_first = false,
|
||||||
double padding_value = 0.0,
|
double padding_value = 0.0,
|
||||||
std::optional<int64_t> total_length = torch::nullopt) {
|
std::optional<int64_t> total_length = std::nullopt) {
|
||||||
int64_t max_seq_length = sequence.batch_sizes().size(0);
|
int64_t max_seq_length = sequence.batch_sizes().size(0);
|
||||||
if (total_length.has_value()) {
|
if (total_length.has_value()) {
|
||||||
int64_t total_length_val = total_length.value();
|
int64_t total_length_val = total_length.value();
|
||||||
|
|
|
||||||
|
|
@ -607,7 +607,7 @@ std::tuple<Tensor, std::tuple<Tensor, Tensor>> LSTMImpl::forward_helper(
|
||||||
const Tensor& batch_sizes,
|
const Tensor& batch_sizes,
|
||||||
const Tensor& sorted_indices,
|
const Tensor& sorted_indices,
|
||||||
int64_t max_batch_size,
|
int64_t max_batch_size,
|
||||||
torch::optional<std::tuple<Tensor, Tensor>> hx_opt) {
|
std::optional<std::tuple<Tensor, Tensor>> hx_opt) {
|
||||||
std::tuple<Tensor, Tensor> hx;
|
std::tuple<Tensor, Tensor> hx;
|
||||||
if (!hx_opt.has_value()) {
|
if (!hx_opt.has_value()) {
|
||||||
int64_t num_directions = options.bidirectional() ? 2 : 1;
|
int64_t num_directions = options.bidirectional() ? 2 : 1;
|
||||||
|
|
@ -664,7 +664,7 @@ std::tuple<Tensor, std::tuple<Tensor, Tensor>> LSTMImpl::forward_helper(
|
||||||
|
|
||||||
std::tuple<Tensor, std::tuple<Tensor, Tensor>> LSTMImpl::forward(
|
std::tuple<Tensor, std::tuple<Tensor, Tensor>> LSTMImpl::forward(
|
||||||
const Tensor& input,
|
const Tensor& input,
|
||||||
torch::optional<std::tuple<Tensor, Tensor>> hx_opt) {
|
std::optional<std::tuple<Tensor, Tensor>> hx_opt) {
|
||||||
auto batch_sizes = torch::Tensor();
|
auto batch_sizes = torch::Tensor();
|
||||||
auto max_batch_size = options.batch_first() ? input.size(0) : input.size(1);
|
auto max_batch_size = options.batch_first() ? input.size(0) : input.size(1);
|
||||||
auto sorted_indices = torch::Tensor();
|
auto sorted_indices = torch::Tensor();
|
||||||
|
|
@ -680,7 +680,7 @@ std::tuple<Tensor, std::tuple<Tensor, Tensor>> LSTMImpl::forward(
|
||||||
std::tuple<PackedSequence, std::tuple<Tensor, Tensor>> LSTMImpl::
|
std::tuple<PackedSequence, std::tuple<Tensor, Tensor>> LSTMImpl::
|
||||||
forward_with_packed_input(
|
forward_with_packed_input(
|
||||||
const PackedSequence& packed_input,
|
const PackedSequence& packed_input,
|
||||||
torch::optional<std::tuple<Tensor, Tensor>> hx_opt) {
|
std::optional<std::tuple<Tensor, Tensor>> hx_opt) {
|
||||||
const auto& input = packed_input.data();
|
const auto& input = packed_input.data();
|
||||||
const auto& batch_sizes = packed_input.batch_sizes();
|
const auto& batch_sizes = packed_input.batch_sizes();
|
||||||
const auto& sorted_indices = packed_input.sorted_indices();
|
const auto& sorted_indices = packed_input.sorted_indices();
|
||||||
|
|
@ -945,7 +945,7 @@ LSTMCellImpl::LSTMCellImpl(const LSTMCellOptions& options_)
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor> LSTMCellImpl::forward(
|
std::tuple<Tensor, Tensor> LSTMCellImpl::forward(
|
||||||
const Tensor& input,
|
const Tensor& input,
|
||||||
torch::optional<std::tuple<Tensor, Tensor>> hx_opt) {
|
std::optional<std::tuple<Tensor, Tensor>> hx_opt) {
|
||||||
this->check_forward_input(input, "input");
|
this->check_forward_input(input, "input");
|
||||||
if (hx_opt.has_value()) {
|
if (hx_opt.has_value()) {
|
||||||
this->check_forward_input(std::get<0>(hx_opt.value()), "hx[0]");
|
this->check_forward_input(std::get<0>(hx_opt.value()), "hx[0]");
|
||||||
|
|
|
||||||
|
|
@ -242,7 +242,7 @@ def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequenc
|
||||||
JIT_TO_CPP_DEFAULT = {
|
JIT_TO_CPP_DEFAULT = {
|
||||||
"False": "false",
|
"False": "false",
|
||||||
"True": "true",
|
"True": "true",
|
||||||
"None": "torch::executorch::nullopt", # UGH this one is type directed
|
"None": "torch::execustd::nullopt", # UGH this one is type directed
|
||||||
"[]": "{}",
|
"[]": "{}",
|
||||||
"contiguous_format": "torch::executorch::MemoryFormat::Contiguous",
|
"contiguous_format": "torch::executorch::MemoryFormat::Contiguous",
|
||||||
"long": "torch::executorch::kLong",
|
"long": "torch::executorch::kLong",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user