[OSS][Metal] Support Resnet models

Summary:
This diff adds the missing ops to run the Resnet models from Torchvision. Move the tensors to GPU can significantly improve the perf as show below (iPhone11)

Time running on CPU (ms):

```
forward took: 166.115
forward took: 150.722
forward took: 150.383
forward took: 150.345
forward took: 150.761
forward took: 150.533
forward took: 150.588
forward took: 150.812
forward took: 150.925
forward took: 150.25
```

Time running on GPU (ms):

```
forward took: 39.9355
forward took: 41.3531
forward took: 41.798
forward took: 40.4744
forward took: 39.5181
forward took: 42.6464
forward took: 41.2658
forward took: 40.0862
forward took: 42.3533
forward took: 41.9348
```

Discrepancy in result

```
GPU:
    "(623, 4.6211)",
    "(111, 3.8809)",
    "(499, 3.8555)",
    "(596, 3.8047)",
    "(473, 3.7422)",
    "(846, 3.5762)",
    "(892, 3.5449)",
    "(813, 3.5098)",
    "(446, 3.5020)",
    "(902, 3.4980)"
CPU:
    "(623, 4.4229)",
    "(499, 3.8321)",
    "(596, 3.6192)",
    "(111, 3.5295)",
    "(813, 3.4848)",
    "(584, 3.3979)",
    "(418, 3.3357)",
    "(473, 3.2760)",
    "(846, 3.2745)",
    "(902, 3.2376)"
```

Test Plan: {F340824316}

Reviewed By: IvanKobzarev

Differential Revision: D24416294

fbshipit-source-id: 12c9199ade0b76a7aa8a3838eddc4c19c79b6f37
This commit is contained in:
Tao Xu 2020-10-22 10:47:15 -07:00 committed by Facebook GitHub Bot
parent 93719440b8
commit b63ddd6f57
3 changed files with 122 additions and 0 deletions

View File

