pytorch/torch/nn/backends/thnn.py
Adam Paszke adbcb3c1dc Move dropout and alpha dropout to ATen (#10384)
Summary:
zdevito ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10384

Reviewed By: ezyang

Differential Revision: D9272583

Pulled By: apaszke

fbshipit-source-id: ed5d37b28ce9ff25800bbaa0daf066cfbf1f9921
2018-08-10 14:55:28 -07:00

38 lines
954 B
Python

from .backend import FunctionBackend
class THNNFunctionBackend(FunctionBackend):
def __reduce__(self):
return (_get_thnn_function_backend, ())
def __deepcopy__(self, memo):
memo[id(self)] = self
return self
def __copy__(self):
return self
def _get_thnn_function_backend():
return backend
def _initialize_backend():
from .._functions.thnn import _all_functions as _thnn_functions
from .._functions.rnn import RNN, \
RNNTanhCell, RNNReLUCell, GRUCell, LSTMCell
backend.register_function('RNN', RNN)
backend.register_function('RNNTanhCell', RNNTanhCell)
backend.register_function('RNNReLUCell', RNNReLUCell)
backend.register_function('LSTMCell', LSTMCell)
backend.register_function('GRUCell', GRUCell)
for cls in _thnn_functions:
name = cls.__name__
backend.register_function(name, cls)
backend = THNNFunctionBackend()
_initialize_backend()