mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Supported in this CL:
* Attaching sharding descriptors to HLO ops * Partitioning the HLO graph into per-device computations based on those sharding descriptors. * All operator support for device placement and ops replicated on all devices. * Elementwise op support for tiled shardings. * 2D Convolution support for tiled shardings (no stride or dilation support). PiperOrigin-RevId: 173946036
This commit is contained in:
parent
682a6ed64f
commit
efcbf6e34e
|
|
@ -103,20 +103,17 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel,
|
|||
DeviceNameUtils::ParseFullName(op_kernel->requested_device(), &parsed),
|
||||
errors::Internal("Unable to parse device name: ",
|
||||
op_kernel->requested_device()));
|
||||
xla::OpDeviceAssignment assignment;
|
||||
// If no device ID assignment is found, XLA is free to use whatever device it
|
||||
// wants. In practice this usually has the effect of placing things on
|
||||
// device 0.
|
||||
if (parsed.has_id) {
|
||||
assignment.set_has_device(true);
|
||||
assignment.set_device(parsed.id);
|
||||
b->SetSharding(xla::ShardingBuilder::AssignDevice(parsed.id));
|
||||
}
|
||||
b->SetDeviceAssignment(assignment);
|
||||
|
||||
op_kernel->Compute(context);
|
||||
|
||||
b->ClearOpMetadata();
|
||||
b->ClearDeviceAssignment();
|
||||
b->ClearSharding();
|
||||
VLOG(4) << "Done";
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -40,12 +40,13 @@ template <typename T>
|
|||
class Array {
|
||||
public:
|
||||
// Creates a new array with the specified dimensions.
|
||||
explicit Array(const std::vector<int64>& sizes) : Array(sizes, T()) {}
|
||||
explicit Array(tensorflow::gtl::ArraySlice<int64> sizes)
|
||||
: Array(sizes, T()) {}
|
||||
|
||||
// Creates a new array with the specified dimensions and specified value for
|
||||
// every cell.
|
||||
Array(const std::vector<int64>& sizes, T value)
|
||||
: sizes_(sizes), values_(new T[num_elements()]) {
|
||||
Array(tensorflow::gtl::ArraySlice<int64> sizes, T value)
|
||||
: sizes_(sizes.begin(), sizes.end()), values_(new T[num_elements()]) {
|
||||
Fill(value);
|
||||
}
|
||||
|
||||
|
|
@ -192,6 +193,18 @@ class Array {
|
|||
return values_[calculate_index(indexes)];
|
||||
}
|
||||
|
||||
// Returns the value at the cell specified by the indexes. The number of
|
||||
// arguments have to match with the number of dimensions for the array.
|
||||
const T& operator()(tensorflow::gtl::ArraySlice<int64> indexes) const {
|
||||
return values_[calculate_index(indexes)];
|
||||
}
|
||||
|
||||
// Returns the value at the cell specified by the indexes. The number of
|
||||
// arguments have to match with the number of dimensions for the array.
|
||||
T& operator()(tensorflow::gtl::ArraySlice<int64> indexes) {
|
||||
return values_[calculate_index(indexes)];
|
||||
}
|
||||
|
||||
// Low-level accessor for stuff like memcmp, handle with care. Returns pointer
|
||||
// to the underlying storage of the array (similarly to std::vector::data()).
|
||||
T* data() const {
|
||||
|
|
@ -218,6 +231,11 @@ class Array {
|
|||
std::multiplies<int64>());
|
||||
}
|
||||
|
||||
const T* begin() const { return &values_[0]; }
|
||||
T* begin() { return &values_[0]; }
|
||||
const T* end() const { return &values_[num_elements()]; }
|
||||
T* end() { return &values_[num_elements()]; }
|
||||
|
||||
bool operator==(const Array<T>& other) const {
|
||||
if (sizes_.size() != other.sizes_.size()) {
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -170,6 +170,7 @@ cc_library(
|
|||
":computation",
|
||||
":global_data",
|
||||
":padding",
|
||||
"//tensorflow/compiler/xla:array",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:array3d",
|
||||
"//tensorflow/compiler/xla:array4d",
|
||||
|
|
|
|||
|
|
@ -1794,14 +1794,9 @@ StatusOr<Computation> ComputationBuilder::Build() {
|
|||
|
||||
void ComputationBuilder::AddCommonFieldsToOpRequest(OpRequest* request) const {
|
||||
*request->mutable_metadata() = metadata_;
|
||||
*request->mutable_device_assignment() = device_assignment_;
|
||||
}
|
||||
|
||||
void ComputationBuilder::ClearDeviceAssignment() { device_assignment_.Clear(); }
|
||||
|
||||
void ComputationBuilder::SetDeviceAssignment(
|
||||
const OpDeviceAssignment& assignment) {
|
||||
device_assignment_ = assignment;
|
||||
if (sharding_) {
|
||||
*request->mutable_sharding() = *sharding_;
|
||||
}
|
||||
}
|
||||
|
||||
/* static */ ConvolutionDimensionNumbers
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/array.h"
|
||||
#include "tensorflow/compiler/xla/array2d.h"
|
||||
#include "tensorflow/compiler/xla/array3d.h"
|
||||
#include "tensorflow/compiler/xla/array4d.h"
|
||||
|
|
@ -42,6 +43,58 @@ limitations under the License.
|
|||
|
||||
namespace xla {
|
||||
|
||||
class ShardingBuilder {
|
||||
public:
|
||||
// A shaped array used to describe the assignment of tiles to devices.
|
||||
using TileAssignment = Array<int64>;
|
||||
|
||||
// Creates a replicated sharding - replicate a tensor on every device.
|
||||
static OpSharding Replicate() {
|
||||
OpSharding result;
|
||||
result.set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
|
||||
return result;
|
||||
}
|
||||
// Creates a sharding that assigns a tensor to just one device.
|
||||
static OpSharding AssignDevice(int device) {
|
||||
OpSharding result;
|
||||
result.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
|
||||
result.add_tile_assignment_dimensions(1);
|
||||
result.add_tile_assignment_devices(device);
|
||||
return result;
|
||||
}
|
||||
// Creates a tiled sharding with the given tile shape and assignment of tiles
|
||||
// to devices.
|
||||
static OpSharding Tile(Shape tile_shape,
|
||||
const TileAssignment& tile_assignment) {
|
||||
OpSharding result;
|
||||
result.set_type(OpSharding::Type::OpSharding_Type_OTHER);
|
||||
for (int64 dim : tile_assignment.dimensions()) {
|
||||
result.add_tile_assignment_dimensions(dim);
|
||||
}
|
||||
for (uint32 device : tile_assignment) {
|
||||
result.add_tile_assignment_devices(device);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
// Creates a sharding in one dimension, with the given tile shape which must
|
||||
// be rank 1 and using devices 0..num_tiles.
|
||||
static OpSharding Tile1D(Shape tile_shape, int64 num_tiles) {
|
||||
OpSharding result;
|
||||
result.set_type(OpSharding::Type::OpSharding_Type_OTHER);
|
||||
|
||||
CHECK_EQ(ShapeUtil::Rank(tile_shape), 1);
|
||||
std::vector<int64> dimensions(1, num_tiles);
|
||||
auto& tile_dimension = (*tile_shape.mutable_dimensions())[0];
|
||||
tile_dimension = CeilOfRatio(static_cast<int64>(tile_dimension), num_tiles);
|
||||
*result.mutable_tile_shape() = tile_shape;
|
||||
result.add_tile_assignment_dimensions(num_tiles);
|
||||
for (int64 i = 0; i < num_tiles; ++i) {
|
||||
result.add_tile_assignment_devices(i);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
// Wraps an XLA client with a convenient interface for building up
|
||||
// computations. Any errors encountered in building up the computation are
|
||||
// deferred from being handled until Build() is called.
|
||||
|
|
@ -78,11 +131,11 @@ class ComputationBuilder {
|
|||
|
||||
// Sets an OpDeviceAssignment that will be attached to all instructions
|
||||
// until cleared.
|
||||
void SetDeviceAssignment(const OpDeviceAssignment& assignment);
|
||||
void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
|
||||
|
||||
// Clears the device assignment. Ops will be placed according to the default
|
||||
// placement policy.
|
||||
void ClearDeviceAssignment();
|
||||
void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; }
|
||||
|
||||
// Sets the builder to a mode where it will die immediately when an error is
|
||||
// encountered, rather than producing it in a deferred fashion when Build() is
|
||||
|
|
@ -894,8 +947,9 @@ class ComputationBuilder {
|
|||
// throughout the TensorFlow op kernel implementations).
|
||||
OpMetadata metadata_;
|
||||
|
||||
// Device assignment for the operator.
|
||||
OpDeviceAssignment device_assignment_;
|
||||
// Sharding for this operator. This is structured as a "model"-like operation,
|
||||
// in order to simplify client code, similar to metadata_.
|
||||
tensorflow::gtl::optional<OpSharding> sharding_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ComputationBuilder);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -133,6 +133,7 @@ cc_library(
|
|||
"hlo_instruction.cc",
|
||||
"hlo_module.cc",
|
||||
"hlo_opcode.cc",
|
||||
"hlo_sharding.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"dfs_hlo_visitor.h",
|
||||
|
|
@ -141,6 +142,7 @@ cc_library(
|
|||
"hlo_instruction.h",
|
||||
"hlo_module.h",
|
||||
"hlo_opcode.h",
|
||||
"hlo_sharding.h",
|
||||
],
|
||||
deps = [
|
||||
":hlo_module_config",
|
||||
|
|
@ -148,6 +150,7 @@ cc_library(
|
|||
":hlo_reachability",
|
||||
":name_uniquer",
|
||||
":versioned_computation_handle",
|
||||
"//tensorflow/compiler/xla:array",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:protobuf_util",
|
||||
"//tensorflow/compiler/xla:shape_tree",
|
||||
|
|
@ -238,6 +241,22 @@ tf_cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "hlo_sharding_test",
|
||||
srcs = ["hlo_sharding_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:protobuf_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "call_graph",
|
||||
srcs = ["call_graph.cc"],
|
||||
|
|
|
|||
|
|
@ -56,7 +56,6 @@ std::unique_ptr<HloComputation> HloComputation::Builder::Build(
|
|||
HloInstruction* root =
|
||||
root_instruction ? root_instruction : last_added_instruction_;
|
||||
CHECK_NE(nullptr, root);
|
||||
|
||||
return WrapUnique(new HloComputation(name_, parameter_count, &instructions_,
|
||||
root, fusion_instruction_));
|
||||
}
|
||||
|
|
@ -735,6 +734,10 @@ std::unique_ptr<HloComputation> HloComputation::Clone(const string& suffix) {
|
|||
}
|
||||
|
||||
new_instr = instr->CloneWithNewOperands(instr->shape(), new_operands);
|
||||
new_instr->set_metadata(instr->metadata());
|
||||
if (instr->has_sharding()) {
|
||||
new_instr->set_sharding(instr->sharding());
|
||||
}
|
||||
InsertOrDie(&clone_map, instr, new_instr.get());
|
||||
instructions.push_back(std::move(new_instr));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1039,8 +1039,8 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
|
|||
if (!opcode_specific_info.empty()) {
|
||||
lines.push_back(opcode_specific_info);
|
||||
}
|
||||
if (instr->device_assignment().has_device()) {
|
||||
lines.push_back(StrCat("device=", instr->device_assignment().device()));
|
||||
if (instr->has_sharding()) {
|
||||
lines.push_back(StrCat("sharding=", instr->sharding().ToString()));
|
||||
}
|
||||
// Show the shape and layout of the instruction, unless it's an inlined fusion
|
||||
// node -- there the shape and layout is present in the output node.
|
||||
|
|
|
|||
|
|
@ -1212,6 +1212,9 @@ std::unique_ptr<HloInstruction> HloInstruction::Clone(
|
|||
}
|
||||
}
|
||||
clone->set_parent(parent_);
|
||||
if (has_sharding()) {
|
||||
clone->set_sharding(sharding());
|
||||
}
|
||||
return clone;
|
||||
}
|
||||
|
||||
|
|
@ -1889,8 +1892,8 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
|
|||
if (opcode() == HloOpcode::kGetTupleElement) {
|
||||
extra.push_back(StrCat("index=", tuple_index()));
|
||||
}
|
||||
if (device_assignment_.has_device()) {
|
||||
extra.push_back(StrCat("device=", device_assignment_.device()));
|
||||
if (has_sharding()) {
|
||||
extra.push_back(StrCat("sharding=", sharding().ToString()));
|
||||
}
|
||||
if (!control_successors_.empty()) {
|
||||
extra.push_back(StrCat(
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||
#include "tensorflow/compiler/xla/service/name_uniquer.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
|
@ -713,6 +714,26 @@ class HloInstruction {
|
|||
fusion_kind_ = kind;
|
||||
}
|
||||
|
||||
// Returns the sharding applied to this operator.
|
||||
// REQUIRES: has_sharding() is true.
|
||||
const HloSharding& sharding() const {
|
||||
CHECK(has_sharding());
|
||||
return *sharding_;
|
||||
}
|
||||
// Returns the sharding applied to this operator, or default_ if none exists.
|
||||
const HloSharding& sharding_or_default(const HloSharding& default_) const {
|
||||
return sharding_ ? *sharding_ : default_;
|
||||
}
|
||||
// Sets the sharding of this operator. Should only be called by HloModule or
|
||||
// HloComputation methods.
|
||||
void set_sharding(const HloSharding& sharding) {
|
||||
sharding_ = MakeUnique<HloSharding>(sharding);
|
||||
}
|
||||
// Remove any sharding from this operator.
|
||||
void clear_sharding() { sharding_ = nullptr; }
|
||||
// Return true if this operator has a sharding assigned.
|
||||
bool has_sharding() const { return sharding_ != nullptr; }
|
||||
|
||||
// Merges the fused instructions from 'instruction_to_merge' into the
|
||||
// fused instruction set of 'this', updating operands as necessary.
|
||||
//
|
||||
|
|
@ -984,14 +1005,6 @@ class HloInstruction {
|
|||
void RelayoutConstant(const Layout& new_layout,
|
||||
const ShapeIndex& shape_index = {});
|
||||
|
||||
// Gets/sets the device assignment.
|
||||
const OpDeviceAssignment& device_assignment() const {
|
||||
return device_assignment_;
|
||||
}
|
||||
void set_device_assignment(const OpDeviceAssignment& device_assignment) {
|
||||
device_assignment_ = device_assignment;
|
||||
}
|
||||
|
||||
private:
|
||||
enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse };
|
||||
|
||||
|
|
@ -1124,6 +1137,9 @@ class HloInstruction {
|
|||
// The type of the fusion. Used by kFusion only.
|
||||
FusionKind fusion_kind_;
|
||||
|
||||
// The sharding, if one exists.
|
||||
std::unique_ptr<HloSharding> sharding_;
|
||||
|
||||
// For parameter instructions this field holds the parameter number.
|
||||
int64 parameter_number_ = 0;
|
||||
string parameter_name_;
|
||||
|
|
@ -1184,9 +1200,6 @@ class HloInstruction {
|
|||
// outer-most dimension first).
|
||||
std::vector<int64> outer_dimension_partitions_;
|
||||
|
||||
// Device assignment for the instruction.
|
||||
OpDeviceAssignment device_assignment_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction);
|
||||
};
|
||||
|
||||
|
|
|
|||
232
tensorflow/compiler/xla/service/hlo_sharding.cc
Normal file
232
tensorflow/compiler/xla/service/hlo_sharding.cc
Normal file
|
|
@ -0,0 +1,232 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
using ::tensorflow::strings::StrCat;
|
||||
|
||||
HloSharding HloSharding::AssignDevice(int64 device_id) {
|
||||
return HloSharding(device_id);
|
||||
}
|
||||
|
||||
HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) {
|
||||
CHECK_EQ(1, ShapeUtil::Rank(input_shape));
|
||||
CHECK_GT(num_tiles, 1);
|
||||
std::vector<int64> dimensions(1, num_tiles);
|
||||
Shape tile_shape = input_shape;
|
||||
auto& tile_dimension = (*tile_shape.mutable_dimensions())[0];
|
||||
tile_dimension = CeilOfRatio(static_cast<int64>(tile_dimension), num_tiles);
|
||||
Array<int64> assignment(dimensions);
|
||||
std::iota(assignment.begin(), assignment.end(), 0);
|
||||
return HloSharding(tile_shape, assignment);
|
||||
}
|
||||
|
||||
string HloSharding::ToString() const {
|
||||
string result = StrCat("{", (replicated_ ? " replicated" : ""),
|
||||
(maximal_ ? " maximal" : ""));
|
||||
|
||||
if (replicated_) {
|
||||
return "{replicated}";
|
||||
} else if (maximal_) {
|
||||
return StrCat(
|
||||
"{maximal device=", static_cast<int64>(*tile_assignment_.begin()), "}");
|
||||
} else {
|
||||
return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ",
|
||||
"devices=", VectorString(tile_assignment_), "}");
|
||||
}
|
||||
}
|
||||
|
||||
bool HloSharding::UsesDevice(int64 device) const {
|
||||
const auto& devices = tile_assignment_;
|
||||
return replicated_ ||
|
||||
std::find(devices.begin(), devices.end(), device) != devices.end();
|
||||
}
|
||||
|
||||
std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
|
||||
CHECK(!ShapeUtil::IsTuple(tile_shape_));
|
||||
CHECK(!maximal_);
|
||||
std::vector<int64> ret_index;
|
||||
tile_assignment_.Each([&](tensorflow::gtl::ArraySlice<int64> index, int64 d) {
|
||||
if (d == device) {
|
||||
ret_index = {index.begin(), index.end()};
|
||||
}
|
||||
});
|
||||
CHECK(!ret_index.empty());
|
||||
return ret_index;
|
||||
}
|
||||
|
||||
int64 HloSharding::DeviceForTileIndex(
|
||||
tensorflow::gtl::ArraySlice<int64> index) const {
|
||||
CHECK(!replicated_);
|
||||
if (maximal_) {
|
||||
return *tile_assignment_.begin();
|
||||
}
|
||||
CHECK_EQ(ShapeUtil::Rank(tile_shape_), tile_assignment_.dimensions().size());
|
||||
return tile_assignment_(index);
|
||||
}
|
||||
|
||||
std::vector<int64> HloSharding::TileOffsetForDevice(int64 device) const {
|
||||
CHECK(!ShapeUtil::IsTuple(tile_shape_));
|
||||
|
||||
std::vector<int64> index = TileIndexForDevice(device);
|
||||
if (maximal_) {
|
||||
// Index will always be all zeroes if we're maximal, and tile_shape_ is not
|
||||
// valid.
|
||||
return index;
|
||||
}
|
||||
for (int64 i = 0; i < index.size(); ++i) {
|
||||
index[i] *= tile_shape_.dimensions(i);
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
std::vector<int64> HloSharding::TileLimitForDevice(int64 device) const {
|
||||
CHECK(!ShapeUtil::IsTuple(tile_shape_));
|
||||
CHECK(!maximal_); // Maximal shardings do not have a valid tile shape.
|
||||
|
||||
std::vector<int64> index = TileIndexForDevice(device);
|
||||
for (int64 i = 0; i < index.size(); ++i) {
|
||||
index[i] = (index[i] + 1) * tile_shape_.dimensions(i);
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
StatusOr<int64> HloSharding::UniqueDevice() const {
|
||||
if (!replicated_ && maximal_) {
|
||||
return static_cast<int64>(*tile_assignment_.begin());
|
||||
}
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"UniqueDevice() called on sharding that executes on multiple devices");
|
||||
}
|
||||
|
||||
Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
|
||||
if (replicated_) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// All tile assignments must be less than the number of available cores and
|
||||
// unique.
|
||||
Status status = Status::OK();
|
||||
std::set<int64> seen_cores;
|
||||
tile_assignment_.Each(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> indices, uint32 core) {
|
||||
// Don't overwrite a bad status, so we report the first error.
|
||||
if (status.ok()) {
|
||||
if (core >= num_devices) {
|
||||
status =
|
||||
tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat(
|
||||
"core ", core, " > ", num_devices, " in tile assignment"));
|
||||
} else if (seen_cores.count(core) != 0) {
|
||||
status =
|
||||
tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat(
|
||||
"core ", core, " is not unique in tile assignment"));
|
||||
}
|
||||
}
|
||||
seen_cores.insert(core);
|
||||
});
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
if (IsTileMaximal()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// The tile rank must be the same as the input rank.
|
||||
if (ShapeUtil::Rank(shape) != ShapeUtil::Rank(tile_shape_)) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Tile rank is different to the input rank");
|
||||
}
|
||||
|
||||
// The tile shape must not be the same as the input shape without maximal_
|
||||
// also set. If this is the case, we're not actually sharded and the correct
|
||||
// constructor should have been used.
|
||||
if (ShapeUtil::Equal(shape, tile_shape_)) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Tile shape is the same as the input shape. If a replicated sharding "
|
||||
"was intended, use HloSharding::Replicated(). If a device placement "
|
||||
"was intended, use HloSharding::AssignDevice()");
|
||||
}
|
||||
|
||||
// The tile shape must not be greater than the input shape in any dimension.
|
||||
for (int64 i = 0, e = ShapeUtil::Rank(shape); i != e; ++i) {
|
||||
auto tile_dim = tile_shape_.dimensions(i);
|
||||
auto shape_dim = shape.dimensions(i);
|
||||
if (tile_dim > shape_dim) {
|
||||
return tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat(
|
||||
"Tile is larger than input shape (dimension ", i, ", ", tile_dim,
|
||||
" > ", shape_dim));
|
||||
}
|
||||
}
|
||||
|
||||
// The tile assignment tensor must be exactly dimensioned to ceil(shape[dim]
|
||||
// tile[dim]) for every dimension contained within tile.
|
||||
for (int64 i = 0, e = tile_assignment_.dimensions().size(); i != e; ++i) {
|
||||
int64 expected_dim =
|
||||
CeilOfRatio(shape.dimensions(i), tile_shape_.dimensions(i));
|
||||
if (tile_assignment_.dimensions()[i] != expected_dim) {
|
||||
return tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat(
|
||||
"Tile assignment tensor has incorrect shape. Dimension ", i,
|
||||
" expected ", expected_dim, " but got ",
|
||||
tile_assignment_.dimensions()[i]));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/*static*/ StatusOr<HloSharding> HloSharding::FromProto(
|
||||
const OpSharding& proto) {
|
||||
if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
|
||||
return Replicate();
|
||||
} else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL) {
|
||||
return HloSharding(proto.tile_assignment_devices(0));
|
||||
}
|
||||
// Some versions of gcc cannot infer the TileAssignment constructor from a
|
||||
// braced initializer-list, so create one manually.
|
||||
std::vector<int64> devices(proto.tile_assignment_devices().begin(),
|
||||
proto.tile_assignment_devices().end());
|
||||
Array<int64> tile_assignment(
|
||||
std::vector<int64>(proto.tile_assignment_dimensions().begin(),
|
||||
proto.tile_assignment_dimensions().end()));
|
||||
std::copy(proto.tile_assignment_devices().begin(),
|
||||
proto.tile_assignment_devices().end(), tile_assignment.begin());
|
||||
return HloSharding(proto.tile_shape(), tile_assignment);
|
||||
}
|
||||
|
||||
OpSharding HloSharding::ToProto() const {
|
||||
OpSharding result;
|
||||
*result.mutable_tile_shape() = tile_shape_;
|
||||
for (int64 dim : tile_assignment_.dimensions()) {
|
||||
result.add_tile_assignment_dimensions(dim);
|
||||
}
|
||||
for (auto device : tile_assignment_) {
|
||||
result.add_tile_assignment_devices(device);
|
||||
}
|
||||
if (IsReplicated()) {
|
||||
result.set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
|
||||
} else if (IsTileMaximal()) {
|
||||
result.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
|
||||
} else {
|
||||
result.set_type(OpSharding::Type::OpSharding_Type_OTHER);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
165
tensorflow/compiler/xla/service/hlo_sharding.h
Normal file
165
tensorflow/compiler/xla/service/hlo_sharding.h
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// HLO shardings describe how an HLO instruction is split across multiple
|
||||
// computations.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/compiler/xla/array.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/protobuf_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// HLO shardings describe how an HLO instruction is split across multiple
|
||||
// computations.
|
||||
class HloSharding {
|
||||
public:
|
||||
// Creates a trivial sharding that replicates a maximal tile across all
|
||||
// devices.
|
||||
static HloSharding Replicate() { return HloSharding(); }
|
||||
|
||||
// Creates a sharding that emulates device placement; a tile shape equal to
|
||||
// the input shape (one tile) assigned to a single device.
|
||||
static HloSharding AssignDevice(int64 device_id);
|
||||
|
||||
// Creates a new sharding which splits a shape into tiles each with shape
|
||||
// `tile_shape`. Each tile is assigned to one device, which is specified by
|
||||
// `tile_assignment`. Any tensor not a multiple of the tile size in any
|
||||
// dimension is implicitly padded to the tile size.
|
||||
//
|
||||
// e.g. Tile({2, 2}, {0, 1}) on a tensor of shape {3, 2} would look like:
|
||||
// 2 1 padding
|
||||
// <------><->
|
||||
// +----+----+
|
||||
// | 0 | 1 |
|
||||
// +----+----+
|
||||
//
|
||||
// Split into two tiles, one of which is implicitly padded by one.
|
||||
static HloSharding Tile(const Shape& tile_shape,
|
||||
const Array<int64>& tile_assignment) {
|
||||
return HloSharding(tile_shape, tile_assignment);
|
||||
}
|
||||
|
||||
// Creates a new sharding which splits a one-dimensional input shape into
|
||||
// `num_tiles` tiles.
|
||||
static HloSharding Tile1D(const Shape& input_shape, int64 num_tiles);
|
||||
|
||||
// Create a new sharding from a protobuf OpSharding.
|
||||
static StatusOr<HloSharding> FromProto(const OpSharding& proto);
|
||||
|
||||
OpSharding ToProto() const;
|
||||
string ToString() const;
|
||||
|
||||
// Validate that this sharding can be applied to a tensor with shape `shape`.
|
||||
Status Validate(const Shape& shape, int64 num_devices) const;
|
||||
|
||||
// Returns true if the sharding is trivial: replicate on all devices.
|
||||
bool IsReplicated() const { return replicated_; }
|
||||
|
||||
// Returns true if the tile size is the same as the input size.
|
||||
bool IsTileMaximal() const { return maximal_; }
|
||||
|
||||
// Returns true if the sharding defines an operation on the given device.
|
||||
bool UsesDevice(int64 device) const;
|
||||
|
||||
// Returns the tile that should be executed on the given device.
|
||||
std::vector<int64> TileIndexForDevice(int64 device) const;
|
||||
|
||||
// Returns the device that should execute the given tile.
|
||||
// It is an error to call this if is_replicated() is true.
|
||||
int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice<int64> index) const;
|
||||
|
||||
// Given a device ID, returns the offset within the input space of the
|
||||
// tile that should be executed on the given core. This returns the lower
|
||||
// extent of the tile in the input space.
|
||||
std::vector<int64> TileOffsetForDevice(int64 device) const;
|
||||
|
||||
// Given a device ID, returns the limit within the input space of the
|
||||
// tile that should be executed on the given core. This returns the upper
|
||||
// extent of the tile in the input space.
|
||||
std::vector<int64> TileLimitForDevice(int64 device) const;
|
||||
|
||||
// Returns the single device this op operates on.
|
||||
// Requires !Replicated() && IsTileMaximal().
|
||||
StatusOr<int64> UniqueDevice() const;
|
||||
|
||||
// Returns true if this op only uses a single device.
|
||||
bool HasUniqueDevice() const { return !IsReplicated() && IsTileMaximal(); }
|
||||
|
||||
bool operator==(const HloSharding& other) const {
|
||||
return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
|
||||
protobuf_util::ProtobufEquals(tile_shape_, other.tile_shape_) &&
|
||||
tile_assignment_ == other.tile_assignment_;
|
||||
}
|
||||
bool operator!=(const HloSharding& other) const { return !(*this == other); }
|
||||
|
||||
size_t Hash() const {
|
||||
if (replicated_) {
|
||||
return 0;
|
||||
}
|
||||
size_t h = 0;
|
||||
for (uint32 v : tile_assignment_) {
|
||||
h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
|
||||
}
|
||||
for (uint32 v : tile_shape_.dimensions()) {
|
||||
h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
|
||||
}
|
||||
return h;
|
||||
}
|
||||
|
||||
// Gets the tile shape.
|
||||
// It is an error to call this if IsTileMaximal() is true.
|
||||
const Shape& tile_shape() const { return tile_shape_; }
|
||||
// Gets the tile assignment tensor.
|
||||
// It is an error to call this if IsReplicated() is true.
|
||||
const Array<int64>& tile_assignment() const { return tile_assignment_; }
|
||||
|
||||
private:
|
||||
HloSharding()
|
||||
: replicated_(true),
|
||||
maximal_(true),
|
||||
tile_shape_(),
|
||||
tile_assignment_({0}) {}
|
||||
explicit HloSharding(int64 device_id)
|
||||
: replicated_(false),
|
||||
maximal_(true),
|
||||
tile_shape_(),
|
||||
tile_assignment_({1}, device_id) {}
|
||||
HloSharding(const Shape& tile_shape, const Array<int64>& tile_assignment)
|
||||
: replicated_(false),
|
||||
maximal_(false),
|
||||
tile_shape_(tile_shape),
|
||||
tile_assignment_(tile_assignment) {}
|
||||
|
||||
bool replicated_;
|
||||
bool maximal_;
|
||||
Shape tile_shape_;
|
||||
Array<int64> tile_assignment_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
|
||||
190
tensorflow/compiler/xla/service/hlo_sharding_test.cc
Normal file
190
tensorflow/compiler/xla/service/hlo_sharding_test.cc
Normal file
|
|
@ -0,0 +1,190 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
Array<int64> MakeArray(tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> contents) {
|
||||
Array<int64> a(dimensions);
|
||||
std::copy(contents.begin(), contents.end(), a.begin());
|
||||
return a;
|
||||
}
|
||||
|
||||
class HloShardingTest : public HloTestBase {};
|
||||
|
||||
TEST_F(HloShardingTest, Replicate) {
|
||||
Shape tile_shape = ShapeUtil::MakeShape(U32, {4});
|
||||
HloSharding sharding = HloSharding::Replicate();
|
||||
EXPECT_TRUE(sharding.IsReplicated());
|
||||
EXPECT_TRUE(sharding.IsTileMaximal());
|
||||
EXPECT_TRUE(sharding.UsesDevice(0));
|
||||
EXPECT_TRUE(sharding.UsesDevice(65535));
|
||||
|
||||
HloSharding other = HloSharding::Replicate();
|
||||
EXPECT_EQ(other, sharding);
|
||||
|
||||
EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}),
|
||||
/*num_devices=*/2));
|
||||
EXPECT_IS_NOT_OK(sharding.UniqueDevice());
|
||||
}
|
||||
|
||||
TEST_F(HloShardingTest, DevicePlacement) {
|
||||
HloSharding sharding = HloSharding::AssignDevice(5);
|
||||
EXPECT_FALSE(sharding.IsReplicated());
|
||||
EXPECT_TRUE(sharding.IsTileMaximal());
|
||||
EXPECT_FALSE(sharding.UsesDevice(0));
|
||||
EXPECT_TRUE(sharding.UsesDevice(5));
|
||||
EXPECT_EQ(5, sharding.UniqueDevice().ValueOrDie());
|
||||
|
||||
HloSharding other = HloSharding::Replicate();
|
||||
EXPECT_NE(other, sharding);
|
||||
|
||||
EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}),
|
||||
/*num_devices=*/6));
|
||||
EXPECT_IS_NOT_OK(
|
||||
sharding.Validate(ShapeUtil::MakeShape(U32, {4}), /*num_devices=*/5));
|
||||
}
|
||||
|
||||
TEST_F(HloShardingTest, Tile) {
|
||||
{
|
||||
// Test should fail because of a duplicate tile assignment.
|
||||
Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
|
||||
HloSharding sharding =
|
||||
HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 0, 2, 3}));
|
||||
EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {4, 6}),
|
||||
/*num_devices=*/4));
|
||||
}
|
||||
|
||||
{
|
||||
// Test should pass.
|
||||
Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
|
||||
HloSharding sharding =
|
||||
HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3}));
|
||||
EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4, 6}),
|
||||
/*num_devices=*/2));
|
||||
}
|
||||
|
||||
{
|
||||
// Test should fail due to the tile being larger than the input space.
|
||||
Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
|
||||
HloSharding sharding =
|
||||
HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3}));
|
||||
EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {2, 2}),
|
||||
/*num_devices=*/4));
|
||||
}
|
||||
|
||||
{
|
||||
// Test should fail due to the tile not dividing the input space into 4
|
||||
// sections (even with padding).
|
||||
Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
|
||||
HloSharding sharding =
|
||||
HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3}));
|
||||
EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {6, 3}),
|
||||
/*num_devices=*/4));
|
||||
}
|
||||
|
||||
{
|
||||
// Test should pass.
|
||||
Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
|
||||
HloSharding sharding =
|
||||
HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1}));
|
||||
EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {3, 5}),
|
||||
/*num_devices=*/5));
|
||||
|
||||
EXPECT_EQ(0, sharding.DeviceForTileIndex({0, 0}));
|
||||
EXPECT_EQ(3, sharding.DeviceForTileIndex({0, 1}));
|
||||
EXPECT_EQ(2, sharding.DeviceForTileIndex({1, 0}));
|
||||
EXPECT_EQ(1, sharding.DeviceForTileIndex({1, 1}));
|
||||
|
||||
EXPECT_EQ(sharding.TileOffsetForDevice(0), (std::vector<int64>{0, 0}));
|
||||
EXPECT_EQ(sharding.TileOffsetForDevice(3), (std::vector<int64>{0, 3}));
|
||||
EXPECT_EQ(sharding.TileOffsetForDevice(2), (std::vector<int64>{2, 0}));
|
||||
EXPECT_EQ(sharding.TileOffsetForDevice(1), (std::vector<int64>{2, 3}));
|
||||
|
||||
EXPECT_IS_NOT_OK(sharding.UniqueDevice());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(HloShardingTest, Hash) {
|
||||
auto hash_compare_equal = [](const HloSharding& a, const HloSharding& b) {
|
||||
if (a.Hash() != b.Hash()) {
|
||||
return false;
|
||||
}
|
||||
return a == b;
|
||||
};
|
||||
|
||||
{
|
||||
HloSharding sharding1 = HloSharding::Replicate();
|
||||
HloSharding sharding2 = HloSharding::Replicate();
|
||||
EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
|
||||
}
|
||||
|
||||
{
|
||||
HloSharding sharding1 = HloSharding::AssignDevice(1);
|
||||
HloSharding sharding2 = HloSharding::AssignDevice(1);
|
||||
EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
|
||||
}
|
||||
|
||||
{
|
||||
HloSharding sharding1 = HloSharding::AssignDevice(1);
|
||||
HloSharding sharding2 = HloSharding::AssignDevice(2);
|
||||
EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
|
||||
}
|
||||
|
||||
{
|
||||
Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
|
||||
HloSharding sharding1 =
|
||||
HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1}));
|
||||
HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}),
|
||||
MakeArray({2, 2}, {0, 3, 2, 1}));
|
||||
EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
|
||||
}
|
||||
|
||||
{
|
||||
Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
|
||||
HloSharding sharding1 =
|
||||
HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1}));
|
||||
HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}),
|
||||
MakeArray({2, 2}, {0, 3, 2, 1}));
|
||||
EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
|
||||
}
|
||||
|
||||
{
|
||||
Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
|
||||
HloSharding sharding1 =
|
||||
HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1}));
|
||||
HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}),
|
||||
MakeArray({2, 2}, {0, 3, 1, 2}));
|
||||
EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
|
@ -198,9 +198,10 @@ Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) {
|
|||
NodeDef* node_def = graph_def_.add_node();
|
||||
node_def->set_name(GetNodeNameForInstruction(instruction));
|
||||
node_def->set_op(GetOpDefName(instruction));
|
||||
if (instruction->device_assignment().has_device()) {
|
||||
node_def->set_device(
|
||||
GetDeviceName(instruction->device_assignment().device()));
|
||||
if (instruction->has_sharding() &&
|
||||
instruction->sharding().HasUniqueDevice()) {
|
||||
TF_ASSIGN_OR_RETURN(int64 device, instruction->sharding().UniqueDevice());
|
||||
node_def->set_device(GetDeviceName(device));
|
||||
}
|
||||
SetNodeAttrs(instruction, node_def);
|
||||
if (instruction->opcode() == HloOpcode::kFusion) {
|
||||
|
|
|
|||
|
|
@ -1415,9 +1415,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
|
|||
// proto in the above switch statement.
|
||||
TF_ASSIGN_OR_RETURN(ComputationDataHandle handle, handle_status);
|
||||
TF_RETURN_IF_ERROR(computation->SetOpMetadata(handle, arg->metadata()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation->SetOpDeviceAssignment(handle, arg->device_assignment()));
|
||||
|
||||
if (arg->has_sharding()) {
|
||||
TF_RETURN_IF_ERROR(computation->SetOpSharding(handle, arg->sharding()));
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1315,20 +1315,19 @@ Status UserComputation::SetOpMetadata(const ComputationDataHandle& handle,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status UserComputation::SetOpDeviceAssignment(
|
||||
const ComputationDataHandle& handle,
|
||||
const OpDeviceAssignment& device_assignment) {
|
||||
Status UserComputation::SetOpSharding(const ComputationDataHandle& handle,
|
||||
const OpSharding& sharding) {
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
|
||||
int64 handle_value = handle.handle();
|
||||
if (session_computation_.requests().count(handle_value) == 0) {
|
||||
return InvalidArgument("Invalid handle in SetOpDeviceAssignment (%lld)",
|
||||
return InvalidArgument("Invalid handle in SetOpSharding (%lld)",
|
||||
handle_value);
|
||||
}
|
||||
*session_computation_.mutable_requests()
|
||||
->at(handle_value)
|
||||
.mutable_request()
|
||||
->mutable_device_assignment() = device_assignment;
|
||||
->mutable_sharding() = sharding;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -2518,7 +2517,9 @@ HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast(
|
|||
if (ShapeUtil::IsScalar(operand->shape())) {
|
||||
HloInstruction* broadcast = hlo_builder_.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(broadcast_shape, operand, {}));
|
||||
broadcast->set_device_assignment(operand->device_assignment());
|
||||
if (operand->has_sharding()) {
|
||||
broadcast->set_sharding(operand->sharding());
|
||||
}
|
||||
return broadcast;
|
||||
}
|
||||
// Do explicit broadcast for degenerate broadcast.
|
||||
|
|
@ -2536,12 +2537,16 @@ HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast(
|
|||
ShapeUtil::MakeShape(operand->shape().element_type(),
|
||||
reshaped_dimensions),
|
||||
operand));
|
||||
reshaped_operand->set_device_assignment(operand->device_assignment());
|
||||
if (operand->has_sharding()) {
|
||||
reshaped_operand->set_sharding(operand->sharding());
|
||||
}
|
||||
// Broadcast 'reshape' up to the larger size.
|
||||
HloInstruction* broadcast =
|
||||
hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast(
|
||||
broadcast_shape, reshaped_operand, broadcast_dimensions));
|
||||
broadcast->set_device_assignment(operand->device_assignment());
|
||||
if (operand->has_sharding()) {
|
||||
broadcast->set_sharding(operand->sharding());
|
||||
}
|
||||
return broadcast;
|
||||
}
|
||||
|
||||
|
|
@ -2556,8 +2561,11 @@ void ComputationLowerer::Visit(
|
|||
HloInstruction* hlo_instruction =
|
||||
hlo_builder_.AddInstruction(std::move(instruction));
|
||||
hlo_instruction->set_metadata(request.request().metadata());
|
||||
hlo_instruction->set_device_assignment(
|
||||
request.request().device_assignment());
|
||||
if (request.request().has_sharding()) {
|
||||
OpSharding op_sharding = request.request().sharding();
|
||||
hlo_instruction->set_sharding(
|
||||
HloSharding::FromProto(op_sharding).ValueOrDie());
|
||||
}
|
||||
return hlo_instruction;
|
||||
};
|
||||
auto lookup_instruction = [&](const ComputationDataHandle& handle) {
|
||||
|
|
|
|||
|
|
@ -262,8 +262,8 @@ class UserComputation {
|
|||
const OpMetadata& metadata);
|
||||
|
||||
// Sets the device assignment on the Hlo instruction referenced by 'handle'.
|
||||
Status SetOpDeviceAssignment(const ComputationDataHandle& handle,
|
||||
const OpDeviceAssignment& device_assignment);
|
||||
Status SetOpSharding(const ComputationDataHandle& handle,
|
||||
const OpSharding& sharding);
|
||||
|
||||
// Builds a HLO computation from the UserComputation. The parameter "resolver"
|
||||
// is a function which returns a pointer to the HloComputation corresponding
|
||||
|
|
|
|||
|
|
@ -224,10 +224,13 @@ TEST_F(UserComputationTest, CheckImplicitBroadcastToExplicitBroadcast) {
|
|||
TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle,
|
||||
computation.AddParameterInstruction(b_request));
|
||||
|
||||
OpDeviceAssignment assignment;
|
||||
assignment.set_has_device(true);
|
||||
assignment.set_device(7);
|
||||
TF_EXPECT_OK(computation.SetOpDeviceAssignment(b_handle, assignment));
|
||||
const int64 kDevice = 7;
|
||||
OpSharding sharding;
|
||||
sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
|
||||
sharding.add_tile_assignment_dimensions(1);
|
||||
sharding.add_tile_assignment_devices(kDevice);
|
||||
|
||||
TF_EXPECT_OK(computation.SetOpSharding(b_handle, sharding));
|
||||
|
||||
BinaryOpRequest add;
|
||||
add.set_binop(BINOP_ADD);
|
||||
|
|
@ -260,12 +263,10 @@ TEST_F(UserComputationTest, CheckImplicitBroadcastToExplicitBroadcast) {
|
|||
|
||||
const HloInstruction* broadcast =
|
||||
hlo_computation->root_instruction()->operand(1);
|
||||
EXPECT_TRUE(broadcast->device_assignment().has_device());
|
||||
EXPECT_EQ(assignment.device(), broadcast->device_assignment().device());
|
||||
EXPECT_TRUE(broadcast->has_sharding());
|
||||
|
||||
const HloInstruction* reshape = broadcast->operand(0);
|
||||
EXPECT_TRUE(reshape->device_assignment().has_device());
|
||||
EXPECT_EQ(assignment.device(), reshape->device_assignment().device());
|
||||
EXPECT_TRUE(reshape->has_sharding());
|
||||
}
|
||||
|
||||
TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) {
|
||||
|
|
|
|||
|
|
@ -62,9 +62,16 @@ inline const ::tensorflow::Status& GetStatus(const StatusOr<T>& status) {
|
|||
#define EXPECT_IS_OK(expression) \
|
||||
EXPECT_EQ(tensorflow::Status::OK(), \
|
||||
xla::testing::internal_status::GetStatus(expression))
|
||||
#define EXPECT_IS_NOT_OK(expression) \
|
||||
EXPECT_NE(tensorflow::Status::OK(), \
|
||||
xla::testing::internal_status::GetStatus(expression))
|
||||
#undef ASSERT_IS_OK
|
||||
#define ASSERT_IS_OK(expression) \
|
||||
ASSERT_EQ(tensorflow::Status::OK(), \
|
||||
xla::testing::internal_status::GetStatus(expression))
|
||||
#undef ASSERT_IS_NOT_OK
|
||||
#define ASSERT_IS_NOT_OK(expression) \
|
||||
ASSERT_NE(tensorflow::Status::OK(), \
|
||||
xla::testing::internal_status::GetStatus(expression))
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_TEST_HELPERS_H_
|
||||
|
|
|
|||
|
|
@ -204,6 +204,8 @@ TokKind HloLexer::LexIdentifier() {
|
|||
KEYWORD(HloModule);
|
||||
KEYWORD(ENTRY);
|
||||
KEYWORD(ROOT);
|
||||
KEYWORD(maximal);
|
||||
KEYWORD(replicated);
|
||||
|
||||
#undef KEYWORD
|
||||
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ class HloParser {
|
|||
bool ParseInstructionList(HloComputation::Builder* builder,
|
||||
string* root_name);
|
||||
bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
|
||||
bool ParseSharding(HloInstruction* instruction);
|
||||
bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
|
||||
bool ParseOperands(std::vector<HloInstruction*>* operands);
|
||||
// Fill parsed operands into 'operands' and expect a certain number of
|
||||
|
|
@ -409,21 +410,147 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
|||
return TokenError(StrCat("parsing not yet implemented for op: ",
|
||||
HloOpcodeString(opcode)));
|
||||
}
|
||||
// Parse "device=".
|
||||
// Parse "sharding=".
|
||||
if (lexer_.GetKind() == TokKind::kComma) {
|
||||
int64 device;
|
||||
if (!ParseExtraAttribute(&device, /*expected_attribute=*/"device")) {
|
||||
if (!ParseSharding(instruction)) {
|
||||
return false;
|
||||
}
|
||||
OpDeviceAssignment assignment;
|
||||
assignment.set_has_device(true);
|
||||
assignment.set_device(device);
|
||||
instruction->set_device_assignment(assignment);
|
||||
}
|
||||
|
||||
return AddInstruction(name, instruction);
|
||||
}
|
||||
|
||||
// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? ('devices=' ('['
|
||||
// dims ']')* device_list)? '}' dims ::= int_list device_list ::= int_list
|
||||
bool HloParser::ParseSharding(HloInstruction* instruction) {
|
||||
if (!ParseToken(TokKind::kComma,
|
||||
"expects ',' in front of an extra attribute")) {
|
||||
return false;
|
||||
}
|
||||
string attribute_name;
|
||||
if (!ParseAttributeName(&attribute_name) || attribute_name != "sharding") {
|
||||
return TokenError("expects attribute name: sharding");
|
||||
}
|
||||
|
||||
if (!ParseToken(TokKind::kLbrace,
|
||||
"expected '{' to start sharding attribute")) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool maximal = false;
|
||||
bool replicated = false;
|
||||
std::vector<int64> devices;
|
||||
std::vector<int64> tile_assignment_dimensions;
|
||||
Shape tile_shape;
|
||||
while (lexer_.GetKind() != TokKind::kRbrace) {
|
||||
switch (lexer_.GetKind()) {
|
||||
case TokKind::kw_maximal:
|
||||
maximal = true;
|
||||
lexer_.Lex();
|
||||
break;
|
||||
case TokKind::kw_replicated:
|
||||
replicated = true;
|
||||
lexer_.Lex();
|
||||
break;
|
||||
case TokKind::kAttributeName: {
|
||||
if (lexer_.GetStrVal() == "device") {
|
||||
if (lexer_.Lex() != TokKind::kInt) {
|
||||
return TokenError("device= attribute must be an integer");
|
||||
}
|
||||
devices = {lexer_.GetInt64Val()};
|
||||
lexer_.Lex();
|
||||
} else if (lexer_.GetStrVal() == "devices") {
|
||||
lexer_.Lex();
|
||||
if (!ParseToken(TokKind::kLsquare,
|
||||
"expected '[' to start sharding devices shape")) {
|
||||
return false;
|
||||
}
|
||||
|
||||
do {
|
||||
int64 dim;
|
||||
if (!ParseInt64(&dim)) {
|
||||
return false;
|
||||
}
|
||||
tile_assignment_dimensions.push_back(dim);
|
||||
} while (EatIfPresent(TokKind::kComma));
|
||||
|
||||
if (!ParseToken(TokKind::kRsquare,
|
||||
"expected ']' to start sharding devices shape")) {
|
||||
return false;
|
||||
}
|
||||
do {
|
||||
int64 device;
|
||||
if (!ParseInt64(&device)) {
|
||||
return false;
|
||||
}
|
||||
devices.push_back(device);
|
||||
} while (EatIfPresent(TokKind::kComma));
|
||||
} else {
|
||||
return TokenError(
|
||||
"unknown attribute in sharding: expected device= or devices=");
|
||||
}
|
||||
break;
|
||||
}
|
||||
case TokKind::kShape:
|
||||
tile_shape = lexer_.GetShapeVal();
|
||||
lexer_.Lex();
|
||||
break;
|
||||
case TokKind::kRbrace:
|
||||
break;
|
||||
default:
|
||||
return TokenError("unexpected token");
|
||||
}
|
||||
}
|
||||
|
||||
OpSharding sharding;
|
||||
if (replicated) {
|
||||
if (!devices.empty()) {
|
||||
return TokenError(
|
||||
"replicated shardings should not have any devices assigned");
|
||||
}
|
||||
if (!ShapeUtil::Equal(tile_shape, Shape())) {
|
||||
return TokenError(
|
||||
"replicated shardings should not have any tile shape set");
|
||||
}
|
||||
sharding.set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
|
||||
} else if (maximal) {
|
||||
if (devices.size() != 1) {
|
||||
return TokenError(
|
||||
"maximal shardings should have exactly one device assigned");
|
||||
}
|
||||
if (!ShapeUtil::Equal(tile_shape, Shape())) {
|
||||
return TokenError("maximal shardings should not have any tile shape set");
|
||||
}
|
||||
sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
|
||||
sharding.add_tile_assignment_devices(devices[0]);
|
||||
} else {
|
||||
if (devices.size() <= 1) {
|
||||
return TokenError(
|
||||
"non-maximal shardings must have more than one device assigned");
|
||||
}
|
||||
if (ShapeUtil::Equal(tile_shape, Shape())) {
|
||||
return TokenError("non-maximal shardings should have a tile shape set");
|
||||
}
|
||||
if (tile_assignment_dimensions.empty()) {
|
||||
return TokenError(
|
||||
"non-maximal shardings must have a tile assignment list including "
|
||||
"dimensions");
|
||||
}
|
||||
sharding.set_type(OpSharding::Type::OpSharding_Type_OTHER);
|
||||
*sharding.mutable_tile_shape() = tile_shape;
|
||||
for (int64 dim : tile_assignment_dimensions) {
|
||||
sharding.add_tile_assignment_dimensions(dim);
|
||||
}
|
||||
for (int64 device : devices) {
|
||||
sharding.add_tile_assignment_devices(device);
|
||||
}
|
||||
}
|
||||
|
||||
instruction->set_sharding(HloSharding::FromProto(sharding).ValueOrDie());
|
||||
lexer_.Lex();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
|
||||
const Shape& shape) {
|
||||
switch (shape.element_type()) {
|
||||
|
|
|
|||
|
|
@ -100,9 +100,9 @@ ENTRY %add_constants () -> f32[] {
|
|||
R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module:
|
||||
|
||||
ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] {
|
||||
%v1 = f32[4]{0} parameter(0), device=1
|
||||
%v2 = f32[4]{0} parameter(1), device=1
|
||||
%greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2)
|
||||
%v1 = f32[4]{0} parameter(0), sharding={maximal device=1}
|
||||
%v2 = f32[4]{0} parameter(1), sharding={maximal device=1}
|
||||
%greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2), sharding={replicated}
|
||||
ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2)
|
||||
}
|
||||
|
||||
|
|
@ -164,9 +164,9 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
|
|||
R"(HloModule TwoSendRecvBothWayRecvFist_module:
|
||||
|
||||
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
|
||||
%recv = f32[] recv(), channel_id=15, device=1
|
||||
ROOT %constant = f32[] constant(2.1), device=0
|
||||
%send = () send(f32[] %constant), channel_id=16, device=0
|
||||
%recv = f32[] recv(), channel_id=15, sharding={maximal device=1}
|
||||
ROOT %constant = f32[] constant(2.1), sharding={maximal device=0}
|
||||
%send = () send(f32[] %constant), channel_id=16, sharding={maximal device=0}
|
||||
}
|
||||
|
||||
)"
|
||||
|
|
@ -180,7 +180,7 @@ ENTRY %GetTupleElement.v4 () -> s32[] {
|
|||
%constant = f32[] constant(1.23)
|
||||
%constant.1 = s32[] constant(4)
|
||||
%tuple = (f32[], s32[]) tuple(f32[] %constant, s32[] %constant.1)
|
||||
ROOT %get-tuple-element = s32[] get-tuple-element((f32[], s32[]) %tuple), index=1, device=0
|
||||
ROOT %get-tuple-element = s32[] get-tuple-element((f32[], s32[]) %tuple), index=1, sharding={maximal device=0}
|
||||
}
|
||||
|
||||
)"
|
||||
|
|
@ -289,7 +289,7 @@ TEST_F(HloParserTest, MoreConstants) {
|
|||
|
||||
ENTRY %SelectScalarS32True.v4 () -> s32[] {
|
||||
%constant.2 = pred[] constant(true)
|
||||
%constant.1 = s32[] constant(-42)
|
||||
%constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,3]1,2,3,4}
|
||||
%constant = s32[] constant(42)
|
||||
%select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,6 +44,8 @@ enum class TokKind {
|
|||
kw_ROOT,
|
||||
kw_true,
|
||||
kw_false,
|
||||
kw_maximal,
|
||||
kw_replicated,
|
||||
|
||||
// Typed tokens.
|
||||
kName, // %foo
|
||||
|
|
|
|||
|
|
@ -812,18 +812,32 @@ message RecvRequest {
|
|||
ChannelHandle channel_handle = 2;
|
||||
}
|
||||
|
||||
message OpDeviceAssignment {
|
||||
bool has_device = 1;
|
||||
|
||||
// Number of the device to which this operator is assigned. Ignored if
|
||||
// 'has_device' is false.
|
||||
int32 device = 2;
|
||||
message OpSharding {
|
||||
enum Type {
|
||||
// This sharding is replicated across all devices (implies maximal,
|
||||
// all other fields are unused).
|
||||
REPLICATED = 0;
|
||||
// This sharding is maximal - one device runs the entire operation.
|
||||
MAXIMAL = 1;
|
||||
// Neither of the above; tile_shape and tile_assignment are both used.
|
||||
OTHER = 2;
|
||||
}
|
||||
Type type = 1;
|
||||
// The shape of the sharded tile.
|
||||
Shape tile_shape = 2;
|
||||
// The shape of the tile assignment tensor - this must be the same rank as
|
||||
// tile_shape and the product of its dimensions must equal
|
||||
// tile_assignment_devices.size().
|
||||
repeated int64 tile_assignment_dimensions = 3;
|
||||
// Flattened list of device IDs. The order of flattening is the same as used
|
||||
// by IndexUtil::MultiToLinearIndex(tile_assignment_shape).
|
||||
repeated int64 tile_assignment_devices = 4;
|
||||
}
|
||||
|
||||
message OpRequest {
|
||||
ComputationHandle computation = 1;
|
||||
OpMetadata metadata = 33;
|
||||
OpDeviceAssignment device_assignment = 39;
|
||||
OpSharding sharding = 40;
|
||||
|
||||
oneof op {
|
||||
BinaryOpRequest binary_op_request = 2;
|
||||
|
|
@ -862,7 +876,7 @@ message OpRequest {
|
|||
BatchNormTrainingRequest batch_norm_training_request = 35;
|
||||
BatchNormGradRequest batch_norm_grad_request = 37;
|
||||
BatchNormInferenceRequest batch_norm_inference_request = 38;
|
||||
// Next: 40
|
||||
// Next: 41
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user