Merge branch 'tensorflow:master' into to_ordinal

This commit is contained in:
Awsaf 2023-01-26 01:27:12 +06:00 committed by GitHub
commit a5303845f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
212 changed files with 11951 additions and 1485 deletions

View File

@ -39,9 +39,9 @@ const get_email_domain = async ({github, username}) => {
return domain;
};
/** For trusted parters like Intel, we want to auto-run tests and mark the PR as ready to pull
/** For trusted parters like Intel, we want to auto-run tests
This allows us to reduce the delay to external partners
Add Labels - kokoro:force-run, ready to pull
Add Labels - kokoro:force-run
The PR is also assigned to specific teams to fast track review
Additional reviewers can be added manually based on PR contents
@param {!object}
@ -50,7 +50,7 @@ const get_email_domain = async ({github, username}) => {
@return {string} Returns the message with labels attached and assignees added
*/
const filter_action = async ({github, context, domain}) => {
const labels = ['kokoro:force-run', 'ready to pull'];
const labels = ['kokoro:force-run'];
let assignees = [];
const title = context.payload.pull_request && context.payload.pull_request.title;

View File

@ -5,6 +5,16 @@
* <DOCUMENT BREAKING CHANGES HERE>
* <THIS SECTION SHOULD CONTAIN API, ABI AND BEHAVIORAL BREAKING CHANGES>
* Build, Compilation and Packaging
* Removal of redundant packages: the `tensorflow-gpu` and `tf-nightly-gpu`
packages have been effectively removed and replaced with packages that
direct users to switch to `tensorflow` or `tf-nightly` respectively.
The naming difference was the only difference between the two sets of
packages ever since TensorFlow 2.1, so there is no loss of functionality
or GPU support. See
https://pypi.org/project/tensorflow-gpu for more details.
* `tf.function`:
* tf.function now uses the Python inspect library directly for parsing
@ -126,6 +136,8 @@
* `tf.test`:
* Added `tf.test.experimental.sync_devices`, which is useful for
accurately measuring performance in benchmarks.
* `tf.experimental.dtensor`:
* Added experimental support to ReduceScatter fuse on GPU (NCCL).
# Bug Fixes and Other Changes

View File

@ -970,6 +970,7 @@ package_group(
"//learning/brain/tfrt/...",
"//learning/lib/ami/simple_ml/...",
"//learning/pathways/...",
"//learning/serving/contrib/tfrt/mlir/canonical_ops/...",
"//perftools/accelerators/xprof/integration_tests/...",
"//smartass/brain/configure/...",
"//tensorflow/...",

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "tensorflow/c/kernels_experimental.h"
#include "tensorflow/c/tf_status_helper.h"
@ -129,6 +130,17 @@ TF_VariableInfo* TF_CreateVariableInfoFromContext(TF_OpKernelContext* ctx,
return new TF_VariableInfo(index, handle.name(), variable);
}
void TF_LockVariableInfos(TF_VariableInfo** vars, int num_vars,
TF_Status* status) {
std::vector<tensorflow::VariableInfo*> variable_ptrs;
variable_ptrs.reserve(num_vars);
for (int i = 0; i < num_vars; ++i) {
variable_ptrs.push_back(&(vars[i]->var_info));
}
tsl::Status cc_status = LockVariables(absl::MakeSpan(variable_ptrs));
tsl::Set_TF_Status_from_Status(status, cc_status);
}
void TF_AllocateTempForVariableInfo(TF_OpKernelContext* ctx,
TF_VariableInfo* var_info,
TF_Status* status) {

View File

@ -89,6 +89,10 @@ TF_CAPI_EXPORT extern void TF_LookupOrCreatePluginResource(
TF_CAPI_EXPORT extern TF_VariableInfo* TF_CreateVariableInfoFromContext(
TF_OpKernelContext* ctx, int index, TF_Status* status);
TF_CAPI_EXPORT extern void TF_LockVariableInfos(TF_VariableInfo** vars,
int num_vars,
TF_Status* status);
TF_CAPI_EXPORT extern void TF_AllocateTempForVariableInfo(
TF_OpKernelContext* ctx, TF_VariableInfo* var_info, TF_Status* status);

View File

@ -0,0 +1,28 @@
load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library")
load("//tensorflow/compiler/mlir/quantization/stablehlo:internal_visibility_allowlist.bzl", "internal_visibility_allowlist")
package_group(
name = "internal_visibility_allowlist_package",
packages = [
"//tensorflow/compiler/mlir/lite/...",
"//tensorflow/compiler/mlir/quantization/...",
"//third_party/cloud_tpu/inference_converter/...", # TPU Inference Converter V1
] + internal_visibility_allowlist(),
)
tf_proto_library(
name = "quantization_options_proto",
srcs = ["quantization_options.proto"],
cc_api_version = 2,
make_default_target_header_only = True,
visibility = [":internal_visibility_allowlist_package"],
)
# copybara:uncomment_begin(google-only)
# py_proto_library(
# name = "quantization_options_py_pb2",
# api_version = 2,
# visibility = [":internal_visibility_allowlist_package"],
# deps = [":quantization_options_proto"],
# )
# copybara:uncomment_end

View File

@ -0,0 +1,10 @@
"""Internal visibility rules."""
def internal_visibility_allowlist():
"""Returns a list of g3 packages that can depend on internal targets."""
return [
"//learning/brain/experimental/mlir/quantization/...",
"//learning/brain/mlir/quantization/tensorflow/...",
"//learning/brain/mobile/programmability/...",
"//learning/brain/experimental/tfq/...",
]

View File

@ -0,0 +1,107 @@
syntax = "proto3";
package stablehlo.quantization;
option cc_enable_arenas = true;
// Defines arious options to specify and control the behavior of the
// StableHLO quantizer.
// NEXT ID: 2
message QuantizationOptions {
QuantizationMethod quantization_method = 1;
}
// NEXT ID: 3
message QuantizationMethod {
// Quantization Method can be either preset or custom.
oneof quantization_method {
PresetQuantizationMethod preset_quantization_method = 1;
CustomQuantizationMethod custom_quantization_method = 2;
}
}
// Preset model quantization method for optimization.
//
// Common quantization methods are defined as preset methods in this message.
// Note that the quantization specs (ex: bit width) for preset quantization
// methods are fixed. To use a different quantization spec for a given method,
// use CustomQuantizationMethod.
// NEXT ID: 2
message PresetQuantizationMethod {
// Preset quantization methods that are supported as a stable API.
// NEXT ID: 3
enum PresetMethod {
// TODO(b/266173150): Update preset methods after redefining quantization
// pattern matching in DarwiNN.
// This should never be used. Using this will generally result in an error.
METHOD_UNSPECIFIED = 0; // go/do-include-enum-unspecified
// Apply default weight-only quantization. Weights are quantized during
// conversion, then dequantized during inference. Data type is as follows:
// Weight: i8, Bias: f32, Activation: f32, Input/output: f32
WEIGHT_ONLY = 1;
// Apply default dynamic range quantization. Quantized tensor value's
// ranges are determined during graph runtime. Data type is as follows:
// Weight: i8, Bias: f32, Activation: f32, Input/output: f32
DYNAMIC_RANGE = 2;
}
PresetMethod preset_method = 1;
}
// Custom option for specifying quantization spec details.
// If the selected quantization option is not available, StableHLO quantizer
// will raise an error.
// NEXT ID: 2
message CustomQuantizationMethod {
// Specify component name, bit width, and other specs for all compoenents
// intended to be quantized.
repeated QuantizationComponentSpec quantization_component_spec = 1;
}
// Quantization spec per each component designated to be quantized.
// Components whose QuantizationComponentSpec is not specified will not be
// quantized, and remain f32.
// NEXT ID: 7
message QuantizationComponentSpec {
// NEXT ID: 4
enum QuantizationComponent {
COMPONENT_UNSPECIFIED = 0;
COMPONENT_ACTIVATION = 1;
COMPONENT_WEIGHT = 2;
COMPONENT_BIAS = 3;
}
// NEXT ID: 4
enum BitWidth {
BIT_WIDTH_UNSPECIFIED = 0;
BIT_WIDTH_4 = 1;
BIT_WIDTH_8 = 2;
BIT_WIDTH_16 = 3;
}
// NEXT ID: 4
enum BitType {
BIT_TYPE_UNSPECIFIED = 0;
BIT_TYPE_INT = 1;
BIT_TYPE_FLOAT = 2;
BIT_TYPE_BFLOAT = 3;
}
QuantizationComponent quantization_component = 1;
// Defines the target bit of the data.
BitWidth bit_width = 2;
// Defines the type of data of the quantized component.
BitType bit_type = 3;
// Defines whether quantization is done in narrow range.
bool enable_narrow_range = 4;
// Defines whether quantiation is done per-channel.
bool enable_per_channel_quantization = 5;
// Defines whether quantization is done symmetrically.
bool enable_symmetric = 6;
}

View File

@ -289,9 +289,9 @@ cc_library(
hdrs = ["ops/uniform_op_quant_spec.h"],
compatible_with = get_compatible_with_cloud(),
deps = [
":tf_quant_ops",
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/quantization/tensorflow:tf_quant_ops",
"//tensorflow/compiler/mlir/tensorflow",
"@com_google_absl//absl/container:flat_hash_set",
"@llvm-project//mlir:IR",
@ -423,6 +423,7 @@ cc_library(
],
compatible_with = get_compatible_with_cloud(),
deps = [
":passes",
"//tensorflow/compiler/mlir/quantization/tensorflow/debugging:mlir_dump",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:error_util",
@ -434,6 +435,7 @@ cc_library(
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:path",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",

View File

@ -10,6 +10,7 @@ option cc_enable_arenas = true;
// metadata required for building a SavedModel. This message is primarily used
// to "export" the model produced from various quantization passes in c++ to
// Python layer.
// Next ID: 7
message ExportedModel {
GraphDef graph_def = 1;
@ -29,4 +30,11 @@ message ExportedModel {
// fetching the restore op (see `restore_node_name`), this value is provided
// to the "file_prefix" tensor to identify the checkpoint directory.
string checkpoint_dir = 5;
// Function name -> function alias mapping. This associates the quantized
// functions to the original functions' aliases. This information will be used
// to populate `MetaInfoDef`s `function_aliases` when the quantized model is
// exported to the saved model. This field is usually only populated for the
// TF2 models.
map<string, string> function_aliases = 6;
}

View File

@ -59,6 +59,8 @@ cc_library(
"//tensorflow/tsl/platform:env",
"//tensorflow/tsl/platform:path",
"//tensorflow/tsl/platform:status",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",

View File

@ -53,6 +53,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import builder
from tensorflow.python.saved_model import loader_impl as saved_model_loader
from tensorflow.python.saved_model import save as saved_model_save
from tensorflow.python.saved_model import save_options
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils_impl
from tensorflow.python.saved_model import tag_constants
@ -1852,6 +1853,75 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest):
self.assertAllClose(new_outputs, got_outputs, atol=0.0666)
self.assertAllClose(new_outputs, expected_outputs, atol=0.057)
@test_util.run_in_graph_and_eager_modes
def test_function_alias_preserved(self):
model = self._create_conv2d_model(
input_shape=(1, 3, 4, 3), filter_shape=(2, 3, 3, 2)
)
signatures = {
'serving_default': model.conv.get_concrete_function(),
}
save_opts = save_options.SaveOptions(
function_aliases={'conv_func': model.conv}
)
saved_model_save.save(
model, self._input_saved_model_path, signatures, save_opts
)
def data_gen() -> repr_dataset.RepresentativeDataset:
rng = np.random.default_rng(seed=123)
for _ in range(2):
yield {
'input_tensor': rng.uniform(
low=0, high=150, size=(1, 3, 4, 3)
).astype(np.float32),
}
tags = {tag_constants.SERVING}
quantization_options = quant_opts_pb2.QuantizationOptions(
quantization_method=quant_opts_pb2.QuantizationMethod(
experimental_method=_ExperimentalMethod.STATIC_RANGE
),
op_set=quant_opts_pb2.OpSet.XLA,
)
converted_model = quantize_model.quantize(
self._input_saved_model_path,
['serving_default'],
tags,
self._output_saved_model_path,
quantization_options,
representative_dataset=data_gen(),
)
self.assertIsNotNone(converted_model)
self.assertCountEqual(
converted_model.signatures._signatures.keys(), {'serving_default'}
)
# Test whether the aliased function exists.
output_loader = saved_model_loader.SavedModelLoader(
self._output_saved_model_path
)
# Confirm that the function alias is preserved.
meta_graph_def = output_loader.get_meta_graph_def_from_tags(tags)
function_aliases = meta_graph_def.meta_info_def.function_aliases
self.assertNotEmpty(function_aliases)
self.assertCountEqual(function_aliases.values(), {'conv_func'})
# Test that the aliased function contains a quantized op.
for func_name, alias in function_aliases.items():
if alias == 'conv_func':
for func in meta_graph_def.graph_def.library.function:
if func.signature.name == func_name:
self._contains_op_with_name_and_attribute(
func.node_def, op_name='XlaConvV2', attr_name='', attr_val=None
)
@test_util.deprecated_graph_mode_only
def test_matmul_ptq_model_with_unfreeze_constants(self):
# Uses large weight to exceed the constant size threshold of 64KiB

View File

@ -176,11 +176,12 @@ PYBIND11_MODULE(pywrap_quantize_model, m) {
[](const absl::string_view saved_model_path,
const std::vector<std::string>& signature_keys,
const std::unordered_set<std::string>& tags,
const QuantizationOptions& quant_opts)
const QuantizationOptions& quant_opts,
const absl::flat_hash_map<std::string, std::string>& function_aliases)
-> absl::StatusOr<ExportedModel> {
return tensorflow::quantization::internal::
QuantizePtqModelPreCalibration(saved_model_path, signature_keys,
tags, quant_opts);
tags, quant_opts, function_aliases);
},
R"pbdoc(
Returns serialized ExportedModel that contains the model's GraphDef and
@ -196,11 +197,12 @@ PYBIND11_MODULE(pywrap_quantize_model, m) {
[](const absl::string_view saved_model_path,
const std::vector<std::string>& signature_keys,
const std::unordered_set<std::string>& tags,
const QuantizationOptions& quant_opts)
const QuantizationOptions& quant_opts,
const absl::flat_hash_map<std::string, std::string>& function_aliases)
-> absl::StatusOr<ExportedModel> {
return tensorflow::quantization::internal::
QuantizePtqModelPostCalibration(saved_model_path, signature_keys,
tags, quant_opts);
tags, quant_opts, function_aliases);
},
R"pbdoc(
Returns serialized ExportedModel that contains the quantized model's

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
@ -116,6 +117,11 @@ void AddExportPasses(const bool duplicate_shape_determining_constants,
mlir::CreateFunctionalToExecutorDialectConversionPass());
pm.addPass(mlir::CreateBreakUpIslandsPass());
pm.addPass(mlir::quant::CreateMergeInitializerFunctionOpsToMainPass());
// Used to clean up the "tf._noinliner" attribute that is previously used to
// prevent certain functions from being inlined (see
// `MarkFunctionsNoinlinePass`). InlinerPass must not come after this pass.
pm.addPass(mlir::TF::CreateStripNoinlineAttributePass());
}
// Finds and returns the name of the node from a set of control output nodes.
@ -139,7 +145,8 @@ std::string GetNodeName(const absl::flat_hash_set<Node *> &control_ret_nodes,
GraphDef &&graph_def, const absl::string_view init_node_name,
const absl::string_view restore_node_name,
const absl::string_view checkpoint_dir,
const std::vector<std::string> &variable_shared_names) {
const std::vector<std::string> &variable_shared_names,
const absl::flat_hash_map<std::string, std::string> &function_aliases) {
ExportedModel exported_model{};
*exported_model.mutable_graph_def() = graph_def;
exported_model.set_init_node_name(std::string(init_node_name));
@ -148,6 +155,8 @@ std::string GetNodeName(const absl::flat_hash_set<Node *> &control_ret_nodes,
for (auto &shared_name : variable_shared_names) {
*exported_model.mutable_variable_shared_names()->Add() = shared_name;
}
exported_model.mutable_function_aliases()->insert(function_aliases.begin(),
function_aliases.end());
return exported_model;
}
@ -156,7 +165,8 @@ std::string GetNodeName(const absl::flat_hash_set<Node *> &control_ret_nodes,
// when the conversion fails.
absl::StatusOr<ExportedModel> ConvertMlirModuleToExportedModel(
const mlir::ModuleOp module_op, const absl::string_view checkpoint_dir,
const std::vector<std::string> &variable_shared_names) {
const std::vector<std::string> &variable_shared_names,
const absl::flat_hash_map<std::string, std::string> &function_aliases) {
const GraphExportConfig config{};
FunctionLibraryDefinition flib_def{OpRegistry::Global(),
FunctionDefLibrary()};
@ -179,7 +189,7 @@ absl::StatusOr<ExportedModel> ConvertMlirModuleToExportedModel(
return CreateExportedModel(std::move(graph_def), init_node_name,
restore_node_name, checkpoint_dir,
variable_shared_names);
variable_shared_names, function_aliases);
}
// Runs MLIR passes with `module_op`. The passes are added by calling
@ -387,14 +397,49 @@ absl::StatusOr<ExportedModel> QuantizeQatModel(
}
return ConvertMlirModuleToExportedModel(*module_ref, *checkpoint_dir,
*variable_shared_names);
*variable_shared_names,
/*function_aliases=*/{});
}
// Returns the updated function aliases. `module_op` may have different function
// names from the original model, so it re-associates the aliases with the new
// function names. Both the input `function_aliases` and the returned value
// are function name -> alias mappings. `function_aliases` is the function alias
// mapping of the original function.
absl::flat_hash_map<std::string, std::string> UpdateFunctionAliases(
const absl::flat_hash_map<std::string, std::string> function_aliases,
mlir::ModuleOp module_op) {
absl::flat_hash_map<std::string, std::string> updated_function_aliases;
module_op->walk([&](mlir::func::FuncOp func_op) {
// We may retrieve the original function's name from the attribute.
// Functions without this attribute are ignored.
auto original_func_name =
func_op->getAttrOfType<mlir::StringAttr>("tf._original_func_name");
if (original_func_name) {
if (auto alias_itr = function_aliases.find(original_func_name.str());
alias_itr != function_aliases.end()) {
const std::string alias = alias_itr->second;
const std::string new_func_name = func_op.getSymName().str();
updated_function_aliases[new_func_name] = alias;
VLOG(1) << "Updated function alias. Alias: " << alias
<< ", New function name: " << new_func_name
<< ", Old function name: " << original_func_name.str();
}
}
});
return updated_function_aliases;
}
absl::StatusOr<ExportedModel> QuantizePtqModelPreCalibration(
const absl::string_view saved_model_path,
const std::vector<std::string> &signature_keys,
const std::unordered_set<std::string> &tags,
const QuantizationOptions &quantization_options) {
const QuantizationOptions &quantization_options,
const absl::flat_hash_map<std::string, std::string> &function_aliases) {
// Convert the SavedModelBundle to an MLIR module.
mlir::MLIRContext context = CreateMlirContextForTfQuantization();
@ -416,10 +461,21 @@ absl::StatusOr<ExportedModel> QuantizePtqModelPreCalibration(
}
mlir::OwningOpRef<mlir::ModuleOp> module_ref = std::move(module).value();
const absl::flat_hash_map<std::string, std::string> updated_function_aliases =
UpdateFunctionAliases(function_aliases, *module_ref);
// Collect the names of the functions that have aliases so that they may not
// be inlined.
absl::flat_hash_set<std::string> aliased_function_names;
absl::c_for_each(updated_function_aliases, [&](const auto &aliases) {
return aliased_function_names.insert(aliases.first);
});
if (const absl::Status preprocess_status = PreprocessAndFreezeGraph(
/*mlir_dump_file_prefix=*/kTfQuantPtqPreCalibrationStepName,
/*is_inliner_run=*/true, module_ref.get(), &context,
bundle ? bundle->GetSession() : nullptr);
/*is_inliner_run=*/true,
/*noinline_functions=*/aliased_function_names, module_ref.get(),
&context, bundle ? bundle->GetSession() : nullptr);
!preprocess_status.ok()) {
return preprocess_status;
}
@ -456,14 +512,16 @@ absl::StatusOr<ExportedModel> QuantizePtqModelPreCalibration(
}
return ConvertMlirModuleToExportedModel(*module_ref, *checkpoint_dir,
*variable_shared_names);
*variable_shared_names,
updated_function_aliases);
}
absl::StatusOr<ExportedModel> QuantizePtqModelPostCalibration(
const absl::string_view saved_model_path,
const std::vector<std::string> &signature_keys,
const std::unordered_set<std::string> &tags,
const QuantizationOptions &quantization_options) {
const QuantizationOptions &quantization_options,
const absl::flat_hash_map<std::string, std::string> &function_aliases) {
// Convert the SavedModelBundle to an MLIR module.
mlir::MLIRContext context = CreateMlirContextForTfQuantization();
@ -486,13 +544,24 @@ absl::StatusOr<ExportedModel> QuantizePtqModelPostCalibration(
mlir::OwningOpRef<mlir::ModuleOp> module_ref = std::move(module).value();
const absl::flat_hash_map<std::string, std::string> updated_function_aliases =
UpdateFunctionAliases(function_aliases, *module_ref);
// Collect the names of the functions that have aliases so that they may not
// be inlined.
absl::flat_hash_set<std::string> aliased_function_names;
absl::c_for_each(updated_function_aliases, [&](const auto &aliases) {
return aliased_function_names.insert(aliases.first);
});
// Freezing is required again since variables might have been produced during
// the pre-calibration step. `is_inliner_run = false` to prevent the functions
// lifted for quantization from being inlined.
if (const absl::Status preprocess_status = PreprocessAndFreezeGraph(
/*mlir_dump_file_prefix=*/kTfQuantPtqPostCalibrationStepName,
/*is_inliner_run=*/false, module_ref.get(), &context,
bundle ? bundle->GetSession() : nullptr);
/*is_inliner_run=*/false,
/*noinline_functions=*/aliased_function_names, module_ref.get(),
&context, bundle ? bundle->GetSession() : nullptr);
!preprocess_status.ok()) {
return preprocess_status;
}
@ -527,7 +596,8 @@ absl::StatusOr<ExportedModel> QuantizePtqModelPostCalibration(
}
return ConvertMlirModuleToExportedModel(*module_ref, *checkpoint_dir,
*variable_shared_names);
*variable_shared_names,
updated_function_aliases);
}
absl::StatusOr<ExportedModel> QuantizePtqDynamicRange(
@ -592,7 +662,8 @@ absl::StatusOr<ExportedModel> QuantizePtqDynamicRange(
}
return ConvertMlirModuleToExportedModel(*module_ref, *checkpoint_dir,
*variable_shared_names);
*variable_shared_names,
/*function_aliases=*/{});
}
} // namespace internal

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h"
@ -60,13 +61,15 @@ absl::StatusOr<ExportedModel> QuantizePtqModelPreCalibration(
absl::string_view saved_model_path,
const std::vector<std::string>& exported_names,
const std::unordered_set<std::string>& tags,
const QuantizationOptions& quant_opts);
const QuantizationOptions& quant_opts,
const absl::flat_hash_map<std::string, std::string>& function_aliases);
absl::StatusOr<ExportedModel> QuantizePtqModelPostCalibration(
absl::string_view saved_model_path,
const std::vector<std::string>& signature_keys,
const std::unordered_set<std::string>& tags,
const QuantizationOptions& quant_opts);
const QuantizationOptions& quant_opts,
const absl::flat_hash_map<std::string, std::string>& function_aliases);
} // namespace internal
} // namespace quantization

View File

