* made it explicit in the docstring of Module.register_forward_hook() that the hook(s) will be called AFTER calling forward().
* added "every time" in docstring of Module.register_forward_pre_hook()
* Add weight normalization implementation
This adds forward "pre-hooks" which get called before the module's
forward() method. Weight norm is implemented as a hook which calculates
the weight variable from the weight_g and weight_v every iteration.
Based on @rtqichen implementation.
* Specify return type
a module that returns a non-standard data structure currently breaks
due to checks for backwards hooks. This refactors the code slightly so
this will only break in the event of backwards hooks.
We were keying hooks by RemovableHandle id. However, we don't hold onto
handles and ids of dead objects can be reused. This replaces id(handle)
with a global counter.
The core autograd Variable, Function, and Engine no longer depend on the
Python API. This let's us implement functions in C++. In the future, we
can also multithread engine and release the GIL for most of the
non-Python backwards.
Here's the command I used to invoke autopep8 (in parallel!):
git ls-files | grep '\.py$' | xargs -n1 -P`nproc` autopep8 -i
Several rules are ignored in setup.cfg. The goal is to let autopep8
handle everything which it can handle safely, and to disable any rules
which are tricky or controversial to address. We may want to come back
and re-enable some of these rules later, but I'm trying to make this
patch as safe as possible.
Also configures flake8 to match pep8's behavior.
Also configures TravisCI to check the whole project for lint.
The load_state_dict() function now raises an error if the argument
state_dict has extra keys or is missing keys.
Previously, load_state_dict() ignored extra and missing keys, which made
it hard to notice when you load an invalid state_dict. This could
happen, for example, if you save the state_dict for a DataParallel, but
load it into a single model.
The state_dict() function now only includes the Tensor data from the
paramters, which reduces checkpoint size by not saving gradients.
The register hook calls now return an object that can be used to remove
the hook. For example,
>>> h = module.register_forward_hook(callback)
>>> h.remove() # removes hook
Or as a context manager:
>>> with module.register_forward_hook(callback):
... pass
This makes it easier for libraries to use hooks without worrying about
name collisions.