[Reland] [11/N] Use std::nullopt and std::optional (#132622)

Reland of #132396, which was reverted due to dependency reversion.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132622
Approved by: https://github.com/ezyang
This commit is contained in:
cyy 2024-08-05 20:36:33 +00:00 committed by PyTorch MergeBot
parent 6f4dc56735
commit 6b12dc0224
28 changed files with 87 additions and 87 deletions

View File

@ -22,7 +22,7 @@ static std::vector<std::optional<at::Tensor>> get_boxed_opt_tensor_vector() {
std::vector<std::optional<at::Tensor>> optional_tensors; std::vector<std::optional<at::Tensor>> optional_tensors;
const size_t SIZE = 5; const size_t SIZE = 5;
for (size_t i = 0; i < SIZE * 2; i++) { for (size_t i = 0; i < SIZE * 2; i++) {
auto opt_tensor = (i % 2 == 0) ? std::optional<at::Tensor>(at::empty({0})) : nullopt; auto opt_tensor = (i % 2 == 0) ? std::optional<at::Tensor>(at::empty({0})) : std::nullopt;
optional_tensors.emplace_back(opt_tensor); optional_tensors.emplace_back(opt_tensor);
} }
return optional_tensors; return optional_tensors;

View File

@ -401,7 +401,7 @@ inline void FunctionSchema::checkAndNormalizeInputs(
} }
auto it = kwargs.find(argument.name()); auto it = kwargs.find(argument.name());
if (it != kwargs.end()) { if (it != kwargs.end()) {
checkArg<T>(it->second, argument, nullopt); checkArg<T>(it->second, argument, std::nullopt);
inputs.push_back(it->second); inputs.push_back(it->second);
consumed_kwargs++; consumed_kwargs++;
continue; continue;

View File

@ -70,13 +70,13 @@ public:
// internal-only for registering stack based kernels // internal-only for registering stack based kernels
template<KernelFunction::BoxedKernelFunction* kernel_func> template<KernelFunction::BoxedKernelFunction* kernel_func>
Options&& kernel(DispatchKey dispatch_key) && { Options&& kernel(DispatchKey dispatch_key) && {
return std::move(*this).kernel(dispatch_key, KernelFunction::makeFromBoxedFunction<kernel_func>(), nullopt, nullptr); return std::move(*this).kernel(dispatch_key, KernelFunction::makeFromBoxedFunction<kernel_func>(), std::nullopt, nullptr);
} }
// internal-only for registering stack based catch-all kernels // internal-only for registering stack based catch-all kernels
template<KernelFunction::BoxedKernelFunction* kernel_func> template<KernelFunction::BoxedKernelFunction* kernel_func>
Options&& catchAllKernel() && { Options&& catchAllKernel() && {
return std::move(*this).kernel(std::nullopt, KernelFunction::makeFromBoxedFunction<kernel_func>(), nullopt, nullptr); return std::move(*this).kernel(std::nullopt, KernelFunction::makeFromBoxedFunction<kernel_func>(), std::nullopt, nullptr);
} }
// internal only for registering caffe2 ops // internal only for registering caffe2 ops

View File

@ -217,7 +217,7 @@ void GradInterpreterPtr::sendToNextInterpreterImpl(
op, stack, *base_, op, stack, *base_,
TransformType::Grad, TransformType::Grad,
prevGradMode(), prevGradMode(),
nullopt, std::nullopt,
grad_special_case); grad_special_case);
} }
@ -234,7 +234,7 @@ void JvpInterpreterPtr::sendToNextInterpreterImpl(
autogradBasedTransformSendToNext( autogradBasedTransformSendToNext(
op, stack, *base_, op, stack, *base_,
TransformType::Jvp, TransformType::Jvp,
nullopt, std::nullopt,
prevFwdGradMode(), prevFwdGradMode(),
grad_special_case); grad_special_case);
} }

View File

