Merge ProfiledTensorType and TensorType (#24284)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/24284

This PR finishes the unification of all Tensor types into a single object.
ProfiledTensorType is renamed to TensorType and the old TensorType is
deleted.

Notes:
* Fixes bug in merge for VaryingShape by changing its representation to an
 optional list of optional ints.
* Removes ProfiledTensorType::create(type) invocations that can now
  simply be expect calls on tensor type.

Test Plan: Imported from OSS

Differential Revision: D16794034

Pulled By: zdevito

fbshipit-source-id: 10362398d0bb166d0d385d74801e95d9b87d9dfc
This commit is contained in:
Zachary DeVito 2019-08-20 12:57:40 -07:00 committed by Facebook Github Bot
parent 6824c9018d
commit bdc57d3833
41 changed files with 271 additions and 351 deletions

View File

@ -43,7 +43,6 @@ using OptNameList = c10::optional<std::vector<std::string>>;
_(BoolType) \
_(OptionalType) \
_(VarType) \
_(ProfiledTensorType) \
_(DeviceObjType) \
_(FunctionType) \
_(ClassType) \
@ -251,34 +250,6 @@ struct CAFFE2_API OptionalType
OptionalType(TypePtr elem) : SingleElementType(elem) {}
};
struct TensorType;
using TensorTypePtr = std::shared_ptr<TensorType>;
// This type represents a single Tensor, with an unknown shape.
// Subtype hierarchy for Tensor Types (TensorType as the base type):
// ProfiledTensorType <: TensorType
struct CAFFE2_API TensorType : public Type {
static TensorTypePtr create() {
return TensorTypePtr(new TensorType()); // NOLINT(modernize-make-shared)
}
bool requires_grad() const override {
return true;
}
bool operator==(const Type& rhs) const override {
return rhs.kind() == kind();
}
std::string str() const override {
return "Tensor";
}
static const TypeKind Kind = TypeKind::TensorType;
// global singleton
static TensorTypePtr get();
protected:
TensorType(TypeKind kind = TypeKind::TensorType) : Type(kind) {}
};
template <typename T>
inline c10::optional<T> merge_primitive(
const c10::optional<T>& a,
@ -369,43 +340,31 @@ struct CAFFE2_API VaryingShape {
using VaryingStrides = VaryingShape;
struct ProfiledTensorType;
using ProfiledTensorTypePtr = std::shared_ptr<ProfiledTensorType>;
struct TensorType;
using TensorTypePtr = std::shared_ptr<TensorType>;
// This type represents a single Tensor with a specific size
struct CAFFE2_API ProfiledTensorType : public TensorType {
static ProfiledTensorTypePtr create(const at::Tensor& t) {
return ProfiledTensorTypePtr(new ProfiledTensorType(t));
struct CAFFE2_API TensorType : public Type {
static TensorTypePtr create(const at::Tensor& t) {
return TensorTypePtr(new TensorType(t));
}
static ProfiledTensorTypePtr create(const TypePtr& tptr) {
if (auto ptt = tptr->cast<ProfiledTensorType>()) {
return ptt;
}
if (tptr->isSubtypeOf(TensorType::get())) {
return ProfiledTensorType::get();
}
TORCH_INTERNAL_ASSERT(false, "Expected a tensor type");
}
static ProfiledTensorTypePtr create(
static TensorTypePtr create(
c10::optional<at::ScalarType> scalar_type,
c10::optional<Device> device,
const VaryingShape& sizes,
const VaryingStrides& strides,
c10::optional<bool> requires_grad,
c10::optional<bool> autograd_zero=c10::nullopt) {
return ProfiledTensorTypePtr(new ProfiledTensorType(
return TensorTypePtr(new TensorType(
scalar_type, device, sizes, strides, requires_grad));
}
static ProfiledTensorTypePtr create(
static TensorTypePtr create(
c10::optional<at::ScalarType> scalar_type,
c10::optional<Device> device,
c10::optional<size_t> dim,
c10::optional<bool> requires_grad) {
return ProfiledTensorType::create(
return TensorType::create(
scalar_type,
device,
VaryingShape(dim),
@ -415,7 +374,7 @@ struct CAFFE2_API ProfiledTensorType : public TensorType {
// overloaded create variadic template argument as it could not distinguish
// initializer list
static ProfiledTensorTypePtr createContiguous(
static TensorTypePtr createContiguous(
at::ScalarType scalar_type,
at::Device device,
at::IntArrayRef sizes) {
@ -426,7 +385,7 @@ struct CAFFE2_API ProfiledTensorType : public TensorType {
VaryingShape(contiguousStridesOf(sizes)),
c10::nullopt);
}
static ProfiledTensorTypePtr create(
static TensorTypePtr create(
at::ScalarType scalar_type,
at::Device device,
at::IntArrayRef sizes,
@ -470,21 +429,13 @@ struct CAFFE2_API ProfiledTensorType : public TensorType {
return false;
}
auto rt = rhs.expect<ProfiledTensorType>();
auto rt = rhs.expect<TensorType>();
return scalar_type_ == rt->scalarType() && sizes() == rt->sizes() &&
strides() == rt->strides() && device() == rt->device() &&
requiresGrad() == rt->requiresGrad() && autogradZero() == rt->autogradZero();
}
bool isSubtypeOf(const TypePtr rhs) const override {
if (auto rhs_p = rhs->cast<ProfiledTensorType>()) {
// if we have the same pointer, avoid computing the merge
if (this == rhs_p.get()) {
return true;
}
return *merge(rhs_p) == *rhs_p;
}
return rhs->kind() == TypeKind::TensorType || TensorType::isSubtypeOf(rhs);
}
bool isSubtypeOf(const TypePtr rhs) const override;
std::string str() const override;
c10::optional<size_t> numel() const {
@ -500,27 +451,27 @@ struct CAFFE2_API ProfiledTensorType : public TensorType {
return prod;
}
ProfiledTensorTypePtr withRequiresGrad(c10::optional<bool> s) {
TensorTypePtr withRequiresGrad(c10::optional<bool> s) {
auto copy = clone();
copy->requires_grad_ = s;
return copy;
}
ProfiledTensorTypePtr withScalarType(c10::optional<ScalarType> st) {
TensorTypePtr withScalarType(c10::optional<ScalarType> st) {
auto copy = clone();
copy->scalar_type_ = st;
return copy;
}
ProfiledTensorTypePtr withDim(c10::optional<size_t> d) {
TensorTypePtr withDim(c10::optional<size_t> d) {
auto copy = clone();
copy->sizes_ = VaryingShape(d);
copy->strides_ = VaryingShape(d);
return copy;
}
ProfiledTensorTypePtr withSizesStrides(
TensorTypePtr withSizesStrides(
at::IntArrayRef sizes,
at::IntArrayRef strides) const {
auto cloned = clone();
@ -529,19 +480,19 @@ struct CAFFE2_API ProfiledTensorType : public TensorType {
return cloned;
}
ProfiledTensorTypePtr withSizes(at::IntArrayRef sizes) const {
TensorTypePtr withSizes(at::IntArrayRef sizes) const {
return withSizesStrides(
sizes, contiguousStridesOf(sizes));
}
ProfiledTensorTypePtr dimensionedOnly() const {
TensorTypePtr dimensionedOnly() const {
auto copy = clone();
copy->sizes_ = VaryingShape(sizes().size());
copy->strides_ = VaryingShape(sizes().size());
return copy;
}
ProfiledTensorTypePtr contiguous() const {
TensorTypePtr contiguous() const {
auto cloned = clone();
if (auto concrete_sizes = sizes().concrete_sizes()) {
cloned->strides_ = VaryingShape(contiguousStridesOf(*concrete_sizes));
@ -551,14 +502,14 @@ struct CAFFE2_API ProfiledTensorType : public TensorType {
return cloned;
}
ProfiledTensorTypePtr merge(ProfiledTensorTypePtr other) const {
TensorTypePtr merge(TensorTypePtr other) const {
auto scalar_type = merge_primitive(scalarType(), other->scalarType());
auto dev = merge_primitive(device(), other->device());
auto sz = sizes().merge(other->sizes());
auto srs = strides().merge(other->strides());
auto gr = merge_primitive(requiresGrad(), other->requiresGrad());
auto zero = merge_primitive(autogradZero(), other->autogradZero());
return ProfiledTensorType::create(scalar_type, dev, sz, srs, gr, zero);
return TensorType::create(scalar_type, dev, sz, srs, gr, zero);
}
// is all information about the type specified except for autograd?
// This replaces the notion of a 'CompleteTensorType' that used to exist
@ -568,7 +519,7 @@ struct CAFFE2_API ProfiledTensorType : public TensorType {
return scalar_type_ && device_ && sizes_.isComplete() && strides_.isComplete();
}
ProfiledTensorTypePtr withAutogradZero() {
TensorTypePtr withAutogradZero() {
auto r = clone();
r->autograd_zero_ = true;
return r;
@ -578,13 +529,13 @@ struct CAFFE2_API ProfiledTensorType : public TensorType {
return autograd_zero_;
}
static ProfiledTensorTypePtr get();
static TensorTypePtr get();
static const TypeKind Kind = TypeKind::ProfiledTensorType;
static const TypeKind Kind = TypeKind::TensorType;
private:
ProfiledTensorType(const at::Tensor& tensor)
: TensorType(TypeKind::ProfiledTensorType),
TensorType(const at::Tensor& tensor)
: Type(TypeKind::TensorType),
scalar_type_(tensor.scalar_type()),
device_(tensor.device()),
sizes_(tensor.sizes().size()),
@ -595,14 +546,14 @@ struct CAFFE2_API ProfiledTensorType : public TensorType {
strides_ = tensor.strides().vec();
}
}
ProfiledTensorType(
TensorType(
c10::optional<at::ScalarType> scalar_type,
c10::optional<Device> device,
const VaryingShape& sizes,
const VaryingStrides& strides,
c10::optional<bool> requires_grad,
c10::optional<bool> autograd_zero=c10::nullopt)
: TensorType(TypeKind::ProfiledTensorType),
: Type(TypeKind::TensorType),
scalar_type_(scalar_type),
device_(device),
sizes_(sizes),
@ -610,8 +561,8 @@ struct CAFFE2_API ProfiledTensorType : public TensorType {
requires_grad_(requires_grad),
autograd_zero_(autograd_zero) {}
ProfiledTensorTypePtr clone() const {
return ProfiledTensorTypePtr(new ProfiledTensorType(
TensorTypePtr clone() const {
return TensorTypePtr(new TensorType(
scalar_type_, device_, sizes_, strides_, requires_grad_, autograd_zero_));
}
@ -1170,18 +1121,18 @@ inline TypePtr unshapedType(const TypePtr& type) {
return type->withContained(fmap(type->containedTypes(), unshapedType));
}
inline TypePtr ProfiledTensorType::fromNumberType(TypePtr typ) {
inline TypePtr TensorType::fromNumberType(TypePtr typ) {
if (typ->isSubtypeOf(IntType::get())) {
return ProfiledTensorType::createContiguous(at::kLong, at::kCPU, {});
return TensorType::createContiguous(at::kLong, at::kCPU, {});
} else if (typ->isSubtypeOf(FloatType::get())) {
return ProfiledTensorType::createContiguous(at::kFloat, at::kCPU, {});
return TensorType::createContiguous(at::kFloat, at::kCPU, {});
} else if (typ->isSubtypeOf(BoolType::get())) {
return ProfiledTensorType::createContiguous(at::kLong, at::kCPU, {});
return TensorType::createContiguous(at::kLong, at::kCPU, {});
}
AT_ERROR("unknown number type", typ->str());
}
inline TypePtr ProfiledTensorType::fromBoolType() {
return ProfiledTensorType::createContiguous(at::kLong, at::kCPU, {});
inline TypePtr TensorType::fromBoolType() {
return TensorType::createContiguous(at::kLong, at::kCPU, {});
}
inline at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) {

View File

@ -6,7 +6,7 @@
namespace c10 {
std::ostream& operator<<(std::ostream & out, const Type & t) {
if (auto value = t.cast<ProfiledTensorType>()) {
if (auto value = t.cast<TensorType>()) {
if (value->scalarType().has_value()) {
out << toString(*value->scalarType());
if (!value->sizes().size().has_value()) {
@ -29,6 +29,9 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
}
out << ")";
}
if (value->autogradZero() && *value->autogradZero()) {
out << "[AutogradZero]";
}
} else if(t.kind() == TypeKind::ListType) {
auto prim = t.cast<ListType>()->getElementType();
out << *prim << "[]";
@ -61,12 +64,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
}
TensorTypePtr TensorType::get() {
static auto value = TensorType::create();
return value;
}
ProfiledTensorTypePtr ProfiledTensorType::get() {
static auto value = ProfiledTensorType::create(
static auto value = TensorType::create(
{},
{},
VaryingShape{c10::optional<size_t>()},
@ -140,7 +138,7 @@ ListTypePtr ListType::ofBools() {
// the type, like in the tracer.
TypePtr incompleteInferTypeFrom(const IValue& value) {
if (value.isTensor()) {
return ProfiledTensorType::create(value.toTensor());
return TensorType::create(value.toTensor());
} else if (value.isDouble()) {
return FloatType::get();
} else if (value.isInt()) {
@ -262,8 +260,8 @@ c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2) {
// NB: we do not return NumberType because there is not currently enough
// operator support for it
if (t1->kind() == ProfiledTensorType::Kind && t2->kind() == ProfiledTensorType::Kind) {
return t1->expect<ProfiledTensorType>()->merge(t2->expect<ProfiledTensorType>());
if (t1->kind() == TensorType::Kind && t2->kind() == TensorType::Kind) {
return t1->expect<TensorType>()->merge(t2->expect<TensorType>());
}
if (t1->isSubtypeOf(TensorType::get()) && t2->isSubtypeOf(TensorType::get())) {
@ -479,7 +477,7 @@ bool Type::isSubtypeOf(const TypePtr rhs) const {
return false;
}
std::string ProfiledTensorType::str() const {
std::string TensorType::str() const {
return "Tensor";
}
@ -625,4 +623,15 @@ std::string TupleType::python_str() const {
return ss.str();
}
bool TensorType::isSubtypeOf(const TypePtr rhs) const {
if (auto rhs_p = rhs->cast<TensorType>()) {
// if we have the same pointer, avoid computing the merge
if (this == rhs_p.get()) {
return true;
}
return *merge(rhs_p) == *rhs_p;
}
return Type::isSubtypeOf(rhs);
}
} // namespace c10

View File

@ -91,40 +91,37 @@ void testCompleteArgumentSpec() {
ASSERT_EQ(with_const.at(2).sizes().size(), 2);
}
size_t hashCode(const ProfiledTensorTypePtr& ptr) {
return std::hash<ProfiledTensorType>()(*ptr.get());
size_t hashCode(const TensorTypePtr& ptr) {
return std::hash<TensorType>()(*ptr.get());
}
void testProfiledTensorTypeHashing() {
c10::VaryingShape vs(c10::optional<size_t>{});
auto ptt_empty1 = ProfiledTensorType::create({}, {}, vs, vs, false);
auto ptt_empty2 = ProfiledTensorType::create({}, {}, vs, vs, false);
auto ptt_empty1 = TensorType::create({}, {}, vs, vs, false);
auto ptt_empty2 = TensorType::create({}, {}, vs, vs, false);
ASSERT_EQ(hashCode(ptt_empty1), hashCode(ptt_empty2));
c10::VaryingShape vs22(std::vector<int64_t>{2, 2});
auto ptt_vs22_1 = ProfiledTensorType::create({}, {}, vs22, vs, false);
auto ptt_vs22_2 = ProfiledTensorType::create({}, {}, vs22, vs, false);
auto ptt_vs22_1 = TensorType::create({}, {}, vs22, vs, false);
auto ptt_vs22_2 = TensorType::create({}, {}, vs22, vs, false);
ASSERT_EQ(hashCode(ptt_vs22_1), hashCode(ptt_vs22_2));
c10::VaryingShape vs23(std::vector<int64_t>{2, 3});
auto ptt_vs23_1 = ProfiledTensorType::create({}, {}, vs23, vs, false);
auto ptt_vs23_1 = TensorType::create({}, {}, vs23, vs, false);
ASSERT_NE(hashCode(ptt_vs22_1), hashCode(ptt_vs23_1));
auto ptt_vs22_vs22_1 = ProfiledTensorType::create({}, {}, vs22, vs22, false);
auto ptt_vs22_vs22_2 = ProfiledTensorType::create({}, {}, vs22, vs22, false);
auto ptt_vs22_vs22_1 = TensorType::create({}, {}, vs22, vs22, false);
auto ptt_vs22_vs22_2 = TensorType::create({}, {}, vs22, vs22, false);
ASSERT_EQ(hashCode(ptt_vs22_vs22_1), hashCode(ptt_vs22_vs22_2));
auto ptt_vs22_vs23_2 = ProfiledTensorType::create({}, {}, vs22, vs23, false);
auto ptt_vs22_vs23_2 = TensorType::create({}, {}, vs22, vs23, false);
ASSERT_NE(hashCode(ptt_vs22_vs22_1), hashCode(ptt_vs22_vs23_2));
auto ptt_vs22_vs22_1_true =
ProfiledTensorType::create({}, {}, vs22, vs22, true);
auto ptt_vs22_vs22_2_true =
ProfiledTensorType::create({}, {}, vs22, vs22, true);
auto ptt_vs22_vs22_1_true = TensorType::create({}, {}, vs22, vs22, true);
auto ptt_vs22_vs22_2_true = TensorType::create({}, {}, vs22, vs22, true);
ASSERT_EQ(hashCode(ptt_vs22_vs22_1_true), hashCode(ptt_vs22_vs22_2_true));
auto ptt_vs22_vs22_1_false =
ProfiledTensorType::create({}, {}, vs22, vs22, false);
auto ptt_vs22_vs22_1_false = TensorType::create({}, {}, vs22, vs22, false);
ASSERT_NE(hashCode(ptt_vs22_vs22_1_true), hashCode(ptt_vs22_vs22_1_false));
}

View File

@ -169,7 +169,7 @@ void testADFormulas() {
void testDifferentiate() {
// Note: can't use IRParser for this test due to issue #23989
auto graph = std::make_shared<Graph>();
const auto type = ProfiledTensorType::create(at::ScalarType::Float, at::kCPU, {2, 3, 4}, {12, 4, 1});
const auto type = TensorType::create(at::ScalarType::Float, at::kCPU, {2, 3, 4}, {12, 4, 1});
// Builds graph a * b * a + b
auto* a = graph->addInput()->setType(type);

View File

@ -1011,7 +1011,7 @@ static void checkShape(
bool prev = true) {
auto profile = (prev) ? n->inputs().at(0)->node() : n;
auto tp = profile->output()->type();
auto ptp = tp->expect<ProfiledTensorType>();
auto ptp = tp->expect<TensorType>();
ASSERT_EQ(ptp->sizes().concrete_sizes().value(), expected);
}
@ -1044,7 +1044,9 @@ void testInsertAndEliminateRedundantGuards() {
return n->kind() == prim::Guard;
});
ASSERT_NE(guard, nodes.end());
ASSERT_EQ(guard->input()->type()->cast<ProfiledTensorType>(), nullptr);
ASSERT_EQ(
guard->input()->type()->expect<TensorType>()->sizes().size(),
c10::nullopt);
checkShape(*guard, {2, 3}, false);
auto is_guard = [](Node* n) { return n->kind() == prim::Guard; };
int num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);

View File

@ -2150,8 +2150,8 @@ graph(%Ra, %Rb):
graph = f.graph_for(t, "hi")
input_types = list(next(graph.inputs()).type().elements())
w = input_types[0]
self.assertEqual(input_types[0].kind(), 'ProfiledTensorType')
self.assertEqual(input_types[1].elements()[1].kind(), 'ProfiledTensorType')
self.assertEqual(input_types[0].kind(), 'TensorType')
self.assertEqual(input_types[1].elements()[1].kind(), 'TensorType')
def test_constant_prop_simple(self):
@torch.jit.script
@ -2412,8 +2412,8 @@ graph(%Ra, %Rb):
def test_onnx_transpose_incomplete_tensor_type(self):
# Smoke test to get us into the state where we are attempting to export
# a transpose op, where the input is a TensorType rather than a
# ProfiledTensorType. This would previously not work, since we would
# a transpose op, where the input is a TensorType without size information.
# This would previously not work, since we would
# take the size of the input and use the length of its sizes as the
# number of dimensions in the permutation.
class Foo(torch.jit.ScriptModule):
@ -4561,8 +4561,7 @@ a")
x = torch.randn(3, 1, 5, requires_grad=True)
fn = torch.jit.script(fn)
graph = _propagate_shapes(fn.graph, (x,), False)
a = next(graph.outputs()).type().kind()
self.assertTrue(next(graph.outputs()).type().kind() != 'TensorType')
self.assertTrue(next(graph.outputs()).type().scalarType() == 'Double')
def test_shape_prop_promotion(self):
@torch.jit.script
@ -5168,7 +5167,7 @@ a")
res = fn(t, 1)
self.assertEqual(res, 0)
g = torch.jit.last_executed_optimized_graph()
self.assertEqual(next(g.inputs()).type().kind(), 'ProfiledTensorType')
self.assertEqual(next(g.inputs()).type().kind(), 'TensorType')
@torch.jit.script
def fn(x, y, b):

View File

@ -216,7 +216,7 @@ void ArgumentSpecCreator::specializeTypes(
auto& arg = spec.tensorAt(tensor_arg_spec_offset++);
if (!arg.defined()) {
result_stack.back().emplace_back(
ProfiledTensorType::get()->withAutogradZero());
TensorType::get()->withAutogradZero());
} else {
result_stack.back().emplace_back(arg.toType());
}

View File

@ -44,11 +44,12 @@ struct ArgumentInfo {
TypePtr toType() const {
if (!defined())
return TensorType::get();
return ProfiledTensorType::create(type(),
ConvertIntToCPUOrCUDA(device()),
c10::VaryingShape(dim()),
c10::VaryingShape(dim()),
requires_grad());
return TensorType::create(
type(),
ConvertIntToCPUOrCUDA(device()),
c10::VaryingShape(dim()),
c10::VaryingShape(dim()),
requires_grad());
}
operator TypePtr() const {
return toType();
@ -349,7 +350,7 @@ struct CompleteArgumentInfo {
operator TypePtr() const {
if (!defined())
return TensorType::get();
return ProfiledTensorType::create(
return TensorType::create(
type(), ConvertIntToCPUOrCUDA(device()), sizes(), strides());
}
@ -447,8 +448,8 @@ struct hash<c10::VaryingShape> {
};
template <>
struct hash<c10::ProfiledTensorType> {
size_t operator()(const c10::ProfiledTensorType& ptt) const {
struct hash<c10::TensorType> {
size_t operator()(const c10::TensorType& ptt) const {
return torch::get_hash<
c10::optional<int8_t>,
c10::VaryingShape,

View File

@ -329,8 +329,7 @@ Values are abstract representation of data in the program. When executing, the a
TorchScript, unlike Python, is statically typed, so every Value has a Type associated with it, and every FunctionSchema has a list of argument types and a return type for a function. Type is the base class of a hierarchy of C++ objects that represent the built-in types of TorchScript. Types provide methods such as `Type::isSubtypeOf` that describe the typing relationships. Common type are:
* TensorType - the root type of all Tensors in the system.
* ProfiledTensorType - a tensor with optionally refined information. It may optional know its device, type, requires_grad state, the number of dimensions.
* TensorType - a tensor with optionally refined information. It may optional know its device, type, requires_grad state, the number of dimensions.
If it does know the number of dimensions it may optionally know the size of a particular dimension.
* Tuples - e.g. Tuple[Tensor, Int]. Each member of the tuple is statically typed and the length of the tuple is statically known.
* List[T] - e.g. List[Tensor]. Mutable lists of a particular type.
@ -811,7 +810,7 @@ Execution starts in `GraphExecutor::run`, which takes takes a Stack of inputs.
The ArgumentSpec object is used as a key into a cache that holds pre-optimized Code objects (held in an ExecutionPlan object). On a cache hit, an InterpreterState is created and the Code in the cache is run.
*Pre-derivative Optimization* On a code cache miss, we generate a new optimized Graph on the fly (`compileSpec`). It starts by creating a copy of the initial Graph and setting the input types to the specialized Tensor types observed in this specialization. TensorType inputs to the Graph will get replaced with ProfiledTensorTypes that know the device, number of dimensions, and requires grad state.
*Pre-derivative Optimization* On a code cache miss, we generate a new optimized Graph on the fly (`compileSpec`). It starts by creating a copy of the initial Graph and setting the input types to the specialized Tensor types observed in this specialization. TensorType inputs to the Graph will get refined with types that know the device, number of dimensions, and requires grad state.
```
# post specialization, inputs are now specialized types
@ -1148,7 +1147,7 @@ TODO: fusion, operators
# Profiling Programs
`prim::profile` nodes are inserted on every **use** of a value by `ProfilingRecord::instrumentBlock`. Every `prim::profile` node runs a lambda that uses a captured, initial type value and the type of an incoming tensor and merges the two into `ProfiledTensorType`
`prim::profile` nodes are inserted on every **use** of a value by `ProfilingRecord::instrumentBlock`. Every `prim::profile` node runs a lambda that uses a captured, initial type value and the type of an incoming tensor and merges the two into a refined `TensorType`
`prim::profile` nodes are replaced with `prim::Guard` nodes by `InsertGuards`. `prim::Guard` nodes are inserted to guarantee that beyond the guard a guarded tensor will always be of the profiled shape. This guarantee will enable optimizations and codegens to generate more efficient code.

View File

@ -221,7 +221,7 @@ void EncoderBase::EncodeValueInfo(
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes) {
std::string name = n->debugName();
v->set_name(name);
if (ProfiledTensorTypePtr node_type = n->type()->cast<ProfiledTensorType>()) {
if (TensorTypePtr node_type = n->type()->cast<TensorType>()) {
if (!node_type->isComplete()) {
return;
}

View File

@ -91,7 +91,7 @@ static std::string variableType(const std::shared_ptr<c10::Type>& t) {
return "double";
} else if (t->kind() == TypeKind::BoolType) {
return "bool";
} else if (auto scalar_type = ProfiledTensorType::create(t)->scalarType()) {
} else if (auto scalar_type = t->expect<TensorType>()->scalarType()) {
return calcScalarTypeName(*scalar_type);
}
// something went wrong with the type analysis during shape propagation
@ -115,7 +115,7 @@ static std::string typeCastedValueName(
// cast here, which may end up being a no-op if the tensor's scalar type
// is `double`.
return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
} else if (auto scalar_type = ProfiledTensorType::create(t)->scalarType()) {
} else if (auto scalar_type = t->expect<TensorType>()->scalarType()) {
if (*scalar_type != outtype) {
return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
}
@ -260,8 +260,7 @@ static std::string encodeRHS(const Node* n) {
} else {
size_t i = 0;
auto outtype =
ProfiledTensorType::create(n->output()->type())->scalarType();
auto outtype = n->output()->type()->expect<TensorType>()->scalarType();
TORCH_INTERNAL_ASSERT(outtype);
for (auto in : n->inputs()) {

View File

@ -204,10 +204,10 @@ std::shared_ptr<FusedKernel> compileKernel(
for (size_t i = 0; i < input_desc.size(); i++) {
const auto& desc = input_desc[i];
// TODO: can't get rid of this use of ProfiledTensorType
// TODO: can't get rid of this use of TensorType
// until we switch to ProfilingGraphExecutor, so we don't have to
// run PropagateInputShapes below
graph->inputs()[i]->setType(ProfiledTensorType::create(
graph->inputs()[i]->setType(TensorType::create(
desc.scalar_type,
device,
c10::VaryingShape(desc.nDim()),
@ -254,10 +254,9 @@ std::shared_ptr<FusedKernel> compileKernel(
sizes.at(o->node()->i(attr::dim)) *= o->node()->inputs().size();
}
auto scalar_type = ProfiledTensorType::create(o->type())->scalarType();
auto scalar_type = o->type()->expect<TensorType>()->scalarType();
TORCH_INTERNAL_ASSERT(scalar_type);
auto type =
ProfiledTensorType::createContiguous(*scalar_type, device, sizes);
auto type = TensorType::createContiguous(*scalar_type, device, sizes);
output_desc.emplace_back(type);
const auto& desc = output_desc.back();

View File

@ -41,7 +41,7 @@ struct TORCH_API TensorDesc {
TensorDesc(const at::Tensor& t)
: TensorDesc(t.scalar_type(), t.sizes(), t.strides()) {}
TensorDesc(const c10::ProfiledTensorTypePtr& type)
TensorDesc(const c10::TensorTypePtr& type)
: TensorDesc(
type->scalarType().value(),
type->sizes().concrete_sizes().value(),

View File

@ -1001,7 +1001,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
++af.pc;
} break;
case GUARD: {
auto actual = ProfiledTensorType::create(stack.back().toTensor());
auto actual = TensorType::create(stack.back().toTensor());
const TypePtr& expected = af.types[inst.X];
push(stack, *expected == *actual);
++af.pc;

View File

@ -315,7 +315,7 @@ static void checkSameDevice(const Node* node) {
bool has_device = false;
c10::optional<at::Device> device = c10::nullopt;
auto checkValue = [&](const Value* v) {
if (ProfiledTensorTypePtr type = v->type()->cast<ProfiledTensorType>()) {
if (TensorTypePtr type = v->type()->cast<TensorType>()) {
if (type->device() && !has_device) {
has_device = true;
device = *type->device();
@ -670,7 +670,7 @@ void Graph::remapTypes(const std::function<TypePtr(TypePtr)>& type_map) {
}
void Value::inferTypeFrom(const at::Tensor& output) {
setType(ProfiledTensorType::create(output));
setType(TensorType::create(output));
}
bool Value::mustBeNone() const {
@ -1439,7 +1439,7 @@ Node* Graph::createDict(
Node* Graph::createNumToTensor(Value* value) {
auto typ = value->type();
Node* result = create(prim::NumToTensor, {value});
result->output()->setType(ProfiledTensorType::fromNumberType(std::move(typ)));
result->output()->setType(TensorType::fromNumberType(std::move(typ)));
return result;
}

View File

@ -169,7 +169,7 @@ struct Value {
return type()->requires_grad();
}
bool isCompleteTensor() const {
if (auto pt = type()->cast<ProfiledTensorType>()) {
if (auto pt = type()->cast<TensorType>()) {
return pt->isComplete();
}
return false;

View File

@ -172,7 +172,7 @@ struct BailOutInserter {
for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
if (it->kind() == prim::Guard) {
// this will need to be profiled again
it->input()->setType(TensorType::create());
it->input()->setType(TensorType::get());
// destroy the guard
it->output()->replaceAllUsesWith(it->input());
it.destroyCurrent();
@ -274,7 +274,7 @@ static void removeBailouts(Block* b) {
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
if (it->kind() == prim::BailOut) {
// clear profiling information
it->inputs().at(0)->setType(TensorType::create());
it->inputs().at(0)->setType(TensorType::get());
it->output()->replaceAllUsesWith(it->inputs().at(0));
it.destroyCurrent();
} else {

View File

@ -13,8 +13,6 @@
namespace torch {
namespace jit {
using ::c10::ProfiledTensorTypePtr;
// Replaces prim::Guard nodes with prim::BailOut nodes and
// computes sets of inputs needed to resume execution at
// bailout points

View File

@ -40,7 +40,7 @@ bool isDecomposableNorm(Node* normalize_op) {
if (!input->type()->isSubtypeOf(TensorType::get())) {
return false;
}
auto device = ProfiledTensorType::create(input->type())->device();
auto device = input->type()->expect<TensorType>()->device();
// As of now, we do the decomposition for batchnorm/layernorm on GPU device
// only
if (!device || (*device).is_cpu()) {

View File

@ -47,9 +47,9 @@ static void EraseNumberTypesOnBlock(Block* block) {
default: {
for (auto o : it->outputs()) {
if (o->type()->isSubtypeOf(NumberType::get())) {
o->setType(ProfiledTensorType::fromNumberType(o->type()));
o->setType(TensorType::fromNumberType(o->type()));
} else if (o->type()->isSubtypeOf(BoolType::get())) {
o->setType(ProfiledTensorType::fromBoolType());
o->setType(TensorType::fromBoolType());
}
}
} break;

View File

@ -178,7 +178,7 @@ struct GraphFuser {
if (!v->type()->isSubtypeOf(TensorType::get())) {
return true;
}
auto device = ProfiledTensorType::create(v->type())->device();
auto device = v->type()->expect<TensorType>()->device();
if (!device) {
return true;
}

View File

@ -134,7 +134,7 @@ struct GuardElimination {
input->node()->kind() == prim::Constant) {
AT_ASSERT(
input->node()->kind() != prim::Guard ||
input->type()->expect<ProfiledTensorType>());
input->type()->expect<TensorType>());
} else {
all_inputs_guarded = false;
break;

View File

@ -13,8 +13,6 @@
namespace torch {
namespace jit {
using ::c10::ProfiledTensorTypePtr;
TORCH_API void EliminateRedundantGuards(std::shared_ptr<Graph> graph);
} // namespace jit

View File

@ -30,15 +30,11 @@ struct GuardInserter {
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
auto n = *it;
if (n->kind() == prim::profile && n->outputs().size() == 1) {
auto pttp = n->output()->type()->cast<ProfiledTensorType>();
auto pttp = n->output()->type()->cast<TensorType>();
if (pttp) {
// make a *copy* of ProfilingTensorType, in case we'd like
// to make changes to it independently from the one being
// profiled
auto guard = graph_->create(prim::Guard, {n->input()}, 1);
auto go = guard->output();
auto copy = ProfiledTensorType::create(pttp);
go->setType(copy);
go->setType(pttp);
guard->insertBefore(n);
n->output()->replaceAllUsesWith(go);
} else {

View File

@ -13,8 +13,6 @@
namespace torch {
namespace jit {
using ::c10::ProfiledTensorTypePtr;
TORCH_API void InsertGuards(std::shared_ptr<Graph> graph);
} // namespace jit

View File

@ -13,7 +13,6 @@
namespace torch {
namespace jit {
using ::c10::ProfiledTensorTypePtr;
using SparseBitVector = ::c10::SparseBitVector<256>;
// BuildLivenessSets computes "bailout" liveness which is equivalent to

View File

@ -29,8 +29,8 @@ Node* InsertCastForCond(Value* cond_val, Graph* graph, Node* consumer_node) {
bool IsCondCastRequired(Value* cond_val) {
const auto& type = cond_val->type();
if (type->isSubtypeOf(TensorType::get())) {
if (auto scalar_type = ProfiledTensorType::create(type)->scalarType()) {
if (auto tt = type->cast<TensorType>()) {
if (auto scalar_type = tt->scalarType()) {
return *scalar_type != c10::kBool;
}
}
@ -55,7 +55,7 @@ void FixupONNXLoops(Block* block) {
cond->setType(BoolType::create());
Value* i = sub_block->inputs()[0];
i->setType(ProfiledTensorType::fromNumberType(IntType::get()));
i->setType(TensorType::fromNumberType(IntType::get()));
// add cast to condition input inside the loop.
Value* next_cond_val = sub_block->outputs()[0];

View File

@ -129,13 +129,13 @@ void fuseBroadcast(Block* b) {
// Not all broadcasts are supported by ONNX broadcast.
c10::optional<size_t> axis = fusibleExpandTo(
unexpanded_input->type()
->expect<ProfiledTensorType>()
->expect<TensorType>()
->sizes()
.concrete_sizes()
.value(), // from
n->output()
->type()
->expect<ProfiledTensorType>()
->expect<TensorType>()
->sizes()
.concrete_sizes()
.value()); // to
@ -296,14 +296,13 @@ void pushPackingPastRnn(Block* b) {
// unhygenic way, Pytorch ends up propagating an incorrect type.
// Until a long-term cleanup comes around, we can fix this by
// resetting the size to the correct value.
ProfiledTensorTypePtr oldType =
rnn->inputs().at(0)->type()->cast<ProfiledTensorType>();
TensorTypePtr oldType = rnn->inputs().at(0)->type()->cast<TensorType>();
if (oldType && oldType->isComplete()) {
std::vector<int64_t> new_sizes;
new_sizes.push_back(*oldType->sizes()[0]);
new_sizes.push_back(*oldType->sizes()[1]);
new_sizes.push_back(rnn->i(attr::hidden_size));
ProfiledTensorTypePtr newType = ProfiledTensorType::createContiguous(
TensorTypePtr newType = TensorType::createContiguous(
*oldType->scalarType(), *oldType->device(), new_sizes);
next->outputs().at(0)->setType(newType);
}

View File

@ -27,8 +27,7 @@ static void PrepareDivisionForONNXOnBlock(Block* block) {
it->replaceInput(0, floattensor_inputs[0]);
it->replaceInput(1, floattensor_inputs[1]);
it->output()->setType(
ProfiledTensorType::fromNumberType(FloatType::get()));
it->output()->setType(TensorType::fromNumberType(FloatType::get()));
}
}
}

View File

@ -44,9 +44,8 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) {
"aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor",
/*const_inputs=*/attr::size)) {
// x.expand(x.size()) == x
if (auto input_type = node->namedInput(attr::self)
->type()
->cast<ProfiledTensorType>()) {
if (auto input_type =
node->namedInput(attr::self)->type()->cast<TensorType>()) {
auto expanded_sizes = node->get<c10::List<int64_t>>(attr::size);
auto input_type_sizes = input_type->sizes().concrete_sizes();
if (expanded_sizes.has_value() && input_type_sizes &&
@ -71,8 +70,8 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) {
} else if (node->matches(
"aten::type_as(Tensor self, Tensor other) -> Tensor")) {
// x.type_as(y) == x iff x.type() == y.type()
auto self_type = ProfiledTensorType::create(node->input(0)->type());
auto other_type = ProfiledTensorType::create(node->input(1)->type());
auto self_type = node->input(0)->type()->expect<TensorType>();
auto other_type = node->input(1)->type()->expect<TensorType>();
if (mustBeEqual(self_type->scalarType(), other_type->scalarType()) &&
mustBeEqual(self_type->device(), other_type->device())) {
GRAPH_UPDATE(
@ -108,7 +107,7 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) {
// type_as conditional on the tensor shape being a scalar, but that
// might add overhead, and make analysis harder.
auto add_mat_type =
ProfiledTensorType::create(node->input(1 - mm_side)->type());
node->input(1 - mm_side)->type()->expect<TensorType>();
// if we don't have the rank, we can't tell if the bias is a scalar
if (!add_mat_type->sizes().size()) {
continue;
@ -123,10 +122,10 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) {
SymbolicVariable mat1(mm_node->input(0));
SymbolicVariable mat2(mm_node->input(1));
auto mat1_type = ProfiledTensorType::create(mat1.value()->type());
auto mat1_type = mat1.value()->type()->expect<TensorType>();
auto mat_scalar_type = mat1_type->scalarType();
if (!mat_scalar_type) {
auto mat2_type = ProfiledTensorType::create(mat2.value()->type());
auto mat2_type = mat2.value()->type()->expect<TensorType>();
mat_scalar_type = mat2_type->scalarType();
}
@ -282,7 +281,7 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) {
node->output()->replaceAllUsesWith(node->input());
}
} else if (node->matches("prim::dtype(Tensor a) -> int")) {
auto ptt = ProfiledTensorType::create(node->input()->type());
auto ptt = node->input()->type()->expect<TensorType>();
if (ptt->scalarType()) {
WithInsertPoint guard(node);
auto output = node->owningGraph()->insertConstant(
@ -292,7 +291,7 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) {
node->output()->replaceAllUsesWith(output);
}
} else if (node->matches("prim::device(Tensor a) -> Device")) {
auto ptt = ProfiledTensorType::create(node->input()->type());
auto ptt = node->input()->type()->expect<TensorType>();
if (ptt->device()) {
WithInsertPoint guard(node);
auto output = node->owningGraph()->insertConstant(*ptt->device());
@ -304,7 +303,7 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) {
node->output()->replaceAllUsesWith(output);
}
} else if (node->matches("aten::dim(Tensor self) -> int")) {
auto ptt = ProfiledTensorType::create(node->input()->type());
auto ptt = node->input()->type()->expect<TensorType>();
if (auto dim = ptt->sizes().size()) {
WithInsertPoint guard(node);
auto output =
@ -317,7 +316,7 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) {
node->output()->replaceAllUsesWith(output);
}
} else if (node->matches("prim::is_cuda(Tensor a) -> bool")) {
auto ptt = ProfiledTensorType::create(node->input()->type());
auto ptt = node->input()->type()->expect<TensorType>();
if (ptt->device()) {
WithInsertPoint guard(node);
auto output =

View File

@ -16,7 +16,7 @@ bool getRequiresGrad(Value* value) {
}
void setRequiresGrad(Value* value, bool req_value) {
if (auto type = value->type()->cast<ProfiledTensorType>()) {
if (auto type = value->type()->cast<TensorType>()) {
value->setType(type->withRequiresGrad(req_value));
}
}
@ -72,7 +72,7 @@ void PropagateRequiresGradSimpleNode(Node* node) {
return setRequiresGrad(node->output(), *const_arg);
}
}
if (auto type = node->output()->type()->cast<ProfiledTensorType>()) {
if (auto type = node->output()->type()->cast<TensorType>()) {
if (type->scalarType()) {
setRequiresGrad(node->output(), at::isFloatingType(*type->scalarType()));
}
@ -85,7 +85,7 @@ void PropagateRequiresGradSimpleNode(Node* node) {
bool should_require =
std::any_of(inputs.begin(), inputs.end(), getRequiresGrad);
for (Value* output : outputs) {
if (auto type = output->type()->cast<ProfiledTensorType>()) {
if (auto type = output->type()->cast<TensorType>()) {
if (type->scalarType()) {
setRequiresGrad(
output, should_require && at::isFloatingType(*type->scalarType()));

View File

@ -38,7 +38,7 @@ bool isValidArgumentForRunning(Value* v) {
// allow constants
if (toIValue(v))
return true;
if (ProfiledTensorTypePtr tt = v->type()->cast<ProfiledTensorType>()) {
if (TensorTypePtr tt = v->type()->cast<TensorType>()) {
if (!tt->scalarType()) {
return false;
}
@ -158,7 +158,7 @@ class ShapePropagator {
if (auto iv = toIValue(v)) {
return *iv;
}
if (ProfiledTensorTypePtr type = type_->cast<ProfiledTensorType>()) {
if (TensorTypePtr type = type_->cast<TensorType>()) {
if (type->isComplete()) {
auto attype = type->device()->is_cpu() ? at::CPU(*type->scalarType())
: at::CUDA(*type->scalarType());
@ -185,10 +185,10 @@ class ShapePropagator {
// for each node in the schema with type Tensor, extract the T type
// returns c10::nullopt if any Tensor in the schema does not have a known
// shape ignores non-tensor in the list of inputs
c10::optional<std::vector<ProfiledTensorTypePtr>> gatherTensorTypes(
c10::optional<std::vector<TensorTypePtr>> gatherTensorTypes(
Node* node,
bool complete = false) {
std::vector<ProfiledTensorTypePtr> tensor_types;
std::vector<TensorTypePtr> tensor_types;
auto& schema = node->schema();
auto& args = schema.arguments();
@ -201,7 +201,7 @@ class ShapePropagator {
if (args[i].type()->isSubtypeOf(ListType::ofTensors())) {
return c10::nullopt;
} else if (args[i].type()->isSubtypeOf(TensorType::get())) {
if (auto type = node->input(i)->type()->cast<ProfiledTensorType>()) {
if (auto type = node->input(i)->type()->cast<TensorType>()) {
if (complete && !type->isComplete()) {
return c10::nullopt;
}
@ -236,14 +236,14 @@ class ShapePropagator {
void broadcastBinary(
Node* node,
std::vector<ProfiledTensorTypePtr>& types,
std::vector<TensorTypePtr>& types,
size_t idx1,
size_t idx2) {
auto expected_size = at::infer_size(
*types[idx1]->sizes().concrete_sizes(),
*types[idx2]->sizes().concrete_sizes());
auto broadcast = [&](size_t input_idx) {
ProfiledTensorTypePtr input_type = types.at(input_idx);
TensorTypePtr input_type = types.at(input_idx);
if (input_type->sizes() == expected_size)
return;
auto graph = node->owningGraph();
@ -260,8 +260,8 @@ class ShapePropagator {
};
broadcast(idx1);
broadcast(idx2);
types[0] = node->inputs().at(idx1)->type()->expect<ProfiledTensorType>();
types[1] = node->inputs().at(idx2)->type()->expect<ProfiledTensorType>();
types[0] = node->inputs().at(idx1)->type()->expect<TensorType>();
types[1] = node->inputs().at(idx2)->type()->expect<TensorType>();
}
OperatorSet cannot_propagate_shape_by_running_it = {
@ -369,13 +369,12 @@ class ShapePropagator {
void PropagateCatShape(Node* cat_node) {
static const auto propagate_complete =
[this](Node* node, at::ArrayRef<Value*> tensors) -> bool {
auto input_types = fmap(tensors, [](Value* v) {
return v->type()->cast<ProfiledTensorType>();
});
auto input_types =
fmap(tensors, [](Value* v) { return v->type()->cast<TensorType>(); });
if (!std::all_of(
input_types.begin(),
input_types.end(),
[](const ProfiledTensorTypePtr& tp) {
[](const TensorTypePtr& tp) {
return tp != nullptr && tp->isComplete();
})) {
return false;
@ -407,7 +406,7 @@ class ShapePropagator {
static const auto propagate = [](Node* node,
at::ArrayRef<Value*> tensors) -> bool {
for (Value* v : tensors) {
if (auto type = v->type()->cast<ProfiledTensorType>()) {
if (auto type = v->type()->cast<TensorType>()) {
node->output()->setType(type->dimensionedOnly());
return true;
}
@ -463,7 +462,7 @@ class ShapePropagator {
default_device = inp->toDevice();
}
}
node->output()->setType(ProfiledTensorType::create(
node->output()->setType(TensorType::create(
default_type, default_device, dims, /*requires_grad=*/c10::nullopt));
}
@ -535,10 +534,10 @@ class ShapePropagator {
TypePtr typ = node->input()->type();
if (typ->isSubtypeOf(IntType::get()) ||
typ->isSubtypeOf(BoolType::get())) {
node->output()->setType(ProfiledTensorType::create(
node->output()->setType(TensorType::create(
at::kLong, at::kCPU, 0, /*requires_grad=*/c10::nullopt));
} else if (node->input()->type()->isSubtypeOf(FloatType::get())) {
node->output()->setType(ProfiledTensorType::create(
node->output()->setType(TensorType::create(
at::kDouble, at::kCPU, 0, /*requires_grad=*/c10::nullopt));
}
return;
@ -591,7 +590,7 @@ class ShapePropagator {
}
case prim::ConstantChunk: {
Value* tensor = node->input();
if (auto type = tensor->type()->cast<ProfiledTensorType>()) {
if (auto type = tensor->type()->cast<TensorType>()) {
type = type->dimensionedOnly();
for (Value* output : node->outputs()) {
output->setType(type);
@ -683,9 +682,8 @@ class ShapePropagator {
// primitive/tensor outputs.
bool PropagateTensorShapeOnNode(Node* node, bool insert_expands) {
static const auto broadcast =
[](std::vector<ProfiledTensorTypePtr>& tensor_types,
size_t arg_for_type) -> ProfiledTensorTypePtr {
static const auto broadcast = [](std::vector<TensorTypePtr>& tensor_types,
size_t arg_for_type) -> TensorTypePtr {
if (tensor_types.size() == 1) {
return tensor_types[0]->dimensionedOnly();
}
@ -699,14 +697,14 @@ class ShapePropagator {
max_dims = std::max(*max_dims, *type->dim());
}
}
return ProfiledTensorType::create(
return TensorType::create(
any_type->scalarType(),
any_type->device(),
max_dims,
/*requires_grad=*/c10::nullopt);
};
using type_vec_t = std::vector<ProfiledTensorTypePtr>;
using type_vec_t = std::vector<TensorTypePtr>;
// Formula is expected to return a vector of length equal to the number of
// tensor outputs of the node, or an empty vector which implies that it
// failed to propagate.
@ -815,9 +813,9 @@ class ShapePropagator {
"aten::zeros_like(Tensor self) -> Tensor",
},
[](Node* node) -> type_vec_t {
auto input_type =
node->input(0)->type()->cast<ProfiledTensorType>();
return input_type ? type_vec_t{input_type->dimensionedOnly()} : type_vec_t{};
auto input_type = node->input(0)->type()->cast<TensorType>();
return input_type ? type_vec_t{input_type->dimensionedOnly()}
: type_vec_t{};
}};
// Requirements:
@ -929,10 +927,9 @@ class ShapePropagator {
return {};
}};
static const auto any_tensor_type =
[](Node* node) -> ProfiledTensorTypePtr {
static const auto any_tensor_type = [](Node* node) -> TensorTypePtr {
for (Value* input : node->inputs()) {
if (auto type = input->type()->cast<ProfiledTensorType>()) {
if (auto type = input->type()->cast<TensorType>()) {
if (type->dim().has_value()) {
return type;
}
@ -1035,8 +1032,7 @@ class ShapePropagator {
"aten::prelu(Tensor self, Tensor weight) -> Tensor",
},
[](Node* node) -> type_vec_t {
if (auto type =
node->input(0)->type()->cast<ProfiledTensorType>()) {
if (auto type = node->input(0)->type()->cast<TensorType>()) {
return {type->dimensionedOnly()};
}
return {};
@ -1065,8 +1061,7 @@ class ShapePropagator {
"aten::any(Tensor self) -> Tensor",
},
[](Node* node) -> type_vec_t {
if (auto type =
node->input(0)->type()->cast<ProfiledTensorType>()) {
if (auto type = node->input(0)->type()->cast<TensorType>()) {
return {type->withDim(0)};
}
return {};
@ -1086,7 +1081,7 @@ class ShapePropagator {
},
[](Node* node) -> type_vec_t {
at::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
if (auto type = node->input(0)->type()->cast<ProfiledTensorType>()) {
if (auto type = node->input(0)->type()->cast<TensorType>()) {
auto ret = type->withDim(0);
if(maybe_dtype_option && !maybe_dtype_option->isNone()) {
return {ret->withScalarType(maybe_dtype_option->toScalarType())};
@ -1105,51 +1100,53 @@ class ShapePropagator {
// tensor outputs : 1
// Additionally:
// - First input should be the only tensor input
static const register_formula_for all_reduce_ops_with_integer_upcast_and_dtype{
{
"aten::sum(Tensor self, *, int? dtype) -> Tensor",
"aten::prod(Tensor self, *, int? dtype) -> Tensor",
},
[](Node* node) -> type_vec_t {
if (auto type =
node->input(0)->type()->cast<ProfiledTensorType>()) {
type = type->withDim(0);
at::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
if( maybe_dtype_option && ! maybe_dtype_option->isNone() ) {
return {type->withScalarType(maybe_dtype_option->toScalarType())};
}
if (type->scalarType()) {
return { at::isFloatingType(*type->scalarType())
? type
: type->withScalarType(at::kLong)};
} else {
return { type };
}
}
return {};
}};
static const auto reduce_op_handler =
[](Node* node,
int64_t num_reduced_dim = 0,
bool upcast_integer = false,
c10::optional<IValue> opt_dtype = c10::nullopt) -> type_vec_t {
if(auto type = node->input(0)->type()->cast<ProfiledTensorType>()) {
if (!type->scalarType() || !type->dim()) {
static const register_formula_for
all_reduce_ops_with_integer_upcast_and_dtype{
{
"aten::sum(Tensor self, *, int? dtype) -> Tensor",
"aten::prod(Tensor self, *, int? dtype) -> Tensor",
},
[](Node* node) -> type_vec_t {
if (auto type = node->input(0)->type()->cast<TensorType>()) {
type = type->withDim(0);
at::optional<IValue> maybe_dtype_option =
node->get(attr::dtype);
if (maybe_dtype_option && !maybe_dtype_option->isNone()) {
return {
type->withScalarType(maybe_dtype_option->toScalarType())};
}
if (type->scalarType()) {
return {at::isFloatingType(*type->scalarType())
? type
: type->withScalarType(at::kLong)};
} else {
return {type};
}
}
return {};
}
if( opt_dtype && ! opt_dtype->isNone() ) {
type = type->withScalarType(opt_dtype->toScalarType());
} else if(upcast_integer && !at::isFloatingType(*type->scalarType())) {
type = type->withScalarType(at::kLong);
}
if (*type->dim() >= num_reduced_dim && num_reduced_dim > 0) {
return {type->withDim(*type->dim() - num_reduced_dim)};
} else {
return {type};
}
}
}};
static const auto reduce_op_handler = [](Node* node,
int64_t num_reduced_dim = 0,
bool upcast_integer = false,
c10::optional<IValue> opt_dtype =
c10::nullopt) -> type_vec_t {
if (auto type = node->input(0)->type()->cast<TensorType>()) {
if (!type->scalarType() || !type->dim()) {
return {};
}
if (opt_dtype && !opt_dtype->isNone()) {
type = type->withScalarType(opt_dtype->toScalarType());
} else if (upcast_integer && !at::isFloatingType(*type->scalarType())) {
type = type->withScalarType(at::kLong);
}
if (*type->dim() >= num_reduced_dim && num_reduced_dim > 0) {
return {type->withDim(*type->dim() - num_reduced_dim)};
} else {
return {type};
}
}
return {};
};
static const auto multidim_reduce_with_keepdim =
@ -1175,8 +1172,7 @@ class ShapePropagator {
"aten::argmin(Tensor self, int? dim, bool keepdim) -> Tensor",
},
[](Node* node) -> type_vec_t {
if (auto type =
node->input(0)->type()->cast<ProfiledTensorType>()) {
if (auto type = node->input(0)->type()->cast<TensorType>()) {
if (node->input(1)->type()->kind() == c10::TypeKind::NoneType) {
return {type->withDim(0)};
} else {
@ -1300,7 +1296,7 @@ class ShapePropagator {
(maybe_dtype_option->isNone() ? at::kDouble
: maybe_dtype_option->toScalarType());
return {ProfiledTensorType::create(
return {TensorType::create(
dtype, device, dim, /*requires_grad=*/c10::nullopt)};
};
@ -1324,9 +1320,8 @@ class ShapePropagator {
"aten::zeros_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor",
},
[](Node* node) -> type_vec_t {
if (auto type = node->namedInput(attr::self)
->type()
->cast<ProfiledTensorType>()) {
if (auto type =
node->namedInput(attr::self)->type()->cast<TensorType>()) {
if (type->dim()) {
return factory_with_ndim(node, *type->dim());
}
@ -1398,9 +1393,8 @@ class ShapePropagator {
"aten::_cast_Short(Tensor self, bool non_blocking) -> Tensor",
},
[](Node* node) -> type_vec_t {
if (auto type = node->namedInput(attr::self)
->type()
->cast<ProfiledTensorType>()) {
if (auto type =
node->namedInput(attr::self)->type()->cast<TensorType>()) {
return {type->withScalarType(get_cast_scalar_type(node))};
}
return {};
@ -1428,7 +1422,7 @@ class ShapePropagator {
// This section implements shape prop for an assorted set of nodes that only
// need partial information about their input types.
const auto input_type = [node](size_t index) {
auto result = node->input(index)->type()->cast<ProfiledTensorType>();
auto result = node->input(index)->type()->cast<TensorType>();
if (result) {
result = result->dimensionedOnly();
}
@ -1562,12 +1556,11 @@ class ShapePropagator {
// The code below implements formulas that need type information for all
// their tensor inputs, and have exactly one output.
std::vector<ProfiledTensorTypePtr> tensor_types;
std::vector<TensorTypePtr> tensor_types;
static const auto reshape_prop =
[](Node* node,
Symbol shape_input,
const std::vector<ProfiledTensorTypePtr>& tensor_types)
-> ProfiledTensorTypePtr {
const std::vector<TensorTypePtr>& tensor_types) -> TensorTypePtr {
if (auto list_size = determineListSize(node->namedInput(shape_input))) {
return tensor_types.at(0)->withDim(*list_size);
}
@ -1593,7 +1586,7 @@ class ShapePropagator {
return reshape_prop(node, attr::size, tensor_types);
} else if (node->matches("aten::as_tensor(Tensor data, *, ScalarType? dtype, Device? device) -> Tensor")) {
TypePtr input_type = node->inputs().at(0)->type();
if (auto type = input_type->cast<ProfiledTensorType>()) {
if (auto type = input_type->cast<TensorType>()) {
if (type->scalarType() && type->device()) {
at::ScalarType default_type = *type->scalarType();
c10::Device default_device = *type->device();
@ -1615,7 +1608,7 @@ class ShapePropagator {
default_device = inp->toDevice();
}
}
node->output()->setType(ProfiledTensorType::create(
node->output()->setType(TensorType::create(
default_type,
default_device,
type->dim(),
@ -1722,7 +1715,7 @@ class ShapePropagator {
bool PropagateCompleteShapeOnNode(
Node* node,
bool insert_expands,
std::vector<ProfiledTensorTypePtr> tensor_types) {
std::vector<TensorTypePtr> tensor_types) {
// For expensive ops we can directly encode their shape propagation
// here, otherwise we fallback to running a fake version of the op
// to get a quick and dirty propagation.
@ -1775,7 +1768,7 @@ class ShapePropagator {
auto rhs_sizes = rhs_type->sizes().concrete_sizes().value();
SHAPE_ASSERT(
*lhs_type->sizes().size() == 2 && *rhs_type->sizes().size() == 2);
node->output()->setType(ProfiledTensorType::createContiguous(
node->output()->setType(TensorType::createContiguous(
*lhs_type->scalarType(),
*lhs_type->device(),
at::IntArrayRef{lhs_sizes[0], rhs_sizes[1]}));
@ -1941,7 +1934,7 @@ class ShapePropagator {
(int64_t)*tensor_types.at(0)->sizes().size()};
at::IntArrayRef dims(dim_vec);
node->output()->setType(
ProfiledTensorType::createContiguous(at::kLong, at::kCPU, dims));
TensorType::createContiguous(at::kLong, at::kCPU, dims));
return true;
} else if (node->kind() == ::c10::onnx::Reshape) {
setUnshapedType(node);

View File

@ -16,7 +16,7 @@ void specializeAutogradZero(Graph& g) {
for (Value* input : g.inputs()) {
const auto& tp = input->type();
if (auto tt = tp->cast<ProfiledTensorType>()) {
if (auto tt = tp->cast<TensorType>()) {
if (tt->autogradZero() && *tt->autogradZero()) {
state[input] = State::Zero;
} else {

View File

@ -29,22 +29,26 @@ void ProfilingRecord::instrumentBlock(Block* block) {
auto pn = createProfileNode(nullptr, {i});
auto pno = pn->addOutput();
bool first = true;
pno->setType(TensorType::get());
std::function<void(Stack&)> shape_profiler = [this, pno](Stack& stack) {
IValue t;
pop(stack, t);
if (t.isTensor()) {
auto pttp = ProfiledTensorType::create(t.toTensor());
std::lock_guard<std::mutex> lock(this->mutex_);
if (auto type = pno->type()->cast<ProfiledTensorType>()) {
pno->setType(type->merge(pttp));
} else {
pno->setType(pttp);
}
}
// passing t through
push(stack, t);
};
std::function<void(Stack&)> shape_profiler =
[this, pno, first](Stack& stack) mutable {
IValue t;
pop(stack, t);
if (t.isTensor()) {
auto pttp = TensorType::create(t.toTensor());
std::lock_guard<std::mutex> lock(this->mutex_);
if (auto type = pno->type()->cast<TensorType>()) {
if (!first) {
pttp = pttp->merge(type);
}
pno->setType(pttp);
first = false;
}
}
// passing t through
push(stack, t);
};
pn->setCallback(shape_profiler);
pn->insertBefore(n);
@ -77,11 +81,10 @@ std::unique_ptr<ProfilingRecord> ProfilingRecord::instrumentGraph(
return pr;
}
ProfiledTensorTypePtr ProfilingRecord::toProfiledTensorTypePtr(
const IValue& ival) {
TensorTypePtr ProfilingRecord::toTensorTypePtr(const IValue& ival) {
if (ival.isTensor()) {
auto tensor = ival.toTensor();
return ProfiledTensorType::create(tensor);
return TensorType::create(tensor);
}
return {nullptr};

View File

@ -13,7 +13,7 @@
namespace torch {
namespace jit {
using ::c10::ProfiledTensorTypePtr;
using ::c10::TensorTypePtr;
struct ProfilingRecord {
// N.B. ProfilingRecord's copy and move c-tor are disabled, so we won't
@ -21,7 +21,7 @@ struct ProfilingRecord {
// are captured in callbacks_
ProfilingRecord(const ProfilingRecord&) = delete;
ProfilingRecord(ProfilingRecord&&) noexcept = delete;
static ProfiledTensorTypePtr toProfiledTensorTypePtr(const IValue& ival);
static TensorTypePtr toTensorTypePtr(const IValue& ival);
TORCH_API static std::unique_ptr<ProfilingRecord> instrumentGraph(
const std::shared_ptr<Graph>& graph);

View File

@ -93,7 +93,7 @@ inline MatchTypeReturn tryToInferType(py::handle input) {
// Try tensor types
if (THPVariable_Check(input.ptr())) {
auto tensor = py::cast<at::Tensor>(input);
return MatchTypeReturn(ProfiledTensorType::create(tensor));
return MatchTypeReturn(TensorType::create(tensor));
}
if (input.is(py::none())) {
@ -310,8 +310,7 @@ inline IValue toIValue(
const TypePtr& type,
c10::optional<int32_t> N) {
switch (type->kind()) {
case TypeKind::TensorType:
case TypeKind::ProfiledTensorType: {
case TypeKind::TensorType: {
auto var = py::cast<autograd::Variable>(obj);
if (var.is_sparse()) {
AT_WARN(
@ -391,7 +390,6 @@ inline IValue toIValue(
}
return repeated;
}
case TypeKind::ProfiledTensorType:
case TypeKind::TensorType:
return c10::impl::toList(py::cast<std::vector<at::Tensor>>(obj));
default:

View File

@ -631,15 +631,14 @@ void initPythonIRBindings(PyObject* module_) {
.def(
"dim",
[](Type& t) {
auto vshape =
ProfiledTensorType::create(t.shared_from_this())->sizes();
auto vshape = t.shared_from_this()->expect<TensorType>()->sizes();
return vshape.size() ? py::cast(*vshape.size())
: py::cast<py::none>(Py_None);
})
.def(
"sizes",
[](Type& t) -> py::object {
if (auto ptt = t.expect<ProfiledTensorType>()) {
if (auto ptt = t.expect<TensorType>()) {
if (auto cs = ptt->sizes().concrete_sizes()) {
return py::cast(*cs);
}
@ -649,7 +648,7 @@ void initPythonIRBindings(PyObject* module_) {
.def(
"sizes",
[](Type& t) -> py::object {
if (auto ptt = t.expect<ProfiledTensorType>()) {
if (auto ptt = t.expect<TensorType>()) {
if (auto cs = ptt->strides().concrete_sizes()) {
return py::cast(*cs);
}
@ -660,13 +659,13 @@ void initPythonIRBindings(PyObject* module_) {
"contiguous",
[](Type& t) {
return std::static_pointer_cast<Type>(
t.expect<ProfiledTensorType>()->contiguous());
t.expect<TensorType>()->contiguous());
})
.def(
"scalarType",
[](Type& t) {
auto scalar_type =
ProfiledTensorType::create(t.shared_from_this())->scalarType();
t.shared_from_this()->expect<TensorType>()->scalarType();
return (scalar_type) ? toString(*scalar_type) : nullptr;
})
.def(

View File

@ -280,7 +280,7 @@ struct VISIBILITY_HIDDEN ModuleSelf : public Self {
};
static TypePtr getTensorType(const at::Tensor& t, bool complete) {
auto r = ProfiledTensorType::create(t);
auto r = TensorType::create(t);
if (!complete) {
r = r->dimensionedOnly();
}
@ -354,8 +354,8 @@ static std::shared_ptr<Graph> _assign_output_shapes(
for (size_t i = 0; i < outputs.size(); ++i) {
auto scalar_type = outputs[i].scalar_type();
auto sizes = outputs[i].sizes();
auto type = torch::jit::ProfiledTensorType::createContiguous(
scalar_type, at::kCPU, sizes);
auto type =
torch::jit::TensorType::createContiguous(scalar_type, at::kCPU, sizes);
retval->outputs()[i]->setType(type);
}
return retval;

View File

@ -147,7 +147,7 @@ TypePtr SchemaTypeParser::parseRefinedTensor() {
L.expect('*');
num_dims++;
});
ptr = at::ProfiledTensorType::create(
ptr = at::TensorType::create(
dtype,
at::DeviceType::CPU,
c10::VaryingShape(num_dims),
@ -166,8 +166,7 @@ TypePtr SchemaTypeParser::parseRefinedTensor() {
dims.push_back(dim);
});
at::IntArrayRef dims_ref(dims);
ptr = at::ProfiledTensorType::create(
dtype, at::DeviceType::CPU, dims_ref, false);
ptr = at::TensorType::create(dtype, at::DeviceType::CPU, dims_ref, false);
}
return ptr;
}

View File

@ -21,11 +21,7 @@ struct SymbolicVariable {
return g.addInput()->setType(std::move(type));
}
std::vector<int64_t> sizes() const {
return v->type()
->expect<ProfiledTensorType>()
->sizes()
.concrete_sizes()
.value();
return v->type()->expect<TensorType>()->sizes().concrete_sizes().value();
}
void addAsOutput() const {
v->owningGraph()->registerOutput(v);
@ -317,7 +313,7 @@ struct SymbolicVariable {
return v->owningGraph()->insertConstant(std::move(value));
}
SymbolicVariable typeLike(SymbolicVariable other) const {
if (auto other_type = other.v->type()->cast<ProfiledTensorType>())
if (auto other_type = other.v->type()->cast<TensorType>())
v->setType(other_type->contiguous());
return *this;
}
@ -340,7 +336,7 @@ struct SymbolicVariable {
SymbolicVariable typeLikeWithScalarType(
SymbolicVariable other,
at::ScalarType type) const {
if (auto other_type = other.v->type()->cast<ProfiledTensorType>()) {
if (auto other_type = other.v->type()->cast<TensorType>()) {
auto new_type = other_type->withScalarType(type)->contiguous();
v->setType(new_type);
}
@ -349,8 +345,8 @@ struct SymbolicVariable {
SymbolicVariable typeLikeWithRhsScalarType(
SymbolicVariable other,
SymbolicVariable rhs) const {
auto other_type = other.v->type()->cast<ProfiledTensorType>();
auto rhs_type = rhs.v->type()->cast<ProfiledTensorType>();
auto other_type = other.v->type()->cast<TensorType>();
auto rhs_type = rhs.v->type()->cast<TensorType>();
if (other_type && rhs_type && other_type->isComplete() &&
rhs_type->isComplete()) {
auto new_type =

View File

@ -32,18 +32,8 @@ from functools import wraps
#
# In general, we should avoid depending on the type of Tensor Values contained
# within the trace graph. However, this is sometimes unavoidable (due to ONNX
# spec requirements, etc). If you are implementing a symbolic and need Tensor
# type information, note that there are several levels of Tensor types, defined
# in aten/src/ATen/core/jit_type.h:
#
# TensorType - This is a Tensor, but we don't know anything about its
# properties (e.g. scalar type, # dims, shapes).
# Appears as `Tensor` in graph print-outs.
# ProfiledTensorType <: TensorType - Denotes a Tensor for which we know the
# concrete sizes in addition to the information
# contained in TensorTyper. This adds a sizes()
# method which can be used to retrieve the
# concrete sizes.
# spec requirements, etc). The TensorType object has accessors for these properties
# that return the property if it is statically known and return nullopt otherwise.
#
# In general, we should prefer to rely on the least specific information possible.
# For example, not relying on tensor properties at all is better than relying