[MPS] Add flip (#80214)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80214
Approved by: https://github.com/DenisVieriu97, https://github.com/albanD
This commit is contained in:
qqaatw 2022-06-28 19:51:43 +00:00 committed by PyTorch MergeBot
parent a48f3059b7
commit c4da23ed1b
3 changed files with 110 additions and 0 deletions

View File

@ -9,6 +9,7 @@
#include <ATen/ExpandUtils.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/Resize.h>
@ -27,6 +28,92 @@
namespace at {
namespace native {
Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
using namespace mps;
Tensor result = at::native::empty_mps(
self.sizes(),
self.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);
auto total_dims = self.dim();
// It wraps the dims and checks that there are no repeated dims
auto flip_dims_b = at::dim_list_to_bitset(dims, total_dims);
NSMutableArray<NSNumber*> * ns_dims = [NSMutableArray<NSNumber*> new];
for (const auto i : c10::irange(total_dims)) {
if(flip_dims_b[i] && self.size(i) > 1 && self.stride(i) != 0) {
[ns_dims addObject:[NSNumber numberWithInt:i]];
}
}
// Nothing to do, we return fast
if (dims.size() == 0 || self.numel() <=1) {
result.copy_(self);
return result;
}
MPSStream* stream = getCurrentMPSStream();
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
NSString* ns_dims_key = [[ns_dims valueForKey:@"description"] componentsJoinedByString:@","];
// A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types etc match the earlier created MPSGraph
string key = "flip_mps:" + getTensorsStringKey({self}) + ":" + string([ns_dims_key UTF8String]);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* outputTensor = [mpsGraph reverseTensor:inputTensor
axes:ns_dims
name:nil];
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}
// Create placeholders which use the keys of the CachedGraph to create inputs and outputs of the operation
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
// Run the graph
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
return result;
}
Tensor index_select_mps(const Tensor & self,
int64_t dim,
const Tensor & index) {

View File

@ -4869,6 +4869,7 @@
variants: function, method
dispatch:
CPU, QuantizedCPU, CUDA, QuantizedCUDA: flip
MPS: flip_mps
- func: fliplr(Tensor self) -> Tensor
variants: function, method

View File

@ -3726,6 +3726,28 @@ class TestNLLLoss(TestCase):
helper((2, 8, 4, 5))
# Test flip
def test_flip(self):
def helper(shape, dims):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
flip_result = torch.flip(x, dims=dims)
flip_result_cpu = torch.flip(cpu_x, dims=dims)
self.assertEqual(flip_result, flip_result_cpu)
helper((2, 8, 4, 5), [0])
helper((8, 8, 4, 5), [0, 1])
helper((2, 8, 4, 5), (0, 1, 2, 3))
helper((2, 3, 3), (-1,))
# empty dims
helper((2, 8, 4, 5), [])
# input.numel() == 1
helper((1,), (0,))
# input.numel() == 0
helper((0,), (0,))
# Test index select
def test_index_select(self):
def helper(shape, dim, index, idx_dtype=torch.int32):