Updated call method in base layer

Added a description about enabling eager mode for debugging purpose.
This commit is contained in:
Vishnuvardhan Janapati 2021-06-11 18:14:27 -07:00 committed by GitHub
parent ee37fee4ee
commit 87579ba516
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -477,13 +477,20 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
@doc_controls.for_subclass_implementers
def call(self, inputs, *args, **kwargs): # pylint: disable=unused-argument
"""This is where the layer's logic lives.
"""This is where the layer's forward pass is defined and logic lives.
By default, `layer.call()` run in graph mode for higher
performance. In order to run the `layer.call()` body in eager context,
you need to invoke `model.compile(run_eagerly=True)`. Note that this
will have performance slow down due to eager runtime, and should only
be used for debug purpose.
Note here that `call()` method in `tf.keras` is little bit different
from `keras` API. In `keras` API, you can pass support masking for
layers as additional arguments. Whereas `tf.keras` has `compute_mask()`
method to support masking.
Args:
inputs: Input tensor, or dict/list/tuple of input tensors.
The first positional `inputs` argument is subject to special rules: