mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22175 - Rename AliasAnalysisKind::DEFAULT to AliasAnalysisKind::CONSERVATIVE - Introduce AliasAnalysisKind::FROM_SCHEMA that means the alias annotations of the schema should be honored - Introduce AliasAnalysisKind::INTERNAL_SPECIAL_CASE to be able to run assertions that internal special cased ops are treated correctly - aten:: and prim:: ops are not treated as special cases anymore, but just use AliasAnalysisKind::FROM_SCHEMA - There's a set of assertions to ensure that aten:: and prim:: ops are all correctly set up to use AliasAnalysisKind::FROM_SCHEMA. Once this PR lands and passes all tests, we will remove those assertions and open up for the possibility of different AliasAnalysisKind settings for aten:: and prim:: ops Differential Revision: D15929595 fbshipit-source-id: 7c6a9d4d29e13b8c9a856062cd6fb3f8a46a2e0d
215 lines
6.8 KiB
C++
215 lines
6.8 KiB
C++
#include <torch/csrc/jit/constants.h>
|
|
#include <ATen/core/functional.h>
|
|
#include <torch/csrc/autograd/variable.h>
|
|
#include <torch/csrc/jit/custom_operator.h>
|
|
#include <torch/csrc/jit/operator.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
namespace {
|
|
c10::OperatorOptions aliasAnalysisInternalSpecialCase() {
|
|
c10::OperatorOptions options;
|
|
options.setAliasAnalysis(AliasAnalysisKind::INTERNAL_SPECIAL_CASE);
|
|
return options;
|
|
}
|
|
} // namespace
|
|
|
|
Value* insertConstant(
|
|
Graph& g,
|
|
const IValue& val,
|
|
const c10::TypePtr& result_type,
|
|
c10::optional<SourceRange> loc,
|
|
c10::optional<ScopePtr> scope) {
|
|
auto value = tryInsertConstant(g, val, result_type, loc, scope);
|
|
if (value) {
|
|
return *value;
|
|
}
|
|
throw constant_not_supported_error(
|
|
"Unsupported value kind: " + val.tagKind());
|
|
}
|
|
|
|
// IValue -> Constant node
|
|
c10::optional<Value*> tryInsertConstant(
|
|
Graph& g,
|
|
const IValue& val,
|
|
const c10::TypePtr& result_type,
|
|
c10::optional<SourceRange> loc,
|
|
c10::optional<ScopePtr> scope) {
|
|
Node* n = g.create(prim::Constant);
|
|
if (val.isTensor()) {
|
|
at::Tensor ref = val.toTensor();
|
|
if (!ref.defined()) {
|
|
n->destroy();
|
|
return g.insertNode(g.createNone(TensorType::get()))->output();
|
|
}
|
|
// TODO: fix all cases where we are not passing in a variable,
|
|
// and then change this to an AT_ASSERT
|
|
if (!ref.is_variable()) {
|
|
ref = autograd::make_variable(ref, /*requires_grad=*/false);
|
|
} else {
|
|
AT_ASSERT(!ref.requires_grad());
|
|
}
|
|
n->output()->inferTypeFrom(
|
|
ref); // note: before t_ because of std::move(ref)
|
|
n->t_(attr::value, std::move(ref));
|
|
} else if (val.isInt()) {
|
|
n->i_(attr::value, val.toInt());
|
|
n->output()->setType(IntType::get());
|
|
} else if (val.isDouble()) {
|
|
n->f_(attr::value, val.toDouble());
|
|
n->output()->setType(FloatType::get());
|
|
} else if (val.isBool()) {
|
|
n->i_(attr::value, val.toBool());
|
|
n->output()->setType(BoolType::get());
|
|
} else if (val.isBoolList()) {
|
|
auto bool_list = val.toBoolList();
|
|
n->is_(
|
|
attr::value, std::vector<int64_t>(bool_list.begin(), bool_list.end()));
|
|
n->output()->setType(ListType::ofBools());
|
|
} else if (val.isIntList()) {
|
|
n->is_(attr::value, val.toIntListRef().vec());
|
|
n->output()->setType(ListType::ofInts());
|
|
} else if (val.isTensorList()) {
|
|
n->ts_(
|
|
attr::value,
|
|
fmap(val.toTensorListRef(), [](const at::Tensor& t) {
|
|
AT_ASSERT(t.is_variable() && !t.requires_grad());
|
|
return t;
|
|
}));
|
|
n->output()->setType(ListType::ofTensors());
|
|
} else if (val.isString()) {
|
|
n->s_(attr::value, val.toString()->string());
|
|
n->output()->setType(StringType::get());
|
|
} else if (val.isDevice()) {
|
|
std::stringstream ss;
|
|
ss << val.toDevice();
|
|
n->s_(attr::value, ss.str());
|
|
n->output()->setType(DeviceObjType::get());
|
|
} else if (val.isNone()) {
|
|
n->output()->setType(NoneType::get());
|
|
} else {
|
|
n->destroy();
|
|
return c10::nullopt;
|
|
}
|
|
if (loc)
|
|
n->setSourceRange(*loc);
|
|
if (scope)
|
|
n->setScope(*scope);
|
|
if (result_type) {
|
|
auto inferred_type = n->output()->type();
|
|
|
|
if (inferred_type->isSubtypeOf(NoneType::get()) &&
|
|
!inferred_type->isSubtypeOf(result_type)) {
|
|
// None doesn't subtype Optional, but an Optional can be None, so handle
|
|
// that here
|
|
TORCH_CHECK(
|
|
result_type->kind() == TypeKind::OptionalType,
|
|
"Expected OptionalType or NoneType, got ",
|
|
result_type->python_str());
|
|
n->output()->setType(result_type);
|
|
} else if (!(inferred_type->isSubtypeOf(TensorType::get()) &&
|
|
result_type->isSubtypeOf(inferred_type))) {
|
|
// Retain more type information in case of tensor constant
|
|
n->output()->setType(result_type);
|
|
}
|
|
}
|
|
return g.insertNode(n)->output();
|
|
}
|
|
|
|
RegisterOperators reg({
|
|
Operator(
|
|
FunctionSchema(
|
|
prim::Constant,
|
|
"",
|
|
{},
|
|
{},
|
|
/*is_vararg=*/false,
|
|
/*is_varret=*/true),
|
|
[](const Node* node) -> Operation {
|
|
TypePtr type = node->output()->type();
|
|
if (type->isSubtypeOf(TensorType::get())) {
|
|
auto t = node->t(attr::value);
|
|
return [t](Stack& stack) {
|
|
push(stack, t);
|
|
return 0;
|
|
};
|
|
} else if (type->isSubtypeOf(BoolType::get())) {
|
|
bool b = node->i(attr::value);
|
|
return [b](Stack& stack) {
|
|
push(stack, b);
|
|
return 0;
|
|
};
|
|
} else if (
|
|
type->isSubtypeOf(NumberType::get()) &&
|
|
node->kindOf(attr::value) == AttributeKind::i) {
|
|
auto i = node->i(attr::value);
|
|
return [i](Stack& stack) {
|
|
push(stack, i);
|
|
return 0;
|
|
};
|
|
} else if (
|
|
type->isSubtypeOf(NumberType::get()) &&
|
|
node->kindOf(attr::value) == AttributeKind::f) {
|
|
auto f = node->f(attr::value);
|
|
return [f](Stack& stack) {
|
|
push(stack, f);
|
|
return 0;
|
|
};
|
|
} else if (type->isSubtypeOf(ListType::ofInts())) {
|
|
const auto& is = node->is(attr::value);
|
|
return [is](Stack& stack) {
|
|
push(stack, is);
|
|
return 0;
|
|
};
|
|
} else if (type->isSubtypeOf(ListType::ofBools())) {
|
|
const auto bs = fmap<bool>(node->is(attr::value));
|
|
return [bs](Stack& stack) {
|
|
push(stack, bs);
|
|
return 0;
|
|
};
|
|
} else if (type->isSubtypeOf(ListType::ofTensors())) {
|
|
const auto& ts = node->ts(attr::value);
|
|
return [ts](Stack& stack) {
|
|
push(stack, ts);
|
|
return 0;
|
|
};
|
|
} else if (type == StringType::get()) {
|
|
const auto& s = node->s(attr::value);
|
|
return [s](Stack& stack) {
|
|
push(stack, s);
|
|
return 0;
|
|
};
|
|
} else if (type == DeviceObjType::get()) {
|
|
auto d = c10::Device(node->s(attr::value));
|
|
return [d](Stack& stack) {
|
|
push(stack, d);
|
|
return 0;
|
|
};
|
|
} else if (node->mustBeNone()) {
|
|
return [](Stack& stack) {
|
|
push(stack, IValue());
|
|
return 0;
|
|
};
|
|
} else {
|
|
std::stringstream ss;
|
|
ss << "constant literal not supported for: " << type->str();
|
|
throw std::runtime_error(ss.str());
|
|
}
|
|
},
|
|
aliasAnalysisInternalSpecialCase()),
|
|
});
|
|
|
|
c10::optional<IValue> toIValue(const Value* v) {
|
|
if (v->node()->kind() != prim::Constant) {
|
|
return c10::nullopt;
|
|
}
|
|
// use implemenation of prim::Constant to compute the output IValue
|
|
auto op = getOperation(v->node());
|
|
Stack stack;
|
|
op(stack);
|
|
return stack.back();
|
|
}
|
|
} // namespace jit
|
|
} // namespace torch
|