Supported in this CL:

* Attaching sharding descriptors to HLO ops
  * Partitioning the HLO graph into per-device computations based on those sharding descriptors.
  * All operator support for device placement and ops replicated on all devices.
  * Elementwise op support for tiled shardings.
  * 2D Convolution support for tiled shardings (no stride or dilation support).

PiperOrigin-RevId: 173946036
This commit is contained in:
A. Unique TensorFlower 2017-10-30 14:05:29 -07:00 committed by TensorFlower Gardener
parent 682a6ed64f
commit efcbf6e34e
24 changed files with 937 additions and 85 deletions

View File

@ -103,20 +103,17 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel,
DeviceNameUtils::ParseFullName(op_kernel->requested_device(), &parsed),
errors::Internal("Unable to parse device name: ",
op_kernel->requested_device()));
xla::OpDeviceAssignment assignment;
// If no device ID assignment is found, XLA is free to use whatever device it
// wants. In practice this usually has the effect of placing things on
// device 0.
if (parsed.has_id) {
assignment.set_has_device(true);
assignment.set_device(parsed.id);
b->SetSharding(xla::ShardingBuilder::AssignDevice(parsed.id));
}
b->SetDeviceAssignment(assignment);
op_kernel->Compute(context);
b->ClearOpMetadata();
b->ClearDeviceAssignment();
b->ClearSharding();
VLOG(4) << "Done";
}

View File

