[XLA:GPU] Initialize PrecisionConfig for ScaledDot in composite rewriter.

Explicitly set the operand precisions to `PrecisionConfig::DEFAULT` when creating a `ScaledDot` instruction from a composite call.

PiperOrigin-RevId: 823488638
This commit is contained in:
Ilya Tikhonovskiy 2025-10-24 05:29:10 -07:00 committed by TensorFlower Gardener
parent a5fca6a9b5
commit 0c0947cea6

View File

@ -120,11 +120,13 @@ absl::StatusOr<bool> CompositeRewriter::RewriteComputation(
TF_ASSIGN_OR_RETURN(
DotDimensionNumbers dot_dimension_numbers,
ParseDimensionNumbers(frontend_attrs.at("composite.attributes")));
PrecisionConfig precision{};
precision.mutable_operand_precision()->Resize(2, PrecisionConfig::DEFAULT);
auto* scaled_dot =
computation->AddInstruction(HloInstruction::CreateScaledDot(
call->shape(), call->mutable_operand(0), call->mutable_operand(1),
call->mutable_operand(2), call->mutable_operand(3),
dot_dimension_numbers, PrecisionConfig{}));
dot_dimension_numbers, precision));
TF_RETURN_IF_ERROR(call->ReplaceAllUsesWith(scaled_dot));
TF_RETURN_IF_ERROR(computation->RemoveInstruction(call));
changed = true;