pytorch/caffe2/python/workspace.py
bddppq f94ae3ba1d
Update from facebook (#7696)
* Fix handling of empty batches in SumReduceDimsOp

As titled

* Deferrable async_scheduling finishRun fix

Proper order of finishing run operations in deferrable_async_scheduling net

* Simplify exception handling in async_scheduling

Simplify exception handling, no need to busy wait, thread that processes the
last task can finish the run

* [C2]worker_coordinator_memorize_worker_ids

As titled. This is related to T28689868, where the number of blobs we want to create is equal to the number of worker ids

* Add unit test for nets with no type set

* Ignore total length argument in sympolic_pad_packed_sequence

1- There was a mistake in the code that total_length was added to the wrong symbolic function (pack_padded_sequence) instead of (pad_packed_sequence)
2- No need to throw an exception if total_length is given since it is only used to enable data_parallel training on multi-gpus and doesn't have anything to do with onnx export, so just ignore it. https://fburl.com/tk4gciqp

* Add support for MKLDNN to async_scheduling

Just add MKLDNN as a possible CPU option to async_scheduling's pool function

* [AuFL][ensemble] support branch output for prediction

This diff supports using predictions from different branches and thus enables model ensembling (not fully independent).

* Fix a bug in add_loss in layer_model_helper

As titled.

* Support lradaption for adam

1.lr adaption operator
2.apply to dense adam

* Perf tweaks for async_scheduling

Restore single pool option + remove unnecessary (no-ops) calls

* add quantization to SparseSimdAdagradOp

add a bunch of quantization signatures to SparseSimdAdagradOp, implementations to come next

* [sr] [codemod] Change all SR callsites to use new API

@allow-large-files

This diff refactors all callsites of SR to use the slightly changed API introduced in the diff below. Really what this means is that you need to include the correct header. Also if you were using `ClientFactory::newFactory` you need to not prefix it with `ClientFactory::`.

```
cd ~/fbsource/fbcode
find ./ -type f -exec sed -i -e 's:#include "servicerouter/client/cpp2/ClientFactory.h":#include "servicerouter/client/cpp2/ServiceRouter.h":' -e 's:#include <servicerouter/client/cpp2/ClientFactory.h>:#include <servicerouter/client/cpp2/ServiceRouter.h>:' -e 's/ClientFactory::newFactory(/newFactory(/g' {} \;
```

Also manually fixed spots that couldn't be done automatically (or broke because they depended on transitive includes).

* Back out "Fix handling of empty batches in SumReduceDimsOp"

Original commit changeset: 282da1730cc2 This commit is blocking the
Github->fbcode sync, which really needs to get merged ASAP. D7881937 which this
diff depends on will be reverted in the sync D7990948 which causes this to
break. The sync diff cannot be patched with this reversion because it must be
landed against base revision 5c8c099 , and D7881937 must not be included in the
sync diff because it is breaking GPU tests that are not available in sandcastle
: https://ci.pytorch.org/jenkins/job/caffe2-builds/job/py2-cuda8.0-cudnn6-ubuntu16.04-test/3638/console
for one example.

* Add the flow to support operator benchmark

1) generate model with the operator 2) upload to everstore 3) generate model spec into json file 4) start running the benchmark

* [tum][gpu] Connect DPM trainer with flow and unit tests

This diff:
- Fix some small bugs for Yiming's recent changes to parallelizer, so it suits real use cases.
- Add correct tags to the TUM code, so we can do data parallel transform
- pass extra info when instantiation.
- add unit test for using DPM in TUM model

After this diff, we can do simple box, multi-gpu fully-sync trainer for TUM in Fblearner workflow, but may still need to do speed benchmarking.

* w/o normalized lradaption for adam dense only

The previous lr adaption includes a normalization step when performing the dot product operation. This is not exactly same as what is proposed in the paper. I add normalization as an option. Without it, the operator performs exactly what the paper proposed. With the option, we add the normalization step

* [fb] Use SharedPromise in DeferrableAsyncSchedulingNet

This code is to simplify DeferrableAsyncSchedulingNet by removing condition
variable + small fixes

* [tum] implement cuda sparseLengthsMean and LengthsMean

