pytorch/caffe2/python/memonger.py
2016-07-21 11:26:41 -07:00

146 lines
4.7 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import networkx as nx
import collections
import copy
from caffe2.python import workspace
LiveRange = collections.namedtuple('LiveRange', ["defined", "used"])
def topological_sort_traversal(g):
return nx.topological_sort(g)
def compute_ranges(linearized_ops):
blobs = collections.defaultdict(lambda: LiveRange(defined=None, used=None))
for i, op in enumerate(linearized_ops):
for blob in op.input:
used = blobs[blob].used
if used is None:
used = i
else:
used = max(used, i)
blobs[blob] = blobs[blob]._replace(used=used)
for blob in op.output:
defined = blobs[blob].defined
if defined is None:
defined = i
else:
defined = min(defined, i)
blobs[blob] = blobs[blob]._replace(defined=defined)
return blobs
def is_compatible(candidate_range, assignment, static_blobs):
(name, range_) = assignment[-1]
if name in static_blobs:
return False
if candidate_range.used is None or candidate_range.defined is None \
or range_.defined is None or range_.used is None:
return False
return candidate_range.defined > range_.used
def compute_blob_assignments(assignments):
blob_assignments = {}
for assignment in assignments:
if len(assignment) == 1:
continue
for (blob, _) in assignment:
blob_assignments[blob] = ",".join([b for b, _ in assignment])
return blob_assignments
def compute_assignments(ranges, static_blobs):
ranges = sorted(list(ranges.iteritems()), key=lambda p: p[1].used)
assignments = []
for (name, range_) in ranges:
assigned = False
for assignment in assignments:
if is_compatible(range_, assignment, static_blobs):
assignment.append((name, range_))
assigned = True
break
if assigned:
continue
assignments.append([(name, range_)])
return assignments
def compute_interference_graph(ops):
g = nx.DiGraph()
for i, op in enumerate(ops):
g.add_node(i, op=op)
for i, parent_op in enumerate(ops):
for j, child_op in enumerate(ops):
if i == j:
continue
if any(output in child_op.input for output in parent_op.output):
deps = set(child_op.input).intersection(parent_op.output)
g.add_edge(i, j, deps=deps)
assert nx.is_directed_acyclic_graph(g), child_op
return g
Optimization = collections.namedtuple(
'Optimization', ['net', 'assignments', 'blob_assignments'])
def apply_assignments(net, blob_assignments):
def canonical_name(blob):
if blob not in blob_assignments:
return blob
return b"{}_shared".format(blob_assignments[blob])
for op in net.op:
for i, input_ in enumerate(op.input):
op.input[i] = canonical_name(input_)
for i, output in enumerate(op.output):
op.output[i] = canonical_name(output)
def optimize_interference(net, static_blobs,
ordering_function=topological_sort_traversal):
"""
1) Use a BFS traversal of the execution graph to generate an
ordering of the node executions.
2) Generate use-def ranges for each `blob` in the BFS traversal
order.
3) Assign blobs to `canonical blobs`
4) Rename blobs to canonical blobs
"""
net = copy.deepcopy(net)
g = compute_interference_graph(net.op)
ordering = ordering_function(g)
linearized_ops = [net.op[i] for i in ordering]
ranges = compute_ranges(linearized_ops)
assignments = compute_assignments(ranges, static_blobs)
blob_assignments = compute_blob_assignments(assignments)
apply_assignments(net, blob_assignments)
return Optimization(
net=net,
blob_assignments=blob_assignments,
assignments=assignments)
Statistics = collections.namedtuple(
'Statistics', ['baseline_nbytes', 'optimized_nbytes'])
def compute_statistics(assignments):
def blob_nbytes(blob):
return workspace.FetchBlob(blob).nbytes
blob_bytes = {
blob: blob_nbytes(blob) for assignment in assignments
for (blob, _) in assignment}
baseline_nbytes = sum(v for _, v in blob_bytes.iteritems())
optimized_nbytes = sum(
max(blob_bytes[blob] for (blob, _) in assignment)
for assignment in assignments)
return Statistics(
baseline_nbytes=baseline_nbytes,
optimized_nbytes=optimized_nbytes)