[TF:XLA] Fix string to HLO opcode conversion for atan2, complex, imag and real.

Make sure that we can't forget opcodes by auto-generating the conversion
functions.

Add auto-generated functions to test HLOs for properties (like IsVariadic,
IsComparison, etc.)

This makes changing HLO more robust and easier because there are fewer places
to update when adding or removing an HLO opcode.

Also:
* Fix IsElementwiseBinary for atan2.
* Add a unit test for HLO opcode helpers.
* Express IsElementwiseBinary in terms of IsElementwise() and operand_count()
  to avoid having to keep the two in sync manually.
PiperOrigin-RevId: 174069664
This commit is contained in:
A. Unique TensorFlower 2017-10-31 11:54:57 -07:00 committed by TensorFlower Gardener
parent 3b845c80d5
commit 35939d2d37
4 changed files with 186 additions and 355 deletions

View File

@ -2514,33 +2514,7 @@ std::vector<int64> HloInstruction::OperandIndices(
}
bool HloInstruction::IsElementwiseBinary() const {
switch (opcode_) {
// Binary elementwise operations. If you update this, please update
// IsElementwise() accordingly.
case HloOpcode::kAdd:
case HloOpcode::kComplex:
case HloOpcode::kDivide:
case HloOpcode::kEq:
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kLe:
case HloOpcode::kLt:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNe:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kSubtract:
case HloOpcode::kAnd:
case HloOpcode::kOr:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
return true;
default:
return false;
}
return IsElementwise() && operand_count() == 2;
}
bool HloInstruction::IsElementwise() const {
@ -2551,7 +2525,6 @@ bool HloInstruction::IsElementwise() const {
// Unary elementwise operations.
case HloOpcode::kAbs:
case HloOpcode::kAtan2:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kCeil:
case HloOpcode::kConvert:
@ -2569,11 +2542,12 @@ bool HloInstruction::IsElementwise() const {
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kTanh:
CHECK_EQ(1, operand_count());
return true;
// Binary elementwise operations, the same as in IsElementwiseBinary().
// If you update this, please update IsElementwiseBinary() accordingly.
case HloOpcode::kAdd:
case HloOpcode::kAtan2:
case HloOpcode::kComplex:
case HloOpcode::kDivide:
case HloOpcode::kEq:
@ -2593,6 +2567,7 @@ bool HloInstruction::IsElementwise() const {
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
CHECK_EQ(2, operand_count());
return true;
// Ternary elementwise operations.

View File

@ -21,243 +21,22 @@ limitations under the License.
namespace xla {
string HloOpcodeString(HloOpcode opcode) {
// Note: Do not use ':' in opcode strings. It is used as a special character
// in these places:
// - In extended opcode strings (HloInstruction::ExtendedOpcodeString()), to
// separate the opcode from the fusion kind
// - In fully qualified names (HloInstruction::FullyQualifiedName()), to
// separate the qualifiers (name of the computation and potentially the
// fusion instruction) from the name
switch (opcode) {
case HloOpcode::kAbs:
return "abs";
case HloOpcode::kAdd:
return "add";
case HloOpcode::kAnd:
return "and";
case HloOpcode::kAtan2:
return "atan2";
case HloOpcode::kBatchNormTraining:
return "batch-norm-training";
case HloOpcode::kBatchNormInference:
return "batch-norm-inference";
case HloOpcode::kBatchNormGrad:
return "batch-norm-grad";
case HloOpcode::kBitcast:
return "bitcast";
case HloOpcode::kBroadcast:
return "broadcast";
case HloOpcode::kCall:
return "call";
case HloOpcode::kClamp:
return "clamp";
case HloOpcode::kComplex:
return "complex";
case HloOpcode::kConcatenate:
return "concatenate";
case HloOpcode::kConstant:
return "constant";
case HloOpcode::kConvert:
return "convert";
case HloOpcode::kConvolution:
return "convolution";
case HloOpcode::kCos:
return "cosine";
case HloOpcode::kCrossReplicaSum:
return "cross-replica-sum";
case HloOpcode::kCustomCall:
return "custom-call";
case HloOpcode::kCopy:
return "copy";
case HloOpcode::kDivide:
return "divide";
case HloOpcode::kDot:
return "dot";
case HloOpcode::kDynamicSlice:
return "dynamic-slice";
case HloOpcode::kDynamicUpdateSlice:
return "dynamic-update-slice";
case HloOpcode::kEq:
return "equal-to";
case HloOpcode::kExp:
return "exponential";
case HloOpcode::kFloor:
return "floor";
case HloOpcode::kCeil:
return "ceil";
case HloOpcode::kFusion:
return "fusion";
case HloOpcode::kGe:
return "greater-than-or-equal-to";
case HloOpcode::kGetTupleElement:
return "get-tuple-element";
case HloOpcode::kGt:
return "greater-than";
case HloOpcode::kImag:
return "imag";
case HloOpcode::kInfeed:
return "infeed";
case HloOpcode::kIsFinite:
return "is-finite";
case HloOpcode::kLe:
return "less-than-or-equal-to";
case HloOpcode::kLog:
return "log";
case HloOpcode::kLt:
return "less-than";
case HloOpcode::kMap:
return "map";
case HloOpcode::kMaximum:
return "maximum";
case HloOpcode::kMinimum:
return "minimum";
case HloOpcode::kMultiply:
return "multiply";
case HloOpcode::kNe:
return "not-equal-to";
case HloOpcode::kNegate:
return "negate";
case HloOpcode::kNot:
return "not";
case HloOpcode::kOr:
return "or";
case HloOpcode::kOutfeed:
return "outfeed";
case HloOpcode::kPad:
return "pad";
case HloOpcode::kParameter:
return "parameter";
case HloOpcode::kPower:
return "power";
case HloOpcode::kReal:
return "real";
case HloOpcode::kRecv:
return "recv";
case HloOpcode::kReduce:
return "reduce";
case HloOpcode::kReducePrecision:
return "reduce-precision";
case HloOpcode::kReduceWindow:
return "reduce-window";
case HloOpcode::kRemainder:
return "remainder";
case HloOpcode::kReshape:
return "reshape";
case HloOpcode::kReverse:
return "reverse";
case HloOpcode::kRng:
return "rng";
case HloOpcode::kRoundNearestAfz:
return "round-nearest-afz";
case HloOpcode::kSelectAndScatter:
return "select-and-scatter";
case HloOpcode::kSelect:
return "select";
case HloOpcode::kSend:
return "send";
case HloOpcode::kShiftLeft:
return "shift-left";
case HloOpcode::kShiftRightArithmetic:
return "shift-right-arithmetic";
case HloOpcode::kShiftRightLogical:
return "shift-right-logical";
case HloOpcode::kSign:
return "sign";
case HloOpcode::kSin:
return "sine";
case HloOpcode::kSlice:
return "slice";
case HloOpcode::kSort:
return "sort";
case HloOpcode::kSubtract:
return "subtract";
case HloOpcode::kTanh:
return "tanh";
case HloOpcode::kTrace:
return "trace";
case HloOpcode::kTranspose:
return "transpose";
case HloOpcode::kTuple:
return "tuple";
case HloOpcode::kWhile:
return "while";
#define CASE_OPCODE_STRING(enum_name, opcode_name, ...) \
case HloOpcode::enum_name: \
return opcode_name;
HLO_OPCODE_LIST(CASE_OPCODE_STRING)
#undef CASE_OPCODE_STRING
}
}
StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name) {
static auto* opcode_map = new tensorflow::gtl::FlatMap<string, HloOpcode>(
{{"abs", HloOpcode::kAbs},
{"add", HloOpcode::kAdd},
{"and", HloOpcode::kAnd},
{"batch-norm-training", HloOpcode::kBatchNormTraining},
{"batch-norm-inference", HloOpcode::kBatchNormInference},
{"batch-norm-grad", HloOpcode::kBatchNormGrad},
{"bitcast", HloOpcode::kBitcast},
{"broadcast", HloOpcode::kBroadcast},
{"call", HloOpcode::kCall},
{"clamp", HloOpcode::kClamp},
{"concatenate", HloOpcode::kConcatenate},
{"constant", HloOpcode::kConstant},
{"convert", HloOpcode::kConvert},
{"convolution", HloOpcode::kConvolution},
{"cosine", HloOpcode::kCos},
{"cross-replica-sum", HloOpcode::kCrossReplicaSum},
{"custom-call", HloOpcode::kCustomCall},
{"copy", HloOpcode::kCopy},
{"divide", HloOpcode::kDivide},
{"dot", HloOpcode::kDot},
{"dynamic-slice", HloOpcode::kDynamicSlice},
{"dynamic-update-slice", HloOpcode::kDynamicUpdateSlice},
{"equal-to", HloOpcode::kEq},
{"exponential", HloOpcode::kExp},
{"floor", HloOpcode::kFloor},
{"ceil", HloOpcode::kCeil},
{"fusion", HloOpcode::kFusion},
{"greater-than-or-equal-to", HloOpcode::kGe},
{"get-tuple-element", HloOpcode::kGetTupleElement},
{"greater-than", HloOpcode::kGt},
{"infeed", HloOpcode::kInfeed},
{"is-finite", HloOpcode::kIsFinite},
{"less-than-or-equal-to", HloOpcode::kLe},
{"log", HloOpcode::kLog},
{"less-than", HloOpcode::kLt},
{"map", HloOpcode::kMap},
{"maximum", HloOpcode::kMaximum},
{"minimum", HloOpcode::kMinimum},
{"multiply", HloOpcode::kMultiply},
{"not", HloOpcode::kNot},
{"not-equal-to", HloOpcode::kNe},
{"negate", HloOpcode::kNegate},
{"or", HloOpcode::kOr},
{"outfeed", HloOpcode::kOutfeed},
{"pad", HloOpcode::kPad},
{"parameter", HloOpcode::kParameter},
{"power", HloOpcode::kPower},
{"recv", HloOpcode::kRecv},
{"reduce", HloOpcode::kReduce},
{"reduce-precision", HloOpcode::kReducePrecision},
{"reduce-window", HloOpcode::kReduceWindow},
{"remainder", HloOpcode::kRemainder},
{"reshape", HloOpcode::kReshape},
{"reverse", HloOpcode::kReverse},
{"rng", HloOpcode::kRng},
{"round-nearest-afz", HloOpcode::kRoundNearestAfz},
{"select-and-scatter", HloOpcode::kSelectAndScatter},
{"select", HloOpcode::kSelect},
{"send", HloOpcode::kSend},
{"shift-left", HloOpcode::kShiftLeft},
{"shift-right-arithmetic", HloOpcode::kShiftRightArithmetic},
{"shift-right-logical", HloOpcode::kShiftRightLogical},
{"sign", HloOpcode::kSign},
{"sine", HloOpcode::kSin},
{"slice", HloOpcode::kSlice},
{"sort", HloOpcode::kSort},
{"subtract", HloOpcode::kSubtract},
{"tanh", HloOpcode::kTanh},
{"trace", HloOpcode::kTrace},
{"transpose", HloOpcode::kTranspose},
{"tuple", HloOpcode::kTuple},
{"while", HloOpcode::kWhile}});
static auto* opcode_map = new tensorflow::gtl::FlatMap<string, HloOpcode>({
#define STRING_TO_OPCODE_ENTRY(enum_name, opcode_name, ...) \
{opcode_name, HloOpcode::enum_name},
HLO_OPCODE_LIST(STRING_TO_OPCODE_ENTRY)
#undef STRING_TO_OPCODE_ENTRY
});
auto it = opcode_map->find(opcode_name);
if (it == opcode_map->end()) {
return InvalidArgument("Unknown opcode: %s", opcode_name.c_str());
@ -265,31 +44,36 @@ StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name) {
return it->second;
}
#define CHECK_DEFAULT(property_name, opcode_name) false
#define CHECK_PROPERTY(property_name, opcode_name, value) \
(value & property_name)
#define RESOLVE(_1, _2, target, ...) target
#define HAS_PROPERTY(property, ...) \
RESOLVE(__VA_ARGS__, CHECK_PROPERTY, CHECK_DEFAULT)(property, __VA_ARGS__)
bool HloOpcodeIsComparison(HloOpcode opcode) {
switch (opcode) {
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kLe:
case HloOpcode::kLt:
case HloOpcode::kEq:
case HloOpcode::kNe:
return true;
default:
return false;
#define CASE_IS_COMPARISON(enum_name, ...) \
case HloOpcode::enum_name: \
return HAS_PROPERTY(kHloOpcodeIsComparison, __VA_ARGS__);
HLO_OPCODE_LIST(CASE_IS_COMPARISON)
#undef CASE_IS_COMPARISON
}
}
bool HloOpcodeIsVariadic(HloOpcode opcode) {
switch (opcode) {
case HloOpcode::kCall:
case HloOpcode::kConcatenate:
case HloOpcode::kFusion:
case HloOpcode::kMap:
case HloOpcode::kTuple:
return true;
default:
return false;
#define CASE_IS_VARIADIC(enum_name, ...) \
case HloOpcode::enum_name: \
return HAS_PROPERTY(kHloOpcodeIsVariadic, __VA_ARGS__);
HLO_OPCODE_LIST(CASE_IS_VARIADIC)
#undef CASE_IS_VARIADIC
}
}
#undef HAS_PROPERTY
#undef RESOLVE
#undef CHECK_DEFAULT
#undef CHECK_PROPERTY
} // namespace xla

View File

@ -28,83 +28,112 @@ namespace xla {
// present in the XLA service protobuf.
//
// See the XLA documentation for the semantics of each opcode.
//
// Each entry has the format:
// (enum_name, opcode_name)
// or
// (enum_name, opcode_name, p1 | p2 | ...)
//
// with p1, p2, ... are members of HloOpcodeProperty. They are combined
// using bitwise-or.
//
// Note: Do not use ':' in opcode names. It is used as a special character
// in these places:
// - In extended opcode strings (HloInstruction::ExtendedOpcodeString()), to
// separate the opcode from the fusion kind
// - In fully qualified names (HloInstruction::FullyQualifiedName()), to
// separate the qualifiers (name of the computation and potentially the
// fusion instruction) from the name
#define HLO_OPCODE_LIST(V) \
V(kAbs, "abs") \
V(kAdd, "add") \
V(kAtan2, "atan2") \
V(kBatchNormGrad, "batch-norm-grad") \
V(kBatchNormInference, "batch-norm-inference") \
V(kBatchNormTraining, "batch-norm-training") \
V(kBitcast, "bitcast") \
V(kBroadcast, "broadcast") \
V(kCall, "call", kHloOpcodeIsVariadic) \
V(kCeil, "ceil") \
V(kClamp, "clamp") \
V(kComplex, "complex") \
V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \
V(kConstant, "constant") \
V(kConvert, "convert") \
V(kConvolution, "convolution") \
V(kCopy, "copy") \
V(kCos, "cosine") \
V(kCrossReplicaSum, "cross-replica-sum") \
V(kCustomCall, "custom-call") \
V(kDivide, "divide") \
V(kDot, "dot") \
V(kDynamicSlice, "dynamic-slice") \
V(kDynamicUpdateSlice, "dynamic-update-slice") \
V(kEq, "equal-to", kHloOpcodeIsComparison) \
V(kExp, "exponential") \
V(kFloor, "floor") \
V(kFusion, "fusion", kHloOpcodeIsVariadic) \
V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \
V(kGetTupleElement, "get-tuple-element") \
V(kGt, "greater-than", kHloOpcodeIsComparison) \
V(kImag, "imag") \
V(kInfeed, "infeed") \
V(kIsFinite, "is-finite") \
V(kLe, "less-than-or-equal-to", kHloOpcodeIsComparison) \
V(kLog, "log") \
V(kAnd, "and") \
V(kNot, "not") \
V(kOr, "or") \
V(kLt, "less-than", kHloOpcodeIsComparison) \
V(kMap, "map", kHloOpcodeIsVariadic) \
V(kMaximum, "maximum") \
V(kMinimum, "minimum") \
V(kMultiply, "multiply") \
V(kNe, "not-equal-to", kHloOpcodeIsComparison) \
V(kNegate, "negate") \
V(kOutfeed, "outfeed") \
V(kPad, "pad") \
V(kParameter, "parameter") \
V(kPower, "power") \
V(kReal, "real") \
V(kRecv, "recv") \
V(kReduce, "reduce") \
V(kReducePrecision, "reduce-precision") \
V(kReduceWindow, "reduce-window") \
V(kRemainder, "remainder") \
V(kReshape, "reshape") \
V(kReverse, "reverse") \
V(kRng, "rng") \
V(kRoundNearestAfz, "round-nearest-afz") \
V(kSelect, "select") \
V(kSelectAndScatter, "select-and-scatter") \
V(kSend, "send") \
V(kShiftLeft, "shift-left") \
V(kShiftRightArithmetic, "shift-right-arithmetic") \
V(kShiftRightLogical, "shift-right-logical") \
V(kSign, "sign") \
V(kSin, "sine") \
V(kSlice, "slice") \
V(kSort, "sort") \
V(kSubtract, "subtract") \
V(kTanh, "tanh") \
V(kTrace, "trace") \
V(kTranspose, "transpose") \
V(kTuple, "tuple", kHloOpcodeIsVariadic) \
V(kWhile, "while")
enum class HloOpcode {
kAbs,
kAdd,
kAtan2,
kBatchNormGrad,
kBatchNormInference,
kBatchNormTraining,
kBitcast,
kBroadcast,
kCall,
kCeil,
kClamp,
kComplex,
kConcatenate,
kConstant,
kConvert,
kConvolution,
kCopy,
kCos,
kCrossReplicaSum,
kCustomCall,
kDivide,
kDot,
kDynamicSlice,
kDynamicUpdateSlice,
kEq,
kExp,
kFloor,
kFusion,
kGe,
kGetTupleElement,
kGt,
kImag,
kInfeed,
kIsFinite,
kLe,
kLog,
kAnd,
kNot,
kOr,
kLt,
kMap,
kMaximum,
kMinimum,
kMultiply,
kNe,
kNegate,
kOutfeed,
kPad,
kParameter,
kPower,
kReal,
kRecv,
kReduce,
kReducePrecision,
kReduceWindow,
kRemainder,
kReshape,
kReverse,
kRng,
kRoundNearestAfz,
kSelect,
kSelectAndScatter,
kSend,
kShiftLeft,
kShiftRightArithmetic,
kShiftRightLogical,
kSign,
kSin,
kSlice,
kSort,
kSubtract,
kTanh,
kTrace,
kTranspose,
kTuple,
kWhile,
#define DECLARE_ENUM(enum_name, opcode_name, ...) enum_name,
HLO_OPCODE_LIST(DECLARE_ENUM)
#undef DECLARE_ENUM
};
// List of properties associated with opcodes.
// Properties are defined as increasing powers of two, so that we can use
// bitwise-or to combine properties, and bitwise-and to test for them.
enum HloOpcodeProperty {
kHloOpcodeIsComparison = 1 << 0,
kHloOpcodeIsVariadic = 1 << 1,
};
// Returns a string representation of the opcode.
@ -125,7 +154,9 @@ bool HloOpcodeIsVariadic(HloOpcode opcode);
// Returns the number of HloOpcode values.
inline const uint32_t HloOpcodeCount() {
return static_cast<uint32_t>(HloOpcode::kWhile) + 1;
#define HLO_COUNT_ONE(...) +1
#define HLO_XLIST_LENGTH(list) list(HLO_COUNT_ONE)
return HLO_XLIST_LENGTH(HLO_OPCODE_LIST);
}
} // namespace xla

View File

@ -26,5 +26,46 @@ TEST(HloOpcodeTest, StringifyMultiply) {
ASSERT_EQ("multiply", HloOpcodeString(HloOpcode::kMultiply));
}
TEST(HloOpcodeTest, OpcodeProperties) {
// Test counting macro.
#define SOME_LIST(X) \
X(One) \
X(Two) \
X(Three)
EXPECT_EQ(3, HLO_XLIST_LENGTH(SOME_LIST));
#undef SOME_LIST
for (int i = 0; i < HloOpcodeCount(); ++i) {
auto opcode = static_cast<HloOpcode>(i);
// Test round-trip conversion to and from string.
EXPECT_EQ(opcode, StringToHloOpcode(HloOpcodeString(opcode)).ValueOrDie());
// Test some properties.
switch (opcode) {
case HloOpcode::kEq:
case HloOpcode::kNe:
case HloOpcode::kGt:
case HloOpcode::kLt:
case HloOpcode::kGe:
case HloOpcode::kLe:
EXPECT_TRUE(HloOpcodeIsComparison(opcode));
break;
default:
EXPECT_FALSE(HloOpcodeIsComparison(opcode));
}
switch (opcode) {
case HloOpcode::kCall:
case HloOpcode::kConcatenate:
case HloOpcode::kFusion:
case HloOpcode::kMap:
case HloOpcode::kTuple:
EXPECT_TRUE(HloOpcodeIsVariadic(opcode));
break;
default:
EXPECT_FALSE(HloOpcodeIsVariadic(opcode));
}
}
}
} // namespace
} // namespace xla