pytorch/caffe2/python/helpers/algebra.py
Luke Yeager 3cfc6f26e7 fp16: BatchMatMul
Summary:
Was https://github.com/caffe2/caffe2/pull/1151
Closes https://github.com/caffe2/caffe2/pull/1175

Reviewed By: Yangqing

Differential Revision: D5794634

Pulled By: akyrola

fbshipit-source-id: 911c462824edec3de529a5a4385a4c437e24bf59
2017-09-13 14:35:25 -07:00

26 lines
757 B
Python

## @package algebra
# Module caffe2.python.helpers.algebra
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
def transpose(model, blob_in, blob_out, use_cudnn=False, **kwargs):
"""Transpose."""
if use_cudnn:
kwargs['engine'] = 'CUDNN'
return model.net.Transpose(blob_in, blob_out, **kwargs)
def sum(model, blob_in, blob_out, **kwargs):
"""Sum"""
return model.net.Sum(blob_in, blob_out, **kwargs)
def batch_mat_mul(model, blob_in, blob_out,
enable_tensor_core=False, **kwargs):
if enable_tensor_core:
kwargs['engine'] = 'TENSORCORE'
return model.net.BatchMatMul(blob_in, blob_out, **kwargs)