mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Replaces https://github.com/pytorch/pytorch/pull/21501 because ghimport had errors when i tried to import the stack that i couldn't figure out :'( has the two commits that were previously accepted and the merge commit Pull Request resolved: https://github.com/pytorch/pytorch/pull/22561 Differential Revision: D16135743 Pulled By: eellison fbshipit-source-id: f0a98842ccb334c7ceab04d1437e09dc76be0eb1
115 lines
3.3 KiB
C++
115 lines
3.3 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/jit/ir.h>
|
|
#include <torch/csrc/jit/irparser.h>
|
|
#include <torch/csrc/jit/passes/constant_pooling.h>
|
|
#include <torch/csrc/jit/passes/constant_propagation.h>
|
|
#include <torch/csrc/jit/testing/file_check.h>
|
|
#include "test/cpp/jit/test_base.h"
|
|
#include "torch/csrc/jit/custom_operator.h"
|
|
|
|
#include <sstream>
|
|
#include <string>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
void testConstantPropagation() {
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
script::parseIR(
|
|
R"IR(
|
|
graph():
|
|
%1 : int = prim::Constant[value=1]()
|
|
%0 : int = prim::Constant[value=0]()
|
|
%x : (int, int) = prim::TupleConstruct(%0, %1)
|
|
%y : int = prim::TupleIndex(%x, %0)
|
|
%5 : int = aten::add(%y, %y)
|
|
return (%5)
|
|
)IR",
|
|
&*graph);
|
|
// optimize through tuple construct and indexing
|
|
ConstantPropagation(graph);
|
|
testing::FileCheck()
|
|
.check("graph")
|
|
->check_next("prim::Constant[value=0]")
|
|
->check_next("return")
|
|
->run(*graph);
|
|
}
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
script::parseIR(
|
|
R"IR(
|
|
graph():
|
|
%10 : None = prim::Constant()
|
|
%7 : int = prim::Constant[value=0]()
|
|
%1 : int = prim::Constant[value=1]()
|
|
%0 : int = prim::Constant[value=3]()
|
|
%x : (int, int) = prim::TupleConstruct(%0, %1)
|
|
%y : (int, (int, int)) = prim::TupleConstruct(%1, %x)
|
|
%6 : (int, int) = prim::TupleIndex(%y, %1)
|
|
%z : int = prim::TupleIndex(%6, %7)
|
|
%9 : int = aten::add(%z, %z)
|
|
%ign = prim::Print(%y, %9)
|
|
return (%10) )IR",
|
|
&*graph);
|
|
ConstantPropagation(graph);
|
|
// The index should be optimized away, with a computed value of 6,
|
|
// and the TupleConstructs should still remain
|
|
testing::FileCheck()
|
|
.check_count("TupleConstruct", 2)
|
|
->check_not("TupleIndex")
|
|
->check("value=6")
|
|
->run(*graph);
|
|
}
|
|
{
|
|
RegisterOperators reg({
|
|
Operator(
|
|
"prim::test_tuple() -> (float[])",
|
|
[](const Node* node) {
|
|
return [](Stack& stack) {
|
|
c10::List<double> list;
|
|
auto li = IValue(list);
|
|
std::vector<IValue> tup = {li};
|
|
push(
|
|
stack,
|
|
c10::ivalue::Tuple::create(
|
|
tup, TupleType::create({ListType::ofFloats()})));
|
|
return 0;
|
|
};
|
|
}),
|
|
Operator(
|
|
"prim::run_float_list(float[] a) -> (int)",
|
|
[](const Node* node) {
|
|
return [](Stack& stack) {
|
|
pop(stack);
|
|
push(stack, 1);
|
|
return 0;
|
|
};
|
|
}),
|
|
});
|
|
auto graph = std::make_shared<Graph>();
|
|
script::parseIR(
|
|
R"IR(
|
|
graph():
|
|
%2 : (float[]) = prim::test_tuple()
|
|
%1 : int = prim::Constant[value=0]()
|
|
%y : float[] = prim::TupleIndex(%2, %1)
|
|
%z : int = prim::run_float_list(%y)
|
|
return (%z)
|
|
)IR",
|
|
&*graph);
|
|
ConstantPropagation(graph);
|
|
// float[] are not embeddable as constants, so we should not
|
|
// run the run_float_list op.
|
|
// this logic prevents e.g. running a tensor with grad in constant prop
|
|
testing::FileCheck()
|
|
.check("test_tuple")
|
|
->check("TupleIndex")
|
|
->check("run_float_list")
|
|
->run(*graph);
|
|
}
|
|
}
|
|
} // namespace jit
|
|
} // namespace torch
|