[XLA] don't call SetupDerivedInstruction on lhs and rhs in RewriteAsMultiplyDotWithZeroLhsContractingDim if they haven't changed.

PiperOrigin-RevId: 773743458
This commit is contained in:
Tom Natan 2025-06-20 09:57:50 -07:00 committed by TensorFlower Gardener
parent 3736a70ce0
commit b51a85e414
3 changed files with 27 additions and 2 deletions

View File

@ -414,6 +414,8 @@ xla_cc_test(
"//xla/service:memory_annotations_hdr",
"//xla/service:pattern_matcher",
"//xla/service:shape_inference",
"//xla/service/spmd/shardy:constants",
"//xla/service/spmd/shardy:utils",
"//xla/tests:test_utils",
"//xla/tsl/lib/core:status_test_util",
"//xla/tsl/platform:statusor",

View File

@ -3543,8 +3543,12 @@ AlgebraicSimplifierVisitor::RewriteAsMultiplyDotWithZeroLhsContractingDim(
}
auto new_instruction = HloInstruction::CreateBinary(
dot->shape(), HloOpcode::kMultiply, new_lhs, new_rhs);
dot->SetupDerivedInstruction(new_lhs);
dot->SetupDerivedInstruction(new_rhs);
if (new_lhs != lhs) {
dot->SetupDerivedInstruction(new_lhs);
}
if (new_rhs != rhs) {
dot->SetupDerivedInstruction(new_rhs);
}
dot->SetupDerivedInstruction(new_instruction.get());
return ReplaceWithNewInstruction(dot, std::move(new_instruction));
}

View File

@ -57,6 +57,8 @@ limitations under the License.
#include "xla/service/memory_annotations.h"
#include "xla/service/pattern_matcher.h"
#include "xla/service/shape_inference.h"
#include "xla/service/spmd/shardy/constants.h"
#include "xla/service/spmd/shardy/utils.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tests/test_utils.h"
@ -12620,6 +12622,23 @@ ENTRY main.1 {
HloOpcode::kParameter);
}
TEST_F(AlgebraicSimplifierTest, PreserveSdySharding) {
const std::string hlo_string = R"(
HloModule jit_matmul, entry_computation_layout={(f64[8,3]{1,0}, f64[])->f64[8,3]{1,0}}, allow_spmd_sharding_propagation_to_parameters={false,true}, allow_spmd_sharding_propagation_to_output={true}, num_partitions=2
ENTRY %main.4 (Arg_0.1: f64[8,3], Arg_1.2: f64[]) -> f64[8,3] {
%Arg_1.2 = f64[] parameter(1)
%Arg_0.1 = f64[8,3]{1,0} parameter(0), frontend_attributes={xla.sdy.sharding="#sdy.sharding<@mesh, [{\"x\"}, {}]>"}
ROOT %dot.3 = f64[8,3]{1,0} dot(f64[] %Arg_1.2, f64[8,3]{1,0} %Arg_0.1), lhs_contracting_dims={}, rhs_contracting_dims={}, metadata={op_name="jit(matmul)/jit(main)/dot_general[dimension_numbers=(((), ()), ((), ())) precision=None preferred_element_type=float64]" source_file="third_party/py/jax/tests/pjit_test.py" source_line=4021}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
EXPECT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
EXPECT_EQ(
m->entry_computation()->parameter_instruction(0)->get_frontend_attribute(
sdy::toStringView(sdy::kShardingRoundTripAttr)),
"#sdy.sharding<@mesh, [{\"x\"}, {}]>");
}
TEST_F(AlgebraicSimplifierTest, ReduceOfConstantBroadcastS32) {
const std::string hlo_string = R"(
HloModule test