mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[StableHLO->HLO] Only lower MHLO constants in MHLO prepare for export pass.
PiperOrigin-RevId: 822198262
This commit is contained in:
parent
7524326efd
commit
fe624fe9ce
|
|
@ -130,7 +130,8 @@ void prepareExplicitCapturedConstants(Operation* op) {
|
|||
// it explicit and replace uses within the block
|
||||
Operation *definingOp = input.getDefiningOp();
|
||||
mlir::DenseElementsAttr attr;
|
||||
if (matchPattern(input, m_Constant(&attr))) {
|
||||
if (mlir::isa_and_present<ConstantOp>(input.getDefiningOp()) &&
|
||||
matchPattern(input, m_Constant(&attr))) {
|
||||
Operation *clonedOp = builder.clone(*definingOp);
|
||||
// Find which uses belong to the block and replace
|
||||
// with the cloned/explicit one
|
||||
|
|
@ -146,9 +147,10 @@ void prepareExplicitCapturedConstants(Operation* op) {
|
|||
} // namespace
|
||||
|
||||
void PrepareForExportPass::runOnOperation() {
|
||||
getOperation().walk([&](Operation *op) {
|
||||
getOperation().walk([&](Operation* op) {
|
||||
mlir::SplatElementsAttr attr;
|
||||
if (matchPattern(op, m_Constant(&attr))) return prepareConstantOp(op, attr);
|
||||
if (isa<ConstantOp>(op) && matchPattern(op, m_Constant(&attr)))
|
||||
return prepareConstantOp(op, attr);
|
||||
|
||||
if (auto bcastOp = dyn_cast<BroadcastInDimOp>(op))
|
||||
return prepareBroadcastInDim(bcastOp);
|
||||
|
|
|
|||
|
|
@ -11,6 +11,16 @@ func.func @splat_constants() -> tensor<1x64x224x224xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @non_mhlo_constant
|
||||
func.func @non_mhlo_constant() -> tensor<128x1014x508xcomplex<f64>> {
|
||||
// CHECK: arith.constant dense<(1.000000e+00,2.000000e+00)> : tensor<128x1014x508xcomplex<f64>>
|
||||
// CHECK-NOT: mhlo.broadcast_in_dim
|
||||
%0 = arith.constant dense<(1.000000e+00,2.000000e+00)> : tensor<128x1014x508xcomplex<f64>>
|
||||
func.return %0 : tensor<128x1014x508xcomplex<f64>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @splat_constant_complex_float
|
||||
func.func @splat_constant_complex_float() -> tensor<128x1014x508xcomplex<f64>> {
|
||||
// CHECK: %[[CST:.*]] = mhlo.constant dense<(1.000000e+00,2.000000e+00)> : tensor<complex<f64>>
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user