mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] Fix smooth_l1_loss backward for fp16 (#166687)
And enable fp16 implementation for CPU, which simplifies OpInfo definitions for the op Pull Request resolved: https://github.com/pytorch/pytorch/pull/166687 Approved by: https://github.com/Skylion007 ghstack dependencies: #166214
This commit is contained in:
parent
93a70c717a
commit
4e7232c5da
|
|
@ -139,7 +139,7 @@ void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, dou
|
|||
}
|
||||
);
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES(dtype, "smooth_l1_backward_cpu_out", [&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND(kHalf, dtype, "smooth_l1_backward_cpu_out", [&] {
|
||||
auto norm_val = norm.to<scalar_t>();
|
||||
scalar_t beta_val(beta);
|
||||
auto norm_val_vec = Vectorized<scalar_t>(norm_val);
|
||||
|
|
|
|||
|
|
@ -370,7 +370,7 @@ static void nllnd_loss_backward_impl(Tensor& grad_input_arg,
|
|||
onValue:-1.0f
|
||||
offValue:0.0f
|
||||
name:nil];
|
||||
oneHotTensor = castMPSTensor(mpsGraph, oneHotTensor, inputTensor.dataType);
|
||||
oneHotTensor = castMPSTensor(mpsGraph, oneHotTensor, [inputTensor dataType]);
|
||||
if (isWeightsArrayValid) {
|
||||
oneHotTensor = [mpsGraph multiplicationWithPrimaryTensor:oneHotTensor
|
||||
secondaryTensor:weightTensor
|
||||
|
|
@ -705,6 +705,7 @@ static void smooth_l1_loss_template(const Tensor& input,
|
|||
TORCH_CHECK(beta >= 0, "smooth_l1_loss does not support negative values for beta.");
|
||||
TORCH_CHECK(input.is_mps());
|
||||
TORCH_CHECK(target.is_mps());
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "MPS doesn't know how to do square_i64");
|
||||
if ((input.numel() == 0) || (target.numel() == 0)) {
|
||||
reduction == Reduction::Mean ? output.fill_(std::numeric_limits<float>::quiet_NaN()) : output.zero_();
|
||||
return;
|
||||
|
|
@ -771,7 +772,7 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
|
|||
MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target);
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
||||
|
||||
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta dataType:MPSDataTypeFloat32];
|
||||
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta dataType:[inputTensor dataType]];
|
||||
// xn - yn
|
||||
MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:targetTensor
|
||||
|
|
@ -797,7 +798,8 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
|
|||
name:@"lossTensor"];
|
||||
MPSGraphTensor* outputTensor = lossTensor;
|
||||
if (reduction == Reduction::Mean) {
|
||||
MPSGraphTensor* numelTensor = [mpsGraph constantWithScalar:(double)input.numel() dataType:MPSDataTypeFloat32];
|
||||
MPSGraphTensor* numelTensor = [mpsGraph constantWithScalar:(double)input.numel()
|
||||
dataType:[lossTensor dataType]];
|
||||
outputTensor = [mpsGraph divisionWithPrimaryTensor:lossTensor secondaryTensor:numelTensor name:nil];
|
||||
}
|
||||
MPSGraphTensor* gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:outputTensor
|
||||
|
|
|
|||
|
|
@ -20340,9 +20340,7 @@ op_db: list[OpInfo] = [
|
|||
ref=reference_smooth_l1_loss,
|
||||
sample_inputs_func=sample_inputs_smooth_l1_loss,
|
||||
dtypes=floating_types_and(torch.float16, torch.bfloat16),
|
||||
backward_dtypes=floating_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
backward_dtypes=floating_types_and(torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
|
|
|
|||
|
|
@ -740,8 +740,6 @@ if torch.backends.mps.is_available():
|
|||
"equal": [torch.float16, torch.float32],
|
||||
# 'float' object is not iterable
|
||||
"item": [torch.float16, torch.float32],
|
||||
# "smooth_l1_backward_cpu_out" not implemented for 'Half'
|
||||
"nn.functional.smooth_l1_loss": [torch.float16],
|
||||
# cpu error: grad requires non-empty inputs
|
||||
"randn": [torch.float16, torch.float32],
|
||||
"signal.windows.bartlett": [torch.float32],
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user