mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Layout delegates most of the functionality to C++.
A few unused methods are removed to slightly reduce the API surface. PiperOrigin-RevId: 516617922
This commit is contained in:
parent
559bf4b692
commit
bf65880ef4
|
|
@ -121,7 +121,9 @@
|
|||
* Deprecated `dtensor.run_on` in favor of `dtensor.default_mesh` to
|
||||
correctly indicate that the context does not override the mesh that the
|
||||
ops and functions will run on, it only sets a fallback default mesh.
|
||||
|
||||
* List of members of dtensor.Layout and dtensor.Mesh have slightly changed
|
||||
as part of efforts to consolidate the C++ and Python source
|
||||
code with pybind11. Most notably, Layout.serialized_string is removed.
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
|
|
|
|||
|
|
@ -969,6 +969,10 @@ size_t Layout::num_shards_for_dim(const ShardingSpec& dim) const {
|
|||
return mesh().dim_size(name).value();
|
||||
}
|
||||
|
||||
size_t Layout::num_shards_for_dim(int dim) const {
|
||||
return num_shards_for_dim(sharding_specs_[dim]);
|
||||
}
|
||||
|
||||
bool Layout::IsFullyReplicated() const {
|
||||
for (const auto& sharding_spec : sharding_specs_) {
|
||||
if (num_shards_for_dim(sharding_spec) > 1) {
|
||||
|
|
|
|||
|
|
@ -353,6 +353,7 @@ class Layout {
|
|||
|
||||
int64 rank() const { return sharding_specs_.size(); }
|
||||
size_t num_shards_for_dim(const ShardingSpec& dim) const;
|
||||
size_t num_shards_for_dim(int) const;
|
||||
std::vector<int32> num_shards() const;
|
||||
|
||||
const ShardingSpec& dim(int64 idx) const { return sharding_specs_[idx]; }
|
||||
|
|
@ -367,9 +368,7 @@ class Layout {
|
|||
std::vector<std::string> sharding_spec_strs() const;
|
||||
|
||||
int64 num_devices() const { return mesh_.num_devices(); }
|
||||
StatusOr<const DeviceLocation> device_location(int64 device_id) const {
|
||||
return mesh_.device_location(device_id);
|
||||
}
|
||||
|
||||
// Map hosts to shards.
|
||||
std::map<std::string, ShardVector> HostShardMap() const;
|
||||
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ StatusOr<AllReducePartitions> GetAllReducePartitionsFromReducedDims(
|
|||
AllReducePartitions partitions;
|
||||
for (int64 device = 0; device < output_layout.num_devices(); ++device) {
|
||||
TF_ASSIGN_OR_RETURN(const DeviceLocation device_loc,
|
||||
output_layout.device_location(device));
|
||||
output_layout.mesh().device_location(device));
|
||||
DeviceLocation kept_dims;
|
||||
for (int64 dim_idx = 0; dim_idx < device_loc.size(); ++dim_idx) {
|
||||
if (!reduced_dims.contains(output_layout.mesh().dim_name(dim_idx))) {
|
||||
|
|
|
|||
|
|
@ -928,7 +928,7 @@ mlir::LogicalResult LowerAllGatherOp(mlir::TF::DTensorAllGatherOp all_gather) {
|
|||
// For sharded dimensions, the slice range is [step * device_id, step *
|
||||
// (device_id + 1)), where step = dim_size / num_of_shards.
|
||||
StatusOr<DeviceLocation> device_loc_or_status =
|
||||
src_layout.device_location(device_id);
|
||||
src_layout.mesh().device_location(device_id);
|
||||
if (!device_loc_or_status.ok())
|
||||
return all_gather.emitOpError()
|
||||
<< device_loc_or_status.status().error_message();
|
||||
|
|
|
|||
|
|
@ -361,7 +361,7 @@ class Mesh(_pywrap_dtensor_device.Mesh):
|
|||
|
||||
# TODO(hthu): Consider making this class immutable.
|
||||
@tf_export('experimental.dtensor.Layout', v1=[])
|
||||
class Layout(object):
|
||||
class Layout(_pywrap_dtensor_device.Layout):
|
||||
"""Represents the layout information of a DTensor.
|
||||
|
||||
A layout describes how a distributed tensor is partitioned across a mesh (and
|
||||
|
|
@ -438,38 +438,44 @@ class Layout(object):
|
|||
'valid mesh dimension or UNSHARDED.').format(
|
||||
dim_sharding=dim_sharding))
|
||||
|
||||
# Set object's state
|
||||
self.sharding_specs = sharding_specs
|
||||
self.rank = len(sharding_specs)
|
||||
self.mesh = mesh
|
||||
self.shape = [self.num_shards(i) for i in range(self.rank)]
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return self.serialized_string() == other.serialized_string()
|
||||
super().__init__(sharding_specs, mesh)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'Layout(sharding_specs={self.sharding_specs}, mesh={self.mesh})'
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.serialized_string())
|
||||
def __hash__(self):
|
||||
return hash(self.as_proto().SerializeToString(deterministic=True))
|
||||
|
||||
def as_proto(self) -> layout_pb2.LayoutProto:
|
||||
"""Create a proto representation of a layout."""
|
||||
layout_proto = layout_pb2.LayoutProto()
|
||||
# TODO(panzf): change to pybind11 pickle implementation in the last step
|
||||
def __reduce__(self):
|
||||
return Layout.from_string, (self.to_string(),)
|
||||
|
||||
for dim_sharding in self.sharding_specs:
|
||||
tensor_dim = layout_proto.sharding_specs.add()
|
||||
tensor_dim.sharding_spec = dim_sharding
|
||||
# TODO(b/242201545): Find a way to return Mesh object from the pywrap module.
|
||||
@property
|
||||
def mesh(self):
|
||||
return Mesh.from_proto(super().mesh.as_proto())
|
||||
|
||||
layout_proto.mesh_config.CopyFrom(self.mesh_proto())
|
||||
|
||||
return layout_proto
|
||||
@property
|
||||
def shape(self):
|
||||
return self.mesh.shape()
|
||||
|
||||
@staticmethod
|
||||
def batch_sharded(mesh: Mesh, batch_dim: str, rank: int) -> 'Layout':
|
||||
def batch_sharded(
|
||||
mesh: Mesh, batch_dim: str, rank: int, axis: int = 0
|
||||
) -> 'Layout':
|
||||
"""Returns a layout sharded on batch dimension."""
|
||||
return Layout([batch_dim] + [UNSHARDED] * (rank - 1), mesh)
|
||||
layout_obj = _pywrap_dtensor_device.Layout.__new__(Layout)
|
||||
_pywrap_dtensor_device.Layout.__init__(
|
||||
# Watchout for the different ordering.
|
||||
layout_obj,
|
||||
mesh=mesh,
|
||||
rank=rank,
|
||||
batch_dim=batch_dim,
|
||||
axis=axis,
|
||||
)
|
||||
return layout_obj
|
||||
|
||||
# TODO(b/242201545): Move this to C++ / find the corresponding function there.
|
||||
def delete(self, dims: List[int]) -> 'Layout':
|
||||
"""Returns the layout with the give dimensions deleted."""
|
||||
if not isinstance(dims, list):
|
||||
|
|
@ -480,57 +486,25 @@ class Layout(object):
|
|||
return Layout(new_specs, self.mesh)
|
||||
|
||||
@staticmethod
|
||||
def from_str(layout_str: bytes) -> 'Layout':
|
||||
"""Creates an instance from a serialized Protobuf binary string."""
|
||||
layout_proto = layout_pb2.LayoutProto()
|
||||
layout_proto.ParseFromString(layout_str)
|
||||
sharding_specs = [
|
||||
sharding_spec.sharding_spec
|
||||
for sharding_spec in layout_proto.sharding_specs
|
||||
]
|
||||
mesh = Mesh.from_proto(layout_proto.mesh_config)
|
||||
return Layout(sharding_specs, mesh)
|
||||
def from_proto(layout_proto: layout_pb2.LayoutProto) -> 'Layout':
|
||||
"""Creates an instance from a LayoutProto."""
|
||||
layout_obj = _pywrap_dtensor_device.Layout.__new__(Layout)
|
||||
_pywrap_dtensor_device.Layout__init__(layout_obj, layout_proto)
|
||||
return layout_obj
|
||||
|
||||
@staticmethod
|
||||
def from_string(layout_str: str) -> 'Layout':
|
||||
"""Creates an instance from a human-readable string."""
|
||||
layout_parts = layout_str.split(' ')
|
||||
if len(layout_parts) != 2:
|
||||
raise ValueError(
|
||||
'layout string must contain two parts: specs and mesh. But got {}.'
|
||||
.format(layout_str))
|
||||
|
||||
sharding_specs_str = layout_parts[0].replace('sharding_specs:', '')
|
||||
mesh_str = layout_parts[1].replace('mesh:', '')
|
||||
|
||||
sharding_specs = sharding_specs_str.split(',')[:-1]
|
||||
|
||||
mesh = Mesh.from_string(mesh_str)
|
||||
layout = Layout(sharding_specs, mesh)
|
||||
return layout
|
||||
layout_obj = _pywrap_dtensor_device.Layout.__new__(Layout)
|
||||
_pywrap_dtensor_device.Layout.__init__(layout_obj, layout_str)
|
||||
return layout_obj
|
||||
|
||||
@staticmethod
|
||||
def inner_sharded(mesh: Mesh, inner_dim: str, rank: int) -> 'Layout':
|
||||
"""Returns a layout sharded on inner dimension."""
|
||||
return Layout([UNSHARDED] * (rank - 1) + [inner_dim], mesh)
|
||||
|
||||
def is_fully_replicated(self) -> bool:
|
||||
"""Returns True if all tensor axes are replicated."""
|
||||
return all([self.num_shards(i) == 1 for i in range(self.rank)])
|
||||
|
||||
def mesh_proto(self) -> layout_pb2.MeshProto:
|
||||
"""Returns the underlying mesh in Protobuf format."""
|
||||
return self.mesh.as_proto()
|
||||
|
||||
def num_shards(self, idx: int) -> int:
|
||||
"""Returns the number of shards for tensor dimension `idx`."""
|
||||
dim_sharding = self.sharding_specs[idx]
|
||||
if dim_sharding == UNSHARDED:
|
||||
return 1
|
||||
if dim_sharding == MATCH:
|
||||
return -1
|
||||
return self.mesh.dim_size(dim_sharding)
|
||||
return Layout.batch_sharded(mesh, inner_dim, rank, axis=rank - 1)
|
||||
|
||||
# TODO(b/242201545): Move this to C++ / find the corresponding function there.
|
||||
def offset_to_shard(self):
|
||||
"""Mapping from offset in a flattened list to shard index."""
|
||||
unravel_index = self.mesh.unravel_index()
|
||||
|
|
@ -543,8 +517,10 @@ class Layout(object):
|
|||
else:
|
||||
loc.append(mesh_loc[dim_sharding])
|
||||
locations[offset] = tuple(loc)
|
||||
|
||||
return locations
|
||||
|
||||
# TODO(b/242201545): Move this to C++ / find the corresponding function there.
|
||||
def offset_tuple_to_global_index(self, offset_tuple):
|
||||
"""Mapping from offset to index in global tensor."""
|
||||
index = 0
|
||||
|
|
@ -558,20 +534,6 @@ class Layout(object):
|
|||
@staticmethod
|
||||
def replicated(mesh: Mesh, rank: int) -> 'Layout':
|
||||
"""Returns a replicated layout of rank `rank`."""
|
||||
return Layout([UNSHARDED] * rank, mesh)
|
||||
|
||||
def serialized_string(self) -> bytes:
|
||||
"""Returns a serialized Protobuf binary string representation."""
|
||||
return self.as_proto().SerializeToString(deterministic=True)
|
||||
|
||||
# A layout with no sharding specs is acceptable, therefore we only check the
|
||||
# mesh.
|
||||
def to_string(self) -> str:
|
||||
"""Returns a human-readable string representation."""
|
||||
sharding_spec_str = 'sharding_specs:'
|
||||
# Add comma after each instruction.
|
||||
for spec in self.sharding_specs:
|
||||
sharding_spec_str += spec + ','
|
||||
|
||||
mesh_str = 'mesh:' + self.mesh.to_string()
|
||||
return sharding_spec_str + ' ' + mesh_str
|
||||
layout_obj = _pywrap_dtensor_device.Layout.__new__(Layout)
|
||||
_pywrap_dtensor_device.Layout.__init__(layout_obj, mesh=mesh, rank=rank)
|
||||
return layout_obj
|
||||
|
|
|
|||
|
|
@ -199,17 +199,17 @@ TEST_F(LayoutTest, IsReplicated) {
|
|||
EXPECT_FALSE(BatchLayout().IsFullyReplicated());
|
||||
}
|
||||
|
||||
TEST_F(LayoutTest, LayoutDimLocations) {
|
||||
TEST_F(LayoutTest, MeshDeviceLocations) {
|
||||
Layout layout = BatchLayout();
|
||||
absl::InlinedVector<int64, 4> offset = {1, 2};
|
||||
EXPECT_THAT(layout.device_location(10), IsOkAndHolds(offset));
|
||||
EXPECT_THAT(layout.mesh().device_location(10), IsOkAndHolds(offset));
|
||||
offset = {2, 2};
|
||||
EXPECT_THAT(layout.device_location(18), IsOkAndHolds(offset));
|
||||
EXPECT_THAT(layout.mesh().device_location(18), IsOkAndHolds(offset));
|
||||
offset = {3, 7};
|
||||
EXPECT_THAT(layout.device_location(31), IsOkAndHolds(offset));
|
||||
EXPECT_THAT(layout.mesh().device_location(31), IsOkAndHolds(offset));
|
||||
|
||||
EXPECT_FALSE(layout.device_location(32).ok());
|
||||
EXPECT_FALSE(layout.device_location(-1).ok());
|
||||
EXPECT_FALSE(layout.mesh().device_location(32).ok());
|
||||
EXPECT_FALSE(layout.mesh().device_location(-1).ok());
|
||||
}
|
||||
|
||||
TEST_F(LayoutTest, ScalarLayout) {
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ using tensorflow::dtensor::FetchLayout;
|
|||
using tensorflow::dtensor::GetFunctionCacheStats;
|
||||
using tensorflow::dtensor::IsDTensor;
|
||||
using tensorflow::dtensor::IsSparseDTensor;
|
||||
using tensorflow::dtensor::Layout;
|
||||
using tensorflow::dtensor::Mesh;
|
||||
using tensorflow::dtensor::Pack;
|
||||
using tensorflow::dtensor::SetIteratorElementLayouts;
|
||||
|
|
@ -438,5 +439,59 @@ PYBIND11_MODULE(_pywrap_dtensor_device, m) {
|
|||
"Returns True if Mesh will use XLA for SPMD "
|
||||
"instead of DTensor SPMD.")
|
||||
.def("as_proto", &Mesh::ToProto,
|
||||
"Returns the MeshProto protobuf message.");
|
||||
"Returns the MeshProto protobuf message.")
|
||||
.def("device_location", [](const Mesh& mesh, int device_id) {
|
||||
auto location = mesh.device_location(device_id);
|
||||
if (!location.ok()) {
|
||||
throw py::value_error(location.status().error_message());
|
||||
}
|
||||
return std::vector<int64_t>(location->begin(), location->end());
|
||||
});
|
||||
py::class_<Layout>(m, "Layout")
|
||||
.def(py::init(
|
||||
[](const std::vector<std::string>& sharding_specs, const Mesh& mesh) {
|
||||
auto layout = Layout::GetLayout(sharding_specs, mesh);
|
||||
if (!layout.ok()) {
|
||||
throw py::value_error(layout.status().error_message());
|
||||
}
|
||||
return *layout;
|
||||
}))
|
||||
.def(py::init([](const tensorflow::dtensor::LayoutProto& proto) {
|
||||
auto layout = Layout::FromProto(proto);
|
||||
if (!layout.ok()) {
|
||||
throw py::value_error(layout.status().error_message());
|
||||
}
|
||||
return *layout;
|
||||
}),
|
||||
"Returns a Layout from a LayoutProto.")
|
||||
.def(py::init([](std::string_view layout_str) {
|
||||
auto layout = Layout::FromString(layout_str);
|
||||
if (!layout.ok()) {
|
||||
throw py::value_error(layout.status().error_message());
|
||||
}
|
||||
return *layout;
|
||||
}),
|
||||
"Returns a Layout from a string.")
|
||||
.def(py::init(&Layout::ReplicatedOnMesh), py::arg("mesh"),
|
||||
py::arg("rank"), "Returns a replicated layout.")
|
||||
.def(py::init(&Layout::BatchShardedOnMesh), py::arg("mesh"),
|
||||
py::arg("rank"), py::arg("batch_dim"), py::arg("axis"),
|
||||
"Returns a batch sharded layout.")
|
||||
.def("__eq__", &Layout::operator==)
|
||||
.def("as_proto", &Layout::ToProto)
|
||||
.def("to_string", &Layout::ToString)
|
||||
.def_property_readonly("sharding_specs", &Layout::sharding_spec_strs)
|
||||
.def_property_readonly("rank", &Layout::rank)
|
||||
.def_property_readonly("mesh", &Layout::mesh)
|
||||
.def("is_fully_replicated", &Layout::IsFullyReplicated,
|
||||
"Returns True if all tensor axes are replicated.")
|
||||
.def("is_batch_parallel",
|
||||
[](const Layout& layout) { return layout.IsBatchParallel(); })
|
||||
.def(
|
||||
"num_shards",
|
||||
[](const Layout& layout, int dim) {
|
||||
return layout.num_shards_for_dim(dim);
|
||||
},
|
||||
py::arg("idx"),
|
||||
"Returns the number of shards for tensor dimension `idx`.");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,24 @@
|
|||
path: "tensorflow.experimental.dtensor.Layout"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.dtensor.python.layout.Layout\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'tensorflow.python._pywrap_dtensor_device.Layout\'>"
|
||||
is_instance: "<class \'pybind11_builtins.pybind11_object\'>"
|
||||
member {
|
||||
name: "mesh"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "rank"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "sharding_specs"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'sharding_specs\', \'mesh\'], varargs=None, keywords=None, defaults=None"
|
||||
|
|
@ -12,15 +29,15 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "batch_sharded"
|
||||
argspec: "args=[\'mesh\', \'batch_dim\', \'rank\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'mesh\', \'batch_dim\', \'rank\', \'axis\'], varargs=None, keywords=None, defaults=[\'0\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "delete"
|
||||
argspec: "args=[\'self\', \'dims\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_str"
|
||||
argspec: "args=[\'layout_str\'], varargs=None, keywords=None, defaults=None"
|
||||
name: "from_proto"
|
||||
argspec: "args=[\'layout_proto\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_string"
|
||||
|
|
@ -31,11 +48,11 @@ tf_class {
|
|||
argspec: "args=[\'mesh\', \'inner_dim\', \'rank\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_fully_replicated"
|
||||
name: "is_batch_parallel"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "mesh_proto"
|
||||
name: "is_fully_replicated"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
|
|
@ -54,10 +71,6 @@ tf_class {
|
|||
name: "replicated"
|
||||
argspec: "args=[\'mesh\', \'rank\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "serialized_string"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "to_string"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
|
|
|
|||
|
|
@ -35,6 +35,10 @@ tf_class {
|
|||
name: "coords"
|
||||
argspec: "args=[\'self\', \'device_idx\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "device_location"
|
||||
argspec: "args=[\'self\', \'arg0\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "device_type"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ tf_module {
|
|||
}
|
||||
member {
|
||||
name: "Layout"
|
||||
mtype: "<type \'type\'>"
|
||||
mtype: "<class \'pybind11_builtins.pybind11_type\'>"
|
||||
}
|
||||
member {
|
||||
name: "MATCH"
|
||||
|
|
|
|||
|
|
@ -568,3 +568,4 @@ tensorflow::dtensor::SetIteratorElementLayouts
|
|||
|
||||
[//tensorflow/dtensor/cc:tensor_layout] # DTensor
|
||||
tensorflow::dtensor::Mesh
|
||||
tensorflow::dtensor::Layout
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user