@ -40,12 +40,13 @@ template <typename T>
class Array {
public:
// Creates a new array with the specified dimensions.
explicit Array(const std::vector<int64>& sizes) : Array(sizes, T()) {}
explicit Array(tensorflow::gtl::ArraySlice<int64> sizes)
: Array(sizes, T()) {}
// Creates a new array with the specified dimensions and specified value for
// every cell.
Array(const std::vector<int64>& sizes, T value)
: sizes_(sizes), values_(new T[num_elements()]) {
Array(tensorflow::gtl::ArraySlice<int64> sizes, T value)
: sizes_(sizes.begin(), sizes.end()), values_(new T[num_elements()]) {
Fill(value);
}
@ -192,6 +193,18 @@ class Array {
return values_[calculate_index(indexes)];
}
// Returns the value at the cell specified by the indexes. The number of
// arguments have to match with the number of dimensions for the array.
const T& operator()(tensorflow::gtl::ArraySlice<int64> indexes) const {
return values_[calculate_index(indexes)];
}
// Returns the value at the cell specified by the indexes. The number of
// arguments have to match with the number of dimensions for the array.
T& operator()(tensorflow::gtl::ArraySlice<int64> indexes) {
return values_[calculate_index(indexes)];
}
// Low-level accessor for stuff like memcmp, handle with care. Returns pointer
// to the underlying storage of the array (similarly to std::vector::data()).
T* data() const {
@ -218,6 +231,11 @@ class Array {
std::multiplies<int64>());
}
const T* begin() const { return &values_[0]; }
T* begin() { return &values_[0]; }
const T* end() const { return &values_[num_elements()]; }
T* end() { return &values_[num_elements()]; }
bool operator==(const Array<T>& other) const {
if (sizes_.size() != other.sizes_.size()) {
return false;

View File

@ -170,6 +170,7 @@ cc_library(
":computation",
":global_data",
":padding",
"//tensorflow/compiler/xla:array",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:array4d",

View File

@ -1794,14 +1794,9 @@ StatusOr<Computation> ComputationBuilder::Build() {
void ComputationBuilder::AddCommonFieldsToOpRequest(OpRequest* request) const {
*request->mutable_metadata() = metadata_;
*request->mutable_device_assignment() = device_assignment_;
}
void ComputationBuilder::ClearDeviceAssignment() { device_assignment_.Clear(); }
void ComputationBuilder::SetDeviceAssignment(
const OpDeviceAssignment& assignment) {
device_assignment_ = assignment;
if (sharding_) {
*request->mutable_sharding() = *sharding_;
}
}
/* static */ ConvolutionDimensionNumbers

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <string>
#include <utility>
#include "tensorflow/compiler/xla/array.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
@ -42,6 +43,58 @@ limitations under the License.
namespace xla {
class ShardingBuilder {
public:
// A shaped array used to describe the assignment of tiles to devices.
using TileAssignment = Array<int64>;
// Creates a replicated sharding - replicate a tensor on every device.
static OpSharding Replicate() {
OpSharding result;
result.set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
return result;
}
// Creates a sharding that assigns a tensor to just one device.
static OpSharding AssignDevice(int device) {
OpSharding result;
result.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
result.add_tile_assignment_dimensions(1);
result.add_tile_assignment_devices(device);
return result;
}
// Creates a tiled sharding with the given tile shape and assignment of tiles
// to devices.
static OpSharding Tile(Shape tile_shape,
const TileAssignment& tile_assignment) {
OpSharding result;
result.set_type(OpSharding::Type::OpSharding_Type_OTHER);
for (int64 dim : tile_assignment.dimensions()) {
result.add_tile_assignment_dimensions(dim);
}
for (uint32 device : tile_assignment) {
result.add_tile_assignment_devices(device);
}
return result;
}
// Creates a sharding in one dimension, with the given tile shape which must
// be rank 1 and using devices 0..num_tiles.
static OpSharding Tile1D(Shape tile_shape, int64 num_tiles) {
OpSharding result;
result.set_type(OpSharding::Type::OpSharding_Type_OTHER);
CHECK_EQ(ShapeUtil::Rank(tile_shape), 1);
std::vector<int64> dimensions(1, num_tiles);
auto& tile_dimension = (*tile_shape.mutable_dimensions())[0];
tile_dimension = CeilOfRatio(static_cast<int64>(tile_dimension), num_tiles);
*result.mutable_tile_shape() = tile_shape;
result.add_tile_assignment_dimensions(num_tiles);
for (int64 i = 0; i < num_tiles; ++i) {
result.add_tile_assignment_devices(i);
}
return result;
}
};
// Wraps an XLA client with a convenient interface for building up
// computations. Any errors encountered in building up the computation are
// deferred from being handled until Build() is called.
@ -78,11 +131,11 @@ class ComputationBuilder {
// Sets an OpDeviceAssignment that will be attached to all instructions
// until cleared.
void SetDeviceAssignment(const OpDeviceAssignment& assignment);
void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
// Clears the device assignment. Ops will be placed according to the default
// placement policy.
void ClearDeviceAssignment();
void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; }
// Sets the builder to a mode where it will die immediately when an error is
// encountered, rather than producing it in a deferred fashion when Build() is
@ -894,8 +947,9 @@ class ComputationBuilder {
// throughout the TensorFlow op kernel implementations).
OpMetadata metadata_;
// Device assignment for the operator.
OpDeviceAssignment device_assignment_;
// Sharding for this operator. This is structured as a "model"-like operation,
// in order to simplify client code, similar to metadata_.
tensorflow::gtl::optional<OpSharding> sharding_;
TF_DISALLOW_COPY_AND_ASSIGN(ComputationBuilder);
};

View File

@ -133,6 +133,7 @@ cc_library(
"hlo_instruction.cc",
"hlo_module.cc",
"hlo_opcode.cc",
"hlo_sharding.cc",
],
hdrs = [
"dfs_hlo_visitor.h",
@ -141,6 +142,7 @@ cc_library(
"hlo_instruction.h",
"hlo_module.h",
"hlo_opcode.h",
"hlo_sharding.h",
],
deps = [
":hlo_module_config",
@ -148,6 +150,7 @@ cc_library(
":hlo_reachability",
":name_uniquer",
":versioned_computation_handle",
"//tensorflow/compiler/xla:array",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:shape_tree",
@ -238,6 +241,22 @@ tf_cc_test(
],
)
tf_cc_test(
name = "hlo_sharding_test",
srcs = ["hlo_sharding_test.cc"],
deps = [
":hlo",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
cc_library(
name = "call_graph",
srcs = ["call_graph.cc"],

View File

@ -56,7 +56,6 @@ std::unique_ptr<HloComputation> HloComputation::Builder::Build(
HloInstruction* root =
root_instruction ? root_instruction : last_added_instruction_;
CHECK_NE(nullptr, root);
return WrapUnique(new HloComputation(name_, parameter_count, &instructions_,
root, fusion_instruction_));
}
@ -735,6 +734,10 @@ std::unique_ptr<HloComputation> HloComputation::Clone(const string& suffix) {
}
new_instr = instr->CloneWithNewOperands(instr->shape(), new_operands);
new_instr->set_metadata(instr->metadata());
if (instr->has_sharding()) {
new_instr->set_sharding(instr->sharding());
}
InsertOrDie(&clone_map, instr, new_instr.get());
instructions.push_back(std::move(new_instr));
}

View File

@ -1039,8 +1039,8 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
if (!opcode_specific_info.empty()) {
lines.push_back(opcode_specific_info);
}
if (instr->device_assignment().has_device()) {
lines.push_back(StrCat("device=", instr->device_assignment().device()));
if (instr->has_sharding()) {
lines.push_back(StrCat("sharding=", instr->sharding().ToString()));
}
// Show the shape and layout of the instruction, unless it's an inlined fusion
// node -- there the shape and layout is present in the output node.

View File

@ -1212,6 +1212,9 @@ std::unique_ptr<HloInstruction> HloInstruction::Clone(
}
}
clone->set_parent(parent_);
if (has_sharding()) {
clone->set_sharding(sharding());
}
return clone;
}
@ -1889,8 +1892,8 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
if (opcode() == HloOpcode::kGetTupleElement) {
extra.push_back(StrCat("index=", tuple_index()));
}
if (device_assignment_.has_device()) {
extra.push_back(StrCat("device=", device_assignment_.device()));
if (has_sharding()) {
extra.push_back(StrCat("sharding=", sharding().ToString()));
}
if (!control_successors_.empty()) {
extra.push_back(StrCat(

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -713,6 +714,26 @@ class HloInstruction {
fusion_kind_ = kind;
}
// Returns the sharding applied to this operator.
// REQUIRES: has_sharding() is true.
const HloSharding& sharding() const {
CHECK(has_sharding());
return *sharding_;
}
// Returns the sharding applied to this operator, or default_ if none exists.
const HloSharding& sharding_or_default(const HloSharding& default_) const {
return sharding_ ? *sharding_ : default_;
}
// Sets the sharding of this operator. Should only be called by HloModule or
// HloComputation methods.
void set_sharding(const HloSharding& sharding) {
sharding_ = MakeUnique<HloSharding>(sharding);
}
// Remove any sharding from this operator.
void clear_sharding() { sharding_ = nullptr; }
// Return true if this operator has a sharding assigned.
bool has_sharding() const { return sharding_ != nullptr; }
// Merges the fused instructions from 'instruction_to_merge' into the
// fused instruction set of 'this', updating operands as necessary.
//
@ -984,14 +1005,6 @@ class HloInstruction {
void RelayoutConstant(const Layout& new_layout,
const ShapeIndex& shape_index = {});
// Gets/sets the device assignment.
const OpDeviceAssignment& device_assignment() const {
return device_assignment_;
}
void set_device_assignment(const OpDeviceAssignment& device_assignment) {
device_assignment_ = device_assignment;
}
private:
enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse };
@ -1124,6 +1137,9 @@ class HloInstruction {
// The type of the fusion. Used by kFusion only.
FusionKind fusion_kind_;
// The sharding, if one exists.
std::unique_ptr<HloSharding> sharding_;
// For parameter instructions this field holds the parameter number.
int64 parameter_number_ = 0;
string parameter_name_;
@ -1184,9 +1200,6 @@ class HloInstruction {
// outer-most dimension first).
std::vector<int64> outer_dimension_partitions_;
// Device assignment for the instruction.
OpDeviceAssignment device_assignment_;
TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction);
};

View File

@ -0,0 +1,232 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
using ::tensorflow::strings::StrCat;
HloSharding HloSharding::AssignDevice(int64 device_id) {
return HloSharding(device_id);
}
HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) {
CHECK_EQ(1, ShapeUtil::Rank(input_shape));
CHECK_GT(num_tiles, 1);
std::vector<int64> dimensions(1, num_tiles);
Shape tile_shape = input_shape;
auto& tile_dimension = (*tile_shape.mutable_dimensions())[0];
tile_dimension = CeilOfRatio(static_cast<int64>(tile_dimension), num_tiles);
Array<int64> assignment(dimensions);
std::iota(assignment.begin(), assignment.end(), 0);
return HloSharding(tile_shape, assignment);
}
string HloSharding::ToString() const {
string result = StrCat("{", (replicated_ ? " replicated" : ""),
(maximal_ ? " maximal" : ""));
if (replicated_) {
return "{replicated}";
} else if (maximal_) {
return StrCat(
"{maximal device=", static_cast<int64>(*tile_assignment_.begin()), "}");
} else {
return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ",
"devices=", VectorString(tile_assignment_), "}");
}
}
bool HloSharding::UsesDevice(int64 device) const {
const auto& devices = tile_assignment_;
return replicated_ ||
std::find(devices.begin(), devices.end(), device) != devices.end();
}
std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
CHECK(!ShapeUtil::IsTuple(tile_shape_));
CHECK(!maximal_);
std::vector<int64> ret_index;
tile_assignment_.Each([&](tensorflow::gtl::ArraySlice<int64> index, int64 d) {
if (d == device) {
ret_index = {index.begin(), index.end()};
}
});
CHECK(!ret_index.empty());
return ret_index;
}
int64 HloSharding::DeviceForTileIndex(
tensorflow::gtl::ArraySlice<int64> index) const {
CHECK(!replicated_);
if (maximal_) {
return *tile_assignment_.begin();
}
CHECK_EQ(ShapeUtil::Rank(tile_shape_), tile_assignment_.dimensions().size());
return tile_assignment_(index);
}
std::vector<int64> HloSharding::TileOffsetForDevice(int64 device) const {
CHECK(!ShapeUtil::IsTuple(tile_shape_));
std::vector<int64> index = TileIndexForDevice(device);
if (maximal_) {
// Index will always be all zeroes if we're maximal, and tile_shape_ is not
// valid.
return index;
}
for (int64 i = 0; i < index.size(); ++i) {
index[i] *= tile_shape_.dimensions(i);
}
return index;
}
std::vector<int64> HloSharding::TileLimitForDevice(int64 device) const {
CHECK(!ShapeUtil::IsTuple(tile_shape_));
CHECK(!maximal_); // Maximal shardings do not have a valid tile shape.
std::vector<int64> index = TileIndexForDevice(device);
for (int64 i = 0; i < index.size(); ++i) {
index[i] = (index[i] + 1) * tile_shape_.dimensions(i);
}
return index;
}
StatusOr<int64> HloSharding::UniqueDevice() const {
if (!replicated_ && maximal_) {
return static_cast<int64>(*tile_assignment_.begin());
}
return tensorflow::errors::InvalidArgument(
"UniqueDevice() called on sharding that executes on multiple devices");
}
Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
if (replicated_) {
return Status::OK();
}
// All tile assignments must be less than the number of available cores and
// unique.
Status status = Status::OK();
std::set<int64> seen_cores;
tile_assignment_.Each(
[&](tensorflow::gtl::ArraySlice<int64> indices, uint32 core) {
// Don't overwrite a bad status, so we report the first error.
if (status.ok()) {
if (core >= num_devices) {
status =
tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat(
"core ", core, " > ", num_devices, " in tile assignment"));
} else if (seen_cores.count(core) != 0) {
status =
tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat(
"core ", core, " is not unique in tile assignment"));
}
}
seen_cores.insert(core);
});
if (!status.ok()) {
return status;
}
if (IsTileMaximal()) {
return Status::OK();
}
// The tile rank must be the same as the input rank.
if (ShapeUtil::Rank(shape) != ShapeUtil::Rank(tile_shape_)) {
return tensorflow::errors::InvalidArgument(
"Tile rank is different to the input rank");
}
// The tile shape must not be the same as the input shape without maximal_
// also set. If this is the case, we're not actually sharded and the correct
// constructor should have been used.
if (ShapeUtil::Equal(shape, tile_shape_)) {
return tensorflow::errors::InvalidArgument(
"Tile shape is the same as the input shape. If a replicated sharding "
"was intended, use HloSharding::Replicated(). If a device placement "
"was intended, use HloSharding::AssignDevice()");
}
// The tile shape must not be greater than the input shape in any dimension.
for (int64 i = 0, e = ShapeUtil::Rank(shape); i != e; ++i) {
auto tile_dim = tile_shape_.dimensions(i);
auto shape_dim = shape.dimensions(i);
if (tile_dim > shape_dim) {
return tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat(
"Tile is larger than input shape (dimension ", i, ", ", tile_dim,
" > ", shape_dim));
}
}
// The tile assignment tensor must be exactly dimensioned to ceil(shape[dim]
// tile[dim]) for every dimension contained within tile.
for (int64 i = 0, e = tile_assignment_.dimensions().size(); i != e; ++i) {
int64 expected_dim =
CeilOfRatio(shape.dimensions(i), tile_shape_.dimensions(i));
if (tile_assignment_.dimensions()[i] != expected_dim) {
return tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat(
"Tile assignment tensor has incorrect shape. Dimension ", i,
" expected ", expected_dim, " but got ",
tile_assignment_.dimensions()[i]));
}
}
return Status::OK();
}
/*static*/ StatusOr<HloSharding> HloSharding::FromProto(
const OpSharding& proto) {
if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
return Replicate();
} else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL) {
return HloSharding(proto.tile_assignment_devices(0));
}
// Some versions of gcc cannot infer the TileAssignment constructor from a
// braced initializer-list, so create one manually.
std::vector<int64> devices(proto.tile_assignment_devices().begin(),
proto.tile_assignment_devices().end());
Array<int64> tile_assignment(
std::vector<int64>(proto.tile_assignment_dimensions().begin(),
proto.tile_assignment_dimensions().end()));
std::copy(proto.tile_assignment_devices().begin(),
proto.tile_assignment_devices().end(), tile_assignment.begin());
return HloSharding(proto.tile_shape(), tile_assignment);
}
OpSharding HloSharding::ToProto() const {
OpSharding result;
*result.mutable_tile_shape() = tile_shape_;
for (int64 dim : tile_assignment_.dimensions()) {
result.add_tile_assignment_dimensions(dim);
}
for (auto device : tile_assignment_) {
result.add_tile_assignment_devices(device);
}
if (IsReplicated()) {
result.set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
} else if (IsTileMaximal()) {
result.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
} else {
result.set_type(OpSharding::Type::OpSharding_Type_OTHER);
}
return result;
}
} // namespace xla

View File

@ -0,0 +1,165 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// HLO shardings describe how an HLO instruction is split across multiple
// computations.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
#include <string>
#include "tensorflow/compiler/xla/array.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
// HLO shardings describe how an HLO instruction is split across multiple
// computations.
class HloSharding {
public:
// Creates a trivial sharding that replicates a maximal tile across all
// devices.
static HloSharding Replicate() { return HloSharding(); }
// Creates a sharding that emulates device placement; a tile shape equal to
// the input shape (one tile) assigned to a single device.
static HloSharding AssignDevice(int64 device_id);
// Creates a new sharding which splits a shape into tiles each with shape
// `tile_shape`. Each tile is assigned to one device, which is specified by
// `tile_assignment`. Any tensor not a multiple of the tile size in any
// dimension is implicitly padded to the tile size.
//
// e.g. Tile({2, 2}, {0, 1}) on a tensor of shape {3, 2} would look like:
// 2 1 padding
// <------><->
// +----+----+
// | 0 | 1 |
// +----+----+
//
// Split into two tiles, one of which is implicitly padded by one.
static HloSharding Tile(const Shape& tile_shape,
const Array<int64>& tile_assignment) {
return HloSharding(tile_shape, tile_assignment);
}
// Creates a new sharding which splits a one-dimensional input shape into
// `num_tiles` tiles.
static HloSharding Tile1D(const Shape& input_shape, int64 num_tiles);
// Create a new sharding from a protobuf OpSharding.
static StatusOr<HloSharding> FromProto(const OpSharding& proto);
OpSharding ToProto() const;
string ToString() const;
// Validate that this sharding can be applied to a tensor with shape `shape`.
Status Validate(const Shape& shape, int64 num_devices) const;
// Returns true if the sharding is trivial: replicate on all devices.
bool IsReplicated() const { return replicated_; }
// Returns true if the tile size is the same as the input size.
bool IsTileMaximal() const { return maximal_; }
// Returns true if the sharding defines an operation on the given device.
bool UsesDevice(int64 device) const;
// Returns the tile that should be executed on the given device.
std::vector<int64> TileIndexForDevice(int64 device) const;
// Returns the device that should execute the given tile.
// It is an error to call this if is_replicated() is true.
int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice<int64> index) const;
// Given a device ID, returns the offset within the input space of the
// tile that should be executed on the given core. This returns the lower
// extent of the tile in the input space.
std::vector<int64> TileOffsetForDevice(int64 device) const;
// Given a device ID, returns the limit within the input space of the
// tile that should be executed on the given core. This returns the upper
// extent of the tile in the input space.
std::vector<int64> TileLimitForDevice(int64 device) const;
// Returns the single device this op operates on.
// Requires !Replicated() && IsTileMaximal().
StatusOr<int64> UniqueDevice() const;
// Returns true if this op only uses a single device.
bool HasUniqueDevice() const { return !IsReplicated() && IsTileMaximal(); }
bool operator==(const HloSharding& other) const {
return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
protobuf_util::ProtobufEquals(tile_shape_, other.tile_shape_) &&
tile_assignment_ == other.tile_assignment_;
}
bool operator!=(const HloSharding& other) const { return !(*this == other); }
size_t Hash() const {
if (replicated_) {
return 0;
}
size_t h = 0;
for (uint32 v : tile_assignment_) {
h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
}
for (uint32 v : tile_shape_.dimensions()) {
h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
}
return h;
}
// Gets the tile shape.
// It is an error to call this if IsTileMaximal() is true.
const Shape& tile_shape() const { return tile_shape_; }
// Gets the tile assignment tensor.
// It is an error to call this if IsReplicated() is true.
const Array<int64>& tile_assignment() const { return tile_assignment_; }
private:
HloSharding()
: replicated_(true),
maximal_(true),
tile_shape_(),
tile_assignment_({0}) {}
explicit HloSharding(int64 device_id)
: replicated_(false),
maximal_(true),
tile_shape_(),
tile_assignment_({1}, device_id) {}
HloSharding(const Shape& tile_shape, const Array<int64>& tile_assignment)
: replicated_(false),
maximal_(false),
tile_shape_(tile_shape),
tile_assignment_(tile_assignment) {}
bool replicated_;
bool maximal_;
Shape tile_shape_;
Array<int64> tile_assignment_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_

View File

@ -0,0 +1,190 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include <set>
#include <unordered_map>
#include <utility>
#include <vector>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
namespace {
Array<int64> MakeArray(tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<int64> contents) {
Array<int64> a(dimensions);
std::copy(contents.begin(), contents.end(), a.begin());
return a;
}
class HloShardingTest : public HloTestBase {};
TEST_F(HloShardingTest, Replicate) {
Shape tile_shape = ShapeUtil::MakeShape(U32, {4});
HloSharding sharding = HloSharding::Replicate();
EXPECT_TRUE(sharding.IsReplicated());
EXPECT_TRUE(sharding.IsTileMaximal());
EXPECT_TRUE(sharding.UsesDevice(0));
EXPECT_TRUE(sharding.UsesDevice(65535));
HloSharding other = HloSharding::Replicate();
EXPECT_EQ(other, sharding);
EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}),
/*num_devices=*/2));
EXPECT_IS_NOT_OK(sharding.UniqueDevice());
}
TEST_F(HloShardingTest, DevicePlacement) {
HloSharding sharding = HloSharding::AssignDevice(5);
EXPECT_FALSE(sharding.IsReplicated());
EXPECT_TRUE(sharding.IsTileMaximal());
EXPECT_FALSE(sharding.UsesDevice(0));
EXPECT_TRUE(sharding.UsesDevice(5));
EXPECT_EQ(5, sharding.UniqueDevice().ValueOrDie());
HloSharding other = HloSharding::Replicate();
EXPECT_NE(other, sharding);
EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}),
/*num_devices=*/6));
EXPECT_IS_NOT_OK(
sharding.Validate(ShapeUtil::MakeShape(U32, {4}), /*num_devices=*/5));
}
TEST_F(HloShardingTest, Tile) {
{
// Test should fail because of a duplicate tile assignment.
Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
HloSharding sharding =
HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 0, 2, 3}));
EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {4, 6}),
/*num_devices=*/4));
}
{
// Test should pass.
Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
HloSharding sharding =
HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3}));
EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4, 6}),
/*num_devices=*/2));
}
{
// Test should fail due to the tile being larger than the input space.
Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
HloSharding sharding =
HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3}));
EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {2, 2}),
/*num_devices=*/4));
}
{
// Test should fail due to the tile not dividing the input space into 4
// sections (even with padding).
Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
HloSharding sharding =
HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3}));
EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {6, 3}),
/*num_devices=*/4));
}
{
// Test should pass.
Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
HloSharding sharding =
HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1}));
EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {3, 5}),
/*num_devices=*/5));
EXPECT_EQ(0, sharding.DeviceForTileIndex({0, 0}));
EXPECT_EQ(3, sharding.DeviceForTileIndex({0, 1}));
EXPECT_EQ(2, sharding.DeviceForTileIndex({1, 0}));
EXPECT_EQ(1, sharding.DeviceForTileIndex({1, 1}));
EXPECT_EQ(sharding.TileOffsetForDevice(0), (std::vector<int64>{0, 0}));
EXPECT_EQ(sharding.TileOffsetForDevice(3), (std::vector<int64>{0, 3}));
EXPECT_EQ(sharding.TileOffsetForDevice(2), (std::vector<int64>{2, 0}));
EXPECT_EQ(sharding.TileOffsetForDevice(1), (std::vector<int64>{2, 3}));
EXPECT_IS_NOT_OK(sharding.UniqueDevice());
}
}
TEST_F(HloShardingTest, Hash) {
auto hash_compare_equal = [](const HloSharding& a, const HloSharding& b) {
if (a.Hash() != b.Hash()) {
return false;
}
return a == b;
};
{
HloSharding sharding1 = HloSharding::Replicate();
HloSharding sharding2 = HloSharding::Replicate();
EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
}
{
HloSharding sharding1 = HloSharding::AssignDevice(1);
HloSharding sharding2 = HloSharding::AssignDevice(1);
EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
}
{
HloSharding sharding1 = HloSharding::AssignDevice(1);
HloSharding sharding2 = HloSharding::AssignDevice(2);
EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
}
{
Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
HloSharding sharding1 =
HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1}));
HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}),
MakeArray({2, 2}, {0, 3, 2, 1}));
EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
}
{
Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
HloSharding sharding1 =
HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1}));
HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}),
MakeArray({2, 2}, {0, 3, 2, 1}));
EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
}
{
Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
HloSharding sharding1 =
HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1}));
HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}),
MakeArray({2, 2}, {0, 3, 1, 2}));
EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
}
}
} // namespace
} // namespace xla

