JIT IR - Make valueMapPtr optional in convertNetDefToIR (#17942)

Summary:
Make valueMapPtr optional in convertNetDefToIR, and add tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17942

Differential Revision: D14429687

Pulled By: duc0

fbshipit-source-id: 3a5a72bbb5acc1bfd7144a987688c599016fbf7a
This commit is contained in:
Duc Ngo 2019-03-14 12:16:12 -07:00 committed by Facebook Github Bot
parent 53fb9a462a
commit 5cbc1981f3
3 changed files with 14 additions and 2 deletions

View File

@ -230,6 +230,8 @@ void testNetDefConverter() {
Graph graph;
std::unordered_map<std::string, Value*> vmap;
convertNetDefToIR(net, &graph, &vmap, "caffe2::");
// Sanity check that value map is returned and it works.
AT_ASSERT(vmap["a"]->uniqueName() == "a");
caffe2::NetDef net2;
convertIRToNetDef(&net2, graph, "caffe2::");
@ -246,6 +248,10 @@ void testNetDefConverter() {
AT_ASSERT(net2.external_input(0) == "a");
AT_ASSERT(net2.external_output(0) == "c");
AT_ASSERT(net3.external_input(0) == "a");
Graph graph2;
// Test that conversion works without passing in a valueMap.
convertNetDefToIR(net, &graph2, nullptr, "caffe2::");
}
}

View File

@ -76,6 +76,12 @@ void convertNetDefToIR(
Graph* g,
std::unordered_map<std::string, Value*>* valueMapPtr,
const std::string& prefix) {
if (!valueMapPtr) {
std::unordered_map<std::string, Value*> localValueMap;
// If valueMapPtr is null, we just use a local map since we don't need
// to return the valueMap to the caller.
return convertNetDefToIR(net, g, &localValueMap, prefix);
}
std::unordered_map<std::string, Value*>& valueMap = *valueMapPtr;
std::unordered_map<Value*, std::string> namesMap;
valueMap.clear();

View File

@ -11,14 +11,14 @@ namespace jit {
* The NetDef \p net is converted and the result is stored in the
* torch::jit::Graph \p graph. The function also records name->value map in \p
* valueMapPtr. If the original net had several values with the same name, the
* map will contain the value for the last definition.
* map will contain the value for the last definition. valueMapPtr is optional.
* \p Prefix can be used for appending some string to every operator name (e.g.
* we can add "caffe2::").
*/
void convertNetDefToIR(
const caffe2::NetDef& net,
Graph* graph,
std::unordered_map<std::string, Value*>* valueMapPtr,
std::unordered_map<std::string, Value*>* valueMapPtr = nullptr,
const std::string& prefix = "");
/** \brief Convert PyTorch IR \p graph to Caffe2 NetDef \p net.