diff --git a/aten/src/ATen/TensorIterator.cpp b/aten/src/ATen/TensorIterator.cpp index b96c2d897ba..f978456754d 100644 --- a/aten/src/ATen/TensorIterator.cpp +++ b/aten/src/ATen/TensorIterator.cpp @@ -877,14 +877,8 @@ void TensorIteratorBase::build_borrowing_binary_float_op( .add_input(b)); } -void TensorIteratorBase::build_comparison_op( - const TensorBase& out, const TensorBase& a, const TensorBase& b) { - TensorIteratorConfig config; - +static void set_up_comparison_op_config(TensorIteratorConfig& config, const TensorBase& out) { config.set_check_mem_overlap(true); - config.add_owned_output(out); - config.add_owned_input(a); - config.add_owned_input(b); config.allow_cpu_scalars(true); config.promote_inputs_to_common_dtype(true); @@ -905,7 +899,38 @@ void TensorIteratorBase::build_comparison_op( if (out.defined() && out.scalar_type() != kBool) { config.cast_common_dtype_to_outputs(true); } +} +void TensorIteratorBase::build_comparison_op( + const TensorBase& out, const TensorBase& a, const TensorBase& b) { + TensorIteratorConfig config; + set_up_comparison_op_config(config, out); + + config.add_owned_output(out); + config.add_owned_input(a); + config.add_owned_input(b); + build(config); +} + +void TensorIteratorBase::build_borrowing_comparison_op( + const TensorBase& out, const TensorBase& a, const TensorBase& b) { + TensorIteratorConfig config; + set_up_comparison_op_config(config, out); + + config.add_borrowed_output(out); + config.add_borrowed_input(a); + config.add_borrowed_input(b); + build(config); +} + +void TensorIteratorBase::build_borrowing_except_last_argument_comparison_op( + const TensorBase& out, const TensorBase& a, const TensorBase& b) { + TensorIteratorConfig config; + set_up_comparison_op_config(config, out); + + config.add_borrowed_output(out); + config.add_borrowed_input(a); + config.add_owned_input(b); build(config); } @@ -944,37 +969,65 @@ void TensorIteratorBase::build_borrowing_binary_op( .add_input(b)); } +// This cannot be a function because TensorIteratorConfig is not +// copyable or movable, so it can't be returned from the function. +#define UNARY_FLOAT_OP_CONFIG() \ + TensorIteratorConfig() \ + .set_check_mem_overlap(true) \ + .promote_inputs_to_common_dtype(true) \ + .cast_common_dtype_to_outputs(true) \ + .enforce_safe_casting_to_output(true) \ + .promote_integer_inputs_to_float(true) + void TensorIteratorBase::build_unary_float_op(const TensorBase& out, const TensorBase& a) { - build(TensorIteratorConfig() - .set_check_mem_overlap(true) + build(UNARY_FLOAT_OP_CONFIG() .add_owned_output(out) - .add_owned_input(a) - .promote_inputs_to_common_dtype(true) - .cast_common_dtype_to_outputs(true) - .enforce_safe_casting_to_output(true) - .promote_integer_inputs_to_float(true)); + .add_owned_input(a)); } +void TensorIteratorBase::build_borrowing_unary_float_op(const TensorBase& out, const TensorBase& a) { + build(UNARY_FLOAT_OP_CONFIG() + .add_output(out) + .add_input(a)); +} + +// This cannot be a function because TensorIteratorConfig is not +// copyable or movable, so it can't be returned from the function. +#define UNARY_OP_CONFIG() \ + TensorIteratorConfig() \ + .set_check_mem_overlap(true) \ + .cast_common_dtype_to_outputs(false) \ + .enforce_safe_casting_to_output(false) \ + .check_all_same_dtype(true) + void TensorIteratorBase::build_unary_op(const TensorBase& out, const TensorBase& a) { - build(TensorIteratorConfig() - .set_check_mem_overlap(true) + build(UNARY_OP_CONFIG() .add_owned_output(out) - .add_owned_input(a) - .cast_common_dtype_to_outputs(false) - .enforce_safe_casting_to_output(false) - .check_all_same_dtype(true)); + .add_owned_input(a)); +} + +void TensorIteratorBase::build_borrowing_unary_op(const TensorBase& out, const TensorBase& a) { + build(UNARY_OP_CONFIG() + .add_output(out) + .add_input(a)); +} + +void TensorIteratorBase::build_output_borrowing_argument_owning_unary_op(const TensorBase& out, const TensorBase& a) { + build(UNARY_OP_CONFIG() + .add_output(out) + .add_owned_input(a)); } // Helper to construct a unary op that forcibly promotes output to boolean. // Only be used when the output tensor must have boolean type. -void TensorIteratorBase::build_unary_force_boolean_op(const TensorBase& out, const TensorBase& a) { +void TensorIteratorBase::build_borrowing_unary_force_boolean_op(const TensorBase& out, const TensorBase& a) { build(TensorIteratorConfig() .set_check_mem_overlap(true) .check_all_same_dtype(false) .declare_static_dtype(at::kBool) .declare_static_device(a.device()) - .add_owned_output(out) - .add_owned_input(a)); + .add_output(out) + .add_input(a)); } TensorIterator TensorIterator::binary_op(TensorBase& out, const TensorBase& a, const TensorBase& b) { @@ -1015,12 +1068,6 @@ TensorIterator TensorIterator::unary_float_op(TensorBase& out, const TensorBase& return iter; } -TensorIterator TensorIterator::unary_force_boolean_op(const TensorBase& out, const TensorBase& a) { - TensorIterator iter; - iter.build_unary_force_boolean_op(out, a); - return iter; -} - #define NULLARY_OP_CONFIG() \ TensorIteratorConfig() \ .set_check_mem_overlap(true) \ diff --git a/aten/src/ATen/TensorIterator.h b/aten/src/ATen/TensorIterator.h index f5cd12e3d13..1c485e84f16 100644 --- a/aten/src/ATen/TensorIterator.h +++ b/aten/src/ATen/TensorIterator.h @@ -435,9 +435,21 @@ public: void build_borrowing_binary_op(const TensorBase& out, const TensorBase& a, const TensorBase& b); TORCH_DISALLOW_TEMPORARIES(build_borrowing_binary_op) void build_unary_float_op(const TensorBase& out, const TensorBase& a); + void build_borrowing_unary_float_op(const TensorBase& out, const TensorBase& a); + TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_float_op) void build_unary_op(const TensorBase& out, const TensorBase& a); - void build_unary_force_boolean_op(const TensorBase& out, const TensorBase& a); + // Odd special case needed for pow. Has to borrow the output because + // it's a structured kernel, but the argument is potentially a copy. + void build_output_borrowing_argument_owning_unary_op(const TensorBase& out, const TensorBase& a); + void build_borrowing_unary_op(const TensorBase& out, const TensorBase& a); + TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_op) + void build_borrowing_unary_force_boolean_op(const TensorBase& out, const TensorBase& a); + TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_force_boolean_op) void build_comparison_op(const TensorBase& out, const TensorBase& a, const TensorBase& b); + void build_borrowing_comparison_op(const TensorBase& out, const TensorBase& a, const TensorBase& b); + TORCH_DISALLOW_TEMPORARIES(build_borrowing_comparison_op) + // Another special case: we need to own the second argument for comparison ops. + void build_borrowing_except_last_argument_comparison_op(const TensorBase& out, const TensorBase& a, const TensorBase& b); void build_ternary_op(const TensorBase& out, const TensorBase& a, const TensorBase& b, const TensorBase& c); #undef TORCH_DISALLOW_TEMPORARIES @@ -571,7 +583,6 @@ struct TORCH_API TensorIterator final : public TensorIteratorBase { static TensorIterator unary_op(TensorBase& out, const TensorBase& a); static TensorIterator unary_float_op(TensorBase& out, const TensorBase& a); static TensorIterator nullary_op(TensorBase& out); - static TensorIterator unary_force_boolean_op(const TensorBase& out, const TensorBase& a); static TensorIterator borrowing_nullary_op(const TensorBase& out); static TensorIterator borrowing_nullary_op(TensorBase&& out) = delete; static TensorIterator reduce_op(TensorBase& out, const TensorBase& a); diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index 2d86f06cb49..13c8475eedf 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -211,13 +211,13 @@ void comparison_op_check(const Tensor& self, const Tensor& other, const Tensor& TORCH_META_FUNC2(func, Tensor)(const Tensor& self, const Tensor& other) { \ const Tensor& result = maybe_get_output(); \ comparison_op_check(self, other, result); \ - build_comparison_op(result, self, other); \ + build_borrowing_comparison_op(result, self, other); \ } \ \ TORCH_META_FUNC2(func, Scalar)(const Tensor& self, const Scalar& other) { \ auto other_tensor = \ native::wrapped_scalar_tensor_and_check_convert(other, self); \ - build_comparison_op(maybe_get_output(), self, other_tensor); \ + build_borrowing_except_last_argument_comparison_op(maybe_get_output(), self, other_tensor); \ } CREATE_COMPARISON_SCALAR_TENSOR_META_FUNC(eq); diff --git a/aten/src/ATen/native/Pow.cpp b/aten/src/ATen/native/Pow.cpp index 70be3001e14..c4b54e05ff1 100644 --- a/aten/src/ATen/native/Pow.cpp +++ b/aten/src/ATen/native/Pow.cpp @@ -20,7 +20,7 @@ TORCH_META_FUNC2(pow, Tensor_Scalar) (const Tensor& base, const Scalar& exp) { "Integers to negative integer powers are not allowed."); auto common_dtype = at::result_type(base, exp); - build_unary_op(maybe_get_output(), base.to(common_dtype)); + build_output_borrowing_argument_owning_unary_op(maybe_get_output(), base.to(common_dtype)); } TORCH_META_FUNC2(pow, Scalar) (const Scalar& base, const Tensor& exp) { diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index c5699b672ee..905a0117d7f 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -30,7 +30,7 @@ const OptionalScalarRef max) { TORCH_CHECK(false, "torch.clamp: At least one of 'min' or 'max' must not be None"); } - build_unary_op(maybe_get_output(), self); + build_borrowing_unary_op(maybe_get_output(), self); } TORCH_META_FUNC2(isin, Tensor_Tensor) ( @@ -61,14 +61,14 @@ TORCH_META_FUNC(isposinf) (const Tensor& self) { TORCH_CHECK(!self.is_complex(), "isposinf does not support complex inputs."); TORCH_CHECK(maybe_get_output().defined() ? maybe_get_output().dtype() == at::kBool : true, "isposinf does not support non-boolean outputs."); - build_unary_force_boolean_op(maybe_get_output(), self); + build_borrowing_unary_force_boolean_op(maybe_get_output(), self); } TORCH_META_FUNC(isneginf) (const Tensor& self) { TORCH_CHECK(!self.is_complex(), "isneginf does not support complex inputs."); TORCH_CHECK(maybe_get_output().defined() ? maybe_get_output().dtype() == at::kBool : true, "isneginf does not support non-boolean outputs."); - build_unary_force_boolean_op(maybe_get_output(), self); + build_borrowing_unary_force_boolean_op(maybe_get_output(), self); } static void check_unsupported_complex(const char* name, const Tensor& self) { diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 645e7fcf665..f32662c3f97 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -31,7 +31,7 @@ namespace meta { // For complex inputs, the output type should be the same as input type. #define CREATE_UNARY_FLOAT_META_FUNC(func) \ TORCH_META_FUNC(func) (const Tensor& self) { \ - build_unary_float_op(maybe_get_output(), self); \ + build_borrowing_unary_float_op(maybe_get_output(), self); \ } CREATE_UNARY_FLOAT_META_FUNC(acos) @@ -73,13 +73,13 @@ CREATE_UNARY_FLOAT_META_FUNC(tanh) TORCH_META_FUNC(polygamma)(int64_t n, const Tensor& self) { TORCH_CHECK(n >= 0, "polygamma(n, x) does not support negative n."); - build_unary_float_op(maybe_get_output(), self); + build_borrowing_unary_float_op(maybe_get_output(), self); } // These are normal unary ops that preserve dtype #define CREATE_UNARY_META_FUNC(func) \ TORCH_META_FUNC(func) (const Tensor& self) { \ - build_unary_op(maybe_get_output(), self); \ + build_borrowing_unary_op(maybe_get_output(), self); \ } CREATE_UNARY_META_FUNC(bitwise_not) CREATE_UNARY_META_FUNC(frac) @@ -90,41 +90,41 @@ TORCH_META_FUNC(neg)(const Tensor& self) { TORCH_CHECK(self.scalar_type() != kBool, "Negation, the `-` operator, on a bool tensor is not supported. " "If you are trying to invert a mask, use the `~` or `logical_not()` operator instead."); - build_unary_op(maybe_get_output(), self); + build_borrowing_unary_op(maybe_get_output(), self); } TORCH_META_FUNC(trunc) (const Tensor& self) { // Note: this is consistent with NumPy TORCH_CHECK(!self.is_complex(), "trunc is not supported for complex inputs"); - build_unary_op(maybe_get_output(), self); + build_borrowing_unary_op(maybe_get_output(), self); } TORCH_META_FUNC(floor) (const Tensor& self) { // Note: this is consistent with NumPy TORCH_CHECK(!self.is_complex(), "floor is not supported for complex inputs"); - build_unary_op(maybe_get_output(), self); + build_borrowing_unary_op(maybe_get_output(), self); } TORCH_META_FUNC(sign) (const Tensor& self) { TORCH_CHECK(!self.is_complex(), "Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead."); - build_unary_op(maybe_get_output(), self); + build_borrowing_unary_op(maybe_get_output(), self); } TORCH_META_FUNC(signbit) (const Tensor& self) { TORCH_CHECK(!self.is_complex(), "signbit is not implemented for complex tensors."); TORCH_CHECK(maybe_get_output().defined() ? maybe_get_output().dtype() == at::kBool : true, "signbit does not support non-boolean outputs."); - build_unary_force_boolean_op(maybe_get_output(), self); + build_borrowing_unary_force_boolean_op(maybe_get_output(), self); } TORCH_META_FUNC(ceil) (const Tensor& self) { // Note: this is consistent with NumPy TORCH_CHECK(!self.is_complex(), "ceil is not supported for complex inputs"); - build_unary_op(maybe_get_output(), self); + build_borrowing_unary_op(maybe_get_output(), self); } } // namespace meta diff --git a/aten/src/ATen/native/mkldnn/Matmul.cpp b/aten/src/ATen/native/mkldnn/Matmul.cpp index 69568a04ad5..79402ea09e1 100644 --- a/aten/src/ATen/native/mkldnn/Matmul.cpp +++ b/aten/src/ATen/native/mkldnn/Matmul.cpp @@ -19,7 +19,7 @@ void mkldnn_matmul( bool use_mkldnn_bf16_matmul( const Tensor& mat1, const Tensor& mat2, - const c10::optional& result_opt){ + const Tensor& result_opt){ return false; } @@ -126,9 +126,7 @@ inline bool checksize(const Tensor& mat1, const Tensor& mat2){ bool use_mkldnn_bf16_matmul( const Tensor& mat1, const Tensor& mat2, - const c10::optional& result_opt) { - c10::MaybeOwned result_maybe_owned = at::borrow_from_optional_tensor(result_opt); - const Tensor& result = *result_maybe_owned; + const Tensor& result) { return ( at::globalContext().userEnabledMkldnn() && mat1.scalar_type() == kBFloat16 && diff --git a/aten/src/ATen/native/mkldnn/Matmul.h b/aten/src/ATen/native/mkldnn/Matmul.h index f19365c6759..250f4f22842 100644 --- a/aten/src/ATen/native/mkldnn/Matmul.h +++ b/aten/src/ATen/native/mkldnn/Matmul.h @@ -16,7 +16,7 @@ TORCH_API void mkldnn_matmul( bool use_mkldnn_bf16_matmul( const Tensor& mat1, const Tensor& mat2, - const c10::optional& result_opt); + const Tensor& result_opt); }