mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Merge branch 'tensorflow:master' into to_ordinal
This commit is contained in:
commit
a5303845f1
6
.github/workflows/trusted_partners.js
vendored
6
.github/workflows/trusted_partners.js
vendored
|
|
@ -39,9 +39,9 @@ const get_email_domain = async ({github, username}) => {
|
|||
return domain;
|
||||
};
|
||||
|
||||
/** For trusted parters like Intel, we want to auto-run tests and mark the PR as ready to pull
|
||||
/** For trusted parters like Intel, we want to auto-run tests
|
||||
This allows us to reduce the delay to external partners
|
||||
Add Labels - kokoro:force-run, ready to pull
|
||||
Add Labels - kokoro:force-run
|
||||
The PR is also assigned to specific teams to fast track review
|
||||
Additional reviewers can be added manually based on PR contents
|
||||
@param {!object}
|
||||
|
|
@ -50,7 +50,7 @@ const get_email_domain = async ({github, username}) => {
|
|||
@return {string} Returns the message with labels attached and assignees added
|
||||
*/
|
||||
const filter_action = async ({github, context, domain}) => {
|
||||
const labels = ['kokoro:force-run', 'ready to pull'];
|
||||
const labels = ['kokoro:force-run'];
|
||||
|
||||
let assignees = [];
|
||||
const title = context.payload.pull_request && context.payload.pull_request.title;
|
||||
|
|
|
|||
12
RELEASE.md
12
RELEASE.md
|
|
@ -5,6 +5,16 @@
|
|||
* <DOCUMENT BREAKING CHANGES HERE>
|
||||
* <THIS SECTION SHOULD CONTAIN API, ABI AND BEHAVIORAL BREAKING CHANGES>
|
||||
|
||||
* Build, Compilation and Packaging
|
||||
|
||||
* Removal of redundant packages: the `tensorflow-gpu` and `tf-nightly-gpu`
|
||||
packages have been effectively removed and replaced with packages that
|
||||
direct users to switch to `tensorflow` or `tf-nightly` respectively.
|
||||
The naming difference was the only difference between the two sets of
|
||||
packages ever since TensorFlow 2.1, so there is no loss of functionality
|
||||
or GPU support. See
|
||||
https://pypi.org/project/tensorflow-gpu for more details.
|
||||
|
||||
* `tf.function`:
|
||||
|
||||
* tf.function now uses the Python inspect library directly for parsing
|
||||
|
|
@ -126,6 +136,8 @@
|
|||
* `tf.test`:
|
||||
* Added `tf.test.experimental.sync_devices`, which is useful for
|
||||
accurately measuring performance in benchmarks.
|
||||
* `tf.experimental.dtensor`:
|
||||
* Added experimental support to ReduceScatter fuse on GPU (NCCL).
|
||||
|
||||
# Bug Fixes and Other Changes
|
||||
|
||||
|
|
|
|||
|
|
@ -970,6 +970,7 @@ package_group(
|
|||
"//learning/brain/tfrt/...",
|
||||
"//learning/lib/ami/simple_ml/...",
|
||||
"//learning/pathways/...",
|
||||
"//learning/serving/contrib/tfrt/mlir/canonical_ops/...",
|
||||
"//perftools/accelerators/xprof/integration_tests/...",
|
||||
"//smartass/brain/configure/...",
|
||||
"//tensorflow/...",
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/kernels_experimental.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
|
|
@ -129,6 +130,17 @@ TF_VariableInfo* TF_CreateVariableInfoFromContext(TF_OpKernelContext* ctx,
|
|||
return new TF_VariableInfo(index, handle.name(), variable);
|
||||
}
|
||||
|
||||
void TF_LockVariableInfos(TF_VariableInfo** vars, int num_vars,
|
||||
TF_Status* status) {
|
||||
std::vector<tensorflow::VariableInfo*> variable_ptrs;
|
||||
variable_ptrs.reserve(num_vars);
|
||||
for (int i = 0; i < num_vars; ++i) {
|
||||
variable_ptrs.push_back(&(vars[i]->var_info));
|
||||
}
|
||||
tsl::Status cc_status = LockVariables(absl::MakeSpan(variable_ptrs));
|
||||
tsl::Set_TF_Status_from_Status(status, cc_status);
|
||||
}
|
||||
|
||||
void TF_AllocateTempForVariableInfo(TF_OpKernelContext* ctx,
|
||||
TF_VariableInfo* var_info,
|
||||
TF_Status* status) {
|
||||
|
|
|
|||
|
|
@ -89,6 +89,10 @@ TF_CAPI_EXPORT extern void TF_LookupOrCreatePluginResource(
|
|||
TF_CAPI_EXPORT extern TF_VariableInfo* TF_CreateVariableInfoFromContext(
|
||||
TF_OpKernelContext* ctx, int index, TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TF_LockVariableInfos(TF_VariableInfo** vars,
|
||||
int num_vars,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TF_AllocateTempForVariableInfo(
|
||||
TF_OpKernelContext* ctx, TF_VariableInfo* var_info, TF_Status* status);
|
||||
|
||||
|
|
|
|||
28
tensorflow/compiler/mlir/quantization/stablehlo/BUILD
Normal file
28
tensorflow/compiler/mlir/quantization/stablehlo/BUILD
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library")
|
||||
load("//tensorflow/compiler/mlir/quantization/stablehlo:internal_visibility_allowlist.bzl", "internal_visibility_allowlist")
|
||||
|
||||
package_group(
|
||||
name = "internal_visibility_allowlist_package",
|
||||
packages = [
|
||||
"//tensorflow/compiler/mlir/lite/...",
|
||||
"//tensorflow/compiler/mlir/quantization/...",
|
||||
"//third_party/cloud_tpu/inference_converter/...", # TPU Inference Converter V1
|
||||
] + internal_visibility_allowlist(),
|
||||
)
|
||||
|
||||
tf_proto_library(
|
||||
name = "quantization_options_proto",
|
||||
srcs = ["quantization_options.proto"],
|
||||
cc_api_version = 2,
|
||||
make_default_target_header_only = True,
|
||||
visibility = [":internal_visibility_allowlist_package"],
|
||||
)
|
||||
|
||||
# copybara:uncomment_begin(google-only)
|
||||
# py_proto_library(
|
||||
# name = "quantization_options_py_pb2",
|
||||
# api_version = 2,
|
||||
# visibility = [":internal_visibility_allowlist_package"],
|
||||
# deps = [":quantization_options_proto"],
|
||||
# )
|
||||
# copybara:uncomment_end
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
"""Internal visibility rules."""
|
||||
|
||||
def internal_visibility_allowlist():
|
||||
"""Returns a list of g3 packages that can depend on internal targets."""
|
||||
return [
|
||||
"//learning/brain/experimental/mlir/quantization/...",
|
||||
"//learning/brain/mlir/quantization/tensorflow/...",
|
||||
"//learning/brain/mobile/programmability/...",
|
||||
"//learning/brain/experimental/tfq/...",
|
||||
]
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package stablehlo.quantization;
|
||||
|
||||
option cc_enable_arenas = true;
|
||||
|
||||
// Defines arious options to specify and control the behavior of the
|
||||
// StableHLO quantizer.
|
||||
// NEXT ID: 2
|
||||
message QuantizationOptions {
|
||||
QuantizationMethod quantization_method = 1;
|
||||
}
|
||||
|
||||
// NEXT ID: 3
|
||||
message QuantizationMethod {
|
||||
// Quantization Method can be either preset or custom.
|
||||
oneof quantization_method {
|
||||
PresetQuantizationMethod preset_quantization_method = 1;
|
||||
CustomQuantizationMethod custom_quantization_method = 2;
|
||||
}
|
||||
}
|
||||
|
||||
// Preset model quantization method for optimization.
|
||||
//
|
||||
// Common quantization methods are defined as preset methods in this message.
|
||||
// Note that the quantization specs (ex: bit width) for preset quantization
|
||||
// methods are fixed. To use a different quantization spec for a given method,
|
||||
// use CustomQuantizationMethod.
|
||||
// NEXT ID: 2
|
||||
message PresetQuantizationMethod {
|
||||
// Preset quantization methods that are supported as a stable API.
|
||||
// NEXT ID: 3
|
||||
enum PresetMethod {
|
||||
// TODO(b/266173150): Update preset methods after redefining quantization
|
||||
// pattern matching in DarwiNN.
|
||||
// This should never be used. Using this will generally result in an error.
|
||||
METHOD_UNSPECIFIED = 0; // go/do-include-enum-unspecified
|
||||
|
||||
// Apply default weight-only quantization. Weights are quantized during
|
||||
// conversion, then dequantized during inference. Data type is as follows:
|
||||
// Weight: i8, Bias: f32, Activation: f32, Input/output: f32
|
||||
WEIGHT_ONLY = 1;
|
||||
|
||||
// Apply default dynamic range quantization. Quantized tensor value's
|
||||
// ranges are determined during graph runtime. Data type is as follows:
|
||||
// Weight: i8, Bias: f32, Activation: f32, Input/output: f32
|
||||
DYNAMIC_RANGE = 2;
|
||||
}
|
||||
PresetMethod preset_method = 1;
|
||||
}
|
||||
|
||||
// Custom option for specifying quantization spec details.
|
||||
// If the selected quantization option is not available, StableHLO quantizer
|
||||
// will raise an error.
|
||||
// NEXT ID: 2
|
||||
message CustomQuantizationMethod {
|
||||
// Specify component name, bit width, and other specs for all compoenents
|
||||
// intended to be quantized.
|
||||
repeated QuantizationComponentSpec quantization_component_spec = 1;
|
||||
}
|
||||
|
||||
// Quantization spec per each component designated to be quantized.
|
||||
// Components whose QuantizationComponentSpec is not specified will not be
|
||||
// quantized, and remain f32.
|
||||
// NEXT ID: 7
|
||||
message QuantizationComponentSpec {
|
||||
// NEXT ID: 4
|
||||
enum QuantizationComponent {
|
||||
COMPONENT_UNSPECIFIED = 0;
|
||||
COMPONENT_ACTIVATION = 1;
|
||||
COMPONENT_WEIGHT = 2;
|
||||
COMPONENT_BIAS = 3;
|
||||
}
|
||||
|
||||
// NEXT ID: 4
|
||||
enum BitWidth {
|
||||
BIT_WIDTH_UNSPECIFIED = 0;
|
||||
BIT_WIDTH_4 = 1;
|
||||
BIT_WIDTH_8 = 2;
|
||||
BIT_WIDTH_16 = 3;
|
||||
}
|
||||
|
||||
// NEXT ID: 4
|
||||
enum BitType {
|
||||
BIT_TYPE_UNSPECIFIED = 0;
|
||||
BIT_TYPE_INT = 1;
|
||||
BIT_TYPE_FLOAT = 2;
|
||||
BIT_TYPE_BFLOAT = 3;
|
||||
}
|
||||
|
||||
QuantizationComponent quantization_component = 1;
|
||||
|
||||
// Defines the target bit of the data.
|
||||
BitWidth bit_width = 2;
|
||||
|
||||
// Defines the type of data of the quantized component.
|
||||
BitType bit_type = 3;
|
||||
|
||||
// Defines whether quantization is done in narrow range.
|
||||
bool enable_narrow_range = 4;
|
||||
|
||||
// Defines whether quantiation is done per-channel.
|
||||
bool enable_per_channel_quantization = 5;
|
||||
|
||||
// Defines whether quantization is done symmetrically.
|
||||
bool enable_symmetric = 6;
|
||||
}
|
||||
|
|
@ -289,9 +289,9 @@ cc_library(
|
|||
hdrs = ["ops/uniform_op_quant_spec.h"],
|
||||
compatible_with = get_compatible_with_cloud(),
|
||||
deps = [
|
||||
":tf_quant_ops",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/quantization/tensorflow:tf_quant_ops",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@llvm-project//mlir:IR",
|
||||
|
|
@ -423,6 +423,7 @@ cc_library(
|
|||
],
|
||||
compatible_with = get_compatible_with_cloud(),
|
||||
deps = [
|
||||
":passes",
|
||||
"//tensorflow/compiler/mlir/quantization/tensorflow/debugging:mlir_dump",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||
|
|
@ -434,6 +435,7 @@ cc_library(
|
|||
"//tensorflow/core:core_cpu_base",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:path",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ option cc_enable_arenas = true;
|
|||
// metadata required for building a SavedModel. This message is primarily used
|
||||
// to "export" the model produced from various quantization passes in c++ to
|
||||
// Python layer.
|
||||
// Next ID: 7
|
||||
message ExportedModel {
|
||||
GraphDef graph_def = 1;
|
||||
|
||||
|
|
@ -29,4 +30,11 @@ message ExportedModel {
|
|||
// fetching the restore op (see `restore_node_name`), this value is provided
|
||||
// to the "file_prefix" tensor to identify the checkpoint directory.
|
||||
string checkpoint_dir = 5;
|
||||
|
||||
// Function name -> function alias mapping. This associates the quantized
|
||||
// functions to the original functions' aliases. This information will be used
|
||||
// to populate `MetaInfoDef`s `function_aliases` when the quantized model is
|
||||
// exported to the saved model. This field is usually only populated for the
|
||||
// TF2 models.
|
||||
map<string, string> function_aliases = 6;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -59,6 +59,8 @@ cc_library(
|
|||
"//tensorflow/tsl/platform:env",
|
||||
"//tensorflow/tsl/platform:path",
|
||||
"//tensorflow/tsl/platform:status",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ from tensorflow.python.platform import tf_logging as logging
|
|||
from tensorflow.python.saved_model import builder
|
||||
from tensorflow.python.saved_model import loader_impl as saved_model_loader
|
||||
from tensorflow.python.saved_model import save as saved_model_save
|
||||
from tensorflow.python.saved_model import save_options
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.saved_model import signature_def_utils_impl
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
|
|
@ -1852,6 +1853,75 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest):
|
|||
self.assertAllClose(new_outputs, got_outputs, atol=0.0666)
|
||||
self.assertAllClose(new_outputs, expected_outputs, atol=0.057)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_function_alias_preserved(self):
|
||||
model = self._create_conv2d_model(
|
||||
input_shape=(1, 3, 4, 3), filter_shape=(2, 3, 3, 2)
|
||||
)
|
||||
|
||||
signatures = {
|
||||
'serving_default': model.conv.get_concrete_function(),
|
||||
}
|
||||
save_opts = save_options.SaveOptions(
|
||||
function_aliases={'conv_func': model.conv}
|
||||
)
|
||||
|
||||
saved_model_save.save(
|
||||
model, self._input_saved_model_path, signatures, save_opts
|
||||
)
|
||||
|
||||
def data_gen() -> repr_dataset.RepresentativeDataset:
|
||||
rng = np.random.default_rng(seed=123)
|
||||
for _ in range(2):
|
||||
yield {
|
||||
'input_tensor': rng.uniform(
|
||||
low=0, high=150, size=(1, 3, 4, 3)
|
||||
).astype(np.float32),
|
||||
}
|
||||
|
||||
tags = {tag_constants.SERVING}
|
||||
|
||||
quantization_options = quant_opts_pb2.QuantizationOptions(
|
||||
quantization_method=quant_opts_pb2.QuantizationMethod(
|
||||
experimental_method=_ExperimentalMethod.STATIC_RANGE
|
||||
),
|
||||
op_set=quant_opts_pb2.OpSet.XLA,
|
||||
)
|
||||
|
||||
converted_model = quantize_model.quantize(
|
||||
self._input_saved_model_path,
|
||||
['serving_default'],
|
||||
tags,
|
||||
self._output_saved_model_path,
|
||||
quantization_options,
|
||||
representative_dataset=data_gen(),
|
||||
)
|
||||
|
||||
self.assertIsNotNone(converted_model)
|
||||
self.assertCountEqual(
|
||||
converted_model.signatures._signatures.keys(), {'serving_default'}
|
||||
)
|
||||
|
||||
# Test whether the aliased function exists.
|
||||
output_loader = saved_model_loader.SavedModelLoader(
|
||||
self._output_saved_model_path
|
||||
)
|
||||
|
||||
# Confirm that the function alias is preserved.
|
||||
meta_graph_def = output_loader.get_meta_graph_def_from_tags(tags)
|
||||
function_aliases = meta_graph_def.meta_info_def.function_aliases
|
||||
self.assertNotEmpty(function_aliases)
|
||||
self.assertCountEqual(function_aliases.values(), {'conv_func'})
|
||||
|
||||
# Test that the aliased function contains a quantized op.
|
||||
for func_name, alias in function_aliases.items():
|
||||
if alias == 'conv_func':
|
||||
for func in meta_graph_def.graph_def.library.function:
|
||||
if func.signature.name == func_name:
|
||||
self._contains_op_with_name_and_attribute(
|
||||
func.node_def, op_name='XlaConvV2', attr_name='', attr_val=None
|
||||
)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def test_matmul_ptq_model_with_unfreeze_constants(self):
|
||||
# Uses large weight to exceed the constant size threshold of 64KiB
|
||||
|
|
|
|||
|
|
@ -176,11 +176,12 @@ PYBIND11_MODULE(pywrap_quantize_model, m) {
|
|||
[](const absl::string_view saved_model_path,
|
||||
const std::vector<std::string>& signature_keys,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
const QuantizationOptions& quant_opts)
|
||||
const QuantizationOptions& quant_opts,
|
||||
const absl::flat_hash_map<std::string, std::string>& function_aliases)
|
||||
-> absl::StatusOr<ExportedModel> {
|
||||
return tensorflow::quantization::internal::
|
||||
QuantizePtqModelPreCalibration(saved_model_path, signature_keys,
|
||||
tags, quant_opts);
|
||||
tags, quant_opts, function_aliases);
|
||||
},
|
||||
R"pbdoc(
|
||||
Returns serialized ExportedModel that contains the model's GraphDef and
|
||||
|
|
@ -196,11 +197,12 @@ PYBIND11_MODULE(pywrap_quantize_model, m) {
|
|||
[](const absl::string_view saved_model_path,
|
||||
const std::vector<std::string>& signature_keys,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
const QuantizationOptions& quant_opts)
|
||||
const QuantizationOptions& quant_opts,
|
||||
const absl::flat_hash_map<std::string, std::string>& function_aliases)
|
||||
-> absl::StatusOr<ExportedModel> {
|
||||
return tensorflow::quantization::internal::
|
||||
QuantizePtqModelPostCalibration(saved_model_path, signature_keys,
|
||||
tags, quant_opts);
|
||||
tags, quant_opts, function_aliases);
|
||||
},
|
||||
R"pbdoc(
|
||||
Returns serialized ExportedModel that contains the quantized model's
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
|
|
@ -116,6 +117,11 @@ void AddExportPasses(const bool duplicate_shape_determining_constants,
|
|||
mlir::CreateFunctionalToExecutorDialectConversionPass());
|
||||
pm.addPass(mlir::CreateBreakUpIslandsPass());
|
||||
pm.addPass(mlir::quant::CreateMergeInitializerFunctionOpsToMainPass());
|
||||
|
||||
// Used to clean up the "tf._noinliner" attribute that is previously used to
|
||||
// prevent certain functions from being inlined (see
|
||||
// `MarkFunctionsNoinlinePass`). InlinerPass must not come after this pass.
|
||||
pm.addPass(mlir::TF::CreateStripNoinlineAttributePass());
|
||||
}
|
||||
|
||||
// Finds and returns the name of the node from a set of control output nodes.
|
||||
|
|
@ -139,7 +145,8 @@ std::string GetNodeName(const absl::flat_hash_set<Node *> &control_ret_nodes,
|
|||
GraphDef &&graph_def, const absl::string_view init_node_name,
|
||||
const absl::string_view restore_node_name,
|
||||
const absl::string_view checkpoint_dir,
|
||||
const std::vector<std::string> &variable_shared_names) {
|
||||
const std::vector<std::string> &variable_shared_names,
|
||||
const absl::flat_hash_map<std::string, std::string> &function_aliases) {
|
||||
ExportedModel exported_model{};
|
||||
*exported_model.mutable_graph_def() = graph_def;
|
||||
exported_model.set_init_node_name(std::string(init_node_name));
|
||||
|
|
@ -148,6 +155,8 @@ std::string GetNodeName(const absl::flat_hash_set<Node *> &control_ret_nodes,
|
|||
for (auto &shared_name : variable_shared_names) {
|
||||
*exported_model.mutable_variable_shared_names()->Add() = shared_name;
|
||||
}
|
||||
exported_model.mutable_function_aliases()->insert(function_aliases.begin(),
|
||||
function_aliases.end());
|
||||
|
||||
return exported_model;
|
||||
}
|
||||
|
|
@ -156,7 +165,8 @@ std::string GetNodeName(const absl::flat_hash_set<Node *> &control_ret_nodes,
|
|||
// when the conversion fails.
|
||||
absl::StatusOr<ExportedModel> ConvertMlirModuleToExportedModel(
|
||||
const mlir::ModuleOp module_op, const absl::string_view checkpoint_dir,
|
||||
const std::vector<std::string> &variable_shared_names) {
|
||||
const std::vector<std::string> &variable_shared_names,
|
||||
const absl::flat_hash_map<std::string, std::string> &function_aliases) {
|
||||
const GraphExportConfig config{};
|
||||
FunctionLibraryDefinition flib_def{OpRegistry::Global(),
|
||||
FunctionDefLibrary()};
|
||||
|
|
@ -179,7 +189,7 @@ absl::StatusOr<ExportedModel> ConvertMlirModuleToExportedModel(
|
|||
|
||||
return CreateExportedModel(std::move(graph_def), init_node_name,
|
||||
restore_node_name, checkpoint_dir,
|
||||
variable_shared_names);
|
||||
variable_shared_names, function_aliases);
|
||||
}
|
||||
|
||||
// Runs MLIR passes with `module_op`. The passes are added by calling
|
||||
|
|
@ -387,14 +397,49 @@ absl::StatusOr<ExportedModel> QuantizeQatModel(
|
|||
}
|
||||
|
||||
return ConvertMlirModuleToExportedModel(*module_ref, *checkpoint_dir,
|
||||
*variable_shared_names);
|
||||
*variable_shared_names,
|
||||
/*function_aliases=*/{});
|
||||
}
|
||||
|
||||
// Returns the updated function aliases. `module_op` may have different function
|
||||
// names from the original model, so it re-associates the aliases with the new
|
||||
// function names. Both the input `function_aliases` and the returned value
|
||||
// are function name -> alias mappings. `function_aliases` is the function alias
|
||||
// mapping of the original function.
|
||||
absl::flat_hash_map<std::string, std::string> UpdateFunctionAliases(
|
||||
const absl::flat_hash_map<std::string, std::string> function_aliases,
|
||||
mlir::ModuleOp module_op) {
|
||||
absl::flat_hash_map<std::string, std::string> updated_function_aliases;
|
||||
|
||||
module_op->walk([&](mlir::func::FuncOp func_op) {
|
||||
// We may retrieve the original function's name from the attribute.
|
||||
// Functions without this attribute are ignored.
|
||||
auto original_func_name =
|
||||
func_op->getAttrOfType<mlir::StringAttr>("tf._original_func_name");
|
||||
if (original_func_name) {
|
||||
if (auto alias_itr = function_aliases.find(original_func_name.str());
|
||||
alias_itr != function_aliases.end()) {
|
||||
const std::string alias = alias_itr->second;
|
||||
const std::string new_func_name = func_op.getSymName().str();
|
||||
|
||||
updated_function_aliases[new_func_name] = alias;
|
||||
|
||||
VLOG(1) << "Updated function alias. Alias: " << alias
|
||||
<< ", New function name: " << new_func_name
|
||||
<< ", Old function name: " << original_func_name.str();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return updated_function_aliases;
|
||||
}
|
||||
|
||||
absl::StatusOr<ExportedModel> QuantizePtqModelPreCalibration(
|
||||
const absl::string_view saved_model_path,
|
||||
const std::vector<std::string> &signature_keys,
|
||||
const std::unordered_set<std::string> &tags,
|
||||
const QuantizationOptions &quantization_options) {
|
||||
const QuantizationOptions &quantization_options,
|
||||
const absl::flat_hash_map<std::string, std::string> &function_aliases) {
|
||||
// Convert the SavedModelBundle to an MLIR module.
|
||||
mlir::MLIRContext context = CreateMlirContextForTfQuantization();
|
||||
|
||||
|
|
@ -416,10 +461,21 @@ absl::StatusOr<ExportedModel> QuantizePtqModelPreCalibration(
|
|||
}
|
||||
mlir::OwningOpRef<mlir::ModuleOp> module_ref = std::move(module).value();
|
||||
|
||||
const absl::flat_hash_map<std::string, std::string> updated_function_aliases =
|
||||
UpdateFunctionAliases(function_aliases, *module_ref);
|
||||
|
||||
// Collect the names of the functions that have aliases so that they may not
|
||||
// be inlined.
|
||||
absl::flat_hash_set<std::string> aliased_function_names;
|
||||
absl::c_for_each(updated_function_aliases, [&](const auto &aliases) {
|
||||
return aliased_function_names.insert(aliases.first);
|
||||
});
|
||||
|
||||
if (const absl::Status preprocess_status = PreprocessAndFreezeGraph(
|
||||
/*mlir_dump_file_prefix=*/kTfQuantPtqPreCalibrationStepName,
|
||||
/*is_inliner_run=*/true, module_ref.get(), &context,
|
||||
bundle ? bundle->GetSession() : nullptr);
|
||||
/*is_inliner_run=*/true,
|
||||
/*noinline_functions=*/aliased_function_names, module_ref.get(),
|
||||
&context, bundle ? bundle->GetSession() : nullptr);
|
||||
!preprocess_status.ok()) {
|
||||
return preprocess_status;
|
||||
}
|
||||
|
|
@ -456,14 +512,16 @@ absl::StatusOr<ExportedModel> QuantizePtqModelPreCalibration(
|
|||
}
|
||||
|
||||
return ConvertMlirModuleToExportedModel(*module_ref, *checkpoint_dir,
|
||||
*variable_shared_names);
|
||||
*variable_shared_names,
|
||||
updated_function_aliases);
|
||||
}
|
||||
|
||||
absl::StatusOr<ExportedModel> QuantizePtqModelPostCalibration(
|
||||
const absl::string_view saved_model_path,
|
||||
const std::vector<std::string> &signature_keys,
|
||||
const std::unordered_set<std::string> &tags,
|
||||
const QuantizationOptions &quantization_options) {
|
||||
const QuantizationOptions &quantization_options,
|
||||
const absl::flat_hash_map<std::string, std::string> &function_aliases) {
|
||||
// Convert the SavedModelBundle to an MLIR module.
|
||||
mlir::MLIRContext context = CreateMlirContextForTfQuantization();
|
||||
|
||||
|
|
@ -486,13 +544,24 @@ absl::StatusOr<ExportedModel> QuantizePtqModelPostCalibration(
|
|||
|
||||
mlir::OwningOpRef<mlir::ModuleOp> module_ref = std::move(module).value();
|
||||
|
||||
const absl::flat_hash_map<std::string, std::string> updated_function_aliases =
|
||||
UpdateFunctionAliases(function_aliases, *module_ref);
|
||||
|
||||
// Collect the names of the functions that have aliases so that they may not
|
||||
// be inlined.
|
||||
absl::flat_hash_set<std::string> aliased_function_names;
|
||||
absl::c_for_each(updated_function_aliases, [&](const auto &aliases) {
|
||||
return aliased_function_names.insert(aliases.first);
|
||||
});
|
||||
|
||||
// Freezing is required again since variables might have been produced during
|
||||
// the pre-calibration step. `is_inliner_run = false` to prevent the functions
|
||||
// lifted for quantization from being inlined.
|
||||
if (const absl::Status preprocess_status = PreprocessAndFreezeGraph(
|
||||
/*mlir_dump_file_prefix=*/kTfQuantPtqPostCalibrationStepName,
|
||||
/*is_inliner_run=*/false, module_ref.get(), &context,
|
||||
bundle ? bundle->GetSession() : nullptr);
|
||||
/*is_inliner_run=*/false,
|
||||
/*noinline_functions=*/aliased_function_names, module_ref.get(),
|
||||
&context, bundle ? bundle->GetSession() : nullptr);
|
||||
!preprocess_status.ok()) {
|
||||
return preprocess_status;
|
||||
}
|
||||
|
|
@ -527,7 +596,8 @@ absl::StatusOr<ExportedModel> QuantizePtqModelPostCalibration(
|
|||
}
|
||||
|
||||
return ConvertMlirModuleToExportedModel(*module_ref, *checkpoint_dir,
|
||||
*variable_shared_names);
|
||||
*variable_shared_names,
|
||||
updated_function_aliases);
|
||||
}
|
||||
|
||||
absl::StatusOr<ExportedModel> QuantizePtqDynamicRange(
|
||||
|
|
@ -592,7 +662,8 @@ absl::StatusOr<ExportedModel> QuantizePtqDynamicRange(
|
|||
}
|
||||
|
||||
return ConvertMlirModuleToExportedModel(*module_ref, *checkpoint_dir,
|
||||
*variable_shared_names);
|
||||
*variable_shared_names,
|
||||
/*function_aliases=*/{});
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h"
|
||||
|
|
@ -60,13 +61,15 @@ absl::StatusOr<ExportedModel> QuantizePtqModelPreCalibration(
|
|||
absl::string_view saved_model_path,
|
||||
const std::vector<std::string>& exported_names,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
const QuantizationOptions& quant_opts);
|
||||
const QuantizationOptions& quant_opts,
|
||||
const absl::flat_hash_map<std::string, std::string>& function_aliases);
|
||||
|
||||
absl::StatusOr<ExportedModel> QuantizePtqModelPostCalibration(
|
||||
absl::string_view saved_model_path,
|
||||
const std::vector<std::string>& signature_keys,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
const QuantizationOptions& quant_opts);
|
||||
const QuantizationOptions& quant_opts,
|
||||
const absl::flat_hash_map<std::string, std::string>& function_aliases);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace quantization
|
||||
|
|
|
|||
|
|
@ -619,12 +619,19 @@ def _run_static_range_ptq(
|
|||
according to the quantized graph to match the original signature defs.
|
||||
"""
|
||||
logging.info('Running post-training quantization pre-calibration step.')
|
||||
|
||||
loader = saved_model_loader.SavedModelLoader(saved_model_path)
|
||||
function_aliases = loader.get_meta_graph_def_from_tags(
|
||||
tags
|
||||
).meta_info_def.function_aliases
|
||||
|
||||
exported_model_serialized = (
|
||||
quantize_model_wrapper.quantize_ptq_model_pre_calibration(
|
||||
saved_model_path,
|
||||
list(signature_def_keys),
|
||||
set(tags),
|
||||
quant_opts.SerializeToString(),
|
||||
dict(function_aliases),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -648,6 +655,7 @@ def _run_static_range_ptq(
|
|||
exported_model.restore_node_name,
|
||||
exported_model.checkpoint_dir,
|
||||
exported_model.variable_shared_names,
|
||||
exported_model.function_aliases,
|
||||
)
|
||||
|
||||
# Uses the representative dataset to collect statistics for calibration.
|
||||
|
|
@ -678,6 +686,7 @@ def _run_static_range_ptq(
|
|||
list(signature_def_keys),
|
||||
set(tags),
|
||||
quant_opts.SerializeToString(),
|
||||
dict(exported_model.function_aliases),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -685,10 +694,7 @@ def _run_static_range_ptq(
|
|||
exported_model_serialized
|
||||
)
|
||||
|
||||
return (
|
||||
exported_model,
|
||||
signature_def_map,
|
||||
)
|
||||
return exported_model, signature_def_map
|
||||
|
||||
|
||||
def _static_range_quantize(
|
||||
|
|
@ -780,6 +786,7 @@ def _static_range_quantize(
|
|||
restore_op_name=exported_model.restore_node_name,
|
||||
checkpoint_dir=exported_model.checkpoint_dir,
|
||||
variable_shared_names=exported_model.variable_shared_names,
|
||||
function_aliases=exported_model.function_aliases,
|
||||
)
|
||||
|
||||
return saved_model_load(output_directory)
|
||||
|
|
|
|||
|
|
@ -298,6 +298,39 @@ def _find_variables(
|
|||
return var_mapping
|
||||
|
||||
|
||||
def _save_function_alias(
|
||||
saved_model_dir: str,
|
||||
tags: Collection[str],
|
||||
function_aliases: Mapping[str, str],
|
||||
) -> None:
|
||||
"""Saves the function alias to the SavedModel.
|
||||
|
||||
SavedModelBuilder (TF1 saved model saver) does not support saving function
|
||||
aliases, so this function loads the SavedModel proto and adds the
|
||||
`function_aliases` field.
|
||||
|
||||
Args:
|
||||
saved_model_dir: Path to the saved model directory.
|
||||
tags: A collection of tags to specify the meta graph.
|
||||
function_aliases: Function name -> function alias mapping.
|
||||
"""
|
||||
loader = saved_model_loader.SavedModelLoader(saved_model_dir)
|
||||
meta_graph_def = loader.get_meta_graph_def_from_tags(tags)
|
||||
|
||||
for function_name, function_alias in function_aliases.items():
|
||||
meta_graph_def.meta_info_def.function_aliases[function_name] = (
|
||||
function_alias
|
||||
)
|
||||
|
||||
saved_model_proto_serialized = loader.saved_model.SerializeToString()
|
||||
|
||||
# TODO(b/266015731): Also update and set the SavedModel fingerprint.
|
||||
path = file_io.join(
|
||||
saved_model_dir, saved_model_constants.SAVED_MODEL_FILENAME_PB
|
||||
)
|
||||
file_io.atomic_write_string_to_file(path, saved_model_proto_serialized)
|
||||
|
||||
|
||||
def save_model_v1(
|
||||
graph_def: graph_pb2.GraphDef,
|
||||
output_dir: str,
|
||||
|
|
@ -307,6 +340,7 @@ def save_model_v1(
|
|||
restore_op_name: Optional[str] = None,
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
variable_shared_names: Optional[Sequence[str]] = None,
|
||||
function_aliases: Optional[Mapping[str, str]] = None,
|
||||
) -> None:
|
||||
"""Saves the model.
|
||||
|
||||
|
|
@ -322,6 +356,7 @@ def save_model_v1(
|
|||
restore_op_name: Name of the node for restoration.
|
||||
checkpoint_dir: Path to checkpoint file where variable values are saved.
|
||||
variable_shared_names: Shared name of the variables in the model.
|
||||
function_aliases: Function name -> function alias mapping.
|
||||
|
||||
Raises:
|
||||
ValueError iff the graph does not contain a valid signature.
|
||||
|
|
@ -372,3 +407,6 @@ def save_model_v1(
|
|||
)
|
||||
|
||||
v1_builder.save()
|
||||
|
||||
if function_aliases:
|
||||
_save_function_alias(output_dir, tags, function_aliases)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ option cc_enable_arenas = true;
|
|||
// Various techniques for model quantization are defined within this message
|
||||
// along with a field that specifies a method to be used for a particular
|
||||
// quantization request.
|
||||
// NEXT ID: 3
|
||||
message QuantizationMethod {
|
||||
// Quantization methods that are supported as a stable API.
|
||||
enum Method {
|
||||
|
|
@ -74,8 +75,10 @@ enum QuantizationPrecision {
|
|||
// Unit (either nodes or ops at this moment) wise quantization method for
|
||||
// mixed bit precision quantization. It contains the name of the unit,
|
||||
// the granularity of the unit, and the quantization method for each unit.
|
||||
// NEXT ID: 6
|
||||
message UnitWiseQuantizationPrecision {
|
||||
// Quantization unit granularity.
|
||||
// NEXT ID: 3
|
||||
enum UnitType {
|
||||
// This should never be used. Using this will generally result in an error.
|
||||
UNIT_UNSPECIFIED = 0;
|
||||
|
|
@ -101,6 +104,7 @@ message UnitWiseQuantizationPrecision {
|
|||
|
||||
// List of supported opsets to deploy the quantized model.
|
||||
// The quantized model contains different set of ops depending on the opset.
|
||||
// NEXT ID: 4
|
||||
enum OpSet {
|
||||
OP_SET_UNSPECIFIED = 0; // go/do-include-enum-unspecified
|
||||
// Uses TF ops that mimic quantization behavior. Used when the corresponding
|
||||
|
|
@ -113,6 +117,7 @@ enum OpSet {
|
|||
}
|
||||
|
||||
// Configurations for variable freezing during quantization passes.
|
||||
// NEXT ID: 2
|
||||
message FreezeAllVariables {
|
||||
// Setting this to true freezes all variables to constants during
|
||||
// quantization. Setting this to `false` is an experimental feature and does
|
||||
|
|
@ -127,6 +132,7 @@ message FreezeAllVariables {
|
|||
// 2) A set of supported operations.
|
||||
// 3) Unit wise quantization precision.
|
||||
// 4) Target hardware name.
|
||||
// NEXT ID: 9
|
||||
message QuantizationOptions {
|
||||
// The default quantization configuration for the model. If the below
|
||||
// unit-wise configuration does not exist, we use this default quantization
|
||||
|
|
|
|||
|
|
@ -15,7 +15,9 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
|
|
@ -28,6 +30,7 @@ limitations under the License.
|
|||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h"
|
||||
#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h"
|
||||
|
|
@ -64,6 +67,7 @@ absl::Status RunPassesOnModuleOp(const absl::string_view mlir_dump_file_name,
|
|||
|
||||
absl::Status PreprocessAndFreezeGraph(
|
||||
const absl::string_view mlir_dump_file_prefix, const bool is_inliner_run,
|
||||
const absl::flat_hash_set<std::string>& noinline_functions,
|
||||
mlir::ModuleOp module_op, mlir::MLIRContext* context,
|
||||
llvm::Optional<Session*> session) {
|
||||
mlir::PassManager pm_before_freezing_variables(context);
|
||||
|
|
@ -82,6 +86,12 @@ absl::Status PreprocessAndFreezeGraph(
|
|||
mlir::PassManager pm_after_freezing_variables(context);
|
||||
pm_after_freezing_variables.addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||
pm_after_freezing_variables.addPass(mlir::createCanonicalizerPass());
|
||||
|
||||
// Makes certain functions immune to the `InlinerPass`. Used to preserve
|
||||
// aliased functions.
|
||||
pm_after_freezing_variables.addNestedPass<mlir::func::FuncOp>(
|
||||
mlir::quant::CreateMarkFunctionsNoinlinePass(std::vector<std::string>(
|
||||
noinline_functions.begin(), noinline_functions.end())));
|
||||
if (is_inliner_run) {
|
||||
pm_after_freezing_variables.addPass(mlir::createInlinerPass());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,9 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_QUANTIZE_PREPROCESS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_QUANTIZE_PREPROCESS_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
|
|
@ -36,11 +39,11 @@ inline constexpr absl::string_view kDefaultTfQuantMlirDumpFilePrefix =
|
|||
// `mlir_dump_file_prefix` is primarily used for debugging and does not affect
|
||||
// the preprocessing behavior. Instructions for producing MLIR dump files are in
|
||||
// the comments of `tensorflow::quantization::MaybeEnableIrPrinting` function.
|
||||
absl::Status PreprocessAndFreezeGraph(absl::string_view mlir_dump_file_prefix,
|
||||
bool is_inliner_run,
|
||||
mlir::ModuleOp module_op,
|
||||
mlir::MLIRContext* context,
|
||||
llvm::Optional<Session*> session);
|
||||
absl::Status PreprocessAndFreezeGraph(
|
||||
absl::string_view mlir_dump_file_prefix, bool is_inliner_run,
|
||||
const absl::flat_hash_set<std::string>& noinline_functions,
|
||||
mlir::ModuleOp module_op, mlir::MLIRContext* context,
|
||||
llvm::Optional<Session*> session);
|
||||
|
||||
// Overload of `PreprocessAndFreezeGraph` that uses the default MLIR dump file
|
||||
// prefix.
|
||||
|
|
@ -49,7 +52,8 @@ inline absl::Status PreprocessAndFreezeGraph(mlir::ModuleOp module_op,
|
|||
llvm::Optional<Session*> session) {
|
||||
return PreprocessAndFreezeGraph(
|
||||
/*mlir_dump_file_prefix=*/kDefaultTfQuantMlirDumpFilePrefix,
|
||||
/*is_inliner_run=*/true, module_op, context, session);
|
||||
/*is_inliner_run=*/true, /*noinline_functions=*/{}, module_op, context,
|
||||
session);
|
||||
}
|
||||
|
||||
} // namespace quantization
|
||||
|
|
|
|||
|
|
@ -2307,6 +2307,33 @@ Mutually reduces multiple tensors of identical type and shape.
|
|||
}];
|
||||
}
|
||||
|
||||
def TF_CollectiveReduceScatterV2Op : TF_Op<"CollectiveReduceScatterV2", [DeclareOpInterfaceMethods<TF_GetResourceInstanceInterface>, TF_CollectiveReduceOrderingEffect]> {
|
||||
let summary = [{
|
||||
Mutually reduces multiple tensors of identical type and shape and scatters the result.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_FpOrI32OrI64Tensor:$input,
|
||||
TF_Int32Tensor:$group_size,
|
||||
TF_Int32Tensor:$group_key,
|
||||
TF_Int32Tensor:$instance_key,
|
||||
Variadic<TF_ResourceTensor>:$ordering_token,
|
||||
|
||||
TF_AnyStrAttrOf<["Min", "Max", "Mul", "Add"]>:$merge_op,
|
||||
TF_AnyStrAttrOf<["Id", "Div"]>:$final_op,
|
||||
DefaultValuedOptionalAttr<StrAttr, "\"auto\"">:$communication_hint,
|
||||
DefaultValuedOptionalAttr<F32Attr, "0.0f">:$timeout_seconds,
|
||||
DefaultValuedOptionalAttr<I64Attr, "-1">:$max_subdivs_per_device
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_FpOrI32OrI64Tensor:$data
|
||||
);
|
||||
|
||||
TF_DerivedOperandSizeAttr Nordering_token = TF_DerivedOperandSizeAttr<4>;
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_CollectiveReduceV2Op : TF_Op<"CollectiveReduceV2", [DeclareOpInterfaceMethods<TF_GetResourceInstanceInterface>, TF_CollectiveReduceOrderingEffect]> {
|
||||
let summary = [{
|
||||
Mutually reduces multiple tensors of identical type and shape.
|
||||
|
|
|
|||
|
|
@ -1476,6 +1476,8 @@ An op that groups a list of partitioned inputs together. Supports ND sharding.
|
|||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def TF_TPUPartitionedOutputOp : TF_Op<"TPUPartitionedOutput", [Pure]> {
|
||||
|
|
|
|||
|
|
@ -1074,6 +1074,12 @@ std::optional<std::string> CollectiveReduceV2Op::GetResourceInstanceStr() {
|
|||
: std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<std::string>
|
||||
CollectiveReduceScatterV2Op::GetResourceInstanceStr() {
|
||||
return getNorderingToken() == 0 ? std::optional<std::string>("")
|
||||
: std::nullopt;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConcatOp and ConcatV2Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
|||
|
|
@ -2666,6 +2666,32 @@ LogicalResult ToBoolOp::inferReturnTypes(
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TPUPartitionedInputV2
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// This method mimics this op's core/TF-level shape inference logic
|
||||
LogicalResult TPUPartitionedInputV2Op::verify() {
|
||||
TPUPartitionedInputV2Op op = *this;
|
||||
|
||||
int num_partitions = 1;
|
||||
const mlir::ArrayAttr partition_dims = op.getPartitionDims();
|
||||
for (const mlir::Attribute &dim : partition_dims) {
|
||||
num_partitions *= dim.cast<IntegerAttr>().getInt();
|
||||
}
|
||||
|
||||
const bool is_packed = op.getIsPacked();
|
||||
const bool replicated = partition_dims.empty();
|
||||
const int num_inputs_expected = is_packed ? 1 : num_partitions;
|
||||
|
||||
if (!((replicated && !is_packed) || (op.getN() == num_inputs_expected))) {
|
||||
return op.emitOpError() << "expected " << num_inputs_expected
|
||||
<< " inputs, got " << op.getN();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TransposeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
|||
|
|
@ -2561,6 +2561,26 @@ func.func @testInvalidToBool(%arg0: tensor<i32>) -> tensor<1xi1> {
|
|||
|
||||
// -----
|
||||
|
||||
// Test invalid tf.TPUPartitionedInputV2 with packing
|
||||
func.func @testPackedTPUPartitionedInputV2(tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<4x4xf32> {
|
||||
^bb0(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>):
|
||||
// expected-error @+1 {{expected 1 inputs, got 2}}
|
||||
%0 = "tf.TPUPartitionedInputV2"(%arg0, %arg1) {partition_dims = [2, 1], is_packed = true} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<4x4xf32>
|
||||
func.return %0 : tensor<4x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test invalid tf.TPUPartitionedInputV2 without packing
|
||||
func.func @testUnpackedTPUPartitionedInputV2(tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<4x4xf32> {
|
||||
^bb0(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>):
|
||||
// expected-error @+1 {{expected 2 inputs, got 1}}
|
||||
%0 = "tf.TPUPartitionedInputV2"(%arg0) {partition_dims = [2, 1], is_packed = false} : (tensor<2x4xf32>) -> tensor<4x4xf32>
|
||||
func.return %0 : tensor<4x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test valid tf.Transpose
|
||||
// CHECK-LABEL: testTranspose
|
||||
func.func @testTranspose(tensor<2x3xf32>) -> tensor<3x2xf32> {
|
||||
|
|
|
|||
|
|
@ -13,6 +13,42 @@ func.func @simple(%arg0: tensor<!tf_type.resource<tensor<10x3xf32>>>, %arg1: ten
|
|||
func.return %ri : tensor<!tf_type.resource<tensor<10x3xf32>>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL:func @simple_packed
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tensor<!tf_type.resource<tensor<10x3xf32>>>)
|
||||
func.func @simple_packed(%arg0: tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>> {
|
||||
// CHECK: "tf.TPUReplicateMetadata"()
|
||||
// CHECK: [[RI_0:%.*]] = "tf.TPUReplicatedInput"([[ARG0]])
|
||||
// CHECK-SAME: is_packed = true
|
||||
// CHECK: [[RI_1:%.*]] = "tf.TPUReplicatedInput"([[ARG0]])
|
||||
// CHECK-SAME: is_packed = true
|
||||
// CHECK: [[PI:%.*]] = "tf.TPUPartitionedInputV2"([[RI_0]], [[RI_1]])
|
||||
// CHECK-SAME: is_packed = false
|
||||
"tf.TPUReplicateMetadata"() {num_cores_per_replica = 2 : i64, num_replicas = 2 : i64} : () -> ()
|
||||
%0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
|
||||
%1 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
|
||||
%2 = "tf.TPUReplicatedInput"(%0, %1) {is_packed = false} : (tensor<!tf_type.resource<tensor<10x3xf32>>>, tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
|
||||
// CHECK: return [[PI]]
|
||||
func.return %2 : tensor<!tf_type.resource<tensor<10x3xf32>>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL:func @multi_arg_packed
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tensor<!tf_type.resource<tensor<10x3xf32>>>, [[ARG1:%.*]]: tensor<!tf_type.resource<tensor<10x3xf32>>>)
|
||||
func.func @multi_arg_packed(%arg0: tensor<!tf_type.resource<tensor<10x3xf32>>>, %arg1: tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>> {
|
||||
// CHECK: "tf.TPUReplicateMetadata"()
|
||||
// CHECK: [[RI_0:%.*]] = "tf.TPUReplicatedInput"([[ARG0]], [[ARG1]])
|
||||
// CHECK-SAME: is_packed = false
|
||||
// CHECK: [[RI_1:%.*]] = "tf.TPUReplicatedInput"([[ARG0]], [[ARG1]])
|
||||
// CHECK-SAME: is_packed = false
|
||||
// CHECK: [[PI:%.*]] = "tf.TPUPartitionedInputV2"([[RI_0]], [[RI_1]])
|
||||
// CHECK-SAME: is_packed = false
|
||||
"tf.TPUReplicateMetadata"() {num_cores_per_replica = 2 : i64, num_replicas = 2 : i64} : () -> ()
|
||||
%0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
|
||||
%1 = "tf.TPUPartitionedInputV2"(%arg1) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
|
||||
%2 = "tf.TPUReplicatedInput"(%0, %1) {is_packed = false} : (tensor<!tf_type.resource<tensor<10x3xf32>>>, tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
|
||||
// CHECK: return [[PI]]
|
||||
func.return %2 : tensor<!tf_type.resource<tensor<10x3xf32>>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL:func @missing_xla_sharding
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tensor<!tf_type.resource<tensor<10x3xf32>>>, [[ARG1:%.*]]: tensor<!tf_type.resource<tensor<10x3xf32>>>, [[ARG2:%.*]]: tensor<!tf_type.resource<tensor<10x3xf32>>>, [[ARG3:%.*]]: tensor<!tf_type.resource<tensor<10x3xf32>>>)
|
||||
func.func @missing_xla_sharding(%arg0: tensor<!tf_type.resource<tensor<10x3xf32>>>, %arg1: tensor<!tf_type.resource<tensor<10x3xf32>>>, %arg2: tensor<!tf_type.resource<tensor<10x3xf32>>>, %arg3: tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>> {
|
||||
|
|
@ -44,6 +80,27 @@ func.func @no_change_to_dag(%arg0: tensor<!tf_type.resource<tensor<10x3xf32>>>,
|
|||
|
||||
// -----
|
||||
|
||||
func.func @missing_metadata(%arg0: tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>> {
|
||||
// expected-error@+1 {{num cores per replica unavailable, metadata missing?}}
|
||||
%0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
|
||||
%1 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
|
||||
%2 = "tf.TPUReplicatedInput"(%0, %1) {is_packed = false} : (tensor<!tf_type.resource<tensor<10x3xf32>>>, tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
|
||||
func.return %2 : tensor<!tf_type.resource<tensor<10x3xf32>>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @inconsistent_packing(%arg0: tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>> {
|
||||
"tf.TPUReplicateMetadata"() {num_cores_per_replica = 2 : i64, num_replicas = 2 : i64} : () -> ()
|
||||
%0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
|
||||
// expected-error@+1 {{packing should match across ops}}
|
||||
%1 = "tf.TPUPartitionedInputV2"(%arg0, %arg0) {_XlaSharding = "", partition_dims = [], is_packed = false} : (tensor<!tf_type.resource<tensor<10x3xf32>>>, tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
|
||||
%2 = "tf.TPUReplicatedInput"(%0, %1) {is_packed = false} : (tensor<!tf_type.resource<tensor<10x3xf32>>>, tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
|
||||
func.return %2 : tensor<!tf_type.resource<tensor<10x3xf32>>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @xla_sharding_mismatch(%arg0: tensor<!tf_type.resource<tensor<10x3xf32>>>, %arg1: tensor<!tf_type.resource<tensor<10x3xf32>>>, %arg2: tensor<!tf_type.resource<tensor<10x3xf32>>>, %arg3: tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>> {
|
||||
%pi_0 = "tf.TPUPartitionedInputV2"(%arg0, %arg1) {_XlaSharding = "", partition_dims = []} : (tensor<!tf_type.resource<tensor<10x3xf32>>>, tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
|
||||
%pi_1 = "tf.TPUPartitionedInputV2"(%arg2, %arg3) {_XlaSharding = "123", partition_dims = []} : (tensor<!tf_type.resource<tensor<10x3xf32>>>, tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
|
||||
|
|
@ -80,3 +137,13 @@ func.func @mixed_inputs_to_replicated_op(%arg0: tensor<!tf_type.resource<tensor<
|
|||
%ri = "tf.TPUReplicatedInput"(%pi_0, %arg2) {index = 1} : (tensor<!tf_type.resource<tensor<10x3xf32>>>, tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
|
||||
func.return %ri : tensor<!tf_type.resource<tensor<10x3xf32>>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @num_partitioned_inputs_mismatch_num_cores_per_replica(%arg0: tensor<!tf_type.resource<tensor<10x3xf32>>>, %arg1: tensor<!tf_type.resource<tensor<10x3xf32>>>, %arg2: tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>> {
|
||||
"tf.TPUReplicateMetadata"() {num_cores_per_replica = 2 : i64, num_replicas = 1 : i64} : () -> ()
|
||||
// expected-error@+1 {{expects 2 operands but found 3}}
|
||||
%pi = "tf.TPUPartitionedInputV2"(%arg0, %arg1, %arg2) {_XlaSharding = "", partition_dims = []} : (tensor<!tf_type.resource<tensor<10x3xf32>>>, tensor<!tf_type.resource<tensor<10x3xf32>>>, tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
|
||||
%ri = "tf.TPUReplicatedInput"(%pi) : (tensor<!tf_type.resource<tensor<10x3xf32>>>) -> tensor<!tf_type.resource<tensor<10x3xf32>>>
|
||||
func.return %ri : tensor<!tf_type.resource<tensor<10x3xf32>>>
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,6 +23,27 @@ func.func @read_write_resource(%arg0: tensor<!tf_type.resource<tensor<i32>>>, %a
|
|||
func.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @read_write_packed_resource
|
||||
// CHECK-SAME: ([[ARG0:%.+]]: tensor<!tf_type.resource<tensor<i32>>>)
|
||||
func.func @read_write_packed_resource(%arg0: tensor<!tf_type.resource<tensor<i32>>>) {
|
||||
// CHECK-DAG: [[READ0:%.+]] = "tf.ReadVariableOp"([[ARG0]])
|
||||
// CHECK: [[INPUT:%.+]] = "tf.TPUPartitionedInputV2"([[READ0]])
|
||||
// CHECK-SAME: _XlaSharding = ""
|
||||
// CHECK-SAME: is_packed = true
|
||||
// CHECK-SAME: partition_dims = []
|
||||
%0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<!tf_type.resource<tensor<i32>>>
|
||||
%1 = "tf.ReadVariableOp"(%0) : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
|
||||
// CHECK: [[COMPUTATION:%.+]] = "tf_device.cluster_func"([[INPUT]])
|
||||
%2 = "tf_device.cluster_func"(%1) {func = @computation, use_spmd_for_xla_partitioning = true, num_cores_per_replica = 2 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
// CHECK: [[OUTPUT:%.+]]:2 = "tf.TPUPartitionedOutputV2"([[COMPUTATION]])
|
||||
// CHECK-SAME: _XlaSharding = ""
|
||||
// CHECK-SAME: partition_dims = []
|
||||
// CHECK-DAG: "tf.AssignVariableOp"([[ARG0]], [[OUTPUT]]#0)
|
||||
// CHECK-DAG: "tf.AssignVariableOp"([[ARG0]], [[OUTPUT]]#1)
|
||||
"tf.AssignVariableOp"(%0, %2) : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
|
||||
func.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @read_only_resource
|
||||
// CHECK-SAME: ([[ARG0:%.+]]: tensor<!tf_type.resource<tensor<i32>>>, [[ARG1:%.+]]: tensor<!tf_type.resource<tensor<i32>>>)
|
||||
func.func @read_only_resource(%arg0: tensor<!tf_type.resource<tensor<i32>>>, %arg1: tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32> {
|
||||
|
|
@ -127,6 +148,28 @@ func.func @resource_missing_subtype(%arg0: tensor<!tf_type.resource>, %arg1: ten
|
|||
|
||||
// -----
|
||||
|
||||
func.func @missing_num_cores_per_replica(%arg0: tensor<!tf_type.resource<tensor<i32>>>) {
|
||||
// expected-error@+1 {{op num cores per replica unavailable}}
|
||||
%0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<!tf_type.resource<tensor<i32>>>
|
||||
%1 = "tf.ReadVariableOp"(%0) : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
|
||||
%2 = "tf_device.cluster_func"(%1) {func = @computation, use_spmd_for_xla_partitioning = true} : (tensor<i32>) -> tensor<i32>
|
||||
"tf.AssignVariableOp"(%0, %2) : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
|
||||
func.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @mismatch_num_cores_per_replica(%arg0: tensor<!tf_type.resource<tensor<i32>>>) {
|
||||
// expected-error@+1 {{expects 2 operands but found 3}}
|
||||
%0 = "tf.TPUPartitionedInputV2"(%arg0, %arg0, %arg0) {_XlaSharding = "", partition_dims = []} : (tensor<!tf_type.resource<tensor<i32>>>, tensor<!tf_type.resource<tensor<i32>>>, tensor<!tf_type.resource<tensor<i32>>>) -> tensor<!tf_type.resource<tensor<i32>>>
|
||||
%1 = "tf.ReadVariableOp"(%0) : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
|
||||
%2 = "tf_device.cluster_func"(%1) {func = @computation, use_spmd_for_xla_partitioning = true, num_cores_per_replica = 2 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
"tf.AssignVariableOp"(%0, %2) : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
|
||||
func.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check outside compiled that uses a TPUPartitionedInputV2.
|
||||
|
||||
func.func private @computation(%arg0: tensor<i32>) -> tensor<i32>
|
||||
|
|
|
|||
|
|
@ -144,6 +144,19 @@ bool AddAccessedResourceIds(
|
|||
return false;
|
||||
}
|
||||
|
||||
/* Resources may be merged with an execute op when they are on its device or a
|
||||
* `COMPOSITE`. Note that a `COMPOSITE` represents a set of devices, they
|
||||
* are typically associated with packed variables. Presently, we assume this
|
||||
* set spans all the devices. So, a variable on a `COMPOSITE` will have a local
|
||||
* instance on the execute op's device.
|
||||
*/
|
||||
bool IsResourceMergeable(Attribute& resource_attr, Attribute& device_attr) {
|
||||
return resource_attr &&
|
||||
((resource_attr == device_attr) ||
|
||||
(resource_attr.cast<mlir::StringAttr>().getValue().find(
|
||||
"COMPOSITE") != llvm::StringRef::npos));
|
||||
}
|
||||
|
||||
// Finds the variable access info for a TPUExecute op.
|
||||
// - `check_device` specifies whether it checks the device assignment of the
|
||||
// variables to match the TPUExecute op. This is optional in some context,
|
||||
|
|
@ -187,7 +200,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(
|
|||
if (auto* resource_op = resource.getDefiningOp()) {
|
||||
auto resource_attr = resource_op->getAttr(kDeviceAttr);
|
||||
// Check device matching for the node defining the resource.
|
||||
if (!resource_attr || resource_attr != device_attr) continue;
|
||||
if (!IsResourceMergeable(resource_attr, device_attr)) continue;
|
||||
} else {
|
||||
auto resource_arg = resource.dyn_cast<BlockArgument>();
|
||||
assert(resource_arg);
|
||||
|
|
@ -195,7 +208,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(
|
|||
// Check device matching for the argument defining the resource.
|
||||
auto resource_attr = func.getArgAttrOfType<mlir::StringAttr>(
|
||||
resource_arg.getArgNumber(), kFuncDeviceAttr);
|
||||
if (!resource_attr || resource_attr != device_attr) continue;
|
||||
if (!IsResourceMergeable(resource_attr, device_attr)) continue;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ limitations under the License.
|
|||
==============================================================================*/
|
||||
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
|
@ -41,47 +42,73 @@ LogicalResult ReorderReplicateAndPartitionedInputs(
|
|||
return replicated_input.emitOpError()
|
||||
<< "expects all inputs from 'tf.TPUPartitionedInputV2' ops";
|
||||
|
||||
const auto metadata_iter =
|
||||
replicated_input->getBlock()->getOps<TF::TPUReplicateMetadataOp>();
|
||||
TF::TPUReplicateMetadataOp metadata;
|
||||
if (!metadata_iter.empty()) metadata = *(metadata_iter.begin());
|
||||
|
||||
auto first_partitioned_input = llvm::cast<TF::TPUPartitionedInputV2Op>(
|
||||
replicated_input.getOperand(0).getDefiningOp());
|
||||
std::optional<::llvm::StringRef> xla_sharding =
|
||||
first_partitioned_input.get_XlaSharding();
|
||||
auto partition_dims = first_partitioned_input.getPartitionDims();
|
||||
size_t num_cores_per_replica = first_partitioned_input.getNumOperands();
|
||||
const std::optional<::llvm::StringRef> xla_sharding =
|
||||
first_partitioned_input.get_XlaSharding();
|
||||
|
||||
for (auto operand : replicated_input.getInputs().drop_front()) {
|
||||
size_t num_cores_per_replica = first_partitioned_input.getNumOperands();
|
||||
if (metadata) {
|
||||
num_cores_per_replica = metadata.getNumCoresPerReplica();
|
||||
} else if (first_partitioned_input.getIsPacked()) {
|
||||
return first_partitioned_input->emitOpError()
|
||||
<< "num cores per replica unavailable, metadata missing?";
|
||||
}
|
||||
|
||||
const bool packed_input = first_partitioned_input.getIsPacked();
|
||||
const size_t num_operands_expected = packed_input ? 1 : num_cores_per_replica;
|
||||
if (metadata &&
|
||||
num_operands_expected != first_partitioned_input.getNumOperands()) {
|
||||
return first_partitioned_input->emitOpError()
|
||||
<< "expects " << num_operands_expected << " operands but found "
|
||||
<< first_partitioned_input.getNumOperands();
|
||||
}
|
||||
|
||||
for (const auto& operand : replicated_input.getInputs().drop_front()) {
|
||||
auto partitioned_input =
|
||||
llvm::cast<TF::TPUPartitionedInputV2Op>(operand.getDefiningOp());
|
||||
std::optional<::llvm::StringRef> op_xla_sharding =
|
||||
const std::optional<::llvm::StringRef> op_xla_sharding =
|
||||
partitioned_input.get_XlaSharding();
|
||||
auto op_partition_dims = partitioned_input.getPartitionDims();
|
||||
const auto op_partition_dims = partitioned_input.getPartitionDims();
|
||||
// Abort if TPUPartitionedInputV2(s) do not have the same attributes.
|
||||
if (!llvm::equal(partition_dims, op_partition_dims))
|
||||
if (!llvm::equal(partition_dims, op_partition_dims)) {
|
||||
return partitioned_input->emitOpError()
|
||||
<< "expects partition_dims = " << partition_dims << " but found "
|
||||
<< op_partition_dims;
|
||||
if (partitioned_input.getNumOperands() != num_cores_per_replica)
|
||||
} else if (partitioned_input.getIsPacked() !=
|
||||
first_partitioned_input.getIsPacked()) {
|
||||
return partitioned_input->emitOpError()
|
||||
<< "expects " << num_cores_per_replica << " operands but found "
|
||||
<< "packing should match across ops";
|
||||
} else if (partitioned_input.getNumOperands() != num_operands_expected) {
|
||||
return partitioned_input->emitOpError()
|
||||
<< "expects " << num_operands_expected << " operands but found "
|
||||
<< partitioned_input.getNumOperands();
|
||||
if (xla_sharding != op_xla_sharding)
|
||||
} else if (xla_sharding != op_xla_sharding) {
|
||||
return replicated_input.emitOpError()
|
||||
<< "expects all inputs from 'tf.TPUPartitionedInputV2' ops to "
|
||||
"have identical XLA sharding";
|
||||
}
|
||||
}
|
||||
|
||||
// 2D Matrix to store per core per replica operands. The matrix dimensions are
|
||||
// num_cores_per_replica x num_replicas. i-th row holds the operands for i-th
|
||||
// core. j-th column holds the operands for j-th replica.
|
||||
llvm::SmallVector<llvm::SmallVector<Value, 4>, 4>
|
||||
operands_per_replica_per_core;
|
||||
operands_per_replica_per_core.resize(num_cores_per_replica);
|
||||
operands_per_replica_per_core(num_cores_per_replica);
|
||||
|
||||
// Collect all operands in the 2D matrix.
|
||||
for (auto operand : replicated_input.getInputs()) {
|
||||
auto pi = llvm::cast<TF::TPUPartitionedInputV2Op>(operand.getDefiningOp());
|
||||
for (auto& pi_operand : pi->getOpOperands()) {
|
||||
unsigned core_id = pi_operand.getOperandNumber();
|
||||
operands_per_replica_per_core[core_id].push_back(pi_operand.get());
|
||||
Operation* pi = operand.getDefiningOp();
|
||||
for (unsigned core_id = 0; core_id < num_cores_per_replica; ++core_id) {
|
||||
const auto pi_operand =
|
||||
packed_input ? pi->getOperand(0) : pi->getOperand(core_id);
|
||||
operands_per_replica_per_core[core_id].push_back(pi_operand);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -89,16 +116,23 @@ LogicalResult ReorderReplicateAndPartitionedInputs(
|
|||
// `tf.TPUPartitionedInputV2` op.
|
||||
OpBuilder builder(replicated_input);
|
||||
llvm::SmallVector<Value, 4> operands_per_core;
|
||||
for (const auto& operands_per_replica : operands_per_replica_per_core) {
|
||||
for (auto& operands_per_replica : operands_per_replica_per_core) {
|
||||
const bool is_packed =
|
||||
packed_input && llvm::all_equal(operands_per_replica);
|
||||
if (is_packed) // reduce the duplicates to one input for packed vars
|
||||
operands_per_replica.erase(operands_per_replica.begin() + 1,
|
||||
operands_per_replica.end());
|
||||
auto replicate_op = builder.create<TF::TPUReplicatedInputOp>(
|
||||
replicated_input.getLoc(), replicated_input.getType(),
|
||||
operands_per_replica, replicated_input->getAttrs());
|
||||
replicate_op.setIsPacked(is_packed);
|
||||
operands_per_core.push_back(replicate_op);
|
||||
}
|
||||
|
||||
auto pi = builder.create<TF::TPUPartitionedInputV2Op>(
|
||||
first_partitioned_input.getLoc(), replicated_input.getType(),
|
||||
operands_per_core, first_partitioned_input->getAttrs());
|
||||
pi.setIsPacked(false); // inputs are now ops--not resources
|
||||
replicated_input.replaceAllUsesWith(pi.getOutput());
|
||||
return success();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||
#include <memory>
|
||||
#include <tuple>
|
||||
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
|
|
@ -38,6 +39,9 @@ namespace {
|
|||
#define GEN_PASS_DEF_TPURESOURCEREADSWRITESPARTITIONINGPASS
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc"
|
||||
|
||||
constexpr char kUseSpmdAttr[] = "use_spmd_for_xla_partitioning";
|
||||
constexpr char kNumCoresPerReplicaAttr[] = "num_cores_per_replica";
|
||||
|
||||
struct TPUResourceReadsWritesPartitioningPass
|
||||
: public impl::TPUResourceReadsWritesPartitioningPassBase<
|
||||
TPUResourceReadsWritesPartitioningPass> {
|
||||
|
|
@ -109,12 +113,14 @@ LogicalResult UpdateReadUses(TF::ReadVariableOp old_read,
|
|||
LogicalResult PartitionResourceReadsWrites(
|
||||
tf_device::ClusterFuncOp cluster_func) {
|
||||
bool use_spmd = false;
|
||||
if (auto use_spmd_attr = cluster_func->getAttrOfType<BoolAttr>(
|
||||
"use_spmd_for_xla_partitioning"))
|
||||
if (auto use_spmd_attr = cluster_func->getAttrOfType<BoolAttr>(kUseSpmdAttr))
|
||||
use_spmd = use_spmd_attr.getValue();
|
||||
|
||||
if (!use_spmd) return success();
|
||||
|
||||
auto num_cores_per_replica_attr =
|
||||
cluster_func->getAttrOfType<IntegerAttr>(kNumCoresPerReplicaAttr);
|
||||
|
||||
// Wrap the ClusterFunc with a ParallelExecute if it does not already exist.
|
||||
OpBuilder builder(cluster_func);
|
||||
tf_device::ParallelExecuteOp parallel_execute =
|
||||
|
|
@ -138,20 +144,40 @@ LogicalResult PartitionResourceReadsWrites(
|
|||
!AllResourceTypesHaveSubtypes(partitioned_input.getInputs().getTypes()))
|
||||
continue;
|
||||
|
||||
const auto inputs = partitioned_input.getInputs();
|
||||
const bool packed_input = partitioned_input.getIsPacked();
|
||||
int num_cores_per_replica = partitioned_input.getN();
|
||||
if (num_cores_per_replica_attr) {
|
||||
num_cores_per_replica = num_cores_per_replica_attr.getInt();
|
||||
} else if (packed_input) {
|
||||
return partitioned_input->emitOpError()
|
||||
<< "num cores per replica unavailable";
|
||||
}
|
||||
|
||||
const int num_operands_expected = packed_input ? 1 : num_cores_per_replica;
|
||||
if (num_cores_per_replica_attr && num_operands_expected != inputs.size()) {
|
||||
return partitioned_input->emitOpError()
|
||||
<< "expects " << num_operands_expected << " operands but found "
|
||||
<< partitioned_input.getNumOperands();
|
||||
}
|
||||
|
||||
builder.setInsertionPoint(assign_var);
|
||||
llvm::SmallVector<Type, 4> partitioned_output_types;
|
||||
partitioned_output_types.reserve(partitioned_input.getN());
|
||||
for (Type input_type : partitioned_input.getInputs().getTypes())
|
||||
partitioned_output_types.push_back(GetResourceSubtype(input_type));
|
||||
partitioned_output_types.reserve(num_cores_per_replica);
|
||||
for (int i = 0; i < num_cores_per_replica; ++i) {
|
||||
const auto& input = packed_input ? inputs[0] : inputs[i];
|
||||
partitioned_output_types.push_back(GetResourceSubtype(input.getType()));
|
||||
}
|
||||
|
||||
auto partitioned_output = builder.create<TF::TPUPartitionedOutputV2Op>(
|
||||
cluster_func->getLoc(), partitioned_output_types, result,
|
||||
partitioned_input.getPartitionDimsAttr(),
|
||||
partitioned_input.get_XlaShardingAttr());
|
||||
for (auto resource_write : llvm::zip(partitioned_input.getInputs(),
|
||||
partitioned_output.getOutput()))
|
||||
for (auto [i, value] : llvm::enumerate(partitioned_output.getOutput())) {
|
||||
const auto& resource = packed_input ? inputs[0] : inputs[i];
|
||||
builder.create<TF::AssignVariableOp>(
|
||||
assign_var->getLoc(), /*resource=*/std::get<0>(resource_write),
|
||||
/*value=*/std::get<1>(resource_write));
|
||||
assign_var->getLoc(), /*resource=*/resource, /*value=*/value);
|
||||
}
|
||||
assign_var.erase();
|
||||
}
|
||||
|
||||
|
|
@ -167,13 +193,24 @@ LogicalResult PartitionResourceReadsWrites(
|
|||
continue;
|
||||
}
|
||||
|
||||
builder.setInsertionPoint(partitioned_input);
|
||||
// we only want to create one read variable op per unique input
|
||||
// otherwise tpu rewriting will fail to clean up the duplicates
|
||||
llvm::SmallMapVector<Value, Value, 4> read_variable_ops;
|
||||
llvm::SmallVector<Value, 4> partitioned_reads;
|
||||
builder.setInsertionPoint(partitioned_input);
|
||||
|
||||
for (Value input : partitioned_input.getInputs()) {
|
||||
auto partitioned_read = builder.create<TF::ReadVariableOp>(
|
||||
read_var->getLoc(), GetResourceSubtype(input), input);
|
||||
partitioned_reads.push_back(partitioned_read.getValue());
|
||||
auto search = read_variable_ops.find(input);
|
||||
// if a read variable op already doesn't exist for this input, create it
|
||||
if (search == read_variable_ops.end()) {
|
||||
auto partitioned_read = builder.create<TF::ReadVariableOp>(
|
||||
read_var->getLoc(), GetResourceSubtype(input), input);
|
||||
search = read_variable_ops.insert({input, partitioned_read.getValue()})
|
||||
.first;
|
||||
}
|
||||
partitioned_reads.push_back(search->second);
|
||||
}
|
||||
|
||||
auto partitioned_read = builder.create<TF::TPUPartitionedInputV2Op>(
|
||||
partitioned_input->getLoc(), read_var.getValue().getType(),
|
||||
partitioned_reads, partitioned_input.getPartitionDimsAttr(),
|
||||
|
|
|
|||
|
|
@ -1,136 +0,0 @@
|
|||
// RUN: tf-tfrt-opt %s -split-input-file \
|
||||
// RUN: -xla-cpu-transform-matmul="tile-sizes=8,4,2" \
|
||||
// RUN: | FileCheck %s --check-prefix=MARKED
|
||||
|
||||
// RUN: tf-tfrt-opt %s -split-input-file \
|
||||
// RUN: -xla-cpu-transform-matmul="tile-sizes=8,4,2" \
|
||||
// RUN: | FileCheck %s --check-prefix=TRANSFORMED
|
||||
|
||||
// RUN: tf-tfrt-opt %s -split-input-file -xla-cpu-transform-matmul="tile-sizes=8,4,2" \
|
||||
// RUN: -canonicalize -vectorize-perfectly-tiled-loops \
|
||||
// RUN: | FileCheck %s --check-prefix=VECTORIZED
|
||||
|
||||
// RUN: tf-tfrt-opt %s -split-input-file -xla-cpu-transform-matmul="lower-to-mmt4d=true" \
|
||||
// RUN: -vectorize-perfectly-tiled-loops \
|
||||
// RUN: | FileCheck %s --check-prefix=MMT4D
|
||||
|
||||
func.func @matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>)
|
||||
-> tensor<?x?xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
|
||||
%c1 = arith.constant 1 : index
|
||||
%1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
|
||||
%2 = tensor.empty(%0, %1) : tensor<?x?xf32>
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
%3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%4 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%3 : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %4 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// TRANSFORMED-LABEL: func @matmul(
|
||||
// TRANSFORMED-SAME: %[[LHS:.*]]: tensor<?x?xf32>, %[[RHS:.*]]: tensor<?x?xf32>)
|
||||
|
||||
// TRANSFORMED-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||
// TRANSFORMED: %[[INIT:.*]] = tensor.empty
|
||||
|
||||
// TRANSFORMED: %[[MAIN_PAR:.*]] = gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[IUB:.*]], %[[JUB:.*]]) step
|
||||
// TRANSFORMED: %[[MAIN_SLICE:.*]] = tensor.extract_slice %[[INIT]]
|
||||
// TRANSFORMED: %[[MAIN_FILL:.*]] = linalg.fill{{.*}}outs(%[[MAIN_SLICE]]
|
||||
// TRANSFORMED: %[[MAIN_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]]) to (%[[KUB:.*]]) {{.*}} outs ({{.*}} = %[[MAIN_FILL]]:
|
||||
// TRANSFORMED: %[[MAIN_PAR_MAIN_FOR_MATMUL:.*]] = linalg.matmul
|
||||
// TRANSFORMED: gml_st.set_yield %[[MAIN_PAR_MAIN_FOR_MATMUL]]
|
||||
// TRANSFORMED: %[[REM_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[KUB]]) {{.*}} outs ({{.*}} = %[[MAIN_FOR]]:
|
||||
// TRANSFORMED: %[[MAIN_PAR_REM_FOR_MATMUL:.*]] = linalg.matmul
|
||||
// TRANSFORMED: gml_st.set_yield %[[MAIN_PAR_REM_FOR_MATMUL]]
|
||||
// TRANSFORMED: gml_st.set_yield %[[REM_FOR]]
|
||||
|
||||
// TRANSFORMED: %[[REM_RHS_PAR:.*]] = gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[JUB]])
|
||||
// TRANSFORMED: %[[REM_RHS_SLICE:.*]] = tensor.extract_slice %[[MAIN_PAR]]
|
||||
// TRANSFORMED: %[[REM_RHS_FILL:.*]] = linalg.fill{{.*}}outs(%[[REM_RHS_SLICE]]
|
||||
// TRANSFORMED: %[[REM_RHS_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]]) {{.*}} outs ({{.*}} = %[[REM_RHS_FILL]]:
|
||||
// TRANSFORMED: %[[REM_RHS_PAR_MATMUL:.*]] = linalg.matmul
|
||||
// TRANSFORMED: gml_st.set_yield %[[REM_RHS_PAR_MATMUL]]
|
||||
// TRANSFORMED: gml_st.set_yield %[[REM_RHS_FOR]]
|
||||
|
||||
// TRANSFORMED: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[IUB]], %[[C0]])
|
||||
// TRANSFORMED: %[[REM_LHS_SLICE:.*]] = tensor.extract_slice %[[REM_RHS_PAR]]
|
||||
// TRANSFORMED: %[[REM_LHS_FILL:.*]] = linalg.fill{{.*}}outs(%[[REM_LHS_SLICE]]
|
||||
// TRANSFORMED: %[[REM_LHS_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]]) {{.*}} outs ({{.*}} = %[[REM_LHS_FILL]]:
|
||||
// TRANSFORMED: %[[REM_LHS_PAR_MATMUL:.*]] = linalg.matmul
|
||||
// TRANSFORMED: gml_st.set_yield %[[REM_LHS_PAR_MATMUL]]
|
||||
// TRANSFORMED: gml_st.set_yield %[[REM_LHS_FOR]]
|
||||
|
||||
// -----
|
||||
|
||||
// VECTORIZED-LABEL: func @matmul(
|
||||
// VECTORIZED-SAME: %[[LHS:.*]]: tensor<?x?xf32>, %[[RHS:.*]]: tensor<?x?xf32>)
|
||||
|
||||
// VECTORIZED-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x4xf32>
|
||||
// VECTORIZED-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||
// VECTORIZED-DAG: %[[INIT:.*]] = tensor.empty
|
||||
|
||||
// VECTORIZED: %[[MAIN_PAR:.*]] = gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[IUB:.*]], %[[JUB:.*]]) step
|
||||
// VECTORIZED: %[[MAIN_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]]) to (%[[KUB:.*]]) {{.*}} outs (%[[ARG:.*]] =
|
||||
// VECTORIZED: %[[LHS_READ:.*]] = vector.transfer_read {{.*}} vector<8x2xf32>
|
||||
// VECTORIZED: %[[RHS_READ:.*]] = vector.transfer_read {{.*}} vector<2x4xf32>
|
||||
// VECTORIZED: %[[CONTRACT:.*]] = vector.contract {{.*}} %[[LHS_READ]], %[[RHS_READ]], %[[ARG]]
|
||||
// VECTORIZED: gml_st.set_yield %[[CONTRACT]]
|
||||
// VECTORIZED: %[[WRITE:.*]] = vector.transfer_write %[[MAIN_FOR]]
|
||||
// VECTORIZED: %[[REM_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[KUB]]) {{.*}} outs ({{.*}} = %[[WRITE]]:
|
||||
// VECTORIZED: %[[MAIN_PAR_REM_FOR_MATMUL:.*]] = linalg.matmul
|
||||
// VECTORIZED: gml_st.set_yield %[[MAIN_PAR_REM_FOR_MATMUL]]
|
||||
// VECTORIZED: gml_st.set_yield %[[REM_FOR]]
|
||||
|
||||
// VECTORIZED: %[[REM_RHS_PAR:.*]] = gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[JUB]])
|
||||
// VECTORIZED: %[[REM_RHS_SLICE:.*]] = tensor.extract_slice %[[MAIN_PAR]]
|
||||
// VECTORIZED: %[[REM_RHS_FILL:.*]] = linalg.fill{{.*}}outs(%[[REM_RHS_SLICE]]
|
||||
// VECTORIZED: %[[REM_RHS_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]]) {{.*}} outs ({{.*}} = %[[REM_RHS_FILL]]:
|
||||
// VECTORIZED: %[[REM_RHS_PAR_MATMUL:.*]] = linalg.matmul
|
||||
// VECTORIZED: gml_st.set_yield %[[REM_RHS_PAR_MATMUL]]
|
||||
// VECTORIZED: gml_st.set_yield %[[REM_RHS_FOR]]
|
||||
|
||||
// VECTORIZED: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[IUB]], %[[C0]])
|
||||
// VECTORIZED: %[[REM_LHS_SLICE:.*]] = tensor.extract_slice %[[REM_RHS_PAR]]
|
||||
// VECTORIZED: %[[REM_LHS_FILL:.*]] = linalg.fill{{.*}}outs(%[[REM_LHS_SLICE]]
|
||||
// VECTORIZED: %[[REM_LHS_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]]) {{.*}} outs ({{.*}} = %[[REM_LHS_FILL]]:
|
||||
// VECTORIZED: %[[REM_LHS_PAR_MATMUL:.*]] = linalg.matmul
|
||||
// VECTORIZED: gml_st.set_yield %[[REM_LHS_PAR_MATMUL]]
|
||||
// VECTORIZED: gml_st.set_yield %[[REM_LHS_FOR]]
|
||||
|
||||
// -----
|
||||
|
||||
// MARKED-LABEL: func @matmul(
|
||||
|
||||
// MARKED: %[[C0:.*]] = arith.constant 0 : index
|
||||
// MARKED: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[IUB:.*]], %[[JUB:.*]]) step
|
||||
// MARKED: gml_st.for (%[[K:.*]]) = (%[[C0]]) to (%[[KUB:.*]]) step
|
||||
// MARKED: } {__peeling_applied_label__, __perfectly_tiled_loop_label__}
|
||||
// MARKED: gml_st.for (%[[K:.*]]) = (%[[KUB]])
|
||||
// MARKED: } {__peeling_applied_label__
|
||||
// MARKED: } {__peeling_applied_label__
|
||||
|
||||
// MARKED: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[JUB]])
|
||||
// MARKED: gml_st.for (%[[K:.*]]) = (%[[C0]])
|
||||
// MARKED: }
|
||||
// MARKED: } {__peeling_applied_label__
|
||||
|
||||
// MARKED: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[IUB]], %[[C0]])
|
||||
// MARKED: gml_st.for (%[[K:.*]]) = (%[[C0]])
|
||||
// MARKED: }
|
||||
// MARKED: } {__peeling_applied_label__
|
||||
|
||||
// -----
|
||||
|
||||
// MMT4D-LABEL: func @matmul(
|
||||
|
||||
// MMT4D-NOT: linalg.matmul
|
||||
// MMT4D: scf.for {{.*}} = %c0 to %[[DIM0:.*]] step %c1
|
||||
// MMT4D: scf.for {{.*}} = %c0 to %[[DIM1:.*]] step %c1
|
||||
// MMT4D: vector.transfer_read
|
||||
// MMT4D: %[[KERNEL:.*]] = scf.for {{.*}} = %c0 to %[[DIM2:.*]] step %c1 {{.*}} -> (vector<1x1x8x8xf32>)
|
||||
// MMT4D: vector.transfer_read
|
||||
// MMT4D: vector.transfer_read
|
||||
// MMT4D: %[[CONTRACT:.*]] = vector.contract
|
||||
// MMT4D: scf.yield %[[CONTRACT]]
|
||||
// MMT4D: %[[WRITE:.*]] = vector.transfer_write %[[KERNEL]]
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
// RUN: xla-opt -hlo-legalize-to-linalg -hlo-xla-runtime-sparsification %s | FileCheck %s
|
||||
|
||||
#SparseVector = #sparse_tensor.encoding<{ dimLevelType = ["compressed"] }>
|
||||
|
||||
// CHECK-LABEL: func.func @mult_sparse_dense(
|
||||
// CHECK-SAME: %[[PTR:.*0]]: memref<?xindex>,
|
||||
// CHECK-SAME: %[[IDX:.*1]]: memref<?xindex>,
|
||||
// CHECK-SAME: %[[VAL:.*2]]: memref<?xf64>,
|
||||
// CHECK-SAME: %[[SPEC:.*3]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)>
|
||||
// CHECK-SAME: %[[DENSE:.*4]]: memref<10xf64>) -> memref<10xf64> {
|
||||
// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
|
||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[A:.*]] = memref.alloc() {alignment = 64 : i64} : memref<10xf64>
|
||||
// CHECK: linalg.fill ins(%[[F0]] : f64) outs(%[[A]] : memref<10xf64>)
|
||||
// CHECK: %[[LO:.*]] = memref.load %[[PTR]][%[[C0]]] : memref<?xindex>
|
||||
// CHECK: %[[HI:.*]] = memref.load %[[PTR]][%[[C1]]] : memref<?xindex>
|
||||
// CHECK: scf.for %[[II:.*]] = %[[LO]] to %[[HI]] step %[[C1]] {
|
||||
// CHECK: %[[I:.*]] = memref.load %[[IDX]][%[[II]]] : memref<?xindex>
|
||||
// CHECK: %[[T0:.*]] = memref.load %[[VAL]][%[[II]]] : memref<?xf64>
|
||||
// CHECK: %[[T1:.*]] = memref.load %[[DENSE]][%[[I]]] : memref<10xf64>
|
||||
// CHECK: %[[T3:.*]] = arith.mulf %[[T0]], %[[T1]] : f64
|
||||
// CHECK: memref.store %[[T3]], %[[A]][%[[I]]] : memref<10xf64>
|
||||
// CHECK: }
|
||||
// CHECK: return %[[A]] : memref<10xf64>
|
||||
// CHECK: }
|
||||
func.func @mult_sparse_dense(%arg0: tensor<10xf64, #SparseVector>,
|
||||
%arg1: tensor<10xf64>)
|
||||
-> tensor<10xf64> {
|
||||
%0 = mhlo.multiply %arg0, %arg1 : (tensor<10xf64, #SparseVector>,
|
||||
tensor<10xf64>) -> tensor<10xf64>
|
||||
return %0 : tensor<10xf64>
|
||||
}
|
||||
|
|
@ -35,10 +35,7 @@ bool IsGoogleTensorRTEnabled() {
|
|||
#else // TF_USE_TENSORRT_STATIC
|
||||
auto handle_or = se::internal::DsoLoader::TryDlopenTensorRTLibraries();
|
||||
if (!handle_or.ok()) {
|
||||
LOG_WARNING_WITH_PREFIX
|
||||
<< "Cannot dlopen some TensorRT libraries. If you would like "
|
||||
"to use Nvidia GPU with TensorRT, please make sure the "
|
||||
"missing libraries mentioned above are installed properly.";
|
||||
LOG_WARNING_WITH_PREFIX << "Could not find TensorRT";
|
||||
}
|
||||
return handle_or.ok();
|
||||
#endif // TF_USE_TENSORRT_STATIC
|
||||
|
|
|
|||
|
|
@ -3951,7 +3951,7 @@ Status HloEvaluator::HandleReduce(HloInstruction* instr) {
|
|||
}
|
||||
}
|
||||
|
||||
const int num_threads = tsl::port::MaxParallelism() + 1;
|
||||
const int num_threads = ShapeUtil::GetForEachIndexParallelThreadCount() + 1;
|
||||
std::vector<std::unique_ptr<HloEvaluator>> embedded_evaluators;
|
||||
embedded_evaluators.reserve(num_threads);
|
||||
for (int i = 0; i < num_threads; ++i) {
|
||||
|
|
|
|||
|
|
@ -1819,7 +1819,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
|||
const Shape window_shape = ShapeUtil::MakeShape(
|
||||
input_arrays[0]->shape().element_type(), window_dimension_sizes);
|
||||
|
||||
const int num_threads = tsl::port::MaxParallelism() + 1;
|
||||
const int num_threads = ShapeUtil::GetForEachIndexParallelThreadCount() + 1;
|
||||
std::vector<std::unique_ptr<HloEvaluator>> embedded_evaluators;
|
||||
embedded_evaluators.reserve(num_threads);
|
||||
for (int i = 0; i < num_threads; ++i) {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
cc_binary(
|
||||
load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_binary")
|
||||
load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library")
|
||||
|
||||
xla_cc_binary(
|
||||
name = "mlir_replay",
|
||||
srcs = ["mlir_replay.cc"],
|
||||
deps = [
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ limitations under the License.
|
|||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
|
|
@ -429,11 +430,13 @@ LogicalResult fuseOutputFill(PatternRewriter& rewriter, Operation* op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
FailureOr<Operation*> tileAndFuseGreedily(
|
||||
FailureOr<ParallelOp> tileUsingGmlStParallelAndFuseGreedily(
|
||||
PatternRewriter& rewriter, Operation* op,
|
||||
const mlir::gml_st::TilingOptions& opts, StringRef label,
|
||||
llvm::function_ref<bool(Operation*)> fuseFilterFn) {
|
||||
auto tilingResult = tile(opts, rewriter, cast<TilingInterface>(op));
|
||||
assert(opts.distribute == true &&
|
||||
"gml_st.for should not be used for CPU pipeline");
|
||||
auto tilingResult = tileUsingGmlSt(opts, rewriter, cast<TilingInterface>(op));
|
||||
if (failed(tilingResult)) return failure();
|
||||
|
||||
// If we did not tile (e.g. when all tile sizes are 0), do not replace
|
||||
|
|
@ -446,7 +449,25 @@ FailureOr<Operation*> tileAndFuseGreedily(
|
|||
fuseFilterFn);
|
||||
}
|
||||
setLabel(tilingResult->tiledOps.front(), label);
|
||||
return tilingResult->loop;
|
||||
return cast<ParallelOp>(tilingResult->loop);
|
||||
}
|
||||
|
||||
FailureOr<scf::SCFTilingResult> tileUsingSCFForOpAndFuseGreedily(
|
||||
PatternRewriter& rewriter, Operation* op, const scf::SCFTilingOptions& opts,
|
||||
StringRef label, llvm::function_ref<bool(Operation*)> fuseFilterFn) {
|
||||
auto tilingResult = scf::tileUsingSCFForOp(rewriter, op, opts);
|
||||
if (failed(tilingResult)) return failure();
|
||||
|
||||
// If we did not tile (e.g. when all tile sizes are 0), do not replace
|
||||
// original op and just mark it as transformed then return.
|
||||
if (!tilingResult->loops.empty()) {
|
||||
rewriter.replaceOp(op, tilingResult->replacements);
|
||||
|
||||
// Fuse ops into the loop.
|
||||
fuseGreedily(rewriter, *tilingResult->loops.back().getBody(), fuseFilterFn);
|
||||
}
|
||||
setLabel(tilingResult->tiledOps.front(), label);
|
||||
return tilingResult;
|
||||
}
|
||||
|
||||
LogicalResult tilePeeledOpsToScalars(
|
||||
|
|
@ -464,8 +485,8 @@ LogicalResult tilePeeledOpsToScalars(
|
|||
opts.setTileSizeComputationFn(SmallVector<int64_t>(
|
||||
cast<linalg::LinalgOp>(definingOp).getNumLoops(), 1));
|
||||
|
||||
if (failed(tileAndFuseGreedily(rewriter, definingOp, opts, label,
|
||||
fuseFilterFn)))
|
||||
if (failed(tileUsingGmlStParallelAndFuseGreedily(rewriter, definingOp, opts,
|
||||
label, fuseFilterFn)))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
|
|
|
|||
|
|
@ -61,12 +61,17 @@ FusionCluster findMapFusionCluster(Operation *op);
|
|||
// Fuses linalg.fill that is used in init argument of the op.
|
||||
LogicalResult fuseOutputFill(PatternRewriter &rewriter, Operation *op);
|
||||
|
||||
// Tiles the op and fuses greedily according to the filter function.
|
||||
FailureOr<Operation *> tileAndFuseGreedily(
|
||||
// Tiles the op to gml_st.parallel and fuses greedily according to the filter.
|
||||
FailureOr<ParallelOp> tileUsingGmlStParallelAndFuseGreedily(
|
||||
PatternRewriter &rewriter, Operation *op,
|
||||
const mlir::gml_st::TilingOptions &opts, StringRef label,
|
||||
llvm::function_ref<bool(Operation *)> fuseFilterFn);
|
||||
|
||||
// Tiles the op to scf.for and fuses greedily according to the filter.
|
||||
FailureOr<scf::SCFTilingResult> tileUsingSCFForOpAndFuseGreedily(
|
||||
PatternRewriter &rewriter, Operation *op, const scf::SCFTilingOptions &opts,
|
||||
StringRef label, llvm::function_ref<bool(Operation *)> fuseFilterFn);
|
||||
|
||||
// Tiles the op to 1 for all dimensions and fuses greedily according to the
|
||||
// filter function.
|
||||
LogicalResult tilePeeledOpsToScalars(
|
||||
|
|
|
|||
|
|
@ -23,13 +23,10 @@ limitations under the License.
|
|||
|
||||
#include "gml_st/IR/gml_st_ops.h"
|
||||
#include "gml_st/transforms/transforms.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
|
||||
#include "mlir/Dialect/Tensor/Utils/Utils.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace gml_st {
|
||||
|
|
@ -43,25 +40,8 @@ bool hasTensorSemantics(Operation *op) {
|
|||
llvm::all_of(op->getOperandTypes(), isATensor);
|
||||
}
|
||||
|
||||
/// Rewrite a LoopOp/ParallelOp/ForOp with bounds/step that potentially do not
|
||||
/// divide evenly into two LoopOp/ParallelOp/ForOps: One where the step divides
|
||||
/// the iteration space evenly, followed another one for the last (partial)
|
||||
/// iteration (if any). This function only rewrites the `idx`-th loop of the
|
||||
/// loop nest represented by the LoopOp/ParallelOp/ForOp. To peel the entire
|
||||
/// loop nest, this function must be called multiple times.
|
||||
///
|
||||
/// This function rewrites the given LoopOp/ParallelOp/ForOp in-place and
|
||||
/// creates a new LoopOp/ParallelOp/ForOp for the last iteration. It replaces
|
||||
/// all uses of the original LoopOp/ParallelOp/ForOp with the results of the
|
||||
/// newly generated one.
|
||||
///
|
||||
/// The newly generated LoopOp/ParallelOp/ForOp is returned via `result`. The
|
||||
/// boundary at which the loop is split (new upper bound) is returned via
|
||||
/// `splitBound`. The return value indicates whether the
|
||||
/// LoopOp/ParallelOp/ForOp was rewritten or not.
|
||||
template <typename LoopTy>
|
||||
LogicalResult peelLoop(RewriterBase &b, LoopTy loopOp, int64_t idx,
|
||||
LoopTy &result, Value &splitBound) {
|
||||
LogicalResult peelLoop(RewriterBase &b, ParallelOp loopOp, int64_t idx,
|
||||
ParallelOp &result, Value &splitBound) {
|
||||
if (!hasTensorSemantics(loopOp)) return failure();
|
||||
|
||||
Value lb = loopOp.getLowerBound()[idx], ub = loopOp.getUpperBound()[idx],
|
||||
|
|
@ -93,7 +73,7 @@ LogicalResult peelLoop(RewriterBase &b, LoopTy loopOp, int64_t idx,
|
|||
bvm.map(termDst, res);
|
||||
}
|
||||
b.setInsertionPointAfter(loopOp);
|
||||
auto remainderLoop = cast<LoopTy>(b.clone(*loopOp.getOperation(), bvm));
|
||||
auto remainderLoop = cast<ParallelOp>(b.clone(*loopOp.getOperation(), bvm));
|
||||
|
||||
Operation *remainderLoopOp = remainderLoop.getOperation();
|
||||
|
||||
|
|
@ -137,17 +117,31 @@ void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, Operation *mainLoop,
|
|||
});
|
||||
}
|
||||
|
||||
template <typename LoopTy>
|
||||
FailureOr<LoopTy> peelAndCanonicalizeGmlStLoopImpl(RewriterBase &rewriter,
|
||||
LoopTy loopOp, int64_t idx) {
|
||||
} // namespace
|
||||
|
||||
PeelingResult peelAllLoops(ParallelOp loop, mlir::PatternRewriter &rewriter) {
|
||||
setLabel(loop, kPeelingAppliedLabel);
|
||||
PeelingResult peelingResult;
|
||||
for (unsigned peeledIdx = 0; peeledIdx < loop.getNumLoops(); ++peeledIdx) {
|
||||
auto peel = peelAndCanonicalizeGmlStLoop(rewriter, loop, peeledIdx);
|
||||
if (failed(peel)) continue;
|
||||
// Mark the new loop if one was created.
|
||||
setLabel(peel->getOperation(), kPeelingAppliedLabel);
|
||||
peelingResult.push_back(*peel);
|
||||
}
|
||||
return peelingResult;
|
||||
}
|
||||
|
||||
FailureOr<ParallelOp> peelAndCanonicalizeGmlStLoop(RewriterBase &rewriter,
|
||||
ParallelOp loopOp,
|
||||
int64_t idx) {
|
||||
int64_t numLoops = loopOp.getNumLoops();
|
||||
if (idx < 0 || numLoops <= idx) return failure();
|
||||
|
||||
Value ub = loopOp.getUpperBound()[idx];
|
||||
LoopTy remainderLoop;
|
||||
ParallelOp remainderLoop;
|
||||
Value splitBound;
|
||||
if (failed(
|
||||
peelLoop<LoopTy>(rewriter, loopOp, idx, remainderLoop, splitBound)))
|
||||
if (failed(peelLoop(rewriter, loopOp, idx, remainderLoop, splitBound)))
|
||||
return failure();
|
||||
|
||||
// Rewrite affine.min and affine.max ops.
|
||||
|
|
@ -162,39 +156,13 @@ FailureOr<LoopTy> peelAndCanonicalizeGmlStLoopImpl(RewriterBase &rewriter,
|
|||
return remainderLoop;
|
||||
}
|
||||
|
||||
template <typename LoopTy>
|
||||
PeelingResult peelAllLoopsImpl(LoopTy loop, mlir::PatternRewriter &rewriter) {
|
||||
setLabel(loop, kPeelingAppliedLabel);
|
||||
PeelingResult peelingResult;
|
||||
for (unsigned peeledIdx = 0; peeledIdx < loop.getNumLoops(); ++peeledIdx) {
|
||||
auto peel =
|
||||
peelAndCanonicalizeGmlStLoopImpl<LoopTy>(rewriter, loop, peeledIdx);
|
||||
if (failed(peel)) continue;
|
||||
// Mark the new loop if one was created.
|
||||
setLabel(peel->getOperation(), kPeelingAppliedLabel);
|
||||
peelingResult.push_back(*peel);
|
||||
}
|
||||
return peelingResult;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
PeelingResult peelAllLoops(ForOp loop, mlir::PatternRewriter &rewriter) {
|
||||
return peelAllLoopsImpl<ForOp>(loop, rewriter);
|
||||
}
|
||||
|
||||
PeelingResult peelAllLoops(ParallelOp loop, mlir::PatternRewriter &rewriter) {
|
||||
return peelAllLoopsImpl<ParallelOp>(loop, rewriter);
|
||||
}
|
||||
|
||||
FailureOr<ForOp> peelAndCanonicalizeGmlStLoop(RewriterBase &rewriter,
|
||||
ForOp loopOp, int64_t idx) {
|
||||
return peelAndCanonicalizeGmlStLoopImpl<ForOp>(rewriter, loopOp, idx);
|
||||
}
|
||||
|
||||
FailureOr<ParallelOp> peelAndCanonicalizeGmlStLoop(RewriterBase &rewriter,
|
||||
ParallelOp loopOp,
|
||||
int64_t idx) {
|
||||
return peelAndCanonicalizeGmlStLoopImpl<ParallelOp>(rewriter, loopOp, idx);
|
||||
SCFForPeelingResult peelSCFForOp(RewriterBase &rewriter, scf::ForOp loop) {
|
||||
// Peeling fails, if the step divides the upper bound. In that case,
|
||||
// we still want to return {loop, nullptr}.
|
||||
scf::ForOp tailLoop;
|
||||
return succeeded(scf::peelAndCanonicalizeForLoop(rewriter, loop, tailLoop))
|
||||
? SCFForPeelingResult{loop, tailLoop}
|
||||
: SCFForPeelingResult{loop, nullptr};
|
||||
}
|
||||
|
||||
} // namespace gml_st
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||
|
||||
#include "gml_st/IR/gml_st_ops.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace mlir {
|
||||
|
|
@ -30,25 +31,22 @@ constexpr llvm::StringRef kPeelingAppliedLabel = "__peeling_applied_label__";
|
|||
|
||||
using PeelingResult = SmallVector<Operation *>;
|
||||
|
||||
/// Rewrite a gml_st::ParallelOp/ForOp with bounds/step that potentially
|
||||
/// do not divide evenly into a gml_st::ParallelOp/ForOp where the step
|
||||
/// divides the iteration space evenly, followed by another
|
||||
/// gml_st::ParallelOp/ForOp for the last (partial) iteration (if any).
|
||||
/// This transformation is called "loop peeling".
|
||||
/// Rewrite a gml_st::ParallelOp with bounds/step that potentially do not divide
|
||||
/// evenly into a gml_st::ParallelOp where the step divides the iteration space
|
||||
/// evenly, followed by another gml_st::ParallelOp for the last (partial)
|
||||
/// iteration (if any). This transformation is called "loop peeling".
|
||||
///
|
||||
/// These functions peel all loops in the loop nest by calling
|
||||
/// peelAndCanonicalizeGmlStLoop. Additionally, they mark all loops (main and
|
||||
/// remainder loops) as peeled, so the same loop is not rewritten a second time.
|
||||
PeelingResult peelAllLoops(ForOp loop, mlir::PatternRewriter &rewriter);
|
||||
PeelingResult peelAllLoops(ParallelOp loop, mlir::PatternRewriter &rewriter);
|
||||
|
||||
/// These functions peel the `idx`-th loop of the
|
||||
/// gml_st::ParallelOp/ForOp. To peel all loops in the loop nest, these
|
||||
/// functions must be called multiple times.
|
||||
/// These functions peel the `idx`-th loop of the gml_st::ParallelOp. To peel
|
||||
/// all loops in the loop nest, these functions must be called multiple times.
|
||||
///
|
||||
/// After loop peeling, these functions try to simplify/canonicalize affine.min
|
||||
/// and affine.max ops in the body of the two gml_st::ParallelOp/ForOps.
|
||||
/// For more details, refer to `mlir::scf::peelAndCanonicalizeForLoop`.
|
||||
/// and affine.max ops in the body of the two gml_st::ParallelOps. For more
|
||||
/// details, refer to `mlir::scf::peelAndCanonicalizeForLoop`.
|
||||
///
|
||||
/// The return value indicates whether the loop was rewritten or not. Loops are
|
||||
/// not rewritten if:
|
||||
|
|
@ -56,18 +54,23 @@ PeelingResult peelAllLoops(ParallelOp loop, mlir::PatternRewriter &rewriter);
|
|||
/// * Loop bounds and step size are static, and step already divides the
|
||||
/// iteration space evenly.
|
||||
///
|
||||
/// Note: These functions rewrite the given gml_st::ParallelOp/ForOp
|
||||
/// in-place and clone the gml_st::ParallelOp/ForOp operation for the last
|
||||
/// iteration. They replace all uses of the unpeeled gml_st::ParallelOp/ForOp
|
||||
/// with the results of the newly generated gml_st::ParallelOp/ForOp.
|
||||
/// Note: These functions rewrite the given gml_st::ParallelOp in-place and
|
||||
/// clone the gml_st::ParallelOp operation for the last iteration. They replace
|
||||
/// all uses of the unpeeled gml_st::ParallelOp with the results of the newly
|
||||
/// generated gml_st::ParallelOp.
|
||||
///
|
||||
/// Note: These functions do not mark the loops as peeled. This should be
|
||||
/// handled by the caller.
|
||||
FailureOr<ForOp> peelAndCanonicalizeGmlStLoop(RewriterBase &rewriter,
|
||||
ForOp loopOp, int64_t idx);
|
||||
FailureOr<ParallelOp> peelAndCanonicalizeGmlStLoop(RewriterBase &rewriter,
|
||||
ParallelOp loopOp,
|
||||
int64_t idx);
|
||||
|
||||
struct SCFForPeelingResult {
|
||||
scf::ForOp mainLoop = nullptr;
|
||||
scf::ForOp tailLoop = nullptr;
|
||||
};
|
||||
SCFForPeelingResult peelSCFForOp(RewriterBase &rewriter, scf::ForOp);
|
||||
|
||||
} // namespace gml_st
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
|||
|
|
@ -175,7 +175,7 @@ struct TilingPattern : public OpInterfaceRewritePattern<TilingInterface> {
|
|||
if (!filterFn || failed(filterFn(op)) || hasLabel(op, kTileAppliedLabel))
|
||||
return failure();
|
||||
|
||||
auto tilingResult = tile(options, rewriter, op);
|
||||
auto tilingResult = tileUsingGmlSt(options, rewriter, op);
|
||||
if (failed(tilingResult)) return failure();
|
||||
|
||||
// If we did not tile (e.g. when all tile sizes are 0), do not replace
|
||||
|
|
@ -243,8 +243,9 @@ struct TilingPass : public impl::TilingPassBase<TilingPass> {
|
|||
|
||||
} // namespace
|
||||
|
||||
FailureOr<TilingResult> tile(const TilingOptions &options,
|
||||
PatternRewriter &rewriter, TilingInterface op) {
|
||||
FailureOr<TilingResult> tileUsingGmlSt(const TilingOptions &options,
|
||||
PatternRewriter &rewriter,
|
||||
TilingInterface op) {
|
||||
rewriter.setInsertionPoint(op);
|
||||
if (!options.tileSizeComputationFn) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/TilingInterface.h"
|
||||
|
||||
|
|
@ -53,8 +54,9 @@ struct TilingOptions {
|
|||
|
||||
/// Create tiled operation based on the specified tiling options. The result is
|
||||
/// equivalent to original op.
|
||||
FailureOr<TilingResult> tile(const TilingOptions &options,
|
||||
PatternRewriter &rewriter, TilingInterface op);
|
||||
FailureOr<TilingResult> tileUsingGmlSt(const TilingOptions &options,
|
||||
PatternRewriter &rewriter,
|
||||
TilingInterface op);
|
||||
|
||||
/// Populate tiling patterns.
|
||||
void populateTilingPatterns(
|
||||
|
|
|
|||
|
|
@ -178,7 +178,7 @@ struct TilePartialSoftmaxPattern
|
|||
tilingOptions.distributionLabel = distributionLabel;
|
||||
// Tile.
|
||||
FailureOr<TilingResult> tilingResult =
|
||||
tile(tilingOptions, rewriter, op);
|
||||
tileUsingGmlSt(tilingOptions, rewriter, op);
|
||||
if (failed(tilingResult)) return failure();
|
||||
|
||||
rewriter.replaceOp(op, tilingResult->loop->getResults());
|
||||
|
|
|
|||
|
|
@ -74,21 +74,19 @@ struct TileMapPattern : public OpRewritePattern<linalg::MapOp> {
|
|||
return tiles;
|
||||
};
|
||||
|
||||
auto tiledLoop = tileAndFuseGreedily(rewriter, op, opts,
|
||||
kMapTransformedLabel, fuseFilterFn);
|
||||
auto tiledLoop = tileUsingGmlStParallelAndFuseGreedily(
|
||||
rewriter, op, opts, kMapTransformedLabel, fuseFilterFn);
|
||||
if (failed(tiledLoop)) return failure();
|
||||
|
||||
// Peel parallel loops.
|
||||
if (auto loop = dyn_cast_or_null<ParallelOp>(*tiledLoop)) {
|
||||
auto peelingResult = peelAllLoops(loop, rewriter);
|
||||
setLabel(loop, kPerfectlyTiledLoopLabel);
|
||||
auto peelingResult = peelAllLoops(*tiledLoop, rewriter);
|
||||
setLabel(*tiledLoop, kPerfectlyTiledLoopLabel);
|
||||
|
||||
// Tile ops in the peeled loop again, to size 1, so they can be
|
||||
// scalarized.
|
||||
if (failed(tilePeeledOpsToScalars(rewriter, peelingResult,
|
||||
kMapTransformedLabel, fuseFilterFn)))
|
||||
return failure();
|
||||
}
|
||||
// Tile ops in the peeled loop again, to size 1, so they can be
|
||||
// scalarized.
|
||||
if (failed(tilePeeledOpsToScalars(rewriter, peelingResult,
|
||||
kMapTransformedLabel, fuseFilterFn)))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ limitations under the License.
|
|||
#include "gml_st/transforms/tiling/tiling.h"
|
||||
#include "gml_st/transforms/transforms.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Complex/IR/Complex.h"
|
||||
|
|
@ -36,9 +35,9 @@ limitations under the License.
|
|||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
|
||||
|
|
@ -363,12 +362,11 @@ struct MatmulToMmt4dPattern : public OpRewritePattern<linalg::MatmulOp> {
|
|||
};
|
||||
|
||||
FailureOr<TilingResult> tileMatmul(PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<int64_t> tileSizes,
|
||||
bool distribute) {
|
||||
ArrayRef<int64_t> tileSizes) {
|
||||
TilingOptions opts;
|
||||
opts.setTileSizeComputationFn(tileSizes);
|
||||
opts.distribute = distribute;
|
||||
return tile(opts, rewriter, cast<TilingInterface>(op));
|
||||
opts.distribute = true;
|
||||
return tileUsingGmlSt(opts, rewriter, cast<TilingInterface>(op));
|
||||
}
|
||||
|
||||
/// Splits the tile sizes in `parallelSizes` into `reductionSizes` for the
|
||||
|
|
@ -387,8 +385,9 @@ void splitParallelAndReductionTiles(linalg::LinalgOp op,
|
|||
}
|
||||
}
|
||||
|
||||
FailureOr<Operation *> tile(PatternRewriter &rewriter, Operation *op,
|
||||
const scf::SCFTilingOptions &tilingOptions) {
|
||||
FailureOr<Operation *> tileUsingSCFForAndReplace(
|
||||
PatternRewriter &rewriter, Operation *op,
|
||||
const scf::SCFTilingOptions &tilingOptions) {
|
||||
auto tilingResult = scf::tileUsingSCFForOp(rewriter, op, tilingOptions);
|
||||
if (failed(tilingResult) || tilingResult->loops.empty()) return failure();
|
||||
rewriter.replaceOp(op, tilingResult->replacements);
|
||||
|
|
@ -422,13 +421,16 @@ struct Mmt4DTransformPattern : public OpRewritePattern<linalg::Mmt4DOp> {
|
|||
});
|
||||
|
||||
auto *lhsOp = mmt4dOp.getInputs()[0].getDefiningOp();
|
||||
if (failed(tile(rewriter, lhsOp, packTilingOptions))) return failure();
|
||||
if (failed(tileUsingSCFForAndReplace(rewriter, lhsOp, packTilingOptions)))
|
||||
return failure();
|
||||
|
||||
auto *rhsOp = mmt4dOp.getInputs()[1].getDefiningOp();
|
||||
if (failed(tile(rewriter, rhsOp, packTilingOptions))) return failure();
|
||||
if (failed(tileUsingSCFForAndReplace(rewriter, rhsOp, packTilingOptions)))
|
||||
return failure();
|
||||
|
||||
auto *accOp = mmt4dOp.getOutputs()[0].getDefiningOp();
|
||||
if (failed(tile(rewriter, accOp, packTilingOptions))) return failure();
|
||||
if (failed(tileUsingSCFForAndReplace(rewriter, accOp, packTilingOptions)))
|
||||
return failure();
|
||||
|
||||
// Tile tensor.unpack op.
|
||||
auto unpackTilingOptions =
|
||||
|
|
@ -452,7 +454,9 @@ struct Mmt4DTransformPattern : public OpRewritePattern<linalg::Mmt4DOp> {
|
|||
});
|
||||
|
||||
auto *unpackOp = *mmt4dOp->user_begin();
|
||||
if (failed(tile(rewriter, unpackOp, unpackTilingOptions))) return failure();
|
||||
if (failed(
|
||||
tileUsingSCFForAndReplace(rewriter, unpackOp, unpackTilingOptions)))
|
||||
return failure();
|
||||
|
||||
// Compute the tile sizes. Note that at this stage we only do layout tiling.
|
||||
// Later we might also want to do traversal tiling (only on M and N dims).
|
||||
|
|
@ -484,15 +488,16 @@ struct Mmt4DTransformPattern : public OpRewritePattern<linalg::Mmt4DOp> {
|
|||
reductionTileSizes);
|
||||
|
||||
// Tile the parallel loops.
|
||||
auto tiledOp =
|
||||
tile(rewriter, mmt4dOp,
|
||||
scf::SCFTilingOptions().setTileSizes(parallelTileSizes));
|
||||
auto tiledOp = tileUsingSCFForAndReplace(
|
||||
rewriter, mmt4dOp.getOperation(),
|
||||
scf::SCFTilingOptions().setTileSizes(parallelTileSizes));
|
||||
if (failed(tiledOp)) return failure();
|
||||
mmt4dOp = cast<linalg::Mmt4DOp>(*tiledOp);
|
||||
|
||||
// Tile the reduction loops.
|
||||
tiledOp = tile(rewriter, mmt4dOp,
|
||||
scf::SCFTilingOptions().setTileSizes(reductionTileSizes));
|
||||
tiledOp = tileUsingSCFForAndReplace(
|
||||
rewriter, mmt4dOp.getOperation(),
|
||||
scf::SCFTilingOptions().setTileSizes(reductionTileSizes));
|
||||
if (failed(tiledOp)) return failure();
|
||||
mmt4dOp = cast<linalg::Mmt4DOp>(*tiledOp);
|
||||
|
||||
|
|
@ -521,7 +526,7 @@ struct MatmulTransformPattern : public OpRewritePattern<linalg::MatmulOp> {
|
|||
if (hasLabel(matmulOp, kMatmulTransformedLabel))
|
||||
return rewriter.notifyMatchFailure(matmulOp,
|
||||
"has already been transformed.");
|
||||
if (isa<gml_st::ParallelOp, gml_st::ForOp>(matmulOp->getParentOp()))
|
||||
if (isa<gml_st::ParallelOp, scf::ForOp>(matmulOp->getParentOp()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
matmulOp, "has already been tiled by another pass.");
|
||||
|
||||
|
|
@ -536,8 +541,8 @@ struct MatmulTransformPattern : public OpRewritePattern<linalg::MatmulOp> {
|
|||
if (isa<linalg::MatmulOp>(tilingRoot)) parallelDimsTileSizes.push_back(0);
|
||||
|
||||
// First level tiling: parallel dimensions.
|
||||
auto tilingParallelDimsResult = tileMatmul(
|
||||
rewriter, tilingRoot, parallelDimsTileSizes, /*distribute=*/true);
|
||||
auto tilingParallelDimsResult =
|
||||
tileMatmul(rewriter, tilingRoot, parallelDimsTileSizes);
|
||||
if (failed(tilingParallelDimsResult)) return failure();
|
||||
|
||||
// Update the results if tiling occurred.
|
||||
|
|
@ -552,7 +557,7 @@ struct MatmulTransformPattern : public OpRewritePattern<linalg::MatmulOp> {
|
|||
}
|
||||
|
||||
// Second level tiling: reduction dimension for matmuls.
|
||||
SmallVector<TilingResult> tilingReductionDimsResults;
|
||||
SmallVector<scf::SCFTilingResult> tilingReductionDimsResults;
|
||||
for (auto op :
|
||||
llvm::to_vector(tilingRoot->getBlock()->getOps<linalg::MatmulOp>())) {
|
||||
// Fusion into the output.
|
||||
|
|
@ -560,7 +565,7 @@ struct MatmulTransformPattern : public OpRewritePattern<linalg::MatmulOp> {
|
|||
|
||||
auto result = tileMatmulReductionDims(rewriter, op);
|
||||
if (failed(result)) return failure();
|
||||
tilingReductionDimsResults.push_back(result.value());
|
||||
tilingReductionDimsResults.push_back(*result);
|
||||
}
|
||||
|
||||
// Peel parallel loops.
|
||||
|
|
@ -574,27 +579,27 @@ struct MatmulTransformPattern : public OpRewritePattern<linalg::MatmulOp> {
|
|||
// Peel reduction loop inside the main parallel loop, label the main loop as
|
||||
// "perfectly tiled" one, to enable vectorization after canonicalization.
|
||||
for (auto &res : tilingReductionDimsResults) {
|
||||
if (auto loop = dyn_cast_or_null<ForOp>(res.loop)) {
|
||||
auto peelingResult = peelAllLoops(loop, rewriter);
|
||||
setLabel(loop, kPerfectlyTiledLoopLabel);
|
||||
if (res.loops.size() == 1) {
|
||||
auto peelingResult = peelSCFForOp(rewriter, res.loops.front());
|
||||
setLabel(peelingResult.mainLoop, kPerfectlyTiledLoopLabel);
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
FailureOr<TilingResult> tileMatmulReductionDims(
|
||||
FailureOr<scf::SCFTilingResult> tileMatmulReductionDims(
|
||||
PatternRewriter &rewriter, linalg::MatmulOp matmulOp) const {
|
||||
SmallVector<int64_t> reductionDimsTileSizes{0, 0, reductionDimTileSize};
|
||||
auto tilingReductionDimsResult = tileMatmul(
|
||||
rewriter, matmulOp, reductionDimsTileSizes, /*distribute=*/false);
|
||||
scf::SCFTilingOptions opts;
|
||||
opts.setTileSizes(reductionDimsTileSizes);
|
||||
auto tilingReductionDimsResult =
|
||||
scf::tileUsingSCFForOp(rewriter, matmulOp.getOperation(), opts);
|
||||
if (failed(tilingReductionDimsResult)) return failure();
|
||||
|
||||
// Update the results if tiling occurred.
|
||||
if (tilingReductionDimsResult->loop != nullptr) {
|
||||
rewriter.replaceOp(matmulOp,
|
||||
tilingReductionDimsResult->loop->getResults());
|
||||
if (!tilingReductionDimsResult->loops.empty()) {
|
||||
rewriter.replaceOp(matmulOp, tilingReductionDimsResult->replacements);
|
||||
matmulOp =
|
||||
cast<linalg::MatmulOp>(tilingReductionDimsResult->tiledOps.front());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ FailureOr<TilingResult> tileMatmul(PatternRewriter &rewriter, Operation *op,
|
|||
opts.setTileSizeComputationFn(tileSizes);
|
||||
opts.distribute = distribute;
|
||||
opts.distributionLabel = distributionLabel;
|
||||
return tile(opts, rewriter, cast<TilingInterface>(op));
|
||||
return tileUsingGmlSt(opts, rewriter, cast<TilingInterface>(op));
|
||||
}
|
||||
|
||||
/// Pattern to tile `linalg.matmul`, fuse `linalg.fill` into generated
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@ limitations under the License.
|
|||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
|
|
@ -46,12 +48,12 @@ constexpr llvm::StringRef kReduceTransformedLabel =
|
|||
|
||||
FailureOr<TilingResult> tileReduce(PatternRewriter &rewriter,
|
||||
linalg::ReduceOp reduceOp,
|
||||
ArrayRef<int64_t> tileSizes,
|
||||
bool distribute) {
|
||||
ArrayRef<int64_t> tileSizes) {
|
||||
TilingOptions opts;
|
||||
opts.setTileSizeComputationFn(tileSizes);
|
||||
opts.distribute = distribute;
|
||||
return tile(opts, rewriter, cast<TilingInterface>(reduceOp.getOperation()));
|
||||
opts.distribute = true;
|
||||
return tileUsingGmlSt(opts, rewriter,
|
||||
cast<TilingInterface>(reduceOp.getOperation()));
|
||||
}
|
||||
|
||||
SmallVector<int64_t> getParallelDimTileSizes(int64_t reductionDim,
|
||||
|
|
@ -106,7 +108,7 @@ struct Reduce1DTransformPattern : public OpRewritePattern<linalg::ReduceOp> {
|
|||
"has already been transformed.");
|
||||
}
|
||||
|
||||
if (isa<gml_st::ParallelOp, gml_st::ForOp>(reduceOp->getParentOp())) {
|
||||
if (isa<gml_st::ParallelOp, scf::ForOp>(reduceOp->getParentOp())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
reduceOp, "has already been tiled by another pass.");
|
||||
}
|
||||
|
|
@ -136,7 +138,6 @@ struct Reduce1DTransformPattern : public OpRewritePattern<linalg::ReduceOp> {
|
|||
auto fillOp = reduceOp.getInits().front().getDefiningOp<linalg::FillOp>();
|
||||
if (!fillOp) return failure();
|
||||
auto neutralValue = fillOp.value();
|
||||
// .get
|
||||
|
||||
// fillOp.getValue();
|
||||
Type elementType = neutralValue.getType();
|
||||
|
|
@ -149,12 +150,11 @@ struct Reduce1DTransformPattern : public OpRewritePattern<linalg::ReduceOp> {
|
|||
rewriter.create<linalg::FillOp>(loc, neutralValue, emptyVector)
|
||||
.getResult(0);
|
||||
|
||||
auto tiledLoopBodyBuilder = [&](OpBuilder &b, Location loc, ValueRange ivs,
|
||||
auto tiledLoopBodyBuilder = [&](OpBuilder &b, Location loc, Value iv,
|
||||
ValueRange inits) {
|
||||
// Tile input as tensor<TILE_SIZExELEM_TYPE> and reshape into
|
||||
// tensor<(TILE_SIZE/VECTOR_SIZE)xVECTOR_SIZExELEM_TYPE>.
|
||||
Value inputSlice =
|
||||
tileAndReshapeInput(b, loc, ivs.front(), input, elementType);
|
||||
Value inputSlice = tileAndReshapeInput(b, loc, iv, input, elementType);
|
||||
|
||||
tensor::ExtractSliceOp initSlice = create1DSlice(
|
||||
b, loc, inits.front(), b.getIndexAttr(0), b.getIndexAttr(vectorSize));
|
||||
|
|
@ -171,18 +171,13 @@ struct Reduce1DTransformPattern : public OpRewritePattern<linalg::ReduceOp> {
|
|||
rewriter.cloneRegionBefore(reduceOp.getRegion(), region, region.end());
|
||||
setLabel(tiledReduceOp, kReduceTransformedLabel);
|
||||
|
||||
b.create<gml_st::SetYieldOp>(
|
||||
loc, tiledReduceOp.getResults(), inits,
|
||||
b.create<TileOp>(loc, initSlice.getMixedOffsets(),
|
||||
initSlice.getMixedSizes(),
|
||||
initSlice.getMixedStrides())
|
||||
.getResult());
|
||||
b.create<scf::YieldOp>(loc, tiledReduceOp.getResults());
|
||||
};
|
||||
|
||||
// Create a tiled loop
|
||||
auto tiledLoop = rewriter.create<ForOp>(loc, filledVector.getType(), zero,
|
||||
tileableBound, tileSizeValue,
|
||||
filledVector, tiledLoopBodyBuilder);
|
||||
auto tiledLoop =
|
||||
rewriter.create<scf::ForOp>(loc, zero, tileableBound, tileSizeValue,
|
||||
filledVector, tiledLoopBodyBuilder);
|
||||
setLabel(tiledLoop, kPerfectlyTiledLoopLabel);
|
||||
|
||||
// Create `linalg.reduce` from tensor<VECTOR_SIZExELEM_TYPE> to
|
||||
|
|
@ -191,30 +186,25 @@ struct Reduce1DTransformPattern : public OpRewritePattern<linalg::ReduceOp> {
|
|||
cloneReduceOp(rewriter, reduceOp, tiledLoop.getResult(0),
|
||||
reduceOp.getInits().front());
|
||||
|
||||
auto remainderLoopBodyBuilder = [&](OpBuilder &b, Location loc,
|
||||
ValueRange ivs, ValueRange inits) {
|
||||
Value inputSlice =
|
||||
create1DSlice(b, loc, input, ivs.front(), remainderSize);
|
||||
auto remainderLoopBodyBuilder = [&](OpBuilder &b, Location loc, Value iv,
|
||||
ValueRange inits) {
|
||||
Value inputSlice = create1DSlice(b, loc, input, iv, remainderSize);
|
||||
|
||||
Value initSlice = b.create<tensor::ExtractSliceOp>(
|
||||
loc, inits.front(), /*offsets=*/SmallVector<OpFoldResult>{},
|
||||
/*sizes=*/SmallVector<OpFoldResult>{},
|
||||
/*strides=*/SmallVector<OpFoldResult>{});
|
||||
|
||||
auto newReduceOp = cloneReduceOp(b, reduceOp, inputSlice, initSlice);
|
||||
|
||||
Value initTile = b.create<gml_st::TileOp>(
|
||||
loc, /*offsets=*/SmallVector<OpFoldResult>{});
|
||||
b.create<gml_st::SetYieldOp>(loc, newReduceOp, inits, initTile);
|
||||
auto newReduce = cloneReduceOp(b, reduceOp, inputSlice, initSlice);
|
||||
b.create<scf::YieldOp>(loc, newReduce);
|
||||
};
|
||||
|
||||
// Combine `horizontal reduce` with the tail of the input. The tail is
|
||||
// always smaller than TILE_SIZE.
|
||||
auto remainderLoop =
|
||||
rewriter
|
||||
.create<ForOp>(loc, reduceOp.getResultTypes(), tileableBound,
|
||||
inputSize, tileSizeValue, horizontalReduce,
|
||||
remainderLoopBodyBuilder)
|
||||
.create<scf::ForOp>(loc, tileableBound, inputSize, tileSizeValue,
|
||||
horizontalReduce, remainderLoopBodyBuilder)
|
||||
.getResult(0);
|
||||
|
||||
rewriter.replaceOp(reduceOp, remainderLoop);
|
||||
|
|
@ -329,26 +319,28 @@ struct Reduce2DTransformPattern : public OpRewritePattern<linalg::ReduceOp> {
|
|||
[&](Operation *op) { return fusionCluster.contains(op); });
|
||||
|
||||
// Process all reduces in a fusion cluster.
|
||||
for (auto op :
|
||||
for (auto tiledReduceOp :
|
||||
llvm::to_vector(tilingRoot->getBlock()->getOps<linalg::ReduceOp>())) {
|
||||
// Fuse Fill.
|
||||
if (failed(fuseOutputFill(rewriter, op))) return failure();
|
||||
if (failed(fuseOutputFill(rewriter, tiledReduceOp))) return failure();
|
||||
|
||||
// Second level tiling: reduction dimension.
|
||||
auto tilingReductionDimsResult = tileReductionDims(rewriter, op);
|
||||
auto tilingReductionDimsResult =
|
||||
tileReductionDims(rewriter, tiledReduceOp);
|
||||
if (failed(tilingReductionDimsResult)) return failure();
|
||||
|
||||
// Update the results if tiling occurred.
|
||||
if (tilingReductionDimsResult->loop != nullptr) {
|
||||
rewriter.replaceOp(op, tilingReductionDimsResult->loop->getResults());
|
||||
op =
|
||||
if (!tilingReductionDimsResult->loops.empty()) {
|
||||
rewriter.replaceOp(tiledReduceOp,
|
||||
tilingReductionDimsResult->replacements);
|
||||
tiledReduceOp =
|
||||
cast<linalg::ReduceOp>(tilingReductionDimsResult->tiledOps.front());
|
||||
fuseGreedily(rewriter, *op->getBlock(),
|
||||
fuseGreedily(rewriter, *tiledReduceOp->getBlock(),
|
||||
[&](Operation *op) { return isa<linalg::MapOp>(op); });
|
||||
}
|
||||
setLabel(op, kReduceTransformedLabel);
|
||||
setLabel(tiledReduceOp, kReduceTransformedLabel);
|
||||
|
||||
// Peel parallel loops.
|
||||
// Peel reduction loops.
|
||||
if (failed(peelReduction(rewriter, tilingParallelDimsResult.value(),
|
||||
tilingReductionDimsResult.value())))
|
||||
return failure();
|
||||
|
|
@ -404,15 +396,14 @@ struct Reduce2DTransformPattern : public OpRewritePattern<linalg::ReduceOp> {
|
|||
tilingParallelDimsResult =
|
||||
tileReduce(rewriter, reduceOp,
|
||||
getParallelDimTileSizes(reduceOp.getDimensions()[0],
|
||||
parallelDimTileSize),
|
||||
/*distribute=*/true);
|
||||
parallelDimTileSize));
|
||||
} else if (isa<linalg::MapOp>(tilingRoot)) {
|
||||
TilingOptions opts;
|
||||
opts.setTileSizeComputationFn({parallelDimTileSize});
|
||||
opts.distribute = true;
|
||||
|
||||
tilingParallelDimsResult =
|
||||
tile(opts, rewriter, cast<TilingInterface>(tilingRoot));
|
||||
tileUsingGmlSt(opts, rewriter, cast<TilingInterface>(tilingRoot));
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
|
|
@ -420,19 +411,18 @@ struct Reduce2DTransformPattern : public OpRewritePattern<linalg::ReduceOp> {
|
|||
return tilingParallelDimsResult;
|
||||
}
|
||||
|
||||
FailureOr<TilingResult> tileReductionDims(PatternRewriter &rewriter,
|
||||
linalg::ReduceOp reduceOp) const {
|
||||
auto tilingReductionDimsResult =
|
||||
tileReduce(rewriter, reduceOp,
|
||||
getReductionDimTileSizes(reduceOp.getDimensions()[0],
|
||||
reductionDimTileSize),
|
||||
/*distribute=*/false);
|
||||
return tilingReductionDimsResult;
|
||||
FailureOr<scf::SCFTilingResult> tileReductionDims(
|
||||
PatternRewriter &rewriter, linalg::ReduceOp reduceOp) const {
|
||||
scf::SCFTilingOptions tilingOptions;
|
||||
tilingOptions.setTileSizes(getReductionDimTileSizes(
|
||||
reduceOp.getDimensions()[0], reductionDimTileSize));
|
||||
return scf::tileUsingSCFForOp(rewriter, reduceOp.getOperation(),
|
||||
tilingOptions);
|
||||
}
|
||||
|
||||
LogicalResult peelReduction(
|
||||
PatternRewriter &rewriter, const TilingResult &tilingParallelDimsResult,
|
||||
const TilingResult &tilingReductionDimsResult) const {
|
||||
const scf::SCFTilingResult &tilingReductionDimsResult) const {
|
||||
// Peel parallel loops.
|
||||
if (auto loop =
|
||||
dyn_cast_or_null<ParallelOp>(tilingParallelDimsResult.loop)) {
|
||||
|
|
@ -441,36 +431,44 @@ struct Reduce2DTransformPattern : public OpRewritePattern<linalg::ReduceOp> {
|
|||
|
||||
// Peel reduction loop inside the main parallel loop, label the main loop as
|
||||
// "perfectly tiled" one, to enable vectorization after canonicalization.
|
||||
if (auto forLoop =
|
||||
dyn_cast_or_null<ForOp>(tilingReductionDimsResult.loop)) {
|
||||
auto peelingResult = peelAllLoops(forLoop, rewriter);
|
||||
setLabel(forLoop, kPerfectlyTiledLoopLabel);
|
||||
if (!tilingReductionDimsResult.loops.empty()) {
|
||||
scf::ForOp forLoop = tilingReductionDimsResult.loops.front();
|
||||
SCFForPeelingResult peelingResult = peelSCFForOp(rewriter, forLoop);
|
||||
if (peelingResult.mainLoop) {
|
||||
setLabel(peelingResult.mainLoop, kPerfectlyTiledLoopLabel);
|
||||
}
|
||||
|
||||
if (!peelingResult.tailLoop) return success();
|
||||
// Tile ops in the peeled loop again, to size 1, so they can be
|
||||
// scalarized.
|
||||
for (auto *loop : peelingResult) {
|
||||
ForOp peeledLoop = dyn_cast<ForOp>(loop);
|
||||
auto *terminatorOp = peeledLoop->getRegion(0).front().getTerminator();
|
||||
if (!terminatorOp) return failure();
|
||||
scf::ForOp peeledLoop = peelingResult.tailLoop;
|
||||
auto yieldOp = cast<scf::YieldOp>(peeledLoop.getBody()->getTerminator());
|
||||
auto reduceOp = getRootReduce(yieldOp);
|
||||
if (!reduceOp) return failure();
|
||||
|
||||
auto reduceOp =
|
||||
terminatorOp->getOperand(0).getDefiningOp<linalg::ReduceOp>();
|
||||
if (!reduceOp) return failure();
|
||||
scf::SCFTilingOptions opts;
|
||||
opts.setTileSizes(
|
||||
getReductionDimTileSizes(reduceOp.getDimensions()[0], 1));
|
||||
|
||||
mlir::gml_st::TilingOptions opts;
|
||||
opts.setTileSizeComputationFn(
|
||||
getReductionDimTileSizes(reduceOp.getDimensions()[0], 1));
|
||||
opts.distribute = false;
|
||||
|
||||
if (failed(tileAndFuseGreedily(
|
||||
rewriter, reduceOp, opts, kReduceTransformedLabel,
|
||||
[&](Operation *op) { return isa<linalg::MapOp>(op); })))
|
||||
return failure();
|
||||
}
|
||||
if (failed(tileUsingSCFForOpAndFuseGreedily(
|
||||
rewriter, reduceOp, opts, kReduceTransformedLabel,
|
||||
[&](Operation *op) { return isa<linalg::MapOp>(op); })))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
linalg::ReduceOp getRootReduce(scf::YieldOp yieldOp) const {
|
||||
if (yieldOp.getResults().size() != 1) return nullptr;
|
||||
|
||||
Value reduceResult = yieldOp.getResults().front();
|
||||
if (auto insertSliceOp =
|
||||
reduceResult.getDefiningOp<tensor::InsertSliceOp>()) {
|
||||
reduceResult = insertSliceOp.getSource();
|
||||
}
|
||||
return reduceResult.getDefiningOp<linalg::ReduceOp>();
|
||||
}
|
||||
|
||||
int64_t parallelDimTileSize;
|
||||
int64_t reductionDimTileSize;
|
||||
};
|
||||
|
|
@ -489,7 +487,8 @@ struct TransformReduceForCpuPass
|
|||
|
||||
void getDependentDialects(DialectRegistry ®istry) const final {
|
||||
registry.insert<mlir::gml_st::GmlStDialect, arith::ArithDialect,
|
||||
linalg::LinalgDialect, tensor::TensorDialect>();
|
||||
linalg::LinalgDialect, scf::SCFDialect,
|
||||
tensor::TensorDialect>();
|
||||
linalg::registerTilingInterfaceExternalModels(registry);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -44,8 +44,8 @@ FailureOr<TilingResult> tileReverseAndUpdateResultIfTiled(
|
|||
TilingOptions opts;
|
||||
opts.setTileSizeComputationFn(tileSizes);
|
||||
opts.distribute = distribute;
|
||||
auto tilingResult =
|
||||
tile(opts, rewriter, cast<TilingInterface>(reverseOp.getOperation()));
|
||||
auto tilingResult = tileUsingGmlSt(
|
||||
opts, rewriter, cast<TilingInterface>(reverseOp.getOperation()));
|
||||
|
||||
if (failed(tilingResult)) return failure();
|
||||
|
||||
|
|
|
|||
|
|
@ -19,9 +19,11 @@ limitations under the License.
|
|||
#include "gml_st/IR/gml_st_ops.h"
|
||||
#include "gml_st/transforms/passes.h"
|
||||
#include "gml_st/transforms/tiling/tiling.h"
|
||||
#include "gml_st/transforms/transforms.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
|
@ -33,23 +35,24 @@ namespace {
|
|||
#define GEN_PASS_DEF_TRANSFORMSCATTERFORCPUPASS
|
||||
#include "gml_st/transforms/passes.h.inc"
|
||||
|
||||
struct TransformScatterForCpuPass
|
||||
: public impl::TransformScatterForCpuPassBase<TransformScatterForCpuPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const final {
|
||||
registry.insert<mlir::gml_st::GmlStDialect, arith::ArithDialect,
|
||||
tensor::TensorDialect>();
|
||||
linalg::registerTilingInterfaceExternalModels(registry);
|
||||
}
|
||||
constexpr llvm::StringRef kScatterTransformedLabel =
|
||||
"__scatter_transformed_label__";
|
||||
|
||||
void runOnOperation() override {
|
||||
func::FuncOp f = getOperation();
|
||||
MLIRContext *ctx = &getContext();
|
||||
struct TileScatterPattern : public OpRewritePattern<thlo::ScatterOp> {
|
||||
using OpRewritePattern<thlo::ScatterOp>::OpRewritePattern;
|
||||
|
||||
mlir::gml_st::TilingOptions opts;
|
||||
opts.distribute = false; // Tile to `for` loops.
|
||||
LogicalResult matchAndRewrite(thlo::ScatterOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (hasLabel(op, kScatterTransformedLabel)) return failure();
|
||||
|
||||
if (isa<scf::ForOp>(op->getParentOp())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "has already been tiled by another pass.");
|
||||
}
|
||||
|
||||
// Tile everything to points.
|
||||
opts.tileSizeComputationFn = [](OpBuilder &b, Operation *op) {
|
||||
scf::SCFTilingOptions opts;
|
||||
opts.setTileSizeComputationFunction([](OpBuilder &b, Operation *op) {
|
||||
OpBuilder::InsertionGuard guard(b);
|
||||
b.setInsertionPointToStart(
|
||||
&op->getParentOfType<func::FuncOp>().getBody().front());
|
||||
|
|
@ -57,22 +60,43 @@ struct TransformScatterForCpuPass
|
|||
auto loops = cast<TilingInterface>(op).getLoopIteratorTypes();
|
||||
return SmallVector<Value>(
|
||||
loops.size(), b.create<arith::ConstantIndexOp>(op->getLoc(), 1));
|
||||
};
|
||||
});
|
||||
|
||||
auto filterFn = [&](Operation *op) {
|
||||
if (isa<mlir::thlo::ScatterOp>(op))
|
||||
return success();
|
||||
return failure();
|
||||
};
|
||||
auto tilingResult = scf::tileUsingSCFForOp(
|
||||
rewriter, cast<TilingInterface>(op.getOperation()), opts);
|
||||
if (failed(tilingResult)) return failure();
|
||||
|
||||
// If we did not tile, do not replace original op and just mark it as
|
||||
// transformed then return.
|
||||
if (!tilingResult->loops.empty()) {
|
||||
rewriter.replaceOp(op, tilingResult->replacements);
|
||||
}
|
||||
setLabel(tilingResult->tiledOps.front(), kScatterTransformedLabel);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TransformScatterForCpuPass
|
||||
: public impl::TransformScatterForCpuPassBase<TransformScatterForCpuPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const final {
|
||||
registry.insert<arith::ArithDialect, gml_st::GmlStDialect, scf::SCFDialect,
|
||||
tensor::TensorDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
func::FuncOp f = getOperation();
|
||||
MLIRContext *ctx = &getContext();
|
||||
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateTilingPatterns(ctx, filterFn, opts, &patterns);
|
||||
patterns.add<TileScatterPattern>(ctx);
|
||||
|
||||
if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) {
|
||||
if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
removeTilingLabels(f);
|
||||
// Ensure we drop the marker in the end.
|
||||
f.walk([](thlo::ScatterOp scatterOp) {
|
||||
removeLabel(scatterOp, kScatterTransformedLabel);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -53,8 +53,8 @@ struct TileSortPattern : public OpRewritePattern<SortOp> {
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "has already been tiled by another pass.");
|
||||
|
||||
auto tilingResult =
|
||||
tile(options, rewriter, cast<TilingInterface>(op.getOperation()));
|
||||
auto tilingResult = tileUsingGmlSt(
|
||||
options, rewriter, cast<TilingInterface>(op.getOperation()));
|
||||
if (failed(tilingResult)) return failure();
|
||||
|
||||
// If we did not tile (e.g. when all tile sizes are 0), do not replace
|
||||
|
|
|
|||
|
|
@ -56,8 +56,8 @@ struct TileTransposePattern : public OpRewritePattern<linalg::TransposeOp> {
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "has already been tiled by another pass.");
|
||||
|
||||
auto tilingResult =
|
||||
tile(options, rewriter, cast<TilingInterface>(op.getOperation()));
|
||||
auto tilingResult = tileUsingGmlSt(
|
||||
options, rewriter, cast<TilingInterface>(op.getOperation()));
|
||||
if (failed(tilingResult)) return failure();
|
||||
|
||||
// If we did not tile (e.g. when all tile sizes are 0), do not replace
|
||||
|
|
|
|||
|
|
@ -868,7 +868,7 @@ struct VectorizePerfectlyTiledLoopsPass
|
|||
});
|
||||
};
|
||||
auto isPerfectlyTiledLoop = [&](Operation *op) {
|
||||
return (isa<ForOp>(op) || isa<ParallelOp>(op)) &&
|
||||
return (isa<ForOp, ParallelOp, scf::ForOp>(op)) &&
|
||||
hasLabel(op, kPerfectlyTiledLoopLabel);
|
||||
};
|
||||
auto isInsidePerfectlyTiledLoop = [&](Operation *op) {
|
||||
|
|
|
|||
|
|
@ -18,10 +18,10 @@ func.func @matmul_static(%arg0: tensor<128x16xf32>, %arg1: tensor<16x64xf32>,
|
|||
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
|
||||
// CHECK: %[[FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]])
|
||||
// CHECK: %[[FOR:.*]] = scf.for %[[K:.*]] = %[[C0]]
|
||||
// CHECK: %[[MATMUL:.*]] = linalg.matmul
|
||||
// CHECK-SAME: -> tensor<8x4xf32>
|
||||
// CHECK: gml_st.set_yield %[[MATMUL]]
|
||||
// CHECK: scf.yield %[[MATMUL]]
|
||||
// CHECK: gml_st.set_yield %[[FOR]]
|
||||
|
||||
// -----
|
||||
|
|
@ -59,28 +59,36 @@ func.func @matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>)
|
|||
// TRANSFORMED: %[[MAIN_PAR:.*]] = gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[IUB:.*]], %[[JUB:.*]]) step
|
||||
// TRANSFORMED: %[[MAIN_SLICE:.*]] = tensor.extract_slice %[[INIT]]
|
||||
// TRANSFORMED: %[[MAIN_FILL:.*]] = linalg.fill{{.*}}outs(%[[MAIN_SLICE]]
|
||||
// TRANSFORMED: %[[MAIN_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]]) to (%[[KUB:.*]]) {{.*}} outs ({{.*}} = %[[MAIN_FILL]]:
|
||||
// TRANSFORMED: %[[MAIN_FOR:.*]] = scf.for %[[K:.*]] = %[[C0]] to %[[KUB:[a-z0-9]+]]
|
||||
// TRANSFORMED-SAME: iter_args(%{{.*}} = %[[MAIN_FILL]])
|
||||
// TRANSFORMED: %[[MAIN_PAR_MAIN_FOR_MATMUL:.*]] = linalg.matmul
|
||||
// TRANSFORMED: gml_st.set_yield %[[MAIN_PAR_MAIN_FOR_MATMUL]]
|
||||
// TRANSFORMED: %[[REM_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[KUB]]) {{.*}} outs ({{.*}} = %[[MAIN_FOR]]:
|
||||
// TRANSFORMED: %[[UPDATE:.*]] = tensor.insert_slice %[[MAIN_PAR_MAIN_FOR_MATMUL]]
|
||||
// TRANSFORMED-NEXT: scf.yield %[[UPDATE]]
|
||||
// TRANSFORMED: %[[REM_FOR:.*]] = scf.for %[[K:.*]] = %[[KUB]]
|
||||
// TRANSFORMED-SAME: iter_args(%{{.*}} = %[[MAIN_FOR]])
|
||||
// TRANSFORMED: %[[MAIN_PAR_REM_FOR_MATMUL:.*]] = linalg.matmul
|
||||
// TRANSFORMED: gml_st.set_yield %[[MAIN_PAR_REM_FOR_MATMUL]]
|
||||
// TRANSFORMED: %[[UPDATE:.*]] = tensor.insert_slice %[[MAIN_PAR_REM_FOR_MATMUL]]
|
||||
// TRANSFORMED-NEXT: scf.yield %[[UPDATE]]
|
||||
// TRANSFORMED: gml_st.set_yield %[[REM_FOR]]
|
||||
|
||||
// TRANSFORMED: %[[REM_RHS_PAR:.*]] = gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[JUB]])
|
||||
// TRANSFORMED: %[[REM_RHS_SLICE:.*]] = tensor.extract_slice %[[MAIN_PAR]]
|
||||
// TRANSFORMED: %[[REM_RHS_FILL:.*]] = linalg.fill{{.*}}outs(%[[REM_RHS_SLICE]]
|
||||
// TRANSFORMED: %[[REM_RHS_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]]) {{.*}} outs ({{.*}} = %[[REM_RHS_FILL]]:
|
||||
// TRANSFORMED: %[[REM_RHS_FOR:.*]] = scf.for %[[K:.*]] = %[[C0]]
|
||||
// TRANSFORMED-SAME: iter_args({{.*}} = %[[REM_RHS_FILL]])
|
||||
// TRANSFORMED: %[[REM_RHS_PAR_MATMUL:.*]] = linalg.matmul
|
||||
// TRANSFORMED: gml_st.set_yield %[[REM_RHS_PAR_MATMUL]]
|
||||
// TRANSFORMED: %[[UPDATE:.*]] = tensor.insert_slice %[[REM_RHS_PAR_MATMUL]]
|
||||
// TRANSFORMED-NEXT: scf.yield %[[UPDATE]]
|
||||
// TRANSFORMED: gml_st.set_yield %[[REM_RHS_FOR]]
|
||||
|
||||
// TRANSFORMED: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[IUB]], %[[C0]])
|
||||
// TRANSFORMED: %[[REM_LHS_SLICE:.*]] = tensor.extract_slice %[[REM_RHS_PAR]]
|
||||
// TRANSFORMED: %[[REM_LHS_FILL:.*]] = linalg.fill{{.*}}outs(%[[REM_LHS_SLICE]]
|
||||
// TRANSFORMED: %[[REM_LHS_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]]) {{.*}} outs ({{.*}} = %[[REM_LHS_FILL]]:
|
||||
// TRANSFORMED: %[[REM_LHS_FOR:.*]] = scf.for %[[K:.*]] = %[[C0]]
|
||||
// TRANSFORMED-SAME: iter_args({{.*}} = %[[REM_LHS_FILL]])
|
||||
// TRANSFORMED: %[[REM_LHS_PAR_MATMUL:.*]] = linalg.matmul
|
||||
// TRANSFORMED: gml_st.set_yield %[[REM_LHS_PAR_MATMUL]]
|
||||
// TRANSFORMED: %[[UPDATE:.*]] = tensor.insert_slice %[[REM_LHS_PAR_MATMUL]]
|
||||
// TRANSFORMED-NEXT: scf.yield %[[UPDATE]]
|
||||
// TRANSFORMED: gml_st.set_yield %[[REM_LHS_FOR]]
|
||||
|
||||
// -----
|
||||
|
|
@ -89,15 +97,15 @@ func.func @matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>)
|
|||
|
||||
// MARKED: %[[C0:.*]] = arith.constant 0 : index
|
||||
// MARKED: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[IUB:.*]], %[[JUB:.*]]) step
|
||||
// MARKED: gml_st.for (%[[K:.*]]) = (%[[C0]]) to (%[[KUB:.*]]) step
|
||||
// MARKED: scf.for %[[K:.*]] = %[[C0]] to %[[KUB:.*]] step
|
||||
// MARKED: __perfectly_tiled_loop_label__
|
||||
// MARKED: gml_st.for (%[[K:.*]]) = (%[[KUB]])
|
||||
// MARKED: scf.for %[[K:.*]] = %[[KUB]]
|
||||
|
||||
// MARKED: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[JUB]])
|
||||
// MARKED: gml_st.for (%[[K:.*]]) = (%[[C0]])
|
||||
// MARKED: scf.for %[[K:.*]] = %[[C0]]
|
||||
|
||||
// MARKED: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[IUB]], %[[C0]])
|
||||
// MARKED: gml_st.for (%[[K:.*]]) = (%[[C0]])
|
||||
// MARKED: scf.for %[[K:.*]] = %[[C0]]
|
||||
|
||||
// -----
|
||||
|
||||
|
|
@ -139,21 +147,25 @@ func.func @matmul_fuse_output(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
|
|||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
|
||||
// CHECK: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
|
||||
// CHECK: gml_st.for (%[[K:.*]]) = (%[[C0]])
|
||||
// CHECK: scf.for %[[K:.*]] = %[[C0]]
|
||||
// CHECK: %[[MATMUL:.*]] = linalg.matmul
|
||||
// CHECK: gml_st.set_yield %[[MATMUL]]
|
||||
// CHECK: %[[UPDATE:.*]] = tensor.insert_slice %[[MATMUL]]
|
||||
// CHECK-NEXT: scf.yield %[[UPDATE]]
|
||||
|
||||
// CHECK: gml_st.for
|
||||
// CHECK: scf.for
|
||||
// CHECK: %[[MATMUL:.*]] = linalg.matmul
|
||||
// CHECK: gml_st.set_yield %[[MATMUL]]
|
||||
// CHECK: %[[UPDATE:.*]] = tensor.insert_slice %[[MATMUL]]
|
||||
// CHECK-NEXT: scf.yield %[[UPDATE]]
|
||||
|
||||
// CHECK: gml_st.for (%[[K:.*]]) = (%[[C0]])
|
||||
// CHECK: scf.for %[[K:.*]] = %[[C0]]
|
||||
// CHECK: %[[MATMUL:.*]] = linalg.matmul
|
||||
// CHECK: gml_st.set_yield %[[MATMUL]]
|
||||
// CHECK: %[[UPDATE:.*]] = tensor.insert_slice %[[MATMUL]]
|
||||
// CHECK-NEXT: scf.yield %[[UPDATE]]
|
||||
|
||||
// CHECK: gml_st.for
|
||||
// CHECK: scf.for
|
||||
// CHECK: %[[MATMUL:.*]] = linalg.matmul
|
||||
// CHECK: gml_st.set_yield %[[MATMUL]]
|
||||
// CHECK: %[[UPDATE:.*]] = tensor.insert_slice %[[MATMUL]]
|
||||
// CHECK-NEXT: scf.yield %[[UPDATE]]
|
||||
|
||||
// CHECK: linalg.map
|
||||
// CHECK: linalg.map
|
||||
|
|
@ -161,23 +173,27 @@ func.func @matmul_fuse_output(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
|
|||
// CHECK: gml_st.set_yield
|
||||
|
||||
// CHECK: gml_st.parallel
|
||||
// CHECK: gml_st.for
|
||||
// CHECK: scf.for
|
||||
// CHECK: %[[MATMUL:.*]] = linalg.matmul
|
||||
// CHECK: gml_st.set_yield %[[MATMUL]]
|
||||
// CHECK: gml_st.for
|
||||
// CHECK: %[[UPDATE:.*]] = tensor.insert_slice %[[MATMUL]]
|
||||
// CHECK-NEXT: scf.yield %[[UPDATE]]
|
||||
// CHECK: scf.for
|
||||
// CHECK: %[[MATMUL:.*]] = linalg.matmul
|
||||
// CHECK: gml_st.set_yield %[[MATMUL]]
|
||||
// CHECK: %[[UPDATE:.*]] = tensor.insert_slice %[[MATMUL]]
|
||||
// CHECK-NEXT: scf.yield %[[UPDATE]]
|
||||
// CHECK: linalg.map
|
||||
// CHECK: linalg.map
|
||||
// CHECK: gml_st.set_yield
|
||||
|
||||
// CHECK: gml_st.parallel
|
||||
// CHECK: gml_st.for
|
||||
// CHECK: scf.for
|
||||
// CHECK: %[[MATMUL:.*]] = linalg.matmul
|
||||
// CHECK: gml_st.set_yield %[[MATMUL]]
|
||||
// CHECK: gml_st.for
|
||||
// CHECK: %[[UPDATE:.*]] = tensor.insert_slice %[[MATMUL]]
|
||||
// CHECK-NEXT: scf.yield %[[UPDATE]]
|
||||
// CHECK: scf.for
|
||||
// CHECK: %[[MATMUL:.*]] = linalg.matmul
|
||||
// CHECK: gml_st.set_yield %[[MATMUL]]
|
||||
// CHECK: %[[UPDATE:.*]] = tensor.insert_slice %[[MATMUL]]
|
||||
// CHECK-NEXT: scf.yield %[[UPDATE]]
|
||||
// CHECK: linalg.map
|
||||
// CHECK: linalg.map
|
||||
// CHECK: gml_st.set_yield
|
||||
|
|
|
|||
|
|
@ -1,11 +1,6 @@
|
|||
// RUN: mlir-hlo-opt %s \
|
||||
// RUN: mlir-hlo-opt %s --split-input-file \
|
||||
// RUN: -xla-cpu-transform-reduce="vector-size=8 tile-size-1d=32 tile-sizes-2d=4,2" \
|
||||
// RUN: --split-input-file \
|
||||
// RUN: | FileCheck %s --check-prefixes=CHECK,PEELED
|
||||
// RUN: mlir-hlo-opt %s \
|
||||
// RUN: -xla-cpu-transform-reduce="vector-size=8 tile-size-1d=32 tile-sizes-2d=4,2" \
|
||||
// RUN: --split-input-file \
|
||||
// RUN: | FileCheck %s --check-prefixes=MARKED
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
func.func @reduce_add_static(%input: tensor<100x10xf32>,
|
||||
%output: tensor<10xf32>) -> tensor<10xf32> {
|
||||
|
|
@ -15,29 +10,14 @@ func.func @reduce_add_static(%input: tensor<100x10xf32>,
|
|||
dimensions = [0]
|
||||
return %res : tensor<10xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @reduce_add_static
|
||||
|
||||
// CHECK-LABEL: func @reduce_add_static(
|
||||
// CHECK-SAME: %[[IN:.*]]: tensor<100x10xf32>,
|
||||
// CHECK-SAME: %[[OUT:.*]]: tensor<10xf32>)
|
||||
// CHECK-SAME: -> tensor<10xf32> {
|
||||
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
|
||||
// CHECK: gml_st.parallel (%[[I:.*]]) = (%[[C0]])
|
||||
// CHECK: %[[IN_SLICE_1:.*]] = tensor.extract_slice %[[IN]]
|
||||
// CHECK: %[[OUT_SLICE_1:.*]] = tensor.extract_slice %[[OUT]]
|
||||
|
||||
// CHECK: %[[FOR:.*]] = gml_st.for (%[[J:.*]]) = (%[[C0]])
|
||||
// CHECK: %[[IN_SLICE_2:.*]] = tensor.extract_slice
|
||||
// CHECK: %[[OUT_SLICE_2:.*]] = tensor.extract_slice
|
||||
|
||||
// CHECK: %[[REDUCED:.*]] = linalg.reduce
|
||||
// CHECK-SAME: ins(%[[IN_SLICE_2]] : tensor<2x?xf32>)
|
||||
// CHECK-SAME: outs(%[[OUT_SLICE_2]] : tensor<?xf32>)
|
||||
// CHECK-SAME: dimensions = [0]
|
||||
|
||||
// CHECK: gml_st.set_yield %[[REDUCED]]
|
||||
// CHECK: gml_st.set_yield %[[FOR]]
|
||||
// CHECK: gml_st.parallel
|
||||
// CHECK: scf.for
|
||||
// CHECK: linalg.reduce
|
||||
// CHECK-NEXT: tensor.insert_slice
|
||||
// CHECK-NEXT: scf.yield
|
||||
// CHECK: gml_st.set_yield
|
||||
|
||||
// -----
|
||||
|
||||
|
|
@ -56,54 +36,23 @@ func.func @reduce_mulf(%input: tensor<?x?xf32>,
|
|||
return %res : tensor<?xf32>
|
||||
}
|
||||
|
||||
// PEELED-LABEL: func @reduce_mulf(
|
||||
// PEELED-SAME: %[[IN:.*]]: tensor<?x?xf32>,
|
||||
// PEELED-SAME: %[[OUT:.*]]: tensor<?xf32>)
|
||||
// PEELED-SAME: -> tensor<?xf32> {
|
||||
// CHECK-LABEL: func @reduce_mulf
|
||||
|
||||
// PEELED-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// PEELED-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||
// PEELED-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||
// PEELED: %[[INIT:.*]] = tensor.empty
|
||||
// PEELED: %[[DIM0:.*]] = tensor.dim %[[IN]], %[[C0]]
|
||||
// CHECK: gml_st.parallel
|
||||
// CHECK: linalg.fill
|
||||
// CHECK: scf.for
|
||||
// CHECK: scf.yield
|
||||
// CHECK: scf.for
|
||||
// CHECK: scf.for
|
||||
// CHECK: scf.yield
|
||||
// CHECK: scf.yield
|
||||
// CHECK: gml_st.set_yield
|
||||
|
||||
// PEELED: %[[MAIN_PAR:.*]] = gml_st.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[IUB:.*]]) step
|
||||
// PEELED: %[[MAIN_SLICE:.*]] = tensor.extract_slice %[[INIT]]
|
||||
// PEELED: %[[MAIN_FILL:.*]] = linalg.fill{{.*}}outs(%[[MAIN_SLICE]]
|
||||
// PEELED: %[[MAIN_FOR:.*]] = gml_st.for (%[[J:.*]]) = (%[[C0]]) to (%[[JUB:.*]]) {{.*}} outs ({{.*}} = %[[MAIN_FILL]]:
|
||||
// PEELED: %[[MAIN_PAR_MAIN_FOR_REDUCE:.*]] = linalg.reduce
|
||||
// PEELED: gml_st.set_yield %[[MAIN_PAR_MAIN_FOR_REDUCE]]
|
||||
// PEELED: %[[REM_FOR:.*]] = gml_st.for (%[[J:.*]]) = (%[[JUB]]) {{.*}} outs (%[[REM_FOR_ARG:.*]] = %[[MAIN_FOR]]:
|
||||
// PEELED: %[[REM_FOR_SLICE:.*]] = tensor.extract_slice %[[REM_FOR_ARG]]
|
||||
// PEELED: %[[SCALAR_REM_FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]]) {{.*}} outs ({{.*}} = %[[REM_FOR_SLICE]]:
|
||||
// PEELED: %[[MAIN_PAR_REM_FOR_REDUCE:.*]] = linalg.reduce
|
||||
// PEELED: gml_st.set_yield %[[MAIN_PAR_REM_FOR_REDUCE]]
|
||||
// PEELED: gml_st.set_yield %[[SCALAR_REM_FOR]]
|
||||
// PEELED: gml_st.set_yield %[[REM_FOR]]
|
||||
|
||||
// PEELED: %[[REM_PAR:.*]] = gml_st.parallel (%[[I:.*]]) = (%[[IUB]])
|
||||
// PEELED: %[[REM_SLICE:.*]] = tensor.extract_slice %[[MAIN_PAR]]
|
||||
// PEELED: %[[REM_FILL:.*]] = linalg.fill{{.*}}outs(%[[REM_SLICE]]
|
||||
// PEELED: %[[REM_FOR:.*]] = gml_st.for (%[[J:.*]]) = (%[[C0]]) {{.*}} outs ({{.*}} = %[[REM_FILL]]:
|
||||
// PEELED: %[[REM_PAR_REDUCE:.*]] = linalg.reduce
|
||||
// PEELED: gml_st.set_yield %[[REM_PAR_REDUCE]]
|
||||
// PEELED: gml_st.set_yield %[[REM_FOR]]
|
||||
|
||||
// -----
|
||||
|
||||
// MARKED-LABEL: func @reduce_mulf(
|
||||
// MARKED: %[[C0:.*]] = arith.constant 0 : index
|
||||
// MARKED: gml_st.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[IUB:.*]]) step
|
||||
// MARKED: gml_st.for (%[[J:.*]]) = (%[[C0]]) to (%[[JUB:.*]]) step
|
||||
// MARKED: __perfectly_tiled_loop_label__
|
||||
// MARKED: gml_st.for (%[[J:.*]]) = (%[[JUB]])
|
||||
// MARKED: } {__peeling_applied_label__}
|
||||
// MARKED: } {__peeling_applied_label__}
|
||||
|
||||
// MARKED: gml_st.parallel (%[[I:.*]]) = (%[[IUB]])
|
||||
// MARKED: gml_st.for (%[[J:.*]]) = (%[[C0]])
|
||||
// MARKED: } {__peeling_applied_label__}
|
||||
// MARKED: } {__peeling_applied_label__}
|
||||
// CHECK: gml_st.parallel
|
||||
// CHECK: linalg.fill
|
||||
// CHECK: scf.for
|
||||
// CHECK: scf.yield
|
||||
// CHECK: gml_st.set_yield
|
||||
|
||||
// -----
|
||||
|
||||
|
|
@ -120,37 +69,19 @@ func.func @reduce_map_fuse(%arg0: tensor<10x100xf32>,
|
|||
return %res : tensor<10xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @reduce_map_fuse(
|
||||
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<10x100xf32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: tensor<10x100xf32>,
|
||||
// CHECK-SAME: %[[OUT:.*]]: tensor<10xf32>)
|
||||
// CHECK-SAME: -> tensor<10xf32> {
|
||||
// CHECK-LABEL: func @reduce_map_fuse
|
||||
|
||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[INIT:.*]] = tensor.empty
|
||||
// CHECK: gml_st.parallel
|
||||
// CHECK: scf.for
|
||||
// CHECK: linalg.map
|
||||
// CHECK: linalg.reduce
|
||||
// CHECK: gml_st.set_yield
|
||||
|
||||
// CHECK: gml_st.parallel (%[[I:.*]]) = (%[[C0]])
|
||||
// CHECK: %[[ARG0_SLICE_1:.*]] = tensor.extract_slice %[[ARG0]]
|
||||
// CHECK: %[[ARG1_SLICE_1:.*]] = tensor.extract_slice %[[ARG1]]
|
||||
// CHECK: %[[INIT_SLICE_1:.*]] = tensor.extract_slice %[[INIT]]
|
||||
// CHECK: %[[OUT_SLICE_1:.*]] = tensor.extract_slice %[[OUT]]
|
||||
|
||||
// CHECK: %[[FOR:.*]] = gml_st.for (%[[J:.*]]) = (%[[C0]])
|
||||
// CHECK: %[[ARG0_SLICE_2:.*]] = tensor.extract_slice
|
||||
// CHECK: %[[ARG1_SLICE_2:.*]] = tensor.extract_slice
|
||||
// CHECK: %[[INIT_SLICE_2:.*]] = tensor.extract_slice
|
||||
// CHECK: %[[MAPPED:.*]] = linalg.map
|
||||
// CHECK-SAME: ins(%[[ARG0_SLICE_2]], %[[ARG1_SLICE_2]]
|
||||
// CHECK-SAME: outs(%[[INIT_SLICE_2]] : tensor<?x2xf32>)
|
||||
|
||||
// CHECK: %[[OUT_SLICE_2:.*]] = tensor.extract_slice
|
||||
// CHECK: %[[REDUCED:.*]] = linalg.reduce
|
||||
// CHECK-SAME: ins(%[[MAPPED]] : tensor<?x2xf32>)
|
||||
// CHECK-SAME: outs(%[[OUT_SLICE_2]] : tensor<?xf32>)
|
||||
// CHECK-SAME: dimensions = [1]
|
||||
|
||||
// CHECK: gml_st.set_yield %[[REDUCED]]
|
||||
// CHECK: gml_st.set_yield %[[FOR]]
|
||||
// CHECK: gml_st.parallel
|
||||
// CHECK: scf.for
|
||||
// CHECK: linalg.map
|
||||
// CHECK: linalg.reduce
|
||||
// CHECK: gml_st.set_yield
|
||||
|
||||
// -----
|
||||
|
||||
|
|
@ -179,31 +110,29 @@ func.func @reduce_1d_static(%arg0: tensor<100xf32>) -> tensor<f32> {
|
|||
// CHECK-DAG: %[[FILL0:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMP0]] : tensor<f32>)
|
||||
// CHECK-DAG: %[[FILL1:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMP1]] : tensor<8xf32>)
|
||||
|
||||
// CHECK: %[[TILE_RESULT:.*]] = gml_st.for (%[[I:.*]]) = (%[[C0]]) to
|
||||
// CHECK-SAME: (%[[C96]]) step (%[[C32]]) outs (%[[ACC:.*]] = %[[FILL1]]
|
||||
// CHECK: %[[TILE_RESULT:.*]] = scf.for %[[I:.*]] = %[[C0]] to
|
||||
// CHECK-SAME: %[[C96]] step %[[C32]] iter_args(%[[ACC:.*]] = %[[FILL1]]
|
||||
// CHECK: %[[INPUT_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[I]]] [32] [1]
|
||||
// CHECK: %[[SHAPED_SLICE:.*]] = tensor.expand_shape %[[INPUT_SLICE]]
|
||||
// CHECK: %[[TILED_REDUCE:.*]] = linalg.reduce
|
||||
// CHECK-SAME: ins(%[[SHAPED_SLICE]]
|
||||
// CHECK-SAME: outs(%[[ACC]]
|
||||
// CHECK-SAME: dimensions = [0]
|
||||
// CHECK: gml_st.set_yield %[[TILED_REDUCE]]
|
||||
// CHECK-NEXT: } : tensor<8xf32>
|
||||
// CHECK: scf.yield %[[TILED_REDUCE]]
|
||||
|
||||
// CHECK: %[[HORIZONTAL_REDUCE:.*]] = linalg.reduce
|
||||
// CHECK-SAME: ins(%[[TILE_RESULT]]
|
||||
// CHECK-SAME: outs(%[[FILL0]]
|
||||
// CHECK-SAME: dimensions = [0]
|
||||
|
||||
// CHECK: %[[REMAINDER_RESULT:.*]] = gml_st.for (%[[J:.*]]) = (%[[C96]]) to
|
||||
// CHECK-SAME: (%[[C100]]) step (%[[C32]]) outs (%[[ACC1:.*]] = %[[HORIZONTAL_REDUCE]]
|
||||
// CHECK: %[[REMAINDER_RESULT:.*]] = scf.for %[[J:.*]] = %[[C96]] to
|
||||
// CHECK-SAME: %[[C100]] step %[[C32]] iter_args(%[[ACC1:.*]] = %[[HORIZONTAL_REDUCE]]
|
||||
// CHECK: %[[INPUT_SLICE1:.*]] = tensor.extract_slice %[[ARG0]][%[[J]]] [%[[C4]]] [1]
|
||||
// CHECK: %[[REMAINDER_REDUCE:.*]] = linalg.reduce
|
||||
// CHECK-SAME: ins(%[[INPUT_SLICE1]]
|
||||
// CHECK-SAME: outs(%[[ACC1]]
|
||||
// CHECK-SAME: dimensions = [0]
|
||||
// CHECK: gml_st.set_yield %[[REMAINDER_REDUCE]]
|
||||
// CHECK-NEXT: } : tensor<f32>
|
||||
// CHECK: scf.yield %[[REMAINDER_REDUCE]]
|
||||
// CHECK: return %[[REMAINDER_RESULT]]
|
||||
|
||||
// -----
|
||||
|
|
@ -231,15 +160,15 @@ func.func @reduce_1d_dynamic(%arg0: tensor<?xf32>) -> tensor<f32> {
|
|||
// CHECK-DAG: %[[TILABLE_BOUND:.*]] = affine.apply #map()[%[[INPUT_SIZE]]]
|
||||
// CHECK-DAG: %[[REMAINDER_SIZE:.*]] = affine.apply #map1()[%[[TILABLE_BOUND]], %[[INPUT_SIZE]]]
|
||||
|
||||
// CHECK: %[[TILE_RESULT:.*]] = gml_st.for (%[[I:.*]]) = (%[[C0]]) to
|
||||
// CHECK-SAME: (%[[TILABLE_BOUND]]) step (%[[C32]])
|
||||
// CHECK: %[[TILE_RESULT:.*]] = scf.for %[[I:.*]] = %[[C0]] to
|
||||
// CHECK-SAME: %[[TILABLE_BOUND]] step %[[C32]]
|
||||
// CHECK: %[[TILED_REDUCE:.*]] = linalg.reduce
|
||||
// CHECK: __perfectly_tiled_loop_label__
|
||||
|
||||
// CHECK: %[[HORIZONTAL_REDUCE:.*]] = linalg.reduce
|
||||
|
||||
// CHECK: %[[REMAINDER_RESULT:.*]] = gml_st.for (%[[J:.*]]) = (%[[TILABLE_BOUND]]) to
|
||||
// CHECK-SAME: (%[[INPUT_SIZE]]) step (%[[C32]]) outs (%[[ACC1:.*]] = %[[HORIZONTAL_REDUCE]]
|
||||
// CHECK: %[[REMAINDER_RESULT:.*]] = scf.for %[[J:.*]] = %[[TILABLE_BOUND]] to
|
||||
// CHECK-SAME: %[[INPUT_SIZE]] step %[[C32]] iter_args(%[[ACC1:.*]] = %[[HORIZONTAL_REDUCE]]
|
||||
// CHECK: %[[INPUT_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[J]]] [%[[REMAINDER_SIZE]]] [1]
|
||||
|
||||
// CHECK: return %[[REMAINDER_RESULT]]
|
||||
|
|
@ -265,23 +194,20 @@ func.func @reduce_map_fuse_map(%arg0: tensor<10x100xf32>,
|
|||
return %res : tensor<10xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @reduce_map_fuse_map(
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK-LABEL: func @reduce_map_fuse_map
|
||||
|
||||
// CHECK: gml_st.parallel (%[[I:.*]]) = (%[[C0]])
|
||||
// CHECK: gml_st.for
|
||||
// CHECK: %[[MAP:.*]] = linalg.map
|
||||
// CHECK: %[[REDUCE:.*]] = linalg.reduce
|
||||
// CHECK: gml_st.set_yield %[[REDUCE]]
|
||||
// CHECK: gml_st.parallel
|
||||
// CHECK: scf.for
|
||||
// CHECK: linalg.map
|
||||
// CHECK: linalg.reduce
|
||||
// CHECK: scf.yield
|
||||
// CHECK: linalg.map
|
||||
// CHECK: gml_st.set_yield
|
||||
|
||||
// CHECK: %[[MAP:.*]] = linalg.map
|
||||
// CHECK: gml_st.set_yield %[[MAP]]
|
||||
|
||||
// CHECK: gml_st.parallel
|
||||
// CHECK: gml_st.for
|
||||
// CHECK: %[[MAP:.*]] = linalg.map
|
||||
// CHECK: %[[REDUCE:.*]] = linalg.reduce
|
||||
// CHECK: gml_st.set_yield %[[REDUCE]]
|
||||
|
||||
// CHECK: %[[MAP:.*]] = linalg.map
|
||||
// CHECK: gml_st.set_yield %[[MAP]]
|
||||
// CHECK: gml_st.parallel
|
||||
// CHECK: scf.for
|
||||
// CHECK: linalg.map
|
||||
// CHECK: linalg.reduce
|
||||
// CHECK: scf.yield
|
||||
// CHECK: linalg.map
|
||||
// CHECK: gml_st.set_yield
|
||||
|
|
|
|||
|
|
@ -15,6 +15,6 @@ func.func @scatter_small_vector_dim(%indices: tensor<?x2xindex>,
|
|||
}
|
||||
|
||||
// CHECK-LABEL: @scatter_small_vector_dim
|
||||
// CHECK: gml_st.for
|
||||
// CHECK: scf.for
|
||||
// CHECK: thlo.scatter
|
||||
// CHECK-SAME: ins(%{{.*}} : tensor<1x2xindex>, %{{.*}} : tensor<1x?x?xf32>)
|
||||
|
|
|
|||
|
|
@ -2337,6 +2337,14 @@ func.func @dynamic_slice_mismatch_indices(%arg0: tensor<3x4xi32>, %arg1: tensor<
|
|||
|
||||
// -----
|
||||
|
||||
func.func @dynamic_slice_mismatch_indices_element_type(%arg0: tensor<3x4xi32>, %arg1: tensor<i32>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
|
||||
// expected-error@+1 {{start indices must have same element type}}
|
||||
%0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i32>, tensor<i64>) -> tensor<1x4xi32>
|
||||
func.return %0 : tensor<1x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>) -> tensor<1x4xi32> {
|
||||
// expected-error@+1 {{has mismatched number of start indices (1) and the rank of operand (2)}}
|
||||
%0 = "mhlo.dynamic_slice"(%arg0, %arg1) {slice_sizes = dense<[1]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor<i64>) -> tensor<1x4xi32>
|
||||
|
|
|
|||
|
|
@ -55,9 +55,6 @@ LogicalBufferProto BufferValue::ToProto(const SizeFunction& size_fn) const {
|
|||
if (has_color()) {
|
||||
proto.set_color(color());
|
||||
}
|
||||
// TODO(b/239098765): Stop populating these fields and delete them when
|
||||
// profiler finishes adaptation.
|
||||
proto.mutable_defined_at()->set_instruction_name(instruction()->name());
|
||||
return proto;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3384,5 +3384,83 @@ ROOT %arg_tuple.1 = (f32[]{:T(256)}, f32[]{:T(256)}) parameter(0), parameter_rep
|
|||
VLOG(2) << module->ToString();
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, AsyncCallDUSNoCopy) {
|
||||
const char* const kModuleString = R"(
|
||||
HloModule async_call
|
||||
|
||||
%called_computation {
|
||||
%out_param = s32[1024]{0} parameter(1)
|
||||
%input = s32[1024]{0} parameter(0)
|
||||
%size = s32[] constant(256)
|
||||
%index = s32[] custom-call(), custom_call_target="Baz"
|
||||
%start = s32[] multiply(s32[] %size, s32[] %index)
|
||||
%input2 = s32[256]{0} dynamic-slice(s32[1024]{0} %input, s32[] %start), dynamic_slice_sizes={256}
|
||||
%output = s32[256]{0} add(s32[256]{0} %input2, s32[256]{0} %input2)
|
||||
ROOT %output2 = s32[1024]{0} dynamic-update-slice(s32[1024]{0} %out_param, s32[256]{0} %output, s32[] %start)
|
||||
}, execution_thread="foobar"
|
||||
|
||||
%async_wrapped {
|
||||
%async_param = s32[1024]{0} parameter(0)
|
||||
%async_param.1 = s32[1024]{0} parameter(1)
|
||||
ROOT %call = s32[1024]{0} call(s32[1024]{0} %async_param, s32[1024]{0} %async_param.1), to_apply=%called_computation
|
||||
}, execution_thread="foobar"
|
||||
|
||||
ENTRY %main {
|
||||
%input.1 = s32[1024]{0} parameter(0)
|
||||
%buf = s32[1024]{0} custom-call(), custom_call_target="AllocateBuffer"
|
||||
%async-start = ((s32[1024]{0}, s32[1024]{0}), s32[1024]{0}, u32[]) async-start(s32[1024]{0} %input.1, s32[1024]{0} %buf), async_group_id=0, async_execution_thread="foobar", calls=%async_wrapped
|
||||
ROOT %async-done = s32[1024]{0} async-done(((s32[1024]{0}, s32[1024]{0}), s32[1024]{0}, u32[]) %async-start), async_group_id=0, async_execution_thread="foobar", calls=%async_wrapped
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(kModuleString));
|
||||
|
||||
CopyInsertion copy_insertion(nullptr,
|
||||
/*use_region_based_live_range_analysis=*/-1);
|
||||
ASSERT_IS_OK(copy_insertion.Run(module.get(), {"foobar"}).status());
|
||||
VLOG(2) << module->ToString();
|
||||
EXPECT_EQ(CountCopies(*module), 0);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, AsyncCallDUSCopy) {
|
||||
const char* const kModuleString = R"(
|
||||
HloModule async_call
|
||||
|
||||
%called_computation {
|
||||
%out_param = s32[1024]{0} parameter(1)
|
||||
%input = s32[1024]{0} parameter(0)
|
||||
%size = s32[] constant(256)
|
||||
%index = s32[] custom-call(), custom_call_target="Baz"
|
||||
%start = s32[] multiply(s32[] %size, s32[] %index)
|
||||
%input2 = s32[256]{0} dynamic-slice(s32[1024]{0} %input, s32[] %start), dynamic_slice_sizes={256}
|
||||
%output = s32[256]{0} add(s32[256]{0} %input2, s32[256]{0} %input2)
|
||||
ROOT %output2 = s32[1024]{0} dynamic-update-slice(s32[1024]{0} %out_param, s32[256]{0} %output, s32[] %start)
|
||||
}, execution_thread="foobar"
|
||||
|
||||
%async_wrapped {
|
||||
%async_param = s32[1024]{0} parameter(0)
|
||||
%async_param.1 = s32[1024]{0} parameter(1)
|
||||
ROOT %call = s32[1024]{0} call(s32[1024]{0} %async_param, s32[1024]{0} %async_param.1), to_apply=%called_computation
|
||||
}, execution_thread="foobar"
|
||||
|
||||
ENTRY %main {
|
||||
%input.1 = s32[1024]{0} parameter(0)
|
||||
%input.2 = s32[1024]{0} parameter(1)
|
||||
%async-start = ((s32[1024]{0}, s32[1024]{0}), s32[1024]{0}, u32[]) async-start(s32[1024]{0} %input.1, s32[1024]{0} %input.2), async_group_id=0, async_execution_thread="foobar", calls=%async_wrapped
|
||||
ROOT %async-done = s32[1024]{0} async-done(((s32[1024]{0}, s32[1024]{0}), s32[1024]{0}, u32[]) %async-start), async_group_id=0, async_execution_thread="foobar", calls=%async_wrapped
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(kModuleString));
|
||||
|
||||
CopyInsertion copy_insertion(nullptr,
|
||||
/*use_region_based_live_range_analysis=*/-1);
|
||||
ASSERT_IS_OK(copy_insertion.Run(module.get(), {"foobar"}).status());
|
||||
VLOG(2) << module->ToString();
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
|
|
|||
|
|
@ -369,7 +369,6 @@ cc_library(
|
|||
"//tensorflow/tsl/platform:logging",
|
||||
"//tensorflow/tsl/platform:human_readable_json",
|
||||
"//tensorflow/tsl/platform:status",
|
||||
"//tensorflow/tsl/profiler/lib:nvtx_utils",
|
||||
"//tensorflow/tsl/protobuf:dnn_proto_cc",
|
||||
] + if_gpu_is_configured([
|
||||
":triangular_solve_thunk",
|
||||
|
|
@ -714,7 +713,6 @@ cc_library(
|
|||
compatible_with = get_compatible_with_cloud(),
|
||||
deps = [
|
||||
":target_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/hlo/ir:hlo",
|
||||
"//tensorflow/compiler/xla/mlir_hlo",
|
||||
"//tensorflow/compiler/xla/mlir_hlo:lhlo",
|
||||
|
|
@ -722,7 +720,6 @@ cc_library(
|
|||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_type_conversion_util",
|
||||
"//tensorflow/compiler/xla/stream_executor",
|
||||
"//tensorflow/compiler/xla/translate/hlo_to_mhlo:hlo_utils",
|
||||
"//tensorflow/compiler/xla/translate/mhlo_to_hlo:type_to_shape",
|
||||
"@llvm-project//llvm:Core",
|
||||
"@llvm-project//mlir:ArithDialect",
|
||||
|
|
@ -1248,16 +1245,13 @@ cc_library(
|
|||
srcs = ["instruction_fusion.cc"],
|
||||
hdrs = ["instruction_fusion.h"],
|
||||
deps = [
|
||||
":gpu_device_info",
|
||||
":gpu_fusible",
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/hlo/ir:hlo",
|
||||
"//tensorflow/compiler/xla/service:fusion_node_indexing_evaluation",
|
||||
"//tensorflow/compiler/xla/service:hlo_query",
|
||||
"//tensorflow/compiler/xla/service:instruction_fusion",
|
||||
"//tensorflow/compiler/xla/service:pattern_matcher",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
|
@ -1269,13 +1263,11 @@ xla_cc_test(
|
|||
srcs = ["instruction_fusion_test.cc"],
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
":gpu_device_info_for_tests",
|
||||
":gpu_fusible",
|
||||
":instruction_fusion",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/hlo/ir:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
|
|
@ -1291,7 +1283,6 @@ cc_library(
|
|||
":gpu_fusible",
|
||||
":gpu_hlo_cost_analysis",
|
||||
":gpu_performance_model",
|
||||
":instruction_fusion",
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
|
|
@ -1300,8 +1291,6 @@ cc_library(
|
|||
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/compiler/xla/service:hlo_reachability",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
|
||||
"//tensorflow/tsl/platform:logging",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
|
|
@ -1413,16 +1402,13 @@ cc_library(
|
|||
":gpu_fusible",
|
||||
":gpu_hlo_cost_analysis",
|
||||
":gpu_performance_model",
|
||||
":instruction_fusion",
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/hlo/ir:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
|
||||
"//tensorflow/tsl/platform:errors",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
|
@ -2391,6 +2377,7 @@ cc_library(
|
|||
srcs = ["gpu_fusible.cc"],
|
||||
hdrs = ["gpu_fusible.h"],
|
||||
deps = [
|
||||
":gpu_device_info",
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/hlo/ir:hlo",
|
||||
|
|
@ -2633,10 +2620,9 @@ xla_cc_test(
|
|||
srcs = ["horizontal_loop_fusion_test.cc"],
|
||||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":fusion_merger",
|
||||
":gpu_device_info_for_tests",
|
||||
":horizontal_loop_fusion",
|
||||
":instruction_fusion",
|
||||
":multi_output_fusion",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
|
|
@ -2648,7 +2634,6 @@ xla_cc_test(
|
|||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
|
||||
"//tensorflow/compiler/xla/service:tuple_simplifier",
|
||||
"//tensorflow/compiler/xla/tests:filecheck",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/tsl/lib/core:status_test_util",
|
||||
|
|
@ -2660,16 +2645,12 @@ cc_library(
|
|||
srcs = ["horizontal_input_fusion.cc"],
|
||||
hdrs = ["horizontal_input_fusion.h"],
|
||||
deps = [
|
||||
":gpu_device_info",
|
||||
":gpu_fusible",
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/hlo/ir:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_creation_utils",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/tsl/platform:errors",
|
||||
"//tensorflow/tsl/platform:logging",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
|
@ -2679,16 +2660,13 @@ xla_cc_test(
|
|||
srcs = ["horizontal_input_fusion_test.cc"],
|
||||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":gpu_device_info_for_tests",
|
||||
":horizontal_input_fusion",
|
||||
":multi_output_fusion",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
|
||||
"//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test",
|
||||
"//tensorflow/compiler/xla/tests:filecheck",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -240,8 +240,8 @@ FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) {
|
|||
// Skip 'fusion' instruction if merging it into at least one of the users
|
||||
// would make the fusion use too much shared memory or registers.
|
||||
FusionDecision fits = FusionFitsInBudget(
|
||||
*user, *producer, /*is_consumer_producer_fusion=*/true,
|
||||
&fusion_info_cache_);
|
||||
*user, *producer, gpu_device_info_,
|
||||
/*is_consumer_producer_fusion=*/true, &fusion_info_cache_);
|
||||
if (!fits) {
|
||||
++num_fail_fusion_too_large_;
|
||||
return fits;
|
||||
|
|
|
|||
|
|
@ -43,7 +43,6 @@ limitations under the License.
|
|||
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h"
|
||||
|
|
@ -703,6 +702,8 @@ Status GpuCompiler::OptimizeHloModule(
|
|||
OptimizeHloPostLayoutAssignment(hlo_module, stream_exec, device_allocator,
|
||||
gpu_target_config, autotune_results));
|
||||
|
||||
const GpuDeviceInfo& gpu_device_info = gpu_target_config.gpu_device_info;
|
||||
|
||||
{
|
||||
HloPassFix<HloPassPipeline> fusion("fusion");
|
||||
// We try to split variadic ops with many parameters into several such ops
|
||||
|
|
@ -713,9 +714,10 @@ Status GpuCompiler::OptimizeHloModule(
|
|||
HloVerifierOpts{}.MakeLayoutSensitive().WithInstructionCanChangeLayout(
|
||||
LayoutAssignment::InstructionCanChangeLayout),
|
||||
/*debug_only=*/true);
|
||||
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
|
||||
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
|
||||
const GpuDeviceInfo gpu_device_info = gpu_target_config.gpu_device_info;
|
||||
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false,
|
||||
gpu_device_info);
|
||||
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true,
|
||||
gpu_device_info);
|
||||
fusion.AddPass<FusionMerger>(gpu_device_info, ShapeSizeBytesFunction());
|
||||
fusion.AddPass<GpuMultiOutputFusion>(gpu_device_info,
|
||||
ShapeSizeBytesFunction());
|
||||
|
|
@ -728,7 +730,7 @@ Status GpuCompiler::OptimizeHloModule(
|
|||
{
|
||||
HloPassFix<HloPassPipeline> horizontal_fusion("horizontal fusion");
|
||||
horizontal_fusion.AddPass<GpuHorizontalLoopFusion>();
|
||||
horizontal_fusion.AddPass<GpuHorizontalInputFusion>();
|
||||
horizontal_fusion.AddPass<GpuHorizontalInputFusion>(gpu_device_info);
|
||||
horizontal_fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
|
||||
/*only_fusion_computations=*/true);
|
||||
horizontal_fusion.AddPass<HloDCE>();
|
||||
|
|
|
|||
|
|
@ -486,13 +486,14 @@ static int64_t NumUnnestedReductions(const HloInstruction& instr,
|
|||
// to true to enable more fusion.
|
||||
FusionDecision FusionFitsInBudget(const HloInstruction& instr1,
|
||||
const HloInstruction& instr2,
|
||||
const GpuDeviceInfo& device_info,
|
||||
bool is_consumer_producer_fusion,
|
||||
FusionInfoCache* cache /*=nullptr*/) {
|
||||
if (SharedMemoryUsage(instr1, cache) + SharedMemoryUsage(instr2, cache) >
|
||||
kSharedMemoryBudgetInBytes) {
|
||||
device_info.shared_memory_per_block) {
|
||||
return FusionDecision{}
|
||||
<< "shared memory usage would be over the budget of "
|
||||
<< kSharedMemoryBudgetInBytes << "B";
|
||||
<< device_info.shared_memory_per_block << "B";
|
||||
}
|
||||
|
||||
if (NumUnnestedReductions(instr1, cache) +
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_FUSIBLE_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h"
|
||||
#include "tensorflow/compiler/xla/service/instruction_fusion.h"
|
||||
|
||||
// TODO(b/112957171): Extract logic to determine fusibility of HLO ops from
|
||||
|
|
@ -92,6 +93,7 @@ bool IsInputFusibleScatter(const HloInstruction& instr);
|
|||
// the producer, set consumer_producer_fusion to true to enable more fusion.
|
||||
FusionDecision FusionFitsInBudget(const HloInstruction& instr1,
|
||||
const HloInstruction& instr2,
|
||||
const GpuDeviceInfo& device_info,
|
||||
bool is_consumer_producer_fusion = false,
|
||||
FusionInfoCache* cache = nullptr);
|
||||
|
||||
|
|
|
|||
|
|
@ -18,13 +18,9 @@ limitations under the License.
|
|||
#include <algorithm>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
|
||||
#include "tensorflow/tsl/platform/errors.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
|
@ -46,8 +42,9 @@ Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) {
|
|||
|
||||
class HorizontalInputFusionImpl {
|
||||
public:
|
||||
explicit HorizontalInputFusionImpl(HloComputation* computation)
|
||||
: computation_(computation) {}
|
||||
explicit HorizontalInputFusionImpl(HloComputation* computation,
|
||||
const GpuDeviceInfo& d)
|
||||
: computation_(computation), device_info_(d) {}
|
||||
|
||||
~HorizontalInputFusionImpl() {}
|
||||
|
||||
|
|
@ -55,6 +52,7 @@ class HorizontalInputFusionImpl {
|
|||
|
||||
private:
|
||||
HloComputation* computation_;
|
||||
const GpuDeviceInfo device_info_;
|
||||
}; // HorizontalInputFusionImpl
|
||||
|
||||
// Compares one-by-one the dimensions of `shape_a` and `shape_b` from left to
|
||||
|
|
@ -138,7 +136,7 @@ StatusOr<bool> HorizontalInputFusionImpl::Run() {
|
|||
HloInstruction* fusion_anchor = candidates[fusion_anchor_id];
|
||||
HloInstruction* fused = candidates[j];
|
||||
if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) &&
|
||||
FusionFitsInBudget(*fusion_anchor, *fused)) {
|
||||
FusionFitsInBudget(*fusion_anchor, *fused, device_info_)) {
|
||||
VLOG(3) << "Fuse " << fused->ToString() << " into "
|
||||
<< fusion_anchor->ToString();
|
||||
fusion_anchor->MergeFusionInstructionIntoMultiOutput(fused);
|
||||
|
|
@ -159,7 +157,7 @@ StatusOr<bool> HorizontalInputFusionImpl::Run() {
|
|||
|
||||
StatusOr<bool> GpuHorizontalInputFusion::RunOnComputation(
|
||||
HloComputation* computation) {
|
||||
HorizontalInputFusionImpl horizontal_fusion_impl(computation);
|
||||
HorizontalInputFusionImpl horizontal_fusion_impl(computation, device_info_);
|
||||
return horizontal_fusion_impl.Run();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@ limitations under the License.
|
|||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
|
|
@ -38,7 +38,7 @@ namespace gpu {
|
|||
// ROOT tuple of the entry computation.
|
||||
class GpuHorizontalInputFusion : public HloModulePass {
|
||||
public:
|
||||
GpuHorizontalInputFusion() {}
|
||||
explicit GpuHorizontalInputFusion(const GpuDeviceInfo& d) : device_info_(d) {}
|
||||
|
||||
absl::string_view name() const override {
|
||||
return "gpu_horizontal_input_fusion";
|
||||
|
|
@ -51,6 +51,8 @@ class GpuHorizontalInputFusion : public HloModulePass {
|
|||
|
||||
private:
|
||||
StatusOr<bool> RunOnComputation(HloComputation*);
|
||||
|
||||
const GpuDeviceInfo device_info_;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
|
|
|
|||
|
|
@ -15,14 +15,11 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_device_info_for_tests.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/filecheck.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
|
@ -30,7 +27,11 @@ namespace {
|
|||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
class HorizontalInputFusionTest : public GpuCodegenTest {};
|
||||
class HorizontalInputFusionTest : public GpuCodegenTest {
|
||||
public:
|
||||
GpuHorizontalInputFusion horizontal_input_fusion_{
|
||||
TestGpuDeviceInfo::RTXA6000DeviceInfo()};
|
||||
};
|
||||
|
||||
TEST_F(HorizontalInputFusionTest, BasicTest) {
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
|
|
@ -64,7 +65,7 @@ TEST_F(HorizontalInputFusionTest, BasicTest) {
|
|||
)")
|
||||
.value();
|
||||
|
||||
EXPECT_TRUE(GpuHorizontalInputFusion().Run(module.get()).value());
|
||||
EXPECT_TRUE(horizontal_input_fusion_.Run(module.get()).value());
|
||||
|
||||
const HloInstruction* entry_root =
|
||||
module->entry_computation()->root_instruction();
|
||||
|
|
@ -208,7 +209,7 @@ TEST_F(HorizontalInputFusionTest, MultiOutputFusionTest) {
|
|||
)")
|
||||
.value();
|
||||
|
||||
EXPECT_TRUE(GpuHorizontalInputFusion().Run(module.get()).value());
|
||||
EXPECT_TRUE(horizontal_input_fusion_.Run(module.get()).value());
|
||||
}
|
||||
|
||||
TEST_F(HorizontalInputFusionTest, NonfusionInstrs) {
|
||||
|
|
@ -232,7 +233,7 @@ TEST_F(HorizontalInputFusionTest, NonfusionInstrs) {
|
|||
)")
|
||||
.value();
|
||||
|
||||
EXPECT_TRUE(GpuHorizontalInputFusion().Run(module.get()).value());
|
||||
EXPECT_TRUE(horizontal_input_fusion_.Run(module.get()).value());
|
||||
|
||||
const HloInstruction* entry_root =
|
||||
module->entry_computation()->root_instruction();
|
||||
|
|
|
|||
|
|
@ -16,9 +16,8 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_device_info_for_tests.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
|
|
@ -28,7 +27,6 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/filecheck.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/tsl/lib/core/status_test_util.h"
|
||||
|
||||
|
|
@ -192,8 +190,11 @@ TEST_F(HorizontalLoopFusionTest, HorizontalLoopFusionAfterVerticalFusion) {
|
|||
.value();
|
||||
|
||||
HloPassPipeline fusion("fusion");
|
||||
fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/false);
|
||||
fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/true);
|
||||
const GpuDeviceInfo device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
|
||||
fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/false,
|
||||
device_info);
|
||||
fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/true,
|
||||
device_info);
|
||||
EXPECT_TRUE(fusion.Run(module.get()).value());
|
||||
EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value());
|
||||
TF_ASSERT_OK(verifier().Run(module.get()).status());
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ FusionDecision GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
|
|||
|
||||
// The following checks are potentially expensive.
|
||||
if (NoFusionPossible too_large =
|
||||
!FusionFitsInBudget(*consumer, *producer,
|
||||
!FusionFitsInBudget(*consumer, *producer, device_info_,
|
||||
/*is_consumer_producer_fusion=*/true)) {
|
||||
return !too_large;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h"
|
||||
#include "tensorflow/compiler/xla/service/instruction_fusion.h"
|
||||
|
||||
namespace xla {
|
||||
|
|
@ -30,8 +31,9 @@ namespace gpu {
|
|||
|
||||
class GpuInstructionFusion : public InstructionFusion {
|
||||
public:
|
||||
explicit GpuInstructionFusion(bool may_duplicate)
|
||||
: InstructionFusion(GpuInstructionFusion::IsExpensive, may_duplicate) {}
|
||||
explicit GpuInstructionFusion(bool may_duplicate, const GpuDeviceInfo& d)
|
||||
: InstructionFusion(GpuInstructionFusion::IsExpensive, may_duplicate),
|
||||
device_info_(d) {}
|
||||
|
||||
static bool IsExpensive(const HloInstruction& instruction);
|
||||
|
||||
|
|
@ -69,6 +71,8 @@ class GpuInstructionFusion : public InstructionFusion {
|
|||
// indexed with different index vectors.
|
||||
absl::flat_hash_map<const HloInstruction*, FusionNodeIndexingEvaluation>
|
||||
fusion_node_evaluations_;
|
||||
|
||||
const GpuDeviceInfo device_info_;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
|
|
|
|||
|
|
@ -15,11 +15,9 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/hlo/ir/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_device_info_for_tests.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
|
@ -29,7 +27,11 @@ namespace op = xla::testing::opcode_matchers;
|
|||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
using InstructionFusionTest = HloTestBase;
|
||||
class InstructionFusionTest : public HloTestBase {
|
||||
public:
|
||||
GpuInstructionFusion duplicating_instruction_fusion_{
|
||||
/*may_duplicate=*/true, TestGpuDeviceInfo::RTXA6000DeviceInfo()};
|
||||
};
|
||||
|
||||
TEST_F(InstructionFusionTest,
|
||||
CostlyProducerAndOperandElementReusingConsumerNotFused) {
|
||||
|
|
@ -45,8 +47,7 @@ TEST_F(InstructionFusionTest,
|
|||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
EXPECT_EQ(broadcast2, computation->root_instruction());
|
||||
EXPECT_FALSE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
EXPECT_EQ(broadcast2, computation->root_instruction());
|
||||
}
|
||||
|
||||
|
|
@ -64,8 +65,7 @@ TEST_F(InstructionFusionTest,
|
|||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
EXPECT_EQ(broadcast2, computation->root_instruction());
|
||||
EXPECT_TRUE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
EXPECT_THAT(computation->root_instruction(), op::Fusion());
|
||||
}
|
||||
|
||||
|
|
@ -82,8 +82,7 @@ TEST_F(InstructionFusionTest,
|
|||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
EXPECT_EQ(reshape2, computation->root_instruction());
|
||||
EXPECT_TRUE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
EXPECT_THAT(computation->root_instruction(), op::Fusion());
|
||||
}
|
||||
|
||||
|
|
@ -100,8 +99,7 @@ TEST_F(InstructionFusionTest,
|
|||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
EXPECT_EQ(transpose2, computation->root_instruction());
|
||||
EXPECT_TRUE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
EXPECT_THAT(computation->root_instruction(), op::Fusion());
|
||||
}
|
||||
|
||||
|
|
@ -119,8 +117,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotFused) {
|
|||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
EXPECT_EQ(log, computation->root_instruction());
|
||||
EXPECT_TRUE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) {
|
||||
|
|
@ -135,8 +132,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) {
|
|||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
EXPECT_EQ(transpose2, computation->root_instruction());
|
||||
EXPECT_FALSE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
}
|
||||
|
||||
// Tests that broadcasts fused into a fusion with a reduce root.
|
||||
|
|
@ -159,8 +155,7 @@ TEST_F(InstructionFusionTest, BroadcastIntoReduce) {
|
|||
})")
|
||||
.value();
|
||||
|
||||
EXPECT_TRUE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Fusion());
|
||||
|
|
@ -186,8 +181,7 @@ TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduce) {
|
|||
})")
|
||||
.value();
|
||||
|
||||
EXPECT_FALSE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduceFusion) {
|
||||
|
|
@ -215,8 +209,7 @@ TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduceFusion) {
|
|||
})")
|
||||
.value();
|
||||
|
||||
EXPECT_FALSE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, DoNotRepeatLargeReduceWindow) {
|
||||
|
|
@ -241,8 +234,7 @@ TEST_F(InstructionFusionTest, DoNotRepeatLargeReduceWindow) {
|
|||
})")
|
||||
.value();
|
||||
|
||||
EXPECT_FALSE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, FuseLayoutChangingOpWithElementwise) {
|
||||
|
|
@ -255,8 +247,7 @@ TEST_F(InstructionFusionTest, FuseLayoutChangingOpWithElementwise) {
|
|||
})")
|
||||
.value();
|
||||
|
||||
EXPECT_TRUE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Fusion());
|
||||
|
|
@ -275,8 +266,7 @@ TEST_F(InstructionFusionTest, BitcastIntoAdd) {
|
|||
})")
|
||||
.value();
|
||||
|
||||
EXPECT_TRUE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Fusion());
|
||||
|
|
@ -296,8 +286,7 @@ TEST_F(InstructionFusionTest, AddIntoBitcast) {
|
|||
})")
|
||||
.value();
|
||||
|
||||
EXPECT_TRUE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Fusion());
|
||||
|
|
@ -316,8 +305,7 @@ TEST_F(InstructionFusionTest, DontFuseGTE) {
|
|||
})")
|
||||
.value();
|
||||
|
||||
EXPECT_FALSE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
}
|
||||
|
||||
// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is
|
||||
|
|
@ -341,8 +329,7 @@ TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) {
|
|||
})")
|
||||
.value();
|
||||
|
||||
EXPECT_TRUE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Tuple(op::Fusion(), op::Fusion()))
|
||||
|
|
@ -371,8 +358,7 @@ TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) {
|
|||
})")
|
||||
.value();
|
||||
|
||||
EXPECT_FALSE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value())
|
||||
EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value())
|
||||
<< module->ToString();
|
||||
}
|
||||
|
||||
|
|
@ -390,8 +376,7 @@ TEST_F(InstructionFusionTest, DotOutputFusionImpossible) {
|
|||
})")
|
||||
.value();
|
||||
|
||||
EXPECT_TRUE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Fusion());
|
||||
|
|
@ -444,8 +429,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusion) {
|
|||
|
||||
// Multi-output fusion is disabled here and performed in the
|
||||
// GpuMultiOutputFusion pass instead.
|
||||
ASSERT_FALSE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
ASSERT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, FuseScalarConstant) {
|
||||
|
|
@ -462,8 +446,7 @@ TEST_F(InstructionFusionTest, FuseScalarConstant) {
|
|||
})")
|
||||
.value();
|
||||
|
||||
EXPECT_TRUE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Fusion());
|
||||
|
|
@ -491,8 +474,7 @@ TEST_F(InstructionFusionTest, AvoidsLargeFusion) {
|
|||
}
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(b.Build());
|
||||
EXPECT_TRUE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
SCOPED_TRACE(module->ToString());
|
||||
for (const HloInstruction* instr : computation->instructions()) {
|
||||
EXPECT_LE(instr->operand_count(), MaxOperandsAndOutputsPerFusion())
|
||||
|
|
@ -527,8 +509,7 @@ TEST_F(InstructionFusionTest, FuseIntoScatter) {
|
|||
})")
|
||||
.value();
|
||||
|
||||
EXPECT_TRUE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Add(op::Fusion(), op::Fusion()));
|
||||
|
|
@ -557,8 +538,7 @@ TEST_F(InstructionFusionTest, NonscalarConstantsNotFused) {
|
|||
})")
|
||||
.value();
|
||||
|
||||
EXPECT_TRUE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
// The f32[16] constant should not be fused into the reduce, but the f32[]
|
||||
// constant should be.
|
||||
auto* root = module->entry_computation()->root_instruction();
|
||||
|
|
@ -578,8 +558,7 @@ TEST_F(InstructionFusionTest, FuseReverse) {
|
|||
})")
|
||||
.value();
|
||||
|
||||
EXPECT_TRUE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Fusion());
|
||||
|
|
@ -698,8 +677,7 @@ TEST_F(InstructionFusionTest, FloatingPointExpIsCheap) {
|
|||
})")
|
||||
.value();
|
||||
|
||||
EXPECT_TRUE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Tuple(op::Fusion(), op::Fusion()))
|
||||
|
|
@ -725,8 +703,7 @@ TEST_F(InstructionFusionTest, SmallReducedDimensionIsNotLoweredToLoop) {
|
|||
})")
|
||||
.value();
|
||||
|
||||
EXPECT_TRUE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
ASSERT_THAT(root, op::Fusion());
|
||||
|
|
@ -767,8 +744,7 @@ TEST_F(InstructionFusionTest, DontTouchSoftmaxCustomCall) {
|
|||
ROOT %custom-call = f32[554112,10]{1,0} custom-call(f32[554112,10]{1,0} %param_0), custom_call_target="__softmax_fusion", called_computations={%softmax_computation}
|
||||
})")
|
||||
.value();
|
||||
EXPECT_FALSE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/true).Run(module.get()).value());
|
||||
EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, IotaIntoVariadicReduction) {
|
||||
|
|
@ -806,8 +782,10 @@ TEST_F(InstructionFusionTest, IotaIntoVariadicReduction) {
|
|||
})")
|
||||
.value();
|
||||
|
||||
EXPECT_TRUE(
|
||||
GpuInstructionFusion(/*may_duplicate=*/false).Run(module.get()).value());
|
||||
EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false,
|
||||
TestGpuDeviceInfo::RTXA6000DeviceInfo())
|
||||
.Run(module.get())
|
||||
.value());
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Fusion(op::Parameter()));
|
||||
EXPECT_THAT(
|
||||
|
|
|
|||
|
|
@ -31,12 +31,6 @@ limitations under the License.
|
|||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
// The amount of shared memory a CUDA kernel can use.
|
||||
//
|
||||
// Stay on the conservative side, this is smaller than full 64kB, but allows
|
||||
// some extra space for cache.
|
||||
inline constexpr int64_t kSharedMemoryBudgetInBytes = 48 * 1024;
|
||||
|
||||
// If a dimensions is smaller than this, untiled transposition may be more
|
||||
// efficient.
|
||||
inline constexpr int64_t kMinDimensionToTransposeTiled = 16;
|
||||
|
|
|
|||
|
|
@ -131,7 +131,6 @@ limitations under the License.
|
|||
#include "tensorflow/tsl/platform/errors.h"
|
||||
#include "tensorflow/tsl/platform/human_readable_json.h"
|
||||
#include "tensorflow/tsl/platform/logging.h"
|
||||
#include "tensorflow/tsl/profiler/lib/nvtx_utils.h"
|
||||
#include "tensorflow/tsl/protobuf/dnn.pb.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
|
@ -4569,12 +4568,8 @@ static bool CanVectorizeReduction(
|
|||
se::CudaComputeCapability cc, mlir::lmhlo::FusionOp fusion,
|
||||
HloComputation* fused_computation,
|
||||
const ReductionDimensions& reduction_dimensions, int num_threads_x,
|
||||
Vector3 reduction_tiling, const Shape& input_shape, int64_t shmem_usage,
|
||||
Vector3 reduction_tiling, const Shape& input_shape,
|
||||
bool reduction_is_race_free) {
|
||||
// Vectorization might cause us to run out of budget.
|
||||
if (shmem_usage * 2 > kSharedMemoryBudgetInBytes) {
|
||||
return false;
|
||||
}
|
||||
if (!reduction_dimensions.is_row_reduction) {
|
||||
return IsUnrollingColumnReductionBeneficial(
|
||||
fusion, fused_computation, input_shape,
|
||||
|
|
@ -4678,10 +4673,15 @@ StatusOr<ReductionCodegenInfo> IrEmitterUnnested::ComputeReductionCodegenInfo(
|
|||
: kLinearIndexingX;
|
||||
int64_t shmem_usage =
|
||||
ProjectedShmemUsageBytes(reduction_dimensions, instr_index_groups);
|
||||
const int64_t shmem_budget =
|
||||
ir_emitter_context_->gpu_device_info().shared_memory_per_block;
|
||||
bool reduction_is_race_free = ReductionIsRaceFree(reduction_dimensions);
|
||||
bool vectorize = CanVectorizeReduction(
|
||||
cc, fusion, fused_computation, reduction_dimensions, num_threads_x,
|
||||
reduction_tiling, input_shape, shmem_usage, reduction_is_race_free);
|
||||
bool vectorize =
|
||||
// Vectorization might cause us to run out of budget.
|
||||
(shmem_usage * 2 <= shmem_budget) &&
|
||||
CanVectorizeReduction(cc, fusion, fused_computation, reduction_dimensions,
|
||||
num_threads_x, reduction_tiling, input_shape,
|
||||
reduction_is_race_free);
|
||||
int vector_size = vectorize ? 2 : 1;
|
||||
|
||||
int num_partial_results = 1;
|
||||
|
|
@ -4707,7 +4707,7 @@ StatusOr<ReductionCodegenInfo> IrEmitterUnnested::ComputeReductionCodegenInfo(
|
|||
}
|
||||
}
|
||||
|
||||
while (shmem_usage * num_partial_results > kSharedMemoryBudgetInBytes) {
|
||||
while (shmem_usage * num_partial_results > shmem_budget) {
|
||||
num_partial_results /= 2;
|
||||
if (num_partial_results == 1) {
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ bool IsProfitableOperand(HloInstruction* instr) {
|
|||
}
|
||||
|
||||
FusionDecision LegalToFuse(HloInstruction* instr1, HloInstruction* instr2,
|
||||
const GpuDeviceInfo& device_info,
|
||||
FusionInfoCache* fusion_info_cache) {
|
||||
CHECK(instr1->opcode() == HloOpcode::kFusion);
|
||||
|
||||
|
|
@ -66,7 +67,7 @@ FusionDecision LegalToFuse(HloInstruction* instr1, HloInstruction* instr2,
|
|||
}
|
||||
|
||||
// Do this check last, as it may be expensive.
|
||||
return FusionFitsInBudget(*instr1, *instr2,
|
||||
return FusionFitsInBudget(*instr1, *instr2, device_info,
|
||||
/*is_consumer_producer_fusion=*/false,
|
||||
fusion_info_cache);
|
||||
}
|
||||
|
|
@ -161,7 +162,7 @@ std::vector<HloInstruction*> GetProducerConsumerMultiOutputFusionCandidates(
|
|||
<< " would introduce a cycle when fused.");
|
||||
continue;
|
||||
}
|
||||
if (!FusionFitsInBudget(*producer, *consumer,
|
||||
if (!FusionFitsInBudget(*producer, *consumer, device_info,
|
||||
/*is_consumer_producer_fusion=*/false,
|
||||
fusion_info_cache)) {
|
||||
dump_negative_explanation(
|
||||
|
|
@ -260,7 +261,7 @@ bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent,
|
|||
if (NoFusionPossible sibling_fusible =
|
||||
(!IsSiblingFusionCandidate(*j) || !is_disconnected(*i, *j) ||
|
||||
!ShapesCompatibleForMultiOutputFusion(*(*i), *(*j)) ||
|
||||
!LegalToFuse(*i, *j, fusion_info_cache))) {
|
||||
!LegalToFuse(*i, *j, device_info_, fusion_info_cache))) {
|
||||
// We pick `j` arbitrarily as a consumer.
|
||||
if (dump_fusion) {
|
||||
RegisterFusionState(
|
||||
|
|
|
|||
|
|
@ -643,6 +643,7 @@ xla_cc_test(
|
|||
":gpu_codegen_test",
|
||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/service/gpu:gpu_device_info_for_tests",
|
||||
"//tensorflow/compiler/xla/service/gpu:gpu_fusible",
|
||||
"//tensorflow/compiler/xla/service/gpu:instruction_fusion",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
|
|
@ -657,12 +658,9 @@ xla_cc_test(
|
|||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":gpu_codegen_test",
|
||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
|
||||
"//tensorflow/compiler/xla/service/gpu:fusion_merger",
|
||||
"//tensorflow/compiler/xla/service/gpu:gpu_device_info_for_tests",
|
||||
"//tensorflow/compiler/xla/service/gpu:gpu_fusible",
|
||||
"//tensorflow/compiler/xla/service/gpu:instruction_fusion",
|
||||
"//tensorflow/compiler/xla/service/gpu:multi_output_fusion",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
|
|
|
|||
|
|
@ -19,12 +19,9 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_device_info_for_tests.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/tsl/platform/test.h"
|
||||
|
|
@ -45,12 +42,13 @@ class GpuFusionPipelineTest : public GpuCodegenTest {
|
|||
void CheckGpuFusionPipeline(absl::string_view hlo,
|
||||
std::optional<absl::string_view> expected) {
|
||||
HloPassPipeline pipeline("gpu-fusion");
|
||||
pipeline.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
|
||||
pipeline.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
|
||||
pipeline.AddPass<FusionMerger>(TestGpuDeviceInfo::RTXA6000DeviceInfo(),
|
||||
ShapeSizeBytesFunction());
|
||||
pipeline.AddPass<GpuMultiOutputFusion>(
|
||||
TestGpuDeviceInfo::RTXA6000DeviceInfo(), ShapeSizeBytesFunction());
|
||||
const GpuDeviceInfo device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
|
||||
pipeline.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false,
|
||||
device_info);
|
||||
pipeline.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true, device_info);
|
||||
pipeline.AddPass<FusionMerger>(device_info, ShapeSizeBytesFunction());
|
||||
pipeline.AddPass<GpuMultiOutputFusion>(device_info,
|
||||
ShapeSizeBytesFunction());
|
||||
|
||||
RunAndFilecheckHloRewrite(hlo, std::move(pipeline), expected);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_device_info_for_tests.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
|
||||
|
|
@ -81,7 +82,10 @@ TEST_F(GpuFusionTest, FusedBiggerThenThresholdButDoNotChangeTheFusionl) {
|
|||
b.AddInstruction(
|
||||
HloInstruction::CreateConcatenate(concat_shape, slice_params, 1));
|
||||
module->AddEntryComputation(b.Build());
|
||||
EXPECT_TRUE(GpuInstructionFusion(false).Run(module.get()).value());
|
||||
EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false,
|
||||
TestGpuDeviceInfo::RTXA6000DeviceInfo())
|
||||
.Run(module.get())
|
||||
.value());
|
||||
EXPECT_TRUE(module->entry_computation()->root_instruction()->opcode() ==
|
||||
HloOpcode::kFusion);
|
||||
for (HloInstruction* instr : module->entry_computation()->instructions()) {
|
||||
|
|
@ -93,8 +97,11 @@ class TransposeFusionTest : public GpuFusionTest {
|
|||
public:
|
||||
void CheckGpuFusion(absl::string_view hlo,
|
||||
std::optional<absl::string_view> expected) {
|
||||
RunAndFilecheckHloRewrite(hlo, GpuInstructionFusion{/*may_duplicate=*/true},
|
||||
expected);
|
||||
RunAndFilecheckHloRewrite(
|
||||
hlo,
|
||||
GpuInstructionFusion{/*may_duplicate=*/true,
|
||||
TestGpuDeviceInfo::RTXA6000DeviceInfo()},
|
||||
expected);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -327,6 +327,19 @@ bool HloOrdering::UsesBeforeValueDefinition(
|
|||
return true;
|
||||
}
|
||||
}
|
||||
// The use at an async call occurs before values that are defined in the
|
||||
// called computation of the async wrapped instruction.
|
||||
if (use.instruction->IsAsynchronous() &&
|
||||
use.instruction->async_wrapped_opcode() == HloOpcode::kCall) {
|
||||
const HloInstruction* async = use.instruction;
|
||||
if (call_graph_->InstructionIsNestedIn(
|
||||
value.defining_instruction(),
|
||||
async->async_wrapped_instruction()->to_apply())) {
|
||||
VLOG(4) << " use is async " << use.instruction->name()
|
||||
<< " and def is in called computation";
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if (use.instruction->opcode() == HloOpcode::kConditional) {
|
||||
const HloInstruction* conditional = use.instruction;
|
||||
// In general the use of a value in the conditional parameter should be
|
||||
|
|
|
|||
|
|
@ -589,5 +589,52 @@ ENTRY entry {
|
|||
ordering.UsesBeforeValueDefinition({&tuple_use}, value, *dataflow));
|
||||
}
|
||||
|
||||
TEST_F(HloOrderingTest, AsyncCallUses) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule single_sc_async_call
|
||||
|
||||
%called_computation {
|
||||
%out_param = s32[1024]{0} parameter(1)
|
||||
%input = s32[1024]{0} parameter(0)
|
||||
%size = s32[] constant(256)
|
||||
%index = s32[] custom-call(), custom_call_target="Baz"
|
||||
%start = s32[] multiply(s32[] %size, s32[] %index)
|
||||
%input2 = s32[256]{0} dynamic-slice(s32[1024]{0} %input, s32[] %start), dynamic_slice_sizes={256}
|
||||
%output = s32[256]{0} add(s32[256]{0} %input2, s32[256]{0} %input2)
|
||||
ROOT %output2 = s32[1024]{0} dynamic-update-slice(s32[1024]{0} %out_param, s32[256]{0} %output, s32[] %start)
|
||||
}, execution_thread="foobar"
|
||||
|
||||
%async_wrapped {
|
||||
%async_param = s32[1024]{0} parameter(0)
|
||||
%async_param.1 = s32[1024]{0} parameter(1)
|
||||
ROOT %call = s32[1024]{0} call(s32[1024]{0} %async_param, s32[1024]{0} %async_param.1), to_apply=%called_computation
|
||||
}, execution_thread="foobar"
|
||||
|
||||
ENTRY %main {
|
||||
%input.1 = s32[1024]{0} parameter(0)
|
||||
%buf = s32[1024]{0} custom-call(), custom_call_target="AllocateBuffer"
|
||||
%async-start = ((s32[1024]{0}, s32[1024]{0}), s32[1024]{0}, u32[]) async-start(s32[1024]{0} %input.1, s32[1024]{0} %buf), async_group_id=0, async_execution_thread="foobar", calls=%async_wrapped
|
||||
ROOT %async-done = s32[1024]{0} async-done(((s32[1024]{0}, s32[1024]{0}), s32[1024]{0}, u32[]) %async-start), async_group_id=0, async_execution_thread="foobar", calls=%async_wrapped
|
||||
}
|
||||
)";
|
||||
HloModuleConfig hlo_config;
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string, hlo_config));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
|
||||
HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
|
||||
DependencyHloOrdering ordering(module.get());
|
||||
auto async_start = FindInstruction(module.get(), "async-start");
|
||||
auto async_done = FindInstruction(module.get(), "async-done");
|
||||
auto call = FindInstruction(module.get(), "call");
|
||||
auto output2 = FindInstruction(module.get(), "output2");
|
||||
|
||||
auto async_start_use = HloUse{async_start, 1};
|
||||
auto async_done_use = HloUse{async_done, 0, {0, 1}};
|
||||
auto call_use = HloUse{call, 1};
|
||||
const HloValue& value = dataflow->GetUniqueValueAt(output2, {});
|
||||
EXPECT_TRUE(ordering.UsesBeforeValueDefinition(
|
||||
{&async_start_use, &call_use, &async_done_use}, value, *dataflow));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
|
|
|||
|
|
@ -1694,6 +1694,8 @@ namespace {
|
|||
|
||||
struct ParallelState {
|
||||
explicit ParallelState(int64_t task_count) : counter(task_count) {
|
||||
// If this method is changed, please remember to change
|
||||
// GetForEachIndexParallelThreadCount() as well.
|
||||
static auto* global_pool = new tsl::thread::ThreadPool(
|
||||
tsl::Env::Default(), "foreach", tsl::port::MaxParallelism());
|
||||
pool = global_pool;
|
||||
|
|
@ -1743,6 +1745,11 @@ struct ParallelState {
|
|||
return pstate.status;
|
||||
}
|
||||
|
||||
/* static */ int ShapeUtil::GetForEachIndexParallelThreadCount() {
|
||||
ParallelState pstate(/*task_count=*/0);
|
||||
return pstate.pool->NumThreads();
|
||||
}
|
||||
|
||||
/* static */ Shape ShapeUtil::DeleteDimensions(
|
||||
absl::Span<int64_t const> dims_to_delete, Shape shape) {
|
||||
std::vector<int64_t> dims_to_delete_v(dims_to_delete.begin(),
|
||||
|
|
|
|||
|
|
@ -722,11 +722,19 @@ class ShapeUtil {
|
|||
// A parallel version of ForEachIndex(WithStatus). This can only be used if
|
||||
// the visitor_function is thread-safe and the order of iteration does not
|
||||
// matter.
|
||||
//
|
||||
// Please use GetForEachIndexParallelThreadCount() to get the number of
|
||||
// threads in the threadpool of ForEachIndexParallel*. This will not change
|
||||
// during the runtime of the process. Please DO NOT use
|
||||
// tsl::port::MaxParallelism() for this purpose, as it may change.
|
||||
static void ForEachIndexParallel(
|
||||
const Shape& shape, absl::Span<const int64_t> base,
|
||||
absl::Span<const int64_t> count, absl::Span<const int64_t> incr,
|
||||
const ForEachParallelVisitorFunction& visitor_function);
|
||||
|
||||
// Returns the number of threads in the threadpool of ForEachIndexParallel*.
|
||||
static int GetForEachIndexParallelThreadCount();
|
||||
|
||||
static Status ForEachIndexParallelWithStatus(
|
||||
const Shape& shape, absl::Span<const int64_t> base,
|
||||
absl::Span<const int64_t> count, absl::Span<const int64_t> incr,
|
||||
|
|
|
|||
|
|
@ -630,6 +630,23 @@ TEST(ShapeUtilTest, ForEachIndexWithStatus) {
|
|||
EXPECT_EQ(invocations, 5);
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, GetForEachIndexParallelThreadCount) {
|
||||
const int kThreadCount = ShapeUtil::GetForEachIndexParallelThreadCount();
|
||||
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {10, 100});
|
||||
auto check_func = [kThreadCount](absl::Span<const int64_t> /*indexes*/,
|
||||
int thread_id) -> StatusOr<bool> {
|
||||
EXPECT_GE(thread_id, -1);
|
||||
EXPECT_LT(thread_id, kThreadCount);
|
||||
return true;
|
||||
};
|
||||
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
ShapeUtil::ForEachIndexParallel(shape, /*base=*/{0, 0}, /*count=*/{10, 100},
|
||||
/*incr=*/{1, 1}, check_func);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ForEachIndexParallel) {
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {10, 10});
|
||||
int64_t output[10][10];
|
||||
|
|
|
|||
|
|
@ -115,6 +115,7 @@ cc_library(
|
|||
"//tensorflow/compiler/xla/stream_executor/platform",
|
||||
"//tensorflow/compiler/xla/stream_executor/platform:dso_loader",
|
||||
"//tensorflow/tsl/platform:env",
|
||||
"//tensorflow/tsl/platform:static_threadlocal",
|
||||
] + tf_additional_cuda_driver_deps()) + select({
|
||||
# include dynamic loading implementation only when if_cuda_is_configured and build dynamically
|
||||
"//tensorflow/tsl:is_cuda_enabled_and_oss": ["cudart_stub"],
|
||||
|
|
|
|||
|
|
@ -152,15 +152,15 @@ void Diagnostician::LogDiagnosticInformation() {
|
|||
CFRelease(kext_infos);
|
||||
#elif !defined(PLATFORM_WINDOWS)
|
||||
if (access(kDriverVersionPath, F_OK) != 0) {
|
||||
LOG(INFO) << "kernel driver does not appear to be running on this host "
|
||||
<< "(" << tsl::port::Hostname() << "): "
|
||||
<< "/proc/driver/nvidia/version does not exist";
|
||||
VLOG(1) << "kernel driver does not appear to be running on this host "
|
||||
<< "(" << tsl::port::Hostname() << "): "
|
||||
<< "/proc/driver/nvidia/version does not exist";
|
||||
return;
|
||||
}
|
||||
auto dev0_path = GetDevNodePath(0);
|
||||
if (access(dev0_path.c_str(), F_OK) != 0) {
|
||||
LOG(INFO) << "no NVIDIA GPU device is present: " << dev0_path
|
||||
<< " does not exist";
|
||||
VLOG(1) << "no NVIDIA GPU device is present: " << dev0_path
|
||||
<< " does not exist";
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -36,11 +36,11 @@ limitations under the License.
|
|||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/lib/error.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/lib/static_threadlocal.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/platform/logging.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/platform/port.h"
|
||||
#include "tensorflow/tsl/platform/env.h"
|
||||
#include "tensorflow/tsl/platform/stacktrace.h"
|
||||
#include "tensorflow/tsl/platform/static_threadlocal.h"
|
||||
#include "tensorflow/tsl/platform/threadpool.h"
|
||||
|
||||
bool FLAGS_gpuexec_cuda_driver_inject_init_error = false;
|
||||
|
|
@ -130,7 +130,7 @@ struct ThreadLocalData {
|
|||
int depth;
|
||||
};
|
||||
|
||||
SE_STATIC_THREAD_LOCAL_POD(ThreadLocalData, tls_data);
|
||||
TSL_STATIC_THREAD_LOCAL_POD(ThreadLocalData, tls_data);
|
||||
|
||||
} // namespace
|
||||
|
||||
|
|
@ -261,7 +261,7 @@ static tsl::Status InternalInit() {
|
|||
if (res == CUDA_SUCCESS) {
|
||||
return ::tsl::OkStatus();
|
||||
} else if (res == CUDA_ERROR_SHARED_OBJECT_INIT_FAILED) {
|
||||
LOG(WARNING) << "failed call to cuInit: " << ToString(res);
|
||||
VLOG(1) << "failed call to cuInit: " << ToString(res);
|
||||
} else {
|
||||
LOG(ERROR) << "failed call to cuInit: " << ToString(res);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ cc_library(
|
|||
"//tensorflow/tsl/platform:env",
|
||||
"//tensorflow/tsl/platform:numbers",
|
||||
"//tensorflow/tsl/platform:stacktrace",
|
||||
"//tensorflow/tsl/platform:static_threadlocal",
|
||||
]),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_ROCM_ROCBLAS_WRAPPER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_ROCM_ROCBLAS_WRAPPER_H_
|
||||
|
||||
#include "rocm/include/rocblas.h"
|
||||
#include "rocm/include/rocblas/rocblas.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_activation.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/platform/dso_loader.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/platform/port.h"
|
||||
|
|
|
|||
|
|
@ -29,13 +29,13 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_diagnostics.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_driver.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/lib/error.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/lib/static_threadlocal.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/platform/logging.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/platform/port.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/rocm/rocm_driver_wrapper.h"
|
||||
#include "tensorflow/tsl/platform/env.h"
|
||||
#include "tensorflow/tsl/platform/numbers.h"
|
||||
#include "tensorflow/tsl/platform/stacktrace.h"
|
||||
#include "tensorflow/tsl/platform/static_threadlocal.h"
|
||||
#include "tensorflow/tsl/platform/threadpool.h"
|
||||
|
||||
bool FLAGS_gpuexec_rocm_driver_inject_init_error = false;
|
||||
|
|
@ -168,7 +168,7 @@ struct ThreadLocalData {
|
|||
int depth;
|
||||
};
|
||||
|
||||
SE_STATIC_THREAD_LOCAL_POD(ThreadLocalData, tls_data);
|
||||
TSL_STATIC_THREAD_LOCAL_POD(ThreadLocalData, tls_data);
|
||||
|
||||
} // namespace
|
||||
|
||||
|
|
|
|||
|
|
@ -112,6 +112,22 @@ xla_cc_binary(
|
|||
],
|
||||
)
|
||||
|
||||
xla_cc_test(
|
||||
name = "replay_computation_bin_test",
|
||||
srcs = ["replay_computation_bin_test.cc"],
|
||||
data = [
|
||||
"add.hlo",
|
||||
":replay_computation_cpu",
|
||||
":replay_computation_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/tsl/platform:path",
|
||||
"//tensorflow/tsl/platform:subprocess",
|
||||
"//tensorflow/tsl/platform:test",
|
||||
"//tensorflow/tsl/platform:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
xla_cc_binary(
|
||||
name = "replay_computation_gpu",
|
||||
tags = ["gpu"],
|
||||
|
|
|
|||
64
tensorflow/compiler/xla/tools/replay_computation_bin_test.cc
Normal file
64
tensorflow/compiler/xla/tools/replay_computation_bin_test.cc
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/tsl/platform/path.h"
|
||||
#include "tensorflow/tsl/platform/subprocess.h"
|
||||
#include "tensorflow/tsl/platform/test.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
// TODO(ddunleavy): test something more specific.
|
||||
|
||||
std::string PathToAddHlo() {
|
||||
return tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "tools", "add.hlo");
|
||||
}
|
||||
|
||||
TEST(ReplayComputation, AddHloHost) {
|
||||
// Get relevant paths to run_hlo_module and add.hlo
|
||||
std::string replay_computation_bin = tsl::io::JoinPath(
|
||||
tsl::testing::XlaSrcRoot(), "tools", "replay_computation_cpu");
|
||||
|
||||
tsl::SubProcess proc;
|
||||
proc.SetProgram(replay_computation_bin,
|
||||
{replay_computation_bin, PathToAddHlo(), "--use_fake_data"});
|
||||
EXPECT_TRUE(proc.Start());
|
||||
|
||||
// Just make sure that the process's exit code is 0
|
||||
int status = proc.Communicate(nullptr, nullptr, nullptr);
|
||||
EXPECT_TRUE(WIFEXITED(status));
|
||||
ASSERT_EQ(0, WEXITSTATUS(status));
|
||||
}
|
||||
|
||||
TEST(ReplayComputation, AddHloInterpreter) {
|
||||
// Get relevant paths to run_hlo_module and add.hlo
|
||||
std::string replay_computation_bin = tsl::io::JoinPath(
|
||||
tsl::testing::XlaSrcRoot(), "tools", "replay_computation_interpreter");
|
||||
|
||||
tsl::SubProcess proc;
|
||||
proc.SetProgram(replay_computation_bin,
|
||||
{replay_computation_bin, PathToAddHlo(), "--use_fake_data"});
|
||||
EXPECT_TRUE(proc.Start());
|
||||
|
||||
// Just make sure that the process's exit code is 0
|
||||
int status = proc.Communicate(nullptr, nullptr, nullptr);
|
||||
EXPECT_TRUE(WIFEXITED(status));
|
||||
ASSERT_EQ(0, WEXITSTATUS(status));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
|
@ -629,16 +629,23 @@ cc_library(
|
|||
"//tensorflow/core/kernels/image:image",
|
||||
"//tensorflow/core/kernels/sparse:kernels",
|
||||
] + if_mkl([
|
||||
"//tensorflow/core/kernels/mkl:mkl_aggregate_ops",
|
||||
"//tensorflow/core/kernels/mkl:mkl_concat_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_dequantize_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_conv_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_cwise_ops_common",
|
||||
"//tensorflow/core/kernels/mkl:mkl_fused_batch_norm_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_identity_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_input_conversion_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_layer_norm_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_lrn_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_pooling_ops",
|
||||
"//tensorflow/core/kernels/mkl:mkl_qmatmul_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_requantize_ops",
|
||||
"//tensorflow/core/kernels/mkl:mkl_quantize_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_relu_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_reshape_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_slice_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_softmax_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_swish_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_fused_mish_op",
|
||||
|
|
@ -646,8 +653,8 @@ cc_library(
|
|||
"//tensorflow/core/kernels/mkl:mkl_batch_matmul_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_einsum_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_matmul_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_tfconv_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_tmp_bf16_ops",
|
||||
"//tensorflow/core/kernels/mkl:mkl_deprecated_ops",
|
||||
]) + if_cuda_or_rocm([
|
||||
"//tensorflow/core/kernels:cudnn_rnn_kernels",
|
||||
]) + if_cuda([
|
||||
|
|
@ -1956,21 +1963,28 @@ tf_cc_test_mkl(
|
|||
"//tensorflow/core/kernels:ops_util",
|
||||
"//third_party/eigen3",
|
||||
] + if_mkl([
|
||||
"//tensorflow/core/kernels/mkl:mkl_aggregate_ops",
|
||||
"//tensorflow/core/kernels/mkl:mkl_batch_matmul_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_einsum_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_concat_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_conv_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_cwise_ops_common",
|
||||
"//tensorflow/core/kernels/mkl:mkl_dequantize_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_fused_batch_norm_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_identity_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_input_conversion_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_lrn_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_matmul_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_pooling_ops",
|
||||
"//tensorflow/core/kernels/mkl:mkl_qmatmul_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_quantize_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_relu_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_reshape_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_slice_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_softmax_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_tfconv_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_transpose_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_tmp_bf16_ops",
|
||||
"//tensorflow/core/kernels/mkl:mkl_deprecated_ops",
|
||||
]),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
op {
|
||||
graph_op_name: "CollectiveReduceScatterV2"
|
||||
summary: "Mutually reduces multiple tensors of identical type and shape and scatters the result."
|
||||
visibility: HIDDEN
|
||||
}
|
||||
|
|
@ -282,6 +282,7 @@ filegroup(
|
|||
"memory_types.h",
|
||||
"mkl_cpu_allocator.h",
|
||||
"mkl_layout_pass.h",
|
||||
"mkl_tfconversion_pass.h",
|
||||
"node_file_writer.h",
|
||||
"optimization_registry.h",
|
||||
"partitioning_utils.h",
|
||||
|
|
@ -1314,6 +1315,26 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mkl_tfconversion_pass",
|
||||
srcs = ["mkl_tfconversion_pass.cc"],
|
||||
hdrs = [
|
||||
"mkl_tfconversion_pass.h",
|
||||
"//tensorflow/core/graph:mkl_graph_util_header",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":function",
|
||||
":optimization_registry",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "node_file_writer",
|
||||
srcs = ["node_file_writer.cc"],
|
||||
|
|
@ -1888,6 +1909,7 @@ tf_cuda_library(
|
|||
":memory_types",
|
||||
":mkl_cpu_allocator",
|
||||
":mkl_layout_pass",
|
||||
":mkl_tfconversion_pass",
|
||||
":optimization_registry",
|
||||
":optimized_function_graph_info",
|
||||
":parallel_concat_optimizer",
|
||||
|
|
@ -3073,6 +3095,59 @@ tf_cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
tf_cc_test_mkl(
|
||||
name = "mkl_related_tests",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"mkl_layout_pass_test.cc",
|
||||
"mkl_tfconversion_pass_test.cc",
|
||||
],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":core",
|
||||
":core_cpu",
|
||||
":core_cpu_internal",
|
||||
":direct_session_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/cc:sendrecv_ops",
|
||||
"//tensorflow/core/kernels:ops_util",
|
||||
"//third_party/eigen3",
|
||||
] + if_mkl([
|
||||
"//tensorflow/core/kernels/mkl:mkl_aggregate_ops",
|
||||
"//tensorflow/core/kernels/mkl:mkl_batch_matmul_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_concat_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_conv_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_cwise_ops_common",
|
||||
"//tensorflow/core/kernels/mkl:mkl_dequantize_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_einsum_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_fused_batch_norm_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_identity_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_input_conversion_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_lrn_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_matmul_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_pooling_ops",
|
||||
"//tensorflow/core/kernels/mkl:mkl_qmatmul_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_quantize_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_relu_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_reshape_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_slice_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_softmax_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_tfconv_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_transpose_op",
|
||||
"//tensorflow/core/kernels/mkl:mkl_tmp_bf16_ops",
|
||||
]),
|
||||
)
|
||||
|
||||
# TODO(bmzhao): Refactor this target to use granular dependencies
|
||||
# after stage 4 of the TF build refactor is complete:
|
||||
# https://github.com/tensorflow/community/pull/179
|
||||
|
|
|
|||
|
|
@ -303,14 +303,16 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
|
|||
}
|
||||
|
||||
Tensor* output = ctx->mutable_output(0);
|
||||
const Tensor* input = (col_params->instance.type == REDUCTION_COLLECTIVE ||
|
||||
col_params->instance.type == GATHER_COLLECTIVE ||
|
||||
col_params->instance.type == PERMUTE_COLLECTIVE ||
|
||||
col_params->instance.type == ALL_TO_ALL_COLLECTIVE ||
|
||||
(col_params->instance.type == BROADCAST_COLLECTIVE &&
|
||||
col_params->is_source))
|
||||
? &ctx->input(0)
|
||||
: nullptr;
|
||||
const Tensor* input =
|
||||
(col_params->instance.type == REDUCTION_COLLECTIVE ||
|
||||
col_params->instance.type == GATHER_COLLECTIVE ||
|
||||
col_params->instance.type == PERMUTE_COLLECTIVE ||
|
||||
col_params->instance.type == ALL_TO_ALL_COLLECTIVE ||
|
||||
col_params->instance.type == REDUCE_SCATTER_COLLECTIVE ||
|
||||
(col_params->instance.type == BROADCAST_COLLECTIVE &&
|
||||
col_params->is_source))
|
||||
? &ctx->input(0)
|
||||
: nullptr;
|
||||
CollectiveImplementationInterface* col_impl = nullptr;
|
||||
Status status = CreateCollective(*col_params, &col_impl);
|
||||
if (!status.ok()) {
|
||||
|
|
|
|||
|
|
@ -78,6 +78,9 @@ const char* GetCollectiveName(const CollectiveParams* cp, bool nccl) {
|
|||
case ALL_TO_ALL_COLLECTIVE:
|
||||
return "AllToAll";
|
||||
|
||||
case REDUCE_SCATTER_COLLECTIVE:
|
||||
return nccl ? "NcclReduceScatter" : "undef";
|
||||
|
||||
default:
|
||||
return "undef";
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1193,7 +1193,9 @@ bool ExecutorState<PropagatorStateType>::NodeDone(
|
|||
// iterating through a tf.data input pipeline.
|
||||
if (!errors::IsOutOfRange(s)) {
|
||||
LOG(INFO) << "[" << immutable_state_.params().device->name()
|
||||
<< "] Executor start aborting: " << s;
|
||||
<< "] (DEBUG INFO) Executor start aborting (this does not "
|
||||
"indicate an error and you can ignore this message): "
|
||||
<< s;
|
||||
} else {
|
||||
VLOG(1) << "[" << immutable_state_.params().device->name()
|
||||
<< "] Executor start aborting: " << s;
|
||||
|
|
|
|||
|
|
@ -66,7 +66,8 @@ namespace tensorflow {
|
|||
namespace {
|
||||
bool IsCollectiveV2(const string& op) {
|
||||
return op == "CollectiveReduceV2" || op == "CollectiveGatherV2" ||
|
||||
op == "CollectiveBcastRecvV2" || op == "CollectiveBcastSendV2";
|
||||
op == "CollectiveBcastRecvV2" || op == "CollectiveBcastSendV2" ||
|
||||
op == "ColectiveReduceScatterV2";
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
|
|||
5211
tensorflow/core/common_runtime/mkl_layout_pass_test.cc
Normal file
5211
tensorflow/core/common_runtime/mkl_layout_pass_test.cc
Normal file
File diff suppressed because it is too large
Load Diff
462
tensorflow/core/common_runtime/mkl_tfconversion_pass.cc
Normal file
462
tensorflow/core/common_runtime/mkl_tfconversion_pass.cc
Normal file
|
|
@ -0,0 +1,462 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if defined(INTEL_MKL) && defined(ENABLE_MKL)
|
||||
|
||||
#include "tensorflow/core/common_runtime/mkl_tfconversion_pass.h"
|
||||
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/mkl_graph_util.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// This pass inserts Mkl to Tf tensor conversion nodes (represented by C)
|
||||
// in the graph in between A and B, where A and B match any one
|
||||
// of the following cases:
|
||||
//
|
||||
// 1) A = a node that generates output in the Mkl format and,
|
||||
// B = a node that does not accept input in the Mkl format and,
|
||||
// A -> B (there is a direct edge between A and B, then
|
||||
// We will insert C such that A->C->B.
|
||||
//
|
||||
// 2) A = a node that generates output in the Mkl format and,
|
||||
// B = NULL (in other words, A is the last node in the graph), then
|
||||
// We will insert C such that A->C->B. (C will be the last node.)
|
||||
//
|
||||
// Note that case 1 applies to all outputs of A that are input to B.
|
||||
// In other words, the conversions will be required for every output
|
||||
// of A that is input to B. For example, let us say the output of A
|
||||
// is A1, A2, A3, of which A1 and A2 are in Mkl format, but A3 is not
|
||||
// in Mkl format, and all of them are input to B. In such case, we will
|
||||
// do the conversion for A1 and A2 only. We do not need to do any conversion
|
||||
// for A3.
|
||||
//
|
||||
// This pass relies on ops registering themselves about their Mkl compliance.
|
||||
// An Mkl-compliant op can accept inputs in the Mkl format, and produce outputs
|
||||
// in the Mkl format. Non-compliant ops accept inputs and outputs in the
|
||||
// TensorFlow format.
|
||||
//
|
||||
// ADDENDUM: For element-wise ops, we may or may not need a conversion to
|
||||
// take place before we hit the op. For this, we add a new op before each
|
||||
// element-wise MKL op to deal with the inputs, called _MklInputConversion.
|
||||
// This pass has been enhanced to add this capability.
|
||||
//
|
||||
// The _MklInputConversion op will check the inputs to the elementwise op and
|
||||
// make sure that either both are in MKL format or both are in TF format,
|
||||
// depending on their initial state and whether broadcast is needed or not.
|
||||
|
||||
class MklToTfConversionPass : public GraphOptimizationPass {
|
||||
public:
|
||||
MklToTfConversionPass() {}
|
||||
Status Run(const GraphOptimizationPassOptions& options);
|
||||
|
||||
// Insert layout conversion node in the graph pointed by g.
|
||||
// Function scans the graph for candidate edges where we
|
||||
// need to insert conversion nodes.
|
||||
//
|
||||
// @return true even if single conversion node is inserted;
|
||||
// false, otherwise.
|
||||
bool RunPass(std::unique_ptr<Graph>* g);
|
||||
|
||||
private:
|
||||
// Is the input Op supported by Mkl-specific layout?
|
||||
//
|
||||
// @input op_name string of the op
|
||||
// @input T Datatype to use for checking input op
|
||||
// @return true if op is Mkl supported; false, otherwise.
|
||||
inline bool IsMklSupportedOp(const string& op_name, DataType T) const {
|
||||
return mkl_op_registry::IsMklOp(op_name, T, false);
|
||||
}
|
||||
|
||||
// Is the input Op supported by Mkl-specific layout AND
|
||||
// is it element-wise?
|
||||
//
|
||||
// @input op_name string of the op
|
||||
// @input T Datatype to use for checking input op
|
||||
// @return true if op is Mkl supported; false, otherwise.
|
||||
inline bool IsMklElementWiseOp(const string& op_name, DataType T) const {
|
||||
return mkl_op_registry::IsMklElementWiseOp(op_name, T);
|
||||
}
|
||||
|
||||
// Insert layout conversion node on the edge pointed by 'e' from graph 'g'.
|
||||
//
|
||||
// Edge will be deleted once a call to this function is successful.
|
||||
// Any attempt to use the edge after this call
|
||||
// will lead to undefined behaviors.
|
||||
//
|
||||
// @return Success:OK() if insertion is successful, otherwise returns
|
||||
// appropriate error status code.
|
||||
Status InsertConversionNodeOnEdge(std::unique_ptr<Graph>* g, Edge*);
|
||||
|
||||
// For element-wise ops, we need to sanitize the inputs. For this, we add a
|
||||
// new node at the input of the replacement element-wise node that checks
|
||||
// the inputs and converts one/both of them as required. See the op code
|
||||
// comments for details.
|
||||
//
|
||||
// Insert input conversion node as parent of 'n' from graph 'g'.
|
||||
//
|
||||
// @return Success:OK() if insertion is successful, otherwise returns
|
||||
// appropriate error status code.
|
||||
Status InsertInputConversionNode(std::unique_ptr<Graph>* g, Node*);
|
||||
};
|
||||
|
||||
// We register MklToTf insertion for phase 2 in post-partition grouping
|
||||
// because we register MklLayoutRewritePass for phase 1 in post-partition
|
||||
// grouping. We register this pass after partitioning so that we get a
|
||||
// complete picture of inputs and outputs of the nodes in the graphs.
|
||||
const OptimizationPassRegistry::Grouping kMklTfConvPassGroup =
|
||||
OptimizationPassRegistry::POST_PARTITIONING;
|
||||
#ifdef ENABLE_MKL
|
||||
REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass);
|
||||
#endif // ENABLE_MKL
|
||||
|
||||
Status MklToTfConversionPass::InsertConversionNodeOnEdge(
|
||||
std::unique_ptr<Graph>* g, Edge* e) {
|
||||
CHECK_NOTNULL(e);
|
||||
|
||||
Node* src = e->src();
|
||||
Node* dst = e->dst();
|
||||
|
||||
CHECK_NOTNULL(src);
|
||||
CHECK_NOTNULL(dst);
|
||||
|
||||
Node* conversion_node = nullptr;
|
||||
DataType src_datatype = src->output_type(e->src_output());
|
||||
DataType dst_datatype = dst->input_type(e->dst_input());
|
||||
string data_format;
|
||||
|
||||
// We compare source and destination datatypes only when both are found.
|
||||
if (src_datatype != dst_datatype) {
|
||||
string err_msg = "T attribute of " + src->name() + ":" +
|
||||
std::to_string(e->src_output()) + " and " + dst->name() +
|
||||
":" + std::to_string(e->dst_input()) +
|
||||
" do not"
|
||||
" match. Will not insert MklToTf node in such case.";
|
||||
return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str());
|
||||
}
|
||||
|
||||
TF_CHECK_OK(
|
||||
NodeBuilder((*g)->NewName("Mkl2Tf"), "_MklToTf")
|
||||
.Input(src, e->src_output())
|
||||
.Input(src, DataIndexToMetaDataIndex(
|
||||
e->src_output(),
|
||||
src->num_outputs())) // Get an Mkl tensor slot
|
||||
// from the Tf tensor slot.
|
||||
.Device(src->def().device()) // We want to get conversion node
|
||||
// on same device as source node.
|
||||
.Attr("T", src_datatype)
|
||||
.Finalize(&**g, &conversion_node));
|
||||
|
||||
CHECK_NOTNULL(conversion_node);
|
||||
// TODO(Intel-tf) MklToTf accepts only NHWC or NCHW, but doesn't seem to be
|
||||
// using data_format. This code might be redundant.
|
||||
if (GetNodeAttr(src->def(), "data_format", &data_format) == OkStatus() &&
|
||||
(data_format == ToString(FORMAT_NHWC) ||
|
||||
data_format == ToString(FORMAT_NCHW))) {
|
||||
conversion_node->AddAttr("data_format", data_format);
|
||||
}
|
||||
|
||||
// Get assigned device from source node and apply it to conversion node.
|
||||
// We want conversion node to be on the same device as the source node.
|
||||
conversion_node->set_assigned_device_name(src->assigned_device_name());
|
||||
|
||||
// Set the Mkl op label for this op.
|
||||
conversion_node->AddAttr("_kernel",
|
||||
mkl_op_registry::kMklLayoutDependentOpLabel);
|
||||
|
||||
// Now that we have added edge from src->conversion_node, let's add edge from
|
||||
// output of conversion_node to the dest node. Since conversion_node
|
||||
// has only 1 output, the src_output of conversion_node is 0.
|
||||
CHECK_NOTNULL((*g)->AddEdge(conversion_node, 0, dst, e->dst_input()));
|
||||
|
||||
VLOG(1) << "MklToTfConversionPass: Inserting Conversion node on: "
|
||||
<< src->type_string() << " and " << dst->type_string()
|
||||
<< " successful.";
|
||||
|
||||
// Remove src->dst edge now.
|
||||
(*g)->RemoveEdge(e);
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
Status MklToTfConversionPass::InsertInputConversionNode(
|
||||
std::unique_ptr<Graph>* g, Node* n) {
|
||||
CHECK_NOTNULL(n);
|
||||
|
||||
// Get the input nodes and edges
|
||||
std::vector<const Edge*> edges;
|
||||
TF_CHECK_OK(n->input_edges(&edges));
|
||||
if (edges.size() != 4) {
|
||||
return Status(error::Code::INVALID_ARGUMENT,
|
||||
"MKL Binary Element-wise op should have exactly 2 data"
|
||||
" inputs and 2 metadata inputs");
|
||||
}
|
||||
|
||||
// Sanity check: ensure that both inputs are of the expected type, and the
|
||||
// same type as input type
|
||||
CHECK_EQ(BaseType(edges[0]->src()->output_type(edges[0]->src_output())),
|
||||
BaseType(edges[1]->src()->output_type(edges[1]->src_output())));
|
||||
CHECK_EQ(BaseType(edges[0]->src()->output_type(edges[0]->src_output())),
|
||||
BaseType(n->input_type(0)));
|
||||
|
||||
// Check ordering of edges
|
||||
for (uint32 i = 0; i < 4; i++) {
|
||||
CHECK_EQ((edges[i]->dst_input() == i), true);
|
||||
}
|
||||
|
||||
// Build the conversion node and specify src as input.
|
||||
Node* conversion_node = nullptr;
|
||||
|
||||
TF_CHECK_OK(
|
||||
NodeBuilder((*g)->NewName("MklInputConversion"), "_MklInputConversion")
|
||||
.Input(edges[0]->src(), edges[0]->src_output())
|
||||
.Input(edges[1]->src(), edges[1]->src_output())
|
||||
.Input(edges[2]->src(), edges[2]->src_output())
|
||||
.Input(edges[3]->src(), edges[3]->src_output())
|
||||
.Device(n->def().device())
|
||||
.Attr("T", n->input_type(0))
|
||||
.Finalize(&**g, &conversion_node));
|
||||
|
||||
CHECK_NOTNULL(conversion_node);
|
||||
|
||||
// Change the destination of any control edges to the InputConversion node
|
||||
if (edges.size() != n->in_edges().size()) {
|
||||
std::vector<const Edge*> edges_to_remove;
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
CHECK_NOTNULL((*g)->AddControlEdge(e->src(), conversion_node));
|
||||
edges_to_remove.push_back(e);
|
||||
}
|
||||
}
|
||||
for (const Edge* e : edges_to_remove) {
|
||||
(*g)->RemoveEdge(e);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(Intel-tf) MklInputConversion accepts only NHWC or NCHW, but doesn't
|
||||
// seem to be using data_format. This code might be redundant.
|
||||
string data_format;
|
||||
if (GetNodeAttr(edges[0]->src()->def(), "data_format", &data_format) ==
|
||||
OkStatus() &&
|
||||
(data_format == ToString(FORMAT_NHWC) ||
|
||||
data_format == ToString(FORMAT_NCHW))) {
|
||||
conversion_node->AddAttr("data_format", data_format);
|
||||
}
|
||||
|
||||
// Get assigned device from destination node and apply it to conversion node.
|
||||
// We want conversion node to be on the same device as the destination node.
|
||||
conversion_node->set_assigned_device_name(n->assigned_device_name());
|
||||
|
||||
// Set the Mkl op label for this op.
|
||||
conversion_node->AddAttr("_kernel",
|
||||
mkl_op_registry::kMklLayoutDependentOpLabel);
|
||||
|
||||
// Now that we have added edges from src->conversion_node, let's add edge from
|
||||
// output of conversion_node to the element-wise node.
|
||||
CHECK_NOTNULL((*g)->AddEdge(conversion_node, 0, n, edges[0]->dst_input()));
|
||||
CHECK_NOTNULL((*g)->AddEdge(conversion_node, 1, n, edges[1]->dst_input()));
|
||||
CHECK_NOTNULL((*g)->AddEdge(conversion_node, 2, n, edges[2]->dst_input()));
|
||||
CHECK_NOTNULL((*g)->AddEdge(conversion_node, 3, n, edges[3]->dst_input()));
|
||||
|
||||
VLOG(1) << "MklToTfConversionPass - InputConversion: Inserting input "
|
||||
<< "conversion node on: " << n->type_string() << " successful.";
|
||||
|
||||
// Remove src->dst edge now.
|
||||
(*g)->RemoveEdge(edges[0]);
|
||||
(*g)->RemoveEdge(edges[1]);
|
||||
(*g)->RemoveEdge(edges[2]);
|
||||
(*g)->RemoveEdge(edges[3]);
|
||||
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
|
||||
bool result = false;
|
||||
|
||||
CHECK_NOTNULL(g);
|
||||
|
||||
DumpGraph("Before MklToTfConversionPass", &**g);
|
||||
|
||||
// Since we are looking for an Mkl-supported op node immediately
|
||||
// followed by a non-Mkl op node, we will just iterate over edge
|
||||
// set of the graph.
|
||||
// edge set whose source and destination are candidates for
|
||||
// inserting conversion node
|
||||
std::vector<Edge*> candidate_edges;
|
||||
|
||||
for (const Edge* e : (*g)->edges()) {
|
||||
Node* src = e->src();
|
||||
Node* dst = e->dst();
|
||||
|
||||
// We skip control edges.
|
||||
if (e->IsControlEdge()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// We skip adding MklToTf on an edge between X->MklToTf or
|
||||
// MklToTf->X, where X is any node.
|
||||
if (src->type_string().compare("_MklToTf") == 0 ||
|
||||
dst->type_string().compare("_MklToTf") == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
VLOG(1) << "MklToTfConversionPass: InsertConversionNodes: "
|
||||
<< src->type_string() << " and " << dst->type_string();
|
||||
|
||||
// Let's get source and destination data type.
|
||||
// We cannot check datatype on destination node because destination node
|
||||
// may not be Mkl node.
|
||||
DataType src_datatype;
|
||||
DataType dst_datatype;
|
||||
bool src_is_mkl_op =
|
||||
(GetNodeAttr(src->def(), "T", &src_datatype) == OkStatus() &&
|
||||
IsMklSupportedOp(src->type_string(), src_datatype));
|
||||
bool dst_is_mkl_op =
|
||||
(GetNodeAttr(dst->def(), "T", &dst_datatype) == OkStatus() &&
|
||||
IsMklSupportedOp(dst->type_string(), dst_datatype));
|
||||
|
||||
// Check if src with is Mkl-compliant, while dst is not Mkl-compliant.
|
||||
if (src_is_mkl_op && !dst_is_mkl_op) {
|
||||
VLOG(1) << "MklToTfConversionPass: Scheduled nodes " << src->name()
|
||||
<< " and " << dst->name() << " for inserting conversion nodes";
|
||||
candidate_edges.push_back(const_cast<Edge*>(e));
|
||||
}
|
||||
}
|
||||
|
||||
// Process all candidate edges and insert conversion nodes on them.
|
||||
for (Edge* e : candidate_edges) {
|
||||
// Even if we insert conversion node on a single edge, we
|
||||
// need to return true.
|
||||
string src_name = e->src()->name();
|
||||
string dst_name = e->dst()->name();
|
||||
if (InsertConversionNodeOnEdge(g, e) == OkStatus()) {
|
||||
VLOG(1) << "MklToTfConversionPass: Inserted conversion "
|
||||
<< "node on edge between " << src_name << " and " << dst_name;
|
||||
result = true;
|
||||
}
|
||||
}
|
||||
|
||||
DumpGraph("After MklToTfConversionPass", &**g);
|
||||
|
||||
//---------------------------------------------------------------------------
|
||||
// Check all nodes and add an input-conversion-node if the node is an mkl
|
||||
// element-wise node.
|
||||
VLOG(1) << "Before running MklToTfConversionPass - InputConversion";
|
||||
|
||||
std::vector<Node*> candidate_nodes;
|
||||
std::vector<Node*> order;
|
||||
GetReversePostOrder(**g, &order); // This will give us topological sort.
|
||||
|
||||
for (Node* n : order) {
|
||||
// If node is not an op or it does not have a datatype, then skip.
|
||||
DataType datatype;
|
||||
if (!n->IsOp() || (GetNodeAttr(n->def(), "T", &datatype) != OkStatus())) {
|
||||
continue;
|
||||
}
|
||||
if (IsMklElementWiseOp(n->type_string(), datatype)) {
|
||||
// If the input node is an input-conversion op, skip
|
||||
Node* input_node = nullptr;
|
||||
TF_CHECK_OK(n->input_node(0, &input_node));
|
||||
DataType input_datatype;
|
||||
if ((GetNodeAttr(n->def(), "T", &input_datatype) == OkStatus()) &&
|
||||
(input_node->type_string().compare("_MklInputConversion") == 0)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
VLOG(1) << "MklToTfConversionPass: InputConversion: Scheduled node "
|
||||
<< n->name() << " for inserting input conversion node";
|
||||
candidate_nodes.push_back(const_cast<Node*>(n));
|
||||
}
|
||||
}
|
||||
|
||||
// Process all candidate edges and insert conversion nodes on them.
|
||||
for (Node* n : candidate_nodes) {
|
||||
// Even if we insert conversion node on a single node, we
|
||||
// need to return true.
|
||||
if (InsertInputConversionNode(g, n) == OkStatus()) {
|
||||
VLOG(1) << "MklToTfConversionPass: Inserted conversion "
|
||||
<< "on node " << n->name();
|
||||
result = true;
|
||||
}
|
||||
}
|
||||
DumpGraph("After MklToTfConversionPass - InputConversion", &**g);
|
||||
|
||||
// We need to return true even if we insert one conversion node
|
||||
// anywhere in the graph.
|
||||
return result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Run function for the pass
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
bool InsertMklToTfConversionNodes(std::unique_ptr<Graph>* g) {
|
||||
return MklToTfConversionPass().RunPass(g);
|
||||
}
|
||||
|
||||
Status MklToTfConversionPass::Run(const GraphOptimizationPassOptions& options) {
|
||||
if (options.graph == nullptr && options.partition_graphs == nullptr) {
|
||||
return OkStatus();
|
||||
}
|
||||
if (!IsMKLEnabled()) {
|
||||
VLOG(2) << "TF-MKL: MKL is not enabled";
|
||||
return OkStatus();
|
||||
}
|
||||
if (NativeFormatEnabled()) {
|
||||
VLOG(2)
|
||||
<< "Running in native format mode, MklToTfConversionPass won't run.";
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
auto process_graph = [&](std::unique_ptr<Graph>* g) {
|
||||
// Get the ownership of graph
|
||||
std::unique_ptr<Graph>* ng = std::move(g);
|
||||
RunPass(ng);
|
||||
// Return the ownership of graph back
|
||||
g->reset(ng->release());
|
||||
};
|
||||
|
||||
if (kMklTfConvPassGroup != OptimizationPassRegistry::POST_PARTITIONING) {
|
||||
// For any pre-partitioning phase, graph is stored in options.graph.
|
||||
process_graph(options.graph);
|
||||
} else {
|
||||
// For post partitioning phase, graphs are stored in
|
||||
// options.partition_graphs.
|
||||
for (auto& pg : *options.partition_graphs) {
|
||||
process_graph(&pg.second);
|
||||
}
|
||||
}
|
||||
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // defined(INTEL_MKL) && defined(ENABLE_MKL)
|
||||
40
tensorflow/core/common_runtime/mkl_tfconversion_pass.h
Normal file
40
tensorflow/core/common_runtime/mkl_tfconversion_pass.h
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// An optimization pass that inserts MklToTf conversion nodes in the graph
|
||||
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_MKL_TFCONVERSION_PASS_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_MKL_TFCONVERSION_PASS_H_
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include <sys/types.h>
|
||||
#include <memory>
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
typedef unsigned int uint;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
// Interface to invoke the pass for unit test
|
||||
//
|
||||
// Returns true if and only if 'g' is mutated.
|
||||
extern bool InsertMklToTfConversionNodes(std::unique_ptr<Graph>* g);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_MKL_TFCONVERSION_PASS_H_
|
||||
311
tensorflow/core/common_runtime/mkl_tfconversion_pass_test.cc
Normal file
311
tensorflow/core/common_runtime/mkl_tfconversion_pass_test.cc
Normal file
|
|
@ -0,0 +1,311 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if defined(INTEL_MKL) && defined(ENABLE_MKL)
|
||||
|
||||
#include "tensorflow/core/common_runtime/mkl_tfconversion_pass.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/mkl_graph_util.h"
|
||||
#include "tensorflow/core/graph/testlib.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class MklToTfConversionPass : public ::testing::Test {
|
||||
public:
|
||||
MklToTfConversionPass() : graph_(OpRegistry::Global()) {}
|
||||
|
||||
static void InitGraph(const string& s, Graph* graph) {
|
||||
GraphDef graph_def;
|
||||
|
||||
auto parser = protobuf::TextFormat::Parser();
|
||||
CHECK(parser.MergeFromString(s, &graph_def)) << s;
|
||||
GraphConstructorOptions opts;
|
||||
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
|
||||
}
|
||||
|
||||
void InitGraph(const string& s) {
|
||||
InitGraph(s, &graph_);
|
||||
original_ = CanonicalGraphString(&graph_);
|
||||
}
|
||||
|
||||
static bool IncludeNode(const Node* n) { return n->IsOp(); }
|
||||
|
||||
static string EdgeId(const Node* n, int index) {
|
||||
if (index == 0) {
|
||||
return n->name();
|
||||
} else if (index == Graph::kControlSlot) {
|
||||
return strings::StrCat(n->name(), ":control");
|
||||
} else {
|
||||
return strings::StrCat(n->name(), ":", index);
|
||||
}
|
||||
}
|
||||
|
||||
string CanonicalGraphString(Graph* g) {
|
||||
std::vector<string> nodes;
|
||||
std::vector<string> edges;
|
||||
for (const Node* n : g->nodes()) {
|
||||
if (IncludeNode(n)) {
|
||||
nodes.push_back(strings::StrCat(n->name(), "(", n->type_string(), ")"));
|
||||
}
|
||||
}
|
||||
for (const Edge* e : g->edges()) {
|
||||
if (IncludeNode(e->src()) && IncludeNode(e->dst())) {
|
||||
edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->",
|
||||
EdgeId(e->dst(), e->dst_input())));
|
||||
}
|
||||
}
|
||||
// Canonicalize
|
||||
std::sort(nodes.begin(), nodes.end());
|
||||
std::sort(edges.begin(), edges.end());
|
||||
return strings::StrCat(absl::StrJoin(nodes, ";"), "|",
|
||||
absl::StrJoin(edges, ";"));
|
||||
}
|
||||
|
||||
string DoRunMklToTfConversionPass() {
|
||||
string before = CanonicalGraphString(&graph_);
|
||||
LOG(ERROR) << "Before MklToTf conversion pass: " << before;
|
||||
|
||||
std::unique_ptr<Graph>* ug = new std::unique_ptr<Graph>(&graph_);
|
||||
InsertMklToTfConversionNodes(ug);
|
||||
|
||||
string result = CanonicalGraphString(&graph_);
|
||||
LOG(ERROR) << "After MklToTf conversion pass: " << result;
|
||||
return result;
|
||||
}
|
||||
|
||||
const string& OriginalGraph() const { return original_; }
|
||||
|
||||
Graph graph_;
|
||||
string original_;
|
||||
};
|
||||
|
||||
REGISTER_OP("Float_Input").Output("o: float").SetIsStateful();
|
||||
REGISTER_OP("_Mkl_Input").Output("o: uint8").SetIsStateful();
|
||||
|
||||
TEST_F(MklToTfConversionPass, Basic) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Float_Input'}"
|
||||
"node { name: 'B' op: 'Float_Input'}"
|
||||
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }"
|
||||
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }");
|
||||
EXPECT_EQ(DoRunMklToTfConversionPass(),
|
||||
"A(Float_Input);B(Float_Input);C(Mul);D(Mul)|"
|
||||
"A->C;A->D;B->C:1;B->D:1");
|
||||
}
|
||||
|
||||
// MklConv2D followed by Non-Mkl layer
|
||||
// C=MklConv2D(A,M,B,N); E=Sub(C,D) (for interleaved ordering)
|
||||
// C=MklConv2D(A,B,M,N); E=Sub(C,D) (for contiguous ordering)
|
||||
TEST_F(MklToTfConversionPass, Positive) {
|
||||
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Float_Input'}"
|
||||
"node { name: 'M' op: '_Mkl_Input'}"
|
||||
"node { name: 'B' op: 'Float_Input'}"
|
||||
"node { name: 'N' op: '_Mkl_Input'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } "
|
||||
"}"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'M', 'B', 'N']}"
|
||||
"node { name: 'D' op: 'Float_Input'}"
|
||||
"node { name: 'E' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'D']}");
|
||||
EXPECT_EQ(DoRunMklToTfConversionPass(),
|
||||
"A(Float_Input);B(Float_Input);C(_MklConv2D);D(Float_Input);E("
|
||||
"Sub);M(_Mkl_Input);"
|
||||
"Mkl2Tf/_0(_MklToTf);N(_Mkl_Input)|A->C;B->C:2;C->Mkl2Tf/_0;"
|
||||
"C:1->Mkl2Tf/_0:1;D->E:1;M->C:1;Mkl2Tf/_0->E;N->C:3");
|
||||
} else {
|
||||
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Float_Input'}"
|
||||
"node { name: 'B' op: 'Float_Input'}"
|
||||
"node { name: 'M' op: '_Mkl_Input'}"
|
||||
"node { name: 'N' op: '_Mkl_Input'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } "
|
||||
"}"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B', 'M', 'N']}"
|
||||
"node { name: 'D' op: 'Float_Input'}"
|
||||
"node { name: 'E' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'D']}");
|
||||
EXPECT_EQ(DoRunMklToTfConversionPass(),
|
||||
"A(Float_Input);B(Float_Input);C(_MklConv2D);D(Float_Input);E("
|
||||
"Sub);M(_Mkl_Input);"
|
||||
"Mkl2Tf/_0(_MklToTf);N(_Mkl_Input)|A->C;B->C:1;C->Mkl2Tf/_0;"
|
||||
"C:2->Mkl2Tf/_0:1;D->E:1;M->C:2;Mkl2Tf/_0->E;N->C:3");
|
||||
}
|
||||
}
|
||||
|
||||
// MklConv2D followed by MklToTf op followed by Non-Mkl layer.
|
||||
// C=MklConv2D(A,M,B,N); D=MklToTf(C:0, C:1) F=Sub(D,E) (for interleaved)
|
||||
// C=MklConv2D(A,B,M,N); D=MklToTf(C:0, C:2) F=Sub(D,E) (for contiguous)
|
||||
// MklToTf node should not be inserted again.
|
||||
TEST_F(MklToTfConversionPass, Negative_DoubleInsert) {
|
||||
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Float_Input'}"
|
||||
"node { name: 'M' op: '_Mkl_Input'}"
|
||||
"node { name: 'B' op: 'Float_Input'}"
|
||||
"node { name: 'N' op: '_Mkl_Input'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } "
|
||||
"}"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'M', 'B', 'N']}"
|
||||
"node { name: 'D' op: '_MklToTf'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['C:0', 'C:1']}"
|
||||
"node { name: 'E' op: 'Float_Input'}"
|
||||
"node { name: 'F' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['D', 'E']}");
|
||||
EXPECT_EQ(DoRunMklToTfConversionPass(),
|
||||
"A(Float_Input);B(Float_Input);C(_MklConv2D);D(_MklToTf);E(Float_"
|
||||
"Input);"
|
||||
"F(Sub);M(_Mkl_Input);N(_Mkl_Input)|"
|
||||
"A->C;B->C:2;C->D;C:1->D:1;D->F;E->F:1;M->C:1;N->C:3");
|
||||
} else {
|
||||
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Float_Input'}"
|
||||
"node { name: 'B' op: 'Float_Input'}"
|
||||
"node { name: 'M' op: '_Mkl_Input'}"
|
||||
"node { name: 'N' op: '_Mkl_Input'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } "
|
||||
"}"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B', 'M', 'N']}"
|
||||
"node { name: 'D' op: '_MklToTf'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['C:0', 'C:2']}"
|
||||
"node { name: 'E' op: 'Float_Input'}"
|
||||
"node { name: 'F' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['D', 'E']}");
|
||||
EXPECT_EQ(DoRunMklToTfConversionPass(),
|
||||
"A(Float_Input);B(Float_Input);C(_MklConv2D);D(_MklToTf);E(Float_"
|
||||
"Input);"
|
||||
"F(Sub);M(_Mkl_Input);N(_Mkl_Input)|"
|
||||
"A->C;B->C:1;C->D;C:2->D:1;D->F;E->F:1;M->C:2;N->C:3");
|
||||
}
|
||||
}
|
||||
|
||||
// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y);
|
||||
// There is no Mkl layer so no conversion op should be inserted.
|
||||
TEST_F(MklToTfConversionPass, Negative_NoMklLayer) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Float_Input'}"
|
||||
"node { name: 'B' op: 'Float_Input'}"
|
||||
"node { name: 'C' op: 'Conv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B']}"
|
||||
"node { name: 'D' op: 'Float_Input'}"
|
||||
"node { name: 'E' op: 'BiasAdd'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['C', 'D'] }"
|
||||
"node { name: 'Y' op: 'Float_Input'}"
|
||||
"node { name: 'Z' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['E', 'Y']}");
|
||||
EXPECT_EQ(DoRunMklToTfConversionPass(),
|
||||
"A(Float_Input);B(Float_Input);C(Conv2D);D(Float_Input);E(BiasAdd);"
|
||||
"Y(Float_Input);Z(Sub)|"
|
||||
"A->C;B->C:1;C->E;D->E:1;E->Z;Y->Z:1");
|
||||
}
|
||||
|
||||
static void BM_RunMklToTfConversionPass(int iters, int op_nodes) {
|
||||
testing::StopTiming();
|
||||
string s;
|
||||
for (int in = 0; in < 10; in++) {
|
||||
s += strings::Printf("node { name: 'in%04d' op: 'Float_Input'}", in);
|
||||
}
|
||||
random::PhiloxRandom philox(301, 17);
|
||||
random::SimplePhilox rnd(&philox);
|
||||
for (int op = 0; op < op_nodes; op++) {
|
||||
s += strings::Printf(
|
||||
"node { name: 'op%04d' op: 'Mul' attr { key: 'T' value { "
|
||||
"type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }",
|
||||
op, rnd.Uniform(10), rnd.Uniform(10));
|
||||
}
|
||||
|
||||
bool first = true;
|
||||
while (iters > 0) {
|
||||
Graph* graph = new Graph(OpRegistry::Global());
|
||||
MklToTfConversionPass::InitGraph(s, graph);
|
||||
int N = graph->num_node_ids();
|
||||
if (first) {
|
||||
testing::SetLabel(strings::StrCat("Per graph node. Nodes: ", N));
|
||||
first = false;
|
||||
}
|
||||
{
|
||||
testing::StartTiming();
|
||||
std::unique_ptr<Graph> ug(graph);
|
||||
InsertMklToTfConversionNodes(&ug);
|
||||
testing::StopTiming();
|
||||
}
|
||||
iters -= N; // Our benchmark units are individual graph nodes,
|
||||
// not whole graphs
|
||||
// delete graph;
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_RunMklToTfConversionPass)->Arg(1000)->Arg(10000);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // INTEL_MKL && ENABLE_MKL
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user