mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA] don't call SetupDerivedInstruction on lhs and rhs in RewriteAsMultiplyDotWithZeroLhsContractingDim if they haven't changed.
PiperOrigin-RevId: 773743458
This commit is contained in:
parent
3736a70ce0
commit
b51a85e414
|
|
@ -414,6 +414,8 @@ xla_cc_test(
|
|||
"//xla/service:memory_annotations_hdr",
|
||||
"//xla/service:pattern_matcher",
|
||||
"//xla/service:shape_inference",
|
||||
"//xla/service/spmd/shardy:constants",
|
||||
"//xla/service/spmd/shardy:utils",
|
||||
"//xla/tests:test_utils",
|
||||
"//xla/tsl/lib/core:status_test_util",
|
||||
"//xla/tsl/platform:statusor",
|
||||
|
|
|
|||
|
|
@ -3543,8 +3543,12 @@ AlgebraicSimplifierVisitor::RewriteAsMultiplyDotWithZeroLhsContractingDim(
|
|||
}
|
||||
auto new_instruction = HloInstruction::CreateBinary(
|
||||
dot->shape(), HloOpcode::kMultiply, new_lhs, new_rhs);
|
||||
dot->SetupDerivedInstruction(new_lhs);
|
||||
dot->SetupDerivedInstruction(new_rhs);
|
||||
if (new_lhs != lhs) {
|
||||
dot->SetupDerivedInstruction(new_lhs);
|
||||
}
|
||||
if (new_rhs != rhs) {
|
||||
dot->SetupDerivedInstruction(new_rhs);
|
||||
}
|
||||
dot->SetupDerivedInstruction(new_instruction.get());
|
||||
return ReplaceWithNewInstruction(dot, std::move(new_instruction));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -57,6 +57,8 @@ limitations under the License.
|
|||
#include "xla/service/memory_annotations.h"
|
||||
#include "xla/service/pattern_matcher.h"
|
||||
#include "xla/service/shape_inference.h"
|
||||
#include "xla/service/spmd/shardy/constants.h"
|
||||
#include "xla/service/spmd/shardy/utils.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/shape_util.h"
|
||||
#include "xla/tests/test_utils.h"
|
||||
|
|
@ -12620,6 +12622,23 @@ ENTRY main.1 {
|
|||
HloOpcode::kParameter);
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, PreserveSdySharding) {
|
||||
const std::string hlo_string = R"(
|
||||
HloModule jit_matmul, entry_computation_layout={(f64[8,3]{1,0}, f64[])->f64[8,3]{1,0}}, allow_spmd_sharding_propagation_to_parameters={false,true}, allow_spmd_sharding_propagation_to_output={true}, num_partitions=2
|
||||
ENTRY %main.4 (Arg_0.1: f64[8,3], Arg_1.2: f64[]) -> f64[8,3] {
|
||||
%Arg_1.2 = f64[] parameter(1)
|
||||
%Arg_0.1 = f64[8,3]{1,0} parameter(0), frontend_attributes={xla.sdy.sharding="#sdy.sharding<@mesh, [{\"x\"}, {}]>"}
|
||||
ROOT %dot.3 = f64[8,3]{1,0} dot(f64[] %Arg_1.2, f64[8,3]{1,0} %Arg_0.1), lhs_contracting_dims={}, rhs_contracting_dims={}, metadata={op_name="jit(matmul)/jit(main)/dot_general[dimension_numbers=(((), ()), ((), ())) precision=None preferred_element_type=float64]" source_file="third_party/py/jax/tests/pjit_test.py" source_line=4021}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
|
||||
EXPECT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
|
||||
EXPECT_EQ(
|
||||
m->entry_computation()->parameter_instruction(0)->get_frontend_attribute(
|
||||
sdy::toStringView(sdy::kShardingRoundTripAttr)),
|
||||
"#sdy.sharding<@mesh, [{\"x\"}, {}]>");
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, ReduceOfConstantBroadcastS32) {
|
||||
const std::string hlo_string = R"(
|
||||
HloModule test
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user