pytorch/torch/utils/mkldnn.py
Aaron Gokaslan 3fe437b24b [BE]: Update flake8 to v6.1.0 and fix lints (#116591)
Updates flake8 to v6.1.0 and fixes a few lints using sed and some ruff tooling.
- Replace `assert(0)` with `raise AssertionError()`
- Remove extraneous parenthesis i.e.
  - `assert(a == b)` -> `assert a == b`
  - `if(x > y or y < z):`->`if x > y or y < z:`
  - And `return('...')` -> `return '...'`

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116591
Approved by: https://github.com/albanD, https://github.com/malfet
2024-01-03 06:04:44 +00:00

234 lines
7.7 KiB
Python

import torch
class MkldnnLinear(torch.jit.ScriptModule):
def __init__(self, dense_module, dtype):
super().__init__()
self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype))
if dense_module.bias is not None:
# Bias can be fp32 or bf16 for OneDNN bf16 path, but for good accuracy,
# we use fp32 dtype.
self.register_buffer('bias', dense_module.bias.to_mkldnn())
else:
# TODO: Remove this once ScriptModule supports registering None buffer
self.register_buffer(
'bias',
torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn())
@torch.jit.script_method
def __getstate__(self):
return (self.weight.to_dense(), self.bias.to_dense(), self.training)
@torch.jit.script_method
def __setstate__(self, state):
self.weight = state[0].to_mkldnn()
self.bias = state[1].to_mkldnn()
self.training = state[2]
@torch.jit.script_method
def forward(self, x):
x_mkldnn = x if x.is_mkldnn else x.to_mkldnn()
y_mkldnn = torch._C._nn.mkldnn_linear(x_mkldnn, self.weight, self.bias)
y = y_mkldnn if x.is_mkldnn else y_mkldnn.to_dense()
return y
class _MkldnnConvNd(torch.jit.ScriptModule):
"""Common base of MkldnnConv1d and MkldnnConv2d."""
__constants__ = ['stride', 'padding', 'dilation', 'groups']
def __init__(self, dense_module):
super().__init__()
self.stride = dense_module.stride
self.padding = dense_module.padding
self.dilation = dense_module.dilation
self.groups = dense_module.groups
if dense_module.bias is not None:
self.register_buffer('bias', dense_module.bias.to_mkldnn())
else:
# Bias can be fp32 or bf16 for OneDNN bf16 path, but for good accuracy,
# we use fp32 dtype.
# TODO: Remove this once ScriptModule supports registering None buffer
self.register_buffer(
'bias',
torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn())
@torch.jit.script_method
def __getstate__(self):
return (self.weight.to_dense(), self.bias.to_dense(), self.training)
@torch.jit.script_method
def forward(self, x):
return torch.mkldnn_convolution(
x,
self.weight,
self.bias,
self.padding,
self.stride,
self.dilation,
self.groups)
class MkldnnConv1d(_MkldnnConvNd):
def __init__(self, dense_module, dtype):
super().__init__(dense_module)
self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype))
@torch.jit.script_method
def __setstate__(self, state):
self.weight = state[0].to_mkldnn()
self.bias = state[1].to_mkldnn()
self.training = state[2]
class MkldnnConv2d(_MkldnnConvNd):
def __init__(self, dense_module, dtype):
super().__init__(dense_module)
self.register_buffer('weight', torch._C._nn.mkldnn_reorder_conv2d_weight(
dense_module.weight.to_mkldnn(dtype),
self.padding,
self.stride,
self.dilation,
self.groups))
@torch.jit.script_method
def __setstate__(self, state):
self.weight = torch._C._nn.mkldnn_reorder_conv2d_weight(
state[0].to_mkldnn(),
self.padding,
self.stride,
self.dilation,
self.groups)
self.bias = state[1].to_mkldnn()
self.training = state[2]
class MkldnnConv3d(_MkldnnConvNd):
def __init__(self, dense_module, dtype):
super().__init__(dense_module)
self.register_buffer('weight', torch._C._nn.mkldnn_reorder_conv3d_weight(
dense_module.weight.to_mkldnn(dtype),
self.padding,
self.stride,
self.dilation,
self.groups))
@torch.jit.script_method
def __setstate__(self, state):
self.weight = torch._C._nn.mkldnn_reorder_conv3d_weight(
state[0].to_mkldnn(),
self.padding,
self.stride,
self.dilation,
self.groups)
self.bias = state[1].to_mkldnn()
self.training = state[2]
class MkldnnBatchNorm(torch.jit.ScriptModule):
__constants__ = ['exponential_average_factor', 'eps']
def __init__(self, dense_module):
super().__init__()
assert not dense_module.training
assert dense_module.track_running_stats
assert dense_module.affine
if dense_module.momentum is None:
self.exponential_average_factor = 0.0
else:
self.exponential_average_factor = dense_module.momentum
self.eps = dense_module.eps
self.register_buffer('weight', dense_module.weight.to_mkldnn())
self.register_buffer('bias', dense_module.bias.to_mkldnn())
self.register_buffer('running_mean', dense_module.running_mean.to_mkldnn())
self.register_buffer('running_var', dense_module.running_var.to_mkldnn())
@torch.jit.script_method
def __getstate__(self):
weight = self.weight.to_dense()
bias = self.bias.to_dense()
running_mean = self.running_mean.to_dense()
running_var = self.running_var.to_dense()
return (weight, bias, running_mean, running_var, self.training)
@torch.jit.script_method
def __setstate__(self, state):
self.weight = state[0].to_mkldnn()
self.bias = state[1].to_mkldnn()
self.running_mean = state[2].to_mkldnn()
self.running_var = state[3].to_mkldnn()
self.training = state[4]
@torch.jit.script_method
def forward(self, x):
return torch.batch_norm(
x,
self.weight,
self.bias,
self.running_mean,
self.running_var,
False, # training
self.exponential_average_factor,
self.eps,
False, # cuda_enabled
)
class MkldnnPrelu(torch.jit.ScriptModule):
def __init__(self, dense_module, dtype):
super().__init__()
self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype))
@torch.jit.script_method
def __getstate__(self):
return (self.weight.to_dense(), self.training)
@torch.jit.script_method
def __setstate__(self, state):
self.weight = state[0].to_mkldnn()
self.training = state[1]
@torch.jit.script_method
def forward(self, x):
x_mkldnn = x if x.is_mkldnn else x.to_mkldnn()
y_mkldnn = torch.prelu(x_mkldnn, self.weight)
y = y_mkldnn if x.is_mkldnn else y_mkldnn.to_dense()
return y
def to_mkldnn(module, dtype=torch.float):
assert dtype in [torch.float, torch.bfloat16, torch.half], \
"MKLDNN only support float, bfloat16, and half path now"
def m_fn(m, d):
if isinstance(m, torch.nn.Linear):
return MkldnnLinear(m, d)
elif isinstance(m, torch.nn.Conv1d):
return MkldnnConv1d(m, d)
elif isinstance(m, torch.nn.Conv2d):
return MkldnnConv2d(m, d)
elif isinstance(m, torch.nn.Conv3d):
return MkldnnConv3d(m, d)
elif isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)):
# For batchnorm bf16 path, OneDNN requires weight and bias need fp32 dtype.
# so it doesn't need dtype argument.
return MkldnnBatchNorm(m)
elif isinstance(m, torch.nn.PReLU):
return MkldnnPrelu(m, d)
else:
return m
def m_fn_rec(m, d):
new_m = m_fn(m, d)
for name, sub_m in m.named_children():
setattr(new_m, name, m_fn_rec(sub_m, d))
return new_m
return m_fn_rec(module, dtype)