mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: This PR enables C++ frontend modules to be bound into Python and added as submodules of Python modules. For this, I added lots of pybind11 bindings for the `torch::nn::Module` class, and modified the `torch.nn.Module` class in Python to have a new Metaclass that makes `isinstance(m, torch.nn.Module)` return true when `m` is a C++ frontend module. The methods and fields of C++ modules are bound in such a way that they work seamlessly as submodules of Python modules for most operations (one exception I know of: calling `.to()` ends up calling `.apply()` on each submodule with a Python lambda, which cannot be used in C++ -- this may require small changes on Python side). I've added quite a bunch of tests to verify the bindings and equality with Python. I think I should also try out adding a C++ module as part of some large PyTorch module, like a WLM or something, and see if everything works smoothly. The next step for inter-op across our system is ScriptModule <-> C++ Frontend Module inter-op. I think this will then also allow using C++ frontend modules from TorchScript. apaszke zdevito CC dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/13481 Differential Revision: D12981996 Pulled By: goldsborough fbshipit-source-id: 147370d3596ebb0e94c82cec92993a148fee50a7
78 lines
2.3 KiB
Python
78 lines
2.3 KiB
Python
"""Functionality for Python <-> C++ frontend inter-op."""
|
|
|
|
from torch import nn
|
|
|
|
|
|
class OrderedDictWrapper(object):
|
|
"""
|
|
A wrapper around a C++ OrderedDict that dynamically evaluates the
|
|
OrderedDict getter on a bound C++ module, such that new changes on the C++
|
|
side are picked up. Otherwise accessing e.g. ``cpp_module._parameters`` just
|
|
once would get a frozen copy of the parameters at the time of access.
|
|
``torch.nn.Module`` accesses ``_parameters`` et al. via ``self.__dict__`` so
|
|
using properties does not work.
|
|
"""
|
|
|
|
def __init__(self, cpp_module, attr):
|
|
self.cpp_module = cpp_module
|
|
self.attr = attr
|
|
|
|
@property
|
|
def cpp_dict(self):
|
|
return getattr(self.cpp_module, self.attr)
|
|
|
|
# Magic methods cannot be assigned dynamically and bypass ``getattr``, so we
|
|
# must manually override them.
|
|
|
|
def items(self):
|
|
return self.cpp_dict.items()
|
|
|
|
def keys(self):
|
|
return self.cpp_dict.keys()
|
|
|
|
def values(self):
|
|
return self.cpp_dict.values()
|
|
|
|
def __iter__(self):
|
|
return self.cpp_dict.__iter__()
|
|
|
|
def __len__(self):
|
|
return self.cpp_dict.__len__()
|
|
|
|
def __contains__(self, key):
|
|
return self.cpp_dict.__contains__(key)
|
|
|
|
def __getitem__(self, key):
|
|
return self.cpp_dict.__getitem__(key)
|
|
|
|
|
|
class ModuleWrapper(nn.Module):
|
|
"""
|
|
A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and
|
|
delegates all access.
|
|
"""
|
|
|
|
def __init__(self, cpp_module):
|
|
# Assign before the super class constructor so ``self.training`` can be
|
|
# assigned to in the super class constructor.
|
|
self.cpp_module = cpp_module
|
|
super(ModuleWrapper, self).__init__()
|
|
self._parameters = OrderedDictWrapper(cpp_module, "_parameters")
|
|
self._buffers = OrderedDictWrapper(cpp_module, "_buffers")
|
|
self._modules = OrderedDictWrapper(cpp_module, "_modules")
|
|
for attr in dir(cpp_module):
|
|
# Skip magic methods and the three attributes above.
|
|
if not attr.startswith("_"):
|
|
setattr(self, attr, getattr(self.cpp_module, attr))
|
|
|
|
@property
|
|
def training(self):
|
|
return self.cpp_module.training
|
|
|
|
@training.setter
|
|
def training(self, mode):
|
|
self.cpp_module.train(mode)
|
|
|
|
def __repr__(self):
|
|
return self.cpp_module.__repr__()
|