pytorch/caffe2/python/gradient_checker.py
Paul Jesse Hellemn b875fb281c
Update from facebook (#7451)
* [bootcamp] Improve "Shape" operator to support axes specification

To improve .shape operator of Caffe2 to support x.shape(tensor, axes), which takes an optional int array "axes" as input. For example, x.shape(tensor, [1, 0]) will return the dimension for axis 1 and 0 following the specified order. For current version, "axes" input allows duplications and can have arbitrary length.

* Back out "Add barrier net that runs before training nets"

Original commit changeset: b373fdc9c30f. Need additional changes to some callers to support barrier failures.

* Change warning to verbose log to reduce log spam

The `LOG(WARNING)` was a bit spammy for regular use so lets just make it a `VLOG`.

* Extract the shared code from different caffe2_benchmark binaries

The OSS benchmark and Internal benchmark will share most functions in the benchmark.

* Support MFR in sequence training

As titled.

* Make knowledge distillation work with using logged prediction feature as teacher label.

1) Add loading raw dense feature as teacher label.
2) Optional calibration function for teacher label
3) Add teacher label into generic unit test
4) Deprecated TTSN workflow version using feature_options to config teacher label

* [C2/CUDA]: unjoined cross entropy sigmoid

as desc

* Add async_scheduling executor into deferrable_net_exec_test

Add async_scheduling into tests and fix some exception cases

* Fix Event disabled error

When disabling event in RNN ops make sure we don't call Finish on disabled
event from op's RunAsync

* cuda ensure cpu output op can handle both TensorCPU and TensorCUDA

as desc.

* [C2 Core] Infer input device option in C2 hypothesis_test checkers

Improve how we default input blob device options.
Previously it defaults as where op lives but it is not necessarily the case.

For example:
CopyCPUToGPU

* [C2 Op]SplitByLengthsOp CPU/GPU implementation

[C2 Op]SplitByLengthsOp CPU/GPU implementation

* fix undefined symbol error

not sure why we're getting undefined symbol even with link_whole = True
Need to figure out why but need this workaround for now

* Add tools in DAIPlayground platform to help debugging models

Add additional tools to allow Plauground override individual method defined in AnyExp.  This will allow user to create module that specificly change certain default method behavior.  An example included in this diff is deactivating test model and checkpointing.  When debugging any model problems, switching off components helps me quickly narrow down the location of the bug.  The technique is extensively used in task T27038712 (Steady memory increase in EDPM, eventually resulting in gloo/cuda.cu:34: out of memory)

* add shape and type inference for int8 conversion operator

* Fix flaky test for group_norm

Fix flaky test for group_norm

* Fix group_norm_op_test flaky

Fix group_norm_op_test flaky

* Implementation of composite learning rate policy

In many state-of-the-arts deep learning works, people use a simple trick to
schedule the learning rate: use a fixed learning rate until error plateaus
and then switch to a different fixed learning rate, and so on. In this diff,
we implemented a simple version of the composite learning rate. The user gives
a set of learning rates policies and corresponding iteration nums, and the
optimizer will change the learning rate policy based on the number of iterations so far.

For example, the user give two learning rate policies, one is FixedLearningRate
and PolyLearningRate, with an iteration number of 1k. Then the first 1k iteration,
we use FixedLearningRate. For the following iterations, we use PolyLearningRate.

* Split two use cases of CachedReader into two classes, DBFileReader and CachedReader

# Use Cases:

1). input: DB file -> output: DatasetReader.

Use DBFileReader.

2). input: Reader -> build cache DB file -> output: DatasetReader.

Use CachedReader.

# Changes to CachedReader:

1). Move db_path to the constructor.
Because in mock reader. cache will always be built ahead.

# Changes to tests:

1). Make a separate TestCase class for CachedReader and DBFileReader.

2). Make it possible to add more test functions by adding setUp, tearDown and _make_temp_path.

3). Make delete db_path more general. `db_path` could be a file for `log_file_db`, but could also be a directory for `leveldb`.

* Back out "On Mobile phones, call GlobalInit with no arguments in predictor in case we need to perform initialization"

Original commit changeset: 4489c6133f11

* Fix LARS bug

Fixed a bug in the LARS implementation which caused all subsequent blobs not using LARS to have the LARS learning rate multiplier applied to them.

