Implement gradient operator for GatherByKeys. (#24348)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/24348

Partition + GatherByKeys pair is pretty handy for implementing strategy where
part of the keys will be on local machine, while part of the keys will end up
on the remote machin (for cases when there is exactly 1 id).

Reviewed By: aazzolini

Differential Revision: D16802988

fbshipit-source-id: 4c7ac97fc0db3ce88575fccab0c7bf69dcbef965
This commit is contained in:
Andrey Malevich 2019-08-15 12:17:59 -07:00 committed by Facebook Github Bot
parent b0e794e6e9
commit 6f08be46b0

View File

@ -106,10 +106,52 @@ X_0_part_0, X_1_part_0, ..., X_N-1_part_0, X_0_part_1, ..., X_N-1_part_K-1
"Output Partitions. The number of output tensors has to be a " "Output Partitions. The number of output tensors has to be a "
"multiple of the number of input tensors."); "multiple of the number of input tensors.");
namespace {
class GetGatherByKeyGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
std::vector<OperatorDef> GetGradientDefs() override {
ArgumentHelper argsHelper(def_);
auto pack_first_input =
argsHelper.GetSingleArgument<int>("pack_first_input", 0);
Argument packArg = MakeArgument<int>("pack_first_input", pack_first_input);
if (g_output_[0].IsDense()) {
std::vector<std::string> inputs;
for (int i = 1; i < g_input_.size(); ++i) {
inputs.push_back("_" + GI(i) + "_keys");
inputs.push_back(GI(i));
}
return SingleGradientDef(
"Partition",
"",
std::vector<std::string>{I(0), GO(0)},
inputs,
std::vector<Argument>{packArg});
} else {
std::vector<std::string> inputs;
for (int i = 1; i < g_input_.size(); ++i) {
inputs.push_back("_" + GI_I(i) + "_keys");
inputs.push_back(GI_I(i));
inputs.push_back(GI_V(i));
}
return SingleGradientDef(
"Partition",
"",
std::vector<std::string>{I(0), GO_I(0), GO_V(0)},
inputs,
std::vector<Argument>{packArg});
}
}
};
} // namespace
// This should actually have gradient, but for now nothing uses it. // This should actually have gradient, but for now nothing uses it.
// Because gradient computation right now is not input/output aware it can't be // Because gradient computation right now is not input/output aware it can't be
// GRADIENT_NOT_IMPLEMENTEDYET // GRADIENT_NOT_IMPLEMENTEDYET
NO_GRADIENT(Partition); NO_GRADIENT(Partition);
NO_GRADIENT(LengthsPartition); NO_GRADIENT(LengthsPartition);
REGISTER_GRADIENT(GatherByKey, GetGatherByKeyGradient);
} // namespace } // namespace
} // namespace caffe2 } // namespace caffe2