mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[JIT] Add special cases for batch_norm, instance_norm in alias_analysis (#66554)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66554 In native_functions.yaml, the schemas for batch_norm and instance_norm are incorrect: the inputs `running_mean` and `running_var` are mutated, but are not marked as such in the function schema. Since `(a!)?` annotations are currently not working (see #65760), this instead adds a special case to `alias_anaysis.cpp`. If the value of `training` or `use_input_stats` is known to be `false`, then `alias_analysis` will mark the input as _not_ being written to. Test Plan: Removed the `skip` annotation on the following test, and added a special exception in `check_alias_annotations`: ``` python test/test_ops.py -k test_variant_consistency_jit_nn_functional_batch_norm ``` Also: ``` ./build/bin/test_jit --gtest_filter="*BatchAndInstanceNormFixture*" ``` Imported from OSS Reviewed By: eellison Differential Revision: D31612339 fbshipit-source-id: 12ca61b782b9e41e06883ba080a276209dc435bb
This commit is contained in:
parent
cf77bd4cf4
commit
e86d8323cb
|
|
@ -5,6 +5,7 @@
|
|||
#include <torch/csrc/jit/ir/alias_analysis.h>
|
||||
#include <torch/csrc/jit/ir/irparser.h>
|
||||
#include <torch/csrc/jit/runtime/custom_operator.h>
|
||||
#include <torch/csrc/jit/runtime/graph_iterator.h>
|
||||
#include <torch/csrc/utils/memory.h>
|
||||
|
||||
namespace torch {
|
||||
|
|
@ -481,6 +482,135 @@ TEST(AliasAnalysisTest, SafeToChangeAliasingRelationship) {
|
|||
EXPECT_TRUE(aliasDb.safeToChangeAliasingRelationship(vmap["d"], vmap["c"]));
|
||||
}
|
||||
|
||||
class BatchAndInstanceNormFixture
|
||||
: public ::testing::TestWithParam<std::tuple<std::string, NodeKind, bool>> {
|
||||
};
|
||||
|
||||
TEST_P(BatchAndInstanceNormFixture, BatchAndInstanceNorm) {
|
||||
auto param = GetParam();
|
||||
auto fnName = std::get<0>(param);
|
||||
auto nodeKind = std::get<1>(param);
|
||||
auto isTraining = std::get<2>(param);
|
||||
std::string isTrainingStr = std::to_string((int)isTraining);
|
||||
|
||||
auto graph = std::make_shared<Graph>();
|
||||
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%input : Tensor, %running_mean : Tensor, %running_var : Tensor):
|
||||
%none : NoneType = prim::Constant()
|
||||
%training : bool = prim::Constant[value=)IR" +
|
||||
isTrainingStr + R"IR(]()
|
||||
%momentum : float = prim::Constant[value=1.0]()
|
||||
%eps : float = prim::Constant[value=1.0e-9]()
|
||||
%cudnn_enabled : bool = prim::Constant[value=0]()
|
||||
%res : Tensor = )IR" +
|
||||
fnName +
|
||||
R"IR((%input, %none, %none, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled)
|
||||
return (%res)
|
||||
)IR",
|
||||
&*graph);
|
||||
|
||||
graph->lint();
|
||||
DepthFirstGraphNodeIterator it(graph);
|
||||
|
||||
Node* n = nullptr;
|
||||
while ((n = it.next()) != nullptr) {
|
||||
if (n->kind() == nodeKind) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(n != nullptr);
|
||||
|
||||
AliasDb aliasDb(graph);
|
||||
EXPECT_TRUE(aliasDb.hasWriters(n) == isTraining);
|
||||
}
|
||||
|
||||
TEST_P(BatchAndInstanceNormFixture, BatchAndInstanceNormTrainingUnknown) {
|
||||
auto param = GetParam();
|
||||
auto fnName = std::get<0>(param);
|
||||
auto nodeKind = std::get<1>(param);
|
||||
|
||||
auto graph = std::make_shared<Graph>();
|
||||
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%input : Tensor, %running_mean : Tensor, %running_var : Tensor, %training : bool):
|
||||
%none : NoneType = prim::Constant()
|
||||
%momentum : float = prim::Constant[value=1.0]()
|
||||
%eps : float = prim::Constant[value=1.0e-9]()
|
||||
%cudnn_enabled : bool = prim::Constant[value=0]()
|
||||
%res : Tensor = )IR" +
|
||||
fnName +
|
||||
R"IR((%input, %none, %none, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled)
|
||||
return (%res)
|
||||
)IR",
|
||||
&*graph);
|
||||
|
||||
graph->lint();
|
||||
DepthFirstGraphNodeIterator it(graph);
|
||||
|
||||
Node* n = nullptr;
|
||||
while ((n = it.next()) != nullptr) {
|
||||
if (n->kind() == nodeKind) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(n != nullptr);
|
||||
|
||||
AliasDb aliasDb(graph);
|
||||
EXPECT_TRUE(aliasDb.hasWriters(n));
|
||||
}
|
||||
|
||||
TEST_P(BatchAndInstanceNormFixture, BatchNormTrainingWithNoMeanOrVar) {
|
||||
auto param = GetParam();
|
||||
auto fnName = std::get<0>(param);
|
||||
auto nodeKind = std::get<1>(param);
|
||||
auto isTraining = std::get<2>(param);
|
||||
std::string isTrainingStr = std::to_string((int)isTraining);
|
||||
|
||||
auto graph = std::make_shared<Graph>();
|
||||
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%input : Tensor):
|
||||
%none : NoneType = prim::Constant()
|
||||
%training : bool = prim::Constant[value=)IR" +
|
||||
isTrainingStr + R"IR(]()
|
||||
%momentum : float = prim::Constant[value=1.0]()
|
||||
%eps : float = prim::Constant[value=1.0e-9]()
|
||||
%cudnn_enabled : bool = prim::Constant[value=0]()
|
||||
%res : Tensor = )IR" +
|
||||
fnName +
|
||||
R"IR((%input, %none, %none, %none, %none, %training, %momentum, %eps, %cudnn_enabled)
|
||||
return (%res)
|
||||
)IR",
|
||||
&*graph);
|
||||
|
||||
graph->lint();
|
||||
DepthFirstGraphNodeIterator it(graph);
|
||||
|
||||
Node* n = nullptr;
|
||||
while ((n = it.next()) != nullptr) {
|
||||
if (n->kind() == nodeKind) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(n != nullptr);
|
||||
|
||||
AliasDb aliasDb(graph);
|
||||
EXPECT_FALSE(aliasDb.hasWriters(n));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
AliasAnalysisTest,
|
||||
BatchAndInstanceNormFixture,
|
||||
::testing::Values(
|
||||
std::make_tuple("aten::batch_norm", aten::batch_norm, false),
|
||||
std::make_tuple("aten::instance_norm", aten::instance_norm, false),
|
||||
std::make_tuple("aten::batch_norm", aten::batch_norm, true),
|
||||
std::make_tuple("aten::instance_norm", aten::instance_norm, true)));
|
||||
|
||||
TEST(WriteTrackingTest, Basic) {
|
||||
RegisterOperators reg({Operator(
|
||||
"prim::creates_alias(Tensor(a) x) -> Tensor(a)",
|
||||
|
|
|
|||
|
|
@ -627,6 +627,10 @@ void AliasDb::analyzeImpl(Node* node) {
|
|||
case prim::rpc_sync:
|
||||
case prim::rpc_remote:
|
||||
return analyzeRpcAsync(node);
|
||||
case aten::batch_norm:
|
||||
return analyzeBatchNorm(node);
|
||||
case aten::instance_norm:
|
||||
return analyzeInstanceNorm(node);
|
||||
case prim::GradOf:
|
||||
return analyzeGradOf(node);
|
||||
case prim::BroadcastMKLDNNTensors: {
|
||||
|
|
@ -1000,6 +1004,73 @@ void AliasDb::analyzeRpcAsync(Node* node) {
|
|||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
c10::optional<bool> getConstantBooleanInput(
|
||||
Node* node,
|
||||
const std::string& inputName) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
node->hasNamedInput(inputName), inputName + " input is expected");
|
||||
auto value = node->namedInput(inputName);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
value->type() == BoolType::get(),
|
||||
inputName + "training input is expected to be a bool");
|
||||
return constant_as<bool>(value);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// custom behavior for batch_norm because (a!)? annotations currently
|
||||
// aren't supported, and because behavior differs depending on the value of
|
||||
// training
|
||||
void AliasDb::analyzeBatchNorm(Node* node) {
|
||||
// we invoking freezing for inference, so we assume training will be folded to
|
||||
// a constant false to avoid needing to invoke freezing multiple times in
|
||||
// order to make batch norm weights constant
|
||||
for (Value* output : node->outputs()) {
|
||||
giveFreshAlias(output);
|
||||
}
|
||||
|
||||
if (isFrozen_) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto isTraining = getConstantBooleanInput(node, "training");
|
||||
|
||||
if (!isTraining.has_value() || *isTraining) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
node->hasNamedInput("running_mean"), "running_mean input is expected");
|
||||
auto runningMean = node->namedInput("running_mean");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
node->hasNamedInput("running_var"), "running_var input is expected");
|
||||
auto runningVar = node->namedInput("running_var");
|
||||
|
||||
registerWrite(runningMean, node);
|
||||
registerWrite(runningVar, node);
|
||||
}
|
||||
}
|
||||
|
||||
// custom behavior for instance_norm, because (a!)? annotations currently
|
||||
// aren't supported, and because behavior differs depending on the value of
|
||||
// use_input_stats
|
||||
void AliasDb::analyzeInstanceNorm(Node* node) {
|
||||
for (Value* output : node->outputs()) {
|
||||
giveFreshAlias(output);
|
||||
}
|
||||
|
||||
auto useInputStats = getConstantBooleanInput(node, "use_input_stats");
|
||||
|
||||
if (!useInputStats.has_value() || *useInputStats) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
node->hasNamedInput("running_mean"), "running_mean input is expected");
|
||||
auto runningMean = node->namedInput("running_mean");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
node->hasNamedInput("running_var"), "running_var input is expected");
|
||||
auto runningVar = node->namedInput("running_var");
|
||||
|
||||
registerWrite(runningMean, node);
|
||||
registerWrite(runningVar, node);
|
||||
}
|
||||
}
|
||||
|
||||
// SetAttr: writes to the `self` field
|
||||
void AliasDb::analyzeSetAttr(Node* node) {
|
||||
const auto self = node->inputs().at(0);
|
||||
|
|
|
|||
|
|
@ -215,6 +215,8 @@ class AliasDb {
|
|||
void analyzeFork(Node* node);
|
||||
void analyzeWait(Node* node);
|
||||
void analyzeRpcAsync(Node* node);
|
||||
void analyzeBatchNorm(Node* node);
|
||||
void analyzeInstanceNorm(Node* node);
|
||||
void analyzeGradOf(Node* node);
|
||||
void analyzeSetAttr(Node* node);
|
||||
void analyzeConservative(Node* node);
|
||||
|
|
|
|||
|
|
@ -224,6 +224,21 @@ c10::optional<IValue> toIValueProp(const Value* v) {
|
|||
}
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
// batch_norm and instance_norm have incorrect annotations, because
|
||||
// (a!)? annotations aren't supported, so these checks would fail.
|
||||
// Their behavior also varies depending on the `training` and
|
||||
// `use_input_stats` arguments.
|
||||
// There are custom implementations in alias_analysis.cpp for these ops.
|
||||
bool shouldIgnoreNode(const Node* n) {
|
||||
switch (n->kind()) {
|
||||
case aten::batch_norm:
|
||||
case aten::instance_norm:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void checkAliasAnnotation(
|
||||
|
|
@ -232,6 +247,9 @@ void checkAliasAnnotation(
|
|||
const std::string& unqualifiedOpName) {
|
||||
// Find the node that corresponds to our op name
|
||||
const auto node = findNodeForOp(*graph, unqualifiedOpName);
|
||||
if (shouldIgnoreNode(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Build the stack to use as input to the op
|
||||
Stack stack;
|
||||
|
|
|
|||
|
|
@ -8611,11 +8611,6 @@ op_db: List[OpInfo] = [
|
|||
dtypes=floating_types(),
|
||||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
skips=(
|
||||
# RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED at
|
||||
# "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":142, please report a bug to PyTorch
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
|
||||
),
|
||||
sample_inputs_func=sample_inputs_batch_norm),
|
||||
# This variant tests batch_norm with cuDNN disabled only on CUDA devices
|
||||
OpInfo('nn.functional.batch_norm',
|
||||
|
|
@ -8624,11 +8619,6 @@ op_db: List[OpInfo] = [
|
|||
dtypesIfCPU=empty_types(),
|
||||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
skips=(
|
||||
# RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED at
|
||||
# "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":142, please report a bug to PyTorch
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
|
||||
),
|
||||
decorators=[onlyCUDA, disablecuDNN],
|
||||
sample_inputs_func=sample_inputs_batch_norm),
|
||||
# We have to add 2 OpInfo entry for `igamma` and `igammac`.First is the
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user