diff --git a/AUTHORS b/AUTHORS index a46ae7e616a..aa4be5169dc 100644 --- a/AUTHORS +++ b/AUTHORS @@ -7,4 +7,4 @@ # The email address is not required for organizations. Google Inc. -Yuan Tang terrytangyuan@gmail.com +Yuan Tang diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 43abdaafbf4..1b537ca73cc 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -114,6 +114,7 @@ pylint --rcfile=/tmp/pylintrc myfile.py * [Google Java Style Guide](https://google.github.io/styleguide/javaguide.html) * [Google JavaScript Style Guide](https://google.github.io/styleguide/jsguide.html) * [Google Shell Style Guide](https://google.github.io/styleguide/shell.xml) +* [Google Objective-C Style Guide](http://google.github.io/styleguide/objcguide.html) #### Running sanity check diff --git a/configure.py b/configure.py index 99c0a8d3215..680448d7b6a 100644 --- a/configure.py +++ b/configure.py @@ -1088,6 +1088,28 @@ def set_computecpp_toolkit_path(environ_cp): computecpp_toolkit_path) +def set_trisycl_include_dir(environ_cp): + """Set TRISYCL_INCLUDE_DIR.""" + ask_trisycl_include_dir = ('Please specify the location of the triSYCL ' + 'include directory. (Use --config=sycl_trisycl ' + 'when building with Bazel) ' + '[Default is %s]: ') % _DEFAULT_TRISYCL_INCLUDE_DIR + while True: + trisycl_include_dir = get_from_env_or_user_or_default( + environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir, + _DEFAULT_TRISYCL_INCLUDE_DIR) + if os.path.exists(trisycl_include_dir): + break + + print('Invalid triSYCL include directory, %s cannot be found' + % (trisycl_include_dir)) + + # Set TRISYCL_INCLUDE_DIR + environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir + write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', + trisycl_include_dir) + + def set_trisycl_include_dir(environ_cp): """Set TRISYCL_INCLUDE_DIR.""" diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 09fadfcab51..13a3bba5e6d 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -196,6 +196,18 @@ Status MaxPoolGradV2Helper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("MaxPoolV2", MaxPoolGradV2Helper); +Status LRNGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs){ + internal::LRNGrad::Attrs grad_attrs; + + auto dx = internal::LRNGrad(scope, grad_inputs[0], op.input(0), op.output(0), + grad_attrs); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("LRN", LRNGradHelper); + } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index ac66f51cf01..f9063e83650 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -191,5 +191,12 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) { RunTest(x, x_init_value, y, y_shape); } +TEST_F(NNGradTest, LRN){ + TensorShape x_shape({1, 1, 2, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + auto y = LRN(scope_, x); + RunTest(x, x_shape, y, x_shape); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/xla/service/cpu/disassembler.h b/tensorflow/compiler/xla/service/cpu/disassembler.h index b6feaa7e45c..5e302f88990 100644 --- a/tensorflow/compiler/xla/service/cpu/disassembler.h +++ b/tensorflow/compiler/xla/service/cpu/disassembler.h @@ -37,7 +37,7 @@ struct DisassemblerResult { DisassemblerResult(const string& text, size_t code_size_bytes) : text(text), code_size_bytes(code_size_bytes) {} - // The dissassembled text sections of the object file. + // The disassembled text sections of the object file. string text; // The total number of bytes of executable code in the object file. uint64_t code_size_bytes; @@ -53,7 +53,7 @@ class Disassembler { // Returns a DisassemblerResult for the given object file, containing the // disassembled code. // - // If we couldnt' retrieve a disassembler for this platform, an error status + // If we couldn't retrieve a disassembler for this platform, an error status // is returned. StatusOr DisassembleObjectFile( const llvm::object::ObjectFile& object_file) const; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 5e798c2045d..03cf9aaf907 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -428,7 +428,7 @@ class HloInstruction { Status RemoveControlDependencyTo(HloInstruction* instruction); // Returns the set of control predecessors (successors) of this - // instruction. Control predecessors (sucessors) must execute before (after) + // instruction. Control predecessors (successors) must execute before (after) // the current instruction. const std::vector& control_predecessors() const { return control_predecessors_; diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index b7ade951150..61f7821519b 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -64,6 +64,7 @@ py_library( "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py", "//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/opt:opt_py", + "//tensorflow/contrib/periodic_resample:init_py", "//tensorflow/contrib/predictor", "//tensorflow/contrib/quantization:quantization_py", "//tensorflow/contrib/quantize:quantize_graph", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 1eda1abfcf7..08247c6b38a 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -55,6 +55,7 @@ from tensorflow.contrib import model_pruning from tensorflow.contrib import nccl from tensorflow.contrib import nn from tensorflow.contrib import opt +from tensorflow.contrib import periodic_resample from tensorflow.contrib import predictor from tensorflow.contrib import quantization from tensorflow.contrib import quantize diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD index aa8f5ed12bc..fe8bd072afd 100644 --- a/tensorflow/contrib/cloud/BUILD +++ b/tensorflow/contrib/cloud/BUILD @@ -60,9 +60,7 @@ tf_py_test( size = "small", srcs = ["python/ops/bigquery_reader_ops_test.py"], additional_deps = [ - ":bigquery_reader_ops_op_lib", ":cloud_py", - "//tensorflow/contrib/cloud/kernels:bigquery_reader_ops", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt new file mode 100644 index 00000000000..a0fca690ef6 --- /dev/null +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -0,0 +1,449 @@ +tensorflow +tensorflow/core +tensorflow/core/example +tensorflow/core/framework +tensorflow/core/lib +tensorflow/core/lib/core +tensorflow/core/protobuf +tensorflow/core/util +tensorflow/examples +tensorflow/examples/tutorials +tensorflow/examples/tutorials/mnist +tensorflow/python +tensorflow/python/client +tensorflow/python/data +tensorflow/python/data/ops +tensorflow/python/data/util +tensorflow/python/debug +tensorflow/python/debug/cli +tensorflow/python/debug/examples +tensorflow/python/debug/lib +tensorflow/python/debug/wrappers +tensorflow/python/eager +tensorflow/python/estimator +tensorflow/python/estimator/canned +tensorflow/python/estimator/export +tensorflow/python/estimator/inputs +tensorflow/python/estimator/inputs/queues +tensorflow/python/feature_column +tensorflow/python/framework +tensorflow/python/grappler +tensorflow/python/keras +tensorflow/python/keras/activations +tensorflow/python/keras/applications +tensorflow/python/keras/applications/inception_resnet_v2 +tensorflow/python/keras/applications/inception_v3 +tensorflow/python/keras/applications/mobilenet +tensorflow/python/keras/applications/resnet50 +tensorflow/python/keras/applications/vgg16 +tensorflow/python/keras/applications/vgg19 +tensorflow/python/keras/applications/xception +tensorflow/python/keras/backend +tensorflow/python/keras/callbacks +tensorflow/python/keras/constraints +tensorflow/python/keras/datasets +tensorflow/python/keras/datasets/boston_housing +tensorflow/python/keras/datasets/cifar10 +tensorflow/python/keras/datasets/cifar100 +tensorflow/python/keras/datasets/fashion_mnist +tensorflow/python/keras/datasets/imdb +tensorflow/python/keras/datasets/mnist +tensorflow/python/keras/datasets/reuters +tensorflow/python/keras/estimator +tensorflow/python/keras/initializers +tensorflow/python/keras/layers +tensorflow/python/keras/losses +tensorflow/python/keras/metrics +tensorflow/python/keras/models +tensorflow/python/keras/optimizers +tensorflow/python/keras/preprocessing +tensorflow/python/keras/preprocessing/image +tensorflow/python/keras/preprocessing/sequence +tensorflow/python/keras/preprocessing/text +tensorflow/python/keras/regularizers +tensorflow/python/keras/utils +tensorflow/python/keras/wrappers +tensorflow/python/keras/wrappers/scikit_learn +tensorflow/python/keras/_impl +tensorflow/python/keras/_impl/keras +tensorflow/python/keras/_impl/keras/applications +tensorflow/python/keras/_impl/keras/datasets +tensorflow/python/keras/_impl/keras/engine +tensorflow/python/keras/_impl/keras/layers +tensorflow/python/keras/_impl/keras/preprocessing +tensorflow/python/keras/_impl/keras/utils +tensorflow/python/keras/_impl/keras/wrappers +tensorflow/python/kernel_tests +tensorflow/python/kernel_tests/distributions +tensorflow/python/kernel_tests/linalg +tensorflow/python/kernel_tests/random +tensorflow/python/layers +tensorflow/python/lib +tensorflow/python/lib/core +tensorflow/python/lib/io +tensorflow/python/ops +tensorflow/python/ops/distributions +tensorflow/python/ops/linalg +tensorflow/python/ops/losses +tensorflow/python/platform +tensorflow/python/platform/default +tensorflow/python/platform/summary +tensorflow/python/profiler/ +tensorflow/python/profiler/internal +tensorflow/python/saved_model +tensorflow/python/summary +tensorflow/python/summary/writer +tensorflow/python/tools +tensorflow/python/training +tensorflow/python/user_ops +tensorflow/python/util +tensorflow/python/util/protobuf +tensorflow/tools +tensorflow/tools/graph_transforms +tensorflow/contrib +tensorflow/contrib/all_reduce +tensorflow/contrib/all_reduce/python +tensorflow/contrib/android +tensorflow/contrib/android/java +tensorflow/contrib/android/java/org +tensorflow/contrib/android/java/org/tensorflow +tensorflow/contrib/android/java/org/tensorflow/contrib +tensorflow/contrib/android/java/org/tensorflow/contrib/android +tensorflow/contrib/android/jni +tensorflow/contrib/batching +tensorflow/contrib/batching/kernels +tensorflow/contrib/batching/python +tensorflow/contrib/batching/python/ops +tensorflow/contrib/bayesflow +tensorflow/contrib/bayesflow/examples +tensorflow/contrib/bayesflow/examples/reinforce_simple +tensorflow/contrib/bayesflow/python +tensorflow/contrib/bayesflow/python/ops +tensorflow/contrib/boosted_trees +tensorflow/contrib/boosted_trees/estimator_batch +tensorflow/contrib/boosted_trees/kernels +tensorflow/contrib/boosted_trees/ops +tensorflow/contrib/boosted_trees/proto +tensorflow/contrib/boosted_trees/python +tensorflow/contrib/boosted_trees/python/ops +tensorflow/contrib/cloud +tensorflow/contrib/cloud/kernels +tensorflow/contrib/cloud/ops +tensorflow/contrib/cloud/python +tensorflow/contrib/cloud/python/ops +tensorflow/contrib/cluster_resolver +tensorflow/contrib/cluster_resolver/python +tensorflow/contrib/cluster_resolver/python/training +tensorflow/contrib/compiler +tensorflow/contrib/copy_graph +tensorflow/contrib/copy_graph/python +tensorflow/contrib/copy_graph/python/util +tensorflow/contrib/crf +tensorflow/contrib/crf/python +tensorflow/contrib/crf/python/ops +tensorflow/contrib/cudnn_rnn +tensorflow/contrib/cudnn_rnn/kernels +tensorflow/contrib/cudnn_rnn/ops +tensorflow/contrib/cudnn_rnn/python +tensorflow/contrib/cudnn_rnn/python/layers +tensorflow/contrib/cudnn_rnn/python/ops +tensorflow/contrib/data +tensorflow/contrib/data/kernels +tensorflow/contrib/data/python +tensorflow/contrib/data/python/kernel_tests +tensorflow/contrib/data/python/ops +tensorflow/contrib/decision_trees +tensorflow/contrib/decision_trees/proto +tensorflow/contrib/deprecated +tensorflow/contrib/distributions +tensorflow/contrib/distributions/python +tensorflow/contrib/distributions/python/ops +tensorflow/contrib/distributions/python/ops/bijectors +tensorflow/contrib/eager +tensorflow/contrib/eager/python +tensorflow/contrib/estimator +tensorflow/contrib/estimator/python +tensorflow/contrib/estimator/python/estimator +tensorflow/contrib/factorization +tensorflow/contrib/factorization/examples +tensorflow/contrib/factorization/kernels +tensorflow/contrib/factorization/ops +tensorflow/contrib/factorization/python +tensorflow/contrib/factorization/python/ops +tensorflow/contrib/ffmpeg +tensorflow/contrib/ffmpeg/default +tensorflow/contrib/framework +tensorflow/contrib/framework/kernels +tensorflow/contrib/framework/ops +tensorflow/contrib/framework/python +tensorflow/contrib/framework/python/framework +tensorflow/contrib/framework/python/ops +tensorflow/contrib/fused_conv +tensorflow/contrib/fused_conv/kernels +tensorflow/contrib/fused_conv/python +tensorflow/contrib/fused_conv/python/ops +tensorflow/contrib/gan +tensorflow/contrib/gan/python +tensorflow/contrib/gan/python/estimator +tensorflow/contrib/gan/python/estimator/python +tensorflow/contrib/gan/python/eval +tensorflow/contrib/gan/python/eval/python +tensorflow/contrib/gan/python/features +tensorflow/contrib/gan/python/features/python +tensorflow/contrib/gan/python/losses +tensorflow/contrib/gan/python/losses/python +tensorflow/contrib/graph_editor +tensorflow/contrib/graph_editor/examples +tensorflow/contrib/grid_rnn +tensorflow/contrib/grid_rnn/python +tensorflow/contrib/grid_rnn/python/ops +tensorflow/contrib/hooks +tensorflow/contrib/hooks/python +tensorflow/contrib/image +tensorflow/contrib/image/kernels +tensorflow/contrib/image/ops +tensorflow/contrib/image/python +tensorflow/contrib/image/python/ops +tensorflow/contrib/input_pipeline +tensorflow/contrib/input_pipeline/kernels +tensorflow/contrib/input_pipeline/ops +tensorflow/contrib/input_pipeline/python +tensorflow/contrib/input_pipeline/python/ops +tensorflow/contrib/integrate +tensorflow/contrib/integrate/python +tensorflow/contrib/integrate/python/ops +tensorflow/contrib/ios_examples +tensorflow/contrib/ios_examples/benchmark +tensorflow/contrib/ios_examples/benchmark/benchmark.xcodeproj +tensorflow/contrib/ios_examples/benchmark/data +tensorflow/contrib/ios_examples/camera +tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj +tensorflow/contrib/ios_examples/camera/en.lproj +tensorflow/contrib/ios_examples/simple +tensorflow/contrib/ios_examples/simple/data +tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj +tensorflow/contrib/keras +tensorflow/contrib/keras/api +tensorflow/contrib/keras/api/keras +tensorflow/contrib/keras/api/keras/activations +tensorflow/contrib/keras/api/keras/applications +tensorflow/contrib/keras/api/keras/applications/inception_v3 +tensorflow/contrib/keras/api/keras/applications/mobilenet +tensorflow/contrib/keras/api/keras/applications/resnet50 +tensorflow/contrib/keras/api/keras/applications/vgg16 +tensorflow/contrib/keras/api/keras/applications/vgg19 +tensorflow/contrib/keras/api/keras/applications/xception +tensorflow/contrib/keras/api/keras/backend +tensorflow/contrib/keras/api/keras/callbacks +tensorflow/contrib/keras/api/keras/constraints +tensorflow/contrib/keras/api/keras/datasets +tensorflow/contrib/keras/api/keras/datasets/boston_housing +tensorflow/contrib/keras/api/keras/datasets/cifar10 +tensorflow/contrib/keras/api/keras/datasets/cifar100 +tensorflow/contrib/keras/api/keras/datasets/imdb +tensorflow/contrib/keras/api/keras/datasets/mnist +tensorflow/contrib/keras/api/keras/datasets/reuters +tensorflow/contrib/keras/api/keras/initializers +tensorflow/contrib/keras/api/keras/layers +tensorflow/contrib/keras/api/keras/losses +tensorflow/contrib/keras/api/keras/metrics +tensorflow/contrib/keras/api/keras/models +tensorflow/contrib/keras/api/keras/optimizers +tensorflow/contrib/keras/api/keras/preprocessing +tensorflow/contrib/keras/api/keras/preprocessing/image +tensorflow/contrib/keras/api/keras/preprocessing/sequence +tensorflow/contrib/keras/api/keras/preprocessing/text +tensorflow/contrib/keras/api/keras/regularizers +tensorflow/contrib/keras/api/keras/utils +tensorflow/contrib/keras/api/keras/wrappers +tensorflow/contrib/keras/api/keras/wrappers/scikit_learn +tensorflow/contrib/kernel_methods +tensorflow/contrib/kernel_methods/python +tensorflow/contrib/kernel_methods/python/mappers +tensorflow/contrib/kfac +tensorflow/contrib/kfac/examples +tensorflow/contrib/kfac/python +tensorflow/contrib/kfac/python/ops +tensorflow/contrib/labeled_tensor +tensorflow/contrib/labeled_tensor/python +tensorflow/contrib/labeled_tensor/python/ops +tensorflow/contrib/layers +tensorflow/contrib/layers/kernels +tensorflow/contrib/layers/ops +tensorflow/contrib/layers/python +tensorflow/contrib/layers/python/layers +tensorflow/contrib/layers/python/ops +tensorflow/contrib/learn +tensorflow/contrib/learn/python +tensorflow/contrib/learn/python/learn +tensorflow/contrib/learn/python/learn/dataframe +tensorflow/contrib/learn/python/learn/dataframe/queues +tensorflow/contrib/learn/python/learn/dataframe/transforms +tensorflow/contrib/learn/python/learn/datasets +tensorflow/contrib/learn/python/learn/datasets/data +tensorflow/contrib/learn/python/learn/estimators +tensorflow/contrib/learn/python/learn/learn_io +tensorflow/contrib/learn/python/learn/ops +tensorflow/contrib/learn/python/learn/preprocessing +tensorflow/contrib/learn/python/learn/utils +tensorflow/contrib/legacy_seq2seq +tensorflow/contrib/legacy_seq2seq/python +tensorflow/contrib/legacy_seq2seq/python/ops +tensorflow/contrib/linalg +tensorflow/contrib/linalg/python +tensorflow/contrib/linalg/python/ops +tensorflow/contrib/linear_optimizer +tensorflow/contrib/linear_optimizer/kernels +tensorflow/contrib/linear_optimizer/kernels/g3doc +tensorflow/contrib/linear_optimizer/python +tensorflow/contrib/linear_optimizer/python/ops +tensorflow/contrib/lookup +tensorflow/contrib/losses +tensorflow/contrib/losses/python +tensorflow/contrib/losses/python/losses +tensorflow/contrib/losses/python/metric_learning +tensorflow/contrib/makefile +tensorflow/contrib/memory_stats +tensorflow/contrib/memory_stats/kernels +tensorflow/contrib/memory_stats/ops +tensorflow/contrib/memory_stats/python +tensorflow/contrib/memory_stats/python/ops +tensorflow/contrib/meta_graph_transform +tensorflow/contrib/metrics +tensorflow/contrib/metrics/ops +tensorflow/contrib/metrics/python +tensorflow/contrib/metrics/python/metrics +tensorflow/contrib/metrics/python/ops +tensorflow/contrib/model_pruning +tensorflow/contrib/model_pruning/examples +tensorflow/contrib/model_pruning/examples/cifar10 +tensorflow/contrib/model_pruning/python +tensorflow/contrib/model_pruning/python/layers +tensorflow/contrib/nccl +tensorflow/contrib/nccl/kernels +tensorflow/contrib/nccl/ops +tensorflow/contrib/nccl/python +tensorflow/contrib/nccl/python/ops +tensorflow/contrib/ndlstm +tensorflow/contrib/ndlstm/python +tensorflow/contrib/nearest_neighbor/kernels +tensorflow/contrib/nearest_neighbor/ops +tensorflow/contrib/nearest_neighbor/python +tensorflow/contrib/nearest_neighbor/python/ops +tensorflow/contrib/nn +tensorflow/contrib/nn/python +tensorflow/contrib/nn/python/ops +tensorflow/contrib/opt +tensorflow/contrib/opt/python +tensorflow/contrib/opt/python/training +tensorflow/contrib/pi_examples +tensorflow/contrib/pi_examples/camera +tensorflow/contrib/pi_examples/label_image +tensorflow/contrib/pi_examples/label_image/data +tensorflow/contrib/periodic_resample +tensorflow/contrib/periodic_resample/python +tensorflow/contrib/periodic_resample/python/kernels +tensorflow/contrib/periodic_resample/python/ops +tensorflow/contrib/predictor +tensorflow/contrib/quantization +tensorflow/contrib/quantization/python +tensorflow/contrib/quantize +tensorflow/contrib/quantize/python +tensorflow/contrib/receptive_field +tensorflow/contrib/receptive_field/python +tensorflow/contrib/reduce_slice_ops +tensorflow/contrib/reduce_slice_ops/kernels +tensorflow/contrib/reduce_slice_ops/ops +tensorflow/contrib/reduce_slice_ops/python +tensorflow/contrib/reduce_slice_ops/python/ops +tensorflow/contrib/remote_fused_graph/pylib +tensorflow/contrib/remote_fused_graph/pylib/python +tensorflow/contrib/remote_fused_graph/pylib/python/ops +tensorflow/contrib/resampler +tensorflow/contrib/resampler/kernels +tensorflow/contrib/resampler/ops +tensorflow/contrib/resampler/python +tensorflow/contrib/resampler/python/ops +tensorflow/contrib/rnn +tensorflow/contrib/rnn/kernels +tensorflow/contrib/rnn/ops +tensorflow/contrib/rnn/python +tensorflow/contrib/rnn/python/kernel_tests +tensorflow/contrib/rnn/python/ops +tensorflow/contrib/saved_model +tensorflow/contrib/saved_model/python +tensorflow/contrib/saved_model/python/saved_model +tensorflow/contrib/seq2seq +tensorflow/contrib/seq2seq/kernels +tensorflow/contrib/seq2seq/ops +tensorflow/contrib/seq2seq/python +tensorflow/contrib/seq2seq/python/ops +tensorflow/contrib/session_bundle +tensorflow/contrib/session_bundle/example +tensorflow/contrib/signal +tensorflow/contrib/signal/python +tensorflow/contrib/signal/python/ops +tensorflow/contrib/slim +tensorflow/contrib/slim/python +tensorflow/contrib/slim/python/slim +tensorflow/contrib/slim/python/slim/data +tensorflow/contrib/slim/python/slim/nets +tensorflow/contrib/solvers +tensorflow/contrib/solvers/python +tensorflow/contrib/solvers/python/ops +tensorflow/contrib/sparsemax +tensorflow/contrib/sparsemax/python +tensorflow/contrib/sparsemax/python/ops +tensorflow/contrib/specs +tensorflow/contrib/specs/python +tensorflow/contrib/staging +tensorflow/contrib/stat_summarizer +tensorflow/contrib/stat_summarizer/python +tensorflow/contrib/stateless +tensorflow/contrib/stateless/python +tensorflow/contrib/summary +tensorflow/contrib/tensorboard +tensorflow/contrib/tensorboard/plugins +tensorflow/contrib/tensorboard/plugins/projector +tensorflow/contrib/tensor_forest +tensorflow/contrib/tensor_forest/client +tensorflow/contrib/tensor_forest/core +tensorflow/contrib/tensor_forest/core/ops +tensorflow/contrib/tensor_forest/data +tensorflow/contrib/tensor_forest/hybrid +tensorflow/contrib/tensor_forest/hybrid/core +tensorflow/contrib/tensor_forest/hybrid/core/ops +tensorflow/contrib/tensor_forest/hybrid/ops +tensorflow/contrib/tensor_forest/hybrid/python +tensorflow/contrib/tensor_forest/hybrid/python/layers +tensorflow/contrib/tensor_forest/hybrid/python/models +tensorflow/contrib/tensor_forest/hybrid/python/ops +tensorflow/contrib/tensor_forest/kernels +tensorflow/contrib/tensor_forest/python +tensorflow/contrib/tensor_forest/python/ops +tensorflow/contrib/testing +tensorflow/contrib/testing/python +tensorflow/contrib/testing/python/framework +tensorflow/contrib/text +tensorflow/contrib/text/kernels +tensorflow/contrib/text/ops +tensorflow/contrib/text/python +tensorflow/contrib/text/python/ops +tensorflow/contrib/tfprof +tensorflow/contrib/timeseries +tensorflow/contrib/timeseries/examples +tensorflow/contrib/timeseries/examples/data +tensorflow/contrib/timeseries/python +tensorflow/contrib/timeseries/python/timeseries +tensorflow/contrib/timeseries/python/timeseries/state_space_models +tensorflow/contrib/tpu +tensorflow/contrib/tpu/ops +tensorflow/contrib/tpu/profiler +tensorflow/contrib/tpu/python +tensorflow/contrib/tpu/python/ops +tensorflow/contrib/tpu/python/profiler +tensorflow/contrib/tpu/python/tpu +tensorflow/contrib/training +tensorflow/contrib/training/python +tensorflow/contrib/training/python/training +tensorflow/contrib/util diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt new file mode 100644 index 00000000000..8a9c406d8b1 --- /dev/null +++ b/tensorflow/contrib/cmake/python_protos.txt @@ -0,0 +1,19 @@ +tensorflow/core +tensorflow/core/profiler +tensorflow/python +tensorflow/contrib/boosted_trees/proto +tensorflow/contrib/cloud/kernels +tensorflow/contrib/decision_trees/proto +tensorflow/contrib/gdr +tensorflow/contrib/lite/toco +tensorflow/contrib/mpi +tensorflow/contrib/mpi_collectives +tensorflow/contrib/session_bundle +tensorflow/contrib/tensor_forest/proto +tensorflow/contrib/tensorboard/graph_explorer/proto +tensorflow/contrib/tensorboard/plugins/projector +tensorflow/contrib/tensorboard/plugins/trace +tensorflow/contrib/tpu/proto +tensorflow/contrib/tpu/profiler +tensorflow/contrib/training/python/training +tensorflow/contrib/verbs diff --git a/tensorflow/contrib/cmake/python_protos_cc.txt b/tensorflow/contrib/cmake/python_protos_cc.txt new file mode 100644 index 00000000000..d4a257b25c8 --- /dev/null +++ b/tensorflow/contrib/cmake/python_protos_cc.txt @@ -0,0 +1,5 @@ +tensorflow/core/profiler +tensorflow/python +tensorflow/contrib/session_bundle +tensorflow/contrib/tensorboard +tensorflow/contrib/training diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 4a61ed7a354..e8c2cd34732 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -92,6 +92,7 @@ GENERATE_CONTRIB_OP_LIBRARY(image_sirds "${tensorflow_source_dir}/tensorflow/con GENERATE_CONTRIB_OP_LIBRARY(layers_sparse_feature_cross "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc") GENERATE_CONTRIB_OP_LIBRARY(memory_stats "${tensorflow_source_dir}/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(nccl "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(periodic_resample "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/ops/array_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(nearest_neighbor "${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/ops/nearest_neighbor_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(resampler "${tensorflow_source_dir}/tensorflow/contrib/resampler/ops/resampler_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(rnn_gru "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/gru_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index d102b442e7a..8db6929e31a 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -120,33 +120,34 @@ function(RELATIVE_PROTOBUF_GENERATE_CPP SRCS HDRS ROOT_DIR) set(${HDRS} ${${HDRS}} PARENT_SCOPE) endfunction() -file(GLOB_RECURSE tf_protos_python_srcs RELATIVE ${tensorflow_source_dir} - "${tensorflow_source_dir}/tensorflow/core/*.proto" - "${tensorflow_source_dir}/tensorflow/core/profiler/*.proto" - "${tensorflow_source_dir}/tensorflow/python/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/proto/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/decision_trees/proto/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/proto/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tpu/proto/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tpu/profiler/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/training/*.proto" -) +FILE(READ python_protos.txt python_protos) +# Convert file contents into a CMake list (where each element in the list is one line of the file) +STRING(REGEX REPLACE ";" "\\\\;" python_protos "${python_protos}") +STRING(REGEX REPLACE "\n" ";" python_protos "${python_protos}") + +foreach(python_proto ${python_protos}) + file(GLOB_RECURSE tf_python_protos_src RELATIVE ${tensorflow_source_dir} + "${tensorflow_source_dir}/${python_proto}/*.proto" + ) + list(APPEND tf_python_protos_srcs ${tf_python_protos_src}) +endforeach(python_proto) + RELATIVE_PROTOBUF_GENERATE_PYTHON( - ${tensorflow_source_dir} PYTHON_PROTO_GENFILES ${tf_protos_python_srcs} + ${tensorflow_source_dir} PYTHON_PROTO_GENFILES ${tf_python_protos_srcs} ) -# NOTE(mrry): Avoid regenerating the tensorflow/core protos because this -# can cause benign-but-failing-on-Windows-due-to-file-locking conflicts -# when two rules attempt to generate the same file. -file(GLOB_RECURSE tf_python_protos_cc_srcs RELATIVE ${tensorflow_source_dir} - "${tensorflow_source_dir}/tensorflow/core/profiler/*.proto" - "${tensorflow_source_dir}/tensorflow/python/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/training/*.proto" -) +FILE(READ python_protos_cc.txt python_protos_cc) +# Convert file contents into a CMake list (where each element in the list is one line of the file) +STRING(REGEX REPLACE ";" "\\\\;" python_protos_cc "${python_protos_cc}") +STRING(REGEX REPLACE "\n" ";" python_protos_cc "${python_protos_cc}") + +foreach(python_proto_cc ${python_protos_cc}) + file(GLOB_RECURSE tf_python_protos_cc_src RELATIVE ${tensorflow_source_dir} + "${tensorflow_source_dir}/${python_proto_cc}/*.proto" + ) + list(APPEND tf_python_protos_cc_srcs ${tf_python_protos_cc_src}) +endforeach(python_proto_cc) + RELATIVE_PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS ${tensorflow_source_dir} ${tf_python_protos_cc_srcs} ) @@ -192,315 +193,15 @@ function(add_python_module MODULE_NAME) endif() endfunction() -add_python_module("tensorflow") -add_python_module("tensorflow/core") -add_python_module("tensorflow/core/example") -add_python_module("tensorflow/core/framework") -add_python_module("tensorflow/core/lib") -add_python_module("tensorflow/core/lib/core") -add_python_module("tensorflow/core/protobuf") -add_python_module("tensorflow/core/util") -add_python_module("tensorflow/examples") -add_python_module("tensorflow/examples/tutorials") -add_python_module("tensorflow/examples/tutorials/mnist") -add_python_module("tensorflow/python") -add_python_module("tensorflow/python/client") -add_python_module("tensorflow/python/data") -add_python_module("tensorflow/python/data/ops") -add_python_module("tensorflow/python/data/util") -add_python_module("tensorflow/python/debug") -add_python_module("tensorflow/python/debug/cli") -add_python_module("tensorflow/python/debug/examples") -add_python_module("tensorflow/python/debug/lib") -add_python_module("tensorflow/python/debug/wrappers") -add_python_module("tensorflow/python/eager") -add_python_module("tensorflow/python/estimator") -add_python_module("tensorflow/python/estimator/canned") -add_python_module("tensorflow/python/estimator/export") -add_python_module("tensorflow/python/estimator/inputs") -add_python_module("tensorflow/python/estimator/inputs/queues") -add_python_module("tensorflow/python/feature_column") -add_python_module("tensorflow/python/framework") -add_python_module("tensorflow/python/grappler") -add_python_module("tensorflow/python/keras") -add_python_module("tensorflow/python/keras/activations") -add_python_module("tensorflow/python/keras/applications") -add_python_module("tensorflow/python/keras/applications/inception_resnet_v2") -add_python_module("tensorflow/python/keras/applications/inception_v3") -add_python_module("tensorflow/python/keras/applications/mobilenet") -add_python_module("tensorflow/python/keras/applications/resnet50") -add_python_module("tensorflow/python/keras/applications/vgg16") -add_python_module("tensorflow/python/keras/applications/vgg19") -add_python_module("tensorflow/python/keras/applications/xception") -add_python_module("tensorflow/python/keras/backend") -add_python_module("tensorflow/python/keras/callbacks") -add_python_module("tensorflow/python/keras/constraints") -add_python_module("tensorflow/python/keras/datasets") -add_python_module("tensorflow/python/keras/datasets/boston_housing") -add_python_module("tensorflow/python/keras/datasets/cifar10") -add_python_module("tensorflow/python/keras/datasets/cifar100") -add_python_module("tensorflow/python/keras/datasets/fashion_mnist") -add_python_module("tensorflow/python/keras/datasets/imdb") -add_python_module("tensorflow/python/keras/datasets/mnist") -add_python_module("tensorflow/python/keras/datasets/reuters") -add_python_module("tensorflow/python/keras/estimator") -add_python_module("tensorflow/python/keras/initializers") -add_python_module("tensorflow/python/keras/layers") -add_python_module("tensorflow/python/keras/losses") -add_python_module("tensorflow/python/keras/metrics") -add_python_module("tensorflow/python/keras/models") -add_python_module("tensorflow/python/keras/optimizers") -add_python_module("tensorflow/python/keras/preprocessing") -add_python_module("tensorflow/python/keras/preprocessing/image") -add_python_module("tensorflow/python/keras/preprocessing/sequence") -add_python_module("tensorflow/python/keras/preprocessing/text") -add_python_module("tensorflow/python/keras/regularizers") -add_python_module("tensorflow/python/keras/utils") -add_python_module("tensorflow/python/keras/wrappers") -add_python_module("tensorflow/python/keras/wrappers/scikit_learn") -add_python_module("tensorflow/python/keras/_impl") -add_python_module("tensorflow/python/keras/_impl/keras") -add_python_module("tensorflow/python/keras/_impl/keras/applications") -add_python_module("tensorflow/python/keras/_impl/keras/datasets") -add_python_module("tensorflow/python/keras/_impl/keras/engine") -add_python_module("tensorflow/python/keras/_impl/keras/layers") -add_python_module("tensorflow/python/keras/_impl/keras/preprocessing") -add_python_module("tensorflow/python/keras/_impl/keras/utils") -add_python_module("tensorflow/python/keras/_impl/keras/wrappers") -add_python_module("tensorflow/python/kernel_tests") -add_python_module("tensorflow/python/kernel_tests/distributions") -add_python_module("tensorflow/python/kernel_tests/linalg") -add_python_module("tensorflow/python/layers") -add_python_module("tensorflow/python/lib") -add_python_module("tensorflow/python/lib/core") -add_python_module("tensorflow/python/lib/io") -add_python_module("tensorflow/python/ops") -add_python_module("tensorflow/python/ops/distributions") -add_python_module("tensorflow/python/ops/linalg") -add_python_module("tensorflow/python/ops/losses") -add_python_module("tensorflow/python/platform") -add_python_module("tensorflow/python/platform/default") -add_python_module("tensorflow/python/platform/summary") -add_python_module("tensorflow/python/profiler/") -add_python_module("tensorflow/python/profiler/internal") -add_python_module("tensorflow/python/saved_model") -add_python_module("tensorflow/python/summary") -add_python_module("tensorflow/python/summary/writer") -add_python_module("tensorflow/python/tools") -add_python_module("tensorflow/python/training") -add_python_module("tensorflow/python/user_ops") -add_python_module("tensorflow/python/util") -add_python_module("tensorflow/python/util/protobuf") -add_python_module("tensorflow/tools") -add_python_module("tensorflow/tools/graph_transforms") -add_python_module("tensorflow/contrib") -add_python_module("tensorflow/contrib/all_reduce") -add_python_module("tensorflow/contrib/all_reduce/python") -add_python_module("tensorflow/contrib/android") -add_python_module("tensorflow/contrib/android/java") -add_python_module("tensorflow/contrib/android/java/org") -add_python_module("tensorflow/contrib/android/java/org/tensorflow") -add_python_module("tensorflow/contrib/android/java/org/tensorflow/contrib") -add_python_module("tensorflow/contrib/android/java/org/tensorflow/contrib/android") -add_python_module("tensorflow/contrib/android/jni") -add_python_module("tensorflow/contrib/bayesflow") -add_python_module("tensorflow/contrib/bayesflow/examples") -add_python_module("tensorflow/contrib/bayesflow/examples/reinforce_simple") -add_python_module("tensorflow/contrib/bayesflow/python") -add_python_module("tensorflow/contrib/bayesflow/python/kernel_tests") -add_python_module("tensorflow/contrib/bayesflow/python/ops") -add_python_module("tensorflow/contrib/boosted_trees") -add_python_module("tensorflow/contrib/boosted_trees/estimator_batch") -add_python_module("tensorflow/contrib/boosted_trees/ops") -add_python_module("tensorflow/contrib/boosted_trees/proto") -add_python_module("tensorflow/contrib/boosted_trees/python") -add_python_module("tensorflow/contrib/boosted_trees/python/kernel_tests") -add_python_module("tensorflow/contrib/boosted_trees/python/ops") -add_python_module("tensorflow/contrib/cloud") -add_python_module("tensorflow/contrib/cloud/kernels") -add_python_module("tensorflow/contrib/cloud/ops") -add_python_module("tensorflow/contrib/cloud/python") -add_python_module("tensorflow/contrib/cloud/python/ops") -add_python_module("tensorflow/contrib/cluster_resolver") -add_python_module("tensorflow/contrib/cluster_resolver/python") -add_python_module("tensorflow/contrib/cluster_resolver/python/training") -add_python_module("tensorflow/contrib/compiler") -add_python_module("tensorflow/contrib/copy_graph") -add_python_module("tensorflow/contrib/copy_graph/python") -add_python_module("tensorflow/contrib/copy_graph/python/util") -add_python_module("tensorflow/contrib/crf") -add_python_module("tensorflow/contrib/crf/python") -add_python_module("tensorflow/contrib/crf/python/kernel_tests") -add_python_module("tensorflow/contrib/crf/python/ops") -add_python_module("tensorflow/contrib/cudnn_rnn") -add_python_module("tensorflow/contrib/cudnn_rnn/kernels") -add_python_module("tensorflow/contrib/cudnn_rnn/ops") -add_python_module("tensorflow/contrib/cudnn_rnn/python") -add_python_module("tensorflow/contrib/cudnn_rnn/python/kernel_tests") -add_python_module("tensorflow/contrib/cudnn_rnn/python/layers") -add_python_module("tensorflow/contrib/cudnn_rnn/python/ops") -add_python_module("tensorflow/contrib/data") -add_python_module("tensorflow/contrib/data/python") -add_python_module("tensorflow/contrib/data/python/kernel_tests") -add_python_module("tensorflow/contrib/data/python/ops") -add_python_module("tensorflow/contrib/decision_trees") -add_python_module("tensorflow/contrib/decision_trees/proto") -add_python_module("tensorflow/contrib/deprecated") -add_python_module("tensorflow/contrib/distributions") -add_python_module("tensorflow/contrib/distributions/python") -add_python_module("tensorflow/contrib/distributions/python/kernel_tests") -add_python_module("tensorflow/contrib/distributions/python/ops") -add_python_module("tensorflow/contrib/distributions/python/ops/bijectors") -add_python_module("tensorflow/contrib/eager") -add_python_module("tensorflow/contrib/eager/python") -add_python_module("tensorflow/contrib/estimator") -add_python_module("tensorflow/contrib/estimator/python") -add_python_module("tensorflow/contrib/estimator/python/estimator") -add_python_module("tensorflow/contrib/factorization") -add_python_module("tensorflow/contrib/factorization/examples") -add_python_module("tensorflow/contrib/factorization/kernels") -add_python_module("tensorflow/contrib/factorization/ops") -add_python_module("tensorflow/contrib/factorization/python") -add_python_module("tensorflow/contrib/factorization/python/kernel_tests") -add_python_module("tensorflow/contrib/factorization/python/ops") -add_python_module("tensorflow/contrib/ffmpeg") -add_python_module("tensorflow/contrib/ffmpeg/default") -add_python_module("tensorflow/contrib/ffmpeg/testdata") -add_python_module("tensorflow/contrib/framework") -add_python_module("tensorflow/contrib/framework/kernels") -add_python_module("tensorflow/contrib/framework/ops") -add_python_module("tensorflow/contrib/framework/python") -add_python_module("tensorflow/contrib/framework/python/framework") -add_python_module("tensorflow/contrib/framework/python/ops") -add_python_module("tensorflow/contrib/gan") -add_python_module("tensorflow/contrib/gan/python") -add_python_module("tensorflow/contrib/gan/python/eval") -add_python_module("tensorflow/contrib/gan/python/eval/python") -add_python_module("tensorflow/contrib/gan/python/features") -add_python_module("tensorflow/contrib/gan/python/features/python") -add_python_module("tensorflow/contrib/gan/python/estimator") -add_python_module("tensorflow/contrib/gan/python/estimator/python") -add_python_module("tensorflow/contrib/gan/python/losses") -add_python_module("tensorflow/contrib/gan/python/losses/python") -add_python_module("tensorflow/contrib/graph_editor") -add_python_module("tensorflow/contrib/graph_editor/examples") -add_python_module("tensorflow/contrib/graph_editor/tests") -add_python_module("tensorflow/contrib/grid_rnn") -add_python_module("tensorflow/contrib/grid_rnn/python") -add_python_module("tensorflow/contrib/grid_rnn/python/kernel_tests") -add_python_module("tensorflow/contrib/grid_rnn/python/ops") -add_python_module("tensorflow/contrib/hooks") -add_python_module("tensorflow/contrib/image") -add_python_module("tensorflow/contrib/image/ops") -add_python_module("tensorflow/contrib/image/python") -add_python_module("tensorflow/contrib/image/python/ops") -add_python_module("tensorflow/contrib/input_pipeline") -add_python_module("tensorflow/contrib/input_pipeline/ops") -add_python_module("tensorflow/contrib/input_pipeline/python") -add_python_module("tensorflow/contrib/input_pipeline/python/ops") -add_python_module("tensorflow/contrib/integrate") -add_python_module("tensorflow/contrib/integrate/python") -add_python_module("tensorflow/contrib/integrate/python/ops") -add_python_module("tensorflow/contrib/ios_examples") -add_python_module("tensorflow/contrib/ios_examples/benchmark") -add_python_module("tensorflow/contrib/ios_examples/benchmark/benchmark.xcodeproj") -add_python_module("tensorflow/contrib/ios_examples/benchmark/data") -add_python_module("tensorflow/contrib/ios_examples/camera") -add_python_module("tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj") -add_python_module("tensorflow/contrib/ios_examples/camera/en.lproj") -add_python_module("tensorflow/contrib/ios_examples/simple") -add_python_module("tensorflow/contrib/ios_examples/simple/data") -add_python_module("tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj") -add_python_module("tensorflow/contrib/keras") -add_python_module("tensorflow/contrib/keras/api") -add_python_module("tensorflow/contrib/keras/api/keras") -add_python_module("tensorflow/contrib/keras/api/keras/activations") -add_python_module("tensorflow/contrib/keras/api/keras/applications") -add_python_module("tensorflow/contrib/keras/api/keras/applications/inception_v3") -add_python_module("tensorflow/contrib/keras/api/keras/applications/mobilenet") -add_python_module("tensorflow/contrib/keras/api/keras/applications/resnet50") -add_python_module("tensorflow/contrib/keras/api/keras/applications/vgg16") -add_python_module("tensorflow/contrib/keras/api/keras/applications/vgg19") -add_python_module("tensorflow/contrib/keras/api/keras/applications/xception") -add_python_module("tensorflow/contrib/keras/api/keras/backend") -add_python_module("tensorflow/contrib/keras/api/keras/callbacks") -add_python_module("tensorflow/contrib/keras/api/keras/constraints") -add_python_module("tensorflow/contrib/keras/api/keras/datasets") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/boston_housing") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/cifar10") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/cifar100") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/imdb") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/mnist") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/reuters") -add_python_module("tensorflow/contrib/keras/api/keras/initializers") -add_python_module("tensorflow/contrib/keras/api/keras/layers") -add_python_module("tensorflow/contrib/keras/api/keras/losses") -add_python_module("tensorflow/contrib/keras/api/keras/metrics") -add_python_module("tensorflow/contrib/keras/api/keras/models") -add_python_module("tensorflow/contrib/keras/api/keras/optimizers") -add_python_module("tensorflow/contrib/keras/api/keras/preprocessing") -add_python_module("tensorflow/contrib/keras/api/keras/preprocessing/image") -add_python_module("tensorflow/contrib/keras/api/keras/preprocessing/sequence") -add_python_module("tensorflow/contrib/keras/api/keras/preprocessing/text") -add_python_module("tensorflow/contrib/keras/api/keras/regularizers") -add_python_module("tensorflow/contrib/keras/api/keras/utils") -add_python_module("tensorflow/contrib/keras/api/keras/wrappers") -add_python_module("tensorflow/contrib/keras/api/keras/wrappers/scikit_learn") -add_python_module("tensorflow/contrib/keras/python") -add_python_module("tensorflow/contrib/keras/python/keras") -add_python_module("tensorflow/contrib/keras/python/keras/applications") -add_python_module("tensorflow/contrib/keras/python/keras/datasets") -add_python_module("tensorflow/contrib/keras/python/keras/engine") -add_python_module("tensorflow/contrib/keras/python/keras/layers") -add_python_module("tensorflow/contrib/keras/python/keras/preprocessing") -add_python_module("tensorflow/contrib/keras/python/keras/utils") -add_python_module("tensorflow/contrib/keras/python/keras/wrappers") -add_python_module("tensorflow/contrib/kernel_methods") -add_python_module("tensorflow/contrib/kernel_methods/python") -add_python_module("tensorflow/contrib/kernel_methods/python/mappers") -add_python_module("tensorflow/contrib/kfac") -add_python_module("tensorflow/contrib/kfac/examples") -add_python_module("tensorflow/contrib/kfac/python") -add_python_module("tensorflow/contrib/kfac/python/ops") -add_python_module("tensorflow/contrib/labeled_tensor") -add_python_module("tensorflow/contrib/labeled_tensor/python") -add_python_module("tensorflow/contrib/labeled_tensor/python/ops") -add_python_module("tensorflow/contrib/layers") -add_python_module("tensorflow/contrib/layers/kernels") -add_python_module("tensorflow/contrib/layers/ops") -add_python_module("tensorflow/contrib/layers/python") -add_python_module("tensorflow/contrib/layers/python/kernel_tests") -add_python_module("tensorflow/contrib/layers/python/layers") -add_python_module("tensorflow/contrib/layers/python/ops") -add_python_module("tensorflow/contrib/learn") -add_python_module("tensorflow/contrib/learn/python") -add_python_module("tensorflow/contrib/learn/python/learn") -add_python_module("tensorflow/contrib/learn/python/learn/dataframe") -add_python_module("tensorflow/contrib/learn/python/learn/dataframe/queues") -add_python_module("tensorflow/contrib/learn/python/learn/dataframe/transforms") -add_python_module("tensorflow/contrib/learn/python/learn/datasets") -add_python_module("tensorflow/contrib/learn/python/learn/datasets/data") -add_python_module("tensorflow/contrib/learn/python/learn/estimators") -add_python_module("tensorflow/contrib/learn/python/learn/learn_io") -add_python_module("tensorflow/contrib/learn/python/learn/ops") -add_python_module("tensorflow/contrib/learn/python/learn/preprocessing") -add_python_module("tensorflow/contrib/learn/python/learn/preprocessing/tests") -add_python_module("tensorflow/contrib/learn/python/learn/tests") -add_python_module("tensorflow/contrib/learn/python/learn/tests/dataframe") -add_python_module("tensorflow/contrib/learn/python/learn/utils") -add_python_module("tensorflow/contrib/legacy_seq2seq") -add_python_module("tensorflow/contrib/legacy_seq2seq/python") -add_python_module("tensorflow/contrib/legacy_seq2seq/python/ops") -add_python_module("tensorflow/contrib/linalg") -add_python_module("tensorflow/contrib/linalg/python") -add_python_module("tensorflow/contrib/linalg/python/ops") -add_python_module("tensorflow/contrib/linalg/python/kernel_tests") -add_python_module("tensorflow/contrib/linear_optimizer") -add_python_module("tensorflow/contrib/linear_optimizer/kernels") -add_python_module("tensorflow/contrib/linear_optimizer/kernels/g3doc") -add_python_module("tensorflow/contrib/linear_optimizer/python") -add_python_module("tensorflow/contrib/linear_optimizer/python/kernel_tests") -add_python_module("tensorflow/contrib/linear_optimizer/python/ops") +FILE(READ python_modules.txt python_modules) +# Convert file contents into a CMake list (where each element in the list is one line of the file) +STRING(REGEX REPLACE ";" "\\\\;" python_modules "${python_modules}") +STRING(REGEX REPLACE "\n" ";" python_modules "${python_modules}") + +foreach(python_module ${python_modules}) + add_python_module(${python_module}) +endforeach(python_module) + add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite") @@ -514,157 +215,6 @@ add_custom_command( TARGET tf_python_copy_scripts_to_destination PRE_BUILD COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python/lite.py) -add_python_module("tensorflow/contrib/lookup") -add_python_module("tensorflow/contrib/losses") -add_python_module("tensorflow/contrib/losses/python") -add_python_module("tensorflow/contrib/losses/python/losses") -add_python_module("tensorflow/contrib/losses/python/metric_learning") -add_python_module("tensorflow/contrib/makefile") -add_python_module("tensorflow/contrib/makefile/test") -add_python_module("tensorflow/contrib/memory_stats") -add_python_module("tensorflow/contrib/memory_stats/kernels") -add_python_module("tensorflow/contrib/memory_stats/ops") -add_python_module("tensorflow/contrib/memory_stats/python") -add_python_module("tensorflow/contrib/memory_stats/python/kernel_tests") -add_python_module("tensorflow/contrib/memory_stats/python/ops") -add_python_module("tensorflow/contrib/meta_graph_transform") -add_python_module("tensorflow/contrib/metrics") -add_python_module("tensorflow/contrib/metrics/kernels") -add_python_module("tensorflow/contrib/metrics/ops") -add_python_module("tensorflow/contrib/metrics/python") -add_python_module("tensorflow/contrib/metrics/python/kernel_tests") -add_python_module("tensorflow/contrib/metrics/python/metrics") -add_python_module("tensorflow/contrib/metrics/python/ops") -add_python_module("tensorflow/contrib/model_pruning") -add_python_module("tensorflow/contrib/model_pruning/examples") -add_python_module("tensorflow/contrib/model_pruning/examples/cifar10") -add_python_module("tensorflow/contrib/model_pruning/python") -add_python_module("tensorflow/contrib/model_pruning/python/layers") -add_python_module("tensorflow/contrib/ndlstm") -add_python_module("tensorflow/contrib/ndlstm/python") -add_python_module("tensorflow/contrib/nn") -add_python_module("tensorflow/contrib/nn/python") -add_python_module("tensorflow/contrib/nn/python/ops") -add_python_module("tensorflow/contrib/nccl") -add_python_module("tensorflow/contrib/nccl/kernels") -add_python_module("tensorflow/contrib/nccl/ops") -add_python_module("tensorflow/contrib/nccl/python") -add_python_module("tensorflow/contrib/nccl/python/ops") -add_python_module("tensorflow/contrib/nearest_neighbor/kernels") -add_python_module("tensorflow/contrib/nearest_neighbor/ops") -add_python_module("tensorflow/contrib/nearest_neighbor/python") -add_python_module("tensorflow/contrib/nearest_neighbor/python/kernel_tests") -add_python_module("tensorflow/contrib/nearest_neighbor/python/ops") -add_python_module("tensorflow/contrib/opt") -add_python_module("tensorflow/contrib/opt/python") -add_python_module("tensorflow/contrib/opt/python/training") -add_python_module("tensorflow/contrib/pi_examples") -add_python_module("tensorflow/contrib/pi_examples/camera") -add_python_module("tensorflow/contrib/pi_examples/label_image") -add_python_module("tensorflow/contrib/pi_examples/label_image/data") -add_python_module("tensorflow/contrib/predictor") -add_python_module("tensorflow/contrib/quantization") -add_python_module("tensorflow/contrib/quantization/python") -add_python_module("tensorflow/contrib/quantize") -add_python_module("tensorflow/contrib/quantize/python") -add_python_module("tensorflow/contrib/remote_fused_graph/pylib") -add_python_module("tensorflow/contrib/remote_fused_graph/pylib/python") -add_python_module("tensorflow/contrib/remote_fused_graph/pylib/python/ops") -add_python_module("tensorflow/contrib/resampler") -add_python_module("tensorflow/contrib/resampler/kernels") -add_python_module("tensorflow/contrib/resampler/ops") -add_python_module("tensorflow/contrib/resampler/python") -add_python_module("tensorflow/contrib/resampler/python/ops") -add_python_module("tensorflow/contrib/rnn") -add_python_module("tensorflow/contrib/rnn/kernels") -add_python_module("tensorflow/contrib/rnn/ops") -add_python_module("tensorflow/contrib/rnn/python") -add_python_module("tensorflow/contrib/rnn/python/kernel_tests") -add_python_module("tensorflow/contrib/rnn/python/ops") -add_python_module("tensorflow/contrib/saved_model") -add_python_module("tensorflow/contrib/saved_model/python") -add_python_module("tensorflow/contrib/saved_model/python/saved_model") -add_python_module("tensorflow/contrib/seq2seq") -add_python_module("tensorflow/contrib/seq2seq/kernels") -add_python_module("tensorflow/contrib/seq2seq/ops") -add_python_module("tensorflow/contrib/seq2seq/python") -add_python_module("tensorflow/contrib/seq2seq/python/kernel_tests") -add_python_module("tensorflow/contrib/seq2seq/python/ops") -add_python_module("tensorflow/contrib/session_bundle") -add_python_module("tensorflow/contrib/session_bundle/example") -add_python_module("tensorflow/contrib/session_bundle/testdata") -add_python_module("tensorflow/contrib/signal") -add_python_module("tensorflow/contrib/signal/python") -add_python_module("tensorflow/contrib/signal/python/ops") -add_python_module("tensorflow/contrib/slim") -add_python_module("tensorflow/contrib/slim/python") -add_python_module("tensorflow/contrib/slim/python/slim") -add_python_module("tensorflow/contrib/slim/python/slim/data") -add_python_module("tensorflow/contrib/slim/python/slim/nets") -add_python_module("tensorflow/contrib/solvers") -add_python_module("tensorflow/contrib/solvers/python") -add_python_module("tensorflow/contrib/solvers/python/ops") -add_python_module("tensorflow/contrib/sparsemax") -add_python_module("tensorflow/contrib/sparsemax/python") -add_python_module("tensorflow/contrib/sparsemax/python/ops") -add_python_module("tensorflow/contrib/specs") -add_python_module("tensorflow/contrib/specs/python") -add_python_module("tensorflow/contrib/staging") -add_python_module("tensorflow/contrib/stat_summarizer") -add_python_module("tensorflow/contrib/stateless") -add_python_module("tensorflow/contrib/tensorboard") -add_python_module("tensorflow/contrib/tensorboard/plugins") -add_python_module("tensorflow/contrib/tensorboard/plugins/projector") -add_python_module("tensorflow/contrib/tensor_forest") -add_python_module("tensorflow/contrib/tensor_forest/client") -add_python_module("tensorflow/contrib/tensor_forest/core") -add_python_module("tensorflow/contrib/tensor_forest/core/ops") -add_python_module("tensorflow/contrib/tensor_forest/data") -add_python_module("tensorflow/contrib/tensor_forest/hybrid") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/core") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/core/ops") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/ops") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python/kernel_tests") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python/layers") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python/models") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python/ops") -add_python_module("tensorflow/contrib/tensor_forest/python") -add_python_module("tensorflow/contrib/tensor_forest/python/kernel_tests") -add_python_module("tensorflow/contrib/tensor_forest/python/ops") -add_python_module("tensorflow/contrib/testing") -add_python_module("tensorflow/contrib/testing/python") -add_python_module("tensorflow/contrib/testing/python/framework") -add_python_module("tensorflow/contrib/text") -add_python_module("tensorflow/contrib/text/kernels") -add_python_module("tensorflow/contrib/text/ops") -add_python_module("tensorflow/contrib/text/python") -add_python_module("tensorflow/contrib/text/python/ops") -add_python_module("tensorflow/contrib/tfprof") -add_python_module("tensorflow/contrib/timeseries") -add_python_module("tensorflow/contrib/timeseries/examples") -add_python_module("tensorflow/contrib/timeseries/examples/data") -add_python_module("tensorflow/contrib/timeseries/python") -add_python_module("tensorflow/contrib/timeseries/python/timeseries") -add_python_module("tensorflow/contrib/timeseries/python/timeseries/state_space_models") -add_python_module("tensorflow/contrib/tpu") -add_python_module("tensorflow/contrib/tpu/ops") -add_python_module("tensorflow/contrib/tpu/profiler") -add_python_module("tensorflow/contrib/tpu/python") -add_python_module("tensorflow/contrib/tpu/python/ops") -add_python_module("tensorflow/contrib/tpu/python/profiler") -add_python_module("tensorflow/contrib/tpu/python/tpu") -add_python_module("tensorflow/contrib/training") -add_python_module("tensorflow/contrib/training/python") -add_python_module("tensorflow/contrib/training/python/training") -add_python_module("tensorflow/contrib/util") -add_python_module("tensorflow/contrib/reduce_slice_ops") -add_python_module("tensorflow/contrib/reduce_slice_ops/kernels") -add_python_module("tensorflow/contrib/reduce_slice_ops/ops") -add_python_module("tensorflow/contrib/reduce_slice_ops/python") -add_python_module("tensorflow/contrib/reduce_slice_ops/python/kernel_tests") -add_python_module("tensorflow/contrib/reduce_slice_ops/python/ops") -add_python_module("tensorflow/contrib/summary") # Generate the tensorflow.python.platform.build_info module. set(BUILD_INFO_PY "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/platform/build_info.py") @@ -817,6 +367,9 @@ GENERATE_PYTHON_OP_LIB("contrib_memory_stats_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/memory_stats/ops/gen_memory_stats_ops.py) GENERATE_PYTHON_OP_LIB("contrib_nccl_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/nccl/ops/gen_nccl_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_periodic_resample_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/periodic_resample/python/ops/gen_periodic_resample_op.py) + GENERATE_PYTHON_OP_LIB("contrib_nearest_neighbor_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/nearest_neighbor/ops/gen_nearest_neighbor_ops.py) GENERATE_PYTHON_OP_LIB("contrib_resampler_ops" @@ -1019,6 +572,20 @@ target_link_libraries(pywrap_tensorflow_internal PRIVATE ) if(WIN32) + + # include contrib/periodic_resample as .so + # + set(tf_periodic_resample_srcs + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h" + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/ops/array_ops.cc" + ) + + AddUserOps(TARGET _periodic_resample_op + SOURCES "${tf_periodic_resample_srcs}" + DEPENDS pywrap_tensorflow_internal tf_python_ops + DISTCOPY ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/periodic_resample/python/ops/) + # include contrib/nearest_neighbor as .so # set(tf_nearest_neighbor_srcs diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 09e22285e10..2d58a48a497 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -154,6 +154,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/contrib/factorization/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/image/*_test.py" "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/*_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/python/kernel_tests/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/python/kernel_tests/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/python/kernel_tests/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/stateless/python/kernel_tests/*_test.py" @@ -224,6 +225,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) # Numerical issues, calculations off. "${tensorflow_source_dir}/tensorflow/python/kernel_tests/concat_op_test.py" "${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/wals_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py" "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/utils/data_utils_test.py" "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/backend_test.py" "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py" diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 1e8a6b26c9e..2cb6b7e76c6 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -143,6 +143,7 @@ py_test( size = "small", srcs = ["filter_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", @@ -315,6 +316,7 @@ py_test( size = "small", srcs = ["prefetch_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ ":dataset_serialization_test", "//tensorflow/python:platform", @@ -423,6 +425,7 @@ py_test( size = "medium", srcs = ["shuffle_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 145b9495ff4..b2c641f8ab3 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -204,6 +204,24 @@ cuda_py_test( ], ) +cuda_py_test( + name = "half_normal_test", + size = "medium", + srcs = ["python/kernel_tests/half_normal_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], +) + cuda_py_test( name = "inverse_gamma_test", srcs = ["python/kernel_tests/inverse_gamma_test.py"], diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 0d12d838932..66827179e9f 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -36,6 +36,7 @@ from tensorflow.contrib.distributions.python.ops.distribution_util import softpl from tensorflow.contrib.distributions.python.ops.distribution_util import tridiag from tensorflow.contrib.distributions.python.ops.estimator import * from tensorflow.contrib.distributions.python.ops.geometric import * +from tensorflow.contrib.distributions.python.ops.half_normal import * from tensorflow.contrib.distributions.python.ops.independent import * from tensorflow.contrib.distributions.python.ops.inverse_gamma import * from tensorflow.contrib.distributions.python.ops.logistic import * @@ -107,6 +108,7 @@ _allowed_symbols = [ 'Gamma', 'GammaWithSoftplusConcentrationRate', 'Geometric', + 'HalfNormal', 'Independent', 'InverseGamma', 'InverseGammaWithSoftplusConcentrationRate', diff --git a/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py new file mode 100644 index 00000000000..a4e75660083 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py @@ -0,0 +1,320 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for initializers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import importlib +import numpy as np + +from tensorflow.contrib.distributions.python.ops import half_normal as hn_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging + + +def try_import(name): # pylint: disable=invalid-name + module = None + try: + module = importlib.import_module(name) + except ImportError as e: + tf_logging.warning("Could not import %s: %s" % (name, str(e))) + return module + +stats = try_import("scipy.stats") + + +class HalfNormalTest(test.TestCase): + + def setUp(self): + self._rng = np.random.RandomState(123) + + def assertAllFinite(self, tensor): + is_finite = np.isfinite(tensor.eval()) + all_true = np.ones_like(is_finite, dtype=np.bool) + self.assertAllEqual(all_true, is_finite) + + def _testParamShapes(self, sample_shape, expected): + with self.test_session(): + param_shapes = hn_lib.HalfNormal.param_shapes(sample_shape) + scale_shape = param_shapes["scale"] + self.assertAllEqual(expected, scale_shape.eval()) + scale = array_ops.ones(scale_shape) + self.assertAllEqual( + expected, + array_ops.shape(hn_lib.HalfNormal(scale).sample()).eval()) + + def _testParamStaticShapes(self, sample_shape, expected): + param_shapes = hn_lib.HalfNormal.param_static_shapes(sample_shape) + scale_shape = param_shapes["scale"] + self.assertEqual(expected, scale_shape) + + def _testBatchShapes(self, dist, tensor): + self.assertAllEqual(dist.batch_shape_tensor().eval(), tensor.shape) + self.assertAllEqual(dist.batch_shape_tensor().eval(), tensor.eval().shape) + self.assertAllEqual(dist.batch_shape, tensor.shape) + self.assertAllEqual(dist.batch_shape, tensor.eval().shape) + + def testParamShapes(self): + sample_shape = [10, 3, 4] + self._testParamShapes(sample_shape, sample_shape) + self._testParamShapes(constant_op.constant(sample_shape), sample_shape) + + def testParamStaticShapes(self): + sample_shape = [10, 3, 4] + self._testParamStaticShapes(sample_shape, sample_shape) + self._testParamStaticShapes( + tensor_shape.TensorShape(sample_shape), sample_shape) + + def testHalfNormalLogPDF(self): + with self.test_session(): + batch_size = 6 + scale = constant_op.constant([3.0] * batch_size) + x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32) + halfnorm = hn_lib.HalfNormal(scale=scale) + + log_pdf = halfnorm.log_prob(x) + self._testBatchShapes(halfnorm, log_pdf) + + pdf = halfnorm.prob(x) + self._testBatchShapes(halfnorm, pdf) + + if not stats: + return + expected_log_pdf = stats.halfnorm(scale=scale.eval()).logpdf(x) + self.assertAllClose(expected_log_pdf, log_pdf.eval()) + self.assertAllClose(np.exp(expected_log_pdf), pdf.eval()) + + def testHalfNormalLogPDFMultidimensional(self): + with self.test_session(): + batch_size = 6 + scale = constant_op.constant([[3.0, 1.0]] * batch_size) + x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T + halfnorm = hn_lib.HalfNormal(scale=scale) + + log_pdf = halfnorm.log_prob(x) + self._testBatchShapes(halfnorm, log_pdf) + + pdf = halfnorm.prob(x) + self._testBatchShapes(halfnorm, pdf) + + if not stats: + return + expected_log_pdf = stats.halfnorm(scale=scale.eval()).logpdf(x) + self.assertAllClose(expected_log_pdf, log_pdf.eval()) + self.assertAllClose(np.exp(expected_log_pdf), pdf.eval()) + + def testHalfNormalCDF(self): + with self.test_session(): + batch_size = 50 + scale = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) + halfnorm = hn_lib.HalfNormal(scale=scale) + + cdf = halfnorm.cdf(x) + self._testBatchShapes(halfnorm, cdf) + + log_cdf = halfnorm.log_cdf(x) + self._testBatchShapes(halfnorm, log_cdf) + + if not stats: + return + expected_logcdf = stats.halfnorm(scale=scale).logcdf(x) + self.assertAllClose(expected_logcdf, log_cdf.eval(), atol=0) + self.assertAllClose(np.exp(expected_logcdf), cdf.eval(), atol=0) + + def testHalfNormalSurvivalFunction(self): + with self.test_session(): + batch_size = 50 + scale = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) + halfnorm = hn_lib.HalfNormal(scale=scale) + + sf = halfnorm.survival_function(x) + self._testBatchShapes(halfnorm, sf) + + log_sf = halfnorm.log_survival_function(x) + self._testBatchShapes(halfnorm, log_sf) + + if not stats: + return + expected_logsf = stats.halfnorm(scale=scale).logsf(x) + self.assertAllClose(expected_logsf, log_sf.eval(), atol=0) + self.assertAllClose(np.exp(expected_logsf), sf.eval(), atol=0) + + def testHalfNormalQuantile(self): + with self.test_session(): + batch_size = 50 + scale = self._rng.rand(batch_size) + 1.0 + p = np.linspace(0., 1.0, batch_size).astype(np.float64) + + halfnorm = hn_lib.HalfNormal(scale=scale) + x = halfnorm.quantile(p) + self._testBatchShapes(halfnorm, x) + + if not stats: + return + expected_x = stats.halfnorm(scale=scale).ppf(p) + self.assertAllClose(expected_x, x.eval(), atol=0) + + def testFiniteGradients(self): + for dtype in [np.float32, np.float64]: + g = ops.Graph() + with g.as_default(): + scale = variables.Variable(dtype(3.0)) + dist = hn_lib.HalfNormal(scale=scale) + x = np.array([0.01, 0.1, 1., 5., 10.]).astype(dtype) + for func in [ + dist.cdf, dist.log_cdf, dist.survival_function, + dist.log_prob, dist.prob, dist.log_survival_function, + ]: + print(func.__name__) + value = func(x) + grads = gradients_impl.gradients(value, [scale]) + with self.test_session(graph=g): + variables.global_variables_initializer().run() + self.assertAllFinite(value) + self.assertAllFinite(grads[0]) + + def testHalfNormalEntropy(self): + with self.test_session(): + scale = np.array([[1.0, 2.0, 3.0]]) + halfnorm = hn_lib.HalfNormal(scale=scale) + + # See https://en.wikipedia.org/wiki/Half-normal_distribution for the + # entropy formula used here. + expected_entropy = 0.5 * np.log(np.pi * scale ** 2.0 / 2.0) + 0.5 + + entropy = halfnorm.entropy() + self._testBatchShapes(halfnorm, entropy) + self.assertAllClose(expected_entropy, entropy.eval()) + + def testHalfNormalMeanAndMode(self): + with self.test_session(): + scale = np.array([11., 12., 13.]) + + halfnorm = hn_lib.HalfNormal(scale=scale) + expected_mean = scale * np.sqrt(2.0) / np.sqrt(np.pi) + + self.assertAllEqual((3,), halfnorm.mean().eval().shape) + self.assertAllEqual(expected_mean, halfnorm.mean().eval()) + + self.assertAllEqual((3,), halfnorm.mode().eval().shape) + self.assertAllEqual([0., 0., 0.], halfnorm.mode().eval()) + + def testHalfNormalVariance(self): + with self.test_session(): + scale = np.array([7., 7., 7.]) + halfnorm = hn_lib.HalfNormal(scale=scale) + expected_variance = scale ** 2.0 * (1.0 - 2.0 / np.pi) + + self.assertAllEqual((3,), halfnorm.variance().eval().shape) + self.assertAllEqual(expected_variance, halfnorm.variance().eval()) + + def testHalfNormalStandardDeviation(self): + with self.test_session(): + scale = np.array([7., 7., 7.]) + halfnorm = hn_lib.HalfNormal(scale=scale) + expected_variance = scale ** 2.0 * (1.0 - 2.0 / np.pi) + + self.assertAllEqual((3,), halfnorm.stddev().shape) + self.assertAllEqual(np.sqrt(expected_variance), halfnorm.stddev().eval()) + + def testHalfNormalSample(self): + with self.test_session(): + scale = constant_op.constant(3.0) + n = constant_op.constant(100000) + halfnorm = hn_lib.HalfNormal(scale=scale) + + sample = halfnorm.sample(n) + + self.assertEqual(sample.eval().shape, (100000,)) + self.assertAllClose(sample.eval().mean(), + 3.0 * np.sqrt(2.0) / np.sqrt(np.pi), atol=1e-1) + + expected_shape = tensor_shape.TensorShape([n.eval()]).concatenate( + tensor_shape.TensorShape(halfnorm.batch_shape_tensor().eval())) + self.assertAllEqual(expected_shape, sample.shape) + self.assertAllEqual(expected_shape, sample.eval().shape) + + expected_shape_static = (tensor_shape.TensorShape( + [n.eval()]).concatenate(halfnorm.batch_shape)) + self.assertAllEqual(expected_shape_static, sample.shape) + self.assertAllEqual(expected_shape_static, sample.eval().shape) + + def testHalfNormalSampleMultiDimensional(self): + with self.test_session(): + batch_size = 2 + scale = constant_op.constant([[2.0, 3.0]] * batch_size) + n = constant_op.constant(100000) + halfnorm = hn_lib.HalfNormal(scale=scale) + + sample = halfnorm.sample(n) + self.assertEqual(sample.shape, (100000, batch_size, 2)) + self.assertAllClose(sample.eval()[:, 0, 0].mean(), + 2.0 * np.sqrt(2.0) / np.sqrt(np.pi), atol=1e-1) + self.assertAllClose(sample.eval()[:, 0, 1].mean(), + 3.0 * np.sqrt(2.0) / np.sqrt(np.pi), atol=1e-1) + + expected_shape = tensor_shape.TensorShape([n.eval()]).concatenate( + tensor_shape.TensorShape(halfnorm.batch_shape_tensor().eval())) + self.assertAllEqual(expected_shape, sample.shape) + self.assertAllEqual(expected_shape, sample.eval().shape) + + expected_shape_static = (tensor_shape.TensorShape( + [n.eval()]).concatenate(halfnorm.batch_shape)) + self.assertAllEqual(expected_shape_static, sample.shape) + self.assertAllEqual(expected_shape_static, sample.eval().shape) + + def testNegativeSigmaFails(self): + with self.test_session(): + halfnorm = hn_lib.HalfNormal(scale=[-5.], validate_args=True, name="G") + with self.assertRaisesOpError("Condition x > 0 did not hold"): + halfnorm.mean().eval() + + def testHalfNormalShape(self): + with self.test_session(): + scale = constant_op.constant([6.0] * 5) + halfnorm = hn_lib.HalfNormal(scale=scale) + + self.assertEqual(halfnorm.batch_shape_tensor().eval(), [5]) + self.assertEqual(halfnorm.batch_shape, tensor_shape.TensorShape([5])) + self.assertAllEqual(halfnorm.event_shape_tensor().eval(), []) + self.assertEqual(halfnorm.event_shape, tensor_shape.TensorShape([])) + + def testHalfNormalShapeWithPlaceholders(self): + scale = array_ops.placeholder(dtype=dtypes.float32) + halfnorm = hn_lib.HalfNormal(scale=scale) + + with self.test_session() as sess: + # get_batch_shape should return an "" tensor. + self.assertEqual(halfnorm.batch_shape, tensor_shape.TensorShape(None)) + self.assertEqual(halfnorm.event_shape, ()) + self.assertAllEqual(halfnorm.event_shape_tensor().eval(), []) + self.assertAllEqual( + sess.run(halfnorm.batch_shape_tensor(), + feed_dict={scale: [1.0, 2.0]}), [2]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py new file mode 100644 index 00000000000..fc0751a6e0b --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/half_normal.py @@ -0,0 +1,171 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Half Normal distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import special_math + + +__all__ = [ + "HalfNormal", +] + + +class HalfNormal(distribution.Distribution): + """The Half Normal distribution with scale `scale`. + + #### Mathematical details + + The half normal is a transformation of a centered normal distribution. + If some random variable `X` has normal distribution, + ```none + X ~ Normal(0.0, scale) + Y = |X| + ``` + Then `Y` will have half normal distribution. The probability density + function (pdf) is: + + ```none + pdf(x; scale, x > 0) = sqrt(2) / (scale * sqrt(pi)) * + exp(- 1/2 * (x / scale) ** 2) + ) + ``` + Where `scale = sigma` is the standard deviation of the underlying normal + distribution. + + #### Examples + + Examples of initialization of one or a batch of distributions. + + ```python + # Define a single scalar HalfNormal distribution. + dist = tf.contrib.distributions.HalfNormal(scale=3.0) + + # Evaluate the cdf at 1, returning a scalar. + dist.cdf(1.) + + # Define a batch of two scalar valued HalfNormals. + # The first has scale 11.0, the second 22.0 + dist = tf.contrib.distributions.HalfNormal(scale=[11.0, 22.0]) + + # Evaluate the pdf of the first distribution on 1.0, and the second on 1.5, + # returning a length two tensor. + dist.prob([1.0, 1.5]) + + # Get 3 samples, returning a 3 x 2 tensor. + dist.sample([3]) + ``` + + """ + + def __init__(self, + scale, + validate_args=False, + allow_nan_stats=True, + name="HalfNormal"): + """Construct HalfNormals with scale `scale`. + + Args: + scale: Floating point tensor; the scales of the distribution(s). + Must contain only positive values. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, + statistics (e.g., mean, mode, variance) use the value "`NaN`" to + indicate the result is undefined. When `False`, an exception is raised + if one or more of the statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + """ + parameters = locals() + with ops.name_scope(name, values=[scale]): + with ops.control_dependencies([check_ops.assert_positive(scale)] if + validate_args else []): + self._scale = array_ops.identity(scale, name="scale") + super(HalfNormal, self).__init__( + dtype=self._scale.dtype, + reparameterization_type=distribution.FULLY_REPARAMETERIZED, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + graph_parents=[self._scale], + name=name) + + @staticmethod + def _param_shapes(sample_shape): + return {"scale": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)} + + @property + def scale(self): + """Distribution parameter for the scale.""" + return self._scale + + def _batch_shape_tensor(self): + return array_ops.shape(self.scale) + + def _batch_shape(self): + return self.scale.shape + + def _event_shape_tensor(self): + return constant_op.constant([], dtype=dtypes.int32) + + def _event_shape(self): + return tensor_shape.scalar() + + def _sample_n(self, n, seed=None): + shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) + sampled = random_ops.random_normal( + shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) + return math_ops.abs(sampled * self.scale) + + def _prob(self, x): + coeff = np.sqrt(2) / self.scale / np.sqrt(np.pi) + pdf = coeff * math_ops.exp(- 0.5 * (x / self.scale) ** 2) + return pdf * math_ops.cast(x >= 0, self.dtype) + + def _cdf(self, x): + truncated_x = nn.relu(x) + return math_ops.erf(truncated_x / self.scale / np.sqrt(2.0)) + + def _entropy(self): + return 0.5 * math_ops.log(np.pi * self.scale ** 2.0 / 2.0) + 0.5 + + def _mean(self): + return self.scale * np.sqrt(2.0) / np.sqrt(np.pi) + + def _quantile(self, p): + return np.sqrt(2.0) * self.scale * special_math.erfinv(p) + + def _mode(self): + return array_ops.zeros(self.batch_shape_tensor()) + + def _variance(self): + return self.scale ** 2.0 * (1.0 - 2.0 / np.pi) diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py index 5448918a507..0623b2c7263 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py +++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py @@ -320,13 +320,14 @@ class MixtureSameFamily(distribution.Distribution): return array_ops.shape(d.batch_shape_tensor())[0] dist_batch_ndims = _get_ndims(self) cat_batch_ndims = _get_ndims(self.mixture_distribution) - bnd = distribution_util.pick_vector( + pad_ndims = distribution_util.pick_vector( self.mixture_distribution.is_scalar_batch(), - [dist_batch_ndims], [cat_batch_ndims])[0] + [dist_batch_ndims], + [dist_batch_ndims - cat_batch_ndims])[0] s = array_ops.shape(x) x = array_ops.reshape(x, shape=array_ops.concat([ s[:-1], - array_ops.ones([bnd], dtype=dtypes.int32), + array_ops.ones([pad_ndims], dtype=dtypes.int32), s[-1:], array_ops.ones([self._event_ndims], dtype=dtypes.int32), ], axis=0)) diff --git a/tensorflow/contrib/eager/python/examples/spinn/BUILD b/tensorflow/contrib/eager/python/examples/spinn/BUILD index 0263d213250..a1f8a759e2a 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/BUILD +++ b/tensorflow/contrib/eager/python/examples/spinn/BUILD @@ -38,4 +38,5 @@ cuda_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", ], + tags = ["no_pip"], # because spinn.py is under third_party/. ) diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index 07b7857e7b2..3f1ece45105 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -441,7 +441,7 @@ def get_unique_variable(var_op_name): """ candidates = get_variables(scope=var_op_name) if not candidates: - raise ValueError('Couldnt find variable %s' % var_op_name) + raise ValueError('Couldn\'t find variable %s' % var_op_name) for candidate in candidates: if candidate.op.name == var_op_name: diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 2d42875b468..0d25a098525 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -2654,51 +2654,52 @@ def spatial_softmax(features, ValueError: If unexpected data_format specified. ValueError: If num_channels dimension is unspecified. """ - shape = array_ops.shape(features) - static_shape = features.shape - if data_format == DATA_FORMAT_NHWC: - height, width, num_channels = shape[1], shape[2], static_shape[3] - elif data_format == DATA_FORMAT_NCHW: - num_channels, height, width = static_shape[1], shape[2], shape[3] - else: - raise ValueError('data_format has to be either NCHW or NHWC.') - if num_channels.value is None: - raise ValueError('The num_channels dimension of the inputs to ' - '`spatial_softmax` should be defined. Found `None`.') - - with ops.name_scope(name, 'spatial_softmax', [features]) as name: - # Create tensors for x and y coordinate values, scaled to range [-1, 1]. - pos_x, pos_y = array_ops.meshgrid(math_ops.lin_space(-1., 1., num=height), - math_ops.lin_space(-1., 1., num=width), - indexing='ij') - pos_x = array_ops.reshape(pos_x, [height * width]) - pos_y = array_ops.reshape(pos_y, [height * width]) - if temperature is None: - temperature_collections = utils.get_variable_collections( - variables_collections, 'temperature') - temperature = variables.model_variable( - 'temperature', - shape=(), - dtype=dtypes.float32, - initializer=init_ops.ones_initializer(), - collections=temperature_collections, - trainable=trainable) - if data_format == 'NCHW': - features = array_ops.reshape(features, [-1, height * width]) + with variable_scope.variable_scope(name, 'spatial_softmax'): + shape = array_ops.shape(features) + static_shape = features.shape + if data_format == DATA_FORMAT_NHWC: + height, width, num_channels = shape[1], shape[2], static_shape[3] + elif data_format == DATA_FORMAT_NCHW: + num_channels, height, width = static_shape[1], shape[2], shape[3] else: - features = array_ops.reshape( - array_ops.transpose(features, [0, 3, 1, 2]), [-1, height * width]) + raise ValueError('data_format has to be either NCHW or NHWC.') + if num_channels.value is None: + raise ValueError('The num_channels dimension of the inputs to ' + '`spatial_softmax` should be defined. Found `None`.') - softmax_attention = nn.softmax(features/temperature) - expected_x = math_ops.reduce_sum( - pos_x * softmax_attention, [1], keep_dims=True) - expected_y = math_ops.reduce_sum( - pos_y * softmax_attention, [1], keep_dims=True) - expected_xy = array_ops.concat([expected_x, expected_y], 1) - feature_keypoints = array_ops.reshape( - expected_xy, [-1, num_channels.value * 2]) - feature_keypoints.set_shape([None, num_channels.value * 2]) - return feature_keypoints + with ops.name_scope('spatial_softmax_op', 'spatial_softmax_op', [features]): + # Create tensors for x and y coordinate values, scaled to range [-1, 1]. + pos_x, pos_y = array_ops.meshgrid(math_ops.lin_space(-1., 1., num=height), + math_ops.lin_space(-1., 1., num=width), + indexing='ij') + pos_x = array_ops.reshape(pos_x, [height * width]) + pos_y = array_ops.reshape(pos_y, [height * width]) + if temperature is None: + temperature_collections = utils.get_variable_collections( + variables_collections, 'temperature') + temperature = variables.model_variable( + 'temperature', + shape=(), + dtype=dtypes.float32, + initializer=init_ops.ones_initializer(), + collections=temperature_collections, + trainable=trainable) + if data_format == 'NCHW': + features = array_ops.reshape(features, [-1, height * width]) + else: + features = array_ops.reshape( + array_ops.transpose(features, [0, 3, 1, 2]), [-1, height * width]) + + softmax_attention = nn.softmax(features/temperature) + expected_x = math_ops.reduce_sum( + pos_x * softmax_attention, [1], keep_dims=True) + expected_y = math_ops.reduce_sum( + pos_y * softmax_attention, [1], keep_dims=True) + expected_xy = array_ops.concat([expected_x, expected_y], 1) + feature_keypoints = array_ops.reshape( + expected_xy, [-1, num_channels.value * 2]) + feature_keypoints.set_shape([None, num_channels.value * 2]) + return feature_keypoints def stack(inputs, layer, stack_args, **kwargs): diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 788d2d0b1a5..05ed8b3409e 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -30,7 +30,6 @@ import six from google.protobuf import message from tensorflow.contrib import layers -from tensorflow.contrib import metrics as metrics_lib from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_args from tensorflow.contrib.framework import list_variables @@ -60,6 +59,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_util from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import resources from tensorflow.python.ops import variables from tensorflow.python.platform import gfile @@ -1230,7 +1230,7 @@ class Estimator(BaseEstimator): if metric_key.MetricKey.LOSS not in model_fn_ops.eval_metric_ops: model_fn_ops.eval_metric_ops[metric_key.MetricKey.LOSS] = ( - metrics_lib.streaming_mean(model_fn_ops.loss)) + metrics_lib.mean(model_fn_ops.loss)) return model_fn_ops def _get_predict_ops(self, features): diff --git a/tensorflow/contrib/lite/README.md b/tensorflow/contrib/lite/README.md index c7464bcc9d3..fc9144d5fc7 100644 --- a/tensorflow/contrib/lite/README.md +++ b/tensorflow/contrib/lite/README.md @@ -4,7 +4,7 @@ TensorFlow Lite is TensorFlow's lightweight solution for mobile and embedded dev TensorFlow Lite uses many techniques for achieving low latency like optimizing the kernels for specific mobile apps, pre-fused activations, quantized kernels that allow smaller and faster (fixed-point math) models, and in the future, leverage specialized machine learning hardware to get the best possible performance for a particular model on a particular device. ![image](g3doc/TFLite-Architecture.jpg) -# Getting Started with a Demo App +# Getting Started with an Android Demo App This section contains an example application using TensorFlow Lite for Android devices. The demo is a sample camera app that classifies images continuously using a quantized Mobilenet model. A device running Android 5.0 ( API 21) or higher is required to run the demo. @@ -17,7 +17,7 @@ There are 3 ways to get the demo app to your device In the demo app, inference is done using the TensorFlow Lite Java API. The demo app classifies frames in real-time, displaying the top most probable classifications. It also displays the time taken to detect the object. ## Downloading the pre-built binary -The fastest path to trying the demo, is to download the pre-built binary +The fastest path to trying the demo, is to download the pre-built binary [TfLiteCameraDemo.apk](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk) Once the apk is installed, click the app icon to start the app. The first-time the app is opened, the app asks for runtime permissions to access the device camera. The demo app opens the back-camera of the device and recognizes the objects in the camera's field of view. At the bottom of the image (or at the left of the image if the device is in landscape mode), it shows the latency of classification and the top three objects classified. @@ -69,7 +69,7 @@ android_ndk_repository( Additional details on building with Android can be found [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/README.md). -### Build the source code +### Build the source code Run bazel with the following command to build the demo. Build the demo app: @@ -86,6 +86,17 @@ environment (due to a Bazel bug). ### More about the demo The demo is resizing each camera image frame to (224 width * 224 height) to match the quantized Mobilenet model being used. The resized image is converted into a ByteBuffer row by row of size 1 * 224 * 224 * 3 bytes, where 1 is the number of images in a batch 224 * 224 is the width and height of the image 3 bytes represents three colors of a pixel. This demo uses the TensorFlow Lite Java inference API for models which take a single input and provide a single output. This outputs a two-dimensional array, with the first dimension being the category index and the second dimension being the confidence of classification. The Mobilenet model has 1001 unique categories and the app sorts the probabilities of all the categories and displays the top three. The Mobilenet quantized model is bundled within the assets directory of the app. +# iOS Demo App + +Similar to the Android demo app, there's an iOS camera app that uses exactly the same model (224 * 224 quantized Mobilenet). + +This demo app requires a camera so it doesn't work with simulators. It need to be executed on a real iOS device. Follow the instructions to build and run the demo app: + +1. Follow the Building section [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/ios.md#building) to build the universal iOS library for TensorFlow Lite. +1. Install [CocoaPods](https://cocoapods.org/) if it wasn't installed yet: `sudo gem install cocoapods`. +1. Run `pod install` in `tensorflow/contrib/lite/examples/ios/camera` to generate the workspace file. +1. Open the project by running `open tflite_camera_example.xcworkspace`, and build the app in XCode. + # TensorFlow Lite Quick Start ## Step 1. Decide which GraphDef to use diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/download_dependencies.sh index 778d618361e..7fce1ba3461 100755 --- a/tensorflow/contrib/lite/download_dependencies.sh +++ b/tensorflow/contrib/lite/download_dependencies.sh @@ -19,6 +19,13 @@ set -e DOWNLOADS_DIR=tensorflow/contrib/lite/downloads BZL_FILE_PATH=tensorflow/workspace.bzl +# Ensure it is being run from repo root +if [ ! -f $BZL_FILE_PATH ]; then + echo "Could not find ${BZL_FILE_PATH}": + echo "Likely you are not running this from the root directory of the repository."; + exit 1; +fi + EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm index ea398ad14e8..10f31bb6f17 100644 --- a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm +++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm @@ -123,7 +123,11 @@ static void GetTopN(const uint8_t* prediction, const int prediction_size, const AVCaptureDevice* device = [AVCaptureDevice defaultDeviceWithMediaType:AVMediaTypeVideo]; AVCaptureDeviceInput* deviceInput = [AVCaptureDeviceInput deviceInputWithDevice:device error:&error]; - assert(error == nil); + + if (error != nil) { + NSLog(@"Failed to initialize AVCaptureDeviceInput. Note: This app doesn't work with simulator"); + assert(NO); + } if ([session canAddInput:deviceInput]) [session addInput:deviceInput]; diff --git a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm index fe26ceec427..d1215fa0bff 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm +++ b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm @@ -20,6 +20,7 @@ - (BOOL)application:(UIApplication *)application didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { + UITabBarController *bar = [[UITabBarController alloc] init]; [bar setViewControllers:@[ [[RunModelViewController alloc] init] ]]; bar.selectedIndex = 0; diff --git a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm index cb19377d7e3..cb0fe1a7650 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm +++ b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm @@ -31,6 +31,7 @@ std::vector LoadImageFromFile(const char* file_name, int* out_width, in std::vector file_data(bytes_in_file); fread(file_data.data(), 1, bytes_in_file, file_handle); fclose(file_handle); + CFDataRef file_data_ref = CFDataCreateWithBytesNoCopy(NULL, file_data.data(), bytes_in_file, kCFAllocatorNull); CGDataProviderRef image_provider = CGDataProviderCreateWithCFData(file_data_ref); @@ -63,6 +64,7 @@ std::vector LoadImageFromFile(const char* file_name, int* out_width, in const int bytes_in_image = (bytes_per_row * height); std::vector result(bytes_in_image); const int bits_per_component = 8; + CGContextRef context = CGBitmapContextCreate(result.data(), width, height, bits_per_component, bytes_per_row, color_space, kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index 89e8693490d..3d6a3ec0fd4 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -24,6 +24,7 @@ py_test( name = "lite_test", srcs = ["lite_test.py"], srcs_version = "PY2AND3", + tags = ["no_oss"], deps = [ ":lite", "//tensorflow/python:array_ops", diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h new file mode 100755 index 00000000000..cbf10275f31 --- /dev/null +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -0,0 +1,5417 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// automatically generated by the FlatBuffers compiler, do not modify + +#ifndef FLATBUFFERS_GENERATED_SCHEMA_TFLITE_H_ +#define FLATBUFFERS_GENERATED_SCHEMA_TFLITE_H_ + +#include "flatbuffers/flatbuffers.h" + +namespace tflite { + +struct QuantizationParameters; +struct QuantizationParametersT; + +struct Tensor; +struct TensorT; + +struct Conv2DOptions; +struct Conv2DOptionsT; + +struct Pool2DOptions; +struct Pool2DOptionsT; + +struct DepthwiseConv2DOptions; +struct DepthwiseConv2DOptionsT; + +struct ConcatEmbeddingsOptions; +struct ConcatEmbeddingsOptionsT; + +struct LSHProjectionOptions; +struct LSHProjectionOptionsT; + +struct SVDFOptions; +struct SVDFOptionsT; + +struct RNNOptions; +struct RNNOptionsT; + +struct FullyConnectedOptions; +struct FullyConnectedOptionsT; + +struct SoftmaxOptions; +struct SoftmaxOptionsT; + +struct ConcatenationOptions; +struct ConcatenationOptionsT; + +struct AddOptions; +struct AddOptionsT; + +struct MulOptions; +struct MulOptionsT; + +struct L2NormOptions; +struct L2NormOptionsT; + +struct LocalResponseNormalizationOptions; +struct LocalResponseNormalizationOptionsT; + +struct LSTMOptions; +struct LSTMOptionsT; + +struct ResizeBilinearOptions; +struct ResizeBilinearOptionsT; + +struct CallOptions; +struct CallOptionsT; + +struct ReshapeOptions; +struct ReshapeOptionsT; + +struct SkipGramOptions; +struct SkipGramOptionsT; + +struct SpaceToDepthOptions; +struct SpaceToDepthOptionsT; + +struct EmbeddingLookupSparseOptions; +struct EmbeddingLookupSparseOptionsT; + +struct OperatorCode; +struct OperatorCodeT; + +struct Operator; +struct OperatorT; + +struct SubGraph; +struct SubGraphT; + +struct Buffer; +struct BufferT; + +struct Model; +struct ModelT; + +enum TensorType { + TensorType_FLOAT32 = 0, + TensorType_FLOAT16 = 1, + TensorType_INT32 = 2, + TensorType_UINT8 = 3, + TensorType_INT64 = 4, + TensorType_STRING = 5, + TensorType_MIN = TensorType_FLOAT32, + TensorType_MAX = TensorType_STRING +}; + +inline TensorType (&EnumValuesTensorType())[6] { + static TensorType values[] = {TensorType_FLOAT32, TensorType_FLOAT16, + TensorType_INT32, TensorType_UINT8, + TensorType_INT64, TensorType_STRING}; + return values; +} + +inline const char **EnumNamesTensorType() { + static const char *names[] = {"FLOAT32", "FLOAT16", "INT32", "UINT8", + "INT64", "STRING", nullptr}; + return names; +} + +inline const char *EnumNameTensorType(TensorType e) { + const size_t index = static_cast(e); + return EnumNamesTensorType()[index]; +} + +enum BuiltinOperator { + BuiltinOperator_ADD = 0, + BuiltinOperator_AVERAGE_POOL_2D = 1, + BuiltinOperator_CONCATENATION = 2, + BuiltinOperator_CONV_2D = 3, + BuiltinOperator_DEPTHWISE_CONV_2D = 4, + BuiltinOperator_EMBEDDING_LOOKUP = 7, + BuiltinOperator_FULLY_CONNECTED = 9, + BuiltinOperator_HASHTABLE_LOOKUP = 10, + BuiltinOperator_L2_NORMALIZATION = 11, + BuiltinOperator_L2_POOL_2D = 12, + BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION = 13, + BuiltinOperator_LOGISTIC = 14, + BuiltinOperator_LSH_PROJECTION = 15, + BuiltinOperator_LSTM = 16, + BuiltinOperator_MAX_POOL_2D = 17, + BuiltinOperator_MUL = 18, + BuiltinOperator_RELU = 19, + BuiltinOperator_RELU1 = 20, + BuiltinOperator_RELU6 = 21, + BuiltinOperator_RESHAPE = 22, + BuiltinOperator_RESIZE_BILINEAR = 23, + BuiltinOperator_RNN = 24, + BuiltinOperator_SOFTMAX = 25, + BuiltinOperator_SPACE_TO_DEPTH = 26, + BuiltinOperator_SVDF = 27, + BuiltinOperator_TANH = 28, + BuiltinOperator_CONCAT_EMBEDDINGS = 29, + BuiltinOperator_SKIP_GRAM = 30, + BuiltinOperator_CALL = 31, + BuiltinOperator_CUSTOM = 32, + BuiltinOperator_EMBEDDING_LOOKUP_SPARSE = 33, + BuiltinOperator_MIN = BuiltinOperator_ADD, + BuiltinOperator_MAX = BuiltinOperator_EMBEDDING_LOOKUP_SPARSE +}; + +inline BuiltinOperator (&EnumValuesBuiltinOperator())[31] { + static BuiltinOperator values[] = { + BuiltinOperator_ADD, + BuiltinOperator_AVERAGE_POOL_2D, + BuiltinOperator_CONCATENATION, + BuiltinOperator_CONV_2D, + BuiltinOperator_DEPTHWISE_CONV_2D, + BuiltinOperator_EMBEDDING_LOOKUP, + BuiltinOperator_FULLY_CONNECTED, + BuiltinOperator_HASHTABLE_LOOKUP, + BuiltinOperator_L2_NORMALIZATION, + BuiltinOperator_L2_POOL_2D, + BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, + BuiltinOperator_LOGISTIC, + BuiltinOperator_LSH_PROJECTION, + BuiltinOperator_LSTM, + BuiltinOperator_MAX_POOL_2D, + BuiltinOperator_MUL, + BuiltinOperator_RELU, + BuiltinOperator_RELU1, + BuiltinOperator_RELU6, + BuiltinOperator_RESHAPE, + BuiltinOperator_RESIZE_BILINEAR, + BuiltinOperator_RNN, + BuiltinOperator_SOFTMAX, + BuiltinOperator_SPACE_TO_DEPTH, + BuiltinOperator_SVDF, + BuiltinOperator_TANH, + BuiltinOperator_CONCAT_EMBEDDINGS, + BuiltinOperator_SKIP_GRAM, + BuiltinOperator_CALL, + BuiltinOperator_CUSTOM, + BuiltinOperator_EMBEDDING_LOOKUP_SPARSE}; + return values; +} + +inline const char **EnumNamesBuiltinOperator() { + static const char *names[] = {"ADD", + "AVERAGE_POOL_2D", + "CONCATENATION", + "CONV_2D", + "DEPTHWISE_CONV_2D", + "", + "", + "EMBEDDING_LOOKUP", + "", + "FULLY_CONNECTED", + "HASHTABLE_LOOKUP", + "L2_NORMALIZATION", + "L2_POOL_2D", + "LOCAL_RESPONSE_NORMALIZATION", + "LOGISTIC", + "LSH_PROJECTION", + "LSTM", + "MAX_POOL_2D", + "MUL", + "RELU", + "RELU1", + "RELU6", + "RESHAPE", + "RESIZE_BILINEAR", + "RNN", + "SOFTMAX", + "SPACE_TO_DEPTH", + "SVDF", + "TANH", + "CONCAT_EMBEDDINGS", + "SKIP_GRAM", + "CALL", + "CUSTOM", + "EMBEDDING_LOOKUP_SPARSE", + nullptr}; + return names; +} + +inline const char *EnumNameBuiltinOperator(BuiltinOperator e) { + const size_t index = static_cast(e); + return EnumNamesBuiltinOperator()[index]; +} + +enum BuiltinOptions { + BuiltinOptions_NONE = 0, + BuiltinOptions_Conv2DOptions = 1, + BuiltinOptions_DepthwiseConv2DOptions = 2, + BuiltinOptions_ConcatEmbeddingsOptions = 3, + BuiltinOptions_LSHProjectionOptions = 4, + BuiltinOptions_Pool2DOptions = 5, + BuiltinOptions_SVDFOptions = 6, + BuiltinOptions_RNNOptions = 7, + BuiltinOptions_FullyConnectedOptions = 8, + BuiltinOptions_SoftmaxOptions = 9, + BuiltinOptions_ConcatenationOptions = 10, + BuiltinOptions_AddOptions = 11, + BuiltinOptions_L2NormOptions = 12, + BuiltinOptions_LocalResponseNormalizationOptions = 13, + BuiltinOptions_LSTMOptions = 14, + BuiltinOptions_ResizeBilinearOptions = 15, + BuiltinOptions_CallOptions = 16, + BuiltinOptions_ReshapeOptions = 17, + BuiltinOptions_SkipGramOptions = 18, + BuiltinOptions_SpaceToDepthOptions = 19, + BuiltinOptions_EmbeddingLookupSparseOptions = 20, + BuiltinOptions_MulOptions = 21, + BuiltinOptions_MIN = BuiltinOptions_NONE, + BuiltinOptions_MAX = BuiltinOptions_MulOptions +}; + +inline BuiltinOptions (&EnumValuesBuiltinOptions())[22] { + static BuiltinOptions values[] = { + BuiltinOptions_NONE, + BuiltinOptions_Conv2DOptions, + BuiltinOptions_DepthwiseConv2DOptions, + BuiltinOptions_ConcatEmbeddingsOptions, + BuiltinOptions_LSHProjectionOptions, + BuiltinOptions_Pool2DOptions, + BuiltinOptions_SVDFOptions, + BuiltinOptions_RNNOptions, + BuiltinOptions_FullyConnectedOptions, + BuiltinOptions_SoftmaxOptions, + BuiltinOptions_ConcatenationOptions, + BuiltinOptions_AddOptions, + BuiltinOptions_L2NormOptions, + BuiltinOptions_LocalResponseNormalizationOptions, + BuiltinOptions_LSTMOptions, + BuiltinOptions_ResizeBilinearOptions, + BuiltinOptions_CallOptions, + BuiltinOptions_ReshapeOptions, + BuiltinOptions_SkipGramOptions, + BuiltinOptions_SpaceToDepthOptions, + BuiltinOptions_EmbeddingLookupSparseOptions, + BuiltinOptions_MulOptions}; + return values; +} + +inline const char **EnumNamesBuiltinOptions() { + static const char *names[] = {"NONE", + "Conv2DOptions", + "DepthwiseConv2DOptions", + "ConcatEmbeddingsOptions", + "LSHProjectionOptions", + "Pool2DOptions", + "SVDFOptions", + "RNNOptions", + "FullyConnectedOptions", + "SoftmaxOptions", + "ConcatenationOptions", + "AddOptions", + "L2NormOptions", + "LocalResponseNormalizationOptions", + "LSTMOptions", + "ResizeBilinearOptions", + "CallOptions", + "ReshapeOptions", + "SkipGramOptions", + "SpaceToDepthOptions", + "EmbeddingLookupSparseOptions", + "MulOptions", + nullptr}; + return names; +} + +inline const char *EnumNameBuiltinOptions(BuiltinOptions e) { + const size_t index = static_cast(e); + return EnumNamesBuiltinOptions()[index]; +} + +template +struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_NONE; +}; + +template <> +struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_Conv2DOptions; +}; + +template <> +struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = + BuiltinOptions_DepthwiseConv2DOptions; +}; + +template <> +struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = + BuiltinOptions_ConcatEmbeddingsOptions; +}; + +template <> +struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LSHProjectionOptions; +}; + +template <> +struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_Pool2DOptions; +}; + +template <> +struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SVDFOptions; +}; + +template <> +struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_RNNOptions; +}; + +template <> +struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_FullyConnectedOptions; +}; + +template <> +struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SoftmaxOptions; +}; + +template <> +struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ConcatenationOptions; +}; + +template <> +struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_AddOptions; +}; + +template <> +struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_L2NormOptions; +}; + +template <> +struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = + BuiltinOptions_LocalResponseNormalizationOptions; +}; + +template <> +struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LSTMOptions; +}; + +template <> +struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ResizeBilinearOptions; +}; + +template <> +struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_CallOptions; +}; + +template <> +struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ReshapeOptions; +}; + +template <> +struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SkipGramOptions; +}; + +template <> +struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SpaceToDepthOptions; +}; + +template <> +struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = + BuiltinOptions_EmbeddingLookupSparseOptions; +}; + +template <> +struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_MulOptions; +}; + +struct BuiltinOptionsUnion { + BuiltinOptions type; + void *value; + + BuiltinOptionsUnion() : type(BuiltinOptions_NONE), value(nullptr) {} + BuiltinOptionsUnion(BuiltinOptionsUnion &&u) FLATBUFFERS_NOEXCEPT + : type(BuiltinOptions_NONE), + value(nullptr) { + std::swap(type, u.type); + std::swap(value, u.value); + } + BuiltinOptionsUnion(const BuiltinOptionsUnion &) FLATBUFFERS_NOEXCEPT; + BuiltinOptionsUnion &operator=(const BuiltinOptionsUnion &u) + FLATBUFFERS_NOEXCEPT { + BuiltinOptionsUnion t(u); + std::swap(type, t.type); + std::swap(value, t.value); + return *this; + } + BuiltinOptionsUnion &operator=(BuiltinOptionsUnion &&u) FLATBUFFERS_NOEXCEPT { + std::swap(type, u.type); + std::swap(value, u.value); + return *this; + } + ~BuiltinOptionsUnion() { Reset(); } + + void Reset(); + +#ifndef FLATBUFFERS_CPP98_STL + template + void Set(T &&val) { + Reset(); + type = BuiltinOptionsTraits::enum_value; + if (type != BuiltinOptions_NONE) { + value = new T(std::forward(val)); + } + } +#endif // FLATBUFFERS_CPP98_STL + + static void *UnPack(const void *obj, BuiltinOptions type, + const flatbuffers::resolver_function_t *resolver); + flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, + const flatbuffers::rehasher_function_t *_rehasher = nullptr) const; + + Conv2DOptionsT *AsConv2DOptions() { + return type == BuiltinOptions_Conv2DOptions + ? reinterpret_cast(value) + : nullptr; + } + const Conv2DOptionsT *AsConv2DOptions() const { + return type == BuiltinOptions_Conv2DOptions + ? reinterpret_cast(value) + : nullptr; + } + DepthwiseConv2DOptionsT *AsDepthwiseConv2DOptions() { + return type == BuiltinOptions_DepthwiseConv2DOptions + ? reinterpret_cast(value) + : nullptr; + } + const DepthwiseConv2DOptionsT *AsDepthwiseConv2DOptions() const { + return type == BuiltinOptions_DepthwiseConv2DOptions + ? reinterpret_cast(value) + : nullptr; + } + ConcatEmbeddingsOptionsT *AsConcatEmbeddingsOptions() { + return type == BuiltinOptions_ConcatEmbeddingsOptions + ? reinterpret_cast(value) + : nullptr; + } + const ConcatEmbeddingsOptionsT *AsConcatEmbeddingsOptions() const { + return type == BuiltinOptions_ConcatEmbeddingsOptions + ? reinterpret_cast(value) + : nullptr; + } + LSHProjectionOptionsT *AsLSHProjectionOptions() { + return type == BuiltinOptions_LSHProjectionOptions + ? reinterpret_cast(value) + : nullptr; + } + const LSHProjectionOptionsT *AsLSHProjectionOptions() const { + return type == BuiltinOptions_LSHProjectionOptions + ? reinterpret_cast(value) + : nullptr; + } + Pool2DOptionsT *AsPool2DOptions() { + return type == BuiltinOptions_Pool2DOptions + ? reinterpret_cast(value) + : nullptr; + } + const Pool2DOptionsT *AsPool2DOptions() const { + return type == BuiltinOptions_Pool2DOptions + ? reinterpret_cast(value) + : nullptr; + } + SVDFOptionsT *AsSVDFOptions() { + return type == BuiltinOptions_SVDFOptions + ? reinterpret_cast(value) + : nullptr; + } + const SVDFOptionsT *AsSVDFOptions() const { + return type == BuiltinOptions_SVDFOptions + ? reinterpret_cast(value) + : nullptr; + } + RNNOptionsT *AsRNNOptions() { + return type == BuiltinOptions_RNNOptions + ? reinterpret_cast(value) + : nullptr; + } + const RNNOptionsT *AsRNNOptions() const { + return type == BuiltinOptions_RNNOptions + ? reinterpret_cast(value) + : nullptr; + } + FullyConnectedOptionsT *AsFullyConnectedOptions() { + return type == BuiltinOptions_FullyConnectedOptions + ? reinterpret_cast(value) + : nullptr; + } + const FullyConnectedOptionsT *AsFullyConnectedOptions() const { + return type == BuiltinOptions_FullyConnectedOptions + ? reinterpret_cast(value) + : nullptr; + } + SoftmaxOptionsT *AsSoftmaxOptions() { + return type == BuiltinOptions_SoftmaxOptions + ? reinterpret_cast(value) + : nullptr; + } + const SoftmaxOptionsT *AsSoftmaxOptions() const { + return type == BuiltinOptions_SoftmaxOptions + ? reinterpret_cast(value) + : nullptr; + } + ConcatenationOptionsT *AsConcatenationOptions() { + return type == BuiltinOptions_ConcatenationOptions + ? reinterpret_cast(value) + : nullptr; + } + const ConcatenationOptionsT *AsConcatenationOptions() const { + return type == BuiltinOptions_ConcatenationOptions + ? reinterpret_cast(value) + : nullptr; + } + AddOptionsT *AsAddOptions() { + return type == BuiltinOptions_AddOptions + ? reinterpret_cast(value) + : nullptr; + } + const AddOptionsT *AsAddOptions() const { + return type == BuiltinOptions_AddOptions + ? reinterpret_cast(value) + : nullptr; + } + L2NormOptionsT *AsL2NormOptions() { + return type == BuiltinOptions_L2NormOptions + ? reinterpret_cast(value) + : nullptr; + } + const L2NormOptionsT *AsL2NormOptions() const { + return type == BuiltinOptions_L2NormOptions + ? reinterpret_cast(value) + : nullptr; + } + LocalResponseNormalizationOptionsT *AsLocalResponseNormalizationOptions() { + return type == BuiltinOptions_LocalResponseNormalizationOptions + ? reinterpret_cast(value) + : nullptr; + } + const LocalResponseNormalizationOptionsT * + AsLocalResponseNormalizationOptions() const { + return type == BuiltinOptions_LocalResponseNormalizationOptions + ? reinterpret_cast( + value) + : nullptr; + } + LSTMOptionsT *AsLSTMOptions() { + return type == BuiltinOptions_LSTMOptions + ? reinterpret_cast(value) + : nullptr; + } + const LSTMOptionsT *AsLSTMOptions() const { + return type == BuiltinOptions_LSTMOptions + ? reinterpret_cast(value) + : nullptr; + } + ResizeBilinearOptionsT *AsResizeBilinearOptions() { + return type == BuiltinOptions_ResizeBilinearOptions + ? reinterpret_cast(value) + : nullptr; + } + const ResizeBilinearOptionsT *AsResizeBilinearOptions() const { + return type == BuiltinOptions_ResizeBilinearOptions + ? reinterpret_cast(value) + : nullptr; + } + CallOptionsT *AsCallOptions() { + return type == BuiltinOptions_CallOptions + ? reinterpret_cast(value) + : nullptr; + } + const CallOptionsT *AsCallOptions() const { + return type == BuiltinOptions_CallOptions + ? reinterpret_cast(value) + : nullptr; + } + ReshapeOptionsT *AsReshapeOptions() { + return type == BuiltinOptions_ReshapeOptions + ? reinterpret_cast(value) + : nullptr; + } + const ReshapeOptionsT *AsReshapeOptions() const { + return type == BuiltinOptions_ReshapeOptions + ? reinterpret_cast(value) + : nullptr; + } + SkipGramOptionsT *AsSkipGramOptions() { + return type == BuiltinOptions_SkipGramOptions + ? reinterpret_cast(value) + : nullptr; + } + const SkipGramOptionsT *AsSkipGramOptions() const { + return type == BuiltinOptions_SkipGramOptions + ? reinterpret_cast(value) + : nullptr; + } + SpaceToDepthOptionsT *AsSpaceToDepthOptions() { + return type == BuiltinOptions_SpaceToDepthOptions + ? reinterpret_cast(value) + : nullptr; + } + const SpaceToDepthOptionsT *AsSpaceToDepthOptions() const { + return type == BuiltinOptions_SpaceToDepthOptions + ? reinterpret_cast(value) + : nullptr; + } + EmbeddingLookupSparseOptionsT *AsEmbeddingLookupSparseOptions() { + return type == BuiltinOptions_EmbeddingLookupSparseOptions + ? reinterpret_cast(value) + : nullptr; + } + const EmbeddingLookupSparseOptionsT *AsEmbeddingLookupSparseOptions() const { + return type == BuiltinOptions_EmbeddingLookupSparseOptions + ? reinterpret_cast(value) + : nullptr; + } + MulOptionsT *AsMulOptions() { + return type == BuiltinOptions_MulOptions + ? reinterpret_cast(value) + : nullptr; + } + const MulOptionsT *AsMulOptions() const { + return type == BuiltinOptions_MulOptions + ? reinterpret_cast(value) + : nullptr; + } +}; + +bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, + BuiltinOptions type); +bool VerifyBuiltinOptionsVector( + flatbuffers::Verifier &verifier, + const flatbuffers::Vector> *values, + const flatbuffers::Vector *types); + +enum Padding { + Padding_SAME = 0, + Padding_VALID = 1, + Padding_MIN = Padding_SAME, + Padding_MAX = Padding_VALID +}; + +inline Padding (&EnumValuesPadding())[2] { + static Padding values[] = {Padding_SAME, Padding_VALID}; + return values; +} + +inline const char **EnumNamesPadding() { + static const char *names[] = {"SAME", "VALID", nullptr}; + return names; +} + +inline const char *EnumNamePadding(Padding e) { + const size_t index = static_cast(e); + return EnumNamesPadding()[index]; +} + +enum ActivationFunctionType { + ActivationFunctionType_NONE = 0, + ActivationFunctionType_RELU = 1, + ActivationFunctionType_RELU1 = 2, + ActivationFunctionType_RELU6 = 3, + ActivationFunctionType_TANH = 4, + ActivationFunctionType_SIGN_BIT = 5, + ActivationFunctionType_MIN = ActivationFunctionType_NONE, + ActivationFunctionType_MAX = ActivationFunctionType_SIGN_BIT +}; + +inline ActivationFunctionType (&EnumValuesActivationFunctionType())[6] { + static ActivationFunctionType values[] = { + ActivationFunctionType_NONE, ActivationFunctionType_RELU, + ActivationFunctionType_RELU1, ActivationFunctionType_RELU6, + ActivationFunctionType_TANH, ActivationFunctionType_SIGN_BIT}; + return values; +} + +inline const char **EnumNamesActivationFunctionType() { + static const char *names[] = {"NONE", "RELU", "RELU1", "RELU6", + "TANH", "SIGN_BIT", nullptr}; + return names; +} + +inline const char *EnumNameActivationFunctionType(ActivationFunctionType e) { + const size_t index = static_cast(e); + return EnumNamesActivationFunctionType()[index]; +} + +enum LSHProjectionType { + LSHProjectionType_UNKNOWN = 0, + LSHProjectionType_SPARSE = 1, + LSHProjectionType_DENSE = 2, + LSHProjectionType_MIN = LSHProjectionType_UNKNOWN, + LSHProjectionType_MAX = LSHProjectionType_DENSE +}; + +inline LSHProjectionType (&EnumValuesLSHProjectionType())[3] { + static LSHProjectionType values[] = {LSHProjectionType_UNKNOWN, + LSHProjectionType_SPARSE, + LSHProjectionType_DENSE}; + return values; +} + +inline const char **EnumNamesLSHProjectionType() { + static const char *names[] = {"UNKNOWN", "SPARSE", "DENSE", nullptr}; + return names; +} + +inline const char *EnumNameLSHProjectionType(LSHProjectionType e) { + const size_t index = static_cast(e); + return EnumNamesLSHProjectionType()[index]; +} + +enum CombinerType { + CombinerType_SUM = 0, + CombinerType_MEAN = 1, + CombinerType_SQRTN = 2, + CombinerType_MIN = CombinerType_SUM, + CombinerType_MAX = CombinerType_SQRTN +}; + +inline CombinerType (&EnumValuesCombinerType())[3] { + static CombinerType values[] = {CombinerType_SUM, CombinerType_MEAN, + CombinerType_SQRTN}; + return values; +} + +inline const char **EnumNamesCombinerType() { + static const char *names[] = {"SUM", "MEAN", "SQRTN", nullptr}; + return names; +} + +inline const char *EnumNameCombinerType(CombinerType e) { + const size_t index = static_cast(e); + return EnumNamesCombinerType()[index]; +} + +enum CustomOptionsFormat { + CustomOptionsFormat_FLEXBUFFERS = 0, + CustomOptionsFormat_MIN = CustomOptionsFormat_FLEXBUFFERS, + CustomOptionsFormat_MAX = CustomOptionsFormat_FLEXBUFFERS +}; + +inline CustomOptionsFormat (&EnumValuesCustomOptionsFormat())[1] { + static CustomOptionsFormat values[] = {CustomOptionsFormat_FLEXBUFFERS}; + return values; +} + +inline const char **EnumNamesCustomOptionsFormat() { + static const char *names[] = {"FLEXBUFFERS", nullptr}; + return names; +} + +inline const char *EnumNameCustomOptionsFormat(CustomOptionsFormat e) { + const size_t index = static_cast(e); + return EnumNamesCustomOptionsFormat()[index]; +} + +struct QuantizationParametersT : public flatbuffers::NativeTable { + typedef QuantizationParameters TableType; + std::vector min; + std::vector max; + std::vector scale; + std::vector zero_point; + QuantizationParametersT() {} +}; + +struct QuantizationParameters FLATBUFFERS_FINAL_CLASS + : private flatbuffers::Table { + typedef QuantizationParametersT NativeTableType; + enum { VT_MIN = 4, VT_MAX = 6, VT_SCALE = 8, VT_ZERO_POINT = 10 }; + const flatbuffers::Vector *min() const { + return GetPointer *>(VT_MIN); + } + const flatbuffers::Vector *max() const { + return GetPointer *>(VT_MAX); + } + const flatbuffers::Vector *scale() const { + return GetPointer *>(VT_SCALE); + } + const flatbuffers::Vector *zero_point() const { + return GetPointer *>(VT_ZERO_POINT); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_MIN) && + verifier.Verify(min()) && VerifyOffset(verifier, VT_MAX) && + verifier.Verify(max()) && VerifyOffset(verifier, VT_SCALE) && + verifier.Verify(scale()) && VerifyOffset(verifier, VT_ZERO_POINT) && + verifier.Verify(zero_point()) && verifier.EndTable(); + } + QuantizationParametersT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + QuantizationParametersT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct QuantizationParametersBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_min(flatbuffers::Offset> min) { + fbb_.AddOffset(QuantizationParameters::VT_MIN, min); + } + void add_max(flatbuffers::Offset> max) { + fbb_.AddOffset(QuantizationParameters::VT_MAX, max); + } + void add_scale(flatbuffers::Offset> scale) { + fbb_.AddOffset(QuantizationParameters::VT_SCALE, scale); + } + void add_zero_point( + flatbuffers::Offset> zero_point) { + fbb_.AddOffset(QuantizationParameters::VT_ZERO_POINT, zero_point); + } + explicit QuantizationParametersBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + QuantizationParametersBuilder &operator=( + const QuantizationParametersBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateQuantizationParameters( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> min = 0, + flatbuffers::Offset> max = 0, + flatbuffers::Offset> scale = 0, + flatbuffers::Offset> zero_point = 0) { + QuantizationParametersBuilder builder_(_fbb); + builder_.add_zero_point(zero_point); + builder_.add_scale(scale); + builder_.add_max(max); + builder_.add_min(min); + return builder_.Finish(); +} + +inline flatbuffers::Offset +CreateQuantizationParametersDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *min = nullptr, + const std::vector *max = nullptr, + const std::vector *scale = nullptr, + const std::vector *zero_point = nullptr) { + return tflite::CreateQuantizationParameters( + _fbb, min ? _fbb.CreateVector(*min) : 0, + max ? _fbb.CreateVector(*max) : 0, + scale ? _fbb.CreateVector(*scale) : 0, + zero_point ? _fbb.CreateVector(*zero_point) : 0); +} + +flatbuffers::Offset CreateQuantizationParameters( + flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct TensorT : public flatbuffers::NativeTable { + typedef Tensor TableType; + std::vector shape; + TensorType type; + uint32_t buffer; + std::string name; + std::unique_ptr quantization; + TensorT() : type(TensorType_FLOAT32), buffer(0) {} +}; + +struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TensorT NativeTableType; + enum { + VT_SHAPE = 4, + VT_TYPE = 6, + VT_BUFFER = 8, + VT_NAME = 10, + VT_QUANTIZATION = 12 + }; + const flatbuffers::Vector *shape() const { + return GetPointer *>(VT_SHAPE); + } + TensorType type() const { + return static_cast(GetField(VT_TYPE, 0)); + } + uint32_t buffer() const { return GetField(VT_BUFFER, 0); } + const flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + const QuantizationParameters *quantization() const { + return GetPointer(VT_QUANTIZATION); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SHAPE) && + verifier.Verify(shape()) && VerifyField(verifier, VT_TYPE) && + VerifyField(verifier, VT_BUFFER) && + VerifyOffset(verifier, VT_NAME) && verifier.Verify(name()) && + VerifyOffset(verifier, VT_QUANTIZATION) && + verifier.VerifyTable(quantization()) && verifier.EndTable(); + } + TensorT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(TensorT *_o, const flatbuffers::resolver_function_t *_resolver = + nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct TensorBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_shape(flatbuffers::Offset> shape) { + fbb_.AddOffset(Tensor::VT_SHAPE, shape); + } + void add_type(TensorType type) { + fbb_.AddElement(Tensor::VT_TYPE, static_cast(type), 0); + } + void add_buffer(uint32_t buffer) { + fbb_.AddElement(Tensor::VT_BUFFER, buffer, 0); + } + void add_name(flatbuffers::Offset name) { + fbb_.AddOffset(Tensor::VT_NAME, name); + } + void add_quantization( + flatbuffers::Offset quantization) { + fbb_.AddOffset(Tensor::VT_QUANTIZATION, quantization); + } + explicit TensorBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TensorBuilder &operator=(const TensorBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTensor( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> shape = 0, + TensorType type = TensorType_FLOAT32, uint32_t buffer = 0, + flatbuffers::Offset name = 0, + flatbuffers::Offset quantization = 0) { + TensorBuilder builder_(_fbb); + builder_.add_quantization(quantization); + builder_.add_name(name); + builder_.add_buffer(buffer); + builder_.add_shape(shape); + builder_.add_type(type); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateTensorDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *shape = nullptr, + TensorType type = TensorType_FLOAT32, uint32_t buffer = 0, + const char *name = nullptr, + flatbuffers::Offset quantization = 0) { + return tflite::CreateTensor( + _fbb, shape ? _fbb.CreateVector(*shape) : 0, type, buffer, + name ? _fbb.CreateString(name) : 0, quantization); +} + +flatbuffers::Offset CreateTensor( + flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct Conv2DOptionsT : public flatbuffers::NativeTable { + typedef Conv2DOptions TableType; + Padding padding; + int32_t stride_w; + int32_t stride_h; + ActivationFunctionType fused_activation_function; + Conv2DOptionsT() + : padding(Padding_SAME), + stride_w(0), + stride_h(0), + fused_activation_function(ActivationFunctionType_NONE) {} +}; + +struct Conv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef Conv2DOptionsT NativeTableType; + enum { + VT_PADDING = 4, + VT_STRIDE_W = 6, + VT_STRIDE_H = 8, + VT_FUSED_ACTIVATION_FUNCTION = 10 + }; + Padding padding() const { + return static_cast(GetField(VT_PADDING, 0)); + } + int32_t stride_w() const { return GetField(VT_STRIDE_W, 0); } + int32_t stride_h() const { return GetField(VT_STRIDE_H, 0); } + ActivationFunctionType fused_activation_function() const { + return static_cast( + GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_PADDING) && + VerifyField(verifier, VT_STRIDE_W) && + VerifyField(verifier, VT_STRIDE_H) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + verifier.EndTable(); + } + Conv2DOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + Conv2DOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct Conv2DOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_padding(Padding padding) { + fbb_.AddElement(Conv2DOptions::VT_PADDING, + static_cast(padding), 0); + } + void add_stride_w(int32_t stride_w) { + fbb_.AddElement(Conv2DOptions::VT_STRIDE_W, stride_w, 0); + } + void add_stride_h(int32_t stride_h) { + fbb_.AddElement(Conv2DOptions::VT_STRIDE_H, stride_h, 0); + } + void add_fused_activation_function( + ActivationFunctionType fused_activation_function) { + fbb_.AddElement(Conv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast(fused_activation_function), 0); + } + explicit Conv2DOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + Conv2DOptionsBuilder &operator=(const Conv2DOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateConv2DOptions( + flatbuffers::FlatBufferBuilder &_fbb, Padding padding = Padding_SAME, + int32_t stride_w = 0, int32_t stride_h = 0, + ActivationFunctionType fused_activation_function = + ActivationFunctionType_NONE) { + Conv2DOptionsBuilder builder_(_fbb); + builder_.add_stride_h(stride_h); + builder_.add_stride_w(stride_w); + builder_.add_fused_activation_function(fused_activation_function); + builder_.add_padding(padding); + return builder_.Finish(); +} + +flatbuffers::Offset CreateConv2DOptions( + flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct Pool2DOptionsT : public flatbuffers::NativeTable { + typedef Pool2DOptions TableType; + Padding padding; + int32_t stride_w; + int32_t stride_h; + int32_t filter_width; + int32_t filter_height; + ActivationFunctionType fused_activation_function; + Pool2DOptionsT() + : padding(Padding_SAME), + stride_w(0), + stride_h(0), + filter_width(0), + filter_height(0), + fused_activation_function(ActivationFunctionType_NONE) {} +}; + +struct Pool2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef Pool2DOptionsT NativeTableType; + enum { + VT_PADDING = 4, + VT_STRIDE_W = 6, + VT_STRIDE_H = 8, + VT_FILTER_WIDTH = 10, + VT_FILTER_HEIGHT = 12, + VT_FUSED_ACTIVATION_FUNCTION = 14 + }; + Padding padding() const { + return static_cast(GetField(VT_PADDING, 0)); + } + int32_t stride_w() const { return GetField(VT_STRIDE_W, 0); } + int32_t stride_h() const { return GetField(VT_STRIDE_H, 0); } + int32_t filter_width() const { return GetField(VT_FILTER_WIDTH, 0); } + int32_t filter_height() const { + return GetField(VT_FILTER_HEIGHT, 0); + } + ActivationFunctionType fused_activation_function() const { + return static_cast( + GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_PADDING) && + VerifyField(verifier, VT_STRIDE_W) && + VerifyField(verifier, VT_STRIDE_H) && + VerifyField(verifier, VT_FILTER_WIDTH) && + VerifyField(verifier, VT_FILTER_HEIGHT) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + verifier.EndTable(); + } + Pool2DOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + Pool2DOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct Pool2DOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_padding(Padding padding) { + fbb_.AddElement(Pool2DOptions::VT_PADDING, + static_cast(padding), 0); + } + void add_stride_w(int32_t stride_w) { + fbb_.AddElement(Pool2DOptions::VT_STRIDE_W, stride_w, 0); + } + void add_stride_h(int32_t stride_h) { + fbb_.AddElement(Pool2DOptions::VT_STRIDE_H, stride_h, 0); + } + void add_filter_width(int32_t filter_width) { + fbb_.AddElement(Pool2DOptions::VT_FILTER_WIDTH, filter_width, 0); + } + void add_filter_height(int32_t filter_height) { + fbb_.AddElement(Pool2DOptions::VT_FILTER_HEIGHT, filter_height, 0); + } + void add_fused_activation_function( + ActivationFunctionType fused_activation_function) { + fbb_.AddElement(Pool2DOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast(fused_activation_function), 0); + } + explicit Pool2DOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + Pool2DOptionsBuilder &operator=(const Pool2DOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreatePool2DOptions( + flatbuffers::FlatBufferBuilder &_fbb, Padding padding = Padding_SAME, + int32_t stride_w = 0, int32_t stride_h = 0, int32_t filter_width = 0, + int32_t filter_height = 0, + ActivationFunctionType fused_activation_function = + ActivationFunctionType_NONE) { + Pool2DOptionsBuilder builder_(_fbb); + builder_.add_filter_height(filter_height); + builder_.add_filter_width(filter_width); + builder_.add_stride_h(stride_h); + builder_.add_stride_w(stride_w); + builder_.add_fused_activation_function(fused_activation_function); + builder_.add_padding(padding); + return builder_.Finish(); +} + +flatbuffers::Offset CreatePool2DOptions( + flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct DepthwiseConv2DOptionsT : public flatbuffers::NativeTable { + typedef DepthwiseConv2DOptions TableType; + Padding padding; + int32_t stride_w; + int32_t stride_h; + int32_t depth_multiplier; + ActivationFunctionType fused_activation_function; + DepthwiseConv2DOptionsT() + : padding(Padding_SAME), + stride_w(0), + stride_h(0), + depth_multiplier(0), + fused_activation_function(ActivationFunctionType_NONE) {} +}; + +struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS + : private flatbuffers::Table { + typedef DepthwiseConv2DOptionsT NativeTableType; + enum { + VT_PADDING = 4, + VT_STRIDE_W = 6, + VT_STRIDE_H = 8, + VT_DEPTH_MULTIPLIER = 10, + VT_FUSED_ACTIVATION_FUNCTION = 12 + }; + Padding padding() const { + return static_cast(GetField(VT_PADDING, 0)); + } + int32_t stride_w() const { return GetField(VT_STRIDE_W, 0); } + int32_t stride_h() const { return GetField(VT_STRIDE_H, 0); } + int32_t depth_multiplier() const { + return GetField(VT_DEPTH_MULTIPLIER, 0); + } + ActivationFunctionType fused_activation_function() const { + return static_cast( + GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_PADDING) && + VerifyField(verifier, VT_STRIDE_W) && + VerifyField(verifier, VT_STRIDE_H) && + VerifyField(verifier, VT_DEPTH_MULTIPLIER) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + verifier.EndTable(); + } + DepthwiseConv2DOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + DepthwiseConv2DOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct DepthwiseConv2DOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_padding(Padding padding) { + fbb_.AddElement(DepthwiseConv2DOptions::VT_PADDING, + static_cast(padding), 0); + } + void add_stride_w(int32_t stride_w) { + fbb_.AddElement(DepthwiseConv2DOptions::VT_STRIDE_W, stride_w, 0); + } + void add_stride_h(int32_t stride_h) { + fbb_.AddElement(DepthwiseConv2DOptions::VT_STRIDE_H, stride_h, 0); + } + void add_depth_multiplier(int32_t depth_multiplier) { + fbb_.AddElement(DepthwiseConv2DOptions::VT_DEPTH_MULTIPLIER, + depth_multiplier, 0); + } + void add_fused_activation_function( + ActivationFunctionType fused_activation_function) { + fbb_.AddElement( + DepthwiseConv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast(fused_activation_function), 0); + } + explicit DepthwiseConv2DOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + DepthwiseConv2DOptionsBuilder &operator=( + const DepthwiseConv2DOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateDepthwiseConv2DOptions( + flatbuffers::FlatBufferBuilder &_fbb, Padding padding = Padding_SAME, + int32_t stride_w = 0, int32_t stride_h = 0, int32_t depth_multiplier = 0, + ActivationFunctionType fused_activation_function = + ActivationFunctionType_NONE) { + DepthwiseConv2DOptionsBuilder builder_(_fbb); + builder_.add_depth_multiplier(depth_multiplier); + builder_.add_stride_h(stride_h); + builder_.add_stride_w(stride_w); + builder_.add_fused_activation_function(fused_activation_function); + builder_.add_padding(padding); + return builder_.Finish(); +} + +flatbuffers::Offset CreateDepthwiseConv2DOptions( + flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ConcatEmbeddingsOptionsT : public flatbuffers::NativeTable { + typedef ConcatEmbeddingsOptions TableType; + int32_t num_channels; + std::vector num_columns_per_channel; + std::vector embedding_dim_per_channel; + ConcatEmbeddingsOptionsT() : num_channels(0) {} +}; + +struct ConcatEmbeddingsOptions FLATBUFFERS_FINAL_CLASS + : private flatbuffers::Table { + typedef ConcatEmbeddingsOptionsT NativeTableType; + enum { + VT_NUM_CHANNELS = 4, + VT_NUM_COLUMNS_PER_CHANNEL = 6, + VT_EMBEDDING_DIM_PER_CHANNEL = 8 + }; + int32_t num_channels() const { return GetField(VT_NUM_CHANNELS, 0); } + const flatbuffers::Vector *num_columns_per_channel() const { + return GetPointer *>( + VT_NUM_COLUMNS_PER_CHANNEL); + } + const flatbuffers::Vector *embedding_dim_per_channel() const { + return GetPointer *>( + VT_EMBEDDING_DIM_PER_CHANNEL); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_NUM_CHANNELS) && + VerifyOffset(verifier, VT_NUM_COLUMNS_PER_CHANNEL) && + verifier.Verify(num_columns_per_channel()) && + VerifyOffset(verifier, VT_EMBEDDING_DIM_PER_CHANNEL) && + verifier.Verify(embedding_dim_per_channel()) && verifier.EndTable(); + } + ConcatEmbeddingsOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + ConcatEmbeddingsOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ConcatEmbeddingsOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_num_channels(int32_t num_channels) { + fbb_.AddElement(ConcatEmbeddingsOptions::VT_NUM_CHANNELS, + num_channels, 0); + } + void add_num_columns_per_channel( + flatbuffers::Offset> + num_columns_per_channel) { + fbb_.AddOffset(ConcatEmbeddingsOptions::VT_NUM_COLUMNS_PER_CHANNEL, + num_columns_per_channel); + } + void add_embedding_dim_per_channel( + flatbuffers::Offset> + embedding_dim_per_channel) { + fbb_.AddOffset(ConcatEmbeddingsOptions::VT_EMBEDDING_DIM_PER_CHANNEL, + embedding_dim_per_channel); + } + explicit ConcatEmbeddingsOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ConcatEmbeddingsOptionsBuilder &operator=( + const ConcatEmbeddingsOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset +CreateConcatEmbeddingsOptions(flatbuffers::FlatBufferBuilder &_fbb, + int32_t num_channels = 0, + flatbuffers::Offset> + num_columns_per_channel = 0, + flatbuffers::Offset> + embedding_dim_per_channel = 0) { + ConcatEmbeddingsOptionsBuilder builder_(_fbb); + builder_.add_embedding_dim_per_channel(embedding_dim_per_channel); + builder_.add_num_columns_per_channel(num_columns_per_channel); + builder_.add_num_channels(num_channels); + return builder_.Finish(); +} + +inline flatbuffers::Offset +CreateConcatEmbeddingsOptionsDirect( + flatbuffers::FlatBufferBuilder &_fbb, int32_t num_channels = 0, + const std::vector *num_columns_per_channel = nullptr, + const std::vector *embedding_dim_per_channel = nullptr) { + return tflite::CreateConcatEmbeddingsOptions( + _fbb, num_channels, + num_columns_per_channel + ? _fbb.CreateVector(*num_columns_per_channel) + : 0, + embedding_dim_per_channel + ? _fbb.CreateVector(*embedding_dim_per_channel) + : 0); +} + +flatbuffers::Offset CreateConcatEmbeddingsOptions( + flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LSHProjectionOptionsT : public flatbuffers::NativeTable { + typedef LSHProjectionOptions TableType; + LSHProjectionType type; + LSHProjectionOptionsT() : type(LSHProjectionType_UNKNOWN) {} +}; + +struct LSHProjectionOptions FLATBUFFERS_FINAL_CLASS + : private flatbuffers::Table { + typedef LSHProjectionOptionsT NativeTableType; + enum { VT_TYPE = 4 }; + LSHProjectionType type() const { + return static_cast(GetField(VT_TYPE, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_TYPE) && verifier.EndTable(); + } + LSHProjectionOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + LSHProjectionOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LSHProjectionOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_type(LSHProjectionType type) { + fbb_.AddElement(LSHProjectionOptions::VT_TYPE, + static_cast(type), 0); + } + explicit LSHProjectionOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + LSHProjectionOptionsBuilder &operator=(const LSHProjectionOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateLSHProjectionOptions( + flatbuffers::FlatBufferBuilder &_fbb, + LSHProjectionType type = LSHProjectionType_UNKNOWN) { + LSHProjectionOptionsBuilder builder_(_fbb); + builder_.add_type(type); + return builder_.Finish(); +} + +flatbuffers::Offset CreateLSHProjectionOptions( + flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SVDFOptionsT : public flatbuffers::NativeTable { + typedef SVDFOptions TableType; + int32_t rank; + ActivationFunctionType fused_activation_function; + SVDFOptionsT() + : rank(0), fused_activation_function(ActivationFunctionType_NONE) {} +}; + +struct SVDFOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SVDFOptionsT NativeTableType; + enum { VT_RANK = 4, VT_FUSED_ACTIVATION_FUNCTION = 6 }; + int32_t rank() const { return GetField(VT_RANK, 0); } + ActivationFunctionType fused_activation_function() const { + return static_cast( + GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_RANK) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + verifier.EndTable(); + } + SVDFOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + SVDFOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SVDFOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_rank(int32_t rank) { + fbb_.AddElement(SVDFOptions::VT_RANK, rank, 0); + } + void add_fused_activation_function( + ActivationFunctionType fused_activation_function) { + fbb_.AddElement(SVDFOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast(fused_activation_function), 0); + } + explicit SVDFOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SVDFOptionsBuilder &operator=(const SVDFOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSVDFOptions( + flatbuffers::FlatBufferBuilder &_fbb, int32_t rank = 0, + ActivationFunctionType fused_activation_function = + ActivationFunctionType_NONE) { + SVDFOptionsBuilder builder_(_fbb); + builder_.add_rank(rank); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +flatbuffers::Offset CreateSVDFOptions( + flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct RNNOptionsT : public flatbuffers::NativeTable { + typedef RNNOptions TableType; + ActivationFunctionType fused_activation_function; + RNNOptionsT() : fused_activation_function(ActivationFunctionType_NONE) {} +}; + +struct RNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef RNNOptionsT NativeTableType; + enum { VT_FUSED_ACTIVATION_FUNCTION = 4 }; + ActivationFunctionType fused_activation_function() const { + return static_cast( + GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + verifier.EndTable(); + } + RNNOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + RNNOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct RNNOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_fused_activation_function( + ActivationFunctionType fused_activation_function) { + fbb_.AddElement(RNNOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast(fused_activation_function), 0); + } + explicit RNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + RNNOptionsBuilder &operator=(const RNNOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateRNNOptions( + flatbuffers::FlatBufferBuilder &_fbb, + ActivationFunctionType fused_activation_function = + ActivationFunctionType_NONE) { + RNNOptionsBuilder builder_(_fbb); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +flatbuffers::Offset CreateRNNOptions( + flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct FullyConnectedOptionsT : public flatbuffers::NativeTable { + typedef FullyConnectedOptions TableType; + ActivationFunctionType fused_activation_function; + FullyConnectedOptionsT() + : fused_activation_function(ActivationFunctionType_NONE) {} +}; + +struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS + : private flatbuffers::Table { + typedef FullyConnectedOptionsT NativeTableType; + enum { VT_FUSED_ACTIVATION_FUNCTION = 4 }; + ActivationFunctionType fused_activation_function() const { + return static_cast( + GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + verifier.EndTable(); + } + FullyConnectedOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + FullyConnectedOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct FullyConnectedOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_fused_activation_function( + ActivationFunctionType fused_activation_function) { + fbb_.AddElement(FullyConnectedOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast(fused_activation_function), 0); + } + explicit FullyConnectedOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + FullyConnectedOptionsBuilder &operator=(const FullyConnectedOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateFullyConnectedOptions( + flatbuffers::FlatBufferBuilder &_fbb, + ActivationFunctionType fused_activation_function = + ActivationFunctionType_NONE) { + FullyConnectedOptionsBuilder builder_(_fbb); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +flatbuffers::Offset CreateFullyConnectedOptions( + flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SoftmaxOptionsT : public flatbuffers::NativeTable { + typedef SoftmaxOptions TableType; + float beta; + SoftmaxOptionsT() : beta(0.0f) {} +}; + +struct SoftmaxOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SoftmaxOptionsT NativeTableType; + enum { VT_BETA = 4 }; + float beta() const { return GetField(VT_BETA, 0.0f); } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_BETA) && verifier.EndTable(); + } + SoftmaxOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + SoftmaxOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SoftmaxOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_beta(float beta) { + fbb_.AddElement(SoftmaxOptions::VT_BETA, beta, 0.0f); + } + explicit SoftmaxOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SoftmaxOptionsBuilder &operator=(const SoftmaxOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSoftmaxOptions( + flatbuffers::FlatBufferBuilder &_fbb, float beta = 0.0f) { + SoftmaxOptionsBuilder builder_(_fbb); + builder_.add_beta(beta); + return builder_.Finish(); +} + +flatbuffers::Offset CreateSoftmaxOptions( + flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ConcatenationOptionsT : public flatbuffers::NativeTable { + typedef ConcatenationOptions TableType; + int32_t axis; + ActivationFunctionType fused_activation_function; + ConcatenationOptionsT() + : axis(0), fused_activation_function(ActivationFunctionType_NONE) {} +}; + +struct ConcatenationOptions FLATBUFFERS_FINAL_CLASS + : private flatbuffers::Table { + typedef ConcatenationOptionsT NativeTableType; + enum { VT_AXIS = 4, VT_FUSED_ACTIVATION_FUNCTION = 6 }; + int32_t axis() const { return GetField(VT_AXIS, 0); } + ActivationFunctionType fused_activation_function() const { + return static_cast( + GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_AXIS) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + verifier.EndTable(); + } + ConcatenationOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + ConcatenationOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ConcatenationOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_axis(int32_t axis) { + fbb_.AddElement(ConcatenationOptions::VT_AXIS, axis, 0); + } + void add_fused_activation_function( + ActivationFunctionType fused_activation_function) { + fbb_.AddElement(ConcatenationOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast(fused_activation_function), 0); + } + explicit ConcatenationOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ConcatenationOptionsBuilder &operator=(const ConcatenationOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateConcatenationOptions( + flatbuffers::FlatBufferBuilder &_fbb, int32_t axis = 0, + ActivationFunctionType fused_activation_function = + ActivationFunctionType_NONE) { + ConcatenationOptionsBuilder builder_(_fbb); + builder_.add_axis(axis); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +flatbuffers::Offset CreateConcatenationOptions( + flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct AddOptionsT : public flatbuffers::NativeTable { + typedef AddOptions TableType; + ActivationFunctionType fused_activation_function; + AddOptionsT() : fused_activation_function(ActivationFunctionType_NONE) {} +}; + +struct AddOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef AddOptionsT NativeTableType; + enum { VT_FUSED_ACTIVATION_FUNCTION = 4 }; + ActivationFunctionType fused_activation_function() const { + return static_cast( + GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + verifier.EndTable(); + } + AddOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + AddOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct AddOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_fused_activation_function( + ActivationFunctionType fused_activation_function) { + fbb_.AddElement(AddOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast(fused_activation_function), 0); + } + explicit AddOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + AddOptionsBuilder &operator=(const AddOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateAddOptions( + flatbuffers::FlatBufferBuilder &_fbb, + ActivationFunctionType fused_activation_function = + ActivationFunctionType_NONE) { + AddOptionsBuilder builder_(_fbb); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +flatbuffers::Offset CreateAddOptions( + flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct MulOptionsT : public flatbuffers::NativeTable { + typedef MulOptions TableType; + ActivationFunctionType fused_activation_function; + MulOptionsT() : fused_activation_function(ActivationFunctionType_NONE) {} +}; + +struct MulOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef MulOptionsT NativeTableType; + enum { VT_FUSED_ACTIVATION_FUNCTION = 4 }; + ActivationFunctionType fused_activation_function() const { + return static_cast( + GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + verifier.EndTable(); + } + MulOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + MulOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct MulOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_fused_activation_function( + ActivationFunctionType fused_activation_function) { + fbb_.AddElement(MulOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast(fused_activation_function), 0); + } + explicit MulOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + MulOptionsBuilder &operator=(const MulOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateMulOptions( + flatbuffers::FlatBufferBuilder &_fbb, + ActivationFunctionType fused_activation_function = + ActivationFunctionType_NONE) { + MulOptionsBuilder builder_(_fbb); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +flatbuffers::Offset CreateMulOptions( + flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct L2NormOptionsT : public flatbuffers::NativeTable { + typedef L2NormOptions TableType; + ActivationFunctionType fused_activation_function; + L2NormOptionsT() : fused_activation_function(ActivationFunctionType_NONE) {} +}; + +struct L2NormOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef L2NormOptionsT NativeTableType; + enum { VT_FUSED_ACTIVATION_FUNCTION = 4 }; + ActivationFunctionType fused_activation_function() const { + return static_cast( + GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + verifier.EndTable(); + } + L2NormOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + L2NormOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct L2NormOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_fused_activation_function( + ActivationFunctionType fused_activation_function) { + fbb_.AddElement(L2NormOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast(fused_activation_function), 0); + } + explicit L2NormOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + L2NormOptionsBuilder &operator=(const L2NormOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateL2NormOptions( + flatbuffers::FlatBufferBuilder &_fbb, + ActivationFunctionType fused_activation_function = + ActivationFunctionType_NONE) { + L2NormOptionsBuilder builder_(_fbb); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +flatbuffers::Offset CreateL2NormOptions( + flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LocalResponseNormalizationOptionsT : public flatbuffers::NativeTable { + typedef LocalResponseNormalizationOptions TableType; + int32_t radius; + float bias; + float alpha; + float beta; + LocalResponseNormalizationOptionsT() + : radius(0), bias(0.0f), alpha(0.0f), beta(0.0f) {} +}; + +struct LocalResponseNormalizationOptions FLATBUFFERS_FINAL_CLASS + : private flatbuffers::Table { + typedef LocalResponseNormalizationOptionsT NativeTableType; + enum { VT_RADIUS = 4, VT_BIAS = 6, VT_ALPHA = 8, VT_BETA = 10 }; + int32_t radius() const { return GetField(VT_RADIUS, 0); } + float bias() const { return GetField(VT_BIAS, 0.0f); } + float alpha() const { return GetField(VT_ALPHA, 0.0f); } + float beta() const { return GetField(VT_BETA, 0.0f); } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_RADIUS) && + VerifyField(verifier, VT_BIAS) && + VerifyField(verifier, VT_ALPHA) && + VerifyField(verifier, VT_BETA) && verifier.EndTable(); + } + LocalResponseNormalizationOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + LocalResponseNormalizationOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, + const LocalResponseNormalizationOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LocalResponseNormalizationOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_radius(int32_t radius) { + fbb_.AddElement(LocalResponseNormalizationOptions::VT_RADIUS, + radius, 0); + } + void add_bias(float bias) { + fbb_.AddElement(LocalResponseNormalizationOptions::VT_BIAS, bias, + 0.0f); + } + void add_alpha(float alpha) { + fbb_.AddElement(LocalResponseNormalizationOptions::VT_ALPHA, alpha, + 0.0f); + } + void add_beta(float beta) { + fbb_.AddElement(LocalResponseNormalizationOptions::VT_BETA, beta, + 0.0f); + } + explicit LocalResponseNormalizationOptionsBuilder( + flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + LocalResponseNormalizationOptionsBuilder &operator=( + const LocalResponseNormalizationOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset +CreateLocalResponseNormalizationOptions(flatbuffers::FlatBufferBuilder &_fbb, + int32_t radius = 0, float bias = 0.0f, + float alpha = 0.0f, float beta = 0.0f) { + LocalResponseNormalizationOptionsBuilder builder_(_fbb); + builder_.add_beta(beta); + builder_.add_alpha(alpha); + builder_.add_bias(bias); + builder_.add_radius(radius); + return builder_.Finish(); +} + +flatbuffers::Offset +CreateLocalResponseNormalizationOptions( + flatbuffers::FlatBufferBuilder &_fbb, + const LocalResponseNormalizationOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LSTMOptionsT : public flatbuffers::NativeTable { + typedef LSTMOptions TableType; + ActivationFunctionType fused_activation_function; + float cell_clip; + float proj_clip; + LSTMOptionsT() + : fused_activation_function(ActivationFunctionType_NONE), + cell_clip(0.0f), + proj_clip(0.0f) {} +}; + +struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef LSTMOptionsT NativeTableType; + enum { VT_FUSED_ACTIVATION_FUNCTION = 4, VT_CELL_CLIP = 6, VT_PROJ_CLIP = 8 }; + ActivationFunctionType fused_activation_function() const { + return static_cast( + GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + float cell_clip() const { return GetField(VT_CELL_CLIP, 0.0f); } + float proj_clip() const { return GetField(VT_PROJ_CLIP, 0.0f); } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + VerifyField(verifier, VT_CELL_CLIP) && + VerifyField(verifier, VT_PROJ_CLIP) && verifier.EndTable(); + } + LSTMOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + LSTMOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LSTMOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_fused_activation_function( + ActivationFunctionType fused_activation_function) { + fbb_.AddElement(LSTMOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast(fused_activation_function), 0); + } + void add_cell_clip(float cell_clip) { + fbb_.AddElement(LSTMOptions::VT_CELL_CLIP, cell_clip, 0.0f); + } + void add_proj_clip(float proj_clip) { + fbb_.AddElement(LSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f); + } + explicit LSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + LSTMOptionsBuilder &operator=(const LSTMOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateLSTMOptions( + flatbuffers::FlatBufferBuilder &_fbb, + ActivationFunctionType fused_activation_function = + ActivationFunctionType_NONE, + float cell_clip = 0.0f, float proj_clip = 0.0f) { + LSTMOptionsBuilder builder_(_fbb); + builder_.add_proj_clip(proj_clip); + builder_.add_cell_clip(cell_clip); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +flatbuffers::Offset CreateLSTMOptions( + flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ResizeBilinearOptionsT : public flatbuffers::NativeTable { + typedef ResizeBilinearOptions TableType; + int32_t new_height; + int32_t new_width; + ResizeBilinearOptionsT() : new_height(0), new_width(0) {} +}; + +struct ResizeBilinearOptions FLATBUFFERS_FINAL_CLASS + : private flatbuffers::Table { + typedef ResizeBilinearOptionsT NativeTableType; + enum { VT_NEW_HEIGHT = 4, VT_NEW_WIDTH = 6 }; + int32_t new_height() const { return GetField(VT_NEW_HEIGHT, 0); } + int32_t new_width() const { return GetField(VT_NEW_WIDTH, 0); } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_NEW_HEIGHT) && + VerifyField(verifier, VT_NEW_WIDTH) && verifier.EndTable(); + } + ResizeBilinearOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + ResizeBilinearOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ResizeBilinearOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_new_height(int32_t new_height) { + fbb_.AddElement(ResizeBilinearOptions::VT_NEW_HEIGHT, new_height, + 0); + } + void add_new_width(int32_t new_width) { + fbb_.AddElement(ResizeBilinearOptions::VT_NEW_WIDTH, new_width, 0); + } + explicit ResizeBilinearOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ResizeBilinearOptionsBuilder &operator=(const ResizeBilinearOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateResizeBilinearOptions( + flatbuffers::FlatBufferBuilder &_fbb, int32_t new_height = 0, + int32_t new_width = 0) { + ResizeBilinearOptionsBuilder builder_(_fbb); + builder_.add_new_width(new_width); + builder_.add_new_height(new_height); + return builder_.Finish(); +} + +flatbuffers::Offset CreateResizeBilinearOptions( + flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct CallOptionsT : public flatbuffers::NativeTable { + typedef CallOptions TableType; + uint32_t subgraph; + CallOptionsT() : subgraph(0) {} +}; + +struct CallOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef CallOptionsT NativeTableType; + enum { VT_SUBGRAPH = 4 }; + uint32_t subgraph() const { return GetField(VT_SUBGRAPH, 0); } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_SUBGRAPH) && verifier.EndTable(); + } + CallOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + CallOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct CallOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_subgraph(uint32_t subgraph) { + fbb_.AddElement(CallOptions::VT_SUBGRAPH, subgraph, 0); + } + explicit CallOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + CallOptionsBuilder &operator=(const CallOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateCallOptions( + flatbuffers::FlatBufferBuilder &_fbb, uint32_t subgraph = 0) { + CallOptionsBuilder builder_(_fbb); + builder_.add_subgraph(subgraph); + return builder_.Finish(); +} + +flatbuffers::Offset CreateCallOptions( + flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ReshapeOptionsT : public flatbuffers::NativeTable { + typedef ReshapeOptions TableType; + std::vector new_shape; + ReshapeOptionsT() {} +}; + +struct ReshapeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ReshapeOptionsT NativeTableType; + enum { VT_NEW_SHAPE = 4 }; + const flatbuffers::Vector *new_shape() const { + return GetPointer *>(VT_NEW_SHAPE); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_NEW_SHAPE) && + verifier.Verify(new_shape()) && verifier.EndTable(); + } + ReshapeOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + ReshapeOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ReshapeOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_new_shape( + flatbuffers::Offset> new_shape) { + fbb_.AddOffset(ReshapeOptions::VT_NEW_SHAPE, new_shape); + } + explicit ReshapeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ReshapeOptionsBuilder &operator=(const ReshapeOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateReshapeOptions( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> new_shape = 0) { + ReshapeOptionsBuilder builder_(_fbb); + builder_.add_new_shape(new_shape); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateReshapeOptionsDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *new_shape = nullptr) { + return tflite::CreateReshapeOptions( + _fbb, new_shape ? _fbb.CreateVector(*new_shape) : 0); +} + +flatbuffers::Offset CreateReshapeOptions( + flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SkipGramOptionsT : public flatbuffers::NativeTable { + typedef SkipGramOptions TableType; + int32_t ngram_size; + int32_t max_skip_size; + bool include_all_ngrams; + SkipGramOptionsT() + : ngram_size(0), max_skip_size(0), include_all_ngrams(false) {} +}; + +struct SkipGramOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SkipGramOptionsT NativeTableType; + enum { VT_NGRAM_SIZE = 4, VT_MAX_SKIP_SIZE = 6, VT_INCLUDE_ALL_NGRAMS = 8 }; + int32_t ngram_size() const { return GetField(VT_NGRAM_SIZE, 0); } + int32_t max_skip_size() const { + return GetField(VT_MAX_SKIP_SIZE, 0); + } + bool include_all_ngrams() const { + return GetField(VT_INCLUDE_ALL_NGRAMS, 0) != 0; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_NGRAM_SIZE) && + VerifyField(verifier, VT_MAX_SKIP_SIZE) && + VerifyField(verifier, VT_INCLUDE_ALL_NGRAMS) && + verifier.EndTable(); + } + SkipGramOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + SkipGramOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SkipGramOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_ngram_size(int32_t ngram_size) { + fbb_.AddElement(SkipGramOptions::VT_NGRAM_SIZE, ngram_size, 0); + } + void add_max_skip_size(int32_t max_skip_size) { + fbb_.AddElement(SkipGramOptions::VT_MAX_SKIP_SIZE, max_skip_size, + 0); + } + void add_include_all_ngrams(bool include_all_ngrams) { + fbb_.AddElement(SkipGramOptions::VT_INCLUDE_ALL_NGRAMS, + static_cast(include_all_ngrams), 0); + } + explicit SkipGramOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SkipGramOptionsBuilder &operator=(const SkipGramOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSkipGramOptions( + flatbuffers::FlatBufferBuilder &_fbb, int32_t ngram_size = 0, + int32_t max_skip_size = 0, bool include_all_ngrams = false) { + SkipGramOptionsBuilder builder_(_fbb); + builder_.add_max_skip_size(max_skip_size); + builder_.add_ngram_size(ngram_size); + builder_.add_include_all_ngrams(include_all_ngrams); + return builder_.Finish(); +} + +flatbuffers::Offset CreateSkipGramOptions( + flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SpaceToDepthOptionsT : public flatbuffers::NativeTable { + typedef SpaceToDepthOptions TableType; + int32_t block_size; + SpaceToDepthOptionsT() : block_size(0) {} +}; + +struct SpaceToDepthOptions FLATBUFFERS_FINAL_CLASS + : private flatbuffers::Table { + typedef SpaceToDepthOptionsT NativeTableType; + enum { VT_BLOCK_SIZE = 4 }; + int32_t block_size() const { return GetField(VT_BLOCK_SIZE, 0); } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_BLOCK_SIZE) && verifier.EndTable(); + } + SpaceToDepthOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + SpaceToDepthOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SpaceToDepthOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_block_size(int32_t block_size) { + fbb_.AddElement(SpaceToDepthOptions::VT_BLOCK_SIZE, block_size, 0); + } + explicit SpaceToDepthOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SpaceToDepthOptionsBuilder &operator=(const SpaceToDepthOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSpaceToDepthOptions( + flatbuffers::FlatBufferBuilder &_fbb, int32_t block_size = 0) { + SpaceToDepthOptionsBuilder builder_(_fbb); + builder_.add_block_size(block_size); + return builder_.Finish(); +} + +flatbuffers::Offset CreateSpaceToDepthOptions( + flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct EmbeddingLookupSparseOptionsT : public flatbuffers::NativeTable { + typedef EmbeddingLookupSparseOptions TableType; + CombinerType combiner; + EmbeddingLookupSparseOptionsT() : combiner(CombinerType_SUM) {} +}; + +struct EmbeddingLookupSparseOptions FLATBUFFERS_FINAL_CLASS + : private flatbuffers::Table { + typedef EmbeddingLookupSparseOptionsT NativeTableType; + enum { VT_COMBINER = 4 }; + CombinerType combiner() const { + return static_cast(GetField(VT_COMBINER, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_COMBINER) && verifier.EndTable(); + } + EmbeddingLookupSparseOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + EmbeddingLookupSparseOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, + const EmbeddingLookupSparseOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct EmbeddingLookupSparseOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_combiner(CombinerType combiner) { + fbb_.AddElement(EmbeddingLookupSparseOptions::VT_COMBINER, + static_cast(combiner), 0); + } + explicit EmbeddingLookupSparseOptionsBuilder( + flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + EmbeddingLookupSparseOptionsBuilder &operator=( + const EmbeddingLookupSparseOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset +CreateEmbeddingLookupSparseOptions(flatbuffers::FlatBufferBuilder &_fbb, + CombinerType combiner = CombinerType_SUM) { + EmbeddingLookupSparseOptionsBuilder builder_(_fbb); + builder_.add_combiner(combiner); + return builder_.Finish(); +} + +flatbuffers::Offset +CreateEmbeddingLookupSparseOptions( + flatbuffers::FlatBufferBuilder &_fbb, + const EmbeddingLookupSparseOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct OperatorCodeT : public flatbuffers::NativeTable { + typedef OperatorCode TableType; + BuiltinOperator builtin_code; + std::string custom_code; + OperatorCodeT() : builtin_code(BuiltinOperator_ADD) {} +}; + +struct OperatorCode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef OperatorCodeT NativeTableType; + enum { VT_BUILTIN_CODE = 4, VT_CUSTOM_CODE = 6 }; + BuiltinOperator builtin_code() const { + return static_cast(GetField(VT_BUILTIN_CODE, 0)); + } + const flatbuffers::String *custom_code() const { + return GetPointer(VT_CUSTOM_CODE); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_BUILTIN_CODE) && + VerifyOffset(verifier, VT_CUSTOM_CODE) && + verifier.Verify(custom_code()) && verifier.EndTable(); + } + OperatorCodeT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + OperatorCodeT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct OperatorCodeBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_builtin_code(BuiltinOperator builtin_code) { + fbb_.AddElement(OperatorCode::VT_BUILTIN_CODE, + static_cast(builtin_code), 0); + } + void add_custom_code(flatbuffers::Offset custom_code) { + fbb_.AddOffset(OperatorCode::VT_CUSTOM_CODE, custom_code); + } + explicit OperatorCodeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + OperatorCodeBuilder &operator=(const OperatorCodeBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateOperatorCode( + flatbuffers::FlatBufferBuilder &_fbb, + BuiltinOperator builtin_code = BuiltinOperator_ADD, + flatbuffers::Offset custom_code = 0) { + OperatorCodeBuilder builder_(_fbb); + builder_.add_custom_code(custom_code); + builder_.add_builtin_code(builtin_code); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateOperatorCodeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + BuiltinOperator builtin_code = BuiltinOperator_ADD, + const char *custom_code = nullptr) { + return tflite::CreateOperatorCode( + _fbb, builtin_code, custom_code ? _fbb.CreateString(custom_code) : 0); +} + +flatbuffers::Offset CreateOperatorCode( + flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct OperatorT : public flatbuffers::NativeTable { + typedef Operator TableType; + uint32_t opcode_index; + std::vector inputs; + std::vector outputs; + BuiltinOptionsUnion builtin_options; + std::vector custom_options; + CustomOptionsFormat custom_options_format; + OperatorT() + : opcode_index(0), + custom_options_format(CustomOptionsFormat_FLEXBUFFERS) {} +}; + +struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef OperatorT NativeTableType; + enum { + VT_OPCODE_INDEX = 4, + VT_INPUTS = 6, + VT_OUTPUTS = 8, + VT_BUILTIN_OPTIONS_TYPE = 10, + VT_BUILTIN_OPTIONS = 12, + VT_CUSTOM_OPTIONS = 14, + VT_CUSTOM_OPTIONS_FORMAT = 16 + }; + uint32_t opcode_index() const { + return GetField(VT_OPCODE_INDEX, 0); + } + const flatbuffers::Vector *inputs() const { + return GetPointer *>(VT_INPUTS); + } + const flatbuffers::Vector *outputs() const { + return GetPointer *>(VT_OUTPUTS); + } + BuiltinOptions builtin_options_type() const { + return static_cast( + GetField(VT_BUILTIN_OPTIONS_TYPE, 0)); + } + const void *builtin_options() const { + return GetPointer(VT_BUILTIN_OPTIONS); + } + template + const T *builtin_options_as() const; + const Conv2DOptions *builtin_options_as_Conv2DOptions() const { + return builtin_options_type() == BuiltinOptions_Conv2DOptions + ? static_cast(builtin_options()) + : nullptr; + } + const DepthwiseConv2DOptions *builtin_options_as_DepthwiseConv2DOptions() + const { + return builtin_options_type() == BuiltinOptions_DepthwiseConv2DOptions + ? static_cast(builtin_options()) + : nullptr; + } + const ConcatEmbeddingsOptions *builtin_options_as_ConcatEmbeddingsOptions() + const { + return builtin_options_type() == BuiltinOptions_ConcatEmbeddingsOptions + ? static_cast(builtin_options()) + : nullptr; + } + const LSHProjectionOptions *builtin_options_as_LSHProjectionOptions() const { + return builtin_options_type() == BuiltinOptions_LSHProjectionOptions + ? static_cast(builtin_options()) + : nullptr; + } + const Pool2DOptions *builtin_options_as_Pool2DOptions() const { + return builtin_options_type() == BuiltinOptions_Pool2DOptions + ? static_cast(builtin_options()) + : nullptr; + } + const SVDFOptions *builtin_options_as_SVDFOptions() const { + return builtin_options_type() == BuiltinOptions_SVDFOptions + ? static_cast(builtin_options()) + : nullptr; + } + const RNNOptions *builtin_options_as_RNNOptions() const { + return builtin_options_type() == BuiltinOptions_RNNOptions + ? static_cast(builtin_options()) + : nullptr; + } + const FullyConnectedOptions *builtin_options_as_FullyConnectedOptions() + const { + return builtin_options_type() == BuiltinOptions_FullyConnectedOptions + ? static_cast(builtin_options()) + : nullptr; + } + const SoftmaxOptions *builtin_options_as_SoftmaxOptions() const { + return builtin_options_type() == BuiltinOptions_SoftmaxOptions + ? static_cast(builtin_options()) + : nullptr; + } + const ConcatenationOptions *builtin_options_as_ConcatenationOptions() const { + return builtin_options_type() == BuiltinOptions_ConcatenationOptions + ? static_cast(builtin_options()) + : nullptr; + } + const AddOptions *builtin_options_as_AddOptions() const { + return builtin_options_type() == BuiltinOptions_AddOptions + ? static_cast(builtin_options()) + : nullptr; + } + const L2NormOptions *builtin_options_as_L2NormOptions() const { + return builtin_options_type() == BuiltinOptions_L2NormOptions + ? static_cast(builtin_options()) + : nullptr; + } + const LocalResponseNormalizationOptions * + builtin_options_as_LocalResponseNormalizationOptions() const { + return builtin_options_type() == + BuiltinOptions_LocalResponseNormalizationOptions + ? static_cast( + builtin_options()) + : nullptr; + } + const LSTMOptions *builtin_options_as_LSTMOptions() const { + return builtin_options_type() == BuiltinOptions_LSTMOptions + ? static_cast(builtin_options()) + : nullptr; + } + const ResizeBilinearOptions *builtin_options_as_ResizeBilinearOptions() + const { + return builtin_options_type() == BuiltinOptions_ResizeBilinearOptions + ? static_cast(builtin_options()) + : nullptr; + } + const CallOptions *builtin_options_as_CallOptions() const { + return builtin_options_type() == BuiltinOptions_CallOptions + ? static_cast(builtin_options()) + : nullptr; + } + const ReshapeOptions *builtin_options_as_ReshapeOptions() const { + return builtin_options_type() == BuiltinOptions_ReshapeOptions + ? static_cast(builtin_options()) + : nullptr; + } + const SkipGramOptions *builtin_options_as_SkipGramOptions() const { + return builtin_options_type() == BuiltinOptions_SkipGramOptions + ? static_cast(builtin_options()) + : nullptr; + } + const SpaceToDepthOptions *builtin_options_as_SpaceToDepthOptions() const { + return builtin_options_type() == BuiltinOptions_SpaceToDepthOptions + ? static_cast(builtin_options()) + : nullptr; + } + const EmbeddingLookupSparseOptions * + builtin_options_as_EmbeddingLookupSparseOptions() const { + return builtin_options_type() == BuiltinOptions_EmbeddingLookupSparseOptions + ? static_cast( + builtin_options()) + : nullptr; + } + const MulOptions *builtin_options_as_MulOptions() const { + return builtin_options_type() == BuiltinOptions_MulOptions + ? static_cast(builtin_options()) + : nullptr; + } + const flatbuffers::Vector *custom_options() const { + return GetPointer *>(VT_CUSTOM_OPTIONS); + } + CustomOptionsFormat custom_options_format() const { + return static_cast( + GetField(VT_CUSTOM_OPTIONS_FORMAT, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_OPCODE_INDEX) && + VerifyOffset(verifier, VT_INPUTS) && verifier.Verify(inputs()) && + VerifyOffset(verifier, VT_OUTPUTS) && verifier.Verify(outputs()) && + VerifyField(verifier, VT_BUILTIN_OPTIONS_TYPE) && + VerifyOffset(verifier, VT_BUILTIN_OPTIONS) && + VerifyBuiltinOptions(verifier, builtin_options(), + builtin_options_type()) && + VerifyOffset(verifier, VT_CUSTOM_OPTIONS) && + verifier.Verify(custom_options()) && + VerifyField(verifier, VT_CUSTOM_OPTIONS_FORMAT) && + verifier.EndTable(); + } + OperatorT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + OperatorT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +template <> +inline const Conv2DOptions *Operator::builtin_options_as() + const { + return builtin_options_as_Conv2DOptions(); +} + +template <> +inline const DepthwiseConv2DOptions * +Operator::builtin_options_as() const { + return builtin_options_as_DepthwiseConv2DOptions(); +} + +template <> +inline const ConcatEmbeddingsOptions * +Operator::builtin_options_as() const { + return builtin_options_as_ConcatEmbeddingsOptions(); +} + +template <> +inline const LSHProjectionOptions * +Operator::builtin_options_as() const { + return builtin_options_as_LSHProjectionOptions(); +} + +template <> +inline const Pool2DOptions *Operator::builtin_options_as() + const { + return builtin_options_as_Pool2DOptions(); +} + +template <> +inline const SVDFOptions *Operator::builtin_options_as() const { + return builtin_options_as_SVDFOptions(); +} + +template <> +inline const RNNOptions *Operator::builtin_options_as() const { + return builtin_options_as_RNNOptions(); +} + +template <> +inline const FullyConnectedOptions * +Operator::builtin_options_as() const { + return builtin_options_as_FullyConnectedOptions(); +} + +template <> +inline const SoftmaxOptions *Operator::builtin_options_as() + const { + return builtin_options_as_SoftmaxOptions(); +} + +template <> +inline const ConcatenationOptions * +Operator::builtin_options_as() const { + return builtin_options_as_ConcatenationOptions(); +} + +template <> +inline const AddOptions *Operator::builtin_options_as() const { + return builtin_options_as_AddOptions(); +} + +template <> +inline const L2NormOptions *Operator::builtin_options_as() + const { + return builtin_options_as_L2NormOptions(); +} + +template <> +inline const LocalResponseNormalizationOptions * +Operator::builtin_options_as() const { + return builtin_options_as_LocalResponseNormalizationOptions(); +} + +template <> +inline const LSTMOptions *Operator::builtin_options_as() const { + return builtin_options_as_LSTMOptions(); +} + +template <> +inline const ResizeBilinearOptions * +Operator::builtin_options_as() const { + return builtin_options_as_ResizeBilinearOptions(); +} + +template <> +inline const CallOptions *Operator::builtin_options_as() const { + return builtin_options_as_CallOptions(); +} + +template <> +inline const ReshapeOptions *Operator::builtin_options_as() + const { + return builtin_options_as_ReshapeOptions(); +} + +template <> +inline const SkipGramOptions *Operator::builtin_options_as() + const { + return builtin_options_as_SkipGramOptions(); +} + +template <> +inline const SpaceToDepthOptions * +Operator::builtin_options_as() const { + return builtin_options_as_SpaceToDepthOptions(); +} + +template <> +inline const EmbeddingLookupSparseOptions * +Operator::builtin_options_as() const { + return builtin_options_as_EmbeddingLookupSparseOptions(); +} + +template <> +inline const MulOptions *Operator::builtin_options_as() const { + return builtin_options_as_MulOptions(); +} + +struct OperatorBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_opcode_index(uint32_t opcode_index) { + fbb_.AddElement(Operator::VT_OPCODE_INDEX, opcode_index, 0); + } + void add_inputs(flatbuffers::Offset> inputs) { + fbb_.AddOffset(Operator::VT_INPUTS, inputs); + } + void add_outputs(flatbuffers::Offset> outputs) { + fbb_.AddOffset(Operator::VT_OUTPUTS, outputs); + } + void add_builtin_options_type(BuiltinOptions builtin_options_type) { + fbb_.AddElement(Operator::VT_BUILTIN_OPTIONS_TYPE, + static_cast(builtin_options_type), 0); + } + void add_builtin_options(flatbuffers::Offset builtin_options) { + fbb_.AddOffset(Operator::VT_BUILTIN_OPTIONS, builtin_options); + } + void add_custom_options( + flatbuffers::Offset> custom_options) { + fbb_.AddOffset(Operator::VT_CUSTOM_OPTIONS, custom_options); + } + void add_custom_options_format(CustomOptionsFormat custom_options_format) { + fbb_.AddElement(Operator::VT_CUSTOM_OPTIONS_FORMAT, + static_cast(custom_options_format), 0); + } + explicit OperatorBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + OperatorBuilder &operator=(const OperatorBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateOperator( + flatbuffers::FlatBufferBuilder &_fbb, uint32_t opcode_index = 0, + flatbuffers::Offset> inputs = 0, + flatbuffers::Offset> outputs = 0, + BuiltinOptions builtin_options_type = BuiltinOptions_NONE, + flatbuffers::Offset builtin_options = 0, + flatbuffers::Offset> custom_options = 0, + CustomOptionsFormat custom_options_format = + CustomOptionsFormat_FLEXBUFFERS) { + OperatorBuilder builder_(_fbb); + builder_.add_custom_options(custom_options); + builder_.add_builtin_options(builtin_options); + builder_.add_outputs(outputs); + builder_.add_inputs(inputs); + builder_.add_opcode_index(opcode_index); + builder_.add_custom_options_format(custom_options_format); + builder_.add_builtin_options_type(builtin_options_type); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateOperatorDirect( + flatbuffers::FlatBufferBuilder &_fbb, uint32_t opcode_index = 0, + const std::vector *inputs = nullptr, + const std::vector *outputs = nullptr, + BuiltinOptions builtin_options_type = BuiltinOptions_NONE, + flatbuffers::Offset builtin_options = 0, + const std::vector *custom_options = nullptr, + CustomOptionsFormat custom_options_format = + CustomOptionsFormat_FLEXBUFFERS) { + return tflite::CreateOperator( + _fbb, opcode_index, inputs ? _fbb.CreateVector(*inputs) : 0, + outputs ? _fbb.CreateVector(*outputs) : 0, builtin_options_type, + builtin_options, + custom_options ? _fbb.CreateVector(*custom_options) : 0, + custom_options_format); +} + +flatbuffers::Offset CreateOperator( + flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SubGraphT : public flatbuffers::NativeTable { + typedef SubGraph TableType; + std::vector> tensors; + std::vector inputs; + std::vector outputs; + std::vector> operators; + std::string name; + SubGraphT() {} +}; + +struct SubGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SubGraphT NativeTableType; + enum { + VT_TENSORS = 4, + VT_INPUTS = 6, + VT_OUTPUTS = 8, + VT_OPERATORS = 10, + VT_NAME = 12 + }; + const flatbuffers::Vector> *tensors() const { + return GetPointer> *>( + VT_TENSORS); + } + const flatbuffers::Vector *inputs() const { + return GetPointer *>(VT_INPUTS); + } + const flatbuffers::Vector *outputs() const { + return GetPointer *>(VT_OUTPUTS); + } + const flatbuffers::Vector> *operators() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_OPERATORS); + } + const flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_TENSORS) && + verifier.Verify(tensors()) && + verifier.VerifyVectorOfTables(tensors()) && + VerifyOffset(verifier, VT_INPUTS) && verifier.Verify(inputs()) && + VerifyOffset(verifier, VT_OUTPUTS) && verifier.Verify(outputs()) && + VerifyOffset(verifier, VT_OPERATORS) && + verifier.Verify(operators()) && + verifier.VerifyVectorOfTables(operators()) && + VerifyOffset(verifier, VT_NAME) && verifier.Verify(name()) && + verifier.EndTable(); + } + SubGraphT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + SubGraphT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SubGraphBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_tensors( + flatbuffers::Offset>> + tensors) { + fbb_.AddOffset(SubGraph::VT_TENSORS, tensors); + } + void add_inputs(flatbuffers::Offset> inputs) { + fbb_.AddOffset(SubGraph::VT_INPUTS, inputs); + } + void add_outputs(flatbuffers::Offset> outputs) { + fbb_.AddOffset(SubGraph::VT_OUTPUTS, outputs); + } + void add_operators( + flatbuffers::Offset>> + operators) { + fbb_.AddOffset(SubGraph::VT_OPERATORS, operators); + } + void add_name(flatbuffers::Offset name) { + fbb_.AddOffset(SubGraph::VT_NAME, name); + } + explicit SubGraphBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SubGraphBuilder &operator=(const SubGraphBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSubGraph( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset>> + tensors = 0, + flatbuffers::Offset> inputs = 0, + flatbuffers::Offset> outputs = 0, + flatbuffers::Offset>> + operators = 0, + flatbuffers::Offset name = 0) { + SubGraphBuilder builder_(_fbb); + builder_.add_name(name); + builder_.add_operators(operators); + builder_.add_outputs(outputs); + builder_.add_inputs(inputs); + builder_.add_tensors(tensors); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateSubGraphDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector> *tensors = nullptr, + const std::vector *inputs = nullptr, + const std::vector *outputs = nullptr, + const std::vector> *operators = nullptr, + const char *name = nullptr) { + return tflite::CreateSubGraph( + _fbb, + tensors ? _fbb.CreateVector>(*tensors) : 0, + inputs ? _fbb.CreateVector(*inputs) : 0, + outputs ? _fbb.CreateVector(*outputs) : 0, + operators ? _fbb.CreateVector>(*operators) + : 0, + name ? _fbb.CreateString(name) : 0); +} + +flatbuffers::Offset CreateSubGraph( + flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct BufferT : public flatbuffers::NativeTable { + typedef Buffer TableType; + std::vector data; + BufferT() {} +}; + +struct Buffer FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef BufferT NativeTableType; + enum { VT_DATA = 4 }; + const flatbuffers::Vector *data() const { + return GetPointer *>(VT_DATA); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_DATA) && + verifier.Verify(data()) && verifier.EndTable(); + } + BufferT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BufferT *_o, const flatbuffers::resolver_function_t *_resolver = + nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const BufferT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct BufferBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data(flatbuffers::Offset> data) { + fbb_.AddOffset(Buffer::VT_DATA, data); + } + explicit BufferBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + BufferBuilder &operator=(const BufferBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateBuffer( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> data = 0) { + BufferBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateBufferDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *data = nullptr) { + return tflite::CreateBuffer(_fbb, + data ? _fbb.CreateVector(*data) : 0); +} + +flatbuffers::Offset CreateBuffer( + flatbuffers::FlatBufferBuilder &_fbb, const BufferT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ModelT : public flatbuffers::NativeTable { + typedef Model TableType; + uint32_t version; + std::vector> operator_codes; + std::vector> subgraphs; + std::string description; + std::vector> buffers; + ModelT() : version(0) {} +}; + +struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ModelT NativeTableType; + enum { + VT_VERSION = 4, + VT_OPERATOR_CODES = 6, + VT_SUBGRAPHS = 8, + VT_DESCRIPTION = 10, + VT_BUFFERS = 12 + }; + uint32_t version() const { return GetField(VT_VERSION, 0); } + const flatbuffers::Vector> *operator_codes() + const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_OPERATOR_CODES); + } + const flatbuffers::Vector> *subgraphs() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_SUBGRAPHS); + } + const flatbuffers::String *description() const { + return GetPointer(VT_DESCRIPTION); + } + const flatbuffers::Vector> *buffers() const { + return GetPointer> *>( + VT_BUFFERS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_VERSION) && + VerifyOffset(verifier, VT_OPERATOR_CODES) && + verifier.Verify(operator_codes()) && + verifier.VerifyVectorOfTables(operator_codes()) && + VerifyOffset(verifier, VT_SUBGRAPHS) && + verifier.Verify(subgraphs()) && + verifier.VerifyVectorOfTables(subgraphs()) && + VerifyOffset(verifier, VT_DESCRIPTION) && + verifier.Verify(description()) && + VerifyOffset(verifier, VT_BUFFERS) && verifier.Verify(buffers()) && + verifier.VerifyVectorOfTables(buffers()) && verifier.EndTable(); + } + ModelT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ModelT *_o, const flatbuffers::resolver_function_t *_resolver = + nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ModelBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_version(uint32_t version) { + fbb_.AddElement(Model::VT_VERSION, version, 0); + } + void add_operator_codes( + flatbuffers::Offset< + flatbuffers::Vector>> + operator_codes) { + fbb_.AddOffset(Model::VT_OPERATOR_CODES, operator_codes); + } + void add_subgraphs( + flatbuffers::Offset>> + subgraphs) { + fbb_.AddOffset(Model::VT_SUBGRAPHS, subgraphs); + } + void add_description(flatbuffers::Offset description) { + fbb_.AddOffset(Model::VT_DESCRIPTION, description); + } + void add_buffers( + flatbuffers::Offset>> + buffers) { + fbb_.AddOffset(Model::VT_BUFFERS, buffers); + } + explicit ModelBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ModelBuilder &operator=(const ModelBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateModel( + flatbuffers::FlatBufferBuilder &_fbb, uint32_t version = 0, + flatbuffers::Offset>> + operator_codes = 0, + flatbuffers::Offset>> + subgraphs = 0, + flatbuffers::Offset description = 0, + flatbuffers::Offset>> + buffers = 0) { + ModelBuilder builder_(_fbb); + builder_.add_buffers(buffers); + builder_.add_description(description); + builder_.add_subgraphs(subgraphs); + builder_.add_operator_codes(operator_codes); + builder_.add_version(version); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateModelDirect( + flatbuffers::FlatBufferBuilder &_fbb, uint32_t version = 0, + const std::vector> *operator_codes = + nullptr, + const std::vector> *subgraphs = nullptr, + const char *description = nullptr, + const std::vector> *buffers = nullptr) { + return tflite::CreateModel( + _fbb, version, + operator_codes ? _fbb.CreateVector>( + *operator_codes) + : 0, + subgraphs ? _fbb.CreateVector>(*subgraphs) + : 0, + description ? _fbb.CreateString(description) : 0, + buffers ? _fbb.CreateVector>(*buffers) : 0); +} + +flatbuffers::Offset CreateModel( + flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +inline QuantizationParametersT *QuantizationParameters::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new QuantizationParametersT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void QuantizationParameters::UnPackTo( + QuantizationParametersT *_o, + const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = min(); + if (_e) { + _o->min.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->min[_i] = _e->Get(_i); + } + } + }; + { + auto _e = max(); + if (_e) { + _o->max.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->max[_i] = _e->Get(_i); + } + } + }; + { + auto _e = scale(); + if (_e) { + _o->scale.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->scale[_i] = _e->Get(_i); + } + } + }; + { + auto _e = zero_point(); + if (_e) { + _o->zero_point.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->zero_point[_i] = _e->Get(_i); + } + } + }; +} + +inline flatbuffers::Offset QuantizationParameters::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateQuantizationParameters(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateQuantizationParameters( + flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const QuantizationParametersT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _min = _o->min.size() ? _fbb.CreateVector(_o->min) : 0; + auto _max = _o->max.size() ? _fbb.CreateVector(_o->max) : 0; + auto _scale = _o->scale.size() ? _fbb.CreateVector(_o->scale) : 0; + auto _zero_point = + _o->zero_point.size() ? _fbb.CreateVector(_o->zero_point) : 0; + return tflite::CreateQuantizationParameters(_fbb, _min, _max, _scale, + _zero_point); +} + +inline TensorT *Tensor::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new TensorT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void Tensor::UnPackTo( + TensorT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = shape(); + if (_e) { + _o->shape.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->shape[_i] = _e->Get(_i); + } + } + }; + { + auto _e = type(); + _o->type = _e; + }; + { + auto _e = buffer(); + _o->buffer = _e; + }; + { + auto _e = name(); + if (_e) _o->name = _e->str(); + }; + { + auto _e = quantization(); + if (_e) + _o->quantization = + std::unique_ptr(_e->UnPack(_resolver)); + }; +} + +inline flatbuffers::Offset Tensor::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateTensor(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateTensor( + flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const TensorT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _shape = _o->shape.size() ? _fbb.CreateVector(_o->shape) : 0; + auto _type = _o->type; + auto _buffer = _o->buffer; + auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name); + auto _quantization = _o->quantization + ? CreateQuantizationParameters( + _fbb, _o->quantization.get(), _rehasher) + : 0; + return tflite::CreateTensor(_fbb, _shape, _type, _buffer, _name, + _quantization); +} + +inline Conv2DOptionsT *Conv2DOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new Conv2DOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void Conv2DOptions::UnPackTo( + Conv2DOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = padding(); + _o->padding = _e; + }; + { + auto _e = stride_w(); + _o->stride_w = _e; + }; + { + auto _e = stride_h(); + _o->stride_h = _e; + }; + { + auto _e = fused_activation_function(); + _o->fused_activation_function = _e; + }; +} + +inline flatbuffers::Offset Conv2DOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateConv2DOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateConv2DOptions( + flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const Conv2DOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _padding = _o->padding; + auto _stride_w = _o->stride_w; + auto _stride_h = _o->stride_h; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreateConv2DOptions(_fbb, _padding, _stride_w, _stride_h, + _fused_activation_function); +} + +inline Pool2DOptionsT *Pool2DOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new Pool2DOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void Pool2DOptions::UnPackTo( + Pool2DOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = padding(); + _o->padding = _e; + }; + { + auto _e = stride_w(); + _o->stride_w = _e; + }; + { + auto _e = stride_h(); + _o->stride_h = _e; + }; + { + auto _e = filter_width(); + _o->filter_width = _e; + }; + { + auto _e = filter_height(); + _o->filter_height = _e; + }; + { + auto _e = fused_activation_function(); + _o->fused_activation_function = _e; + }; +} + +inline flatbuffers::Offset Pool2DOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreatePool2DOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreatePool2DOptions( + flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const Pool2DOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _padding = _o->padding; + auto _stride_w = _o->stride_w; + auto _stride_h = _o->stride_h; + auto _filter_width = _o->filter_width; + auto _filter_height = _o->filter_height; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreatePool2DOptions(_fbb, _padding, _stride_w, _stride_h, + _filter_width, _filter_height, + _fused_activation_function); +} + +inline DepthwiseConv2DOptionsT *DepthwiseConv2DOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new DepthwiseConv2DOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void DepthwiseConv2DOptions::UnPackTo( + DepthwiseConv2DOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = padding(); + _o->padding = _e; + }; + { + auto _e = stride_w(); + _o->stride_w = _e; + }; + { + auto _e = stride_h(); + _o->stride_h = _e; + }; + { + auto _e = depth_multiplier(); + _o->depth_multiplier = _e; + }; + { + auto _e = fused_activation_function(); + _o->fused_activation_function = _e; + }; +} + +inline flatbuffers::Offset DepthwiseConv2DOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateDepthwiseConv2DOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateDepthwiseConv2DOptions( + flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const DepthwiseConv2DOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _padding = _o->padding; + auto _stride_w = _o->stride_w; + auto _stride_h = _o->stride_h; + auto _depth_multiplier = _o->depth_multiplier; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreateDepthwiseConv2DOptions(_fbb, _padding, _stride_w, + _stride_h, _depth_multiplier, + _fused_activation_function); +} + +inline ConcatEmbeddingsOptionsT *ConcatEmbeddingsOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ConcatEmbeddingsOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ConcatEmbeddingsOptions::UnPackTo( + ConcatEmbeddingsOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = num_channels(); + _o->num_channels = _e; + }; + { + auto _e = num_columns_per_channel(); + if (_e) { + _o->num_columns_per_channel.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->num_columns_per_channel[_i] = _e->Get(_i); + } + } + }; + { + auto _e = embedding_dim_per_channel(); + if (_e) { + _o->embedding_dim_per_channel.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->embedding_dim_per_channel[_i] = _e->Get(_i); + } + } + }; +} + +inline flatbuffers::Offset +ConcatEmbeddingsOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateConcatEmbeddingsOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset +CreateConcatEmbeddingsOptions( + flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const ConcatEmbeddingsOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _num_channels = _o->num_channels; + auto _num_columns_per_channel = + _o->num_columns_per_channel.size() + ? _fbb.CreateVector(_o->num_columns_per_channel) + : 0; + auto _embedding_dim_per_channel = + _o->embedding_dim_per_channel.size() + ? _fbb.CreateVector(_o->embedding_dim_per_channel) + : 0; + return tflite::CreateConcatEmbeddingsOptions(_fbb, _num_channels, + _num_columns_per_channel, + _embedding_dim_per_channel); +} + +inline LSHProjectionOptionsT *LSHProjectionOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new LSHProjectionOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void LSHProjectionOptions::UnPackTo( + LSHProjectionOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = type(); + _o->type = _e; + }; +} + +inline flatbuffers::Offset LSHProjectionOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateLSHProjectionOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateLSHProjectionOptions( + flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const LSHProjectionOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _type = _o->type; + return tflite::CreateLSHProjectionOptions(_fbb, _type); +} + +inline SVDFOptionsT *SVDFOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new SVDFOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void SVDFOptions::UnPackTo( + SVDFOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = rank(); + _o->rank = _e; + }; + { + auto _e = fused_activation_function(); + _o->fused_activation_function = _e; + }; +} + +inline flatbuffers::Offset SVDFOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateSVDFOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateSVDFOptions( + flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const SVDFOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _rank = _o->rank; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreateSVDFOptions(_fbb, _rank, _fused_activation_function); +} + +inline RNNOptionsT *RNNOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new RNNOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void RNNOptions::UnPackTo( + RNNOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = fused_activation_function(); + _o->fused_activation_function = _e; + }; +} + +inline flatbuffers::Offset RNNOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateRNNOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateRNNOptions( + flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const RNNOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreateRNNOptions(_fbb, _fused_activation_function); +} + +inline FullyConnectedOptionsT *FullyConnectedOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new FullyConnectedOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void FullyConnectedOptions::UnPackTo( + FullyConnectedOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = fused_activation_function(); + _o->fused_activation_function = _e; + }; +} + +inline flatbuffers::Offset FullyConnectedOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateFullyConnectedOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateFullyConnectedOptions( + flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const FullyConnectedOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreateFullyConnectedOptions(_fbb, _fused_activation_function); +} + +inline SoftmaxOptionsT *SoftmaxOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new SoftmaxOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void SoftmaxOptions::UnPackTo( + SoftmaxOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = beta(); + _o->beta = _e; + }; +} + +inline flatbuffers::Offset SoftmaxOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateSoftmaxOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateSoftmaxOptions( + flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const SoftmaxOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _beta = _o->beta; + return tflite::CreateSoftmaxOptions(_fbb, _beta); +} + +inline ConcatenationOptionsT *ConcatenationOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ConcatenationOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ConcatenationOptions::UnPackTo( + ConcatenationOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = axis(); + _o->axis = _e; + }; + { + auto _e = fused_activation_function(); + _o->fused_activation_function = _e; + }; +} + +inline flatbuffers::Offset ConcatenationOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateConcatenationOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateConcatenationOptions( + flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const ConcatenationOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _axis = _o->axis; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreateConcatenationOptions(_fbb, _axis, + _fused_activation_function); +} + +inline AddOptionsT *AddOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new AddOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void AddOptions::UnPackTo( + AddOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = fused_activation_function(); + _o->fused_activation_function = _e; + }; +} + +inline flatbuffers::Offset AddOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateAddOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateAddOptions( + flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const AddOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreateAddOptions(_fbb, _fused_activation_function); +} + +inline MulOptionsT *MulOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new MulOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void MulOptions::UnPackTo( + MulOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = fused_activation_function(); + _o->fused_activation_function = _e; + }; +} + +inline flatbuffers::Offset MulOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateMulOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateMulOptions( + flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const MulOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreateMulOptions(_fbb, _fused_activation_function); +} + +inline L2NormOptionsT *L2NormOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new L2NormOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void L2NormOptions::UnPackTo( + L2NormOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = fused_activation_function(); + _o->fused_activation_function = _e; + }; +} + +inline flatbuffers::Offset L2NormOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateL2NormOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateL2NormOptions( + flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const L2NormOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreateL2NormOptions(_fbb, _fused_activation_function); +} + +inline LocalResponseNormalizationOptionsT * +LocalResponseNormalizationOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new LocalResponseNormalizationOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void LocalResponseNormalizationOptions::UnPackTo( + LocalResponseNormalizationOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = radius(); + _o->radius = _e; + }; + { + auto _e = bias(); + _o->bias = _e; + }; + { + auto _e = alpha(); + _o->alpha = _e; + }; + { + auto _e = beta(); + _o->beta = _e; + }; +} + +inline flatbuffers::Offset +LocalResponseNormalizationOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, + const LocalResponseNormalizationOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateLocalResponseNormalizationOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset +CreateLocalResponseNormalizationOptions( + flatbuffers::FlatBufferBuilder &_fbb, + const LocalResponseNormalizationOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const LocalResponseNormalizationOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _radius = _o->radius; + auto _bias = _o->bias; + auto _alpha = _o->alpha; + auto _beta = _o->beta; + return tflite::CreateLocalResponseNormalizationOptions(_fbb, _radius, _bias, + _alpha, _beta); +} + +inline LSTMOptionsT *LSTMOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new LSTMOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void LSTMOptions::UnPackTo( + LSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = fused_activation_function(); + _o->fused_activation_function = _e; + }; + { + auto _e = cell_clip(); + _o->cell_clip = _e; + }; + { + auto _e = proj_clip(); + _o->proj_clip = _e; + }; +} + +inline flatbuffers::Offset LSTMOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateLSTMOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateLSTMOptions( + flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const LSTMOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _fused_activation_function = _o->fused_activation_function; + auto _cell_clip = _o->cell_clip; + auto _proj_clip = _o->proj_clip; + return tflite::CreateLSTMOptions(_fbb, _fused_activation_function, _cell_clip, + _proj_clip); +} + +inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ResizeBilinearOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ResizeBilinearOptions::UnPackTo( + ResizeBilinearOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = new_height(); + _o->new_height = _e; + }; + { + auto _e = new_width(); + _o->new_width = _e; + }; +} + +inline flatbuffers::Offset ResizeBilinearOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateResizeBilinearOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateResizeBilinearOptions( + flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const ResizeBilinearOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _new_height = _o->new_height; + auto _new_width = _o->new_width; + return tflite::CreateResizeBilinearOptions(_fbb, _new_height, _new_width); +} + +inline CallOptionsT *CallOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new CallOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void CallOptions::UnPackTo( + CallOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = subgraph(); + _o->subgraph = _e; + }; +} + +inline flatbuffers::Offset CallOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateCallOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateCallOptions( + flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const CallOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _subgraph = _o->subgraph; + return tflite::CreateCallOptions(_fbb, _subgraph); +} + +inline ReshapeOptionsT *ReshapeOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ReshapeOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ReshapeOptions::UnPackTo( + ReshapeOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = new_shape(); + if (_e) { + _o->new_shape.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->new_shape[_i] = _e->Get(_i); + } + } + }; +} + +inline flatbuffers::Offset ReshapeOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateReshapeOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateReshapeOptions( + flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const ReshapeOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _new_shape = _o->new_shape.size() ? _fbb.CreateVector(_o->new_shape) : 0; + return tflite::CreateReshapeOptions(_fbb, _new_shape); +} + +inline SkipGramOptionsT *SkipGramOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new SkipGramOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void SkipGramOptions::UnPackTo( + SkipGramOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = ngram_size(); + _o->ngram_size = _e; + }; + { + auto _e = max_skip_size(); + _o->max_skip_size = _e; + }; + { + auto _e = include_all_ngrams(); + _o->include_all_ngrams = _e; + }; +} + +inline flatbuffers::Offset SkipGramOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateSkipGramOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateSkipGramOptions( + flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const SkipGramOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _ngram_size = _o->ngram_size; + auto _max_skip_size = _o->max_skip_size; + auto _include_all_ngrams = _o->include_all_ngrams; + return tflite::CreateSkipGramOptions(_fbb, _ngram_size, _max_skip_size, + _include_all_ngrams); +} + +inline SpaceToDepthOptionsT *SpaceToDepthOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new SpaceToDepthOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void SpaceToDepthOptions::UnPackTo( + SpaceToDepthOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = block_size(); + _o->block_size = _e; + }; +} + +inline flatbuffers::Offset SpaceToDepthOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateSpaceToDepthOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateSpaceToDepthOptions( + flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const SpaceToDepthOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _block_size = _o->block_size; + return tflite::CreateSpaceToDepthOptions(_fbb, _block_size); +} + +inline EmbeddingLookupSparseOptionsT *EmbeddingLookupSparseOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new EmbeddingLookupSparseOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void EmbeddingLookupSparseOptions::UnPackTo( + EmbeddingLookupSparseOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = combiner(); + _o->combiner = _e; + }; +} + +inline flatbuffers::Offset +EmbeddingLookupSparseOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, + const EmbeddingLookupSparseOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateEmbeddingLookupSparseOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset +CreateEmbeddingLookupSparseOptions( + flatbuffers::FlatBufferBuilder &_fbb, + const EmbeddingLookupSparseOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const EmbeddingLookupSparseOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _combiner = _o->combiner; + return tflite::CreateEmbeddingLookupSparseOptions(_fbb, _combiner); +} + +inline OperatorCodeT *OperatorCode::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new OperatorCodeT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void OperatorCode::UnPackTo( + OperatorCodeT *_o, + const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = builtin_code(); + _o->builtin_code = _e; + }; + { + auto _e = custom_code(); + if (_e) _o->custom_code = _e->str(); + }; +} + +inline flatbuffers::Offset OperatorCode::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateOperatorCode(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateOperatorCode( + flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const OperatorCodeT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _builtin_code = _o->builtin_code; + auto _custom_code = + _o->custom_code.empty() ? 0 : _fbb.CreateString(_o->custom_code); + return tflite::CreateOperatorCode(_fbb, _builtin_code, _custom_code); +} + +inline OperatorT *Operator::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new OperatorT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void Operator::UnPackTo( + OperatorT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = opcode_index(); + _o->opcode_index = _e; + }; + { + auto _e = inputs(); + if (_e) { + _o->inputs.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->inputs[_i] = _e->Get(_i); + } + } + }; + { + auto _e = outputs(); + if (_e) { + _o->outputs.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->outputs[_i] = _e->Get(_i); + } + } + }; + { + auto _e = builtin_options_type(); + _o->builtin_options.type = _e; + }; + { + auto _e = builtin_options(); + if (_e) + _o->builtin_options.value = + BuiltinOptionsUnion::UnPack(_e, builtin_options_type(), _resolver); + }; + { + auto _e = custom_options(); + if (_e) { + _o->custom_options.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->custom_options[_i] = _e->Get(_i); + } + } + }; + { + auto _e = custom_options_format(); + _o->custom_options_format = _e; + }; +} + +inline flatbuffers::Offset Operator::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateOperator(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateOperator( + flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const OperatorT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _opcode_index = _o->opcode_index; + auto _inputs = _o->inputs.size() ? _fbb.CreateVector(_o->inputs) : 0; + auto _outputs = _o->outputs.size() ? _fbb.CreateVector(_o->outputs) : 0; + auto _builtin_options_type = _o->builtin_options.type; + auto _builtin_options = _o->builtin_options.Pack(_fbb); + auto _custom_options = + _o->custom_options.size() ? _fbb.CreateVector(_o->custom_options) : 0; + auto _custom_options_format = _o->custom_options_format; + return tflite::CreateOperator(_fbb, _opcode_index, _inputs, _outputs, + _builtin_options_type, _builtin_options, + _custom_options, _custom_options_format); +} + +inline SubGraphT *SubGraph::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new SubGraphT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void SubGraph::UnPackTo( + SubGraphT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = tensors(); + if (_e) { + _o->tensors.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->tensors[_i] = + std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); + } + } + }; + { + auto _e = inputs(); + if (_e) { + _o->inputs.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->inputs[_i] = _e->Get(_i); + } + } + }; + { + auto _e = outputs(); + if (_e) { + _o->outputs.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->outputs[_i] = _e->Get(_i); + } + } + }; + { + auto _e = operators(); + if (_e) { + _o->operators.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->operators[_i] = + std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); + } + } + }; + { + auto _e = name(); + if (_e) _o->name = _e->str(); + }; +} + +inline flatbuffers::Offset SubGraph::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateSubGraph(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateSubGraph( + flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const SubGraphT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _tensors = + _o->tensors.size() + ? _fbb.CreateVector>( + _o->tensors.size(), + [](size_t i, _VectorArgs *__va) { + return CreateTensor(*__va->__fbb, __va->__o->tensors[i].get(), + __va->__rehasher); + }, + &_va) + : 0; + auto _inputs = _o->inputs.size() ? _fbb.CreateVector(_o->inputs) : 0; + auto _outputs = _o->outputs.size() ? _fbb.CreateVector(_o->outputs) : 0; + auto _operators = _o->operators.size() + ? _fbb.CreateVector>( + _o->operators.size(), + [](size_t i, _VectorArgs *__va) { + return CreateOperator( + *__va->__fbb, __va->__o->operators[i].get(), + __va->__rehasher); + }, + &_va) + : 0; + auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name); + return tflite::CreateSubGraph(_fbb, _tensors, _inputs, _outputs, _operators, + _name); +} + +inline BufferT *Buffer::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new BufferT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void Buffer::UnPackTo( + BufferT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = data(); + if (_e) { + _o->data.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->data[_i] = _e->Get(_i); + } + } + }; +} + +inline flatbuffers::Offset Buffer::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const BufferT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateBuffer(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateBuffer( + flatbuffers::FlatBufferBuilder &_fbb, const BufferT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const BufferT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _data = _o->data.size() ? _fbb.CreateVector(_o->data) : 0; + return tflite::CreateBuffer(_fbb, _data); +} + +inline ModelT *Model::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ModelT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void Model::UnPackTo( + ModelT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = version(); + _o->version = _e; + }; + { + auto _e = operator_codes(); + if (_e) { + _o->operator_codes.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->operator_codes[_i] = + std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); + } + } + }; + { + auto _e = subgraphs(); + if (_e) { + _o->subgraphs.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->subgraphs[_i] = + std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); + } + } + }; + { + auto _e = description(); + if (_e) _o->description = _e->str(); + }; + { + auto _e = buffers(); + if (_e) { + _o->buffers.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->buffers[_i] = + std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); + } + } + }; +} + +inline flatbuffers::Offset Model::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateModel(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateModel( + flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const ModelT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _version = _o->version; + auto _operator_codes = + _o->operator_codes.size() + ? _fbb.CreateVector>( + _o->operator_codes.size(), + [](size_t i, _VectorArgs *__va) { + return CreateOperatorCode(*__va->__fbb, + __va->__o->operator_codes[i].get(), + __va->__rehasher); + }, + &_va) + : 0; + auto _subgraphs = _o->subgraphs.size() + ? _fbb.CreateVector>( + _o->subgraphs.size(), + [](size_t i, _VectorArgs *__va) { + return CreateSubGraph( + *__va->__fbb, __va->__o->subgraphs[i].get(), + __va->__rehasher); + }, + &_va) + : 0; + auto _description = + _o->description.empty() ? 0 : _fbb.CreateString(_o->description); + auto _buffers = + _o->buffers.size() + ? _fbb.CreateVector>( + _o->buffers.size(), + [](size_t i, _VectorArgs *__va) { + return CreateBuffer(*__va->__fbb, __va->__o->buffers[i].get(), + __va->__rehasher); + }, + &_va) + : 0; + return tflite::CreateModel(_fbb, _version, _operator_codes, _subgraphs, + _description, _buffers); +} + +inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, + const void *obj, BuiltinOptions type) { + switch (type) { + case BuiltinOptions_NONE: { + return true; + } + case BuiltinOptions_Conv2DOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_DepthwiseConv2DOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ConcatEmbeddingsOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LSHProjectionOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_Pool2DOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SVDFOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_RNNOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_FullyConnectedOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SoftmaxOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ConcatenationOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_AddOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_L2NormOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LocalResponseNormalizationOptions: { + auto ptr = + reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LSTMOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ResizeBilinearOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_CallOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ReshapeOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SkipGramOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SpaceToDepthOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_EmbeddingLookupSparseOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_MulOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + default: + return false; + } +} + +inline bool VerifyBuiltinOptionsVector( + flatbuffers::Verifier &verifier, + const flatbuffers::Vector> *values, + const flatbuffers::Vector *types) { + if (values->size() != types->size()) return false; + for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyBuiltinOptions(verifier, values->Get(i), + types->GetEnum(i))) { + return false; + } + } + return true; +} + +inline void *BuiltinOptionsUnion::UnPack( + const void *obj, BuiltinOptions type, + const flatbuffers::resolver_function_t *resolver) { + switch (type) { + case BuiltinOptions_Conv2DOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_DepthwiseConv2DOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ConcatEmbeddingsOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LSHProjectionOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_Pool2DOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SVDFOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_RNNOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_FullyConnectedOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SoftmaxOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ConcatenationOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_AddOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_L2NormOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LocalResponseNormalizationOptions: { + auto ptr = + reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LSTMOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ResizeBilinearOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_CallOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ReshapeOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SkipGramOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SpaceToDepthOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_EmbeddingLookupSparseOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_MulOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + default: + return nullptr; + } +} + +inline flatbuffers::Offset BuiltinOptionsUnion::Pack( + flatbuffers::FlatBufferBuilder &_fbb, + const flatbuffers::rehasher_function_t *_rehasher) const { + switch (type) { + case BuiltinOptions_Conv2DOptions: { + auto ptr = reinterpret_cast(value); + return CreateConv2DOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_DepthwiseConv2DOptions: { + auto ptr = reinterpret_cast(value); + return CreateDepthwiseConv2DOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ConcatEmbeddingsOptions: { + auto ptr = reinterpret_cast(value); + return CreateConcatEmbeddingsOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LSHProjectionOptions: { + auto ptr = reinterpret_cast(value); + return CreateLSHProjectionOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_Pool2DOptions: { + auto ptr = reinterpret_cast(value); + return CreatePool2DOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SVDFOptions: { + auto ptr = reinterpret_cast(value); + return CreateSVDFOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_RNNOptions: { + auto ptr = reinterpret_cast(value); + return CreateRNNOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_FullyConnectedOptions: { + auto ptr = reinterpret_cast(value); + return CreateFullyConnectedOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SoftmaxOptions: { + auto ptr = reinterpret_cast(value); + return CreateSoftmaxOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ConcatenationOptions: { + auto ptr = reinterpret_cast(value); + return CreateConcatenationOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_AddOptions: { + auto ptr = reinterpret_cast(value); + return CreateAddOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_L2NormOptions: { + auto ptr = reinterpret_cast(value); + return CreateL2NormOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LocalResponseNormalizationOptions: { + auto ptr = + reinterpret_cast(value); + return CreateLocalResponseNormalizationOptions(_fbb, ptr, _rehasher) + .Union(); + } + case BuiltinOptions_LSTMOptions: { + auto ptr = reinterpret_cast(value); + return CreateLSTMOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ResizeBilinearOptions: { + auto ptr = reinterpret_cast(value); + return CreateResizeBilinearOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_CallOptions: { + auto ptr = reinterpret_cast(value); + return CreateCallOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ReshapeOptions: { + auto ptr = reinterpret_cast(value); + return CreateReshapeOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SkipGramOptions: { + auto ptr = reinterpret_cast(value); + return CreateSkipGramOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SpaceToDepthOptions: { + auto ptr = reinterpret_cast(value); + return CreateSpaceToDepthOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_EmbeddingLookupSparseOptions: { + auto ptr = reinterpret_cast(value); + return CreateEmbeddingLookupSparseOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_MulOptions: { + auto ptr = reinterpret_cast(value); + return CreateMulOptions(_fbb, ptr, _rehasher).Union(); + } + default: + return 0; + } +} + +inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) + FLATBUFFERS_NOEXCEPT : type(u.type), + value(nullptr) { + switch (type) { + case BuiltinOptions_Conv2DOptions: { + value = new Conv2DOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_DepthwiseConv2DOptions: { + value = new DepthwiseConv2DOptionsT( + *reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ConcatEmbeddingsOptions: { + value = new ConcatEmbeddingsOptionsT( + *reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LSHProjectionOptions: { + value = new LSHProjectionOptionsT( + *reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_Pool2DOptions: { + value = new Pool2DOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SVDFOptions: { + value = new SVDFOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_RNNOptions: { + value = new RNNOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_FullyConnectedOptions: { + value = new FullyConnectedOptionsT( + *reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SoftmaxOptions: { + value = + new SoftmaxOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ConcatenationOptions: { + value = new ConcatenationOptionsT( + *reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_AddOptions: { + value = new AddOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_L2NormOptions: { + value = new L2NormOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LocalResponseNormalizationOptions: { + value = new LocalResponseNormalizationOptionsT( + *reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LSTMOptions: { + value = new LSTMOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ResizeBilinearOptions: { + value = new ResizeBilinearOptionsT( + *reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_CallOptions: { + value = new CallOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ReshapeOptions: { + value = + new ReshapeOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SkipGramOptions: { + value = + new SkipGramOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SpaceToDepthOptions: { + value = new SpaceToDepthOptionsT( + *reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_EmbeddingLookupSparseOptions: { + value = new EmbeddingLookupSparseOptionsT( + *reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_MulOptions: { + value = new MulOptionsT(*reinterpret_cast(u.value)); + break; + } + default: + break; + } +} + +inline void BuiltinOptionsUnion::Reset() { + switch (type) { + case BuiltinOptions_Conv2DOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_DepthwiseConv2DOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ConcatEmbeddingsOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LSHProjectionOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_Pool2DOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SVDFOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_RNNOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_FullyConnectedOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SoftmaxOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ConcatenationOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_AddOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_L2NormOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LocalResponseNormalizationOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LSTMOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ResizeBilinearOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_CallOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ReshapeOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SkipGramOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SpaceToDepthOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_EmbeddingLookupSparseOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_MulOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + default: + break; + } + value = nullptr; + type = BuiltinOptions_NONE; +} + +inline const tflite::Model *GetModel(const void *buf) { + return flatbuffers::GetRoot(buf); +} + +inline const char *ModelIdentifier() { return "TFL3"; } + +inline bool ModelBufferHasIdentifier(const void *buf) { + return flatbuffers::BufferHasIdentifier(buf, ModelIdentifier()); +} + +inline bool VerifyModelBuffer(flatbuffers::Verifier &verifier) { + return verifier.VerifyBuffer(ModelIdentifier()); +} + +inline const char *ModelExtension() { return "tflite"; } + +inline void FinishModelBuffer(flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset root) { + fbb.Finish(root, ModelIdentifier()); +} + +inline std::unique_ptr UnPackModel( + const void *buf, const flatbuffers::resolver_function_t *res = nullptr) { + return std::unique_ptr(GetModel(buf)->UnPack(res)); +} + +} // namespace tflite + +#endif // FLATBUFFERS_GENERATED_SCHEMA_TFLITE_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD index 793eb366a43..332253a092a 100644 --- a/tensorflow/contrib/lite/toco/tflite/BUILD +++ b/tensorflow/contrib/lite/toco/tflite/BUILD @@ -1,3 +1,8 @@ +package( + # To suppress build cleaner error about inclusion of schema_generate.h. + features = ["-layering_check"], +) + licenses(["notice"]) # Apache 2.0 load( diff --git a/tensorflow/contrib/lite/tools/benchmark_model.cc b/tensorflow/contrib/lite/tools/benchmark_model.cc index ef43f64131c..6ae3ab57294 100644 --- a/tensorflow/contrib/lite/tools/benchmark_model.cc +++ b/tensorflow/contrib/lite/tools/benchmark_model.cc @@ -31,6 +31,7 @@ void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); #endif #define LOG(x) std::cerr + #define CHECK(x) \ if (!(x)) { \ LOG(ERROR) << #x << "failed"; \ diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index a2b444d53ae..b6104413089 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -19,13 +19,20 @@ set -e DOWNLOADS_DIR=tensorflow/contrib/makefile/downloads BZL_FILE_PATH=tensorflow/workspace.bzl -EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" +# Ensure it is being run from repo root +if [ ! -f $BZL_FILE_PATH ]; then + echo "Could not find ${BZL_FILE_PATH}": + echo "Likely you are not running this from the root directory of the repository."; + exit 1; +fi + +EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)" GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" PROTOBUF_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" RE2_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" -FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" +FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)" ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)" # TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64, diff --git a/tensorflow/contrib/model_pruning/python/layers/core_layers.py b/tensorflow/contrib/model_pruning/python/layers/core_layers.py index ae60d8b1e18..95dfd8f4213 100644 --- a/tensorflow/contrib/model_pruning/python/layers/core_layers.py +++ b/tensorflow/contrib/model_pruning/python/layers/core_layers.py @@ -72,8 +72,8 @@ class _MaskedConv(base.Layer): linear activation. use_bias: Boolean, whether the layer uses a bias. kernel_initializer: An initializer for the convolution kernel. - bias_initializer: An initializer for the bias vector. If None, no bias will - be applied. + bias_initializer: An initializer for the bias vector. If None, the default + initializer will be used. kernel_regularizer: Optional regularizer for the convolution kernel. bias_regularizer: Optional regularizer for the bias vector. activity_regularizer: Regularizer function for the output. @@ -279,8 +279,8 @@ class MaskedConv2D(_MaskedConv): linear activation. use_bias: Boolean, whether the layer uses a bias. kernel_initializer: An initializer for the convolution kernel. - bias_initializer: An initializer for the bias vector. If None, no bias will - be applied. + bias_initializer: An initializer for the bias vector. If None, the default + initializer will be used. kernel_regularizer: Optional regularizer for the convolution kernel. bias_regularizer: Optional regularizer for the bias vector. activity_regularizer: Regularizer function for the output. diff --git a/tensorflow/contrib/periodic_resample/BUILD b/tensorflow/contrib/periodic_resample/BUILD new file mode 100644 index 00000000000..71582f9c9a0 --- /dev/null +++ b/tensorflow/contrib/periodic_resample/BUILD @@ -0,0 +1,113 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load( + "//tensorflow:tensorflow.bzl", + "tf_gen_op_libs", + "tf_custom_op_library", + "tf_custom_op_py_library", + "tf_gen_op_wrapper_py", +) + +cc_library( + name = "all_ops", + srcs = [":custom_op_sources"], + hdrs = [":custom_op_headers"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], + alwayslink = 1, +) + +tf_custom_op_library( + name = "python/ops/_periodic_resample_op.so", + srcs = [ + ":custom_op_headers", + ":custom_op_sources", + ], +) + +tf_gen_op_libs( + op_lib_names = ["array_ops"], +) + +tf_gen_op_wrapper_py( + name = "gen_periodic_resample_op_py", + out = "python/ops/gen_periodic_resample_op.py", + deps = [":array_ops_op_lib"], +) + +tf_custom_op_py_library( + name = "periodic_resample_op_py", + srcs = ["python/ops/periodic_resample_op.py"], + dso = ["python/ops/_periodic_resample_op.so"], + kernels = [ + ":array_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":gen_periodic_resample_op_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:framework_for_generated_wrappers", + ], +) + +py_library( + name = "init_py", + srcs = [ + "__init__.py", + "python/__init__.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":periodic_resample_op_py", + ], +) + +# py_library( +# name = "periodic_resample_op_py", +# srcs = ["python/ops/periodic_resample_op.py"], +# data = ["python/ops/_periodic_resample_op.so"], +# srcs_version = "PY2AND3", +# ) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +filegroup( + name = "custom_op_sources", + srcs = glob( + [ + "ops/*.cc", + "kernels/*.cc", + ], + exclude = [ + "ops/*_test.cc", + "kernels/*_test.cc", + ], + ), +) + +filegroup( + name = "custom_op_headers", + srcs = glob( + [ + "kernels/*.h", + "ops/*.h", + ], + ), +) diff --git a/tensorflow/contrib/periodic_resample/__init__.py b/tensorflow/contrib/periodic_resample/__init__.py new file mode 100644 index 00000000000..fde9091b88f --- /dev/null +++ b/tensorflow/contrib/periodic_resample/__init__.py @@ -0,0 +1,27 @@ +# ============================================================================= +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Custom op used by periodic_resample.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.periodic_resample.python.ops.periodic_resample_op import periodic_resample +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ["periodic_resample"] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc new file mode 100644 index 00000000000..9cee405cef2 --- /dev/null +++ b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc @@ -0,0 +1,26 @@ +// ============================================================================= +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h" + +namespace tensorflow { + +REGISTER_KERNEL_BUILDER(Name("PeriodicResample") + .Device(DEVICE_CPU), + PeriodicResampleOp); + +} // namespace tensorflow diff --git a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h new file mode 100644 index 00000000000..bef21f7a5c8 --- /dev/null +++ b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h @@ -0,0 +1,230 @@ +// ============================================================================= +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef TENSORFLOW_KERNELS_PERIODICRESAMPLE_OP_H_ +#define TENSORFLOW_KERNELS_PERIODICRESAMPLE_OP_H_ + +#include +#include +#include +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/status.h" + +namespace { + +template +IndexT compute_input_index( + IndexVecT* target_dimensions, const IndexT& output_index, + const IndexVecT& original_dimensions, const int& adjustable_dimension, + const std::vector& dimension_ceiling, + const std::vector& cumulative_dimensions, IndexT* result, + std::vector* output_indices, const int& rank) { + *result = 0; + output_indices->clear(); + + // un-rasterize the output index + auto last_reduced_i = output_index; + for (auto r = rank - 1; r >= 0; --r) { + (*output_indices)[r] = last_reduced_i % (*target_dimensions)[r]; + last_reduced_i = + (last_reduced_i - (*output_indices)[r]) / (*target_dimensions)[r]; + } + + // rasterize the input index + IndexT last_index_factor = 1; + for (auto r = rank - 1; r >= 0; --r) { + IndexT index = 0; + if (r != adjustable_dimension) + index = (*output_indices)[r] / dimension_ceiling[r]; + else { + for (int qi = 0; qi < rank; ++qi) { + if (qi == adjustable_dimension) continue; + index += cumulative_dimensions[qi] * + ((*output_indices)[qi] % dimension_ceiling[qi]); + } + index *= (*target_dimensions)[adjustable_dimension]; + index += (*output_indices)[r]; + } + *result += last_index_factor * index; + last_index_factor *= original_dimensions[r]; + } + + return *result; +} + +template // both types are needed here b/c IndexVecT and + // InputDataT are not related + void + fill_periodic_tensor( + tensorflow::OpKernelContext* context, + const IndexVecT& desired_shape, + const tensorflow::Tensor& input_tensor) { + // input is a strided array (last index is fastest, C-ordered) + auto input = input_tensor.flat(); + const int rank = input_tensor.dims(); + // original and target dimensions + std::vector original_dimensions(rank), + target_dimensions(rank); + tensorflow::int64 total_size(input_tensor.NumElements()), new_sliced_size(1); + // factors by which original_dimensions increases/decreases w.r.t. + // target_dimensions + std::vector dimension_ceiling(rank), + cumulative_dimensions(rank); + // index of adjustable dimension + int adjustable_dimension; + tensorflow::TensorShape output_shape; + + // requires that the rank of the input tensor and length of the desired shape + // are equal + OP_REQUIRES(context, rank == desired_shape.size(), + tensorflow::errors::InvalidArgument( + "periodic_resample expects the rank of the input tensor, ", + rank, ", to be the same as the length of the desired shape, ", + desired_shape.size(), ".")); + + bool found = false; + for (int i = 0; i < rank; ++i) { + // if (desired_shape(i) < 1) { + if (desired_shape[i] < 1) { + // only one index can be adjustable + OP_REQUIRES(context, !found, + tensorflow::errors::InvalidArgument( + "periodic_resample expects only " + "one index to be marked as adjustable.")); + adjustable_dimension = i; + found = true; + } else { + // target_dimensions[i] = desired_shape(i); + target_dimensions[i] = desired_shape[i]; + new_sliced_size *= target_dimensions[i]; + } + } + // at least one index needs to be adjustable + OP_REQUIRES(context, found, + tensorflow::errors::InvalidArgument( + "periodic_resample expects at least " + "one index to be marked as adjustable.")); + + int count = 0; + for (const auto dim_info : input_tensor.shape()) { + original_dimensions[count] = dim_info.size; + ++count; + } + + target_dimensions[adjustable_dimension] = total_size / new_sliced_size; + + count = 0; + for (int i = 0; i < input_tensor.shape().dims(); ++i) { + dimension_ceiling[count] = tensorflow::int64(std::ceil( + float(target_dimensions[count]) / float(original_dimensions[count]))); + if (count == 0) + cumulative_dimensions[count] = 1; + else + cumulative_dimensions[count] = + cumulative_dimensions[count - 1] * dimension_ceiling[count - 1]; + ++count; + } + + // ensure that the new dimension is greater than zero + OP_REQUIRES(context, target_dimensions[adjustable_dimension] > 0, + tensorflow::errors::InvalidArgument( + "periodic_resample found that the " + "adjustable dimension, ", + adjustable_dimension, ", isn't greater than zero, ", + target_dimensions[adjustable_dimension], ".")); + for (int i = 0; i < rank; ++i) { + output_shape.AddDim(target_dimensions[i]); + } + const auto new_size = + new_sliced_size * target_dimensions[adjustable_dimension]; + + // Create an output tensor and attach it to the current context + tensorflow::Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, output_shape, &output_tensor)); + auto output = output_tensor->flat(); + + // memory is allocated for these variables outside the inner loop for + // efficiency (although, I could create a separate class scope for + // this purpose instead) + tensorflow::int64 result = 0; + std::vector output_indices(target_dimensions.size()); + + // Fill output tensor with periodically resampled input tensor values + for (tensorflow::int64 output_index = 0; output_index < new_size; + ++output_index) { + output(output_index) = input(compute_input_index( + &target_dimensions, output_index, original_dimensions, + adjustable_dimension, dimension_ceiling, cumulative_dimensions, &result, + &output_indices, rank)); + } +} + +void create_output_tensor( + tensorflow::OpKernelContext* context, + const tensorflow::Tensor& input_tensor, + const tensorflow::DataType& input_tensor_type, + const tensorflow::PartialTensorShape& desired_shape_tensor) { + auto desired_shape = desired_shape_tensor.dim_sizes(); + + // obligatory type switch + switch (input_tensor_type) { + case tensorflow::DataTypeToEnum::value: + fill_periodic_tensor(context, desired_shape, input_tensor); + break; + case tensorflow::DataTypeToEnum::value: + fill_periodic_tensor(context, desired_shape, input_tensor); + break; + case tensorflow::DataTypeToEnum::value: + fill_periodic_tensor(context, desired_shape, + input_tensor); + break; + case tensorflow::DataTypeToEnum::value: + fill_periodic_tensor(context, desired_shape, + input_tensor); + break; + default:; + } +} + +} // namespace + +class PeriodicResampleOp : public tensorflow::OpKernel { + public: + explicit PeriodicResampleOp(tensorflow::OpKernelConstruction* context) + : tensorflow::OpKernel(context) { + // Get the desired shape + OP_REQUIRES_OK(context, context->GetAttr("shape", &desired_shape)); + } + + void Compute(tensorflow::OpKernelContext* context) override { + // Grab the input tensor + const tensorflow::Tensor& input_tensor = context->input(0); + const tensorflow::DataType input_tensor_type = context->input_dtype(0); + + create_output_tensor(context, input_tensor, input_tensor_type, + desired_shape); + } + + private: + tensorflow::PartialTensorShape desired_shape; +}; + +#endif // TENSORFLOW_KERNELS_PERIODICRESAMPLE_OP_H_ diff --git a/tensorflow/contrib/periodic_resample/ops/array_ops.cc b/tensorflow/contrib/periodic_resample/ops/array_ops.cc new file mode 100644 index 00000000000..6029ad6a0d1 --- /dev/null +++ b/tensorflow/contrib/periodic_resample/ops/array_ops.cc @@ -0,0 +1,88 @@ +// ============================================================================= +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" + +using namespace tensorflow; + +REGISTER_OP("PeriodicResample") + .Attr("T: numbertype") + .Input("values: T") + .Attr("shape: shape") + .Output("output: T") + .SetShapeFn(shape_inference::ExplicitShape) + .Doc(R"doc( +Periodically resample elements of a tensor to conform to `shape`. + +This function implements a slightly more generic version of the subpixel +convolutions found in this [paper](https://arxiv.org/abs/1609.05158). + +The formula for computing the elements in the `output` tensor is as follows: + `T` = `values` tensor of rank `R` + `S` = desired `shape` of output tensor (vector of length `R`) + `P` = `output` tensor of rank `R` + \((T_1,\ldots,T_R)\) = shape(`T`) + \([S_1,\ldots,S_q,\ldots,S_R]\) = elements of vector `S` + + A single element in `S` is left unspecified (denoted \(S_q=-1\)). + Let \(f_i\) denote the (possibly non-integer) factor that relates the original + dimension to the desired dimensions, \(S_i=f_i T_i\), for \(i\neq q\) where + \(f_i>0\). + Define the following: + \(g_i=\lceil f_i\rceil\) + \(t=\prod_i T_i\) + \(s=\prod_{i\neq q} S_i\) + \(S_q\) can then be defined as by \(S_q=\lfloor t/s\rfloor\). + The elements of the resulting tensor are defined as + \(P_{s_1,\ldots,s_R}=T_{h_1,\ldots,h_q,\ldots,h_R}\). + The \(h_i\) (\(i\neq q\)) are defined by \(h_i=\lfloor s_i/g_i\rfloor\). + \(h_q=S_q\sum_{j\neq q}^{q-1}G_j \mathrm{mod}(s_j,g_j) + s_q\), where + \(G_j=\prod_{i}^{j-1}g_i\) (\(G_0=1\)). + +One drawback of this method is that whenever the output dimensions are slightly +less than integer multiples of the input dimensions, many of the tensor elements +are repeated in an inefficient way. This is resolved by specifying that all +desired dimensions are integer multiples of the input tensor. + +For example: + +```prettyprint +`input` is [[ 0 1 2 3] + [ 4 5 6 7] + [ 8 9 10 11]] + +tf.periodic_resample(input, [6, None]) ==> [[ 0 1] + [ 2 3] + [ 4 5] + [ 6 7] + [ 8 9] + [10 11]] +``` + +values: The tensor of rank `R` to periodic_resample +shape: A 1-D tensor representing the desired shape of the output tensor. + Exactly one element of this tensor must have the value `None` which represents + that this dimension of `values` can be adjusted downward in order to + accommodate increases in other dimensions. The specified sizes of the + non-adjustable dimensions must by at least as large as in the `values` tensor. +output: Periodically resampled tensor that has dimensions specified as in + `shape` except that the dimension specified as `None` will be minimally + decreased as necessary. + +)doc"); diff --git a/tensorflow/contrib/periodic_resample/python/__init__.py b/tensorflow/contrib/periodic_resample/python/__init__.py new file mode 100644 index 00000000000..a8b6ead0f59 --- /dev/null +++ b/tensorflow/contrib/periodic_resample/python/__init__.py @@ -0,0 +1,20 @@ +# ============================================================================= +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Public API of periodic_resample.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py new file mode 100644 index 00000000000..1d727870f65 --- /dev/null +++ b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py @@ -0,0 +1,101 @@ +# ============================================================================= +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy +import tensorflow +from tensorflow.contrib.periodic_resample import periodic_resample +from tensorflow.python.framework import test_util +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + + +class PeriodicResampleTest(test_util.TensorFlowTestCase): + + def testPeriodicResampleBasic2D(self): + + input_tensor = numpy.arange(12).reshape((3, 4)) + desired_shape = numpy.array([6, None]) + output_tensor = input_tensor.reshape((6, 2)) + + with self.test_session(): + variables.global_variables_initializer().run() + result = periodic_resample(input_tensor, desired_shape).eval() + self.assertAllEqual(result, output_tensor) + + def testPeriodicResampleTruncatedBasic2D(self): + + input_tensor = numpy.arange(12).reshape((3, 4)) + desired_shape = numpy.array([5, None]) + output_tensor = input_tensor.reshape((6, 2))[:-1] + + with self.test_session(): + variables.global_variables_initializer().run() + result = periodic_resample(input_tensor, desired_shape).eval() + self.assertAllEqual(result, output_tensor) + + def testPeriodicResampleBasic3D(self): + + input_tensor = numpy.arange(2*2*4).reshape((2, 2, 4)) + desired_shape = numpy.array([4, 4, None]) + output_tensor = numpy.array([[[0], [2], [4], [6]], + [[1], [3], [5], [7]], + [[8], [10], [12], [14]], + [[9], [11], [13], [15]]]) + + # NOTE: output_tensor != input_tensor.reshape((4, 4, -1)) + with self.test_session(): + variables.global_variables_initializer().run() + result = periodic_resample(input_tensor, desired_shape).eval() + # input_tensor[0, 0, 0] == result[0, 0, 0] + # input_tensor[0, 0, 1] == result[1, 0, 0] + # input_tensor[0, 0, 2] == result[0, 1, 0] + # input_tensor[0, 0, 3] == result[1, 1, 0] + self.assertAllEqual(result, output_tensor) + + def testPeriodicResampleBasic4D(self): + + input_tensor = numpy.arange(2*2*2*8).reshape((2, 2, 2, 8)) + desired_shape = numpy.array([4, 4, 4, None]) + output_tensor = numpy.array([[[[0], [4], [8], [12]], + [[2], [6], [10], [14]], + [[16], [20], [24], [28]], + [[18], [22], [26], [30]]], + [[[1], [5], [9], [13]], + [[3], [7], [11], [15]], + [[17], [21], [25], [29]], + [[19], [23], [27], [31]]], + [[[32], [36], [40], [44]], + [[34], [38], [42], [46]], + [[48], [52], [56], [60]], + [[50], [54], [58], [62]]], + [[[33], [37], [41], [45]], + [[35], [39], [43], [47]], + [[49], [53], [57], [61]], + [[51], [55], [59], [63]]]]) + + # NOTE: output_tensor != input_tensor.reshape((4, 4, 4, -1)) + with self.test_session(): + variables.global_variables_initializer().run() + result = periodic_resample(input_tensor, desired_shape).eval() + self.assertAllEqual(result, output_tensor) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py b/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py new file mode 100644 index 00000000000..6a09f70f442 --- /dev/null +++ b/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py @@ -0,0 +1,30 @@ +# ============================================================================= +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.contrib.periodic_resample.python.ops import gen_periodic_resample_op + +from tensorflow.contrib.periodic_resample.python.ops.gen_periodic_resample_op import periodic_resample + +from tensorflow.contrib.util import loader +from tensorflow.python.platform import resource_loader + +_periodic_resample_op = loader.load_op_library( + resource_loader.get_path_to_datafile('_periodic_resample_op.so')) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index f130a2187c2..84fcf733c14 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -40,7 +40,6 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test - # pylint: enable=protected-access Linear = core_rnn_cell._Linear # pylint: disable=invalid-name diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py index 3bdd475fade..7970c20a269 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py @@ -24,24 +24,20 @@ import sys import tensorflow as tf - tf.flags.DEFINE_string('service_addr', '', 'Address of TPU profiler service e.g. localhost:8466') - - tf.flags.DEFINE_string('logdir', '', 'Path of TensorBoard log directory e.g. /tmp/tb_log') - - tf.flags.DEFINE_integer('duration_ms', 2000, 'Duration of tracing in ms.') - FLAGS = tf.flags.FLAGS - - EXECUTABLE = 'data/capture_tpu_profile' +def run_main(): + tf.app.run(main) + + def main(unused_argv=None): if not FLAGS.service_addr or not FLAGS.logdir: sys.exit('service_addr and logdir must be provided.') @@ -54,4 +50,4 @@ def main(unused_argv=None): if __name__ == '__main__': - tf.app.run(main) + run_main() diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py index e77cae4695d..ee6950699e7 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py @@ -23,7 +23,7 @@ from setuptools import setup _VERSION = '1.3.0-a1' CONSOLE_SCRIPTS = [ - 'capture_tpu_profile=cloud_tpu_profiler.main:main', + 'capture_tpu_profile=cloud_tpu_profiler.main:run_main', ] REQUIRED_PACKAGES = [ diff --git a/tensorflow/contrib/verbs/BUILD b/tensorflow/contrib/verbs/BUILD index 746ff38b37f..38a84ffb10e 100644 --- a/tensorflow/contrib/verbs/BUILD +++ b/tensorflow/contrib/verbs/BUILD @@ -7,6 +7,8 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 +load("//tensorflow:tensorflow.bzl", "tf_cuda_library") + exports_files(["LICENSE"]) filegroup( @@ -97,7 +99,7 @@ cc_library( alwayslink = 1, ) -cc_library( +tf_cuda_library( name = "rdma_rendezvous_mgr", srcs = ["rdma_rendezvous_mgr.cc"], hdrs = ["rdma_rendezvous_mgr.h"], @@ -130,7 +132,7 @@ cc_library( ], ) -cc_library( +tf_cuda_library( name = "rdma", srcs = ["rdma.cc"], hdrs = ["rdma.h"], diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc index ac8d994502f..ae9a384565a 100644 --- a/tensorflow/contrib/verbs/rdma.cc +++ b/tensorflow/contrib/verbs/rdma.cc @@ -18,11 +18,14 @@ limitations under the License. #include "tensorflow/contrib/verbs/rdma.h" #include #include +#include #include "tensorflow/contrib/verbs/verbs_util.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" +#if GOOGLE_CUDA #include "tensorflow/core/common_runtime/gpu/gpu_util.h" #include "tensorflow/core/common_runtime/gpu/process_state.h" +#endif #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" #include "tensorflow/core/distributed_runtime/session_mgr.h" #include "tensorflow/core/framework/rendezvous.h" @@ -31,6 +34,7 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/core/threadpool.h" namespace tensorflow { @@ -418,9 +422,6 @@ RdmaAdapter::RdmaAdapter(const WorkerEnv* worker_env) 0); CHECK(cq_) << "Failed to create completion queue"; CHECK(!ibv_req_notify_cq(cq_, 0)) << "Failed to request CQ notification"; - polling_thread_.reset(Env::Default()->StartThread( - ThreadOptions(), "RdmaAdapterCQThread", [this] { Process_CQ(); })); - VLOG(2) << "Start RdmaAdapter: " << name(); } RdmaAdapter::~RdmaAdapter() { @@ -432,6 +433,12 @@ RdmaAdapter::~RdmaAdapter() { CHECK(!ibv_close_device(context_)) << "Failed to release context"; } +void RdmaAdapter::StartPolling() { + polling_thread_.reset(Env::Default()->StartThread( + ThreadOptions(), "RdmaAdapterCQThread", [this] { Process_CQ(); })); + VLOG(2) << "Start RdmaAdapter: " << name(); +} + string RdmaAdapter::name() const { return string(context_->device->name); } // Function to process incoming messages @@ -452,9 +459,9 @@ void RdmaAdapter::Process_CQ() { CHECK_GE(ne, 0); for (int i = 0; i < ne; ++i) { CHECK(wc_[i].status == IBV_WC_SUCCESS) - << "Failed status \n" - << ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " " - << static_cast(wc_[i].wr_id) << " " << wc_[i].vendor_err; + << "Failed status \n" << ibv_wc_status_str(wc_[i].status) << " " + << wc_[i].status << " " << static_cast(wc_[i].wr_id) << " " + << wc_[i].vendor_err; if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) { RdmaChannel* rc = reinterpret_cast(wc_[i].wr_id); // put back a recv wr. @@ -557,9 +564,44 @@ void RdmaAdapter::Process_CQ() { } } +int RdmaChannel::PingPostRecv() { + struct ibv_recv_wr wr, *bad_wr; + memset(&wr, 0, sizeof(wr)); + wr.sg_list = &ping_sge_list_; + wr.num_sge = 1; + wr.wr_id = kPingRecvWrid; + + return ibv_post_recv(qp_, &wr, &bad_wr); +} + +int RdmaChannel::PingPostSend() { + struct ibv_send_wr wr, *bad_wr; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = (uint64_t) this; + wr.sg_list = &ping_sge_list_; + wr.num_sge = 1; + wr.opcode = IBV_WR_SEND; + wr.send_flags = IBV_SEND_SIGNALED; + + return ibv_post_send(qp_, &wr, &bad_wr); +} + RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, const string remote_name) : adapter_(adapter), local_name_(local_name), remote_name_(remote_name) { + + struct ibv_sge list; + + mr_ = ibv_reg_mr(adapter_->pd_, ping_buff_, kPingBuffSize, + IBV_ACCESS_LOCAL_WRITE); + CHECK(mr_) << "Failed to register memory region"; + + memset(&list, 0, sizeof(list)); + list.addr = (uintptr_t)ping_buff_; + list.length = kPingBuffSize; + list.lkey = mr_->lkey; + + ping_sge_list_ = list; // Create queue pair { struct ibv_qp_init_attr attr; @@ -610,7 +652,7 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, // create message and ack buffers, then initialize the tables. { const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer", - "tx_ack_buffer", "rx_ack_buffer"}; + "tx_ack_buffer", "rx_ack_buffer"}; tx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[0]); rx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[1]); tx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[2]); @@ -632,15 +674,13 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, buffer_index_name_table_.insert({index, buffer_names[i]}); buffer_name_index_table_.insert({buffer_names[i], index}); } - - // Initiate recv - for (int i = 0; i < 100; i++) { - Recv(); - } } + CHECK(PingPostRecv() == 0) << "Couldn't post receive from " << remote_name_ + << " with error " << std::strerror(errno); } RdmaChannel::~RdmaChannel() { + ibv_dereg_mr(mr_); CHECK(!ibv_destroy_qp(qp_)) << "Failed to destroy QP"; delete tx_message_buffer_; delete rx_message_buffer_; @@ -671,7 +711,7 @@ void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) { void RdmaChannel::Recv() { struct ibv_recv_wr wr; memset(&wr, 0, sizeof(wr)); - wr.wr_id = (uint64_t)this; + wr.wr_id = (uint64_t) this; struct ibv_recv_wr* bad_wr; CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv"; } @@ -825,11 +865,11 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) { attr.ah_attr.grh.traffic_class = adapter_->params_.traffic_class; int r; - CHECK(!(r = ibv_modify_qp(qp_, &attr, - IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | - IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | - IBV_QP_MAX_DEST_RD_ATOMIC | - IBV_QP_MIN_RNR_TIMER))) + CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_AV | + IBV_QP_PATH_MTU | + IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | + IBV_QP_MAX_DEST_RD_ATOMIC | + IBV_QP_MIN_RNR_TIMER))) << "QP to Ready to Receive " << r; memset(&attr, 0, sizeof(ibv_qp_attr)); @@ -840,10 +880,10 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) { attr.rnr_retry = 7; /* infinite */ attr.max_rd_atomic = 1; - CHECK(!(r = ibv_modify_qp(qp_, &attr, - IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | - IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | - IBV_QP_MAX_QP_RD_ATOMIC))) + CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_TIMEOUT | + IBV_QP_RETRY_CNT | + IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | + IBV_QP_MAX_QP_RD_ATOMIC))) << "QP to Ready to Send " << r; connected_ = true; @@ -930,7 +970,7 @@ void RdmaBuffer::Write(uint32_t imm_data, size_t buffer_size) { struct ibv_send_wr wr; memset(&wr, 0, sizeof(wr)); - wr.wr_id = (uint64_t)this; + wr.wr_id = (uint64_t) this; wr.sg_list = &list; wr.num_sge = 1; wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; @@ -1025,9 +1065,10 @@ Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback( TensorProto proto; if (src_dev->tensorflow_gpu_device_info() && (!send_args.alloc_attrs.on_host())) { - CHECK(send_args.device_context) - << "send dev name: " << src_dev->name() - << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); +#if GOOGLE_CUDA + CHECK(send_args.device_context) << "send dev name: " << src_dev->name() + << " gpu_info: " + << src_dev->tensorflow_gpu_device_info(); if (can_memcpy) { AllocatorAttributes host_alloc_attrs; @@ -1053,8 +1094,8 @@ Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback( // aync instead GPUUtil::SetProtoFromGPU( in, src_dev, send_args.device_context, &proto, is_dead, - [this, proto, buffer_size, key, in, step_id, key_with_step_id, - is_dead, send_args, recv_args](const Status& s) mutable { + [this, proto, buffer_size, key, in, step_id, key_with_step_id, + is_dead, send_args, recv_args](const Status& s) mutable { CHECK(s.ok()) << "copy proto from gpu sync"; auto tensor_bytes = proto.ByteSize(); buffer_size += tensor_bytes; @@ -1063,6 +1104,7 @@ Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback( &proto, NULL, send_args, recv_args); }); } +#endif // GOOGLE_CUDA } else { // tensor is in CPU memory. StringPiece copy_buf; diff --git a/tensorflow/contrib/verbs/rdma.h b/tensorflow/contrib/verbs/rdma.h index 00217c81d4c..fea2327d77f 100644 --- a/tensorflow/contrib/verbs/rdma.h +++ b/tensorflow/contrib/verbs/rdma.h @@ -67,9 +67,20 @@ struct RemoteMR { uint64_t remote_addr; uint32_t rkey; }; -enum BufferStatus { none, idle, busy }; -enum Location { local, remote }; -enum BufferType { ACK, MESSAGE, TENSOR }; +enum BufferStatus { + none, + idle, + busy +}; +enum Location { + local, + remote +}; +enum BufferType { + ACK, + MESSAGE, + TENSOR +}; enum RdmaMessageType { RDMA_MESSAGE_ACK, RDMA_MESSAGE_BUFFER_IDLE, @@ -96,6 +107,7 @@ class RdmaAdapter { ~RdmaAdapter(); // Adapter name, e.g. mlx5_0. string name() const; + void StartPolling(); void Process_CQ(); protected: @@ -150,6 +162,15 @@ class RdmaChannel { void RemoveRecvCallback(const string& key); void RunRecvCallback(const string& key); static const int kNumMessageBuffers = 4; + static const int kPingRecvWrid = 0; + + private: + static const int kPingBuffSize = 1024; + char ping_buff_[kPingBuffSize]; + struct ibv_mr* mr_; + struct ibv_sge ping_sge_list_; + int PingPostRecv(); + int PingPostSend(); protected: const RdmaAdapter* adapter_; @@ -202,7 +223,7 @@ class RdmaBuffer { } void FreeBuffer(); void EnqueueItem(string Item); - virtual void SendNextItem(){}; + virtual void SendNextItem() {}; void CreateCPUBuffer(size_t size, bool lock = true); void SetRemoteMR(RemoteMR rmi, bool override); uint32_t LookupBufferIndex(const string& buffer_name) { diff --git a/tensorflow/contrib/verbs/rdma_mgr.cc b/tensorflow/contrib/verbs/rdma_mgr.cc index 09b878843f5..9cb307bcfa0 100644 --- a/tensorflow/contrib/verbs/rdma_mgr.cc +++ b/tensorflow/contrib/verbs/rdma_mgr.cc @@ -115,6 +115,57 @@ void RdmaMgr::SetupChannels() { } } +// Check connectivity by pinging every channel +bool RdmaMgr::ConnectivityCheck() { + int i, rcnt = 0, scnt = 0; + + for (const auto& p : channel_table_) { + string worker_name = p.first; + RdmaChannel* rc = p.second; + + VLOG(2) << "Ping to " << worker_name; + CHECK(rc->PingPostSend() == 0) << "Couldn't post send to " << worker_name + << " with error: " << std::strerror(errno); + for (i = 0; i < rc->adapter_->params_.queue_depth - 1; i++) { + rc->Recv(); + } + } + + while (rcnt < num_remote_workers_ || scnt < num_remote_workers_) { + int ne; + do { + ne = ibv_poll_cq(rdma_adapter_->cq_, 2 * num_remote_workers_, + rdma_adapter_->wc_); + CHECK(ne >= 0) << "poll CQ failed " << ne << "with error" + << std::strerror(errno); + } while (ne < 1); + + for (i = 0; i < ne; ++i) { + ibv_wc_status s = rdma_adapter_->wc_[i].status; + // recv complete + if ((int)rdma_adapter_->wc_[i].wr_id == RdmaChannel::kPingRecvWrid) { + CHECK(s == IBV_WC_SUCCESS) << ": " << ibv_wc_status_str( + rdma_adapter_->wc_[i].status) + << "(" << rdma_adapter_->wc_[i].status + << ") for PING_RECV_WRID"; + ++rcnt; + // send complete + } else { + RdmaChannel* rc = + reinterpret_cast(rdma_adapter_->wc_[i].wr_id); + CHECK(s == IBV_WC_SUCCESS) << ": " << ibv_wc_status_str( + rdma_adapter_->wc_[i].status) + << "(" << rdma_adapter_->wc_[i].status + << ") to " << rc->remote_name_; + ++scnt; + } + } // for + } // while + CHECK(rcnt == scnt) << "Connectivity check failed!"; + rdma_adapter_->StartPolling(); + return (num_remote_workers_ == rcnt) && (num_remote_workers_ == scnt); +} + RdmaMgr::~RdmaMgr() { for (const auto& p : channel_table_) delete p.second; channel_table_.clear(); diff --git a/tensorflow/contrib/verbs/rdma_mgr.h b/tensorflow/contrib/verbs/rdma_mgr.h index b156f64096c..e711e604788 100644 --- a/tensorflow/contrib/verbs/rdma_mgr.h +++ b/tensorflow/contrib/verbs/rdma_mgr.h @@ -28,12 +28,16 @@ limitations under the License. namespace tensorflow { class RdmaMgr { + friend class RdmaChannel; + friend class RdmaAdapter; + public: explicit RdmaMgr(const WorkerEnv* const worker_env, GrpcChannelCache* const channel_cache); ~RdmaMgr(); RdmaChannel* FindChannel(const string& key); void SetupChannels(); + bool ConnectivityCheck(); const string& local_worker() { return local_worker_; } private: @@ -44,7 +48,6 @@ class RdmaMgr { RdmaAdapter* rdma_adapter_; typedef std::unordered_map ChannelTable; ChannelTable channel_table_; - TF_DISALLOW_COPY_AND_ASSIGN(RdmaMgr); }; diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc index ce82ca28830..74f6681af3c 100644 --- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc +++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc @@ -21,8 +21,10 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" +#if GOOGLE_CUDA #include "tensorflow/core/common_runtime/gpu/gpu_util.h" #include "tensorflow/core/common_runtime/gpu/process_state.h" +#endif // GOOGLE_CUDA #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -58,20 +60,13 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync( // parse src_name and dst_name string src_name, dst_name, unused; if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_name, + &unused) || + !DeviceNameUtils::SplitDeviceName(parsed.dst_device, &dst_name, &unused)) { - s = errors::Internal("Could not parse src name."); + s = errors::Internal("Could not parse src or dst name."); } - CHECK(s.ok()) << "s is not ok, error code " << s.error_message(); - if (!s.ok()) { - done(s, Args(), recv_args, Tensor{}, false); - return; - } - if (!DeviceNameUtils::SplitDeviceName(parsed.dst_device, &dst_name, - &unused)) { - s = errors::Internal("Could not parse dst name."); - } - CHECK(s.ok()) << "s is not ok, error code " << s.error_message(); if (!s.ok()) { + LOG(ERROR) << "s is not ok, error code " << s.error_message(); done(s, Args(), recv_args, Tensor{}, false); return; } @@ -82,18 +77,13 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync( // insert callback rc->InsertRecvCallback(key_with_step_id, [this, key, key_with_step_id, rc, recv_args, parsed, done]() { - Status s; - Device* src_dev; - s = env_->device_mgr->LookupDevice("CPU:0", &src_dev); - CHECK(s.ok()) << "s is not ok, error code " << s.error_message(); - if (!s.ok()) { - done(s, Args(), recv_args, Tensor(), true); - return; - } - Device* dst_dev; - s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev); - CHECK(s.ok()) << "s is not ok, error code " << s.error_message(); - if (!s.ok()) { + Status src_s, dst_s, s; + Device* src_dev, *dst_dev; + src_s = env_->device_mgr->LookupDevice("CPU:0", &src_dev); + dst_s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev); + if (!src_s.ok() || !dst_s.ok()) { + s = src_s.ok() ? dst_s : src_s; + LOG(ERROR) << "s is not ok, error code " << s.error_message(); done(s, Args(), recv_args, Tensor(), true); return; } @@ -110,9 +100,10 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync( if (can_memcpy) { if (dst_dev->tensorflow_gpu_device_info() && (!recv_args.alloc_attrs.on_host())) { +#if GOOGLE_CUDA CHECK(recv_args.device_context) - << "send dev name: " << src_dev->name() - << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); + << "send dev name: " << src_dev->name() + << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); Tensor copy(alloc, rm.data_type_, rm.tensor_shape_); memcpy(DMAHelper::base(©), input, rm.tensor_bytes_); @@ -122,14 +113,15 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync( GPUUtil::CopyCPUTensorToGPU( ©, recv_args.device_context, dst_dev, &gpu_copy, - [this, gpu_copy, key, key_with_step_id, recv_args, done, rm, - rc](const Status& s) { + [this, gpu_copy, key, key_with_step_id, recv_args, done, rm, rc]( + const Status& s) { CHECK(s.ok()) << "copy tensor to gpu sync"; Tensor val; val = std::move(gpu_copy); RecvPostCopyOps(key, key_with_step_id, recv_args, done, rm, rc, val, s); }); +#endif // GOOGLE_CUDA return; } else { AllocatorAttributes host_alloc_attrs; diff --git a/tensorflow/contrib/verbs/verbs_server_lib.cc b/tensorflow/contrib/verbs/verbs_server_lib.cc index 6d1c79c0fb2..a606ef75a42 100644 --- a/tensorflow/contrib/verbs/verbs_server_lib.cc +++ b/tensorflow/contrib/verbs/verbs_server_lib.cc @@ -49,8 +49,8 @@ VerbsServer::~VerbsServer() { Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def, GrpcChannelCache** channel_cache) { string name_prefix = - strings::StrCat("/job:", server_def.job_name(), "/replica:0", - "/task:", server_def.task_index()); + strings::StrCat("/job:", server_def.job_name(), "/replica:0", "/task:", + server_def.task_index()); GrpcChannelSpec channel_spec; TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec)); @@ -103,6 +103,7 @@ Status VerbsServer::Start() { ThreadOptions(), "TF_verbs_service", [this] { verbs_service_->HandleRPCsLoop(); })); rdma_mgr_->SetupChannels(); + CHECK(rdma_mgr_->ConnectivityCheck()) << "Connectivity check failed!"; verbs_state_ = CONNECTED; } } diff --git a/tensorflow/core/common_runtime/pending_counts.h b/tensorflow/core/common_runtime/pending_counts.h index 9e39b6b7b93..5707f525922 100644 --- a/tensorflow/core/common_runtime/pending_counts.h +++ b/tensorflow/core/common_runtime/pending_counts.h @@ -44,7 +44,7 @@ namespace tensorflow { // PendingCounts counts(layout); // ... -// counts.decrement_panding(h[id], 1); +// counts.decrement_pending(h[id], 1); class PendingCounts { public: // The state machine for a node's execution. diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index c82d57694a7..3ae52f414fa 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -127,7 +127,7 @@ Status InferShapesForFunctionSubNode(const Node* node, ShapeRefiner* refiner, // // NOTE: Recursive user-defined functions are not supported. // Maybe we won't support recursive functions at all in TF, because of -// other maintanabilty issues. +// other maintainability issues. Status ShapeRefiner::InferShapesForFunction( const tensorflow::FunctionDef* function_def, bool keep_nested_shapes, ExtendedInferenceContext* outer_context) { diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index 1924c05d3dd..add80eda23d 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -1152,7 +1152,7 @@ Status Partition(const PartitionOptions& opts, Graph* g, // Add control edges from 'ref_control_inputs' to 'ref_recvs'. // NOTE(yuanbyu): Adding these control edges should not introduce // deadlocks. 'dst' has implicit "read" nodes that, when we split - // across devices, are made explicit; Retargettig the dependencies + // across devices, are made explicit; Retargeting the dependencies // to 'dst' to those nodes would not introduce cycles if there isn't // one before the transformation. // NOTE(yuanbyu): This may impact performance because it defers the diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h index cb32d643347..880e4e712ef 100644 --- a/tensorflow/core/graph/mkl_graph_util.h +++ b/tensorflow/core/graph/mkl_graph_util.h @@ -21,107 +21,108 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { -// Since our ops are going to produce and also consume N addition tensors -// (Mkl) for N Tensorflow tensors, we can have following different -// orderings among these 2N tensors. -// -// E.g., for Tensorflow tensors A, B, and C, our ops will produce and -// consume A_m, B_m, and C_m additionally. -// -// INTERLEAVED: in this case 2N tensors are interleaved. So for above -// example, the ordering looks like: A, A_m, B, B_m, C, C_m. -// -// CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed -// by N Mkl tensors. So for above example, the ordering looks -// like: A, B, C, A_m, B_m, C_m -// -// Following APIs map index of original Tensorflow tensors to their -// appropriate position based on selected ordering. For contiguous ordering, -// we need to know the total number of tensors (parameter total). -// -typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering; -// NOTE: Currently, we use contiguous ordering. If you change this, then you -// would need to change Mkl op definitions in nn_ops.cc. -static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS; + // Since our ops are going to produce and also consume N addition tensors + // (Mkl) for N Tensorflow tensors, we can have following different + // orderings among these 2N tensors. + // + // E.g., for Tensorflow tensors A, B, and C, our ops will produce and + // consume A_m, B_m, and C_m additionally. + // + // INTERLEAVED: in this case 2N tensors are interleaved. So for above + // example, the ordering looks like: A, A_m, B, B_m, C, C_m. + // + // CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed + // by N Mkl tensors. So for above example, the ordering looks + // like: A, B, C, A_m, B_m, C_m + // + // Following APIs map index of original Tensorflow tensors to their + // appropriate position based on selected ordering. For contiguous ordering, + // we need to know the total number of tensors (parameter total). + // + typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering; + // NOTE: Currently, we use contiguous ordering. If you change this, then you + // would need to change Mkl op definitions in nn_ops.cc. + static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS; -// Get index of MetaData tensor from index 'n' of Data tensor. -inline int DataIndexToMetaDataIndex(int n, int total_tensors) { - if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { - // For interleaved ordering, Mkl tensor follows immediately after - // Tensorflow tensor. - return n + 1; - } else { - CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); - // For contiguous ordering, Mkl tensor is n+total_tensors / 2 away. - return n + total_tensors / 2; + // Get index of MetaData tensor from index 'n' of Data tensor. + inline int DataIndexToMetaDataIndex(int n, int total_tensors) { + if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { + // For interleaved ordering, Mkl tensor follows immediately after + // Tensorflow tensor. + return n + 1; + } else { + CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); + // For contiguous ordering, Mkl tensor is n+total_tensors / 2 away. + return n + total_tensors / 2; + } } -} -int inline GetTensorDataIndex(int n, int total_tensors) { - if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { - return 2 * n; // index corresponding to nth input/output tensor - } else { - CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); - return n; - } -} + int inline GetTensorDataIndex(int n, int total_tensors) { + if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { + return 2 * n; // index corresponding to nth input/output tensor + } else { + CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); + return n; + } + } -int inline GetTensorMetaDataIndex(int n, int total_tensors) { - // Get index for TensorData first and then use mapping function - // to get TensorMetaData index from TensorData index. - int tidx = GetTensorDataIndex(n, total_tensors); - return DataIndexToMetaDataIndex(tidx, total_tensors); -} + int inline GetTensorMetaDataIndex(int n, int total_tensors) { + // Get index for TensorData first and then use mapping function + // to get TensorMetaData index from TensorData index. + int tidx = GetTensorDataIndex(n, total_tensors); + return DataIndexToMetaDataIndex(tidx, total_tensors); + } namespace mkl_op_registry { -static const char* kMklOpLabel = "MklOp"; -static const char* kMklOpLabelPattern = "label='MklOp'"; + static const char* kMklOpLabel = "MklOp"; + static const char* kMklOpLabelPattern = "label='MklOp'"; -// Get the name of Mkl op from original TensorFlow op -// We prefix 'Mkl' to the original op to get Mkl op. -inline string GetMklOpName(const string& name) { - // Prefix that we add to Tensorflow op name to construct Mkl op name. - const char* const kMklOpPrefix = "_Mkl"; - return string(kMklOpPrefix) + name; -} - -// Check whether opname with type T is registered as MKL-compliant. -// -// @input: name of the op -// @input: T datatype to be used for checking op -// @return: true if opname is registered as Mkl op; false otherwise -static inline bool IsMklOp(const std::string& op_name, DataType T) { - string kernel = KernelsRegisteredForOp(op_name); - bool result = - kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT); - if (result) { - VLOG(1) << "mkl_op_registry::" << op_name << " is " << kMklOpLabel; - } - return result; -} - -// Check whether opname with type T is registered as MKL-compliant and -// is element-wise. -// -// @input: name of the op -// @input: T datatype to be used for checking op -// @return: true if opname is registered as element-wise Mkl op; -// false otherwise -static inline bool IsMklElementWiseOp(const std::string& op_name, DataType T) { - if (!IsMklOp(op_name, T)) { - return false; + // Get the name of Mkl op from original TensorFlow op + // We prefix 'Mkl' to the original op to get Mkl op. + inline string GetMklOpName(const string& name) { + // Prefix that we add to Tensorflow op name to construct Mkl op name. + const char* const kMklOpPrefix = "_Mkl"; + return string(kMklOpPrefix) + name; } - bool result = (0 == op_name.compare(GetMklOpName("Add")) || - 0 == op_name.compare(GetMklOpName("Sub")) || - 0 == op_name.compare(GetMklOpName("Mul")) || - 0 == op_name.compare(GetMklOpName("Maximum")) || - 0 == op_name.compare(GetMklOpName("SquaredDifference"))); + // Check whether opname with type T is registered as MKL-compliant. + // + // @input: name of the op + // @input: T datatype to be used for checking op + // @return: true if opname is registered as Mkl op; false otherwise + static inline bool IsMklOp(const std::string& op_name, DataType T) { + string kernel = KernelsRegisteredForOp(op_name); + bool result = + kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT); + if (result) { + VLOG(1) << "mkl_op_registry::" << op_name << " is " << kMklOpLabel; + } + return result; + } - VLOG(1) << "mkl_op_registry::" << op_name - << " is elementwise MKL op: " << result; - return result; -} + // Check whether opname with type T is registered as MKL-compliant and + // is element-wise. + // + // @input: name of the op + // @input: T datatype to be used for checking op + // @return: true if opname is registered as element-wise Mkl op; + // false otherwise + static inline bool IsMklElementWiseOp(const std::string& op_name, + DataType T) { + if (!IsMklOp(op_name, T)) { + return false; + } + + bool result = (0 == op_name.compare(GetMklOpName("Add")) || + 0 == op_name.compare(GetMklOpName("Sub")) || + 0 == op_name.compare(GetMklOpName("Mul")) || + 0 == op_name.compare(GetMklOpName("Maximum")) || + 0 == op_name.compare(GetMklOpName("SquaredDifference"))); + + VLOG(1) << "mkl_op_registry::" << op_name + << " is elementwise MKL op: " << result; + return result; + } } // namespace mkl_op_registry } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index f4c9073deee..912075aa286 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -37,8 +37,8 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/tensor_format.h" -#include "tensorflow/core/graph/mkl_graph_util.h" #include "tensorflow/core/graph/mkl_layout_pass.h" +#include "tensorflow/core/graph/mkl_graph_util.h" namespace tensorflow { diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc index 3fd89e2b666..599bb88f015 100644 --- a/tensorflow/core/graph/mkl_tfconversion_pass.cc +++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc @@ -33,8 +33,8 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/graph/mkl_graph_util.h" #include "tensorflow/core/graph/mkl_tfconversion_pass.h" +#include "tensorflow/core/graph/mkl_graph_util.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/cwise_op_asinh.cc b/tensorflow/core/kernels/cwise_op_asinh.cc index 822d72e0685..0aec6aac344 100644 --- a/tensorflow/core/kernels/cwise_op_asinh.cc +++ b/tensorflow/core/kernels/cwise_op_asinh.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/tensorflow/core/kernels/cwise_op_bitwise_and.cc b/tensorflow/core/kernels/cwise_op_bitwise_and.cc index 017a2182dcf..5a6cf4bad16 100644 --- a/tensorflow/core/kernels/cwise_op_bitwise_and.cc +++ b/tensorflow/core/kernels/cwise_op_bitwise_and.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER6(BinaryOp, CPU, "BitwiseAnd", functor::bitwise_and, int8, int16, int32, - int64, uint8, uint16); +REGISTER8(BinaryOp, CPU, "BitwiseAnd", functor::bitwise_and, int8, int16, int32, + int64, uint8, uint16, uint32, uint64); #if TENSORFLOW_USE_SYCL #define REGISTER_SYCL_KERNEL(TYPE) \ @@ -30,13 +30,15 @@ REGISTER_SYCL_KERNEL(int32); REGISTER_SYCL_KERNEL(int64); REGISTER_SYCL_KERNEL(uint8); REGISTER_SYCL_KERNEL(uint16); +REGISTER_SYCL_KERNEL(uint32); +REGISTER_SYCL_KERNEL(uint64); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYCL #if GOOGLE_CUDA -REGISTER6(BinaryOp, GPU, "BitwiseAnd", functor::bitwise_and, int8, int16, int32, - int64, uint8, uint16); +REGISTER8(BinaryOp, GPU, "BitwiseAnd", functor::bitwise_and, int8, int16, int32, + int64, uint8, uint16, uint32, uint64); #endif // GOOGLE_CUDA } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_bitwise_or.cc b/tensorflow/core/kernels/cwise_op_bitwise_or.cc index 36f45fe92df..201a10198a6 100644 --- a/tensorflow/core/kernels/cwise_op_bitwise_or.cc +++ b/tensorflow/core/kernels/cwise_op_bitwise_or.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER6(BinaryOp, CPU, "BitwiseOr", functor::bitwise_or, int8, int16, int32, - int64, uint8, uint16); +REGISTER8(BinaryOp, CPU, "BitwiseOr", functor::bitwise_or, int8, int16, int32, + int64, uint8, uint16, uint32, uint64); #if TENSORFLOW_USE_SYCL #define REGISTER_SYCL_KERNEL(TYPE) \ @@ -30,13 +30,15 @@ REGISTER_SYCL_KERNEL(int32); REGISTER_SYCL_KERNEL(int64); REGISTER_SYCL_KERNEL(uint8); REGISTER_SYCL_KERNEL(uint16); +REGISTER_SYCL_KERNEL(uint32); +REGISTER_SYCL_KERNEL(uint64); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYCL #if GOOGLE_CUDA -REGISTER6(BinaryOp, GPU, "BitwiseOr", functor::bitwise_or, int8, int16, int32, - int64, uint8, uint16); +REGISTER8(BinaryOp, GPU, "BitwiseOr", functor::bitwise_or, int8, int16, int32, + int64, uint8, uint16, uint32, uint64); #endif // GOOGLE_CUDA } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_bitwise_xor.cc b/tensorflow/core/kernels/cwise_op_bitwise_xor.cc index 36432d851d9..2a7cd269959 100644 --- a/tensorflow/core/kernels/cwise_op_bitwise_xor.cc +++ b/tensorflow/core/kernels/cwise_op_bitwise_xor.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER6(BinaryOp, CPU, "BitwiseXor", functor::bitwise_xor, int8, int16, int32, - int64, uint8, uint16); +REGISTER8(BinaryOp, CPU, "BitwiseXor", functor::bitwise_xor, int8, int16, int32, + int64, uint8, uint16, uint32, uint64); #if TENSORFLOW_USE_SYCL #define REGISTER_SYCL_KERNEL(TYPE) \ @@ -30,13 +30,15 @@ REGISTER_SYCL_KERNEL(int32); REGISTER_SYCL_KERNEL(int64); REGISTER_SYCL_KERNEL(uint8); REGISTER_SYCL_KERNEL(uint16); +REGISTER_SYCL_KERNEL(uint32); +REGISTER_SYCL_KERNEL(uint64); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYCL #if GOOGLE_CUDA -REGISTER6(BinaryOp, GPU, "BitwiseXor", functor::bitwise_xor, int8, int16, int32, - int64, uint8, uint16); +REGISTER8(BinaryOp, GPU, "BitwiseXor", functor::bitwise_xor, int8, int16, int32, + int64, uint8, uint16, uint32, uint64); #endif // GOOGLE_CUDA } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_gpu_bitwise_and.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_bitwise_and.cu.cc index 27f973c90d7..3fbf69c114d 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_bitwise_and.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_bitwise_and.cu.cc @@ -19,7 +19,8 @@ limitations under the License. namespace tensorflow { namespace functor { -DEFINE_BINARY6(bitwise_and, int8, int16, int32, int64, uint8, uint16); +DEFINE_BINARY8(bitwise_and, int8, int16, int32, int64, uint8, uint16, uint32, + uint64); } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_gpu_bitwise_or.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_bitwise_or.cu.cc index a34c3a52cd6..8bcb82266a2 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_bitwise_or.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_bitwise_or.cu.cc @@ -19,7 +19,8 @@ limitations under the License. namespace tensorflow { namespace functor { -DEFINE_BINARY6(bitwise_or, int8, int16, int32, int64, uint8, uint16); +DEFINE_BINARY8(bitwise_or, int8, int16, int32, int64, uint8, uint16, uint32, + uint64); } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_gpu_bitwise_xor.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_bitwise_xor.cu.cc index a4531ab7c6f..e62a87aba44 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_bitwise_xor.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_bitwise_xor.cu.cc @@ -19,7 +19,8 @@ limitations under the License. namespace tensorflow { namespace functor { -DEFINE_BINARY6(bitwise_xor, int8, int16, int32, int64, uint8, uint16); +DEFINE_BINARY8(bitwise_xor, int8, int16, int32, int64, uint8, uint16, uint32, + uint64); } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/decode_bmp_op.cc b/tensorflow/core/kernels/decode_bmp_op.cc index 6d9fdfcf33b..c778278e8fb 100644 --- a/tensorflow/core/kernels/decode_bmp_op.cc +++ b/tensorflow/core/kernels/decode_bmp_op.cc @@ -49,6 +49,12 @@ class DecodeBmpOp : public OpKernel { // Start decoding image to get shape details const StringPiece input = contents.scalar()(); + OP_REQUIRES(context, (32 <= input.size()), + errors::InvalidArgument("Incomplete bmp content, requires at " + "least 32 bytes to find the header " + "size, width, height, and bpp, got ", + input.size(), " bytes")); + const uint8* img_bytes = reinterpret_cast(input.data()); const int32 header_size = internal::SubtleMustCopy( *(reinterpret_cast(img_bytes + 10))); @@ -74,6 +80,22 @@ class DecodeBmpOp : public OpKernel { errors::InvalidArgument( "Number of channels must be 1, 3 or 4, was ", channels_)); + // there may be padding bytes when the width is not a multiple of 4 bytes + // 8 * channels == bits per pixel + const int row_size = (8 * channels_ * width + 31) / 32 * 4; + + const int last_pixel_offset = + header_size + (abs(height) - 1) * row_size + (width - 1) * channels_; + + // [expected file size] = [last pixel offset] + [last pixel size=channels] + const int expected_file_size = last_pixel_offset + channels_; + + OP_REQUIRES( + context, (expected_file_size <= input.size()), + errors::InvalidArgument("Incomplete bmp content, requires at least ", + expected_file_size, " bytes, got ", + input.size(), " bytes")); + // if height is negative, data layout is top down // otherwise, it's bottom up bool top_down = (height < 0); @@ -86,25 +108,23 @@ class DecodeBmpOp : public OpKernel { const uint8* bmp_pixels = &img_bytes[header_size]; - Decode(bmp_pixels, output->flat().data(), width, abs(height), - channels_, top_down); + Decode(bmp_pixels, row_size, output->flat().data(), width, + abs(height), channels_, top_down); } - uint8* Decode(const uint8* input, uint8* const output, const int width, - const int height, const int channles, bool top_down); + uint8* Decode(const uint8* input, const int row_size, uint8* const output, + const int width, const int height, const int channles, + bool top_down); private: int channels_; }; REGISTER_KERNEL_BUILDER(Name("DecodeBmp").Device(DEVICE_CPU), DecodeBmpOp); -uint8* DecodeBmpOp::Decode(const uint8* input, uint8* const output, - const int width, const int height, - const int channels, bool top_down) { - // there may be padding bytes when the width is not a multiple of 4 bytes - // 8 * channels == bits per pixel - int row_size = (8 * channels * width + 31) / 32 * 4; - +uint8* DecodeBmpOp::Decode(const uint8* input, const int row_size, + uint8* const output, const int width, + const int height, const int channels, + bool top_down) { for (int i = 0; i < height; i++) { int src_pos; int dst_pos; diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc index 02da64ce98d..a5fd07fbe17 100644 --- a/tensorflow/core/kernels/depthwise_conv_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_op.cc @@ -430,10 +430,9 @@ TF_CALL_double(REGISTER_CPU_KERNEL); #endif #if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative") - .Device(DEVICE_GPU) - .TypeConstraint("T"), - DepthwiseConv2dNativeOp); +REGISTER_KERNEL_BUILDER( + Name("DepthwiseConv2dNative").Device(DEVICE_GPU).TypeConstraint("T"), + DepthwiseConv2dNativeOp); REGISTER_KERNEL_BUILDER( Name("DepthwiseConv2dNative").Device(DEVICE_GPU).TypeConstraint("T"), diff --git a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc new file mode 100644 index 00000000000..9bb58b13f38 --- /dev/null +++ b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc @@ -0,0 +1,465 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The algorithm for dynamic partition has the following steps: +// 1. Let N be the size of partitions. We initialize a new vector indices_in +// with the values 0, 1, 2, ..., N-1. +// 2. We apply cub::DeviceRadixSort::SortPairs to the key - value pairs given +// by partitions and indices_in. This will result in two new vectors +// partitions_out and indices_out, with partitions_out sorted. +// 3. The first dimension of outputs[i] is equal to the number of i-values in +// partitions_out. We determine it in two steps: +// - apply cub::DeviceReduce::ReduceByKey to count how many times each value +// appears in partitions_out, +// - move the results to partition_count. This handles missing values +// (corresponding to empty parts). +// 4. Because partition_count is on the GPU, we bring it asynchronously to +// the CPU. Then we can allocate the output tensors. +// 5. Finally, we use indices_out and the gather functor to collect the output. +// This works, because for each interval of i-values, indices_out points +// to the slices which should form output[i]. + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "external/cub_archive/cub/device/device_radix_sort.cuh" +#include "external/cub_archive/cub/device/device_reduce.cuh" +#include "external/cub_archive/cub/iterator/constant_input_iterator.cuh" +#include "external/cub_archive/cub/thread/thread_operators.cuh" +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/kernels/gather_functor_gpu.cu.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/transform_output_iterator.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +namespace { + +template +__global__ void RangeInitKernel(const T start, const T delta, const int32 size, + T* out) { + CUDA_1D_KERNEL_LOOP(i, size) { out[i] = start + i * delta; } +} + +__global__ void MoveValuesKernel(const int32* keys, const int32* values, + const int32* size, int32 out_size, + int32* out) { + int32 N = min(ldg(size), out_size); + CUDA_1D_KERNEL_LOOP(i, N) { + int32 key = ldg(keys + i); + int32 value = ldg(values + i); + if (FastBoundsCheck(key, out_size)) out[key] = value; + } +} + +// Initialize out with range start, start + delta, start + 2 * delta, ... +// This is needed because tf.range has no GPU implementation. +template +void RangeInit(const GPUDevice& d, const T start, const T delta, + const int32 size, typename TTypes::Flat out) { + CudaLaunchConfig config = GetCudaLaunchConfig(size, d); + RangeInitKernel< + T><<>>( + start, delta, size, out.data()); +} + +// Given *num_runs pairs (key, value), this function moves the value +// corresponding to key i at position i in the array out. +void MoveValues(const GPUDevice& d, int32* keys, int32* values, int32* num_runs, + int32 out_size, int32* out) { + // Because num_runs is located on the GPU, we can not access it directly. + // So we launch the kernel with size = out_size. + // This is valid for correct inputs, because then out_size >= *num_runs. + // For wrong inputs, we may have out_size < *num_runs. In this case we will + // only handle the first out_size values. + CudaLaunchConfig config = GetCudaLaunchConfig(out_size, d); + MoveValuesKernel<<>>(keys, values, num_runs, out_size, out); +} + +template +void CallGatherKernel(const GPUDevice& d, const T* params, const int32* indices, + T* out, int64 gather_dim_size, int64 indices_size, + int64 slice_size, int64 out_size) { + CudaLaunchConfig config = GetCudaLaunchConfig(out_size, d); + GatherOpKernel< + T, int32, + true><<>>( + params, indices, out, gather_dim_size, indices_size, slice_size, + out_size); +} + +struct IdentityOp { + __device__ int32 __forceinline__ operator()(const int32& a) const { + return a; + } +}; + +// Define an output iterator that only allows assignment to +// positions between [base, base + limit). +class BoundedOutputIterator + : public TransformOutputIterator { + private: + int32 limit; + int32* base; + + struct BoundedReference : Reference { + int32 limit; + int32* base; + // Constructor + __host__ __device__ __forceinline__ + BoundedReference(int32* ptr, int32* base, IdentityOp op, int32 limit) + : Reference(ptr, op), limit(limit), base(base) {} + + // Assignment + __host__ __device__ __forceinline__ int32 operator=(int32 val) { + if (ptr - base < limit && ptr - base >= 0) *ptr = val; + return val; + } + }; + + public: + typedef BoundedOutputIterator self_type; + typedef BoundedReference reference; + + __host__ __device__ __forceinline__ BoundedOutputIterator(int32* ptr, + IdentityOp op, + int32 size) + : TransformOutputIterator(ptr, op), limit(size), base(ptr) {} + + __host__ __device__ __forceinline__ + BoundedOutputIterator(int32* ptr, int32* base, IdentityOp op, int32 size) + : TransformOutputIterator(ptr, op), limit(size), base(base) {} + + // Indirection + __host__ __device__ __forceinline__ reference operator*() const { + return BoundedReference(ptr, base, conversion_op, limit); + } + + // Array subscript + __host__ __device__ __forceinline__ reference operator[](int32 n) const { + return BoundedReference(ptr + n, base, conversion_op, limit); + } + + // Addition + __host__ __device__ __forceinline__ self_type operator+(int32 n) const { + self_type retval(ptr + n, base, conversion_op, limit); + return retval; + } + + // Subtraction + __host__ __device__ __forceinline__ self_type operator-(int32 n) const { + self_type retval(ptr - n, base, conversion_op, limit); + return retval; + } +}; + +} // namespace + +// The current implementation has memory cost on GPU +// I + P + max(3N + R + P, O + N), where: +// I - the size of the input +// N - the size of the partitions tensor +// R - the temporary storage used by cub::RadixSort, about 2N +// P - the number of partitions +// O - the size of the output +// So roughly the cost is I + P + max(5N, O + N). +template +class DynamicPartitionOpGPU : public AsyncOpKernel { + public: + explicit DynamicPartitionOpGPU(OpKernelConstruction* c) : AsyncOpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("num_partitions", &num_partitions_)); + OP_REQUIRES(c, num_partitions_ >= 1, + errors::InvalidArgument("num_partitions must be at least 1")); + } + + void AllocateTempSpace(OpKernelContext* c, int32 N, Tensor* indices_in, + Tensor* partitions_out, Tensor* indices_out, + DoneCallback done) { + int32 M = std::max(N, num_partitions_); + // indices_in will be made slightly larger to accommodate + // later computations. + OP_REQUIRES_OK_ASYNC( + c, c->allocate_temp(DT_INT32, TensorShape({M}), indices_in), done); + OP_REQUIRES_OK_ASYNC( + c, c->allocate_temp(DT_INT32, TensorShape({N}), partitions_out), done); + OP_REQUIRES_OK_ASYNC( + c, c->allocate_temp(DT_INT32, TensorShape({N}), indices_out), done); + } + + void AllocateOutputs(OpKernelContext* c, const Tensor* data, + const Tensor* partitions, const Tensor* partition_count, + OpOutputList* Tout, DoneCallback done) { + auto e_part_count = partition_count->flat(); + // Allocate output tensors of the right size + OP_REQUIRES_OK_ASYNC(c, c->output_list("outputs", Tout), done); + for (int p = 0; p < num_partitions_; p++) { + TensorShape shape; + shape.AddDim(e_part_count(p)); + for (int i = partitions->dims(); i < data->dims(); i++) { + shape.AddDim(data->dim_size(i)); + } + Tensor* out; + OP_REQUIRES_OK_ASYNC(c, Tout->allocate(p, shape, &out), done); + } + } + + void ComputeAsync(OpKernelContext* c, DoneCallback done) { + const Tensor& data = c->input(0); + const Tensor& partitions = c->input(1); + + OP_REQUIRES_ASYNC( + c, TensorShapeUtils::StartsWith(data.shape(), partitions.shape()), + errors::InvalidArgument("data.shape must start with partitions.shape, ", + "got data.shape = ", data.shape().DebugString(), + ", partitions.shape = ", + partitions.shape().DebugString()), + done); + + Tensor partition_count; + + // We must handle the case of empty partitions separately, + // because kernels don't work with 0-sized tensors. + if (partitions.NumElements() == 0) { + AllocatorAttributes alloc_attr; + alloc_attr.set_on_host(true); + OP_REQUIRES_OK_ASYNC( + c, c->allocate_temp(DT_INT32, TensorShape({num_partitions_}), + &partition_count, alloc_attr), + done); + auto e_part_count = partition_count.flat(); + for (int i = 0; i < num_partitions_; i++) e_part_count(i) = 0; + OpOutputList outputs; + this->AllocateOutputs(c, &data, &partitions, &partition_count, &outputs, + done); + if (c->status().ok()) done(); + return; + } + + // Prepare for counting. + OP_REQUIRES_OK_ASYNC( + c, c->allocate_temp(DT_INT32, TensorShape({num_partitions_}), + &partition_count), + done); + Tensor indices_out; + // Count how many times each partition index occurs. + // Also sort the info in partitions and output it in indices_out, + // in preparation for the next step. + this->CountAndSortParts(c, &partitions, &partition_count, &indices_out, + done); + if (!c->status().ok()) return; + + // In order to allocate the output tensor we have to move partition_count + // to CPU. + auto* stream = c->op_device_context()->stream(); + OP_REQUIRES_ASYNC(c, stream, errors::Internal("No GPU stream available."), + done); + Tensor cpu_tensor; + AllocatorAttributes alloc_attr; + alloc_attr.set_on_host(true); + alloc_attr.set_gpu_compatible(true); + OP_REQUIRES_OK_ASYNC( + c, c->allocate_temp(partition_count.dtype(), partition_count.shape(), + &cpu_tensor, alloc_attr), + done); + perftools::gputools::DeviceMemoryBase wrapped( + partition_count.flat().data(), num_partitions_ * sizeof(int32)); + const bool status = + stream + ->ThenMemcpy(cpu_tensor.flat().data(), wrapped, + num_partitions_ * sizeof(int32)) + .ok(); + OP_REQUIRES_ASYNC( + c, status, + errors::Internal("Failed to launch copy from device to host."), done); + + // Keep a reference to partition_count so that the buffer + // is not deallocated at the end of the function, before + // memcpy is completed. + TensorReference partition_ref(partition_count); + auto wrapped_callback = [this, c, &data, &partitions, indices_out, + partition_ref, cpu_tensor, done]() { + OpOutputList outputs; + this->AllocateOutputs(c, &data, &partitions, &cpu_tensor, &outputs, done); + if (!c->status().ok()) { + partition_ref.Unref(); + return; + } + int32 N = partitions.NumElements(); + int64 slice_size = data.NumElements() / N; + this->GatherSlices(c, &data, &indices_out, N, slice_size, outputs); + partition_ref.Unref(); + done(); + }; + + c->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute( + stream, wrapped_callback); + } + + protected: + void RadixSort(OpKernelContext* c, const Tensor* partitions, + Tensor* indices_in, Tensor* partitions_out, + Tensor* indices_out, DoneCallback done) { + int32 N = partitions->NumElements(); + const GPUDevice& device = c->eigen_device(); + const cudaStream_t& cu_stream = GetCudaStream(c); + + // Initialize the indices_in tensor using the Range GPU kernel. + RangeInit(device, 0, 1, N, indices_in->flat()); + // Obtain the pointers to inner buffers. + const int32* partitions_ptr = partitions->flat().data(); + int32* partitions_out_ptr = partitions_out->flat().data(); + int32* indices_in_ptr = indices_in->flat().data(); + int32* indices_out_ptr = indices_out->flat().data(); + // Determine temporary device storage requirements. + Tensor cub_temp_storage; + size_t temp_storage_bytes = 0; + cub::DeviceRadixSort::SortPairs( + NULL, temp_storage_bytes, partitions_ptr, partitions_out_ptr, + indices_in_ptr, indices_out_ptr, N, 0, sizeof(int32) * 8, cu_stream); + // Allocate temporary storage. + OP_REQUIRES_OK_ASYNC( + c, c->allocate_temp( + DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), + &cub_temp_storage), + done); + // Radix-sort the partition information. + cub::DeviceRadixSort::SortPairs( + cub_temp_storage.flat().data(), temp_storage_bytes, + partitions_ptr, partitions_out_ptr, indices_in_ptr, indices_out_ptr, N, + 0, sizeof(int32) * 8, cu_stream); + } // At this point cub_temp_storage will be marked for deallocation. + + void CountAndSortParts(OpKernelContext* c, const Tensor* partitions, + Tensor* partition_count, Tensor* indices_out, + DoneCallback done) { + const GPUDevice& device = c->eigen_device(); + const cudaStream_t& cu_stream = GetCudaStream(c); + int32 N = partitions->NumElements(); + Tensor indices_in; + Tensor partitions_out; + Tensor aggregates_out; + + // Allocate memory for Radix-Sort. + this->AllocateTempSpace(c, N, &indices_in, &partitions_out, indices_out, + done); + if (!c->status().ok()) return; + this->RadixSort(c, partitions, &indices_in, &partitions_out, indices_out, + done); + if (!c->status().ok()) return; + // We will now apply a reduce operation to count how many times + // each index appears in partitions. + + // Zero-out the partition_count tensor. + functor::SetZeroFunctor zero_functor; + zero_functor(device, partition_count->flat()); + // Allocate memory for aggregates_out. + OP_REQUIRES_OK_ASYNC( + c, c->allocate_temp(DT_INT32, TensorShape({num_partitions_}), + &aggregates_out), + done); + // Obtain the pointers to inner buffers. + int32* keys_in_ptr = partitions_out.flat().data(); + // Here we reuse the indices_in tensor for the unique keys output. + int32* unique_out_ptr = indices_in.flat().data(); + int32* aggregates_out_ptr = aggregates_out.flat().data(); + // We wrap the pointers in bounded output iterators to guard against + // wrong inputs (more than num_partitions distinct indices). + IdentityOp id_op; + BoundedOutputIterator unique_out_it(unique_out_ptr, id_op, num_partitions_); + BoundedOutputIterator aggregates_out_it(aggregates_out_ptr, id_op, + num_partitions_); + + cub::ConstantInputIterator values_in(1); + cub::Sum reduction_op; + + // Allocate space on GPU for the number of runs. This is required by CUB. + Tensor num_runs; + OP_REQUIRES_OK_ASYNC( + c, c->allocate_temp(DT_INT32, TensorShape({1}), &num_runs), done); + int32* num_runs_ptr = num_runs.flat().data(); + + // Determine temporary device storage requirements + Tensor cub_temp_storage; + size_t temp_storage_bytes = 0; + cub::DeviceReduce::ReduceByKey(NULL, temp_storage_bytes, keys_in_ptr, + unique_out_it, values_in, aggregates_out_it, + num_runs_ptr, reduction_op, N, cu_stream); + // Allocate temporary storage. + OP_REQUIRES_OK_ASYNC( + c, c->allocate_temp( + DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), + &cub_temp_storage), + done); + // Run reduce-by-key. The effect is that we count how many times + // each index appears in partitions. The distinct indices are stored + // in unique_out, while the count is stored in aggregates_out. + // The total number of distinct indices is stored in num_runs. + cub::DeviceReduce::ReduceByKey(cub_temp_storage.flat().data(), + temp_storage_bytes, keys_in_ptr, + unique_out_it, values_in, aggregates_out_it, + num_runs_ptr, reduction_op, N, cu_stream); + // We are not done yet. unique_out only contains the indices that appeared + // at least once in partitions. We move each value from aggregates_out + // to the corresponding position in partition_count. This will handle + // possibly empty parts. + MoveValues(device, unique_out_ptr, aggregates_out_ptr, num_runs_ptr, + num_partitions_, partition_count->flat().data()); + } // At this point indices_in, partitions_out, aggregates_out + // and cub_temp_storage will be marked for deallocation. + + void GatherSlices(OpKernelContext* c, const Tensor* data, + const Tensor* indices, int32 N, int64 slice_size, + OpOutputList& outs) { + const GPUDevice& device = c->eigen_device(); + const int32* ind_base = indices->flat().data(); + const T* data_base = data->flat().data(); + + for (int p = 0; p < num_partitions_; p++) { + int32 indices_size = outs[p]->dim_size(0); + int64 out_size = outs[p]->NumElements(); + T* out_base = outs[p]->flat().data(); + if (out_size > 0) + CallGatherKernel(device, data_base, ind_base, out_base, N, + indices_size, slice_size, out_size); + ind_base += indices_size; + } + } + + int32 num_partitions_; +}; + +#define REGISTER_DYNAMIC_PARTITION_GPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("DynamicPartition").Device(DEVICE_GPU).TypeConstraint("T"), \ + DynamicPartitionOpGPU) + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_DYNAMIC_PARTITION_GPU); +TF_CALL_complex64(REGISTER_DYNAMIC_PARTITION_GPU); +TF_CALL_complex128(REGISTER_DYNAMIC_PARTITION_GPU); +#undef REGISTER_DYNAMIC_PARTITION_GPU + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc index d8bdb700e66..2eefadad494 100644 --- a/tensorflow/core/kernels/maxpooling_op.cc +++ b/tensorflow/core/kernels/maxpooling_op.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/core/kernels/maxpooling_op.h" #include -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -38,6 +37,7 @@ limitations under the License. #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/use_cudnn.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #if GOOGLE_CUDA #include "tensorflow/core/kernels/maxpooling_op_gpu.h" diff --git a/tensorflow/core/kernels/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl_batch_matmul_op.cc index d9713075be6..9fee94f9465 100644 --- a/tensorflow/core/kernels/mkl_batch_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_batch_matmul_op.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/kernels/fill_functor.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #define MKL_Complex8 tensorflow::complex64 #define MKL_Complex16 tensorflow::complex128 diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc index 9080bf7be89..f291281108d 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc @@ -45,12 +45,12 @@ limitations under the License. #ifdef INTEL_MKL_DNN #include "mkldnn.hpp" -using mkldnn::prop_kind; using mkldnn::stream; +using mkldnn::prop_kind; +using mkldnn::convolution_forward; using mkldnn::convolution_backward_weights; using mkldnn::convolution_direct; -using mkldnn::convolution_forward; #endif @@ -463,13 +463,12 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel { // Generate input shapes. TensorShape filter_shape; - OP_REQUIRES( - context, TensorShapeUtils::IsVector(filter_tensor.shape()), - errors::InvalidArgument( + OP_REQUIRES(context, TensorShapeUtils::IsVector(filter_tensor.shape()), + errors::InvalidArgument( "Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ", filter_tensor.dims())); OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( - filter_tensor.vec(), &filter_shape)); + filter_tensor.vec(), &filter_shape)); TensorShape input_shape = input_tensor.shape(); TensorShape obp_shape = obp_tensor.shape(); @@ -481,26 +480,27 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel { // Get forward convolution parameters. MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_); - conv_utl.GetConvFwdSizesInMklOrder( - input_shape, filter_shape, &fwd_input_dims, &fwd_filter_dims, - &strides, &fwd_output_dims_tf_order, &fwd_output_dims, &padding_l, - &padding_r); + conv_utl.GetConvFwdSizesInMklOrder(input_shape, filter_shape, + &fwd_input_dims, &fwd_filter_dims, + &strides, + &fwd_output_dims_tf_order, + &fwd_output_dims, + &padding_l, &padding_r); if (!context->status().ok()) return; // Create Convolution forward descriptor since Convolution backward // API needs it. For that, we first need to create input, filter // and output memory descriptors. auto mkl_data_format = TFDataFormatToMklDnnDataFormat(data_format_); - auto fwd_src_md = - memory::desc(fwd_input_dims, MklDnnType(), mkl_data_format); - auto fwd_filter_md = - memory::desc(fwd_filter_dims, MklDnnType(), memory::format::hwio); - auto fwd_out_md = - memory::desc(fwd_output_dims, MklDnnType(), mkl_data_format); - auto fwd_desc = convolution_forward::desc( - prop_kind::forward, convolution_direct, fwd_src_md, fwd_filter_md, - fwd_out_md, strides, padding_l, padding_r, - TFPaddingToMklDnnPadding(padding_)); + auto fwd_src_md = memory::desc(fwd_input_dims, MklDnnType(), + mkl_data_format); + auto fwd_filter_md = memory::desc(fwd_filter_dims, MklDnnType(), + memory::format::hwio); + auto fwd_out_md = memory::desc(fwd_output_dims, MklDnnType(), + mkl_data_format); + auto fwd_desc = convolution_forward::desc(prop_kind::forward, + convolution_direct, fwd_src_md, fwd_filter_md, fwd_out_md, + strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_)); auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine); // Allocate output tensor and shape @@ -537,22 +537,23 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel { output.SetOpMemDesc(bwd_output_dims, memory::format::any); // Create convolution backward weights primitive. - auto bwd_desc = convolution_backward_weights::desc( - convolution_direct, input.GetOpMemDesc(), output.GetOpMemDesc(), - outbackprop.GetOpMemDesc(), strides, padding_l, padding_r, - TFPaddingToMklDnnPadding(padding_)); + auto bwd_desc = convolution_backward_weights::desc(convolution_direct, + input.GetOpMemDesc(), output.GetOpMemDesc(), + outbackprop.GetOpMemDesc(), strides, padding_l, + padding_r, TFPaddingToMklDnnPadding(padding_)); - auto bwd_pd = convolution_backward_weights::primitive_desc( - bwd_desc, cpu_engine, fwd_pd); + auto bwd_pd = convolution_backward_weights::primitive_desc(bwd_desc, + cpu_engine, + fwd_pd); PrepareAndExecutePrimitive(bwd_pd, &input, &outbackprop, &output); - } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); - OP_REQUIRES_OK( - context, - errors::Aborted("Operation received an exception:", error_msg)); + } catch (mkldnn::error &e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + + ", in file " + string(__FILE__) + ":" + + std::to_string(__LINE__); + OP_REQUIRES_OK(context, errors::Aborted("Operation received an exception:", + error_msg)); } } @@ -563,8 +564,9 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel { // Prepare and execute net - checks for input and output reorders. void PrepareAndExecutePrimitive( - const convolution_backward_weights::primitive_desc& conv_pd, - MklDnnData* input, MklDnnData* obp, MklDnnData* output) { + const convolution_backward_weights::primitive_desc& conv_pd, + MklDnnData* input, MklDnnData* obp, + MklDnnData* output) { // Create reorders between user layout and MKL layout if it is needed and // add it to the net before convolution. std::vector net; @@ -575,10 +577,10 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel { // output side, we will prepare reorder primitive in case output // reorder to user memory is required. bool output_reorder_required = output->PrepareReorderToUserMemIfReq( - conv_pd.diff_weights_primitive_desc()); + conv_pd.diff_weights_primitive_desc()); - net.push_back(convolution_backward_weights( - conv_pd, input->GetOpMem(), obp->GetOpMem(), output->GetOpMem())); + net.push_back(convolution_backward_weights(conv_pd, input->GetOpMem(), + obp->GetOpMem(), output->GetOpMem())); // Insert reorder primitive in the net for output reorder if reorder is // required. diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc index 4b6bf92e426..4a47d0463ef 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc @@ -23,8 +23,6 @@ limitations under the License. #define EIGEN_USE_THREADS #include #include -#include "mkl_dnn.h" -#include "mkl_dnn_types.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -43,16 +41,18 @@ limitations under the License. #include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/use_cudnn.h" #include "tensorflow/core/util/work_sharder.h" +#include "mkl_dnn.h" +#include "mkl_dnn_types.h" #ifdef INTEL_MKL_DNN #include "mkldnn.hpp" -using mkldnn::prop_kind; using mkldnn::stream; +using mkldnn::prop_kind; -using mkldnn::convolution_backward_data; -using mkldnn::convolution_direct; using mkldnn::convolution_forward; +using mkldnn::convolution_direct; +using mkldnn::convolution_backward_data; #endif namespace tensorflow { @@ -397,13 +397,12 @@ class MklConv2DCustomBackpropInputOp : public OpKernel { // Generate input shape. TensorShape input_shape; - OP_REQUIRES( - context, TensorShapeUtils::IsVector(input_tensor.shape()), - errors::InvalidArgument( + OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()), + errors::InvalidArgument( "Conv2DBackpropInput: input_sizes input must be 1-dim, not ", input_tensor.dims())); OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( - input_tensor.vec(), &input_shape)); + input_tensor.vec(), &input_shape)); TensorShape filter_shape = filter_tensor.shape(); TensorShape obp_shape = obp_tensor.shape(); @@ -415,26 +414,27 @@ class MklConv2DCustomBackpropInputOp : public OpKernel { // Get forward convolution parameters. MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_); - conv_utl.GetConvFwdSizesInMklOrder( - input_shape, filter_shape, &fwd_input_dims, &fwd_filter_dims, - &strides, &fwd_output_dims_tf_order, &fwd_output_dims, &padding_l, - &padding_r); + conv_utl.GetConvFwdSizesInMklOrder(input_shape, filter_shape, + &fwd_input_dims, &fwd_filter_dims, + &strides, + &fwd_output_dims_tf_order, + &fwd_output_dims, + &padding_l, &padding_r); if (!context->status().ok()) return; // Create Convolution forward descriptor since Convolution backward // API needs it. For that, we first need to create input, filter // and output memory descriptors. auto mkl_data_format = TFDataFormatToMklDnnDataFormat(data_format_); - auto fwd_src_md = - memory::desc(fwd_input_dims, MklDnnType(), mkl_data_format); - auto fwd_filter_md = - memory::desc(fwd_filter_dims, MklDnnType(), memory::format::hwio); - auto fwd_out_md = - memory::desc(fwd_output_dims, MklDnnType(), mkl_data_format); - auto fwd_desc = convolution_forward::desc( - prop_kind::forward, convolution_direct, fwd_src_md, fwd_filter_md, - fwd_out_md, strides, padding_l, padding_r, - TFPaddingToMklDnnPadding(padding_)); + auto fwd_src_md = memory::desc(fwd_input_dims, MklDnnType(), + mkl_data_format); + auto fwd_filter_md = memory::desc(fwd_filter_dims, MklDnnType(), + memory::format::hwio); + auto fwd_out_md = memory::desc(fwd_output_dims, MklDnnType(), + mkl_data_format); + auto fwd_desc = convolution_forward::desc(prop_kind::forward, + convolution_direct, fwd_src_md, fwd_filter_md, fwd_out_md, + strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_)); auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine); // Allocate output tensor and shape @@ -475,22 +475,23 @@ class MklConv2DCustomBackpropInputOp : public OpKernel { output.SetOpMemDesc(bwd_output_dims, memory::format::any); // Create convolution backward data primitive. - auto bwd_desc = convolution_backward_data::desc( - convolution_direct, output.GetOpMemDesc(), filter.GetOpMemDesc(), - outbackprop.GetOpMemDesc(), strides, padding_l, padding_r, - TFPaddingToMklDnnPadding(padding_)); + auto bwd_desc = convolution_backward_data::desc(convolution_direct, + output.GetOpMemDesc(), filter.GetOpMemDesc(), + outbackprop.GetOpMemDesc(), strides, padding_l, + padding_r, TFPaddingToMklDnnPadding(padding_)); - auto bwd_pd = convolution_backward_data::primitive_desc( - bwd_desc, cpu_engine, fwd_pd); + auto bwd_pd = convolution_backward_data::primitive_desc(bwd_desc, + cpu_engine, + fwd_pd); PrepareAndExecutePrimitive(bwd_pd, &filter, &outbackprop, &output); - } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); - OP_REQUIRES_OK( - context, - errors::Aborted("Operation received an exception:", error_msg)); + } catch (mkldnn::error &e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + + ", in file " + string(__FILE__) + ":" + + std::to_string(__LINE__); + OP_REQUIRES_OK(context, errors::Aborted("Operation received an exception:", + error_msg)); } } @@ -501,8 +502,9 @@ class MklConv2DCustomBackpropInputOp : public OpKernel { // Prepare and execute net - checks for input and output reorders. void PrepareAndExecutePrimitive( - const convolution_backward_data::primitive_desc& conv_pd, - MklDnnData* filter, MklDnnData* obp, MklDnnData* output) { + const convolution_backward_data::primitive_desc& conv_pd, + MklDnnData* filter, MklDnnData* obp, + MklDnnData* output) { // Create reorders between user layout and MKL layout if it is needed and // add it to the net before convolution. std::vector net; @@ -512,11 +514,11 @@ class MklConv2DCustomBackpropInputOp : public OpKernel { // Memory for output of convolution. Since we may need reorder on the // output side, we will prepare reorder primitive in case output // reorder to user memory is required. - bool output_reorder_required = - output->PrepareReorderToUserMemIfReq(conv_pd.diff_src_primitive_desc()); + bool output_reorder_required = output->PrepareReorderToUserMemIfReq( + conv_pd.diff_src_primitive_desc()); - net.push_back(convolution_backward_data( - conv_pd, obp->GetOpMem(), filter->GetOpMem(), output->GetOpMem())); + net.push_back(convolution_backward_data(conv_pd, obp->GetOpMem(), + filter->GetOpMem(), output->GetOpMem())); // Insert reorder primitive in the net for output reorder if reorder is // required. diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index 369f632fb46..a9872b8d6d3 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include -#include #include +#include #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -46,11 +46,11 @@ limitations under the License. #ifdef INTEL_MKL_DNN #include "mkldnn.hpp" -using mkldnn::prop_kind; using mkldnn::stream; +using mkldnn::prop_kind; -using mkldnn::convolution_direct; using mkldnn::convolution_forward; +using mkldnn::convolution_direct; #endif namespace tensorflow { @@ -523,16 +523,19 @@ class MklConv2DOp : public OpKernel { // Get shapes of input tensors in MKL-DNN order MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_); - conv_utl.GetConvFwdSizesInMklOrder( - src_tensor.shape(), filter_tensor.shape(), &src_dims, &filter_dims, - &strides, &output_dims_tf_order, &output_dims_mkl_order, &padding_l, - &padding_r); + conv_utl.GetConvFwdSizesInMklOrder(src_tensor.shape(), + filter_tensor.shape(), + &src_dims, &filter_dims, &strides, + &output_dims_tf_order, + &output_dims_mkl_order, &padding_l, + &padding_r); if (!context->status().ok()) return; // Check for corner case - if there is nothing to compute, return. - TensorShape tf_output_shape( - {output_dims_tf_order[0], output_dims_tf_order[1], - output_dims_tf_order[2], output_dims_tf_order[3]}); + TensorShape tf_output_shape({output_dims_tf_order[0], + output_dims_tf_order[1], + output_dims_tf_order[2], + output_dims_tf_order[3]}); Tensor* output_tensor = nullptr; MklShape mkl_output_mkl_shape; mkl_output_mkl_shape.SetMklTensor(false); @@ -569,13 +572,13 @@ class MklConv2DOp : public OpKernel { // the layout is Tensorflow's layout (NHWC or NCHW depending on data // format). src.SetUsrMem(src_dims, TFDataFormatToMklDnnDataFormat(data_format_), - const_cast( - static_cast(src_tensor.flat().data()))); + const_cast(static_cast( + src_tensor.flat().data()))); // Although filter shape (filter_dims) required is in MKL-DNN order, // the layout is Tensorflow's layout (HWIO). filter.SetUsrMem(filter_dims, memory::format::hwio, const_cast(static_cast( - filter_tensor.flat().data()))); + filter_tensor.flat().data()))); // Although output shape (output_dims) required is in MKL-DNN order, // layout is Tensorflow's layout (NHWC or NCHW depending on data format). output.SetUsrMem(output_dims_mkl_order, @@ -595,36 +598,36 @@ class MklConv2DOp : public OpKernel { const Tensor& bias_tensor = MklGetInput(context, 2); bias.SetUsrMem(bias_size, memory::format::x, const_cast(static_cast( - bias_tensor.flat().data()))); + bias_tensor.flat().data()))); bias.SetOpMemDesc(bias_size, memory::format::any); // Create convolution primitive with Bias. - auto conv_desc = convolution_forward::desc( - prop_kind::forward, convolution_direct, src.GetOpMemDesc(), - filter.GetOpMemDesc(), bias.GetOpMemDesc(), output.GetOpMemDesc(), - strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_)); + auto conv_desc = convolution_forward::desc(prop_kind::forward, + convolution_direct, src.GetOpMemDesc(), filter.GetOpMemDesc(), + bias.GetOpMemDesc(), output.GetOpMemDesc(), strides, + padding_l, padding_r, TFPaddingToMklDnnPadding(padding_)); - auto conv_prim_desc = - convolution_forward::primitive_desc(conv_desc, cpu_engine); + auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, + cpu_engine); PrepareAndExecuteNet(conv_prim_desc, &src, &filter, &bias, &output); } else { // Create convolution primitive without Bias. - auto conv_desc = convolution_forward::desc( - prop_kind::forward, convolution_direct, src.GetOpMemDesc(), - filter.GetOpMemDesc(), output.GetOpMemDesc(), strides, padding_l, - padding_r, TFPaddingToMklDnnPadding(padding_)); + auto conv_desc = convolution_forward::desc(prop_kind::forward, + convolution_direct, src.GetOpMemDesc(), filter.GetOpMemDesc(), + output.GetOpMemDesc(), strides, padding_l, padding_r, + TFPaddingToMklDnnPadding(padding_)); - auto conv_prim_desc = - convolution_forward::primitive_desc(conv_desc, cpu_engine); + auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, + cpu_engine); PrepareAndExecuteNet(conv_prim_desc, &src, &filter, nullptr, &output); } - } catch (mkldnn::error& e) { + } catch (mkldnn::error &e) { string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + std::string(e.message) + ", in file " + - std::string(__FILE__) + ":" + std::to_string(__LINE__); - OP_REQUIRES_OK( - context, - errors::Aborted("Operation received an exception:", error_msg)); + ", message: " + std::string(e.message) + + ", in file " + std::string(__FILE__) + ":" + + std::to_string(__LINE__); + OP_REQUIRES_OK(context, + errors::Aborted("Operation received an exception:", error_msg)); } } @@ -635,9 +638,9 @@ class MklConv2DOp : public OpKernel { // Prepare and execute net - checks for input and output reorders. void PrepareAndExecuteNet( - const convolution_forward::primitive_desc& conv_prim_desc, - MklDnnData* src, MklDnnData* filter, MklDnnData* bias, - MklDnnData* output) { + const convolution_forward::primitive_desc& conv_prim_desc, + MklDnnData* src, MklDnnData* filter, + MklDnnData* bias, MklDnnData* output) { // Create reorders between user layout and MKL layout if it is needed and // add it to the net before convolution. std::vector net; @@ -648,19 +651,18 @@ class MklConv2DOp : public OpKernel { // output side, we will prepare reorder primitive in case output // reorder to user memory is required. bool output_reorder_required = output->PrepareReorderToUserMemIfReq( - conv_prim_desc.dst_primitive_desc()); + conv_prim_desc.dst_primitive_desc()); // Create convolution primitive and add it to net. if (bias) { CHECK_EQ(biasEnabled, true); net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(), - filter->GetOpMem(), bias->GetOpMem(), - output->GetOpMem())); + filter->GetOpMem(), bias->GetOpMem(), + output->GetOpMem())); } else { CHECK_EQ(biasEnabled, false); net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(), - filter->GetOpMem(), - output->GetOpMem())); + filter->GetOpMem(), output->GetOpMem())); } // Insert reorder primitive in the net for output reorder if reorder is diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h index e29af19ca9b..f0cb37f8a42 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.h +++ b/tensorflow/core/kernels/mkl_conv_ops.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_ #define TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_ -#include #include +#include #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -26,8 +26,8 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" #include "tensorflow/core/kernels/bounds_check.h" -#include "tensorflow/core/kernels/conv_grad_ops.h" #include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/conv_grad_ops.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/strings/numbers.h" @@ -49,15 +49,15 @@ namespace tensorflow { class MklDnnConvUtil { protected: - OpKernelContext *context_; // We don't own this. + OpKernelContext* context_; // We don't own this. std::vector strides_; Padding padding_; TensorFormat data_format_; public: - MklDnnConvUtil(OpKernelContext *context, const std::vector &strides, - Padding pad, TensorFormat fm) - : context_(context), strides_(strides), padding_(pad), data_format_(fm) {} + MklDnnConvUtil(OpKernelContext* context, const std::vector& strides, + Padding pad, TensorFormat fm) : context_(context), + strides_(strides), padding_(pad), data_format_(fm) {} virtual ~MklDnnConvUtil() { context_ = nullptr; } @@ -75,14 +75,14 @@ class MklDnnConvUtil { // requires input in NCHW format. Function does not return anything. // But errors arising from sanity checks are returned in context's // status. - virtual inline void GetInputSizeInMklOrder(const TensorShape &input_shape, - memory::dims *input_dims) { -#define CHECK_BOUNDS(val, err_msg) \ - do { \ - OP_REQUIRES(context_, \ - FastBoundsCheck(val, std::numeric_limits::max()), \ - errors::InvalidArgument(err_msg)); \ - } while (0) + virtual inline void + GetInputSizeInMklOrder(const TensorShape& input_shape, + memory::dims *input_dims) { + #define CHECK_BOUNDS(val, err_msg) do { \ + OP_REQUIRES(context_, FastBoundsCheck(val, \ + std::numeric_limits::max()), \ + errors::InvalidArgument(err_msg)); \ + }while(0) CHECK_NOTNULL(input_dims); @@ -105,7 +105,7 @@ class MklDnnConvUtil { CHECK_BOUNDS(input_batch_raw, "Input batch too large"); int input_batch = static_cast(input_batch_raw); -#undef CHECK_BOUNDS + #undef CHECK_BOUNDS // MKL-DNN always requires input in NCHW format. *input_dims = {input_batch, input_depth, input_rows, input_cols}; @@ -125,9 +125,10 @@ class MklDnnConvUtil { // forward gets actual tensor as input). // // TODO(nhasabni): Add similar function for input and filter in MklShape. - virtual inline void GetFilterSizeInMklOrder(const TensorShape &input_shape, - const TensorShape &filter_shape, - memory::dims *filter_dims) { + virtual inline void + GetFilterSizeInMklOrder(const TensorShape& input_shape, + const TensorShape& filter_shape, + memory::dims *filter_dims) { CHECK_NOTNULL(filter_dims); OP_REQUIRES(context_, filter_shape.dims() == 4, @@ -135,18 +136,17 @@ class MklDnnConvUtil { filter_shape.DebugString())); for (int i = 0; i < 3; i++) { - OP_REQUIRES(context_, - FastBoundsCheck(filter_shape.dim_size(i), - std::numeric_limits::max()), - errors::InvalidArgument("filter too large")); + OP_REQUIRES(context_, FastBoundsCheck(filter_shape.dim_size(i), + std::numeric_limits::max()), + errors::InvalidArgument("filter too large")); } int input_depth = GetTensorDim(input_shape, data_format_, 'C'); - OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2), - errors::InvalidArgument( - "input and filter must have the same depth: ", input_depth, - " vs ", filter_shape.dim_size(2))); + OP_REQUIRES( + context_, input_depth == filter_shape.dim_size(2), + errors::InvalidArgument("input and filter must have the same depth: ", + input_depth, " vs ", filter_shape.dim_size(2))); // TF filter is always in (rows, cols, in_depth, out_depth) order. int filter_rows = static_cast(filter_shape.dim_size(0)); @@ -163,25 +163,25 @@ class MklDnnConvUtil { // requires filter in OIHW format. Function does not return anything. // But errors arising from sanity checks are returned in context's // status. - virtual inline void GetFilterSizeInMklOrder(size_t src_index, - size_t filter_index, - memory::dims *filter_dims) { + virtual inline void + GetFilterSizeInMklOrder(size_t src_index, size_t filter_index, + memory::dims *filter_dims) { CHECK_NOTNULL(filter_dims); - const Tensor &input = MklGetInput(context_, src_index); - const Tensor &filter = MklGetInput(context_, filter_index); + const Tensor& input = MklGetInput(context_, src_index); + const Tensor& filter = MklGetInput(context_, filter_index); GetFilterSizeInMklOrder(input.shape(), filter.shape(), filter_dims); } // Calculate Bias size for 2D Convolution. Function does not return // anything, but sets error in context status. - virtual inline void GetBiasSizeInMklOrder(size_t bias_index, - memory::dims *bias_dims) { - const Tensor &bias = MklGetInput(context_, bias_index); + virtual inline void + GetBiasSizeInMklOrder(size_t bias_index, memory::dims *bias_dims) { + const Tensor& bias = MklGetInput(context_, bias_index); OP_REQUIRES(context_, bias.dims() == 1, errors::InvalidArgument("bias must be 1-dimensional: ", bias.shape().DebugString())); - *bias_dims = {static_cast(bias.dim_size(0))}; + *bias_dims = { static_cast(bias.dim_size(0)) }; } // Function to calculate output and padding size for 2D convolution. @@ -193,11 +193,13 @@ class MklDnnConvUtil { // status is returned via context status. // // TODO(nhasabni): Add similar function for input and filter in MklShape. - virtual inline void GetOutputAndPadSizeInMklOrder( - const TensorShape &input_shape, const TensorShape &filter_shape, - const memory::dims &strides, memory::dims *output_dims_tf_order, - memory::dims *output_dims_mkl_order, memory::dims *pad_l, - memory::dims *pad_r) { + virtual inline void + GetOutputAndPadSizeInMklOrder(const TensorShape& input_shape, + const TensorShape& filter_shape, + const memory::dims& strides, + memory::dims *output_dims_tf_order, + memory::dims *output_dims_mkl_order, + memory::dims *pad_l, memory::dims *pad_r) { CHECK_NOTNULL(output_dims_tf_order); CHECK_NOTNULL(output_dims_mkl_order); CHECK_NOTNULL(pad_l); @@ -223,21 +225,21 @@ class MklDnnConvUtil { int64 out_rows = 0, out_cols = 0; int64 pad_top = 0, pad_bottom = 0, pad_left, pad_right; - OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose( - input_rows, filter_rows, stride_rows, padding_, - &out_rows, &pad_top, &pad_bottom)); - OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose( - input_cols, filter_cols, stride_cols, padding_, - &out_cols, &pad_left, &pad_right)); + OP_REQUIRES_OK(context_, + GetWindowedOutputSizeVerbose(input_rows, filter_rows, stride_rows, + padding_, &out_rows, &pad_top, &pad_bottom)); + OP_REQUIRES_OK(context_, + GetWindowedOutputSizeVerbose(input_cols, filter_cols, stride_cols, + padding_, &out_cols, &pad_left, &pad_right)); // Tensorflow output is in data_format order. (NHWC or NCHW) - TensorShape out_shape = - ShapeFromFormat(data_format_, out_batch, out_rows, out_cols, out_depth); + TensorShape out_shape = ShapeFromFormat(data_format_, out_batch, + out_rows, out_cols, out_depth); *output_dims_tf_order = TFShapeToMklDnnDims(out_shape); // MKL-DNN always needs output in NCHW format. *output_dims_mkl_order = {out_batch, out_depth, static_cast(out_rows), - static_cast(out_cols)}; + static_cast(out_cols)}; // Now handle padding. MKL-DNN uses asymetric padding. *pad_l = {static_cast(pad_top), static_cast(pad_left)}; @@ -248,25 +250,27 @@ class MklDnnConvUtil { // See comment on GetConvOutputAndPadSizeInMklOrder for parameters. // // Function does not return anything, but sets error in context status. - inline void GetOutputAndPadSizeInMklOrder( - size_t src_index, size_t filter_index, const memory::dims &strides, - memory::dims *output_dims_tf_order, memory::dims *output_dims_mkl_order, - memory::dims *pad_l, memory::dims *pad_r) { + inline void + GetOutputAndPadSizeInMklOrder(size_t src_index, size_t filter_index, + const memory::dims& strides, + memory::dims *output_dims_tf_order, + memory::dims *output_dims_mkl_order, + memory::dims *pad_l, memory::dims *pad_r) { CHECK_NOTNULL(output_dims_tf_order); CHECK_NOTNULL(output_dims_mkl_order); CHECK_NOTNULL(pad_l); CHECK_NOTNULL(pad_r); - const Tensor &input = MklGetInput(context_, src_index); - const Tensor &filter = MklGetInput(context_, filter_index); + const Tensor& input = MklGetInput(context_, src_index); + const Tensor& filter = MklGetInput(context_, filter_index); OP_REQUIRES(context_, input.dims() == 4, errors::InvalidArgument("input must be 4-dimensional", - input.shape().DebugString())); + input.shape().DebugString())); - GetOutputAndPadSizeInMklOrder(input.shape(), filter.shape(), strides, - output_dims_tf_order, output_dims_mkl_order, - pad_l, pad_r); + GetOutputAndPadSizeInMklOrder(input.shape(), filter.shape(), + strides, output_dims_tf_order, + output_dims_mkl_order, pad_l, pad_r); } // Wrapper function to calculate input, filter, and output sizes of @@ -275,12 +279,15 @@ class MklDnnConvUtil { // also calculates strides and paddings for 2D Convolution. // // Function does not return anything, but sets error in context status. - inline void GetConvFwdSizesInMklOrder( - const TensorShape &input_shape, const TensorShape &filter_shape, - memory::dims *input_dims, memory::dims *filter_dims, - memory::dims *strides, memory::dims *output_dims_tf_order, - memory::dims *output_dims_mkl_order, memory::dims *pad_l, - memory::dims *pad_r) { + inline void GetConvFwdSizesInMklOrder(const TensorShape& input_shape, + const TensorShape& filter_shape, + memory::dims *input_dims, + memory::dims *filter_dims, + memory::dims *strides, + memory::dims *output_dims_tf_order, + memory::dims *output_dims_mkl_order, + memory::dims *pad_l, + memory::dims *pad_r) { CHECK_NOTNULL(input_dims); CHECK_NOTNULL(filter_dims); CHECK_NOTNULL(strides); @@ -295,7 +302,8 @@ class MklDnnConvUtil { if (!context_->status().ok()) return; GetStridesInMklOrder(strides); GetOutputAndPadSizeInMklOrder(input_shape, filter_shape, *strides, - output_dims_tf_order, output_dims_mkl_order, + output_dims_tf_order, + output_dims_mkl_order, pad_l, pad_r); if (!context_->status().ok()) return; } diff --git a/tensorflow/core/kernels/shape_ops.h b/tensorflow/core/kernels/shape_ops.h index 55be308901b..8d9d0ea8461 100644 --- a/tensorflow/core/kernels/shape_ops.h +++ b/tensorflow/core/kernels/shape_ops.h @@ -235,10 +235,10 @@ class SqueezeOp : public OpKernel { if (!wrapped_squeeze_dims.empty()) { if (wrapped_squeeze_dims.count(i) > 0) { OP_REQUIRES(ctx, existing_dim == 1, - errors::InvalidArgument( - "Tried to explicitly squeeze " - "dimension ", - i, " but dimension was not 1: ", existing_dim)); + errors::InvalidArgument("Tried to explicitly squeeze " + "dimension ", + i, " but dimension was not 1: ", + existing_dim)); } else { // This dimension is not being squeezed. new_shape.push_back(existing_dim); diff --git a/tensorflow/core/kernels/slice_op.h b/tensorflow/core/kernels/slice_op.h index db7eded745e..0362a021336 100644 --- a/tensorflow/core/kernels/slice_op.h +++ b/tensorflow/core/kernels/slice_op.h @@ -24,6 +24,7 @@ limitations under the License. namespace tensorflow { namespace functor { + template struct Slice { void operator()(const Device& d, typename TTypes::Tensor output, diff --git a/tensorflow/core/util/transform_output_iterator.h b/tensorflow/core/util/transform_output_iterator.h index 1640791ad17..059206c75b9 100644 --- a/tensorflow/core/util/transform_output_iterator.h +++ b/tensorflow/core/util/transform_output_iterator.h @@ -24,7 +24,7 @@ namespace tensorflow { template class TransformOutputIterator { - private: + protected: // Proxy object struct Reference { StoreType* ptr; diff --git a/tensorflow/docs_src/extend/add_filesys.md b/tensorflow/docs_src/extend/add_filesys.md index 44ba198998c..f0591b7b7d8 100644 --- a/tensorflow/docs_src/extend/add_filesys.md +++ b/tensorflow/docs_src/extend/add_filesys.md @@ -35,6 +35,7 @@ Note that TensorFlow already includes many filesystem implementations, such as: * HDFS - the Hadoop File System * GCS - Google Cloud Storage filesystem +* S3 - Amazon Simple Storage Service filesystem * A "memory-mapped-file" filesystem The rest of this guide describes how to implement a custom filesystem. diff --git a/tensorflow/docs_src/extend/estimators.md b/tensorflow/docs_src/extend/estimators.md index 7e6507c5840..96fc9fae472 100644 --- a/tensorflow/docs_src/extend/estimators.md +++ b/tensorflow/docs_src/extend/estimators.md @@ -515,7 +515,7 @@ using `mean_squared_error()` (in bold): loss = tf.losses.mean_squared_error(labels, predictions) ... -See the @{$python/contrib.losses$API guide} for a +See the @{tf.losses$API guide} for a full list of loss functions and more details on supported arguments and usage. Supplementary metrics for evaluation can be added to an `eval_metric_ops` dict. @@ -694,5 +694,5 @@ For additional reference materials on building `Estimator`s, see the following sections of the API guides: * @{$python/contrib.layers$Layers} -* @{$python/contrib.losses$Losses} +* @{tf.losses$Losses} * @{$python/contrib.layers#optimization$Optimization} diff --git a/tensorflow/docs_src/get_started/input_fn.md b/tensorflow/docs_src/get_started/input_fn.md index f0dcdc47ff1..24bfdbdd2e9 100644 --- a/tensorflow/docs_src/get_started/input_fn.md +++ b/tensorflow/docs_src/get_started/input_fn.md @@ -292,7 +292,7 @@ prediction_set = pd.read_csv("boston_predict.csv", skipinitialspace=True, Next, create a list of `FeatureColumn`s for the input data, which formally specify the set of features to use for training. Because all features in the housing data set contain continuous values, you can create their -`FeatureColumn`s using the `tf.contrib.layers.real_valued_column()` function: +`FeatureColumn`s using the `tf.feature_column.numeric_column()` function: ```python feature_cols = [tf.feature_column.numeric_column(k) for k in FEATURES] diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md index 79b383817b4..3afd0aec0f3 100644 --- a/tensorflow/docs_src/install/install_mac.md +++ b/tensorflow/docs_src/install/install_mac.md @@ -79,22 +79,23 @@ Take the following steps to install TensorFlow with Virtualenv: 4. Activate the Virtualenv environment by issuing one of the following commands: -
$ source ~/tensorflow/bin/activate      # If using bash, sh, ksh, or zsh
-    $ source ~/tensorflow/bin/activate.csh  # If using csh or tcsh 
+
$ cd targetDirectory
+    $ source ./bin/activate      # If using bash, sh, ksh, or zsh
+    $ source ./bin/activate.csh  # If using csh or tcsh 
The preceding `source` command should change your prompt to the following: -
 (tensorflow)$ 
+
 (targetDirectory)$ 
5. Ensure pip ≥8.1 is installed: -
 (tensorflow)$ easy_install -U pip
+
 (targetDirectory)$ easy_install -U pip
6. Issue one of the following commands to install TensorFlow and all the packages that TensorFlow requires into the active Virtualenv environment: -
 (tensorflow)$ pip install --upgrade tensorflow      # for Python 2.7
-     (tensorflow)$ pip3 install --upgrade tensorflow     # for Python 3.n
+     
 (targetDirectory)$ pip install --upgrade tensorflow      # for Python 2.7
+     (targetDirectory)$ pip3 install --upgrade tensorflow     # for Python 3.n
 
   7. Optional. If Step 6 failed (typically because you invoked a pip version
      lower than 8.1), install TensorFlow in the active
@@ -128,16 +129,18 @@ to confirm that the installation worked properly.
 
 Note that you must activate the Virtualenv environment each time you
 use TensorFlow in a new shell.  If the Virtualenv environment is not
-currently active (that is, the prompt is not `(tensorflow)`, invoke
+currently active (that is, the prompt is not `(targetDirectory)`, invoke
 one of the following commands:
 
-
$ source ~/tensorflow/bin/activate      # bash, sh, ksh, or zsh
-$ source ~/tensorflow/bin/activate.csh  # csh or tcsh 
+
$ cd targetDirectory
+$ source ./bin/activate      # If using bash, sh, ksh, or zsh
+$ source ./bin/activate.csh  # If using csh or tcsh 
+ Your prompt will transform to the following to indicate that your tensorflow environment is active: -
 (tensorflow)$ 
+
 (targetDirectory)$ 
When the Virtualenv environment is active, you may run TensorFlow programs from this shell. @@ -145,7 +148,7 @@ TensorFlow programs from this shell. When you are done using TensorFlow, you may deactivate the environment by issuing the following command: -
 (tensorflow)$ deactivate 
+
 (targetDirectory)$ deactivate 
The prompt will revert back to your default prompt (as defined by `PS1`). @@ -331,19 +334,19 @@ Take the following steps to install TensorFlow in an Anaconda environment: 3. Activate the conda environment by issuing the following command:
$ source activate tensorflow
-     (tensorflow)$  # Your prompt should change
+ (targetDirectory)$ # Your prompt should change
4. Issue a command of the following format to install TensorFlow inside your conda environment: -
(tensorflow)$ pip install --ignore-installed --upgrade TF_PYTHON_URL
+
(targetDirectory)$ pip install --ignore-installed --upgrade TF_PYTHON_URL
where TF_PYTHON_URL is the [URL of the TensorFlow Python package](#the_url_of_the_tensorflow_python_package). For example, the following command installs the CPU-only version of TensorFlow for Python 2.7: -
 (tensorflow)$ pip install --ignore-installed --upgrade \
+     
 (targetDirectory)$ pip install --ignore-installed --upgrade \
      https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.4.0-py2-none-any.whl
diff --git a/tensorflow/docs_src/programmers_guide/datasets.md b/tensorflow/docs_src/programmers_guide/datasets.md index 073bdb7baa9..308cbad3764 100644 --- a/tensorflow/docs_src/programmers_guide/datasets.md +++ b/tensorflow/docs_src/programmers_guide/datasets.md @@ -190,8 +190,8 @@ validation_dataset = tf.data.Dataset.range(50) # A reinitializable iterator is defined by its structure. We could use the # `output_types` and `output_shapes` properties of either `training_dataset` # or `validation_dataset` here, because they are compatible. -iterator = Iterator.from_structure(training_dataset.output_types, - training_dataset.output_shapes) +iterator = tf.data.Iterator.from_structure(training_dataset.output_types, + training_dataset.output_shapes) next_element = iterator.get_next() training_init_op = iterator.make_initializer(training_dataset) @@ -735,7 +735,7 @@ def dataset_input_fn(): parsed = tf.parse_single_example(record, keys_to_features) # Perform additional preprocessing on the parsed data. - image = tf.decode_jpeg(parsed["image_data"]) + image = tf.image.decode_jpeg(parsed["image_data"]) image = tf.reshape(image, [299, 299, 1]) label = tf.cast(parsed["label"], tf.int32) diff --git a/tensorflow/docs_src/programmers_guide/saved_model.md b/tensorflow/docs_src/programmers_guide/saved_model.md index 34e8e5faf56..54693f3d4d3 100644 --- a/tensorflow/docs_src/programmers_guide/saved_model.md +++ b/tensorflow/docs_src/programmers_guide/saved_model.md @@ -33,7 +33,7 @@ roughly speaking, map variable names to tensor values. Create a `Saver` with `tf.train.Saver()` to manage all variables in the model. For example, the following snippet demonstrates how to call the -`tf.train.Saver.save` method to save variables to a checkpoint file: +`tf.train.Saver.save` method to save variables to checkpoint files: ```python # Create some variables. @@ -58,7 +58,7 @@ with tf.Session() as sess: dec_v2.op.run() # Save the variables to disk. save_path = saver.save(sess, "/tmp/model.ckpt") - print("Model saved in file: %s" % save_path) + print("Model saved in path: %s" % save_path) ``` @@ -66,10 +66,10 @@ with tf.Session() as sess: ### Restoring variables The `tf.train.Saver` object not only saves variables to checkpoint files, it -also restores variables. Note that when you restore variables from a file you -do not have to initialize them beforehand. For example, the following snippet -demonstrates how to call the `tf.train.Saver.restore` method to restore -variables from a checkpoint file: +also restores variables. Note that when you restore variables you do not have +to initialize them beforehand. For example, the following snippet demonstrates +how to call the `tf.train.Saver.restore` method to restore variables from the +checkpoint files: ```python tf.reset_default_graph() @@ -92,6 +92,12 @@ with tf.Session() as sess: print("v2 : %s" % v2.eval()) ``` +Notes: + +* There is not a physical file called "/tmp/model.ckpt". It is the **prefix** + of filenames created for the checkpoint. Users only interact with the + prefix instead of physical checkpoint files. + ### Choosing which variables to save and restore diff --git a/tensorflow/examples/android/README.md b/tensorflow/examples/android/README.md index 51621d51ef1..30a26d13c57 100644 --- a/tensorflow/examples/android/README.md +++ b/tensorflow/examples/android/README.md @@ -168,7 +168,7 @@ download-models.gradle. **Optional**: If you wish to place the models in your assets manually, remove all of the `model_files` entries from the `assets` list in `tensorflow_demo` -found in the `[BUILD](BUILD)` file. Then download and extract the archives +found in the [`BUILD`](BUILD#L92) file. Then download and extract the archives yourself to the `assets` directory in the source tree: ```bash diff --git a/tensorflow/examples/speech_commands/train.py b/tensorflow/examples/speech_commands/train.py index f5bf04305a7..bec7dacd211 100644 --- a/tensorflow/examples/speech_commands/train.py +++ b/tensorflow/examples/speech_commands/train.py @@ -161,7 +161,7 @@ def main(_): evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.summary.scalar('accuracy', evaluation_step) - global_step = tf.contrib.framework.get_or_create_global_step() + global_step = tf.train.get_or_create_global_step() increment_global_step = tf.assign(global_step, global_step + 1) saver = tf.train.Saver(tf.global_variables()) diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go index cd05e2aa0af..2d25c04dc9b 100644 --- a/tensorflow/go/tensor.go +++ b/tensorflow/go/tensor.go @@ -328,6 +328,14 @@ func encodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error { } } + // Optimisation: if only one dimension is left we can use binary.Write() directly for this slice + if len(shape) == 1 && v.Len() > 0 { + switch v.Index(0).Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + return binary.Write(w, nativeEndian, v.Interface()) + } + } + subShape := shape[1:] for i := 0; i < v.Len(); i++ { err := encodeTensor(w, v.Index(i), subShape) @@ -360,6 +368,15 @@ func decodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect. case reflect.Slice: val := reflect.Indirect(ptr) val.Set(reflect.MakeSlice(typ, int(shape[0]), int(shape[0]))) + + // Optimization: if only one dimension is left we can use binary.Read() directly for this slice + if len(shape) == 1 && val.Len() > 0 { + switch val.Index(0).Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + return binary.Read(r, nativeEndian, val.Interface()) + } + } + for i := 0; i < val.Len(); i++ { if err := decodeTensor(r, shape[1:], typ.Elem(), val.Index(i).Addr()); err != nil { return err diff --git a/tensorflow/go/tensor_test.go b/tensorflow/go/tensor_test.go index 674a8ce86f8..793c36dd4db 100644 --- a/tensorflow/go/tensor_test.go +++ b/tensorflow/go/tensor_test.go @@ -243,3 +243,23 @@ func BenchmarkNewTensor(b *testing.B) { ) b.Run("[150528]", func(b *testing.B) { benchmarkNewTensor(b, vector) }) } + +func benchmarkDecodeTensor(b *testing.B, t *Tensor) { + for i := 0; i < b.N; i++ { + _ = t.Value() + } +} + +func BenchmarkDecodeTensor(b *testing.B) { + var ( + // Some sample sizes from the Inception image labeling model. + // Where input tensors correspond to a 224x224 RGB image + // flattened into a vector. + vector [224 * 224 * 3]int32 + ) + t, err := NewTensor(vector) + if err != nil { + b.Fatalf("(%v, %v)", t, err) + } + b.Run("[150528]", func(b *testing.B) { benchmarkDecodeTensor(b, t) }) +} diff --git a/tensorflow/python/debug/lib/stepper.py b/tensorflow/python/debug/lib/stepper.py index 1fa0b3dba2b..c27b3f51cdd 100644 --- a/tensorflow/python/debug/lib/stepper.py +++ b/tensorflow/python/debug/lib/stepper.py @@ -80,7 +80,7 @@ class NodeStepper(object): when they are required as data dependencies. The temporary directories are automatically clean when the NodeStepper - instance exits as a context mananger. + instance exits as a context manager. Once the tracing is complete, it will issue a run() call on the underlying session, using the aforementioned feed_dict prepared by the input diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py index 3b295a7e35c..51075731ddc 100644 --- a/tensorflow/python/estimator/export/export.py +++ b/tensorflow/python/estimator/export/export.py @@ -191,7 +191,8 @@ def build_all_signature_defs(receiver_tensors, if not isinstance(receiver_tensors, dict): receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors} if export_outputs is None or not isinstance(export_outputs, dict): - raise ValueError('export_outputs must be a dict.') + raise ValueError('export_outputs must be a dict and not' + '{}'.format(type(export_outputs))) signature_def_map = {} excluded_signatures = {} diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py index 3cbef4707a5..8442bf04acc 100644 --- a/tensorflow/python/estimator/export/export_test.py +++ b/tensorflow/python/estimator/export/export_test.py @@ -358,7 +358,8 @@ class ExportTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError) as e: export.build_all_signature_defs(receiver_tensor, None) - self.assertEqual("export_outputs must be a dict.", str(e.exception)) + self.assertTrue(str(e.exception).startswith( + "export_outputs must be a dict")) def test_get_timestamped_export_dir(self): export_dir_base = tempfile.mkdtemp() + "export/" diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD old mode 100644 new mode 100755 index d9391dd6c58..4a60b7835ec --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -150,6 +150,7 @@ py_library( "//tensorflow/python:variables", "//tensorflow/python/estimator", "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/saved_model", "@six_archive//:six", ], ) @@ -552,7 +553,7 @@ py_test( py_test( name = "data_utils_test", - size = "small", + size = "medium", srcs = ["_impl/keras/utils/data_utils_test.py"], srcs_version = "PY2AND3", tags = [ diff --git a/tensorflow/python/keras/_impl/keras/estimator.py b/tensorflow/python/keras/_impl/keras/estimator.py index 2e931769c73..4370341ad1b 100644 --- a/tensorflow/python/keras/_impl/keras/estimator.py +++ b/tensorflow/python/keras/_impl/keras/estimator.py @@ -19,9 +19,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os from tensorflow.python.client import session from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.estimator import export as export_lib from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed @@ -31,9 +33,12 @@ from tensorflow.python.keras._impl.keras import models from tensorflow.python.keras._impl.keras.utils.generic_utils import CustomObjectScope from tensorflow.python.ops import metrics as metrics_module from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import signature_constants from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util +_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + def _create_ordered_io(keras_model, estimator_io_dict, is_input=True): """Create a list of tensors from IO dictionary based on Keras IO order. @@ -184,7 +189,11 @@ def _create_keras_model_fn(keras_model, custom_objects=None): predictions=predictions, loss=loss, train_op=train_op, - eval_metric_ops=eval_metric_ops) + eval_metric_ops=eval_metric_ops, + export_outputs={ + _DEFAULT_SERVING_KEY: + export_lib.export_output.PredictOutput(predictions) + }) return model_fn @@ -222,7 +231,7 @@ def _save_first_checkpoint(keras_model, estimator, custom_objects, K._initialize_variables(sess) # pylint: enable=protected-access saver = saver_lib.Saver() - saver.save(sess, estimator.model_dir + '/') + saver.save(sess, os.path.join(estimator.model_dir, 'keras_model.ckpt')) def model_to_estimator(keras_model=None, diff --git a/tensorflow/python/kernel_tests/decode_bmp_op_test.py b/tensorflow/python/kernel_tests/decode_bmp_op_test.py index 35f8f76991a..c67c26b7be0 100644 --- a/tensorflow/python/kernel_tests/decode_bmp_op_test.py +++ b/tensorflow/python/kernel_tests/decode_bmp_op_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl from tensorflow.python.ops import array_ops from tensorflow.python.ops import image_ops from tensorflow.python.platform import test diff --git a/tensorflow/python/kernel_tests/distributions/special_math_test.py b/tensorflow/python/kernel_tests/distributions/special_math_test.py index 9441cdbe39e..2d434a39c29 100644 --- a/tensorflow/python/kernel_tests/distributions/special_math_test.py +++ b/tensorflow/python/kernel_tests/distributions/special_math_test.py @@ -332,6 +332,32 @@ class LogNdtrGradientTest(NdtrGradientTest): _use_log = True +class ErfInvTest(test.TestCase): + + def testErfInvValues(self): + with self.test_session(): + if not special: + return + + x = np.linspace(0., 1.0, 50).astype(np.float64) + + expected_x = special.erfinv(x) + x = special_math.erfinv(x) + self.assertAllClose(expected_x, x.eval(), atol=0.) + + def testErfInvIntegerInput(self): + with self.test_session(): + + with self.assertRaises(TypeError): + x = np.array([1, 2, 3]).astype(np.int32) + special_math.erfinv(x) + + with self.assertRaises(TypeError): + x = np.array([1, 2, 3]).astype(np.int64) + special_math.erfinv(x) + + + class LogCDFLaplaceTest(test.TestCase): # Note that scipy.stats.laplace does not have a stable Log CDF, so we cannot # rely on scipy to cross check the extreme values. diff --git a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py index 48830957075..b4fb5aa4117 100644 --- a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py +++ b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py @@ -33,13 +33,14 @@ from tensorflow.python.platform import test class DynamicPartitionTest(test.TestCase): def testSimpleOneDimensional(self): - with self.test_session() as sess: - data = constant_op.constant([0, 13, 2, 39, 4, 17]) + with self.test_session(use_gpu=True) as sess: + data = constant_op.constant([0, 13, 2, 39, 4, 17], dtype=dtypes.float32) indices = constant_op.constant([0, 0, 2, 3, 2, 1]) partitions = data_flow_ops.dynamic_partition( data, indices, num_partitions=4) partition_vals = sess.run(partitions) + self.assertEqual(4, len(partition_vals)) self.assertAllEqual([0, 13], partition_vals[0]) self.assertAllEqual([17], partition_vals[1]) self.assertAllEqual([2, 4], partition_vals[2]) @@ -52,14 +53,16 @@ class DynamicPartitionTest(test.TestCase): self.assertEqual([None], partitions[3].get_shape().as_list()) def testSimpleTwoDimensional(self): - with self.test_session() as sess: + with self.test_session(use_gpu=True) as sess: data = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], - [12, 13, 14], [15, 16, 17]]) + [12, 13, 14], [15, 16, 17]], + dtype=dtypes.float32) indices = constant_op.constant([0, 0, 2, 3, 2, 1]) partitions = data_flow_ops.dynamic_partition( data, indices, num_partitions=4) partition_vals = sess.run(partitions) + self.assertEqual(4, len(partition_vals)) self.assertAllEqual([[0, 1, 2], [3, 4, 5]], partition_vals[0]) self.assertAllEqual([[15, 16, 17]], partition_vals[1]) self.assertAllEqual([[6, 7, 8], [12, 13, 14]], partition_vals[2]) @@ -71,9 +74,84 @@ class DynamicPartitionTest(test.TestCase): self.assertEqual([None, 3], partitions[2].get_shape().as_list()) self.assertEqual([None, 3], partitions[3].get_shape().as_list()) + def testLargeOneDimensional(self): + num = 100000 + data_list = [x for x in range(num)] + indices_list = [x % 2 for x in range(num)] + part1 = [x for x in range(num) if x % 2 == 0] + part2 = [x for x in range(num) if x % 2 == 1] + with self.test_session(use_gpu=True) as sess: + data = constant_op.constant(data_list, dtype=dtypes.float32) + indices = constant_op.constant(indices_list, dtype=dtypes.int32) + partitions = data_flow_ops.dynamic_partition( + data, indices, num_partitions=2) + partition_vals = sess.run(partitions) + + self.assertEqual(2, len(partition_vals)) + self.assertAllEqual(part1, partition_vals[0]) + self.assertAllEqual(part2, partition_vals[1]) + + def testLargeTwoDimensional(self): + rows = 100000 + cols = 100 + data_list = [None] * rows + for i in range(rows): + data_list[i] = [i for _ in range(cols)] + num_partitions = 97 + indices_list = [(i ** 2) % num_partitions for i in range(rows)] + parts = [[] for _ in range(num_partitions)] + for i in range(rows): + parts[(i ** 2) % num_partitions].append(data_list[i]) + with self.test_session(use_gpu=True) as sess: + data = constant_op.constant(data_list, dtype=dtypes.float32) + indices = constant_op.constant(indices_list, dtype=dtypes.int32) + partitions = data_flow_ops.dynamic_partition( + data, indices, num_partitions=num_partitions) + partition_vals = sess.run(partitions) + + self.assertEqual(num_partitions, len(partition_vals)) + for i in range(num_partitions): + # reshape because of empty parts + parts_np = np.array(parts[i], dtype=np.float).reshape(-1, cols) + self.assertAllEqual(parts_np, partition_vals[i]) + + def testSimpleComplex(self): + data_list = [1 + 2j, 3 + 4j, 5 + 6j, 7 + 8j] + indices_list = [1, 0, 1, 0] + with self.test_session(use_gpu=True) as sess: + data = constant_op.constant(data_list, dtype=dtypes.complex64) + indices = constant_op.constant(indices_list, dtype=dtypes.int32) + partitions = data_flow_ops.dynamic_partition( + data, indices, num_partitions=2) + partition_vals = sess.run(partitions) + + self.assertEqual(2, len(partition_vals)) + self.assertAllEqual([3 + 4j, 7 + 8j], partition_vals[0]) + self.assertAllEqual([1 + 2j, 5 + 6j], partition_vals[1]) + + def testScalarPartitions(self): + data_list = [10, 13, 12, 11] + with self.test_session(use_gpu=True) as sess: + data = constant_op.constant(data_list, dtype=dtypes.float64) + indices = 3 + partitions = data_flow_ops.dynamic_partition( + data, indices, num_partitions=4) + partition_vals = sess.run(partitions) + + self.assertEqual(4, len(partition_vals)) + self.assertAllEqual(np.array([], dtype=np.float64).reshape(-1, 4), + partition_vals[0]) + self.assertAllEqual(np.array([], dtype=np.float64).reshape(-1, 4), + partition_vals[1]) + self.assertAllEqual(np.array([], dtype=np.float64).reshape(-1, 4), + partition_vals[2]) + self.assertAllEqual(np.array([10, 13, 12, 11], + dtype=np.float64).reshape(-1, 4), + partition_vals[3]) + def testHigherRank(self): np.random.seed(7) - with self.test_session() as sess: + with self.test_session(use_gpu=True) as sess: for n in 2, 3: for shape in (4,), (4, 5), (4, 5, 2): partitions = np.random.randint(n, size=np.prod(shape)).reshape(shape) @@ -95,6 +173,115 @@ class DynamicPartitionTest(test.TestCase): self.assertEqual(grads[1], None) # Partitions has no gradients self.assertAllEqual(7 * data, sess.run(grads[0])) + def testEmptyParts(self): + data_list = [1, 2, 3, 4] + indices_list = [1, 3, 1, 3] + with self.test_session(use_gpu=True) as sess: + data = constant_op.constant(data_list, dtype=dtypes.float32) + indices = constant_op.constant(indices_list, dtype=dtypes.int32) + partitions = data_flow_ops.dynamic_partition( + data, indices, num_partitions=4) + partition_vals = sess.run(partitions) + + self.assertEqual(4, len(partition_vals)) + self.assertAllEqual([], partition_vals[0]) + self.assertAllEqual([1, 3], partition_vals[1]) + self.assertAllEqual([], partition_vals[2]) + self.assertAllEqual([2, 4], partition_vals[3]) + + def testEmptyDataTwoDimensional(self): + data_list = [[], []] + indices_list = [0, 1] + with self.test_session(use_gpu=True) as sess: + data = constant_op.constant(data_list, dtype=dtypes.float32) + indices = constant_op.constant(indices_list, dtype=dtypes.int32) + partitions = data_flow_ops.dynamic_partition( + data, indices, num_partitions=3) + partition_vals = sess.run(partitions) + + self.assertEqual(3, len(partition_vals)) + self.assertAllEqual([[]], partition_vals[0]) + self.assertAllEqual([[]], partition_vals[1]) + self.assertAllEqual(np.array([], dtype=np.float).reshape(0, 0), + partition_vals[2]) + + def testEmptyPartitions(self): + data_list = [] + indices_list = [] + with self.test_session(use_gpu=True) as sess: + data = constant_op.constant(data_list, dtype=dtypes.float32) + indices = constant_op.constant(indices_list, dtype=dtypes.int32) + partitions = data_flow_ops.dynamic_partition( + data, indices, num_partitions=2) + partition_vals = sess.run(partitions) + + self.assertEqual(2, len(partition_vals)) + self.assertAllEqual([], partition_vals[0]) + self.assertAllEqual([], partition_vals[1]) + + def testGPUTooManyParts(self): + # This test only makes sense on the GPU. There we do not check + # for errors. In this case, we should discard all but the first + # num_partitions indices. + if not test.is_gpu_available(): + return + + data_list = [1, 2, 3, 4, 5, 6] + indices_list = [6, 5, 4, 3, 1, 0] + with self.test_session(use_gpu=True) as sess: + data = constant_op.constant(data_list, dtype=dtypes.float32) + indices = constant_op.constant(indices_list, dtype=dtypes.int32) + partitions = data_flow_ops.dynamic_partition( + data, indices, num_partitions=2) + partition_vals = sess.run(partitions) + + self.assertEqual(2, len(partition_vals)) + self.assertAllEqual([6], partition_vals[0]) + self.assertAllEqual([5], partition_vals[1]) + + def testGPUPartsTooLarge(self): + # This test only makes sense on the GPU. There we do not check + # for errors. In this case, we should discard all the values + # larger than num_partitions. + if not test.is_gpu_available(): + return + + data_list = [1, 2, 3, 4, 5, 6] + indices_list = [10, 11, 2, 12, 0, 1000] + with self.test_session(use_gpu=True) as sess: + data = constant_op.constant(data_list, dtype=dtypes.float32) + indices = constant_op.constant(indices_list, dtype=dtypes.int32) + partitions = data_flow_ops.dynamic_partition( + data, indices, num_partitions=5) + partition_vals = sess.run(partitions) + + self.assertEqual(5, len(partition_vals)) + self.assertAllEqual([5], partition_vals[0]) + self.assertAllEqual([], partition_vals[1]) + self.assertAllEqual([3], partition_vals[2]) + self.assertAllEqual([], partition_vals[3]) + self.assertAllEqual([], partition_vals[4]) + + def testGPUAllIndicesBig(self): + # This test only makes sense on the GPU. There we do not check + # for errors. In this case, we should discard all the values + # and have an empty output. + if not test.is_gpu_available(): + return + + data_list = [1.1, 2.1, 3.1, 4.1, 5.1, 6.1] + indices_list = [90, 70, 60, 100, 110, 40] + with self.test_session(use_gpu=True) as sess: + data = constant_op.constant(data_list, dtype=dtypes.float32) + indices = constant_op.constant(indices_list, dtype=dtypes.int32) + partitions = data_flow_ops.dynamic_partition( + data, indices, num_partitions=40) + partition_vals = sess.run(partitions) + + self.assertEqual(40, len(partition_vals)) + for i in range(40): + self.assertAllEqual([], partition_vals[i]) + def testErrorIndexOutOfRange(self): with self.test_session() as sess: data = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], diff --git a/tensorflow/python/ops/bitwise_ops_test.py b/tensorflow/python/ops/bitwise_ops_test.py index fa1b219b177..75eb100a90f 100644 --- a/tensorflow/python/ops/bitwise_ops_test.py +++ b/tensorflow/python/ops/bitwise_ops_test.py @@ -36,7 +36,7 @@ class BitwiseOpTest(test_util.TensorFlowTestCase): def testBinaryOps(self): dtype_list = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, - dtypes.uint8, dtypes.uint16] + dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64] with self.test_session(use_gpu=True) as sess: for dtype in dtype_list: diff --git a/tensorflow/python/ops/distributions/special_math.py b/tensorflow/python/ops/distributions/special_math.py index 222a39ad828..bed4cbb2c1a 100644 --- a/tensorflow/python/ops/distributions/special_math.py +++ b/tensorflow/python/ops/distributions/special_math.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops __all__ = [ + "erfinv", "ndtr", "ndtri", "log_ndtr", @@ -350,6 +351,29 @@ def _log_ndtr_asymptotic_series(x, series_order): return 1. + even_sum - odd_sum +def erfinv(x, name="erfinv"): + """The inverse function for erf, the error function. + + Args: + x: `Tensor` of type `float32`, `float64`. + name: Python string. A name for the operation (default="erfinv"). + + Returns: + x: `Tensor` with `dtype=x.dtype`. + + Raises: + TypeError: if `x` is not floating-type. + """ + + with ops.name_scope(name, values=[x]): + x = ops.convert_to_tensor(x, name="x") + if x.dtype.as_numpy_dtype not in [np.float32, np.float64]: + raise TypeError( + "x.dtype=%s is not handled, see docstring for supported types." + % x.dtype) + return ndtri((x + 1.0) / 2.0) / np.sqrt(2) + + def _double_factorial(n): """The double factorial function for small Python integer `n`.""" return np.prod(np.arange(n, 1, -2)) diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py index 55a18d28cae..b74971f6542 100644 --- a/tensorflow/python/ops/losses/losses_impl.py +++ b/tensorflow/python/ops/losses/losses_impl.py @@ -652,7 +652,7 @@ def softmax_cross_entropy( Args: onehot_labels: `[batch_size, num_classes]` target one-hot-encoded labels. - logits: [batch_size, num_classes] logits outputs of the network . + logits: `[batch_size, num_classes]` logits outputs of the network . weights: Optional `Tensor` whose rank is either 0, or rank 1 and is broadcastable to the loss which is a `Tensor` of shape `[batch_size]`. label_smoothing: If greater than 0 then smooth the labels. diff --git a/tensorflow/python/platform/tf_logging.py b/tensorflow/python/platform/tf_logging.py index 71ee5e365f7..85ed4f071c7 100644 --- a/tensorflow/python/platform/tf_logging.py +++ b/tensorflow/python/platform/tf_logging.py @@ -30,64 +30,92 @@ from logging import ERROR from logging import FATAL from logging import INFO from logging import WARN +import threading import six from tensorflow.python.util.all_util import remove_undocumented -# Determine whether we are in an interactive environment -_interactive = False -try: - # This is only defined in interactive shells - if _sys.ps1: _interactive = True -except AttributeError: - # Even now, we may be in an interactive shell with `python -i`. - _interactive = _sys.flags.interactive +# Don't use this directly. Use _get_logger() instead. +_logger = None +_logger_lock = threading.Lock() -# Scope the tensorflow logger to not conflict with users' loggers -_logger = _logging.getLogger('tensorflow') -# If we are in an interactive environment (like jupyter), set loglevel to info -# and pipe the output to stdout -if _interactive: - _logger.setLevel(INFO) - _logging_target = _sys.stdout -else: - _logging_target = _sys.stderr +def _get_logger(): + global _logger -# Add the output handler -_handler = _logging.StreamHandler(_logging_target) -_handler.setFormatter(_logging.Formatter(_logging.BASIC_FORMAT, None)) -_logger.addHandler(_handler) + # Use double-checked locking to avoid taking lock unnecessarily. + if _logger: + return _logger + + _logger_lock.acquire() + + try: + if _logger: + return _logger + + # Scope the TensorFlow logger to not conflict with users' loggers. + logger = _logging.getLogger('tensorflow') + + # Don't further configure the TensorFlow logger if the root logger is + # already configured. This prevents double logging in those cases. + if not _logging.getLogger().handlers: + # Determine whether we are in an interactive environment + _interactive = False + try: + # This is only defined in interactive shells. + if _sys.ps1: _interactive = True + except AttributeError: + # Even now, we may be in an interactive shell with `python -i`. + _interactive = _sys.flags.interactive + + # If we are in an interactive environment (like Jupyter), set loglevel + # to INFO and pipe the output to stdout. + if _interactive: + logger.setLevel(INFO) + _logging_target = _sys.stdout + else: + _logging_target = _sys.stderr + + # Add the output handler. + _handler = _logging.StreamHandler(_logging_target) + _handler.setFormatter(_logging.Formatter(_logging.BASIC_FORMAT, None)) + logger.addHandler(_handler) + + _logger = logger + return _logger + + finally: + _logger_lock.release() def log(level, msg, *args, **kwargs): - _logger.log(level, msg, *args, **kwargs) + _get_logger().log(level, msg, *args, **kwargs) def debug(msg, *args, **kwargs): - _logger.debug(msg, *args, **kwargs) + _get_logger().debug(msg, *args, **kwargs) def error(msg, *args, **kwargs): - _logger.error(msg, *args, **kwargs) + _get_logger().error(msg, *args, **kwargs) def fatal(msg, *args, **kwargs): - _logger.fatal(msg, *args, **kwargs) + _get_logger().fatal(msg, *args, **kwargs) def info(msg, *args, **kwargs): - _logger.info(msg, *args, **kwargs) + _get_logger().info(msg, *args, **kwargs) def warn(msg, *args, **kwargs): - _logger.warn(msg, *args, **kwargs) + _get_logger().warn(msg, *args, **kwargs) def warning(msg, *args, **kwargs): - _logger.warning(msg, *args, **kwargs) + _get_logger().warning(msg, *args, **kwargs) _level_names = { @@ -118,7 +146,7 @@ def flush(): # Code below is taken from pyglib/logging def vlog(level, msg, *args, **kwargs): - _logger.log(level, msg, *args, **kwargs) + _get_logger().log(level, msg, *args, **kwargs) def _GetNextLogCountPerToken(token): @@ -225,12 +253,12 @@ def google2_log_prefix(level, timestamp=None, file_and_line=None): def get_verbosity(): """Return how much logging output will be produced.""" - return _logger.getEffectiveLevel() + return _get_logger().getEffectiveLevel() def set_verbosity(v): """Sets the threshold for what messages will be logged.""" - _logger.setLevel(v) + _get_logger().setLevel(v) def _get_thread_id(): diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index dc19e1bc94e..5ddc688a4cf 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -57,6 +57,8 @@ if sys.version_info.major == 3: REQUIRED_PACKAGES.append('wheel >= 0.26') else: REQUIRED_PACKAGES.append('wheel') + # mock comes with unittest.mock for python3, need to install for python2 + REQUIRED_PACKAGES.append('mock >= 2.0.0') # tf-nightly should depend on tb-nightly if 'tf_nightly' in project_name: @@ -65,6 +67,11 @@ if 'tf_nightly' in project_name: REQUIRED_PACKAGES[i] = 'tb-nightly >= 1.5.0a0, < 1.6.0a0' break +# weakref.finalize and enum were introduced in Python 3.4 +if sys.version_info < (3, 4): + REQUIRED_PACKAGES.append('backports.weakref >= 1.0rc1') + REQUIRED_PACKAGES.append('enum34 >= 1.1.6') + # pylint: disable=line-too-long CONSOLE_SCRIPTS = [ 'freeze_graph = tensorflow.python.tools.freeze_graph:main', diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 5753b0c897c..20e1aaaf6ec 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -1,40 +1,21 @@ # TensorFlow external dependencies that can be loaded in WORKSPACE files. load("//third_party/gpus:cuda_configure.bzl", "cuda_configure") - -load("//third_party/sycl:sycl_configure.bzl", "sycl_configure") load("//third_party/mkl:build_defs.bzl", "mkl_repository") -load( - "@io_bazel_rules_closure//closure/private:java_import_external.bzl", - "java_import_external", -) -load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external") load("//third_party/py:python_configure.bzl", "python_configure") -load( - "//third_party/toolchains/cpus/arm:arm_compiler_configure.bzl", - "arm_compiler_configure", -) - -def _is_windows(repository_ctx): - """Returns true if the host operating system is windows.""" - return repository_ctx.os.name.lower().find("windows") != -1 - -def _get_env_var(repository_ctx, name): - """Find an environment variable.""" - if name in repository_ctx.os.environ: - return repository_ctx.os.environ[name] - else: - return None +load("//third_party/sycl:sycl_configure.bzl", "sycl_configure") +load("//third_party/toolchains/cpus/arm:arm_compiler_configure.bzl", "arm_compiler_configure") +load("//third_party:repo.bzl", "tf_http_archive") +load("@io_bazel_rules_closure//closure/private:java_import_external.bzl", "java_import_external") +load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external") # Parse the bazel version string from `native.bazel_version`. def _parse_bazel_version(bazel_version): # Remove commit from version. version = bazel_version.split(" ", 1)[0] - # Split into (release, date) parts and only return the release # as a tuple of integers. parts = version.split("-", 1) - # Turn "release" into a tuple of strings version_tuple = () for number in parts[0].split("."): @@ -57,50 +38,6 @@ def check_version(bazel_version): fail("\nCurrent Bazel version is {}, expected at least {}\n".format( native.bazel_version, bazel_version)) -# Executes specified command with arguments and calls 'fail' if it exited with -# non-zero code -def _execute_and_check_ret_code(repo_ctx, cmd_and_args): - result = repo_ctx.execute(cmd_and_args, timeout=10) - if result.return_code != 0: - fail(("Non-zero return code({1}) when executing '{0}':\n" + "Stdout: {2}\n" - + "Stderr: {3}").format(" ".join(cmd_and_args), result.return_code, - result.stdout, result.stderr)) - -# Apply a patch_file to the repository root directory -# Runs 'patch -p1' -def _apply_patch(repo_ctx, patch_file): - # Don't check patch on Windows, because patch is only available under bash. - if not _is_windows(repo_ctx) and not repo_ctx.which("patch"): - fail("patch command is not found, please install it") - - cmd = [ - "patch", "-p1", "-d", repo_ctx.path("."), "-i", repo_ctx.path(patch_file) - ] - if _is_windows(repo_ctx): - bazel_sh = _get_env_var(repo_ctx, "BAZEL_SH") - if not bazel_sh: - fail("BAZEL_SH environment variable is not set") - cmd = [bazel_sh, "-l", "-c", " ".join(cmd)] - _execute_and_check_ret_code(repo_ctx, cmd) - -# Download the repository and apply a patch to its root -def _patched_http_archive_impl(repo_ctx): - repo_ctx.download_and_extract( - repo_ctx.attr.urls, - sha256=repo_ctx.attr.sha256, - stripPrefix=repo_ctx.attr.strip_prefix) - _apply_patch(repo_ctx, repo_ctx.attr.patch_file) - -patched_http_archive = repository_rule( - attrs = { - "patch_file": attr.label(), - "urls": attr.string_list(default = []), - "sha256": attr.string(default = ""), - "strip_prefix": attr.string(default = ""), - }, - implementation = _patched_http_archive_impl, -) - # If TensorFlow is linked as a submodule. # path_prefix is no longer used. # tf_repo_name is thought to be under consideration. @@ -134,7 +71,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): print("path_prefix was specified to tf_workspace but is no longer used " + "and will be removed in the future.") - native.new_http_archive( + tf_http_archive( name = "mkl_dnn", urls = [ "https://mirror.bazel.build/github.com/01org/mkl-dnn/archive/b01e3a55a07be62172e713bcd2644c5176360212.tar.gz", @@ -145,7 +82,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party/mkl_dnn:mkldnn.BUILD")), ) - native.http_archive( + tf_http_archive( name = "com_google_absl", urls = [ "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/cc4bed2d74f7c8717e31f9579214ab52a9c9c610.tar.gz", @@ -155,7 +92,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): strip_prefix = "abseil-cpp-cc4bed2d74f7c8717e31f9579214ab52a9c9c610", ) - native.new_http_archive( + tf_http_archive( name = "eigen_archive", urls = [ "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/429aa5254200.tar.gz", @@ -166,18 +103,20 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:eigen.BUILD")), ) - native.new_http_archive( + tf_http_archive( name = "arm_compiler", - build_file = str(Label("//:arm_compiler.BUILD")), sha256 = "970285762565c7890c6c087d262b0a18286e7d0384f13a37786d8521773bc969", strip_prefix = "tools-0e906ebc527eab1cdbf7adabff5b474da9562e9f/arm-bcm2708/arm-rpi-4.9.3-linux-gnueabihf", urls = [ "https://mirror.bazel.build/github.com/raspberrypi/tools/archive/0e906ebc527eab1cdbf7adabff5b474da9562e9f.tar.gz", + # Please uncomment me, when the next upgrade happens. Then + # remove the whitelist entry in third_party/repo.bzl. # "https://github.com/raspberrypi/tools/archive/0e906ebc527eab1cdbf7adabff5b474da9562e9f.tar.gz", ], + build_file = str(Label("//:arm_compiler.BUILD")), ) - native.new_http_archive( + tf_http_archive( name = "libxsmm_archive", urls = [ "https://mirror.bazel.build/github.com/hfp/libxsmm/archive/1.8.1.tar.gz", @@ -188,15 +127,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:libxsmm.BUILD")), ) - native.bind( - name = "xsmm_avx", - actual = "@libxsmm_archive//third_party:xsmm_avx", - ) - - native.new_http_archive( + tf_http_archive( name = "ortools_archive", urls = [ "https://mirror.bazel.build/github.com/google/or-tools/archive/253f7955c6a1fd805408fba2e42ac6d45b312d15.tar.gz", + # Please uncomment me, when the next upgrade happens. Then + # remove the whitelist entry in third_party/repo.bzl. # "https://github.com/google/or-tools/archive/253f7955c6a1fd805408fba2e42ac6d45b312d15.tar.gz", ], sha256 = "932075525642b04ac6f1b50589f1df5cd72ec2f448b721fd32234cf183f0e755", @@ -204,7 +140,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:ortools.BUILD")), ) - native.http_archive( + tf_http_archive( name = "com_googlesource_code_re2", urls = [ "https://mirror.bazel.build/github.com/google/re2/archive/26cd968b735e227361c9703683266f01e5df7857.tar.gz", @@ -215,7 +151,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): strip_prefix = "re2-26cd968b735e227361c9703683266f01e5df7857", ) - native.http_archive( + tf_http_archive( name = "gemmlowp", urls = [ "https://mirror.bazel.build/github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.zip", @@ -225,7 +161,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): strip_prefix = "gemmlowp-010bb3e71a26ca1d0884a167081d092b43563996", ) - native.new_http_archive( + tf_http_archive( name = "farmhash_archive", urls = [ "https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz", @@ -236,12 +172,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:farmhash.BUILD")), ) - native.bind( - name = "farmhash", - actual = "@farmhash//:farmhash", - ) - - native.new_http_archive( + tf_http_archive( name = "highwayhash", urls = [ "https://mirror.bazel.build/github.com/google/highwayhash/archive/dfcb97ca4fe9277bf9dc1802dd979b071896453b.tar.gz", @@ -252,7 +183,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:highwayhash.BUILD")), ) - native.new_http_archive( + tf_http_archive( name = "nasm", urls = [ "https://mirror.bazel.build/www.nasm.us/pub/nasm/releasebuilds/2.12.02/nasm-2.12.02.tar.bz2", @@ -263,7 +194,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:nasm.BUILD")), ) - native.new_http_archive( + tf_http_archive( name = "jpeg", urls = [ "https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.1.tar.gz", @@ -274,7 +205,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party/jpeg:jpeg.BUILD")), ) - native.new_http_archive( + tf_http_archive( name = "png_archive", urls = [ "https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.2.53.tar.gz", @@ -285,7 +216,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:png.BUILD")), ) - native.new_http_archive( + tf_http_archive( name = "sqlite_archive", urls = [ "https://mirror.bazel.build/www.sqlite.org/2017/sqlite-amalgamation-3200000.zip", @@ -293,10 +224,10 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "208780b3616f9de0aeb50822b7a8f5482f6515193859e91ed61637be6ad74fd4", strip_prefix = "sqlite-amalgamation-3200000", - build_file = str(Label("//third_party:sqlite.BUILD")) + build_file = str(Label("//third_party:sqlite.BUILD")), ) - native.new_http_archive( + tf_http_archive( name = "gif_archive", urls = [ "https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz", @@ -307,7 +238,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:gif.BUILD")), ) - native.new_http_archive( + tf_http_archive( name = "six_archive", urls = [ "https://mirror.bazel.build/pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", @@ -318,7 +249,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:six.BUILD")), ) - native.http_archive( + tf_http_archive( name = "absl_py", urls = [ "https://mirror.bazel.build/github.com/abseil/abseil-py/archive/acec853355ef987eae48a8d87a79351c15dff593.tar.gz", @@ -328,7 +259,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): strip_prefix = "abseil-py-acec853355ef987eae48a8d87a79351c15dff593", ) - native.new_http_archive( + tf_http_archive( name = "org_python_pypi_backports_weakref", urls = [ "https://mirror.bazel.build/pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz", @@ -339,7 +270,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:backports_weakref.BUILD")), ) - native.new_http_archive( + tf_http_archive( name = "com_github_andreif_codegen", urls = [ "https://mirror.bazel.build/github.com/andreif/codegen/archive/1.0.tar.gz", @@ -361,12 +292,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): }, ) - native.bind( - name = "six", - actual = "@six_archive//:six", - ) - - patched_http_archive( + tf_http_archive( name = "protobuf_archive", urls = [ "https://mirror.bazel.build/github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", @@ -381,20 +307,10 @@ def tf_workspace(path_prefix="", tf_repo_name=""): patch_file = str(Label("//third_party/protobuf:add_noinlines.patch")), ) - native.bind( - name = "protobuf", - actual = "@protobuf_archive//:protobuf", - ) - - native.bind( - name = "protobuf_headers", - actual = "@protobuf_archive//:protobuf_headers", - ) - # We need to import the protobuf library under the names com_google_protobuf # and com_google_protobuf_cc to enable proto_library support in bazel. # Unfortunately there is no way to alias http_archives at the moment. - native.http_archive( + tf_http_archive( name = "com_google_protobuf", urls = [ "https://mirror.bazel.build/github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", @@ -404,7 +320,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): strip_prefix = "protobuf-b04e5cba356212e4e8c66c61bbe0c3a20537c5b9", ) - native.http_archive( + tf_http_archive( name = "com_google_protobuf_cc", urls = [ "https://mirror.bazel.build/github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", @@ -414,7 +330,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): strip_prefix = "protobuf-b04e5cba356212e4e8c66c61bbe0c3a20537c5b9", ) - native.http_archive( + tf_http_archive( name = "nsync", urls = [ "https://mirror.bazel.build/github.com/google/nsync/archive/8502189abfa44c249c01c2cad64e6ed660a9a668.tar.gz", @@ -424,7 +340,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): strip_prefix = "nsync-8502189abfa44c249c01c2cad64e6ed660a9a668", ) - native.http_archive( + tf_http_archive( name = "com_google_googletest", urls = [ "https://mirror.bazel.build/github.com/google/googletest/archive/9816b96a6ddc0430671693df90192bbee57108b6.zip", @@ -434,7 +350,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): strip_prefix = "googletest-9816b96a6ddc0430671693df90192bbee57108b6", ) - native.http_archive( + tf_http_archive( name = "com_github_gflags_gflags", urls = [ "https://mirror.bazel.build/github.com/gflags/gflags/archive/f8a0efe03aa69b3336d8e228b37d4ccb17324b88.tar.gz", @@ -444,12 +360,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): strip_prefix = "gflags-f8a0efe03aa69b3336d8e228b37d4ccb17324b88", ) - native.bind( - name = "python_headers", - actual = str(Label("//util/python:python_headers")), - ) - - native.new_http_archive( + tf_http_archive( name = "pcre", sha256 = "ccdf7e788769838f8285b3ee672ed573358202305ee361cfec7a4a4fb005bbc7", urls = [ @@ -460,7 +371,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:pcre.BUILD")), ) - native.new_http_archive( + tf_http_archive( name = "swig", sha256 = "58a475dbbd4a4d7075e5fe86d4e54c9edde39847cdb96a3053d87cb64a23a453", urls = [ @@ -472,7 +383,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:swig.BUILD")), ) - native.new_http_archive( + tf_http_archive( name = "curl", sha256 = "ff3e80c1ca6a068428726cd7dd19037a47cc538ce58ef61c59587191039b2ca6", urls = [ @@ -483,26 +394,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:curl.BUILD")), ) - # grpc expects //external:protobuf_clib and //external:protobuf_compiler - # to point to the protobuf's compiler library. - native.bind( - name = "protobuf_clib", - actual = "@protobuf_archive//:protoc_lib", - ) - - native.bind( - name = "libssl", - actual = "@boringssl//:ssl", - ) - - # gRPC has includes directly from their third_party path for nanopb, so we - # must depend on their version of it. - native.bind( - name = "nanopb", - actual = "@grpc//third_party/nanopb:nanopb", - ) - - native.http_archive( + tf_http_archive( name = "grpc", urls = [ "https://mirror.bazel.build/github.com/grpc/grpc/archive/f836c7e941beb003289dc6e9a58a6e47f5caa5f0.tar.gz", @@ -512,26 +404,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): strip_prefix = "grpc-f836c7e941beb003289dc6e9a58a6e47f5caa5f0", ) - # gRPC wants the existence of a cares dependence but its contents are not - # actually important since we have set GRPC_ARES=0 in tools/bazel.rc - native.bind( - name = "cares", - actual = "@grpc//third_party/nanopb:nanopb", - ) - - # protobuf expects //external:grpc_cpp_plugin to point to grpc's - # C++ plugin code generator. - native.bind( - name = "grpc_cpp_plugin", - actual = "@grpc//:grpc_cpp_plugin", - ) - - native.bind( - name = "grpc_lib", - actual = "@grpc//:grpc++_unsecure", - ) - - native.new_http_archive( + tf_http_archive( name = "linenoise", sha256 = "7f51f45887a3d31b4ce4fa5965210a5e64637ceac12720cfce7954d6a2e812f7", urls = [ @@ -544,7 +417,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): # TODO(phawkins): currently, this rule uses an unofficial LLVM mirror. # Switch to an official source of snapshots if/when possible. - native.new_http_archive( + tf_http_archive( name = "llvm", urls = [ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/9ab4c272cb604a7f947865428c4ef2169fee2100.tar.gz", @@ -555,7 +428,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party/llvm:llvm.BUILD")), ) - native.new_http_archive( + tf_http_archive( name = "lmdb", urls = [ "https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz", @@ -566,7 +439,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:lmdb.BUILD")), ) - native.new_http_archive( + tf_http_archive( name = "jsoncpp_git", urls = [ "https://mirror.bazel.build/github.com/open-source-parsers/jsoncpp/archive/11086dd6a7eba04289944367ca82cea71299ed70.tar.gz", @@ -577,12 +450,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:jsoncpp.BUILD")), ) - native.bind( - name = "jsoncpp", - actual = "@jsoncpp_git//:jsoncpp", - ) - - native.http_archive( + tf_http_archive( name = "boringssl", urls = [ "https://mirror.bazel.build/github.com/google/boringssl/archive/a0fb951d2a26a8ee746b52f3ba81ab011a0af778.tar.gz", @@ -592,7 +460,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): strip_prefix = "boringssl-a0fb951d2a26a8ee746b52f3ba81ab011a0af778", ) - native.new_http_archive( + tf_http_archive( name = "zlib_archive", urls = [ "https://mirror.bazel.build/zlib.net/zlib-1.2.8.tar.gz", @@ -603,12 +471,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:zlib.BUILD")), ) - native.bind( - name = "zlib", - actual = "@zlib_archive//:zlib", - ) - - native.new_http_archive( + tf_http_archive( name = "fft2d", urls = [ "https://mirror.bazel.build/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz", @@ -618,7 +481,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party/fft2d:fft2d.BUILD")), ) - native.new_http_archive( + tf_http_archive( name = "snappy", urls = [ "https://mirror.bazel.build/github.com/google/snappy/archive/1.1.4.tar.gz", @@ -629,7 +492,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:snappy.BUILD")), ) - native.new_http_archive( + tf_http_archive( name = "nccl_archive", urls = [ "https://mirror.bazel.build/github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz", @@ -640,14 +503,14 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:nccl.BUILD")), ) - native.new_http_archive( + tf_http_archive( name = "aws", urls = [ - "https://mirror.bazel.build/github.com/aws/aws-sdk-cpp/archive/1.0.90.tar.gz", - "https://github.com/aws/aws-sdk-cpp/archive/1.0.90.tar.gz", + "https://mirror.bazel.build/github.com/aws/aws-sdk-cpp/archive/1.3.15.tar.gz", + "https://github.com/aws/aws-sdk-cpp/archive/1.3.15.tar.gz", ], - sha256 = "f599b57aec4f03ad696044dd430b2d201864113937353adc346f53ad47991319", - strip_prefix = "aws-sdk-cpp-1.0.90", + sha256 = "b888d8ce5fc10254c3dd6c9020c7764dd53cf39cf011249d0b4deda895de1b7c", + strip_prefix = "aws-sdk-cpp-1.3.15", build_file = str(Label("//third_party:aws.BUILD")), ) @@ -676,7 +539,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): testonly_ = True, ) - native.new_http_archive( + tf_http_archive( name = "jemalloc", urls = [ "https://mirror.bazel.build/github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz", @@ -722,7 +585,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): licenses = ["notice"], # Apache 2.0 ) - native.new_http_archive( + tf_http_archive( name = "com_google_pprof", urls = [ "https://mirror.bazel.build/github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz", @@ -733,7 +596,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:pprof.BUILD")), ) - native.new_http_archive( + tf_http_archive( name = "cub_archive", urls = [ "https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.4.zip", @@ -744,12 +607,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:cub.BUILD")), ) - native.bind( - name = "cub", - actual = "@cub_archive//:cub", - ) - - native.new_http_archive( + tf_http_archive( name = "cython", sha256 = "6dcd30b5ceb887b2b965ee7ceb82ea3acb5f0642fe2206c7636b45acea4798e5", urls = [ @@ -758,9 +616,10 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], strip_prefix = "cython-3732784c45cfb040a5b0936951d196f83a12ea17", build_file = str(Label("//third_party:cython.BUILD")), + delete = ["BUILD.bazel"], ) - native.http_archive( + tf_http_archive( name = "bazel_toolchains", urls = [ "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/b49ba3689f46ac50e9277dafd8ff32b26951f82e.tar.gz", @@ -770,7 +629,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): strip_prefix = "bazel-toolchains-b49ba3689f46ac50e9277dafd8ff32b26951f82e", ) - native.new_http_archive( + tf_http_archive( name = "arm_neon_2_x86_sse", sha256 = "c8d90aa4357f8079d427e87a6f4c493da1fa4140aee926c05902d7ec1533d9a5", strip_prefix = "ARM_NEON_2_x86_SSE-0f77d9d182265259b135dad949230ecbf1a2633d", @@ -781,25 +640,102 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:arm_neon_2_x86_sse.BUILD")), ) - native.new_http_archive( + tf_http_archive( name = "flatbuffers", - build_file = str(Label("//third_party/flatbuffers:flatbuffers.BUILD")), strip_prefix = "flatbuffers-971a68110e4fc1bace10fcb6deeb189e7e1a34ce", sha256 = "874088d2ee0d9f8524191f77209556415f03dd44e156276edf19e5b90ceb5f55", urls = [ "https://mirror.bazel.build/github.com/google/flatbuffers/archive/971a68110e4fc1bace10fcb6deeb189e7e1a34ce.tar.gz", "https://github.com/google/flatbuffers/archive/971a68110e4fc1bace10fcb6deeb189e7e1a34ce.tar.gz", ], + build_file = str(Label("//third_party/flatbuffers:flatbuffers.BUILD")), ) - native.new_http_archive( + + tf_http_archive( name = "tflite_mobilenet", - build_file = str(Label("//third_party:tflite_mobilenet.BUILD")), sha256 = "23f814d1c076bdf03715dfb6cab3713aa4fbdf040fd5448c43196bd2e97a4c1b", urls = [ "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip", "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip", ], + build_file = str(Label("//third_party:tflite_mobilenet.BUILD")), + ) + + ############################################################################## + # BIND DEFINITIONS + # + # Please do not add bind() definitions unless we have no other choice. + # If that ends up being the case, please leave a comment explaining + # why we can't depend on the canonical build target. + + # gRPC wants a cares dependency but its contents is not actually + # important since we have set GRPC_ARES=0 in tools/bazel.rc + native.bind( + name = "cares", + actual = "@grpc//third_party/nanopb:nanopb", + ) + + # Needed by Protobuf + native.bind( + name = "grpc_cpp_plugin", + actual = "@grpc//:grpc_cpp_plugin", + ) + + # gRPC has three empty C++ functions which it wants the user to define + # at build time. https://github.com/grpc/grpc/issues/13590 + native.bind( + name = "grpc_lib", + actual = "@grpc//:grpc++_unsecure", + ) + + # Needed by gRPC + native.bind( + name = "libssl", + actual = "@boringssl//:ssl", + ) + + # Needed by gRPC + native.bind( + name = "nanopb", + actual = "@grpc//third_party/nanopb:nanopb", + ) + + # Needed by gRPC + native.bind( + name = "protobuf", + actual = "@protobuf_archive//:protobuf", + ) + + # gRPC expects //external:protobuf_clib and //external:protobuf_compiler + # to point to Protobuf's compiler library. + native.bind( + name = "protobuf_clib", + actual = "@protobuf_archive//:protoc_lib", + ) + + # Needed by gRPC + native.bind( + name = "protobuf_headers", + actual = "@protobuf_archive//:protobuf_headers", + ) + + # Needed by Protobuf + native.bind( + name = "python_headers", + actual = str(Label("//util/python:python_headers")), + ) + + # Needed by Protobuf + native.bind( + name = "six", + actual = "@six_archive//:six", + ) + + # Needed by gRPC + native.bind( + name = "zlib", + actual = "@zlib_archive//:zlib", ) native.new_http_archive( diff --git a/third_party/repo.bzl b/third_party/repo.bzl index eb91316f67d..d6e5dfced0f 100644 --- a/third_party/repo.bzl +++ b/third_party/repo.bzl @@ -96,6 +96,7 @@ tf_http_archive = repository_rule( "build_file": attr.label(), }) """Downloads and creates Bazel repos for dependencies. + This is a swappable replacement for both http_archive() and new_http_archive() that offers some additional features. It also helps ensure best practices are followed.