mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[JIT] Add special cases for batch_norm, instance_norm in alias_analysis (#66554)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66554 In native_functions.yaml, the schemas for batch_norm and instance_norm are incorrect: the inputs `running_mean` and `running_var` are mutated, but are not marked as such in the function schema. Since `(a!)?` annotations are currently not working (see #65760), this instead adds a special case to `alias_anaysis.cpp`. If the value of `training` or `use_input_stats` is known to be `false`, then `alias_analysis` will mark the input as _not_ being written to. Test Plan: Removed the `skip` annotation on the following test, and added a special exception in `check_alias_annotations`: ``` python test/test_ops.py -k test_variant_consistency_jit_nn_functional_batch_norm ``` Also: ``` ./build/bin/test_jit --gtest_filter="*BatchAndInstanceNormFixture*" ``` Imported from OSS Reviewed By: eellison Differential Revision: D31612339 fbshipit-source-id: 12ca61b782b9e41e06883ba080a276209dc435bb
This commit is contained in:
parent
cf77bd4cf4
commit
e86d8323cb
|
|
@ -5,6 +5,7 @@
|
||||||
#include <torch/csrc/jit/ir/alias_analysis.h>
|
#include <torch/csrc/jit/ir/alias_analysis.h>
|
||||||
#include <torch/csrc/jit/ir/irparser.h>
|
#include <torch/csrc/jit/ir/irparser.h>
|
||||||
#include <torch/csrc/jit/runtime/custom_operator.h>
|
#include <torch/csrc/jit/runtime/custom_operator.h>
|
||||||
|
#include <torch/csrc/jit/runtime/graph_iterator.h>
|
||||||
#include <torch/csrc/utils/memory.h>
|
#include <torch/csrc/utils/memory.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
|
|
@ -481,6 +482,135 @@ TEST(AliasAnalysisTest, SafeToChangeAliasingRelationship) {
|
||||||
EXPECT_TRUE(aliasDb.safeToChangeAliasingRelationship(vmap["d"], vmap["c"]));
|
EXPECT_TRUE(aliasDb.safeToChangeAliasingRelationship(vmap["d"], vmap["c"]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class BatchAndInstanceNormFixture
|
||||||
|
: public ::testing::TestWithParam<std::tuple<std::string, NodeKind, bool>> {
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_P(BatchAndInstanceNormFixture, BatchAndInstanceNorm) {
|
||||||
|
auto param = GetParam();
|
||||||
|
auto fnName = std::get<0>(param);
|
||||||
|
auto nodeKind = std::get<1>(param);
|
||||||
|
auto isTraining = std::get<2>(param);
|
||||||
|
std::string isTrainingStr = std::to_string((int)isTraining);
|
||||||
|
|
||||||
|
auto graph = std::make_shared<Graph>();
|
||||||
|
|
||||||
|
parseIR(
|
||||||
|
R"IR(
|
||||||
|
graph(%input : Tensor, %running_mean : Tensor, %running_var : Tensor):
|
||||||
|
%none : NoneType = prim::Constant()
|
||||||
|
%training : bool = prim::Constant[value=)IR" +
|
||||||
|
isTrainingStr + R"IR(]()
|
||||||
|
%momentum : float = prim::Constant[value=1.0]()
|
||||||
|
%eps : float = prim::Constant[value=1.0e-9]()
|
||||||
|
%cudnn_enabled : bool = prim::Constant[value=0]()
|
||||||
|
%res : Tensor = )IR" +
|
||||||
|
fnName +
|
||||||
|
R"IR((%input, %none, %none, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled)
|
||||||
|
return (%res)
|
||||||
|
)IR",
|
||||||
|
&*graph);
|
||||||
|
|
||||||
|
graph->lint();
|
||||||
|
DepthFirstGraphNodeIterator it(graph);
|
||||||
|
|
||||||
|
Node* n = nullptr;
|
||||||
|
while ((n = it.next()) != nullptr) {
|
||||||
|
if (n->kind() == nodeKind) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_TRUE(n != nullptr);
|
||||||
|
|
||||||
|
AliasDb aliasDb(graph);
|
||||||
|
EXPECT_TRUE(aliasDb.hasWriters(n) == isTraining);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(BatchAndInstanceNormFixture, BatchAndInstanceNormTrainingUnknown) {
|
||||||
|
auto param = GetParam();
|
||||||
|
auto fnName = std::get<0>(param);
|
||||||
|
auto nodeKind = std::get<1>(param);
|
||||||
|
|
||||||
|
auto graph = std::make_shared<Graph>();
|
||||||
|
|
||||||
|
parseIR(
|
||||||
|
R"IR(
|
||||||
|
graph(%input : Tensor, %running_mean : Tensor, %running_var : Tensor, %training : bool):
|
||||||
|
%none : NoneType = prim::Constant()
|
||||||
|
%momentum : float = prim::Constant[value=1.0]()
|
||||||
|
%eps : float = prim::Constant[value=1.0e-9]()
|
||||||
|
%cudnn_enabled : bool = prim::Constant[value=0]()
|
||||||
|
%res : Tensor = )IR" +
|
||||||
|
fnName +
|
||||||
|
R"IR((%input, %none, %none, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled)
|
||||||
|
return (%res)
|
||||||
|
)IR",
|
||||||
|
&*graph);
|
||||||
|
|
||||||
|
graph->lint();
|
||||||
|
DepthFirstGraphNodeIterator it(graph);
|
||||||
|
|
||||||
|
Node* n = nullptr;
|
||||||
|
while ((n = it.next()) != nullptr) {
|
||||||
|
if (n->kind() == nodeKind) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_TRUE(n != nullptr);
|
||||||
|
|
||||||
|
AliasDb aliasDb(graph);
|
||||||
|
EXPECT_TRUE(aliasDb.hasWriters(n));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(BatchAndInstanceNormFixture, BatchNormTrainingWithNoMeanOrVar) {
|
||||||
|
auto param = GetParam();
|
||||||
|
auto fnName = std::get<0>(param);
|
||||||
|
auto nodeKind = std::get<1>(param);
|
||||||
|
auto isTraining = std::get<2>(param);
|
||||||
|
std::string isTrainingStr = std::to_string((int)isTraining);
|
||||||
|
|
||||||
|
auto graph = std::make_shared<Graph>();
|
||||||
|
|
||||||
|
parseIR(
|
||||||
|
R"IR(
|
||||||
|
graph(%input : Tensor):
|
||||||
|
%none : NoneType = prim::Constant()
|
||||||
|
%training : bool = prim::Constant[value=)IR" +
|
||||||
|
isTrainingStr + R"IR(]()
|
||||||
|
%momentum : float = prim::Constant[value=1.0]()
|
||||||
|
%eps : float = prim::Constant[value=1.0e-9]()
|
||||||
|
%cudnn_enabled : bool = prim::Constant[value=0]()
|
||||||
|
%res : Tensor = )IR" +
|
||||||
|
fnName +
|
||||||
|
R"IR((%input, %none, %none, %none, %none, %training, %momentum, %eps, %cudnn_enabled)
|
||||||
|
return (%res)
|
||||||
|
)IR",
|
||||||
|
&*graph);
|
||||||
|
|
||||||
|
graph->lint();
|
||||||
|
DepthFirstGraphNodeIterator it(graph);
|
||||||
|
|
||||||
|
Node* n = nullptr;
|
||||||
|
while ((n = it.next()) != nullptr) {
|
||||||
|
if (n->kind() == nodeKind) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_TRUE(n != nullptr);
|
||||||
|
|
||||||
|
AliasDb aliasDb(graph);
|
||||||
|
EXPECT_FALSE(aliasDb.hasWriters(n));
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
AliasAnalysisTest,
|
||||||
|
BatchAndInstanceNormFixture,
|
||||||
|
::testing::Values(
|
||||||
|
std::make_tuple("aten::batch_norm", aten::batch_norm, false),
|
||||||
|
std::make_tuple("aten::instance_norm", aten::instance_norm, false),
|
||||||
|
std::make_tuple("aten::batch_norm", aten::batch_norm, true),
|
||||||
|
std::make_tuple("aten::instance_norm", aten::instance_norm, true)));
|
||||||
|
|
||||||
TEST(WriteTrackingTest, Basic) {
|
TEST(WriteTrackingTest, Basic) {
|
||||||
RegisterOperators reg({Operator(
|
RegisterOperators reg({Operator(
|
||||||
"prim::creates_alias(Tensor(a) x) -> Tensor(a)",
|
"prim::creates_alias(Tensor(a) x) -> Tensor(a)",
|
||||||
|
|
|
||||||
|
|
@ -627,6 +627,10 @@ void AliasDb::analyzeImpl(Node* node) {
|
||||||
case prim::rpc_sync:
|
case prim::rpc_sync:
|
||||||
case prim::rpc_remote:
|
case prim::rpc_remote:
|
||||||
return analyzeRpcAsync(node);
|
return analyzeRpcAsync(node);
|
||||||
|
case aten::batch_norm:
|
||||||
|
return analyzeBatchNorm(node);
|
||||||
|
case aten::instance_norm:
|
||||||
|
return analyzeInstanceNorm(node);
|
||||||
case prim::GradOf:
|
case prim::GradOf:
|
||||||
return analyzeGradOf(node);
|
return analyzeGradOf(node);
|
||||||
case prim::BroadcastMKLDNNTensors: {
|
case prim::BroadcastMKLDNNTensors: {
|
||||||
|
|
@ -1000,6 +1004,73 @@ void AliasDb::analyzeRpcAsync(Node* node) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
c10::optional<bool> getConstantBooleanInput(
|
||||||
|
Node* node,
|
||||||
|
const std::string& inputName) {
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
node->hasNamedInput(inputName), inputName + " input is expected");
|
||||||
|
auto value = node->namedInput(inputName);
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
value->type() == BoolType::get(),
|
||||||
|
inputName + "training input is expected to be a bool");
|
||||||
|
return constant_as<bool>(value);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// custom behavior for batch_norm because (a!)? annotations currently
|
||||||
|
// aren't supported, and because behavior differs depending on the value of
|
||||||
|
// training
|
||||||
|
void AliasDb::analyzeBatchNorm(Node* node) {
|
||||||
|
// we invoking freezing for inference, so we assume training will be folded to
|
||||||
|
// a constant false to avoid needing to invoke freezing multiple times in
|
||||||
|
// order to make batch norm weights constant
|
||||||
|
for (Value* output : node->outputs()) {
|
||||||
|
giveFreshAlias(output);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isFrozen_) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto isTraining = getConstantBooleanInput(node, "training");
|
||||||
|
|
||||||
|
if (!isTraining.has_value() || *isTraining) {
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
node->hasNamedInput("running_mean"), "running_mean input is expected");
|
||||||
|
auto runningMean = node->namedInput("running_mean");
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
node->hasNamedInput("running_var"), "running_var input is expected");
|
||||||
|
auto runningVar = node->namedInput("running_var");
|
||||||
|
|
||||||
|
registerWrite(runningMean, node);
|
||||||
|
registerWrite(runningVar, node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// custom behavior for instance_norm, because (a!)? annotations currently
|
||||||
|
// aren't supported, and because behavior differs depending on the value of
|
||||||
|
// use_input_stats
|
||||||
|
void AliasDb::analyzeInstanceNorm(Node* node) {
|
||||||
|
for (Value* output : node->outputs()) {
|
||||||
|
giveFreshAlias(output);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto useInputStats = getConstantBooleanInput(node, "use_input_stats");
|
||||||
|
|
||||||
|
if (!useInputStats.has_value() || *useInputStats) {
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
node->hasNamedInput("running_mean"), "running_mean input is expected");
|
||||||
|
auto runningMean = node->namedInput("running_mean");
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
node->hasNamedInput("running_var"), "running_var input is expected");
|
||||||
|
auto runningVar = node->namedInput("running_var");
|
||||||
|
|
||||||
|
registerWrite(runningMean, node);
|
||||||
|
registerWrite(runningVar, node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// SetAttr: writes to the `self` field
|
// SetAttr: writes to the `self` field
|
||||||
void AliasDb::analyzeSetAttr(Node* node) {
|
void AliasDb::analyzeSetAttr(Node* node) {
|
||||||
const auto self = node->inputs().at(0);
|
const auto self = node->inputs().at(0);
|
||||||
|
|
|
||||||
|
|
@ -215,6 +215,8 @@ class AliasDb {
|
||||||
void analyzeFork(Node* node);
|
void analyzeFork(Node* node);
|
||||||
void analyzeWait(Node* node);
|
void analyzeWait(Node* node);
|
||||||
void analyzeRpcAsync(Node* node);
|
void analyzeRpcAsync(Node* node);
|
||||||
|
void analyzeBatchNorm(Node* node);
|
||||||
|
void analyzeInstanceNorm(Node* node);
|
||||||
void analyzeGradOf(Node* node);
|
void analyzeGradOf(Node* node);
|
||||||
void analyzeSetAttr(Node* node);
|
void analyzeSetAttr(Node* node);
|
||||||
void analyzeConservative(Node* node);
|
void analyzeConservative(Node* node);
|
||||||
|
|
|
||||||
|
|
@ -224,6 +224,21 @@ c10::optional<IValue> toIValueProp(const Value* v) {
|
||||||
}
|
}
|
||||||
return c10::nullopt;
|
return c10::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// batch_norm and instance_norm have incorrect annotations, because
|
||||||
|
// (a!)? annotations aren't supported, so these checks would fail.
|
||||||
|
// Their behavior also varies depending on the `training` and
|
||||||
|
// `use_input_stats` arguments.
|
||||||
|
// There are custom implementations in alias_analysis.cpp for these ops.
|
||||||
|
bool shouldIgnoreNode(const Node* n) {
|
||||||
|
switch (n->kind()) {
|
||||||
|
case aten::batch_norm:
|
||||||
|
case aten::instance_norm:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void checkAliasAnnotation(
|
void checkAliasAnnotation(
|
||||||
|
|
@ -232,6 +247,9 @@ void checkAliasAnnotation(
|
||||||
const std::string& unqualifiedOpName) {
|
const std::string& unqualifiedOpName) {
|
||||||
// Find the node that corresponds to our op name
|
// Find the node that corresponds to our op name
|
||||||
const auto node = findNodeForOp(*graph, unqualifiedOpName);
|
const auto node = findNodeForOp(*graph, unqualifiedOpName);
|
||||||
|
if (shouldIgnoreNode(node)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Build the stack to use as input to the op
|
// Build the stack to use as input to the op
|
||||||
Stack stack;
|
Stack stack;
|
||||||
|
|
|
||||||
|
|
@ -8611,11 +8611,6 @@ op_db: List[OpInfo] = [
|
||||||
dtypes=floating_types(),
|
dtypes=floating_types(),
|
||||||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||||
supports_out=False,
|
supports_out=False,
|
||||||
skips=(
|
|
||||||
# RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED at
|
|
||||||
# "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":142, please report a bug to PyTorch
|
|
||||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
|
|
||||||
),
|
|
||||||
sample_inputs_func=sample_inputs_batch_norm),
|
sample_inputs_func=sample_inputs_batch_norm),
|
||||||
# This variant tests batch_norm with cuDNN disabled only on CUDA devices
|
# This variant tests batch_norm with cuDNN disabled only on CUDA devices
|
||||||
OpInfo('nn.functional.batch_norm',
|
OpInfo('nn.functional.batch_norm',
|
||||||
|
|
@ -8624,11 +8619,6 @@ op_db: List[OpInfo] = [
|
||||||
dtypesIfCPU=empty_types(),
|
dtypesIfCPU=empty_types(),
|
||||||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||||
supports_out=False,
|
supports_out=False,
|
||||||
skips=(
|
|
||||||
# RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED at
|
|
||||||
# "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":142, please report a bug to PyTorch
|
|
||||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
|
|
||||||
),
|
|
||||||
decorators=[onlyCUDA, disablecuDNN],
|
decorators=[onlyCUDA, disablecuDNN],
|
||||||
sample_inputs_func=sample_inputs_batch_norm),
|
sample_inputs_func=sample_inputs_batch_norm),
|
||||||
# We have to add 2 OpInfo entry for `igamma` and `igammac`.First is the
|
# We have to add 2 OpInfo entry for `igamma` and `igammac`.First is the
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user