mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[PyTorch] Fix a bunch of structured kernel refcounting (#71140)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71140 Structured kernels need to use the borrowing variants of the build APIs to TensorIterator. (I am working on a debug check for this, but it is currently too strict, and relaxing it does not catch these bugs.) ghstack-source-id: 147191022 Test Plan: CI Reviewed By: bdhirsh Differential Revision: D33520003 fbshipit-source-id: 3b0ff9036acdb78ae6fc7489ed0ed487d5ff080f
This commit is contained in:
parent
746487075e
commit
80ef4e14e3
|
|
@ -877,14 +877,8 @@ void TensorIteratorBase::build_borrowing_binary_float_op(
|
||||||
.add_input(b));
|
.add_input(b));
|
||||||
}
|
}
|
||||||
|
|
||||||
void TensorIteratorBase::build_comparison_op(
|
static void set_up_comparison_op_config(TensorIteratorConfig& config, const TensorBase& out) {
|
||||||
const TensorBase& out, const TensorBase& a, const TensorBase& b) {
|
|
||||||
TensorIteratorConfig config;
|
|
||||||
|
|
||||||
config.set_check_mem_overlap(true);
|
config.set_check_mem_overlap(true);
|
||||||
config.add_owned_output(out);
|
|
||||||
config.add_owned_input(a);
|
|
||||||
config.add_owned_input(b);
|
|
||||||
config.allow_cpu_scalars(true);
|
config.allow_cpu_scalars(true);
|
||||||
config.promote_inputs_to_common_dtype(true);
|
config.promote_inputs_to_common_dtype(true);
|
||||||
|
|
||||||
|
|
@ -905,7 +899,38 @@ void TensorIteratorBase::build_comparison_op(
|
||||||
if (out.defined() && out.scalar_type() != kBool) {
|
if (out.defined() && out.scalar_type() != kBool) {
|
||||||
config.cast_common_dtype_to_outputs(true);
|
config.cast_common_dtype_to_outputs(true);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TensorIteratorBase::build_comparison_op(
|
||||||
|
const TensorBase& out, const TensorBase& a, const TensorBase& b) {
|
||||||
|
TensorIteratorConfig config;
|
||||||
|
set_up_comparison_op_config(config, out);
|
||||||
|
|
||||||
|
config.add_owned_output(out);
|
||||||
|
config.add_owned_input(a);
|
||||||
|
config.add_owned_input(b);
|
||||||
|
build(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TensorIteratorBase::build_borrowing_comparison_op(
|
||||||
|
const TensorBase& out, const TensorBase& a, const TensorBase& b) {
|
||||||
|
TensorIteratorConfig config;
|
||||||
|
set_up_comparison_op_config(config, out);
|
||||||
|
|
||||||
|
config.add_borrowed_output(out);
|
||||||
|
config.add_borrowed_input(a);
|
||||||
|
config.add_borrowed_input(b);
|
||||||
|
build(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TensorIteratorBase::build_borrowing_except_last_argument_comparison_op(
|
||||||
|
const TensorBase& out, const TensorBase& a, const TensorBase& b) {
|
||||||
|
TensorIteratorConfig config;
|
||||||
|
set_up_comparison_op_config(config, out);
|
||||||
|
|
||||||
|
config.add_borrowed_output(out);
|
||||||
|
config.add_borrowed_input(a);
|
||||||
|
config.add_owned_input(b);
|
||||||
build(config);
|
build(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -944,37 +969,65 @@ void TensorIteratorBase::build_borrowing_binary_op(
|
||||||
.add_input(b));
|
.add_input(b));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This cannot be a function because TensorIteratorConfig is not
|
||||||
|
// copyable or movable, so it can't be returned from the function.
|
||||||
|
#define UNARY_FLOAT_OP_CONFIG() \
|
||||||
|
TensorIteratorConfig() \
|
||||||
|
.set_check_mem_overlap(true) \
|
||||||
|
.promote_inputs_to_common_dtype(true) \
|
||||||
|
.cast_common_dtype_to_outputs(true) \
|
||||||
|
.enforce_safe_casting_to_output(true) \
|
||||||
|
.promote_integer_inputs_to_float(true)
|
||||||
|
|
||||||
void TensorIteratorBase::build_unary_float_op(const TensorBase& out, const TensorBase& a) {
|
void TensorIteratorBase::build_unary_float_op(const TensorBase& out, const TensorBase& a) {
|
||||||
build(TensorIteratorConfig()
|
build(UNARY_FLOAT_OP_CONFIG()
|
||||||
.set_check_mem_overlap(true)
|
|
||||||
.add_owned_output(out)
|
.add_owned_output(out)
|
||||||
.add_owned_input(a)
|
.add_owned_input(a));
|
||||||
.promote_inputs_to_common_dtype(true)
|
|
||||||
.cast_common_dtype_to_outputs(true)
|
|
||||||
.enforce_safe_casting_to_output(true)
|
|
||||||
.promote_integer_inputs_to_float(true));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TensorIteratorBase::build_borrowing_unary_float_op(const TensorBase& out, const TensorBase& a) {
|
||||||
|
build(UNARY_FLOAT_OP_CONFIG()
|
||||||
|
.add_output(out)
|
||||||
|
.add_input(a));
|
||||||
|
}
|
||||||
|
|
||||||
|
// This cannot be a function because TensorIteratorConfig is not
|
||||||
|
// copyable or movable, so it can't be returned from the function.
|
||||||
|
#define UNARY_OP_CONFIG() \
|
||||||
|
TensorIteratorConfig() \
|
||||||
|
.set_check_mem_overlap(true) \
|
||||||
|
.cast_common_dtype_to_outputs(false) \
|
||||||
|
.enforce_safe_casting_to_output(false) \
|
||||||
|
.check_all_same_dtype(true)
|
||||||
|
|
||||||
void TensorIteratorBase::build_unary_op(const TensorBase& out, const TensorBase& a) {
|
void TensorIteratorBase::build_unary_op(const TensorBase& out, const TensorBase& a) {
|
||||||
build(TensorIteratorConfig()
|
build(UNARY_OP_CONFIG()
|
||||||
.set_check_mem_overlap(true)
|
|
||||||
.add_owned_output(out)
|
.add_owned_output(out)
|
||||||
.add_owned_input(a)
|
.add_owned_input(a));
|
||||||
.cast_common_dtype_to_outputs(false)
|
}
|
||||||
.enforce_safe_casting_to_output(false)
|
|
||||||
.check_all_same_dtype(true));
|
void TensorIteratorBase::build_borrowing_unary_op(const TensorBase& out, const TensorBase& a) {
|
||||||
|
build(UNARY_OP_CONFIG()
|
||||||
|
.add_output(out)
|
||||||
|
.add_input(a));
|
||||||
|
}
|
||||||
|
|
||||||
|
void TensorIteratorBase::build_output_borrowing_argument_owning_unary_op(const TensorBase& out, const TensorBase& a) {
|
||||||
|
build(UNARY_OP_CONFIG()
|
||||||
|
.add_output(out)
|
||||||
|
.add_owned_input(a));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper to construct a unary op that forcibly promotes output to boolean.
|
// Helper to construct a unary op that forcibly promotes output to boolean.
|
||||||
// Only be used when the output tensor must have boolean type.
|
// Only be used when the output tensor must have boolean type.
|
||||||
void TensorIteratorBase::build_unary_force_boolean_op(const TensorBase& out, const TensorBase& a) {
|
void TensorIteratorBase::build_borrowing_unary_force_boolean_op(const TensorBase& out, const TensorBase& a) {
|
||||||
build(TensorIteratorConfig()
|
build(TensorIteratorConfig()
|
||||||
.set_check_mem_overlap(true)
|
.set_check_mem_overlap(true)
|
||||||
.check_all_same_dtype(false)
|
.check_all_same_dtype(false)
|
||||||
.declare_static_dtype(at::kBool)
|
.declare_static_dtype(at::kBool)
|
||||||
.declare_static_device(a.device())
|
.declare_static_device(a.device())
|
||||||
.add_owned_output(out)
|
.add_output(out)
|
||||||
.add_owned_input(a));
|
.add_input(a));
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorIterator TensorIterator::binary_op(TensorBase& out, const TensorBase& a, const TensorBase& b) {
|
TensorIterator TensorIterator::binary_op(TensorBase& out, const TensorBase& a, const TensorBase& b) {
|
||||||
|
|
@ -1015,12 +1068,6 @@ TensorIterator TensorIterator::unary_float_op(TensorBase& out, const TensorBase&
|
||||||
return iter;
|
return iter;
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorIterator TensorIterator::unary_force_boolean_op(const TensorBase& out, const TensorBase& a) {
|
|
||||||
TensorIterator iter;
|
|
||||||
iter.build_unary_force_boolean_op(out, a);
|
|
||||||
return iter;
|
|
||||||
}
|
|
||||||
|
|
||||||
#define NULLARY_OP_CONFIG() \
|
#define NULLARY_OP_CONFIG() \
|
||||||
TensorIteratorConfig() \
|
TensorIteratorConfig() \
|
||||||
.set_check_mem_overlap(true) \
|
.set_check_mem_overlap(true) \
|
||||||
|
|
|
||||||
|
|
@ -435,9 +435,21 @@ public:
|
||||||
void build_borrowing_binary_op(const TensorBase& out, const TensorBase& a, const TensorBase& b);
|
void build_borrowing_binary_op(const TensorBase& out, const TensorBase& a, const TensorBase& b);
|
||||||
TORCH_DISALLOW_TEMPORARIES(build_borrowing_binary_op)
|
TORCH_DISALLOW_TEMPORARIES(build_borrowing_binary_op)
|
||||||
void build_unary_float_op(const TensorBase& out, const TensorBase& a);
|
void build_unary_float_op(const TensorBase& out, const TensorBase& a);
|
||||||
|
void build_borrowing_unary_float_op(const TensorBase& out, const TensorBase& a);
|
||||||
|
TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_float_op)
|
||||||
void build_unary_op(const TensorBase& out, const TensorBase& a);
|
void build_unary_op(const TensorBase& out, const TensorBase& a);
|
||||||
void build_unary_force_boolean_op(const TensorBase& out, const TensorBase& a);
|
// Odd special case needed for pow. Has to borrow the output because
|
||||||
|
// it's a structured kernel, but the argument is potentially a copy.
|
||||||
|
void build_output_borrowing_argument_owning_unary_op(const TensorBase& out, const TensorBase& a);
|
||||||
|
void build_borrowing_unary_op(const TensorBase& out, const TensorBase& a);
|
||||||
|
TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_op)
|
||||||
|
void build_borrowing_unary_force_boolean_op(const TensorBase& out, const TensorBase& a);
|
||||||
|
TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_force_boolean_op)
|
||||||
void build_comparison_op(const TensorBase& out, const TensorBase& a, const TensorBase& b);
|
void build_comparison_op(const TensorBase& out, const TensorBase& a, const TensorBase& b);
|
||||||
|
void build_borrowing_comparison_op(const TensorBase& out, const TensorBase& a, const TensorBase& b);
|
||||||
|
TORCH_DISALLOW_TEMPORARIES(build_borrowing_comparison_op)
|
||||||
|
// Another special case: we need to own the second argument for comparison ops.
|
||||||
|
void build_borrowing_except_last_argument_comparison_op(const TensorBase& out, const TensorBase& a, const TensorBase& b);
|
||||||
void build_ternary_op(const TensorBase& out, const TensorBase& a, const TensorBase& b, const TensorBase& c);
|
void build_ternary_op(const TensorBase& out, const TensorBase& a, const TensorBase& b, const TensorBase& c);
|
||||||
|
|
||||||
#undef TORCH_DISALLOW_TEMPORARIES
|
#undef TORCH_DISALLOW_TEMPORARIES
|
||||||
|
|
@ -571,7 +583,6 @@ struct TORCH_API TensorIterator final : public TensorIteratorBase {
|
||||||
static TensorIterator unary_op(TensorBase& out, const TensorBase& a);
|
static TensorIterator unary_op(TensorBase& out, const TensorBase& a);
|
||||||
static TensorIterator unary_float_op(TensorBase& out, const TensorBase& a);
|
static TensorIterator unary_float_op(TensorBase& out, const TensorBase& a);
|
||||||
static TensorIterator nullary_op(TensorBase& out);
|
static TensorIterator nullary_op(TensorBase& out);
|
||||||
static TensorIterator unary_force_boolean_op(const TensorBase& out, const TensorBase& a);
|
|
||||||
static TensorIterator borrowing_nullary_op(const TensorBase& out);
|
static TensorIterator borrowing_nullary_op(const TensorBase& out);
|
||||||
static TensorIterator borrowing_nullary_op(TensorBase&& out) = delete;
|
static TensorIterator borrowing_nullary_op(TensorBase&& out) = delete;
|
||||||
static TensorIterator reduce_op(TensorBase& out, const TensorBase& a);
|
static TensorIterator reduce_op(TensorBase& out, const TensorBase& a);
|
||||||
|
|
|
||||||
|
|
@ -211,13 +211,13 @@ void comparison_op_check(const Tensor& self, const Tensor& other, const Tensor&
|
||||||
TORCH_META_FUNC2(func, Tensor)(const Tensor& self, const Tensor& other) { \
|
TORCH_META_FUNC2(func, Tensor)(const Tensor& self, const Tensor& other) { \
|
||||||
const Tensor& result = maybe_get_output(); \
|
const Tensor& result = maybe_get_output(); \
|
||||||
comparison_op_check(self, other, result); \
|
comparison_op_check(self, other, result); \
|
||||||
build_comparison_op(result, self, other); \
|
build_borrowing_comparison_op(result, self, other); \
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
TORCH_META_FUNC2(func, Scalar)(const Tensor& self, const Scalar& other) { \
|
TORCH_META_FUNC2(func, Scalar)(const Tensor& self, const Scalar& other) { \
|
||||||
auto other_tensor = \
|
auto other_tensor = \
|
||||||
native::wrapped_scalar_tensor_and_check_convert(other, self); \
|
native::wrapped_scalar_tensor_and_check_convert(other, self); \
|
||||||
build_comparison_op(maybe_get_output(), self, other_tensor); \
|
build_borrowing_except_last_argument_comparison_op(maybe_get_output(), self, other_tensor); \
|
||||||
}
|
}
|
||||||
|
|
||||||
CREATE_COMPARISON_SCALAR_TENSOR_META_FUNC(eq);
|
CREATE_COMPARISON_SCALAR_TENSOR_META_FUNC(eq);
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ TORCH_META_FUNC2(pow, Tensor_Scalar) (const Tensor& base, const Scalar& exp) {
|
||||||
"Integers to negative integer powers are not allowed.");
|
"Integers to negative integer powers are not allowed.");
|
||||||
|
|
||||||
auto common_dtype = at::result_type(base, exp);
|
auto common_dtype = at::result_type(base, exp);
|
||||||
build_unary_op(maybe_get_output(), base.to(common_dtype));
|
build_output_borrowing_argument_owning_unary_op(maybe_get_output(), base.to(common_dtype));
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_META_FUNC2(pow, Scalar) (const Scalar& base, const Tensor& exp) {
|
TORCH_META_FUNC2(pow, Scalar) (const Scalar& base, const Tensor& exp) {
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ const OptionalScalarRef max) {
|
||||||
TORCH_CHECK(false, "torch.clamp: At least one of 'min' or 'max' must not be None");
|
TORCH_CHECK(false, "torch.clamp: At least one of 'min' or 'max' must not be None");
|
||||||
}
|
}
|
||||||
|
|
||||||
build_unary_op(maybe_get_output(), self);
|
build_borrowing_unary_op(maybe_get_output(), self);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_META_FUNC2(isin, Tensor_Tensor) (
|
TORCH_META_FUNC2(isin, Tensor_Tensor) (
|
||||||
|
|
@ -61,14 +61,14 @@ TORCH_META_FUNC(isposinf) (const Tensor& self) {
|
||||||
TORCH_CHECK(!self.is_complex(), "isposinf does not support complex inputs.");
|
TORCH_CHECK(!self.is_complex(), "isposinf does not support complex inputs.");
|
||||||
TORCH_CHECK(maybe_get_output().defined() ? maybe_get_output().dtype() == at::kBool : true,
|
TORCH_CHECK(maybe_get_output().defined() ? maybe_get_output().dtype() == at::kBool : true,
|
||||||
"isposinf does not support non-boolean outputs.");
|
"isposinf does not support non-boolean outputs.");
|
||||||
build_unary_force_boolean_op(maybe_get_output(), self);
|
build_borrowing_unary_force_boolean_op(maybe_get_output(), self);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_META_FUNC(isneginf) (const Tensor& self) {
|
TORCH_META_FUNC(isneginf) (const Tensor& self) {
|
||||||
TORCH_CHECK(!self.is_complex(), "isneginf does not support complex inputs.");
|
TORCH_CHECK(!self.is_complex(), "isneginf does not support complex inputs.");
|
||||||
TORCH_CHECK(maybe_get_output().defined() ? maybe_get_output().dtype() == at::kBool : true,
|
TORCH_CHECK(maybe_get_output().defined() ? maybe_get_output().dtype() == at::kBool : true,
|
||||||
"isneginf does not support non-boolean outputs.");
|
"isneginf does not support non-boolean outputs.");
|
||||||
build_unary_force_boolean_op(maybe_get_output(), self);
|
build_borrowing_unary_force_boolean_op(maybe_get_output(), self);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void check_unsupported_complex(const char* name, const Tensor& self) {
|
static void check_unsupported_complex(const char* name, const Tensor& self) {
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ namespace meta {
|
||||||
// For complex inputs, the output type should be the same as input type.
|
// For complex inputs, the output type should be the same as input type.
|
||||||
#define CREATE_UNARY_FLOAT_META_FUNC(func) \
|
#define CREATE_UNARY_FLOAT_META_FUNC(func) \
|
||||||
TORCH_META_FUNC(func) (const Tensor& self) { \
|
TORCH_META_FUNC(func) (const Tensor& self) { \
|
||||||
build_unary_float_op(maybe_get_output(), self); \
|
build_borrowing_unary_float_op(maybe_get_output(), self); \
|
||||||
}
|
}
|
||||||
|
|
||||||
CREATE_UNARY_FLOAT_META_FUNC(acos)
|
CREATE_UNARY_FLOAT_META_FUNC(acos)
|
||||||
|
|
@ -73,13 +73,13 @@ CREATE_UNARY_FLOAT_META_FUNC(tanh)
|
||||||
|
|
||||||
TORCH_META_FUNC(polygamma)(int64_t n, const Tensor& self) {
|
TORCH_META_FUNC(polygamma)(int64_t n, const Tensor& self) {
|
||||||
TORCH_CHECK(n >= 0, "polygamma(n, x) does not support negative n.");
|
TORCH_CHECK(n >= 0, "polygamma(n, x) does not support negative n.");
|
||||||
build_unary_float_op(maybe_get_output(), self);
|
build_borrowing_unary_float_op(maybe_get_output(), self);
|
||||||
}
|
}
|
||||||
|
|
||||||
// These are normal unary ops that preserve dtype
|
// These are normal unary ops that preserve dtype
|
||||||
#define CREATE_UNARY_META_FUNC(func) \
|
#define CREATE_UNARY_META_FUNC(func) \
|
||||||
TORCH_META_FUNC(func) (const Tensor& self) { \
|
TORCH_META_FUNC(func) (const Tensor& self) { \
|
||||||
build_unary_op(maybe_get_output(), self); \
|
build_borrowing_unary_op(maybe_get_output(), self); \
|
||||||
}
|
}
|
||||||
CREATE_UNARY_META_FUNC(bitwise_not)
|
CREATE_UNARY_META_FUNC(bitwise_not)
|
||||||
CREATE_UNARY_META_FUNC(frac)
|
CREATE_UNARY_META_FUNC(frac)
|
||||||
|
|
@ -90,41 +90,41 @@ TORCH_META_FUNC(neg)(const Tensor& self) {
|
||||||
TORCH_CHECK(self.scalar_type() != kBool,
|
TORCH_CHECK(self.scalar_type() != kBool,
|
||||||
"Negation, the `-` operator, on a bool tensor is not supported. "
|
"Negation, the `-` operator, on a bool tensor is not supported. "
|
||||||
"If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
|
"If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
|
||||||
build_unary_op(maybe_get_output(), self);
|
build_borrowing_unary_op(maybe_get_output(), self);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_META_FUNC(trunc) (const Tensor& self) {
|
TORCH_META_FUNC(trunc) (const Tensor& self) {
|
||||||
// Note: this is consistent with NumPy
|
// Note: this is consistent with NumPy
|
||||||
TORCH_CHECK(!self.is_complex(),
|
TORCH_CHECK(!self.is_complex(),
|
||||||
"trunc is not supported for complex inputs");
|
"trunc is not supported for complex inputs");
|
||||||
build_unary_op(maybe_get_output(), self);
|
build_borrowing_unary_op(maybe_get_output(), self);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_META_FUNC(floor) (const Tensor& self) {
|
TORCH_META_FUNC(floor) (const Tensor& self) {
|
||||||
// Note: this is consistent with NumPy
|
// Note: this is consistent with NumPy
|
||||||
TORCH_CHECK(!self.is_complex(),
|
TORCH_CHECK(!self.is_complex(),
|
||||||
"floor is not supported for complex inputs");
|
"floor is not supported for complex inputs");
|
||||||
build_unary_op(maybe_get_output(), self);
|
build_borrowing_unary_op(maybe_get_output(), self);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_META_FUNC(sign) (const Tensor& self) {
|
TORCH_META_FUNC(sign) (const Tensor& self) {
|
||||||
TORCH_CHECK(!self.is_complex(),
|
TORCH_CHECK(!self.is_complex(),
|
||||||
"Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.");
|
"Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.");
|
||||||
build_unary_op(maybe_get_output(), self);
|
build_borrowing_unary_op(maybe_get_output(), self);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_META_FUNC(signbit) (const Tensor& self) {
|
TORCH_META_FUNC(signbit) (const Tensor& self) {
|
||||||
TORCH_CHECK(!self.is_complex(), "signbit is not implemented for complex tensors.");
|
TORCH_CHECK(!self.is_complex(), "signbit is not implemented for complex tensors.");
|
||||||
TORCH_CHECK(maybe_get_output().defined() ? maybe_get_output().dtype() == at::kBool : true,
|
TORCH_CHECK(maybe_get_output().defined() ? maybe_get_output().dtype() == at::kBool : true,
|
||||||
"signbit does not support non-boolean outputs.");
|
"signbit does not support non-boolean outputs.");
|
||||||
build_unary_force_boolean_op(maybe_get_output(), self);
|
build_borrowing_unary_force_boolean_op(maybe_get_output(), self);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_META_FUNC(ceil) (const Tensor& self) {
|
TORCH_META_FUNC(ceil) (const Tensor& self) {
|
||||||
// Note: this is consistent with NumPy
|
// Note: this is consistent with NumPy
|
||||||
TORCH_CHECK(!self.is_complex(),
|
TORCH_CHECK(!self.is_complex(),
|
||||||
"ceil is not supported for complex inputs");
|
"ceil is not supported for complex inputs");
|
||||||
build_unary_op(maybe_get_output(), self);
|
build_borrowing_unary_op(maybe_get_output(), self);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace meta
|
} // namespace meta
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ void mkldnn_matmul(
|
||||||
bool use_mkldnn_bf16_matmul(
|
bool use_mkldnn_bf16_matmul(
|
||||||
const Tensor& mat1,
|
const Tensor& mat1,
|
||||||
const Tensor& mat2,
|
const Tensor& mat2,
|
||||||
const c10::optional<Tensor>& result_opt){
|
const Tensor& result_opt){
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -126,9 +126,7 @@ inline bool checksize(const Tensor& mat1, const Tensor& mat2){
|
||||||
bool use_mkldnn_bf16_matmul(
|
bool use_mkldnn_bf16_matmul(
|
||||||
const Tensor& mat1,
|
const Tensor& mat1,
|
||||||
const Tensor& mat2,
|
const Tensor& mat2,
|
||||||
const c10::optional<Tensor>& result_opt) {
|
const Tensor& result) {
|
||||||
c10::MaybeOwned<Tensor> result_maybe_owned = at::borrow_from_optional_tensor(result_opt);
|
|
||||||
const Tensor& result = *result_maybe_owned;
|
|
||||||
return (
|
return (
|
||||||
at::globalContext().userEnabledMkldnn() &&
|
at::globalContext().userEnabledMkldnn() &&
|
||||||
mat1.scalar_type() == kBFloat16 &&
|
mat1.scalar_type() == kBFloat16 &&
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ TORCH_API void mkldnn_matmul(
|
||||||
bool use_mkldnn_bf16_matmul(
|
bool use_mkldnn_bf16_matmul(
|
||||||
const Tensor& mat1,
|
const Tensor& mat1,
|
||||||
const Tensor& mat2,
|
const Tensor& mat2,
|
||||||
const c10::optional<Tensor>& result_opt);
|
const Tensor& result_opt);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user