[MPS]: Added op upsample_nearest1d (#81303)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81303
Approved by: https://github.com/malfet
This commit is contained in:
Kulin Seth 2022-07-13 21:39:50 +00:00 committed by PyTorch MergeBot
parent 85efdec060
commit 067c8067a3
3 changed files with 87 additions and 0 deletions

View File

@ -916,5 +916,77 @@ TORCH_IMPL_FUNC(upsample_bilinear2d_out_mps) (
using namespace mps;
upsample_out_mps(input, output_size, scales_h, scales_w, output, MPSGraphResizeBilinear, align_corners);
}
void upsample1d_out_mps(const Tensor& input,
IntArrayRef output_size,
c10::optional<double> scales,
const Tensor& output,
MPSGraphResizeMode requested_mode)
{
// Get stream
using namespace mps;
using CachedGraph = MPSUnaryCachedGraph;
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
/* sizes */
int64_t out_size = output_size[0];
@autoreleasepool {
MPSShape* input_shape = getMPSShape(input);
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
string key = string("upsample_1d:") + mps::getMPSShapeString(input_shape) + ":" +
getMPSTypeString(input.scalar_type()) +
":size" + to_string(out_size) +
":mode" + to_string(requested_mode);
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
if(!cachedGraph) {
cachedGraph = static_cast<CachedGraph*>(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), input_shape);
newCachedGraph->outputTensor_ = [mpsGraph resizeTensor:newCachedGraph->inputTensor_
size:@[ @(out_size), @(1)]
mode:requested_mode
centerResult: true
alignCorners: true
layout: MPSGraphTensorNamedDataLayoutCHW
name:nil];
}
return newCachedGraph;
}));
}
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
}
}
TORCH_IMPL_FUNC(upsample_nearest1d_out_mps) (
const Tensor& input,
IntArrayRef output_size,
c10::optional<double> scales,
const Tensor& output)
{
using namespace mps;
upsample1d_out_mps(input, output_size, scales, output, MPSGraphResizeNearest);
}
} // namespace native
} // namespace at

View File

@ -10483,6 +10483,7 @@
dispatch:
CPU: upsample_nearest1d_out_cpu
CUDA: upsample_nearest1d_out_cuda
MPS: upsample_nearest1d_out_mps
- func: _upsample_nearest_exact1d.out(Tensor self, int[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)
python_module: nn

View File

@ -3300,6 +3300,20 @@ class TestNLLLoss(TestCase):
helper(1, 1, 4, 4)
helper(7, 5, 3, 2)
def test_upsample_nearest1d(self):
def helper(N, C, H, W):
inputCPU = torch.arange(C * H * W, device='cpu', dtype=torch.float,
requires_grad=True).reshape(C, H, W)
inputMPS = inputCPU.detach().clone().to('mps')
outputCPU = torch.nn.functional.interpolate(inputCPU, scale_factor=2.0, mode='nearest')
outputMPS = torch.nn.functional.interpolate(inputMPS, scale_factor=2.0, mode='nearest')
self.assertEqual(outputCPU, outputMPS)
helper(1, 1, 4, 4)
helper(7, 5, 3, 2)
# Test concat forward
def test_cat1(self):
def helper(shape_x, shape_y, shape_z):