mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Added get variable utils to tf.estimator.Estimator.
PiperOrigin-RevId: 171052121
This commit is contained in:
parent
083bd5dde5
commit
d66e77f7c3
|
|
@ -1,155 +0,0 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""Demonstrates a regression on Boston housing data.
|
||||
|
||||
This example demonstrates how to run experiments with TF Boosted Trees on
|
||||
a regression dataset. We split all the data into 20% test and 80% train,
|
||||
and are using l2 loss and l2 regularization.
|
||||
|
||||
Example Usage:
|
||||
|
||||
python tensorflow/contrib/boosted_trees/examples/boston.py \
|
||||
--batch_size=404 --output_dir="/tmp/boston" --depth=4 --learning_rate=0.1 \
|
||||
--num_eval_steps=1 --num_trees=500 --l2=4 \
|
||||
--vmodule=training_ops=1
|
||||
|
||||
When training is done, mean squared error on eval data is reported.
|
||||
Point tensorboard to the directory for the run to see how the training
|
||||
progresses:
|
||||
|
||||
tensorboard --logdir=/tmp/boston
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeRegressor
|
||||
from tensorflow.contrib.boosted_trees.proto import learner_pb2
|
||||
from tensorflow.contrib.layers.python.layers import feature_column
|
||||
from tensorflow.contrib.learn import learn_runner
|
||||
|
||||
_TEST_SPLIT_RATIO = 0.2
|
||||
_TEST_SPLIT_SEED = 42
|
||||
_BOSTON_NUM_FEATURES = 13
|
||||
|
||||
|
||||
# Main config - creates a TF Boosted Trees Estimator based on flags.
|
||||
def _get_tfbt(output_dir, feature_cols):
|
||||
"""Configures TF Boosted Trees estimator based on flags."""
|
||||
learner_config = learner_pb2.LearnerConfig()
|
||||
|
||||
learner_config.learning_rate_tuner.fixed.learning_rate = FLAGS.learning_rate
|
||||
learner_config.regularization.l1 = 0.0
|
||||
# Set the regularization per instance in such a way that
|
||||
# regularization for the full training data is equal to l2 flag.
|
||||
learner_config.regularization.l2 = FLAGS.l2 / FLAGS.batch_size
|
||||
learner_config.constraints.max_tree_depth = FLAGS.depth
|
||||
learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE
|
||||
|
||||
run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300)
|
||||
|
||||
# Create a TF Boosted trees regression estimator.
|
||||
estimator = GradientBoostedDecisionTreeRegressor(
|
||||
learner_config=learner_config,
|
||||
# For the WHOLE_TREE strategy, set the examples_per_layer to be equal to
|
||||
# batch size.
|
||||
examples_per_layer=FLAGS.batch_size,
|
||||
feature_columns=feature_cols,
|
||||
label_dimension=1,
|
||||
model_dir=output_dir,
|
||||
num_trees=FLAGS.num_trees,
|
||||
center_bias=False,
|
||||
config=run_config)
|
||||
return estimator
|
||||
|
||||
|
||||
def _make_experiment_fn(output_dir):
|
||||
"""Creates experiment for gradient boosted decision trees."""
|
||||
(x_train, y_train), (x_test,
|
||||
y_test) = tf.keras.datasets.boston_housing.load_data()
|
||||
|
||||
train_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||
x={"x": x_train},
|
||||
y=y_train,
|
||||
batch_size=FLAGS.batch_size,
|
||||
num_epochs=None,
|
||||
shuffle=True)
|
||||
|
||||
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||
x={"x": x_test}, y=y_test, num_epochs=1, shuffle=False)
|
||||
|
||||
feature_columns = [
|
||||
feature_column.real_valued_column("x", dimension=_BOSTON_NUM_FEATURES)
|
||||
]
|
||||
|
||||
return tf.contrib.learn.Experiment(
|
||||
estimator=_get_tfbt(output_dir, feature_columns),
|
||||
train_input_fn=train_input_fn,
|
||||
eval_input_fn=eval_input_fn,
|
||||
train_steps=None,
|
||||
eval_steps=FLAGS.num_eval_steps,
|
||||
eval_metrics=None)
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
learn_runner.run(
|
||||
experiment_fn=_make_experiment_fn,
|
||||
output_dir=FLAGS.output_dir,
|
||||
schedule="train_and_evaluate")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
parser = argparse.ArgumentParser()
|
||||
# Define the list of flags that users can change.
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="The batch size for reading data.")
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Choose the dir for the output.")
|
||||
parser.add_argument(
|
||||
"--num_eval_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of steps to run evaluation for.")
|
||||
# Flags for gradient boosted trees config.
|
||||
parser.add_argument(
|
||||
"--depth", type=int, default=4, help="Maximum depth of weak learners.")
|
||||
parser.add_argument(
|
||||
"--l2", type=float, default=1.0, help="l2 regularization per batch.")
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="Learning rate (shrinkage weight) with which each new tree is added."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_trees",
|
||||
type=int,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Number of trees to grow before stopping.")
|
||||
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
|
@ -129,8 +129,8 @@ def _get_tfbt(output_dir):
|
|||
def _make_experiment_fn(output_dir):
|
||||
"""Creates experiment for gradient boosted decision trees."""
|
||||
data = tf.contrib.learn.datasets.mnist.load_mnist()
|
||||
train_input_fn = get_input_fn(data.train, FLAGS.batch_size)
|
||||
eval_input_fn = get_input_fn(data.validation, FLAGS.eval_batch_size)
|
||||
train_input_fn = get_input_fn(data.train, batch_size=256)
|
||||
eval_input_fn = get_input_fn(data.validation, batch_size=5000)
|
||||
|
||||
return tf.contrib.learn.Experiment(
|
||||
estimator=_get_tfbt(output_dir),
|
||||
|
|
|
|||
|
|
@ -10,9 +10,8 @@ load(":src/gen/gen_ops.bzl", "tf_java_op_gen_srcjar")
|
|||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_binary_additional_srcs",
|
||||
"tf_cc_binary",
|
||||
"tf_copts",
|
||||
"tf_custom_op_library",
|
||||
"tf_cc_binary",
|
||||
"tf_java_test",
|
||||
)
|
||||
|
||||
|
|
@ -181,16 +180,10 @@ tf_java_test(
|
|||
],
|
||||
)
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "my_test_op.so",
|
||||
srcs = ["src/test/native/my_test_op.cc"],
|
||||
)
|
||||
|
||||
tf_java_test(
|
||||
name = "TensorFlowTest",
|
||||
size = "small",
|
||||
srcs = ["src/test/java/org/tensorflow/TensorFlowTest.java"],
|
||||
data = [":my_test_op.so"],
|
||||
javacopts = JAVACOPTS,
|
||||
test_class = "org.tensorflow.TensorFlowTest",
|
||||
deps = [
|
||||
|
|
|
|||
|
|
@ -29,36 +29,6 @@ public final class TensorFlow {
|
|||
*/
|
||||
public static native byte[] registeredOpList();
|
||||
|
||||
/**
|
||||
* Load the dynamic library in filename and register the operations and kernels present in that
|
||||
* library.
|
||||
*
|
||||
* @param filename Path of the dynamic library containing operations and kernels to load.
|
||||
* @return Serialized bytes of the <a
|
||||
* href="https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto">OpList</a>
|
||||
* protocol buffer message defining the operations defined in the library.
|
||||
* @throws UnsatisfiedLinkError if filename cannot be loaded.
|
||||
*/
|
||||
public static byte[] loadLibrary(String filename) {
|
||||
long h = 0;
|
||||
try {
|
||||
h = libraryLoad(filename);
|
||||
} catch (RuntimeException e) {
|
||||
throw new UnsatisfiedLinkError(e.getMessage());
|
||||
}
|
||||
try {
|
||||
return libraryOpList(h);
|
||||
} finally {
|
||||
libraryDelete(h);
|
||||
}
|
||||
}
|
||||
|
||||
private static native long libraryLoad(String filename);
|
||||
|
||||
private static native void libraryDelete(long handle);
|
||||
|
||||
private static native byte[] libraryOpList(long handle);
|
||||
|
||||
private TensorFlow() {}
|
||||
|
||||
/** Load the TensorFlow runtime C library. */
|
||||
|
|
|
|||
|
|
@ -14,10 +14,7 @@ limitations under the License.
|
|||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/java/src/main/native/tensorflow_jni.h"
|
||||
|
||||
#include <limits>
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/java/src/main/native/exception_jni.h"
|
||||
|
||||
JNIEXPORT jstring JNICALL Java_org_tensorflow_TensorFlow_version(JNIEnv* env,
|
||||
jclass clazz) {
|
||||
|
|
@ -33,35 +30,3 @@ Java_org_tensorflow_TensorFlow_registeredOpList(JNIEnv* env, jclass clazz) {
|
|||
TF_DeleteBuffer(buf);
|
||||
return ret;
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_org_tensorflow_TensorFlow_libraryLoad(
|
||||
JNIEnv* env, jclass clazz, jstring filename) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
const char* cname = env->GetStringUTFChars(filename, nullptr);
|
||||
TF_Library* h = TF_LoadLibrary(cname, status);
|
||||
throwExceptionIfNotOK(env, status);
|
||||
env->ReleaseStringUTFChars(filename, cname);
|
||||
TF_DeleteStatus(status);
|
||||
return reinterpret_cast<jlong>(h);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_TensorFlow_libraryDelete(
|
||||
JNIEnv* env, jclass clazz, jlong handle) {
|
||||
if (handle != 0) {
|
||||
TF_DeleteLibraryHandle(reinterpret_cast<TF_Library*>(handle));
|
||||
}
|
||||
}
|
||||
|
||||
JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_TensorFlow_libraryOpList(
|
||||
JNIEnv* env, jclass clazz, jlong handle) {
|
||||
TF_Buffer buf = TF_GetOpList(reinterpret_cast<TF_Library*>(handle));
|
||||
if (buf.length > std::numeric_limits<jint>::max()) {
|
||||
throwException(env, kIndexOutOfBoundsException,
|
||||
"Serialized OpList is too large for a byte[] array");
|
||||
return nullptr;
|
||||
}
|
||||
auto ret_len = static_cast<jint>(buf.length);
|
||||
jbyteArray ret = env->NewByteArray(ret_len);
|
||||
env->SetByteArrayRegion(ret, 0, ret_len, static_cast<const jbyte*>(buf.data));
|
||||
return ret;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ extern "C" {
|
|||
* Method: version
|
||||
* Signature: ()Ljava/lang/String;
|
||||
*/
|
||||
JNIEXPORT jstring JNICALL Java_org_tensorflow_TensorFlow_version(JNIEnv *,
|
||||
JNIEXPORT jstring JNICALL Java_org_tensorflow_TensorFlow_version(JNIEnv*,
|
||||
jclass);
|
||||
|
||||
/*
|
||||
|
|
@ -36,33 +36,7 @@ JNIEXPORT jstring JNICALL Java_org_tensorflow_TensorFlow_version(JNIEnv *,
|
|||
* Signature: ()[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_tensorflow_TensorFlow_registeredOpList(JNIEnv *, jclass);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_TensorFlow
|
||||
* Method: libraryLoad
|
||||
* Signature: (Ljava/lang/String;)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_org_tensorflow_TensorFlow_libraryLoad(JNIEnv *,
|
||||
jclass,
|
||||
jstring);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_TensorFlow
|
||||
* Method: libraryDelete
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_TensorFlow_libraryDelete(JNIEnv *,
|
||||
jclass,
|
||||
jlong);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_TensorFlow
|
||||
* Method: libraryOpList
|
||||
* Signature: (J)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_tensorflow_TensorFlow_libraryOpList(JNIEnv *, jclass, jlong);
|
||||
Java_org_tensorflow_TensorFlow_registeredOpList(JNIEnv*, jclass);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ limitations under the License.
|
|||
package org.tensorflow;
|
||||
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
|
@ -37,26 +36,4 @@ public class TensorFlowTest {
|
|||
// was not sorted out. Revisit? Till then, at least exercise the code.
|
||||
assertTrue(TensorFlow.registeredOpList().length > 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void loadLibrary() {
|
||||
// TODO(ashankar): This tell will fail when built with --config=monolithic.
|
||||
// Figure out how we can ignore the test in that case.
|
||||
try (Graph g = new Graph()) {
|
||||
// Build a graph with an unrecognized operation.
|
||||
try {
|
||||
g.opBuilder("MyTest", "MyTest").build();
|
||||
fail("should not be able to construct graphs with unregistered ops");
|
||||
} catch (IllegalArgumentException e) {
|
||||
// expected exception
|
||||
}
|
||||
|
||||
// Load the library containing the operation.
|
||||
byte[] opList = TensorFlow.loadLibrary("tensorflow/java/my_test_op.so");
|
||||
assertTrue(opList.length > 0);
|
||||
|
||||
// Now graph building should succeed.
|
||||
g.opBuilder("MyTest", "MyTest").build();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,21 +0,0 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
REGISTER_OP("MyTest")
|
||||
.Doc("Custom operation for testing.")
|
||||
.SetShapeFn(tensorflow::shape_inference::UnknownShape);
|
||||
|
|
@ -168,31 +168,27 @@ def make_tensor(v, arg_name):
|
|||
|
||||
def args_to_matching_eager(l, ctx, default_dtype=None):
|
||||
"""Convert sequence `l` to eager same-type Tensors."""
|
||||
EagerTensor = ops.EagerTensor # pylint: disable=invalid-name
|
||||
if all(isinstance(x, EagerTensor) for x in l):
|
||||
return l[0].dtype, l
|
||||
# TODO(josh11b): Could we do a better job if we also passed in the
|
||||
# allowed dtypes when that was known?
|
||||
|
||||
# Is some input already a Tensor with a dtype?
|
||||
dtype = None
|
||||
for t in l:
|
||||
if isinstance(t, EagerTensor):
|
||||
if isinstance(t, ops.EagerTensor):
|
||||
dtype = t.dtype
|
||||
break
|
||||
|
||||
internal_convert_to_tensor = ops.internal_convert_to_tensor
|
||||
if dtype is None:
|
||||
# Infer a dtype based on the first value, and use that dtype for the
|
||||
# remaining values.
|
||||
ret = []
|
||||
for t in l:
|
||||
ret.append(internal_convert_to_tensor(
|
||||
ret.append(ops.internal_convert_to_tensor(
|
||||
t, dtype, preferred_dtype=default_dtype, ctx=ctx))
|
||||
if dtype is None:
|
||||
dtype = ret[-1].dtype
|
||||
else:
|
||||
ret = [internal_convert_to_tensor(t, dtype, ctx=ctx) for t in l]
|
||||
ret = [ops.internal_convert_to_tensor(t, dtype, ctx=ctx) for t in l]
|
||||
|
||||
return dtype, ret
|
||||
|
||||
|
|
|
|||
|
|
@ -112,10 +112,8 @@ class Layer(object):
|
|||
self._per_input_losses = {}
|
||||
self._per_input_updates = {}
|
||||
self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
|
||||
call_fn_args = estimator_util.fn_args(self.call)
|
||||
self._compute_previous_mask = ('mask' in call_fn_args or
|
||||
hasattr(self, 'compute_mask'))
|
||||
self._call_has_scope_arg = 'scope' in call_fn_args
|
||||
self._compute_previous_mask = ('mask' in estimator_util.fn_args(self.call)
|
||||
or hasattr(self, 'compute_mask'))
|
||||
|
||||
# These lists will be filled via successive calls
|
||||
# to self._add_inbound_node().
|
||||
|
|
@ -557,15 +555,7 @@ class Layer(object):
|
|||
self.build(input_shapes[0])
|
||||
else:
|
||||
self.build(input_shapes)
|
||||
try:
|
||||
# Note: not all sub-classes of Layer call Layer.__init__ (especially
|
||||
# the ones under tensorflow/python/keras). Hence we recompute this
|
||||
# attribute here if it is not set.
|
||||
# TODO(agarwal): Fix the sub-classes and avoid this complexity.
|
||||
call_has_scope_arg = self._call_has_scope_arg
|
||||
except AttributeError:
|
||||
call_has_scope_arg = 'scope' in estimator_util.fn_args(self.call)
|
||||
if call_has_scope_arg:
|
||||
if 'scope' in estimator_util.fn_args(self.call):
|
||||
kwargs['scope'] = scope
|
||||
# Check input assumptions set after layer building, e.g. input shape.
|
||||
if in_graph_mode:
|
||||
|
|
@ -1443,10 +1433,8 @@ class Network(Layer):
|
|||
self._activity_regularizer = None
|
||||
self._scope = next(vs.variable_scope(None, default_name=base_name).gen)
|
||||
self._base_name = base_name
|
||||
call_fn_args = estimator_util.fn_args(self.call)
|
||||
self._compute_previous_mask = ('mask' in call_fn_args or
|
||||
hasattr(self, 'compute_mask'))
|
||||
self._call_has_scope_arg = 'scope' in call_fn_args
|
||||
self._compute_previous_mask = ('mask' in estimator_util.fn_args(self.call)
|
||||
or hasattr(self, 'compute_mask'))
|
||||
|
||||
# This acts just like the `trainable` attribute of any layer instance.
|
||||
# It does not affect users of the underlying layers, only users of the
|
||||
|
|
|
|||
|
|
@ -330,7 +330,7 @@ class BatchNormalization(base.Layer):
|
|||
lambda: self._one_minus_decay,
|
||||
lambda: 0.)
|
||||
else:
|
||||
one_minus_decay = ops.convert_to_tensor(self._one_minus_decay)
|
||||
one_minus_decay = self._one_minus_decay
|
||||
if training_value or training_value is None:
|
||||
mean_update = self._assign_moving_average(self.moving_mean, mean,
|
||||
one_minus_decay)
|
||||
|
|
|
|||
|
|
@ -2317,10 +2317,6 @@ def conj(x, name=None):
|
|||
Raises:
|
||||
TypeError: If `x` is not a numeric tensor.
|
||||
"""
|
||||
if isinstance(x, ops.Tensor):
|
||||
dt = x.dtype
|
||||
if dt.is_floating or dt.is_integer:
|
||||
return x
|
||||
with ops.name_scope(name, "Conj", [x]) as name:
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
if x.dtype.is_complex or x.dtype == dtypes.variant:
|
||||
|
|
|
|||
|
|
@ -540,8 +540,16 @@ class ResourceVariable(variables.Variable):
|
|||
the read operation.
|
||||
"""
|
||||
with ops.name_scope("Read"):
|
||||
# Ensure we read the variable in the same device as the handle.
|
||||
with ops.device(self._handle_device):
|
||||
# In graph mode, ensure we read the variable in the same device as the
|
||||
# handle. In eager mode, however, this sometimes tries to read a GPU
|
||||
# variable in the CPU because the handle is host memory. For now, then, we
|
||||
# need to skip the device block in eager. TODO(apassos): eager should have
|
||||
# separate notions of device and memory, so handle.device can be GPU while
|
||||
# handle.memory_space is always CPU.
|
||||
if context.in_graph_mode():
|
||||
with ops.device(self._handle_device):
|
||||
value = self._read_variable_op()
|
||||
else:
|
||||
value = self._read_variable_op()
|
||||
# Return an identity so it can get placed on whatever device the context
|
||||
# specifies instead of the device where the variable is.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user