mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
[Java] Add base classes and utilities for operation wrappers. (#11188)
* Add base classes and utilities for operation wrappers. * Rename Input interface to Operand * Introduce changes after code review
This commit is contained in:
parent
a72fc31bca
commit
7c1fe9068b
|
|
@ -162,6 +162,32 @@ java_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
java_test(
|
||||||
|
name = "PrimitiveOpTest",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["src/test/java/org/tensorflow/op/PrimitiveOpTest.java"],
|
||||||
|
javacopts = JAVACOPTS,
|
||||||
|
test_class = "org.tensorflow.op.PrimitiveOpTest",
|
||||||
|
deps = [
|
||||||
|
":tensorflow",
|
||||||
|
":testutil",
|
||||||
|
"@junit",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
java_test(
|
||||||
|
name = "OperandsTest",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["src/test/java/org/tensorflow/op/OperandsTest.java"],
|
||||||
|
javacopts = JAVACOPTS,
|
||||||
|
test_class = "org.tensorflow.op.OperandsTest",
|
||||||
|
deps = [
|
||||||
|
":tensorflow",
|
||||||
|
":testutil",
|
||||||
|
"@junit",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "libtensorflow_jni",
|
name = "libtensorflow_jni",
|
||||||
srcs = select({
|
srcs = select({
|
||||||
|
|
|
||||||
|
|
@ -21,20 +21,20 @@ package org.tensorflow;
|
||||||
* <p>Example usage:
|
* <p>Example usage:
|
||||||
*
|
*
|
||||||
* <pre>{@code
|
* <pre>{@code
|
||||||
* // The "decodeJpeg" operation can be used as input to the "cast" operation
|
* // The "decodeJpeg" operation can be used as an operand to the "cast" operation
|
||||||
* Input decodeJpeg = ops.image().decodeJpeg(...);
|
* Operand decodeJpeg = ops.image().decodeJpeg(...);
|
||||||
* ops.math().cast(decodeJpeg, DataType.FLOAT);
|
* ops.math().cast(decodeJpeg, DataType.FLOAT);
|
||||||
*
|
*
|
||||||
* // The output "y" of the "unique" operation can be used as input to the "cast" operation
|
* // The output "y" of the "unique" operation can be used as an operand to the "cast" operation
|
||||||
* Output y = ops.array().unique(...).y();
|
* Output y = ops.array().unique(...).y();
|
||||||
* ops.math().cast(y, DataType.FLOAT);
|
* ops.math().cast(y, DataType.FLOAT);
|
||||||
*
|
*
|
||||||
* // The "split" operation can be used as input list to the "concat" operation
|
* // The "split" operation can be used as operand list to the "concat" operation
|
||||||
* Iterable<? extends Input> split = ops.array().split(...);
|
* Iterable<? extends Operand> split = ops.array().split(...);
|
||||||
* ops.array().concat(0, split);
|
* ops.array().concat(0, split);
|
||||||
* }</pre>
|
* }</pre>
|
||||||
*/
|
*/
|
||||||
public interface Input {
|
public interface Operand {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the symbolic handle of a tensor.
|
* Returns the symbolic handle of a tensor.
|
||||||
|
|
@ -91,6 +91,21 @@ public final class Operation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns symbolic handles to a list of tensors produced by this operation.
|
||||||
|
*
|
||||||
|
* @param idx index of the first tensor of the list
|
||||||
|
* @param length number of tensors in the list
|
||||||
|
* @return array of {@code Output}
|
||||||
|
*/
|
||||||
|
public Output[] outputList(int idx, int length) {
|
||||||
|
Output[] outputs = new Output[length];
|
||||||
|
for (int i = 0; i < length; ++i) {
|
||||||
|
outputs[i] = output(idx + i);
|
||||||
|
}
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
|
||||||
/** Returns a symbolic handle to one of the tensors produced by this operation. */
|
/** Returns a symbolic handle to one of the tensors produced by this operation. */
|
||||||
public Output output(int idx) {
|
public Output output(int idx) {
|
||||||
return new Output(this, idx);
|
return new Output(this, idx);
|
||||||
|
|
|
||||||
|
|
@ -22,10 +22,10 @@ import java.util.Objects;
|
||||||
* <p>An Output is a symbolic handle to a tensor. The value of the Tensor is computed by executing
|
* <p>An Output is a symbolic handle to a tensor. The value of the Tensor is computed by executing
|
||||||
* the {@link Operation} in a {@link Session}.
|
* the {@link Operation} in a {@link Session}.
|
||||||
*
|
*
|
||||||
* <p>By implementing the {@link Input} interface, instances of this class could also be passed
|
* <p>By implementing the {@link Operand} interface, instances of this class also act as operands to
|
||||||
* directly in input to an operation.
|
* {@link org.tensorflow.op.Op Op} instances.
|
||||||
*/
|
*/
|
||||||
public final class Output implements Input {
|
public final class Output implements Operand {
|
||||||
|
|
||||||
/** Handle to the idx-th output of the Operation {@code op}. */
|
/** Handle to the idx-th output of the Operation {@code op}. */
|
||||||
public Output(Operation op, int idx) {
|
public Output(Operation op, int idx) {
|
||||||
|
|
|
||||||
35
tensorflow/java/src/main/java/org/tensorflow/op/Op.java
Normal file
35
tensorflow/java/src/main/java/org/tensorflow/op/Op.java
Normal file
|
|
@ -0,0 +1,35 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package org.tensorflow.op;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A marker interface for all operation wrappers.
|
||||||
|
*
|
||||||
|
* <p>Operation wrappers provide strongly typed interfaces for building operations and linking them
|
||||||
|
* into a graph without the use of literals and indexes required by the core classes.
|
||||||
|
*
|
||||||
|
* <p>This interface allows keeping references to any operation wrapper using a common type.
|
||||||
|
*
|
||||||
|
* <pre>{@code
|
||||||
|
* // All values returned by an Ops call can be referred as a Op
|
||||||
|
* Op split = ops.array().split(...);
|
||||||
|
* Op shape = ops.array().shape(...);
|
||||||
|
*
|
||||||
|
* // All operations could be added to an Op collection
|
||||||
|
* Collection<Op> allOps = Arrays.asList(split, shape);
|
||||||
|
* }
|
||||||
|
*/
|
||||||
|
public interface Op {}
|
||||||
|
|
@ -0,0 +1,47 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package org.tensorflow.op;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.tensorflow.Operand;
|
||||||
|
import org.tensorflow.OperationBuilder;
|
||||||
|
import org.tensorflow.Output;
|
||||||
|
|
||||||
|
/** Utilities for manipulating operand related types and lists. */
|
||||||
|
public final class Operands {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts a list of {@link Operand} into an array of {@link Output}.
|
||||||
|
*
|
||||||
|
* <p>Operation wrappers need to convert back a list of inputs into an array of outputs in order
|
||||||
|
* to build an operation, see {@link OperationBuilder#addInputList(Output[])}.
|
||||||
|
*
|
||||||
|
* @param inputs an iteration of input operands
|
||||||
|
* @return an array of outputs
|
||||||
|
*/
|
||||||
|
public static Output[] asOutputs(Iterable<? extends Operand> inputs) {
|
||||||
|
List<Output> outputList = new ArrayList<>();
|
||||||
|
for (Operand input : inputs) {
|
||||||
|
outputList.add(input.asOutput());
|
||||||
|
}
|
||||||
|
return outputList.toArray(new Output[outputList.size()]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disabled constructor
|
||||||
|
private Operands() {}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,65 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package org.tensorflow.op;
|
||||||
|
|
||||||
|
import org.tensorflow.Operation;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A base class for {@link Op} implementations that are backed by a single {@link Operation}.
|
||||||
|
*
|
||||||
|
* <p>Each operation registered in the TensorFlow core is a primitive and is provided as a {@code
|
||||||
|
* PrimitiveOp}. Custom operations working with only one primitive may also derive from this class.
|
||||||
|
*/
|
||||||
|
public abstract class PrimitiveOp implements Op {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final int hashCode() {
|
||||||
|
return operation.hashCode();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final boolean equals(Object obj) {
|
||||||
|
if (this == obj) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// Note: we consider that all objects wrapping the same operation are equal, no matter their
|
||||||
|
// implementation
|
||||||
|
if (!(obj instanceof PrimitiveOp)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return operation.equals(((PrimitiveOp) obj).operation);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final String toString() {
|
||||||
|
return String.format("<%s '%s'>", operation.type(), operation.name());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Underlying operation. It is deliberately not exposed by a getter method to avoid any name
|
||||||
|
* conflict with generated methods of the subclasses.
|
||||||
|
*/
|
||||||
|
protected final Operation operation;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructor.
|
||||||
|
*
|
||||||
|
* @param operation the underlying operation
|
||||||
|
*/
|
||||||
|
protected PrimitiveOp(Operation operation) {
|
||||||
|
this.operation = operation;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -24,6 +24,7 @@ import static org.junit.Assert.fail;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
import org.junit.runners.JUnit4;
|
import org.junit.runners.JUnit4;
|
||||||
|
|
@ -153,6 +154,19 @@ public class OperationTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void outputList() {
|
||||||
|
try (Graph g = new Graph()) {
|
||||||
|
Operation split = TestUtil.split(g, "split", new int[] {0, 1, 2}, 3);
|
||||||
|
Output[] outputs = split.outputList(1, 2);
|
||||||
|
assertNotNull(outputs);
|
||||||
|
assertEquals(2, outputs.length);
|
||||||
|
for (int i = 0; i < outputs.length; ++i) {
|
||||||
|
assertEquals(i + 1, outputs[i].index());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static int split(int[] values, int num_split) {
|
private static int split(int[] values, int num_split) {
|
||||||
try (Graph g = new Graph()) {
|
try (Graph g = new Graph()) {
|
||||||
return g.opBuilder("Split", "Split")
|
return g.opBuilder("Split", "Split")
|
||||||
|
|
|
||||||
|
|
@ -48,6 +48,14 @@ public class TestUtil {
|
||||||
.output(0);
|
.output(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static Operation split(Graph g, String name, int[] values, int num_split) {
|
||||||
|
return g.opBuilder("Split", name)
|
||||||
|
.addInput(constant(g, "split_dim", 0))
|
||||||
|
.addInput(constant(g, "values", values))
|
||||||
|
.setAttr("num_split", num_split)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
public static void transpose_A_times_X(Graph g, int[][] a) {
|
public static void transpose_A_times_X(Graph g, int[][] a) {
|
||||||
matmul(g, "Y", constant(g, "A", a), placeholder(g, "X", DataType.INT32), true, false);
|
matmul(g, "Y", constant(g, "A", a), placeholder(g, "X", DataType.INT32), true, false);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,47 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package org.tensorflow.op;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertSame;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.JUnit4;
|
||||||
|
import org.tensorflow.Graph;
|
||||||
|
import org.tensorflow.Operation;
|
||||||
|
import org.tensorflow.Output;
|
||||||
|
import org.tensorflow.TestUtil;
|
||||||
|
|
||||||
|
/** Unit tests for {@link org.tensorflow.op.Operands}. */
|
||||||
|
@RunWith(JUnit4.class)
|
||||||
|
public class OperandsTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void createOutputArrayFromOperandList() {
|
||||||
|
try (Graph g = new Graph()) {
|
||||||
|
Operation split = TestUtil.split(g, "split", new int[] {0, 1, 2}, 3);
|
||||||
|
List<Output> list = Arrays.asList(split.output(0), split.output(2));
|
||||||
|
Output[] array = Operands.asOutputs(list);
|
||||||
|
assertEquals(list.size(), array.length);
|
||||||
|
assertSame(array[0], list.get(0));
|
||||||
|
assertSame(array[1], list.get(1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,57 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package org.tensorflow.op;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertFalse;
|
||||||
|
import static org.junit.Assert.assertNotEquals;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.tensorflow.Graph;
|
||||||
|
import org.tensorflow.Output;
|
||||||
|
import org.tensorflow.TestUtil;
|
||||||
|
|
||||||
|
public class PrimitiveOpTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void equalsHashcode() {
|
||||||
|
try (Graph g = new Graph()) {
|
||||||
|
Output array = TestUtil.constant(g, "array", new int[2]);
|
||||||
|
|
||||||
|
PrimitiveOp test1 =
|
||||||
|
new PrimitiveOp(g.opBuilder("Shape", "shape1").addInput(array).build()) {};
|
||||||
|
PrimitiveOp test2 =
|
||||||
|
new PrimitiveOp(g.opBuilder("Shape", "shape2").addInput(array).build()) {};
|
||||||
|
PrimitiveOp test3 = new PrimitiveOp(test1.operation) {};
|
||||||
|
|
||||||
|
// equals() tests
|
||||||
|
assertNotEquals(test1, test2);
|
||||||
|
assertEquals(test1, test3);
|
||||||
|
assertEquals(test3, test1);
|
||||||
|
assertNotEquals(test2, test3);
|
||||||
|
|
||||||
|
// hashcode() tests
|
||||||
|
Set<PrimitiveOp> ops = new HashSet<>();
|
||||||
|
assertTrue(ops.add(test1));
|
||||||
|
assertTrue(ops.add(test2));
|
||||||
|
assertFalse(ops.add(test3));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user