@ -103,7 +103,7 @@ convolution_batch_rule(const Tensor& lhs, std::optional<int64_t> lhs_bdim, const
out = reshape_dim_outof_symint(out_spec[1], lhs.sizes()[*lhs_bdim], out); out = reshape_dim_outof_symint(out_spec[1], lhs.sizes()[*lhs_bdim], out);
result = std::make_tuple(out, out_spec[1]); result = std::make_tuple(out, out_spec[1]);
} else { } else {
result = std::make_tuple(at::convolution_symint(lhs, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups), nullopt); result = std::make_tuple(at::convolution_symint(lhs, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups), std::nullopt);
} }
if (separate_bias) { if (separate_bias) {
auto A = std::get<0>(result); auto A = std::get<0>(result);
@ -255,7 +255,7 @@ convolution_backward_input_batch_rule(
const auto weight_ = reshape_dim_into(*weight_bdim, 0, weight); const auto weight_ = reshape_dim_into(*weight_bdim, 0, weight);
auto dummy_input = make_dummy(input, input_bdim, 1, batch_size); auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
const auto result = at::convolution_backward_symint( const auto result = at::convolution_backward_symint(
grad_output_, dummy_input, weight_, nullopt, stride, padding, grad_output_, dummy_input, weight_, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups * batch_size, mask); dilation, transposed, output_padding, groups * batch_size, mask);
const auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result)); const auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result));
return std::make_tuple(grad_input, 1); return std::make_tuple(grad_input, 1);
@ -266,7 +266,7 @@ convolution_backward_input_batch_rule(
const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 0, grad_output); const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 0, grad_output);
auto dummy_input = make_dummy(input, input_bdim, 0, batch_size); auto dummy_input = make_dummy(input, input_bdim, 0, batch_size);
const auto result = at::convolution_backward_symint( const auto result = at::convolution_backward_symint(
grad_output_, dummy_input, weight, nullopt, stride, padding, grad_output_, dummy_input, weight, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask); dilation, transposed, output_padding, groups, mask);
const auto grad_input = reshape_dim_outof(0, batch_size, std::get<0>(result)); const auto grad_input = reshape_dim_outof(0, batch_size, std::get<0>(result));
return std::make_tuple(grad_input, 0); return std::make_tuple(grad_input, 0);
@ -279,7 +279,7 @@ convolution_backward_input_batch_rule(
const auto weight_ = reshape_dim_into(*weight_bdim, in_ch_dim, weight); const auto weight_ = reshape_dim_into(*weight_bdim, in_ch_dim, weight);
auto dummy_input = make_dummy(input, input_bdim, 1, batch_size); auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
const auto result = at::convolution_backward_symint( const auto result = at::convolution_backward_symint(
grad_output, dummy_input, weight_, nullopt, stride, padding, grad_output, dummy_input, weight_, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask); dilation, transposed, output_padding, groups, mask);
const auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result)); const auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result));
return std::make_tuple(grad_input, 1); return std::make_tuple(grad_input, 1);
@ -290,7 +290,7 @@ convolution_backward_input_batch_rule(
const auto weight_ = reshape_dim_into(*weight_bdim, 1, weight); const auto weight_ = reshape_dim_into(*weight_bdim, 1, weight);
auto dummy_input = make_dummy(input, input_bdim, 1, batch_size); auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
const auto result = at::convolution_backward_symint( const auto result = at::convolution_backward_symint(
grad_output, dummy_input, weight_, nullopt, stride, padding, grad_output, dummy_input, weight_, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask); dilation, transposed, output_padding, groups, mask);
grad_input = std::get<0>(result); // N(GBI) grad_input = std::get<0>(result); // N(GBI)
} else { } else {
@ -301,7 +301,7 @@ convolution_backward_input_batch_rule(
weight_ = weight_.flatten(0, 2); // (GBI)O weight_ = weight_.flatten(0, 2); // (GBI)O
const auto dummy_input = make_dummy(input, input_bdim, 1, batch_size); const auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
const auto result = at::convolution_backward_symint( const auto result = at::convolution_backward_symint(
grad_output, dummy_input, weight_, nullopt, stride, padding, grad_output, dummy_input, weight_, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask); dilation, transposed, output_padding, groups, mask);
grad_input = std::get<0>(result); // N(GBI) grad_input = std::get<0>(result); // N(GBI)
} }
@ -315,9 +315,9 @@ convolution_backward_input_batch_rule(
TORCH_INTERNAL_ASSERT(input_bdim); TORCH_INTERNAL_ASSERT(input_bdim);
const auto dummy_input = make_dummy(input, input_bdim, 0, 1); const auto dummy_input = make_dummy(input, input_bdim, 0, 1);
const auto result = at::convolution_backward_symint( const auto result = at::convolution_backward_symint(
grad_output, dummy_input, weight, nullopt, stride, padding, grad_output, dummy_input, weight, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask); dilation, transposed, output_padding, groups, mask);
return std::make_tuple(std::get<0>(result), nullopt); return std::make_tuple(std::get<0>(result), std::nullopt);
} }
} }
static std::tuple<Tensor, std::optional<int64_t>> static std::tuple<Tensor, std::optional<int64_t>>
@ -335,7 +335,7 @@ convolution_backward_weight_batch_rule(
const auto input_ = reshape_dim_into(*input_bdim, 1, input); const auto input_ = reshape_dim_into(*input_bdim, 1, input);
const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size); const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
const auto result = at::convolution_backward_symint( const auto result = at::convolution_backward_symint(
grad_output_, input_, dummy_weight, nullopt, stride, padding, grad_output_, input_, dummy_weight, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups * batch_size, mask); dilation, transposed, output_padding, groups * batch_size, mask);
auto grad_weight = std::get<1>(result); auto grad_weight = std::get<1>(result);
grad_weight = reshape_dim_outof_symint(0, batch_size, grad_weight); grad_weight = reshape_dim_outof_symint(0, batch_size, grad_weight);
@ -349,7 +349,7 @@ convolution_backward_weight_batch_rule(
const auto out_ch_dim = transposed ? 1 : 0; const auto out_ch_dim = transposed ? 1 : 0;
const auto dummy_weight = make_dummy(weight, weight_bdim, out_ch_dim, batch_size); const auto dummy_weight = make_dummy(weight, weight_bdim, out_ch_dim, batch_size);
const auto result = at::convolution_backward_symint( const auto result = at::convolution_backward_symint(
grad_output_, input, dummy_weight, nullopt, stride, padding, grad_output_, input, dummy_weight, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask); dilation, transposed, output_padding, groups, mask);
auto grad_weight = std::get<1>(result); auto grad_weight = std::get<1>(result);
grad_weight = reshape_dim_outof_symint(out_ch_dim, batch_size, grad_weight); grad_weight = reshape_dim_outof_symint(out_ch_dim, batch_size, grad_weight);
@ -363,7 +363,7 @@ convolution_backward_weight_batch_rule(
// BN(GO), N(GI) -> N(GBO), N(GI) -> (GBO)I // BN(GO), N(GI) -> N(GBO), N(GI) -> (GBO)I
const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size); const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
const auto result = at::convolution_backward_symint( const auto result = at::convolution_backward_symint(
grad_output_, input, dummy_weight, nullopt, stride, padding, grad_output_, input, dummy_weight, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask); dilation, transposed, output_padding, groups, mask);
auto grad_weight = std::get<1>(result); auto grad_weight = std::get<1>(result);
grad_weight = grad_weight.unflatten_symint(0, { groups, batch_size, -1 }); // GBOI grad_weight = grad_weight.unflatten_symint(0, { groups, batch_size, -1 }); // GBOI
@ -374,7 +374,7 @@ convolution_backward_weight_batch_rule(
// BN(GO), N(GI) -> N(GBO), N(GI) -> (GI)(BO) // BN(GO), N(GI) -> N(GBO), N(GI) -> (GI)(BO)
const auto dummy_weight = make_dummy(weight, weight_bdim, 1, batch_size); const auto dummy_weight = make_dummy(weight, weight_bdim, 1, batch_size);
const auto result = at::convolution_backward_symint( const auto result = at::convolution_backward_symint(
grad_output_, input, dummy_weight, nullopt, stride, padding, grad_output_, input, dummy_weight, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask); dilation, transposed, output_padding, groups, mask);
auto grad_weight = std::get<1>(result); auto grad_weight = std::get<1>(result);
grad_weight = reshape_dim_outof_symint(1, batch_size, grad_weight); grad_weight = reshape_dim_outof_symint(1, batch_size, grad_weight);
@ -390,7 +390,7 @@ convolution_backward_weight_batch_rule(
const auto in_ch_dim = transposed ? 0 : 1; const auto in_ch_dim = transposed ? 0 : 1;
const auto dummy_weight = make_dummy(weight, weight_bdim, in_ch_dim, batch_size); const auto dummy_weight = make_dummy(weight, weight_bdim, in_ch_dim, batch_size);
const auto result = at::convolution_backward_symint( const auto result = at::convolution_backward_symint(
grad_output, input_, dummy_weight, nullopt, stride, padding, grad_output, input_, dummy_weight, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask); dilation, transposed, output_padding, groups, mask);
auto grad_weight = std::get<1>(result); auto grad_weight = std::get<1>(result);
grad_weight = reshape_dim_outof_symint(in_ch_dim, batch_size, grad_weight); grad_weight = reshape_dim_outof_symint(in_ch_dim, batch_size, grad_weight);
@ -404,7 +404,7 @@ convolution_backward_weight_batch_rule(
// regular: N(GO), BN(GI) -> N(GO), N(GBI) -> (GO)(BI) // regular: N(GO), BN(GI) -> N(GO), N(GBI) -> (GO)(BI)
const auto dummy_weight = make_dummy(weight, weight_bdim, 1, batch_size); const auto dummy_weight = make_dummy(weight, weight_bdim, 1, batch_size);
const auto result = at::convolution_backward_symint( const auto result = at::convolution_backward_symint(
grad_output, input_, dummy_weight, nullopt, stride, padding, grad_output, input_, dummy_weight, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask); dilation, transposed, output_padding, groups, mask);
auto grad_weight = std::get<1>(result); auto grad_weight = std::get<1>(result);
grad_weight = reshape_dim_outof_symint(1, batch_size, grad_weight); grad_weight = reshape_dim_outof_symint(1, batch_size, grad_weight);
@ -413,7 +413,7 @@ convolution_backward_weight_batch_rule(
// transposed: N(GO), BN(GI) -> N(GO), N(GBI) -> (GBI)O // transposed: N(GO), BN(GI) -> N(GO), N(GBI) -> (GBI)O
const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size); const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
const auto result = at::convolution_backward_symint( const auto result = at::convolution_backward_symint(
grad_output, input_, dummy_weight, nullopt, stride, padding, grad_output, input_, dummy_weight, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask); dilation, transposed, output_padding, groups, mask);
auto grad_weight = std::get<1>(result); auto grad_weight = std::get<1>(result);
grad_weight = grad_weight.unflatten_symint(0, { groups, batch_size, -1 }); // GBIO grad_weight = grad_weight.unflatten_symint(0, { groups, batch_size, -1 }); // GBIO
@ -426,9 +426,9 @@ convolution_backward_weight_batch_rule(
TORCH_INTERNAL_ASSERT(weight_bdim); TORCH_INTERNAL_ASSERT(weight_bdim);
const auto dummy_weight = make_dummy(weight, weight_bdim, 0, 1); const auto dummy_weight = make_dummy(weight, weight_bdim, 0, 1);
const auto result = at::convolution_backward_symint( const auto result = at::convolution_backward_symint(
grad_output, input, dummy_weight, nullopt, stride, padding, grad_output, input, dummy_weight, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask); dilation, transposed, output_padding, groups, mask);
return std::make_tuple(std::get<1>(result), nullopt); return std::make_tuple(std::get<1>(result), std::nullopt);
} }
} }
@ -482,7 +482,7 @@ static std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing(
input = reshape_dim_into(*input_bdim, 1, input); input = reshape_dim_into(*input_bdim, 1, input);
weight = reshape_dim_into(*weight_bdim, 0, weight); weight = reshape_dim_into(*weight_bdim, 0, weight);
const auto result = at::convolution_backward_symint( const auto result = at::convolution_backward_symint(
grad_output, input, weight, nullopt, stride, padding, dilation, grad_output, input, weight, std::nullopt, stride, padding, dilation,
transposed, output_padding, batch_size * groups, output_mask); transposed, output_padding, batch_size * groups, output_mask);
// N(BI), (BO)I -> NBI, BOI // N(BI), (BO)I -> NBI, BOI
const auto grad_input = output_mask[0] ? const auto grad_input = output_mask[0] ?

View File

@ -34,7 +34,7 @@ int64_t numelWithoutBatchDim(const Tensor& tensor, std::optional<int64_t> maybe_
return tensor.numel() / tensor.size(*maybe_batch_dim); return tensor.numel() / tensor.size(*maybe_batch_dim);
} }
optional<int64_t> valIfNonempty(optional<int64_t> maybe_empty, int64_t new_val) { std::optional<int64_t> valIfNonempty(std::optional<int64_t> maybe_empty, int64_t new_val) {
if (maybe_empty.has_value()) { if (maybe_empty.has_value()) {
return new_val; return new_val;
} }
@ -43,7 +43,7 @@ optional<int64_t> valIfNonempty(optional<int64_t> maybe_empty, int64_t new_val)
int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim) { int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim) {
// NB: assumes the batch dim is at the front of the tensor // NB: assumes the batch dim is at the front of the tensor
std::optional<int64_t> bdim = has_batch_dim ? std::optional<int64_t>(0) : nullopt; std::optional<int64_t> bdim = has_batch_dim ? std::optional<int64_t>(0) : std::nullopt;
auto rank = rankWithoutBatchDim(tensor, bdim); auto rank = rankWithoutBatchDim(tensor, bdim);
auto wrapped_dim = maybe_wrap_dim(logical_dim, rank); auto wrapped_dim = maybe_wrap_dim(logical_dim, rank);
if (has_batch_dim) { if (has_batch_dim) {
@ -54,7 +54,7 @@ int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical
VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArrayRef logical_dims) { VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArrayRef logical_dims) {
// NB: assumes the batch dim is at the front of the tensor // NB: assumes the batch dim is at the front of the tensor
std::optional<int64_t> bdim = has_batch_dim ? std::optional<int64_t>(0) : nullopt; std::optional<int64_t> bdim = has_batch_dim ? std::optional<int64_t>(0) : std::nullopt;
auto rank = rankWithoutBatchDim(tensor, bdim); auto rank = rankWithoutBatchDim(tensor, bdim);
VmapDimVector result; VmapDimVector result;
result.reserve(logical_dims.size()); result.reserve(logical_dims.size());

View File

@ -109,11 +109,11 @@ static Tensor binary_cross_entropy_plumbing(
auto target_ = moveBatchDimToFront(target_value, target_bdim); auto target_ = moveBatchDimToFront(target_value, target_bdim);
self_ = ensure_has_bdim(self_, self_bdim.has_value(), bdim_size); self_ = ensure_has_bdim(self_, self_bdim.has_value(), bdim_size);
target_ = ensure_has_bdim(target_, target_bdim.has_value(), bdim_size); target_ = ensure_has_bdim(target_, target_bdim.has_value(), bdim_size);
result = at::binary_cross_entropy(self_, target_, nullopt, Reduction::None); result = at::binary_cross_entropy(self_, target_, std::nullopt, Reduction::None);
result = makeBatched(result, 0, cur_level); result = makeBatched(result, 0, cur_level);
} else { } else {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
result = at::binary_cross_entropy(self_value, target_value, nullopt, Reduction::None); result = at::binary_cross_entropy(self_value, target_value, std::nullopt, Reduction::None);
} }
if (weight.has_value() && weight->defined()) { if (weight.has_value() && weight->defined()) {
result = result * weight.value(); result = result * weight.value();
@ -153,12 +153,12 @@ static Tensor binary_cross_entropy_backward_plumbing(
target_ = ensure_has_bdim(target_, target_bdim.has_value(), bdim_size); target_ = ensure_has_bdim(target_, target_bdim.has_value(), bdim_size);
grad_input = at::binary_cross_entropy_backward( grad_input = at::binary_cross_entropy_backward(
grad_, input_, target_, nullopt, Reduction::None); grad_, input_, target_, std::nullopt, Reduction::None);
grad_input = makeBatched(grad_input, 0, cur_level); grad_input = makeBatched(grad_input, 0, cur_level);
} else { } else {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
grad_input = at::binary_cross_entropy_backward( grad_input = at::binary_cross_entropy_backward(
grad_value, input_value, target_value, nullopt, Reduction::None); grad_value, input_value, target_value, std::nullopt, Reduction::None);
} }
if (weight_opt.has_value() && weight_opt->defined()) { if (weight_opt.has_value() && weight_opt->defined()) {
grad_input = grad_input * weight_opt.value(); grad_input = grad_input * weight_opt.value();

View File

@ -130,7 +130,7 @@ grid_sample_batch_rule(const Tensor& input, std::optional<int64_t> input_bdim, c
out = reshape_dim_outof(0, input.sizes()[*grid_bdim], out); out = reshape_dim_outof(0, input.sizes()[*grid_bdim], out);
result = std::make_tuple(std::move(out), 0); result = std::make_tuple(std::move(out), 0);
} else { } else {
result = std::make_tuple(Func(input, grid, std::forward<ExtraArgs>(extra_args)...), nullopt); result = std::make_tuple(Func(input, grid, std::forward<ExtraArgs>(extra_args)...), std::nullopt);
} }
return result; return result;
} }

View File

@ -114,7 +114,7 @@ batch_norm_batch_rule(
if (bias.defined()) { if (bias.defined()) {
const auto result_logical_rank = rankWithoutBatchDim( const auto result_logical_rank = rankWithoutBatchDim(
result0, result0,
bdim_size.has_value() || weight_bdim.has_value() ? std::optional<int64_t>(0) : std::optional<int64_t>(nullopt)); bdim_size.has_value() || weight_bdim.has_value() ? std::optional<int64_t>(0) : std::optional<int64_t>(std::nullopt));
auto bias_ = moveBatchDimToFront(bias, bias_bdim); auto bias_ = moveBatchDimToFront(bias, bias_bdim);
bias_ = padRight(bias_, bias_bdim, result_logical_rank); bias_ = padRight(bias_, bias_bdim, result_logical_rank);
result0 = result0 + bias_; result0 = result0 + bias_;
@ -144,7 +144,7 @@ std::tuple<at::Tensor, std::optional<int64_t>> batch_norm_backward_no_weight_bia
const auto dummy_weight = at::ones(input.size(1), input.options()); const auto dummy_weight = at::ones(input.size(1), input.options());
const auto result = Func( const auto result = Func(
grad_out, input, dummy_weight, running_mean_opt, running_var_opt, mean, rstd, training, eps, {true, false, false}); grad_out, input, dummy_weight, running_mean_opt, running_var_opt, mean, rstd, training, eps, {true, false, false});
return std::make_tuple(std::get<0>(result), nullopt); return std::make_tuple(std::get<0>(result), std::nullopt);
} }
auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim); auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim);
@ -259,13 +259,13 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> batch_norm_backward_plumbing(
// NB: output isn't saved... // NB: output isn't saved...
auto mean = training ? save_mean : running_mean; auto mean = training ? save_mean : running_mean;
auto var = training ? save_rstd : (1 / at::sqrt(running_var + eps)); auto var = training ? save_rstd : (1 / at::sqrt(running_var + eps));
const auto normalized_input = (input.transpose(0, 1) - padRight(mean, nullopt, input.dim())) * padRight(var, nullopt, input.dim()); const auto normalized_input = (input.transpose(0, 1) - padRight(mean, std::nullopt, input.dim())) * padRight(var, std::nullopt, input.dim());
const auto expanded_grad_weight = normalized_input * grad_out.transpose(0, 1); const auto expanded_grad_weight = normalized_input * grad_out.transpose(0, 1);
grad_weight = expanded_grad_weight.sum(range(1, grad_out.dim())); grad_weight = expanded_grad_weight.sum(range(1, grad_out.dim()));
} }
if (output_mask[0]) { if (output_mask[0]) {
const auto grad_normalized_input = weight.defined() ? const auto grad_normalized_input = weight.defined() ?
grad_out.transpose(0, 1) * padRight(weight, nullopt, grad_out.dim()) : grad_out.transpose(0, 1); // [B0, C, B, *] grad_out.transpose(0, 1) * padRight(weight, std::nullopt, grad_out.dim()) : grad_out.transpose(0, 1); // [B0, C, B, *]
auto [grad_normalized_input_value, grad_normalized_input_bdim] = auto [grad_normalized_input_value, grad_normalized_input_bdim] =
unwrapTensorAtLevel(grad_normalized_input.transpose(0, 1), cur_level); // [B0, B, C, *] unwrapTensorAtLevel(grad_normalized_input.transpose(0, 1), cur_level); // [B0, B, C, *]
@ -312,25 +312,25 @@ static std::tuple<Tensor,Tensor,Tensor> native_group_norm_plumbing(
const auto bdim_size = input_value.size(*input_bdim); const auto bdim_size = input_value.size(*input_bdim);
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
const auto result = at::native_group_norm(input_, nullopt, nullopt, N * bdim_size, C, HxW, group, eps); const auto result = at::native_group_norm(input_, std::nullopt, std::nullopt, N * bdim_size, C, HxW, group, eps);
result0 = makeBatched(reshape_dim_outof(0, bdim_size, std::get<0>(result)), 0, cur_level); result0 = makeBatched(reshape_dim_outof(0, bdim_size, std::get<0>(result)), 0, cur_level);
mean = makeBatched(reshape_dim_outof(0, bdim_size, std::get<1>(result)), 0, cur_level); mean = makeBatched(reshape_dim_outof(0, bdim_size, std::get<1>(result)), 0, cur_level);
rstd = makeBatched(reshape_dim_outof(0, bdim_size, std::get<2>(result)), 0, cur_level); rstd = makeBatched(reshape_dim_outof(0, bdim_size, std::get<2>(result)), 0, cur_level);
} else { } else {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
const auto result = at::native_group_norm(input_value, nullopt, nullopt, N, C, HxW, group, eps); const auto result = at::native_group_norm(input_value, std::nullopt, std::nullopt, N, C, HxW, group, eps);
result0 = std::get<0>(result); result0 = std::get<0>(result);
mean = std::get<1>(result); mean = std::get<1>(result);
rstd = std::get<2>(result); rstd = std::get<2>(result);
} }
if (weight.defined()) { if (weight.defined()) {
const auto padded_weight = padRight(weight, nullopt, result0.dim() - 1); const auto padded_weight = padRight(weight, std::nullopt, result0.dim() - 1);
result0 = result0 * padded_weight; result0 = result0 * padded_weight;
} }
if (bias.defined()) { if (bias.defined()) {
const auto padded_bias = padRight(bias, nullopt, result0.dim() - 1); const auto padded_bias = padRight(bias, std::nullopt, result0.dim() - 1);
result0 = result0 + padded_bias; result0 = result0 + padded_bias;
} }
@ -364,7 +364,7 @@ static std::tuple<at::Tensor, std::optional<int64_t>> group_norm_backward_no_wei
input_.contiguous(), input_.contiguous(),
mean_.contiguous(), mean_.contiguous(),
rstd_.contiguous(), rstd_.contiguous(),
nullopt, N * bdim_size, C, HxW, group, {true, false, false}); std::nullopt, N * bdim_size, C, HxW, group, {true, false, false});
auto result0 = std::get<0>(result); auto result0 = std::get<0>(result);
result0 = reshape_dim_outof(0, bdim_size, result0); result0 = reshape_dim_outof(0, bdim_size, result0);
return std::make_tuple(result0, 0); return std::make_tuple(result0, 0);
@ -410,14 +410,14 @@ static std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing(
if (output_mask[1] && weight.defined()) { if (output_mask[1] && weight.defined()) {
const auto reshaped_input = reshape_dim_outof(1, group, input); const auto reshaped_input = reshape_dim_outof(1, group, input);
const auto normalized_input = (reshaped_input - padRight(mean, nullopt, reshaped_input.dim())) * padRight(rstd, nullopt, reshaped_input.dim()); const auto normalized_input = (reshaped_input - padRight(mean, std::nullopt, reshaped_input.dim())) * padRight(rstd, std::nullopt, reshaped_input.dim());
const auto expanded_grad_weight = reshape_dim_into(1, 1, normalized_input) * grad_out; const auto expanded_grad_weight = reshape_dim_into(1, 1, normalized_input) * grad_out;
grad_weight = expanded_grad_weight.transpose(0, 1).sum(range(1, expanded_grad_weight.dim())); grad_weight = expanded_grad_weight.transpose(0, 1).sum(range(1, expanded_grad_weight.dim()));
} }
if (output_mask[0]) { if (output_mask[0]) {
const auto grad_normalized_input = weight.defined() ? const auto grad_normalized_input = weight.defined() ?
grad_out * padRight(weight, nullopt, grad_out.dim() - 1) : grad_out; grad_out * padRight(weight, std::nullopt, grad_out.dim() - 1) : grad_out;
auto [grad_normalized_input_value, grad_normalized_input_bdim] = auto [grad_normalized_input_value, grad_normalized_input_bdim] =
unwrapTensorAtLevel(grad_normalized_input, cur_level); unwrapTensorAtLevel(grad_normalized_input, cur_level);
@ -508,7 +508,7 @@ native_layer_norm_batch_rule(
_check_layer_norm_inputs(normalized_shape, weight, weight_bdim, bias, bias_bdim); _check_layer_norm_inputs(normalized_shape, weight, weight_bdim, bias, bias_bdim);
const auto input_logical_rank = rankWithoutBatchDim(input, input_bdim); const auto input_logical_rank = rankWithoutBatchDim(input, input_bdim);
const auto result = at::native_layer_norm_symint(input_, normalized_shape, nullopt, nullopt, eps); const auto result = at::native_layer_norm_symint(input_, normalized_shape, std::nullopt, std::nullopt, eps);
auto result0 = std::get<0>(result); auto result0 = std::get<0>(result);
const auto mean = std::get<1>(result); const auto mean = std::get<1>(result);
const auto rstd = std::get<2>(result); const auto rstd = std::get<2>(result);
@ -522,7 +522,7 @@ native_layer_norm_batch_rule(
if (bias.defined()) { if (bias.defined()) {
const auto result_logical_rank = rankWithoutBatchDim( const auto result_logical_rank = rankWithoutBatchDim(
result0, result0,
input_bdim.has_value() || weight_bdim.has_value() ? std::optional<int64_t>(0) : std::optional<int64_t>(nullopt)); input_bdim.has_value() || weight_bdim.has_value() ? std::optional<int64_t>(0) : std::optional<int64_t>(std::nullopt));
auto bias_ = moveBatchDimToFront(bias, bias_bdim); auto bias_ = moveBatchDimToFront(bias, bias_bdim);
bias_ = maybePadToLogicalRank(bias_, /*has_bdim*/bias_bdim, result_logical_rank); bias_ = maybePadToLogicalRank(bias_, /*has_bdim*/bias_bdim, result_logical_rank);
result0 = result0 + bias_; result0 = result0 + bias_;
@ -540,8 +540,8 @@ static std::tuple<at::Tensor, std::optional<int64_t>> native_layer_norm_backward
if (!grad_out_bdim.has_value() && !input_bdim.has_value() && if (!grad_out_bdim.has_value() && !input_bdim.has_value() &&
!mean_bdim.has_value() && !rstd_bdim.has_value()) { !mean_bdim.has_value() && !rstd_bdim.has_value()) {
const auto result = at::native_layer_norm_backward( const auto result = at::native_layer_norm_backward(
grad_out, input, normalized_shape, mean, rstd, nullopt, nullopt, {true, false, false}); grad_out, input, normalized_shape, mean, rstd, std::nullopt, std::nullopt, {true, false, false});
return std::make_tuple(std::get<0>(result), nullopt); return std::make_tuple(std::get<0>(result), std::nullopt);
} }
auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim); auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim);
@ -562,7 +562,7 @@ static std::tuple<at::Tensor, std::optional<int64_t>> native_layer_norm_backward
normalized_shape, normalized_shape,
mean_.contiguous(), mean_.contiguous(),
rstd_.contiguous(), rstd_.contiguous(),
nullopt, nullopt, {true, false, false}); std::nullopt, std::nullopt, {true, false, false});
return std::make_tuple(std::get<0>(result), 0); return std::make_tuple(std::get<0>(result), 0);
} }
@ -677,7 +677,7 @@ struct CudnnBatchNormBatchRuleHelper {
auto res = batch_norm_batch_rule<F, Func>( auto res = batch_norm_batch_rule<F, Func>(
input, input_bdim, weight_opt, weight_bdim, bias_opt, bias_bdim, input, input_bdim, weight_opt, weight_bdim, bias_opt, bias_bdim,
running_mean_opt, running_mean_bdim, running_var_opt, running_var_bdim, training, momentum, eps); running_mean_opt, running_mean_bdim, running_var_opt, running_var_bdim, training, momentum, eps);
return std::tuple_cat(res, std::make_tuple(reserve, nullopt)); return std::tuple_cat(res, std::make_tuple(reserve, std::nullopt));
} }
}; };

