diff --git a/third_party/xla/xla/hlo/analysis/symbolic_expr.cc b/third_party/xla/xla/hlo/analysis/symbolic_expr.cc index c0df20d9105..e7588404cf2 100644 --- a/third_party/xla/xla/hlo/analysis/symbolic_expr.cc +++ b/third_party/xla/xla/hlo/analysis/symbolic_expr.cc @@ -557,6 +557,12 @@ SymbolicExprContext* SymbolicExpr::GetContext() const { return impl_->ctx_; } SymbolicExprType SymbolicExpr::GetType() const { return impl_->type_; } +bool SymbolicExpr::IsBinaryOp() const { + auto type = GetType(); + return type != SymbolicExprType::kConstant && + type != SymbolicExprType::kVariable; +} + SymbolicExpr SymbolicExpr::GetLHS() const { return impl_->lhs_; } SymbolicExpr SymbolicExpr::GetRHS() const { return impl_->rhs_; } @@ -732,12 +738,11 @@ SymbolicExpr SymbolicExpr::Replace( return it->second; } - SymbolicExprType type = GetType(); - if (type == SymbolicExprType::kConstant || - type == SymbolicExprType::kVariable) { + if (!IsBinaryOp()) { return *this; } + SymbolicExprType type = GetType(); SymbolicExpr lhs = GetLHS(); SymbolicExpr rhs = GetRHS(); SymbolicExpr new_lhs = lhs.Replace(replacements); @@ -779,12 +784,11 @@ SymbolicExpr SymbolicExpr::Canonicalize() const { return *this; } - SymbolicExprType type = GetType(); - if (type == SymbolicExprType::kConstant || - type == SymbolicExprType::kVariable) { + if (!IsBinaryOp()) { return *this; } + SymbolicExprType type = GetType(); SymbolicExpr lhs = this->GetLHS().Canonicalize(); SymbolicExpr rhs = this->GetRHS().Canonicalize(); diff --git a/third_party/xla/xla/hlo/analysis/symbolic_expr.h b/third_party/xla/xla/hlo/analysis/symbolic_expr.h index 1f9a9ea27bf..ffb8269d62f 100644 --- a/third_party/xla/xla/hlo/analysis/symbolic_expr.h +++ b/third_party/xla/xla/hlo/analysis/symbolic_expr.h @@ -65,6 +65,7 @@ class SymbolicExpr { SymbolicExprContext* GetContext() const; SymbolicExprType GetType() const; + bool IsBinaryOp() const; SymbolicExpr GetLHS() const; SymbolicExpr GetRHS() const; int64_t GetValue() const;