pytorch/torch/csrc/jit/script/schema_matching.h
Xiang Gao eae139e18f Support named tuple return from operators on JIT (#16253)
Summary:
Fixes: https://github.com/pytorch/pytorch/issues/16233

The following changes are made:
- Modify `TupleType` to store optional field names
- Modify schema matching to return fill in those field names when creating  `TupleType` as return type.
- Modify codegen of JIT to copy field names to schema string
- Modify `SchemaParser` to set field names of returned schema.
- Modify `SimpleValue::attr` to emit tuple indexing for named tuple.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16253

Reviewed By: ezyang

Differential Revision: D13954298

Pulled By: zdevito

fbshipit-source-id: 247d483d78a0c9c12d1ba36e1f1ec6c3f1a3007b
2019-02-10 18:15:56 -08:00

61 lines
1.8 KiB
C++

#pragma once
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/named_value.h>
#include <ATen/core/function_schema.h>
namespace torch {
namespace jit {
namespace script {
// try to match a list if inputs and keyword 'attributes' to this schema,
// if it works return the flat list of positional inputs to the call
// if it returns nullopt, then failure_messages contains a good error report
// set convert_tensor_to_num to true if ImplicitTensorToNums should be inserted
// to match the schema
struct MatchedSchema {
std::vector<Value*> inputs;
std::vector<TypePtr> return_types;
c10::OptNameList return_field_names;
};
TORCH_API c10::optional<MatchedSchema> tryMatchSchema(
const ::c10::FunctionSchema& schema,
const SourceRange& loc,
Graph& graph,
c10::optional<NamedValue> self,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
std::ostream& failure_messages,
bool allow_conversions);
TORCH_API Value* emitBuiltinCall(
const SourceRange& loc,
Graph& graph,
Symbol name,
const c10::optional<NamedValue>& self,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
// if true, emitBuiltinCall will throw an exception if this builtin does not
// exist, otherwise it will return nullptr if the builtin is not found.
bool required);
TORCH_API c10::optional<size_t> findInputWithName(
const std::string& name,
at::ArrayRef<NamedValue> kwargs);
// applies implict conversion from value trying to turn it into type
// concrete_type it succeeds if the return_value->isSubclassOf(concrete_type)
TORCH_API Value* tryConvertToType(
const SourceRange& loc,
Graph& graph,
const TypePtr& concrete_type,
Value* value,
bool allow_conversions);
} // namespace script
} // namespace jit
} // namespace torch