mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Internal cleanup.
PiperOrigin-RevId: 170922297
This commit is contained in:
parent
d0c76cd188
commit
5123f29718
|
|
@ -324,12 +324,18 @@ def imperative_grad(
|
||||||
result.append(_aggregate_grads(g))
|
result.append(_aggregate_grads(g))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
_op_attr_type_cache = {}
|
||||||
|
|
||||||
|
|
||||||
def op_attr_type(op_type, attr_name):
|
def op_attr_type(op_type, attr_name):
|
||||||
|
try:
|
||||||
|
return _op_attr_type_cache[(op_type, attr_name)]
|
||||||
|
except KeyError:
|
||||||
with errors.raise_exception_on_not_ok_status() as status:
|
with errors.raise_exception_on_not_ok_status() as status:
|
||||||
h = context.context()._handle # pylint: disable=protected-access
|
h = context.context()._handle # pylint: disable=protected-access
|
||||||
op = pywrap_tensorflow.TFE_NewOp(h, op_type, status)
|
op = pywrap_tensorflow.TFE_NewOp(h, op_type, status)
|
||||||
attr_type = pywrap_tensorflow.TFE_OpGetAttrType(op, attr_name, status)
|
attr_type = pywrap_tensorflow.TFE_OpGetAttrType(op, attr_name, status)
|
||||||
|
_op_attr_type_cache[(op_type, attr_name)] = attr_type
|
||||||
return attr_type
|
return attr_type
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -277,7 +277,9 @@ class BatchNormalization(base.Layer):
|
||||||
with ops.name_scope(None, 'AssignMovingAvg',
|
with ops.name_scope(None, 'AssignMovingAvg',
|
||||||
[variable, value, one_minus_decay]) as scope:
|
[variable, value, one_minus_decay]) as scope:
|
||||||
with ops.colocate_with(variable):
|
with ops.colocate_with(variable):
|
||||||
update_delta = (variable.read_value() - value) * one_minus_decay
|
update_delta = math_ops.multiply(
|
||||||
|
math_ops.subtract(variable.read_value(), value),
|
||||||
|
one_minus_decay)
|
||||||
if isinstance(variable, resource_variable_ops.ResourceVariable):
|
if isinstance(variable, resource_variable_ops.ResourceVariable):
|
||||||
# state_ops.assign_sub does an extra read_variable_op after the
|
# state_ops.assign_sub does an extra read_variable_op after the
|
||||||
# assign. We avoid that here.
|
# assign. We avoid that here.
|
||||||
|
|
|
||||||
|
|
@ -79,11 +79,12 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
|
||||||
|
|
||||||
def _ExtractInputShapes(inputs):
|
def _ExtractInputShapes(inputs):
|
||||||
"""Extract the shapes of a set of input tensors."""
|
"""Extract the shapes of a set of input tensors."""
|
||||||
|
if not context.in_graph_mode():
|
||||||
|
return array_ops.shape_n(inputs)
|
||||||
sizes = []
|
sizes = []
|
||||||
fully_known = True
|
fully_known = True
|
||||||
for x in inputs:
|
for x in inputs:
|
||||||
input_shape = array_ops.shape(x)
|
input_shape = array_ops.shape(x)
|
||||||
if context.in_graph_mode():
|
|
||||||
if not isinstance(input_shape,
|
if not isinstance(input_shape,
|
||||||
ops.Tensor) or input_shape.op.type != "Const":
|
ops.Tensor) or input_shape.op.type != "Const":
|
||||||
fully_known = False
|
fully_known = False
|
||||||
|
|
|
||||||
|
|
@ -460,16 +460,25 @@ def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _):
|
||||||
|
|
||||||
@ops.RegisterGradient("Conv2D")
|
@ops.RegisterGradient("Conv2D")
|
||||||
def _Conv2DGrad(op, grad):
|
def _Conv2DGrad(op, grad):
|
||||||
return [nn_ops.conv2d_backprop_input(
|
strides = op.get_attr("strides")
|
||||||
array_ops.shape(op.inputs[0]), op.inputs[1], grad, op.get_attr("strides"),
|
padding = op.get_attr("padding")
|
||||||
op.get_attr("padding"), op.get_attr("use_cudnn_on_gpu"),
|
use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu")
|
||||||
op.get_attr("data_format")),
|
data_format = op.get_attr("data_format")
|
||||||
|
shape_0, shape_1 = array_ops.shape_n([op.inputs[0], op.inputs[1]])
|
||||||
|
return [nn_ops.conv2d_backprop_input(shape_0,
|
||||||
|
op.inputs[1],
|
||||||
|
grad,
|
||||||
|
strides,
|
||||||
|
padding,
|
||||||
|
use_cudnn_on_gpu,
|
||||||
|
data_format),
|
||||||
nn_ops.conv2d_backprop_filter(op.inputs[0],
|
nn_ops.conv2d_backprop_filter(op.inputs[0],
|
||||||
array_ops.shape(op.inputs[1]), grad,
|
shape_1,
|
||||||
op.get_attr("strides"),
|
grad,
|
||||||
op.get_attr("padding"),
|
strides,
|
||||||
op.get_attr("use_cudnn_on_gpu"),
|
padding,
|
||||||
op.get_attr("data_format"))]
|
use_cudnn_on_gpu,
|
||||||
|
data_format)]
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("DepthwiseConv2dNative")
|
@ops.RegisterGradient("DepthwiseConv2dNative")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user