mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: This PR adds Python/C++ API parity tracker at `test/cpp_api_parity/parity-tracker.md`, which currently shows parity status for `torch.nn` modules. A good amount of line changes here is moving `new_criterion_tests` from `test_nn.py` to `common_nn.py`, so that it can be used in `test_cpp_api_parity.py`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/25289 Differential Revision: D17188085 Pulled By: yf225 fbshipit-source-id: 33d12fb1a4de2d9147ed09380973f361a3981fdf
110 lines
3.5 KiB
Python
110 lines
3.5 KiB
Python
import torch
|
|
|
|
from cpp_api_parity import torch_nn_modules, TorchNNModuleMetadata
|
|
|
|
'''
|
|
`SampleModule` is used by `test_cpp_api_parity.py` to test that Python / C++ API
|
|
parity test harness works for `torch.nn.Module` subclasses.
|
|
|
|
When `SampleModule.has_parity` is true, behavior of `reset_parameters` / `forward` /
|
|
`backward` is the same as the C++ equivalent.
|
|
|
|
When `SampleModule.has_parity` is false, behavior of `reset_parameters` / `forward` /
|
|
`backward` is different from the C++ equivalent.
|
|
'''
|
|
|
|
class SampleModule(torch.nn.Module):
|
|
def __init__(self, has_parity, has_submodule, int_option=0, double_option=0.1,
|
|
bool_option=False, string_option='0', tensor_option=torch.empty(1)):
|
|
super(SampleModule, self).__init__()
|
|
self.has_parity = has_parity
|
|
self.register_parameter('param', torch.nn.Parameter(torch.empty(3, 4)))
|
|
self.register_buffer('buffer', torch.empty(4, 5))
|
|
if has_submodule:
|
|
self.submodule = SampleModule(self.has_parity, False)
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
with torch.no_grad():
|
|
self.param.fill_(1)
|
|
self.buffer.fill_(1)
|
|
self.attr = 10
|
|
if not self.has_parity:
|
|
self.param.add_(10)
|
|
self.buffer.add_(10)
|
|
self.attr += 90
|
|
|
|
def forward(self, x):
|
|
submodule_forward_result = self.submodule(x) if hasattr(self, 'submodule') else 0
|
|
if not self.has_parity:
|
|
return x + self.param * 4 + submodule_forward_result + 3
|
|
else:
|
|
return x + self.param * 2 + submodule_forward_result
|
|
|
|
SAMPLE_MODULE_CPP_SOURCE = """\n
|
|
namespace torch {
|
|
namespace nn{
|
|
struct C10_EXPORT SampleModuleOptions {
|
|
SampleModuleOptions(bool has_submodule) : has_submodule_(has_submodule) {}
|
|
TORCH_ARG(bool, has_submodule);
|
|
TORCH_ARG(int64_t, int_option);
|
|
TORCH_ARG(double, double_option);
|
|
TORCH_ARG(bool, bool_option);
|
|
TORCH_ARG(std::string, string_option);
|
|
TORCH_ARG(torch::Tensor, tensor_option);
|
|
};
|
|
|
|
struct C10_EXPORT SampleModuleImpl : public torch::nn::Cloneable<SampleModuleImpl> {
|
|
SampleModuleImpl(bool has_submodule) : SampleModuleImpl(SampleModuleOptions(has_submodule)) {}
|
|
explicit SampleModuleImpl(SampleModuleOptions options) {
|
|
if (options.has_submodule_) {
|
|
submodule = register_module("submodule", std::make_shared<SampleModuleImpl>(false));
|
|
}
|
|
reset();
|
|
}
|
|
void reset() {
|
|
attr = 10;
|
|
param = register_parameter("param", torch::ones({3, 4}));
|
|
buffer = register_buffer("buffer", torch::ones({4, 5}));
|
|
}
|
|
torch::Tensor forward(torch::Tensor x) {
|
|
return x + param * 2 + (submodule ? submodule->forward(x) : torch::zeros_like(x));
|
|
}
|
|
torch::Tensor param;
|
|
torch::Tensor buffer;
|
|
int attr;
|
|
std::shared_ptr<SampleModuleImpl> submodule{nullptr};
|
|
};
|
|
|
|
TORCH_MODULE(SampleModule);
|
|
}
|
|
}
|
|
"""
|
|
|
|
module_tests = [
|
|
dict(
|
|
module_name='SampleModule',
|
|
constructor_args=(True, True),
|
|
cpp_constructor_args='(true)',
|
|
input_size=(3, 4),
|
|
desc='has_parity',
|
|
has_parity=True,
|
|
),
|
|
dict(
|
|
module_name='SampleModule',
|
|
constructor_args=(False, True),
|
|
cpp_constructor_args='(true)',
|
|
input_size=(3, 4),
|
|
desc='no_parity',
|
|
has_parity=False,
|
|
),
|
|
]
|
|
|
|
torch_nn_modules.module_metadata_map['SampleModule'] = TorchNNModuleMetadata(
|
|
cpp_default_constructor_args='(true)',
|
|
num_attrs_recursive=6,
|
|
cpp_sources=SAMPLE_MODULE_CPP_SOURCE,
|
|
)
|
|
|
|
torch.nn.SampleModule = SampleModule
|