mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
[XTile] Modify Stable HLO check on iota to restrict it to the 1D case.
PiperOrigin-RevId: 826085272
This commit is contained in:
parent
f2b36d1780
commit
6dd75c4e8b
|
|
@ -10,3 +10,9 @@ xtile.entry_func @fails_illegal_op(%arg0: memref<2xf32>, %arg1: index) {
|
|||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @iota_2d_fails() -> tensor<2x2xi32> {
|
||||
// expected-error @+1 {{Only 1D iota is supported}}
|
||||
%0 = stablehlo.iota dim = 0 : tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
|
|
|
|||
|
|
@ -83,8 +83,16 @@ std::optional<absl::string_view> IsLegalTensorOp(mlir::Operation* op) {
|
|||
std::optional<absl::string_view> IsLegalStablehloOp(mlir::Operation* op) {
|
||||
if (mlir::isa<mlir::stablehlo::BroadcastInDimOp, mlir::stablehlo::ReduceOp,
|
||||
mlir::stablehlo::ReturnOp, mlir::stablehlo::TransposeOp,
|
||||
mlir::stablehlo::IotaOp, mlir::stablehlo::DotGeneralOp,
|
||||
mlir::stablehlo::ReshapeOp>(op)) {
|
||||
mlir::stablehlo::DotGeneralOp, mlir::stablehlo::ReshapeOp>(
|
||||
op)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (auto iota = mlir::dyn_cast<mlir::stablehlo::IotaOp>(op)) {
|
||||
if (iota.getType().getRank() != 1) {
|
||||
return "Only 1D iota is supported";
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user