mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Removes unnecessary eager-mode call to convert_to_tensor in record_gradient.
PiperOrigin-RevId: 170944265
This commit is contained in:
parent
add6d2d03c
commit
d4ea993cae
|
|
@ -524,7 +524,7 @@ _grad_fn_accepts_none_for_indices = {
|
|||
}
|
||||
|
||||
|
||||
def _record_gradient(op_name, inputs, attrs, results, ctx, name):
|
||||
def _record_gradient(op_name, inputs, attrs, results, name):
|
||||
"""Records gradients for a TensorFlow operation.
|
||||
|
||||
Args:
|
||||
|
|
@ -534,7 +534,6 @@ def _record_gradient(op_name, inputs, attrs, results, ctx, name):
|
|||
attrs: A tuple with alternating string attr names and attr values for this
|
||||
operation.
|
||||
results: The results of the operation (as a flat list).
|
||||
ctx: The value of context.context().
|
||||
name: Customized name for the operation.
|
||||
|
||||
Returns:
|
||||
|
|
@ -572,7 +571,6 @@ def _record_gradient(op_name, inputs, attrs, results, ctx, name):
|
|||
"output_grads", orig_outputs, "gradients", result)
|
||||
return result
|
||||
|
||||
inputs = [ops.internal_convert_to_tensor(x, ctx=ctx) for x in inputs]
|
||||
tape.record_operation(op_name, results, inputs, [], grad_fn)
|
||||
if _tracing:
|
||||
print("Computed op", (name if name else op_name), "inputs", inputs,
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ def execute(op_name, num_outputs, inputs, attrs, ctx, name=None):
|
|||
|
||||
|
||||
def record_gradient(unused_op_name, unused_inputs, unused_attrs, unused_results,
|
||||
unused_ctx, unused_name):
|
||||
unused_name):
|
||||
"""Import backprop if you want gradients recorded."""
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -412,7 +412,7 @@ string GenEagerPythonOp::Code() {
|
|||
" if not _result:\n"
|
||||
" return _op\n");
|
||||
}
|
||||
strings::StrAppend(&result_, " _inputs_flat = ", inputs, "\n");
|
||||
strings::StrAppend(&result_, " _inputs_flat = _op.inputs\n");
|
||||
|
||||
// Compute graph-mode attrs.
|
||||
if (op_def_.attr_size() > 0) {
|
||||
|
|
@ -511,7 +511,7 @@ string GenEagerPythonOp::Code() {
|
|||
if (num_outs_ > 0) {
|
||||
strings::StrAppend(&result_, " _execute.record_gradient(\n", " \"",
|
||||
op_def_.name(),
|
||||
"\", _inputs_flat, _attrs, _result, _ctx, name)\n");
|
||||
"\", _inputs_flat, _attrs, _result, name)\n");
|
||||
if (num_outs_ == 1 && !output_sizes[0].empty()) {
|
||||
// Single list result.
|
||||
} else if (num_outs_ == 1) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user