Support TensorSpec input for experimental_get_compiler_ir

PiperOrigin-RevId: 501907754
This commit is contained in:
John QiangZhang 2023-01-13 12:08:14 -08:00 committed by TensorFlower Gardener
parent 49c939590d
commit ea3ddeefb0
9 changed files with 554 additions and 126 deletions

View File

@ -129,6 +129,9 @@
* Added `tf.nn.experimental.general_dropout`, which is similar to * Added `tf.nn.experimental.general_dropout`, which is similar to
`tf.random.experimental.stateless_dropout` but accepts a custom sampler `tf.random.experimental.stateless_dropout` but accepts a custom sampler
function. function.
* `tf.types.experimental.GenericFunction`
* The `experimental_get_compiler_ir` method supports tf.TensorSpec
compilation arguments.
# Thanks to our Contributors # Thanks to our Contributors

View File

@ -545,21 +545,20 @@ cc_library(
hdrs = ["get_compiler_ir.h"], hdrs = ["get_compiler_ir.h"],
visibility = [":internal"], visibility = [":internal"],
deps = [ deps = [
":common",
":compilability_check_util", ":compilability_check_util",
":device_compiler", ":device_compiler",
":flags",
":xla_device_no_jit_rewrite_registration", ":xla_device_no_jit_rewrite_registration",
":xla_launch_util", ":xla_launch_util",
"//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/compiler/xla/client:executable_build_options",
"//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:hlo_graph_dumper",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime:core_cpu_internal", "//tensorflow/core/common_runtime:core_cpu_internal",
"//tensorflow/core/common_runtime/eager:tensor_handle", "//tensorflow/core/common_runtime/eager:tensor_handle",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",

View File

@ -15,31 +15,31 @@ limitations under the License.
#include "tensorflow/compiler/jit/get_compiler_ir.h" #include "tensorflow/compiler/jit/get_compiler_ir.h"
#include <deque>
#include <iterator>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "tensorflow/compiler/jit/compilability_check_util.h" #include "tensorflow/compiler/jit/compilability_check_util.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/device_compiler.h" #include "tensorflow/compiler/jit/device_compiler.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/jit/xla_platform_info.h" #include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/resource_handle.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/platform/statusor.h"
#include "tensorflow/core/util/ptr_util.h" #include "tensorflow/tsl/platform/errors.h"
namespace tensorflow { namespace tensorflow {
@ -77,108 +77,9 @@ static StatusOr<std::unique_ptr<xla::LocalExecutable>> BuildExecutable(
return std::move(executables[0]); return std::move(executables[0]);
} }
StatusOr<std::string> GetCompilerIr( static StatusOr<std::string> BuildHLOString(
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr, IrExportStage stage, const XlaCompiler::CompilationResult& result,
absl::string_view func_name, Device* dev, EagerContext* context, xla::LocalClient* local_client, const XlaCompiler::Options& options) {
absl::Span<const TensorHandle* const> inputs_handles) {
using XlaDeviceCompiler =
DeviceCompiler<xla::LocalExecutable, xla::LocalClient>;
auto is_tfrt_tpu_supported_stage = [](IrExportStage stage) {
return stage == IrExportStage::HLO ||
stage == IrExportStage::HLO_NO_METADATA ||
stage == IrExportStage::HLO_SERIALIZED;
};
// TODO(b/238830423): support GetCompilerIr on TFRT TPU device for stages
// that requires compilation from HLO to executable.
if (dev->device_type() != DEVICE_CPU &&
dev->tensorflow_accelerator_device_info()->stream == nullptr &&
!is_tfrt_tpu_supported_stage(stage)) {
return errors::Internal(
"GetCompilerIr with requested stage is not supported on this device.");
}
NameAttrList function;
function.set_name(std::string{func_name});
FunctionLibraryRuntime* flr = pflr->GetFLR(dev->name());
ResourceMgr* rmgr = dev->resource_manager();
const FunctionBody* fbody = nullptr;
std::vector<int> constant_arg_indices;
std::vector<int> resource_arg_indices;
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
MemoryTypeVector input_memory_types =
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
std::deque<Tensor> inputs_storage;
std::vector<const Tensor*> inputs;
inputs.reserve(inputs_handles.size());
for (int i = 0; i < inputs_handles.size(); i++) {
const TensorHandle* th = inputs_handles[i];
const Tensor* t;
// Handle owns the tensor.
TF_RETURN_IF_ERROR(th->Tensor(&t));
if (absl::c_binary_search(constant_arg_indices, i)) {
// Need to make sure it's on the host.
inputs_storage.emplace_back(t->dtype(), t->shape());
TF_RETURN_IF_ERROR(
th->CopyToDevice(*context, /*d=*/nullptr, &inputs_storage.back()));
inputs.push_back(&inputs_storage.back());
} else {
inputs.push_back(t);
}
}
std::vector<VariableInfo> variable_infos;
TF_RETURN_IF_ERROR(GetVariableInfosFromInputs(
rmgr, dev, inputs, resource_arg_indices, &variable_infos));
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(dev);
XlaDeviceCompiler* xla_device_compiler;
TF_RETURN_IF_ERROR(rmgr->LookupOrCreate<XlaDeviceCompiler>(
rmgr->default_container(), "xla_device_compiler", &xla_device_compiler,
[&](XlaDeviceCompiler** xla_device_compiler) {
return BuildXlaDeviceCompiler(dev, flr, platform_info,
xla_device_compiler);
}));
core::ScopedUnref xla_device_compiler_ref(xla_device_compiler);
se::Stream* stream = nullptr;
if (const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info =
dev->tensorflow_accelerator_device_info()) {
stream = accelerator_device_info->stream;
}
XlaCompiler::Options options;
if (platform_info.device_type() == DEVICE_TPU && stream == nullptr) {
options = GenerateTfrtTpuCompilerOptions(*xla_device_compiler, *flr);
} else {
options = GenerateCompilerOptions(*xla_device_compiler, *flr, dev, stream,
platform_info,
/*has_ref_vars=*/false);
}
XlaCompiler::CompileOptions compile_options;
compile_options.always_return_tuple = false;
compile_options.alias_resource_update = true;
XlaCompiler compiler(options);
StatusOr<std::vector<XlaCompiler::Argument>> args =
XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_arg_indices, inputs, variable_infos, dev);
TF_RETURN_IF_ERROR(args.status());
xla::LocalClient* local_client = xla_device_compiler->client();
XlaCompiler::CompilationResult result;
TF_RETURN_IF_ERROR(
compiler.CompileFunction(compile_options, function, *args, &result));
switch (stage) { switch (stage) {
case IrExportStage::HLO: case IrExportStage::HLO:
case IrExportStage::HLO_NO_METADATA: case IrExportStage::HLO_NO_METADATA:
@ -232,4 +133,158 @@ StatusOr<std::string> GetCompilerIr(
} }
} }
static StatusOr<std::vector<XlaCompiler::Argument>>
BuildXlaCompilerArgumentFromFuncBody(const FunctionBody* fbody) {
TF_RET_CHECK(fbody != nullptr);
auto& input_args = fbody->fdef.signature().input_arg();
int input_arg_size = input_args.size();
// Shape info is not in input_arg. parse it from arg_attrs.
auto& arg_attrs = fbody->fdef.arg_attr();
if (arg_attrs.size() != input_arg_size) {
return errors::InvalidArgument(
"The function to be lowered uses some tf.Variable defined outside_"
"_the_function body. This is not supported with using_tensor_spec."
"Please modify the function with pure functional style.");
}
std::vector<TensorShape> shapes;
shapes.reserve(input_arg_size);
for (const auto& attr : arg_attrs) {
const unsigned int& idx = attr.first;
bool has_function_input_shape = false;
for (const auto& attr_value : attr.second.attr()) {
if (attr_value.first == "_output_shapes") {
TF_RETURN_IF_ERROR(TensorShape::BuildTensorShape(
attr_value.second.list().shape()[0], &shapes[idx]));
has_function_input_shape = true;
}
}
TF_RET_CHECK(has_function_input_shape);
}
// Build Xla Compiler Arguments
std::vector<XlaCompiler::Argument> args;
args.resize(input_arg_size);
for (int64_t input_num = 0; input_num < input_arg_size; ++input_num) {
XlaCompiler::Argument& arg = args[input_num];
arg.kind = XlaCompiler::Argument::kParameter;
arg.type = input_args[input_num].type();
arg.shape = shapes[input_num];
arg.name = input_args[input_num].name();
}
return args;
}
StatusOr<std::string> GetCompilerIr(
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
absl::string_view func_name, Device* dev, EagerContext* context,
absl::Span<const TensorHandle* const> inputs_handles) {
// input_handles vector is empty for using_tensor_spec case
bool using_tensor_spec = inputs_handles.empty() ? true : false;
using XlaDeviceCompiler =
DeviceCompiler<xla::LocalExecutable, xla::LocalClient>;
auto is_tfrt_tpu_supported_stage = [](IrExportStage stage) {
return stage == IrExportStage::HLO ||
stage == IrExportStage::HLO_NO_METADATA ||
stage == IrExportStage::HLO_SERIALIZED;
};
// TODO(b/238830423): support GetCompilerIr on TFRT TPU device for stages
// that requires compilation from HLO to executable.
if (dev->device_type() != DEVICE_CPU &&
dev->tensorflow_accelerator_device_info()->stream == nullptr &&
!is_tfrt_tpu_supported_stage(stage)) {
return errors::Internal(
"GetCompilerIr with requested stage is not supported on this device.");
}
NameAttrList function;
function.set_name(std::string{func_name});
FunctionLibraryRuntime* flr = pflr->GetFLR(dev->name());
ResourceMgr* rmgr = dev->resource_manager();
const FunctionBody* fbody = nullptr;
std::vector<int> constant_arg_indices;
std::vector<int> resource_arg_indices;
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
std::vector<const Tensor*> inputs;
std::deque<Tensor> inputs_storage;
inputs.reserve(inputs_handles.size());
std::vector<VariableInfo> variable_infos;
if (!using_tensor_spec) {
MemoryTypeVector input_memory_types =
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
for (int i = 0; i < inputs_handles.size(); i++) {
const TensorHandle* th = inputs_handles[i];
const Tensor* t;
// Handle owns the tensor.
TF_RETURN_IF_ERROR(th->Tensor(&t));
if (absl::c_binary_search(constant_arg_indices, i)) {
// Need to make sure it's on the host.
inputs_storage.emplace_back(t->dtype(), t->shape());
TF_RETURN_IF_ERROR(
th->CopyToDevice(*context, /*d=*/nullptr, &inputs_storage.back()));
inputs.push_back(&inputs_storage.back());
} else {
inputs.push_back(t);
}
}
TF_RETURN_IF_ERROR(GetVariableInfosFromInputs(
rmgr, dev, inputs, resource_arg_indices, &variable_infos));
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
}
XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(dev);
XlaDeviceCompiler* xla_device_compiler;
TF_RETURN_IF_ERROR(rmgr->LookupOrCreate<XlaDeviceCompiler>(
rmgr->default_container(), "xla_device_compiler", &xla_device_compiler,
[&](XlaDeviceCompiler** xla_device_compiler) {
return BuildXlaDeviceCompiler(dev, flr, platform_info,
xla_device_compiler);
}));
core::ScopedUnref xla_device_compiler_ref(xla_device_compiler);
se::Stream* stream = nullptr;
if (const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info =
dev->tensorflow_accelerator_device_info()) {
stream = accelerator_device_info->stream;
}
XlaCompiler::Options options;
if (platform_info.device_type() == DEVICE_TPU && stream == nullptr) {
options = GenerateTfrtTpuCompilerOptions(*xla_device_compiler, *flr);
} else {
options = GenerateCompilerOptions(*xla_device_compiler, *flr, dev, stream,
platform_info,
/*has_ref_vars=*/false);
}
XlaCompiler::CompileOptions compile_options;
compile_options.always_return_tuple = false;
compile_options.alias_resource_update = true;
XlaCompiler compiler(options);
StatusOr<std::vector<XlaCompiler::Argument>> args;
if (using_tensor_spec) {
args = BuildXlaCompilerArgumentFromFuncBody(fbody);
} else {
args = XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_arg_indices, inputs, variable_infos, dev);
}
TF_RETURN_IF_ERROR(args.status());
xla::LocalClient* local_client = xla_device_compiler->client();
XlaCompiler::CompilationResult result;
TF_RETURN_IF_ERROR(
compiler.CompileFunction(compile_options, function, *args, &result));
return BuildHLOString(stage, result, local_client, options);
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -63,6 +63,7 @@ py_library(
visibility = ["//tensorflow:internal"], visibility = ["//tensorflow:internal"],
deps = [ deps = [
":attributes", ":attributes",
":compiler_ir",
":function_spec", ":function_spec",
"//tensorflow/python:cond_v2", # TODO(b/118513001): Imported via control_flow_ops; remove. "//tensorflow/python:cond_v2", # TODO(b/118513001): Imported via control_flow_ops; remove.
"//tensorflow/python:control_flow_ops", "//tensorflow/python:control_flow_ops",
@ -380,3 +381,48 @@ tf_py_test(
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
], ],
) )
py_library(
name = "compiler_ir",
srcs = ["compiler_ir.py"],
srcs_version = "PY3",
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python:random_ops",
"//tensorflow/python/eager:context",
"//tensorflow/python/util",
],
)
tf_xla_py_test(
name = "compiler_ir_test",
srcs = ["compiler_ir_test.py"],
disabled_backends = [
"cpu_ondemand",
],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_mac",
"no_pip",
"no_tfrt", # TODO(b/185944215)
"no_windows",
],
use_xla_device = False,
deps = [
":compiler_ir",
":polymorphic_function",
"//tensorflow/compiler/tests:xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:array_ops_gen",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:variables",
"//tensorflow/python/framework:dtypes",
"//tensorflow/python/framework:tensor_spec",
"//tensorflow/python/framework:test_lib",
"//tensorflow/python/util",
],
)

View File

@ -0,0 +1,70 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implmentation for defining get_compiler_ir."""
from tensorflow.python.eager import context
from tensorflow.python.ops import random_ops
from tensorflow.python.util import nest
def maybe_get_device_name(device_name):
# TODO(cheshire): This is a hack to get the current "preferred" device,
# there is no current API to get it otherwise.
if device_name is None:
device_name = random_ops.random_normal([]).device
return device_name
def from_concrete_function(concrete_fn):
"""Generate the Compiler Ir from tf concrete function.
Args:
concrete_fn: returned by using get_concrete_function.
Returns:
Function callable that generate the HLO text.
Raises:
ValueError: if concrete_fn is not "compilable" without concrete
inputs.
"""
context.ensure_initialized()
# TODO(b/265073174) support users input tf.TensorSpec list here.
if not all(
[s.shape.is_fully_defined() for s in nest.flatten(concrete_fn.inputs)]
):
raise ValueError(
f"Only support static input shape but got inputs = {concrete_fn.inputs}"
)
fn_name = concrete_fn.name
def compiler_ir_generator(stage="hlo", device_name=None):
device_name = maybe_get_device_name(device_name)
res_bytes = context.context().get_compiler_ir(
device_name=device_name,
stage=stage,
function_name=fn_name,
# args list is empty for using_tensor_spec case
args=[],
)
if stage in (
"hlo_serialized",
"optimized_hlo_serialized",
"optimized_hlo_proto_serialized",
):
return res_bytes
else:
return res_bytes.decode("utf-8")
return compiler_ir_generator

View File

@ -0,0 +1,176 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from tensorflow.compiler.tests import xla_test
from tensorflow.python.eager.polymorphic_function import compiler_ir
from tensorflow.python.eager.polymorphic_function import polymorphic_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.util import nest
class CompilerIrTest(xla_test.XLATestCase):
def _compareTwoMethodsCompilerIROutput(self, f, args, kwargs):
flat_args = list(args) + list(kwargs.values())
if not all([isinstance(x, ops.Tensor) for x in flat_args]):
self.skipTest('It only support args and kwargs are all tf.Tensor types.')
args_spec = nest.map_structure(tensor_spec.TensorSpec.from_tensor, args)
kwargs_spec = nest.map_structure(tensor_spec.TensorSpec.from_tensor, kwargs)
hlo_1 = f.experimental_get_compiler_ir(*args, **kwargs)()
hlo_2 = f.experimental_get_compiler_ir(*args_spec, **kwargs_spec)()
if hlo_1 != hlo_2:
self.fail(
'The tensor_spec way experimental_get_compiler_ir give diff result to'
f' normal experimental_get_compiler_ir. \nhlo_1:\n{hlo_1}'
f'\nhlo_2:\n{hlo_2}\n'
)
def test_zero_input(self):
with ops.device('device:{}:0'.format(self.device)):
@polymorphic_function.function(jit_compile=True, autograph=False)
def fun_tf():
return array_ops.zeros((10), dtype=dtypes.int32)
self._compareTwoMethodsCompilerIROutput(fun_tf, [], {})
def test_constant_slice(self):
with ops.device('device:{}:0'.format(self.device)):
# Constant slice. This is the common case.
x = array_ops.zeros((10,), dtype=dtypes.int32)
@polymorphic_function.function(jit_compile=True, autograph=False)
def fun_tf(x):
begin = 0
return x[begin:5]
self._compareTwoMethodsCompilerIROutput(fun_tf, [x], {})
def test_compile_time_constant(self):
with ops.device('device:{}:0'.format(self.device)):
# Non-constant slice, but compile-time constant depending only on shapes.
x = array_ops.zeros((10,), dtype=dtypes.int32)
@polymorphic_function.function(jit_compile=True, autograph=False)
def fun_tf(x):
# begin is a compile-time constant, even if x is not
begin = array_ops.shape_v2(x)[0] - 2
return x[begin:]
self._compareTwoMethodsCompilerIROutput(fun_tf, [x], {})
def test_capture_constant(self):
with ops.device('device:{}:0'.format(self.device)):
# Capture a constant
outer_ct = [3.0]
x = ops.convert_to_tensor([2.0, 3.0, 4.0], dtype=dtypes.float32)
@polymorphic_function.function(jit_compile=True, autograph=False)
def fun_tf(x):
return x * gen_array_ops.broadcast_to(outer_ct, x.shape) + 1.0
self._compareTwoMethodsCompilerIROutput(fun_tf, [x], {})
def test_unsupported_dynamic_input(self):
with ops.device('device:{}:0'.format(self.device)):
@polymorphic_function.function(jit_compile=True)
def f(x):
return x
with self.assertRaisesRegex(
ValueError, 'Only support static input shape but got'
):
args_spec = [tensor_spec.TensorSpec((None), dtype=dtypes.float32)]
concrete_fn = f.get_concrete_function(*args_spec)
_ = compiler_ir.from_concrete_function(concrete_fn)(stage='hlo')
def test_unsupported_shape_depend_input(self):
with ops.device('device:{}:0'.format(self.device)):
# Those cases output shapes are dynamic.
@polymorphic_function.function(jit_compile=True)
def f2(x):
return x[x[0] : 0]
args = [ops.convert_to_tensor([1, 2, 3, 4])]
args_spec = nest.map_structure(tensor_spec.TensorSpec.from_tensor, args)
concrete_fn = f2.get_concrete_function(*args_spec)
if test_util.is_mlir_bridge_enabled():
with self.assertRaisesRegex(
ValueError, 'TF to XLA legalization failed'
):
_ = compiler_ir.from_concrete_function(concrete_fn)(stage='hlo')
else:
_ = compiler_ir.from_concrete_function(concrete_fn)(stage='hlo')
# Those cases both input and output shapes are static but tf graph
# contains constant args inside.
@polymorphic_function.function(jit_compile=True)
def f3(a, b):
c = array_ops.slice(a, [b], [-1])
return math_ops.reduce_sum(c)
args = [ops.convert_to_tensor([2, 3]), ops.convert_to_tensor(1)]
args_spec = nest.map_structure(tensor_spec.TensorSpec.from_tensor, args)
concrete_fn = f3.get_concrete_function(*args_spec)
if test_util.is_mlir_bridge_enabled():
with self.assertRaisesRegex(
ValueError, 'TF to XLA legalization failed'
):
_ = compiler_ir.from_concrete_function(concrete_fn)(stage='hlo')
else:
_ = compiler_ir.from_concrete_function(concrete_fn)(stage='hlo')
def test_unsupported_capture_outside_variable(self):
"""Those cases define tf.Variable outside function body."""
with ops.device('device:{}:0'.format(self.device)):
error_msg = (
'The function to be lowered uses some tf.Variable defined outside_'
'_the_function body.'
)
v = variables.Variable([0.1, 0.1])
@polymorphic_function.function(jit_compile=True)
def f4(a, b):
return (a + b) * v
a = constant_op.constant([1.1, 1.1])
b = constant_op.constant([2.2, 2.2])
kwargs = {'b': a, 'a': b}
with self.assertRaisesRegex(ValueError, error_msg):
kwargs_spec = nest.map_structure(
tensor_spec.TensorSpec.from_tensor, kwargs
)
concrete_fn = f4.get_concrete_function(**kwargs_spec)
_ = compiler_ir.from_concrete_function(concrete_fn)(stage='hlo')
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()

View File

@ -75,17 +75,18 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import lift_to_graph from tensorflow.python.eager import lift_to_graph
from tensorflow.python.eager import monitoring from tensorflow.python.eager import monitoring
from tensorflow.python.eager.polymorphic_function import attributes as attributes_lib from tensorflow.python.eager.polymorphic_function import attributes as attributes_lib
from tensorflow.python.eager.polymorphic_function import compiler_ir
from tensorflow.python.eager.polymorphic_function import function_spec as function_spec_lib from tensorflow.python.eager.polymorphic_function import function_spec as function_spec_lib
from tensorflow.python.eager.polymorphic_function import tracing_compiler from tensorflow.python.eager.polymorphic_function import tracing_compiler
from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import func_graph as func_graph_module from tensorflow.python.framework import func_graph as func_graph_module
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.profiler import trace from tensorflow.python.profiler import trace
@ -998,6 +999,32 @@ class Function(core.GenericFunction, trackable.Trackable):
raise ValueError("Compiler IR can only be returned for functions marked " raise ValueError("Compiler IR can only be returned for functions marked "
"with 'jit_compile=True'") "with 'jit_compile=True'")
is_tensor_spec = lambda x: isinstance(x, tensor_spec.TensorSpec)
def _check_inputs(args, kwargs):
all_inputs = list(args) + list(kwargs.values())
# Emtpy input is okay.
if not all_inputs:
return
if any(map(is_tensor_spec, all_inputs)) and any(
map(lambda x: not is_tensor_spec(x), all_inputs)
):
raise ValueError(
"experimental_get_compiler_ir supports either "
"(1) all inputs are TensorSpec or "
"(2) all inputs are tf.Tensor/python variables"
)
_check_inputs(args, kwargs)
if (
len(args) + len(kwargs.values()) > 0
and all(map(is_tensor_spec, args))
and all(map(is_tensor_spec, kwargs.values()))
):
# For the case inputs are not empty and input types are all tf.TensorSpec
concrete_fn = self.get_concrete_function(*args, **kwargs)
return compiler_ir.from_concrete_function(concrete_fn)
concrete_fn = self.get_concrete_function(*args, **kwargs) concrete_fn = self.get_concrete_function(*args, **kwargs)
fn_name = concrete_fn.name fn_name = concrete_fn.name
@ -1006,10 +1033,7 @@ class Function(core.GenericFunction, trackable.Trackable):
concrete_fn._function_spec.canonicalize_function_inputs(args, kwargs)) concrete_fn._function_spec.canonicalize_function_inputs(args, kwargs))
def compiler_ir_generator(stage="hlo", device_name=None): def compiler_ir_generator(stage="hlo", device_name=None):
# TODO(cheshire): This is a hack to get the current "preferred" device, device_name = compiler_ir.maybe_get_device_name(device_name)
# there is no current API to get it otherwise.
if device_name is None:
device_name = random_ops.random_normal([]).device
res_bytes = context.context().get_compiler_ir( res_bytes = context.context().get_compiler_ir(
device_name=device_name, device_name=device_name,
stage=stage, stage=stage,

View File

@ -38,11 +38,31 @@ from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.util import nest
@test_util.with_eager_op_as_function @test_util.with_eager_op_as_function
class FunctionTest(xla_test.XLATestCase): class FunctionTest(xla_test.XLATestCase):
def _compareTwoMethodsCompilerIROutput(self, f, args, kwargs):
"""Assert the two differnet methods (tensor_spec inputs or tensor inputs) experimental_get_compiler give same HLO text."""
flat_args = list(args) + list(kwargs.values())
if not all([isinstance(x, ops.Tensor) for x in flat_args]):
self.skipTest('It only support args and kwargs are all tf.Tensor types.')
args_spec = nest.map_structure(tensor_spec.TensorSpec.from_tensor, args)
kwargs_spec = nest.map_structure(tensor_spec.TensorSpec.from_tensor, kwargs)
hlo_1 = f.experimental_get_compiler_ir(*args, **kwargs)()
hlo_2 = f.experimental_get_compiler_ir(*args_spec, **kwargs_spec)()
if hlo_1 != hlo_2:
self.fail(
'The tensor_spec way experimental_get_compiler_ir give diff result to'
f' normal experimental_get_compiler_ir. \nhlo_1:\n{hlo_1}'
f'\nhlo_2:\n{hlo_2}\n'
)
def testAutoclusteringWithTfFunction(self): def testAutoclusteringWithTfFunction(self):
if 'tpu' in self.device.lower(): if 'tpu' in self.device.lower():
self.skipTest('Autoclustering does not run on TPU') self.skipTest('Autoclustering does not run on TPU')
@ -189,6 +209,7 @@ class FunctionTest(xla_test.XLATestCase):
matches = re.findall('channel_id=([0-9]*),', hlo_str) matches = re.findall('channel_id=([0-9]*),', hlo_str)
self.assertLen(matches, 2) self.assertLen(matches, 2)
self.assertNotEqual(matches[0], matches[1]) self.assertNotEqual(matches[0], matches[1])
self._compareTwoMethodsCompilerIROutput(fn, [inputs, inputs], {})
def testCollectiveReduceGroupAssignment(self): def testCollectiveReduceGroupAssignment(self):
if not test_util.is_mlir_bridge_enabled(): if not test_util.is_mlir_bridge_enabled():
@ -209,6 +230,7 @@ class FunctionTest(xla_test.XLATestCase):
# instructions generated by XLA. # instructions generated by XLA.
hlo_str = fn.experimental_get_compiler_ir(inputs)() hlo_str = fn.experimental_get_compiler_ir(inputs)()
self.assertIn('replica_groups={{0}}', hlo_str) self.assertIn('replica_groups={{0}}', hlo_str)
self._compareTwoMethodsCompilerIROutput(fn, [inputs], {})
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not' @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
'support stack traces') 'support stack traces')
@ -222,6 +244,7 @@ class FunctionTest(xla_test.XLATestCase):
inputs = constant_op.constant([1, 2, 2, 3, 3]) inputs = constant_op.constant([1, 2, 2, 3, 3])
self.assertIn('polymorphic_function_xla_jit_test', self.assertIn('polymorphic_function_xla_jit_test',
fn.experimental_get_compiler_ir(inputs, inputs)()) fn.experimental_get_compiler_ir(inputs, inputs)())
self._compareTwoMethodsCompilerIROutput(fn, [inputs, inputs], {})
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not' @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
'support stack traces') 'support stack traces')
@ -239,6 +262,7 @@ class FunctionTest(xla_test.XLATestCase):
inputs = constant_op.constant([1, 2, 2, 3, 3]) inputs = constant_op.constant([1, 2, 2, 3, 3])
self.assertIn('polymorphic_function_xla_jit_test', self.assertIn('polymorphic_function_xla_jit_test',
g.experimental_get_compiler_ir(inputs, inputs)()) g.experimental_get_compiler_ir(inputs, inputs)())
self._compareTwoMethodsCompilerIROutput(g, [inputs, inputs], {})
def testPythonStackTrace(self): def testPythonStackTrace(self):
with ops.device('device:{}:0'.format(self.device)): with ops.device('device:{}:0'.format(self.device)):
@ -864,6 +888,7 @@ class FunctionTest(xla_test.XLATestCase):
self.assertIn( self.assertIn(
'label', 'label',
f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo_dot')) f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo_dot'))
self._compareTwoMethodsCompilerIROutput(f, [a, b], {})
def testGetCompilerIrNoDevicePlacement(self): def testGetCompilerIrNoDevicePlacement(self):
if 'gpu' not in self.device.lower(): if 'gpu' not in self.device.lower():
@ -879,6 +904,7 @@ class FunctionTest(xla_test.XLATestCase):
self.assertIn( self.assertIn(
'label', 'label',
f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo_dot')) f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo_dot'))
self._compareTwoMethodsCompilerIROutput(f, [a, b], {})
def testGetCompilerIrNonTensors(self): def testGetCompilerIrNonTensors(self):
with ops.device('device:{}:0'.format(self.device)): with ops.device('device:{}:0'.format(self.device)):
@ -891,6 +917,7 @@ class FunctionTest(xla_test.XLATestCase):
self.assertIn('tuple', self.assertIn('tuple',
f.experimental_get_compiler_ir(l)()) f.experimental_get_compiler_ir(l)())
self._compareTwoMethodsCompilerIROutput(f, [l], {})
def testGetCompilerIrSerialized(self): def testGetCompilerIrSerialized(self):
with ops.device('device:{}:0'.format(self.device)): with ops.device('device:{}:0'.format(self.device)):
@ -904,6 +931,7 @@ class FunctionTest(xla_test.XLATestCase):
hlo = fn.experimental_get_compiler_ir(inputs)( hlo = fn.experimental_get_compiler_ir(inputs)(
stage=stage, device_name=f'/device:{self.device}:0') stage=stage, device_name=f'/device:{self.device}:0')
self.assertIsInstance(hlo, bytes) self.assertIsInstance(hlo, bytes)
self._compareTwoMethodsCompilerIROutput(fn, [inputs], {})
def testDotOptimizedHlo(self): def testDotOptimizedHlo(self):
with ops.device('device:{}:0'.format(self.device)): with ops.device('device:{}:0'.format(self.device)):
@ -1171,12 +1199,35 @@ class FunctionTest(xla_test.XLATestCase):
def f(x): def f(x):
return math_ops.argmax(x) return math_ops.argmax(x)
hlo = f.experimental_get_compiler_ir( inputs = array_ops.ones([10], dtype=dtypes.float32)
array_ops.ones([10], dtype=dtypes.float32))( hlo = f.experimental_get_compiler_ir(inputs)(stage='hlo')
stage='hlo')
# Test that reduction occurs only once. # Test that reduction occurs only once.
self.assertGreater(hlo.count('reduce'), 1) self.assertGreater(hlo.count('reduce'), 1)
self._compareTwoMethodsCompilerIROutput(f, [inputs], {})
def testExperimentalGetCompilerIRBasic(self):
with ops.device('device:{}:0'.format(self.device)):
@polymorphic_function.function(jit_compile=True)
def inner_tf_func(x):
return math_ops.sin(x)
x = constant_op.constant([2.0, 3.0])
self._compareTwoMethodsCompilerIROutput(inner_tf_func, [x], {})
def testExperimentalGetCompilerIRAutograph(self):
with ops.device('device:{}:0'.format(self.device)):
@polymorphic_function.function(jit_compile=True, autograph=True)
def f(x, y):
if x[0] > 1:
return y[0]
else:
return y[1]
x, y = constant_op.constant([2, 3]), constant_op.constant([2, 3])
self._compareTwoMethodsCompilerIROutput(f, [x, y], {})
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -179,9 +179,10 @@ class GenericFunction(Callable):
backwards compatibility of returned IR or the allowed values of `stage`. backwards compatibility of returned IR or the allowed values of `stage`.
Args: Args:
*args: Arguments used for compilation; same arguments as used for calling *args: compilation args supports inputs either: (1) all inputs are
the function. Need to be eager tensors. TensorSpec or (2) all inputs are tf.Tensor/Python variables.
**kwargs: Keyword arguments used for compilation. **kwargs: Keyword arguments used for compilation. Same requirement as
compiliation args.
Returns: Returns:
Function callable with the following kwargs: Function callable with the following kwargs:
@ -230,8 +231,11 @@ class GenericFunction(Callable):
``` ```
Raises: Raises:
ValueError: If an invalid `stage` is selected or if applied to a function ValueError:
which is not compiled (`jit_compile=True` is not set). (1) If an invalid `stage` is selected
(2) or if applied to a function which is not compiled
(`jit_compile=True` is not set).
(3) or if input shapes are not fully defined for tf.TensorSpec inputs
TypeError: When called with input in graph mode. TypeError: When called with input in graph mode.
""" """
pass pass