[JIT] Add special cases for batch_norm, instance_norm in alias_analysis (#66554)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66554

In native_functions.yaml, the schemas for batch_norm and instance_norm
are incorrect: the inputs `running_mean` and `running_var` are mutated,
but are not marked as such in the function schema. Since `(a!)?`
annotations are currently not working (see #65760), this instead adds a
special case to `alias_anaysis.cpp`. If the value of `training` or
`use_input_stats` is known to be `false`, then `alias_analysis` will
mark the input as _not_ being written to.

Test Plan:
Removed the `skip` annotation on the following test, and added a special
exception in `check_alias_annotations`:
```
python test/test_ops.py -k test_variant_consistency_jit_nn_functional_batch_norm
```

Also:
```
./build/bin/test_jit --gtest_filter="*BatchAndInstanceNormFixture*"
```

Imported from OSS

Reviewed By: eellison

Differential Revision: D31612339

fbshipit-source-id: 12ca61b782b9e41e06883ba080a276209dc435bb
This commit is contained in:
David Berard 2021-10-20 10:19:52 -07:00 committed by Facebook GitHub Bot
parent cf77bd4cf4
commit e86d8323cb
5 changed files with 221 additions and 10 deletions

View File

@ -5,6 +5,7 @@
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/graph_iterator.h>
#include <torch/csrc/utils/memory.h>
namespace torch {
@ -481,6 +482,135 @@ TEST(AliasAnalysisTest, SafeToChangeAliasingRelationship) {
EXPECT_TRUE(aliasDb.safeToChangeAliasingRelationship(vmap["d"], vmap["c"]));
}
class BatchAndInstanceNormFixture
: public ::testing::TestWithParam<std::tuple<std::string, NodeKind, bool>> {
};
TEST_P(BatchAndInstanceNormFixture, BatchAndInstanceNorm) {
auto param = GetParam();
auto fnName = std::get<0>(param);
auto nodeKind = std::get<1>(param);
auto isTraining = std::get<2>(param);
std::string isTrainingStr = std::to_string((int)isTraining);
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%input : Tensor, %running_mean : Tensor, %running_var : Tensor):
%none : NoneType = prim::Constant()
%training : bool = prim::Constant[value=)IR" +
isTrainingStr + R"IR(]()
%momentum : float = prim::Constant[value=1.0]()
%eps : float = prim::Constant[value=1.0e-9]()
%cudnn_enabled : bool = prim::Constant[value=0]()
%res : Tensor = )IR" +
fnName +
R"IR((%input, %none, %none, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled)
return (%res)
)IR",
&*graph);
graph->lint();
DepthFirstGraphNodeIterator it(graph);
Node* n = nullptr;
while ((n = it.next()) != nullptr) {
if (n->kind() == nodeKind) {
break;
}
}
EXPECT_TRUE(n != nullptr);
AliasDb aliasDb(graph);
EXPECT_TRUE(aliasDb.hasWriters(n) == isTraining);
}
TEST_P(BatchAndInstanceNormFixture, BatchAndInstanceNormTrainingUnknown) {
auto param = GetParam();
auto fnName = std::get<0>(param);
auto nodeKind = std::get<1>(param);
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%input : Tensor, %running_mean : Tensor, %running_var : Tensor, %training : bool):
%none : NoneType = prim::Constant()
%momentum : float = prim::Constant[value=1.0]()
%eps : float = prim::Constant[value=1.0e-9]()
%cudnn_enabled : bool = prim::Constant[value=0]()
%res : Tensor = )IR" +
fnName +
R"IR((%input, %none, %none, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled)
return (%res)
)IR",
&*graph);
graph->lint();
DepthFirstGraphNodeIterator it(graph);
Node* n = nullptr;
while ((n = it.next()) != nullptr) {
if (n->kind() == nodeKind) {
break;
}
}
EXPECT_TRUE(n != nullptr);
AliasDb aliasDb(graph);
EXPECT_TRUE(aliasDb.hasWriters(n));
}
TEST_P(BatchAndInstanceNormFixture, BatchNormTrainingWithNoMeanOrVar) {
auto param = GetParam();
auto fnName = std::get<0>(param);
auto nodeKind = std::get<1>(param);
auto isTraining = std::get<2>(param);
std::string isTrainingStr = std::to_string((int)isTraining);
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%input : Tensor):
%none : NoneType = prim::Constant()
%training : bool = prim::Constant[value=)IR" +
isTrainingStr + R"IR(]()
%momentum : float = prim::Constant[value=1.0]()
%eps : float = prim::Constant[value=1.0e-9]()
%cudnn_enabled : bool = prim::Constant[value=0]()
%res : Tensor = )IR" +
fnName +
R"IR((%input, %none, %none, %none, %none, %training, %momentum, %eps, %cudnn_enabled)
return (%res)
)IR",
&*graph);
graph->lint();
DepthFirstGraphNodeIterator it(graph);
Node* n = nullptr;
while ((n = it.next()) != nullptr) {
if (n->kind() == nodeKind) {
break;
}
}
EXPECT_TRUE(n != nullptr);
AliasDb aliasDb(graph);
EXPECT_FALSE(aliasDb.hasWriters(n));
}
INSTANTIATE_TEST_SUITE_P(
AliasAnalysisTest,
BatchAndInstanceNormFixture,
::testing::Values(
std::make_tuple("aten::batch_norm", aten::batch_norm, false),
std::make_tuple("aten::instance_norm", aten::instance_norm, false),
std::make_tuple("aten::batch_norm", aten::batch_norm, true),
std::make_tuple("aten::instance_norm", aten::instance_norm, true)));
TEST(WriteTrackingTest, Basic) {
RegisterOperators reg({Operator(
"prim::creates_alias(Tensor(a) x) -> Tensor(a)",

View File

@ -627,6 +627,10 @@ void AliasDb::analyzeImpl(Node* node) {
case prim::rpc_sync:
case prim::rpc_remote:
return analyzeRpcAsync(node);
case aten::batch_norm:
return analyzeBatchNorm(node);
case aten::instance_norm:
return analyzeInstanceNorm(node);
case prim::GradOf:
return analyzeGradOf(node);
case prim::BroadcastMKLDNNTensors: {
@ -1000,6 +1004,73 @@ void AliasDb::analyzeRpcAsync(Node* node) {
}
}
namespace {
c10::optional<bool> getConstantBooleanInput(
Node* node,
const std::string& inputName) {
TORCH_INTERNAL_ASSERT(
node->hasNamedInput(inputName), inputName + " input is expected");
auto value = node->namedInput(inputName);
TORCH_INTERNAL_ASSERT(
value->type() == BoolType::get(),
inputName + "training input is expected to be a bool");
return constant_as<bool>(value);
}
} // namespace
// custom behavior for batch_norm because (a!)? annotations currently
// aren't supported, and because behavior differs depending on the value of
// training
void AliasDb::analyzeBatchNorm(Node* node) {
// we invoking freezing for inference, so we assume training will be folded to
// a constant false to avoid needing to invoke freezing multiple times in
// order to make batch norm weights constant
for (Value* output : node->outputs()) {
giveFreshAlias(output);
}
if (isFrozen_) {
return;
}
auto isTraining = getConstantBooleanInput(node, "training");
if (!isTraining.has_value() || *isTraining) {
TORCH_INTERNAL_ASSERT(
node->hasNamedInput("running_mean"), "running_mean input is expected");
auto runningMean = node->namedInput("running_mean");
TORCH_INTERNAL_ASSERT(
node->hasNamedInput("running_var"), "running_var input is expected");
auto runningVar = node->namedInput("running_var");
registerWrite(runningMean, node);
registerWrite(runningVar, node);
}
}
// custom behavior for instance_norm, because (a!)? annotations currently
// aren't supported, and because behavior differs depending on the value of
// use_input_stats
void AliasDb::analyzeInstanceNorm(Node* node) {
for (Value* output : node->outputs()) {
giveFreshAlias(output);
}
auto useInputStats = getConstantBooleanInput(node, "use_input_stats");
if (!useInputStats.has_value() || *useInputStats) {
TORCH_INTERNAL_ASSERT(
node->hasNamedInput("running_mean"), "running_mean input is expected");
auto runningMean = node->namedInput("running_mean");
TORCH_INTERNAL_ASSERT(
node->hasNamedInput("running_var"), "running_var input is expected");
auto runningVar = node->namedInput("running_var");
registerWrite(runningMean, node);
registerWrite(runningVar, node);
}
}
// SetAttr: writes to the `self` field
void AliasDb::analyzeSetAttr(Node* node) {
const auto self = node->inputs().at(0);

View File

@ -215,6 +215,8 @@ class AliasDb {
void analyzeFork(Node* node);
void analyzeWait(Node* node);
void analyzeRpcAsync(Node* node);
void analyzeBatchNorm(Node* node);
void analyzeInstanceNorm(Node* node);
void analyzeGradOf(Node* node);
void analyzeSetAttr(Node* node);
void analyzeConservative(Node* node);

View File

@ -224,6 +224,21 @@ c10::optional<IValue> toIValueProp(const Value* v) {
}
return c10::nullopt;
}
// batch_norm and instance_norm have incorrect annotations, because
// (a!)? annotations aren't supported, so these checks would fail.
// Their behavior also varies depending on the `training` and
// `use_input_stats` arguments.
// There are custom implementations in alias_analysis.cpp for these ops.
bool shouldIgnoreNode(const Node* n) {
switch (n->kind()) {
case aten::batch_norm:
case aten::instance_norm:
return true;
default:
return false;
}
}
} // namespace
void checkAliasAnnotation(
@ -232,6 +247,9 @@ void checkAliasAnnotation(
const std::string& unqualifiedOpName) {
// Find the node that corresponds to our op name
const auto node = findNodeForOp(*graph, unqualifiedOpName);
if (shouldIgnoreNode(node)) {
return;
}
// Build the stack to use as input to the op
Stack stack;

View File

@ -8611,11 +8611,6 @@ op_db: List[OpInfo] = [
dtypes=floating_types(),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
skips=(
# RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED at
# "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":142, please report a bug to PyTorch
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
),
sample_inputs_func=sample_inputs_batch_norm),
# This variant tests batch_norm with cuDNN disabled only on CUDA devices
OpInfo('nn.functional.batch_norm',
@ -8624,11 +8619,6 @@ op_db: List[OpInfo] = [
dtypesIfCPU=empty_types(),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
skips=(
# RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED at
# "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":142, please report a bug to PyTorch
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
),
decorators=[onlyCUDA, disablecuDNN],
sample_inputs_func=sample_inputs_batch_norm),
# We have to add 2 OpInfo entry for `igamma` and `igammac`.First is the