torch::optional -> std::optional (#138987)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138987
Approved by: https://github.com/Skylion007
This commit is contained in:
Richard Barnes 2024-10-28 19:09:43 +00:00 committed by PyTorch MergeBot
parent 228963ad60
commit 068f7e7a78
12 changed files with 64 additions and 64 deletions

View File

@ -33,7 +33,7 @@ struct DummyDataset : datasets::Dataset<DummyDataset, int> {
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
return 1 + index; return 1 + index;
} }
torch::optional<size_t> size() const override { std::optional<size_t> size() const override {
return size_; return size_;
} }
@ -151,8 +151,8 @@ struct InfiniteStreamDataset
return batch; return batch;
} }
torch::optional<size_t> size() const override { std::optional<size_t> size() const override {
return torch::nullopt; return std::nullopt;
} }
size_t counter = 0; size_t counter = 0;
@ -459,7 +459,7 @@ TEST(DataTest, StackTransformWorksForExample) {
return {tensor[index], 1 + tensor[index]}; return {tensor[index], 1 + tensor[index]};
} }
torch::optional<size_t> size() const override { std::optional<size_t> size() const override {
return tensor.size(0); return tensor.size(0);
} }
@ -503,7 +503,7 @@ struct TensorStringDataset
return {torch::tensor(static_cast<double>(index)), std::to_string(index)}; return {torch::tensor(static_cast<double>(index)), std::to_string(index)};
} }
torch::optional<size_t> size() const override { std::optional<size_t> size() const override {
return 100; return 100;
} }
}; };
@ -542,7 +542,7 @@ struct DummyTensorDataset
return {tensor, static_cast<int>(channels)}; return {tensor, static_cast<int>(channels)};
} }
torch::optional<size_t> size() const override { std::optional<size_t> size() const override {
return 100; return 100;
} }
}; };
@ -624,7 +624,7 @@ struct UnCopyableDataset : public datasets::Dataset<UnCopyableDataset> {
torch::tensor({static_cast<int64_t>(index)})}; torch::tensor({static_cast<int64_t>(index)})};
} }
torch::optional<size_t> size() const override { std::optional<size_t> size() const override {
return 100; return 100;
} }
}; };
@ -753,7 +753,7 @@ struct UncopyableDataset : datasets::Dataset<UncopyableDataset, int> {
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
return 1 + index; return 1 + index;
} }
torch::optional<size_t> size() const override { std::optional<size_t> size() const override {
return 100; return 100;
} }
}; };
@ -806,7 +806,7 @@ struct TestIndexDataset
} }
return batch; return batch;
} }
torch::optional<size_t> size() const override { std::optional<size_t> size() const override {
return data.size(); return data.size();
} }
std::vector<int> data; std::vector<int> data;
@ -814,10 +814,10 @@ struct TestIndexDataset
struct TestIndexSampler : public samplers::Sampler<TestIndex> { struct TestIndexSampler : public samplers::Sampler<TestIndex> {
explicit TestIndexSampler(size_t size) : size_(size) {} explicit TestIndexSampler(size_t size) : size_(size) {}
void reset(torch::optional<size_t> new_size = torch::nullopt) override {} void reset(std::optional<size_t> new_size = std::nullopt) override {}
torch::optional<TestIndex> next(size_t batch_size) override { std::optional<TestIndex> next(size_t batch_size) override {
if (index_ >= size_) { if (index_ >= size_) {
return torch::nullopt; return std::nullopt;
} }
std::vector<size_t> indices(batch_size); std::vector<size_t> indices(batch_size);
std::iota(indices.begin(), indices.end(), size_t(0)); std::iota(indices.begin(), indices.end(), size_t(0));
@ -847,7 +847,7 @@ TEST(DataTest, DistributedRandomSamplerSingleReplicaProduceCorrectSamples) {
samplers::DistributedRandomSampler drs(sample_count); samplers::DistributedRandomSampler drs(sample_count);
std::vector<size_t> res; std::vector<size_t> res;
torch::optional<std::vector<size_t>> idx; std::optional<std::vector<size_t>> idx;
while ((idx = drs.next(3)).has_value()) { while ((idx = drs.next(3)).has_value()) {
res.insert(std::end(res), std::begin(*idx), std::end(*idx)); res.insert(std::end(res), std::begin(*idx), std::end(*idx));
} }
@ -879,7 +879,7 @@ TEST(DataTest, DistributedRandomSamplerMultiReplicaProduceCorrectSamples) {
std::vector<size_t> res; std::vector<size_t> res;
for (const auto i : c10::irange(num_replicas)) { for (const auto i : c10::irange(num_replicas)) {
(*samplers[i]).reset(); (*samplers[i]).reset();
torch::optional<std::vector<size_t>> idx; std::optional<std::vector<size_t>> idx;
while ((idx = (*samplers[i]).next(batch_size)).has_value()) { while ((idx = (*samplers[i]).next(batch_size)).has_value()) {
res.insert(std::end(res), std::begin(*idx), std::end(*idx)); res.insert(std::end(res), std::begin(*idx), std::end(*idx));
} }
@ -943,7 +943,7 @@ TEST(DataTest, DistributedSequentialSamplerSingleReplicaProduceCorrectSamples) {
samplers::DistributedSequentialSampler dss(sample_count); samplers::DistributedSequentialSampler dss(sample_count);
std::vector<size_t> res; std::vector<size_t> res;
torch::optional<std::vector<size_t>> idx; std::optional<std::vector<size_t>> idx;
while ((idx = dss.next(batch_size)).has_value()) { while ((idx = dss.next(batch_size)).has_value()) {
res.insert(std::end(res), std::begin(*idx), std::end(*idx)); res.insert(std::end(res), std::begin(*idx), std::end(*idx));
} }
@ -976,7 +976,7 @@ TEST(DataTest, DistributedSequentialSamplerMultiReplicaProduceCorrectSamples) {
std::vector<size_t> res; std::vector<size_t> res;
for (const auto i : c10::irange(num_replicas)) { for (const auto i : c10::irange(num_replicas)) {
(*samplers[i]).reset(); (*samplers[i]).reset();
torch::optional<std::vector<size_t>> idx; std::optional<std::vector<size_t>> idx;
while ((idx = (*samplers[i]).next(batch_size)).has_value()) { while ((idx = (*samplers[i]).next(batch_size)).has_value()) {
res.insert(std::end(res), std::begin(*idx), std::end(*idx)); res.insert(std::end(res), std::begin(*idx), std::end(*idx));
} }
@ -1052,8 +1052,8 @@ struct UnsizedDataset : public datasets::Dataset<UnsizedDataset> {
torch::data::Example<> get(size_t i) override { torch::data::Example<> get(size_t i) override {
return {torch::ones(i), torch::ones(i)}; return {torch::ones(i), torch::ones(i)};
} }
torch::optional<size_t> size() const noexcept override { std::optional<size_t> size() const noexcept override {
return torch::nullopt; return std::nullopt;
} }
}; };
@ -1150,7 +1150,7 @@ TEST(DataLoaderTest, CanUseIteratorAlgorithms) {
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
return 1 + indices.front(); return 1 + indices.front();
} }
torch::optional<size_t> size() const override { std::optional<size_t> size() const override {
return 10; return 10;
} }
}; };
@ -1270,7 +1270,7 @@ TEST(DataLoaderTest, RespectsTimeout) {
baton->cv.wait_for(lock, 1000 * kMillisecond); baton->cv.wait_for(lock, 1000 * kMillisecond);
return 0; return 0;
} }
torch::optional<size_t> size() const override { std::optional<size_t> size() const override {
return 100; return 100;
} }
std::shared_ptr<Baton> baton; std::shared_ptr<Baton> baton;
@ -1388,7 +1388,7 @@ struct Dataset : datasets::BatchDataset<Dataset, size_t> {
return indices.front(); return indices.front();
} }
torch::optional<size_t> size() const override { std::optional<size_t> size() const override {
return kNumberOfWorkers; return kNumberOfWorkers;
} }
@ -1441,7 +1441,7 @@ TEST(DataLoaderTest, TestExceptionsArePropagatedFromWorkers) {
int get(size_t index) override { int get(size_t index) override {
throw std::invalid_argument("badness"); throw std::invalid_argument("badness");
} }
torch::optional<size_t> size() const override { std::optional<size_t> size() const override {
return 100; return 100;
} }
}; };
@ -1467,13 +1467,13 @@ TEST(DataLoaderTest, StatefulDatasetWithNoWorkers) {
const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10; const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
struct D : datasets::StatefulDataset<D, int, size_t> { struct D : datasets::StatefulDataset<D, int, size_t> {
torch::optional<int> get_batch(size_t) override { std::optional<int> get_batch(size_t) override {
if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) { if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
return counter++; return counter++;
} }
return torch::nullopt; return std::nullopt;
} }
torch::optional<size_t> size() const override { std::optional<size_t> size() const override {
return 100; return 100;
} }
void reset() override { void reset() override {
@ -1504,14 +1504,14 @@ TEST(DataLoaderTest, StatefulDatasetWithManyWorkers) {
const int kNumberOfWorkers = 4; const int kNumberOfWorkers = 4;
struct D : datasets::StatefulDataset<D, int, size_t> { struct D : datasets::StatefulDataset<D, int, size_t> {
torch::optional<int> get_batch(size_t) override { std::optional<int> get_batch(size_t) override {
std::lock_guard<std::mutex> lock(mutex); std::lock_guard<std::mutex> lock(mutex);
if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) { if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
return counter++; return counter++;
} }
return torch::nullopt; return std::nullopt;
} }
torch::optional<size_t> size() const override { std::optional<size_t> size() const override {
return 100; return 100;
} }
void reset() override { void reset() override {
@ -1544,13 +1544,13 @@ TEST(DataLoaderTest, StatefulDatasetWithMap) {
const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10; const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
struct D : datasets::StatefulDataset<D, int, size_t> { struct D : datasets::StatefulDataset<D, int, size_t> {
torch::optional<int> get_batch(size_t) override { std::optional<int> get_batch(size_t) override {
if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) { if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
return counter++; return counter++;
} }
return torch::nullopt; return std::nullopt;
} }
torch::optional<size_t> size() const override { std::optional<size_t> size() const override {
return 100; return 100;
} }
void reset() override { void reset() override {
@ -1587,7 +1587,7 @@ TEST(DataLoaderTest, StatefulDatasetWithCollate) {
const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10; const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
struct D : datasets::StatefulDataset<D> { struct D : datasets::StatefulDataset<D> {
torch::optional<std::vector<Example<>>> get_batch( std::optional<std::vector<Example<>>> get_batch(
size_t batch_size) override { size_t batch_size) override {
if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) { if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
counter += batch_size; counter += batch_size;
@ -1597,9 +1597,9 @@ TEST(DataLoaderTest, StatefulDatasetWithCollate) {
torch::ones(batch_size + 1), torch::zeros(batch_size - 1)}); torch::ones(batch_size + 1), torch::zeros(batch_size - 1)});
return batch; return batch;
} }
return torch::nullopt; return std::nullopt;
} }
torch::optional<size_t> size() const override { std::optional<size_t> size() const override {
return 100; return 100;
} }
void reset() override { void reset() override {
@ -1616,7 +1616,7 @@ TEST(DataLoaderTest, StatefulDatasetWithCollate) {
// Notice that the `get_batch()` of the dataset returns a vector<Example>, but // Notice that the `get_batch()` of the dataset returns a vector<Example>, but
// the `Stack` collation stacks the tensors into one. // the `Stack` collation stacks the tensors into one.
torch::optional<Example<>> batch = d.get_batch(kBatchSize); std::optional<Example<>> batch = d.get_batch(kBatchSize);
ASSERT_TRUE(batch.has_value()); ASSERT_TRUE(batch.has_value());
ASSERT_EQ(batch->data.size(0), kBatchSize); ASSERT_EQ(batch->data.size(0), kBatchSize);
ASSERT_EQ(batch->data.size(1), kBatchSize + 1); ASSERT_EQ(batch->data.size(1), kBatchSize + 1);
@ -2117,7 +2117,7 @@ TEST(DataLoaderTest, ChunkDatasetCrossChunkShuffle) {
public: public:
explicit S(size_t size) : size_(size), index_(0){}; explicit S(size_t size) : size_(size), index_(0){};
void reset(torch::optional<size_t> new_size = torch::nullopt) override { void reset(std::optional<size_t> new_size = std::nullopt) override {
if (new_size.has_value()) { if (new_size.has_value()) {
size_ = *new_size; size_ = *new_size;
} }
@ -2134,10 +2134,10 @@ TEST(DataLoaderTest, ChunkDatasetCrossChunkShuffle) {
} }
// Returns the next batch of indices. // Returns the next batch of indices.
torch::optional<std::vector<size_t>> next(size_t batch_size) override { std::optional<std::vector<size_t>> next(size_t batch_size) override {
const auto remaining_indices = size_ - index_; const auto remaining_indices = size_ - index_;
if (remaining_indices == 0) { if (remaining_indices == 0) {
return torch::nullopt; return std::nullopt;
} }
auto return_size = std::min(batch_size, remaining_indices); auto return_size = std::min(batch_size, remaining_indices);
std::vector<size_t> index_batch( std::vector<size_t> index_batch(

View File

@ -2329,7 +2329,7 @@ TEST_F(FunctionalTest, Interpolate) {
auto tensor = torch::rand({2, 3, 32, 32}); auto tensor = torch::rand({2, 3, 32, 32});
std::vector<int64_t> osize = {8, 10}; std::vector<int64_t> osize = {8, 10};
auto expected = auto expected =
at::native::_upsample_nearest_exact2d(tensor, osize, torch::nullopt); at::native::_upsample_nearest_exact2d(tensor, osize, std::nullopt);
auto options = F::InterpolateFuncOptions() auto options = F::InterpolateFuncOptions()
.size(osize) .size(osize)
@ -2342,8 +2342,8 @@ TEST_F(FunctionalTest, Interpolate) {
{ {
auto tensor = torch::rand({2, 3, 32, 32}); auto tensor = torch::rand({2, 3, 32, 32});
std::vector<int64_t> osize = {8, 10}; std::vector<int64_t> osize = {8, 10};
auto expected = at::native::_upsample_bilinear2d_aa( auto expected =
tensor, osize, false, torch::nullopt); at::native::_upsample_bilinear2d_aa(tensor, osize, false, std::nullopt);
auto options = F::InterpolateFuncOptions() auto options = F::InterpolateFuncOptions()
.size(osize) .size(osize)
@ -2356,8 +2356,8 @@ TEST_F(FunctionalTest, Interpolate) {
{ {
auto tensor = torch::rand({2, 3, 32, 32}); auto tensor = torch::rand({2, 3, 32, 32});
std::vector<int64_t> osize = {8, 10}; std::vector<int64_t> osize = {8, 10};
auto expected = at::native::_upsample_bicubic2d_aa( auto expected =
tensor, osize, false, torch::nullopt); at::native::_upsample_bicubic2d_aa(tensor, osize, false, std::nullopt);
auto options = F::InterpolateFuncOptions() auto options = F::InterpolateFuncOptions()
.size(osize) .size(osize)

View File

@ -381,8 +381,8 @@ TEST_F(ModuleTest, CallingCloneOnModuleThatDoesNotOverrideCloneThrows) {
TEST_F(ModuleTest, CallingCloneOnModuleThatDoesOverrideCloneDoesNotThrow) { TEST_F(ModuleTest, CallingCloneOnModuleThatDoesOverrideCloneDoesNotThrow) {
struct Cloneable : Module { struct Cloneable : Module {
std::shared_ptr<Module> clone( std::shared_ptr<Module> clone(
const torch::optional<torch::Device>& device = const std::optional<torch::Device>& device =
torch::nullopt) const override { std::nullopt) const override {
return nullptr; return nullptr;
} }
}; };

View File

@ -190,7 +190,7 @@ TEST_F(
auto output = parallel::data_parallel( auto output = parallel::data_parallel(
m, m,
input, input,
/*devices=*/torch::nullopt, /*devices=*/std::nullopt,
/*output_device=*/torch::Device(torch::kCUDA, 1)); /*output_device=*/torch::Device(torch::kCUDA, 1));
ASSERT_TRUE(output.defined()); ASSERT_TRUE(output.defined());
ASSERT_TRUE(output.device().is_cuda()); ASSERT_TRUE(output.device().is_cuda());

View File

@ -750,7 +750,7 @@ TEST_F(RNNTest, UsePackedSequenceAsInput) {
std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04)); std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
// Test passing optional argument to `LSTM::forward_with_packed_input` // Test passing optional argument to `LSTM::forward_with_packed_input`
rnn_output = m->forward_with_packed_input(packed_input, torch::nullopt); rnn_output = m->forward_with_packed_input(packed_input, std::nullopt);
ASSERT_TRUE(torch::allclose( ASSERT_TRUE(torch::allclose(
std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04)); std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
} }

View File

@ -317,7 +317,7 @@ struct DummyDataset : torch::data::datasets::Dataset<DummyDataset, int> {
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
return 1 + index; return 1 + index;
} }
torch::optional<size_t> size() const override { std::optional<size_t> size() const override {
return size_; return size_;
} }

View File

@ -50,7 +50,7 @@ template <
class BatchDataBuffer { class BatchDataBuffer {
public: public:
using UnwrappedBatchType = UnwrappedBatch; using UnwrappedBatchType = UnwrappedBatch;
using BatchType = torch::optional<UnwrappedBatchType>; using BatchType = std::optional<UnwrappedBatchType>;
using BatchRequestType = typename ExampleSampler::BatchRequestType; using BatchRequestType = typename ExampleSampler::BatchRequestType;
BatchDataBuffer( BatchDataBuffer(
@ -316,7 +316,7 @@ class ChunkDataset final
typename ChunkReader::BatchType, typename ChunkReader::BatchType,
size_t> { size_t> {
public: public:
using BatchType = torch::optional<typename ChunkReader::BatchType>; using BatchType = std::optional<typename ChunkReader::BatchType>;
using UnwrappedBatchType = typename ChunkReader::BatchType; using UnwrappedBatchType = typename ChunkReader::BatchType;
using BatchRequestType = size_t; using BatchRequestType = size_t;
using ChunkSamplerType = ChunkSampler; using ChunkSamplerType = ChunkSampler;
@ -404,7 +404,7 @@ class ChunkDataset final
/// size is not used for chunk dataset. /// size is not used for chunk dataset.
std::optional<size_t> size() const override { std::optional<size_t> size() const override {
return torch::nullopt; return std::nullopt;
} }
// provide a references to chunk sampler. Used mainly in distributed data // provide a references to chunk sampler. Used mainly in distributed data

View File

@ -12,7 +12,7 @@
namespace torch::data::datasets { namespace torch::data::datasets {
namespace detail { namespace detail {
template <bool C, typename T> template <bool C, typename T>
using optional_if_t = std::conditional_t<C, torch::optional<T>, T>; using optional_if_t = std::conditional_t<C, std::optional<T>, T>;
} // namespace detail } // namespace detail
/// A `MapDataset` is a dataset that applies a transform to a source dataset. /// A `MapDataset` is a dataset that applies a transform to a source dataset.

View File

@ -158,17 +158,17 @@ class TORCH_API LSTMImpl : public detail::RNNImplBase<LSTMImpl> {
std::tuple<Tensor, std::tuple<Tensor, Tensor>> forward( std::tuple<Tensor, std::tuple<Tensor, Tensor>> forward(
const Tensor& input, const Tensor& input,
torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {}); std::optional<std::tuple<Tensor, Tensor>> hx_opt = {});
protected: protected:
FORWARD_HAS_DEFAULT_ARGS( FORWARD_HAS_DEFAULT_ARGS(
{1, AnyValue(torch::optional<std::tuple<Tensor, Tensor>>())}) {1, AnyValue(std::optional<std::tuple<Tensor, Tensor>>())})
public: public:
std::tuple<torch::nn::utils::rnn::PackedSequence, std::tuple<Tensor, Tensor>> std::tuple<torch::nn::utils::rnn::PackedSequence, std::tuple<Tensor, Tensor>>
forward_with_packed_input( forward_with_packed_input(
const torch::nn::utils::rnn::PackedSequence& packed_input, const torch::nn::utils::rnn::PackedSequence& packed_input,
torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {}); std::optional<std::tuple<Tensor, Tensor>> hx_opt = {});
LSTMOptions options; LSTMOptions options;
@ -191,7 +191,7 @@ class TORCH_API LSTMImpl : public detail::RNNImplBase<LSTMImpl> {
const Tensor& batch_sizes, const Tensor& batch_sizes,
const Tensor& sorted_indices, const Tensor& sorted_indices,
int64_t max_batch_size, int64_t max_batch_size,
torch::optional<std::tuple<Tensor, Tensor>> hx_opt); std::optional<std::tuple<Tensor, Tensor>> hx_opt);
}; };
/// A `ModuleHolder` subclass for `LSTMImpl`. /// A `ModuleHolder` subclass for `LSTMImpl`.
@ -343,11 +343,11 @@ class TORCH_API LSTMCellImpl : public detail::RNNCellImplBase<LSTMCellImpl> {
std::tuple<Tensor, Tensor> forward( std::tuple<Tensor, Tensor> forward(
const Tensor& input, const Tensor& input,
torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {}); std::optional<std::tuple<Tensor, Tensor>> hx_opt = {});
protected: protected:
FORWARD_HAS_DEFAULT_ARGS( FORWARD_HAS_DEFAULT_ARGS(
{1, AnyValue(torch::optional<std::tuple<Tensor, Tensor>>())}) {1, AnyValue(std::optional<std::tuple<Tensor, Tensor>>())})
public: public:
LSTMCellOptions options; LSTMCellOptions options;

View File

@ -244,7 +244,7 @@ inline std::tuple<Tensor, Tensor> pad_packed_sequence(
const PackedSequence& sequence, const PackedSequence& sequence,
bool batch_first = false, bool batch_first = false,
double padding_value = 0.0, double padding_value = 0.0,
std::optional<int64_t> total_length = torch::nullopt) { std::optional<int64_t> total_length = std::nullopt) {
int64_t max_seq_length = sequence.batch_sizes().size(0); int64_t max_seq_length = sequence.batch_sizes().size(0);
if (total_length.has_value()) { if (total_length.has_value()) {
int64_t total_length_val = total_length.value(); int64_t total_length_val = total_length.value();

View File

@ -607,7 +607,7 @@ std::tuple<Tensor, std::tuple<Tensor, Tensor>> LSTMImpl::forward_helper(
const Tensor& batch_sizes, const Tensor& batch_sizes,
const Tensor& sorted_indices, const Tensor& sorted_indices,
int64_t max_batch_size, int64_t max_batch_size,
torch::optional<std::tuple<Tensor, Tensor>> hx_opt) { std::optional<std::tuple<Tensor, Tensor>> hx_opt) {
std::tuple<Tensor, Tensor> hx; std::tuple<Tensor, Tensor> hx;
if (!hx_opt.has_value()) { if (!hx_opt.has_value()) {
int64_t num_directions = options.bidirectional() ? 2 : 1; int64_t num_directions = options.bidirectional() ? 2 : 1;
@ -664,7 +664,7 @@ std::tuple<Tensor, std::tuple<Tensor, Tensor>> LSTMImpl::forward_helper(
std::tuple<Tensor, std::tuple<Tensor, Tensor>> LSTMImpl::forward( std::tuple<Tensor, std::tuple<Tensor, Tensor>> LSTMImpl::forward(
const Tensor& input, const Tensor& input,
torch::optional<std::tuple<Tensor, Tensor>> hx_opt) { std::optional<std::tuple<Tensor, Tensor>> hx_opt) {
auto batch_sizes = torch::Tensor(); auto batch_sizes = torch::Tensor();
auto max_batch_size = options.batch_first() ? input.size(0) : input.size(1); auto max_batch_size = options.batch_first() ? input.size(0) : input.size(1);
auto sorted_indices = torch::Tensor(); auto sorted_indices = torch::Tensor();
@ -680,7 +680,7 @@ std::tuple<Tensor, std::tuple<Tensor, Tensor>> LSTMImpl::forward(
std::tuple<PackedSequence, std::tuple<Tensor, Tensor>> LSTMImpl:: std::tuple<PackedSequence, std::tuple<Tensor, Tensor>> LSTMImpl::
forward_with_packed_input( forward_with_packed_input(
const PackedSequence& packed_input, const PackedSequence& packed_input,
torch::optional<std::tuple<Tensor, Tensor>> hx_opt) { std::optional<std::tuple<Tensor, Tensor>> hx_opt) {
const auto& input = packed_input.data(); const auto& input = packed_input.data();
const auto& batch_sizes = packed_input.batch_sizes(); const auto& batch_sizes = packed_input.batch_sizes();
const auto& sorted_indices = packed_input.sorted_indices(); const auto& sorted_indices = packed_input.sorted_indices();
@ -945,7 +945,7 @@ LSTMCellImpl::LSTMCellImpl(const LSTMCellOptions& options_)
std::tuple<Tensor, Tensor> LSTMCellImpl::forward( std::tuple<Tensor, Tensor> LSTMCellImpl::forward(
const Tensor& input, const Tensor& input,
torch::optional<std::tuple<Tensor, Tensor>> hx_opt) { std::optional<std::tuple<Tensor, Tensor>> hx_opt) {
this->check_forward_input(input, "input"); this->check_forward_input(input, "input");
if (hx_opt.has_value()) { if (hx_opt.has_value()) {
this->check_forward_input(std::get<0>(hx_opt.value()), "hx[0]"); this->check_forward_input(std::get<0>(hx_opt.value()), "hx[0]");

View File

@ -242,7 +242,7 @@ def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequenc
JIT_TO_CPP_DEFAULT = { JIT_TO_CPP_DEFAULT = {
"False": "false", "False": "false",
"True": "true", "True": "true",
"None": "torch::executorch::nullopt", # UGH this one is type directed "None": "torch::execustd::nullopt", # UGH this one is type directed
"[]": "{}", "[]": "{}",
"contiguous_format": "torch::executorch::MemoryFormat::Contiguous", "contiguous_format": "torch::executorch::MemoryFormat::Contiguous",
"long": "torch::executorch::kLong", "long": "torch::executorch::kLong",