* [tum] support sparse init & add uniformFill option

as title

* Propagate exception for async nets

Capture the exception when an exception is thrown in async nets and re-throw it after wait().  This allows exceptions to be propagated up to the caller.

This diff was a part of D7752068.  We split the diff so that C2 core files changes are in a separate diff.

* Automatic update of fbcode/onnx to 69894f207dfcd72d1e70497d387201cec327efbc

Previous import was 403ccfbd0161c38f0834413d790bad0874afbf9a

Included changes:
- **[69894f2](https://github.com/onnx/onnx/commit/69894f2)**: Use op schema.all tensor types in random like definitions (#865) <Scott McKay>
- **[b9d6b90](https://github.com/onnx/onnx/commit/b9d6b90)**: Clarify random like operators (#846) <Scott McKay>
- **[fc6b5fb](https://github.com/onnx/onnx/commit/fc6b5fb)**: Refactor shape inference implementation (#855) <anderspapitto>
- **[b7d8dc8](https://github.com/onnx/onnx/commit/b7d8dc8)**: fix cmake warning message (#863) <Eric S. Yu>
- **[f585c5d](https://github.com/onnx/onnx/commit/f585c5d)**: add pytorch-operator test for tile (#831) <Wenhao Hu>
- **[993fe70](https://github.com/onnx/onnx/commit/993fe70)**: add install step (#832) <Eric S. Yu>
- **[68bc26c](https://github.com/onnx/onnx/commit/68bc26c)**: add type inference for traditional ml ops except classifier ops. (#857) <Ke Zhang>
- **[9cc0cda](https://github.com/onnx/onnx/commit/9cc0cda)**: fix string representation of scalar types (#858) <G. Ramalingam>
- **[1078925](https://github.com/onnx/onnx/commit/1078925)**: fix y in pow test case to scalar (#852) <Wenhao Hu>
- **[c66fb6f](https://github.com/onnx/onnx/commit/c66fb6f)**: Add some math function shape inference (#845) <anderspapitto>
- **[ff667d1](https://github.com/onnx/onnx/commit/ff667d1)**: Refactor return type and docs for ONNXIFI_BACKEND_DIRECTX_ID (#853) <Marat Dukhan>
- **[11c6876](https://github.com/onnx/onnx/commit/11c6876)**: clear initializer names when clear initializer (#849) <Wenhao Hu>
- **[73c34ae](https://github.com/onnx/onnx/commit/73c34ae)**: Clarify FeatureVectorizer description. (#843) <Scott McKay>
- **[1befb9b](https://github.com/onnx/onnx/commit/1befb9b)**: Remove useless text in docs (#850) <Lu Fang>
- **[e84788f](https://github.com/onnx/onnx/commit/e84788f)**: Fix SELU attributes' default values (#839) <Lu Fang>
- **[ebac046](https://github.com/onnx/onnx/commit/ebac046)**: Add tile test case (#823) <Wenhao Hu>
- **[8b7a925](https://github.com/onnx/onnx/commit/8b7a925)**: a few more shape inference functions (#772) <anderspapitto>
- **[9718f42](https://github.com/onnx/onnx/commit/9718f42)**: Make the coefficient non optional for LinearClassifier (#836) <Jaliya Ekanayake>
- **[ef083d0](https://github.com/onnx/onnx/commit/ef083d0)**: Add save_tensor and load_tensor functions for Protos (#770) <Lu Fang>
- **[45ceb55](https://github.com/onnx/onnx/commit/45ceb55)**: Check if CMAKE_BUILD_TYPE set before project(). (#812) <Sergii Dymchenko>
- **[4b3d2b0](https://github.com/onnx/onnx/commit/4b3d2b0)**: [WIP] reenable shape inference tests (#834) <anderspapitto>
- **[22d17ee](https://github.com/onnx/onnx/commit/22d17ee)**: RNN tests: LSTM, GRU, SimpleRNN (#739) <Peyman Manikashani>
- **[de65b95](https://github.com/onnx/onnx/commit/de65b95)**: dimension denotation (#443) <Tian Jin>
- **[eccc76e](https://github.com/onnx/onnx/commit/eccc76e)**: fix field number issue in onnx operator proto and enable its build (#829) <Ke Zhang>
- **[d582beb](https://github.com/onnx/onnx/commit/d582beb)**: disable shape inference test to unbreak ci (#830) <Lu Fang>
- **[485b787](https://github.com/onnx/onnx/commit/485b787)**: function proto for composite op. (#802) <Ke Zhang>
- **[cd58928](https://github.com/onnx/onnx/commit/cd58928)**: specify defaults for attributes of Affine op (#820) <G. Ramalingam>
- **[7ee2cf9](https://github.com/onnx/onnx/commit/7ee2cf9)**: merge the dummy backend back into the main one (#743) <anderspapitto>
- **[1c03a5a](https://github.com/onnx/onnx/commit/1c03a5a)**: [Proposal] ONNX Interface for Framework Integration (previously ONNX Backend API) header and docs (#551) <Marat Dukhan>
- **[3769a98](https://github.com/onnx/onnx/commit/3769a98)**: Rename real model test case from VGG-16 to ZFNet (#821) <Lu Fang>

* [C2]ReluN Op

relu n op.

tf reference: https://www.tensorflow.org/api_docs/python/tf/nn/relu6

* Call destructor when assigning a blob value

* Add executor overrides

Add executor overrides flag to enable migration to async_scheduling executor

* Add barrier net that runs before training nets - attempt #2

Add a synchonize barrier net that is run before training nets.  With this net, shards that are faster will wait for other shards before start training.  This reduce chances of the faster shards timing out during GLOO AllReduce.
Removed explicit data_parallel_model.py.synchronize call in holmes workflow.

This change was landed previously but caused errors for some EDPM workflows - See https://fb.facebook.com/groups/1426530000692545/permalink/1906766366002237/ - because EDPM assumes any call to CreateOrCloneCommonWorld and Gloo ops are wrapped in exception handlers but in this case exception thrown in the barrier init net is not handled.

To address this issue, we add _CreateOrCloneCommonWorld to the param_init_net instead of a new barrier init net.  Since errors for param_init_net run is handled gracefully and re-rendezvous, it should fixes the problem.

* Handle empty nets in async_scheduling

Make sure we don't get stuck on empty nets

* use CUDA_ARCH for conditional compile

* [C2 fix] infer function for ensure_cpu_output_op

* Update group_norm test to reduce flaky test

* Fix lr_multiplier for GPU
2018-05-10 23:14:27 -07:00

312 lines
12 KiB
Python

## @package gradient_checker
# Module caffe2.python.gradient_checker
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
from caffe2.python import core, workspace, net_drawer
from caffe2.proto import caffe2_pb2
def _get_grad_blob(grad_map, input_to_check):
grad_blob = grad_map[input_to_check]
if isinstance(grad_blob, core.BlobReference):
return workspace.blobs[grad_blob]
# If grad_blob is not a single blob, it should be a gradient slice.
# To make it comparable with the estimiated gradient which is dense,
# we need to first convert grad_blob to dense gradient.
assert isinstance(grad_blob, core.GradientSlice)
dense_grad = 'tmp_dense_grad'
sparse_to_dense_op = core.CreateOperator(
'SparseToDense',
[grad_blob.indices, grad_blob.values, input_to_check],
dense_grad,
)
workspace.RunOperatorOnce(sparse_to_dense_op)
return workspace.blobs[dense_grad]
def _get_grad(net, outputs, outputs_with_grad, input_values, inputs_with_grads):
grad_net = net.Clone(net.Name() + "_copy")
grad_map = grad_net.AddGradientOperators(outputs_with_grad)
for name, value in (input_values or {}).items():
workspace.blobs[name] = value
for input_to_check in inputs_with_grads:
assert input_to_check in grad_map, (
'{} has no gradient, cannot check net gradient.'.format(
input_to_check))
assert str(input_to_check) in workspace.blobs
workspace.RunNetOnce(grad_net)
forward_results = [(output, workspace.blobs[output]) for output in outputs]
grads = {input_to_check: _get_grad_blob(grad_map, input_to_check)
for input_to_check in inputs_with_grads}
return forward_results, grads, grad_net
def _assert_close(value1, value2, threshold, err_msg=''):
np.testing.assert_allclose(
value1, value2,
atol=threshold, rtol=threshold,
err_msg=err_msg,
)
delta = np.abs(value1 - value2).flatten()
return np.mean(delta), max(delta)
class NetGradientChecker(object):
@staticmethod
def CompareNets(nets, outputs, outputs_with_grad_ids,
inputs_with_grads, input_values=None,
threshold=0.0000001, print_net_images=False):
def _get_output_with_grad_names(net_outputs):
return [net_outputs[i] for i in outputs_with_grad_ids]
if print_net_images:
for i, net in enumerate(nets):
png = net_drawer.GetPydotGraph(net).create_png()
with open("caffe2_net_forward_" + str(i) + net.Name() + ".png",
'wb') \
as f:
f.write(png)
results = [
_get_grad(net, net_outputs,
_get_output_with_grad_names(net_outputs),
input_values, inputs_with_grads)
for net, net_outputs in zip(nets, outputs)
]
if print_net_images:
_, _, backward_nets = zip(*results)
for i, net in enumerate(backward_nets):
png = net_drawer.GetPydotGraph(net).create_png()
with open("caffe2_net_" + str(i) + net.Name() + ".png", 'wb') \
as f:
f.write(png)
first_net_results, first_net_grads, _ = results[0]
for net_results, net_grads, _ in results[1:]:
assert len(net_results) == len(first_net_results)
for idx, ((blob1, blob_value1), (blob2, blob_value2)) in enumerate(
zip(first_net_results, net_results)):
_assert_close(
blob_value1, blob_value2, threshold,
err_msg="Different forward pass results for output id {}. "
"Corresponding output blobs: {} and {}".format(
idx, blob1, blob2))
assert net_grads.keys() == first_net_grads.keys()
for blob, blob_grad_value in net_grads.items():
_assert_close(
first_net_grads[blob], blob_grad_value, threshold,
err_msg="Different gradients for input {}".format(blob))
@staticmethod
def Check(net, outputs_with_grad, input_values,
input_to_check, step_size=0.0001,
threshold=0.05, print_net=True):
net_results, net_grads, full_net = _get_grad(
net, [], outputs_with_grad, input_values, [input_to_check])
analytic_grad = net_grads[input_to_check]
def GetLoss(new_value):
workspace.blobs[input_to_check] = new_value
workspace.RunNetOnce(full_net)
return sum([
workspace.blobs[output]
for output in outputs_with_grad
]).sum()
def GetValue(dim, delta):
input_value = input_values[input_to_check].copy()
input_value.flat[dim] += delta
return input_value
grad_estimate = np.zeros_like(input_values[input_to_check])
for dim in range(input_values[input_to_check].size):
pos_loss = GetLoss(GetValue(dim, step_size))
neg_loss = GetLoss(GetValue(dim, -step_size))
grad_estimate.flat[dim] = (pos_loss - neg_loss) / step_size / 2
err_msg = "Error in gradient check for net_copy {}".format(
net.Name())
if print_net:
err_msg += ": {}".format(net.Proto())
return _assert_close(analytic_grad, grad_estimate, threshold, err_msg)
class GradientChecker:
"""A gradient checker in Python.
This is not the most efficient way to check gradients, as the Python
interface will involve a lot of copies back and forth operations. Use at your
own risk.
"""
def __init__(
self,
stepsize,
threshold,
device_option=None,
workspace_name="gradient_check"
):
self._stepsize = stepsize
self._threshold = threshold
self._device_option = device_option or caffe2_pb2.DeviceOption()
self._workspace_name = workspace_name
def GetLossAndGrad(
self, op, grad_ops, x, input_name, grad_name, outputs_with_grads
):
# First, feed in the current input. Note that we are not changing
# anything else, so we don't need to feed in others.
workspace.FeedBlob(input_name, x, self._device_option)
# Run.
workspace.RunOperatorOnce(op)
loss = 0.
# Get Loss and feed in the gradients, run gradient ops.
for idx in outputs_with_grads:
name = op.output[idx]
arr = workspace.FetchBlob(name)
loss += (arr**2).sum()
workspace.FeedBlob(name + '_grad', arr, self._device_option)
loss /= 2.
# Run gradient ops
workspace.RunOperatorsOnce(grad_ops)
# Get gradients
if isinstance(grad_name, core.GradientSlice):
workspace.FeedBlob('zeros', np.zeros_like(x, dtype=np.float32))
workspace.FeedBlob('ones', np.ones(1, dtype=np.float32))
gv_cpu_op = core.CreateOperator(
'EnsureCPUOutput', grad_name.values, grad_name.values + '_cpu',
device_option=self._device_option
)
gi_cpu_op = core.CreateOperator(
'EnsureCPUOutput', grad_name.indices, grad_name.indices + '_cpu',
device_option=self._device_option
)
sparse_to_dense_op = core.CreateOperator(
'ScatterWeightedSum',
[
'zeros', 'ones', grad_name.indices + '_cpu',
grad_name.values + '_cpu', 'ones'
],
'zeros',
)
workspace.RunOperatorOnce(gv_cpu_op)
workspace.RunOperatorOnce(gi_cpu_op)
workspace.RunOperatorOnce(sparse_to_dense_op)
grad = workspace.FetchBlob('zeros')
else:
grad = workspace.FetchBlob(grad_name)
return loss, grad
def CheckSimple(
self,
op,
inputs,
input_to_check,
outputs_with_grads,
grad_ops=None,
input_device_options=None
):
"""Checks the operator in a very simple fashion by stacking a sum of
squares on the top.
Inputs:
op: the operator to be checked.
inputs: the input data in numpy arrays.
input_to_check: an index specifying which input blob we should
check.
outputs_with_grads: indices specifying which output blobs will we
need to check gradients with. For these outputs, we will collect a
squared sum and also feed in their gradients.
grad_operator: the gradient operator. If not given, we will get the
gradient operator from the gradient registry.
input_device_options: an optional mapping from input names to
DeviceOptions (to override the default DeviceOption)
Outputs:
boolean: True if it passes, False if it does not pass.
"""
# Entering the checker workspace
old_ws_name = workspace.CurrentWorkspace()
if self._workspace_name != old_ws_name:
workspace.SwitchWorkspace(self._workspace_name, True)
op.device_option.CopyFrom(self._device_option)
if grad_ops is None:
# TODO(jiayq): use the gradient registration instead of the old
# hack.
grad_ops, g_input = core.GradientRegistry.GetGradientForOp(
op, [s + '_grad' for s in op.output])
dims_to_check = inputs[input_to_check].size
_input_device_options = input_device_options or \
core.InferOpBlobDevicesAsDict(op)[0]
# First, feed in the input.
for i, arr in enumerate(inputs):
workspace.FeedBlob(
op.input[i], arr,
_input_device_options.get(
op.input[i], self._device_option))
# Get the loss and gradient for the original.
input_name = op.input[input_to_check]
grad_name = g_input[input_to_check]
loss, grad = self.GetLossAndGrad(
op, grad_ops, inputs[input_to_check], input_name, grad_name,
outputs_with_grads
)
grad_estimate = np.zeros_like(inputs[input_to_check])
if grad_estimate.shape != grad.shape:
raise Exception(
"Mismatched gradient shapes: estimated ({}), grad ({})".format(
grad_estimate.shape, grad.shape))
for current_dim in range(dims_to_check):
# Positive gradient
inputs[input_to_check].flat[current_dim] += self._stepsize
pos_loss, _ = self.GetLossAndGrad(
op, grad_ops, inputs[input_to_check], input_name,
grad_name, outputs_with_grads
)
# Negative gradient
inputs[input_to_check].flat[current_dim] -= self._stepsize * 2
neg_loss, _ = self.GetLossAndGrad(
op, grad_ops, inputs[input_to_check], input_name,
grad_name, outputs_with_grads
)
# Recover the value
inputs[input_to_check].flat[current_dim] += self._stepsize
grad_estimate.flat[current_dim] = (
pos_loss - neg_loss) / self._stepsize / 2
# Now, check correctness
fail_mat = ~np.isclose(
grad, grad_estimate, atol=self._threshold, rtol=self._threshold)
if np.any(fail_mat):
idx = np.flatnonzero(fail_mat)
print('Failed. [idx, grad, grad_estimate] are:')
print(np.vstack([idx, grad.flat[idx], grad_estimate.flat[idx]]).T)
ret = False
else:
ret = True
# After finishing, cleaning up things.
if self._workspace_name != old_ws_name:
# We reset the workspace to make sure everything intermediate is
# cleaned up. Note that there is no need to delete a workspace -
# when empty it takes a very limited amount of memory.
workspace.ResetWorkspace()
workspace.SwitchWorkspace(old_ws_name)
return ret, grad, grad_estimate