mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
git-subtree-dir: torch/lib/THNN git-subtree-mainline:c3f0c1e2e0git-subtree-split:4fe7059a31
59 lines
1.2 KiB
C
59 lines
1.2 KiB
C
#ifndef TH_GENERIC_FILE
|
|
#define TH_GENERIC_FILE "generic/Threshold.c"
|
|
#else
|
|
|
|
void THNN_(Threshold_updateOutput)(
|
|
THNNState *state,
|
|
THTensor *input,
|
|
THTensor *output,
|
|
real threshold,
|
|
real val,
|
|
bool inplace)
|
|
{
|
|
if (inplace)
|
|
{
|
|
TH_TENSOR_APPLY(real, input,
|
|
if (*input_data <= threshold)
|
|
*input_data = val;
|
|
);
|
|
THTensor_(set)(output, input);
|
|
}
|
|
else
|
|
{
|
|
THTensor_(resizeAs)(output, input);
|
|
TH_TENSOR_APPLY2(real, output, real, input,
|
|
*output_data = (*input_data > threshold) ? *input_data : val;
|
|
);
|
|
}
|
|
}
|
|
|
|
void THNN_(Threshold_updateGradInput)(
|
|
THNNState *state,
|
|
THTensor *input,
|
|
THTensor *gradOutput,
|
|
THTensor *gradInput,
|
|
real threshold,
|
|
bool inplace)
|
|
{
|
|
if (inplace)
|
|
{
|
|
TH_TENSOR_APPLY2(real, gradOutput, real, input,
|
|
if ((*input_data) <= threshold)
|
|
*gradOutput_data = 0;
|
|
);
|
|
THTensor_(set)(gradInput, gradOutput);
|
|
}
|
|
else
|
|
{
|
|
THTensor_(resizeAs)(gradInput, input);
|
|
TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, input,
|
|
if ((*input_data) > threshold)
|
|
*gradInput_data = *gradOutput_data;
|
|
else
|
|
*gradInput_data = 0;
|
|
);
|
|
}
|
|
}
|
|
|
|
#endif
|