mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] Fix smooth_l1_loss backward for fp16 (#166687)
And enable fp16 implementation for CPU, which simplifies OpInfo definitions for the op Pull Request resolved: https://github.com/pytorch/pytorch/pull/166687 Approved by: https://github.com/Skylion007 ghstack dependencies: #166214
This commit is contained in:
parent
93a70c717a
commit
4e7232c5da
|
|
@ -139,7 +139,7 @@ void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, dou
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
AT_DISPATCH_ALL_TYPES(dtype, "smooth_l1_backward_cpu_out", [&] {
|
AT_DISPATCH_ALL_TYPES_AND(kHalf, dtype, "smooth_l1_backward_cpu_out", [&] {
|
||||||
auto norm_val = norm.to<scalar_t>();
|
auto norm_val = norm.to<scalar_t>();
|
||||||
scalar_t beta_val(beta);
|
scalar_t beta_val(beta);
|
||||||
auto norm_val_vec = Vectorized<scalar_t>(norm_val);
|
auto norm_val_vec = Vectorized<scalar_t>(norm_val);
|
||||||
|
|
|
||||||
|
|
@ -370,7 +370,7 @@ static void nllnd_loss_backward_impl(Tensor& grad_input_arg,
|
||||||
onValue:-1.0f
|
onValue:-1.0f
|
||||||
offValue:0.0f
|
offValue:0.0f
|
||||||
name:nil];
|
name:nil];
|
||||||
oneHotTensor = castMPSTensor(mpsGraph, oneHotTensor, inputTensor.dataType);
|
oneHotTensor = castMPSTensor(mpsGraph, oneHotTensor, [inputTensor dataType]);
|
||||||
if (isWeightsArrayValid) {
|
if (isWeightsArrayValid) {
|
||||||
oneHotTensor = [mpsGraph multiplicationWithPrimaryTensor:oneHotTensor
|
oneHotTensor = [mpsGraph multiplicationWithPrimaryTensor:oneHotTensor
|
||||||
secondaryTensor:weightTensor
|
secondaryTensor:weightTensor
|
||||||
|
|
@ -705,6 +705,7 @@ static void smooth_l1_loss_template(const Tensor& input,
|
||||||
TORCH_CHECK(beta >= 0, "smooth_l1_loss does not support negative values for beta.");
|
TORCH_CHECK(beta >= 0, "smooth_l1_loss does not support negative values for beta.");
|
||||||
TORCH_CHECK(input.is_mps());
|
TORCH_CHECK(input.is_mps());
|
||||||
TORCH_CHECK(target.is_mps());
|
TORCH_CHECK(target.is_mps());
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "MPS doesn't know how to do square_i64");
|
||||||
if ((input.numel() == 0) || (target.numel() == 0)) {
|
if ((input.numel() == 0) || (target.numel() == 0)) {
|
||||||
reduction == Reduction::Mean ? output.fill_(std::numeric_limits<float>::quiet_NaN()) : output.zero_();
|
reduction == Reduction::Mean ? output.fill_(std::numeric_limits<float>::quiet_NaN()) : output.zero_();
|
||||||
return;
|
return;
|
||||||
|
|
@ -771,7 +772,7 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
|
||||||
MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target);
|
MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target);
|
||||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
||||||
|
|
||||||
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta dataType:MPSDataTypeFloat32];
|
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta dataType:[inputTensor dataType]];
|
||||||
// xn - yn
|
// xn - yn
|
||||||
MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor
|
MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor
|
||||||
secondaryTensor:targetTensor
|
secondaryTensor:targetTensor
|
||||||
|
|
@ -797,7 +798,8 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
|
||||||
name:@"lossTensor"];
|
name:@"lossTensor"];
|
||||||
MPSGraphTensor* outputTensor = lossTensor;
|
MPSGraphTensor* outputTensor = lossTensor;
|
||||||
if (reduction == Reduction::Mean) {
|
if (reduction == Reduction::Mean) {
|
||||||
MPSGraphTensor* numelTensor = [mpsGraph constantWithScalar:(double)input.numel() dataType:MPSDataTypeFloat32];
|
MPSGraphTensor* numelTensor = [mpsGraph constantWithScalar:(double)input.numel()
|
||||||
|
dataType:[lossTensor dataType]];
|
||||||
outputTensor = [mpsGraph divisionWithPrimaryTensor:lossTensor secondaryTensor:numelTensor name:nil];
|
outputTensor = [mpsGraph divisionWithPrimaryTensor:lossTensor secondaryTensor:numelTensor name:nil];
|
||||||
}
|
}
|
||||||
MPSGraphTensor* gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:outputTensor
|
MPSGraphTensor* gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:outputTensor
|
||||||
|
|
|
||||||
|
|
@ -20340,9 +20340,7 @@ op_db: list[OpInfo] = [
|
||||||
ref=reference_smooth_l1_loss,
|
ref=reference_smooth_l1_loss,
|
||||||
sample_inputs_func=sample_inputs_smooth_l1_loss,
|
sample_inputs_func=sample_inputs_smooth_l1_loss,
|
||||||
dtypes=floating_types_and(torch.float16, torch.bfloat16),
|
dtypes=floating_types_and(torch.float16, torch.bfloat16),
|
||||||
backward_dtypes=floating_types_and(torch.bfloat16),
|
backward_dtypes=floating_types_and(torch.float16, torch.bfloat16),
|
||||||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
|
||||||
backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
|
||||||
supports_out=False,
|
supports_out=False,
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
supports_fwgrad_bwgrad=True,
|
supports_fwgrad_bwgrad=True,
|
||||||
|
|
|
||||||
|
|
@ -740,8 +740,6 @@ if torch.backends.mps.is_available():
|
||||||
"equal": [torch.float16, torch.float32],
|
"equal": [torch.float16, torch.float32],
|
||||||
# 'float' object is not iterable
|
# 'float' object is not iterable
|
||||||
"item": [torch.float16, torch.float32],
|
"item": [torch.float16, torch.float32],
|
||||||
# "smooth_l1_backward_cpu_out" not implemented for 'Half'
|
|
||||||
"nn.functional.smooth_l1_loss": [torch.float16],
|
|
||||||
# cpu error: grad requires non-empty inputs
|
# cpu error: grad requires non-empty inputs
|
||||||
"randn": [torch.float16, torch.float32],
|
"randn": [torch.float16, torch.float32],
|
||||||
"signal.windows.bartlett": [torch.float32],
|
"signal.windows.bartlett": [torch.float32],
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user