mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: zdevito ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/10384 Reviewed By: ezyang Differential Revision: D9272583 Pulled By: apaszke fbshipit-source-id: ed5d37b28ce9ff25800bbaa0daf066cfbf1f9921
38 lines
954 B
Python
38 lines
954 B
Python
from .backend import FunctionBackend
|
|
|
|
|
|
class THNNFunctionBackend(FunctionBackend):
|
|
|
|
def __reduce__(self):
|
|
return (_get_thnn_function_backend, ())
|
|
|
|
def __deepcopy__(self, memo):
|
|
memo[id(self)] = self
|
|
return self
|
|
|
|
def __copy__(self):
|
|
return self
|
|
|
|
|
|
def _get_thnn_function_backend():
|
|
return backend
|
|
|
|
|
|
def _initialize_backend():
|
|
from .._functions.thnn import _all_functions as _thnn_functions
|
|
from .._functions.rnn import RNN, \
|
|
RNNTanhCell, RNNReLUCell, GRUCell, LSTMCell
|
|
|
|
backend.register_function('RNN', RNN)
|
|
backend.register_function('RNNTanhCell', RNNTanhCell)
|
|
backend.register_function('RNNReLUCell', RNNReLUCell)
|
|
backend.register_function('LSTMCell', LSTMCell)
|
|
backend.register_function('GRUCell', GRUCell)
|
|
for cls in _thnn_functions:
|
|
name = cls.__name__
|
|
backend.register_function(name, cls)
|
|
|
|
|
|
backend = THNNFunctionBackend()
|
|
_initialize_backend()
|