mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Implement gradient operator for GatherByKeys. (#24348)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/24348 Partition + GatherByKeys pair is pretty handy for implementing strategy where part of the keys will be on local machine, while part of the keys will end up on the remote machin (for cases when there is exactly 1 id). Reviewed By: aazzolini Differential Revision: D16802988 fbshipit-source-id: 4c7ac97fc0db3ce88575fccab0c7bf69dcbef965
This commit is contained in:
parent
b0e794e6e9
commit
6f08be46b0
|
|
@ -106,10 +106,52 @@ X_0_part_0, X_1_part_0, ..., X_N-1_part_0, X_0_part_1, ..., X_N-1_part_K-1
|
|||
"Output Partitions. The number of output tensors has to be a "
|
||||
"multiple of the number of input tensors.");
|
||||
|
||||
namespace {
|
||||
|
||||
class GetGatherByKeyGradient : public GradientMakerBase {
|
||||
using GradientMakerBase::GradientMakerBase;
|
||||
std::vector<OperatorDef> GetGradientDefs() override {
|
||||
ArgumentHelper argsHelper(def_);
|
||||
auto pack_first_input =
|
||||
argsHelper.GetSingleArgument<int>("pack_first_input", 0);
|
||||
|
||||
Argument packArg = MakeArgument<int>("pack_first_input", pack_first_input);
|
||||
if (g_output_[0].IsDense()) {
|
||||
std::vector<std::string> inputs;
|
||||
for (int i = 1; i < g_input_.size(); ++i) {
|
||||
inputs.push_back("_" + GI(i) + "_keys");
|
||||
inputs.push_back(GI(i));
|
||||
}
|
||||
return SingleGradientDef(
|
||||
"Partition",
|
||||
"",
|
||||
std::vector<std::string>{I(0), GO(0)},
|
||||
inputs,
|
||||
std::vector<Argument>{packArg});
|
||||
} else {
|
||||
std::vector<std::string> inputs;
|
||||
for (int i = 1; i < g_input_.size(); ++i) {
|
||||
inputs.push_back("_" + GI_I(i) + "_keys");
|
||||
inputs.push_back(GI_I(i));
|
||||
inputs.push_back(GI_V(i));
|
||||
}
|
||||
return SingleGradientDef(
|
||||
"Partition",
|
||||
"",
|
||||
std::vector<std::string>{I(0), GO_I(0), GO_V(0)},
|
||||
inputs,
|
||||
std::vector<Argument>{packArg});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// This should actually have gradient, but for now nothing uses it.
|
||||
// Because gradient computation right now is not input/output aware it can't be
|
||||
// GRADIENT_NOT_IMPLEMENTEDYET
|
||||
NO_GRADIENT(Partition);
|
||||
NO_GRADIENT(LengthsPartition);
|
||||
REGISTER_GRADIENT(GatherByKey, GetGatherByKeyGradient);
|
||||
} // namespace
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user