mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[MPS]: Added op upsample_nearest1d (#81303)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/81303 Approved by: https://github.com/malfet
This commit is contained in:
parent
85efdec060
commit
067c8067a3
|
|
@ -916,5 +916,77 @@ TORCH_IMPL_FUNC(upsample_bilinear2d_out_mps) (
|
|||
using namespace mps;
|
||||
upsample_out_mps(input, output_size, scales_h, scales_w, output, MPSGraphResizeBilinear, align_corners);
|
||||
}
|
||||
|
||||
void upsample1d_out_mps(const Tensor& input,
|
||||
IntArrayRef output_size,
|
||||
c10::optional<double> scales,
|
||||
const Tensor& output,
|
||||
MPSGraphResizeMode requested_mode)
|
||||
{
|
||||
// Get stream
|
||||
using namespace mps;
|
||||
using CachedGraph = MPSUnaryCachedGraph;
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
/* sizes */
|
||||
int64_t out_size = output_size[0];
|
||||
@autoreleasepool {
|
||||
MPSShape* input_shape = getMPSShape(input);
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string key = string("upsample_1d:") + mps::getMPSShapeString(input_shape) + ":" +
|
||||
getMPSTypeString(input.scalar_type()) +
|
||||
":size" + to_string(out_size) +
|
||||
":mode" + to_string(requested_mode);
|
||||
|
||||
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
|
||||
if(!cachedGraph) {
|
||||
cachedGraph = static_cast<CachedGraph*>(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
|
||||
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), input_shape);
|
||||
newCachedGraph->outputTensor_ = [mpsGraph resizeTensor:newCachedGraph->inputTensor_
|
||||
size:@[ @(out_size), @(1)]
|
||||
mode:requested_mode
|
||||
centerResult: true
|
||||
alignCorners: true
|
||||
layout: MPSGraphTensorNamedDataLayoutCHW
|
||||
name:nil];
|
||||
}
|
||||
return newCachedGraph;
|
||||
}));
|
||||
}
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TORCH_IMPL_FUNC(upsample_nearest1d_out_mps) (
|
||||
const Tensor& input,
|
||||
IntArrayRef output_size,
|
||||
c10::optional<double> scales,
|
||||
const Tensor& output)
|
||||
{
|
||||
using namespace mps;
|
||||
upsample1d_out_mps(input, output_size, scales, output, MPSGraphResizeNearest);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -10483,6 +10483,7 @@
|
|||
dispatch:
|
||||
CPU: upsample_nearest1d_out_cpu
|
||||
CUDA: upsample_nearest1d_out_cuda
|
||||
MPS: upsample_nearest1d_out_mps
|
||||
|
||||
- func: _upsample_nearest_exact1d.out(Tensor self, int[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: nn
|
||||
|
|
|
|||
|
|
@ -3300,6 +3300,20 @@ class TestNLLLoss(TestCase):
|
|||
helper(1, 1, 4, 4)
|
||||
helper(7, 5, 3, 2)
|
||||
|
||||
def test_upsample_nearest1d(self):
|
||||
def helper(N, C, H, W):
|
||||
inputCPU = torch.arange(C * H * W, device='cpu', dtype=torch.float,
|
||||
requires_grad=True).reshape(C, H, W)
|
||||
inputMPS = inputCPU.detach().clone().to('mps')
|
||||
|
||||
outputCPU = torch.nn.functional.interpolate(inputCPU, scale_factor=2.0, mode='nearest')
|
||||
outputMPS = torch.nn.functional.interpolate(inputMPS, scale_factor=2.0, mode='nearest')
|
||||
|
||||
self.assertEqual(outputCPU, outputMPS)
|
||||
|
||||
helper(1, 1, 4, 4)
|
||||
helper(7, 5, 3, 2)
|
||||
|
||||
# Test concat forward
|
||||
def test_cat1(self):
|
||||
def helper(shape_x, shape_y, shape_z):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user