mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Test Plan: revert-hammer
Differential Revision:
D25718705 (891759f860)
Original commit changeset: 6a9e3e6d17aa
fbshipit-source-id: 1a4ef0bfdec8eb8e7ce149bfbdb34a4ad8d964b6
36 lines
1000 B
Python
36 lines
1000 B
Python
from typing import List
|
|
|
|
from torch._six import container_abcs
|
|
from itertools import repeat
|
|
|
|
|
|
def _ntuple(n):
|
|
def parse(x):
|
|
if isinstance(x, container_abcs.Iterable):
|
|
return x
|
|
return tuple(repeat(x, n))
|
|
return parse
|
|
|
|
_single = _ntuple(1)
|
|
_pair = _ntuple(2)
|
|
_triple = _ntuple(3)
|
|
_quadruple = _ntuple(4)
|
|
|
|
|
|
def _reverse_repeat_tuple(t, n):
|
|
r"""Reverse the order of `t` and repeat each element for `n` times.
|
|
|
|
This can be used to translate padding arg used by Conv and Pooling modules
|
|
to the ones used by `F.pad`.
|
|
"""
|
|
return tuple(x for x in reversed(t) for _ in range(n))
|
|
|
|
|
|
def _list_with_default(out_size, defaults):
|
|
# type: (List[int], List[int]) -> List[int]
|
|
if isinstance(out_size, int):
|
|
return out_size
|
|
if len(defaults) <= len(out_size):
|
|
raise ValueError('Input dimension should be at least {}'.format(len(out_size) + 1))
|
|
return [v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size):])]
|