as title

* Adding an optional parameter to allow use of protobufs in InferShapesAndTypes function.

Adding an optional parameter to allow use of protobufs in InferShapesAndTypes function.

* Move feature_to_index to FeatureSpec.feature_to_index

move feature_to_index to FeatureSpec.feature_to_index to avoid override other fields

* [Caffe2] Rename bytes_moved to bytes_written

Just a rename in preparation for supporting bytes_read.

* [c2] fix ReduceFrontSumOp for empty case by setting 0

otherwise, it may use the results from last iteration when it's empty batch.

* [Caffe2] [Int8] Improve Intel CPU performance

* [Easy] Improve PrependDim op logging

as titled

* DBFileReader expand db_path using os.path.expanduser(..)

Since there are a lot of possible use cases of `DBFileReader` to read from user home path, like `~/local/sample.db`, I want to save people's trouble of calling `os.path.expanduser(db_path)` themselves.

* [Caffe2] Add bytes_read to cost structure

We're adding analytical read bytes to cost functions.  This extends the structure accordingly for all CostInference defined operators.
Additionally, some small bug fixes were performed:
1) Cost functions now extract type information of operands instead of assuming float

* Fix sleef on aarch64 for hhvm

@bypass-lint

Rename flag

* Remove duplicated part in caffe2/ideep/operators/conv_op.cc

should be sync error

* Rename test helper function test_adagrad_sparse_helper to adagrad_sparse_test_helper to avoid confusing pytest
2018-05-19 23:10:48 -07:00

651 lines
20 KiB
Python

