nll_loss_forward: port to structured kernel (#61443)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61443

For more information, see #55070.

This PR also adds a new type, `OptionalTensorRef` as a replacement for `c10::optional<Tensor>&` in order to avoid the reference count manipulations that are inevitable with the latter. I have confirmed using Godbolt/Compiler Explorer that this class does indeed avoid manipulating the reference count of the `intrusive_ptr` inside the `Tensor` it refers to:

1. [P429709479](https://www.internalfb.com/phabricator/paste/view/P429709479) - Given a `const Tensor&` in scope, an `OptionalTensorRef` can be constructed without bumping refcount.
2. [P429709883](https://www.internalfb.com/phabricator/paste/view/P429709883) - Given an `OptionalTensorRef`, a `const Tensor&` can be produced without bumping refcount.
3. [P429710335](https://www.internalfb.com/phabricator/paste/view/P429710335) - When `OptionalTensorRef` is destructed, the refcount should not be decremented.
4. [P429769525](https://www.internalfb.com/phabricator/paste/view/P429769525) - `OptionalTensorRef` can be assigned without refcount manipulation.
5. [P429769882](https://www.internalfb.com/phabricator/paste/view/P429769882) - `OptionalTensorRef` can be move assigned without refcount manipulation.

Test Plan: Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D29780666

Pulled By: SplitInfinity

fbshipit-source-id: 7af157215300e9254d635433cbd583f7329fe064
This commit is contained in:
Meghan Lele 2021-07-20 11:44:02 -07:00 committed by Facebook GitHub Bot
parent f0df0207ec
commit 1c80b5220b
8 changed files with 136 additions and 128 deletions

View File

@ -1,3 +1,49 @@
#pragma once
#include <ATen/core/TensorBody.h>
#include <c10/util/Exception.h>
namespace at {
class TORCH_API OptionalTensorRef {
public:
OptionalTensorRef() {}
~OptionalTensorRef() {
ref_.unsafeReleaseTensorImpl();
}
OptionalTensorRef(const Tensor& src)
: ref_(c10::intrusive_ptr<TensorImpl>(
src.unsafeGetTensorImpl(),
c10::raw::DontIncreaseRefcount{})) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src.defined());
}
OptionalTensorRef(const OptionalTensorRef& rhs)
: OptionalTensorRef(rhs.ref_) {}
OptionalTensorRef& operator=(const OptionalTensorRef& rhs) {
// Need to call unsafeReleaseTensorImpl on ref_ since we are reassigning it
// (which does not call the destructor).
ref_.unsafeReleaseTensorImpl();
ref_ = Tensor(c10::intrusive_ptr<TensorImpl>(
rhs.ref_.unsafeGetTensorImpl(), c10::raw::DontIncreaseRefcount{}));
return *this;
}
bool has_value() const {
return ref_.defined();
}
const Tensor& getTensorRef() const & {
return ref_;
}
operator bool() const {
return ref_.defined();
}
private:
Tensor ref_;
};
} // namespace at

View File

@ -2,10 +2,57 @@
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/TensorMeta.h>
#include <ATen/TensorUtils.h>
#include <ATen/native/cpu/utils.h>
namespace at {
namespace meta {
TORCH_META_FUNC(nll_loss_forward)
(const Tensor& self,
const Tensor& target,
const OptionalTensorRef weight_opt,
int64_t reduction,
int64_t ignore_index) {
const Tensor& weight = weight_opt.getTensorRef();
TORCH_CHECK(
self.dim() > 0 && self.dim() <= 2, "input tensor should be 1D or 2D");
TORCH_CHECK(
target.dim() == 1,
"1D target tensor expected, multi-target not supported");
TORCH_CHECK(
self.size(0) == target.size(0),
"size mismatch (got input: ",
self.sizes(),
", target: ",
target.sizes(),
")")
const auto n_classes = self.size(-1);
TORCH_CHECK(
!weight.defined() || (weight.dim() == 1 && weight.numel() == n_classes),
"weight tensor should be defined either for all ",
n_classes,
" classes or no classes"
" but got weight tensor of shape: ",
weight.sizes());
const auto n_dims = self.dim();
const auto batch_size = self.size(0);
if (reduction == Reduction::None && n_dims == 2) {
set_output(0, {batch_size}, self.options());
} else {
// produce scalar output when reducing or input is 1d
set_output(0, {}, self.options());
}
set_output(1, {}, self.options());
}
} // namespace meta
namespace native {
namespace {
@ -26,8 +73,8 @@ inline scalar_t* optional_data(const Tensor& source) {
template <typename scalar_t, typename target_t>
static void nll_loss_out_frame(
Tensor& output,
Tensor& total_weight,
const Tensor& output,
const Tensor& total_weight,
const Tensor& input,
const Tensor& target,
const Tensor& weight,
@ -44,7 +91,6 @@ static void nll_loss_out_frame(
if (reduction == Reduction::None && n_dims == 2) {
const auto batch_size = input.size(0);
output.resize_({batch_size});
auto input_acc = input.accessor<scalar_t, 2>();
auto target_acc = target.accessor<target_t, 1>();
@ -74,9 +120,6 @@ static void nll_loss_out_frame(
return;
}
// produce scalar output when reducing or input is 1d
output.resize_({});
auto input_contiguous = input.contiguous();
auto target_contiguous = target.contiguous();
@ -158,38 +201,13 @@ static void nll_loss_out_frame(
}
void nll_loss_forward_out_cpu_template(
Tensor& output,
Tensor& total_weight,
const Tensor& output,
const Tensor& total_weight,
const Tensor& input,
const Tensor& target,
const Tensor& weight,
int64_t reduction,
int64_t ignore_index) {
TORCH_CHECK(
input.dim() > 0 && input.dim() <= 2, "input tensor should be 1D or 2D");
TORCH_CHECK(
target.dim() == 1,
"1D target tensor expected, multi-target not supported");
TORCH_CHECK(
input.size(0) == target.size(0),
"size mismatch (got input: ",
input.sizes(),
", target: ",
target.sizes(),
")")
const auto n_classes = input.size(-1);
TORCH_CHECK(
!weight.defined() || (weight.dim() == 1 && weight.numel() == n_classes),
"weight tensor should be defined either for all ",
n_classes,
" classes or no classes"
" but got weight tensor of shape: ",
weight.sizes());
total_weight.resize_({});
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, input.scalar_type(), "nll_loss_out_frame", [&] {
if (target.scalar_type() == kByte) {
@ -377,35 +395,17 @@ void nll_loss_backward_out_cpu_template(
} // namespace
std::tuple<Tensor&, Tensor&> nll_loss_forward_out_cpu(const Tensor& self,
const Tensor& target, const c10::optional<Tensor>& weight_opt,
int64_t reduction,
int64_t ignore_index,
Tensor& output,
Tensor& total_weight) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
TORCH_IMPL_FUNC(nll_loss_forward_out_cpu)
(const Tensor& self,
const Tensor& target,
const OptionalTensorRef weight_opt,
int64_t reduction,
int64_t ignore_index,
const Tensor& output,
const Tensor& total_weight) {
const Tensor& weight = weight_opt.getTensorRef();
nll_loss_forward_out_cpu_template(
output, total_weight, self, target, weight, reduction, ignore_index);
return std::tuple<Tensor&, Tensor&>(output, total_weight);
}
std::tuple<Tensor, Tensor> nll_loss_forward_cpu(
const Tensor& self,
const Tensor& target, const c10::optional<Tensor>& weight_opt,
int64_t reduction,
int64_t ignore_index) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
auto output = at::empty({0}, self.options());
auto total_weight = at::empty({0}, self.options());
at::native::nll_loss_forward_out_cpu(
self, target, weight, reduction, ignore_index, output, total_weight);
return std::make_tuple(output, total_weight);
}
Tensor& nll_loss_backward_out_cpu(const Tensor& grad_output,

View File

@ -273,43 +273,16 @@ __global__ void nll_loss_forward_reduce_cuda_kernel_2d(
}
void nll_loss_forward_out_cuda_template(
Tensor& output,
Tensor& total_weight,
const Tensor& output,
const Tensor& total_weight,
const Tensor& input,
const Tensor& target,
const c10::optional<Tensor>& weight_opt,
const Tensor& weight,
int64_t reduction,
int64_t ignore_index) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned =
at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
TORCH_CHECK(
target.dim() == 1,
"1D target tensor expected, multi-target not supported");
int64_t n_classes = input.size(-1);
int64_t n_dims = input.dim();
TORCH_CHECK(n_dims > 0 && n_dims <= 2, "input tensor should be 1D or 2D");
int64_t batch_size = n_dims == 1 ? 1 : input.size(0);
int64_t num_targets = target.size(0);
TORCH_CHECK(
batch_size == num_targets,
"size mismatch (got input: ",
input.sizes(),
", target: ",
target.sizes(),
")")
TORCH_CHECK(
!weight.defined() || (weight.dim() == 1 && weight.numel() == n_classes),
"weight tensor should be defined either for all ",
n_classes,
" classes or no classes"
" but got weight tensor of shape: ",
weight.sizes());
auto weight_ = weight.defined() ? weight.contiguous() : weight;
@ -616,30 +589,17 @@ void nll_loss_backward_out_cuda_template(
} // namespace
std::tuple<Tensor&, Tensor&> nll_loss_forward_out_cuda(
const Tensor& self,
const Tensor& target,
const c10::optional<Tensor>& weight_opt,
int64_t reduction,
int64_t ignore_index,
Tensor& output,
Tensor& total_weight) {
TORCH_IMPL_FUNC(nll_loss_forward_out_cuda)
(const Tensor& self,
const Tensor& target,
const OptionalTensorRef weight_opt,
int64_t reduction,
int64_t ignore_index,
const Tensor& output,
const Tensor& total_weight) {
const Tensor& weight = weight_opt.getTensorRef();
nll_loss_forward_out_cuda_template(
output, total_weight, self, target, weight_opt, reduction, ignore_index);
return std::tuple<Tensor&, Tensor&>(output, total_weight);
}
std::tuple<Tensor, Tensor> nll_loss_forward_cuda(
const Tensor& self,
const Tensor& target,
const c10::optional<Tensor>& weight_opt,
int64_t reduction,
int64_t ignore_index) {
auto output = at::empty({0}, self.options());
auto total_weight = at::empty({0}, self.options());
nll_loss_forward_out_cuda_template(
output, total_weight, self, target, weight_opt, reduction, ignore_index);
return std::make_tuple(output, total_weight);
output, total_weight, self, target, weight, reduction, ignore_index);
}
Tensor& nll_loss_backward_out_cuda(const Tensor& grad_output,

View File

@ -8251,15 +8251,14 @@
- func: nll_loss_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!))
python_module: nn
structured: True
dispatch:
CPU: nll_loss_forward_out_cpu
CUDA: nll_loss_forward_out_cuda
- func: nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight)
python_module: nn
dispatch:
CPU: nll_loss_forward_cpu
CUDA: nll_loss_forward_cuda
structured_delegate: nll_loss_forward.output
- func: nll_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!)
python_module: nn

View File

@ -282,12 +282,6 @@ class intrusive_ptr final {
// intrusive_ptr out of raw pointers except from inside the make_intrusive(),
// reclaim() and weak_intrusive_ptr::lock() implementations.
// This constructor will not increase the ref counter for you.
// We use the tagged dispatch mechanism to explicitly mark this constructor
// to not increase the refcount
explicit intrusive_ptr(TTarget* target, raw::DontIncreaseRefcount) noexcept
: target_(target) {}
// This constructor will increase the ref counter for you.
// This constructor will be used by the make_intrusive(), and also pybind11,
// which wrap the intrusive_ptr holder around the raw pointer and incref
@ -317,6 +311,12 @@ class intrusive_ptr final {
intrusive_ptr() noexcept
: intrusive_ptr(NullType::singleton(), raw::DontIncreaseRefcount{}) {}
// This constructor will not increase the ref counter for you.
// We use the tagged dispatch mechanism to explicitly mark this constructor
// to not increase the refcount
explicit intrusive_ptr(TTarget* target, raw::DontIncreaseRefcount) noexcept
: target_(target) {}
intrusive_ptr(intrusive_ptr&& rhs) noexcept : target_(rhs.target_) {
rhs.target_ = NullType::singleton();
}

View File

@ -5,7 +5,8 @@ from tools.codegen.model import (Argument, BaseTy, BaseType, ListType,
from tools.codegen.api.types import (ArgName, BaseCType, Binding, ArrayRefCType,
ConstRefCType, OptionalCType, NamedCType,
tensorT, scalarT, intArrayRefT, dimnameListT)
tensorT, scalarT, intArrayRefT, dimnameListT,
optionalTensorRefT)
from tools.codegen.api import cpp
from typing import Union, List
@ -32,10 +33,7 @@ def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
raise AssertionError(f"base type should have been value type {t}")
elif isinstance(t, OptionalType):
if t.elem == BaseType(BaseTy.Tensor):
raise AssertionError(
"optional tensor not supported by structured yet; to implement this "
"add OptionalTensor c.f. https://github.com/pytorch/pytorch/issues/51456"
)
return NamedCType(binds, BaseCType(optionalTensorRefT))
elif t.elem == BaseType(BaseTy.Scalar):
raise AssertionError(
"optional scalar not supported by structured yet"

View File

@ -3,7 +3,7 @@ from tools.codegen.api.types import (BaseCType, Binding, ConstRefCType,
Expr, MutRefCType, OptionalCType,
NamedCType, SpecialArgName, tensorT,
memoryFormatT, tensorOptionsT, scalarTypeT,
boolT, deviceT, layoutT)
boolT, deviceT, layoutT, optionalTensorRefT)
# This file implements a small program synthesis engine that implements
# conversions between one API to another.
@ -94,6 +94,10 @@ def translate(
ctx[NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))] = \
f'({b.expr}.has_value() ? *{b.expr} : at::Tensor())'
if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))):
ctx[NamedCType(t.name, BaseCType(optionalTensorRefT))] = \
f'(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())'
# Add implicit bindings if the generated code is inside a Tensor method
if method:
ctx[NamedCType("self", MutRefCType(BaseCType(tensorT)))] = "const_cast<Tensor&>(*this)"

View File

@ -38,6 +38,7 @@ stringT = BaseCppType('c10', 'string_view')
generatorT = BaseCppType('at', 'Generator')
scalarTypeT = BaseCppType('at', 'ScalarType')
tensorT = BaseCppType('at', 'Tensor')
optionalTensorRefT = BaseCppType('at', 'OptionalTensorRef')
tensorListT = BaseCppType('at', 'TensorList')
dimnameT = BaseCppType('at', 'Dimname')
dimnameListT = BaseCppType('at', 'DimnameList')