mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[PyTorch] Fix a bunch of structured kernel refcounting (#71140)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71140
Structured kernels need to use the borrowing variants of the build APIs to TensorIterator. (I am working on a debug check for this, but it is currently too strict, and relaxing it does not catch these bugs.)
ghstack-source-id: 147191022
Test Plan: CI
Reviewed By: bdhirsh
Differential Revision: D33520003
fbshipit-source-id: 3b0ff9036acdb78ae6fc7489ed0ed487d5ff080f
(cherry picked from commit 80ef4e14e3)
This commit is contained in:
parent
b98e955b24
commit
c59942ac73
|
|
@ -877,14 +877,8 @@ void TensorIteratorBase::build_borrowing_binary_float_op(
|
|||
.add_input(b));
|
||||
}
|
||||
|
||||
void TensorIteratorBase::build_comparison_op(
|
||||
const TensorBase& out, const TensorBase& a, const TensorBase& b) {
|
||||
TensorIteratorConfig config;
|
||||
|
||||
static void set_up_comparison_op_config(TensorIteratorConfig& config, const TensorBase& out) {
|
||||
config.set_check_mem_overlap(true);
|
||||
config.add_owned_output(out);
|
||||
config.add_owned_input(a);
|
||||
config.add_owned_input(b);
|
||||
config.allow_cpu_scalars(true);
|
||||
config.promote_inputs_to_common_dtype(true);
|
||||
|
||||
|
|
@ -905,7 +899,38 @@ void TensorIteratorBase::build_comparison_op(
|
|||
if (out.defined() && out.scalar_type() != kBool) {
|
||||
config.cast_common_dtype_to_outputs(true);
|
||||
}
|
||||
}
|
||||
|
||||
void TensorIteratorBase::build_comparison_op(
|
||||
const TensorBase& out, const TensorBase& a, const TensorBase& b) {
|
||||
TensorIteratorConfig config;
|
||||
set_up_comparison_op_config(config, out);
|
||||
|
||||
config.add_owned_output(out);
|
||||
config.add_owned_input(a);
|
||||
config.add_owned_input(b);
|
||||
build(config);
|
||||
}
|
||||
|
||||
void TensorIteratorBase::build_borrowing_comparison_op(
|
||||
const TensorBase& out, const TensorBase& a, const TensorBase& b) {
|
||||
TensorIteratorConfig config;
|
||||
set_up_comparison_op_config(config, out);
|
||||
|
||||
config.add_borrowed_output(out);
|
||||
config.add_borrowed_input(a);
|
||||
config.add_borrowed_input(b);
|
||||
build(config);
|
||||
}
|
||||
|
||||
void TensorIteratorBase::build_borrowing_except_last_argument_comparison_op(
|
||||
const TensorBase& out, const TensorBase& a, const TensorBase& b) {
|
||||
TensorIteratorConfig config;
|
||||
set_up_comparison_op_config(config, out);
|
||||
|
||||
config.add_borrowed_output(out);
|
||||
config.add_borrowed_input(a);
|
||||
config.add_owned_input(b);
|
||||
build(config);
|
||||
}
|
||||
|
||||
|
|
@ -944,37 +969,65 @@ void TensorIteratorBase::build_borrowing_binary_op(
|
|||
.add_input(b));
|
||||
}
|
||||
|
||||
// This cannot be a function because TensorIteratorConfig is not
|
||||
// copyable or movable, so it can't be returned from the function.
|
||||
#define UNARY_FLOAT_OP_CONFIG() \
|
||||
TensorIteratorConfig() \
|
||||
.set_check_mem_overlap(true) \
|
||||
.promote_inputs_to_common_dtype(true) \
|
||||
.cast_common_dtype_to_outputs(true) \
|
||||
.enforce_safe_casting_to_output(true) \
|
||||
.promote_integer_inputs_to_float(true)
|
||||
|
||||
void TensorIteratorBase::build_unary_float_op(const TensorBase& out, const TensorBase& a) {
|
||||
build(TensorIteratorConfig()
|
||||
.set_check_mem_overlap(true)
|
||||
build(UNARY_FLOAT_OP_CONFIG()
|
||||
.add_owned_output(out)
|
||||
.add_owned_input(a)
|
||||
.promote_inputs_to_common_dtype(true)
|
||||
.cast_common_dtype_to_outputs(true)
|
||||
.enforce_safe_casting_to_output(true)
|
||||
.promote_integer_inputs_to_float(true));
|
||||
.add_owned_input(a));
|
||||
}
|
||||
|
||||
void TensorIteratorBase::build_borrowing_unary_float_op(const TensorBase& out, const TensorBase& a) {
|
||||
build(UNARY_FLOAT_OP_CONFIG()
|
||||
.add_output(out)
|
||||
.add_input(a));
|
||||
}
|
||||
|
||||
// This cannot be a function because TensorIteratorConfig is not
|
||||
// copyable or movable, so it can't be returned from the function.
|
||||
#define UNARY_OP_CONFIG() \
|
||||
TensorIteratorConfig() \
|
||||
.set_check_mem_overlap(true) \
|
||||
.cast_common_dtype_to_outputs(false) \
|
||||
.enforce_safe_casting_to_output(false) \
|
||||
.check_all_same_dtype(true)
|
||||
|
||||
void TensorIteratorBase::build_unary_op(const TensorBase& out, const TensorBase& a) {
|
||||
build(TensorIteratorConfig()
|
||||
.set_check_mem_overlap(true)
|
||||
build(UNARY_OP_CONFIG()
|
||||
.add_owned_output(out)
|
||||
.add_owned_input(a)
|
||||
.cast_common_dtype_to_outputs(false)
|
||||
.enforce_safe_casting_to_output(false)
|
||||
.check_all_same_dtype(true));
|
||||
.add_owned_input(a));
|
||||
}
|
||||
|
||||
void TensorIteratorBase::build_borrowing_unary_op(const TensorBase& out, const TensorBase& a) {
|
||||
build(UNARY_OP_CONFIG()
|
||||
.add_output(out)
|
||||
.add_input(a));
|
||||
}
|
||||
|
||||
void TensorIteratorBase::build_output_borrowing_argument_owning_unary_op(const TensorBase& out, const TensorBase& a) {
|
||||
build(UNARY_OP_CONFIG()
|
||||
.add_output(out)
|
||||
.add_owned_input(a));
|
||||
}
|
||||
|
||||
// Helper to construct a unary op that forcibly promotes output to boolean.
|
||||
// Only be used when the output tensor must have boolean type.
|
||||
void TensorIteratorBase::build_unary_force_boolean_op(const TensorBase& out, const TensorBase& a) {
|
||||
void TensorIteratorBase::build_borrowing_unary_force_boolean_op(const TensorBase& out, const TensorBase& a) {
|
||||
build(TensorIteratorConfig()
|
||||
.set_check_mem_overlap(true)
|
||||
.check_all_same_dtype(false)
|
||||
.declare_static_dtype(at::kBool)
|
||||
.declare_static_device(a.device())
|
||||
.add_owned_output(out)
|
||||
.add_owned_input(a));
|
||||
.add_output(out)
|
||||
.add_input(a));
|
||||
}
|
||||
|
||||
TensorIterator TensorIterator::binary_op(TensorBase& out, const TensorBase& a, const TensorBase& b) {
|
||||
|
|
@ -1015,12 +1068,6 @@ TensorIterator TensorIterator::unary_float_op(TensorBase& out, const TensorBase&
|
|||
return iter;
|
||||
}
|
||||
|
||||
TensorIterator TensorIterator::unary_force_boolean_op(const TensorBase& out, const TensorBase& a) {
|
||||
TensorIterator iter;
|
||||
iter.build_unary_force_boolean_op(out, a);
|
||||
return iter;
|
||||
}
|
||||
|
||||
#define NULLARY_OP_CONFIG() \
|
||||
TensorIteratorConfig() \
|
||||
.set_check_mem_overlap(true) \
|
||||
|
|
|
|||
|
|
@ -435,9 +435,21 @@ public:
|
|||
void build_borrowing_binary_op(const TensorBase& out, const TensorBase& a, const TensorBase& b);
|
||||
TORCH_DISALLOW_TEMPORARIES(build_borrowing_binary_op)
|
||||
void build_unary_float_op(const TensorBase& out, const TensorBase& a);
|
||||
void build_borrowing_unary_float_op(const TensorBase& out, const TensorBase& a);
|
||||
TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_float_op)
|
||||
void build_unary_op(const TensorBase& out, const TensorBase& a);
|
||||
void build_unary_force_boolean_op(const TensorBase& out, const TensorBase& a);
|
||||
// Odd special case needed for pow. Has to borrow the output because
|
||||
// it's a structured kernel, but the argument is potentially a copy.
|
||||
void build_output_borrowing_argument_owning_unary_op(const TensorBase& out, const TensorBase& a);
|
||||
void build_borrowing_unary_op(const TensorBase& out, const TensorBase& a);
|
||||
TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_op)
|
||||
void build_borrowing_unary_force_boolean_op(const TensorBase& out, const TensorBase& a);
|
||||
TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_force_boolean_op)
|
||||
void build_comparison_op(const TensorBase& out, const TensorBase& a, const TensorBase& b);
|
||||
void build_borrowing_comparison_op(const TensorBase& out, const TensorBase& a, const TensorBase& b);
|
||||
TORCH_DISALLOW_TEMPORARIES(build_borrowing_comparison_op)
|
||||
// Another special case: we need to own the second argument for comparison ops.
|
||||
void build_borrowing_except_last_argument_comparison_op(const TensorBase& out, const TensorBase& a, const TensorBase& b);
|
||||
void build_ternary_op(const TensorBase& out, const TensorBase& a, const TensorBase& b, const TensorBase& c);
|
||||
|
||||
#undef TORCH_DISALLOW_TEMPORARIES
|
||||
|
|
@ -571,7 +583,6 @@ struct TORCH_API TensorIterator final : public TensorIteratorBase {
|
|||
static TensorIterator unary_op(TensorBase& out, const TensorBase& a);
|
||||
static TensorIterator unary_float_op(TensorBase& out, const TensorBase& a);
|
||||
static TensorIterator nullary_op(TensorBase& out);
|
||||
static TensorIterator unary_force_boolean_op(const TensorBase& out, const TensorBase& a);
|
||||
static TensorIterator borrowing_nullary_op(const TensorBase& out);
|
||||
static TensorIterator borrowing_nullary_op(TensorBase&& out) = delete;
|
||||
static TensorIterator reduce_op(TensorBase& out, const TensorBase& a);
|
||||
|
|
|
|||
|
|
@ -211,13 +211,13 @@ void comparison_op_check(const Tensor& self, const Tensor& other, const Tensor&
|
|||
TORCH_META_FUNC2(func, Tensor)(const Tensor& self, const Tensor& other) { \
|
||||
const Tensor& result = maybe_get_output(); \
|
||||
comparison_op_check(self, other, result); \
|
||||
build_comparison_op(result, self, other); \
|
||||
build_borrowing_comparison_op(result, self, other); \
|
||||
} \
|
||||
\
|
||||
TORCH_META_FUNC2(func, Scalar)(const Tensor& self, const Scalar& other) { \
|
||||
auto other_tensor = \
|
||||
native::wrapped_scalar_tensor_and_check_convert(other, self); \
|
||||
build_comparison_op(maybe_get_output(), self, other_tensor); \
|
||||
build_borrowing_except_last_argument_comparison_op(maybe_get_output(), self, other_tensor); \
|
||||
}
|
||||
|
||||
CREATE_COMPARISON_SCALAR_TENSOR_META_FUNC(eq);
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ TORCH_META_FUNC2(pow, Tensor_Scalar) (const Tensor& base, const Scalar& exp) {
|
|||
"Integers to negative integer powers are not allowed.");
|
||||
|
||||
auto common_dtype = at::result_type(base, exp);
|
||||
build_unary_op(maybe_get_output(), base.to(common_dtype));
|
||||
build_output_borrowing_argument_owning_unary_op(maybe_get_output(), base.to(common_dtype));
|
||||
}
|
||||
|
||||
TORCH_META_FUNC2(pow, Scalar) (const Scalar& base, const Tensor& exp) {
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ const OptionalScalarRef max) {
|
|||
TORCH_CHECK(false, "torch.clamp: At least one of 'min' or 'max' must not be None");
|
||||
}
|
||||
|
||||
build_unary_op(maybe_get_output(), self);
|
||||
build_borrowing_unary_op(maybe_get_output(), self);
|
||||
}
|
||||
|
||||
TORCH_META_FUNC2(isin, Tensor_Tensor) (
|
||||
|
|
@ -61,14 +61,14 @@ TORCH_META_FUNC(isposinf) (const Tensor& self) {
|
|||
TORCH_CHECK(!self.is_complex(), "isposinf does not support complex inputs.");
|
||||
TORCH_CHECK(maybe_get_output().defined() ? maybe_get_output().dtype() == at::kBool : true,
|
||||
"isposinf does not support non-boolean outputs.");
|
||||
build_unary_force_boolean_op(maybe_get_output(), self);
|
||||
build_borrowing_unary_force_boolean_op(maybe_get_output(), self);
|
||||
}
|
||||
|
||||
TORCH_META_FUNC(isneginf) (const Tensor& self) {
|
||||
TORCH_CHECK(!self.is_complex(), "isneginf does not support complex inputs.");
|
||||
TORCH_CHECK(maybe_get_output().defined() ? maybe_get_output().dtype() == at::kBool : true,
|
||||
"isneginf does not support non-boolean outputs.");
|
||||
build_unary_force_boolean_op(maybe_get_output(), self);
|
||||
build_borrowing_unary_force_boolean_op(maybe_get_output(), self);
|
||||
}
|
||||
|
||||
static void check_unsupported_complex(const char* name, const Tensor& self) {
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ namespace meta {
|
|||
// For complex inputs, the output type should be the same as input type.
|
||||
#define CREATE_UNARY_FLOAT_META_FUNC(func) \
|
||||
TORCH_META_FUNC(func) (const Tensor& self) { \
|
||||
build_unary_float_op(maybe_get_output(), self); \
|
||||
build_borrowing_unary_float_op(maybe_get_output(), self); \
|
||||
}
|
||||
|
||||
CREATE_UNARY_FLOAT_META_FUNC(acos)
|
||||
|
|
@ -73,13 +73,13 @@ CREATE_UNARY_FLOAT_META_FUNC(tanh)
|
|||
|
||||
TORCH_META_FUNC(polygamma)(int64_t n, const Tensor& self) {
|
||||
TORCH_CHECK(n >= 0, "polygamma(n, x) does not support negative n.");
|
||||
build_unary_float_op(maybe_get_output(), self);
|
||||
build_borrowing_unary_float_op(maybe_get_output(), self);
|
||||
}
|
||||
|
||||
// These are normal unary ops that preserve dtype
|
||||
#define CREATE_UNARY_META_FUNC(func) \
|
||||
TORCH_META_FUNC(func) (const Tensor& self) { \
|
||||
build_unary_op(maybe_get_output(), self); \
|
||||
build_borrowing_unary_op(maybe_get_output(), self); \
|
||||
}
|
||||
CREATE_UNARY_META_FUNC(bitwise_not)
|
||||
CREATE_UNARY_META_FUNC(frac)
|
||||
|
|
@ -90,41 +90,41 @@ TORCH_META_FUNC(neg)(const Tensor& self) {
|
|||
TORCH_CHECK(self.scalar_type() != kBool,
|
||||
"Negation, the `-` operator, on a bool tensor is not supported. "
|
||||
"If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
|
||||
build_unary_op(maybe_get_output(), self);
|
||||
build_borrowing_unary_op(maybe_get_output(), self);
|
||||
}
|
||||
|
||||
TORCH_META_FUNC(trunc) (const Tensor& self) {
|
||||
// Note: this is consistent with NumPy
|
||||
TORCH_CHECK(!self.is_complex(),
|
||||
"trunc is not supported for complex inputs");
|
||||
build_unary_op(maybe_get_output(), self);
|
||||
build_borrowing_unary_op(maybe_get_output(), self);
|
||||
}
|
||||
|
||||
TORCH_META_FUNC(floor) (const Tensor& self) {
|
||||
// Note: this is consistent with NumPy
|
||||
TORCH_CHECK(!self.is_complex(),
|
||||
"floor is not supported for complex inputs");
|
||||
build_unary_op(maybe_get_output(), self);
|
||||
build_borrowing_unary_op(maybe_get_output(), self);
|
||||
}
|
||||
|
||||
TORCH_META_FUNC(sign) (const Tensor& self) {
|
||||
TORCH_CHECK(!self.is_complex(),
|
||||
"Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.");
|
||||
build_unary_op(maybe_get_output(), self);
|
||||
build_borrowing_unary_op(maybe_get_output(), self);
|
||||
}
|
||||
|
||||
TORCH_META_FUNC(signbit) (const Tensor& self) {
|
||||
TORCH_CHECK(!self.is_complex(), "signbit is not implemented for complex tensors.");
|
||||
TORCH_CHECK(maybe_get_output().defined() ? maybe_get_output().dtype() == at::kBool : true,
|
||||
"signbit does not support non-boolean outputs.");
|
||||
build_unary_force_boolean_op(maybe_get_output(), self);
|
||||
build_borrowing_unary_force_boolean_op(maybe_get_output(), self);
|
||||
}
|
||||
|
||||
TORCH_META_FUNC(ceil) (const Tensor& self) {
|
||||
// Note: this is consistent with NumPy
|
||||
TORCH_CHECK(!self.is_complex(),
|
||||
"ceil is not supported for complex inputs");
|
||||
build_unary_op(maybe_get_output(), self);
|
||||
build_borrowing_unary_op(maybe_get_output(), self);
|
||||
}
|
||||
|
||||
} // namespace meta
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ void mkldnn_matmul(
|
|||
bool use_mkldnn_bf16_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const c10::optional<Tensor>& result_opt){
|
||||
const Tensor& result_opt){
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -126,9 +126,7 @@ inline bool checksize(const Tensor& mat1, const Tensor& mat2){
|
|||
bool use_mkldnn_bf16_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const c10::optional<Tensor>& result_opt) {
|
||||
c10::MaybeOwned<Tensor> result_maybe_owned = at::borrow_from_optional_tensor(result_opt);
|
||||
const Tensor& result = *result_maybe_owned;
|
||||
const Tensor& result) {
|
||||
return (
|
||||
at::globalContext().userEnabledMkldnn() &&
|
||||
mat1.scalar_type() == kBFloat16 &&
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ TORCH_API void mkldnn_matmul(
|
|||
bool use_mkldnn_bf16_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const c10::optional<Tensor>& result_opt);
|
||||
const Tensor& result_opt);
|
||||
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user