## @package workspace
# Module caffe2.python.workspace
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import collections
import contextlib
from google.protobuf.message import Message
from multiprocessing import Process
import os
from collections import defaultdict
import logging
import numpy as np
from past.builtins import basestring
import shutil
import socket
import tempfile
from caffe2.proto import caffe2_pb2
from caffe2.python import scope, utils
import caffe2.python._import_c_extension as C
logger = logging.getLogger(__name__)
Blobs = C.blobs
CreateBlob = C.create_blob
CurrentWorkspace = C.current_workspace
DeserializeBlob = C.deserialize_blob
GlobalInit = C.global_init
HasBlob = C.has_blob
RegisteredOperators = C.registered_operators
SerializeBlob = C.serialize_blob
SwitchWorkspace = C.switch_workspace
RootFolder = C.root_folder
Workspaces = C.workspaces
BenchmarkNet = C.benchmark_net
GetStats = C.get_stats
operator_tracebacks = defaultdict(dict)
is_asan = C.is_asan
has_gpu_support = C.has_gpu_support
if has_gpu_support:
NumCudaDevices = C.num_cuda_devices
GetCUDAVersion = C.get_cuda_version
GetCuDNNVersion = C.get_cudnn_version
def GetCudaPeerAccessPattern():
return np.asarray(C.get_cuda_peer_access_pattern())
GetDeviceProperties = C.get_device_properties
else:
NumCudaDevices = lambda: 0 # noqa
GetCuDNNVersion = lambda: 0 # noqa
GetCuDNNVersion = lambda: 0 # noqa
GetCudaPeerAccessPattern = lambda: np.array([]) # noqa
GetDeviceProperties = lambda x: None # noqa
IsNUMAEnabled = C.is_numa_enabled
GetNumNUMANodes = C.get_num_numa_nodes
GetBlobNUMANode = C.get_blob_numa_node
def _GetFreeFlaskPort():
"""Get a free flask port."""
# We will prefer to use 5000. If not, we will then pick a random port.
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = sock.connect_ex(('127.0.0.1', 5000))
if result == 0:
return 5000
else:
s = socket.socket()
s.bind(('', 0))
port = s.getsockname()[1]
s.close()
# Race condition: between the interval we close the socket and actually
# start a mint process, another process might have occupied the port. We
# don't do much here as this is mostly for convenience in research
# rather than 24x7 service.
return port
def StartMint(root_folder=None, port=None):
"""Start a mint instance.
TODO(Yangqing): this does not work well under ipython yet. According to
https://github.com/ipython/ipython/issues/5862
writing up some fix is a todo item.
"""
from caffe2.python.mint import app
if root_folder is None:
# Get the root folder from the current workspace
root_folder = C.root_folder()
if port is None:
port = _GetFreeFlaskPort()
process = Process(
target=app.main,
args=(
['-p', str(port), '-r', root_folder],
)
)
process.start()
print('Mint running at http://{}:{}'.format(socket.getfqdn(), port))
return process
def StringifyProto(obj):
"""Stringify a protocol buffer object.
Inputs:
obj: a protocol buffer object, or a Pycaffe2 object that has a Proto()
function.
Outputs:
string: the output protobuf string.
Raises:
AttributeError: if the passed in object does not have the right attribute.
"""
if isinstance(obj, basestring):
return obj
else:
if isinstance(obj, Message):
# First, see if this object is a protocol buffer, which we can
# simply serialize with the SerializeToString() call.
return obj.SerializeToString()
elif hasattr(obj, 'Proto'):
return obj.Proto().SerializeToString()
else:
raise ValueError("Unexpected argument to StringifyProto of type " +
type(obj).__name__)
def ResetWorkspace(root_folder=None):
if root_folder is None:
# Reset the workspace, but keep the current root folder setting.
return C.reset_workspace(C.root_folder())
else:
if not os.path.exists(root_folder):
os.makedirs(root_folder)
return C.reset_workspace(root_folder)
def CreateNet(net, overwrite=False, input_blobs=None):
if input_blobs is None:
input_blobs = []
for input_blob in input_blobs:
C.create_blob(input_blob)
return CallWithExceptionIntercept(
C.create_net,
C.Workspace.current._last_failed_op_net_position,
GetNetName(net),
StringifyProto(net), overwrite,
)
def Predictor(init_net, predict_net):
return C.Predictor(StringifyProto(init_net), StringifyProto(predict_net))
def GetOperatorCost(operator, blobs):
return C.get_operator_cost(StringifyProto(operator), blobs)
def RunOperatorOnce(operator):
return C.run_operator_once(StringifyProto(operator))
def RunOperatorsOnce(operators):
for op in operators:
success = RunOperatorOnce(op)
if not success:
return False
return True
def CallWithExceptionIntercept(func, op_id_fetcher, net_name, *args, **kwargs):
try:
return func(*args, **kwargs)
except Exception:
op_id = op_id_fetcher()
net_tracebacks = operator_tracebacks.get(net_name, None)
logger.warning(
'Original python traceback for operator `{}` in network '
'`{}` in exception above (most recent call last):'.format(
op_id, net_name))
if net_tracebacks and op_id in net_tracebacks:
tb = net_tracebacks[op_id]
for line in reversed(tb):
logger.warning(' File "{}", line {}, in {}'.format(
line[0], line[1], line[2]))
raise
def RunNetOnce(net):
return CallWithExceptionIntercept(
C.run_net_once,
C.Workspace.current._last_failed_op_net_position,
GetNetName(net),
StringifyProto(net),
)
def RunNet(name, num_iter=1, allow_fail=False):
"""Runs a given net.
Inputs:
name: the name of the net, or a reference to the net.
num_iter: number of iterations to run
allow_fail: if True, does not assert on net exec failure but returns False
Returns:
True or an exception.
"""
return CallWithExceptionIntercept(
C.run_net,
C.Workspace.current._last_failed_op_net_position,
GetNetName(name),
StringifyNetName(name), num_iter, allow_fail,
)
def RunPlan(plan_or_step):
# TODO(jiayq): refactor core.py/workspace.py to avoid circular deps
import caffe2.python.core as core
if isinstance(plan_or_step, core.ExecutionStep):
plan_or_step = core.Plan(plan_or_step)
return C.run_plan(StringifyProto(plan_or_step))
def InferShapesAndTypes(nets, blob_dimensions=None, nets_proto=False):
"""Infers the shapes and types for the specified nets.
Inputs:
nets: the list of nets
blob_dimensions (optional): a dictionary of blobs and their dimensions.
If not specified, the workspace blobs are used.
nets_proto (optional): a boolean flag indicating whether the protobuffer
representation is passed to the routine.
Returns:
A tuple of (shapes, types) dictionaries keyed by blob name.
"""
if nets_proto:
net_protos = [StringifyProto(n) for n in nets]
else:
net_protos = [StringifyProto(n.Proto()) for n in nets]
if blob_dimensions is None:
blobdesc_prototxt = C.infer_shapes_and_types_from_workspace(net_protos)
else:
blobdesc_prototxt = C.infer_shapes_and_types_from_map(
net_protos, blob_dimensions
)
blobdesc_proto = caffe2_pb2.TensorShapes()
blobdesc_proto.ParseFromString(blobdesc_prototxt)
shapes = {}
types = {}
for ts in blobdesc_proto.shapes:
if not ts.unknown_shape:
shapes[ts.name] = list(ts.dims)
types[ts.name] = ts.data_type
return (shapes, types)
def _StringifyName(name, expected_type):
if isinstance(name, basestring):
return name
assert type(name).__name__ == expected_type, \
"Expected a string or %s" % expected_type
return str(name)
def StringifyBlobName(name):
return _StringifyName(name, "BlobReference")
def StringifyNetName(name):
return _StringifyName(name, "Net")
def GetNetName(net):
if isinstance(net, basestring):
return net
if type(net).__name__ == "Net":
return net.Name()
if isinstance(net, caffe2_pb2.NetDef):
return net.name
raise Exception("Not a Net object: {}".format(str(net)))
def FeedBlob(name, arr, device_option=None):
"""Feeds a blob into the workspace.
Inputs:
name: the name of the blob.
arr: either a TensorProto object or a numpy array object to be fed into
the workspace.
device_option (optional): the device option to feed the data with.
Returns:
True or False, stating whether the feed is successful.
"""
if type(arr) is caffe2_pb2.TensorProto:
arr = utils.Caffe2TensorToNumpyArray(arr)
if type(arr) is np.ndarray and arr.dtype.kind in 'SU':
# Plain NumPy strings are weird, let's use objects instead
arr = arr.astype(np.object)
if device_option is None:
device_option = scope.CurrentDeviceScope()
if device_option and device_option.device_type == caffe2_pb2.CUDA:
if arr.dtype == np.dtype('float64'):
logger.warning(
"CUDA operators do not support 64-bit doubles, " +
"please use arr.astype(np.float32) or np.int32 for ints." +
" Blob: {}".format(name) +
" type: {}".format(str(arr.dtype))
)
name = StringifyBlobName(name)
if device_option is not None:
return C.feed_blob(name, arr, StringifyProto(device_option))
else:
return C.feed_blob(name, arr)
def FetchBlobs(names):
"""Fetches a list of blobs from the workspace.
Inputs:
names: list of names of blobs - strings or BlobReferences
Returns:
list of fetched blobs
"""
return [FetchBlob(name) for name in names]
def FetchBlob(name):
"""Fetches a blob from the workspace.
Inputs:
name: the name of the blob - a string or a BlobReference
Returns:
Fetched blob (numpy array or string) if successful
"""
result = C.fetch_blob(StringifyBlobName(name))
if isinstance(result, tuple):
raise TypeError(
"Use FetchInt8Blob to fetch Int8 Blob {}".format(
StringifyBlobName(name)
)
)
return result
Int8Tensor = collections.namedtuple(
'Int8Tensor', ['data', 'scale', 'zero_point']
)
def FetchInt8Blob(name):
"""Fetches an Int8 blob from the workspace. It shared backend implementation
with FetchBlob but it is recommened when fetching Int8 Blobs
Inputs:
name: the name of the Int8 blob - a string or a BlobReference
Returns:
data: int8 numpy array, data
scale: float, fake quantization scale
zero_point: int, fake quantization offset
"""
result = C.fetch_blob(StringifyBlobName(name))
assert isinstance(result, tuple), \
'You are not fetching an Int8Blob {}. Please use FetchBlob'.format(
StringifyBlobName(name))
return Int8Tensor(*result)
def _Workspace_fetch_int8_blob(ws, name):
"""Fetches an Int8 blob from the workspace. It shared backend implementation
with FetchBlob but it is recommened when fetching Int8 Blobs
Inputs:
name: the name of the Int8 blob - a string or a BlobReference
Returns:
data: int8 numpy array, data
scale: float, fake quantization scale
zero_point: int, fake quantization offset
"""
result = ws.fetch_blob(name)
assert isinstance(result, tuple), \
'You are not fetching an Int8Blob {}. Please use fetch_blob'.format(
StringifyBlobName(name))
return Int8Tensor(*result)
C.Workspace.fetch_int8_blob = _Workspace_fetch_int8_blob
def ApplyTransform(transform_key, net):
"""Apply a Transform to a NetDef protobuf object, and returns the new
transformed NetDef.
Inputs:
transform_key: the name of the transform, as it is stored in the registry
net: a NetDef protobuf object
Returns:
Transformed NetDef protobuf object.
"""
transformed_net = caffe2_pb2.NetDef()
transformed_str = C.apply_transform(
str(transform_key).encode('utf-8'),
net.SerializeToString(),
)
transformed_net.ParseFromString(transformed_str)
return transformed_net
def ApplyTransformIfFaster(transform_key, net, init_net, **kwargs):
"""Apply a Transform to a NetDef protobuf object, and returns the new
transformed NetDef, only if it runs faster than the original.
The runs are performed on the current active workspace (gWorkspace).
You should initialize that workspace before making a call to this function.
Inputs:
transform_key: the name of the transform, as it is stored in the registry
net: a NetDef protobuf object
init_net: The net to initialize the workspace.
warmup_runs (optional):
Determines how many times the net is run before testing.
Will be 5 by default.
main_runs (optional):
Determines how many times the net is run during testing.
Will be 10 by default.
improvement_threshold (optional):
Determines the factor which the new net needs to be faster
in order to replace the old. Will be 1.01 by default.
Returns:
Either a Transformed NetDef protobuf object, or the original netdef.
"""
warmup_runs = kwargs['warmup_runs'] if 'warmup_runs' in kwargs else 5
main_runs = kwargs['main_runs'] if 'main_runs' in kwargs else 10
improvement_threshold = kwargs['improvement_threshold'] \
if 'improvement_threshold' in kwargs else 1.01
transformed_net = caffe2_pb2.NetDef()
transformed_str = C.apply_transform_if_faster(
str(transform_key).encode('utf-8'),
net.SerializeToString(),
init_net.SerializeToString(),
warmup_runs,
main_runs,
float(improvement_threshold),
)
transformed_net.ParseFromString(transformed_str)
return transformed_net
def GetNameScope():
"""Return the current namescope string. To be used to fetch blobs"""
return scope.CurrentNameScope()
class _BlobDict(object):
"""Provides python dict compatible way to do fetching and feeding"""
def __getitem__(self, key):
return FetchBlob(key)
def __setitem__(self, key, value):
return FeedBlob(key, value)
def __len__(self):
return len(C.blobs())
def __iter__(self):
return C.blobs().__iter__()
def __contains__(self, item):
return C.has_blob(item)
blobs = _BlobDict()
################################################################################
# Utilities for immediate mode
#
# Caffe2's immediate mode implements the following behavior: between the two
# function calls StartImmediate() and StopImmediate(), for any operator that is
# called through CreateOperator(), we will also run that operator in a workspace
# that is specific to the immediate mode. The user is explicitly expected to
# make sure that these ops have proper inputs and outputs, i.e. one should not
# run an op where an external input is not created or fed.
#
# Users can use FeedImmediate() and FetchImmediate() to interact with blobs
# in the immediate workspace.
#
# Once StopImmediate() is called, all contents in the immediate workspace is
# freed up so one can continue using normal runs.
#
# The immediate mode is solely for debugging purposes and support will be very
# sparse.
################################################################################
_immediate_mode = False
_immediate_workspace_name = "_CAFFE2_IMMEDIATE"
_immediate_root_folder = ''
def IsImmediate():
return _immediate_mode
@contextlib.contextmanager
def WorkspaceGuard(workspace_name):
current = CurrentWorkspace()
SwitchWorkspace(workspace_name, True)
yield
SwitchWorkspace(current)
def StartImmediate(i_know=False):
global _immediate_mode
global _immediate_root_folder
if IsImmediate():
# already in immediate mode. We will kill the previous one
# and start from fresh.
StopImmediate()
_immediate_mode = True
with WorkspaceGuard(_immediate_workspace_name):
_immediate_root_folder = tempfile.mkdtemp()
ResetWorkspace(_immediate_root_folder)
if i_know:
# if the user doesn't want to see the warning message, sure...
return
print("""
Enabling immediate mode in caffe2 python is an EXTREMELY EXPERIMENTAL
feature and may very easily go wrong. This is because Caffe2 uses a
declarative way of defining operators and models, which is essentially
not meant to run things in an interactive way. Read the following carefully
to make sure that you understand the caveats.
(1) You need to make sure that the sequences of operators you create are
actually runnable sequentially. For example, if you create an op that takes
an input X, somewhere earlier you should have already created X.
(2) Caffe2 immediate uses one single workspace, so if the set of operators
you run are intended to be under different workspaces, they will not run.
To create boundaries between such use cases, you can call FinishImmediate()
and StartImmediate() manually to flush out everything no longer needed.
(3) Underlying objects held by the immediate mode may interfere with your
normal run. For example, if there is a leveldb that you opened in immediate
mode and did not close, your main run will fail because leveldb does not
support double opening. Immediate mode may also occupy a lot of memory esp.
on GPUs. Call FinishImmediate() as soon as possible when you no longer
need it.
(4) Immediate is designed to be slow. Every immediate call implicitly
creates a temp operator object, runs it, and destroys the operator. This
slow-speed run is by design to discourage abuse. For most use cases other
than debugging, do NOT turn on immediate mode.
(5) If there is anything FATAL happening in the underlying C++ code, the
immediate mode will immediately (pun intended) cause the runtime to crash.
Thus you should use immediate mode with extra care. If you still would
like to, have fun [https://xkcd.com/149/].
""")
def StopImmediate():
"""Stops an immediate mode run."""
# Phew, that was a dangerous ride.
global _immediate_mode
global _immediate_root_folder
if not IsImmediate():
return
with WorkspaceGuard(_immediate_workspace_name):
ResetWorkspace()
shutil.rmtree(_immediate_root_folder)
_immediate_root_folder = ''
_immediate_mode = False
def ImmediateBlobs():
with WorkspaceGuard(_immediate_workspace_name):
return Blobs()
def RunOperatorImmediate(op):
with WorkspaceGuard(_immediate_workspace_name):
RunOperatorOnce(op)
def FetchImmediate(*args, **kwargs):
with WorkspaceGuard(_immediate_workspace_name):
return FetchBlob(*args, **kwargs)
def FeedImmediate(*args, **kwargs):
with WorkspaceGuard(_immediate_workspace_name):
return FeedBlob(*args, **kwargs)
# CWorkspace utilities
def _Workspace_create_net_with_exception_intercept(ws, net, overwrite=False):
return CallWithExceptionIntercept(
ws._create_net,
ws._last_failed_op_net_position,
GetNetName(net),
StringifyProto(net), overwrite,
)
C.Workspace.create_net = _Workspace_create_net_with_exception_intercept
def _Workspace_run(ws, obj):
if hasattr(obj, 'Proto'):
obj = obj.Proto()
if isinstance(obj, caffe2_pb2.PlanDef):
return ws._run_plan(obj.SerializeToString())
if isinstance(obj, caffe2_pb2.NetDef):
return CallWithExceptionIntercept(
ws._run_net,
ws._last_failed_op_net_position,
GetNetName(obj),
obj.SerializeToString(),
)
# return ws._run_net(obj.SerializeToString())
if isinstance(obj, caffe2_pb2.OperatorDef):
return ws._run_operator(obj.SerializeToString())
raise ValueError(
"Don't know how to do Workspace.run() on {}".format(type(obj)))
C.Workspace.run = _Workspace_run
def _Blob_feed(blob, arg, device_option=None):
if device_option is not None:
device_option = StringifyProto(device_option)
return blob._feed(arg, device_option)
C.Blob.feed = _Blob_feed