mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
* add reduce=True arg to MarginRankingLoss * make default margin arg match for legacy * remove accidentally added test * fix test * fix native_functions.yaml alphabetical order
42 lines
1.2 KiB
Python
42 lines
1.2 KiB
Python
from .backend import FunctionBackend
|
|
|
|
|
|
class THNNFunctionBackend(FunctionBackend):
|
|
|
|
def __reduce__(self):
|
|
return (_get_thnn_function_backend, ())
|
|
|
|
def __deepcopy__(self, memo):
|
|
memo[id(self)] = self
|
|
return self
|
|
|
|
def __copy__(self):
|
|
return self
|
|
|
|
|
|
def _get_thnn_function_backend():
|
|
return backend
|
|
|
|
|
|
def _initialize_backend():
|
|
from .._functions.thnn import _all_functions as _thnn_functions
|
|
from .._functions.rnn import RNN, \
|
|
RNNTanhCell, RNNReLUCell, GRUCell, LSTMCell
|
|
from .._functions.dropout import Dropout, FeatureDropout
|
|
|
|
backend.register_function('RNN', RNN)
|
|
backend.register_function('RNNTanhCell', RNNTanhCell)
|
|
backend.register_function('RNNReLUCell', RNNReLUCell)
|
|
backend.register_function('LSTMCell', LSTMCell)
|
|
backend.register_function('GRUCell', GRUCell)
|
|
backend.register_function('Dropout', Dropout)
|
|
backend.register_function('Dropout2d', FeatureDropout)
|
|
backend.register_function('Dropout3d', FeatureDropout)
|
|
for cls in _thnn_functions:
|
|
name = cls.__name__
|
|
backend.register_function(name, cls)
|
|
|
|
|
|
backend = THNNFunctionBackend()
|
|
_initialize_backend()
|