pytorch/torch/nn/modules/utils.py
Elias Ellison fea618b524 [JIT] remove list with default builtin (#34171)
Summary:
I think this was added when we couldn't compile the function itself. now we can.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34171

Differential Revision: D20269960

Pulled By: eellison

fbshipit-source-id: 0a60458d639995d9448789c249d405343881b304
2020-03-09 16:02:26 -07:00

34 lines
935 B
Python

from torch._six import container_abcs
from itertools import repeat
def _ntuple(n):
def parse(x):
if isinstance(x, container_abcs.Iterable):
return x
return tuple(repeat(x, n))
return parse
_single = _ntuple(1)
_pair = _ntuple(2)
_triple = _ntuple(3)
_quadruple = _ntuple(4)
def _repeat_tuple(t, n):
r"""Repeat each element of `t` for `n` times.
This can be used to translate padding arg used by Conv and Pooling modules
to the ones used by `F.pad`.
"""
return tuple(x for x in t for _ in range(n))
def _list_with_default(out_size, defaults):
# type: (List[int], List[int]) -> List[int]
if isinstance(out_size, int):
return out_size
if len(defaults) <= len(out_size):
raise ValueError('Input dimension should be at least {}'.format(len(out_size) + 1))
return [v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size):])]