From 599165861e9f8656f80a0162b40575f55fae171a Mon Sep 17 00:00:00 2001 From: KB Sriram Date: Thu, 27 Jul 2017 10:53:57 -0700 Subject: [PATCH] Add the Constant operator class (#11559) Create a custom operator class to create constants in the Graph, and introduce the Operator marker annotation to identify operator classes. Please see #7149 for the master tracking issue. --- tensorflow/java/BUILD | 15 +- .../tensorflow/op/annotation/Operator.java | 112 ++++++++++++ .../java/org/tensorflow/op/core/Constant.java | 173 ++++++++++++++++++ .../org/tensorflow/op/core/ConstantTest.java | 131 +++++++++++++ 4 files changed, 430 insertions(+), 1 deletion(-) create mode 100644 tensorflow/java/src/main/java/org/tensorflow/op/annotation/Operator.java create mode 100644 tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java create mode 100644 tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index 9fb4821cb15..64b37677357 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -34,7 +34,7 @@ filegroup( filegroup( name = "java_op_sources", - srcs = glob(["src/main/java/org/tensorflow/op/*.java"]), + srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]), visibility = [ "//tensorflow/java:__pkg__", ], @@ -191,6 +191,19 @@ java_test( ], ) +java_test( + name = "ConstantTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/op/core/ConstantTest.java"], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.op.core.ConstantTest", + deps = [ + ":tensorflow", + ":testutil", + "@junit", + ], +) + filegroup( name = "libtensorflow_jni", srcs = select({ diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/annotation/Operator.java b/tensorflow/java/src/main/java/org/tensorflow/op/annotation/Operator.java new file mode 100644 index 00000000000..59476fb43d4 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/op/annotation/Operator.java @@ -0,0 +1,112 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.op.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotation used by classes to make TensorFlow operations conveniently accessible via {@code + * org.tensorflow.op.Ops}. + * + *

An annotation processor (TODO: not yet implemented) builds the {@code Ops} class by + * aggregating all classes annotated as {@code @Operator}s. Each annotated class must have at + * least one public static factory method named {@code create} that accepts a {@link + * org.tensorflow.op.Scope} as its first argument. The processor then adds a convenience method in + * the {@code Ops} class. For example: + * + *

{@code
+ * @Operator
+ * public final class MyOp implements Op {
+ *   public static MyOp create(Scope scope, Operand operand) {
+ *     ...
+ *   }
+ * }
+ * }
+ * + *

results in a method in the {@code Ops} class + * + *

{@code
+ * import org.tensorflow.op.Ops;
+ * ...
+ * Ops ops = new Ops(graph);
+ * ...
+ * ops.myOp(operand);
+ * // and has exactly the same effect as calling
+ * // MyOp.create(ops.getScope(), operand);
+ * }
+ */ +@Documented +@Target(ElementType.TYPE) +@Retention(RetentionPolicy.CLASS) +public @interface Operator { + /** + * Specify an optional group within the {@code Ops} class. + * + *

By default, an annotation processor will create convenience methods directly in the {@code + * Ops} class. An annotated operator may optionally choose to place the method within a group. For + * example: + * + *

{@code
+   * @Operator(group="math")
+   * public final class Add extends PrimitiveOp implements Operand {
+   *   ...
+   * }
+   * }
+ * + *

results in the {@code add} method placed within a {@code math} group within the {@code Ops} + * class. + * + *

{@code
+   * ops.math().add(...);
+   * }
+ * + *

The group name must be a valid Java + * identifier. + */ + String group() default ""; + + /** + * Name for the wrapper method used in the {@code Ops} class. + * + *

By default, a processor derives the method name in the {@code Ops} class from the class name + * of the operator. This attribute allow you to provide a different name instead. For example: + * + *

{@code
+   * @Operator(name="myOperation")
+   * public final class MyRealOperation implements Operand {
+   *   public static MyRealOperation create(...)
+   * }
+   * }
+ * + *

results in this method added to the {@code Ops} class + * + *

{@code
+   * ops.myOperation(...);
+   * // and is the same as calling
+   * // MyRealOperation.create(...)
+   * }
+ * + *

The name must be a valid Java + * identifier. + */ + String name() default ""; +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java new file mode 100644 index 00000000000..cd7931d3bb7 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java @@ -0,0 +1,173 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.op.core; + +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; +import org.tensorflow.DataType; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.Output; +import org.tensorflow.Tensor; +import org.tensorflow.op.PrimitiveOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Operator; + +/** An operator producing a constant value. */ +@Operator +public final class Constant extends PrimitiveOp implements Operand { + /** + * Create a constant from a Java object. + * + *

The argument {@code object} is first converted into a Tensor using {@link + * org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be + * provided. For example: + * + *

{@code
+   * Constant.create(scope, 7); // returns a constant scalar tensor 7
+   * }
+ * + * @param scope is a scope used to add the underlying operation. + * @param object a Java object representing the constant. + * @see org.tensorflow.Tensor#create(Object) Tensor.create + */ + public static Constant create(Scope scope, Object object) { + try (Tensor value = Tensor.create(object)) { + return createWithTensor(scope, value); + } + } + + /** + * Create a {@link DataType#INT32} constant with data from the given buffer. + * + *

Creates a constant with the given shape by copying elements from the buffer (starting from + * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents + * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this + * method. + * + * @param scope is a scope used to add the underlying operation. + * @param shape the tensor shape. + * @param data a buffer containing the tensor data. + * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer + */ + public static Constant create(Scope scope, long[] shape, IntBuffer data) { + try (Tensor value = Tensor.create(shape, data)) { + return createWithTensor(scope, value); + } + } + + /** + * Create a {@link DataType#FLOAT} constant with data from the given buffer. + * + *

Creates a constant with the given shape by copying elements from the buffer (starting from + * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents + * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this + * method. + * + * @param scope is a scope used to add the underlying operation. + * @param shape the tensor shape. + * @param data a buffer containing the tensor data. + * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer + */ + public static Constant create(Scope scope, long[] shape, FloatBuffer data) { + try (Tensor value = Tensor.create(shape, data)) { + return createWithTensor(scope, value); + } + } + + /** + * Create a {@link DataType#DOUBLE} constant with data from the given buffer. + * + *

Creates a constant with the given shape by copying elements from the buffer (starting from + * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents + * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this + * method. + * + * @param scope is a scope used to add the underlying operation. + * @param shape the tensor shape. + * @param data a buffer containing the tensor data. + * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer + */ + public static Constant create(Scope scope, long[] shape, DoubleBuffer data) { + try (Tensor value = Tensor.create(shape, data)) { + return createWithTensor(scope, value); + } + } + + /** + * Create a {@link DataType#INT64} constant with data from the given buffer. + * + *

Creates a constant with the given shape by copying elements from the buffer (starting from + * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents + * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this + * method. + * + * @param scope is a scope used to add the underlying operation. + * @param shape the tensor shape. + * @param data a buffer containing the tensor data. + * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer + */ + public static Constant create(Scope scope, long[] shape, LongBuffer data) { + try (Tensor value = Tensor.create(shape, data)) { + return createWithTensor(scope, value); + } + } + + /** + * Create a constant with data from the given buffer. + * + *

Creates a Constant with the provided shape of any type where the constant data has been + * encoded into {@code data} as per the specification of the TensorFlow C API. + * + * @param scope is a scope used to add the underlying operation. + * @param dataType the tensor datatype. + * @param shape the tensor shape. + * @param data a buffer containing the tensor data. + * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the + * buffer + */ + public static Constant create(Scope scope, DataType dataType, long[] shape, ByteBuffer data) { + try (Tensor value = Tensor.create(dataType, shape, data)) { + return createWithTensor(scope, value); + } + } + + private static Constant createWithTensor(Scope scope, Tensor value) { + return new Constant( + scope + .graph() + .opBuilder("Const", scope.makeOpName("Const")) + .setAttr("value", value) + .setAttr("dtype", value.dataType()) + .build()); + } + + @Override + public Output asOutput() { + return output; + } + + private Constant(Operation operation) { + super(operation); + output = operation.output(0); + } + + private final Output output; +} diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java new file mode 100644 index 00000000000..ec237924855 --- /dev/null +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java @@ -0,0 +1,131 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.op.core; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertTrue; + +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.tensorflow.DataType; +import org.tensorflow.Graph; +import org.tensorflow.Session; +import org.tensorflow.Tensor; +import org.tensorflow.op.Scope; + +@RunWith(JUnit4.class) +public class ConstantTest { + private static final float EPSILON = 1e-7f; + + @Test + public void createIntBuffer() { + int[] ints = {1, 2, 3, 4}; + long[] shape = {4}; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant op = Constant.create(scope, shape, IntBuffer.wrap(ints)); + Tensor result = sess.runner().fetch(op.asOutput()).run().get(0); + int[] actual = new int[ints.length]; + assertArrayEquals(ints, result.copyTo(actual)); + } + } + + @Test + public void createFloatBuffer() { + float[] floats = {1, 2, 3, 4}; + long[] shape = {4}; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant op = Constant.create(scope, shape, FloatBuffer.wrap(floats)); + Tensor result = sess.runner().fetch(op.asOutput()).run().get(0); + float[] actual = new float[floats.length]; + assertArrayEquals(floats, result.copyTo(actual), EPSILON); + } + } + + @Test + public void createDoubleBuffer() { + double[] doubles = {1, 2, 3, 4}; + long[] shape = {4}; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant op = Constant.create(scope, shape, DoubleBuffer.wrap(doubles)); + Tensor result = sess.runner().fetch(op.asOutput()).run().get(0); + double[] actual = new double[doubles.length]; + assertArrayEquals(doubles, result.copyTo(actual), EPSILON); + } + } + + @Test + public void createLongBuffer() { + long[] longs = {1, 2, 3, 4}; + long[] shape = {4}; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant op = Constant.create(scope, shape, LongBuffer.wrap(longs)); + Tensor result = sess.runner().fetch(op.asOutput()).run().get(0); + long[] actual = new long[longs.length]; + assertArrayEquals(longs, result.copyTo(actual)); + } + } + + @Test + public void createStringBuffer() throws IOException { + + byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4}; + long[] shape = {}; + + // byte arrays (DataType.STRING in Tensorflow) are encoded as an offset in the data buffer, + // followed by a varint encoded size, followed by the data. + ByteArrayOutputStream baout = new ByteArrayOutputStream(); + DataOutputStream out = new DataOutputStream(baout); + // Offset in array. + out.writeLong(0L); + // Varint encoded length of buffer. + // For any number < 0x80, the varint encoding is simply the number itself. + // https://developers.google.com/protocol-buffers/docs/encoding#varints + assertTrue(data.length < 0x80); + out.write(data.length); + out.write(data); + out.close(); + byte[] content = baout.toByteArray(); + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant op = Constant.create(scope, DataType.STRING, shape, ByteBuffer.wrap(content)); + Tensor result = sess.runner().fetch(op.asOutput()).run().get(0); + assertArrayEquals(data, result.bytesValue()); + } + } +}