mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49980 From ``` ./python/libcst/libcst codemod remove_unused_imports.RemoveUnusedImportsWithGlean --no-format caffe2/ ``` Test Plan: Standard sandcastle tests Reviewed By: xush6528 Differential Revision: D25727359 fbshipit-source-id: c4f60005b10546423dc093d31d46deb418352286
341 lines
11 KiB
Python
341 lines
11 KiB
Python
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
import copy
|
|
import json
|
|
|
|
import numpy as np
|
|
|
|
from caffe2.proto import caffe2_pb2
|
|
from caffe2.python import core, workspace, utils
|
|
import caffe2.python._import_c_extension as C
|
|
|
|
|
|
|
|
def pairwise(iterable):
|
|
from itertools import tee
|
|
a, b = tee(iterable)
|
|
next(b, None)
|
|
return zip(a, b)
|
|
|
|
|
|
def last_producer(ops, blob):
|
|
for (i, op) in reversed(list(enumerate(ops))):
|
|
if blob in op.output:
|
|
return i
|
|
raise ValueError("Failed to find last producer of blob, %s", blob)
|
|
|
|
|
|
def blob_uses(net, blob):
|
|
u = []
|
|
for i, op in enumerate(net.op):
|
|
if blob in op.input or blob in op.control_input:
|
|
u.append(i)
|
|
return u
|
|
|
|
|
|
def GetArgumentParser():
|
|
parser = argparse.ArgumentParser(description="Caffe2 optimization")
|
|
parser.add_argument("--init_net",
|
|
type=argparse.FileType('rb'),
|
|
help="init net")
|
|
parser.add_argument("--pred_net",
|
|
type=argparse.FileType('rb'),
|
|
help="predict net")
|
|
parser.add_argument("--verify_input",
|
|
type=argparse.FileType('r'),
|
|
help="input dims for verification")
|
|
parser.add_argument("--fuse_bn", default=False, action='store_true')
|
|
parser.add_argument("--fuse_mul_add", default=False, action='store_true')
|
|
parser.add_argument("--fuse_conv_relu", default=False, action='store_true')
|
|
return parser
|
|
|
|
|
|
def fuse_first_bn(net, params, removed_tensors):
|
|
net = copy.deepcopy(net)
|
|
params = copy.deepcopy(params)
|
|
|
|
for ((i, current), (j, next_)) in pairwise(enumerate(net.op)):
|
|
if next_.input[0] != current.output[0]:
|
|
continue
|
|
|
|
if current.type not in ("Conv", "ConvTranspose") \
|
|
or next_.type != "SpatialBN":
|
|
continue
|
|
if len(blob_uses(net, current.output[0])) != 1:
|
|
# Can't fuse if more than one user
|
|
continue
|
|
|
|
# else, can fuse
|
|
conv = current
|
|
bn = next_
|
|
fused_conv = copy.deepcopy(conv)
|
|
fused_conv.output[0] = bn.output[0]
|
|
|
|
# Fix fused_conv to ensure we have a bias passed.
|
|
if len(fused_conv.input) != 3:
|
|
bias_name = "{}_bias".format(conv.input[1])
|
|
net.external_input.extend([bias_name])
|
|
fused_conv.input.extend([bias_name])
|
|
for arg in fused_conv.arg:
|
|
if arg.name == "no_bias":
|
|
arg.i = 0
|
|
|
|
conv_weight = params[conv.input[1]]
|
|
conv_bias = params[conv.input[2]] if len(conv.input) == 3 \
|
|
else np.zeros(shape=(conv_weight.shape[0])).astype(np.float32)
|
|
|
|
bn_scale = params[bn.input[1]]
|
|
bn_bias = params[bn.input[2]]
|
|
bn_running_mean = params[bn.input[3]]
|
|
bn_running_var = params[bn.input[4]]
|
|
|
|
# First, BN computation can be phrased as follows:
|
|
# (X - running_mean) * (1.0 / sqrt(running_var + eps)) *
|
|
# bn_scale + bias
|
|
# Thus, we can rewrite bn_scale as:
|
|
# X * bn_scale * 1.0 / (sqrt(running_var + eps)) + (bias -
|
|
# running_mean * (1.0 / sqrt(running_var + eps)) * bn_scale)
|
|
# Thus, can just have the affine transform
|
|
# X * A + B
|
|
# where
|
|
# A = bn_scale * 1.0 / (sqrt(running_var + eps))
|
|
# B = (bias - running_mean * (1.0 / sqrt(running_var + eps))
|
|
# * bn_scale)
|
|
eps = 1.0e-5
|
|
for arg in bn.arg:
|
|
if arg.name == "epsilon":
|
|
eps = arg.f
|
|
A = bn_scale * 1.0 / (np.sqrt(bn_running_var + eps))
|
|
B = bn_bias - bn_running_mean * A
|
|
|
|
# This identify should hold if we have correctly fused
|
|
# np.testing.assert_array_equal(
|
|
# params[conv.output[0]] * A + B,
|
|
# params[bn.output[0]])
|
|
|
|
# Now, we have that the computation made is the following:
|
|
# ((X `conv` W) + b) * A + B
|
|
# Then, we can simply fuse this as follows:
|
|
# (X `conv` (W * A)) + b * A + B
|
|
# which is simply
|
|
# (X `conv` Q) + C
|
|
# where
|
|
|
|
# Q = W * A
|
|
# C = b * A + B
|
|
|
|
# For ConvTranspose, from the view of convolutions as a
|
|
# Toepeliz multiplication, we have W_ = W^T, so the weights
|
|
# are laid out as (R, S, K, K) (vs (S, R, K, K) for a Conv),
|
|
# so the weights broadcast slightly differently. Remember, our
|
|
# BN scale 'B' is of size (S,)
|
|
|
|
A_ = A.reshape(-1, 1, 1, 1) if conv.type == "Conv" else \
|
|
A.reshape(1, -1, 1, 1)
|
|
|
|
C = conv_bias * A + B
|
|
Q = conv_weight * A_
|
|
|
|
params[fused_conv.input[1]] = Q
|
|
params[fused_conv.input[2]] = C
|
|
new_ops = net.op[:i] + [fused_conv] + net.op[j + 1:]
|
|
del net.op[:]
|
|
removed_tensors.append(bn.input[1])
|
|
removed_tensors.append(bn.input[2])
|
|
removed_tensors.append(bn.input[3])
|
|
removed_tensors.append(bn.input[4])
|
|
del params[bn.input[1]]
|
|
del params[bn.input[2]]
|
|
del params[bn.input[3]]
|
|
del params[bn.input[4]]
|
|
net.op.extend(new_ops)
|
|
break
|
|
return net, params, removed_tensors
|
|
|
|
|
|
def fuse_bn(net, params, ignore_failure):
|
|
# Run until we hit a fixed point
|
|
removed_tensors = []
|
|
while True:
|
|
(next_net, next_params, removed_tensors) = \
|
|
fuse_first_bn(net, params, removed_tensors)
|
|
if len(next_net.op) == len(net.op):
|
|
if (
|
|
any(op.type == "SpatialBN" for op in next_net.op) and
|
|
not ignore_failure
|
|
):
|
|
raise Exception(
|
|
"Model contains SpatialBN op after fusion: %s", next_net)
|
|
return (next_net, next_params, removed_tensors)
|
|
net, params, removed_tensors = (next_net, next_params, removed_tensors)
|
|
|
|
|
|
def fuse_first_mul_add(net, params, removed_tensors):
|
|
net = copy.deepcopy(net)
|
|
params = copy.deepcopy(params)
|
|
|
|
for ((i, current), (j, next_)) in pairwise(enumerate(net.op)):
|
|
if current.type != "Mul" or next_.type != "Add":
|
|
continue
|
|
|
|
if next_.input[0] != current.output[0]:
|
|
raise Exception("Failure to fuse")
|
|
|
|
if len(blob_uses(net, current.output[0])) != 1:
|
|
raise Exception("Failure to fuse")
|
|
|
|
log.info("Fusing at index %s", i)
|
|
mul_ = current
|
|
add_ = next_
|
|
batch_norm = copy.deepcopy(mul_)
|
|
batch_norm.type = "SpatialBN"
|
|
batch_norm.arg.extend([utils.MakeArgument("is_test", 1)])
|
|
batch_norm.arg.extend([utils.MakeArgument("epsilon", float(1e-9))])
|
|
|
|
def s(x):
|
|
return "{}{}".format(add_.output[0], x)
|
|
fake_mean = s("_mean")
|
|
fake_var = s("_var")
|
|
|
|
del batch_norm.input[:]
|
|
batch_norm.input.extend([mul_.input[0],
|
|
mul_.input[1],
|
|
add_.input[1],
|
|
fake_mean,
|
|
fake_var])
|
|
params[fake_mean] = np.zeros_like(params[mul_.input[1]])
|
|
params[fake_var] = np.ones_like(params[mul_.input[1]])
|
|
net.external_input.extend([fake_mean, fake_var])
|
|
|
|
batch_norm.output[0] = add_.output[0]
|
|
new_ops = net.op[:i] + [batch_norm] + net.op[j + 1:]
|
|
del net.op[:]
|
|
net.op.extend(new_ops)
|
|
break
|
|
return net, params, removed_tensors
|
|
|
|
|
|
def fuse_mul_add(net, params):
|
|
# Run until we hit a fixed point
|
|
removed_tensors = []
|
|
while True:
|
|
(next_net, next_params, removed_tensors) = \
|
|
fuse_first_mul_add(net, params, removed_tensors)
|
|
if len(next_net.op) == len(net.op):
|
|
return (next_net, next_params, removed_tensors)
|
|
net, params, removed_tensors = (next_net, next_params, removed_tensors)
|
|
|
|
|
|
def add_tensor(net, name, blob):
|
|
''' Create an operator to store the tensor 'blob',
|
|
run the operator to put the blob to workspace.
|
|
uint8 is stored as an array of string with one element.
|
|
'''
|
|
kTypeNameMapper = {
|
|
np.dtype('float32'): "GivenTensorFill",
|
|
np.dtype('int32'): "GivenTensorIntFill",
|
|
np.dtype('int64'): "GivenTensorInt64Fill",
|
|
np.dtype('uint8'): "GivenTensorStringFill",
|
|
}
|
|
|
|
shape = blob.shape
|
|
values = blob
|
|
# pass array of uint8 as a string to save storage
|
|
# storing uint8_t has a large overhead for now
|
|
if blob.dtype == np.dtype('uint8'):
|
|
shape = [1]
|
|
values = [str(blob.data)]
|
|
|
|
op = core.CreateOperator(
|
|
kTypeNameMapper[blob.dtype],
|
|
[], [name],
|
|
arg=[
|
|
utils.MakeArgument("shape", shape),
|
|
utils.MakeArgument("values", values),
|
|
]
|
|
)
|
|
net.op.extend([op])
|
|
|
|
|
|
def gen_init_net_from_blobs(blobs):
|
|
''' Generate an initialization net based on a blob dict '''
|
|
ret = caffe2_pb2.NetDef()
|
|
for name, blob in blobs.items():
|
|
add_tensor(ret, name, blob)
|
|
return ret
|
|
|
|
|
|
def fuse_conv_relu(net):
|
|
net = copy.deepcopy(net)
|
|
device_option = core.DeviceOption(caffe2_pb2.IDEEP)
|
|
for op in net.op:
|
|
op.device_option.CopyFrom(device_option)
|
|
|
|
new_net = caffe2_pb2.NetDef()
|
|
new_net.ParseFromString(C.transform_optimizeForMKLDNN(net.SerializeToString()))
|
|
return new_net
|
|
|
|
|
|
def Optimize(args):
|
|
init_net = caffe2_pb2.NetDef()
|
|
predict_net = caffe2_pb2.NetDef()
|
|
init_net.ParseFromString(args.init_net.read())
|
|
predict_net.ParseFromString(args.pred_net.read())
|
|
|
|
workspace.ResetWorkspace()
|
|
workspace.RunNetOnce(init_net)
|
|
param_dict = {p: workspace.FetchBlob(p) for p in workspace.Blobs()}
|
|
|
|
external_inputs = {}
|
|
external_outputs = {}
|
|
if args.verify_input:
|
|
value_info = json.load(args.verify_input)
|
|
input_shapes = {k : v[-1] for (k, v) in value_info.items()}
|
|
print("input info: {}".format(input_shapes))
|
|
for k, v in input_shapes.items():
|
|
external_inputs[k] = np.random.randn(*v).astype(np.float32)
|
|
workspace.FeedBlob(k, external_inputs[k])
|
|
workspace.RunNetOnce(predict_net)
|
|
for o in predict_net.external_output:
|
|
external_outputs[o] = workspace.FetchBlob(o)
|
|
|
|
if args.fuse_mul_add:
|
|
predict_net, param_dict, _ = fuse_mul_add(predict_net, param_dict)
|
|
if args.fuse_bn:
|
|
predict_net, param_dict, _ = fuse_bn(predict_net, param_dict, False)
|
|
if args.fuse_conv_relu:
|
|
predict_net = fuse_conv_relu(predict_net)
|
|
|
|
external_outputs_opt = {}
|
|
if args.verify_input:
|
|
workspace.ResetWorkspace()
|
|
device_option = core.DeviceOption(caffe2_pb2.IDEEP) if args.fuse_conv_relu else core.DeviceOption(caffe2_pb2.CPU)
|
|
with core.DeviceScope(device_option):
|
|
for k, v in param_dict.items():
|
|
workspace.FeedBlob(k, v, device_option)
|
|
for k, v in external_inputs.items():
|
|
workspace.FeedBlob(k, v, device_option)
|
|
workspace.RunNetOnce(predict_net)
|
|
for o in predict_net.external_output:
|
|
external_outputs_opt[o] = workspace.FetchBlob(o)
|
|
assert np.allclose(external_outputs[o],
|
|
external_outputs_opt[o],
|
|
atol=1e-3,
|
|
rtol=1e-3)
|
|
|
|
for i, o in enumerate(predict_net.op):
|
|
print("op[{}]: {}".format(i, o.type))
|
|
init_net = gen_init_net_from_blobs(param_dict)
|
|
with open('init_net.pb', 'wb') as f:
|
|
f.write(init_net.SerializeToString())
|
|
with open('predict_net.pb', 'wb') as f:
|
|
f.write(predict_net.SerializeToString())
|
|
|
|
if __name__ == '__main__':
|
|
args = GetArgumentParser().parse_args()
|
|
Optimize(args)
|