mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Android] print type name for IValues (#64602)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64602 print type name in error message for easier debugging. Test Plan: Example: java.lang.IllegalStateException: Expected IValue type Tensor, actual type TensorList Reviewed By: beback4u Differential Revision: D30782318 fbshipit-source-id: 60d88a659e7b4bb2b574b12c7652a28f0d5ad0d2
This commit is contained in:
parent
11ef68938c
commit
1b5b210f2c
|
|
@ -40,6 +40,24 @@ public class IValue {
|
|||
private static final int TYPE_CODE_DICT_STRING_KEY = 13;
|
||||
private static final int TYPE_CODE_DICT_LONG_KEY = 14;
|
||||
|
||||
private String[] TYPE_NAMES = {
|
||||
"Unknown",
|
||||
"Null",
|
||||
"Tensor",
|
||||
"Bool",
|
||||
"Long",
|
||||
"Double",
|
||||
"String",
|
||||
"Tuple",
|
||||
"BoolList",
|
||||
"LongList",
|
||||
"DoubleList",
|
||||
"TensorList",
|
||||
"GenericList",
|
||||
"DictStringKey",
|
||||
"DictLongKey",
|
||||
};
|
||||
|
||||
@DoNotStrip private final int mTypeCode;
|
||||
@DoNotStrip private Object mData;
|
||||
|
||||
|
|
@ -312,7 +330,14 @@ public class IValue {
|
|||
if (typeCode != typeCodeExpected) {
|
||||
throw new IllegalStateException(
|
||||
String.format(
|
||||
Locale.US, "Expected IValue type %d, actual type %d", typeCodeExpected, typeCode));
|
||||
Locale.US,
|
||||
"Expected IValue type %s, actual type %s",
|
||||
getTypeName(typeCodeExpected),
|
||||
getTypeName(typeCode)));
|
||||
}
|
||||
}
|
||||
|
||||
private String getTypeName(int typeCode) {
|
||||
return typeCode >= 0 && typeCode < TYPE_NAMES.length ? TYPE_NAMES[typeCode] : "Unknown";
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user