mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[MPS] Add flip (#80214)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/80214 Approved by: https://github.com/DenisVieriu97, https://github.com/albanD
This commit is contained in:
parent
a48f3059b7
commit
c4da23ed1b
|
|
@ -9,6 +9,7 @@
|
|||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/WrapDimUtilsMulti.h>
|
||||
#include <ATen/native/LinearAlgebraUtils.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
|
|
@ -27,6 +28,92 @@
|
|||
namespace at {
|
||||
namespace native {
|
||||
|
||||
Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
|
||||
using namespace mps;
|
||||
|
||||
Tensor result = at::native::empty_mps(
|
||||
self.sizes(),
|
||||
self.scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
c10::nullopt);
|
||||
|
||||
auto total_dims = self.dim();
|
||||
// It wraps the dims and checks that there are no repeated dims
|
||||
auto flip_dims_b = at::dim_list_to_bitset(dims, total_dims);
|
||||
NSMutableArray<NSNumber*> * ns_dims = [NSMutableArray<NSNumber*> new];
|
||||
|
||||
for (const auto i : c10::irange(total_dims)) {
|
||||
if(flip_dims_b[i] && self.size(i) > 1 && self.stride(i) != 0) {
|
||||
[ns_dims addObject:[NSNumber numberWithInt:i]];
|
||||
}
|
||||
}
|
||||
|
||||
// Nothing to do, we return fast
|
||||
if (dims.size() == 0 || self.numel() <=1) {
|
||||
result.copy_(self);
|
||||
return result;
|
||||
}
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
NSString* ns_dims_key = [[ns_dims valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
// A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types etc match the earlier created MPSGraph
|
||||
string key = "flip_mps:" + getTensorsStringKey({self}) + ":" + string([ns_dims_key UTF8String]);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
if(!cachedGraph) {
|
||||
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
|
||||
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
MPSGraphTensor* outputTensor = [mpsGraph reverseTensor:inputTensor
|
||||
axes:ns_dims
|
||||
name:nil];
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
// Create placeholders which use the keys of the CachedGraph to create inputs and outputs of the operation
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);
|
||||
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
|
||||
// Run the graph
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
|
||||
return result;
|
||||
|
||||
}
|
||||
|
||||
Tensor index_select_mps(const Tensor & self,
|
||||
int64_t dim,
|
||||
const Tensor & index) {
|
||||
|
|
|
|||
|
|
@ -4869,6 +4869,7 @@
|
|||
variants: function, method
|
||||
dispatch:
|
||||
CPU, QuantizedCPU, CUDA, QuantizedCUDA: flip
|
||||
MPS: flip_mps
|
||||
|
||||
- func: fliplr(Tensor self) -> Tensor
|
||||
variants: function, method
|
||||
|
|
|
|||
|
|
@ -3726,6 +3726,28 @@ class TestNLLLoss(TestCase):
|
|||
|
||||
helper((2, 8, 4, 5))
|
||||
|
||||
# Test flip
|
||||
def test_flip(self):
|
||||
def helper(shape, dims):
|
||||
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
|
||||
x = cpu_x.detach().clone().to('mps')
|
||||
|
||||
flip_result = torch.flip(x, dims=dims)
|
||||
flip_result_cpu = torch.flip(cpu_x, dims=dims)
|
||||
|
||||
self.assertEqual(flip_result, flip_result_cpu)
|
||||
|
||||
helper((2, 8, 4, 5), [0])
|
||||
helper((8, 8, 4, 5), [0, 1])
|
||||
helper((2, 8, 4, 5), (0, 1, 2, 3))
|
||||
helper((2, 3, 3), (-1,))
|
||||
# empty dims
|
||||
helper((2, 8, 4, 5), [])
|
||||
# input.numel() == 1
|
||||
helper((1,), (0,))
|
||||
# input.numel() == 0
|
||||
helper((0,), (0,))
|
||||
|
||||
# Test index select
|
||||
def test_index_select(self):
|
||||
def helper(shape, dim, index, idx_dtype=torch.int32):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user