mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Summary: In this diff we implement a way to ensure the internal thrift schema from cfgr (configerator/structs/caffe2/torch/export/schema.thrift) and the schema in OSS (torch/_export/serde/schema.thrift) are in sync, by adding a unittest to reflect on the type names and fields from each schema and compare them field by field. When we detect new fields/types from torch/_export/serde/schema.thrift, there'll be a test failure on the trunk and the error message hints people to add the missing field/type to the thrift schema from cfgr, so that they are always in sync in practice. Test Plan: buck test mode/opt caffe2/test:test_export -- -r test_thrift_schema_in_sync Differential Revision: D66716834 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141989 Approved by: https://github.com/yiming0416
315 lines
5.8 KiB
Thrift
315 lines
5.8 KiB
Thrift
// @generated by update_schema.py
|
|
// checksum<<0e89c5e620ad16c05bfe4fa2060ad43dcb0938dc31d77faad36b92f216c2c903>>
|
|
|
|
namespace py3 torch._export
|
|
namespace cpp2 torch._export.schema
|
|
|
|
enum Layout {
|
|
Unknown = 0,
|
|
SparseCoo = 1,
|
|
SparseCsr = 2,
|
|
SparseCsc = 3,
|
|
SparseBsr = 4,
|
|
SparseBsc = 5,
|
|
_mkldnn = 6,
|
|
Strided = 7,
|
|
}
|
|
|
|
|
|
enum MemoryFormat {
|
|
Unknown = 0,
|
|
ContiguousFormat = 1,
|
|
ChannelsLast = 2,
|
|
ChannelsLast3d = 3,
|
|
PreserveFormat = 4,
|
|
}
|
|
|
|
|
|
enum ScalarType {
|
|
UNKNOWN = 0,
|
|
BYTE = 1,
|
|
CHAR = 2,
|
|
SHORT = 3,
|
|
INT = 4,
|
|
LONG = 5,
|
|
HALF = 6,
|
|
FLOAT = 7,
|
|
DOUBLE = 8,
|
|
COMPLEXHALF = 9,
|
|
COMPLEXFLOAT = 10,
|
|
COMPLEXDOUBLE = 11,
|
|
BOOL = 12,
|
|
BFLOAT16 = 13,
|
|
UINT16 = 28,
|
|
}
|
|
|
|
|
|
struct Device {
|
|
10: string type;
|
|
20: optional i64 index;
|
|
}
|
|
|
|
union SymExprHint {
|
|
10: i64 as_int;
|
|
20: bool as_bool;
|
|
30: double as_float;
|
|
}
|
|
|
|
struct SymExpr {
|
|
10: string expr_str;
|
|
20: optional SymExprHint hint;
|
|
}
|
|
|
|
union SymInt {
|
|
10: SymExpr as_expr;
|
|
20: i64 as_int;
|
|
}
|
|
|
|
union SymFloat {
|
|
10: SymExpr as_expr;
|
|
20: double as_float;
|
|
}
|
|
|
|
union SymBool {
|
|
10: SymExpr as_expr;
|
|
20: bool as_bool;
|
|
}
|
|
|
|
struct TensorMeta {
|
|
10: ScalarType dtype;
|
|
20: list<SymInt> sizes;
|
|
30: bool requires_grad;
|
|
40: Device device;
|
|
50: list<SymInt> strides;
|
|
60: SymInt storage_offset;
|
|
70: Layout layout;
|
|
}
|
|
|
|
union SymIntArgument {
|
|
10: string as_name;
|
|
20: i64 as_int;
|
|
}
|
|
|
|
union SymFloatArgument {
|
|
10: string as_name;
|
|
20: double as_float;
|
|
}
|
|
|
|
union SymBoolArgument {
|
|
10: string as_name;
|
|
20: bool as_bool;
|
|
}
|
|
|
|
struct TensorArgument {
|
|
10: string name;
|
|
}
|
|
|
|
struct TokenArgument {
|
|
10: string name;
|
|
}
|
|
|
|
union OptionalTensorArgument {
|
|
20: TensorArgument as_tensor;
|
|
10: bool as_none;
|
|
}
|
|
|
|
struct GraphArgument {
|
|
10: string name;
|
|
20: Graph graph;
|
|
}
|
|
|
|
struct CustomObjArgument {
|
|
10: string name;
|
|
20: string class_fqn;
|
|
}
|
|
|
|
union Argument {
|
|
10: bool as_none;
|
|
20: TensorArgument as_tensor;
|
|
30: list<TensorArgument> as_tensors;
|
|
50: i64 as_int;
|
|
70: list<i64> as_ints;
|
|
80: double as_float;
|
|
90: list<double> as_floats;
|
|
100: string as_string;
|
|
101: list<string> as_strings;
|
|
110: SymIntArgument as_sym_int;
|
|
120: list<SymIntArgument> as_sym_ints;
|
|
130: ScalarType as_scalar_type;
|
|
140: MemoryFormat as_memory_format;
|
|
150: Layout as_layout;
|
|
160: Device as_device;
|
|
170: bool as_bool;
|
|
180: list<bool> as_bools;
|
|
182: SymBoolArgument as_sym_bool;
|
|
184: list<SymBoolArgument> as_sym_bools;
|
|
200: GraphArgument as_graph;
|
|
190: list<OptionalTensorArgument> as_optional_tensors;
|
|
210: CustomObjArgument as_custom_obj;
|
|
220: string as_operator;
|
|
230: SymFloatArgument as_sym_float;
|
|
240: list<SymFloatArgument> as_sym_floats;
|
|
}
|
|
|
|
struct NamedArgument {
|
|
10: string name;
|
|
20: Argument arg;
|
|
}
|
|
|
|
struct Node {
|
|
10: string target;
|
|
20: list<NamedArgument> inputs;
|
|
30: list<Argument> outputs;
|
|
40: map<string, string> metadata;
|
|
}
|
|
|
|
struct Graph {
|
|
10: list<Argument> inputs;
|
|
20: list<Argument> outputs;
|
|
30: list<Node> nodes;
|
|
40: map<string, TensorMeta> tensor_values;
|
|
50: map<string, SymInt> sym_int_values;
|
|
60: map<string, SymBool> sym_bool_values;
|
|
70: bool is_single_tensor_return;
|
|
80: map<string, CustomObjArgument> custom_obj_values;
|
|
90: map<string, SymFloat> sym_float_values;
|
|
}
|
|
|
|
struct UserInputSpec {
|
|
10: Argument arg;
|
|
}
|
|
|
|
union ConstantValue {
|
|
10: bool as_none;
|
|
20: i64 as_int;
|
|
30: double as_float;
|
|
40: string as_string;
|
|
50: bool as_bool;
|
|
}
|
|
|
|
struct InputToConstantInputSpec {
|
|
10: string name;
|
|
20: ConstantValue value;
|
|
}
|
|
|
|
struct InputToParameterSpec {
|
|
10: TensorArgument arg;
|
|
20: string parameter_name;
|
|
}
|
|
|
|
struct InputToBufferSpec {
|
|
10: TensorArgument arg;
|
|
20: string buffer_name;
|
|
30: bool persistent;
|
|
}
|
|
|
|
struct InputToTensorConstantSpec {
|
|
10: TensorArgument arg;
|
|
20: string tensor_constant_name;
|
|
}
|
|
|
|
struct InputToCustomObjSpec {
|
|
10: CustomObjArgument arg;
|
|
20: string custom_obj_name;
|
|
}
|
|
|
|
struct InputTokenSpec {
|
|
10: TokenArgument arg;
|
|
}
|
|
|
|
union InputSpec {
|
|
10: UserInputSpec user_input;
|
|
20: InputToParameterSpec parameter;
|
|
30: InputToBufferSpec buffer;
|
|
40: InputToTensorConstantSpec tensor_constant;
|
|
50: InputToCustomObjSpec custom_obj;
|
|
70: InputTokenSpec token;
|
|
60: InputToConstantInputSpec constant_input;
|
|
}
|
|
|
|
struct UserOutputSpec {
|
|
10: Argument arg;
|
|
}
|
|
|
|
struct LossOutputSpec {
|
|
10: TensorArgument arg;
|
|
}
|
|
|
|
struct BufferMutationSpec {
|
|
10: TensorArgument arg;
|
|
20: string buffer_name;
|
|
}
|
|
|
|
struct GradientToParameterSpec {
|
|
10: TensorArgument arg;
|
|
20: string parameter_name;
|
|
}
|
|
|
|
struct GradientToUserInputSpec {
|
|
10: TensorArgument arg;
|
|
20: string user_input_name;
|
|
}
|
|
|
|
struct UserInputMutationSpec {
|
|
10: TensorArgument arg;
|
|
20: string user_input_name;
|
|
}
|
|
|
|
struct OutputTokenSpec {
|
|
10: TokenArgument arg;
|
|
}
|
|
|
|
union OutputSpec {
|
|
10: UserOutputSpec user_output;
|
|
20: LossOutputSpec loss_output;
|
|
30: BufferMutationSpec buffer_mutation;
|
|
40: GradientToParameterSpec gradient_to_parameter;
|
|
50: GradientToUserInputSpec gradient_to_user_input;
|
|
60: UserInputMutationSpec user_input_mutation;
|
|
70: OutputTokenSpec token;
|
|
}
|
|
|
|
struct GraphSignature {
|
|
10: list<InputSpec> input_specs;
|
|
20: list<OutputSpec> output_specs;
|
|
}
|
|
|
|
struct RangeConstraint {
|
|
10: optional i64 min_val;
|
|
20: optional i64 max_val;
|
|
}
|
|
|
|
struct ModuleCallSignature {
|
|
10: list<Argument> inputs;
|
|
20: list<Argument> outputs;
|
|
30: string in_spec;
|
|
40: string out_spec;
|
|
50: optional list<string> forward_arg_names;
|
|
}
|
|
|
|
struct ModuleCallEntry {
|
|
10: string fqn;
|
|
30: optional ModuleCallSignature signature;
|
|
}
|
|
|
|
struct GraphModule {
|
|
10: Graph graph;
|
|
50: GraphSignature signature;
|
|
60: list<ModuleCallEntry> module_call_graph;
|
|
40: map<string, string> metadata;
|
|
}
|
|
|
|
struct SchemaVersion {
|
|
10: i64 major;
|
|
20: i64 minor;
|
|
}
|
|
|
|
struct ExportedProgram {
|
|
10: GraphModule graph_module;
|
|
20: map<string, i64> opset_version;
|
|
30: map<string, RangeConstraint> range_constraints;
|
|
60: SchemaVersion schema_version;
|
|
70: list<string> verifiers;
|
|
80: string torch_version;
|
|
}
|