[DataLoader] Typing Enforcement for DataPipe at construct-time (#54066)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54066

## Feature
- Add a decorator `construct_time_validation` to validate each input datapipe according to the corresponding type hint.

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D27327236

Pulled By: ejguan

fbshipit-source-id: a9d4c6edb5b05090bd5a369eee50a6fb4d7cf957
This commit is contained in:
Erjia Guan 2021-04-02 15:19:06 -07:00 committed by Facebook GitHub Bot
parent 44edf8c421
commit 1535520f08
3 changed files with 79 additions and 4 deletions

View File

@ -13,7 +13,7 @@ from unittest import skipIf
import torch
import torch.nn as nn
from torch.testing._internal.common_utils import (TestCase, run_tests)
from torch.utils.data import IterDataPipe, RandomSampler, DataLoader
from torch.utils.data import IterDataPipe, RandomSampler, DataLoader, construct_time_validation
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, TypeVar, Set, Union
import torch.utils.data.datapipes as dp
@ -697,6 +697,40 @@ class TestTyping(TestCase):
dp = DP6() # type: ignore
self.assertTrue(dp.type.param == int)
def test_construct_time(self):
class DP0(IterDataPipe[Tuple]):
@construct_time_validation
def __init__(self, dp: IterDataPipe):
self.dp = dp
def __iter__(self) -> Iterator[Tuple]:
for d in self.dp:
yield d, str(d)
class DP1(IterDataPipe[int]):
@construct_time_validation
def __init__(self, dp: IterDataPipe[Tuple[int, str]]):
self.dp = dp
def __iter__(self) -> Iterator[int]:
for a, b in self.dp:
yield a
# Non-DataPipe input with DataPipe hint
datasource = [(1, '1'), (2, '2'), (3, '3')]
with self.assertRaisesRegex(TypeError, r"Expected argument 'dp' as a IterDataPipe"):
dp = DP0(datasource)
dp = DP0(IDP(range(10)))
with self.assertRaisesRegex(TypeError, r"Expected type of argument 'dp' as a subtype"):
dp = DP1(dp)
with self.assertRaisesRegex(TypeError, r"Can not decorate"):
class InvalidDP1(IterDataPipe[int]):
@construct_time_validation
def __iter__(self):
yield 0
if __name__ == '__main__':
run_tests()

View File

@ -7,7 +7,9 @@ from torch.utils.data.dataset import \
from torch.utils.data.dataset import IterableDataset as IterDataPipe
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.dataloader import DataLoader, _DatasetKind, get_worker_info
from torch.utils.data.decorator import functional_datapipe, guaranteed_datapipes_determinism, non_deterministic
from torch.utils.data.decorator import \
(functional_datapipe, guaranteed_datapipes_determinism, non_deterministic,
construct_time_validation)
__all__ = ['Sampler', 'SequentialSampler', 'RandomSampler',
@ -16,7 +18,7 @@ __all__ = ['Sampler', 'SequentialSampler', 'RandomSampler',
'ConcatDataset', 'ChainDataset', 'Subset', 'random_split',
'DataLoader', '_DatasetKind', 'get_worker_info',
'IterDataPipe', 'functional_datapipe', 'guaranteed_datapipes_determinism',
'non_deterministic']
'non_deterministic', 'construct_time_validation']
################################################################################

View File

@ -1,8 +1,13 @@
from typing import Any, Callable, Optional, Type, Union
import inspect
from functools import wraps
from typing import Any, Callable, Optional, Type, Union, get_type_hints
from torch.utils.data import IterDataPipe
from torch.utils.data._typing import _DataPipeMeta
######################################################
# Functional API
######################################################
class functional_datapipe(object):
name: str
@ -23,6 +28,9 @@ class functional_datapipe(object):
return cls
######################################################
# Determinism
######################################################
_determinism: bool = False
@ -94,3 +102,34 @@ class non_deterministic(object):
"for this DataPipe if that is acceptable for your application"
.format(self.cls.__name__)) # type: ignore
return self.cls(*args, **kwargs) # type: ignore
######################################################
# typing
######################################################
# Construct-time checking
# Validate each DataPipe with hint as a subtype of the hint.
def construct_time_validation(f):
if f.__name__ not in ('__init__', '__new__'):
raise TypeError("Can not decorate function {} with 'construct_time_validation'"
.format(f.__name__))
signature = inspect.signature(f)
hints = get_type_hints(f)
@wraps(f)
def wrapper(*args, **kwargs):
bound = signature.bind(*args, **kwargs)
for argument_name, value in bound.arguments.items():
if argument_name in hints and isinstance(hints[argument_name], _DataPipeMeta):
hint = hints[argument_name]
if not isinstance(value, IterDataPipe):
raise TypeError("Expected argument '{}' as a IterDataPipe, but found {}"
.format(argument_name, type(value)))
if not value.type.issubtype(hint.type):
raise TypeError("Expected type of argument '{}' as a subtype of "
"hint {}, but found {}"
.format(argument_name, hint.type, value.type))
return f(*args, **kwargs)
return wrapper