mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/31800 If we know that two constants are the same object, we can ignore other constraints and pool them together. This fixes an issue introduced by the other PR where quantization relied on constant pooling happening for correctness. Test Plan: Imported from OSS Differential Revision: D19269499 Pulled By: eellison fbshipit-source-id: 9d4396125aa6899cb081863d463d4f024135cbf4
86 lines
2.4 KiB
C++
86 lines
2.4 KiB
C++
#include <torch/csrc/jit/ir.h>
|
|
#include <torch/csrc/jit/irparser.h>
|
|
#include <torch/csrc/jit/passes/constant_pooling.h>
|
|
#include <torch/csrc/jit/passes/constant_propagation.h>
|
|
#include <torch/csrc/jit/testing/file_check.h>
|
|
#include "test/cpp/jit/test_base.h"
|
|
|
|
#include <sstream>
|
|
#include <string>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
void testConstantPooling() {
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
script::parseIR(
|
|
R"IR(
|
|
graph():
|
|
%8 : int = prim::Constant[value=1]()
|
|
%10 : int = prim::Constant[value=1]()
|
|
return (%8, %10)
|
|
)IR",
|
|
&*graph);
|
|
ConstantPooling(graph);
|
|
testing::FileCheck()
|
|
.check_count("prim::Constant", 1, /*exactly*/ true)
|
|
->run(*graph);
|
|
}
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%cond : Tensor):
|
|
%a : str = prim::Constant[value="bcd"]()
|
|
%3 : bool = aten::Bool(%cond)
|
|
%b : str = prim::If(%3)
|
|
block0():
|
|
%b.1 : str = prim::Constant[value="abc"]()
|
|
-> (%b.1)
|
|
block1():
|
|
%b.2 : str = prim::Constant[value="abc"]()
|
|
-> (%b.2)
|
|
%7 : (str, str) = prim::TupleConstruct(%a, %b)
|
|
return (%7)
|
|
)IR",
|
|
&*graph);
|
|
ConstantPooling(graph);
|
|
testing::FileCheck()
|
|
.check_count("prim::Constant[value=\"abc\"]", 1, /*exactly*/ true)
|
|
->check_count("prim::Constant[value=\"bcd\"]", 1, /*exactly*/ true)
|
|
->run(*graph);
|
|
}
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
script::parseIR(
|
|
R"IR(
|
|
graph():
|
|
%2 : int = prim::Constant[value=2]()
|
|
%1 : int = prim::Constant[value=1]()
|
|
%5 : int? = prim::Constant()
|
|
%7 : Device? = prim::Constant()
|
|
%15: bool = prim::Constant[value=0]()
|
|
%10 : int = prim::Constant[value=6]()
|
|
%3 : int[] = prim::ListConstruct(%1, %2)
|
|
%x : Tensor = aten::tensor(%3, %5, %7, %15)
|
|
%y : Tensor = aten::tensor(%3, %10, %7, %15)
|
|
%9 : int[] = prim::ListConstruct(%1, %2)
|
|
%z : Tensor = aten::tensor(%9, %10, %7, %15)
|
|
%f = prim::Print(%x, %y, %z)
|
|
return (%1)
|
|
)IR",
|
|
&*graph);
|
|
// three tensors created - two different devices among the three
|
|
// don't have good support for parsing tensor constants
|
|
ConstantPropagation(graph);
|
|
ConstantPooling(graph);
|
|
testing::FileCheck()
|
|
.check_count("Float(2) = prim::Constant", 1, /*exactly*/ true)
|
|
->check_count("Long(2) = prim::Constant", 1, /*exactly*/ true)
|
|
->run(*graph);
|
|
}
|
|
}
|
|
} // namespace jit
|
|
} // namespace torch
|