diff --git a/test/cpp/jit/test_alias_analysis.cpp b/test/cpp/jit/test_alias_analysis.cpp index c87cbb04f42..dca6e3dce49 100644 --- a/test/cpp/jit/test_alias_analysis.cpp +++ b/test/cpp/jit/test_alias_analysis.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include namespace torch { @@ -481,6 +482,135 @@ TEST(AliasAnalysisTest, SafeToChangeAliasingRelationship) { EXPECT_TRUE(aliasDb.safeToChangeAliasingRelationship(vmap["d"], vmap["c"])); } +class BatchAndInstanceNormFixture + : public ::testing::TestWithParam> { +}; + +TEST_P(BatchAndInstanceNormFixture, BatchAndInstanceNorm) { + auto param = GetParam(); + auto fnName = std::get<0>(param); + auto nodeKind = std::get<1>(param); + auto isTraining = std::get<2>(param); + std::string isTrainingStr = std::to_string((int)isTraining); + + auto graph = std::make_shared(); + + parseIR( + R"IR( + graph(%input : Tensor, %running_mean : Tensor, %running_var : Tensor): + %none : NoneType = prim::Constant() + %training : bool = prim::Constant[value=)IR" + + isTrainingStr + R"IR(]() + %momentum : float = prim::Constant[value=1.0]() + %eps : float = prim::Constant[value=1.0e-9]() + %cudnn_enabled : bool = prim::Constant[value=0]() + %res : Tensor = )IR" + + fnName + + R"IR((%input, %none, %none, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled) + return (%res) + )IR", + &*graph); + + graph->lint(); + DepthFirstGraphNodeIterator it(graph); + + Node* n = nullptr; + while ((n = it.next()) != nullptr) { + if (n->kind() == nodeKind) { + break; + } + } + EXPECT_TRUE(n != nullptr); + + AliasDb aliasDb(graph); + EXPECT_TRUE(aliasDb.hasWriters(n) == isTraining); +} + +TEST_P(BatchAndInstanceNormFixture, BatchAndInstanceNormTrainingUnknown) { + auto param = GetParam(); + auto fnName = std::get<0>(param); + auto nodeKind = std::get<1>(param); + + auto graph = std::make_shared(); + + parseIR( + R"IR( + graph(%input : Tensor, %running_mean : Tensor, %running_var : Tensor, %training : bool): + %none : NoneType = prim::Constant() + %momentum : float = prim::Constant[value=1.0]() + %eps : float = prim::Constant[value=1.0e-9]() + %cudnn_enabled : bool = prim::Constant[value=0]() + %res : Tensor = )IR" + + fnName + + R"IR((%input, %none, %none, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled) + return (%res) + )IR", + &*graph); + + graph->lint(); + DepthFirstGraphNodeIterator it(graph); + + Node* n = nullptr; + while ((n = it.next()) != nullptr) { + if (n->kind() == nodeKind) { + break; + } + } + EXPECT_TRUE(n != nullptr); + + AliasDb aliasDb(graph); + EXPECT_TRUE(aliasDb.hasWriters(n)); +} + +TEST_P(BatchAndInstanceNormFixture, BatchNormTrainingWithNoMeanOrVar) { + auto param = GetParam(); + auto fnName = std::get<0>(param); + auto nodeKind = std::get<1>(param); + auto isTraining = std::get<2>(param); + std::string isTrainingStr = std::to_string((int)isTraining); + + auto graph = std::make_shared(); + + parseIR( + R"IR( + graph(%input : Tensor): + %none : NoneType = prim::Constant() + %training : bool = prim::Constant[value=)IR" + + isTrainingStr + R"IR(]() + %momentum : float = prim::Constant[value=1.0]() + %eps : float = prim::Constant[value=1.0e-9]() + %cudnn_enabled : bool = prim::Constant[value=0]() + %res : Tensor = )IR" + + fnName + + R"IR((%input, %none, %none, %none, %none, %training, %momentum, %eps, %cudnn_enabled) + return (%res) + )IR", + &*graph); + + graph->lint(); + DepthFirstGraphNodeIterator it(graph); + + Node* n = nullptr; + while ((n = it.next()) != nullptr) { + if (n->kind() == nodeKind) { + break; + } + } + EXPECT_TRUE(n != nullptr); + + AliasDb aliasDb(graph); + EXPECT_FALSE(aliasDb.hasWriters(n)); +} + +INSTANTIATE_TEST_SUITE_P( + AliasAnalysisTest, + BatchAndInstanceNormFixture, + ::testing::Values( + std::make_tuple("aten::batch_norm", aten::batch_norm, false), + std::make_tuple("aten::instance_norm", aten::instance_norm, false), + std::make_tuple("aten::batch_norm", aten::batch_norm, true), + std::make_tuple("aten::instance_norm", aten::instance_norm, true))); + TEST(WriteTrackingTest, Basic) { RegisterOperators reg({Operator( "prim::creates_alias(Tensor(a) x) -> Tensor(a)", diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index 8f24237c60c..865635f5a1f 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -627,6 +627,10 @@ void AliasDb::analyzeImpl(Node* node) { case prim::rpc_sync: case prim::rpc_remote: return analyzeRpcAsync(node); + case aten::batch_norm: + return analyzeBatchNorm(node); + case aten::instance_norm: + return analyzeInstanceNorm(node); case prim::GradOf: return analyzeGradOf(node); case prim::BroadcastMKLDNNTensors: { @@ -1000,6 +1004,73 @@ void AliasDb::analyzeRpcAsync(Node* node) { } } +namespace { +c10::optional getConstantBooleanInput( + Node* node, + const std::string& inputName) { + TORCH_INTERNAL_ASSERT( + node->hasNamedInput(inputName), inputName + " input is expected"); + auto value = node->namedInput(inputName); + TORCH_INTERNAL_ASSERT( + value->type() == BoolType::get(), + inputName + "training input is expected to be a bool"); + return constant_as(value); +} +} // namespace + +// custom behavior for batch_norm because (a!)? annotations currently +// aren't supported, and because behavior differs depending on the value of +// training +void AliasDb::analyzeBatchNorm(Node* node) { + // we invoking freezing for inference, so we assume training will be folded to + // a constant false to avoid needing to invoke freezing multiple times in + // order to make batch norm weights constant + for (Value* output : node->outputs()) { + giveFreshAlias(output); + } + + if (isFrozen_) { + return; + } + + auto isTraining = getConstantBooleanInput(node, "training"); + + if (!isTraining.has_value() || *isTraining) { + TORCH_INTERNAL_ASSERT( + node->hasNamedInput("running_mean"), "running_mean input is expected"); + auto runningMean = node->namedInput("running_mean"); + TORCH_INTERNAL_ASSERT( + node->hasNamedInput("running_var"), "running_var input is expected"); + auto runningVar = node->namedInput("running_var"); + + registerWrite(runningMean, node); + registerWrite(runningVar, node); + } +} + +// custom behavior for instance_norm, because (a!)? annotations currently +// aren't supported, and because behavior differs depending on the value of +// use_input_stats +void AliasDb::analyzeInstanceNorm(Node* node) { + for (Value* output : node->outputs()) { + giveFreshAlias(output); + } + + auto useInputStats = getConstantBooleanInput(node, "use_input_stats"); + + if (!useInputStats.has_value() || *useInputStats) { + TORCH_INTERNAL_ASSERT( + node->hasNamedInput("running_mean"), "running_mean input is expected"); + auto runningMean = node->namedInput("running_mean"); + TORCH_INTERNAL_ASSERT( + node->hasNamedInput("running_var"), "running_var input is expected"); + auto runningVar = node->namedInput("running_var"); + + registerWrite(runningMean, node); + registerWrite(runningVar, node); + } +} + // SetAttr: writes to the `self` field void AliasDb::analyzeSetAttr(Node* node) { const auto self = node->inputs().at(0); diff --git a/torch/csrc/jit/ir/alias_analysis.h b/torch/csrc/jit/ir/alias_analysis.h index d128f2829fc..c2211a09ec5 100644 --- a/torch/csrc/jit/ir/alias_analysis.h +++ b/torch/csrc/jit/ir/alias_analysis.h @@ -215,6 +215,8 @@ class AliasDb { void analyzeFork(Node* node); void analyzeWait(Node* node); void analyzeRpcAsync(Node* node); + void analyzeBatchNorm(Node* node); + void analyzeInstanceNorm(Node* node); void analyzeGradOf(Node* node); void analyzeSetAttr(Node* node); void analyzeConservative(Node* node); diff --git a/torch/csrc/jit/passes/utils/check_alias_annotation.cpp b/torch/csrc/jit/passes/utils/check_alias_annotation.cpp index 0d5dde95e5a..d538e33a213 100644 --- a/torch/csrc/jit/passes/utils/check_alias_annotation.cpp +++ b/torch/csrc/jit/passes/utils/check_alias_annotation.cpp @@ -224,6 +224,21 @@ c10::optional toIValueProp(const Value* v) { } return c10::nullopt; } + +// batch_norm and instance_norm have incorrect annotations, because +// (a!)? annotations aren't supported, so these checks would fail. +// Their behavior also varies depending on the `training` and +// `use_input_stats` arguments. +// There are custom implementations in alias_analysis.cpp for these ops. +bool shouldIgnoreNode(const Node* n) { + switch (n->kind()) { + case aten::batch_norm: + case aten::instance_norm: + return true; + default: + return false; + } +} } // namespace void checkAliasAnnotation( @@ -232,6 +247,9 @@ void checkAliasAnnotation( const std::string& unqualifiedOpName) { // Find the node that corresponds to our op name const auto node = findNodeForOp(*graph, unqualifiedOpName); + if (shouldIgnoreNode(node)) { + return; + } // Build the stack to use as input to the op Stack stack; diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index ffccbb0d72f..2057db7812f 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -8611,11 +8611,6 @@ op_db: List[OpInfo] = [ dtypes=floating_types(), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), supports_out=False, - skips=( - # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED at - # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":142, please report a bug to PyTorch - DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), - ), sample_inputs_func=sample_inputs_batch_norm), # This variant tests batch_norm with cuDNN disabled only on CUDA devices OpInfo('nn.functional.batch_norm', @@ -8624,11 +8619,6 @@ op_db: List[OpInfo] = [ dtypesIfCPU=empty_types(), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), supports_out=False, - skips=( - # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED at - # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":142, please report a bug to PyTorch - DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), - ), decorators=[onlyCUDA, disablecuDNN], sample_inputs_func=sample_inputs_batch_norm), # We have to add 2 OpInfo entry for `igamma` and `igammac`.First is the