from torch._six import container_abcs from itertools import repeat def _ntuple(n): def parse(x): if isinstance(x, container_abcs.Iterable): return x return tuple(repeat(x, n)) return parse _single = _ntuple(1) _pair = _ntuple(2) _triple = _ntuple(3) _quadruple = _ntuple(4) def _repeat_tuple(t, n): r"""Repeat each element of `t` for `n` times. This can be used to translate padding arg used by Conv and Pooling modules to the ones used by `F.pad`. """ return tuple(x for x in t for _ in range(n)) def _list_with_default(out_size, defaults): if isinstance(out_size, int): return out_size if len(defaults) <= len(out_size): raise ValueError('Input dimension should be at least {}'.format(len(out_size) + 1)) return [v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size):])]