View File

@ -198,9 +198,10 @@ Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) {
NodeDef* node_def = graph_def_.add_node();
node_def->set_name(GetNodeNameForInstruction(instruction));
node_def->set_op(GetOpDefName(instruction));
if (instruction->device_assignment().has_device()) {
node_def->set_device(
GetDeviceName(instruction->device_assignment().device()));
if (instruction->has_sharding() &&
instruction->sharding().HasUniqueDevice()) {
TF_ASSIGN_OR_RETURN(int64 device, instruction->sharding().UniqueDevice());
node_def->set_device(GetDeviceName(device));
}
SetNodeAttrs(instruction, node_def);
if (instruction->opcode() == HloOpcode::kFusion) {

View File

@ -1415,9 +1415,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
// proto in the above switch statement.
TF_ASSIGN_OR_RETURN(ComputationDataHandle handle, handle_status);
TF_RETURN_IF_ERROR(computation->SetOpMetadata(handle, arg->metadata()));
TF_RETURN_IF_ERROR(
computation->SetOpDeviceAssignment(handle, arg->device_assignment()));
if (arg->has_sharding()) {
TF_RETURN_IF_ERROR(computation->SetOpSharding(handle, arg->sharding()));
}
return tensorflow::Status::OK();
}

View File

@ -1315,20 +1315,19 @@ Status UserComputation::SetOpMetadata(const ComputationDataHandle& handle,
return Status::OK();
}
Status UserComputation::SetOpDeviceAssignment(
const ComputationDataHandle& handle,
const OpDeviceAssignment& device_assignment) {
Status UserComputation::SetOpSharding(const ComputationDataHandle& handle,
const OpSharding& sharding) {
tensorflow::mutex_lock lock(mutex_);
int64 handle_value = handle.handle();
if (session_computation_.requests().count(handle_value) == 0) {
return InvalidArgument("Invalid handle in SetOpDeviceAssignment (%lld)",
return InvalidArgument("Invalid handle in SetOpSharding (%lld)",
handle_value);
}
*session_computation_.mutable_requests()
->at(handle_value)
.mutable_request()
->mutable_device_assignment() = device_assignment;
->mutable_sharding() = sharding;
return Status::OK();
}
@ -2518,7 +2517,9 @@ HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast(
if (ShapeUtil::IsScalar(operand->shape())) {
HloInstruction* broadcast = hlo_builder_.AddInstruction(
HloInstruction::CreateBroadcast(broadcast_shape, operand, {}));
broadcast->set_device_assignment(operand->device_assignment());
if (operand->has_sharding()) {
broadcast->set_sharding(operand->sharding());
}
return broadcast;
}
// Do explicit broadcast for degenerate broadcast.
@ -2536,12 +2537,16 @@ HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast(
ShapeUtil::MakeShape(operand->shape().element_type(),
reshaped_dimensions),
operand));
reshaped_operand->set_device_assignment(operand->device_assignment());
if (operand->has_sharding()) {
reshaped_operand->set_sharding(operand->sharding());
}
// Broadcast 'reshape' up to the larger size.
HloInstruction* broadcast =
hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast(
broadcast_shape, reshaped_operand, broadcast_dimensions));
broadcast->set_device_assignment(operand->device_assignment());
if (operand->has_sharding()) {
broadcast->set_sharding(operand->sharding());
}
return broadcast;
}
@ -2556,8 +2561,11 @@ void ComputationLowerer::Visit(
HloInstruction* hlo_instruction =
hlo_builder_.AddInstruction(std::move(instruction));
hlo_instruction->set_metadata(request.request().metadata());
hlo_instruction->set_device_assignment(
request.request().device_assignment());
if (request.request().has_sharding()) {
OpSharding op_sharding = request.request().sharding();
hlo_instruction->set_sharding(
HloSharding::FromProto(op_sharding).ValueOrDie());
}
return hlo_instruction;
};
auto lookup_instruction = [&](const ComputationDataHandle& handle) {

View File

@ -262,8 +262,8 @@ class UserComputation {
const OpMetadata& metadata);
// Sets the device assignment on the Hlo instruction referenced by 'handle'.
Status SetOpDeviceAssignment(const ComputationDataHandle& handle,
const OpDeviceAssignment& device_assignment);
Status SetOpSharding(const ComputationDataHandle& handle,
const OpSharding& sharding);
// Builds a HLO computation from the UserComputation. The parameter "resolver"
// is a function which returns a pointer to the HloComputation corresponding

View File

@ -224,10 +224,13 @@ TEST_F(UserComputationTest, CheckImplicitBroadcastToExplicitBroadcast) {
TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle,
computation.AddParameterInstruction(b_request));
OpDeviceAssignment assignment;
assignment.set_has_device(true);
assignment.set_device(7);
TF_EXPECT_OK(computation.SetOpDeviceAssignment(b_handle, assignment));
const int64 kDevice = 7;
OpSharding sharding;
sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
sharding.add_tile_assignment_dimensions(1);
sharding.add_tile_assignment_devices(kDevice);
TF_EXPECT_OK(computation.SetOpSharding(b_handle, sharding));
BinaryOpRequest add;
add.set_binop(BINOP_ADD);
@ -260,12 +263,10 @@ TEST_F(UserComputationTest, CheckImplicitBroadcastToExplicitBroadcast) {
const HloInstruction* broadcast =
hlo_computation->root_instruction()->operand(1);
EXPECT_TRUE(broadcast->device_assignment().has_device());
EXPECT_EQ(assignment.device(), broadcast->device_assignment().device());
EXPECT_TRUE(broadcast->has_sharding());
const HloInstruction* reshape = broadcast->operand(0);
EXPECT_TRUE(reshape->device_assignment().has_device());
EXPECT_EQ(assignment.device(), reshape->device_assignment().device());
EXPECT_TRUE(reshape->has_sharding());
}
TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) {

View File

@ -62,9 +62,16 @@ inline const ::tensorflow::Status& GetStatus(const StatusOr<T>& status) {
#define EXPECT_IS_OK(expression) \
EXPECT_EQ(tensorflow::Status::OK(), \
xla::testing::internal_status::GetStatus(expression))
#define EXPECT_IS_NOT_OK(expression) \
EXPECT_NE(tensorflow::Status::OK(), \
xla::testing::internal_status::GetStatus(expression))
#undef ASSERT_IS_OK
#define ASSERT_IS_OK(expression) \
ASSERT_EQ(tensorflow::Status::OK(), \
xla::testing::internal_status::GetStatus(expression))
#undef ASSERT_IS_NOT_OK
#define ASSERT_IS_NOT_OK(expression) \
ASSERT_NE(tensorflow::Status::OK(), \
xla::testing::internal_status::GetStatus(expression))
#endif // TENSORFLOW_COMPILER_XLA_TEST_HELPERS_H_

View File

@ -204,6 +204,8 @@ TokKind HloLexer::LexIdentifier() {
KEYWORD(HloModule);
KEYWORD(ENTRY);
KEYWORD(ROOT);
KEYWORD(maximal);
KEYWORD(replicated);
#undef KEYWORD

View File

@ -49,6 +49,7 @@ class HloParser {
bool ParseInstructionList(HloComputation::Builder* builder,
string* root_name);
bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
bool ParseSharding(HloInstruction* instruction);
bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
bool ParseOperands(std::vector<HloInstruction*>* operands);
// Fill parsed operands into 'operands' and expect a certain number of
@ -409,21 +410,147 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
return TokenError(StrCat("parsing not yet implemented for op: ",
HloOpcodeString(opcode)));
}
// Parse "device=".
// Parse "sharding=".
if (lexer_.GetKind() == TokKind::kComma) {
int64 device;
if (!ParseExtraAttribute(&device, /*expected_attribute=*/"device")) {
if (!ParseSharding(instruction)) {
return false;
}
OpDeviceAssignment assignment;
assignment.set_has_device(true);
assignment.set_device(device);
instruction->set_device_assignment(assignment);
}
return AddInstruction(name, instruction);
}
// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? ('devices=' ('['
// dims ']')* device_list)? '}' dims ::= int_list device_list ::= int_list
bool HloParser::ParseSharding(HloInstruction* instruction) {
if (!ParseToken(TokKind::kComma,
"expects ',' in front of an extra attribute")) {
return false;
}
string attribute_name;
if (!ParseAttributeName(&attribute_name) || attribute_name != "sharding") {
return TokenError("expects attribute name: sharding");
}
if (!ParseToken(TokKind::kLbrace,
"expected '{' to start sharding attribute")) {
return false;
}
bool maximal = false;
bool replicated = false;
std::vector<int64> devices;
std::vector<int64> tile_assignment_dimensions;
Shape tile_shape;
while (lexer_.GetKind() != TokKind::kRbrace) {
switch (lexer_.GetKind()) {
case TokKind::kw_maximal:
maximal = true;
lexer_.Lex();
break;
case TokKind::kw_replicated:
replicated = true;
lexer_.Lex();
break;
case TokKind::kAttributeName: {
if (lexer_.GetStrVal() == "device") {
if (lexer_.Lex() != TokKind::kInt) {
return TokenError("device= attribute must be an integer");
}
devices = {lexer_.GetInt64Val()};
lexer_.Lex();
} else if (lexer_.GetStrVal() == "devices") {
lexer_.Lex();
if (!ParseToken(TokKind::kLsquare,
"expected '[' to start sharding devices shape")) {
return false;
}
do {
int64 dim;
if (!ParseInt64(&dim)) {
return false;
}
tile_assignment_dimensions.push_back(dim);
} while (EatIfPresent(TokKind::kComma));
if (!ParseToken(TokKind::kRsquare,
"expected ']' to start sharding devices shape")) {
return false;
}
do {
int64 device;
if (!ParseInt64(&device)) {
return false;
}
devices.push_back(device);
} while (EatIfPresent(TokKind::kComma));
} else {
return TokenError(
"unknown attribute in sharding: expected device= or devices=");
}
break;
}
case TokKind::kShape:
tile_shape = lexer_.GetShapeVal();
lexer_.Lex();
break;
case TokKind::kRbrace:
break;
default:
return TokenError("unexpected token");
}
}
OpSharding sharding;
if (replicated) {
if (!devices.empty()) {
return TokenError(
"replicated shardings should not have any devices assigned");
}
if (!ShapeUtil::Equal(tile_shape, Shape())) {
return TokenError(
"replicated shardings should not have any tile shape set");
}
sharding.set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
} else if (maximal) {
if (devices.size() != 1) {
return TokenError(
"maximal shardings should have exactly one device assigned");
}
if (!ShapeUtil::Equal(tile_shape, Shape())) {
return TokenError("maximal shardings should not have any tile shape set");
}
sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
sharding.add_tile_assignment_devices(devices[0]);
} else {
if (devices.size() <= 1) {
return TokenError(
"non-maximal shardings must have more than one device assigned");
}
if (ShapeUtil::Equal(tile_shape, Shape())) {
return TokenError("non-maximal shardings should have a tile shape set");
}
if (tile_assignment_dimensions.empty()) {
return TokenError(
"non-maximal shardings must have a tile assignment list including "
"dimensions");
}
sharding.set_type(OpSharding::Type::OpSharding_Type_OTHER);
*sharding.mutable_tile_shape() = tile_shape;
for (int64 dim : tile_assignment_dimensions) {
sharding.add_tile_assignment_dimensions(dim);
}
for (int64 device : devices) {
sharding.add_tile_assignment_devices(device);
}
}
instruction->set_sharding(HloSharding::FromProto(sharding).ValueOrDie());
lexer_.Lex();
return true;
}
bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
const Shape& shape) {
switch (shape.element_type()) {

View File

@ -100,9 +100,9 @@ ENTRY %add_constants () -> f32[] {
R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module:
ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] {
%v1 = f32[4]{0} parameter(0), device=1
%v2 = f32[4]{0} parameter(1), device=1
%greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2)
%v1 = f32[4]{0} parameter(0), sharding={maximal device=1}
%v2 = f32[4]{0} parameter(1), sharding={maximal device=1}
%greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2), sharding={replicated}
ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2)
}
@ -164,9 +164,9 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
R"(HloModule TwoSendRecvBothWayRecvFist_module:
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
%recv = f32[] recv(), channel_id=15, device=1
ROOT %constant = f32[] constant(2.1), device=0
%send = () send(f32[] %constant), channel_id=16, device=0
%recv = f32[] recv(), channel_id=15, sharding={maximal device=1}
ROOT %constant = f32[] constant(2.1), sharding={maximal device=0}
%send = () send(f32[] %constant), channel_id=16, sharding={maximal device=0}
}
)"
@ -180,7 +180,7 @@ ENTRY %GetTupleElement.v4 () -> s32[] {
%constant = f32[] constant(1.23)
%constant.1 = s32[] constant(4)
%tuple = (f32[], s32[]) tuple(f32[] %constant, s32[] %constant.1)
ROOT %get-tuple-element = s32[] get-tuple-element((f32[], s32[]) %tuple), index=1, device=0
ROOT %get-tuple-element = s32[] get-tuple-element((f32[], s32[]) %tuple), index=1, sharding={maximal device=0}
}
)"
@ -289,7 +289,7 @@ TEST_F(HloParserTest, MoreConstants) {
ENTRY %SelectScalarS32True.v4 () -> s32[] {
%constant.2 = pred[] constant(true)
%constant.1 = s32[] constant(-42)
%constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,3]1,2,3,4}
%constant = s32[] constant(42)
%select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant)
}

