pytorch/torch/csrc/jit/python/module_python.h
Meghan Lele 6c1c1111de [JIT] Add reference semantics to TorchScript classes (#44324)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44324

**Summary**
This commit adds reference semantics to TorchScript class types;
modifications made to them within TorchScript will be visible in Python.

**Test Plan**
This commit adds a unit test to `TestClassType` that checks that
modifications made to a class type instance passed into TorchScript are
visible in Python after executing the scripted function or module.

**Fixes**
This commit closes #41421.

Test Plan: Imported from OSS

Reviewed By: gmagogsfm

Differential Revision: D24912807

Pulled By: SplitInfinity

fbshipit-source-id: d64ac6211012425b040b987e3358253016e84ca0
2021-06-30 14:27:17 -07:00

29 lines
685 B
C++

#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/csrc/jit/api/module.h>
namespace py = pybind11;
namespace torch {
namespace jit {
inline c10::optional<Module> as_module(const py::object& obj) {
if (py::isinstance(
obj, py::module::import("torch.jit").attr("ScriptModule"))) {
return py::cast<Module>(obj.attr("_c"));
}
return c10::nullopt;
}
inline c10::optional<Object> as_object(const py::object& obj) {
if (py::isinstance(
obj, py::module::import("torch.jit").attr("RecursiveScriptClass"))) {
return py::cast<Object>(obj.attr("_c"));
}
return c10::nullopt;
}
} // namespace jit
} // namespace torch