pytorch/torch/csrc/jit/ir/graph_utils.h
Maxwell Nuyens 0d0ebcdfe5 feature: adding the ability to restore shapes after loading a traced model (#90744)
Adds the ability to store inputs used in tracing models when calling torch.jit.save and restore the input shapes using torch.jit.load if the appropriate variables are set.

Fixes [89185](https://github.com/pytorch/pytorch/issues/89185)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90744
Approved by: https://github.com/davidberard98
2023-02-10 17:12:52 +00:00

26 lines
529 B
C++

#pragma once
#include <torch/csrc/jit/ir/ir.h>
#include <vector>
namespace torch {
namespace jit {
TORCH_API TypePtr getTensorType(const at::Tensor& t, bool complete);
TORCH_API TypePtr inferShapeAndTypeForInput(
TypePtr input_type,
Stack::const_iterator& s_iter,
const Stack::const_iterator& s_iter_end,
bool complete);
TORCH_API void setInputTensorTypes(
Graph& g,
const Stack& stack,
bool complete,
const std::vector<int>& param_count_list = {});
} // namespace jit
} // namespace torch