mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
JIT IR - Make valueMapPtr optional in convertNetDefToIR (#17942)
Summary: Make valueMapPtr optional in convertNetDefToIR, and add tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/17942 Differential Revision: D14429687 Pulled By: duc0 fbshipit-source-id: 3a5a72bbb5acc1bfd7144a987688c599016fbf7a
This commit is contained in:
parent
53fb9a462a
commit
5cbc1981f3
|
|
@ -230,6 +230,8 @@ void testNetDefConverter() {
|
|||
Graph graph;
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
convertNetDefToIR(net, &graph, &vmap, "caffe2::");
|
||||
// Sanity check that value map is returned and it works.
|
||||
AT_ASSERT(vmap["a"]->uniqueName() == "a");
|
||||
|
||||
caffe2::NetDef net2;
|
||||
convertIRToNetDef(&net2, graph, "caffe2::");
|
||||
|
|
@ -246,6 +248,10 @@ void testNetDefConverter() {
|
|||
AT_ASSERT(net2.external_input(0) == "a");
|
||||
AT_ASSERT(net2.external_output(0) == "c");
|
||||
AT_ASSERT(net3.external_input(0) == "a");
|
||||
|
||||
Graph graph2;
|
||||
// Test that conversion works without passing in a valueMap.
|
||||
convertNetDefToIR(net, &graph2, nullptr, "caffe2::");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -76,6 +76,12 @@ void convertNetDefToIR(
|
|||
Graph* g,
|
||||
std::unordered_map<std::string, Value*>* valueMapPtr,
|
||||
const std::string& prefix) {
|
||||
if (!valueMapPtr) {
|
||||
std::unordered_map<std::string, Value*> localValueMap;
|
||||
// If valueMapPtr is null, we just use a local map since we don't need
|
||||
// to return the valueMap to the caller.
|
||||
return convertNetDefToIR(net, g, &localValueMap, prefix);
|
||||
}
|
||||
std::unordered_map<std::string, Value*>& valueMap = *valueMapPtr;
|
||||
std::unordered_map<Value*, std::string> namesMap;
|
||||
valueMap.clear();
|
||||
|
|
|
|||
|
|
@ -11,14 +11,14 @@ namespace jit {
|
|||
* The NetDef \p net is converted and the result is stored in the
|
||||
* torch::jit::Graph \p graph. The function also records name->value map in \p
|
||||
* valueMapPtr. If the original net had several values with the same name, the
|
||||
* map will contain the value for the last definition.
|
||||
* map will contain the value for the last definition. valueMapPtr is optional.
|
||||
* \p Prefix can be used for appending some string to every operator name (e.g.
|
||||
* we can add "caffe2::").
|
||||
*/
|
||||
void convertNetDefToIR(
|
||||
const caffe2::NetDef& net,
|
||||
Graph* graph,
|
||||
std::unordered_map<std::string, Value*>* valueMapPtr,
|
||||
std::unordered_map<std::string, Value*>* valueMapPtr = nullptr,
|
||||
const std::string& prefix = "");
|
||||
|
||||
/** \brief Convert PyTorch IR \p graph to Caffe2 NetDef \p net.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user