mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Previously `torch.jit.trace` relies on AutoGrad hooks to infer name of tensors in computation, including those of function/method arguments. This often doesn't work out because: - These names often do not exist - Tracer uses argument name of first tensor operation on each tensor as inferred argument names. These tensor operations have programmatically-generated names like `argument_1` This PR extracts argument names directly from Python functions and pass them down to tracer, which then assigns them to correct graph inputs. This way, we always have the correct argument names captured in IR. This is useful for both debugging and supporting using `InterfaceType` to represent traced modules. Pull Request resolved: https://github.com/pytorch/pytorch/pull/51775 Reviewed By: izdeby Differential Revision: D26273105 Pulled By: gmagogsfm fbshipit-source-id: 934a385041137dc3731bb6fa8657b11532fed9e5
38 lines
911 B
C++
38 lines
911 B
C++
#pragma once
|
|
|
|
#include <torch/csrc/jit/frontend/source_range.h>
|
|
#include <torch/csrc/jit/frontend/tracer.h>
|
|
#include <torch/csrc/python_headers.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
|
|
#include <memory>
|
|
#include <string>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
struct Module;
|
|
|
|
namespace tracer {
|
|
void initPythonTracerBindings(PyObject* module);
|
|
|
|
SourceRange getPythonInterpreterSourceRange();
|
|
|
|
Node* preRecordPythonTrace(
|
|
THPObjectPtr pyobj,
|
|
const std::string& arg_types,
|
|
at::ArrayRef<autograd::Variable> inputs,
|
|
std::vector<THPObjectPtr> scalar_args);
|
|
|
|
std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracing(
|
|
const py::function& func,
|
|
Stack inputs,
|
|
const py::function& var_name_lookup_fn,
|
|
bool strict,
|
|
bool force_outplace,
|
|
Module* self = nullptr,
|
|
const std::vector<std::string>& argument_names = {});
|
|
} // namespace tracer
|
|
} // namespace jit
|
|
} // namespace torch
|