@ -160,6 +160,11 @@ Tensor relu(const Tensor& input) {
return mpscnn::relu(input); return mpscnn::relu(input);
} }
Tensor& relu_(Tensor& input) {
TORCH_CHECK(input.is_metal());
return mpscnn::relu_(input);
}
Tensor sigmoid(const Tensor& input) { Tensor sigmoid(const Tensor& input) {
TORCH_CHECK(input.is_metal()); TORCH_CHECK(input.is_metal());
return mpscnn::sigmoid(input); return mpscnn::sigmoid(input);
@ -192,6 +197,14 @@ Tensor add_Tensor(const Tensor& input1, const Tensor& input2, Scalar alpha) {
return mpscnn::add(input1, input2.is_metal() ? input2 : input2.metal()); return mpscnn::add(input1, input2.is_metal() ? input2 : input2.metal());
} }
Tensor& add__Tensor(Tensor& input1, const Tensor& input2, Scalar alpha) {
TORCH_CHECK(input1.is_metal());
TORCH_CHECK(input1.dim() == input2.dim());
TORCH_CHECK(input1.sizes()[2] == input2.sizes()[2]);
TORCH_CHECK(input1.sizes()[3] == input2.sizes()[3]);
return mpscnn::add_(input1, input2.is_metal() ? input2 : input2.metal());
}
Tensor sub_Tensor(const Tensor& input1, const Tensor& input2, Scalar alpha) { Tensor sub_Tensor(const Tensor& input1, const Tensor& input2, Scalar alpha) {
TORCH_CHECK(input1.is_metal()); TORCH_CHECK(input1.is_metal());
TORCH_CHECK(input1.dim() == input2.dim()); TORCH_CHECK(input1.dim() == input2.dim());
@ -223,9 +236,18 @@ Tensor reshape(const Tensor& input, IntArrayRef shape) {
return mpscnn::reshape(input, shape); return mpscnn::reshape(input, shape);
} }
Tensor flatten_using_ints(
const Tensor& input,
int64_t start_dim,
int64_t end_dim) {
TORCH_CHECK(input.is_metal());
return mpscnn::flatten_using_ints(input, start_dim, end_dim);
}
TORCH_LIBRARY_IMPL(aten, Metal, m) { TORCH_LIBRARY_IMPL(aten, Metal, m) {
m.impl("conv2d", TORCH_FN(conv2d)); m.impl("conv2d", TORCH_FN(conv2d));
m.impl("add.Tensor", TORCH_FN(add_Tensor)); m.impl("add.Tensor", TORCH_FN(add_Tensor));
m.impl("add_.Tensor", TORCH_FN(add__Tensor));
m.impl("addmm", TORCH_FN(addmm)); m.impl("addmm", TORCH_FN(addmm));
m.impl_UNBOXED("empty.memory_format", empty); m.impl_UNBOXED("empty.memory_format", empty);
m.impl("empty_strided", TORCH_FN(empty_strided)); m.impl("empty_strided", TORCH_FN(empty_strided));
@ -233,6 +255,7 @@ TORCH_LIBRARY_IMPL(aten, Metal, m) {
m.impl("max_pool2d", TORCH_FN(max_pool2d)); m.impl("max_pool2d", TORCH_FN(max_pool2d));
m.impl("mul.Tensor", TORCH_FN(mul_Tensor)); m.impl("mul.Tensor", TORCH_FN(mul_Tensor));
m.impl("relu", TORCH_FN(relu)); m.impl("relu", TORCH_FN(relu));
m.impl("relu_", TORCH_FN(relu_));
m.impl("sigmoid", TORCH_FN(sigmoid)); m.impl("sigmoid", TORCH_FN(sigmoid));
m.impl("sub.Tensor", TORCH_FN(sub_Tensor)); m.impl("sub.Tensor", TORCH_FN(sub_Tensor));
m.impl("upsample_nearest2d.vec", TORCH_FN(upsample_nearest2d_vec)); m.impl("upsample_nearest2d.vec", TORCH_FN(upsample_nearest2d_vec));
@ -240,6 +263,7 @@ TORCH_LIBRARY_IMPL(aten, Metal, m) {
m.impl("adaptive_avg_pool2d", TORCH_FN(adaptive_avg_pool2d)); m.impl("adaptive_avg_pool2d", TORCH_FN(adaptive_avg_pool2d));
m.impl("hardtanh_", TORCH_FN(hardtanh_)); m.impl("hardtanh_", TORCH_FN(hardtanh_));
m.impl("reshape", TORCH_FN(reshape)); m.impl("reshape", TORCH_FN(reshape));
m.impl("flatten.using_ints", TORCH_FN(flatten_using_ints));
} }
} // namespace metal } // namespace metal

View File

@ -30,6 +30,8 @@ Tensor global_avg_pool2d(const Tensor& input, IntArrayRef output_size);
Tensor relu(const Tensor& input); Tensor relu(const Tensor& input);
Tensor& relu_(Tensor& input);
Tensor sigmoid(const Tensor& input); Tensor sigmoid(const Tensor& input);
Tensor& hardtanh_(Tensor& input, Scalar min_val, Scalar max_val); Tensor& hardtanh_(Tensor& input, Scalar min_val, Scalar max_val);
@ -44,6 +46,8 @@ Tensor addmm(const Tensor& bias, const Tensor& input, const Tensor& weight);
Tensor add(const Tensor& input1, const Tensor& input2); Tensor add(const Tensor& input1, const Tensor& input2);
Tensor& add_(Tensor& input1, const Tensor& input2);
Tensor sub(const Tensor& input1, const Tensor& input2); Tensor sub(const Tensor& input1, const Tensor& input2);
Tensor mul(const Tensor& input1, const Tensor& input2); Tensor mul(const Tensor& input1, const Tensor& input2);
@ -55,6 +59,8 @@ Tensor upsample_nearest2d_vec(
c10::optional<IntArrayRef> output_size, c10::optional<IntArrayRef> output_size,
c10::optional<ArrayRef<double>> scale_factors); c10::optional<ArrayRef<double>> scale_factors);
Tensor flatten_using_ints(const Tensor & input, int64_t start_dim, int64_t end_dim);
Tensor copy_to_host(const Tensor& input); Tensor copy_to_host(const Tensor& input);
} // namespace mpscnn } // namespace mpscnn

View File

@ -216,11 +216,36 @@ Tensor neuronKernel(const Tensor& input, MPSCNNNeuron* neuron) {
return output; return output;
} }
API_AVAILABLE(ios(10.0), macos(10.13))
Tensor& neuronKernel_(Tensor& input, MPSCNNNeuron* neuron) {
MPSImage* X = imageFromTensor(input);
std::vector<int64_t> outputSize = input.sizes().vec();
std::vector<int64_t> textureSize = outputSize;
if (input.dim() == 2) {
textureSize = {outputSize[0], outputSize[1], 1, 1};
}
MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input);
MPSImage* Y = [MPSImage temporaryImageFromSize:input.sizes().vec()
commandBuffer:commandBuffer];
[neuron encodeToCommandBuffer:commandBuffer.buffer
sourceImage:X
destinationImage:Y];
MetalTensorImpl* impl = (MetalTensorImpl*)input.unsafeGetTensorImpl();
MetalTensor& metalTensor = impl->unsafe_opaque_handle();
metalTensor.texture()->copyFromTexture(Y);
return input;
}
API_AVAILABLE(ios(10.0), macos(10.13)) API_AVAILABLE(ios(10.0), macos(10.13))
Tensor relu(const Tensor& input) { Tensor relu(const Tensor& input) {
return neuronKernel(input, [MPSCNNNeuronOp relu]); return neuronKernel(input, [MPSCNNNeuronOp relu]);
} }
API_AVAILABLE(ios(10.0), macos(10.13))
Tensor& relu_(Tensor& input) {
return neuronKernel_(input, [MPSCNNNeuronOp relu]);
}
API_AVAILABLE(ios(10.0), macos(10.13)) API_AVAILABLE(ios(10.0), macos(10.13))
Tensor sigmoid(const Tensor& input) { Tensor sigmoid(const Tensor& input) {
return neuronKernel(input, [MPSCNNNeuronOp sigmoid]); return neuronKernel(input, [MPSCNNNeuronOp sigmoid]);
@ -356,12 +381,50 @@ Tensor binaryElementwiseKernel(
return output; return output;
} }
API_AVAILABLE(ios(10.0), macos(10.13))
Tensor& binaryElementwiseKernel_(
Tensor& input1,
const Tensor& input2,
NSString* arrayKernel,
NSString* nonarrayKernal) {
MPSImage* X1 = imageFromTensor(input1);
MPSImage* X2 = imageFromTensor(input2);
std::vector<int64_t> outputSize = input1.sizes().vec();
MetalCommandBuffer* cb1 = commandBufferFromInputTensor(input1);
MetalCommandBuffer* cb2 = commandBufferFromInputTensor(input2);
TORCH_CHECK([cb1 isEqual:cb2], @"inputs have different command buffer");
MPSImage* Y = [MPSImage temporaryImageFromSize:outputSize commandBuffer:cb1];
id<MTLComputePipelineState> state = [[MPSCNNContext sharedInstance]
pipelineState:kernelFor(X1, arrayKernel, nonarrayKernal)];
id<MTLComputeCommandEncoder> encoder = [cb1.buffer computeCommandEncoder];
[encoder setComputePipelineState:state];
[encoder setTexture:[X1 texture] atIndex:0];
[encoder setTexture:[X2 texture] atIndex:1];
[encoder setTexture:[Y texture] atIndex:2];
const auto& launchParams = spatialPointwiseKernelLaunchParams(state, Y);
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
[encoder endEncoding];
[X1 markRead];
[X2 markRead];
MetalTensorImpl* impl = (MetalTensorImpl*)input1.unsafeGetTensorImpl();
MetalTensor& metalTensor = impl->unsafe_opaque_handle();
metalTensor.texture()->copyFromTexture(Y);
return input1;
}
API_AVAILABLE(ios(10.0), macos(10.13)) API_AVAILABLE(ios(10.0), macos(10.13))
Tensor add(const Tensor& input1, const Tensor& input2) { Tensor add(const Tensor& input1, const Tensor& input2) {
return binaryElementwiseKernel( return binaryElementwiseKernel(
input1, input2, @"elementwise_add", @"elementwise_add_nonarray"); input1, input2, @"elementwise_add", @"elementwise_add_nonarray");
} }
API_AVAILABLE(ios(10.0), macos(10.13))
Tensor& add_(Tensor& input1, const Tensor& input2) {
return binaryElementwiseKernel_(
input1, input2, @"elementwise_add", @"elementwise_add_nonarray");
}
API_AVAILABLE(ios(10.0), macos(10.13)) API_AVAILABLE(ios(10.0), macos(10.13))
Tensor sub(const Tensor& input1, const Tensor& input2) { Tensor sub(const Tensor& input1, const Tensor& input2) {
return binaryElementwiseKernel( return binaryElementwiseKernel(
@ -510,6 +573,35 @@ Tensor upsample_nearest2d_vec(
return output; return output;
} }
Tensor flatten_using_ints(
const Tensor& input,
int64_t start_dim,
int64_t end_dim) {
start_dim = maybe_wrap_dim(start_dim, input.dim());
end_dim = maybe_wrap_dim(end_dim, input.dim());
TORCH_CHECK(
start_dim <= end_dim,
"flatten() has invalid args: start_dim cannot come after end_dim");
std::vector<int64_t> shape;
if (input.dim() == 0) {
return input.reshape({1});
}
if (start_dim == end_dim) {
return input;
}
auto slice_numel =
prod_intlist(input.sizes().slice(start_dim, end_dim - start_dim + 1));
shape.reserve(input.dim() - end_dim + start_dim);
for (int64_t i = 0; i < start_dim; i++) {
shape.push_back(input.size(i));
}
shape.push_back(slice_numel);
for (int64_t i = end_dim + 1; i < input.dim(); i++) {
shape.push_back(input.size(i));
}
return input.reshape(shape);
}
Tensor copy_to_host(const Tensor& input) { Tensor copy_to_host(const Tensor& input) {
MPSImage* X = imageFromTensor(input); MPSImage* X = imageFromTensor(input);
MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input); MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input);