mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA:GPU] Initialize PrecisionConfig for ScaledDot in composite rewriter.
Explicitly set the operand precisions to `PrecisionConfig::DEFAULT` when creating a `ScaledDot` instruction from a composite call. PiperOrigin-RevId: 823488638
This commit is contained in:
parent
a5fca6a9b5
commit
0c0947cea6
|
|
@ -120,11 +120,13 @@ absl::StatusOr<bool> CompositeRewriter::RewriteComputation(
|
|||
TF_ASSIGN_OR_RETURN(
|
||||
DotDimensionNumbers dot_dimension_numbers,
|
||||
ParseDimensionNumbers(frontend_attrs.at("composite.attributes")));
|
||||
PrecisionConfig precision{};
|
||||
precision.mutable_operand_precision()->Resize(2, PrecisionConfig::DEFAULT);
|
||||
auto* scaled_dot =
|
||||
computation->AddInstruction(HloInstruction::CreateScaledDot(
|
||||
call->shape(), call->mutable_operand(0), call->mutable_operand(1),
|
||||
call->mutable_operand(2), call->mutable_operand(3),
|
||||
dot_dimension_numbers, PrecisionConfig{}));
|
||||
dot_dimension_numbers, precision));
|
||||
TF_RETURN_IF_ERROR(call->ReplaceAllUsesWith(scaled_dot));
|
||||
TF_RETURN_IF_ERROR(computation->RemoveInstruction(call));
|
||||
changed = true;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user