[MPS] Fix ops with bool issues in macOS Monterey (#94464)

Summary:
- Remove redundant bool casts from scatter/gather
- Make the workarounds for scatter/gather (for bool/uint8 data types) OS specific - use them only in macOS Monterey, ignore them starting with macOS Ventura
- Make all tensors ranked in scatter

Fixes following tests:
```
test_output_match_slice_scatter_cpu_bool
test_output_match_select_scatter_cpu_bool
test_output_match_diagonal_scatter_cpu_bool
test_output_match_repeat_cpu_bool
test_output_match_rot90_cpu_bool
etc..
```

Still failing on macOS Monterey (needs additional investigation):
```
test_output_match_scatter_cpu_bool
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94464
Approved by: https://github.com/kulinseth
This commit is contained in:
Denis Vieriu 2023-02-10 21:36:21 +00:00 committed by PyTorch MergeBot
parent 5b1cedacde
commit 728dfeee48
6 changed files with 71 additions and 65 deletions

View File

@ -439,7 +439,16 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
using CachedGraph = mps::MPSUnaryCachedGraph;
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
MPSDataType inputDataType = getMPSScalarType(self.scalar_type());
MPSDataType outputDataType = getMPSScalarType(self.scalar_type());
if (!is_macos_13_or_newer()) {
if (self.scalar_type() == kBool) {
inputDataType = MPSDataTypeInt8;
}
if (result.scalar_type() == kBool) {
outputDataType = MPSDataTypeInt8;
}
}
@autoreleasepool {
NSString* ns_dims_key = [[ns_dims valueForKey:@"description"] componentsJoinedByString:@","];
// A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types etc match the earlier created MPSGraph
@ -454,7 +463,7 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputDataType, getMPSShape(self));
MPSGraphTensor* outputTensor = [mpsGraph reverseTensor:inputTensor
axes:ns_dims
name:nil];
@ -466,8 +475,10 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
}
// Create placeholders which use the keys of the CachedGraph to create inputs and outputs of the operation
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);
Placeholder inputPlaceholder = Placeholder(
cachedGraph->inputTensor_, self, /*mpsShape*/nil, /*gatherTensorData=*/true, inputDataType);
Placeholder outputPlaceholder = Placeholder(
cachedGraph->outputTensor_, result, /*mpsShape*/nil, /*gatherTensorData=*/false, outputDataType);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
@ -656,12 +667,15 @@ Tensor& index_select_out_mps(const Tensor & self,
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
auto inputType = getMPSDataType(self.scalar_type());
auto outputType = getMPSDataType(output.scalar_type());
if (inputType == MPSDataTypeUInt8 || inputType == MPSDataTypeBool) {
inputType = MPSDataTypeInt8;
if (inputType == MPSDataTypeUInt8 ||
(!is_macos_13_or_newer() && inputType == MPSDataTypeBool)) {
inputType = MPSDataTypeInt8;
}
if (outputType == MPSDataTypeUInt8 || outputType == MPSDataTypeBool) {
outputType = MPSDataTypeInt8;
if (outputType == MPSDataTypeUInt8 ||
(!is_macos_13_or_newer() && outputType == MPSDataTypeBool)) {
outputType = MPSDataTypeInt8;
}
@autoreleasepool {
string key = "index_select_out_mps" + getTensorsStringKey({self, index}) + ":" + std::to_string(dim);
@ -792,10 +806,11 @@ Tensor & masked_fill__mps(Tensor& self, const Tensor & mask, const Scalar& value
}
Placeholder selfPlaceholder = Placeholder(
cachedGraph->inputTensor_, self, /*mpsShape*/nullptr, /*gatherTensorData=*/true, inputDataType);
cachedGraph->inputTensor_, self, /*mpsShape*/nil, /*gatherTensorData=*/true, inputDataType);
Placeholder maskPlaceholder = Placeholder(
cachedGraph->maskTensor_, *b_mask, /*mpsShape*/nullptr, /*gatherTensorData=*/true, maskDataType);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, self);
cachedGraph->maskTensor_, *b_mask, /*mpsShape*/nil, /*gatherTensorData=*/true, maskDataType);
Placeholder outputPlaceholder = Placeholder(
cachedGraph->outputTensor_, self, /*mpsShape*/nil, /*gatherTensorData=*/false, inputDataType);
// Create dictionary of inputs and outputs
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{

View File

@ -71,6 +71,16 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {
}
auto stream = at::mps::getCurrentMPSStream();
auto inputDataType = getMPSDataType(expanded_tensor.scalar_type());
auto outputDataType = getMPSDataType(result.scalar_type());
if (!is_macos_13_or_newer()) {
if (expanded_tensor.scalar_type() == kBool) {
inputDataType = MPSDataTypeInt8;
}
if (result.scalar_type() == kBool) {
outputDataType = MPSDataTypeInt8;
}
}
@autoreleasepool {
string key = "repeat_mps:" + getTensorsStringKey(self) + ":" + getArrayRefString(repeats);
@ -84,7 +94,7 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, expanded_tensor);
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputDataType, getMPSShape(expanded_tensor));
MPSGraphTensor* outputTensor = [mpsGraph tileTensor:inputTensor
withMultiplier:getMPSShape(repeats)
name:nil];
@ -97,8 +107,10 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, expanded_tensor);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);
Placeholder selfPlaceholder = Placeholder(
cachedGraph->inputTensor_, expanded_tensor, /*mpsShape=*/nil, /*gatherTensorData=*/true, inputDataType);
Placeholder outputPlaceholder = Placeholder(
cachedGraph->outputTensor_, result, /*mpsShape=*/nil, /*gatherTensorData*/false, outputDataType);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()

View File

@ -51,11 +51,13 @@ TORCH_IMPL_FUNC(gather_out_mps)
if(i != dim && [index_shape[i] intValue] < [input_shape[i] intValue])
needSlice = true;
}
// input and output types are always the same
auto dtype = getMPSDataType(self.scalar_type());
// workaround for UInt8 and Bool issues in MPS backend
if (dtype == MPSDataTypeUInt8 || dtype == MPSDataTypeBool) {
dtype = MPSDataTypeInt8;
auto input_type = getMPSDataType(self.scalar_type());
auto output_type = getMPSDataType(output.scalar_type());
if (input_type == MPSDataTypeUInt8 || ((input_type == MPSDataTypeBool && !is_macos_13_or_newer()))) {
input_type = MPSDataTypeInt8;
}
if (output_type == MPSDataTypeUInt8 || ((output_type == MPSDataTypeBool && !is_macos_13_or_newer()))) {
output_type = MPSDataTypeInt8;
}
string key = "gather_out_mps" + getTensorsStringKey({self, index, output}) + ":" + std::to_string(dim);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
@ -68,7 +70,7 @@ TORCH_IMPL_FUNC(gather_out_mps)
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, dtype, input_shape);
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_type, getMPSShape(self));
MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder(mpsGraph, index);
MPSGraphTensor* getInput = inputTensor;
@ -111,9 +113,9 @@ TORCH_IMPL_FUNC(gather_out_mps)
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, input_shape, true, dtype);
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, input_shape, true, input_type);
Placeholder indexPlaceholder = Placeholder(cachedGraph->indexTensor_, index, index_shape);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output, nullptr, false, dtype);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output, nullptr, false, output_type);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),

View File

@ -392,7 +392,8 @@ TORCH_IMPL_FUNC(cat_out_mps)
if (!is_macos_13_or_newer() && out.scalar_type() == kBool) {
outputDataType = MPSDataTypeInt8;
}
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out, nil, false, outputDataType);
Placeholder outputPlaceholder = Placeholder(
cachedGraph->outputTensor_, out, /*mpsShape=*/nil, /*gatherTensorData=*/false, outputDataType);
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
for (auto& inputPlaceholder : inputPlaceholders) {

View File

@ -53,7 +53,7 @@ static Tensor& runViewGraph(ViewCachedGraph* cachedGraph, const at::Tensor& src,
dataType: inputType] autorelease];
if (needsScatter) {
auto updatesType = getMPSScalarType(src.scalar_type());
if (updatesType == MPSDataTypeUInt8 || updatesType == MPSDataTypeBool) {
if (updatesType == MPSDataTypeUInt8 || (updatesType == MPSDataTypeBool && !is_macos_13_or_newer())) {
updatesType = MPSDataTypeInt8;
}
@ -69,10 +69,10 @@ static Tensor& runViewGraph(ViewCachedGraph* cachedGraph, const at::Tensor& src,
strideScalars[i] = getMPSScalar(strides[i], ScalarType::Int);
feeds[cachedGraph->strideTensors[i]] = getMPSGraphTensorFromScalar(stream, strideScalars[i]);
}
// Workaround for MPSShaderLibrary bug
// TODO: Remove once https://github.com/pytorch/pytorch/issues/82305 is resolved
auto outputType = getMPSDataType(output.scalar_type());
if (outputType == MPSDataTypeUInt8) {
// Workaround for MPSShaderLibrary bug in macOS Monterey
// This is fixed in macOS Ventura
auto outputType = getMPSScalarType(output.scalar_type());
if (outputType == MPSDataTypeUInt8 || (outputType == MPSDataTypeBool && !is_macos_13_or_newer())) {
outputType = MPSDataTypeInt8;
}
MPSGraphTensorData* outputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer: outputBuffer
@ -505,7 +505,6 @@ MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mp
static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph, const IntArrayRef& size,
const IntArrayRef& stride, int64_t offset,
const IntArrayRef& base_shape, bool needsScatter,
const bool needsBoolCast,
MPSGraphTensor* updatesTensor)
{
MPSGraph* mpsGraph = cachedGraph->graph();
@ -548,23 +547,9 @@ static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph, const In
name: nil];
MPSGraphTensor *inputTensor = cachedGraph->inputTensor;
// Workaround for bool scatter/gather deficiency
// See https://github.com/pytorch/pytorch/issues/82663
if (needsBoolCast) {
inputTensor = [mpsGraph castTensor:inputTensor
toType:MPSDataTypeInt8
name:@"Cast away from bool"];
}
if (!needsScatter) {
MPSGraphTensor *outputTensor = asStridedLayer_pattern(mpsGraph, inputTensor, shape_size, size, stride, offset);
if (outputTensor) {
if (needsBoolCast) {
outputTensor = [mpsGraph castTensor:outputTensor
toType:MPSDataTypeBool
name:@"Cast back to bool"];
}
return outputTensor;
}
}
@ -597,14 +582,6 @@ static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph, const In
withShapeTensor: shapeTensor
name: nil];
}
// Workaround for bool scatter/gather deficiency
// See https://github.com/pytorch/pytorch/issues/82663
if (needsBoolCast) {
outputTensor = [mpsGraph castTensor:outputTensor
toType:MPSDataTypeBool
name:@"Cast back to bool"];
}
}
return outputTensor;
}
@ -660,13 +637,13 @@ static ViewCachedGraph* createViewGraph(const Tensor& self, const Tensor &update
MPSGraph* mpsGraph = make_mps_graph();
MPSGraphTensor* updatesTensor = nil;
newCachedGraph = new ViewCachedGraph(mpsGraph);
// Workaround for MPSShaderLibrary bug
// TODO: Remove once https://github.com/pytorch/pytorch/issues/82305 is resolved
// Workaround for MPSShaderLibrary bug in macOS Monterey
// This is fixed in macOS Ventura
auto inputType = getMPSScalarType(self.scalar_type());
if (inputType == MPSDataTypeUInt8) {
if (inputType == MPSDataTypeUInt8 || (inputType == MPSDataTypeBool && !is_macos_13_or_newer())) {
inputType = MPSDataTypeInt8;
}
auto needsBoolCast = inputType == MPSDataTypeBool;
// Self is the input tensor we are creating view of
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputType, getMPSShape(base_shape));
newCachedGraph->storageOffsetTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@1]);
@ -675,10 +652,10 @@ static ViewCachedGraph* createViewGraph(const Tensor& self, const Tensor &update
}
if (needsScatter) {
auto updatesType = getMPSScalarType(updates.scalar_type());
if (updatesType == MPSDataTypeUInt8) {
updatesType = MPSDataTypeInt8;
if (updatesType == MPSDataTypeUInt8 || (updatesType == MPSDataTypeBool && !is_macos_13_or_newer())) {
updatesType = MPSDataTypeInt8;
}
newCachedGraph->updatesTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, updatesType);
newCachedGraph->updatesTensor = mpsGraphRankedPlaceHolder(mpsGraph, updatesType, getMPSShape(self.numel()));
updatesTensor = newCachedGraph->updatesTensor;
if (inputType != updatesType) {
updatesTensor = [mpsGraph castTensor:updatesTensor
@ -686,7 +663,7 @@ static ViewCachedGraph* createViewGraph(const Tensor& self, const Tensor &update
name:@"castUpdatesTensor"];
}
}
newCachedGraph->outputTensor = chainViewOperation(newCachedGraph, size, stride, storage_offset, base_shape, needsScatter, needsBoolCast, updatesTensor);
newCachedGraph->outputTensor = chainViewOperation(newCachedGraph, size, stride, storage_offset, base_shape, needsScatter, updatesTensor);
}
return newCachedGraph;
}));

View File

@ -8718,7 +8718,7 @@ class TestConsistency(TestCase):
'diag': ['f32', 'i32'],
'diag_embed': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
'diagflat': ['f32', 'i32'],
'diagonal_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
'diagonal_scatter': ['b8', 'u8', 'f16', 'f32', 'i16', 'i32', 'i64'],
'diff': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'dist': ['f32'],
'dot': ['f32', 'i16', 'i32', 'i64', 'u8'],
@ -8840,25 +8840,25 @@ class TestConsistency(TestCase):
'real': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'reciprocal': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8'],
'remainder' : ['f32', 'f16'],
'repeat': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'repeat': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'repeat_interleave': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'resize_': ['b8', 'i16', 'i32', 'i64', 'u8'],
'resize_as_': ['b8', 'i16', 'i32', 'i64', 'u8'],
'resolve_conj': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'resolve_neg': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'rot90': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'rot90': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'round': ['f32', 'f16', 'i16', 'i32', 'i64'],
'rsqrt': ['b8', 'f32', 'i16', 'i32', 'u8'],
'scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'scatter_add': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'select_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
'select_scatter': ['b8', 'u8', 'f16', 'f32', 'i16', 'i32', 'i64'],
'sgn': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'short': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'sigmoid': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8'],
'sign': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8', 'i64'],
'sin': ['b8', 'f32', 'i16', 'i32', 'u8'],
'sinh': ['b8', 'f32', 'i16', 'i32', 'u8'],
'slice_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
'slice_scatter': ['b8', 'u8', 'f16', 'f32', 'i16', 'i32', 'i64'],
'softmax': ['f32'],
'special.ndtr': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'],
'split': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
@ -9144,7 +9144,6 @@ class TestConsistency(TestCase):
'pow': [torch.int64],
'select_scatter': [torch.uint8],
'sigmoid': [torch.int64],
'slice_scatter': [torch.uint8],
'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8], # moved from section below