mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[DataLoader] Typing Enforcement for DataPipe at construct-time (#54066)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54066 ## Feature - Add a decorator `construct_time_validation` to validate each input datapipe according to the corresponding type hint. Test Plan: Imported from OSS Reviewed By: VitalyFedyunin Differential Revision: D27327236 Pulled By: ejguan fbshipit-source-id: a9d4c6edb5b05090bd5a369eee50a6fb4d7cf957
This commit is contained in:
parent
44edf8c421
commit
1535520f08
|
|
@ -13,7 +13,7 @@ from unittest import skipIf
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.testing._internal.common_utils import (TestCase, run_tests)
|
||||
from torch.utils.data import IterDataPipe, RandomSampler, DataLoader
|
||||
from torch.utils.data import IterDataPipe, RandomSampler, DataLoader, construct_time_validation
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, TypeVar, Set, Union
|
||||
|
||||
import torch.utils.data.datapipes as dp
|
||||
|
|
@ -697,6 +697,40 @@ class TestTyping(TestCase):
|
|||
dp = DP6() # type: ignore
|
||||
self.assertTrue(dp.type.param == int)
|
||||
|
||||
def test_construct_time(self):
|
||||
class DP0(IterDataPipe[Tuple]):
|
||||
@construct_time_validation
|
||||
def __init__(self, dp: IterDataPipe):
|
||||
self.dp = dp
|
||||
|
||||
def __iter__(self) -> Iterator[Tuple]:
|
||||
for d in self.dp:
|
||||
yield d, str(d)
|
||||
|
||||
class DP1(IterDataPipe[int]):
|
||||
@construct_time_validation
|
||||
def __init__(self, dp: IterDataPipe[Tuple[int, str]]):
|
||||
self.dp = dp
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
for a, b in self.dp:
|
||||
yield a
|
||||
|
||||
# Non-DataPipe input with DataPipe hint
|
||||
datasource = [(1, '1'), (2, '2'), (3, '3')]
|
||||
with self.assertRaisesRegex(TypeError, r"Expected argument 'dp' as a IterDataPipe"):
|
||||
dp = DP0(datasource)
|
||||
|
||||
dp = DP0(IDP(range(10)))
|
||||
with self.assertRaisesRegex(TypeError, r"Expected type of argument 'dp' as a subtype"):
|
||||
dp = DP1(dp)
|
||||
|
||||
with self.assertRaisesRegex(TypeError, r"Can not decorate"):
|
||||
class InvalidDP1(IterDataPipe[int]):
|
||||
@construct_time_validation
|
||||
def __iter__(self):
|
||||
yield 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -7,7 +7,9 @@ from torch.utils.data.dataset import \
|
|||
from torch.utils.data.dataset import IterableDataset as IterDataPipe
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.dataloader import DataLoader, _DatasetKind, get_worker_info
|
||||
from torch.utils.data.decorator import functional_datapipe, guaranteed_datapipes_determinism, non_deterministic
|
||||
from torch.utils.data.decorator import \
|
||||
(functional_datapipe, guaranteed_datapipes_determinism, non_deterministic,
|
||||
construct_time_validation)
|
||||
|
||||
|
||||
__all__ = ['Sampler', 'SequentialSampler', 'RandomSampler',
|
||||
|
|
@ -16,7 +18,7 @@ __all__ = ['Sampler', 'SequentialSampler', 'RandomSampler',
|
|||
'ConcatDataset', 'ChainDataset', 'Subset', 'random_split',
|
||||
'DataLoader', '_DatasetKind', 'get_worker_info',
|
||||
'IterDataPipe', 'functional_datapipe', 'guaranteed_datapipes_determinism',
|
||||
'non_deterministic']
|
||||
'non_deterministic', 'construct_time_validation']
|
||||
|
||||
|
||||
################################################################################
|
||||
|
|
|
|||
|
|
@ -1,8 +1,13 @@
|
|||
from typing import Any, Callable, Optional, Type, Union
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Optional, Type, Union, get_type_hints
|
||||
from torch.utils.data import IterDataPipe
|
||||
from torch.utils.data._typing import _DataPipeMeta
|
||||
|
||||
|
||||
######################################################
|
||||
# Functional API
|
||||
######################################################
|
||||
class functional_datapipe(object):
|
||||
name: str
|
||||
|
||||
|
|
@ -23,6 +28,9 @@ class functional_datapipe(object):
|
|||
return cls
|
||||
|
||||
|
||||
######################################################
|
||||
# Determinism
|
||||
######################################################
|
||||
_determinism: bool = False
|
||||
|
||||
|
||||
|
|
@ -94,3 +102,34 @@ class non_deterministic(object):
|
|||
"for this DataPipe if that is acceptable for your application"
|
||||
.format(self.cls.__name__)) # type: ignore
|
||||
return self.cls(*args, **kwargs) # type: ignore
|
||||
|
||||
|
||||
######################################################
|
||||
# typing
|
||||
######################################################
|
||||
# Construct-time checking
|
||||
# Validate each DataPipe with hint as a subtype of the hint.
|
||||
def construct_time_validation(f):
|
||||
if f.__name__ not in ('__init__', '__new__'):
|
||||
raise TypeError("Can not decorate function {} with 'construct_time_validation'"
|
||||
.format(f.__name__))
|
||||
signature = inspect.signature(f)
|
||||
hints = get_type_hints(f)
|
||||
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
bound = signature.bind(*args, **kwargs)
|
||||
for argument_name, value in bound.arguments.items():
|
||||
if argument_name in hints and isinstance(hints[argument_name], _DataPipeMeta):
|
||||
hint = hints[argument_name]
|
||||
if not isinstance(value, IterDataPipe):
|
||||
raise TypeError("Expected argument '{}' as a IterDataPipe, but found {}"
|
||||
.format(argument_name, type(value)))
|
||||
if not value.type.issubtype(hint.type):
|
||||
raise TypeError("Expected type of argument '{}' as a subtype of "
|
||||
"hint {}, but found {}"
|
||||
.format(argument_name, hint.type, value.type))
|
||||
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user