# The purpose of this test is to check that we have implementation parity between # a Python `torch.nn.functional` function and its corresponding C++ `torch::nn::functional` # function. Concretely, this test does the following: # # 1. Get a test params dict from common_nn.py, run forward pass on the Python functional # created using the test params. # # 2. Serialize the Python functional's forward input arguments, deserialize them # in C++ and use them as input for the C++ functional's forward pass. # # 3. Run the forward pass on the C++ functional, and serialize the C++ functional's # forward output. # # 4. Compare Python/C++ functional's forward output. If they are the same, then we # have implementation parity between Python/C++ module. import tempfile from string import Template import re import pprint import os import torch from cpp_api_parity.utils import TorchNNFunctionalTestParams, TORCH_NN_COMMON_TEST_HARNESS, \ compile_cpp_code_inline, set_python_tensors_requires_grad, move_python_tensors_to_device, \ add_test, compute_cpp_args_construction_stmts_and_forward_arg_symbols, serialize_arg_dict_as_script_module, \ compute_arg_dict, decorate_test_fn, compute_temp_file_path, generate_error_msg, is_torch_nn_functional_test, \ try_remove_folder from cpp_api_parity.sample_functional import SAMPLE_FUNCTIONAL_CPP_SOURCE # Expected substitutions: # # ${functional_variant_name} (e.g. `BCELoss_no_reduce`) # ${cpp_args_construction_stmts} # ${cpp_function_call} TORCH_NN_FUNCTIONAL_TEST_FORWARD = Template(""" void ${functional_variant_name}_test_forward( const std::string& arg_dict_file_path, const std::string& forward_output_file_path) { pybind11::gil_scoped_release no_gil; namespace F = torch::nn::functional; // Declare arguments auto arg_dict = load_dict_from_file(arg_dict_file_path); ${cpp_args_construction_stmts}; // Some functionals (such as `F::rrelu`) create random tensors in their call path. // To make sure the random tensors created are the same in Python/C++, we need // to set the RNG seed manually. torch::manual_seed(0); // Run function with arguments auto cpp_output = ${cpp_function_call}; // Save the output into a file to be compared in Python later write_ivalue_to_file(torch::IValue(cpp_output), forward_output_file_path); } """) def run_forward(unit_test_class, test_params): device = test_params.device inputs = set_python_tensors_requires_grad(move_python_tensors_to_device( [arg_value for _, arg_value in test_params.arg_dict['input']], device)) inputs += move_python_tensors_to_device( [arg_value for _, arg_value in test_params.arg_dict['target']], device) inputs += move_python_tensors_to_device( [arg_value for _, arg_value in test_params.arg_dict['extra_args']], device) # Some functionals (such as `F.rrelu`) create random tensors in their call path. # To make sure the random tensors created are the same in Python/C++, we need # to set the RNG seed manually. torch.manual_seed(0) python_output = test_params.test_instance.constructor()(*inputs) return python_output def test_forward(unit_test_class, test_params): functional_variant_name = test_params.functional_variant_name cpp_tmp_folder = test_params.cpp_tmp_folder # Remove the temporary folder if it exists already try_remove_folder(cpp_tmp_folder) os.mkdir(cpp_tmp_folder) # Run forward on Python functional python_output = run_forward(unit_test_class, test_params) # Save Python arguments to be used from C++ function arg_dict_file_path = compute_temp_file_path(cpp_tmp_folder, functional_variant_name, 'arg_dict') serialize_arg_dict_as_script_module(test_params.arg_dict).save(arg_dict_file_path) cpp_test_name = f'{test_params.functional_variant_name}_test_forward' cpp_test_fn = getattr(unit_test_class.functional_impl_check_cpp_module, cpp_test_name) def run_cpp_test_fn_and_check_output(): forward_output_file_path = compute_temp_file_path(cpp_tmp_folder, functional_variant_name, 'forward_output') cpp_test_fn(arg_dict_file_path, forward_output_file_path) cpp_output = torch.load(forward_output_file_path) # Check that forward outputs are equal unit_test_class.assertEqual( python_output, cpp_output, msg=generate_error_msg("forward output", cpp_output, python_output)) run_cpp_test_fn_and_check_output() # Remove temporary folder that stores C++ outputs try_remove_folder(cpp_tmp_folder) def compute_functional_name(test_params_dict): def camel_case_to_snake_case(camel_case_str): return re.sub(r'(? 0 cpp_sources = TORCH_NN_COMMON_TEST_HARNESS + SAMPLE_FUNCTIONAL_CPP_SOURCE functions = [] for test_name, test_params in unit_test_class.functional_test_params_map.items(): cpp_sources += generate_test_cpp_sources(test_params=test_params, template=TORCH_NN_FUNCTIONAL_TEST_FORWARD) functions.append(f'{test_params.functional_variant_name}_test_forward') if print_cpp_source: print(cpp_sources) cpp_module = compile_cpp_code_inline( name='functional_impl_check', cpp_sources=cpp_sources, functions=functions) unit_test_class.functional_impl_check_cpp_module = cpp_module