pytorch/torch/utils/mkldnn.py
Junjie Bai c9f380df02 Add aten mkldnn linear operator
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19210

Reviewed By: dzhulgakov

Differential Revision: D14901641

fbshipit-source-id: 8fa68b9941fd93cea0f313a828cba34c5c81ae11
2019-04-26 13:41:57 -07:00

44 lines
1.4 KiB
Python

from __future__ import absolute_import, division, print_function, unicode_literals
import functools
import torch
def to_mkldnn(module):
def t_fn(t):
if t.is_floating_point():
return t.to_mkldnn()
def m_fn(m):
# TODO: This is a temporary hack to work around the fact that
# nn.Linear is decomposed into addmm/matmul. Later we will
# change nn.Linear to directly call aten linear and we can
# remove this patch
if isinstance(m, torch.nn.Linear):
m.forward = functools.partial(
torch._C._nn.linear,
weight=m.weight,
bias=m.bias)
for param in m._parameters.values():
if param is not None:
# Tensors stored in modules are graph leaves, and we don't
# want to create copy nodes, so we have to unpack the data.
param.data = t_fn(param.data)
if param._grad is not None:
param._grad.data = t_fn(param._grad.data)
for key, buf in m._buffers.items():
if buf is not None:
m._buffers[key] = t_fn(buf)
if isinstance(m, torch.nn.Conv2d):
m.weight.data = torch._C._nn.mkldnn_reorder_conv2d_weight(
m.weight.data,
m.padding,
m.stride,
m.dilation,
m.groups)
return module.apply(m_fn)