View File

@ -24,12 +24,12 @@ static Tensor sum_decomp(
static std::tuple<Tensor, std::optional<int64_t>> _is_all_true_batch_rule( static std::tuple<Tensor, std::optional<int64_t>> _is_all_true_batch_rule(
const Tensor& self, std::optional<int64_t> self_bdim) { const Tensor& self, std::optional<int64_t> self_bdim) {
return std::make_tuple(at::_is_all_true(self), nullopt); return std::make_tuple(at::_is_all_true(self), std::nullopt);
} }
static std::tuple<Tensor, std::optional<int64_t>> _is_any_true_batch_rule( static std::tuple<Tensor, std::optional<int64_t>> _is_any_true_batch_rule(
const Tensor& self, std::optional<int64_t> self_bdim) { const Tensor& self, std::optional<int64_t> self_bdim) {
return std::make_tuple(at::_is_any_true(self), nullopt); return std::make_tuple(at::_is_any_true(self), std::nullopt);
} }
static Tensor mean_decomp( static Tensor mean_decomp(
@ -410,7 +410,7 @@ static Tensor bucketize_decomp_Tensor(
bool right) { bool right) {
// checking logical rank // checking logical rank
TORCH_CHECK(boundaries.dim() == 1, "bucketize: boundaries tensor must be 1 dimension, but got dim(", boundaries.dim(), ")"); TORCH_CHECK(boundaries.dim() == 1, "bucketize: boundaries tensor must be 1 dimension, but got dim(", boundaries.dim(), ")");
return at::searchsorted(boundaries, self, out_int32, right, nullopt, nullopt); return at::searchsorted(boundaries, self, out_int32, right, std::nullopt, std::nullopt);
} }
static Tensor bucketize_decomp_Scalar( static Tensor bucketize_decomp_Scalar(
@ -420,7 +420,7 @@ static Tensor bucketize_decomp_Scalar(
bool right) { bool right) {
// checking logical rank // checking logical rank
TORCH_CHECK(boundaries.dim() == 1, "bucketize: boundaries tensor must be 1 dimension, but got dim(", boundaries.dim(), ")"); TORCH_CHECK(boundaries.dim() == 1, "bucketize: boundaries tensor must be 1 dimension, but got dim(", boundaries.dim(), ")");
return at::searchsorted(boundaries, self, out_int32, right, nullopt, nullopt); return at::searchsorted(boundaries, self, out_int32, right, std::nullopt, std::nullopt);
} }
// Use when the other macros don't work out. // Use when the other macros don't work out.

View File

@ -106,7 +106,7 @@ static std::vector<std::optional<Tensor>> batchIndices(
} }
if (!indices_batched && self_bdim.has_value()) { if (!indices_batched && self_bdim.has_value()) {
indices_.insert(indices_.begin(), nullopt); indices_.insert(indices_.begin(), std::nullopt);
} else if (indices_batched && !self_bdim.has_value()) { } else if (indices_batched && !self_bdim.has_value()) {
// do nothing // do nothing
} else if (indices_batched && (self_bdim.has_value() || values_bdim.has_value())) { } else if (indices_batched && (self_bdim.has_value() || values_bdim.has_value())) {

View File

@ -259,7 +259,7 @@ std::tuple<Tensor, std::optional<int64_t>> squeeze_dim_batch_rule(
std::tuple<Tensor, std::optional<int64_t>> select_batching_rule(const Tensor& self, std::optional<int64_t> bdim, int64_t dim, c10::SymInt index) { std::tuple<Tensor, std::optional<int64_t>> select_batching_rule(const Tensor& self, std::optional<int64_t> bdim, int64_t dim, c10::SymInt index) {
if (!bdim) { if (!bdim) {
return std::make_tuple(self.select_symint(dim, std::move(index)), nullopt); return std::make_tuple(self.select_symint(dim, std::move(index)), std::nullopt);
} }
auto _self = moveBatchDimToFront(self, bdim); auto _self = moveBatchDimToFront(self, bdim);

View File

@ -573,7 +573,7 @@ Tensor cat_batching_rule(const ITensorListRef& tensors, int64_t dim) {
} }
auto new_dim = bdim_size.has_value() ? dim + 1 : dim; auto new_dim = bdim_size.has_value() ? dim + 1 : dim;
std::optional<int64_t> new_bdim = bdim_size.has_value() ? std::make_optional((int64_t)0) : nullopt; std::optional<int64_t> new_bdim = bdim_size.has_value() ? std::make_optional((int64_t)0) : std::nullopt;
auto result = at::cat(tensors_to_cat, new_dim); auto result = at::cat(tensors_to_cat, new_dim);
return makeBatched(result, new_bdim, get_current_level()); return makeBatched(result, new_bdim, get_current_level());
} }

View File

@ -43,12 +43,12 @@ std::vector<Tensor> makeBatchedVector(const std::vector<Tensor>& tensors, std::o
std::tuple<Tensor, std::optional<int64_t>> unwrapTensorAtLevel(const Tensor& tensor, int64_t level) { std::tuple<Tensor, std::optional<int64_t>> unwrapTensorAtLevel(const Tensor& tensor, int64_t level) {
auto* batched = maybeGetBatchedImpl(tensor); auto* batched = maybeGetBatchedImpl(tensor);
if (!batched) { if (!batched) {
return std::make_tuple(tensor, nullopt); return std::make_tuple(tensor, std::nullopt);
} }
if (batched->level() == level) { if (batched->level() == level) {
return std::make_tuple(batched->value(), batched->bdim()); return std::make_tuple(batched->value(), batched->bdim());
} }
return std::make_tuple(tensor, nullopt); return std::make_tuple(tensor, std::nullopt);
} }
bool isBatchedAtLevel(const Tensor& tensor, int64_t level) { bool isBatchedAtLevel(const Tensor& tensor, int64_t level) {

View File

@ -33,7 +33,7 @@ TORCH_API Tensor makeBatched(const Tensor& tensor, std::optional<int64_t> bdim,
// Given a Tensor that may or may not be a BatchedTensor, unwrap it. // Given a Tensor that may or may not be a BatchedTensor, unwrap it.
// If `tensor` is not a BatchedTensor, or is a BatchedTensor but the level // If `tensor` is not a BatchedTensor, or is a BatchedTensor but the level
// doesn't match, then this returns (tensor, nullopt). // doesn't match, then this returns (tensor, std::nullopt).
// Otherwise, it returns (unwrap(tensor), bdim). // Otherwise, it returns (unwrap(tensor), bdim).
TORCH_API std::tuple<Tensor, std::optional<int64_t>> unwrapTensorAtLevel(const Tensor& tensor, int64_t level); TORCH_API std::tuple<Tensor, std::optional<int64_t>> unwrapTensorAtLevel(const Tensor& tensor, int64_t level);

View File

@ -227,7 +227,7 @@ Tensor searchsorted_cpu(
Tensor& bucketize_out_cpu(const Tensor& self, const Tensor& boundaries, bool out_int32, bool right, Tensor& result) { Tensor& bucketize_out_cpu(const Tensor& self, const Tensor& boundaries, bool out_int32, bool right, Tensor& result) {
TORCH_CHECK(boundaries.dim() == 1, "boundaries tensor must be 1 dimension, but got dim(", boundaries.dim(), ")"); TORCH_CHECK(boundaries.dim() == 1, "boundaries tensor must be 1 dimension, but got dim(", boundaries.dim(), ")");
at::native::searchsorted_out_cpu(boundaries, self, out_int32, right, nullopt, nullopt, result); at::native::searchsorted_out_cpu(boundaries, self, out_int32, right, std::nullopt, std::nullopt, result);
return result; return result;
} }

View File

@ -163,7 +163,7 @@ static Tensor align(const Tensor& tensor, DimnameList names, bool is_aligning_tw
tensor.names(), tensor.names(),
names, names,
is_aligning_two_tensors); is_aligning_two_tensors);
auto result = tensor.rename(nullopt).view(expanded_sizes); auto result = tensor.rename(std::nullopt).view(expanded_sizes);
at::internal_set_names_inplace(result, names); at::internal_set_names_inplace(result, names);
return result; return result;
} }

View File

@ -241,7 +241,7 @@ TORCH_META_FUNC2(scatter, value_reduce)
const Tensor& index, const Tensor& index,
const Scalar& src, const Scalar& src,
const c10::string_view reduce) { const c10::string_view reduce) {
scatter_meta_impl(*this, self, dim, index, nullopt, reduce); scatter_meta_impl(*this, self, dim, index, std::nullopt, reduce);
} }
TORCH_META_FUNC(scatter_add) TORCH_META_FUNC(scatter_add)

View File

@ -494,9 +494,9 @@ Tensor to(const Tensor& self, Device device, ScalarType dtype, bool non_blocking
return to_impl( return to_impl(
self, self,
dtype, dtype,
nullopt, std::nullopt,
ensure_has_index(device), ensure_has_index(device),
nullopt, std::nullopt,
non_blocking, non_blocking,
copy, copy,
optional_memory_format); optional_memory_format);
@ -506,9 +506,9 @@ Tensor to(const Tensor& self, ScalarType dtype, bool non_blocking, bool copy, st
return to_impl( return to_impl(
self, self,
dtype, dtype,
nullopt, std::nullopt,
nullopt, std::nullopt,
nullopt, std::nullopt,
non_blocking, non_blocking,
copy, copy,
optional_memory_format); optional_memory_format);

View File

@ -214,7 +214,7 @@ Tensor searchsorted_cuda(
Tensor& bucketize_out_cuda(const Tensor& self, const Tensor& boundaries, bool out_int32, bool right, Tensor& result) { Tensor& bucketize_out_cuda(const Tensor& self, const Tensor& boundaries, bool out_int32, bool right, Tensor& result) {
TORCH_CHECK(boundaries.dim() == 1, "boundaries tensor must be 1 dimension, but got dim(", boundaries.dim(), ")"); TORCH_CHECK(boundaries.dim() == 1, "boundaries tensor must be 1 dimension, but got dim(", boundaries.dim(), ")");
at::native::searchsorted_out_cuda(boundaries, self, out_int32, right, nullopt, nullopt, result); at::native::searchsorted_out_cuda(boundaries, self, out_int32, right, std::nullopt, std::nullopt, result);
return result; return result;
} }

View File

@ -702,7 +702,7 @@ Tensor scaled_dot_product_attention(
attn_mask, attn_mask,
dropout_p, dropout_p,
is_causal, is_causal,
c10::nullopt, /*dropout_mask*/ std::nullopt, /*dropout_mask*/
scale)); scale));
} }
return std::get<0>(at::_scaled_dot_product_attention_math( return std::get<0>(at::_scaled_dot_product_attention_math(

View File

@ -9,7 +9,7 @@ using at::operator<<;
// kNullValue is used to contribute a static hash value any time // kNullValue is used to contribute a static hash value any time
// a node has an Optional<Value> input that is nullopt. It is important // a node has an Optional<Value> input that is nullopt. It is important
// to differentiate between HASH(nullopt, something) and HASH(something, nullopt), // to differentiate between HASH(std::nullopt, something) and HASH(something, std::nullopt),
// and using kNullValue in the hash function in the order of arguments // and using kNullValue in the hash function in the order of arguments
// serves this purpose. // serves this purpose.
static const torch::lazy::Value kNullValue = torch::lazy::Value(); static const torch::lazy::Value kNullValue = torch::lazy::Value();

View File

@ -769,7 +769,7 @@ TEST(IValueTest, getSubValues) {
IValue dict(std::move(m)); IValue dict(std::move(m));
auto objType = ClassType::create(nullopt, {}); auto objType = ClassType::create(std::nullopt, {});
objType->addAttribute("t1", tv1.type()); objType->addAttribute("t1", tv1.type());
objType->addAttribute("t2", tv2.type()); objType->addAttribute("t2", tv2.type());

View File

@ -244,7 +244,7 @@ struct OptionalCUDAStreamGuard {
if (r.has_value()) { if (r.has_value()) {
return std::make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value())); return std::make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value()));
} else { } else {
return nullopt; return std::nullopt;
} }
} }
@ -256,7 +256,7 @@ struct OptionalCUDAStreamGuard {
if (r.has_value()) { if (r.has_value()) {
return std::make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value())); return std::make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value()));
} else { } else {
return nullopt; return std::nullopt;
} }
} }

