The compatible factor shardings should not have overlap between axes across different tensors.

PiperOrigin-RevId: 824815687
This commit is contained in:
Zixuan Jiang 2025-10-27 21:02:19 -07:00 committed by TensorFlower Gardener
parent c09d68c588
commit 402ead44b2
3 changed files with 8 additions and 5 deletions

View File

@ -68,7 +68,8 @@ func.func @ragged_dot_mode_batch(
%arg0: tensor<16x32x64xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>},
%arg1: tensor<16x64x8xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"c"}, {"d"}]>},
%arg2: tensor<4xi32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}]>}) -> (tensor<16x32x8xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"d"}]>}) {
// CHECK: %[[RAGGED_DOT:.*]] = "mhlo.ragged_dot"(%arg0, %arg1, %arg2) <{
// CHECK: %[[RESHARD:.*]] = sdy.reshard %arg2 <@mesh_abcd, [{}]> : tensor<4xi32>
// CHECK: %[[RAGGED_DOT:.*]] = "mhlo.ragged_dot"(%arg0, %arg1, %[[RESHARD]]) <{
// CHECK: }>
// CHECK-SAME: sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {"b"}, {"d"}]>]>
// CHECK-SAME: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, l], [i, l, k], [m])->([i, j, k]) {i=16, j=32, k=8, l=64, m=1} reduction={l}>

View File

@ -62,7 +62,8 @@ func.func @ragged_dot_mode_batch(
%arg0: tensor<16x32x64xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>},
%arg1: tensor<16x64x8xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"c"}, {"d"}]>},
%arg2: tensor<4xi32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}]>}) -> tensor<16x32x8xf32> {
// CHECK: %[[RAGGED_DOT:.*]] = "mhlo.ragged_dot"(%arg0, %arg1, %arg2) <{
// CHECK: %[[RESHARD:.*]] = sdy.reshard %arg2 <@mesh_abcd, [{}]> : tensor<4xi32>
// CHECK: %[[RAGGED_DOT:.*]] = "mhlo.ragged_dot"(%arg0, %arg1, %[[RESHARD]]) <{
// CHECK: }>
// CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {"b"}, {"d"}]>]>
// CHECK: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"c"} %[[RAGGED_DOT]] out_sharding=<@mesh_abcd, [{"a"}, {"b"}, {"d"}]> : tensor<16x32x8xf32>

View File

@ -64,12 +64,13 @@ func.func @ragged_dot_mode_batch(
%arg0: tensor<16x32x64xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>},
%arg1: tensor<16x64x8xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"c"}, {"d"}]>},
%arg2: tensor<4xi32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}]>}) -> tensor<16x32x8xf32> {
// CHECK: %[[RAGGED_DOT:.*]] = "mhlo.ragged_dot"(%arg0, %arg1, %arg2) <{
// CHECK: %[[RESHARD0:.*]] = sdy.reshard %arg2 <@mesh_abcd, [{}]> : tensor<4xi32>
// CHECK: %[[RAGGED_DOT:.*]] = "mhlo.ragged_dot"(%arg0, %arg1, %[[RESHARD0]]) <{
// CHECK: }>
// CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {"b"}, {"d"}]>]>
// CHECK: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"c"} %[[RAGGED_DOT]] out_sharding=<@mesh_abcd, [{"a"}, {"b"}, {"d"}]> : tensor<16x32x8xf32>
// CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %[[ALL_REDUCE]] <@mesh_abcd, [{}, {}, {}]> : tensor<16x32x8xf32>
// CHECK: return %[[RESHARD]] : tensor<16x32x8xf32>
// CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %[[ALL_REDUCE]] <@mesh_abcd, [{}, {}, {}]> : tensor<16x32x8xf32>
// CHECK: return %[[RESHARD1]] : tensor<16x32x8xf32>
%0 = "mhlo.ragged_dot"(%arg0, %arg1, %arg2) <{ragged_dot_dimension_numbers =
#mhlo.ragged_dot<dot_dimension_numbers = #mhlo.dot<
lhs_batching_dimensions = [0], rhs_batching_dimensions = [0],