mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
nll_loss_forward: port to structured kernel (#61443)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61443 For more information, see #55070. This PR also adds a new type, `OptionalTensorRef` as a replacement for `c10::optional<Tensor>&` in order to avoid the reference count manipulations that are inevitable with the latter. I have confirmed using Godbolt/Compiler Explorer that this class does indeed avoid manipulating the reference count of the `intrusive_ptr` inside the `Tensor` it refers to: 1. [P429709479](https://www.internalfb.com/phabricator/paste/view/P429709479) - Given a `const Tensor&` in scope, an `OptionalTensorRef` can be constructed without bumping refcount. 2. [P429709883](https://www.internalfb.com/phabricator/paste/view/P429709883) - Given an `OptionalTensorRef`, a `const Tensor&` can be produced without bumping refcount. 3. [P429710335](https://www.internalfb.com/phabricator/paste/view/P429710335) - When `OptionalTensorRef` is destructed, the refcount should not be decremented. 4. [P429769525](https://www.internalfb.com/phabricator/paste/view/P429769525) - `OptionalTensorRef` can be assigned without refcount manipulation. 5. [P429769882](https://www.internalfb.com/phabricator/paste/view/P429769882) - `OptionalTensorRef` can be move assigned without refcount manipulation. Test Plan: Imported from OSS Reviewed By: jamesr66a Differential Revision: D29780666 Pulled By: SplitInfinity fbshipit-source-id: 7af157215300e9254d635433cbd583f7329fe064
This commit is contained in:
parent
f0df0207ec
commit
1c80b5220b
|
|
@ -1,3 +1,49 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/core/TensorBody.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
namespace at {
|
||||
class TORCH_API OptionalTensorRef {
|
||||
public:
|
||||
OptionalTensorRef() {}
|
||||
|
||||
~OptionalTensorRef() {
|
||||
ref_.unsafeReleaseTensorImpl();
|
||||
}
|
||||
|
||||
OptionalTensorRef(const Tensor& src)
|
||||
: ref_(c10::intrusive_ptr<TensorImpl>(
|
||||
src.unsafeGetTensorImpl(),
|
||||
c10::raw::DontIncreaseRefcount{})) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src.defined());
|
||||
}
|
||||
|
||||
OptionalTensorRef(const OptionalTensorRef& rhs)
|
||||
: OptionalTensorRef(rhs.ref_) {}
|
||||
|
||||
OptionalTensorRef& operator=(const OptionalTensorRef& rhs) {
|
||||
// Need to call unsafeReleaseTensorImpl on ref_ since we are reassigning it
|
||||
// (which does not call the destructor).
|
||||
ref_.unsafeReleaseTensorImpl();
|
||||
ref_ = Tensor(c10::intrusive_ptr<TensorImpl>(
|
||||
rhs.ref_.unsafeGetTensorImpl(), c10::raw::DontIncreaseRefcount{}));
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool has_value() const {
|
||||
return ref_.defined();
|
||||
}
|
||||
|
||||
const Tensor& getTensorRef() const & {
|
||||
return ref_;
|
||||
}
|
||||
|
||||
operator bool() const {
|
||||
return ref_.defined();
|
||||
}
|
||||
|
||||
private:
|
||||
Tensor ref_;
|
||||
};
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -2,10 +2,57 @@
|
|||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/TensorMeta.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/native/cpu/utils.h>
|
||||
|
||||
namespace at {
|
||||
namespace meta {
|
||||
TORCH_META_FUNC(nll_loss_forward)
|
||||
(const Tensor& self,
|
||||
const Tensor& target,
|
||||
const OptionalTensorRef weight_opt,
|
||||
int64_t reduction,
|
||||
int64_t ignore_index) {
|
||||
const Tensor& weight = weight_opt.getTensorRef();
|
||||
|
||||
TORCH_CHECK(
|
||||
self.dim() > 0 && self.dim() <= 2, "input tensor should be 1D or 2D");
|
||||
TORCH_CHECK(
|
||||
target.dim() == 1,
|
||||
"1D target tensor expected, multi-target not supported");
|
||||
TORCH_CHECK(
|
||||
self.size(0) == target.size(0),
|
||||
"size mismatch (got input: ",
|
||||
self.sizes(),
|
||||
", target: ",
|
||||
target.sizes(),
|
||||
")")
|
||||
|
||||
const auto n_classes = self.size(-1);
|
||||
|
||||
TORCH_CHECK(
|
||||
!weight.defined() || (weight.dim() == 1 && weight.numel() == n_classes),
|
||||
"weight tensor should be defined either for all ",
|
||||
n_classes,
|
||||
" classes or no classes"
|
||||
" but got weight tensor of shape: ",
|
||||
weight.sizes());
|
||||
|
||||
const auto n_dims = self.dim();
|
||||
const auto batch_size = self.size(0);
|
||||
|
||||
if (reduction == Reduction::None && n_dims == 2) {
|
||||
set_output(0, {batch_size}, self.options());
|
||||
} else {
|
||||
// produce scalar output when reducing or input is 1d
|
||||
set_output(0, {}, self.options());
|
||||
}
|
||||
|
||||
set_output(1, {}, self.options());
|
||||
}
|
||||
} // namespace meta
|
||||
|
||||
namespace native {
|
||||
|
||||
namespace {
|
||||
|
|
@ -26,8 +73,8 @@ inline scalar_t* optional_data(const Tensor& source) {
|
|||
|
||||
template <typename scalar_t, typename target_t>
|
||||
static void nll_loss_out_frame(
|
||||
Tensor& output,
|
||||
Tensor& total_weight,
|
||||
const Tensor& output,
|
||||
const Tensor& total_weight,
|
||||
const Tensor& input,
|
||||
const Tensor& target,
|
||||
const Tensor& weight,
|
||||
|
|
@ -44,7 +91,6 @@ static void nll_loss_out_frame(
|
|||
|
||||
if (reduction == Reduction::None && n_dims == 2) {
|
||||
const auto batch_size = input.size(0);
|
||||
output.resize_({batch_size});
|
||||
|
||||
auto input_acc = input.accessor<scalar_t, 2>();
|
||||
auto target_acc = target.accessor<target_t, 1>();
|
||||
|
|
@ -74,9 +120,6 @@ static void nll_loss_out_frame(
|
|||
return;
|
||||
}
|
||||
|
||||
// produce scalar output when reducing or input is 1d
|
||||
output.resize_({});
|
||||
|
||||
auto input_contiguous = input.contiguous();
|
||||
auto target_contiguous = target.contiguous();
|
||||
|
||||
|
|
@ -158,38 +201,13 @@ static void nll_loss_out_frame(
|
|||
}
|
||||
|
||||
void nll_loss_forward_out_cpu_template(
|
||||
Tensor& output,
|
||||
Tensor& total_weight,
|
||||
const Tensor& output,
|
||||
const Tensor& total_weight,
|
||||
const Tensor& input,
|
||||
const Tensor& target,
|
||||
const Tensor& weight,
|
||||
int64_t reduction,
|
||||
int64_t ignore_index) {
|
||||
TORCH_CHECK(
|
||||
input.dim() > 0 && input.dim() <= 2, "input tensor should be 1D or 2D");
|
||||
TORCH_CHECK(
|
||||
target.dim() == 1,
|
||||
"1D target tensor expected, multi-target not supported");
|
||||
TORCH_CHECK(
|
||||
input.size(0) == target.size(0),
|
||||
"size mismatch (got input: ",
|
||||
input.sizes(),
|
||||
", target: ",
|
||||
target.sizes(),
|
||||
")")
|
||||
|
||||
const auto n_classes = input.size(-1);
|
||||
|
||||
TORCH_CHECK(
|
||||
!weight.defined() || (weight.dim() == 1 && weight.numel() == n_classes),
|
||||
"weight tensor should be defined either for all ",
|
||||
n_classes,
|
||||
" classes or no classes"
|
||||
" but got weight tensor of shape: ",
|
||||
weight.sizes());
|
||||
|
||||
total_weight.resize_({});
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
ScalarType::BFloat16, input.scalar_type(), "nll_loss_out_frame", [&] {
|
||||
if (target.scalar_type() == kByte) {
|
||||
|
|
@ -377,35 +395,17 @@ void nll_loss_backward_out_cpu_template(
|
|||
|
||||
} // namespace
|
||||
|
||||
std::tuple<Tensor&, Tensor&> nll_loss_forward_out_cpu(const Tensor& self,
|
||||
const Tensor& target, const c10::optional<Tensor>& weight_opt,
|
||||
int64_t reduction,
|
||||
int64_t ignore_index,
|
||||
Tensor& output,
|
||||
Tensor& total_weight) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
|
||||
const Tensor& weight = *weight_maybe_owned;
|
||||
|
||||
TORCH_IMPL_FUNC(nll_loss_forward_out_cpu)
|
||||
(const Tensor& self,
|
||||
const Tensor& target,
|
||||
const OptionalTensorRef weight_opt,
|
||||
int64_t reduction,
|
||||
int64_t ignore_index,
|
||||
const Tensor& output,
|
||||
const Tensor& total_weight) {
|
||||
const Tensor& weight = weight_opt.getTensorRef();
|
||||
nll_loss_forward_out_cpu_template(
|
||||
output, total_weight, self, target, weight, reduction, ignore_index);
|
||||
return std::tuple<Tensor&, Tensor&>(output, total_weight);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> nll_loss_forward_cpu(
|
||||
const Tensor& self,
|
||||
const Tensor& target, const c10::optional<Tensor>& weight_opt,
|
||||
int64_t reduction,
|
||||
int64_t ignore_index) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
|
||||
const Tensor& weight = *weight_maybe_owned;
|
||||
|
||||
auto output = at::empty({0}, self.options());
|
||||
auto total_weight = at::empty({0}, self.options());
|
||||
at::native::nll_loss_forward_out_cpu(
|
||||
self, target, weight, reduction, ignore_index, output, total_weight);
|
||||
return std::make_tuple(output, total_weight);
|
||||
}
|
||||
|
||||
Tensor& nll_loss_backward_out_cpu(const Tensor& grad_output,
|
||||
|
|
|
|||
|
|
@ -273,43 +273,16 @@ __global__ void nll_loss_forward_reduce_cuda_kernel_2d(
|
|||
}
|
||||
|
||||
void nll_loss_forward_out_cuda_template(
|
||||
Tensor& output,
|
||||
Tensor& total_weight,
|
||||
const Tensor& output,
|
||||
const Tensor& total_weight,
|
||||
const Tensor& input,
|
||||
const Tensor& target,
|
||||
const c10::optional<Tensor>& weight_opt,
|
||||
const Tensor& weight,
|
||||
int64_t reduction,
|
||||
int64_t ignore_index) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
c10::MaybeOwned<Tensor> weight_maybe_owned =
|
||||
at::borrow_from_optional_tensor(weight_opt);
|
||||
const Tensor& weight = *weight_maybe_owned;
|
||||
|
||||
TORCH_CHECK(
|
||||
target.dim() == 1,
|
||||
"1D target tensor expected, multi-target not supported");
|
||||
|
||||
int64_t n_classes = input.size(-1);
|
||||
int64_t n_dims = input.dim();
|
||||
|
||||
TORCH_CHECK(n_dims > 0 && n_dims <= 2, "input tensor should be 1D or 2D");
|
||||
int64_t batch_size = n_dims == 1 ? 1 : input.size(0);
|
||||
int64_t num_targets = target.size(0);
|
||||
TORCH_CHECK(
|
||||
batch_size == num_targets,
|
||||
"size mismatch (got input: ",
|
||||
input.sizes(),
|
||||
", target: ",
|
||||
target.sizes(),
|
||||
")")
|
||||
|
||||
TORCH_CHECK(
|
||||
!weight.defined() || (weight.dim() == 1 && weight.numel() == n_classes),
|
||||
"weight tensor should be defined either for all ",
|
||||
n_classes,
|
||||
" classes or no classes"
|
||||
" but got weight tensor of shape: ",
|
||||
weight.sizes());
|
||||
|
||||
auto weight_ = weight.defined() ? weight.contiguous() : weight;
|
||||
|
||||
|
|
@ -616,30 +589,17 @@ void nll_loss_backward_out_cuda_template(
|
|||
|
||||
} // namespace
|
||||
|
||||
std::tuple<Tensor&, Tensor&> nll_loss_forward_out_cuda(
|
||||
const Tensor& self,
|
||||
const Tensor& target,
|
||||
const c10::optional<Tensor>& weight_opt,
|
||||
int64_t reduction,
|
||||
int64_t ignore_index,
|
||||
Tensor& output,
|
||||
Tensor& total_weight) {
|
||||
TORCH_IMPL_FUNC(nll_loss_forward_out_cuda)
|
||||
(const Tensor& self,
|
||||
const Tensor& target,
|
||||
const OptionalTensorRef weight_opt,
|
||||
int64_t reduction,
|
||||
int64_t ignore_index,
|
||||
const Tensor& output,
|
||||
const Tensor& total_weight) {
|
||||
const Tensor& weight = weight_opt.getTensorRef();
|
||||
nll_loss_forward_out_cuda_template(
|
||||
output, total_weight, self, target, weight_opt, reduction, ignore_index);
|
||||
return std::tuple<Tensor&, Tensor&>(output, total_weight);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> nll_loss_forward_cuda(
|
||||
const Tensor& self,
|
||||
const Tensor& target,
|
||||
const c10::optional<Tensor>& weight_opt,
|
||||
int64_t reduction,
|
||||
int64_t ignore_index) {
|
||||
auto output = at::empty({0}, self.options());
|
||||
auto total_weight = at::empty({0}, self.options());
|
||||
nll_loss_forward_out_cuda_template(
|
||||
output, total_weight, self, target, weight_opt, reduction, ignore_index);
|
||||
return std::make_tuple(output, total_weight);
|
||||
output, total_weight, self, target, weight, reduction, ignore_index);
|
||||
}
|
||||
|
||||
Tensor& nll_loss_backward_out_cuda(const Tensor& grad_output,
|
||||
|
|
|
|||
|
|
@ -8251,15 +8251,14 @@
|
|||
|
||||
- func: nll_loss_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!))
|
||||
python_module: nn
|
||||
structured: True
|
||||
dispatch:
|
||||
CPU: nll_loss_forward_out_cpu
|
||||
CUDA: nll_loss_forward_out_cuda
|
||||
|
||||
- func: nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight)
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CPU: nll_loss_forward_cpu
|
||||
CUDA: nll_loss_forward_cuda
|
||||
structured_delegate: nll_loss_forward.output
|
||||
|
||||
- func: nll_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!)
|
||||
python_module: nn
|
||||
|
|
|
|||
|
|
@ -282,12 +282,6 @@ class intrusive_ptr final {
|
|||
// intrusive_ptr out of raw pointers except from inside the make_intrusive(),
|
||||
// reclaim() and weak_intrusive_ptr::lock() implementations.
|
||||
|
||||
// This constructor will not increase the ref counter for you.
|
||||
// We use the tagged dispatch mechanism to explicitly mark this constructor
|
||||
// to not increase the refcount
|
||||
explicit intrusive_ptr(TTarget* target, raw::DontIncreaseRefcount) noexcept
|
||||
: target_(target) {}
|
||||
|
||||
// This constructor will increase the ref counter for you.
|
||||
// This constructor will be used by the make_intrusive(), and also pybind11,
|
||||
// which wrap the intrusive_ptr holder around the raw pointer and incref
|
||||
|
|
@ -317,6 +311,12 @@ class intrusive_ptr final {
|
|||
intrusive_ptr() noexcept
|
||||
: intrusive_ptr(NullType::singleton(), raw::DontIncreaseRefcount{}) {}
|
||||
|
||||
// This constructor will not increase the ref counter for you.
|
||||
// We use the tagged dispatch mechanism to explicitly mark this constructor
|
||||
// to not increase the refcount
|
||||
explicit intrusive_ptr(TTarget* target, raw::DontIncreaseRefcount) noexcept
|
||||
: target_(target) {}
|
||||
|
||||
intrusive_ptr(intrusive_ptr&& rhs) noexcept : target_(rhs.target_) {
|
||||
rhs.target_ = NullType::singleton();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@ from tools.codegen.model import (Argument, BaseTy, BaseType, ListType,
|
|||
|
||||
from tools.codegen.api.types import (ArgName, BaseCType, Binding, ArrayRefCType,
|
||||
ConstRefCType, OptionalCType, NamedCType,
|
||||
tensorT, scalarT, intArrayRefT, dimnameListT)
|
||||
tensorT, scalarT, intArrayRefT, dimnameListT,
|
||||
optionalTensorRefT)
|
||||
from tools.codegen.api import cpp
|
||||
|
||||
from typing import Union, List
|
||||
|
|
@ -32,10 +33,7 @@ def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
|
|||
raise AssertionError(f"base type should have been value type {t}")
|
||||
elif isinstance(t, OptionalType):
|
||||
if t.elem == BaseType(BaseTy.Tensor):
|
||||
raise AssertionError(
|
||||
"optional tensor not supported by structured yet; to implement this "
|
||||
"add OptionalTensor c.f. https://github.com/pytorch/pytorch/issues/51456"
|
||||
)
|
||||
return NamedCType(binds, BaseCType(optionalTensorRefT))
|
||||
elif t.elem == BaseType(BaseTy.Scalar):
|
||||
raise AssertionError(
|
||||
"optional scalar not supported by structured yet"
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from tools.codegen.api.types import (BaseCType, Binding, ConstRefCType,
|
|||
Expr, MutRefCType, OptionalCType,
|
||||
NamedCType, SpecialArgName, tensorT,
|
||||
memoryFormatT, tensorOptionsT, scalarTypeT,
|
||||
boolT, deviceT, layoutT)
|
||||
boolT, deviceT, layoutT, optionalTensorRefT)
|
||||
|
||||
# This file implements a small program synthesis engine that implements
|
||||
# conversions between one API to another.
|
||||
|
|
@ -94,6 +94,10 @@ def translate(
|
|||
ctx[NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))] = \
|
||||
f'({b.expr}.has_value() ? *{b.expr} : at::Tensor())'
|
||||
|
||||
if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))):
|
||||
ctx[NamedCType(t.name, BaseCType(optionalTensorRefT))] = \
|
||||
f'(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())'
|
||||
|
||||
# Add implicit bindings if the generated code is inside a Tensor method
|
||||
if method:
|
||||
ctx[NamedCType("self", MutRefCType(BaseCType(tensorT)))] = "const_cast<Tensor&>(*this)"
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ stringT = BaseCppType('c10', 'string_view')
|
|||
generatorT = BaseCppType('at', 'Generator')
|
||||
scalarTypeT = BaseCppType('at', 'ScalarType')
|
||||
tensorT = BaseCppType('at', 'Tensor')
|
||||
optionalTensorRefT = BaseCppType('at', 'OptionalTensorRef')
|
||||
tensorListT = BaseCppType('at', 'TensorList')
|
||||
dimnameT = BaseCppType('at', 'Dimname')
|
||||
dimnameListT = BaseCppType('at', 'DimnameList')
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user