[XTile] Modify Stable HLO check on iota to restrict it to the 1D case.

PiperOrigin-RevId: 826085272
This commit is contained in:
Will Froom 2025-10-30 10:15:25 -07:00 committed by TensorFlower Gardener
parent f2b36d1780
commit 6dd75c4e8b
2 changed files with 16 additions and 2 deletions

View File

@ -10,3 +10,9 @@ xtile.entry_func @fails_illegal_op(%arg0: memref<2xf32>, %arg1: index) {
} }
// ----- // -----
func.func @iota_2d_fails() -> tensor<2x2xi32> {
// expected-error @+1 {{Only 1D iota is supported}}
%0 = stablehlo.iota dim = 0 : tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}

View File

@ -83,8 +83,16 @@ std::optional<absl::string_view> IsLegalTensorOp(mlir::Operation* op) {
std::optional<absl::string_view> IsLegalStablehloOp(mlir::Operation* op) { std::optional<absl::string_view> IsLegalStablehloOp(mlir::Operation* op) {
if (mlir::isa<mlir::stablehlo::BroadcastInDimOp, mlir::stablehlo::ReduceOp, if (mlir::isa<mlir::stablehlo::BroadcastInDimOp, mlir::stablehlo::ReduceOp,
mlir::stablehlo::ReturnOp, mlir::stablehlo::TransposeOp, mlir::stablehlo::ReturnOp, mlir::stablehlo::TransposeOp,
mlir::stablehlo::IotaOp, mlir::stablehlo::DotGeneralOp, mlir::stablehlo::DotGeneralOp, mlir::stablehlo::ReshapeOp>(
mlir::stablehlo::ReshapeOp>(op)) { op)) {
return std::nullopt;
}
if (auto iota = mlir::dyn_cast<mlir::stablehlo::IotaOp>(op)) {
if (iota.getType().getRank() != 1) {
return "Only 1D iota is supported";
}
return std::nullopt; return std::nullopt;
} }