Added ModuleInfo test for meta device ctx init (#105871)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105871
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki 2023-07-25 08:46:34 -07:00 committed by PyTorch MergeBot
parent 837363c72f
commit e18d53e2df
2 changed files with 42 additions and 6 deletions

View File

@ -1,15 +1,17 @@
# Owner(s): ["module: nn"]
from itertools import product
from itertools import chain, product
from inspect import signature, isgenerator
from copy import deepcopy
import tempfile
from operator import methodcaller
import torch
from torch._subclasses.meta_utils import assert_metadata_eq
from torch.testing._internal.common_cuda import with_tf32_off
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, onlyCUDA, toleranceOverride, tol, skipMeta)
instantiate_device_type_tests, onlyCPU, onlyCUDA, toleranceOverride, tol, skipMeta)
from torch.testing._internal.common_modules import module_db, modules, TrainEvalMode
from torch.testing._internal.common_utils import (
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck,
@ -734,6 +736,36 @@ class TestModule(TestCase):
raise e
@onlyCPU
@modules(module_db)
def test_device_ctx_init(self, device, dtype, module_info, training):
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False, training=training)
with torch.device('meta'):
module_inputs_meta = module_info.module_inputs_func(module_info, device=None, dtype=dtype,
requires_grad=False, training=training)
for module_input, module_input_meta in zip(module_inputs, module_inputs_meta):
c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
fw_args, fw_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
c_args_meta, c_kwargs_meta = module_input_meta.constructor_input.args, module_input_meta.constructor_input.kwargs
fw_args_meta, fw_kwargs_meta = module_input_meta.forward_input.args, module_input_meta.forward_input.kwargs
m_cpu = module_cls(*c_args, **c_kwargs)
with torch.device('meta'):
m = module_cls(*c_args_meta, **c_kwargs_meta)
for (p_meta, p_cpu) in chain(zip(m.parameters(), m_cpu.parameters()),
zip(m.buffers(), m_cpu.buffers())):
if torch.nn.parameter.is_lazy(p_meta):
continue
self.assertTrue(p_meta.is_meta)
assert_metadata_eq(self.assertEqual, p_meta, p_cpu)
instantiate_device_type_tests(TestModule, globals(), allow_mps=True)
if __name__ == '__main__':

View File

@ -1506,7 +1506,7 @@ def module_inputs_torch_nn_FractionalMaxPool2d(module_info, device, dtype, requi
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
def make_random_samples():
return torch.empty((1, 3, 2), dtype=torch.double).uniform_()
return torch.empty((1, 3, 2), dtype=torch.double, device=device).uniform_()
return [
ModuleInput(
@ -1540,7 +1540,7 @@ def module_inputs_torch_nn_FractionalMaxPool3d(module_info, device, dtype, requi
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
def make_random_samples():
return torch.empty((2, 4, 3), dtype=torch.double).uniform_()
return torch.empty((2, 4, 3), dtype=torch.double, device=device).uniform_()
return [
ModuleInput(
@ -1628,11 +1628,14 @@ def module_inputs_torch_nn_LogSigmoid(module_info, device, dtype, requires_grad,
def module_inputs_torch_nn_TransformerEncoder(module_info, device, dtype, requires_grad, training, **kwargs):
# Reuse the TransformerEncoderLayer samples since the forward args are nearly the same.
samples = []
for layer_module_input in module_inputs_torch_nn_TransformerEncoderLayer(
None, device, dtype, requires_grad, training):
# Construct a TransformerEncoderLayer object to pass to TransformerEncoder.
l_args, l_kwargs = (layer_module_input.constructor_input.args,
layer_module_input.constructor_input.kwargs)
l_kwargs['device'] = device
l_kwargs['dtype'] = dtype
encoder_layer = torch.nn.TransformerEncoderLayer(*l_args, **l_kwargs)
num_layers = 2
# Note: TransformerEncoderLayer takes a "src_mask" while
@ -1641,11 +1644,12 @@ def module_inputs_torch_nn_TransformerEncoder(module_info, device, dtype, requir
if 'src_mask' in forward_input.kwargs:
forward_input.kwargs['mask'] = forward_input.kwargs['src_mask']
del forward_input.kwargs['src_mask']
yield ModuleInput(
samples.append(ModuleInput(
constructor_input=FunctionInput(encoder_layer, num_layers),
forward_input=forward_input,
desc=layer_module_input.desc
)
))
return samples
def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)