remove CompleteTensorType

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/24169

Test Plan: Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D16765329

Pulled By: zdevito

fbshipit-source-id: 88560cefba635c3d586a3e4dee67f9b1d901a642
This commit is contained in:
Zachary DeVito 2019-08-15 13:28:01 -07:00 committed by Facebook Github Bot
parent 5ca612b55e
commit 0cbd7fa46f
29 changed files with 301 additions and 460 deletions

View File

@ -30,8 +30,6 @@ using OptNameList = c10::optional<std::vector<std::string>>;
#define C10_FORALL_TYPES(_) \
_(TensorType) \
_(DimensionedTensorType) \
_(CompleteTensorType) \
_(AutogradZeroTensorType) \
_(TupleType) \
_(ListType) \
@ -292,7 +290,7 @@ struct TensorType;
using TensorTypePtr = std::shared_ptr<TensorType>;
// This type represents a single Tensor, with an unknown shape.
// Subtype hierarchy for Tensor Types (TensorType as the base type):
// CompleteTensorType <: DimensionedTensorType <: TensorType
// ProfiledTensorType <: TensorType
// AutogradZeroTensorType <: TensorType
struct CAFFE2_API TensorType : public Type {
static TensorTypePtr create() {
@ -353,98 +351,6 @@ struct CAFFE2_API AutogradZeroTensorType : public TensorType {
AutogradZeroTensorType() : TensorType(TypeKind::AutogradZeroTensorType) {}
};
struct DimensionedTensorType;
using DimensionedTensorTypePtr = std::shared_ptr<DimensionedTensorType>;
// This type represents a single Tensor with a specific size
struct CAFFE2_API DimensionedTensorType : public TensorType {
template <typename... T>
static DimensionedTensorTypePtr create(T&&... all) {
return DimensionedTensorTypePtr(new DimensionedTensorType(
std::forward<T>(all)...)); // NOLINT(modernize-make-shared)
}
at::ScalarType scalarType() const {
return scalar_type_;
}
at::Device device() const {
return device_;
}
int64_t dim() const {
return dim_;
}
bool requires_grad() const override {
return requires_grad_;
}
DimensionedTensorTypePtr toScalarType(at::ScalarType type) {
auto t = DimensionedTensorType::create(*this);
t->scalar_type_ = type;
return t;
}
DimensionedTensorTypePtr withDim(size_t new_dim) {
auto t = DimensionedTensorType::create(*this);
t->dim_ = new_dim;
return t;
}
DimensionedTensorTypePtr withRequiresGrad(bool req) {
auto t = DimensionedTensorType::create(*this);
t->requires_grad_ = req;
return t;
}
bool operator==(const Type& rhs) const override {
if (rhs.kind() != TypeKind::DimensionedTensorType)
return false;
auto rt = rhs.expect<DimensionedTensorType>();
return scalarType() == rt->scalarType() && device() == rt->device() &&
dim() == rt->dim();
}
bool isSubtypeOf(const TypePtr rhs) const override {
return rhs->kind() == TypeKind::TensorType ||
(rhs->kind() == TypeKind::DimensionedTensorType &&
Type::isSubtypeOf(rhs)) ||
TensorType::isSubtypeOf(rhs);
}
bool isSubclass(const TypeKind kind) const override {
return kind == TypeKind::TensorType ||
kind == TypeKind::DimensionedTensorType;
}
std::string str() const override {
// str is used for user-facing error messages, where we
// don't want to reveal underlying size information.
return "Tensor";
}
static const TypeKind Kind = TypeKind::DimensionedTensorType;
protected:
DimensionedTensorType(
const at::Tensor& tensor,
TypeKind kind = TypeKind::DimensionedTensorType)
: DimensionedTensorType(
tensor.scalar_type(),
tensor.device(),
tensor.dim(),
tensor.is_variable() && tensor.requires_grad(),
kind) {}
DimensionedTensorType(
at::ScalarType scalar_type,
at::Device device,
int64_t dim,
bool requires_grad = true,
TypeKind kind = TypeKind::DimensionedTensorType)
: TensorType(kind),
scalar_type_(scalar_type),
requires_grad_(at::isFloatingType(scalar_type) && requires_grad),
device_(device),
dim_(dim) {}
at::ScalarType scalar_type_;
bool requires_grad_;
at::Device device_;
int64_t dim_;
};
template <typename T>
inline c10::optional<T> merge_primitive(
const c10::optional<T>& a,
@ -463,9 +369,14 @@ struct CAFFE2_API VaryingShape {
VaryingShape(const std::vector<int64_t>& vec)
: size_(vec.size()), dims_(vec.begin(), vec.end()) {}
VaryingShape(c10::ArrayRef<int64_t> vec)
: size_(vec.size()), dims_(vec.begin(), vec.end()) {}
VaryingShape(c10::optional<size_t> size)
: size_(size), dims_(size ? size.value() : 0) {}
VaryingShape(size_t size) : VaryingShape(c10::optional<size_t>(size)) {}
bool operator==(const VaryingShape& other) const {
return size_ == other.size_ && dims_ == other.dims_;
}
@ -502,6 +413,18 @@ struct CAFFE2_API VaryingShape {
return sizes;
}
bool isComplete() const {
if (!size_) {
return false;
}
for (auto d : dims_) {
if(!d) {
return false;
}
}
return true;
}
private:
c10::optional<size_t> size_;
std::vector<c10::optional<int64_t>> dims_;
@ -518,21 +441,11 @@ struct CAFFE2_API ProfiledTensorType : public TensorType {
}
static ProfiledTensorTypePtr create(const TypePtr& tptr) {
if (auto dtt = tptr->cast<DimensionedTensorType>()) {
at::VaryingShape vshape(c10::optional<size_t>(dtt->dim()));
return ProfiledTensorType::create(
{dtt->scalarType()},
{dtt->device()},
vshape,
vshape,
{dtt->requires_grad()});
}
if (auto ptt = tptr->cast<ProfiledTensorType>()) {
return ptt;
}
if (tptr->isSubclass(TypeKind::TensorType)) {
if (tptr->isSubtypeOf(TensorType::get())) {
c10::optional<size_t> sz;
return ProfiledTensorType::create(
{}, {}, VaryingShape{sz}, VaryingShape{sz}, {});
@ -564,6 +477,34 @@ struct CAFFE2_API ProfiledTensorType : public TensorType {
requires_grad);
}
// overloaded create variadic template argument as it could not distinguish
// initializer list
static ProfiledTensorTypePtr createContiguous(
at::ScalarType scalar_type,
at::Device device,
at::IntArrayRef sizes) {
return create(
scalar_type,
device,
VaryingShape(sizes),
VaryingShape(contiguousStridesOf(sizes)),
c10::nullopt);
}
static ProfiledTensorTypePtr create(
at::ScalarType scalar_type,
at::Device device,
at::IntArrayRef sizes,
at::IntArrayRef strides) {
return create(
scalar_type,
device,
VaryingShape(sizes),
c10::VaryingShape(strides),
c10::nullopt);
}
static TypePtr fromNumberType(TypePtr typ);
static TypePtr fromBoolType();
c10::optional<size_t> dim() const {
return sizes().size();
}
@ -641,6 +582,20 @@ struct CAFFE2_API ProfiledTensorType : public TensorType {
return copy;
}
ProfiledTensorTypePtr withSizesStrides(
at::IntArrayRef sizes,
at::IntArrayRef strides) const {
auto cloned = clone();
cloned->sizes_ = VaryingShape(sizes);
cloned->strides_ = VaryingStrides(strides);
return cloned;
}
ProfiledTensorTypePtr withSizes(at::IntArrayRef sizes) const {
return withSizesStrides(
sizes, contiguousStridesOf(sizes));
}
ProfiledTensorTypePtr dimensionedOnly() const {
auto copy = clone();
copy->sizes_ = VaryingShape(sizes().size());
@ -648,6 +603,16 @@ struct CAFFE2_API ProfiledTensorType : public TensorType {
return copy;
}
ProfiledTensorTypePtr contiguous() const {
auto cloned = clone();
if (auto concrete_sizes = sizes().concrete_sizes()) {
cloned->strides_ = VaryingShape(contiguousStridesOf(*concrete_sizes));
} else {
cloned->strides_ = VaryingShape(sizes().size());
}
return cloned;
}
ProfiledTensorTypePtr merge(ProfiledTensorTypePtr other) {
auto scalar_type = merge_primitive(scalarType(), other->scalarType());
auto dev = merge_primitive(device(), other->device());
@ -656,6 +621,13 @@ struct CAFFE2_API ProfiledTensorType : public TensorType {
auto gr = merge_primitive(requiresGrad(), other->requiresGrad());
return ProfiledTensorType::create(scalar_type, dev, sz, srs, gr);
}
// is all information about the type specified except for autograd?
// This replaces the notion of a 'CompleteTensorType' that used to exist
// in the type-hierarchy. Excluding require_grad and autogradZero allows
// this to match the old behavior.
bool isComplete() const {
return scalar_type_ && device_ && sizes_.isComplete() && strides_.isComplete();
}
static const TypeKind Kind = TypeKind::ProfiledTensorType;
@ -690,138 +662,6 @@ struct CAFFE2_API ProfiledTensorType : public TensorType {
scalar_type_, device_, sizes_, strides_, requires_grad_));
}
c10::optional<at::ScalarType> scalar_type_;
c10::optional<at::Device> device_;
VaryingShape sizes_;
VaryingStrides strides_;
c10::optional<bool> requires_grad_;
};
struct CompleteTensorType;
using CompleteTensorTypePtr = std::shared_ptr<CompleteTensorType>;
// This type represents a single Tensor with a specific size
struct CAFFE2_API CompleteTensorType : public DimensionedTensorType {
template <typename... T>
static CompleteTensorTypePtr create(T&&... all) {
return CompleteTensorTypePtr(new CompleteTensorType(
std::forward<T>(all)...)); // NOLINT(modernize-make-shared)
}
// overloaded create variadic template argument as it could not distinguish
// initializer list
static CompleteTensorTypePtr create(
at::ScalarType scalar_type,
at::Device device,
at::IntArrayRef sizes) {
return CompleteTensorTypePtr(new CompleteTensorType(
scalar_type, device, sizes)); // NOLINT(modernize-make-shared)
}
static CompleteTensorTypePtr create(
at::ScalarType scalar_type,
at::Device device,
at::IntArrayRef sizes,
at::IntArrayRef strides) {
return CompleteTensorTypePtr(new CompleteTensorType(
scalar_type, device, sizes, strides)); // NOLINT(modernize-make-shared)
}
const std::vector<int64_t>& sizes() const {
return sizes_;
}
const std::vector<int64_t>& strides() const {
return strides_;
}
TypePtr withSizesStrides(at::IntArrayRef sizes, at::IntArrayRef strides)
const {
return CompleteTensorType::create(scalar_type_, device_, sizes, strides);
}
TypePtr withSizes(at::IntArrayRef sizes) const {
return withSizesStrides(
sizes, CompleteTensorType::contiguousStridesOf(sizes));
}
CompleteTensorTypePtr contiguous() const {
auto t = CompleteTensorType::create(*this);
t->strides_ = CompleteTensorType::contiguousStridesOf(sizes_);
return t;
}
CompleteTensorTypePtr toScalarType(at::ScalarType type) {
auto t = CompleteTensorType::create(*this);
t->scalar_type_ = type;
return t;
}
bool operator==(const Type& rhs) const override {
if (rhs.kind() != kind()) {
return false;
}
auto rt = rhs.expect<CompleteTensorType>();
return scalarType() == rt->scalarType() && sizes() == rt->sizes() &&
strides() == rt->strides() && device() == rt->device();
}
bool isSubtypeOf(const TypePtr rhs) const override {
if (rhs->kind() == TypeKind::DimensionedTensorType)
return *expect<DimensionedTensorType>() == *rhs;
return rhs->kind() == TypeKind::TensorType || TensorType::isSubtypeOf(rhs);
}
bool isSubclass(const TypeKind kind) const override {
return kind == TypeKind::TensorType ||
kind == TypeKind::DimensionedTensorType ||
kind == TypeKind::CompleteTensorType;
}
std::string str() const override {
// str is used for user-facing error messages, where we
// don't want to reveal underlying size information.
return "Tensor";
}
size_t numel() const {
size_t prod = 1;
for (auto s : sizes()) {
prod *= s;
}
return prod;
}
static const TypeKind Kind = TypeKind::CompleteTensorType;
static TypePtr fromNumberType(TypePtr typ);
static TypePtr fromBoolType();
private:
CompleteTensorType(const at::Tensor& tensor)
: DimensionedTensorType(tensor, TypeKind::CompleteTensorType),
sizes_(tensor.sizes().vec()),
strides_(tensor.strides().vec()) {}
CompleteTensorType(
at::ScalarType scalar_type,
at::Device device,
at::IntArrayRef sizes,
bool requires_grad = true)
: CompleteTensorType(
scalar_type,
device,
sizes,
CompleteTensorType::contiguousStridesOf(sizes),
requires_grad) {}
CompleteTensorType(
at::ScalarType scalar_type,
at::Device device,
at::IntArrayRef sizes,
at::IntArrayRef strides,
bool requires_grad = true)
: DimensionedTensorType(
scalar_type,
device,
sizes.size(),
requires_grad,
TypeKind::CompleteTensorType),
sizes_(sizes.vec()),
strides_(strides.vec()) {}
static std::vector<int64_t> contiguousStridesOf(at::IntArrayRef sizes) {
std::vector<int64_t> strides(sizes.size());
if (sizes.empty()) // zero-dim case
@ -832,9 +672,12 @@ struct CAFFE2_API CompleteTensorType : public DimensionedTensorType {
}
return strides;
}
std::vector<int64_t> sizes_;
std::vector<int64_t> strides_;
c10::optional<at::ScalarType> scalar_type_;
c10::optional<at::Device> device_;
VaryingShape sizes_;
VaryingStrides strides_;
c10::optional<bool> requires_grad_;
};
struct ListType;
@ -1383,19 +1226,18 @@ inline TypePtr unshapedType(const TypePtr& type) {
return type->withContained(fmap(type->containedTypes(), unshapedType));
}
inline TypePtr CompleteTensorType::fromNumberType(TypePtr typ) {
inline TypePtr ProfiledTensorType::fromNumberType(TypePtr typ) {
if (typ->isSubtypeOf(IntType::get())) {
return CompleteTensorType::create(at::kLong, at::kCPU, {});
return ProfiledTensorType::createContiguous(at::kLong, at::kCPU, {});
} else if (typ->isSubtypeOf(FloatType::get())) {
return CompleteTensorType::create(at::kFloat, at::kCPU, {});
return ProfiledTensorType::createContiguous(at::kFloat, at::kCPU, {});
} else if (typ->isSubtypeOf(BoolType::get())) {
return CompleteTensorType::create(at::kLong, at::kCPU, {});
return ProfiledTensorType::createContiguous(at::kLong, at::kCPU, {});
}
AT_ERROR("unknown number type", typ->str());
}
inline TypePtr CompleteTensorType::fromBoolType() {
return CompleteTensorType::create(at::kLong, at::kCPU, {});
inline TypePtr ProfiledTensorType::fromBoolType() {
return ProfiledTensorType::createContiguous(at::kLong, at::kCPU, {});
}
inline at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) {

View File

@ -6,25 +6,7 @@
namespace c10 {
std::ostream& operator<<(std::ostream & out, const Type & t) {
if(auto value = t.cast<CompleteTensorType>()) {
out << toString(value->scalarType()) << "(";
auto& sizes = value->sizes();
auto& strides = value->strides();
AT_ASSERT(sizes.size() == strides.size());
for (size_t i = 0; i < sizes.size(); i++) {
if (i > 0) {
out << ", ";
}
// TODO: figure out a good way to output strides, or
// add a "debug" printing mode which adds the extra stuff
out << sizes[i]; // << "%" << strides[i];
int64_t expected = i + 1 < sizes.size() ? sizes[i+1]*strides[i+1] : 1;
if (strides[i] != expected) {
out << "!"; //mark non-contiguous
}
}
out << ")";
} else if (auto value = t.cast<ProfiledTensorType>()) {
if (auto value = t.cast<ProfiledTensorType>()) {
if (value->scalarType().has_value()) {
out << toString(*value->scalarType());
if (!value->sizes().size().has_value()) {
@ -151,7 +133,7 @@ ListTypePtr ListType::ofBools() {
// the type, like in the tracer.
TypePtr incompleteInferTypeFrom(const IValue& value) {
if (value.isTensor()) {
return CompleteTensorType::create(value.toTensor());
return ProfiledTensorType::create(value.toTensor());
} else if (value.isDouble()) {
return FloatType::get();
} else if (value.isInt()) {

View File

@ -217,7 +217,7 @@ def method_tests():
('expand', (S, 1), (S, S, S), 'new_dim', (True,)),
('expand', (1,), (S, S, S), '1_element', (True,)),
('expand', (1, S), (1, 1, S), 'new_dim_front_old_front_1', (True,)),
('expand', (), (dont_convert(()),), 'scalar_to_scalar', (True,)),
('expand', (), (dont_convert(()),), 'scalar_to_scalar'),
('expand', (), (1, 3, 2), 'scalar_to_dims', (True,)),
('expand_as', (S, 1, 1), (torch.rand(S, S, S),), '', (True,)),
('exp', (S, S, S), NO_ARGS, '', (True,)),

View File

@ -172,7 +172,7 @@ void testADFormulas() {
void testDifferentiate() {
auto graph = std::make_shared<Graph>();
at::ScalarType s = at::ScalarType::Float;
auto type = CompleteTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1});
auto type = ProfiledTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1});
// Build up a fake graph
auto a = SymbolicVariable::asNewInput(*graph, type);

View File

@ -177,7 +177,7 @@ void testRegisterFusionCachesKernel(std::ostream& out = std::cout) {
auto createGraphWithNames = [](std::string cname, std::string dname) {
auto graph = std::make_shared<Graph>();
at::ScalarType s = at::ScalarType::Float;
auto type = CompleteTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1});
auto type = ProfiledTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1});
auto a = SymbolicVariable::asNewInput(*graph, type);
auto b = SymbolicVariable::asNewInput(*graph, type);
auto c = a * b;

View File

@ -360,7 +360,7 @@ void testATenNativeBatchNorm() {
void testCustomFusion() {
auto graph = std::make_shared<Graph>();
at::ScalarType s = at::ScalarType::Float;
auto type = CompleteTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1});
auto type = ProfiledTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1});
auto a = SymbolicVariable::asNewInput(*graph, type);
auto b = SymbolicVariable::asNewInput(*graph, type);
auto c = a * b;
@ -394,7 +394,7 @@ void testCustomFusion() {
void testCustomFusionNestedBlocks() {
auto g = std::make_shared<Graph>();
at::ScalarType s = at::ScalarType::Float;
auto type = CompleteTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1});
auto type = ProfiledTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1});
// test CustomFusion in nested blocks;
auto a = SymbolicVariable::asNewInput(*g, type);

View File

@ -2408,7 +2408,7 @@ graph(%Ra, %Rb):
def test_onnx_transpose_incomplete_tensor_type(self):
# Smoke test to get us into the state where we are attempting to export
# a transpose op, where the input is a TensorType rather than a
# CompleteTensorType. This would previously not work, since we would
# ProfiledTensorType. This would previously not work, since we would
# take the size of the input and use the length of its sizes as the
# number of dimensions in the permutation.
class Foo(torch.jit.ScriptModule):
@ -8270,9 +8270,9 @@ a")
graph = _propagate_shapes(tensor_unifying.graph, (a, b, c), False)
if_outputs = list(graph.findNode("prim::If").outputs())
self.assertTrue(if_outputs[0].type().str() == "Float(2, 2)")
self.assertTrue(if_outputs[1].type().str() == "Tensor(2, *)")
self.assertTrue(if_outputs[2].type().str() == "Tensor(2, 4)")
self.assertTrue(if_outputs[0].type().str() == "Float(*, *)")
self.assertTrue(if_outputs[1].type().str() == "Tensor(*, *)")
self.assertTrue(if_outputs[2].type().str() == "Tensor(*, *)")
def test_list_unify(self):
# allowing a unififed int?[] would cause a runtime error b/c

View File

@ -349,7 +349,7 @@ struct CompleteArgumentInfo {
operator TypePtr() const {
if (!defined())
return TensorType::get();
return CompleteTensorType::create(
return ProfiledTensorType::create(
type(), ConvertIntToCPUOrCUDA(device()), sizes(), strides());
}

View File

@ -308,8 +308,10 @@ class GradientHelper {
// reutrns them as a tuple
auto sizes = node->namedInput(attr::self)
->type()
->expect<CompleteTensorType>()
->sizes();
->expect<ProfiledTensorType>()
->sizes()
.concrete_sizes()
.value();
return {grads.at(0).reshape(sizes), nullptr};
} else if (

View File

@ -332,7 +332,6 @@ TorchScript, unlike Python, is statically typed, so every Value has a Type assoc
* TensorType - the root type of all Tensors in the system.
* ProfiledTensorType - a tensor with optionally refined information. It may optional know its device, type, requires_grad state, the number of dimensions.
If it does know the number of dimensions it may optionally know the size of a particular dimension.
* CompleteTensorType - A subtype of TensorType that adds fixed sizes (e.g. a [3 x 4] cuda tensor). This only appears from tracing at the moment.
* Tuples - e.g. Tuple[Tensor, Int]. Each member of the tuple is statically typed and the length of the tuple is statically known.
* List[T] - e.g. List[Tensor]. Mutable lists of a particular type.
* Optional[T] - e.g. Optional[Tensor], either the Tensor value or None.

View File

@ -221,11 +221,15 @@ void EncoderBase::EncodeValueInfo(
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes) {
std::string name = n->debugName();
v->set_name(name);
if (CompleteTensorTypePtr node_type = n->type()->cast<CompleteTensorType>()) {
if (ProfiledTensorTypePtr node_type = n->type()->cast<ProfiledTensorType>()) {
if (!node_type->isComplete()) {
return;
}
onnx::TypeProto* t = v->mutable_type();
onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();
onnx::TensorShapeProto* shape = tensor_type->mutable_shape();
const std::vector<std::int64_t>& sizes = node_type->sizes();
std::vector<std::int64_t> sizes =
node_type->sizes().concrete_sizes().value();
for (size_t i = 0; i < sizes.size(); i++) {
shape->add_dim();
if ((dynamic_axes.find(name) != dynamic_axes.end()) &&
@ -236,7 +240,8 @@ void EncoderBase::EncodeValueInfo(
shape->mutable_dim(i)->set_dim_value(sizes[i]);
}
}
tensor_type->set_elem_type(ATenTypeToOnnxType(node_type->scalarType()));
tensor_type->set_elem_type(
ATenTypeToOnnxType(node_type->scalarType().value()));
} else if (BoolTypePtr node_type = n->type()->cast<BoolType>()) {
onnx::TypeProto* t = v->mutable_type();
onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();

View File

@ -256,7 +256,8 @@ std::shared_ptr<FusedKernel> compileKernel(
auto scalar_type = ProfiledTensorType::create(o->type())->scalarType();
TORCH_INTERNAL_ASSERT(scalar_type);
auto type = CompleteTensorType::create(*scalar_type, device, sizes);
auto type =
ProfiledTensorType::createContiguous(*scalar_type, device, sizes);
output_desc.emplace_back(type);
const auto& desc = output_desc.back();

View File

@ -41,8 +41,11 @@ struct TORCH_API TensorDesc {
TensorDesc(const at::Tensor& t)
: TensorDesc(t.scalar_type(), t.sizes(), t.strides()) {}
TensorDesc(const c10::CompleteTensorTypePtr& type)
: TensorDesc(type->scalarType(), type->sizes(), type->strides()) {}
TensorDesc(const c10::ProfiledTensorTypePtr& type)
: TensorDesc(
type->scalarType().value(),
type->sizes().concrete_sizes().value(),
type->strides().concrete_sizes().value()) {}
// number of dimensions after contiguity compression
size_t nDim() const {

View File

@ -314,10 +314,10 @@ static void checkSameDevice(const Node* node) {
bool has_device = false;
c10::optional<at::Device> device = c10::nullopt;
auto checkValue = [&](const Value* v) {
if (CompleteTensorTypePtr type = v->type()->cast<CompleteTensorType>()) {
if (!has_device) {
if (ProfiledTensorTypePtr type = v->type()->cast<ProfiledTensorType>()) {
if (type->device() && !has_device) {
has_device = true;
device = type->device();
device = *type->device();
} else {
AT_ASSERT(device == type->device());
}
@ -669,13 +669,7 @@ void Graph::remapTypes(const std::function<TypePtr(TypePtr)>& type_map) {
}
void Value::inferTypeFrom(const at::Tensor& output) {
if (output.is_mkldnn()) {
// mkldnn tensor as opaque tensor doesn't have strides, so we can
// not create a CompleteTensorType
setType(ProfiledTensorType::create(output));
return;
}
setType(CompleteTensorType::create(output));
setType(ProfiledTensorType::create(output));
}
bool Value::mustBeNone() const {
@ -1427,7 +1421,7 @@ Node* Graph::createDict(
Node* Graph::createNumToTensor(Value* value) {
auto typ = value->type();
Node* result = create(prim::NumToTensor, {value});
result->output()->setType(CompleteTensorType::fromNumberType(std::move(typ)));
result->output()->setType(ProfiledTensorType::fromNumberType(std::move(typ)));
return result;
}

View File

@ -168,7 +168,10 @@ struct Value {
return type()->requires_grad();
}
bool isCompleteTensor() const {
return type()->kind() == TypeKind::CompleteTensorType;
if (auto pt = type()->cast<ProfiledTensorType>()) {
return pt->isComplete();
}
return false;
}
TORCH_API bool mustBeNone() const;
TORCH_API bool mustNotBeNone() const;

View File

@ -47,9 +47,9 @@ static void EraseNumberTypesOnBlock(Block* block) {
default: {
for (auto o : it->outputs()) {
if (o->type()->isSubtypeOf(NumberType::get())) {
o->setType(CompleteTensorType::fromNumberType(o->type()));
o->setType(ProfiledTensorType::fromNumberType(o->type()));
} else if (o->type()->isSubtypeOf(BoolType::get())) {
o->setType(CompleteTensorType::fromBoolType());
o->setType(ProfiledTensorType::fromBoolType());
}
}
} break;

View File

@ -55,7 +55,7 @@ void FixupONNXLoops(Block* block) {
cond->setType(BoolType::create());
Value* i = sub_block->inputs()[0];
i->setType(CompleteTensorType::fromNumberType(IntType::get()));
i->setType(ProfiledTensorType::fromNumberType(IntType::get()));
// add cast to condition input inside the loop.
Value* next_cond_val = sub_block->outputs()[0];

View File

@ -129,9 +129,16 @@ void fuseBroadcast(Block* b) {
// Not all broadcasts are supported by ONNX broadcast.
c10::optional<size_t> axis = fusibleExpandTo(
unexpanded_input->type()
->expect<CompleteTensorType>()
->sizes(), // from
n->output()->type()->expect<CompleteTensorType>()->sizes()); // to
->expect<ProfiledTensorType>()
->sizes()
.concrete_sizes()
.value(), // from
n->output()
->type()
->expect<ProfiledTensorType>()
->sizes()
.concrete_sizes()
.value()); // to
if (axis == c10::nullopt)
continue;
@ -289,15 +296,15 @@ void pushPackingPastRnn(Block* b) {
// unhygenic way, Pytorch ends up propagating an incorrect type.
// Until a long-term cleanup comes around, we can fix this by
// resetting the size to the correct value.
CompleteTensorTypePtr oldType =
rnn->inputs().at(0)->type()->cast<CompleteTensorType>();
if (oldType) {
ProfiledTensorTypePtr oldType =
rnn->inputs().at(0)->type()->cast<ProfiledTensorType>();
if (oldType && oldType->isComplete()) {
std::vector<int64_t> new_sizes;
new_sizes.push_back(oldType->sizes().at(0));
new_sizes.push_back(oldType->sizes().at(1));
new_sizes.push_back(*oldType->sizes()[0]);
new_sizes.push_back(*oldType->sizes()[1]);
new_sizes.push_back(rnn->i(attr::hidden_size));
CompleteTensorTypePtr newType = CompleteTensorType::create(
oldType->scalarType(), oldType->device(), new_sizes);
ProfiledTensorTypePtr newType = ProfiledTensorType::createContiguous(
*oldType->scalarType(), *oldType->device(), new_sizes);
next->outputs().at(0)->setType(newType);
}

View File

@ -28,7 +28,7 @@ static void PrepareDivisionForONNXOnBlock(Block* block) {
it->replaceInput(0, floattensor_inputs[0]);
it->replaceInput(1, floattensor_inputs[1]);
it->output()->setType(
CompleteTensorType::fromNumberType(FloatType::get()));
ProfiledTensorType::fromNumberType(FloatType::get()));
}
}
}

View File

@ -46,9 +46,11 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) {
// x.expand(x.size()) == x
if (auto input_type = node->namedInput(attr::self)
->type()
->cast<CompleteTensorType>()) {
->cast<ProfiledTensorType>()) {
auto expanded_sizes = node->get<c10::List<int64_t>>(attr::size);
if (!expanded_sizes.has_value() || c10::impl::toVector(*expanded_sizes) == input_type->sizes()) {
auto input_type_sizes = input_type->sizes().concrete_sizes();
if (expanded_sizes.has_value() && input_type_sizes &&
c10::impl::toVector(*expanded_sizes) == *input_type_sizes) {
GRAPH_UPDATE(
*node,
" (x.expand(x.size()) == x) is replaced with ",

View File

@ -38,8 +38,11 @@ bool isValidArgumentForRunning(Value* v) {
// allow constants
if (toIValue(v))
return true;
if (CompleteTensorTypePtr tt = v->type()->cast<CompleteTensorType>()) {
return !at::isIntegralType(tt->scalarType(), /*includeBool=*/false);
if (ProfiledTensorTypePtr tt = v->type()->cast<ProfiledTensorType>()) {
if (!tt->scalarType()) {
return false;
}
return !at::isIntegralType(*tt->scalarType(), /*includeBool=*/false);
}
return v->type()->isSubtypeOf(FloatType::get());
}
@ -155,14 +158,19 @@ class ShapePropagator {
if (auto iv = toIValue(v)) {
return *iv;
}
if (CompleteTensorTypePtr type = type_->cast<CompleteTensorType>()) {
auto attype = type->device().is_cpu() ? at::CPU(type->scalarType())
: at::CUDA(type->scalarType());
at::DeviceGuard device_guard(type->device());
auto t =
at::empty_strided(type->sizes(), type->strides(), attype.options())
.zero_();
return autograd::make_variable(t, /*requires_grad=*/false);
if (ProfiledTensorTypePtr type = type_->cast<ProfiledTensorType>()) {
if (type->isComplete()) {
auto attype = type->device()->is_cpu() ? at::CPU(*type->scalarType())
: at::CUDA(*type->scalarType());
at::DeviceGuard device_guard(*type->device());
auto t = at::empty_strided(
*type->sizes().concrete_sizes(),
*type->strides().concrete_sizes(),
attype.options())
.zero_();
return autograd::make_variable(t, /*requires_grad=*/false);
}
// fallthrough
} else if (type_->isSubtypeOf(FloatType::get())) {
return 0.f;
}
@ -177,9 +185,10 @@ class ShapePropagator {
// for each node in the schema with type Tensor, extract the T type
// returns c10::nullopt if any Tensor in the schema does not have a known
// shape ignores non-tensor in the list of inputs
template <typename T>
c10::optional<std::vector<std::shared_ptr<T>>> gatherTensorTypes(Node* node) {
std::vector<std::shared_ptr<T>> tensor_types;
c10::optional<std::vector<ProfiledTensorTypePtr>> gatherTensorTypes(
Node* node,
bool complete = false) {
std::vector<ProfiledTensorTypePtr> tensor_types;
auto& schema = node->schema();
auto& args = schema.arguments();
@ -192,7 +201,10 @@ class ShapePropagator {
if (args[i].type()->isSubtypeOf(ListType::ofTensors())) {
return c10::nullopt;
} else if (args[i].type()->isSubtypeOf(TensorType::get())) {
if (auto type = node->input(i)->type()->cast<T>()) {
if (auto type = node->input(i)->type()->cast<ProfiledTensorType>()) {
if (complete && !type->isComplete()) {
return c10::nullopt;
}
tensor_types.push_back(type);
} else {
return c10::nullopt;
@ -224,13 +236,14 @@ class ShapePropagator {
void broadcastBinary(
Node* node,
std::vector<CompleteTensorTypePtr>& types,
std::vector<ProfiledTensorTypePtr>& types,
size_t idx1,
size_t idx2) {
auto expected_size =
at::infer_size(types[idx1]->sizes(), types[idx2]->sizes());
auto expected_size = at::infer_size(
*types[idx1]->sizes().concrete_sizes(),
*types[idx2]->sizes().concrete_sizes());
auto broadcast = [&](size_t input_idx) {
CompleteTensorTypePtr input_type = types.at(input_idx);
ProfiledTensorTypePtr input_type = types.at(input_idx);
if (input_type->sizes() == expected_size)
return;
auto graph = node->owningGraph();
@ -247,8 +260,8 @@ class ShapePropagator {
};
broadcast(idx1);
broadcast(idx2);
types[0] = node->inputs().at(idx1)->type()->expect<CompleteTensorType>();
types[1] = node->inputs().at(idx2)->type()->expect<CompleteTensorType>();
types[0] = node->inputs().at(idx1)->type()->expect<ProfiledTensorType>();
types[1] = node->inputs().at(idx2)->type()->expect<ProfiledTensorType>();
}
OperatorSet cannot_propagate_shape_by_running_it = {
@ -357,17 +370,19 @@ class ShapePropagator {
static const auto propagate_complete =
[this](Node* node, at::ArrayRef<Value*> tensors) -> bool {
auto input_types = fmap(tensors, [](Value* v) {
return v->type()->cast<CompleteTensorType>();
return v->type()->cast<ProfiledTensorType>();
});
if (!std::all_of(
input_types.begin(),
input_types.end(),
[](const CompleteTensorTypePtr& tp) { return tp != nullptr; })) {
[](const ProfiledTensorTypePtr& tp) {
return tp != nullptr && tp->isComplete();
})) {
return false;
}
if (!node->is_constant(attr::dim))
return false;
std::vector<int64_t> sizes = input_types[0]->sizes();
std::vector<int64_t> sizes = *input_types[0]->sizes().concrete_sizes();
const int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
const int64_t ndim = sizes.size();
@ -376,7 +391,7 @@ class ShapePropagator {
sizes[dim] = 0;
for (auto& tp : input_types) {
auto& tp_sizes = tp->sizes();
auto tp_sizes = tp->sizes().concrete_sizes().value();
if (sizes.size() != tp_sizes.size())
return false;
for (int64_t i = 0; i < ndim; ++i) {
@ -621,7 +636,7 @@ class ShapePropagator {
}
if (auto maybe_complete_types =
gatherTensorTypes<CompleteTensorType>(node)) {
gatherTensorTypes(node, /*complete=*/true)) {
if (PropagateCompleteShapeOnNode(
node, insert_expands, std::move(*maybe_complete_types))) {
return;
@ -840,8 +855,7 @@ class ShapePropagator {
"aten::atan2(Tensor self, Tensor other) -> Tensor",
},
[this](Node* node) -> type_vec_t {
if (auto maybe_tensor_types =
gatherTensorTypes<ProfiledTensorType>(node)) {
if (auto maybe_tensor_types = gatherTensorTypes(node)) {
AT_ASSERT(maybe_tensor_types->size() >= 2);
auto first_scalar_type = (*maybe_tensor_types)[0]->scalarType();
auto second_scalar_type = (*maybe_tensor_types)[1]->scalarType();
@ -865,8 +879,7 @@ class ShapePropagator {
"aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
},
[this](Node* node) -> type_vec_t {
if (auto maybe_tensor_types =
gatherTensorTypes<ProfiledTensorType>(node)) {
if (auto maybe_tensor_types = gatherTensorTypes(node)) {
return {broadcast(*maybe_tensor_types, 0)};
}
return {};
@ -897,8 +910,7 @@ class ShapePropagator {
"aten::__irshift__(Tensor self, Scalar other) -> Tensor",
},
[this](Node* node) -> type_vec_t {
if (auto maybe_tensor_types =
gatherTensorTypes<ProfiledTensorType>(node)) {
if (auto maybe_tensor_types = gatherTensorTypes(node)) {
return {broadcast(*maybe_tensor_types, 0)};
}
return {};
@ -911,8 +923,7 @@ class ShapePropagator {
"aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
},
[this](Node* node) -> type_vec_t {
if (auto maybe_tensor_types =
gatherTensorTypes<ProfiledTensorType>(node)) {
if (auto maybe_tensor_types = gatherTensorTypes(node)) {
return {broadcast(*maybe_tensor_types, 1)};
}
return {};
@ -971,9 +982,9 @@ class ShapePropagator {
"aten::ne(Tensor self, Scalar other) -> Tensor",
},
[this](Node* node) -> type_vec_t {
if (auto maybe_tensor_types =
gatherTensorTypes<ProfiledTensorType>(node)) {
return {broadcast(*maybe_tensor_types, 0)->withScalarType(at::kBool)};
if (auto maybe_tensor_types = gatherTensorTypes(node)) {
return {
broadcast(*maybe_tensor_types, 0)->withScalarType(at::kBool)};
}
return {};
}};
@ -1694,8 +1705,7 @@ class ShapePropagator {
}
return nullptr;
};
if (auto maybe_tensor_types =
gatherTensorTypes<ProfiledTensorType>(node)) {
if (auto maybe_tensor_types = gatherTensorTypes(node)) {
tensor_types = std::move(*maybe_tensor_types);
} else {
return false;
@ -1712,7 +1722,7 @@ class ShapePropagator {
bool PropagateCompleteShapeOnNode(
Node* node,
bool insert_expands,
std::vector<CompleteTensorTypePtr> tensor_types) {
std::vector<ProfiledTensorTypePtr> tensor_types) {
// For expensive ops we can directly encode their shape propagation
// here, otherwise we fallback to running a fake version of the op
// to get a quick and dirty propagation.
@ -1761,17 +1771,19 @@ class ShapePropagator {
} else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
auto lhs_type = tensor_types.at(0);
auto rhs_type = tensor_types.at(1);
auto lhs_sizes = lhs_type->sizes().concrete_sizes().value();
auto rhs_sizes = rhs_type->sizes().concrete_sizes().value();
SHAPE_ASSERT(
lhs_type->sizes().size() == 2 && rhs_type->sizes().size() == 2);
node->output()->setType(CompleteTensorType::create(
lhs_type->scalarType(),
lhs_type->device(),
at::IntArrayRef{lhs_type->sizes().at(0), rhs_type->sizes().at(1)}));
*lhs_type->sizes().size() == 2 && *rhs_type->sizes().size() == 2);
node->output()->setType(ProfiledTensorType::createContiguous(
*lhs_type->scalarType(),
*lhs_type->device(),
at::IntArrayRef{lhs_sizes[0], rhs_sizes[1]}));
return true;
} else if (node->matches("aten::t(Tensor self) -> Tensor")) {
auto tp = tensor_types.at(0);
auto sizes = tp->sizes();
auto strides = tp->strides();
auto sizes = tp->sizes().concrete_sizes().value();
auto strides = tp->strides().concrete_sizes().value();
SHAPE_ASSERT(sizes.size() == 2);
std::swap(sizes.at(0), sizes.at(1));
std::swap(strides.at(0), strides.at(1));
@ -1782,12 +1794,13 @@ class ShapePropagator {
"aten::narrow(Tensor self, int dim, int start, int length) -> Tensor",
/*const_inputs=*/{attr::dim, attr::length})) {
auto tp = tensor_types.at(0);
auto sizes = tp->sizes();
auto sizes = tp->sizes().concrete_sizes().value();
int64_t dim = node->get<int64_t>(attr::dim).value();
int64_t length = node->get<int64_t>(attr::length).value();
SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size());
sizes.at(dim) = length;
node->output()->setType(tp->withSizesStrides(sizes, tp->strides()));
node->output()->setType(
tp->withSizesStrides(sizes, tp->strides().concrete_sizes().value()));
return true;
} else if (node->matches("aten::sum(Tensor self, *, int? dtype) -> Tensor")) {
node->output()->setType(tensor_types.at(0)->withSizes({}));
@ -1796,7 +1809,7 @@ class ShapePropagator {
"aten::sum(Tensor self, int[] dim, bool keepdim, *, int? dtype) -> Tensor",
/*const_inputs=*/{attr::dim, attr::keepdim})) {
auto& tp = tensor_types.at(0);
auto sizes = tp->sizes();
auto sizes = tp->sizes().concrete_sizes().value();
auto dims = node->get<c10::List<int64_t>>(attr::dim).value();
bool keepdim = node->get<bool>(attr::keepdim).value();
std::reverse(dims.begin(), dims.end());
@ -1814,8 +1827,8 @@ class ShapePropagator {
"aten::squeeze(Tensor self, int dim) -> Tensor",
/*const_inputs=*/attr::dim)) {
auto& tp = tensor_types.at(0);
auto sizes = tp->sizes();
auto strides = tp->strides();
auto sizes = tp->sizes().concrete_sizes().value();
auto strides = tp->strides().concrete_sizes().value();
int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size());
if (sizes.at(dim) == 1) {
@ -1828,8 +1841,8 @@ class ShapePropagator {
"aten::unsqueeze(Tensor self, int dim) -> Tensor",
/*const_inputs=*/attr::dim)) {
auto& tp = tensor_types.at(0);
auto sizes = tp->sizes();
auto strides = tp->strides();
auto sizes = tp->sizes().concrete_sizes().value();
auto strides = tp->strides().concrete_sizes().value();
int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) <= sizes.size());
int64_t new_stride = dim >= static_cast<int64_t>(sizes.size())
@ -1860,7 +1873,7 @@ class ShapePropagator {
if (inferred) {
SHAPE_ASSERT(size_product != 0);
size_t numel = 1;
for (int64_t s : tensor_types.at(0)->sizes())
for (int64_t s : tensor_types.at(0)->sizes().concrete_sizes().value())
numel *= s;
int64_t inferred_size = numel / size_product;
sizes[inferred_idx] = inferred_size;
@ -1874,8 +1887,8 @@ class ShapePropagator {
node->output()->setType(node->namedInput(attr::self)->type());
} else {
// This will be a copy, so the result will be contiguous
node->output()->setType(
tensor_types.at(1)->withSizes(tensor_types.at(0)->sizes()));
node->output()->setType(tensor_types.at(1)->withSizes(
tensor_types.at(0)->sizes().concrete_sizes().value()));
}
return true;
} else if (
@ -1885,9 +1898,10 @@ class ShapePropagator {
auto tp = tensor_types.at(0);
std::vector<int64_t> sizes, strides;
std::tie(sizes, strides) = at::inferExpandGeometry(
tp->sizes(),
tp->strides(),
c10::impl::toVector(node->get<c10::List<int64_t>>(attr::size).value()));
tp->sizes().concrete_sizes().value(),
tp->strides().concrete_sizes().value(),
c10::impl::toVector(
node->get<c10::List<int64_t>>(attr::size).value()));
node->output()->setType(tp->withSizesStrides(sizes, strides));
return true;
} else if (
@ -1897,26 +1911,26 @@ class ShapePropagator {
auto ten = tensor_types.at(0);
auto index = tensor_types.at(1);
int64_t dim = node->get<int64_t>(attr::dim).value();
SHAPE_ASSERT(index->sizes().size() == 1);
SHAPE_ASSERT(*index->sizes().size() == 1);
SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < ten->sizes().size());
std::vector<int64_t> sizes = ten->sizes();
sizes[dim] = index->sizes()[0];
std::vector<int64_t> sizes = ten->sizes().concrete_sizes().value();
sizes[dim] = index->sizes()[0].value();
node->output()->setType(ten->withSizes(sizes));
return true;
} else if (node->matches(
"aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]",
/*const_inputs=*/{attr::chunks, attr::dim})) {
auto input_type = tensor_types.at(0);
auto sizes = input_type->sizes();
const auto& strides = input_type->strides();
auto sizes = input_type->sizes().concrete_sizes().value();
auto strides = input_type->strides().concrete_sizes().value();
int64_t dim = node->get<int64_t>(attr::dim).value();
int64_t chunks = node->get<int64_t>(attr::chunks).value();
sizes[dim] /= chunks;
for (Value* output : node->outputs()) {
output->setType(input_type->withSizesStrides(sizes, strides));
}
if (input_type->sizes().at(dim) % chunks != 0) {
sizes[dim] = input_type->sizes().at(dim) % chunks;
if (*input_type->sizes()[dim] % chunks != 0) {
sizes[dim] = *input_type->sizes()[dim] % chunks;
node->outputs().back()->setType(
input_type->withSizesStrides(sizes, strides));
}
@ -1924,10 +1938,10 @@ class ShapePropagator {
} else if (node->kind() == ::c10::onnx::Shape) {
SHAPE_ASSERT(node->inputs().size() == 1 && node->outputs().size() == 1);
std::vector<int64_t> dim_vec = {
(int64_t)tensor_types.at(0)->sizes().size()};
(int64_t)*tensor_types.at(0)->sizes().size()};
at::IntArrayRef dims(dim_vec);
node->output()->setType(
CompleteTensorType::create(at::kLong, at::kCPU, dims));
ProfiledTensorType::createContiguous(at::kLong, at::kCPU, dims));
return true;
} else if (node->kind() == ::c10::onnx::Reshape) {
setUnshapedType(node);

View File

@ -93,15 +93,7 @@ inline MatchTypeReturn tryToInferType(py::handle input) {
// Try tensor types
if (THPVariable_Check(input.ptr())) {
auto tensor = py::cast<at::Tensor>(input);
if (tensor.is_mkldnn()) {
// mkldnn tensor as opaque tensor doesn't have strides, so we can
// not create a CompleteTensorType
return MatchTypeReturn(ProfiledTensorType::create(tensor));
}
// TODO: maybe unshape this type if this is used for script instead of
// tracing
return MatchTypeReturn(CompleteTensorType::create(tensor));
return MatchTypeReturn(ProfiledTensorType::create(tensor));
}
if (input.is(py::none())) {
@ -320,9 +312,7 @@ inline IValue toIValue(
switch (type->kind()) {
case TypeKind::TensorType:
case TypeKind::AutogradZeroTensorType:
case TypeKind::ProfiledTensorType:
case TypeKind::DimensionedTensorType:
case TypeKind::CompleteTensorType: {
case TypeKind::ProfiledTensorType: {
auto var = py::cast<autograd::Variable>(obj);
if (var.is_sparse()) {
AT_WARN(

View File

@ -627,9 +627,7 @@ void initPythonIRBindings(PyObject* module_) {
s << t;
return s.str();
})
.def("kind", [](const Type& t) {
return typeKindToString(t.kind());
})
.def("kind", [](const Type& t) { return typeKindToString(t.kind()); })
.def(
"dim",
[](Type& t) {
@ -640,15 +638,29 @@ void initPythonIRBindings(PyObject* module_) {
})
.def(
"sizes",
[](Type& t) { return t.expect<CompleteTensorType>()->sizes(); })
[](Type& t) -> py::object {
if (auto ptt = t.expect<ProfiledTensorType>()) {
if (auto cs = ptt->sizes().concrete_sizes()) {
return py::cast(*cs);
}
}
return py::none();
})
.def(
"strides",
[](Type& t) { return t.expect<CompleteTensorType>()->strides(); })
"sizes",
[](Type& t) -> py::object {
if (auto ptt = t.expect<ProfiledTensorType>()) {
if (auto cs = ptt->strides().concrete_sizes()) {
return py::cast(*cs);
}
}
return py::none();
})
.def(
"contiguous",
[](Type& t) {
return std::static_pointer_cast<Type>(
t.expect<CompleteTensorType>()->contiguous());
t.expect<ProfiledTensorType>()->contiguous());
})
.def(
"scalarType",

View File

@ -279,26 +279,19 @@ struct VISIBILITY_HIDDEN ModuleSelf : public Self {
const py::object& pyModule_;
};
static TypePtr getTensorType(const at::Tensor& t, const TypeKind type_kind) {
switch (type_kind) {
case TypeKind::ProfiledTensorType:
return ProfiledTensorType::create(t);
case TypeKind::CompleteTensorType: {
auto scalar_type = t.scalar_type();
auto sizes = t.sizes();
return CompleteTensorType::create(scalar_type, at::kCPU, sizes);
}
default:
throw std::runtime_error(
"Attempted to call getTensorType for type kind other than ProfiledTensorType or CompleteTensorType.");
static TypePtr getTensorType(const at::Tensor& t, bool complete) {
auto r = ProfiledTensorType::create(t);
if (!complete) {
r = r->dimensionedOnly();
}
return r;
}
static TupleTypePtr getTupleTensorType(
const Stack::const_iterator& s_iter,
const Stack::const_iterator& s_iter_end,
const TypePtr& tupleType,
const TypeKind type_kind) {
bool complete) {
AT_ASSERT(tupleType->kind() == TupleType::Kind);
AT_ASSERT(s_iter != s_iter_end);
@ -306,27 +299,24 @@ static TupleTypePtr getTupleTensorType(
for (const auto& subType : tupleType->containedTypes()) {
if (subType->kind() == TupleType::Kind) {
types.push_back(
getTupleTensorType(s_iter + 1, s_iter_end, subType, type_kind));
getTupleTensorType(s_iter + 1, s_iter_end, subType, complete));
} else {
types.push_back(getTensorType(s_iter->toTensor(), type_kind));
types.push_back(getTensorType(s_iter->toTensor(), complete));
}
}
return TupleType::create(types);
}
static void setInputTensorTypes(
Graph& g,
const Stack& stack,
const TypeKind type_kind = TypeKind::ProfiledTensorType) {
static void setInputTensorTypes(Graph& g, const Stack& stack, bool complete) {
at::ArrayRef<Value*> input_values = g.inputs();
auto s_iter = stack.begin();
for (auto v : input_values) {
AT_ASSERT(s_iter != stack.end());
if (v->type()->kind() == TupleType::Kind) {
AT_ASSERT(v->node()->kind() == prim::Param);
v->setType(getTupleTensorType(s_iter, stack.end(), v->type(), type_kind));
v->setType(getTupleTensorType(s_iter, stack.end(), v->type(), complete));
} else {
v->setType(getTensorType(s_iter->toTensor(), type_kind));
v->setType(getTensorType(s_iter->toTensor(), complete));
s_iter++;
}
}
@ -338,7 +328,7 @@ static std::shared_ptr<Graph> _propagate_shapes(
bool with_grad = false) {
Stack stack(inputs.begin(), inputs.end());
auto retval = graph.copy();
setInputTensorTypes(*retval, stack);
setInputTensorTypes(*retval, stack, /*complete=*/false);
PropagateInputShapes(retval);
return retval;
}
@ -349,13 +339,10 @@ static std::shared_ptr<Graph> _propagate_and_assign_input_shapes(
bool with_grad = false,
bool propagate = true) {
auto retval = graph.copy();
setInputTensorTypes(*retval, fmap<IValue>(inputs), /*complete=*/true);
if (propagate) {
setInputTensorTypes(*retval, fmap<IValue>(inputs), TypeKind::ProfiledTensorType);
PropagateInputShapes(retval);
}
setInputTensorTypes(
*retval, fmap<IValue>(inputs), TypeKind::CompleteTensorType);
return retval;
}
@ -367,8 +354,8 @@ static std::shared_ptr<Graph> _assign_output_shapes(
for (size_t i = 0; i < outputs.size(); ++i) {
auto scalar_type = outputs[i].scalar_type();
auto sizes = outputs[i].sizes();
auto type =
torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
auto type = torch::jit::ProfiledTensorType::createContiguous(
scalar_type, at::kCPU, sizes);
retval->outputs()[i]->setType(type);
}
return retval;

View File

@ -9,7 +9,7 @@
using c10::AliasInfo;
using c10::BoolType;
using c10::CompleteTensorType;
using c10::CapsuleType;
using c10::DeviceObjType;
using c10::DictType;
using c10::FloatType;
@ -18,7 +18,6 @@ using c10::GeneratorType;
using c10::IntType;
using c10::ListType;
using c10::NoneType;
using c10::CapsuleType;
using c10::NumberType;
using c10::OptionalType;
using c10::StringType;
@ -167,8 +166,8 @@ TypePtr SchemaTypeParser::parseRefinedTensor() {
dims.push_back(dim);
});
at::IntArrayRef dims_ref(dims);
ptr =
CompleteTensorType::create(dtype, at::DeviceType::CPU, dims_ref, false);
ptr = at::ProfiledTensorType::create(
dtype, at::DeviceType::CPU, dims_ref, false);
}
return ptr;
}

View File

@ -20,8 +20,12 @@ struct SymbolicVariable {
static SymbolicVariable asNewInput(Graph& g, TypePtr type) {
return g.addInput()->setType(std::move(type));
}
const std::vector<int64_t>& sizes() const {
return v->type()->expect<CompleteTensorType>()->sizes();
std::vector<int64_t> sizes() const {
return v->type()
->expect<ProfiledTensorType>()
->sizes()
.concrete_sizes()
.value();
}
void addAsOutput() const {
v->owningGraph()->registerOutput(v);
@ -313,7 +317,7 @@ struct SymbolicVariable {
return v->owningGraph()->insertConstant(std::move(value));
}
SymbolicVariable typeLike(SymbolicVariable other) const {
if (auto other_type = other.v->type()->cast<CompleteTensorType>())
if (auto other_type = other.v->type()->cast<ProfiledTensorType>())
v->setType(other_type->contiguous());
return *this;
}
@ -336,8 +340,8 @@ struct SymbolicVariable {
SymbolicVariable typeLikeWithScalarType(
SymbolicVariable other,
at::ScalarType type) const {
if (auto other_type = other.v->type()->cast<CompleteTensorType>()) {
auto new_type = other_type->toScalarType(type)->contiguous();
if (auto other_type = other.v->type()->cast<ProfiledTensorType>()) {
auto new_type = other_type->withScalarType(type)->contiguous();
v->setType(new_type);
}
return *this;
@ -345,11 +349,12 @@ struct SymbolicVariable {
SymbolicVariable typeLikeWithRhsScalarType(
SymbolicVariable other,
SymbolicVariable rhs) const {
auto other_type = other.v->type()->cast<CompleteTensorType>();
auto rhs_type = rhs.v->type()->cast<CompleteTensorType>();
if (other_type && rhs_type) {
auto other_type = other.v->type()->cast<ProfiledTensorType>();
auto rhs_type = rhs.v->type()->cast<ProfiledTensorType>();
if (other_type && rhs_type && other_type->isComplete() &&
rhs_type->isComplete()) {
auto new_type =
other_type->toScalarType(rhs_type->scalarType())->contiguous();
other_type->withScalarType(rhs_type->scalarType())->contiguous();
v->setType(new_type);
}
return *this;

View File

@ -44,17 +44,11 @@ from functools import wraps
# contained in TensorTyper. This adds a sizes()
# method which can be used to retrieve the
# concrete sizes.
# @deprecated
# CompleteTensorType <: TensorType - Denotes a Tensor for which we know the
# concrete sizes in addition to the information
# contained in TensorTyper. This adds a sizes()
# method which can be used to retrieve the
# concrete sizes.
#
# In general, we should prefer to rely on the least specific information possible.
# For example, not relying on tensor properties at all is better than relying
# on the number of dimensions which is better than relying on
# concrete shapes (CompleteTensorType). Doing so will make the export symbolics
# concrete shapes. Doing so will make the export symbolics
# more robust to different graphs.
# ---------------------------------------------------------------------------------

View File

@ -330,7 +330,7 @@ def transpose(g, self, dim0, dim1):
return self
# NB: Transpose in ONNX is actually a Permute
if self.type().kind() == "CompleteTensorType":
if self.isCompleteTensor():
axes = list(range(self.type().dim()))
axes[dim0], axes[dim1] = axes[dim1], axes[dim0]
return g.op("Transpose", self, perm_i=axes)
@ -550,7 +550,7 @@ def get_pool_ceil_padding(input, kernel_size, stride, padding):
def _max_pool(name, tuple_fn, ndims, return_indices):
@parse_args('v', 'is', 'is', 'is', 'is', 'i')
def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
if ceil_mode and input.type().kind() != "CompleteTensorType":
if ceil_mode and not input.isCompleteTensor():
return _unimplemented(name, "input size not accessible")
if set(tuple_fn(dilation)) != {1}:
return _unimplemented(name, "dilation")
@ -608,7 +608,7 @@ max_pool3d_with_indices = _max_pool("max_pool3d_with_indices", _triple, 3, retur
def _avg_pool(name, tuple_fn):
@parse_args('v', 'is', 'is', 'is', 'i', 'i', 'none')
def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override=None):
if ceil_mode and input.type().kind() != "CompleteTensorType":
if ceil_mode and not input.isCompleteTensor():
return _unimplemented(name, "input size not accessible")
if divisor_override and divisor_override.node().kind() != 'prim::Constant':
return _unimplemented(name, "divisor_override")
@ -650,11 +650,11 @@ def _adaptive_pool(name, type, tuple_fn, fn=None):
# the same dimension, which makes it possible to export it to ONNX.
# for MaxPool, GlobalMaxPool does not return indices,
# so we try using max_poolxd_with_indices, and if it is not possible
# (input is not CompleteTensorType or output size not factor of input size)
# (input is not a complete tensor or output size not factor of input size)
# then we call GlobalAveragePool and return None for the indices
if output_size == [1] * len(output_size) and type == "AveragePool":
return g.op("GlobalAveragePool", input)
if input.type().kind() != "CompleteTensorType":
if not input.isCompleteTensor():
if output_size == [1] * len(output_size):
return g.op("GlobalMaxPool", input), None
return _unimplemented(name, 'input size not accessible')
@ -1251,7 +1251,7 @@ def unsqueeze(g, self, dim):
def sort(g, self, dim, decending, out=None):
if out is not None:
_unimplemented("Sort", "Out parameter is not supported for sort")
if self.type().kind() != "CompleteTensorType":
if not self.isCompleteTensor():
return _unimplemented("Sort", "input size not accessible")
return g.op("TopK", self, k_i=self.type().sizes()[dim], axis_i=dim, outputs=2)
@ -1598,7 +1598,7 @@ def flatten(g, input, start_dim, end_dim):
if start_dim == 0 and end_dim == dim - 2 :
return g.op("Flatten", input, axis_i=end_dim + 1)
# use Reshape for cases where the output shape is not 2D
if input.type().kind() != "CompleteTensorType":
if not input.isCompleteTensor():
return _unimplemented("flatten", "input size not accessible")
input_dims = input.type().sizes()
output_dims = []
@ -1655,7 +1655,7 @@ def scatter(g, self, dim, index, src):
@parse_args('v', 'i', 'v', 'v')
def scatter_add(g, self, dim, index, src):
if self.type().kind() != "CompleteTensorType":
if not self.isCompleteTensor():
return _unimplemented("scatter_add", "input size not accessible")
dtype = self.type().scalarType()
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
@ -1689,7 +1689,7 @@ def gather(g, self, dim, index, sparse_grad=False):
@parse_args('v', 'is', 'b', 'i')
def _std(g, input, dim, unbiased, keepdim):
if input.type().kind() == "CompleteTensorType" or input.type().kind() == "DimensionedTensorType":
if input.isCompleteTensor():
sqrd = g.op("Mul", input, input)
if dim is None:
sqrdmean = g.op("ReduceMean", sqrd, keepdims_i=0)

View File

@ -52,7 +52,7 @@ class NodePy(NodeBase):
io_tensor_sizes = []
for n in list_of_node:
io_unique_names.append(n.debugName())
if n.type().kind() == 'CompleteTensorType':
if n.isCompleteTensor():
io_tensor_sizes.append(n.type().sizes())
else:
io_tensor_sizes.append(None)