pytorch/caffe2/python/helpers/pooling.py
Yangqing Jia deb1327b6e Re-apply #266
Summary: Closes https://github.com/caffe2/caffe2/pull/404

Differential Revision: D4943280

Pulled By: Yangqing

fbshipit-source-id: c0988598d8ccb8329feac88382686324b90d4d46
2017-04-25 21:17:04 -07:00

29 lines
787 B
Python

## @package pooling
# Module caffe2.python.helpers.pooling
## @package fc
# Module caffe2.python.helpers.pooling
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
def max_pool(model, blob_in, blob_out, use_cudnn=False, order="NCHW", **kwargs):
"""Max pooling"""
if use_cudnn:
kwargs['engine'] = 'CUDNN'
return model.net.MaxPool(blob_in, blob_out, order=order, **kwargs)
def average_pool(model, blob_in, blob_out, use_cudnn=False, order="NCHW",
**kwargs):
"""Average pooling"""
if use_cudnn:
kwargs['engine'] = 'CUDNN'
return model.net.AveragePool(
blob_in,
blob_out,
order=order,
**kwargs
)