@ -619,12 +619,19 @@ def _run_static_range_ptq(
according to the quantized graph to match the original signature defs.
"""
logging.info('Running post-training quantization pre-calibration step.')
loader = saved_model_loader.SavedModelLoader(saved_model_path)
function_aliases = loader.get_meta_graph_def_from_tags(
tags
).meta_info_def.function_aliases
exported_model_serialized = (
quantize_model_wrapper.quantize_ptq_model_pre_calibration(
saved_model_path,
list(signature_def_keys),
set(tags),
quant_opts.SerializeToString(),
dict(function_aliases),
)
)
@ -648,6 +655,7 @@ def _run_static_range_ptq(
exported_model.restore_node_name,
exported_model.checkpoint_dir,
exported_model.variable_shared_names,
exported_model.function_aliases,
)
# Uses the representative dataset to collect statistics for calibration.
@ -678,6 +686,7 @@ def _run_static_range_ptq(
list(signature_def_keys),
set(tags),
quant_opts.SerializeToString(),
dict(exported_model.function_aliases),
)
)
@ -685,10 +694,7 @@ def _run_static_range_ptq(
exported_model_serialized
)
return (
exported_model,
signature_def_map,
)
return exported_model, signature_def_map
def _static_range_quantize(
@ -780,6 +786,7 @@ def _static_range_quantize(
restore_op_name=exported_model.restore_node_name,
checkpoint_dir=exported_model.checkpoint_dir,
variable_shared_names=exported_model.variable_shared_names,
function_aliases=exported_model.function_aliases,
)
return saved_model_load(output_directory)

View File

@ -298,6 +298,39 @@ def _find_variables(
return var_mapping
def _save_function_alias(
saved_model_dir: str,
tags: Collection[str],
function_aliases: Mapping[str, str],
) -> None:
"""Saves the function alias to the SavedModel.
SavedModelBuilder (TF1 saved model saver) does not support saving function
aliases, so this function loads the SavedModel proto and adds the
`function_aliases` field.
Args:
saved_model_dir: Path to the saved model directory.
tags: A collection of tags to specify the meta graph.
function_aliases: Function name -> function alias mapping.
"""
loader = saved_model_loader.SavedModelLoader(saved_model_dir)
meta_graph_def = loader.get_meta_graph_def_from_tags(tags)
for function_name, function_alias in function_aliases.items():
meta_graph_def.meta_info_def.function_aliases[function_name] = (
function_alias
)
saved_model_proto_serialized = loader.saved_model.SerializeToString()
# TODO(b/266015731): Also update and set the SavedModel fingerprint.
path = file_io.join(
saved_model_dir, saved_model_constants.SAVED_MODEL_FILENAME_PB
)
file_io.atomic_write_string_to_file(path, saved_model_proto_serialized)
def save_model_v1(
graph_def: graph_pb2.GraphDef,
output_dir: str,
@ -307,6 +340,7 @@ def save_model_v1(
restore_op_name: Optional[str] = None,
checkpoint_dir: Optional[str] = None,
variable_shared_names: Optional[Sequence[str]] = None,
function_aliases: Optional[Mapping[str, str]] = None,
) -> None:
"""Saves the model.
@ -322,6 +356,7 @@ def save_model_v1(
restore_op_name: Name of the node for restoration.
checkpoint_dir: Path to checkpoint file where variable values are saved.
variable_shared_names: Shared name of the variables in the model.
function_aliases: Function name -> function alias mapping.
Raises:
ValueError iff the graph does not contain a valid signature.
@ -372,3 +407,6 @@ def save_model_v1(
)
v1_builder.save()
if function_aliases:
_save_function_alias(output_dir, tags, function_aliases)

View File

@ -23,6 +23,7 @@ option cc_enable_arenas = true;
// Various techniques for model quantization are defined within this message
// along with a field that specifies a method to be used for a particular
// quantization request.
// NEXT ID: 3
message QuantizationMethod {
// Quantization methods that are supported as a stable API.
enum Method {
@ -74,8 +75,10 @@ enum QuantizationPrecision {
// Unit (either nodes or ops at this moment) wise quantization method for
// mixed bit precision quantization. It contains the name of the unit,
// the granularity of the unit, and the quantization method for each unit.
// NEXT ID: 6
message UnitWiseQuantizationPrecision {
// Quantization unit granularity.
// NEXT ID: 3
enum UnitType {
// This should never be used. Using this will generally result in an error.
UNIT_UNSPECIFIED = 0;
@ -101,6 +104,7 @@ message UnitWiseQuantizationPrecision {
// List of supported opsets to deploy the quantized model.
// The quantized model contains different set of ops depending on the opset.
// NEXT ID: 4
enum OpSet {
OP_SET_UNSPECIFIED = 0; // go/do-include-enum-unspecified
// Uses TF ops that mimic quantization behavior. Used when the corresponding
@ -113,6 +117,7 @@ enum OpSet {
}
// Configurations for variable freezing during quantization passes.
// NEXT ID: 2
message FreezeAllVariables {
// Setting this to true freezes all variables to constants during
// quantization. Setting this to `false` is an experimental feature and does
@ -127,6 +132,7 @@ message FreezeAllVariables {
// 2) A set of supported operations.
// 3) Unit wise quantization precision.
// 4) Target hardware name.
// NEXT ID: 9
message QuantizationOptions {
// The default quantization configuration for the model. If the below
// unit-wise configuration does not exist, we use this default quantization

View File

@ -15,7 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h"
#include <memory>
#include <string>
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
@ -28,6 +30,7 @@ limitations under the License.
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h"
#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h"
@ -64,6 +67,7 @@ absl::Status RunPassesOnModuleOp(const absl::string_view mlir_dump_file_name,
absl::Status PreprocessAndFreezeGraph(
const absl::string_view mlir_dump_file_prefix, const bool is_inliner_run,
const absl::flat_hash_set<std::string>& noinline_functions,
mlir::ModuleOp module_op, mlir::MLIRContext* context,
llvm::Optional<Session*> session) {
mlir::PassManager pm_before_freezing_variables(context);
@ -82,6 +86,12 @@ absl::Status PreprocessAndFreezeGraph(
mlir::PassManager pm_after_freezing_variables(context);
pm_after_freezing_variables.addPass(mlir::TF::CreateTFShapeInferencePass());
pm_after_freezing_variables.addPass(mlir::createCanonicalizerPass());
// Makes certain functions immune to the `InlinerPass`. Used to preserve
// aliased functions.
pm_after_freezing_variables.addNestedPass<mlir::func::FuncOp>(
mlir::quant::CreateMarkFunctionsNoinlinePass(std::vector<std::string>(
noinline_functions.begin(), noinline_functions.end())));
if (is_inliner_run) {
pm_after_freezing_variables.addPass(mlir::createInlinerPass());
}

View File

@ -15,6 +15,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_QUANTIZE_PREPROCESS_H_
#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_QUANTIZE_PREPROCESS_H_
#include <string>
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "llvm/ADT/Optional.h"
@ -36,11 +39,11 @@ inline constexpr absl::string_view kDefaultTfQuantMlirDumpFilePrefix =
// `mlir_dump_file_prefix` is primarily used for debugging and does not affect
// the preprocessing behavior. Instructions for producing MLIR dump files are in
// the comments of `tensorflow::quantization::MaybeEnableIrPrinting` function.
absl::Status PreprocessAndFreezeGraph(absl::string_view mlir_dump_file_prefix,
bool is_inliner_run,
mlir::ModuleOp module_op,
mlir::MLIRContext* context,
llvm::Optional<Session*> session);
absl::Status PreprocessAndFreezeGraph(
absl::string_view mlir_dump_file_prefix, bool is_inliner_run,
const absl::flat_hash_set<std::string>& noinline_functions,
mlir::ModuleOp module_op, mlir::MLIRContext* context,
llvm::Optional<Session*> session);
// Overload of `PreprocessAndFreezeGraph` that uses the default MLIR dump file
// prefix.
@ -49,7 +52,8 @@ inline absl::Status PreprocessAndFreezeGraph(mlir::ModuleOp module_op,
llvm::Optional<Session*> session) {
return PreprocessAndFreezeGraph(
/*mlir_dump_file_prefix=*/kDefaultTfQuantMlirDumpFilePrefix,
/*is_inliner_run=*/true, module_op, context, session);
/*is_inliner_run=*/true, /*noinline_functions=*/{}, module_op, context,
session);
}
} // namespace quantization

View File

@ -2307,6 +2307,33 @@ Mutually reduces multiple tensors of identical type and shape.
}];
}
def TF_CollectiveReduceScatterV2Op : TF_Op<"CollectiveReduceScatterV2", [DeclareOpInterfaceMethods<TF_GetResourceInstanceInterface>, TF_CollectiveReduceOrderingEffect]> {
let summary = [{
Mutually reduces multiple tensors of identical type and shape and scatters the result.
}];
let arguments = (ins
TF_FpOrI32OrI64Tensor:$input,
TF_Int32Tensor:$group_size,
TF_Int32Tensor:$group_key,
TF_Int32Tensor:$instance_key,
Variadic<TF_ResourceTensor>:$ordering_token,
TF_AnyStrAttrOf<["Min", "Max", "Mul", "Add"]>:$merge_op,
TF_AnyStrAttrOf<["Id", "Div"]>:$final_op,
DefaultValuedOptionalAttr<StrAttr, "\"auto\"">:$communication_hint,
DefaultValuedOptionalAttr<F32Attr, "0.0f">:$timeout_seconds,
DefaultValuedOptionalAttr<I64Attr, "-1">:$max_subdivs_per_device
);
let results = (outs
TF_FpOrI32OrI64Tensor:$data
);
TF_DerivedOperandSizeAttr Nordering_token = TF_DerivedOperandSizeAttr<4>;
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_CollectiveReduceV2Op : TF_Op<"CollectiveReduceV2", [DeclareOpInterfaceMethods<TF_GetResourceInstanceInterface>, TF_CollectiveReduceOrderingEffect]> {
let summary = [{
Mutually reduces multiple tensors of identical type and shape.

View File

@ -1476,6 +1476,8 @@ An op that groups a list of partitioned inputs together. Supports ND sharding.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
let hasVerifier = 1;
}
def TF_TPUPartitionedOutputOp : TF_Op<"TPUPartitionedOutput", [Pure]> {

View File

@ -1074,6 +1074,12 @@ std::optional<std::string> CollectiveReduceV2Op::GetResourceInstanceStr() {
: std::nullopt;
}
std::optional<std::string>
CollectiveReduceScatterV2Op::GetResourceInstanceStr() {
return getNorderingToken() == 0 ? std::optional<std::string>("")
: std::nullopt;
}
//===----------------------------------------------------------------------===//
// ConcatOp and ConcatV2Op
//===----------------------------------------------------------------------===//

View File

@ -2666,6 +2666,32 @@ LogicalResult ToBoolOp::inferReturnTypes(
return success();
}
//===----------------------------------------------------------------------===//
// TPUPartitionedInputV2
//===----------------------------------------------------------------------===//
// This method mimics this op's core/TF-level shape inference logic
LogicalResult TPUPartitionedInputV2Op::verify() {
TPUPartitionedInputV2Op op = *this;
int num_partitions = 1;
const mlir::ArrayAttr partition_dims = op.getPartitionDims();
for (const mlir::Attribute &dim : partition_dims) {
num_partitions *= dim.cast<IntegerAttr>().getInt();
}
const bool is_packed = op.getIsPacked();
const bool replicated = partition_dims.empty();
const int num_inputs_expected = is_packed ? 1 : num_partitions;
if (!((replicated && !is_packed) || (op.getN() == num_inputs_expected))) {
return op.emitOpError() << "expected " << num_inputs_expected
<< " inputs, got " << op.getN();
}
return success();
}
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//

View File

@ -2561,6 +2561,26 @@ func.func @testInvalidToBool(%arg0: tensor<i32>) -> tensor<1xi1> {
// -----
// Test invalid tf.TPUPartitionedInputV2 with packing
func.func @testPackedTPUPartitionedInputV2(tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<4x4xf32> {
^bb0(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>):
// expected-error @+1 {{expected 1 inputs, got 2}}
%0 = "tf.TPUPartitionedInputV2"(%arg0, %arg1) {partition_dims = [2, 1], is_packed = true} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<4x4xf32>
func.return %0 : tensor<4x4xf32>
}
// -----
// Test invalid tf.TPUPartitionedInputV2 without packing
func.func @testUnpackedTPUPartitionedInputV2(tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<4x4xf32> {
^bb0(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>):
// expected-error @+1 {{expected 2 inputs, got 1}}
%0 = "tf.TPUPartitionedInputV2"(%arg0) {partition_dims = [2, 1], is_packed = false} : (tensor<2x4xf32>) -> tensor<4x4xf32>
func.return %0 : tensor<4x4xf32>
}
// -----
// Test valid tf.Transpose
// CHECK-LABEL: testTranspose
func.func @testTranspose(tensor<2x3xf32>) -> tensor<3x2xf32> {

View File

@ -13,6 +13,42 @@ func.func @simple(%arg0: tensor<!tf_type.resource<tensor<10x3xf32>>>, %arg1: ten
func.return %ri : tensor<!tf_type.resource<tensor<10x3xf32>>>
}
// CHECK-LABEL:func @simple_packed
// CHECK-SAME: ([[ARG0:%.*]]: tensor<!tf_type.resource<tensor<10x3xf32>>>)
func.func @simple_packed(%arg0: tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>> {
// CHECK: "tf.TPUReplicateMetadata"()
// CHECK: [[RI_0:%.*]] = "tf.TPUReplicatedInput"([[ARG0]])
// CHECK-SAME: is_packed = true
// CHECK: [[RI_1:%.*]] = "tf.TPUReplicatedInput"([[ARG0]])
// CHECK-SAME: is_packed = true
// CHECK: [[PI:%.*]] = "tf.TPUPartitionedInputV2"([[RI_0]], [[RI_1]])
// CHECK-SAME: is_packed = false
"tf.TPUReplicateMetadata"() {num_cores_per_replica = 2 : i64, num_replicas = 2 : i64} : () -> ()
%0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
%1 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
%2 = "tf.TPUReplicatedInput"(%0, %1) {is_packed = false} : (tensor<!tf_type.resource<tensor<10x3xf32>>>, tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
// CHECK: return [[PI]]
func.return %2 : tensor<!tf_type.resource<tensor<10x3xf32>>>
}
// CHECK-LABEL:func @multi_arg_packed
// CHECK-SAME: ([[ARG0:%.*]]: tensor<!tf_type.resource<tensor<10x3xf32>>>, [[ARG1:%.*]]: tensor<!tf_type.resource<tensor<10x3xf32>>>)
func.func @multi_arg_packed(%arg0: tensor<!tf_type.resource<tensor<10x3xf32>>>, %arg1: tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>> {
// CHECK: "tf.TPUReplicateMetadata"()
// CHECK: [[RI_0:%.*]] = "tf.TPUReplicatedInput"([[ARG0]], [[ARG1]])
// CHECK-SAME: is_packed = false
// CHECK: [[RI_1:%.*]] = "tf.TPUReplicatedInput"([[ARG0]], [[ARG1]])
// CHECK-SAME: is_packed = false
// CHECK: [[PI:%.*]] = "tf.TPUPartitionedInputV2"([[RI_0]], [[RI_1]])
// CHECK-SAME: is_packed = false
"tf.TPUReplicateMetadata"() {num_cores_per_replica = 2 : i64, num_replicas = 2 : i64} : () -> ()
%0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
%1 = "tf.TPUPartitionedInputV2"(%arg1) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
%2 = "tf.TPUReplicatedInput"(%0, %1) {is_packed = false} : (tensor<!tf_type.resource<tensor<10x3xf32>>>, tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
// CHECK: return [[PI]]
func.return %2 : tensor<!tf_type.resource<tensor<10x3xf32>>>
}
// CHECK-LABEL:func @missing_xla_sharding
// CHECK-SAME: ([[ARG0:%.*]]: tensor<!tf_type.resource<tensor<10x3xf32>>>, [[ARG1:%.*]]: tensor<!tf_type.resource<tensor<10x3xf32>>>, [[ARG2:%.*]]: tensor<!tf_type.resource<tensor<10x3xf32>>>, [[ARG3:%.*]]: tensor<!tf_type.resource<tensor<10x3xf32>>>)
func.func @missing_xla_sharding(%arg0: tensor<!tf_type.resource<tensor<10x3xf32>>>, %arg1: tensor<!tf_type.resource<tensor<10x3xf32>>>, %arg2: tensor<!tf_type.resource<tensor<10x3xf32>>>, %arg3: tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>> {
@ -44,6 +80,27 @@ func.func @no_change_to_dag(%arg0: tensor<!tf_type.resource<tensor<10x3xf32>>>,
// -----
func.func @missing_metadata(%arg0: tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>> {
// expected-error@+1 {{num cores per replica unavailable, metadata missing?}}
%0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
%1 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
%2 = "tf.TPUReplicatedInput"(%0, %1) {is_packed = false} : (tensor<!tf_type.resource<tensor<10x3xf32>>>, tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
func.return %2 : tensor<!tf_type.resource<tensor<10x3xf32>>>
}
// -----
func.func @inconsistent_packing(%arg0: tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>> {
"tf.TPUReplicateMetadata"() {num_cores_per_replica = 2 : i64, num_replicas = 2 : i64} : () -> ()
%0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
// expected-error@+1 {{packing should match across ops}}
%1 = "tf.TPUPartitionedInputV2"(%arg0, %arg0) {_XlaSharding = "", partition_dims = [], is_packed = false} : (tensor<!tf_type.resource<tensor<10x3xf32>>>, tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
%2 = "tf.TPUReplicatedInput"(%0, %1) {is_packed = false} : (tensor<!tf_type.resource<tensor<10x3xf32>>>, tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
func.return %2 : tensor<!tf_type.resource<tensor<10x3xf32>>>
}
// -----
func.func @xla_sharding_mismatch(%arg0: tensor<!tf_type.resource<tensor<10x3xf32>>>, %arg1: tensor<!tf_type.resource<tensor<10x3xf32>>>, %arg2: tensor<!tf_type.resource<tensor<10x3xf32>>>, %arg3: tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>> {
%pi_0 = "tf.TPUPartitionedInputV2"(%arg0, %arg1) {_XlaSharding = "", partition_dims = []} : (tensor<!tf_type.resource<tensor<10x3xf32>>>, tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
%pi_1 = "tf.TPUPartitionedInputV2"(%arg2, %arg3) {_XlaSharding = "123", partition_dims = []} : (tensor<!tf_type.resource<tensor<10x3xf32>>>, tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
@ -80,3 +137,13 @@ func.func @mixed_inputs_to_replicated_op(%arg0: tensor<!tf_type.resource<tensor<
%ri = "tf.TPUReplicatedInput"(%pi_0, %arg2) {index = 1} : (tensor<!tf_type.resource<tensor<10x3xf32>>>, tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
func.return %ri : tensor<!tf_type.resource<tensor<10x3xf32>>>
}
// -----
func.func @num_partitioned_inputs_mismatch_num_cores_per_replica(%arg0: tensor<!tf_type.resource<tensor<10x3xf32>>>, %arg1: tensor<!tf_type.resource<tensor<10x3xf32>>>, %arg2: tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>> {
"tf.TPUReplicateMetadata"() {num_cores_per_replica = 2 : i64, num_replicas = 1 : i64} : () -> ()
// expected-error@+1 {{expects 2 operands but found 3}}
%pi = "tf.TPUPartitionedInputV2"(%arg0, %arg1, %arg2) {_XlaSharding = "", partition_dims = []} : (tensor<!tf_type.resource<tensor<10x3xf32>>>, tensor<!tf_type.resource<tensor<10x3xf32>>>, tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
%ri = "tf.TPUReplicatedInput"(%pi) : (tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
func.return %ri : tensor<!tf_type.resource<tensor<10x3xf32>>>
}

View File

@ -23,6 +23,27 @@ func.func @read_write_resource(%arg0: tensor<!tf_type.resource<tensor<i32>>>, %a
func.return
}
// CHECK-LABEL: func @read_write_packed_resource
// CHECK-SAME: ([[ARG0:%.+]]: tensor<!tf_type.resource<tensor<i32>>>)
func.func @read_write_packed_resource(%arg0: tensor<!tf_type.resource<tensor<i32>>>) {
// CHECK-DAG: [[READ0:%.+]] = "tf.ReadVariableOp"([[ARG0]])
// CHECK: [[INPUT:%.+]] = "tf.TPUPartitionedInputV2"([[READ0]])
// CHECK-SAME: _XlaSharding = ""
// CHECK-SAME: is_packed = true
// CHECK-SAME: partition_dims = []
%0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<!tf_type.resource<tensor<i32>>>
%1 = "tf.ReadVariableOp"(%0) : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
// CHECK: [[COMPUTATION:%.+]] = "tf_device.cluster_func"([[INPUT]])
%2 = "tf_device.cluster_func"(%1) {func = @computation, use_spmd_for_xla_partitioning = true, num_cores_per_replica = 2 : i64} : (tensor<i32>) -> tensor<i32>
// CHECK: [[OUTPUT:%.+]]:2 = "tf.TPUPartitionedOutputV2"([[COMPUTATION]])
// CHECK-SAME: _XlaSharding = ""
// CHECK-SAME: partition_dims = []
// CHECK-DAG: "tf.AssignVariableOp"([[ARG0]], [[OUTPUT]]#0)
// CHECK-DAG: "tf.AssignVariableOp"([[ARG0]], [[OUTPUT]]#1)
"tf.AssignVariableOp"(%0, %2) : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
func.return
}
// CHECK-LABEL: func @read_only_resource
// CHECK-SAME: ([[ARG0:%.+]]: tensor<!tf_type.resource<tensor<i32>>>, [[ARG1:%.+]]: tensor<!tf_type.resource<tensor<i32>>>)
func.func @read_only_resource(%arg0: tensor<!tf_type.resource<tensor<i32>>>, %arg1: tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32> {
@ -127,6 +148,28 @@ func.func @resource_missing_subtype(%arg0: tensor<!tf_type.resource>, %arg1: ten
// -----
func.func @missing_num_cores_per_replica(%arg0: tensor<!tf_type.resource<tensor<i32>>>) {
// expected-error@+1 {{op num cores per replica unavailable}}
%0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<!tf_type.resource<tensor<i32>>>
%1 = "tf.ReadVariableOp"(%0) : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
%2 = "tf_device.cluster_func"(%1) {func = @computation, use_spmd_for_xla_partitioning = true} : (tensor<i32>) -> tensor<i32>
"tf.AssignVariableOp"(%0, %2) : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
func.return
}
// -----
func.func @mismatch_num_cores_per_replica(%arg0: tensor<!tf_type.resource<tensor<i32>>>) {
// expected-error@+1 {{expects 2 operands but found 3}}
%0 = "tf.TPUPartitionedInputV2"(%arg0, %arg0, %arg0) {_XlaSharding = "", partition_dims = []} : (tensor<!tf_type.resource<tensor<i32>>>, tensor<!tf_type.resource<tensor<i32>>>, tensor<!tf_type.resource<tensor<i32>>>) -> tensor<!tf_type.resource<tensor<i32>>>
%1 = "tf.ReadVariableOp"(%0) : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
%2 = "tf_device.cluster_func"(%1) {func = @computation, use_spmd_for_xla_partitioning = true, num_cores_per_replica = 2 : i64} : (tensor<i32>) -> tensor<i32>
"tf.AssignVariableOp"(%0, %2) : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
func.return
}
// -----
// Check outside compiled that uses a TPUPartitionedInputV2.
func.func private @computation(%arg0: tensor<i32>) -> tensor<i32>

View File

@ -144,6 +144,19 @@ bool AddAccessedResourceIds(
return false;
}
/* Resources may be merged with an execute op when they are on its device or a
* `COMPOSITE`. Note that a `COMPOSITE` represents a set of devices, they
* are typically associated with packed variables. Presently, we assume this
* set spans all the devices. So, a variable on a `COMPOSITE` will have a local
* instance on the execute op's device.
*/
bool IsResourceMergeable(Attribute& resource_attr, Attribute& device_attr) {
return resource_attr &&
((resource_attr == device_attr) ||
(resource_attr.cast<mlir::StringAttr>().getValue().find(
"COMPOSITE") != llvm::StringRef::npos));
}
// Finds the variable access info for a TPUExecute op.
// - `check_device` specifies whether it checks the device assignment of the
// variables to match the TPUExecute op. This is optional in some context,
@ -187,7 +200,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(
if (auto* resource_op = resource.getDefiningOp()) {
auto resource_attr = resource_op->getAttr(kDeviceAttr);
// Check device matching for the node defining the resource.
if (!resource_attr || resource_attr != device_attr) continue;
if (!IsResourceMergeable(resource_attr, device_attr)) continue;
} else {
auto resource_arg = resource.dyn_cast<BlockArgument>();
assert(resource_arg);
@ -195,7 +208,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(
// Check device matching for the argument defining the resource.
auto resource_attr = func.getArgAttrOfType<mlir::StringAttr>(
resource_arg.getArgNumber(), kFuncDeviceAttr);
if (!resource_attr || resource_attr != device_attr) continue;
if (!IsResourceMergeable(resource_attr, device_attr)) continue;
}
}

View File

@ -11,6 +11,7 @@ limitations under the License.
==============================================================================*/
#include <cstddef>
#include <optional>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
@ -41,47 +42,73 @@ LogicalResult ReorderReplicateAndPartitionedInputs(
return replicated_input.emitOpError()
<< "expects all inputs from 'tf.TPUPartitionedInputV2' ops";
const auto metadata_iter =
replicated_input->getBlock()->getOps<TF::TPUReplicateMetadataOp>();
TF::TPUReplicateMetadataOp metadata;
if (!metadata_iter.empty()) metadata = *(metadata_iter.begin());
auto first_partitioned_input = llvm::cast<TF::TPUPartitionedInputV2Op>(
replicated_input.getOperand(0).getDefiningOp());
std::optional<::llvm::StringRef> xla_sharding =
first_partitioned_input.get_XlaSharding();
auto partition_dims = first_partitioned_input.getPartitionDims();
size_t num_cores_per_replica = first_partitioned_input.getNumOperands();
const std::optional<::llvm::StringRef> xla_sharding =
first_partitioned_input.get_XlaSharding();
for (auto operand : replicated_input.getInputs().drop_front()) {
size_t num_cores_per_replica = first_partitioned_input.getNumOperands();
if (metadata) {
num_cores_per_replica = metadata.getNumCoresPerReplica();
} else if (first_partitioned_input.getIsPacked()) {
return first_partitioned_input->emitOpError()
<< "num cores per replica unavailable, metadata missing?";
}
const bool packed_input = first_partitioned_input.getIsPacked();
const size_t num_operands_expected = packed_input ? 1 : num_cores_per_replica;
if (metadata &&
num_operands_expected != first_partitioned_input.getNumOperands()) {
return first_partitioned_input->emitOpError()
<< "expects " << num_operands_expected << " operands but found "
<< first_partitioned_input.getNumOperands();
}
for (const auto& operand : replicated_input.getInputs().drop_front()) {
auto partitioned_input =
llvm::cast<TF::TPUPartitionedInputV2Op>(operand.getDefiningOp());
std::optional<::llvm::StringRef> op_xla_sharding =
const std::optional<::llvm::StringRef> op_xla_sharding =
partitioned_input.get_XlaSharding();
auto op_partition_dims = partitioned_input.getPartitionDims();
const auto op_partition_dims = partitioned_input.getPartitionDims();
// Abort if TPUPartitionedInputV2(s) do not have the same attributes.
if (!llvm::equal(partition_dims, op_partition_dims))
if (!llvm::equal(partition_dims, op_partition_dims)) {
return partitioned_input->emitOpError()
<< "expects partition_dims = " << partition_dims << " but found "
<< op_partition_dims;
if (partitioned_input.getNumOperands() != num_cores_per_replica)
} else if (partitioned_input.getIsPacked() !=
first_partitioned_input.getIsPacked()) {
return partitioned_input->emitOpError()
<< "expects " << num_cores_per_replica << " operands but found "
<< "packing should match across ops";
} else if (partitioned_input.getNumOperands() != num_operands_expected) {
return partitioned_input->emitOpError()
<< "expects " << num_operands_expected << " operands but found "
<< partitioned_input.getNumOperands();
if (xla_sharding != op_xla_sharding)
} else if (xla_sharding != op_xla_sharding) {
return replicated_input.emitOpError()
<< "expects all inputs from 'tf.TPUPartitionedInputV2' ops to "
"have identical XLA sharding";
}
}
// 2D Matrix to store per core per replica operands. The matrix dimensions are
// num_cores_per_replica x num_replicas. i-th row holds the operands for i-th
// core. j-th column holds the operands for j-th replica.
llvm::SmallVector<llvm::SmallVector<Value, 4>, 4>
operands_per_replica_per_core;
operands_per_replica_per_core.resize(num_cores_per_replica);
operands_per_replica_per_core(num_cores_per_replica);
// Collect all operands in the 2D matrix.
for (auto operand : replicated_input.getInputs()) {
auto pi = llvm::cast<TF::TPUPartitionedInputV2Op>(operand.getDefiningOp());
for (auto& pi_operand : pi->getOpOperands()) {
unsigned core_id = pi_operand.getOperandNumber();
operands_per_replica_per_core[core_id].push_back(pi_operand.get());
Operation* pi = operand.getDefiningOp();
for (unsigned core_id = 0; core_id < num_cores_per_replica; ++core_id) {
const auto pi_operand =
packed_input ? pi->getOperand(0) : pi->getOperand(core_id);
operands_per_replica_per_core[core_id].push_back(pi_operand);
}
}
@ -89,16 +116,23 @@ LogicalResult ReorderReplicateAndPartitionedInputs(
// `tf.TPUPartitionedInputV2` op.
OpBuilder builder(replicated_input);
llvm::SmallVector<Value, 4> operands_per_core;
for (const auto& operands_per_replica : operands_per_replica_per_core) {
for (auto& operands_per_replica : operands_per_replica_per_core) {
const bool is_packed =
packed_input && llvm::all_equal(operands_per_replica);
if (is_packed) // reduce the duplicates to one input for packed vars
operands_per_replica.erase(operands_per_replica.begin() + 1,
operands_per_replica.end());
auto replicate_op = builder.create<TF::TPUReplicatedInputOp>(
replicated_input.getLoc(), replicated_input.getType(),
operands_per_replica, replicated_input->getAttrs());
replicate_op.setIsPacked(is_packed);
operands_per_core.push_back(replicate_op);
}
auto pi = builder.create<TF::TPUPartitionedInputV2Op>(
first_partitioned_input.getLoc(), replicated_input.getType(),
operands_per_core, first_partitioned_input->getAttrs());
pi.setIsPacked(false); // inputs are now ops--not resources
replicated_input.replaceAllUsesWith(pi.getOutput());
return success();
}

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <memory>
#include <tuple>
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
@ -38,6 +39,9 @@ namespace {
#define GEN_PASS_DEF_TPURESOURCEREADSWRITESPARTITIONINGPASS
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc"
constexpr char kUseSpmdAttr[] = "use_spmd_for_xla_partitioning";
constexpr char kNumCoresPerReplicaAttr[] = "num_cores_per_replica";
struct TPUResourceReadsWritesPartitioningPass
: public impl::TPUResourceReadsWritesPartitioningPassBase<
TPUResourceReadsWritesPartitioningPass> {
@ -109,12 +113,14 @@ LogicalResult UpdateReadUses(TF::ReadVariableOp old_read,
LogicalResult PartitionResourceReadsWrites(
tf_device::ClusterFuncOp cluster_func) {
bool use_spmd = false;
if (auto use_spmd_attr = cluster_func->getAttrOfType<BoolAttr>(
"use_spmd_for_xla_partitioning"))
if (auto use_spmd_attr = cluster_func->getAttrOfType<BoolAttr>(kUseSpmdAttr))
use_spmd = use_spmd_attr.getValue();
if (!use_spmd) return success();
auto num_cores_per_replica_attr =
cluster_func->getAttrOfType<IntegerAttr>(kNumCoresPerReplicaAttr);
// Wrap the ClusterFunc with a ParallelExecute if it does not already exist.
OpBuilder builder(cluster_func);
tf_device::ParallelExecuteOp parallel_execute =
@ -138,20 +144,40 @@ LogicalResult PartitionResourceReadsWrites(
!AllResourceTypesHaveSubtypes(partitioned_input.getInputs().getTypes()))
continue;
const auto inputs = partitioned_input.getInputs();
const bool packed_input = partitioned_input.getIsPacked();
int num_cores_per_replica = partitioned_input.getN();
if (num_cores_per_replica_attr) {
num_cores_per_replica = num_cores_per_replica_attr.getInt();
} else if (packed_input) {
return partitioned_input->emitOpError()
<< "num cores per replica unavailable";
}
const int num_operands_expected = packed_input ? 1 : num_cores_per_replica;
if (num_cores_per_replica_attr && num_operands_expected != inputs.size()) {
return partitioned_input->emitOpError()
<< "expects " << num_operands_expected << " operands but found "
<< partitioned_input.getNumOperands();
}
builder.setInsertionPoint(assign_var);
llvm::SmallVector<Type, 4> partitioned_output_types;
partitioned_output_types.reserve(partitioned_input.getN());
for (Type input_type : partitioned_input.getInputs().getTypes())
partitioned_output_types.push_back(GetResourceSubtype(input_type));
partitioned_output_types.reserve(num_cores_per_replica);
for (int i = 0; i < num_cores_per_replica; ++i) {
const auto& input = packed_input ? inputs[0] : inputs[i];
partitioned_output_types.push_back(GetResourceSubtype(input.getType()));
}
auto partitioned_output = builder.create<TF::TPUPartitionedOutputV2Op>(
cluster_func->getLoc(), partitioned_output_types, result,
partitioned_input.getPartitionDimsAttr(),
partitioned_input.get_XlaShardingAttr());
for (auto resource_write : llvm::zip(partitioned_input.getInputs(),
partitioned_output.getOutput()))
for (auto [i, value] : llvm::enumerate(partitioned_output.getOutput())) {
const auto& resource = packed_input ? inputs[0] : inputs[i];
builder.create<TF::AssignVariableOp>(
assign_var->getLoc(), /*resource=*/std::get<0>(resource_write),
/*value=*/std::get<1>(resource_write));
assign_var->getLoc(), /*resource=*/resource, /*value=*/value);
}
assign_var.erase();
}
@ -167,13 +193,24 @@ LogicalResult PartitionResourceReadsWrites(
continue;
}
builder.setInsertionPoint(partitioned_input);
// we only want to create one read variable op per unique input
// otherwise tpu rewriting will fail to clean up the duplicates
llvm::SmallMapVector<Value, Value, 4> read_variable_ops;
llvm::SmallVector<Value, 4> partitioned_reads;
builder.setInsertionPoint(partitioned_input);
for (Value input : partitioned_input.getInputs()) {
auto partitioned_read = builder.create<TF::ReadVariableOp>(
read_var->getLoc(), GetResourceSubtype(input), input);
partitioned_reads.push_back(partitioned_read.getValue());
auto search = read_variable_ops.find(input);
// if a read variable op already doesn't exist for this input, create it
if (search == read_variable_ops.end()) {
auto partitioned_read = builder.create<TF::ReadVariableOp>(
read_var->getLoc(), GetResourceSubtype(input), input);
search = read_variable_ops.insert({input, partitioned_read.getValue()})
.first;
}
partitioned_reads.push_back(search->second);
}
auto partitioned_read = builder.create<TF::TPUPartitionedInputV2Op>(
partitioned_input->getLoc(), read_var.getValue().getType(),
partitioned_reads, partitioned_input.getPartitionDimsAttr(),

View File

@ -1,136 +0,0 @@
// RUN: tf-tfrt-opt %s -split-input-file \
// RUN: -xla-cpu-transform-matmul="tile-sizes=8,4,2" \
// RUN: | FileCheck %s --check-prefix=MARKED
// RUN: tf-tfrt-opt %s -split-input-file \
// RUN: -xla-cpu-transform-matmul="tile-sizes=8,4,2" \
// RUN: | FileCheck %s --check-prefix=TRANSFORMED
// RUN: tf-tfrt-opt %s -split-input-file -xla-cpu-transform-matmul="tile-sizes=8,4,2" \
// RUN: -canonicalize -vectorize-perfectly-tiled-loops \
// RUN: | FileCheck %s --check-prefix=VECTORIZED
// RUN: tf-tfrt-opt %s -split-input-file -xla-cpu-transform-matmul="lower-to-mmt4d=true" \
// RUN: -vectorize-perfectly-tiled-loops \
// RUN: | FileCheck %s --check-prefix=MMT4D
func.func @matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>)
-> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%c1 = arith.constant 1 : index
%1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%2 = tensor.empty(%0, %1) : tensor<?x?xf32>
%cst = arith.constant 0.000000e+00 : f32
%3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
%4 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%3 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %4 : tensor<?x?xf32>
}
// TRANSFORMED-LABEL: func @matmul(
// TRANSFORMED-SAME: %[[LHS:.*]]: tensor<?x?xf32>, %[[RHS:.*]]: tensor<?x?xf32>)
// TRANSFORMED-DAG: %[[C0:.*]] = arith.constant 0 : index
// TRANSFORMED: %[[INIT:.*]] = tensor.empty
// TRANSFORMED: %[[MAIN_PAR:.*]] = gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[IUB:.*]], %[[JUB:.*]]) step
// TRANSFORMED: %[[MAIN_SLICE:.*]] = tensor.extract_slice %[[INIT]]
// TRANSFORMED: %[[MAIN_FILL:.*]] = linalg.fill{{.*}}outs(%[[MAIN_SLICE]]
// TRANSFORMED: %[[MAIN_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]]) to (%[[KUB:.*]]) {{.*}} outs ({{.*}} = %[[MAIN_FILL]]:
// TRANSFORMED: %[[MAIN_PAR_MAIN_FOR_MATMUL:.*]] = linalg.matmul
// TRANSFORMED: gml_st.set_yield %[[MAIN_PAR_MAIN_FOR_MATMUL]]
// TRANSFORMED: %[[REM_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[KUB]]) {{.*}} outs ({{.*}} = %[[MAIN_FOR]]:
// TRANSFORMED: %[[MAIN_PAR_REM_FOR_MATMUL:.*]] = linalg.matmul
// TRANSFORMED: gml_st.set_yield %[[MAIN_PAR_REM_FOR_MATMUL]]
// TRANSFORMED: gml_st.set_yield %[[REM_FOR]]
// TRANSFORMED: %[[REM_RHS_PAR:.*]] = gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[JUB]])
// TRANSFORMED: %[[REM_RHS_SLICE:.*]] = tensor.extract_slice %[[MAIN_PAR]]
// TRANSFORMED: %[[REM_RHS_FILL:.*]] = linalg.fill{{.*}}outs(%[[REM_RHS_SLICE]]
// TRANSFORMED: %[[REM_RHS_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]]) {{.*}} outs ({{.*}} = %[[REM_RHS_FILL]]:
// TRANSFORMED: %[[REM_RHS_PAR_MATMUL:.*]] = linalg.matmul
// TRANSFORMED: gml_st.set_yield %[[REM_RHS_PAR_MATMUL]]
// TRANSFORMED: gml_st.set_yield %[[REM_RHS_FOR]]
// TRANSFORMED: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[IUB]], %[[C0]])
// TRANSFORMED: %[[REM_LHS_SLICE:.*]] = tensor.extract_slice %[[REM_RHS_PAR]]
// TRANSFORMED: %[[REM_LHS_FILL:.*]] = linalg.fill{{.*}}outs(%[[REM_LHS_SLICE]]
// TRANSFORMED: %[[REM_LHS_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]]) {{.*}} outs ({{.*}} = %[[REM_LHS_FILL]]:
// TRANSFORMED: %[[REM_LHS_PAR_MATMUL:.*]] = linalg.matmul
// TRANSFORMED: gml_st.set_yield %[[REM_LHS_PAR_MATMUL]]
// TRANSFORMED: gml_st.set_yield %[[REM_LHS_FOR]]
// -----
// VECTORIZED-LABEL: func @matmul(
// VECTORIZED-SAME: %[[LHS:.*]]: tensor<?x?xf32>, %[[RHS:.*]]: tensor<?x?xf32>)
// VECTORIZED-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x4xf32>
// VECTORIZED-DAG: %[[C0:.*]] = arith.constant 0 : index
// VECTORIZED-DAG: %[[INIT:.*]] = tensor.empty
// VECTORIZED: %[[MAIN_PAR:.*]] = gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[IUB:.*]], %[[JUB:.*]]) step
// VECTORIZED: %[[MAIN_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]]) to (%[[KUB:.*]]) {{.*}} outs (%[[ARG:.*]] =
// VECTORIZED: %[[LHS_READ:.*]] = vector.transfer_read {{.*}} vector<8x2xf32>
// VECTORIZED: %[[RHS_READ:.*]] = vector.transfer_read {{.*}} vector<2x4xf32>
// VECTORIZED: %[[CONTRACT:.*]] = vector.contract {{.*}} %[[LHS_READ]], %[[RHS_READ]], %[[ARG]]
// VECTORIZED: gml_st.set_yield %[[CONTRACT]]
// VECTORIZED: %[[WRITE:.*]] = vector.transfer_write %[[MAIN_FOR]]
// VECTORIZED: %[[REM_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[KUB]]) {{.*}} outs ({{.*}} = %[[WRITE]]:
// VECTORIZED: %[[MAIN_PAR_REM_FOR_MATMUL:.*]] = linalg.matmul
// VECTORIZED: gml_st.set_yield %[[MAIN_PAR_REM_FOR_MATMUL]]
// VECTORIZED: gml_st.set_yield %[[REM_FOR]]
// VECTORIZED: %[[REM_RHS_PAR:.*]] = gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[JUB]])
// VECTORIZED: %[[REM_RHS_SLICE:.*]] = tensor.extract_slice %[[MAIN_PAR]]
// VECTORIZED: %[[REM_RHS_FILL:.*]] = linalg.fill{{.*}}outs(%[[REM_RHS_SLICE]]
// VECTORIZED: %[[REM_RHS_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]]) {{.*}} outs ({{.*}} = %[[REM_RHS_FILL]]:
// VECTORIZED: %[[REM_RHS_PAR_MATMUL:.*]] = linalg.matmul
// VECTORIZED: gml_st.set_yield %[[REM_RHS_PAR_MATMUL]]
// VECTORIZED: gml_st.set_yield %[[REM_RHS_FOR]]
// VECTORIZED: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[IUB]], %[[C0]])
// VECTORIZED: %[[REM_LHS_SLICE:.*]] = tensor.extract_slice %[[REM_RHS_PAR]]
// VECTORIZED: %[[REM_LHS_FILL:.*]] = linalg.fill{{.*}}outs(%[[REM_LHS_SLICE]]
// VECTORIZED: %[[REM_LHS_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]]) {{.*}} outs ({{.*}} = %[[REM_LHS_FILL]]:
// VECTORIZED: %[[REM_LHS_PAR_MATMUL:.*]] = linalg.matmul
// VECTORIZED: gml_st.set_yield %[[REM_LHS_PAR_MATMUL]]
// VECTORIZED: gml_st.set_yield %[[REM_LHS_FOR]]
// -----
// MARKED-LABEL: func @matmul(
// MARKED: %[[C0:.*]] = arith.constant 0 : index
// MARKED: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[IUB:.*]], %[[JUB:.*]]) step
// MARKED: gml_st.for (%[[K:.*]]) = (%[[C0]]) to (%[[KUB:.*]]) step
// MARKED: } {__peeling_applied_label__, __perfectly_tiled_loop_label__}
// MARKED: gml_st.for (%[[K:.*]]) = (%[[KUB]])
// MARKED: } {__peeling_applied_label__
// MARKED: } {__peeling_applied_label__
// MARKED: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[JUB]])
// MARKED: gml_st.for (%[[K:.*]]) = (%[[C0]])
// MARKED: }
// MARKED: } {__peeling_applied_label__
// MARKED: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[IUB]], %[[C0]])
// MARKED: gml_st.for (%[[K:.*]]) = (%[[C0]])
// MARKED: }
// MARKED: } {__peeling_applied_label__
// -----
// MMT4D-LABEL: func @matmul(
// MMT4D-NOT: linalg.matmul
// MMT4D: scf.for {{.*}} = %c0 to %[[DIM0:.*]] step %c1
// MMT4D: scf.for {{.*}} = %c0 to %[[DIM1:.*]] step %c1
// MMT4D: vector.transfer_read
// MMT4D: %[[KERNEL:.*]] = scf.for {{.*}} = %c0 to %[[DIM2:.*]] step %c1 {{.*}} -> (vector<1x1x8x8xf32>)
// MMT4D: vector.transfer_read
// MMT4D: vector.transfer_read
// MMT4D: %[[CONTRACT:.*]] = vector.contract
// MMT4D: scf.yield %[[CONTRACT]]
// MMT4D: %[[WRITE:.*]] = vector.transfer_write %[[KERNEL]]

View File

@ -0,0 +1,33 @@
// RUN: xla-opt -hlo-legalize-to-linalg -hlo-xla-runtime-sparsification %s | FileCheck %s
#SparseVector = #sparse_tensor.encoding<{ dimLevelType = ["compressed"] }>
// CHECK-LABEL: func.func @mult_sparse_dense(
// CHECK-SAME: %[[PTR:.*0]]: memref<?xindex>,
// CHECK-SAME: %[[IDX:.*1]]: memref<?xindex>,
// CHECK-SAME: %[[VAL:.*2]]: memref<?xf64>,
// CHECK-SAME: %[[SPEC:.*3]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>
// CHECK-SAME: %[[DENSE:.*4]]: memref<10xf64>) -> memref<10xf64> {
// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[A:.*]] = memref.alloc() {alignment = 64 : i64} : memref<10xf64>
// CHECK: linalg.fill ins(%[[F0]] : f64) outs(%[[A]] : memref<10xf64>)
// CHECK: %[[LO:.*]] = memref.load %[[PTR]][%[[C0]]] : memref<?xindex>
// CHECK: %[[HI:.*]] = memref.load %[[PTR]][%[[C1]]] : memref<?xindex>
// CHECK: scf.for %[[II:.*]] = %[[LO]] to %[[HI]] step %[[C1]] {
// CHECK: %[[I:.*]] = memref.load %[[IDX]][%[[II]]] : memref<?xindex>
// CHECK: %[[T0:.*]] = memref.load %[[VAL]][%[[II]]] : memref<?xf64>
// CHECK: %[[T1:.*]] = memref.load %[[DENSE]][%[[I]]] : memref<10xf64>
// CHECK: %[[T3:.*]] = arith.mulf %[[T0]], %[[T1]] : f64
// CHECK: memref.store %[[T3]], %[[A]][%[[I]]] : memref<10xf64>
// CHECK: }
// CHECK: return %[[A]] : memref<10xf64>
// CHECK: }
func.func @mult_sparse_dense(%arg0: tensor<10xf64, #SparseVector>,
%arg1: tensor<10xf64>)
-> tensor<10xf64> {
%0 = mhlo.multiply %arg0, %arg1 : (tensor<10xf64, #SparseVector>,
tensor<10xf64>) -> tensor<10xf64>
return %0 : tensor<10xf64>
}

View File

@ -35,10 +35,7 @@ bool IsGoogleTensorRTEnabled() {
#else // TF_USE_TENSORRT_STATIC
auto handle_or = se::internal::DsoLoader::TryDlopenTensorRTLibraries();
if (!handle_or.ok()) {
LOG_WARNING_WITH_PREFIX
<< "Cannot dlopen some TensorRT libraries. If you would like "
"to use Nvidia GPU with TensorRT, please make sure the "
"missing libraries mentioned above are installed properly.";
LOG_WARNING_WITH_PREFIX << "Could not find TensorRT";
}
return handle_or.ok();
#endif // TF_USE_TENSORRT_STATIC

View File

@ -3951,7 +3951,7 @@ Status HloEvaluator::HandleReduce(HloInstruction* instr) {
}
}
const int num_threads = tsl::port::MaxParallelism() + 1;
const int num_threads = ShapeUtil::GetForEachIndexParallelThreadCount() + 1;
std::vector<std::unique_ptr<HloEvaluator>> embedded_evaluators;
embedded_evaluators.reserve(num_threads);
for (int i = 0; i < num_threads; ++i) {

View File

@ -1819,7 +1819,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const Shape window_shape = ShapeUtil::MakeShape(
input_arrays[0]->shape().element_type(), window_dimension_sizes);
const int num_threads = tsl::port::MaxParallelism() + 1;
const int num_threads = ShapeUtil::GetForEachIndexParallelThreadCount() + 1;
std::vector<std::unique_ptr<HloEvaluator>> embedded_evaluators;
embedded_evaluators.reserve(num_threads);
for (int i = 0; i < num_threads; ++i) {

View File

@ -1,4 +1,7 @@
cc_binary(
load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_binary")
load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library")
xla_cc_binary(
name = "mlir_replay",
srcs = ["mlir_replay.cc"],
deps = [

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
@ -429,11 +430,13 @@ LogicalResult fuseOutputFill(PatternRewriter& rewriter, Operation* op) {
return success();
}
FailureOr<Operation*> tileAndFuseGreedily(
FailureOr<ParallelOp> tileUsingGmlStParallelAndFuseGreedily(
PatternRewriter& rewriter, Operation* op,
const mlir::gml_st::TilingOptions& opts, StringRef label,
llvm::function_ref<bool(Operation*)> fuseFilterFn) {
auto tilingResult = tile(opts, rewriter, cast<TilingInterface>(op));
assert(opts.distribute == true &&
"gml_st.for should not be used for CPU pipeline");
auto tilingResult = tileUsingGmlSt(opts, rewriter, cast<TilingInterface>(op));
if (failed(tilingResult)) return failure();
// If we did not tile (e.g. when all tile sizes are 0), do not replace
@ -446,7 +449,25 @@ FailureOr<Operation*> tileAndFuseGreedily(
fuseFilterFn);
}
setLabel(tilingResult->tiledOps.front(), label);
return tilingResult->loop;
return cast<ParallelOp>(tilingResult->loop);
}
FailureOr<scf::SCFTilingResult> tileUsingSCFForOpAndFuseGreedily(
PatternRewriter& rewriter, Operation* op, const scf::SCFTilingOptions& opts,
StringRef label, llvm::function_ref<bool(Operation*)> fuseFilterFn) {
auto tilingResult = scf::tileUsingSCFForOp(rewriter, op, opts);
if (failed(tilingResult)) return failure();
// If we did not tile (e.g. when all tile sizes are 0), do not replace
// original op and just mark it as transformed then return.
if (!tilingResult->loops.empty()) {
rewriter.replaceOp(op, tilingResult->replacements);
// Fuse ops into the loop.
fuseGreedily(rewriter, *tilingResult->loops.back().getBody(), fuseFilterFn);
}
setLabel(tilingResult->tiledOps.front(), label);
return tilingResult;
}
LogicalResult tilePeeledOpsToScalars(
@ -464,8 +485,8 @@ LogicalResult tilePeeledOpsToScalars(
opts.setTileSizeComputationFn(SmallVector<int64_t>(
cast<linalg::LinalgOp>(definingOp).getNumLoops(), 1));
if (failed(tileAndFuseGreedily(rewriter, definingOp, opts, label,
fuseFilterFn)))
if (failed(tileUsingGmlStParallelAndFuseGreedily(rewriter, definingOp, opts,
label, fuseFilterFn)))
return failure();
}
return success();

View File

@ -61,12 +61,17 @@ FusionCluster findMapFusionCluster(Operation *op);
// Fuses linalg.fill that is used in init argument of the op.
LogicalResult fuseOutputFill(PatternRewriter &rewriter, Operation *op);
// Tiles the op and fuses greedily according to the filter function.
FailureOr<Operation *> tileAndFuseGreedily(
// Tiles the op to gml_st.parallel and fuses greedily according to the filter.
FailureOr<ParallelOp> tileUsingGmlStParallelAndFuseGreedily(
PatternRewriter &rewriter, Operation *op,
const mlir::gml_st::TilingOptions &opts, StringRef label,
llvm::function_ref<bool(Operation *)> fuseFilterFn);
// Tiles the op to scf.for and fuses greedily according to the filter.
FailureOr<scf::SCFTilingResult> tileUsingSCFForOpAndFuseGreedily(
PatternRewriter &rewriter, Operation *op, const scf::SCFTilingOptions &opts,
StringRef label, llvm::function_ref<bool(Operation *)> fuseFilterFn);
// Tiles the op to 1 for all dimensions and fuses greedily according to the
// filter function.
LogicalResult tilePeeledOpsToScalars(

View File

@ -23,13 +23,10 @@ limitations under the License.
#include "gml_st/IR/gml_st_ops.h"
#include "gml_st/transforms/transforms.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
namespace mlir {
namespace gml_st {
@ -43,25 +40,8 @@ bool hasTensorSemantics(Operation *op) {
llvm::all_of(op->getOperandTypes(), isATensor);
}
/// Rewrite a LoopOp/ParallelOp/ForOp with bounds/step that potentially do not
/// divide evenly into two LoopOp/ParallelOp/ForOps: One where the step divides
/// the iteration space evenly, followed another one for the last (partial)
/// iteration (if any). This function only rewrites the `idx`-th loop of the
/// loop nest represented by the LoopOp/ParallelOp/ForOp. To peel the entire
/// loop nest, this function must be called multiple times.
///
/// This function rewrites the given LoopOp/ParallelOp/ForOp in-place and
/// creates a new LoopOp/ParallelOp/ForOp for the last iteration. It replaces
/// all uses of the original LoopOp/ParallelOp/ForOp with the results of the
/// newly generated one.
///
/// The newly generated LoopOp/ParallelOp/ForOp is returned via `result`. The
/// boundary at which the loop is split (new upper bound) is returned via
/// `splitBound`. The return value indicates whether the
/// LoopOp/ParallelOp/ForOp was rewritten or not.
template <typename LoopTy>
LogicalResult peelLoop(RewriterBase &b, LoopTy loopOp, int64_t idx,
LoopTy &result, Value &splitBound) {
LogicalResult peelLoop(RewriterBase &b, ParallelOp loopOp, int64_t idx,
ParallelOp &result, Value &splitBound) {
if (!hasTensorSemantics(loopOp)) return failure();
Value lb = loopOp.getLowerBound()[idx], ub = loopOp.getUpperBound()[idx],
@ -93,7 +73,7 @@ LogicalResult peelLoop(RewriterBase &b, LoopTy loopOp, int64_t idx,
bvm.map(termDst, res);
}
b.setInsertionPointAfter(loopOp);
auto remainderLoop = cast<LoopTy>(b.clone(*loopOp.getOperation(), bvm));
auto remainderLoop = cast<ParallelOp>(b.clone(*loopOp.getOperation(), bvm));
Operation *remainderLoopOp = remainderLoop.getOperation();
@ -137,17 +117,31 @@ void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, Operation *mainLoop,
});
}
template <typename LoopTy>
FailureOr<LoopTy> peelAndCanonicalizeGmlStLoopImpl(RewriterBase &rewriter,
LoopTy loopOp, int64_t idx) {
} // namespace
PeelingResult peelAllLoops(ParallelOp loop, mlir::PatternRewriter &rewriter) {
setLabel(loop, kPeelingAppliedLabel);
PeelingResult peelingResult;
for (unsigned peeledIdx = 0; peeledIdx < loop.getNumLoops(); ++peeledIdx) {
auto peel = peelAndCanonicalizeGmlStLoop(rewriter, loop, peeledIdx);
if (failed(peel)) continue;
// Mark the new loop if one was created.
setLabel(peel->getOperation(), kPeelingAppliedLabel);
peelingResult.push_back(*peel);
}
return peelingResult;
}
FailureOr<ParallelOp> peelAndCanonicalizeGmlStLoop(RewriterBase &rewriter,
ParallelOp loopOp,
int64_t idx) {
int64_t numLoops = loopOp.getNumLoops();
if (idx < 0 || numLoops <= idx) return failure();
Value ub = loopOp.getUpperBound()[idx];
LoopTy remainderLoop;
ParallelOp remainderLoop;
Value splitBound;
if (failed(
peelLoop<LoopTy>(rewriter, loopOp, idx, remainderLoop, splitBound)))
if (failed(peelLoop(rewriter, loopOp, idx, remainderLoop, splitBound)))
return failure();
// Rewrite affine.min and affine.max ops.
@ -162,39 +156,13 @@ FailureOr<LoopTy> peelAndCanonicalizeGmlStLoopImpl(RewriterBase &rewriter,
return remainderLoop;
}
template <typename LoopTy>
PeelingResult peelAllLoopsImpl(LoopTy loop, mlir::PatternRewriter &rewriter) {
setLabel(loop, kPeelingAppliedLabel);
PeelingResult peelingResult;
for (unsigned peeledIdx = 0; peeledIdx < loop.getNumLoops(); ++peeledIdx) {
auto peel =
peelAndCanonicalizeGmlStLoopImpl<LoopTy>(rewriter, loop, peeledIdx);
if (failed(peel)) continue;
// Mark the new loop if one was created.
setLabel(peel->getOperation(), kPeelingAppliedLabel);
peelingResult.push_back(*peel);
}
return peelingResult;
}
} // namespace
PeelingResult peelAllLoops(ForOp loop, mlir::PatternRewriter &rewriter) {
return peelAllLoopsImpl<ForOp>(loop, rewriter);
}
PeelingResult peelAllLoops(ParallelOp loop, mlir::PatternRewriter &rewriter) {
return peelAllLoopsImpl<ParallelOp>(loop, rewriter);
}
FailureOr<ForOp> peelAndCanonicalizeGmlStLoop(RewriterBase &rewriter,
ForOp loopOp, int64_t idx) {
return peelAndCanonicalizeGmlStLoopImpl<ForOp>(rewriter, loopOp, idx);
}
FailureOr<ParallelOp> peelAndCanonicalizeGmlStLoop(RewriterBase &rewriter,
ParallelOp loopOp,
int64_t idx) {
return peelAndCanonicalizeGmlStLoopImpl<ParallelOp>(rewriter, loopOp, idx);
SCFForPeelingResult peelSCFForOp(RewriterBase &rewriter, scf::ForOp loop) {
// Peeling fails, if the step divides the upper bound. In that case,
// we still want to return {loop, nullptr}.
scf::ForOp tailLoop;
return succeeded(scf::peelAndCanonicalizeForLoop(rewriter, loop, tailLoop))
? SCFForPeelingResult{loop, tailLoop}
: SCFForPeelingResult{loop, nullptr};
}
} // namespace gml_st

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "gml_st/IR/gml_st_ops.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
@ -30,25 +31,22 @@ constexpr llvm::StringRef kPeelingAppliedLabel = "__peeling_applied_label__";
using PeelingResult = SmallVector<Operation *>;
/// Rewrite a gml_st::ParallelOp/ForOp with bounds/step that potentially
/// do not divide evenly into a gml_st::ParallelOp/ForOp where the step
/// divides the iteration space evenly, followed by another
/// gml_st::ParallelOp/ForOp for the last (partial) iteration (if any).
/// This transformation is called "loop peeling".
/// Rewrite a gml_st::ParallelOp with bounds/step that potentially do not divide
/// evenly into a gml_st::ParallelOp where the step divides the iteration space
/// evenly, followed by another gml_st::ParallelOp for the last (partial)
/// iteration (if any). This transformation is called "loop peeling".
///
/// These functions peel all loops in the loop nest by calling
/// peelAndCanonicalizeGmlStLoop. Additionally, they mark all loops (main and
/// remainder loops) as peeled, so the same loop is not rewritten a second time.
PeelingResult peelAllLoops(ForOp loop, mlir::PatternRewriter &rewriter);
PeelingResult peelAllLoops(ParallelOp loop, mlir::PatternRewriter &rewriter);
/// These functions peel the `idx`-th loop of the
/// gml_st::ParallelOp/ForOp. To peel all loops in the loop nest, these
/// functions must be called multiple times.
/// These functions peel the `idx`-th loop of the gml_st::ParallelOp. To peel
/// all loops in the loop nest, these functions must be called multiple times.
///
/// After loop peeling, these functions try to simplify/canonicalize affine.min
/// and affine.max ops in the body of the two gml_st::ParallelOp/ForOps.
/// For more details, refer to `mlir::scf::peelAndCanonicalizeForLoop`.
/// and affine.max ops in the body of the two gml_st::ParallelOps. For more
/// details, refer to `mlir::scf::peelAndCanonicalizeForLoop`.
///
/// The return value indicates whether the loop was rewritten or not. Loops are
/// not rewritten if:
@ -56,18 +54,23 @@ PeelingResult peelAllLoops(ParallelOp loop, mlir::PatternRewriter &rewriter);
/// * Loop bounds and step size are static, and step already divides the
/// iteration space evenly.
///
/// Note: These functions rewrite the given gml_st::ParallelOp/ForOp
/// in-place and clone the gml_st::ParallelOp/ForOp operation for the last
/// iteration. They replace all uses of the unpeeled gml_st::ParallelOp/ForOp
/// with the results of the newly generated gml_st::ParallelOp/ForOp.
/// Note: These functions rewrite the given gml_st::ParallelOp in-place and
/// clone the gml_st::ParallelOp operation for the last iteration. They replace
/// all uses of the unpeeled gml_st::ParallelOp with the results of the newly
/// generated gml_st::ParallelOp.
///
/// Note: These functions do not mark the loops as peeled. This should be
/// handled by the caller.
FailureOr<ForOp> peelAndCanonicalizeGmlStLoop(RewriterBase &rewriter,
ForOp loopOp, int64_t idx);
FailureOr<ParallelOp> peelAndCanonicalizeGmlStLoop(RewriterBase &rewriter,
ParallelOp loopOp,
int64_t idx);
struct SCFForPeelingResult {
scf::ForOp mainLoop = nullptr;
scf::ForOp tailLoop = nullptr;
};
SCFForPeelingResult peelSCFForOp(RewriterBase &rewriter, scf::ForOp);
} // namespace gml_st
} // namespace mlir

View File

@ -175,7 +175,7 @@ struct TilingPattern : public OpInterfaceRewritePattern<TilingInterface> {
if (!filterFn || failed(filterFn(op)) || hasLabel(op, kTileAppliedLabel))
return failure();
auto tilingResult = tile(options, rewriter, op);
auto tilingResult = tileUsingGmlSt(options, rewriter, op);
if (failed(tilingResult)) return failure();
// If we did not tile (e.g. when all tile sizes are 0), do not replace
@ -243,8 +243,9 @@ struct TilingPass : public impl::TilingPassBase<TilingPass> {
} // namespace
FailureOr<TilingResult> tile(const TilingOptions &options,
PatternRewriter &rewriter, TilingInterface op) {
FailureOr<TilingResult> tileUsingGmlSt(const TilingOptions &options,
PatternRewriter &rewriter,
TilingInterface op) {
rewriter.setInsertionPoint(op);
if (!options.tileSizeComputationFn) {
return rewriter.notifyMatchFailure(

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <functional>
#include <string>
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/TilingInterface.h"
@ -53,8 +54,9 @@ struct TilingOptions {
/// Create tiled operation based on the specified tiling options. The result is
/// equivalent to original op.
FailureOr<TilingResult> tile(const TilingOptions &options,
PatternRewriter &rewriter, TilingInterface op);
FailureOr<TilingResult> tileUsingGmlSt(const TilingOptions &options,
PatternRewriter &rewriter,
TilingInterface op);
/// Populate tiling patterns.
void populateTilingPatterns(

View File

@ -178,7 +178,7 @@ struct TilePartialSoftmaxPattern
tilingOptions.distributionLabel = distributionLabel;
// Tile.
FailureOr<TilingResult> tilingResult =
tile(tilingOptions, rewriter, op);
tileUsingGmlSt(tilingOptions, rewriter, op);
if (failed(tilingResult)) return failure();
rewriter.replaceOp(op, tilingResult->loop->getResults());

View File

@ -74,21 +74,19 @@ struct TileMapPattern : public OpRewritePattern<linalg::MapOp> {
return tiles;
};
auto tiledLoop = tileAndFuseGreedily(rewriter, op, opts,
kMapTransformedLabel, fuseFilterFn);
auto tiledLoop = tileUsingGmlStParallelAndFuseGreedily(
rewriter, op, opts, kMapTransformedLabel, fuseFilterFn);
if (failed(tiledLoop)) return failure();
// Peel parallel loops.
if (auto loop = dyn_cast_or_null<ParallelOp>(*tiledLoop)) {
auto peelingResult = peelAllLoops(loop, rewriter);
setLabel(loop, kPerfectlyTiledLoopLabel);
auto peelingResult = peelAllLoops(*tiledLoop, rewriter);
setLabel(*tiledLoop, kPerfectlyTiledLoopLabel);
// Tile ops in the peeled loop again, to size 1, so they can be
// scalarized.
if (failed(tilePeeledOpsToScalars(rewriter, peelingResult,
kMapTransformedLabel, fuseFilterFn)))
return failure();
}
// Tile ops in the peeled loop again, to size 1, so they can be
// scalarized.
if (failed(tilePeeledOpsToScalars(rewriter, peelingResult,
kMapTransformedLabel, fuseFilterFn)))
return failure();
return success();
}

View File

@ -28,7 +28,6 @@ limitations under the License.
#include "gml_st/transforms/tiling/tiling.h"
#include "gml_st/transforms/transforms.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
@ -36,9 +35,9 @@ limitations under the License.
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
@ -363,12 +362,11 @@ struct MatmulToMmt4dPattern : public OpRewritePattern<linalg::MatmulOp> {
};
FailureOr<TilingResult> tileMatmul(PatternRewriter &rewriter, Operation *op,
ArrayRef<int64_t> tileSizes,
bool distribute) {
ArrayRef<int64_t> tileSizes) {
TilingOptions opts;
opts.setTileSizeComputationFn(tileSizes);
opts.distribute = distribute;
return tile(opts, rewriter, cast<TilingInterface>(op));
opts.distribute = true;
return tileUsingGmlSt(opts, rewriter, cast<TilingInterface>(op));
}
/// Splits the tile sizes in `parallelSizes` into `reductionSizes` for the
@ -387,8 +385,9 @@ void splitParallelAndReductionTiles(linalg::LinalgOp op,
}
}
FailureOr<Operation *> tile(PatternRewriter &rewriter, Operation *op,
const scf::SCFTilingOptions &tilingOptions) {
FailureOr<Operation *> tileUsingSCFForAndReplace(
PatternRewriter &rewriter, Operation *op,
const scf::SCFTilingOptions &tilingOptions) {
auto tilingResult = scf::tileUsingSCFForOp(rewriter, op, tilingOptions);
if (failed(tilingResult) || tilingResult->loops.empty()) return failure();
rewriter.replaceOp(op, tilingResult->replacements);
@ -422,13 +421,16 @@ struct Mmt4DTransformPattern : public OpRewritePattern<linalg::Mmt4DOp> {
});
auto *lhsOp = mmt4dOp.getInputs()[0].getDefiningOp();
if (failed(tile(rewriter, lhsOp, packTilingOptions))) return failure();
if (failed(tileUsingSCFForAndReplace(rewriter, lhsOp, packTilingOptions)))
return failure();
auto *rhsOp = mmt4dOp.getInputs()[1].getDefiningOp();
if (failed(tile(rewriter, rhsOp, packTilingOptions))) return failure();
if (failed(tileUsingSCFForAndReplace(rewriter, rhsOp, packTilingOptions)))
return failure();
auto *accOp = mmt4dOp.getOutputs()[0].getDefiningOp();
if (failed(tile(rewriter, accOp, packTilingOptions))) return failure();
if (failed(tileUsingSCFForAndReplace(rewriter, accOp, packTilingOptions)))
return failure();
// Tile tensor.unpack op.
auto unpackTilingOptions =
@ -452,7 +454,9 @@ struct Mmt4DTransformPattern : public OpRewritePattern<linalg::Mmt4DOp> {
});
auto *unpackOp = *mmt4dOp->user_begin();
if (failed(tile(rewriter, unpackOp, unpackTilingOptions))) return failure();
if (failed(
tileUsingSCFForAndReplace(rewriter, unpackOp, unpackTilingOptions)))
return failure();
// Compute the tile sizes. Note that at this stage we only do layout tiling.
// Later we might also want to do traversal tiling (only on M and N dims).
@ -484,15 +488,16 @@ struct Mmt4DTransformPattern : public OpRewritePattern<linalg::Mmt4DOp> {
reductionTileSizes);
// Tile the parallel loops.
auto tiledOp =
tile(rewriter, mmt4dOp,
scf::SCFTilingOptions().setTileSizes(parallelTileSizes));
auto tiledOp = tileUsingSCFForAndReplace(
rewriter, mmt4dOp.getOperation(),
scf::SCFTilingOptions().setTileSizes(parallelTileSizes));
if (failed(tiledOp)) return failure();
mmt4dOp = cast<linalg::Mmt4DOp>(*tiledOp);
// Tile the reduction loops.
tiledOp = tile(rewriter, mmt4dOp,
scf::SCFTilingOptions().setTileSizes(reductionTileSizes));
tiledOp = tileUsingSCFForAndReplace(
rewriter, mmt4dOp.getOperation(),
scf::SCFTilingOptions().setTileSizes(reductionTileSizes));
if (failed(tiledOp)) return failure();
mmt4dOp = cast<linalg::Mmt4DOp>(*tiledOp);
@ -521,7 +526,7 @@ struct MatmulTransformPattern : public OpRewritePattern<linalg::MatmulOp> {
if (hasLabel(matmulOp, kMatmulTransformedLabel))
return rewriter.notifyMatchFailure(matmulOp,
"has already been transformed.");
if (isa<gml_st::ParallelOp, gml_st::ForOp>(matmulOp->getParentOp()))
if (isa<gml_st::ParallelOp, scf::ForOp>(matmulOp->getParentOp()))
return rewriter.notifyMatchFailure(
matmulOp, "has already been tiled by another pass.");
@ -536,8 +541,8 @@ struct MatmulTransformPattern : public OpRewritePattern<linalg::MatmulOp> {
if (isa<linalg::MatmulOp>(tilingRoot)) parallelDimsTileSizes.push_back(0);
// First level tiling: parallel dimensions.
auto tilingParallelDimsResult = tileMatmul(
rewriter, tilingRoot, parallelDimsTileSizes, /*distribute=*/true);
auto tilingParallelDimsResult =
tileMatmul(rewriter, tilingRoot, parallelDimsTileSizes);
if (failed(tilingParallelDimsResult)) return failure();
// Update the results if tiling occurred.
@ -552,7 +557,7 @@ struct MatmulTransformPattern : public OpRewritePattern<linalg::MatmulOp> {
}
// Second level tiling: reduction dimension for matmuls.
SmallVector<TilingResult> tilingReductionDimsResults;
SmallVector<scf::SCFTilingResult> tilingReductionDimsResults;
for (auto op :
llvm::to_vector(tilingRoot->getBlock()->getOps<linalg::MatmulOp>())) {
// Fusion into the output.
@ -560,7 +565,7 @@ struct MatmulTransformPattern : public OpRewritePattern<linalg::MatmulOp> {
auto result = tileMatmulReductionDims(rewriter, op);
if (failed(result)) return failure();
tilingReductionDimsResults.push_back(result.value());
tilingReductionDimsResults.push_back(*result);
}
// Peel parallel loops.
@ -574,27 +579,27 @@ struct MatmulTransformPattern : public OpRewritePattern<linalg::MatmulOp> {
// Peel reduction loop inside the main parallel loop, label the main loop as
// "perfectly tiled" one, to enable vectorization after canonicalization.
for (auto &res : tilingReductionDimsResults) {
if (auto loop = dyn_cast_or_null<ForOp>(res.loop)) {
auto peelingResult = peelAllLoops(loop, rewriter);
setLabel(loop, kPerfectlyTiledLoopLabel);
if (res.loops.size() == 1) {
auto peelingResult = peelSCFForOp(rewriter, res.loops.front());
setLabel(peelingResult.mainLoop, kPerfectlyTiledLoopLabel);
}
}
return success();
}
private:
FailureOr<TilingResult> tileMatmulReductionDims(
FailureOr<scf::SCFTilingResult> tileMatmulReductionDims(
PatternRewriter &rewriter, linalg::MatmulOp matmulOp) const {
SmallVector<int64_t> reductionDimsTileSizes{0, 0, reductionDimTileSize};
auto tilingReductionDimsResult = tileMatmul(
rewriter, matmulOp, reductionDimsTileSizes, /*distribute=*/false);
scf::SCFTilingOptions opts;
opts.setTileSizes(reductionDimsTileSizes);
auto tilingReductionDimsResult =
scf::tileUsingSCFForOp(rewriter, matmulOp.getOperation(), opts);
if (failed(tilingReductionDimsResult)) return failure();
// Update the results if tiling occurred.
if (tilingReductionDimsResult->loop != nullptr) {
rewriter.replaceOp(matmulOp,
tilingReductionDimsResult->loop->getResults());
if (!tilingReductionDimsResult->loops.empty()) {
rewriter.replaceOp(matmulOp, tilingReductionDimsResult->replacements);
matmulOp =
cast<linalg::MatmulOp>(tilingReductionDimsResult->tiledOps.front());
}

View File

@ -49,7 +49,7 @@ FailureOr<TilingResult> tileMatmul(PatternRewriter &rewriter, Operation *op,
opts.setTileSizeComputationFn(tileSizes);
opts.distribute = distribute;
opts.distributionLabel = distributionLabel;
return tile(opts, rewriter, cast<TilingInterface>(op));
return tileUsingGmlSt(opts, rewriter, cast<TilingInterface>(op));
}
/// Pattern to tile `linalg.matmul`, fuse `linalg.fill` into generated

View File

@ -29,6 +29,8 @@ limitations under the License.
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExpr.h"
@ -46,12 +48,12 @@ constexpr llvm::StringRef kReduceTransformedLabel =
FailureOr<TilingResult> tileReduce(PatternRewriter &rewriter,
linalg::ReduceOp reduceOp,
ArrayRef<int64_t> tileSizes,
bool distribute) {
ArrayRef<int64_t> tileSizes) {
TilingOptions opts;
opts.setTileSizeComputationFn(tileSizes);
opts.distribute = distribute;
return tile(opts, rewriter, cast<TilingInterface>(reduceOp.getOperation()));
opts.distribute = true;
return tileUsingGmlSt(opts, rewriter,
cast<TilingInterface>(reduceOp.getOperation()));
}
SmallVector<int64_t> getParallelDimTileSizes(int64_t reductionDim,
@ -106,7 +108,7 @@ struct Reduce1DTransformPattern : public OpRewritePattern<linalg::ReduceOp> {
"has already been transformed.");
}
if (isa<gml_st::ParallelOp, gml_st::ForOp>(reduceOp->getParentOp())) {
if (isa<gml_st::ParallelOp, scf::ForOp>(reduceOp->getParentOp())) {
return rewriter.notifyMatchFailure(
reduceOp, "has already been tiled by another pass.");
}
@ -136,7 +138,6 @@ struct Reduce1DTransformPattern : public OpRewritePattern<linalg::ReduceOp> {
auto fillOp = reduceOp.getInits().front().getDefiningOp<linalg::FillOp>();
if (!fillOp) return failure();
auto neutralValue = fillOp.value();
// .get
// fillOp.getValue();
Type elementType = neutralValue.getType();
@ -149,12 +150,11 @@ struct Reduce1DTransformPattern : public OpRewritePattern<linalg::ReduceOp> {
rewriter.create<linalg::FillOp>(loc, neutralValue, emptyVector)
.getResult(0);
auto tiledLoopBodyBuilder = [&](OpBuilder &b, Location loc, ValueRange ivs,
auto tiledLoopBodyBuilder = [&](OpBuilder &b, Location loc, Value iv,
ValueRange inits) {
// Tile input as tensor<TILE_SIZExELEM_TYPE> and reshape into
// tensor<(TILE_SIZE/VECTOR_SIZE)xVECTOR_SIZExELEM_TYPE>.
Value inputSlice =
tileAndReshapeInput(b, loc, ivs.front(), input, elementType);
Value inputSlice = tileAndReshapeInput(b, loc, iv, input, elementType);
tensor::ExtractSliceOp initSlice = create1DSlice(
b, loc, inits.front(), b.getIndexAttr(0), b.getIndexAttr(vectorSize));
@ -171,18 +171,13 @@ struct Reduce1DTransformPattern : public OpRewritePattern<linalg::ReduceOp> {
rewriter.cloneRegionBefore(reduceOp.getRegion(), region, region.end());
setLabel(tiledReduceOp, kReduceTransformedLabel);
b.create<gml_st::SetYieldOp>(
loc, tiledReduceOp.getResults(), inits,
b.create<TileOp>(loc, initSlice.getMixedOffsets(),
initSlice.getMixedSizes(),
initSlice.getMixedStrides())
.getResult());
b.create<scf::YieldOp>(loc, tiledReduceOp.getResults());
};
// Create a tiled loop
auto tiledLoop = rewriter.create<ForOp>(loc, filledVector.getType(), zero,
tileableBound, tileSizeValue,
filledVector, tiledLoopBodyBuilder);
auto tiledLoop =
rewriter.create<scf::ForOp>(loc, zero, tileableBound, tileSizeValue,
filledVector, tiledLoopBodyBuilder);
setLabel(tiledLoop, kPerfectlyTiledLoopLabel);
// Create `linalg.reduce` from tensor<VECTOR_SIZExELEM_TYPE> to
@ -191,30 +186,25 @@ struct Reduce1DTransformPattern : public OpRewritePattern<linalg::ReduceOp> {
cloneReduceOp(rewriter, reduceOp, tiledLoop.getResult(0),
reduceOp.getInits().front());
auto remainderLoopBodyBuilder = [&](OpBuilder &b, Location loc,
ValueRange ivs, ValueRange inits) {
Value inputSlice =
create1DSlice(b, loc, input, ivs.front(), remainderSize);
auto remainderLoopBodyBuilder = [&](OpBuilder &b, Location loc, Value iv,
ValueRange inits) {
Value inputSlice = create1DSlice(b, loc, input, iv, remainderSize);
Value initSlice = b.create<tensor::ExtractSliceOp>(
loc, inits.front(), /*offsets=*/SmallVector<OpFoldResult>{},
/*sizes=*/SmallVector<OpFoldResult>{},
/*strides=*/SmallVector<OpFoldResult>{});
auto newReduceOp = cloneReduceOp(b, reduceOp, inputSlice, initSlice);
Value initTile = b.create<gml_st::TileOp>(
loc, /*offsets=*/SmallVector<OpFoldResult>{});
b.create<gml_st::SetYieldOp>(loc, newReduceOp, inits, initTile);
auto newReduce = cloneReduceOp(b, reduceOp, inputSlice, initSlice);
b.create<scf::YieldOp>(loc, newReduce);
};
// Combine `horizontal reduce` with the tail of the input. The tail is
// always smaller than TILE_SIZE.
auto remainderLoop =
rewriter
.create<ForOp>(loc, reduceOp.getResultTypes(), tileableBound,
inputSize, tileSizeValue, horizontalReduce,
remainderLoopBodyBuilder)
.create<scf::ForOp>(loc, tileableBound, inputSize, tileSizeValue,
horizontalReduce, remainderLoopBodyBuilder)
.getResult(0);
rewriter.replaceOp(reduceOp, remainderLoop);
@ -329,26 +319,28 @@ struct Reduce2DTransformPattern : public OpRewritePattern<linalg::ReduceOp> {
[&](Operation *op) { return fusionCluster.contains(op); });
// Process all reduces in a fusion cluster.
for (auto op :
for (auto tiledReduceOp :
llvm::to_vector(tilingRoot->getBlock()->getOps<linalg::ReduceOp>())) {
// Fuse Fill.
if (failed(fuseOutputFill(rewriter, op))) return failure();
if (failed(fuseOutputFill(rewriter, tiledReduceOp))) return failure();
// Second level tiling: reduction dimension.
auto tilingReductionDimsResult = tileReductionDims(rewriter, op);
auto tilingReductionDimsResult =
tileReductionDims(rewriter, tiledReduceOp);
if (failed(tilingReductionDimsResult)) return failure();
// Update the results if tiling occurred.
if (tilingReductionDimsResult->loop != nullptr) {
rewriter.replaceOp(op, tilingReductionDimsResult->loop->getResults());
op =
if (!tilingReductionDimsResult->loops.empty()) {
rewriter.replaceOp(tiledReduceOp,
tilingReductionDimsResult->replacements);
tiledReduceOp =
cast<linalg::ReduceOp>(tilingReductionDimsResult->tiledOps.front());
fuseGreedily(rewriter, *op->getBlock(),
fuseGreedily(rewriter, *tiledReduceOp->getBlock(),
[&](Operation *op) { return isa<linalg::MapOp>(op); });
}
setLabel(op, kReduceTransformedLabel);
setLabel(tiledReduceOp, kReduceTransformedLabel);
// Peel parallel loops.
// Peel reduction loops.
if (failed(peelReduction(rewriter, tilingParallelDimsResult.value(),
tilingReductionDimsResult.value())))
return failure();
@ -404,15 +396,14 @@ struct Reduce2DTransformPattern : public OpRewritePattern<linalg::ReduceOp> {
tilingParallelDimsResult =
tileReduce(rewriter, reduceOp,
getParallelDimTileSizes(reduceOp.getDimensions()[0],
parallelDimTileSize),
/*distribute=*/true);
parallelDimTileSize));
} else if (isa<linalg::MapOp>(tilingRoot)) {
TilingOptions opts;
opts.setTileSizeComputationFn({parallelDimTileSize});
opts.distribute = true;
tilingParallelDimsResult =
tile(opts, rewriter, cast<TilingInterface>(tilingRoot));
tileUsingGmlSt(opts, rewriter, cast<TilingInterface>(tilingRoot));
} else {
return failure();
}
@ -420,19 +411,18 @@ struct Reduce2DTransformPattern : public OpRewritePattern<linalg::ReduceOp> {
return tilingParallelDimsResult;
}
FailureOr<TilingResult> tileReductionDims(PatternRewriter &rewriter,
linalg::ReduceOp reduceOp) const {
auto tilingReductionDimsResult =
tileReduce(rewriter, reduceOp,
getReductionDimTileSizes(reduceOp.getDimensions()[0],
reductionDimTileSize),
/*distribute=*/false);
return tilingReductionDimsResult;
FailureOr<scf::SCFTilingResult> tileReductionDims(
PatternRewriter &rewriter, linalg::ReduceOp reduceOp) const {
scf::SCFTilingOptions tilingOptions;
tilingOptions.setTileSizes(getReductionDimTileSizes(
reduceOp.getDimensions()[0], reductionDimTileSize));
return scf::tileUsingSCFForOp(rewriter, reduceOp.getOperation(),
tilingOptions);
}
LogicalResult peelReduction(
PatternRewriter &rewriter, const TilingResult &tilingParallelDimsResult,
const TilingResult &tilingReductionDimsResult) const {
const scf::SCFTilingResult &tilingReductionDimsResult) const {
// Peel parallel loops.
if (auto loop =
dyn_cast_or_null<ParallelOp>(tilingParallelDimsResult.loop)) {
@ -441,36 +431,44 @@ struct Reduce2DTransformPattern : public OpRewritePattern<linalg::ReduceOp> {
// Peel reduction loop inside the main parallel loop, label the main loop as
// "perfectly tiled" one, to enable vectorization after canonicalization.
if (auto forLoop =
dyn_cast_or_null<ForOp>(tilingReductionDimsResult.loop)) {
auto peelingResult = peelAllLoops(forLoop, rewriter);
setLabel(forLoop, kPerfectlyTiledLoopLabel);
if (!tilingReductionDimsResult.loops.empty()) {
scf::ForOp forLoop = tilingReductionDimsResult.loops.front();
SCFForPeelingResult peelingResult = peelSCFForOp(rewriter, forLoop);
if (peelingResult.mainLoop) {
setLabel(peelingResult.mainLoop, kPerfectlyTiledLoopLabel);
}
if (!peelingResult.tailLoop) return success();
// Tile ops in the peeled loop again, to size 1, so they can be
// scalarized.
for (auto *loop : peelingResult) {
ForOp peeledLoop = dyn_cast<ForOp>(loop);
auto *terminatorOp = peeledLoop->getRegion(0).front().getTerminator();
if (!terminatorOp) return failure();
scf::ForOp peeledLoop = peelingResult.tailLoop;
auto yieldOp = cast<scf::YieldOp>(peeledLoop.getBody()->getTerminator());
auto reduceOp = getRootReduce(yieldOp);
if (!reduceOp) return failure();
auto reduceOp =
terminatorOp->getOperand(0).getDefiningOp<linalg::ReduceOp>();
if (!reduceOp) return failure();
scf::SCFTilingOptions opts;
opts.setTileSizes(
getReductionDimTileSizes(reduceOp.getDimensions()[0], 1));
mlir::gml_st::TilingOptions opts;
opts.setTileSizeComputationFn(
getReductionDimTileSizes(reduceOp.getDimensions()[0], 1));
opts.distribute = false;
if (failed(tileAndFuseGreedily(
rewriter, reduceOp, opts, kReduceTransformedLabel,
[&](Operation *op) { return isa<linalg::MapOp>(op); })))
return failure();
}
if (failed(tileUsingSCFForOpAndFuseGreedily(
rewriter, reduceOp, opts, kReduceTransformedLabel,
[&](Operation *op) { return isa<linalg::MapOp>(op); })))
return failure();
}
return success();
}
linalg::ReduceOp getRootReduce(scf::YieldOp yieldOp) const {
if (yieldOp.getResults().size() != 1) return nullptr;
Value reduceResult = yieldOp.getResults().front();
if (auto insertSliceOp =
reduceResult.getDefiningOp<tensor::InsertSliceOp>()) {
reduceResult = insertSliceOp.getSource();
}
return reduceResult.getDefiningOp<linalg::ReduceOp>();
}
int64_t parallelDimTileSize;
int64_t reductionDimTileSize;
};
@ -489,7 +487,8 @@ struct TransformReduceForCpuPass
void getDependentDialects(DialectRegistry &registry) const final {
registry.insert<mlir::gml_st::GmlStDialect, arith::ArithDialect,
linalg::LinalgDialect, tensor::TensorDialect>();
linalg::LinalgDialect, scf::SCFDialect,
tensor::TensorDialect>();
linalg::registerTilingInterfaceExternalModels(registry);
}

View File

@ -44,8 +44,8 @@ FailureOr<TilingResult> tileReverseAndUpdateResultIfTiled(
TilingOptions opts;
opts.setTileSizeComputationFn(tileSizes);
opts.distribute = distribute;
auto tilingResult =
tile(opts, rewriter, cast<TilingInterface>(reverseOp.getOperation()));
auto tilingResult = tileUsingGmlSt(
opts, rewriter, cast<TilingInterface>(reverseOp.getOperation()));
if (failed(tilingResult)) return failure();

View File

@ -19,9 +19,11 @@ limitations under the License.
#include "gml_st/IR/gml_st_ops.h"
#include "gml_st/transforms/passes.h"
#include "gml_st/transforms/tiling/tiling.h"
#include "gml_st/transforms/transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@ -33,23 +35,24 @@ namespace {
#define GEN_PASS_DEF_TRANSFORMSCATTERFORCPUPASS
#include "gml_st/transforms/passes.h.inc"
struct TransformScatterForCpuPass
: public impl::TransformScatterForCpuPassBase<TransformScatterForCpuPass> {
void getDependentDialects(DialectRegistry &registry) const final {
registry.insert<mlir::gml_st::GmlStDialect, arith::ArithDialect,
tensor::TensorDialect>();
linalg::registerTilingInterfaceExternalModels(registry);
}
constexpr llvm::StringRef kScatterTransformedLabel =
"__scatter_transformed_label__";
void runOnOperation() override {
func::FuncOp f = getOperation();
MLIRContext *ctx = &getContext();
struct TileScatterPattern : public OpRewritePattern<thlo::ScatterOp> {
using OpRewritePattern<thlo::ScatterOp>::OpRewritePattern;
mlir::gml_st::TilingOptions opts;
opts.distribute = false; // Tile to `for` loops.
LogicalResult matchAndRewrite(thlo::ScatterOp op,
PatternRewriter &rewriter) const override {
if (hasLabel(op, kScatterTransformedLabel)) return failure();
if (isa<scf::ForOp>(op->getParentOp())) {
return rewriter.notifyMatchFailure(
op, "has already been tiled by another pass.");
}
// Tile everything to points.
opts.tileSizeComputationFn = [](OpBuilder &b, Operation *op) {
scf::SCFTilingOptions opts;
opts.setTileSizeComputationFunction([](OpBuilder &b, Operation *op) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(
&op->getParentOfType<func::FuncOp>().getBody().front());
@ -57,22 +60,43 @@ struct TransformScatterForCpuPass
auto loops = cast<TilingInterface>(op).getLoopIteratorTypes();
return SmallVector<Value>(
loops.size(), b.create<arith::ConstantIndexOp>(op->getLoc(), 1));
};
});
auto filterFn = [&](Operation *op) {
if (isa<mlir::thlo::ScatterOp>(op))
return success();
return failure();
};
auto tilingResult = scf::tileUsingSCFForOp(
rewriter, cast<TilingInterface>(op.getOperation()), opts);
if (failed(tilingResult)) return failure();
// If we did not tile, do not replace original op and just mark it as
// transformed then return.
if (!tilingResult->loops.empty()) {
rewriter.replaceOp(op, tilingResult->replacements);
}
setLabel(tilingResult->tiledOps.front(), kScatterTransformedLabel);
return success();
}
};
struct TransformScatterForCpuPass
: public impl::TransformScatterForCpuPassBase<TransformScatterForCpuPass> {
void getDependentDialects(DialectRegistry &registry) const final {
registry.insert<arith::ArithDialect, gml_st::GmlStDialect, scf::SCFDialect,
tensor::TensorDialect>();
}
void runOnOperation() override {
func::FuncOp f = getOperation();
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateTilingPatterns(ctx, filterFn, opts, &patterns);
patterns.add<TileScatterPattern>(ctx);
if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) {
if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns))))
return signalPassFailure();
}
removeTilingLabels(f);
// Ensure we drop the marker in the end.
f.walk([](thlo::ScatterOp scatterOp) {
removeLabel(scatterOp, kScatterTransformedLabel);
});
}
};

View File

@ -53,8 +53,8 @@ struct TileSortPattern : public OpRewritePattern<SortOp> {
return rewriter.notifyMatchFailure(
op, "has already been tiled by another pass.");
auto tilingResult =
tile(options, rewriter, cast<TilingInterface>(op.getOperation()));
auto tilingResult = tileUsingGmlSt(
options, rewriter, cast<TilingInterface>(op.getOperation()));
if (failed(tilingResult)) return failure();
// If we did not tile (e.g. when all tile sizes are 0), do not replace

View File

@ -56,8 +56,8 @@ struct TileTransposePattern : public OpRewritePattern<linalg::TransposeOp> {
return rewriter.notifyMatchFailure(
op, "has already been tiled by another pass.");
auto tilingResult =
tile(options, rewriter, cast<TilingInterface>(op.getOperation()));
auto tilingResult = tileUsingGmlSt(
options, rewriter, cast<TilingInterface>(op.getOperation()));
if (failed(tilingResult)) return failure();
// If we did not tile (e.g. when all tile sizes are 0), do not replace

View File

@ -868,7 +868,7 @@ struct VectorizePerfectlyTiledLoopsPass
});
};
auto isPerfectlyTiledLoop = [&](Operation *op) {
return (isa<ForOp>(op) || isa<ParallelOp>(op)) &&
return (isa<ForOp, ParallelOp, scf::ForOp>(op)) &&
hasLabel(op, kPerfectlyTiledLoopLabel);
};
auto isInsidePerfectlyTiledLoop = [&](Operation *op) {

View File

@ -18,10 +18,10 @@ func.func @matmul_static(%arg0: tensor<128x16xf32>, %arg1: tensor<16x64xf32>,
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
// CHECK: %[[FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]])
// CHECK: %[[FOR:.*]] = scf.for %[[K:.*]] = %[[C0]]
// CHECK: %[[MATMUL:.*]] = linalg.matmul
// CHECK-SAME: -> tensor<8x4xf32>
// CHECK: gml_st.set_yield %[[MATMUL]]
// CHECK: scf.yield %[[MATMUL]]
// CHECK: gml_st.set_yield %[[FOR]]
// -----
@ -59,28 +59,36 @@ func.func @matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>)
// TRANSFORMED: %[[MAIN_PAR:.*]] = gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[IUB:.*]], %[[JUB:.*]]) step
// TRANSFORMED: %[[MAIN_SLICE:.*]] = tensor.extract_slice %[[INIT]]
// TRANSFORMED: %[[MAIN_FILL:.*]] = linalg.fill{{.*}}outs(%[[MAIN_SLICE]]
// TRANSFORMED: %[[MAIN_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]]) to (%[[KUB:.*]]) {{.*}} outs ({{.*}} = %[[MAIN_FILL]]:
// TRANSFORMED: %[[MAIN_FOR:.*]] = scf.for %[[K:.*]] = %[[C0]] to %[[KUB:[a-z0-9]+]]
// TRANSFORMED-SAME: iter_args(%{{.*}} = %[[MAIN_FILL]])
// TRANSFORMED: %[[MAIN_PAR_MAIN_FOR_MATMUL:.*]] = linalg.matmul
// TRANSFORMED: gml_st.set_yield %[[MAIN_PAR_MAIN_FOR_MATMUL]]
// TRANSFORMED: %[[REM_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[KUB]]) {{.*}} outs ({{.*}} = %[[MAIN_FOR]]:
// TRANSFORMED: %[[UPDATE:.*]] = tensor.insert_slice %[[MAIN_PAR_MAIN_FOR_MATMUL]]
// TRANSFORMED-NEXT: scf.yield %[[UPDATE]]
// TRANSFORMED: %[[REM_FOR:.*]] = scf.for %[[K:.*]] = %[[KUB]]
// TRANSFORMED-SAME: iter_args(%{{.*}} = %[[MAIN_FOR]])
// TRANSFORMED: %[[MAIN_PAR_REM_FOR_MATMUL:.*]] = linalg.matmul
// TRANSFORMED: gml_st.set_yield %[[MAIN_PAR_REM_FOR_MATMUL]]
// TRANSFORMED: %[[UPDATE:.*]] = tensor.insert_slice %[[MAIN_PAR_REM_FOR_MATMUL]]
// TRANSFORMED-NEXT: scf.yield %[[UPDATE]]
// TRANSFORMED: gml_st.set_yield %[[REM_FOR]]
// TRANSFORMED: %[[REM_RHS_PAR:.*]] = gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[JUB]])
// TRANSFORMED: %[[REM_RHS_SLICE:.*]] = tensor.extract_slice %[[MAIN_PAR]]
// TRANSFORMED: %[[REM_RHS_FILL:.*]] = linalg.fill{{.*}}outs(%[[REM_RHS_SLICE]]
// TRANSFORMED: %[[REM_RHS_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]]) {{.*}} outs ({{.*}} = %[[REM_RHS_FILL]]:
// TRANSFORMED: %[[REM_RHS_FOR:.*]] = scf.for %[[K:.*]] = %[[C0]]
// TRANSFORMED-SAME: iter_args({{.*}} = %[[REM_RHS_FILL]])
// TRANSFORMED: %[[REM_RHS_PAR_MATMUL:.*]] = linalg.matmul
// TRANSFORMED: gml_st.set_yield %[[REM_RHS_PAR_MATMUL]]
// TRANSFORMED: %[[UPDATE:.*]] = tensor.insert_slice %[[REM_RHS_PAR_MATMUL]]
// TRANSFORMED-NEXT: scf.yield %[[UPDATE]]
// TRANSFORMED: gml_st.set_yield %[[REM_RHS_FOR]]
// TRANSFORMED: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[IUB]], %[[C0]])
// TRANSFORMED: %[[REM_LHS_SLICE:.*]] = tensor.extract_slice %[[REM_RHS_PAR]]
// TRANSFORMED: %[[REM_LHS_FILL:.*]] = linalg.fill{{.*}}outs(%[[REM_LHS_SLICE]]
// TRANSFORMED: %[[REM_LHS_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]]) {{.*}} outs ({{.*}} = %[[REM_LHS_FILL]]:
// TRANSFORMED: %[[REM_LHS_FOR:.*]] = scf.for %[[K:.*]] = %[[C0]]
// TRANSFORMED-SAME: iter_args({{.*}} = %[[REM_LHS_FILL]])
// TRANSFORMED: %[[REM_LHS_PAR_MATMUL:.*]] = linalg.matmul
// TRANSFORMED: gml_st.set_yield %[[REM_LHS_PAR_MATMUL]]
// TRANSFORMED: %[[UPDATE:.*]] = tensor.insert_slice %[[REM_LHS_PAR_MATMUL]]
// TRANSFORMED-NEXT: scf.yield %[[UPDATE]]
// TRANSFORMED: gml_st.set_yield %[[REM_LHS_FOR]]
// -----
@ -89,15 +97,15 @@ func.func @matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>)
// MARKED: %[[C0:.*]] = arith.constant 0 : index
// MARKED: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[IUB:.*]], %[[JUB:.*]]) step
// MARKED: gml_st.for (%[[K:.*]]) = (%[[C0]]) to (%[[KUB:.*]]) step
// MARKED: scf.for %[[K:.*]] = %[[C0]] to %[[KUB:.*]] step
// MARKED: __perfectly_tiled_loop_label__
// MARKED: gml_st.for (%[[K:.*]]) = (%[[KUB]])
// MARKED: scf.for %[[K:.*]] = %[[KUB]]
// MARKED: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[JUB]])
// MARKED: gml_st.for (%[[K:.*]]) = (%[[C0]])
// MARKED: scf.for %[[K:.*]] = %[[C0]]
// MARKED: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[IUB]], %[[C0]])
// MARKED: gml_st.for (%[[K:.*]]) = (%[[C0]])
// MARKED: scf.for %[[K:.*]] = %[[C0]]
// -----
@ -139,21 +147,25 @@ func.func @matmul_fuse_output(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
// CHECK: gml_st.for (%[[K:.*]]) = (%[[C0]])
// CHECK: scf.for %[[K:.*]] = %[[C0]]
// CHECK: %[[MATMUL:.*]] = linalg.matmul
// CHECK: gml_st.set_yield %[[MATMUL]]
// CHECK: %[[UPDATE:.*]] = tensor.insert_slice %[[MATMUL]]
// CHECK-NEXT: scf.yield %[[UPDATE]]
// CHECK: gml_st.for
// CHECK: scf.for
// CHECK: %[[MATMUL:.*]] = linalg.matmul
// CHECK: gml_st.set_yield %[[MATMUL]]
// CHECK: %[[UPDATE:.*]] = tensor.insert_slice %[[MATMUL]]
// CHECK-NEXT: scf.yield %[[UPDATE]]
// CHECK: gml_st.for (%[[K:.*]]) = (%[[C0]])
// CHECK: scf.for %[[K:.*]] = %[[C0]]
// CHECK: %[[MATMUL:.*]] = linalg.matmul
// CHECK: gml_st.set_yield %[[MATMUL]]
// CHECK: %[[UPDATE:.*]] = tensor.insert_slice %[[MATMUL]]
// CHECK-NEXT: scf.yield %[[UPDATE]]
// CHECK: gml_st.for
// CHECK: scf.for
// CHECK: %[[MATMUL:.*]] = linalg.matmul
// CHECK: gml_st.set_yield %[[MATMUL]]
// CHECK: %[[UPDATE:.*]] = tensor.insert_slice %[[MATMUL]]
// CHECK-NEXT: scf.yield %[[UPDATE]]
// CHECK: linalg.map
// CHECK: linalg.map
@ -161,23 +173,27 @@ func.func @matmul_fuse_output(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
// CHECK: gml_st.set_yield
// CHECK: gml_st.parallel
// CHECK: gml_st.for
// CHECK: scf.for
// CHECK: %[[MATMUL:.*]] = linalg.matmul
// CHECK: gml_st.set_yield %[[MATMUL]]
// CHECK: gml_st.for
// CHECK: %[[UPDATE:.*]] = tensor.insert_slice %[[MATMUL]]
// CHECK-NEXT: scf.yield %[[UPDATE]]
// CHECK: scf.for
// CHECK: %[[MATMUL:.*]] = linalg.matmul
// CHECK: gml_st.set_yield %[[MATMUL]]
// CHECK: %[[UPDATE:.*]] = tensor.insert_slice %[[MATMUL]]
// CHECK-NEXT: scf.yield %[[UPDATE]]
// CHECK: linalg.map
// CHECK: linalg.map
// CHECK: gml_st.set_yield
// CHECK: gml_st.parallel
// CHECK: gml_st.for
// CHECK: scf.for
// CHECK: %[[MATMUL:.*]] = linalg.matmul
// CHECK: gml_st.set_yield %[[MATMUL]]
// CHECK: gml_st.for
// CHECK: %[[UPDATE:.*]] = tensor.insert_slice %[[MATMUL]]
// CHECK-NEXT: scf.yield %[[UPDATE]]
// CHECK: scf.for
// CHECK: %[[MATMUL:.*]] = linalg.matmul
// CHECK: gml_st.set_yield %[[MATMUL]]
// CHECK: %[[UPDATE:.*]] = tensor.insert_slice %[[MATMUL]]
// CHECK-NEXT: scf.yield %[[UPDATE]]
// CHECK: linalg.map
// CHECK: linalg.map
// CHECK: gml_st.set_yield

View File

@ -1,11 +1,6 @@
// RUN: mlir-hlo-opt %s \
// RUN: mlir-hlo-opt %s --split-input-file \
// RUN: -xla-cpu-transform-reduce="vector-size=8 tile-size-1d=32 tile-sizes-2d=4,2" \
// RUN: --split-input-file \
// RUN: | FileCheck %s --check-prefixes=CHECK,PEELED
// RUN: mlir-hlo-opt %s \
// RUN: -xla-cpu-transform-reduce="vector-size=8 tile-size-1d=32 tile-sizes-2d=4,2" \
// RUN: --split-input-file \
// RUN: | FileCheck %s --check-prefixes=MARKED
// RUN: | FileCheck %s
func.func @reduce_add_static(%input: tensor<100x10xf32>,
%output: tensor<10xf32>) -> tensor<10xf32> {
@ -15,29 +10,14 @@ func.func @reduce_add_static(%input: tensor<100x10xf32>,
dimensions = [0]
return %res : tensor<10xf32>
}
// CHECK-LABEL: func @reduce_add_static
// CHECK-LABEL: func @reduce_add_static(
// CHECK-SAME: %[[IN:.*]]: tensor<100x10xf32>,
// CHECK-SAME: %[[OUT:.*]]: tensor<10xf32>)
// CHECK-SAME: -> tensor<10xf32> {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: gml_st.parallel (%[[I:.*]]) = (%[[C0]])
// CHECK: %[[IN_SLICE_1:.*]] = tensor.extract_slice %[[IN]]
// CHECK: %[[OUT_SLICE_1:.*]] = tensor.extract_slice %[[OUT]]
// CHECK: %[[FOR:.*]] = gml_st.for (%[[J:.*]]) = (%[[C0]])
// CHECK: %[[IN_SLICE_2:.*]] = tensor.extract_slice
// CHECK: %[[OUT_SLICE_2:.*]] = tensor.extract_slice
// CHECK: %[[REDUCED:.*]] = linalg.reduce
// CHECK-SAME: ins(%[[IN_SLICE_2]] : tensor<2x?xf32>)
// CHECK-SAME: outs(%[[OUT_SLICE_2]] : tensor<?xf32>)
// CHECK-SAME: dimensions = [0]
// CHECK: gml_st.set_yield %[[REDUCED]]
// CHECK: gml_st.set_yield %[[FOR]]
// CHECK: gml_st.parallel
// CHECK: scf.for
// CHECK: linalg.reduce
// CHECK-NEXT: tensor.insert_slice
// CHECK-NEXT: scf.yield
// CHECK: gml_st.set_yield
// -----
@ -56,54 +36,23 @@ func.func @reduce_mulf(%input: tensor<?x?xf32>,
return %res : tensor<?xf32>
}
// PEELED-LABEL: func @reduce_mulf(
// PEELED-SAME: %[[IN:.*]]: tensor<?x?xf32>,
// PEELED-SAME: %[[OUT:.*]]: tensor<?xf32>)
// PEELED-SAME: -> tensor<?xf32> {
// CHECK-LABEL: func @reduce_mulf
// PEELED-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// PEELED-DAG: %[[C0:.*]] = arith.constant 0 : index
// PEELED-DAG: %[[C1:.*]] = arith.constant 1 : index
// PEELED: %[[INIT:.*]] = tensor.empty
// PEELED: %[[DIM0:.*]] = tensor.dim %[[IN]], %[[C0]]
// CHECK: gml_st.parallel
// CHECK: linalg.fill
// CHECK: scf.for
// CHECK: scf.yield
// CHECK: scf.for
// CHECK: scf.for
// CHECK: scf.yield
// CHECK: scf.yield
// CHECK: gml_st.set_yield
// PEELED: %[[MAIN_PAR:.*]] = gml_st.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[IUB:.*]]) step
// PEELED: %[[MAIN_SLICE:.*]] = tensor.extract_slice %[[INIT]]
// PEELED: %[[MAIN_FILL:.*]] = linalg.fill{{.*}}outs(%[[MAIN_SLICE]]
// PEELED: %[[MAIN_FOR:.*]] = gml_st.for (%[[J:.*]]) = (%[[C0]]) to (%[[JUB:.*]]) {{.*}} outs ({{.*}} = %[[MAIN_FILL]]:
// PEELED: %[[MAIN_PAR_MAIN_FOR_REDUCE:.*]] = linalg.reduce
// PEELED: gml_st.set_yield %[[MAIN_PAR_MAIN_FOR_REDUCE]]
// PEELED: %[[REM_FOR:.*]] = gml_st.for (%[[J:.*]]) = (%[[JUB]]) {{.*}} outs (%[[REM_FOR_ARG:.*]] = %[[MAIN_FOR]]:
// PEELED: %[[REM_FOR_SLICE:.*]] = tensor.extract_slice %[[REM_FOR_ARG]]
// PEELED: %[[SCALAR_REM_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]]) {{.*}} outs ({{.*}} = %[[REM_FOR_SLICE]]:
// PEELED: %[[MAIN_PAR_REM_FOR_REDUCE:.*]] = linalg.reduce
// PEELED: gml_st.set_yield %[[MAIN_PAR_REM_FOR_REDUCE]]
// PEELED: gml_st.set_yield %[[SCALAR_REM_FOR]]
// PEELED: gml_st.set_yield %[[REM_FOR]]
// PEELED: %[[REM_PAR:.*]] = gml_st.parallel (%[[I:.*]]) = (%[[IUB]])
// PEELED: %[[REM_SLICE:.*]] = tensor.extract_slice %[[MAIN_PAR]]
// PEELED: %[[REM_FILL:.*]] = linalg.fill{{.*}}outs(%[[REM_SLICE]]
// PEELED: %[[REM_FOR:.*]] = gml_st.for (%[[J:.*]]) = (%[[C0]]) {{.*}} outs ({{.*}} = %[[REM_FILL]]:
// PEELED: %[[REM_PAR_REDUCE:.*]] = linalg.reduce
// PEELED: gml_st.set_yield %[[REM_PAR_REDUCE]]
// PEELED: gml_st.set_yield %[[REM_FOR]]
// -----
// MARKED-LABEL: func @reduce_mulf(
// MARKED: %[[C0:.*]] = arith.constant 0 : index
// MARKED: gml_st.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[IUB:.*]]) step
// MARKED: gml_st.for (%[[J:.*]]) = (%[[C0]]) to (%[[JUB:.*]]) step
// MARKED: __perfectly_tiled_loop_label__
// MARKED: gml_st.for (%[[J:.*]]) = (%[[JUB]])
// MARKED: } {__peeling_applied_label__}
// MARKED: } {__peeling_applied_label__}
// MARKED: gml_st.parallel (%[[I:.*]]) = (%[[IUB]])
// MARKED: gml_st.for (%[[J:.*]]) = (%[[C0]])
// MARKED: } {__peeling_applied_label__}
// MARKED: } {__peeling_applied_label__}
// CHECK: gml_st.parallel
// CHECK: linalg.fill
// CHECK: scf.for
// CHECK: scf.yield
// CHECK: gml_st.set_yield
// -----
@ -120,37 +69,19 @@ func.func @reduce_map_fuse(%arg0: tensor<10x100xf32>,
return %res : tensor<10xf32>
}
// CHECK-LABEL: func @reduce_map_fuse(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<10x100xf32>,
// CHECK-SAME: %[[ARG1:.*]]: tensor<10x100xf32>,
// CHECK-SAME: %[[OUT:.*]]: tensor<10xf32>)
// CHECK-SAME: -> tensor<10xf32> {
// CHECK-LABEL: func @reduce_map_fuse
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[INIT:.*]] = tensor.empty
// CHECK: gml_st.parallel
// CHECK: scf.for
// CHECK: linalg.map
// CHECK: linalg.reduce
// CHECK: gml_st.set_yield
// CHECK: gml_st.parallel (%[[I:.*]]) = (%[[C0]])
// CHECK: %[[ARG0_SLICE_1:.*]] = tensor.extract_slice %[[ARG0]]
// CHECK: %[[ARG1_SLICE_1:.*]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[INIT_SLICE_1:.*]] = tensor.extract_slice %[[INIT]]
// CHECK: %[[OUT_SLICE_1:.*]] = tensor.extract_slice %[[OUT]]
// CHECK: %[[FOR:.*]] = gml_st.for (%[[J:.*]]) = (%[[C0]])
// CHECK: %[[ARG0_SLICE_2:.*]] = tensor.extract_slice
// CHECK: %[[ARG1_SLICE_2:.*]] = tensor.extract_slice
// CHECK: %[[INIT_SLICE_2:.*]] = tensor.extract_slice
// CHECK: %[[MAPPED:.*]] = linalg.map
// CHECK-SAME: ins(%[[ARG0_SLICE_2]], %[[ARG1_SLICE_2]]
// CHECK-SAME: outs(%[[INIT_SLICE_2]] : tensor<?x2xf32>)
// CHECK: %[[OUT_SLICE_2:.*]] = tensor.extract_slice
// CHECK: %[[REDUCED:.*]] = linalg.reduce
// CHECK-SAME: ins(%[[MAPPED]] : tensor<?x2xf32>)
// CHECK-SAME: outs(%[[OUT_SLICE_2]] : tensor<?xf32>)
// CHECK-SAME: dimensions = [1]
// CHECK: gml_st.set_yield %[[REDUCED]]
// CHECK: gml_st.set_yield %[[FOR]]
// CHECK: gml_st.parallel
// CHECK: scf.for
// CHECK: linalg.map
// CHECK: linalg.reduce
// CHECK: gml_st.set_yield
// -----
@ -179,31 +110,29 @@ func.func @reduce_1d_static(%arg0: tensor<100xf32>) -> tensor<f32> {
// CHECK-DAG: %[[FILL0:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMP0]] : tensor<f32>)
// CHECK-DAG: %[[FILL1:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMP1]] : tensor<8xf32>)
// CHECK: %[[TILE_RESULT:.*]] = gml_st.for (%[[I:.*]]) = (%[[C0]]) to
// CHECK-SAME: (%[[C96]]) step (%[[C32]]) outs (%[[ACC:.*]] = %[[FILL1]]
// CHECK: %[[TILE_RESULT:.*]] = scf.for %[[I:.*]] = %[[C0]] to
// CHECK-SAME: %[[C96]] step %[[C32]] iter_args(%[[ACC:.*]] = %[[FILL1]]
// CHECK: %[[INPUT_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[I]]] [32] [1]
// CHECK: %[[SHAPED_SLICE:.*]] = tensor.expand_shape %[[INPUT_SLICE]]
// CHECK: %[[TILED_REDUCE:.*]] = linalg.reduce
// CHECK-SAME: ins(%[[SHAPED_SLICE]]
// CHECK-SAME: outs(%[[ACC]]
// CHECK-SAME: dimensions = [0]
// CHECK: gml_st.set_yield %[[TILED_REDUCE]]
// CHECK-NEXT: } : tensor<8xf32>
// CHECK: scf.yield %[[TILED_REDUCE]]
// CHECK: %[[HORIZONTAL_REDUCE:.*]] = linalg.reduce
// CHECK-SAME: ins(%[[TILE_RESULT]]
// CHECK-SAME: outs(%[[FILL0]]
// CHECK-SAME: dimensions = [0]
// CHECK: %[[REMAINDER_RESULT:.*]] = gml_st.for (%[[J:.*]]) = (%[[C96]]) to
// CHECK-SAME: (%[[C100]]) step (%[[C32]]) outs (%[[ACC1:.*]] = %[[HORIZONTAL_REDUCE]]
// CHECK: %[[REMAINDER_RESULT:.*]] = scf.for %[[J:.*]] = %[[C96]] to
// CHECK-SAME: %[[C100]] step %[[C32]] iter_args(%[[ACC1:.*]] = %[[HORIZONTAL_REDUCE]]
// CHECK: %[[INPUT_SLICE1:.*]] = tensor.extract_slice %[[ARG0]][%[[J]]] [%[[C4]]] [1]
// CHECK: %[[REMAINDER_REDUCE:.*]] = linalg.reduce
// CHECK-SAME: ins(%[[INPUT_SLICE1]]
// CHECK-SAME: outs(%[[ACC1]]
// CHECK-SAME: dimensions = [0]
// CHECK: gml_st.set_yield %[[REMAINDER_REDUCE]]
// CHECK-NEXT: } : tensor<f32>
// CHECK: scf.yield %[[REMAINDER_REDUCE]]
// CHECK: return %[[REMAINDER_RESULT]]
// -----
@ -231,15 +160,15 @@ func.func @reduce_1d_dynamic(%arg0: tensor<?xf32>) -> tensor<f32> {
// CHECK-DAG: %[[TILABLE_BOUND:.*]] = affine.apply #map()[%[[INPUT_SIZE]]]
// CHECK-DAG: %[[REMAINDER_SIZE:.*]] = affine.apply #map1()[%[[TILABLE_BOUND]], %[[INPUT_SIZE]]]
// CHECK: %[[TILE_RESULT:.*]] = gml_st.for (%[[I:.*]]) = (%[[C0]]) to
// CHECK-SAME: (%[[TILABLE_BOUND]]) step (%[[C32]])
// CHECK: %[[TILE_RESULT:.*]] = scf.for %[[I:.*]] = %[[C0]] to
// CHECK-SAME: %[[TILABLE_BOUND]] step %[[C32]]
// CHECK: %[[TILED_REDUCE:.*]] = linalg.reduce
// CHECK: __perfectly_tiled_loop_label__
// CHECK: %[[HORIZONTAL_REDUCE:.*]] = linalg.reduce
// CHECK: %[[REMAINDER_RESULT:.*]] = gml_st.for (%[[J:.*]]) = (%[[TILABLE_BOUND]]) to
// CHECK-SAME: (%[[INPUT_SIZE]]) step (%[[C32]]) outs (%[[ACC1:.*]] = %[[HORIZONTAL_REDUCE]]
// CHECK: %[[REMAINDER_RESULT:.*]] = scf.for %[[J:.*]] = %[[TILABLE_BOUND]] to
// CHECK-SAME: %[[INPUT_SIZE]] step %[[C32]] iter_args(%[[ACC1:.*]] = %[[HORIZONTAL_REDUCE]]
// CHECK: %[[INPUT_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[J]]] [%[[REMAINDER_SIZE]]] [1]
// CHECK: return %[[REMAINDER_RESULT]]
@ -265,23 +194,20 @@ func.func @reduce_map_fuse_map(%arg0: tensor<10x100xf32>,
return %res : tensor<10xf32>
}
// CHECK-LABEL: func @reduce_map_fuse_map(
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK-LABEL: func @reduce_map_fuse_map
// CHECK: gml_st.parallel (%[[I:.*]]) = (%[[C0]])
// CHECK: gml_st.for
// CHECK: %[[MAP:.*]] = linalg.map
// CHECK: %[[REDUCE:.*]] = linalg.reduce
// CHECK: gml_st.set_yield %[[REDUCE]]
// CHECK: gml_st.parallel
// CHECK: scf.for
// CHECK: linalg.map
// CHECK: linalg.reduce
// CHECK: scf.yield
// CHECK: linalg.map
// CHECK: gml_st.set_yield
// CHECK: %[[MAP:.*]] = linalg.map
// CHECK: gml_st.set_yield %[[MAP]]
// CHECK: gml_st.parallel
// CHECK: gml_st.for
// CHECK: %[[MAP:.*]] = linalg.map
// CHECK: %[[REDUCE:.*]] = linalg.reduce
// CHECK: gml_st.set_yield %[[REDUCE]]
// CHECK: %[[MAP:.*]] = linalg.map
// CHECK: gml_st.set_yield %[[MAP]]
// CHECK: gml_st.parallel
// CHECK: scf.for
// CHECK: linalg.map
// CHECK: linalg.reduce
// CHECK: scf.yield
// CHECK: linalg.map
// CHECK: gml_st.set_yield

View File

@ -15,6 +15,6 @@ func.func @scatter_small_vector_dim(%indices: tensor<?x2xindex>,
}
// CHECK-LABEL: @scatter_small_vector_dim
// CHECK: gml_st.for
// CHECK: scf.for
// CHECK: thlo.scatter
// CHECK-SAME: ins(%{{.*}} : tensor<1x2xindex>, %{{.*}} : tensor<1x?x?xf32>)

View File

@ -2337,6 +2337,14 @@ func.func @dynamic_slice_mismatch_indices(%arg0: tensor<3x4xi32>, %arg1: tensor<
// -----
func.func @dynamic_slice_mismatch_indices_element_type(%arg0: tensor<3x4xi32>, %arg1: tensor<i32>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
// expected-error@+1 {{start indices must have same element type}}
%0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i32>, tensor<i64>) -> tensor<1x4xi32>
func.return %0 : tensor<1x4xi32>
}
// -----
func.func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>) -> tensor<1x4xi32> {
// expected-error@+1 {{has mismatched number of start indices (1) and the rank of operand (2)}}
%0 = "mhlo.dynamic_slice"(%arg0, %arg1) {slice_sizes = dense<[1]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor<i64>) -> tensor<1x4xi32>

View File

@ -55,9 +55,6 @@ LogicalBufferProto BufferValue::ToProto(const SizeFunction& size_fn) const {
if (has_color()) {
proto.set_color(color());
}
// TODO(b/239098765): Stop populating these fields and delete them when
// profiler finishes adaptation.
proto.mutable_defined_at()->set_instruction_name(instruction()->name());
return proto;
}

View File

@ -3384,5 +3384,83 @@ ROOT %arg_tuple.1 = (f32[]{:T(256)}, f32[]{:T(256)}) parameter(0), parameter_rep
VLOG(2) << module->ToString();
}
TEST_F(CopyInsertionTest, AsyncCallDUSNoCopy) {
const char* const kModuleString = R"(
HloModule async_call
%called_computation {
%out_param = s32[1024]{0} parameter(1)
%input = s32[1024]{0} parameter(0)
%size = s32[] constant(256)
%index = s32[] custom-call(), custom_call_target="Baz"
%start = s32[] multiply(s32[] %size, s32[] %index)
%input2 = s32[256]{0} dynamic-slice(s32[1024]{0} %input, s32[] %start), dynamic_slice_sizes={256}
%output = s32[256]{0} add(s32[256]{0} %input2, s32[256]{0} %input2)
ROOT %output2 = s32[1024]{0} dynamic-update-slice(s32[1024]{0} %out_param, s32[256]{0} %output, s32[] %start)
}, execution_thread="foobar"
%async_wrapped {
%async_param = s32[1024]{0} parameter(0)
%async_param.1 = s32[1024]{0} parameter(1)
ROOT %call = s32[1024]{0} call(s32[1024]{0} %async_param, s32[1024]{0} %async_param.1), to_apply=%called_computation
}, execution_thread="foobar"
ENTRY %main {
%input.1 = s32[1024]{0} parameter(0)
%buf = s32[1024]{0} custom-call(), custom_call_target="AllocateBuffer"
%async-start = ((s32[1024]{0}, s32[1024]{0}), s32[1024]{0}, u32[]) async-start(s32[1024]{0} %input.1, s32[1024]{0} %buf), async_group_id=0, async_execution_thread="foobar", calls=%async_wrapped
ROOT %async-done = s32[1024]{0} async-done(((s32[1024]{0}, s32[1024]{0}), s32[1024]{0}, u32[]) %async-start), async_group_id=0, async_execution_thread="foobar", calls=%async_wrapped
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
ParseAndReturnUnverifiedModule(kModuleString));
CopyInsertion copy_insertion(nullptr,
/*use_region_based_live_range_analysis=*/-1);
ASSERT_IS_OK(copy_insertion.Run(module.get(), {"foobar"}).status());
VLOG(2) << module->ToString();
EXPECT_EQ(CountCopies(*module), 0);
}
TEST_F(CopyInsertionTest, AsyncCallDUSCopy) {
const char* const kModuleString = R"(
HloModule async_call
%called_computation {
%out_param = s32[1024]{0} parameter(1)
%input = s32[1024]{0} parameter(0)
%size = s32[] constant(256)
%index = s32[] custom-call(), custom_call_target="Baz"
%start = s32[] multiply(s32[] %size, s32[] %index)
%input2 = s32[256]{0} dynamic-slice(s32[1024]{0} %input, s32[] %start), dynamic_slice_sizes={256}
%output = s32[256]{0} add(s32[256]{0} %input2, s32[256]{0} %input2)
ROOT %output2 = s32[1024]{0} dynamic-update-slice(s32[1024]{0} %out_param, s32[256]{0} %output, s32[] %start)
}, execution_thread="foobar"
%async_wrapped {
%async_param = s32[1024]{0} parameter(0)
%async_param.1 = s32[1024]{0} parameter(1)
ROOT %call = s32[1024]{0} call(s32[1024]{0} %async_param, s32[1024]{0} %async_param.1), to_apply=%called_computation
}, execution_thread="foobar"
ENTRY %main {
%input.1 = s32[1024]{0} parameter(0)
%input.2 = s32[1024]{0} parameter(1)
%async-start = ((s32[1024]{0}, s32[1024]{0}), s32[1024]{0}, u32[]) async-start(s32[1024]{0} %input.1, s32[1024]{0} %input.2), async_group_id=0, async_execution_thread="foobar", calls=%async_wrapped
ROOT %async-done = s32[1024]{0} async-done(((s32[1024]{0}, s32[1024]{0}), s32[1024]{0}, u32[]) %async-start), async_group_id=0, async_execution_thread="foobar", calls=%async_wrapped
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
ParseAndReturnUnverifiedModule(kModuleString));
CopyInsertion copy_insertion(nullptr,
/*use_region_based_live_range_analysis=*/-1);
ASSERT_IS_OK(copy_insertion.Run(module.get(), {"foobar"}).status());
VLOG(2) << module->ToString();
EXPECT_EQ(CountCopies(*module), 1);
}
} // namespace
} // namespace xla

View File

@ -369,7 +369,6 @@ cc_library(
"//tensorflow/tsl/platform:logging",
"//tensorflow/tsl/platform:human_readable_json",
"//tensorflow/tsl/platform:status",
"//tensorflow/tsl/profiler/lib:nvtx_utils",
"//tensorflow/tsl/protobuf:dnn_proto_cc",
] + if_gpu_is_configured([
":triangular_solve_thunk",
@ -714,7 +713,6 @@ cc_library(
compatible_with = get_compatible_with_cloud(),
deps = [
":target_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/hlo/ir:hlo",
"//tensorflow/compiler/xla/mlir_hlo",
"//tensorflow/compiler/xla/mlir_hlo:lhlo",
@ -722,7 +720,6 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_type_conversion_util",
"//tensorflow/compiler/xla/stream_executor",
"//tensorflow/compiler/xla/translate/hlo_to_mhlo:hlo_utils",
"//tensorflow/compiler/xla/translate/mhlo_to_hlo:type_to_shape",
"@llvm-project//llvm:Core",
"@llvm-project//mlir:ArithDialect",
@ -1248,16 +1245,13 @@ cc_library(
srcs = ["instruction_fusion.cc"],
hdrs = ["instruction_fusion.h"],
deps = [
":gpu_device_info",
":gpu_fusible",
":ir_emission_utils",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/hlo/ir:hlo",
"//tensorflow/compiler/xla/service:fusion_node_indexing_evaluation",
"//tensorflow/compiler/xla/service:hlo_query",
"//tensorflow/compiler/xla/service:instruction_fusion",
"//tensorflow/compiler/xla/service:pattern_matcher",
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
@ -1269,13 +1263,11 @@ xla_cc_test(
srcs = ["instruction_fusion_test.cc"],
tags = ["no_pip"],
deps = [
":gpu_device_info_for_tests",
":gpu_fusible",
":instruction_fusion",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/hlo/ir:hlo",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@ -1291,7 +1283,6 @@ cc_library(
":gpu_fusible",
":gpu_hlo_cost_analysis",
":gpu_performance_model",
":instruction_fusion",
":ir_emission_utils",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:shape_util",
@ -1300,8 +1291,6 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:hlo_reachability",
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
"//tensorflow/tsl/platform:logging",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
@ -1413,16 +1402,13 @@ cc_library(
":gpu_fusible",
":gpu_hlo_cost_analysis",
":gpu_performance_model",
":instruction_fusion",
":ir_emission_utils",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/hlo/ir:hlo",
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
"//tensorflow/tsl/platform:errors",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
],
)
@ -2391,6 +2377,7 @@ cc_library(
srcs = ["gpu_fusible.cc"],
hdrs = ["gpu_fusible.h"],
deps = [
":gpu_device_info",
":ir_emission_utils",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/hlo/ir:hlo",
@ -2633,10 +2620,9 @@ xla_cc_test(
srcs = ["horizontal_loop_fusion_test.cc"],
tags = tf_cuda_tests_tags(),
deps = [
":fusion_merger",
":gpu_device_info_for_tests",
":horizontal_loop_fusion",
":instruction_fusion",
":multi_output_fusion",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@ -2648,7 +2634,6 @@ xla_cc_test(
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:tuple_simplifier",
"//tensorflow/compiler/xla/tests:filecheck",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/tsl/lib/core:status_test_util",
@ -2660,16 +2645,12 @@ cc_library(
srcs = ["horizontal_input_fusion.cc"],
hdrs = ["horizontal_input_fusion.h"],
deps = [
":gpu_device_info",
":gpu_fusible",
":ir_emission_utils",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/hlo/ir:hlo",
"//tensorflow/compiler/xla/service:hlo_creation_utils",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/tsl/platform:errors",
"//tensorflow/tsl/platform:logging",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
@ -2679,16 +2660,13 @@ xla_cc_test(
srcs = ["horizontal_input_fusion_test.cc"],
tags = tf_cuda_tests_tags(),
deps = [
":gpu_device_info_for_tests",
":horizontal_input_fusion",
":multi_output_fusion",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test",
"//tensorflow/compiler/xla/tests:filecheck",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)

View File

@ -240,8 +240,8 @@ FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) {
// Skip 'fusion' instruction if merging it into at least one of the users
// would make the fusion use too much shared memory or registers.
FusionDecision fits = FusionFitsInBudget(
*user, *producer, /*is_consumer_producer_fusion=*/true,
&fusion_info_cache_);
*user, *producer, gpu_device_info_,
/*is_consumer_producer_fusion=*/true, &fusion_info_cache_);
if (!fits) {
++num_fail_fusion_too_large_;
return fits;

View File

@ -43,7 +43,6 @@ limitations under the License.
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h"
@ -703,6 +702,8 @@ Status GpuCompiler::OptimizeHloModule(
OptimizeHloPostLayoutAssignment(hlo_module, stream_exec, device_allocator,
gpu_target_config, autotune_results));
const GpuDeviceInfo& gpu_device_info = gpu_target_config.gpu_device_info;
{
HloPassFix<HloPassPipeline> fusion("fusion");
// We try to split variadic ops with many parameters into several such ops
@ -713,9 +714,10 @@ Status GpuCompiler::OptimizeHloModule(
HloVerifierOpts{}.MakeLayoutSensitive().WithInstructionCanChangeLayout(
LayoutAssignment::InstructionCanChangeLayout),
/*debug_only=*/true);
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
const GpuDeviceInfo gpu_device_info = gpu_target_config.gpu_device_info;
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false,
gpu_device_info);
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true,
gpu_device_info);
fusion.AddPass<FusionMerger>(gpu_device_info, ShapeSizeBytesFunction());
fusion.AddPass<GpuMultiOutputFusion>(gpu_device_info,
ShapeSizeBytesFunction());
@ -728,7 +730,7 @@ Status GpuCompiler::OptimizeHloModule(
{
HloPassFix<HloPassPipeline> horizontal_fusion("horizontal fusion");
horizontal_fusion.AddPass<GpuHorizontalLoopFusion>();
horizontal_fusion.AddPass<GpuHorizontalInputFusion>();
horizontal_fusion.AddPass<GpuHorizontalInputFusion>(gpu_device_info);
horizontal_fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
/*only_fusion_computations=*/true);
horizontal_fusion.AddPass<HloDCE>();

View File

@ -486,13 +486,14 @@ static int64_t NumUnnestedReductions(const HloInstruction& instr,
// to true to enable more fusion.
FusionDecision FusionFitsInBudget(const HloInstruction& instr1,
const HloInstruction& instr2,
const GpuDeviceInfo& device_info,
bool is_consumer_producer_fusion,
FusionInfoCache* cache /*=nullptr*/) {
if (SharedMemoryUsage(instr1, cache) + SharedMemoryUsage(instr2, cache) >
kSharedMemoryBudgetInBytes) {
device_info.shared_memory_per_block) {
return FusionDecision{}
<< "shared memory usage would be over the budget of "
<< kSharedMemoryBudgetInBytes << "B";
<< device_info.shared_memory_per_block << "B";
}
if (NumUnnestedReductions(instr1, cache) +

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_FUSIBLE_H_
#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h"
#include "tensorflow/compiler/xla/service/instruction_fusion.h"
// TODO(b/112957171): Extract logic to determine fusibility of HLO ops from
@ -92,6 +93,7 @@ bool IsInputFusibleScatter(const HloInstruction& instr);
// the producer, set consumer_producer_fusion to true to enable more fusion.
FusionDecision FusionFitsInBudget(const HloInstruction& instr1,
const HloInstruction& instr2,
const GpuDeviceInfo& device_info,
bool is_consumer_producer_fusion = false,
FusionInfoCache* cache = nullptr);

View File

@ -18,13 +18,9 @@ limitations under the License.
#include <algorithm>
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/tsl/platform/errors.h"
namespace xla {
namespace gpu {
@ -46,8 +42,9 @@ Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) {
class HorizontalInputFusionImpl {
public:
explicit HorizontalInputFusionImpl(HloComputation* computation)
: computation_(computation) {}
explicit HorizontalInputFusionImpl(HloComputation* computation,
const GpuDeviceInfo& d)
: computation_(computation), device_info_(d) {}
~HorizontalInputFusionImpl() {}
@ -55,6 +52,7 @@ class HorizontalInputFusionImpl {
private:
HloComputation* computation_;
const GpuDeviceInfo device_info_;
}; // HorizontalInputFusionImpl
// Compares one-by-one the dimensions of `shape_a` and `shape_b` from left to
@ -138,7 +136,7 @@ StatusOr<bool> HorizontalInputFusionImpl::Run() {
HloInstruction* fusion_anchor = candidates[fusion_anchor_id];
HloInstruction* fused = candidates[j];
if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) &&
FusionFitsInBudget(*fusion_anchor, *fused)) {
FusionFitsInBudget(*fusion_anchor, *fused, device_info_)) {
VLOG(3) << "Fuse " << fused->ToString() << " into "
<< fusion_anchor->ToString();
fusion_anchor->MergeFusionInstructionIntoMultiOutput(fused);
@ -159,7 +157,7 @@ StatusOr<bool> HorizontalInputFusionImpl::Run() {
StatusOr<bool> GpuHorizontalInputFusion::RunOnComputation(
HloComputation* computation) {
HorizontalInputFusionImpl horizontal_fusion_impl(computation);
HorizontalInputFusionImpl horizontal_fusion_impl(computation, device_info_);
return horizontal_fusion_impl.Run();
}

View File

@ -17,8 +17,8 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_
#include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h"
#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h"
#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
@ -38,7 +38,7 @@ namespace gpu {
// ROOT tuple of the entry computation.
class GpuHorizontalInputFusion : public HloModulePass {
public:
GpuHorizontalInputFusion() {}
explicit GpuHorizontalInputFusion(const GpuDeviceInfo& d) : device_info_(d) {}
absl::string_view name() const override {
return "gpu_horizontal_input_fusion";
@ -51,6 +51,8 @@ class GpuHorizontalInputFusion : public HloModulePass {
private:
StatusOr<bool> RunOnComputation(HloComputation*);
const GpuDeviceInfo device_info_;
};
} // namespace gpu

View File

@ -15,14 +15,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_device_info_for_tests.h"
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/filecheck.h"
namespace xla {
namespace gpu {
@ -30,7 +27,11 @@ namespace {
namespace op = xla::testing::opcode_matchers;
class HorizontalInputFusionTest : public GpuCodegenTest {};
class HorizontalInputFusionTest : public GpuCodegenTest {
public:
GpuHorizontalInputFusion horizontal_input_fusion_{
TestGpuDeviceInfo::RTXA6000DeviceInfo()};
};
TEST_F(HorizontalInputFusionTest, BasicTest) {
auto module = ParseAndReturnVerifiedModule(R"(
@ -64,7 +65,7 @@ TEST_F(HorizontalInputFusionTest, BasicTest) {
)")
.value();
EXPECT_TRUE(GpuHorizontalInputFusion().Run(module.get()).value());
EXPECT_TRUE(horizontal_input_fusion_.Run(module.get()).value());
const HloInstruction* entry_root =
module->entry_computation()->root_instruction();
@ -208,7 +209,7 @@ TEST_F(HorizontalInputFusionTest, MultiOutputFusionTest) {
)")
.value();
EXPECT_TRUE(GpuHorizontalInputFusion().Run(module.get()).value());
EXPECT_TRUE(horizontal_input_fusion_.Run(module.get()).value());
}
TEST_F(HorizontalInputFusionTest, NonfusionInstrs) {
@ -232,7 +233,7 @@ TEST_F(HorizontalInputFusionTest, NonfusionInstrs) {
)")
.value();
EXPECT_TRUE(GpuHorizontalInputFusion().Run(module.get()).value());
EXPECT_TRUE(horizontal_input_fusion_.Run(module.get()).value());
const HloInstruction* entry_root =
module->entry_computation()->root_instruction();

View File

@ -16,9 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_device_info_for_tests.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
@ -28,7 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/filecheck.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/tsl/lib/core/status_test_util.h"
@ -192,8 +190,11 @@ TEST_F(HorizontalLoopFusionTest, HorizontalLoopFusionAfterVerticalFusion) {
.value();
HloPassPipeline fusion("fusion");
fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/false);
fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/true);
const GpuDeviceInfo device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/false,
device_info);
fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/true,
device_info);
EXPECT_TRUE(fusion.Run(module.get()).value());
EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value());
TF_ASSERT_OK(verifier().Run(module.get()).status());

View File

@ -101,7 +101,7 @@ FusionDecision GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
// The following checks are potentially expensive.
if (NoFusionPossible too_large =
!FusionFitsInBudget(*consumer, *producer,
!FusionFitsInBudget(*consumer, *producer, device_info_,
/*is_consumer_producer_fusion=*/true)) {
return !too_large;
}

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h"
#include "tensorflow/compiler/xla/service/instruction_fusion.h"
namespace xla {
@ -30,8 +31,9 @@ namespace gpu {
class GpuInstructionFusion : public InstructionFusion {
public:
explicit GpuInstructionFusion(bool may_duplicate)
: InstructionFusion(GpuInstructionFusion::IsExpensive, may_duplicate) {}
explicit GpuInstructionFusion(bool may_duplicate, const GpuDeviceInfo& d)
: InstructionFusion(GpuInstructionFusion::IsExpensive, may_duplicate),
device_info_(d) {}
static bool IsExpensive(const HloInstruction& instruction);
@ -69,6 +71,8 @@ class GpuInstructionFusion : public InstructionFusion {
// indexed with different index vectors.
absl::flat_hash_map<const HloInstruction*, FusionNodeIndexingEvaluation>
fusion_node_evaluations_;
const GpuDeviceInfo device_info_;
};
} // namespace gpu

View File

@ -15,11 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/hlo/ir/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_device_info_for_tests.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/util.h"
@ -29,7 +27,11 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace gpu {
using InstructionFusionTest = HloTestBase;
class InstructionFusionTest : public HloTestBase {
public:
GpuInstructionFusion duplicating_instruction_fusion_{
/*may_duplicate=*/true, TestGpuDeviceInfo::RTXA6000DeviceInfo()};
};
TEST_F(InstructionFusionTest,
CostlyProducerAndOperandElementReusingConsumerNotFused) {
@ -45,8 +47,7 @@ TEST_F(InstructionFusionTest,
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(broadcast2, computation->root_instruction());
EXPECT_FALSE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
EXPECT_EQ(broadcast2, computation->root_instruction());
}
@ -64,8 +65,7 @@ TEST_F(InstructionFusionTest,
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(broadcast2, computation->root_instruction());
EXPECT_TRUE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
EXPECT_THAT(computation->root_instruction(), op::Fusion());
}
@ -82,8 +82,7 @@ TEST_F(InstructionFusionTest,
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(reshape2, computation->root_instruction());
EXPECT_TRUE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
EXPECT_THAT(computation->root_instruction(), op::Fusion());
}
@ -100,8 +99,7 @@ TEST_F(InstructionFusionTest,
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(transpose2, computation->root_instruction());
EXPECT_TRUE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
EXPECT_THAT(computation->root_instruction(), op::Fusion());
}
@ -119,8 +117,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotFused) {
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(log, computation->root_instruction());
EXPECT_TRUE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
}
TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) {
@ -135,8 +132,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) {
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(transpose2, computation->root_instruction());
EXPECT_FALSE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
}
// Tests that broadcasts fused into a fusion with a reduce root.
@ -159,8 +155,7 @@ TEST_F(InstructionFusionTest, BroadcastIntoReduce) {
})")
.value();
EXPECT_TRUE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Fusion());
@ -186,8 +181,7 @@ TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduce) {
})")
.value();
EXPECT_FALSE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
}
TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduceFusion) {
@ -215,8 +209,7 @@ TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduceFusion) {
})")
.value();
EXPECT_FALSE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
}
TEST_F(InstructionFusionTest, DoNotRepeatLargeReduceWindow) {
@ -241,8 +234,7 @@ TEST_F(InstructionFusionTest, DoNotRepeatLargeReduceWindow) {
})")
.value();
EXPECT_FALSE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
}
TEST_F(InstructionFusionTest, FuseLayoutChangingOpWithElementwise) {
@ -255,8 +247,7 @@ TEST_F(InstructionFusionTest, FuseLayoutChangingOpWithElementwise) {
})")
.value();
EXPECT_TRUE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Fusion());
@ -275,8 +266,7 @@ TEST_F(InstructionFusionTest, BitcastIntoAdd) {
})")
.value();
EXPECT_TRUE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Fusion());
@ -296,8 +286,7 @@ TEST_F(InstructionFusionTest, AddIntoBitcast) {
})")
.value();
EXPECT_TRUE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Fusion());
@ -316,8 +305,7 @@ TEST_F(InstructionFusionTest, DontFuseGTE) {
})")
.value();
EXPECT_FALSE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
}
// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is
@ -341,8 +329,7 @@ TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) {
})")
.value();
EXPECT_TRUE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Tuple(op::Fusion(), op::Fusion()))
@ -371,8 +358,7 @@ TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) {
})")
.value();
EXPECT_FALSE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value())
EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value())
<< module->ToString();
}
@ -390,8 +376,7 @@ TEST_F(InstructionFusionTest, DotOutputFusionImpossible) {
})")
.value();
EXPECT_TRUE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Fusion());
@ -444,8 +429,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusion) {
// Multi-output fusion is disabled here and performed in the
// GpuMultiOutputFusion pass instead.
ASSERT_FALSE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
ASSERT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
}
TEST_F(InstructionFusionTest, FuseScalarConstant) {
@ -462,8 +446,7 @@ TEST_F(InstructionFusionTest, FuseScalarConstant) {
})")
.value();
EXPECT_TRUE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Fusion());
@ -491,8 +474,7 @@ TEST_F(InstructionFusionTest, AvoidsLargeFusion) {
}
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(b.Build());
EXPECT_TRUE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
SCOPED_TRACE(module->ToString());
for (const HloInstruction* instr : computation->instructions()) {
EXPECT_LE(instr->operand_count(), MaxOperandsAndOutputsPerFusion())
@ -527,8 +509,7 @@ TEST_F(InstructionFusionTest, FuseIntoScatter) {
})")
.value();
EXPECT_TRUE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Add(op::Fusion(), op::Fusion()));
@ -557,8 +538,7 @@ TEST_F(InstructionFusionTest, NonscalarConstantsNotFused) {
})")
.value();
EXPECT_TRUE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
// The f32[16] constant should not be fused into the reduce, but the f32[]
// constant should be.
auto* root = module->entry_computation()->root_instruction();
@ -578,8 +558,7 @@ TEST_F(InstructionFusionTest, FuseReverse) {
})")
.value();
EXPECT_TRUE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Fusion());
@ -698,8 +677,7 @@ TEST_F(InstructionFusionTest, FloatingPointExpIsCheap) {
})")
.value();
EXPECT_TRUE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Tuple(op::Fusion(), op::Fusion()))
@ -725,8 +703,7 @@ TEST_F(InstructionFusionTest, SmallReducedDimensionIsNotLoweredToLoop) {
})")
.value();
EXPECT_TRUE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
HloInstruction* root = module->entry_computation()->root_instruction();
ASSERT_THAT(root, op::Fusion());
@ -767,8 +744,7 @@ TEST_F(InstructionFusionTest, DontTouchSoftmaxCustomCall) {
ROOT %custom-call = f32[554112,10]{1,0} custom-call(f32[554112,10]{1,0} %param_0), custom_call_target="__softmax_fusion", called_computations={%softmax_computation}
})")
.value();
EXPECT_FALSE(
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
}
TEST_F(InstructionFusionTest, IotaIntoVariadicReduction) {
@ -806,8 +782,10 @@ TEST_F(InstructionFusionTest, IotaIntoVariadicReduction) {
})")
.value();
EXPECT_TRUE(
GpuInstructionFusion(/*may_duplicate=*/false).Run(module.get()).value());
EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false,
TestGpuDeviceInfo::RTXA6000DeviceInfo())
.Run(module.get())
.value());
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Fusion(op::Parameter()));
EXPECT_THAT(

View File

@ -31,12 +31,6 @@ limitations under the License.
namespace xla {
namespace gpu {
// The amount of shared memory a CUDA kernel can use.
//
// Stay on the conservative side, this is smaller than full 64kB, but allows
// some extra space for cache.
inline constexpr int64_t kSharedMemoryBudgetInBytes = 48 * 1024;
// If a dimensions is smaller than this, untiled transposition may be more
// efficient.
inline constexpr int64_t kMinDimensionToTransposeTiled = 16;

View File

@ -131,7 +131,6 @@ limitations under the License.
#include "tensorflow/tsl/platform/errors.h"
#include "tensorflow/tsl/platform/human_readable_json.h"
#include "tensorflow/tsl/platform/logging.h"
#include "tensorflow/tsl/profiler/lib/nvtx_utils.h"
#include "tensorflow/tsl/protobuf/dnn.pb.h"
#if GOOGLE_CUDA
@ -4569,12 +4568,8 @@ static bool CanVectorizeReduction(
se::CudaComputeCapability cc, mlir::lmhlo::FusionOp fusion,
HloComputation* fused_computation,
const ReductionDimensions& reduction_dimensions, int num_threads_x,
Vector3 reduction_tiling, const Shape& input_shape, int64_t shmem_usage,
Vector3 reduction_tiling, const Shape& input_shape,
bool reduction_is_race_free) {
// Vectorization might cause us to run out of budget.
if (shmem_usage * 2 > kSharedMemoryBudgetInBytes) {
return false;
}
if (!reduction_dimensions.is_row_reduction) {
return IsUnrollingColumnReductionBeneficial(
fusion, fused_computation, input_shape,
@ -4678,10 +4673,15 @@ StatusOr<ReductionCodegenInfo> IrEmitterUnnested::ComputeReductionCodegenInfo(
: kLinearIndexingX;
int64_t shmem_usage =
ProjectedShmemUsageBytes(reduction_dimensions, instr_index_groups);
const int64_t shmem_budget =
ir_emitter_context_->gpu_device_info().shared_memory_per_block;
bool reduction_is_race_free = ReductionIsRaceFree(reduction_dimensions);
bool vectorize = CanVectorizeReduction(
cc, fusion, fused_computation, reduction_dimensions, num_threads_x,
reduction_tiling, input_shape, shmem_usage, reduction_is_race_free);
bool vectorize =
// Vectorization might cause us to run out of budget.
(shmem_usage * 2 <= shmem_budget) &&
CanVectorizeReduction(cc, fusion, fused_computation, reduction_dimensions,
num_threads_x, reduction_tiling, input_shape,
reduction_is_race_free);
int vector_size = vectorize ? 2 : 1;
int num_partial_results = 1;
@ -4707,7 +4707,7 @@ StatusOr<ReductionCodegenInfo> IrEmitterUnnested::ComputeReductionCodegenInfo(
}
}
while (shmem_usage * num_partial_results > kSharedMemoryBudgetInBytes) {
while (shmem_usage * num_partial_results > shmem_budget) {
num_partial_results /= 2;
if (num_partial_results == 1) {
break;

View File

@ -50,6 +50,7 @@ bool IsProfitableOperand(HloInstruction* instr) {
}
FusionDecision LegalToFuse(HloInstruction* instr1, HloInstruction* instr2,
const GpuDeviceInfo& device_info,
FusionInfoCache* fusion_info_cache) {
CHECK(instr1->opcode() == HloOpcode::kFusion);
@ -66,7 +67,7 @@ FusionDecision LegalToFuse(HloInstruction* instr1, HloInstruction* instr2,
}
// Do this check last, as it may be expensive.
return FusionFitsInBudget(*instr1, *instr2,
return FusionFitsInBudget(*instr1, *instr2, device_info,
/*is_consumer_producer_fusion=*/false,
fusion_info_cache);
}
@ -161,7 +162,7 @@ std::vector<HloInstruction*> GetProducerConsumerMultiOutputFusionCandidates(
<< " would introduce a cycle when fused.");
continue;
}
if (!FusionFitsInBudget(*producer, *consumer,
if (!FusionFitsInBudget(*producer, *consumer, device_info,
/*is_consumer_producer_fusion=*/false,
fusion_info_cache)) {
dump_negative_explanation(
@ -260,7 +261,7 @@ bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent,
if (NoFusionPossible sibling_fusible =
(!IsSiblingFusionCandidate(*j) || !is_disconnected(*i, *j) ||
!ShapesCompatibleForMultiOutputFusion(*(*i), *(*j)) ||
!LegalToFuse(*i, *j, fusion_info_cache))) {
!LegalToFuse(*i, *j, device_info_, fusion_info_cache))) {
// We pick `j` arbitrarily as a consumer.
if (dump_fusion) {
RegisterFusionState(

View File

@ -643,6 +643,7 @@ xla_cc_test(
":gpu_codegen_test",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service/gpu:gpu_device_info_for_tests",
"//tensorflow/compiler/xla/service/gpu:gpu_fusible",
"//tensorflow/compiler/xla/service/gpu:instruction_fusion",
"//tensorflow/compiler/xla/tests:hlo_test_base",
@ -657,12 +658,9 @@ xla_cc_test(
tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service/gpu:fusion_merger",
"//tensorflow/compiler/xla/service/gpu:gpu_device_info_for_tests",
"//tensorflow/compiler/xla/service/gpu:gpu_fusible",
"//tensorflow/compiler/xla/service/gpu:instruction_fusion",
"//tensorflow/compiler/xla/service/gpu:multi_output_fusion",
"//tensorflow/compiler/xla/tests:hlo_test_base",

View File

@ -19,12 +19,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_device_info_for_tests.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/tsl/platform/test.h"
@ -45,12 +42,13 @@ class GpuFusionPipelineTest : public GpuCodegenTest {
void CheckGpuFusionPipeline(absl::string_view hlo,
std::optional<absl::string_view> expected) {
HloPassPipeline pipeline("gpu-fusion");
pipeline.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
pipeline.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
pipeline.AddPass<FusionMerger>(TestGpuDeviceInfo::RTXA6000DeviceInfo(),
ShapeSizeBytesFunction());
pipeline.AddPass<GpuMultiOutputFusion>(
TestGpuDeviceInfo::RTXA6000DeviceInfo(), ShapeSizeBytesFunction());
const GpuDeviceInfo device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
pipeline.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false,
device_info);
pipeline.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true, device_info);
pipeline.AddPass<FusionMerger>(device_info, ShapeSizeBytesFunction());
pipeline.AddPass<GpuMultiOutputFusion>(device_info,
ShapeSizeBytesFunction());
RunAndFilecheckHloRewrite(hlo, std::move(pipeline), expected);
}

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <utility>
#include <vector>
#include "tensorflow/compiler/xla/service/gpu/gpu_device_info_for_tests.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
@ -81,7 +82,10 @@ TEST_F(GpuFusionTest, FusedBiggerThenThresholdButDoNotChangeTheFusionl) {
b.AddInstruction(
HloInstruction::CreateConcatenate(concat_shape, slice_params, 1));
module->AddEntryComputation(b.Build());
EXPECT_TRUE(GpuInstructionFusion(false).Run(module.get()).value());
EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false,
TestGpuDeviceInfo::RTXA6000DeviceInfo())
.Run(module.get())
.value());
EXPECT_TRUE(module->entry_computation()->root_instruction()->opcode() ==
HloOpcode::kFusion);
for (HloInstruction* instr : module->entry_computation()->instructions()) {
@ -93,8 +97,11 @@ class TransposeFusionTest : public GpuFusionTest {
public:
void CheckGpuFusion(absl::string_view hlo,
std::optional<absl::string_view> expected) {
RunAndFilecheckHloRewrite(hlo, GpuInstructionFusion{/*may_duplicate=*/true},
expected);
RunAndFilecheckHloRewrite(
hlo,
GpuInstructionFusion{/*may_duplicate=*/true,
TestGpuDeviceInfo::RTXA6000DeviceInfo()},
expected);
}
};

View File

@ -327,6 +327,19 @@ bool HloOrdering::UsesBeforeValueDefinition(
return true;
}
}
// The use at an async call occurs before values that are defined in the
// called computation of the async wrapped instruction.
if (use.instruction->IsAsynchronous() &&
use.instruction->async_wrapped_opcode() == HloOpcode::kCall) {
const HloInstruction* async = use.instruction;
if (call_graph_->InstructionIsNestedIn(
value.defining_instruction(),
async->async_wrapped_instruction()->to_apply())) {
VLOG(4) << " use is async " << use.instruction->name()
<< " and def is in called computation";
return true;
}
}
if (use.instruction->opcode() == HloOpcode::kConditional) {
const HloInstruction* conditional = use.instruction;
// In general the use of a value in the conditional parameter should be

View File

@ -589,5 +589,52 @@ ENTRY entry {
ordering.UsesBeforeValueDefinition({&tuple_use}, value, *dataflow));
}
TEST_F(HloOrderingTest, AsyncCallUses) {
absl::string_view hlo_string = R"(
HloModule single_sc_async_call
%called_computation {
%out_param = s32[1024]{0} parameter(1)
%input = s32[1024]{0} parameter(0)
%size = s32[] constant(256)
%index = s32[] custom-call(), custom_call_target="Baz"
%start = s32[] multiply(s32[] %size, s32[] %index)
%input2 = s32[256]{0} dynamic-slice(s32[1024]{0} %input, s32[] %start), dynamic_slice_sizes={256}
%output = s32[256]{0} add(s32[256]{0} %input2, s32[256]{0} %input2)
ROOT %output2 = s32[1024]{0} dynamic-update-slice(s32[1024]{0} %out_param, s32[256]{0} %output, s32[] %start)
}, execution_thread="foobar"
%async_wrapped {
%async_param = s32[1024]{0} parameter(0)
%async_param.1 = s32[1024]{0} parameter(1)
ROOT %call = s32[1024]{0} call(s32[1024]{0} %async_param, s32[1024]{0} %async_param.1), to_apply=%called_computation
}, execution_thread="foobar"
ENTRY %main {
%input.1 = s32[1024]{0} parameter(0)
%buf = s32[1024]{0} custom-call(), custom_call_target="AllocateBuffer"
%async-start = ((s32[1024]{0}, s32[1024]{0}), s32[1024]{0}, u32[]) async-start(s32[1024]{0} %input.1, s32[1024]{0} %buf), async_group_id=0, async_execution_thread="foobar", calls=%async_wrapped
ROOT %async-done = s32[1024]{0} async-done(((s32[1024]{0}, s32[1024]{0}), s32[1024]{0}, u32[]) %async-start), async_group_id=0, async_execution_thread="foobar", calls=%async_wrapped
}
)";
HloModuleConfig hlo_config;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string, hlo_config));
TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
DependencyHloOrdering ordering(module.get());
auto async_start = FindInstruction(module.get(), "async-start");
auto async_done = FindInstruction(module.get(), "async-done");
auto call = FindInstruction(module.get(), "call");
auto output2 = FindInstruction(module.get(), "output2");
auto async_start_use = HloUse{async_start, 1};
auto async_done_use = HloUse{async_done, 0, {0, 1}};
auto call_use = HloUse{call, 1};
const HloValue& value = dataflow->GetUniqueValueAt(output2, {});
EXPECT_TRUE(ordering.UsesBeforeValueDefinition(
{&async_start_use, &call_use, &async_done_use}, value, *dataflow));
}
} // namespace
} // namespace xla

View File

@ -1694,6 +1694,8 @@ namespace {
struct ParallelState {
explicit ParallelState(int64_t task_count) : counter(task_count) {
// If this method is changed, please remember to change
// GetForEachIndexParallelThreadCount() as well.
static auto* global_pool = new tsl::thread::ThreadPool(
tsl::Env::Default(), "foreach", tsl::port::MaxParallelism());
pool = global_pool;
@ -1743,6 +1745,11 @@ struct ParallelState {
return pstate.status;
}
/* static */ int ShapeUtil::GetForEachIndexParallelThreadCount() {
ParallelState pstate(/*task_count=*/0);
return pstate.pool->NumThreads();
}
/* static */ Shape ShapeUtil::DeleteDimensions(
absl::Span<int64_t const> dims_to_delete, Shape shape) {
std::vector<int64_t> dims_to_delete_v(dims_to_delete.begin(),

View File

@ -722,11 +722,19 @@ class ShapeUtil {
// A parallel version of ForEachIndex(WithStatus). This can only be used if
// the visitor_function is thread-safe and the order of iteration does not
// matter.
//
// Please use GetForEachIndexParallelThreadCount() to get the number of
// threads in the threadpool of ForEachIndexParallel*. This will not change
// during the runtime of the process. Please DO NOT use
// tsl::port::MaxParallelism() for this purpose, as it may change.
static void ForEachIndexParallel(
const Shape& shape, absl::Span<const int64_t> base,
absl::Span<const int64_t> count, absl::Span<const int64_t> incr,
const ForEachParallelVisitorFunction& visitor_function);
// Returns the number of threads in the threadpool of ForEachIndexParallel*.
static int GetForEachIndexParallelThreadCount();
static Status ForEachIndexParallelWithStatus(
const Shape& shape, absl::Span<const int64_t> base,
absl::Span<const int64_t> count, absl::Span<const int64_t> incr,

View File

@ -630,6 +630,23 @@ TEST(ShapeUtilTest, ForEachIndexWithStatus) {
EXPECT_EQ(invocations, 5);
}
TEST(ShapeUtilTest, GetForEachIndexParallelThreadCount) {
const int kThreadCount = ShapeUtil::GetForEachIndexParallelThreadCount();
Shape shape = ShapeUtil::MakeShape(F32, {10, 100});
auto check_func = [kThreadCount](absl::Span<const int64_t> /*indexes*/,
int thread_id) -> StatusOr<bool> {
EXPECT_GE(thread_id, -1);
EXPECT_LT(thread_id, kThreadCount);
return true;
};
for (int i = 0; i < 10; ++i) {
ShapeUtil::ForEachIndexParallel(shape, /*base=*/{0, 0}, /*count=*/{10, 100},
/*incr=*/{1, 1}, check_func);
}
}
TEST(ShapeUtilTest, ForEachIndexParallel) {
Shape shape = ShapeUtil::MakeShape(F32, {10, 10});
int64_t output[10][10];

View File

@ -115,6 +115,7 @@ cc_library(
"//tensorflow/compiler/xla/stream_executor/platform",
"//tensorflow/compiler/xla/stream_executor/platform:dso_loader",
"//tensorflow/tsl/platform:env",
"//tensorflow/tsl/platform:static_threadlocal",
] + tf_additional_cuda_driver_deps()) + select({
# include dynamic loading implementation only when if_cuda_is_configured and build dynamically
"//tensorflow/tsl:is_cuda_enabled_and_oss": ["cudart_stub"],

View File

@ -152,15 +152,15 @@ void Diagnostician::LogDiagnosticInformation() {
CFRelease(kext_infos);
#elif !defined(PLATFORM_WINDOWS)
if (access(kDriverVersionPath, F_OK) != 0) {
LOG(INFO) << "kernel driver does not appear to be running on this host "
<< "(" << tsl::port::Hostname() << "): "
<< "/proc/driver/nvidia/version does not exist";
VLOG(1) << "kernel driver does not appear to be running on this host "
<< "(" << tsl::port::Hostname() << "): "
<< "/proc/driver/nvidia/version does not exist";
return;
}
auto dev0_path = GetDevNodePath(0);
if (access(dev0_path.c_str(), F_OK) != 0) {
LOG(INFO) << "no NVIDIA GPU device is present: " << dev0_path
<< " does not exist";
VLOG(1) << "no NVIDIA GPU device is present: " << dev0_path
<< " does not exist";
return;
}
#endif

View File

@ -36,11 +36,11 @@ limitations under the License.
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.h"
#include "tensorflow/compiler/xla/stream_executor/lib/error.h"
#include "tensorflow/compiler/xla/stream_executor/lib/static_threadlocal.h"
#include "tensorflow/compiler/xla/stream_executor/platform/logging.h"
#include "tensorflow/compiler/xla/stream_executor/platform/port.h"
#include "tensorflow/tsl/platform/env.h"
#include "tensorflow/tsl/platform/stacktrace.h"
#include "tensorflow/tsl/platform/static_threadlocal.h"
#include "tensorflow/tsl/platform/threadpool.h"
bool FLAGS_gpuexec_cuda_driver_inject_init_error = false;
@ -130,7 +130,7 @@ struct ThreadLocalData {
int depth;
};
SE_STATIC_THREAD_LOCAL_POD(ThreadLocalData, tls_data);
TSL_STATIC_THREAD_LOCAL_POD(ThreadLocalData, tls_data);
} // namespace
@ -261,7 +261,7 @@ static tsl::Status InternalInit() {
if (res == CUDA_SUCCESS) {
return ::tsl::OkStatus();
} else if (res == CUDA_ERROR_SHARED_OBJECT_INIT_FAILED) {
LOG(WARNING) << "failed call to cuInit: " << ToString(res);
VLOG(1) << "failed call to cuInit: " << ToString(res);
} else {
LOG(ERROR) << "failed call to cuInit: " << ToString(res);
}

View File

@ -53,6 +53,7 @@ cc_library(
"//tensorflow/tsl/platform:env",
"//tensorflow/tsl/platform:numbers",
"//tensorflow/tsl/platform:stacktrace",
"//tensorflow/tsl/platform:static_threadlocal",
]),
)

View File

@ -20,7 +20,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_ROCM_ROCBLAS_WRAPPER_H_
#define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_ROCM_ROCBLAS_WRAPPER_H_
#include "rocm/include/rocblas.h"
#include "rocm/include/rocblas/rocblas.h"
#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_activation.h"
#include "tensorflow/compiler/xla/stream_executor/platform/dso_loader.h"
#include "tensorflow/compiler/xla/stream_executor/platform/port.h"

View File

@ -29,13 +29,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_diagnostics.h"
#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_driver.h"
#include "tensorflow/compiler/xla/stream_executor/lib/error.h"
#include "tensorflow/compiler/xla/stream_executor/lib/static_threadlocal.h"
#include "tensorflow/compiler/xla/stream_executor/platform/logging.h"
#include "tensorflow/compiler/xla/stream_executor/platform/port.h"
#include "tensorflow/compiler/xla/stream_executor/rocm/rocm_driver_wrapper.h"
#include "tensorflow/tsl/platform/env.h"
#include "tensorflow/tsl/platform/numbers.h"
#include "tensorflow/tsl/platform/stacktrace.h"
#include "tensorflow/tsl/platform/static_threadlocal.h"
#include "tensorflow/tsl/platform/threadpool.h"
bool FLAGS_gpuexec_rocm_driver_inject_init_error = false;
@ -168,7 +168,7 @@ struct ThreadLocalData {
int depth;
};
SE_STATIC_THREAD_LOCAL_POD(ThreadLocalData, tls_data);
TSL_STATIC_THREAD_LOCAL_POD(ThreadLocalData, tls_data);
} // namespace

View File

@ -112,6 +112,22 @@ xla_cc_binary(
],
)
xla_cc_test(
name = "replay_computation_bin_test",
srcs = ["replay_computation_bin_test.cc"],
data = [
"add.hlo",
":replay_computation_cpu",
":replay_computation_interpreter",
],
deps = [
"//tensorflow/tsl/platform:path",
"//tensorflow/tsl/platform:subprocess",
"//tensorflow/tsl/platform:test",
"//tensorflow/tsl/platform:test_main",
],
)
xla_cc_binary(
name = "replay_computation_gpu",
tags = ["gpu"],

View File

@ -0,0 +1,64 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include "tensorflow/tsl/platform/path.h"
#include "tensorflow/tsl/platform/subprocess.h"
#include "tensorflow/tsl/platform/test.h"
namespace xla {
namespace {
// TODO(ddunleavy): test something more specific.
std::string PathToAddHlo() {
return tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "tools", "add.hlo");
}
TEST(ReplayComputation, AddHloHost) {
// Get relevant paths to run_hlo_module and add.hlo
std::string replay_computation_bin = tsl::io::JoinPath(
tsl::testing::XlaSrcRoot(), "tools", "replay_computation_cpu");
tsl::SubProcess proc;
proc.SetProgram(replay_computation_bin,
{replay_computation_bin, PathToAddHlo(), "--use_fake_data"});
EXPECT_TRUE(proc.Start());
// Just make sure that the process's exit code is 0
int status = proc.Communicate(nullptr, nullptr, nullptr);
EXPECT_TRUE(WIFEXITED(status));
ASSERT_EQ(0, WEXITSTATUS(status));
}
TEST(ReplayComputation, AddHloInterpreter) {
// Get relevant paths to run_hlo_module and add.hlo
std::string replay_computation_bin = tsl::io::JoinPath(
tsl::testing::XlaSrcRoot(), "tools", "replay_computation_interpreter");
tsl::SubProcess proc;
proc.SetProgram(replay_computation_bin,
{replay_computation_bin, PathToAddHlo(), "--use_fake_data"});
EXPECT_TRUE(proc.Start());
// Just make sure that the process's exit code is 0
int status = proc.Communicate(nullptr, nullptr, nullptr);
EXPECT_TRUE(WIFEXITED(status));
ASSERT_EQ(0, WEXITSTATUS(status));
}
} // namespace
} // namespace xla

View File

@ -629,16 +629,23 @@ cc_library(
"//tensorflow/core/kernels/image:image",
"//tensorflow/core/kernels/sparse:kernels",
] + if_mkl([
"//tensorflow/core/kernels/mkl:mkl_aggregate_ops",
"//tensorflow/core/kernels/mkl:mkl_concat_op",
"//tensorflow/core/kernels/mkl:mkl_dequantize_op",
"//tensorflow/core/kernels/mkl:mkl_conv_op",
"//tensorflow/core/kernels/mkl:mkl_cwise_ops_common",
"//tensorflow/core/kernels/mkl:mkl_fused_batch_norm_op",
"//tensorflow/core/kernels/mkl:mkl_identity_op",
"//tensorflow/core/kernels/mkl:mkl_input_conversion_op",
"//tensorflow/core/kernels/mkl:mkl_layer_norm_op",
"//tensorflow/core/kernels/mkl:mkl_lrn_op",
"//tensorflow/core/kernels/mkl:mkl_pooling_ops",
"//tensorflow/core/kernels/mkl:mkl_qmatmul_op",
"//tensorflow/core/kernels/mkl:mkl_requantize_ops",
"//tensorflow/core/kernels/mkl:mkl_quantize_op",
"//tensorflow/core/kernels/mkl:mkl_relu_op",
"//tensorflow/core/kernels/mkl:mkl_reshape_op",
"//tensorflow/core/kernels/mkl:mkl_slice_op",
"//tensorflow/core/kernels/mkl:mkl_softmax_op",
"//tensorflow/core/kernels/mkl:mkl_swish_op",
"//tensorflow/core/kernels/mkl:mkl_fused_mish_op",
@ -646,8 +653,8 @@ cc_library(
"//tensorflow/core/kernels/mkl:mkl_batch_matmul_op",
"//tensorflow/core/kernels/mkl:mkl_einsum_op",
"//tensorflow/core/kernels/mkl:mkl_matmul_op",
"//tensorflow/core/kernels/mkl:mkl_tfconv_op",
"//tensorflow/core/kernels/mkl:mkl_tmp_bf16_ops",
"//tensorflow/core/kernels/mkl:mkl_deprecated_ops",
]) + if_cuda_or_rocm([
"//tensorflow/core/kernels:cudnn_rnn_kernels",
]) + if_cuda([
@ -1956,21 +1963,28 @@ tf_cc_test_mkl(
"//tensorflow/core/kernels:ops_util",
"//third_party/eigen3",
] + if_mkl([
"//tensorflow/core/kernels/mkl:mkl_aggregate_ops",
"//tensorflow/core/kernels/mkl:mkl_batch_matmul_op",
"//tensorflow/core/kernels/mkl:mkl_einsum_op",
"//tensorflow/core/kernels/mkl:mkl_concat_op",
"//tensorflow/core/kernels/mkl:mkl_conv_op",
"//tensorflow/core/kernels/mkl:mkl_cwise_ops_common",
"//tensorflow/core/kernels/mkl:mkl_dequantize_op",
"//tensorflow/core/kernels/mkl:mkl_fused_batch_norm_op",
"//tensorflow/core/kernels/mkl:mkl_identity_op",
"//tensorflow/core/kernels/mkl:mkl_input_conversion_op",
"//tensorflow/core/kernels/mkl:mkl_lrn_op",
"//tensorflow/core/kernels/mkl:mkl_matmul_op",
"//tensorflow/core/kernels/mkl:mkl_pooling_ops",
"//tensorflow/core/kernels/mkl:mkl_qmatmul_op",
"//tensorflow/core/kernels/mkl:mkl_quantize_op",
"//tensorflow/core/kernels/mkl:mkl_relu_op",
"//tensorflow/core/kernels/mkl:mkl_reshape_op",
"//tensorflow/core/kernels/mkl:mkl_slice_op",
"//tensorflow/core/kernels/mkl:mkl_softmax_op",
"//tensorflow/core/kernels/mkl:mkl_tfconv_op",
"//tensorflow/core/kernels/mkl:mkl_transpose_op",
"//tensorflow/core/kernels/mkl:mkl_tmp_bf16_ops",
"//tensorflow/core/kernels/mkl:mkl_deprecated_ops",
]),
)

View File

@ -0,0 +1,5 @@
op {
graph_op_name: "CollectiveReduceScatterV2"
summary: "Mutually reduces multiple tensors of identical type and shape and scatters the result."
visibility: HIDDEN
}

View File

@ -282,6 +282,7 @@ filegroup(
"memory_types.h",
"mkl_cpu_allocator.h",
"mkl_layout_pass.h",
"mkl_tfconversion_pass.h",
"node_file_writer.h",
"optimization_registry.h",
"partitioning_utils.h",
@ -1314,6 +1315,26 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "mkl_tfconversion_pass",
srcs = ["mkl_tfconversion_pass.cc"],
hdrs = [
"mkl_tfconversion_pass.h",
"//tensorflow/core/graph:mkl_graph_util_header",
],
copts = tf_copts(),
deps = [
":function",
":optimization_registry",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
alwayslink = 1,
)
cc_library(
name = "node_file_writer",
srcs = ["node_file_writer.cc"],
@ -1888,6 +1909,7 @@ tf_cuda_library(
":memory_types",
":mkl_cpu_allocator",
":mkl_layout_pass",
":mkl_tfconversion_pass",
":optimization_registry",
":optimized_function_graph_info",
":parallel_concat_optimizer",
@ -3073,6 +3095,59 @@ tf_cc_test(
],
)
tf_cc_test_mkl(
name = "mkl_related_tests",
size = "small",
srcs = [
"mkl_layout_pass_test.cc",
"mkl_tfconversion_pass_test.cc",
],
linkstatic = 1,
deps = [
":core",
":core_cpu",
":core_cpu_internal",
":direct_session_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:scope",
"//tensorflow/cc:sendrecv_ops",
"//tensorflow/core/kernels:ops_util",
"//third_party/eigen3",
] + if_mkl([
"//tensorflow/core/kernels/mkl:mkl_aggregate_ops",
"//tensorflow/core/kernels/mkl:mkl_batch_matmul_op",
"//tensorflow/core/kernels/mkl:mkl_concat_op",
"//tensorflow/core/kernels/mkl:mkl_conv_op",
"//tensorflow/core/kernels/mkl:mkl_cwise_ops_common",
"//tensorflow/core/kernels/mkl:mkl_dequantize_op",
"//tensorflow/core/kernels/mkl:mkl_einsum_op",
"//tensorflow/core/kernels/mkl:mkl_fused_batch_norm_op",
"//tensorflow/core/kernels/mkl:mkl_identity_op",
"//tensorflow/core/kernels/mkl:mkl_input_conversion_op",
"//tensorflow/core/kernels/mkl:mkl_lrn_op",
"//tensorflow/core/kernels/mkl:mkl_matmul_op",
"//tensorflow/core/kernels/mkl:mkl_pooling_ops",
"//tensorflow/core/kernels/mkl:mkl_qmatmul_op",
"//tensorflow/core/kernels/mkl:mkl_quantize_op",
"//tensorflow/core/kernels/mkl:mkl_relu_op",
"//tensorflow/core/kernels/mkl:mkl_reshape_op",
"//tensorflow/core/kernels/mkl:mkl_slice_op",
"//tensorflow/core/kernels/mkl:mkl_softmax_op",
"//tensorflow/core/kernels/mkl:mkl_tfconv_op",
"//tensorflow/core/kernels/mkl:mkl_transpose_op",
"//tensorflow/core/kernels/mkl:mkl_tmp_bf16_ops",
]),
)
# TODO(bmzhao): Refactor this target to use granular dependencies
# after stage 4 of the TF build refactor is complete:
# https://github.com/tensorflow/community/pull/179

View File

@ -303,14 +303,16 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
}
Tensor* output = ctx->mutable_output(0);
const Tensor* input = (col_params->instance.type == REDUCTION_COLLECTIVE ||
col_params->instance.type == GATHER_COLLECTIVE ||
col_params->instance.type == PERMUTE_COLLECTIVE ||
col_params->instance.type == ALL_TO_ALL_COLLECTIVE ||
(col_params->instance.type == BROADCAST_COLLECTIVE &&
col_params->is_source))
? &ctx->input(0)
: nullptr;
const Tensor* input =
(col_params->instance.type == REDUCTION_COLLECTIVE ||
col_params->instance.type == GATHER_COLLECTIVE ||
col_params->instance.type == PERMUTE_COLLECTIVE ||
col_params->instance.type == ALL_TO_ALL_COLLECTIVE ||
col_params->instance.type == REDUCE_SCATTER_COLLECTIVE ||
(col_params->instance.type == BROADCAST_COLLECTIVE &&
col_params->is_source))
? &ctx->input(0)
: nullptr;
CollectiveImplementationInterface* col_impl = nullptr;
Status status = CreateCollective(*col_params, &col_impl);
if (!status.ok()) {

View File

@ -78,6 +78,9 @@ const char* GetCollectiveName(const CollectiveParams* cp, bool nccl) {
case ALL_TO_ALL_COLLECTIVE:
return "AllToAll";
case REDUCE_SCATTER_COLLECTIVE:
return nccl ? "NcclReduceScatter" : "undef";
default:
return "undef";
}

View File

@ -1193,7 +1193,9 @@ bool ExecutorState<PropagatorStateType>::NodeDone(
// iterating through a tf.data input pipeline.
if (!errors::IsOutOfRange(s)) {
LOG(INFO) << "[" << immutable_state_.params().device->name()
<< "] Executor start aborting: " << s;
<< "] (DEBUG INFO) Executor start aborting (this does not "
"indicate an error and you can ignore this message): "
<< s;
} else {
VLOG(1) << "[" << immutable_state_.params().device->name()
<< "] Executor start aborting: " << s;

View File

@ -66,7 +66,8 @@ namespace tensorflow {
namespace {
bool IsCollectiveV2(const string& op) {
return op == "CollectiveReduceV2" || op == "CollectiveGatherV2" ||
op == "CollectiveBcastRecvV2" || op == "CollectiveBcastSendV2";
op == "CollectiveBcastRecvV2" || op == "CollectiveBcastSendV2" ||
op == "ColectiveReduceScatterV2";
}
} // namespace

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,462 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#include "tensorflow/core/common_runtime/mkl_tfconversion_pass.h"
#include <memory>
#include <queue>
#include <set>
#include <utility>
#include <vector>
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/util.h"
namespace tensorflow {
// This pass inserts Mkl to Tf tensor conversion nodes (represented by C)
// in the graph in between A and B, where A and B match any one
// of the following cases:
//
// 1) A = a node that generates output in the Mkl format and,
// B = a node that does not accept input in the Mkl format and,
// A -> B (there is a direct edge between A and B, then
// We will insert C such that A->C->B.
//
// 2) A = a node that generates output in the Mkl format and,
// B = NULL (in other words, A is the last node in the graph), then
// We will insert C such that A->C->B. (C will be the last node.)
//
// Note that case 1 applies to all outputs of A that are input to B.
// In other words, the conversions will be required for every output
// of A that is input to B. For example, let us say the output of A
// is A1, A2, A3, of which A1 and A2 are in Mkl format, but A3 is not
// in Mkl format, and all of them are input to B. In such case, we will
// do the conversion for A1 and A2 only. We do not need to do any conversion
// for A3.
//
// This pass relies on ops registering themselves about their Mkl compliance.
// An Mkl-compliant op can accept inputs in the Mkl format, and produce outputs
// in the Mkl format. Non-compliant ops accept inputs and outputs in the
// TensorFlow format.
//
// ADDENDUM: For element-wise ops, we may or may not need a conversion to
// take place before we hit the op. For this, we add a new op before each
// element-wise MKL op to deal with the inputs, called _MklInputConversion.
// This pass has been enhanced to add this capability.
//
// The _MklInputConversion op will check the inputs to the elementwise op and
// make sure that either both are in MKL format or both are in TF format,
// depending on their initial state and whether broadcast is needed or not.
class MklToTfConversionPass : public GraphOptimizationPass {
public:
MklToTfConversionPass() {}
Status Run(const GraphOptimizationPassOptions& options);
// Insert layout conversion node in the graph pointed by g.
// Function scans the graph for candidate edges where we
// need to insert conversion nodes.
//
// @return true even if single conversion node is inserted;
// false, otherwise.
bool RunPass(std::unique_ptr<Graph>* g);
private:
// Is the input Op supported by Mkl-specific layout?
//
// @input op_name string of the op
// @input T Datatype to use for checking input op
// @return true if op is Mkl supported; false, otherwise.
inline bool IsMklSupportedOp(const string& op_name, DataType T) const {
return mkl_op_registry::IsMklOp(op_name, T, false);
}
// Is the input Op supported by Mkl-specific layout AND
// is it element-wise?
//
// @input op_name string of the op
// @input T Datatype to use for checking input op
// @return true if op is Mkl supported; false, otherwise.
inline bool IsMklElementWiseOp(const string& op_name, DataType T) const {
return mkl_op_registry::IsMklElementWiseOp(op_name, T);
}
// Insert layout conversion node on the edge pointed by 'e' from graph 'g'.
//
// Edge will be deleted once a call to this function is successful.
// Any attempt to use the edge after this call
// will lead to undefined behaviors.
//
// @return Success:OK() if insertion is successful, otherwise returns
// appropriate error status code.
Status InsertConversionNodeOnEdge(std::unique_ptr<Graph>* g, Edge*);
// For element-wise ops, we need to sanitize the inputs. For this, we add a
// new node at the input of the replacement element-wise node that checks
// the inputs and converts one/both of them as required. See the op code
// comments for details.
//
// Insert input conversion node as parent of 'n' from graph 'g'.
//
// @return Success:OK() if insertion is successful, otherwise returns
// appropriate error status code.
Status InsertInputConversionNode(std::unique_ptr<Graph>* g, Node*);
};
// We register MklToTf insertion for phase 2 in post-partition grouping
// because we register MklLayoutRewritePass for phase 1 in post-partition
// grouping. We register this pass after partitioning so that we get a
// complete picture of inputs and outputs of the nodes in the graphs.
const OptimizationPassRegistry::Grouping kMklTfConvPassGroup =
OptimizationPassRegistry::POST_PARTITIONING;
#ifdef ENABLE_MKL
REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass);
#endif // ENABLE_MKL
Status MklToTfConversionPass::InsertConversionNodeOnEdge(
std::unique_ptr<Graph>* g, Edge* e) {
CHECK_NOTNULL(e);
Node* src = e->src();
Node* dst = e->dst();
CHECK_NOTNULL(src);
CHECK_NOTNULL(dst);
Node* conversion_node = nullptr;
DataType src_datatype = src->output_type(e->src_output());
DataType dst_datatype = dst->input_type(e->dst_input());
string data_format;
// We compare source and destination datatypes only when both are found.
if (src_datatype != dst_datatype) {
string err_msg = "T attribute of " + src->name() + ":" +
std::to_string(e->src_output()) + " and " + dst->name() +
":" + std::to_string(e->dst_input()) +
" do not"
" match. Will not insert MklToTf node in such case.";
return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str());
}
TF_CHECK_OK(
NodeBuilder((*g)->NewName("Mkl2Tf"), "_MklToTf")
.Input(src, e->src_output())
.Input(src, DataIndexToMetaDataIndex(
e->src_output(),
src->num_outputs())) // Get an Mkl tensor slot
// from the Tf tensor slot.
.Device(src->def().device()) // We want to get conversion node
// on same device as source node.
.Attr("T", src_datatype)
.Finalize(&**g, &conversion_node));
CHECK_NOTNULL(conversion_node);
// TODO(Intel-tf) MklToTf accepts only NHWC or NCHW, but doesn't seem to be
// using data_format. This code might be redundant.
if (GetNodeAttr(src->def(), "data_format", &data_format) == OkStatus() &&
(data_format == ToString(FORMAT_NHWC) ||
data_format == ToString(FORMAT_NCHW))) {
conversion_node->AddAttr("data_format", data_format);
}
// Get assigned device from source node and apply it to conversion node.
// We want conversion node to be on the same device as the source node.
conversion_node->set_assigned_device_name(src->assigned_device_name());
// Set the Mkl op label for this op.
conversion_node->AddAttr("_kernel",
mkl_op_registry::kMklLayoutDependentOpLabel);
// Now that we have added edge from src->conversion_node, let's add edge from
// output of conversion_node to the dest node. Since conversion_node
// has only 1 output, the src_output of conversion_node is 0.
CHECK_NOTNULL((*g)->AddEdge(conversion_node, 0, dst, e->dst_input()));
VLOG(1) << "MklToTfConversionPass: Inserting Conversion node on: "
<< src->type_string() << " and " << dst->type_string()
<< " successful.";
// Remove src->dst edge now.
(*g)->RemoveEdge(e);
return OkStatus();
}
Status MklToTfConversionPass::InsertInputConversionNode(
std::unique_ptr<Graph>* g, Node* n) {
CHECK_NOTNULL(n);
// Get the input nodes and edges
std::vector<const Edge*> edges;
TF_CHECK_OK(n->input_edges(&edges));
if (edges.size() != 4) {
return Status(error::Code::INVALID_ARGUMENT,
"MKL Binary Element-wise op should have exactly 2 data"
" inputs and 2 metadata inputs");
}
// Sanity check: ensure that both inputs are of the expected type, and the
// same type as input type
CHECK_EQ(BaseType(edges[0]->src()->output_type(edges[0]->src_output())),
BaseType(edges[1]->src()->output_type(edges[1]->src_output())));
CHECK_EQ(BaseType(edges[0]->src()->output_type(edges[0]->src_output())),
BaseType(n->input_type(0)));
// Check ordering of edges
for (uint32 i = 0; i < 4; i++) {
CHECK_EQ((edges[i]->dst_input() == i), true);
}
// Build the conversion node and specify src as input.
Node* conversion_node = nullptr;
TF_CHECK_OK(
NodeBuilder((*g)->NewName("MklInputConversion"), "_MklInputConversion")
.Input(edges[0]->src(), edges[0]->src_output())
.Input(edges[1]->src(), edges[1]->src_output())
.Input(edges[2]->src(), edges[2]->src_output())
.Input(edges[3]->src(), edges[3]->src_output())
.Device(n->def().device())
.Attr("T", n->input_type(0))
.Finalize(&**g, &conversion_node));
CHECK_NOTNULL(conversion_node);
// Change the destination of any control edges to the InputConversion node
if (edges.size() != n->in_edges().size()) {
std::vector<const Edge*> edges_to_remove;
for (const Edge* e : n->in_edges()) {
if (e->IsControlEdge()) {
CHECK_NOTNULL((*g)->AddControlEdge(e->src(), conversion_node));
edges_to_remove.push_back(e);
}
}
for (const Edge* e : edges_to_remove) {
(*g)->RemoveEdge(e);
}
}
// TODO(Intel-tf) MklInputConversion accepts only NHWC or NCHW, but doesn't
// seem to be using data_format. This code might be redundant.
string data_format;
if (GetNodeAttr(edges[0]->src()->def(), "data_format", &data_format) ==
OkStatus() &&
(data_format == ToString(FORMAT_NHWC) ||
data_format == ToString(FORMAT_NCHW))) {
conversion_node->AddAttr("data_format", data_format);
}
// Get assigned device from destination node and apply it to conversion node.
// We want conversion node to be on the same device as the destination node.
conversion_node->set_assigned_device_name(n->assigned_device_name());
// Set the Mkl op label for this op.
conversion_node->AddAttr("_kernel",
mkl_op_registry::kMklLayoutDependentOpLabel);
// Now that we have added edges from src->conversion_node, let's add edge from
// output of conversion_node to the element-wise node.
CHECK_NOTNULL((*g)->AddEdge(conversion_node, 0, n, edges[0]->dst_input()));
CHECK_NOTNULL((*g)->AddEdge(conversion_node, 1, n, edges[1]->dst_input()));
CHECK_NOTNULL((*g)->AddEdge(conversion_node, 2, n, edges[2]->dst_input()));
CHECK_NOTNULL((*g)->AddEdge(conversion_node, 3, n, edges[3]->dst_input()));
VLOG(1) << "MklToTfConversionPass - InputConversion: Inserting input "
<< "conversion node on: " << n->type_string() << " successful.";
// Remove src->dst edge now.
(*g)->RemoveEdge(edges[0]);
(*g)->RemoveEdge(edges[1]);
(*g)->RemoveEdge(edges[2]);
(*g)->RemoveEdge(edges[3]);
return OkStatus();
}
bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
bool result = false;
CHECK_NOTNULL(g);
DumpGraph("Before MklToTfConversionPass", &**g);
// Since we are looking for an Mkl-supported op node immediately
// followed by a non-Mkl op node, we will just iterate over edge
// set of the graph.
// edge set whose source and destination are candidates for
// inserting conversion node
std::vector<Edge*> candidate_edges;
for (const Edge* e : (*g)->edges()) {
Node* src = e->src();
Node* dst = e->dst();
// We skip control edges.
if (e->IsControlEdge()) {
continue;
}
// We skip adding MklToTf on an edge between X->MklToTf or
// MklToTf->X, where X is any node.
if (src->type_string().compare("_MklToTf") == 0 ||
dst->type_string().compare("_MklToTf") == 0) {
continue;
}
VLOG(1) << "MklToTfConversionPass: InsertConversionNodes: "
<< src->type_string() << " and " << dst->type_string();
// Let's get source and destination data type.
// We cannot check datatype on destination node because destination node
// may not be Mkl node.
DataType src_datatype;
DataType dst_datatype;
bool src_is_mkl_op =
(GetNodeAttr(src->def(), "T", &src_datatype) == OkStatus() &&
IsMklSupportedOp(src->type_string(), src_datatype));
bool dst_is_mkl_op =
(GetNodeAttr(dst->def(), "T", &dst_datatype) == OkStatus() &&
IsMklSupportedOp(dst->type_string(), dst_datatype));
// Check if src with is Mkl-compliant, while dst is not Mkl-compliant.
if (src_is_mkl_op && !dst_is_mkl_op) {
VLOG(1) << "MklToTfConversionPass: Scheduled nodes " << src->name()
<< " and " << dst->name() << " for inserting conversion nodes";
candidate_edges.push_back(const_cast<Edge*>(e));
}
}
// Process all candidate edges and insert conversion nodes on them.
for (Edge* e : candidate_edges) {
// Even if we insert conversion node on a single edge, we
// need to return true.
string src_name = e->src()->name();
string dst_name = e->dst()->name();
if (InsertConversionNodeOnEdge(g, e) == OkStatus()) {
VLOG(1) << "MklToTfConversionPass: Inserted conversion "
<< "node on edge between " << src_name << " and " << dst_name;
result = true;
}
}
DumpGraph("After MklToTfConversionPass", &**g);
//---------------------------------------------------------------------------
// Check all nodes and add an input-conversion-node if the node is an mkl
// element-wise node.
VLOG(1) << "Before running MklToTfConversionPass - InputConversion";
std::vector<Node*> candidate_nodes;
std::vector<Node*> order;
GetReversePostOrder(**g, &order); // This will give us topological sort.
for (Node* n : order) {
// If node is not an op or it does not have a datatype, then skip.
DataType datatype;
if (!n->IsOp() || (GetNodeAttr(n->def(), "T", &datatype) != OkStatus())) {
continue;
}
if (IsMklElementWiseOp(n->type_string(), datatype)) {
// If the input node is an input-conversion op, skip
Node* input_node = nullptr;
TF_CHECK_OK(n->input_node(0, &input_node));
DataType input_datatype;
if ((GetNodeAttr(n->def(), "T", &input_datatype) == OkStatus()) &&
(input_node->type_string().compare("_MklInputConversion") == 0)) {
continue;
}
VLOG(1) << "MklToTfConversionPass: InputConversion: Scheduled node "
<< n->name() << " for inserting input conversion node";
candidate_nodes.push_back(const_cast<Node*>(n));
}
}
// Process all candidate edges and insert conversion nodes on them.
for (Node* n : candidate_nodes) {
// Even if we insert conversion node on a single node, we
// need to return true.
if (InsertInputConversionNode(g, n) == OkStatus()) {
VLOG(1) << "MklToTfConversionPass: Inserted conversion "
<< "on node " << n->name();
result = true;
}
}
DumpGraph("After MklToTfConversionPass - InputConversion", &**g);
// We need to return true even if we insert one conversion node
// anywhere in the graph.
return result;
}
//////////////////////////////////////////////////////////////////////////////
// Run function for the pass
//////////////////////////////////////////////////////////////////////////////
bool InsertMklToTfConversionNodes(std::unique_ptr<Graph>* g) {
return MklToTfConversionPass().RunPass(g);
}
Status MklToTfConversionPass::Run(const GraphOptimizationPassOptions& options) {
if (options.graph == nullptr && options.partition_graphs == nullptr) {
return OkStatus();
}
if (!IsMKLEnabled()) {
VLOG(2) << "TF-MKL: MKL is not enabled";
return OkStatus();
}
if (NativeFormatEnabled()) {
VLOG(2)
<< "Running in native format mode, MklToTfConversionPass won't run.";
return OkStatus();
}
auto process_graph = [&](std::unique_ptr<Graph>* g) {
// Get the ownership of graph
std::unique_ptr<Graph>* ng = std::move(g);
RunPass(ng);
// Return the ownership of graph back
g->reset(ng->release());
};
if (kMklTfConvPassGroup != OptimizationPassRegistry::POST_PARTITIONING) {
// For any pre-partitioning phase, graph is stored in options.graph.
process_graph(options.graph);
} else {
// For post partitioning phase, graphs are stored in
// options.partition_graphs.
for (auto& pg : *options.partition_graphs) {
process_graph(&pg.second);
}
}
return OkStatus();
}
} // namespace tensorflow
#endif // defined(INTEL_MKL) && defined(ENABLE_MKL)

View File

@ -0,0 +1,40 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// An optimization pass that inserts MklToTf conversion nodes in the graph
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_MKL_TFCONVERSION_PASS_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_MKL_TFCONVERSION_PASS_H_
#ifdef INTEL_MKL
#include <sys/types.h>
#include <memory>
#include "tensorflow/core/graph/graph.h"
#ifdef _WIN32
typedef unsigned int uint;
#endif
namespace tensorflow {
// Interface to invoke the pass for unit test
//
// Returns true if and only if 'g' is mutated.
extern bool InsertMklToTfConversionNodes(std::unique_ptr<Graph>* g);
} // namespace tensorflow
#endif
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_MKL_TFCONVERSION_PASS_H_

View File

@ -0,0 +1,311 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#include "tensorflow/core/common_runtime/mkl_tfconversion_pass.h"
#include <algorithm>
#include <vector>
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
namespace {
class MklToTfConversionPass : public ::testing::Test {
public:
MklToTfConversionPass() : graph_(OpRegistry::Global()) {}
static void InitGraph(const string& s, Graph* graph) {
GraphDef graph_def;
auto parser = protobuf::TextFormat::Parser();
CHECK(parser.MergeFromString(s, &graph_def)) << s;
GraphConstructorOptions opts;
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
}
void InitGraph(const string& s) {
InitGraph(s, &graph_);
original_ = CanonicalGraphString(&graph_);
}
static bool IncludeNode(const Node* n) { return n->IsOp(); }
static string EdgeId(const Node* n, int index) {
if (index == 0) {
return n->name();
} else if (index == Graph::kControlSlot) {
return strings::StrCat(n->name(), ":control");
} else {
return strings::StrCat(n->name(), ":", index);
}
}
string CanonicalGraphString(Graph* g) {
std::vector<string> nodes;
std::vector<string> edges;
for (const Node* n : g->nodes()) {
if (IncludeNode(n)) {
nodes.push_back(strings::StrCat(n->name(), "(", n->type_string(), ")"));
}
}
for (const Edge* e : g->edges()) {
if (IncludeNode(e->src()) && IncludeNode(e->dst())) {
edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->",
EdgeId(e->dst(), e->dst_input())));
}
}
// Canonicalize
std::sort(nodes.begin(), nodes.end());
std::sort(edges.begin(), edges.end());
return strings::StrCat(absl::StrJoin(nodes, ";"), "|",
absl::StrJoin(edges, ";"));
}
string DoRunMklToTfConversionPass() {
string before = CanonicalGraphString(&graph_);
LOG(ERROR) << "Before MklToTf conversion pass: " << before;
std::unique_ptr<Graph>* ug = new std::unique_ptr<Graph>(&graph_);
InsertMklToTfConversionNodes(ug);
string result = CanonicalGraphString(&graph_);
LOG(ERROR) << "After MklToTf conversion pass: " << result;
return result;
}
const string& OriginalGraph() const { return original_; }
Graph graph_;
string original_;
};
REGISTER_OP("Float_Input").Output("o: float").SetIsStateful();
REGISTER_OP("_Mkl_Input").Output("o: uint8").SetIsStateful();
TEST_F(MklToTfConversionPass, Basic) {
InitGraph(
"node { name: 'A' op: 'Float_Input'}"
"node { name: 'B' op: 'Float_Input'}"
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }"
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoRunMklToTfConversionPass(),
"A(Float_Input);B(Float_Input);C(Mul);D(Mul)|"
"A->C;A->D;B->C:1;B->D:1");
}
// MklConv2D followed by Non-Mkl layer
// C=MklConv2D(A,M,B,N); E=Sub(C,D) (for interleaved ordering)
// C=MklConv2D(A,B,M,N); E=Sub(C,D) (for contiguous ordering)
TEST_F(MklToTfConversionPass, Positive) {
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
InitGraph(
"node { name: 'A' op: 'Float_Input'}"
"node { name: 'M' op: '_Mkl_Input'}"
"node { name: 'B' op: 'Float_Input'}"
"node { name: 'N' op: '_Mkl_Input'}"
"node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } "
"}"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'M', 'B', 'N']}"
"node { name: 'D' op: 'Float_Input'}"
"node { name: 'E' op: 'Sub'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D']}");
EXPECT_EQ(DoRunMklToTfConversionPass(),
"A(Float_Input);B(Float_Input);C(_MklConv2D);D(Float_Input);E("
"Sub);M(_Mkl_Input);"
"Mkl2Tf/_0(_MklToTf);N(_Mkl_Input)|A->C;B->C:2;C->Mkl2Tf/_0;"
"C:1->Mkl2Tf/_0:1;D->E:1;M->C:1;Mkl2Tf/_0->E;N->C:3");
} else {
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'Float_Input'}"
"node { name: 'B' op: 'Float_Input'}"
"node { name: 'M' op: '_Mkl_Input'}"
"node { name: 'N' op: '_Mkl_Input'}"
"node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } "
"}"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B', 'M', 'N']}"
"node { name: 'D' op: 'Float_Input'}"
"node { name: 'E' op: 'Sub'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D']}");
EXPECT_EQ(DoRunMklToTfConversionPass(),
"A(Float_Input);B(Float_Input);C(_MklConv2D);D(Float_Input);E("
"Sub);M(_Mkl_Input);"
"Mkl2Tf/_0(_MklToTf);N(_Mkl_Input)|A->C;B->C:1;C->Mkl2Tf/_0;"
"C:2->Mkl2Tf/_0:1;D->E:1;M->C:2;Mkl2Tf/_0->E;N->C:3");
}
}
// MklConv2D followed by MklToTf op followed by Non-Mkl layer.
// C=MklConv2D(A,M,B,N); D=MklToTf(C:0, C:1) F=Sub(D,E) (for interleaved)
// C=MklConv2D(A,B,M,N); D=MklToTf(C:0, C:2) F=Sub(D,E) (for contiguous)
// MklToTf node should not be inserted again.
TEST_F(MklToTfConversionPass, Negative_DoubleInsert) {
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
InitGraph(
"node { name: 'A' op: 'Float_Input'}"
"node { name: 'M' op: '_Mkl_Input'}"
"node { name: 'B' op: 'Float_Input'}"
"node { name: 'N' op: '_Mkl_Input'}"
"node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } "
"}"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'M', 'B', 'N']}"
"node { name: 'D' op: '_MklToTf'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['C:0', 'C:1']}"
"node { name: 'E' op: 'Float_Input'}"
"node { name: 'F' op: 'Sub'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'E']}");
EXPECT_EQ(DoRunMklToTfConversionPass(),
"A(Float_Input);B(Float_Input);C(_MklConv2D);D(_MklToTf);E(Float_"
"Input);"
"F(Sub);M(_Mkl_Input);N(_Mkl_Input)|"
"A->C;B->C:2;C->D;C:1->D:1;D->F;E->F:1;M->C:1;N->C:3");
} else {
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'Float_Input'}"
"node { name: 'B' op: 'Float_Input'}"
"node { name: 'M' op: '_Mkl_Input'}"
"node { name: 'N' op: '_Mkl_Input'}"
"node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } "
"}"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B', 'M', 'N']}"
"node { name: 'D' op: '_MklToTf'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['C:0', 'C:2']}"
"node { name: 'E' op: 'Float_Input'}"
"node { name: 'F' op: 'Sub'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'E']}");
EXPECT_EQ(DoRunMklToTfConversionPass(),
"A(Float_Input);B(Float_Input);C(_MklConv2D);D(_MklToTf);E(Float_"
"Input);"
"F(Sub);M(_Mkl_Input);N(_Mkl_Input)|"
"A->C;B->C:1;C->D;C:2->D:1;D->F;E->F:1;M->C:2;N->C:3");
}
}
// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y);
// There is no Mkl layer so no conversion op should be inserted.
TEST_F(MklToTfConversionPass, Negative_NoMklLayer) {
InitGraph(
"node { name: 'A' op: 'Float_Input'}"
"node { name: 'B' op: 'Float_Input'}"
"node { name: 'C' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Float_Input'}"
"node { name: 'E' op: 'BiasAdd'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['C', 'D'] }"
"node { name: 'Y' op: 'Float_Input'}"
"node { name: 'Z' op: 'Sub'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}");
EXPECT_EQ(DoRunMklToTfConversionPass(),
"A(Float_Input);B(Float_Input);C(Conv2D);D(Float_Input);E(BiasAdd);"
"Y(Float_Input);Z(Sub)|"
"A->C;B->C:1;C->E;D->E:1;E->Z;Y->Z:1");
}
static void BM_RunMklToTfConversionPass(int iters, int op_nodes) {
testing::StopTiming();
string s;
for (int in = 0; in < 10; in++) {
s += strings::Printf("node { name: 'in%04d' op: 'Float_Input'}", in);
}
random::PhiloxRandom philox(301, 17);
random::SimplePhilox rnd(&philox);
for (int op = 0; op < op_nodes; op++) {
s += strings::Printf(
"node { name: 'op%04d' op: 'Mul' attr { key: 'T' value { "
"type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }",
op, rnd.Uniform(10), rnd.Uniform(10));
}
bool first = true;
while (iters > 0) {
Graph* graph = new Graph(OpRegistry::Global());
MklToTfConversionPass::InitGraph(s, graph);
int N = graph->num_node_ids();
if (first) {
testing::SetLabel(strings::StrCat("Per graph node. Nodes: ", N));
first = false;
}
{
testing::StartTiming();
std::unique_ptr<Graph> ug(graph);
InsertMklToTfConversionNodes(&ug);
testing::StopTiming();
}
iters -= N; // Our benchmark units are individual graph nodes,
// not whole graphs
// delete graph;
}
}
BENCHMARK(BM_RunMklToTfConversionPass)->Arg(1000)->Arg(10000);
} // namespace
} // namespace tensorflow
#endif // INTEL_MKL && ENABLE_MKL

Some files were not shown because too many files have changed in this diff Show More