from typing import List from torch._six import container_abcs from itertools import repeat def _ntuple(n): def parse(x): if isinstance(x, container_abcs.Iterable): return x return tuple(repeat(x, n)) return parse _single = _ntuple(1) _pair = _ntuple(2) _triple = _ntuple(3) _quadruple = _ntuple(4) def _reverse_repeat_tuple(t, n): r"""Reverse the order of `t` and repeat each element for `n` times. This can be used to translate padding arg used by Conv and Pooling modules to the ones used by `F.pad`. """ return tuple(x for x in reversed(t) for _ in range(n)) def _list_with_default(out_size, defaults): # type: (List[int], List[int]) -> List[int] if isinstance(out_size, int): return out_size if len(defaults) <= len(out_size): raise ValueError('Input dimension should be at least {}'.format(len(out_size) + 1)) return [v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size):])]