mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Support TensorSpec input for experimental_get_compiler_ir
PiperOrigin-RevId: 501907754
This commit is contained in:
parent
49c939590d
commit
ea3ddeefb0
|
|
@ -129,6 +129,9 @@
|
|||
* Added `tf.nn.experimental.general_dropout`, which is similar to
|
||||
`tf.random.experimental.stateless_dropout` but accepts a custom sampler
|
||||
function.
|
||||
* `tf.types.experimental.GenericFunction`
|
||||
* The `experimental_get_compiler_ir` method supports tf.TensorSpec
|
||||
compilation arguments.
|
||||
|
||||
|
||||
# Thanks to our Contributors
|
||||
|
|
|
|||
|
|
@ -545,21 +545,20 @@ cc_library(
|
|||
hdrs = ["get_compiler_ir.h"],
|
||||
visibility = [":internal"],
|
||||
deps = [
|
||||
":common",
|
||||
":compilability_check_util",
|
||||
":device_compiler",
|
||||
":flags",
|
||||
":xla_device_no_jit_rewrite_registration",
|
||||
":xla_launch_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla/client:executable_build_options",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime:core_cpu_internal",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/types:span",
|
||||
|
|
|
|||
|
|
@ -15,31 +15,31 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/compiler/jit/get_compiler_ir.h"
|
||||
|
||||
#include <deque>
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/compiler/jit/compilability_check_util.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/device_compiler.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/xla_launch_util.h"
|
||||
#include "tensorflow/compiler/jit/xla_platform_info.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/resource_handle.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/statusor.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
#include "tensorflow/tsl/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
|
@ -77,108 +77,9 @@ static StatusOr<std::unique_ptr<xla::LocalExecutable>> BuildExecutable(
|
|||
return std::move(executables[0]);
|
||||
}
|
||||
|
||||
StatusOr<std::string> GetCompilerIr(
|
||||
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
|
||||
absl::string_view func_name, Device* dev, EagerContext* context,
|
||||
absl::Span<const TensorHandle* const> inputs_handles) {
|
||||
using XlaDeviceCompiler =
|
||||
DeviceCompiler<xla::LocalExecutable, xla::LocalClient>;
|
||||
|
||||
auto is_tfrt_tpu_supported_stage = [](IrExportStage stage) {
|
||||
return stage == IrExportStage::HLO ||
|
||||
stage == IrExportStage::HLO_NO_METADATA ||
|
||||
stage == IrExportStage::HLO_SERIALIZED;
|
||||
};
|
||||
// TODO(b/238830423): support GetCompilerIr on TFRT TPU device for stages
|
||||
// that requires compilation from HLO to executable.
|
||||
if (dev->device_type() != DEVICE_CPU &&
|
||||
dev->tensorflow_accelerator_device_info()->stream == nullptr &&
|
||||
!is_tfrt_tpu_supported_stage(stage)) {
|
||||
return errors::Internal(
|
||||
"GetCompilerIr with requested stage is not supported on this device.");
|
||||
}
|
||||
NameAttrList function;
|
||||
function.set_name(std::string{func_name});
|
||||
|
||||
FunctionLibraryRuntime* flr = pflr->GetFLR(dev->name());
|
||||
ResourceMgr* rmgr = dev->resource_manager();
|
||||
|
||||
const FunctionBody* fbody = nullptr;
|
||||
std::vector<int> constant_arg_indices;
|
||||
std::vector<int> resource_arg_indices;
|
||||
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
|
||||
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
|
||||
|
||||
MemoryTypeVector input_memory_types =
|
||||
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
|
||||
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
|
||||
|
||||
std::deque<Tensor> inputs_storage;
|
||||
std::vector<const Tensor*> inputs;
|
||||
inputs.reserve(inputs_handles.size());
|
||||
for (int i = 0; i < inputs_handles.size(); i++) {
|
||||
const TensorHandle* th = inputs_handles[i];
|
||||
const Tensor* t;
|
||||
// Handle owns the tensor.
|
||||
TF_RETURN_IF_ERROR(th->Tensor(&t));
|
||||
if (absl::c_binary_search(constant_arg_indices, i)) {
|
||||
// Need to make sure it's on the host.
|
||||
inputs_storage.emplace_back(t->dtype(), t->shape());
|
||||
TF_RETURN_IF_ERROR(
|
||||
th->CopyToDevice(*context, /*d=*/nullptr, &inputs_storage.back()));
|
||||
inputs.push_back(&inputs_storage.back());
|
||||
} else {
|
||||
inputs.push_back(t);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
TF_RETURN_IF_ERROR(GetVariableInfosFromInputs(
|
||||
rmgr, dev, inputs, resource_arg_indices, &variable_infos));
|
||||
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
|
||||
|
||||
XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(dev);
|
||||
|
||||
XlaDeviceCompiler* xla_device_compiler;
|
||||
TF_RETURN_IF_ERROR(rmgr->LookupOrCreate<XlaDeviceCompiler>(
|
||||
rmgr->default_container(), "xla_device_compiler", &xla_device_compiler,
|
||||
[&](XlaDeviceCompiler** xla_device_compiler) {
|
||||
return BuildXlaDeviceCompiler(dev, flr, platform_info,
|
||||
xla_device_compiler);
|
||||
}));
|
||||
core::ScopedUnref xla_device_compiler_ref(xla_device_compiler);
|
||||
|
||||
se::Stream* stream = nullptr;
|
||||
if (const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info =
|
||||
dev->tensorflow_accelerator_device_info()) {
|
||||
stream = accelerator_device_info->stream;
|
||||
}
|
||||
|
||||
XlaCompiler::Options options;
|
||||
if (platform_info.device_type() == DEVICE_TPU && stream == nullptr) {
|
||||
options = GenerateTfrtTpuCompilerOptions(*xla_device_compiler, *flr);
|
||||
} else {
|
||||
options = GenerateCompilerOptions(*xla_device_compiler, *flr, dev, stream,
|
||||
platform_info,
|
||||
/*has_ref_vars=*/false);
|
||||
}
|
||||
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
compile_options.always_return_tuple = false;
|
||||
compile_options.alias_resource_update = true;
|
||||
|
||||
XlaCompiler compiler(options);
|
||||
|
||||
StatusOr<std::vector<XlaCompiler::Argument>> args =
|
||||
XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constant_arg_indices, inputs, variable_infos, dev);
|
||||
TF_RETURN_IF_ERROR(args.status());
|
||||
|
||||
xla::LocalClient* local_client = xla_device_compiler->client();
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_RETURN_IF_ERROR(
|
||||
compiler.CompileFunction(compile_options, function, *args, &result));
|
||||
|
||||
static StatusOr<std::string> BuildHLOString(
|
||||
IrExportStage stage, const XlaCompiler::CompilationResult& result,
|
||||
xla::LocalClient* local_client, const XlaCompiler::Options& options) {
|
||||
switch (stage) {
|
||||
case IrExportStage::HLO:
|
||||
case IrExportStage::HLO_NO_METADATA:
|
||||
|
|
@ -232,4 +133,158 @@ StatusOr<std::string> GetCompilerIr(
|
|||
}
|
||||
}
|
||||
|
||||
static StatusOr<std::vector<XlaCompiler::Argument>>
|
||||
BuildXlaCompilerArgumentFromFuncBody(const FunctionBody* fbody) {
|
||||
TF_RET_CHECK(fbody != nullptr);
|
||||
auto& input_args = fbody->fdef.signature().input_arg();
|
||||
int input_arg_size = input_args.size();
|
||||
|
||||
// Shape info is not in input_arg. parse it from arg_attrs.
|
||||
auto& arg_attrs = fbody->fdef.arg_attr();
|
||||
if (arg_attrs.size() != input_arg_size) {
|
||||
return errors::InvalidArgument(
|
||||
"The function to be lowered uses some tf.Variable defined outside_"
|
||||
"_the_function body. This is not supported with using_tensor_spec."
|
||||
"Please modify the function with pure functional style.");
|
||||
}
|
||||
std::vector<TensorShape> shapes;
|
||||
shapes.reserve(input_arg_size);
|
||||
for (const auto& attr : arg_attrs) {
|
||||
const unsigned int& idx = attr.first;
|
||||
bool has_function_input_shape = false;
|
||||
for (const auto& attr_value : attr.second.attr()) {
|
||||
if (attr_value.first == "_output_shapes") {
|
||||
TF_RETURN_IF_ERROR(TensorShape::BuildTensorShape(
|
||||
attr_value.second.list().shape()[0], &shapes[idx]));
|
||||
has_function_input_shape = true;
|
||||
}
|
||||
}
|
||||
TF_RET_CHECK(has_function_input_shape);
|
||||
}
|
||||
|
||||
// Build Xla Compiler Arguments
|
||||
std::vector<XlaCompiler::Argument> args;
|
||||
args.resize(input_arg_size);
|
||||
for (int64_t input_num = 0; input_num < input_arg_size; ++input_num) {
|
||||
XlaCompiler::Argument& arg = args[input_num];
|
||||
arg.kind = XlaCompiler::Argument::kParameter;
|
||||
arg.type = input_args[input_num].type();
|
||||
arg.shape = shapes[input_num];
|
||||
arg.name = input_args[input_num].name();
|
||||
}
|
||||
return args;
|
||||
}
|
||||
|
||||
StatusOr<std::string> GetCompilerIr(
|
||||
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
|
||||
absl::string_view func_name, Device* dev, EagerContext* context,
|
||||
absl::Span<const TensorHandle* const> inputs_handles) {
|
||||
// input_handles vector is empty for using_tensor_spec case
|
||||
bool using_tensor_spec = inputs_handles.empty() ? true : false;
|
||||
using XlaDeviceCompiler =
|
||||
DeviceCompiler<xla::LocalExecutable, xla::LocalClient>;
|
||||
|
||||
auto is_tfrt_tpu_supported_stage = [](IrExportStage stage) {
|
||||
return stage == IrExportStage::HLO ||
|
||||
stage == IrExportStage::HLO_NO_METADATA ||
|
||||
stage == IrExportStage::HLO_SERIALIZED;
|
||||
};
|
||||
// TODO(b/238830423): support GetCompilerIr on TFRT TPU device for stages
|
||||
// that requires compilation from HLO to executable.
|
||||
if (dev->device_type() != DEVICE_CPU &&
|
||||
dev->tensorflow_accelerator_device_info()->stream == nullptr &&
|
||||
!is_tfrt_tpu_supported_stage(stage)) {
|
||||
return errors::Internal(
|
||||
"GetCompilerIr with requested stage is not supported on this device.");
|
||||
}
|
||||
NameAttrList function;
|
||||
function.set_name(std::string{func_name});
|
||||
|
||||
FunctionLibraryRuntime* flr = pflr->GetFLR(dev->name());
|
||||
ResourceMgr* rmgr = dev->resource_manager();
|
||||
|
||||
const FunctionBody* fbody = nullptr;
|
||||
std::vector<int> constant_arg_indices;
|
||||
std::vector<int> resource_arg_indices;
|
||||
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
|
||||
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
|
||||
|
||||
std::vector<const Tensor*> inputs;
|
||||
std::deque<Tensor> inputs_storage;
|
||||
inputs.reserve(inputs_handles.size());
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
if (!using_tensor_spec) {
|
||||
MemoryTypeVector input_memory_types =
|
||||
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
|
||||
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
|
||||
for (int i = 0; i < inputs_handles.size(); i++) {
|
||||
const TensorHandle* th = inputs_handles[i];
|
||||
const Tensor* t;
|
||||
// Handle owns the tensor.
|
||||
TF_RETURN_IF_ERROR(th->Tensor(&t));
|
||||
if (absl::c_binary_search(constant_arg_indices, i)) {
|
||||
// Need to make sure it's on the host.
|
||||
inputs_storage.emplace_back(t->dtype(), t->shape());
|
||||
TF_RETURN_IF_ERROR(
|
||||
th->CopyToDevice(*context, /*d=*/nullptr, &inputs_storage.back()));
|
||||
inputs.push_back(&inputs_storage.back());
|
||||
} else {
|
||||
inputs.push_back(t);
|
||||
}
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(GetVariableInfosFromInputs(
|
||||
rmgr, dev, inputs, resource_arg_indices, &variable_infos));
|
||||
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
|
||||
}
|
||||
|
||||
XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(dev);
|
||||
|
||||
XlaDeviceCompiler* xla_device_compiler;
|
||||
TF_RETURN_IF_ERROR(rmgr->LookupOrCreate<XlaDeviceCompiler>(
|
||||
rmgr->default_container(), "xla_device_compiler", &xla_device_compiler,
|
||||
[&](XlaDeviceCompiler** xla_device_compiler) {
|
||||
return BuildXlaDeviceCompiler(dev, flr, platform_info,
|
||||
xla_device_compiler);
|
||||
}));
|
||||
core::ScopedUnref xla_device_compiler_ref(xla_device_compiler);
|
||||
|
||||
se::Stream* stream = nullptr;
|
||||
if (const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info =
|
||||
dev->tensorflow_accelerator_device_info()) {
|
||||
stream = accelerator_device_info->stream;
|
||||
}
|
||||
|
||||
XlaCompiler::Options options;
|
||||
if (platform_info.device_type() == DEVICE_TPU && stream == nullptr) {
|
||||
options = GenerateTfrtTpuCompilerOptions(*xla_device_compiler, *flr);
|
||||
} else {
|
||||
options = GenerateCompilerOptions(*xla_device_compiler, *flr, dev, stream,
|
||||
platform_info,
|
||||
/*has_ref_vars=*/false);
|
||||
}
|
||||
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
compile_options.always_return_tuple = false;
|
||||
compile_options.alias_resource_update = true;
|
||||
|
||||
XlaCompiler compiler(options);
|
||||
|
||||
StatusOr<std::vector<XlaCompiler::Argument>> args;
|
||||
if (using_tensor_spec) {
|
||||
args = BuildXlaCompilerArgumentFromFuncBody(fbody);
|
||||
} else {
|
||||
args = XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constant_arg_indices, inputs, variable_infos, dev);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(args.status());
|
||||
|
||||
xla::LocalClient* local_client = xla_device_compiler->client();
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_RETURN_IF_ERROR(
|
||||
compiler.CompileFunction(compile_options, function, *args, &result));
|
||||
|
||||
return BuildHLOString(stage, result, local_client, options);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
|||
|
|
@ -63,6 +63,7 @@ py_library(
|
|||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":attributes",
|
||||
":compiler_ir",
|
||||
":function_spec",
|
||||
"//tensorflow/python:cond_v2", # TODO(b/118513001): Imported via control_flow_ops; remove.
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
|
|
@ -380,3 +381,48 @@ tf_py_test(
|
|||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "compiler_ir",
|
||||
srcs = ["compiler_ir.py"],
|
||||
srcs_version = "PY3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/util",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "compiler_ir_test",
|
||||
srcs = ["compiler_ir_test.py"],
|
||||
disabled_backends = [
|
||||
"cpu_ondemand",
|
||||
],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_mac",
|
||||
"no_pip",
|
||||
"no_tfrt", # TODO(b/185944215)
|
||||
"no_windows",
|
||||
],
|
||||
use_xla_device = False,
|
||||
deps = [
|
||||
":compiler_ir",
|
||||
":polymorphic_function",
|
||||
"//tensorflow/compiler/tests:xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:array_ops_gen",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/framework:dtypes",
|
||||
"//tensorflow/python/framework:tensor_spec",
|
||||
"//tensorflow/python/framework:test_lib",
|
||||
"//tensorflow/python/util",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
70
tensorflow/python/eager/polymorphic_function/compiler_ir.py
Normal file
70
tensorflow/python/eager/polymorphic_function/compiler_ir.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Implmentation for defining get_compiler_ir."""
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
def maybe_get_device_name(device_name):
|
||||
# TODO(cheshire): This is a hack to get the current "preferred" device,
|
||||
# there is no current API to get it otherwise.
|
||||
if device_name is None:
|
||||
device_name = random_ops.random_normal([]).device
|
||||
return device_name
|
||||
|
||||
|
||||
def from_concrete_function(concrete_fn):
|
||||
"""Generate the Compiler Ir from tf concrete function.
|
||||
|
||||
Args:
|
||||
concrete_fn: returned by using get_concrete_function.
|
||||
|
||||
Returns:
|
||||
Function callable that generate the HLO text.
|
||||
|
||||
Raises:
|
||||
ValueError: if concrete_fn is not "compilable" without concrete
|
||||
inputs.
|
||||
"""
|
||||
context.ensure_initialized()
|
||||
# TODO(b/265073174) support users input tf.TensorSpec list here.
|
||||
if not all(
|
||||
[s.shape.is_fully_defined() for s in nest.flatten(concrete_fn.inputs)]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Only support static input shape but got inputs = {concrete_fn.inputs}"
|
||||
)
|
||||
fn_name = concrete_fn.name
|
||||
|
||||
def compiler_ir_generator(stage="hlo", device_name=None):
|
||||
device_name = maybe_get_device_name(device_name)
|
||||
res_bytes = context.context().get_compiler_ir(
|
||||
device_name=device_name,
|
||||
stage=stage,
|
||||
function_name=fn_name,
|
||||
# args list is empty for using_tensor_spec case
|
||||
args=[],
|
||||
)
|
||||
if stage in (
|
||||
"hlo_serialized",
|
||||
"optimized_hlo_serialized",
|
||||
"optimized_hlo_proto_serialized",
|
||||
):
|
||||
return res_bytes
|
||||
else:
|
||||
return res_bytes.decode("utf-8")
|
||||
|
||||
return compiler_ir_generator
|
||||
176
tensorflow/python/eager/polymorphic_function/compiler_ir_test.py
Normal file
176
tensorflow/python/eager/polymorphic_function/compiler_ir_test.py
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.eager.polymorphic_function import compiler_ir
|
||||
from tensorflow.python.eager.polymorphic_function import polymorphic_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class CompilerIrTest(xla_test.XLATestCase):
|
||||
|
||||
def _compareTwoMethodsCompilerIROutput(self, f, args, kwargs):
|
||||
flat_args = list(args) + list(kwargs.values())
|
||||
if not all([isinstance(x, ops.Tensor) for x in flat_args]):
|
||||
self.skipTest('It only support args and kwargs are all tf.Tensor types.')
|
||||
|
||||
args_spec = nest.map_structure(tensor_spec.TensorSpec.from_tensor, args)
|
||||
kwargs_spec = nest.map_structure(tensor_spec.TensorSpec.from_tensor, kwargs)
|
||||
|
||||
hlo_1 = f.experimental_get_compiler_ir(*args, **kwargs)()
|
||||
hlo_2 = f.experimental_get_compiler_ir(*args_spec, **kwargs_spec)()
|
||||
|
||||
if hlo_1 != hlo_2:
|
||||
self.fail(
|
||||
'The tensor_spec way experimental_get_compiler_ir give diff result to'
|
||||
f' normal experimental_get_compiler_ir. \nhlo_1:\n{hlo_1}'
|
||||
f'\nhlo_2:\n{hlo_2}\n'
|
||||
)
|
||||
|
||||
def test_zero_input(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
||||
@polymorphic_function.function(jit_compile=True, autograph=False)
|
||||
def fun_tf():
|
||||
return array_ops.zeros((10), dtype=dtypes.int32)
|
||||
|
||||
self._compareTwoMethodsCompilerIROutput(fun_tf, [], {})
|
||||
|
||||
def test_constant_slice(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
# Constant slice. This is the common case.
|
||||
x = array_ops.zeros((10,), dtype=dtypes.int32)
|
||||
|
||||
@polymorphic_function.function(jit_compile=True, autograph=False)
|
||||
def fun_tf(x):
|
||||
begin = 0
|
||||
return x[begin:5]
|
||||
|
||||
self._compareTwoMethodsCompilerIROutput(fun_tf, [x], {})
|
||||
|
||||
def test_compile_time_constant(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
# Non-constant slice, but compile-time constant depending only on shapes.
|
||||
x = array_ops.zeros((10,), dtype=dtypes.int32)
|
||||
|
||||
@polymorphic_function.function(jit_compile=True, autograph=False)
|
||||
def fun_tf(x):
|
||||
# begin is a compile-time constant, even if x is not
|
||||
begin = array_ops.shape_v2(x)[0] - 2
|
||||
return x[begin:]
|
||||
|
||||
self._compareTwoMethodsCompilerIROutput(fun_tf, [x], {})
|
||||
|
||||
def test_capture_constant(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
# Capture a constant
|
||||
outer_ct = [3.0]
|
||||
x = ops.convert_to_tensor([2.0, 3.0, 4.0], dtype=dtypes.float32)
|
||||
|
||||
@polymorphic_function.function(jit_compile=True, autograph=False)
|
||||
def fun_tf(x):
|
||||
return x * gen_array_ops.broadcast_to(outer_ct, x.shape) + 1.0
|
||||
|
||||
self._compareTwoMethodsCompilerIROutput(fun_tf, [x], {})
|
||||
|
||||
def test_unsupported_dynamic_input(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
||||
@polymorphic_function.function(jit_compile=True)
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Only support static input shape but got'
|
||||
):
|
||||
args_spec = [tensor_spec.TensorSpec((None), dtype=dtypes.float32)]
|
||||
concrete_fn = f.get_concrete_function(*args_spec)
|
||||
_ = compiler_ir.from_concrete_function(concrete_fn)(stage='hlo')
|
||||
|
||||
def test_unsupported_shape_depend_input(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
# Those cases output shapes are dynamic.
|
||||
@polymorphic_function.function(jit_compile=True)
|
||||
def f2(x):
|
||||
return x[x[0] : 0]
|
||||
|
||||
args = [ops.convert_to_tensor([1, 2, 3, 4])]
|
||||
args_spec = nest.map_structure(tensor_spec.TensorSpec.from_tensor, args)
|
||||
concrete_fn = f2.get_concrete_function(*args_spec)
|
||||
if test_util.is_mlir_bridge_enabled():
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'TF to XLA legalization failed'
|
||||
):
|
||||
_ = compiler_ir.from_concrete_function(concrete_fn)(stage='hlo')
|
||||
else:
|
||||
_ = compiler_ir.from_concrete_function(concrete_fn)(stage='hlo')
|
||||
|
||||
# Those cases both input and output shapes are static but tf graph
|
||||
# contains constant args inside.
|
||||
@polymorphic_function.function(jit_compile=True)
|
||||
def f3(a, b):
|
||||
c = array_ops.slice(a, [b], [-1])
|
||||
return math_ops.reduce_sum(c)
|
||||
|
||||
args = [ops.convert_to_tensor([2, 3]), ops.convert_to_tensor(1)]
|
||||
args_spec = nest.map_structure(tensor_spec.TensorSpec.from_tensor, args)
|
||||
concrete_fn = f3.get_concrete_function(*args_spec)
|
||||
if test_util.is_mlir_bridge_enabled():
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'TF to XLA legalization failed'
|
||||
):
|
||||
_ = compiler_ir.from_concrete_function(concrete_fn)(stage='hlo')
|
||||
else:
|
||||
_ = compiler_ir.from_concrete_function(concrete_fn)(stage='hlo')
|
||||
|
||||
def test_unsupported_capture_outside_variable(self):
|
||||
"""Those cases define tf.Variable outside function body."""
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
error_msg = (
|
||||
'The function to be lowered uses some tf.Variable defined outside_'
|
||||
'_the_function body.'
|
||||
)
|
||||
|
||||
v = variables.Variable([0.1, 0.1])
|
||||
|
||||
@polymorphic_function.function(jit_compile=True)
|
||||
def f4(a, b):
|
||||
return (a + b) * v
|
||||
|
||||
a = constant_op.constant([1.1, 1.1])
|
||||
b = constant_op.constant([2.2, 2.2])
|
||||
|
||||
kwargs = {'b': a, 'a': b}
|
||||
with self.assertRaisesRegex(ValueError, error_msg):
|
||||
kwargs_spec = nest.map_structure(
|
||||
tensor_spec.TensorSpec.from_tensor, kwargs
|
||||
)
|
||||
concrete_fn = f4.get_concrete_function(**kwargs_spec)
|
||||
_ = compiler_ir.from_concrete_function(concrete_fn)(stage='hlo')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
||||
|
|
@ -75,17 +75,18 @@ from tensorflow.python.eager import context
|
|||
from tensorflow.python.eager import lift_to_graph
|
||||
from tensorflow.python.eager import monitoring
|
||||
from tensorflow.python.eager.polymorphic_function import attributes as attributes_lib
|
||||
from tensorflow.python.eager.polymorphic_function import compiler_ir
|
||||
from tensorflow.python.eager.polymorphic_function import function_spec as function_spec_lib
|
||||
from tensorflow.python.eager.polymorphic_function import tracing_compiler
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import func_graph as func_graph_module
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.profiler import trace
|
||||
|
|
@ -998,6 +999,32 @@ class Function(core.GenericFunction, trackable.Trackable):
|
|||
raise ValueError("Compiler IR can only be returned for functions marked "
|
||||
"with 'jit_compile=True'")
|
||||
|
||||
is_tensor_spec = lambda x: isinstance(x, tensor_spec.TensorSpec)
|
||||
|
||||
def _check_inputs(args, kwargs):
|
||||
all_inputs = list(args) + list(kwargs.values())
|
||||
# Emtpy input is okay.
|
||||
if not all_inputs:
|
||||
return
|
||||
if any(map(is_tensor_spec, all_inputs)) and any(
|
||||
map(lambda x: not is_tensor_spec(x), all_inputs)
|
||||
):
|
||||
raise ValueError(
|
||||
"experimental_get_compiler_ir supports either "
|
||||
"(1) all inputs are TensorSpec or "
|
||||
"(2) all inputs are tf.Tensor/python variables"
|
||||
)
|
||||
|
||||
_check_inputs(args, kwargs)
|
||||
if (
|
||||
len(args) + len(kwargs.values()) > 0
|
||||
and all(map(is_tensor_spec, args))
|
||||
and all(map(is_tensor_spec, kwargs.values()))
|
||||
):
|
||||
# For the case inputs are not empty and input types are all tf.TensorSpec
|
||||
concrete_fn = self.get_concrete_function(*args, **kwargs)
|
||||
return compiler_ir.from_concrete_function(concrete_fn)
|
||||
|
||||
concrete_fn = self.get_concrete_function(*args, **kwargs)
|
||||
fn_name = concrete_fn.name
|
||||
|
||||
|
|
@ -1006,10 +1033,7 @@ class Function(core.GenericFunction, trackable.Trackable):
|
|||
concrete_fn._function_spec.canonicalize_function_inputs(args, kwargs))
|
||||
|
||||
def compiler_ir_generator(stage="hlo", device_name=None):
|
||||
# TODO(cheshire): This is a hack to get the current "preferred" device,
|
||||
# there is no current API to get it otherwise.
|
||||
if device_name is None:
|
||||
device_name = random_ops.random_normal([]).device
|
||||
device_name = compiler_ir.maybe_get_device_name(device_name)
|
||||
res_bytes = context.context().get_compiler_ir(
|
||||
device_name=device_name,
|
||||
stage=stage,
|
||||
|
|
|
|||
|
|
@ -38,11 +38,31 @@ from tensorflow.python.ops import summary_ops_v2
|
|||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
@test_util.with_eager_op_as_function
|
||||
class FunctionTest(xla_test.XLATestCase):
|
||||
|
||||
def _compareTwoMethodsCompilerIROutput(self, f, args, kwargs):
|
||||
"""Assert the two differnet methods (tensor_spec inputs or tensor inputs) experimental_get_compiler give same HLO text."""
|
||||
flat_args = list(args) + list(kwargs.values())
|
||||
if not all([isinstance(x, ops.Tensor) for x in flat_args]):
|
||||
self.skipTest('It only support args and kwargs are all tf.Tensor types.')
|
||||
|
||||
args_spec = nest.map_structure(tensor_spec.TensorSpec.from_tensor, args)
|
||||
kwargs_spec = nest.map_structure(tensor_spec.TensorSpec.from_tensor, kwargs)
|
||||
|
||||
hlo_1 = f.experimental_get_compiler_ir(*args, **kwargs)()
|
||||
hlo_2 = f.experimental_get_compiler_ir(*args_spec, **kwargs_spec)()
|
||||
|
||||
if hlo_1 != hlo_2:
|
||||
self.fail(
|
||||
'The tensor_spec way experimental_get_compiler_ir give diff result to'
|
||||
f' normal experimental_get_compiler_ir. \nhlo_1:\n{hlo_1}'
|
||||
f'\nhlo_2:\n{hlo_2}\n'
|
||||
)
|
||||
|
||||
def testAutoclusteringWithTfFunction(self):
|
||||
if 'tpu' in self.device.lower():
|
||||
self.skipTest('Autoclustering does not run on TPU')
|
||||
|
|
@ -189,6 +209,7 @@ class FunctionTest(xla_test.XLATestCase):
|
|||
matches = re.findall('channel_id=([0-9]*),', hlo_str)
|
||||
self.assertLen(matches, 2)
|
||||
self.assertNotEqual(matches[0], matches[1])
|
||||
self._compareTwoMethodsCompilerIROutput(fn, [inputs, inputs], {})
|
||||
|
||||
def testCollectiveReduceGroupAssignment(self):
|
||||
if not test_util.is_mlir_bridge_enabled():
|
||||
|
|
@ -209,6 +230,7 @@ class FunctionTest(xla_test.XLATestCase):
|
|||
# instructions generated by XLA.
|
||||
hlo_str = fn.experimental_get_compiler_ir(inputs)()
|
||||
self.assertIn('replica_groups={{0}}', hlo_str)
|
||||
self._compareTwoMethodsCompilerIROutput(fn, [inputs], {})
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
|
||||
'support stack traces')
|
||||
|
|
@ -222,6 +244,7 @@ class FunctionTest(xla_test.XLATestCase):
|
|||
inputs = constant_op.constant([1, 2, 2, 3, 3])
|
||||
self.assertIn('polymorphic_function_xla_jit_test',
|
||||
fn.experimental_get_compiler_ir(inputs, inputs)())
|
||||
self._compareTwoMethodsCompilerIROutput(fn, [inputs, inputs], {})
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
|
||||
'support stack traces')
|
||||
|
|
@ -239,6 +262,7 @@ class FunctionTest(xla_test.XLATestCase):
|
|||
inputs = constant_op.constant([1, 2, 2, 3, 3])
|
||||
self.assertIn('polymorphic_function_xla_jit_test',
|
||||
g.experimental_get_compiler_ir(inputs, inputs)())
|
||||
self._compareTwoMethodsCompilerIROutput(g, [inputs, inputs], {})
|
||||
|
||||
def testPythonStackTrace(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
|
@ -864,6 +888,7 @@ class FunctionTest(xla_test.XLATestCase):
|
|||
self.assertIn(
|
||||
'label',
|
||||
f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo_dot'))
|
||||
self._compareTwoMethodsCompilerIROutput(f, [a, b], {})
|
||||
|
||||
def testGetCompilerIrNoDevicePlacement(self):
|
||||
if 'gpu' not in self.device.lower():
|
||||
|
|
@ -879,6 +904,7 @@ class FunctionTest(xla_test.XLATestCase):
|
|||
self.assertIn(
|
||||
'label',
|
||||
f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo_dot'))
|
||||
self._compareTwoMethodsCompilerIROutput(f, [a, b], {})
|
||||
|
||||
def testGetCompilerIrNonTensors(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
|
@ -891,6 +917,7 @@ class FunctionTest(xla_test.XLATestCase):
|
|||
|
||||
self.assertIn('tuple',
|
||||
f.experimental_get_compiler_ir(l)())
|
||||
self._compareTwoMethodsCompilerIROutput(f, [l], {})
|
||||
|
||||
def testGetCompilerIrSerialized(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
|
@ -904,6 +931,7 @@ class FunctionTest(xla_test.XLATestCase):
|
|||
hlo = fn.experimental_get_compiler_ir(inputs)(
|
||||
stage=stage, device_name=f'/device:{self.device}:0')
|
||||
self.assertIsInstance(hlo, bytes)
|
||||
self._compareTwoMethodsCompilerIROutput(fn, [inputs], {})
|
||||
|
||||
def testDotOptimizedHlo(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
|
@ -1171,12 +1199,35 @@ class FunctionTest(xla_test.XLATestCase):
|
|||
def f(x):
|
||||
return math_ops.argmax(x)
|
||||
|
||||
hlo = f.experimental_get_compiler_ir(
|
||||
array_ops.ones([10], dtype=dtypes.float32))(
|
||||
stage='hlo')
|
||||
inputs = array_ops.ones([10], dtype=dtypes.float32)
|
||||
hlo = f.experimental_get_compiler_ir(inputs)(stage='hlo')
|
||||
|
||||
# Test that reduction occurs only once.
|
||||
self.assertGreater(hlo.count('reduce'), 1)
|
||||
self._compareTwoMethodsCompilerIROutput(f, [inputs], {})
|
||||
|
||||
def testExperimentalGetCompilerIRBasic(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
||||
@polymorphic_function.function(jit_compile=True)
|
||||
def inner_tf_func(x):
|
||||
return math_ops.sin(x)
|
||||
|
||||
x = constant_op.constant([2.0, 3.0])
|
||||
self._compareTwoMethodsCompilerIROutput(inner_tf_func, [x], {})
|
||||
|
||||
def testExperimentalGetCompilerIRAutograph(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
||||
@polymorphic_function.function(jit_compile=True, autograph=True)
|
||||
def f(x, y):
|
||||
if x[0] > 1:
|
||||
return y[0]
|
||||
else:
|
||||
return y[1]
|
||||
|
||||
x, y = constant_op.constant([2, 3]), constant_op.constant([2, 3])
|
||||
self._compareTwoMethodsCompilerIROutput(f, [x, y], {})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -179,9 +179,10 @@ class GenericFunction(Callable):
|
|||
backwards compatibility of returned IR or the allowed values of `stage`.
|
||||
|
||||
Args:
|
||||
*args: Arguments used for compilation; same arguments as used for calling
|
||||
the function. Need to be eager tensors.
|
||||
**kwargs: Keyword arguments used for compilation.
|
||||
*args: compilation args supports inputs either: (1) all inputs are
|
||||
TensorSpec or (2) all inputs are tf.Tensor/Python variables.
|
||||
**kwargs: Keyword arguments used for compilation. Same requirement as
|
||||
compiliation args.
|
||||
|
||||
Returns:
|
||||
Function callable with the following kwargs:
|
||||
|
|
@ -230,8 +231,11 @@ class GenericFunction(Callable):
|
|||
```
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid `stage` is selected or if applied to a function
|
||||
which is not compiled (`jit_compile=True` is not set).
|
||||
ValueError:
|
||||
(1) If an invalid `stage` is selected
|
||||
(2) or if applied to a function which is not compiled
|
||||
(`jit_compile=True` is not set).
|
||||
(3) or if input shapes are not fully defined for tf.TensorSpec inputs
|
||||
TypeError: When called with input in graph mode.
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user