mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Added ModuleInfo test for meta device ctx init (#105871)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105871 Approved by: https://github.com/albanD
This commit is contained in:
parent
837363c72f
commit
e18d53e2df
|
|
@ -1,15 +1,17 @@
|
|||
# Owner(s): ["module: nn"]
|
||||
|
||||
from itertools import product
|
||||
from itertools import chain, product
|
||||
from inspect import signature, isgenerator
|
||||
from copy import deepcopy
|
||||
import tempfile
|
||||
from operator import methodcaller
|
||||
|
||||
import torch
|
||||
|
||||
from torch._subclasses.meta_utils import assert_metadata_eq
|
||||
from torch.testing._internal.common_cuda import with_tf32_off
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests, onlyCUDA, toleranceOverride, tol, skipMeta)
|
||||
instantiate_device_type_tests, onlyCPU, onlyCUDA, toleranceOverride, tol, skipMeta)
|
||||
from torch.testing._internal.common_modules import module_db, modules, TrainEvalMode
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck,
|
||||
|
|
@ -734,6 +736,36 @@ class TestModule(TestCase):
|
|||
raise e
|
||||
|
||||
|
||||
@onlyCPU
|
||||
@modules(module_db)
|
||||
def test_device_ctx_init(self, device, dtype, module_info, training):
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=False, training=training)
|
||||
with torch.device('meta'):
|
||||
module_inputs_meta = module_info.module_inputs_func(module_info, device=None, dtype=dtype,
|
||||
requires_grad=False, training=training)
|
||||
|
||||
for module_input, module_input_meta in zip(module_inputs, module_inputs_meta):
|
||||
c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
fw_args, fw_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
||||
|
||||
c_args_meta, c_kwargs_meta = module_input_meta.constructor_input.args, module_input_meta.constructor_input.kwargs
|
||||
fw_args_meta, fw_kwargs_meta = module_input_meta.forward_input.args, module_input_meta.forward_input.kwargs
|
||||
|
||||
m_cpu = module_cls(*c_args, **c_kwargs)
|
||||
|
||||
with torch.device('meta'):
|
||||
m = module_cls(*c_args_meta, **c_kwargs_meta)
|
||||
|
||||
for (p_meta, p_cpu) in chain(zip(m.parameters(), m_cpu.parameters()),
|
||||
zip(m.buffers(), m_cpu.buffers())):
|
||||
if torch.nn.parameter.is_lazy(p_meta):
|
||||
continue
|
||||
self.assertTrue(p_meta.is_meta)
|
||||
assert_metadata_eq(self.assertEqual, p_meta, p_cpu)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestModule, globals(), allow_mps=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -1506,7 +1506,7 @@ def module_inputs_torch_nn_FractionalMaxPool2d(module_info, device, dtype, requi
|
|||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
def make_random_samples():
|
||||
return torch.empty((1, 3, 2), dtype=torch.double).uniform_()
|
||||
return torch.empty((1, 3, 2), dtype=torch.double, device=device).uniform_()
|
||||
|
||||
return [
|
||||
ModuleInput(
|
||||
|
|
@ -1540,7 +1540,7 @@ def module_inputs_torch_nn_FractionalMaxPool3d(module_info, device, dtype, requi
|
|||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
def make_random_samples():
|
||||
return torch.empty((2, 4, 3), dtype=torch.double).uniform_()
|
||||
return torch.empty((2, 4, 3), dtype=torch.double, device=device).uniform_()
|
||||
|
||||
return [
|
||||
ModuleInput(
|
||||
|
|
@ -1628,11 +1628,14 @@ def module_inputs_torch_nn_LogSigmoid(module_info, device, dtype, requires_grad,
|
|||
|
||||
def module_inputs_torch_nn_TransformerEncoder(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
# Reuse the TransformerEncoderLayer samples since the forward args are nearly the same.
|
||||
samples = []
|
||||
for layer_module_input in module_inputs_torch_nn_TransformerEncoderLayer(
|
||||
None, device, dtype, requires_grad, training):
|
||||
# Construct a TransformerEncoderLayer object to pass to TransformerEncoder.
|
||||
l_args, l_kwargs = (layer_module_input.constructor_input.args,
|
||||
layer_module_input.constructor_input.kwargs)
|
||||
l_kwargs['device'] = device
|
||||
l_kwargs['dtype'] = dtype
|
||||
encoder_layer = torch.nn.TransformerEncoderLayer(*l_args, **l_kwargs)
|
||||
num_layers = 2
|
||||
# Note: TransformerEncoderLayer takes a "src_mask" while
|
||||
|
|
@ -1641,11 +1644,12 @@ def module_inputs_torch_nn_TransformerEncoder(module_info, device, dtype, requir
|
|||
if 'src_mask' in forward_input.kwargs:
|
||||
forward_input.kwargs['mask'] = forward_input.kwargs['src_mask']
|
||||
del forward_input.kwargs['src_mask']
|
||||
yield ModuleInput(
|
||||
samples.append(ModuleInput(
|
||||
constructor_input=FunctionInput(encoder_layer, num_layers),
|
||||
forward_input=forward_input,
|
||||
desc=layer_module_input.desc
|
||||
)
|
||||
))
|
||||
return samples
|
||||
|
||||
def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user