View File

@ -44,6 +44,8 @@ enum class TokKind {
kw_ROOT,
kw_true,
kw_false,
kw_maximal,
kw_replicated,
// Typed tokens.
kName, // %foo

View File

@ -812,18 +812,32 @@ message RecvRequest {
ChannelHandle channel_handle = 2;
}
message OpDeviceAssignment {
bool has_device = 1;
// Number of the device to which this operator is assigned. Ignored if
// 'has_device' is false.
int32 device = 2;
message OpSharding {
enum Type {
// This sharding is replicated across all devices (implies maximal,
// all other fields are unused).
REPLICATED = 0;
// This sharding is maximal - one device runs the entire operation.
MAXIMAL = 1;
// Neither of the above; tile_shape and tile_assignment are both used.
OTHER = 2;
}
Type type = 1;
// The shape of the sharded tile.
Shape tile_shape = 2;
// The shape of the tile assignment tensor - this must be the same rank as
// tile_shape and the product of its dimensions must equal
// tile_assignment_devices.size().
repeated int64 tile_assignment_dimensions = 3;
// Flattened list of device IDs. The order of flattening is the same as used
// by IndexUtil::MultiToLinearIndex(tile_assignment_shape).
repeated int64 tile_assignment_devices = 4;
}
message OpRequest {
ComputationHandle computation = 1;
OpMetadata metadata = 33;
OpDeviceAssignment device_assignment = 39;
OpSharding sharding = 40;
oneof op {
BinaryOpRequest binary_op_request = 2;
@ -862,7 +876,7 @@ message OpRequest {
BatchNormTrainingRequest batch_norm_training_request = 35;
BatchNormGradRequest batch_norm_grad_request = 37;
BatchNormInferenceRequest batch_norm_inference_request = 38;
// Next: 40
// Next: 41
}
}