mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Add the Constant operator class (#11559)
Create a custom operator class to create constants in the Graph, and introduce the Operator marker annotation to identify operator classes. Please see #7149 for the master tracking issue.
This commit is contained in:
parent
d09304fca4
commit
599165861e
|
|
@ -34,7 +34,7 @@ filegroup(
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "java_op_sources",
|
name = "java_op_sources",
|
||||||
srcs = glob(["src/main/java/org/tensorflow/op/*.java"]),
|
srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]),
|
||||||
visibility = [
|
visibility = [
|
||||||
"//tensorflow/java:__pkg__",
|
"//tensorflow/java:__pkg__",
|
||||||
],
|
],
|
||||||
|
|
@ -191,6 +191,19 @@ java_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
java_test(
|
||||||
|
name = "ConstantTest",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["src/test/java/org/tensorflow/op/core/ConstantTest.java"],
|
||||||
|
javacopts = JAVACOPTS,
|
||||||
|
test_class = "org.tensorflow.op.core.ConstantTest",
|
||||||
|
deps = [
|
||||||
|
":tensorflow",
|
||||||
|
":testutil",
|
||||||
|
"@junit",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "libtensorflow_jni",
|
name = "libtensorflow_jni",
|
||||||
srcs = select({
|
srcs = select({
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,112 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package org.tensorflow.op.annotation;
|
||||||
|
|
||||||
|
import java.lang.annotation.Documented;
|
||||||
|
import java.lang.annotation.ElementType;
|
||||||
|
import java.lang.annotation.Retention;
|
||||||
|
import java.lang.annotation.RetentionPolicy;
|
||||||
|
import java.lang.annotation.Target;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Annotation used by classes to make TensorFlow operations conveniently accessible via {@code
|
||||||
|
* org.tensorflow.op.Ops}.
|
||||||
|
*
|
||||||
|
* <p>An annotation processor (TODO: not yet implemented) builds the {@code Ops} class by
|
||||||
|
* aggregating all classes annotated as {@code @Operator}s. Each annotated class <b>must</b> have at
|
||||||
|
* least one public static factory method named {@code create} that accepts a {@link
|
||||||
|
* org.tensorflow.op.Scope} as its first argument. The processor then adds a convenience method in
|
||||||
|
* the {@code Ops} class. For example:
|
||||||
|
*
|
||||||
|
* <pre>{@code
|
||||||
|
* @Operator
|
||||||
|
* public final class MyOp implements Op {
|
||||||
|
* public static MyOp create(Scope scope, Operand operand) {
|
||||||
|
* ...
|
||||||
|
* }
|
||||||
|
* }
|
||||||
|
* }</pre>
|
||||||
|
*
|
||||||
|
* <p>results in a method in the {@code Ops} class
|
||||||
|
*
|
||||||
|
* <pre>{@code
|
||||||
|
* import org.tensorflow.op.Ops;
|
||||||
|
* ...
|
||||||
|
* Ops ops = new Ops(graph);
|
||||||
|
* ...
|
||||||
|
* ops.myOp(operand);
|
||||||
|
* // and has exactly the same effect as calling
|
||||||
|
* // MyOp.create(ops.getScope(), operand);
|
||||||
|
* }</pre>
|
||||||
|
*/
|
||||||
|
@Documented
|
||||||
|
@Target(ElementType.TYPE)
|
||||||
|
@Retention(RetentionPolicy.CLASS)
|
||||||
|
public @interface Operator {
|
||||||
|
/**
|
||||||
|
* Specify an optional group within the {@code Ops} class.
|
||||||
|
*
|
||||||
|
* <p>By default, an annotation processor will create convenience methods directly in the {@code
|
||||||
|
* Ops} class. An annotated operator may optionally choose to place the method within a group. For
|
||||||
|
* example:
|
||||||
|
*
|
||||||
|
* <pre>{@code
|
||||||
|
* @Operator(group="math")
|
||||||
|
* public final class Add extends PrimitiveOp implements Operand {
|
||||||
|
* ...
|
||||||
|
* }
|
||||||
|
* }</pre>
|
||||||
|
*
|
||||||
|
* <p>results in the {@code add} method placed within a {@code math} group within the {@code Ops}
|
||||||
|
* class.
|
||||||
|
*
|
||||||
|
* <pre>{@code
|
||||||
|
* ops.math().add(...);
|
||||||
|
* }</pre>
|
||||||
|
*
|
||||||
|
* <p>The group name must be a <a
|
||||||
|
* href="https://docs.oracle.com/javase/specs/jls/se7/html/jls-3.html#jls-3.8">valid Java
|
||||||
|
* identifier</a>.
|
||||||
|
*/
|
||||||
|
String group() default "";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Name for the wrapper method used in the {@code Ops} class.
|
||||||
|
*
|
||||||
|
* <p>By default, a processor derives the method name in the {@code Ops} class from the class name
|
||||||
|
* of the operator. This attribute allow you to provide a different name instead. For example:
|
||||||
|
*
|
||||||
|
* <pre>{@code
|
||||||
|
* @Operator(name="myOperation")
|
||||||
|
* public final class MyRealOperation implements Operand {
|
||||||
|
* public static MyRealOperation create(...)
|
||||||
|
* }
|
||||||
|
* }</pre>
|
||||||
|
*
|
||||||
|
* <p>results in this method added to the {@code Ops} class
|
||||||
|
*
|
||||||
|
* <pre>{@code
|
||||||
|
* ops.myOperation(...);
|
||||||
|
* // and is the same as calling
|
||||||
|
* // MyRealOperation.create(...)
|
||||||
|
* }</pre>
|
||||||
|
*
|
||||||
|
* <p>The name must be a <a
|
||||||
|
* href="https://docs.oracle.com/javase/specs/jls/se7/html/jls-3.html#jls-3.8">valid Java
|
||||||
|
* identifier</a>.
|
||||||
|
*/
|
||||||
|
String name() default "";
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,173 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package org.tensorflow.op.core;
|
||||||
|
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.nio.DoubleBuffer;
|
||||||
|
import java.nio.FloatBuffer;
|
||||||
|
import java.nio.IntBuffer;
|
||||||
|
import java.nio.LongBuffer;
|
||||||
|
import org.tensorflow.DataType;
|
||||||
|
import org.tensorflow.Operand;
|
||||||
|
import org.tensorflow.Operation;
|
||||||
|
import org.tensorflow.Output;
|
||||||
|
import org.tensorflow.Tensor;
|
||||||
|
import org.tensorflow.op.PrimitiveOp;
|
||||||
|
import org.tensorflow.op.Scope;
|
||||||
|
import org.tensorflow.op.annotation.Operator;
|
||||||
|
|
||||||
|
/** An operator producing a constant value. */
|
||||||
|
@Operator
|
||||||
|
public final class Constant extends PrimitiveOp implements Operand {
|
||||||
|
/**
|
||||||
|
* Create a constant from a Java object.
|
||||||
|
*
|
||||||
|
* <p>The argument {@code object} is first converted into a Tensor using {@link
|
||||||
|
* org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be
|
||||||
|
* provided. For example:
|
||||||
|
*
|
||||||
|
* <pre>{@code
|
||||||
|
* Constant.create(scope, 7); // returns a constant scalar tensor 7
|
||||||
|
* }</pre>
|
||||||
|
*
|
||||||
|
* @param scope is a scope used to add the underlying operation.
|
||||||
|
* @param object a Java object representing the constant.
|
||||||
|
* @see org.tensorflow.Tensor#create(Object) Tensor.create
|
||||||
|
*/
|
||||||
|
public static Constant create(Scope scope, Object object) {
|
||||||
|
try (Tensor value = Tensor.create(object)) {
|
||||||
|
return createWithTensor(scope, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a {@link DataType#INT32} constant with data from the given buffer.
|
||||||
|
*
|
||||||
|
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
|
||||||
|
* its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
|
||||||
|
* a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
|
||||||
|
* method.
|
||||||
|
*
|
||||||
|
* @param scope is a scope used to add the underlying operation.
|
||||||
|
* @param shape the tensor shape.
|
||||||
|
* @param data a buffer containing the tensor data.
|
||||||
|
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
|
||||||
|
*/
|
||||||
|
public static Constant create(Scope scope, long[] shape, IntBuffer data) {
|
||||||
|
try (Tensor value = Tensor.create(shape, data)) {
|
||||||
|
return createWithTensor(scope, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a {@link DataType#FLOAT} constant with data from the given buffer.
|
||||||
|
*
|
||||||
|
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
|
||||||
|
* its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
|
||||||
|
* a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
|
||||||
|
* method.
|
||||||
|
*
|
||||||
|
* @param scope is a scope used to add the underlying operation.
|
||||||
|
* @param shape the tensor shape.
|
||||||
|
* @param data a buffer containing the tensor data.
|
||||||
|
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
|
||||||
|
*/
|
||||||
|
public static Constant create(Scope scope, long[] shape, FloatBuffer data) {
|
||||||
|
try (Tensor value = Tensor.create(shape, data)) {
|
||||||
|
return createWithTensor(scope, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a {@link DataType#DOUBLE} constant with data from the given buffer.
|
||||||
|
*
|
||||||
|
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
|
||||||
|
* its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
|
||||||
|
* a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
|
||||||
|
* method.
|
||||||
|
*
|
||||||
|
* @param scope is a scope used to add the underlying operation.
|
||||||
|
* @param shape the tensor shape.
|
||||||
|
* @param data a buffer containing the tensor data.
|
||||||
|
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
|
||||||
|
*/
|
||||||
|
public static Constant create(Scope scope, long[] shape, DoubleBuffer data) {
|
||||||
|
try (Tensor value = Tensor.create(shape, data)) {
|
||||||
|
return createWithTensor(scope, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a {@link DataType#INT64} constant with data from the given buffer.
|
||||||
|
*
|
||||||
|
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
|
||||||
|
* its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
|
||||||
|
* a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
|
||||||
|
* method.
|
||||||
|
*
|
||||||
|
* @param scope is a scope used to add the underlying operation.
|
||||||
|
* @param shape the tensor shape.
|
||||||
|
* @param data a buffer containing the tensor data.
|
||||||
|
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
|
||||||
|
*/
|
||||||
|
public static Constant create(Scope scope, long[] shape, LongBuffer data) {
|
||||||
|
try (Tensor value = Tensor.create(shape, data)) {
|
||||||
|
return createWithTensor(scope, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a constant with data from the given buffer.
|
||||||
|
*
|
||||||
|
* <p>Creates a Constant with the provided shape of any type where the constant data has been
|
||||||
|
* encoded into {@code data} as per the specification of the TensorFlow <a
|
||||||
|
* href="https://www.tensorflow.org/code/tensorflow/c/c_api.h">C API</a>.
|
||||||
|
*
|
||||||
|
* @param scope is a scope used to add the underlying operation.
|
||||||
|
* @param dataType the tensor datatype.
|
||||||
|
* @param shape the tensor shape.
|
||||||
|
* @param data a buffer containing the tensor data.
|
||||||
|
* @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the
|
||||||
|
* buffer
|
||||||
|
*/
|
||||||
|
public static Constant create(Scope scope, DataType dataType, long[] shape, ByteBuffer data) {
|
||||||
|
try (Tensor value = Tensor.create(dataType, shape, data)) {
|
||||||
|
return createWithTensor(scope, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Constant createWithTensor(Scope scope, Tensor value) {
|
||||||
|
return new Constant(
|
||||||
|
scope
|
||||||
|
.graph()
|
||||||
|
.opBuilder("Const", scope.makeOpName("Const"))
|
||||||
|
.setAttr("value", value)
|
||||||
|
.setAttr("dtype", value.dataType())
|
||||||
|
.build());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Output asOutput() {
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Constant(Operation operation) {
|
||||||
|
super(operation);
|
||||||
|
output = operation.output(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
private final Output output;
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,131 @@
|
||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package org.tensorflow.op.core;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
import java.io.ByteArrayOutputStream;
|
||||||
|
import java.io.DataOutputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.nio.DoubleBuffer;
|
||||||
|
import java.nio.FloatBuffer;
|
||||||
|
import java.nio.IntBuffer;
|
||||||
|
import java.nio.LongBuffer;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.JUnit4;
|
||||||
|
import org.tensorflow.DataType;
|
||||||
|
import org.tensorflow.Graph;
|
||||||
|
import org.tensorflow.Session;
|
||||||
|
import org.tensorflow.Tensor;
|
||||||
|
import org.tensorflow.op.Scope;
|
||||||
|
|
||||||
|
@RunWith(JUnit4.class)
|
||||||
|
public class ConstantTest {
|
||||||
|
private static final float EPSILON = 1e-7f;
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void createIntBuffer() {
|
||||||
|
int[] ints = {1, 2, 3, 4};
|
||||||
|
long[] shape = {4};
|
||||||
|
|
||||||
|
try (Graph g = new Graph();
|
||||||
|
Session sess = new Session(g)) {
|
||||||
|
Scope scope = new Scope(g);
|
||||||
|
Constant op = Constant.create(scope, shape, IntBuffer.wrap(ints));
|
||||||
|
Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
|
||||||
|
int[] actual = new int[ints.length];
|
||||||
|
assertArrayEquals(ints, result.copyTo(actual));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void createFloatBuffer() {
|
||||||
|
float[] floats = {1, 2, 3, 4};
|
||||||
|
long[] shape = {4};
|
||||||
|
|
||||||
|
try (Graph g = new Graph();
|
||||||
|
Session sess = new Session(g)) {
|
||||||
|
Scope scope = new Scope(g);
|
||||||
|
Constant op = Constant.create(scope, shape, FloatBuffer.wrap(floats));
|
||||||
|
Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
|
||||||
|
float[] actual = new float[floats.length];
|
||||||
|
assertArrayEquals(floats, result.copyTo(actual), EPSILON);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void createDoubleBuffer() {
|
||||||
|
double[] doubles = {1, 2, 3, 4};
|
||||||
|
long[] shape = {4};
|
||||||
|
|
||||||
|
try (Graph g = new Graph();
|
||||||
|
Session sess = new Session(g)) {
|
||||||
|
Scope scope = new Scope(g);
|
||||||
|
Constant op = Constant.create(scope, shape, DoubleBuffer.wrap(doubles));
|
||||||
|
Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
|
||||||
|
double[] actual = new double[doubles.length];
|
||||||
|
assertArrayEquals(doubles, result.copyTo(actual), EPSILON);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void createLongBuffer() {
|
||||||
|
long[] longs = {1, 2, 3, 4};
|
||||||
|
long[] shape = {4};
|
||||||
|
|
||||||
|
try (Graph g = new Graph();
|
||||||
|
Session sess = new Session(g)) {
|
||||||
|
Scope scope = new Scope(g);
|
||||||
|
Constant op = Constant.create(scope, shape, LongBuffer.wrap(longs));
|
||||||
|
Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
|
||||||
|
long[] actual = new long[longs.length];
|
||||||
|
assertArrayEquals(longs, result.copyTo(actual));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void createStringBuffer() throws IOException {
|
||||||
|
|
||||||
|
byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4};
|
||||||
|
long[] shape = {};
|
||||||
|
|
||||||
|
// byte arrays (DataType.STRING in Tensorflow) are encoded as an offset in the data buffer,
|
||||||
|
// followed by a varint encoded size, followed by the data.
|
||||||
|
ByteArrayOutputStream baout = new ByteArrayOutputStream();
|
||||||
|
DataOutputStream out = new DataOutputStream(baout);
|
||||||
|
// Offset in array.
|
||||||
|
out.writeLong(0L);
|
||||||
|
// Varint encoded length of buffer.
|
||||||
|
// For any number < 0x80, the varint encoding is simply the number itself.
|
||||||
|
// https://developers.google.com/protocol-buffers/docs/encoding#varints
|
||||||
|
assertTrue(data.length < 0x80);
|
||||||
|
out.write(data.length);
|
||||||
|
out.write(data);
|
||||||
|
out.close();
|
||||||
|
byte[] content = baout.toByteArray();
|
||||||
|
|
||||||
|
try (Graph g = new Graph();
|
||||||
|
Session sess = new Session(g)) {
|
||||||
|
Scope scope = new Scope(g);
|
||||||
|
Constant op = Constant.create(scope, DataType.STRING, shape, ByteBuffer.wrap(content));
|
||||||
|
Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
|
||||||
|
assertArrayEquals(data, result.bytesValue());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user