diff --git a/third_party/xla/xla/service/spmd/shardy/test/mhlo_extensions_test.mlir b/third_party/xla/xla/service/spmd/shardy/test/mhlo_extensions_test.mlir index 923d9409898..170056aaead 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/mhlo_extensions_test.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/mhlo_extensions_test.mlir @@ -68,7 +68,8 @@ func.func @ragged_dot_mode_batch( %arg0: tensor<16x32x64xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>}, %arg1: tensor<16x64x8xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"c"}, {"d"}]>}, %arg2: tensor<4xi32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}]>}) -> (tensor<16x32x8xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"d"}]>}) { - // CHECK: %[[RAGGED_DOT:.*]] = "mhlo.ragged_dot"(%arg0, %arg1, %arg2) <{ + // CHECK: %[[RESHARD:.*]] = sdy.reshard %arg2 <@mesh_abcd, [{}]> : tensor<4xi32> + // CHECK: %[[RAGGED_DOT:.*]] = "mhlo.ragged_dot"(%arg0, %arg1, %[[RESHARD]]) <{ // CHECK: }> // CHECK-SAME: sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {"b"}, {"d"}]>]> // CHECK-SAME: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, l], [i, l, k], [m])->([i, j, k]) {i=16, j=32, k=8, l=64, m=1} reduction={l}> diff --git a/third_party/xla/xla/service/spmd/shardy/test/ragged_dot_insert_explicit_reshards.mlir b/third_party/xla/xla/service/spmd/shardy/test/ragged_dot_insert_explicit_reshards.mlir index 9b9c97aa1fb..ad893647423 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/ragged_dot_insert_explicit_reshards.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/ragged_dot_insert_explicit_reshards.mlir @@ -62,7 +62,8 @@ func.func @ragged_dot_mode_batch( %arg0: tensor<16x32x64xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>}, %arg1: tensor<16x64x8xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"c"}, {"d"}]>}, %arg2: tensor<4xi32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}]>}) -> tensor<16x32x8xf32> { - // CHECK: %[[RAGGED_DOT:.*]] = "mhlo.ragged_dot"(%arg0, %arg1, %arg2) <{ + // CHECK: %[[RESHARD:.*]] = sdy.reshard %arg2 <@mesh_abcd, [{}]> : tensor<4xi32> + // CHECK: %[[RAGGED_DOT:.*]] = "mhlo.ragged_dot"(%arg0, %arg1, %[[RESHARD]]) <{ // CHECK: }> // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {"b"}, {"d"}]>]> // CHECK: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"c"} %[[RAGGED_DOT]] out_sharding=<@mesh_abcd, [{"a"}, {"b"}, {"d"}]> : tensor<16x32x8xf32> diff --git a/third_party/xla/xla/service/spmd/shardy/test/ragged_dot_insert_explicit_reshards_enable_full_version_true.mlir b/third_party/xla/xla/service/spmd/shardy/test/ragged_dot_insert_explicit_reshards_enable_full_version_true.mlir index 1f401e10b2b..df81b84f627 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/ragged_dot_insert_explicit_reshards_enable_full_version_true.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/ragged_dot_insert_explicit_reshards_enable_full_version_true.mlir @@ -64,12 +64,13 @@ func.func @ragged_dot_mode_batch( %arg0: tensor<16x32x64xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"b"}, {"c"}]>}, %arg1: tensor<16x64x8xf32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}, {"c"}, {"d"}]>}, %arg2: tensor<4xi32> {sdy.sharding=#sdy.sharding<@mesh_abcd, [{"a"}]>}) -> tensor<16x32x8xf32> { - // CHECK: %[[RAGGED_DOT:.*]] = "mhlo.ragged_dot"(%arg0, %arg1, %arg2) <{ + // CHECK: %[[RESHARD0:.*]] = sdy.reshard %arg2 <@mesh_abcd, [{}]> : tensor<4xi32> + // CHECK: %[[RAGGED_DOT:.*]] = "mhlo.ragged_dot"(%arg0, %arg1, %[[RESHARD0]]) <{ // CHECK: }> // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh_abcd, [{"a"}, {"b"}, {"d"}]>]> // CHECK: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"c"} %[[RAGGED_DOT]] out_sharding=<@mesh_abcd, [{"a"}, {"b"}, {"d"}]> : tensor<16x32x8xf32> - // CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %[[ALL_REDUCE]] <@mesh_abcd, [{}, {}, {}]> : tensor<16x32x8xf32> - // CHECK: return %[[RESHARD]] : tensor<16x32x8xf32> + // CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %[[ALL_REDUCE]] <@mesh_abcd, [{}, {}, {}]> : tensor<16x32x8xf32> + // CHECK: return %[[RESHARD1]] : tensor<16x32x8xf32> %0 = "mhlo.ragged_dot"(%arg0, %arg1, %arg2) <{ragged_dot_dimension_numbers = #mhlo.ragged_dot