mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Adds the ability to store inputs used in tracing models when calling torch.jit.save and restore the input shapes using torch.jit.load if the appropriate variables are set. Fixes [89185](https://github.com/pytorch/pytorch/issues/89185) Pull Request resolved: https://github.com/pytorch/pytorch/pull/90744 Approved by: https://github.com/davidberard98
26 lines
529 B
C++
26 lines
529 B
C++
#pragma once
|
|
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
TORCH_API TypePtr getTensorType(const at::Tensor& t, bool complete);
|
|
|
|
TORCH_API TypePtr inferShapeAndTypeForInput(
|
|
TypePtr input_type,
|
|
Stack::const_iterator& s_iter,
|
|
const Stack::const_iterator& s_iter_end,
|
|
bool complete);
|
|
|
|
TORCH_API void setInputTensorTypes(
|
|
Graph& g,
|
|
const Stack& stack,
|
|
bool complete,
|
|
const std::vector<int>& param_count_list = {});
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|