From 4e7232c5daf753e04e8f4189229e3c33888a33e5 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 31 Oct 2025 14:07:23 -0700 Subject: [PATCH] [MPS] Fix `smooth_l1_loss` backward for fp16 (#166687) And enable fp16 implementation for CPU, which simplifies OpInfo definitions for the op Pull Request resolved: https://github.com/pytorch/pytorch/pull/166687 Approved by: https://github.com/Skylion007 ghstack dependencies: #166214 --- aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp | 2 +- aten/src/ATen/native/mps/operations/LossOps.mm | 8 +++++--- torch/testing/_internal/common_methods_invocations.py | 4 +--- torch/testing/_internal/common_mps.py | 2 -- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp b/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp index 6fad9270bf1..3c3c5a90ec7 100644 --- a/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp @@ -139,7 +139,7 @@ void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, dou } ); } else { - AT_DISPATCH_ALL_TYPES(dtype, "smooth_l1_backward_cpu_out", [&] { + AT_DISPATCH_ALL_TYPES_AND(kHalf, dtype, "smooth_l1_backward_cpu_out", [&] { auto norm_val = norm.to(); scalar_t beta_val(beta); auto norm_val_vec = Vectorized(norm_val); diff --git a/aten/src/ATen/native/mps/operations/LossOps.mm b/aten/src/ATen/native/mps/operations/LossOps.mm index 2ba8860772e..c995b8fc237 100644 --- a/aten/src/ATen/native/mps/operations/LossOps.mm +++ b/aten/src/ATen/native/mps/operations/LossOps.mm @@ -370,7 +370,7 @@ static void nllnd_loss_backward_impl(Tensor& grad_input_arg, onValue:-1.0f offValue:0.0f name:nil]; - oneHotTensor = castMPSTensor(mpsGraph, oneHotTensor, inputTensor.dataType); + oneHotTensor = castMPSTensor(mpsGraph, oneHotTensor, [inputTensor dataType]); if (isWeightsArrayValid) { oneHotTensor = [mpsGraph multiplicationWithPrimaryTensor:oneHotTensor secondaryTensor:weightTensor @@ -705,6 +705,7 @@ static void smooth_l1_loss_template(const Tensor& input, TORCH_CHECK(beta >= 0, "smooth_l1_loss does not support negative values for beta."); TORCH_CHECK(input.is_mps()); TORCH_CHECK(target.is_mps()); + TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "MPS doesn't know how to do square_i64"); if ((input.numel() == 0) || (target.numel() == 0)) { reduction == Reduction::Mean ? output.fill_(std::numeric_limits::quiet_NaN()) : output.zero_(); return; @@ -771,7 +772,7 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output, MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target); MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); - MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta dataType:MPSDataTypeFloat32]; + MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta dataType:[inputTensor dataType]]; // xn - yn MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor secondaryTensor:targetTensor @@ -797,7 +798,8 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output, name:@"lossTensor"]; MPSGraphTensor* outputTensor = lossTensor; if (reduction == Reduction::Mean) { - MPSGraphTensor* numelTensor = [mpsGraph constantWithScalar:(double)input.numel() dataType:MPSDataTypeFloat32]; + MPSGraphTensor* numelTensor = [mpsGraph constantWithScalar:(double)input.numel() + dataType:[lossTensor dataType]]; outputTensor = [mpsGraph divisionWithPrimaryTensor:lossTensor secondaryTensor:numelTensor name:nil]; } MPSGraphTensor* gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:outputTensor diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 47517e8ff9b..fbeccf86f32 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -20340,9 +20340,7 @@ op_db: list[OpInfo] = [ ref=reference_smooth_l1_loss, sample_inputs_func=sample_inputs_smooth_l1_loss, dtypes=floating_types_and(torch.float16, torch.bfloat16), - backward_dtypes=floating_types_and(torch.bfloat16), - dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), - backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + backward_dtypes=floating_types_and(torch.float16, torch.bfloat16), supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, diff --git a/torch/testing/_internal/common_mps.py b/torch/testing/_internal/common_mps.py index b3289853192..9d3d65aba9a 100644 --- a/torch/testing/_internal/common_mps.py +++ b/torch/testing/_internal/common_mps.py @@ -740,8 +740,6 @@ if torch.backends.mps.is_available(): "equal": [torch.float16, torch.float32], # 'float' object is not iterable "item": [torch.float16, torch.float32], - # "smooth_l1_backward_cpu_out" not implemented for 'Half' - "nn.functional.smooth_l1_loss": [torch.float16], # cpu error: grad requires non-empty inputs "randn": [torch.float16, torch.float32], "signal.windows.bartlett": [torch.float32],