[PyTorch] Fix a bunch of structured kernel refcounting (#71140)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71140

Structured kernels need to use the borrowing variants of the build APIs to TensorIterator. (I am working on a debug check for this, but it is currently too strict, and relaxing it does not catch these bugs.)
ghstack-source-id: 147191022

Test Plan: CI

Reviewed By: bdhirsh

Differential Revision: D33520003

fbshipit-source-id: 3b0ff9036acdb78ae6fc7489ed0ed487d5ff080f
(cherry picked from commit 80ef4e14e3)
This commit is contained in:
Scott Wolchok 2022-01-19 16:24:11 -08:00 committed by PyTorch MergeBot
parent b98e955b24
commit c59942ac73
8 changed files with 107 additions and 51 deletions

View File

@ -877,14 +877,8 @@ void TensorIteratorBase::build_borrowing_binary_float_op(
.add_input(b));
}
void TensorIteratorBase::build_comparison_op(
const TensorBase& out, const TensorBase& a, const TensorBase& b) {
TensorIteratorConfig config;
static void set_up_comparison_op_config(TensorIteratorConfig& config, const TensorBase& out) {
config.set_check_mem_overlap(true);
config.add_owned_output(out);
config.add_owned_input(a);
config.add_owned_input(b);
config.allow_cpu_scalars(true);
config.promote_inputs_to_common_dtype(true);
@ -905,7 +899,38 @@ void TensorIteratorBase::build_comparison_op(
if (out.defined() && out.scalar_type() != kBool) {
config.cast_common_dtype_to_outputs(true);
}
}
void TensorIteratorBase::build_comparison_op(
const TensorBase& out, const TensorBase& a, const TensorBase& b) {
TensorIteratorConfig config;
set_up_comparison_op_config(config, out);
config.add_owned_output(out);
config.add_owned_input(a);
config.add_owned_input(b);
build(config);
}
void TensorIteratorBase::build_borrowing_comparison_op(
const TensorBase& out, const TensorBase& a, const TensorBase& b) {
TensorIteratorConfig config;
set_up_comparison_op_config(config, out);
config.add_borrowed_output(out);
config.add_borrowed_input(a);
config.add_borrowed_input(b);
build(config);
}
void TensorIteratorBase::build_borrowing_except_last_argument_comparison_op(
const TensorBase& out, const TensorBase& a, const TensorBase& b) {
TensorIteratorConfig config;
set_up_comparison_op_config(config, out);
config.add_borrowed_output(out);
config.add_borrowed_input(a);
config.add_owned_input(b);
build(config);
}
@ -944,37 +969,65 @@ void TensorIteratorBase::build_borrowing_binary_op(
.add_input(b));
}
// This cannot be a function because TensorIteratorConfig is not
// copyable or movable, so it can't be returned from the function.
#define UNARY_FLOAT_OP_CONFIG() \
TensorIteratorConfig() \
.set_check_mem_overlap(true) \
.promote_inputs_to_common_dtype(true) \
.cast_common_dtype_to_outputs(true) \
.enforce_safe_casting_to_output(true) \
.promote_integer_inputs_to_float(true)
void TensorIteratorBase::build_unary_float_op(const TensorBase& out, const TensorBase& a) {
build(TensorIteratorConfig()
.set_check_mem_overlap(true)
build(UNARY_FLOAT_OP_CONFIG()
.add_owned_output(out)
.add_owned_input(a)
.promote_inputs_to_common_dtype(true)
.cast_common_dtype_to_outputs(true)
.enforce_safe_casting_to_output(true)
.promote_integer_inputs_to_float(true));
.add_owned_input(a));
}
void TensorIteratorBase::build_borrowing_unary_float_op(const TensorBase& out, const TensorBase& a) {
build(UNARY_FLOAT_OP_CONFIG()
.add_output(out)
.add_input(a));
}
// This cannot be a function because TensorIteratorConfig is not
// copyable or movable, so it can't be returned from the function.
#define UNARY_OP_CONFIG() \
TensorIteratorConfig() \
.set_check_mem_overlap(true) \
.cast_common_dtype_to_outputs(false) \
.enforce_safe_casting_to_output(false) \
.check_all_same_dtype(true)
void TensorIteratorBase::build_unary_op(const TensorBase& out, const TensorBase& a) {
build(TensorIteratorConfig()
.set_check_mem_overlap(true)
build(UNARY_OP_CONFIG()
.add_owned_output(out)
.add_owned_input(a)
.cast_common_dtype_to_outputs(false)
.enforce_safe_casting_to_output(false)
.check_all_same_dtype(true));
.add_owned_input(a));
}
void TensorIteratorBase::build_borrowing_unary_op(const TensorBase& out, const TensorBase& a) {
build(UNARY_OP_CONFIG()
.add_output(out)
.add_input(a));
}
void TensorIteratorBase::build_output_borrowing_argument_owning_unary_op(const TensorBase& out, const TensorBase& a) {
build(UNARY_OP_CONFIG()
.add_output(out)
.add_owned_input(a));
}
// Helper to construct a unary op that forcibly promotes output to boolean.
// Only be used when the output tensor must have boolean type.
void TensorIteratorBase::build_unary_force_boolean_op(const TensorBase& out, const TensorBase& a) {
void TensorIteratorBase::build_borrowing_unary_force_boolean_op(const TensorBase& out, const TensorBase& a) {
build(TensorIteratorConfig()
.set_check_mem_overlap(true)
.check_all_same_dtype(false)
.declare_static_dtype(at::kBool)
.declare_static_device(a.device())
.add_owned_output(out)
.add_owned_input(a));
.add_output(out)
.add_input(a));
}
TensorIterator TensorIterator::binary_op(TensorBase& out, const TensorBase& a, const TensorBase& b) {
@ -1015,12 +1068,6 @@ TensorIterator TensorIterator::unary_float_op(TensorBase& out, const TensorBase&
return iter;
}
TensorIterator TensorIterator::unary_force_boolean_op(const TensorBase& out, const TensorBase& a) {
TensorIterator iter;
iter.build_unary_force_boolean_op(out, a);
return iter;
}
#define NULLARY_OP_CONFIG() \
TensorIteratorConfig() \
.set_check_mem_overlap(true) \

View File

@ -435,9 +435,21 @@ public:
void build_borrowing_binary_op(const TensorBase& out, const TensorBase& a, const TensorBase& b);
TORCH_DISALLOW_TEMPORARIES(build_borrowing_binary_op)
void build_unary_float_op(const TensorBase& out, const TensorBase& a);
void build_borrowing_unary_float_op(const TensorBase& out, const TensorBase& a);
TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_float_op)
void build_unary_op(const TensorBase& out, const TensorBase& a);
void build_unary_force_boolean_op(const TensorBase& out, const TensorBase& a);
// Odd special case needed for pow. Has to borrow the output because
// it's a structured kernel, but the argument is potentially a copy.
void build_output_borrowing_argument_owning_unary_op(const TensorBase& out, const TensorBase& a);
void build_borrowing_unary_op(const TensorBase& out, const TensorBase& a);
TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_op)
void build_borrowing_unary_force_boolean_op(const TensorBase& out, const TensorBase& a);
TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_force_boolean_op)
void build_comparison_op(const TensorBase& out, const TensorBase& a, const TensorBase& b);
void build_borrowing_comparison_op(const TensorBase& out, const TensorBase& a, const TensorBase& b);
TORCH_DISALLOW_TEMPORARIES(build_borrowing_comparison_op)
// Another special case: we need to own the second argument for comparison ops.
void build_borrowing_except_last_argument_comparison_op(const TensorBase& out, const TensorBase& a, const TensorBase& b);
void build_ternary_op(const TensorBase& out, const TensorBase& a, const TensorBase& b, const TensorBase& c);
#undef TORCH_DISALLOW_TEMPORARIES
@ -571,7 +583,6 @@ struct TORCH_API TensorIterator final : public TensorIteratorBase {
static TensorIterator unary_op(TensorBase& out, const TensorBase& a);
static TensorIterator unary_float_op(TensorBase& out, const TensorBase& a);
static TensorIterator nullary_op(TensorBase& out);
static TensorIterator unary_force_boolean_op(const TensorBase& out, const TensorBase& a);
static TensorIterator borrowing_nullary_op(const TensorBase& out);
static TensorIterator borrowing_nullary_op(TensorBase&& out) = delete;
static TensorIterator reduce_op(TensorBase& out, const TensorBase& a);

View File

@ -211,13 +211,13 @@ void comparison_op_check(const Tensor& self, const Tensor& other, const Tensor&
TORCH_META_FUNC2(func, Tensor)(const Tensor& self, const Tensor& other) { \
const Tensor& result = maybe_get_output(); \
comparison_op_check(self, other, result); \
build_comparison_op(result, self, other); \
build_borrowing_comparison_op(result, self, other); \
} \
\
TORCH_META_FUNC2(func, Scalar)(const Tensor& self, const Scalar& other) { \
auto other_tensor = \
native::wrapped_scalar_tensor_and_check_convert(other, self); \
build_comparison_op(maybe_get_output(), self, other_tensor); \
build_borrowing_except_last_argument_comparison_op(maybe_get_output(), self, other_tensor); \
}
CREATE_COMPARISON_SCALAR_TENSOR_META_FUNC(eq);

View File

@ -20,7 +20,7 @@ TORCH_META_FUNC2(pow, Tensor_Scalar) (const Tensor& base, const Scalar& exp) {
"Integers to negative integer powers are not allowed.");
auto common_dtype = at::result_type(base, exp);
build_unary_op(maybe_get_output(), base.to(common_dtype));
build_output_borrowing_argument_owning_unary_op(maybe_get_output(), base.to(common_dtype));
}
TORCH_META_FUNC2(pow, Scalar) (const Scalar& base, const Tensor& exp) {

View File

@ -30,7 +30,7 @@ const OptionalScalarRef max) {
TORCH_CHECK(false, "torch.clamp: At least one of 'min' or 'max' must not be None");
}
build_unary_op(maybe_get_output(), self);
build_borrowing_unary_op(maybe_get_output(), self);
}
TORCH_META_FUNC2(isin, Tensor_Tensor) (
@ -61,14 +61,14 @@ TORCH_META_FUNC(isposinf) (const Tensor& self) {
TORCH_CHECK(!self.is_complex(), "isposinf does not support complex inputs.");
TORCH_CHECK(maybe_get_output().defined() ? maybe_get_output().dtype() == at::kBool : true,
"isposinf does not support non-boolean outputs.");
build_unary_force_boolean_op(maybe_get_output(), self);
build_borrowing_unary_force_boolean_op(maybe_get_output(), self);
}
TORCH_META_FUNC(isneginf) (const Tensor& self) {
TORCH_CHECK(!self.is_complex(), "isneginf does not support complex inputs.");
TORCH_CHECK(maybe_get_output().defined() ? maybe_get_output().dtype() == at::kBool : true,
"isneginf does not support non-boolean outputs.");
build_unary_force_boolean_op(maybe_get_output(), self);
build_borrowing_unary_force_boolean_op(maybe_get_output(), self);
}
static void check_unsupported_complex(const char* name, const Tensor& self) {

View File

@ -31,7 +31,7 @@ namespace meta {
// For complex inputs, the output type should be the same as input type.
#define CREATE_UNARY_FLOAT_META_FUNC(func) \
TORCH_META_FUNC(func) (const Tensor& self) { \
build_unary_float_op(maybe_get_output(), self); \
build_borrowing_unary_float_op(maybe_get_output(), self); \
}
CREATE_UNARY_FLOAT_META_FUNC(acos)
@ -73,13 +73,13 @@ CREATE_UNARY_FLOAT_META_FUNC(tanh)
TORCH_META_FUNC(polygamma)(int64_t n, const Tensor& self) {
TORCH_CHECK(n >= 0, "polygamma(n, x) does not support negative n.");
build_unary_float_op(maybe_get_output(), self);
build_borrowing_unary_float_op(maybe_get_output(), self);
}
// These are normal unary ops that preserve dtype
#define CREATE_UNARY_META_FUNC(func) \
TORCH_META_FUNC(func) (const Tensor& self) { \
build_unary_op(maybe_get_output(), self); \
build_borrowing_unary_op(maybe_get_output(), self); \
}
CREATE_UNARY_META_FUNC(bitwise_not)
CREATE_UNARY_META_FUNC(frac)
@ -90,41 +90,41 @@ TORCH_META_FUNC(neg)(const Tensor& self) {
TORCH_CHECK(self.scalar_type() != kBool,
"Negation, the `-` operator, on a bool tensor is not supported. "
"If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
build_unary_op(maybe_get_output(), self);
build_borrowing_unary_op(maybe_get_output(), self);
}
TORCH_META_FUNC(trunc) (const Tensor& self) {
// Note: this is consistent with NumPy
TORCH_CHECK(!self.is_complex(),
"trunc is not supported for complex inputs");
build_unary_op(maybe_get_output(), self);
build_borrowing_unary_op(maybe_get_output(), self);
}
TORCH_META_FUNC(floor) (const Tensor& self) {
// Note: this is consistent with NumPy
TORCH_CHECK(!self.is_complex(),
"floor is not supported for complex inputs");
build_unary_op(maybe_get_output(), self);
build_borrowing_unary_op(maybe_get_output(), self);
}
TORCH_META_FUNC(sign) (const Tensor& self) {
TORCH_CHECK(!self.is_complex(),
"Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.");
build_unary_op(maybe_get_output(), self);
build_borrowing_unary_op(maybe_get_output(), self);
}
TORCH_META_FUNC(signbit) (const Tensor& self) {
TORCH_CHECK(!self.is_complex(), "signbit is not implemented for complex tensors.");
TORCH_CHECK(maybe_get_output().defined() ? maybe_get_output().dtype() == at::kBool : true,
"signbit does not support non-boolean outputs.");
build_unary_force_boolean_op(maybe_get_output(), self);
build_borrowing_unary_force_boolean_op(maybe_get_output(), self);
}
TORCH_META_FUNC(ceil) (const Tensor& self) {
// Note: this is consistent with NumPy
TORCH_CHECK(!self.is_complex(),
"ceil is not supported for complex inputs");
build_unary_op(maybe_get_output(), self);
build_borrowing_unary_op(maybe_get_output(), self);
}
} // namespace meta

View File

@ -19,7 +19,7 @@ void mkldnn_matmul(
bool use_mkldnn_bf16_matmul(
const Tensor& mat1,
const Tensor& mat2,
const c10::optional<Tensor>& result_opt){
const Tensor& result_opt){
return false;
}
@ -126,9 +126,7 @@ inline bool checksize(const Tensor& mat1, const Tensor& mat2){
bool use_mkldnn_bf16_matmul(
const Tensor& mat1,
const Tensor& mat2,
const c10::optional<Tensor>& result_opt) {
c10::MaybeOwned<Tensor> result_maybe_owned = at::borrow_from_optional_tensor(result_opt);
const Tensor& result = *result_maybe_owned;
const Tensor& result) {
return (
at::globalContext().userEnabledMkldnn() &&
mat1.scalar_type() == kBFloat16 &&

View File

@ -16,7 +16,7 @@ TORCH_API void mkldnn_matmul(
bool use_mkldnn_bf16_matmul(
const Tensor& mat1,
const Tensor& mat2,
const c10::optional<Tensor>& result_opt);
const Tensor& result_opt);
}