mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Three small changes: Reviewed By: ajtulloch Differential Revision: D4437131 fbshipit-source-id: c849e36e1c4d1dce947076349df863fafe62c66d
287 lines
9.4 KiB
Python
287 lines
9.4 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import networkx as nx
|
|
import collections
|
|
import time
|
|
import copy
|
|
from caffe2.python import workspace
|
|
|
|
import logging
|
|
|
|
log = logging.getLogger("memonger")
|
|
log.setLevel(logging.INFO)
|
|
LiveRange = collections.namedtuple('LiveRange', ["defined", "used"])
|
|
|
|
|
|
def share_grad_blobs(net, losses, param_grads, namescope):
|
|
'''
|
|
Implements similar optimization as Torch's shareGradInput():
|
|
for the gradients that are passed between layers, share blobs between
|
|
operators when possible. This yields significant memory savings with
|
|
deep networks.
|
|
|
|
Returns an optimized protobuf (assign to net._net)
|
|
'''
|
|
def is_grad_blob(b):
|
|
name = str(b)
|
|
# Note: need to look at _{namescope} pattern as it matches
|
|
# to handle the auto-split gradients
|
|
return "_grad" in name and (name.startswith(namescope) or
|
|
name.startswith("_" + namescope)) and name not in param_grads
|
|
|
|
def is_grad_op(op):
|
|
# TODO: something smarter
|
|
for inp in op.input:
|
|
if is_grad_blob(inp):
|
|
return True
|
|
for out in op.output:
|
|
if is_grad_blob(out):
|
|
return True
|
|
return False
|
|
|
|
start_time = time.time()
|
|
log.warn("NOTE: Executing *experimental* memonger to " +
|
|
"optimize gradient memory")
|
|
|
|
# Collect ops that have something to do with
|
|
# gradients
|
|
if not namescope.endswith("/"):
|
|
namescope += "/"
|
|
|
|
netproto = copy.deepcopy(net.Proto())
|
|
grad_ops = [op for op in netproto.op if is_grad_op(op)]
|
|
|
|
# Create mapping from blobs to ops
|
|
blobs_to_ops = collections.defaultdict(lambda: [])
|
|
blob_input_count = collections.defaultdict(lambda: 0)
|
|
op_inputs = collections.defaultdict(lambda: 0)
|
|
op_visit_count = collections.defaultdict(lambda: 0)
|
|
for i, op in enumerate(grad_ops):
|
|
for inp in op.input:
|
|
if is_grad_blob(inp) or inp in losses:
|
|
# Ignore in-place transformation ops (self cycles)
|
|
if inp not in op.output:
|
|
blobs_to_ops[inp].append(i)
|
|
op_inputs[i] += 1
|
|
|
|
# Traverse operators starting from the loss blobs.
|
|
# Keep tabs on when blobs are seen first and last, and also
|
|
# when operators have their input satisfied. Share blobs only
|
|
# under same branch, avoiding problems with parallel workers.
|
|
output_blobs = set()
|
|
mapping = {}
|
|
|
|
def descend(op_idx, free_blobs):
|
|
cur_op = grad_ops[op_idx]
|
|
new_free_blobs = set()
|
|
for inp in cur_op.input:
|
|
if is_grad_blob(inp):
|
|
blob_input_count[inp] += 1
|
|
if blob_input_count[inp] == len(blobs_to_ops[inp]):
|
|
actual_blob = inp if inp not in mapping else mapping[inp]
|
|
new_free_blobs.add(actual_blob)
|
|
|
|
for outp in cur_op.output:
|
|
if is_grad_blob(outp):
|
|
if outp not in output_blobs:
|
|
# First seen this blob as output, can assign to a free blob
|
|
for freeb in free_blobs:
|
|
mapping[outp] = freeb
|
|
free_blobs.remove(freeb)
|
|
break
|
|
|
|
output_blobs.add(outp)
|
|
|
|
free_blobs.update(new_free_blobs)
|
|
|
|
first_branch = True
|
|
for outp in cur_op.output:
|
|
for inp_op_idx in blobs_to_ops[outp]:
|
|
op_visit_count[inp_op_idx] += 1
|
|
|
|
# Descend only if we have satisfied all inputs
|
|
if op_visit_count[inp_op_idx] == op_inputs[inp_op_idx]:
|
|
free_blobs_fwd = free_blobs if first_branch else set()
|
|
first_branch = False
|
|
descend(inp_op_idx, free_blobs_fwd)
|
|
|
|
# Start DFS from the losses
|
|
for loss in losses:
|
|
for op_idx in blobs_to_ops[loss]:
|
|
descend(op_idx, set())
|
|
|
|
# Rename the shared blobs
|
|
shared_blobs = set(mapping.values())
|
|
renamed = {}
|
|
for j, b in enumerate(shared_blobs):
|
|
renamed[b] = namescope + "__m{}_".format(j)
|
|
|
|
# Final mapping
|
|
for k, v in mapping.items():
|
|
mapping[k] = renamed[v]
|
|
|
|
# Add the originators
|
|
mapping.update(renamed)
|
|
log.info("Remapping {} blobs, using {} shared".format(
|
|
len(mapping), len(renamed),
|
|
))
|
|
apply_assignments(netproto, mapping)
|
|
log.info("Gradient memory optimization took {} secs".format(
|
|
time.time() - start_time),
|
|
)
|
|
return netproto
|
|
|
|
|
|
def topological_sort_traversal(g):
|
|
return nx.topological_sort(g)
|
|
|
|
|
|
def compute_ranges(linearized_ops):
|
|
blobs = collections.defaultdict(lambda: LiveRange(defined=None, used=None))
|
|
for i, op in enumerate(linearized_ops):
|
|
for blob in op.input:
|
|
used = blobs[blob].used
|
|
if used is None:
|
|
used = i
|
|
else:
|
|
used = max(used, i)
|
|
blobs[blob] = blobs[blob]._replace(used=used)
|
|
for blob in op.output:
|
|
defined = blobs[blob].defined
|
|
if defined is None:
|
|
defined = i
|
|
else:
|
|
defined = min(defined, i)
|
|
blobs[blob] = blobs[blob]._replace(defined=defined)
|
|
|
|
return blobs
|
|
|
|
|
|
def is_compatible(candidate_range, assignment, static_blobs):
|
|
(name, range_) = assignment[-1]
|
|
if name in static_blobs:
|
|
return False
|
|
if candidate_range.defined is None or range_.defined is None \
|
|
or range_.used is None:
|
|
return False
|
|
return candidate_range.defined > range_.used
|
|
|
|
|
|
def compute_blob_assignments(assignments):
|
|
blob_assignments = {}
|
|
for assignment in assignments:
|
|
if len(assignment) == 1:
|
|
continue
|
|
last_blob, _ = assignment[-1]
|
|
for (blob, _) in assignment:
|
|
blob_assignments[blob] = last_blob
|
|
return blob_assignments
|
|
|
|
|
|
def compute_assignments(ranges, static_blobs):
|
|
# Sort the ranges based on when they are last used.
|
|
# If LiveRange.used is None, then the blob is never used and could
|
|
# be consumed externally. Sort these to the end of the list as opposed
|
|
# to the beginning so that they can be shared as well.
|
|
ranges = sorted(
|
|
list(ranges.items()),
|
|
key=lambda p: (p[1].used is None, p[1].used),
|
|
)
|
|
assignments = []
|
|
for (name, range_) in ranges:
|
|
assigned = False
|
|
for assignment in assignments:
|
|
if is_compatible(range_, assignment, static_blobs):
|
|
assignment.append((name, range_))
|
|
assigned = True
|
|
break
|
|
if assigned:
|
|
continue
|
|
assignments.append([(name, range_)])
|
|
return assignments
|
|
|
|
|
|
def compute_interference_graph(ops):
|
|
g = nx.DiGraph()
|
|
for i, op in enumerate(ops):
|
|
g.add_node(i, op=op)
|
|
for i, parent_op in enumerate(ops):
|
|
for j, child_op in enumerate(ops):
|
|
if i == j:
|
|
continue
|
|
if any(output in child_op.input for output in parent_op.output):
|
|
deps = set(child_op.input).intersection(parent_op.output)
|
|
g.add_edge(i, j, deps=deps)
|
|
assert nx.is_directed_acyclic_graph(g), child_op
|
|
return g
|
|
|
|
|
|
Optimization = collections.namedtuple(
|
|
'Optimization', ['net', 'assignments', 'blob_assignments'])
|
|
|
|
|
|
def apply_assignments(net, blob_assignments):
|
|
def canonical_name(blob):
|
|
if blob not in blob_assignments:
|
|
return blob
|
|
return blob_assignments[blob]
|
|
|
|
for op in net.op:
|
|
for i, input_ in enumerate(op.input):
|
|
op.input[i] = canonical_name(input_)
|
|
for i, output in enumerate(op.output):
|
|
op.output[i] = canonical_name(output)
|
|
|
|
|
|
def optimize_interference(net, static_blobs,
|
|
ordering_function=topological_sort_traversal):
|
|
"""
|
|
1) Use a BFS traversal of the execution graph to generate an
|
|
ordering of the node executions.
|
|
2) Generate use-def ranges for each `blob` in the BFS traversal
|
|
order.
|
|
3) Assign blobs to `canonical blobs`
|
|
4) Rename blobs to canonical blobs
|
|
"""
|
|
net = copy.deepcopy(net)
|
|
g = compute_interference_graph(net.op)
|
|
ordering = ordering_function(g)
|
|
linearized_ops = [net.op[i] for i in ordering]
|
|
|
|
# Reorder ops in net based on the computed linearlized order.
|
|
# If the graph has multiple topological orderings and if the NetDef's
|
|
# ordering differs from the order used to compute ranges, then the
|
|
# runtime might end up overwriting blobs before they are used.
|
|
del net.op[:]
|
|
net.op.extend(linearized_ops)
|
|
|
|
ranges = compute_ranges(linearized_ops)
|
|
assignments = compute_assignments(ranges, static_blobs)
|
|
blob_assignments = compute_blob_assignments(assignments)
|
|
apply_assignments(net, blob_assignments)
|
|
return Optimization(
|
|
net=net,
|
|
blob_assignments=blob_assignments,
|
|
assignments=assignments)
|
|
|
|
Statistics = collections.namedtuple(
|
|
'Statistics', ['baseline_nbytes', 'optimized_nbytes'])
|
|
|
|
|
|
def compute_statistics(assignments):
|
|
def blob_nbytes(blob):
|
|
return workspace.FetchBlob(blob).nbytes
|
|
blob_bytes = {
|
|
blob: blob_nbytes(blob) for assignment in assignments
|
|
for (blob, _) in assignment}
|
|
baseline_nbytes = sum(v for _, v in blob_bytes.iteritems())
|
|
optimized_nbytes = sum(
|
|
max(blob_bytes[blob] for (blob, _) in assignment)
|
|
for assignment in assignments)
|
|
return Statistics(
|
|
baseline_nbytes=baseline_nbytes,
|
|
optimized_nbytes=optimized_nbytes)
|