mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Summary: Fixes: https://github.com/pytorch/pytorch/issues/16233 The following changes are made: - Modify `TupleType` to store optional field names - Modify schema matching to return fill in those field names when creating `TupleType` as return type. - Modify codegen of JIT to copy field names to schema string - Modify `SchemaParser` to set field names of returned schema. - Modify `SimpleValue::attr` to emit tuple indexing for named tuple. Pull Request resolved: https://github.com/pytorch/pytorch/pull/16253 Reviewed By: ezyang Differential Revision: D13954298 Pulled By: zdevito fbshipit-source-id: 247d483d78a0c9c12d1ba36e1f1ec6c3f1a3007b
61 lines
1.8 KiB
C++
61 lines
1.8 KiB
C++
#pragma once
|
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
|
#include <torch/csrc/jit/ir.h>
|
|
#include <torch/csrc/jit/named_value.h>
|
|
|
|
#include <ATen/core/function_schema.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace script {
|
|
|
|
// try to match a list if inputs and keyword 'attributes' to this schema,
|
|
// if it works return the flat list of positional inputs to the call
|
|
// if it returns nullopt, then failure_messages contains a good error report
|
|
// set convert_tensor_to_num to true if ImplicitTensorToNums should be inserted
|
|
// to match the schema
|
|
|
|
struct MatchedSchema {
|
|
std::vector<Value*> inputs;
|
|
std::vector<TypePtr> return_types;
|
|
c10::OptNameList return_field_names;
|
|
};
|
|
|
|
TORCH_API c10::optional<MatchedSchema> tryMatchSchema(
|
|
const ::c10::FunctionSchema& schema,
|
|
const SourceRange& loc,
|
|
Graph& graph,
|
|
c10::optional<NamedValue> self,
|
|
at::ArrayRef<NamedValue> inputs,
|
|
at::ArrayRef<NamedValue> attributes,
|
|
std::ostream& failure_messages,
|
|
bool allow_conversions);
|
|
|
|
TORCH_API Value* emitBuiltinCall(
|
|
const SourceRange& loc,
|
|
Graph& graph,
|
|
Symbol name,
|
|
const c10::optional<NamedValue>& self,
|
|
at::ArrayRef<NamedValue> inputs,
|
|
at::ArrayRef<NamedValue> attributes,
|
|
// if true, emitBuiltinCall will throw an exception if this builtin does not
|
|
// exist, otherwise it will return nullptr if the builtin is not found.
|
|
bool required);
|
|
|
|
TORCH_API c10::optional<size_t> findInputWithName(
|
|
const std::string& name,
|
|
at::ArrayRef<NamedValue> kwargs);
|
|
|
|
// applies implict conversion from value trying to turn it into type
|
|
// concrete_type it succeeds if the return_value->isSubclassOf(concrete_type)
|
|
TORCH_API Value* tryConvertToType(
|
|
const SourceRange& loc,
|
|
Graph& graph,
|
|
const TypePtr& concrete_type,
|
|
Value* value,
|
|
bool allow_conversions);
|
|
|
|
} // namespace script
|
|
} // namespace jit
|
|
} // namespace torch
|