View File

@ -61,7 +61,7 @@ struct Slice {
return i; return i;
} }
} }
return c10::nullopt; return std::nullopt;
} }
bool contains(const T& value) { bool contains(const T& value) {
return index(value).has_value(); return index(value).has_value();

View File

@ -1693,7 +1693,7 @@ static mpy::object dot(Arena& A, TensorInfo lhs, TensorInfo rhs, Slice<DimEntry>
DotPart ro_dims; DotPart ro_dims;
DotPart lr_dims; DotPart lr_dims;
auto insert_dim = [&] (mpy::hdl<Dim> d, at::optional<int> lhs_idx, at::optional<int> rhs_idx) { auto insert_dim = [&] (mpy::hdl<Dim> d, std::optional<int> lhs_idx, std::optional<int> rhs_idx) {
bool reduced = sum.contains(d); bool reduced = sum.contains(d);
int64_t lhs_stride = lhs_idx ? lhs_strides[*lhs_idx] : 0; int64_t lhs_stride = lhs_idx ? lhs_strides[*lhs_idx] : 0;
int64_t rhs_stride = rhs_idx ? rhs_strides[*rhs_idx] : 0; int64_t rhs_stride = rhs_idx ? rhs_strides[*rhs_idx] : 0;
@ -1732,7 +1732,7 @@ static mpy::object dot(Arena& A, TensorInfo lhs, TensorInfo rhs, Slice<DimEntry>
continue; continue;
} }
auto d = rhs.levels[i]; auto d = rhs.levels[i];
insert_dim(d.dim(), at::nullopt, i); insert_dim(d.dim(), std::nullopt, i);
} }
if (lr_dims.dims.size() != sum.size()) { if (lr_dims.dims.size() != sum.size()) {

View File

@ -118,7 +118,7 @@ struct EValue {
at::ArrayRef<double> as_double_list; at::ArrayRef<double> as_double_list;
at::ArrayRef<bool> as_bool_list; at::ArrayRef<bool> as_bool_list;
EValObjectList<at::Tensor> as_tensor_list; EValObjectList<at::Tensor> as_tensor_list;
EValObjectList<at::optional<at::Tensor>> as_list_optional_tensor; EValObjectList<std::optional<at::Tensor>> as_list_optional_tensor;
} copyable_union; } copyable_union;
// Since a Tensor just holds a TensorImpl*, there's no value to use Tensor* // Since a Tensor just holds a TensorImpl*, there's no value to use Tensor*
@ -347,7 +347,7 @@ struct EValue {
} }
/****** List Optional Tensor Type ******/ /****** List Optional Tensor Type ******/
/*implicit*/ EValue(EValObjectList<at::optional<at::Tensor>> t) /*implicit*/ EValue(EValObjectList<std::optional<at::Tensor>> t)
: tag(Tag::ListOptionalTensor) { : tag(Tag::ListOptionalTensor) {
payload.copyable_union.as_list_optional_tensor = t; payload.copyable_union.as_list_optional_tensor = t;
} }
@ -356,7 +356,7 @@ struct EValue {
return tag == Tag::ListOptionalTensor; return tag == Tag::ListOptionalTensor;
} }
at::ArrayRef<at::optional<at::Tensor>> toListOptionalTensor() { at::ArrayRef<std::optional<at::Tensor>> toListOptionalTensor() {
return payload.copyable_union.as_list_optional_tensor.get(); return payload.copyable_union.as_list_optional_tensor.get();
} }
@ -383,9 +383,9 @@ struct EValue {
* an uninitialized state. * an uninitialized state.
*/ */
template <typename T> template <typename T>
inline at::optional<T> toOptional() { inline std::optional<T> toOptional() {
if (this->isNone()) { if (this->isNone()) {
return at::nullopt; return std::nullopt;
} }
return this->to<T>(); return this->to<T>();
} }
@ -455,15 +455,15 @@ EVALUE_DEFINE_TO(double, toDouble)
EVALUE_DEFINE_TO(at::string_view, toString) EVALUE_DEFINE_TO(at::string_view, toString)
EVALUE_DEFINE_TO(at::ScalarType, toScalarType) EVALUE_DEFINE_TO(at::ScalarType, toScalarType)
EVALUE_DEFINE_TO(at::MemoryFormat, toMemoryFormat) EVALUE_DEFINE_TO(at::MemoryFormat, toMemoryFormat)
EVALUE_DEFINE_TO(at::optional<at::Tensor>, toOptional<at::Tensor>) EVALUE_DEFINE_TO(std::optional<at::Tensor>, toOptional<at::Tensor>)
EVALUE_DEFINE_TO(at::ArrayRef<int64_t>, toIntList) EVALUE_DEFINE_TO(at::ArrayRef<int64_t>, toIntList)
EVALUE_DEFINE_TO( EVALUE_DEFINE_TO(
at::optional<at::ArrayRef<int64_t>>, std::optional<at::ArrayRef<int64_t>>,
toOptional<at::ArrayRef<int64_t>>) toOptional<at::ArrayRef<int64_t>>)
EVALUE_DEFINE_TO( EVALUE_DEFINE_TO(
at::optional<at::ArrayRef<double>>, std::optional<at::ArrayRef<double>>,
toOptional<at::ArrayRef<double>>) toOptional<at::ArrayRef<double>>)
EVALUE_DEFINE_TO(at::ArrayRef<at::optional<at::Tensor>>, toListOptionalTensor) EVALUE_DEFINE_TO(at::ArrayRef<std::optional<at::Tensor>>, toListOptionalTensor)
EVALUE_DEFINE_TO(at::ArrayRef<double>, toDoubleList) EVALUE_DEFINE_TO(at::ArrayRef<double>, toDoubleList)
#undef EVALUE_DEFINE_TO #undef EVALUE_DEFINE_TO

View File

@ -434,12 +434,12 @@ static std::tuple<Tensor, std::optional<int64_t>> unwrapBatched(
int64_t level) { int64_t level) {
auto* batched = maybeGetBatchedImpl(tensor); auto* batched = maybeGetBatchedImpl(tensor);
if (!batched) { if (!batched) {
return std::make_tuple(tensor, nullopt); return std::make_tuple(tensor, std::nullopt);
} }
if (batched->level() == level) { if (batched->level() == level) {
return std::make_tuple(batched->value(), batched->bdim()); return std::make_tuple(batched->value(), batched->bdim());
} }
return std::make_tuple(tensor, nullopt); return std::make_tuple(tensor, std::nullopt);
} }
void initFuncTorchBindings(PyObject* module) { void initFuncTorchBindings(PyObject* module) {