From 0cbd7fa46f25a4bcd85f7d0a370cd223a7a11dab Mon Sep 17 00:00:00 2001 From: Zachary DeVito Date: Thu, 15 Aug 2019 13:28:01 -0700 Subject: [PATCH] remove CompleteTensorType Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/24169 Test Plan: Imported from OSS Reviewed By: jamesr66a Differential Revision: D16765329 Pulled By: zdevito fbshipit-source-id: 88560cefba635c3d586a3e4dee67f9b1d901a642 --- aten/src/ATen/core/jit_type.h | 338 +++++------------- aten/src/ATen/core/type.cpp | 22 +- test/common_methods_invocations.py | 2 +- test/cpp/jit/test_autodiff.h | 2 +- test/cpp/jit/test_fuser.h | 2 +- test/cpp/jit/test_misc.h | 4 +- test/test_jit.py | 8 +- torch/csrc/jit/argument_spec.h | 2 +- torch/csrc/jit/autodiff.cpp | 6 +- torch/csrc/jit/docs/OVERVIEW.md | 1 - torch/csrc/jit/export.cpp | 11 +- torch/csrc/jit/fuser/compiler.cpp | 3 +- torch/csrc/jit/fuser/tensor_desc.h | 7 +- torch/csrc/jit/ir.cpp | 16 +- torch/csrc/jit/ir.h | 5 +- torch/csrc/jit/passes/erase_number_types.cpp | 4 +- .../csrc/jit/passes/onnx/fixup_onnx_loop.cpp | 2 +- torch/csrc/jit/passes/onnx/peephole.cpp | 27 +- .../passes/onnx/prepare_division_for_onnx.cpp | 2 +- torch/csrc/jit/passes/peephole.cpp | 6 +- torch/csrc/jit/passes/shape_analysis.cpp | 150 ++++---- torch/csrc/jit/pybind_utils.h | 14 +- torch/csrc/jit/python_ir.cpp | 26 +- torch/csrc/jit/script/init.cpp | 43 +-- torch/csrc/jit/script/schema_type_parser.cpp | 7 +- torch/csrc/jit/symbolic_variable.h | 23 +- torch/onnx/symbolic_helper.py | 8 +- torch/onnx/symbolic_opset9.py | 18 +- torch/utils/tensorboard/_pytorch_graph.py | 2 +- 29 files changed, 301 insertions(+), 460 deletions(-) diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 74e0665d061..cac47a12fbe 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -30,8 +30,6 @@ using OptNameList = c10::optional>; #define C10_FORALL_TYPES(_) \ _(TensorType) \ - _(DimensionedTensorType) \ - _(CompleteTensorType) \ _(AutogradZeroTensorType) \ _(TupleType) \ _(ListType) \ @@ -292,7 +290,7 @@ struct TensorType; using TensorTypePtr = std::shared_ptr; // This type represents a single Tensor, with an unknown shape. // Subtype hierarchy for Tensor Types (TensorType as the base type): -// CompleteTensorType <: DimensionedTensorType <: TensorType +// ProfiledTensorType <: TensorType // AutogradZeroTensorType <: TensorType struct CAFFE2_API TensorType : public Type { static TensorTypePtr create() { @@ -353,98 +351,6 @@ struct CAFFE2_API AutogradZeroTensorType : public TensorType { AutogradZeroTensorType() : TensorType(TypeKind::AutogradZeroTensorType) {} }; -struct DimensionedTensorType; -using DimensionedTensorTypePtr = std::shared_ptr; -// This type represents a single Tensor with a specific size -struct CAFFE2_API DimensionedTensorType : public TensorType { - template - static DimensionedTensorTypePtr create(T&&... all) { - return DimensionedTensorTypePtr(new DimensionedTensorType( - std::forward(all)...)); // NOLINT(modernize-make-shared) - } - - at::ScalarType scalarType() const { - return scalar_type_; - } - at::Device device() const { - return device_; - } - int64_t dim() const { - return dim_; - } - bool requires_grad() const override { - return requires_grad_; - } - - DimensionedTensorTypePtr toScalarType(at::ScalarType type) { - auto t = DimensionedTensorType::create(*this); - t->scalar_type_ = type; - return t; - } - DimensionedTensorTypePtr withDim(size_t new_dim) { - auto t = DimensionedTensorType::create(*this); - t->dim_ = new_dim; - return t; - } - DimensionedTensorTypePtr withRequiresGrad(bool req) { - auto t = DimensionedTensorType::create(*this); - t->requires_grad_ = req; - return t; - } - - bool operator==(const Type& rhs) const override { - if (rhs.kind() != TypeKind::DimensionedTensorType) - return false; - auto rt = rhs.expect(); - return scalarType() == rt->scalarType() && device() == rt->device() && - dim() == rt->dim(); - } - bool isSubtypeOf(const TypePtr rhs) const override { - return rhs->kind() == TypeKind::TensorType || - (rhs->kind() == TypeKind::DimensionedTensorType && - Type::isSubtypeOf(rhs)) || - TensorType::isSubtypeOf(rhs); - } - bool isSubclass(const TypeKind kind) const override { - return kind == TypeKind::TensorType || - kind == TypeKind::DimensionedTensorType; - } - std::string str() const override { - // str is used for user-facing error messages, where we - // don't want to reveal underlying size information. - return "Tensor"; - } - - static const TypeKind Kind = TypeKind::DimensionedTensorType; - - protected: - DimensionedTensorType( - const at::Tensor& tensor, - TypeKind kind = TypeKind::DimensionedTensorType) - : DimensionedTensorType( - tensor.scalar_type(), - tensor.device(), - tensor.dim(), - tensor.is_variable() && tensor.requires_grad(), - kind) {} - DimensionedTensorType( - at::ScalarType scalar_type, - at::Device device, - int64_t dim, - bool requires_grad = true, - TypeKind kind = TypeKind::DimensionedTensorType) - : TensorType(kind), - scalar_type_(scalar_type), - requires_grad_(at::isFloatingType(scalar_type) && requires_grad), - device_(device), - dim_(dim) {} - - at::ScalarType scalar_type_; - bool requires_grad_; - at::Device device_; - int64_t dim_; -}; - template inline c10::optional merge_primitive( const c10::optional& a, @@ -463,9 +369,14 @@ struct CAFFE2_API VaryingShape { VaryingShape(const std::vector& vec) : size_(vec.size()), dims_(vec.begin(), vec.end()) {} + VaryingShape(c10::ArrayRef vec) + : size_(vec.size()), dims_(vec.begin(), vec.end()) {} + VaryingShape(c10::optional size) : size_(size), dims_(size ? size.value() : 0) {} + VaryingShape(size_t size) : VaryingShape(c10::optional(size)) {} + bool operator==(const VaryingShape& other) const { return size_ == other.size_ && dims_ == other.dims_; } @@ -502,6 +413,18 @@ struct CAFFE2_API VaryingShape { return sizes; } + bool isComplete() const { + if (!size_) { + return false; + } + for (auto d : dims_) { + if(!d) { + return false; + } + } + return true; + } + private: c10::optional size_; std::vector> dims_; @@ -518,21 +441,11 @@ struct CAFFE2_API ProfiledTensorType : public TensorType { } static ProfiledTensorTypePtr create(const TypePtr& tptr) { - if (auto dtt = tptr->cast()) { - at::VaryingShape vshape(c10::optional(dtt->dim())); - return ProfiledTensorType::create( - {dtt->scalarType()}, - {dtt->device()}, - vshape, - vshape, - {dtt->requires_grad()}); - } - if (auto ptt = tptr->cast()) { return ptt; } - if (tptr->isSubclass(TypeKind::TensorType)) { + if (tptr->isSubtypeOf(TensorType::get())) { c10::optional sz; return ProfiledTensorType::create( {}, {}, VaryingShape{sz}, VaryingShape{sz}, {}); @@ -564,6 +477,34 @@ struct CAFFE2_API ProfiledTensorType : public TensorType { requires_grad); } + // overloaded create variadic template argument as it could not distinguish + // initializer list + static ProfiledTensorTypePtr createContiguous( + at::ScalarType scalar_type, + at::Device device, + at::IntArrayRef sizes) { + return create( + scalar_type, + device, + VaryingShape(sizes), + VaryingShape(contiguousStridesOf(sizes)), + c10::nullopt); + } + static ProfiledTensorTypePtr create( + at::ScalarType scalar_type, + at::Device device, + at::IntArrayRef sizes, + at::IntArrayRef strides) { + return create( + scalar_type, + device, + VaryingShape(sizes), + c10::VaryingShape(strides), + c10::nullopt); + } + static TypePtr fromNumberType(TypePtr typ); + static TypePtr fromBoolType(); + c10::optional dim() const { return sizes().size(); } @@ -641,6 +582,20 @@ struct CAFFE2_API ProfiledTensorType : public TensorType { return copy; } + ProfiledTensorTypePtr withSizesStrides( + at::IntArrayRef sizes, + at::IntArrayRef strides) const { + auto cloned = clone(); + cloned->sizes_ = VaryingShape(sizes); + cloned->strides_ = VaryingStrides(strides); + return cloned; + } + + ProfiledTensorTypePtr withSizes(at::IntArrayRef sizes) const { + return withSizesStrides( + sizes, contiguousStridesOf(sizes)); + } + ProfiledTensorTypePtr dimensionedOnly() const { auto copy = clone(); copy->sizes_ = VaryingShape(sizes().size()); @@ -648,6 +603,16 @@ struct CAFFE2_API ProfiledTensorType : public TensorType { return copy; } + ProfiledTensorTypePtr contiguous() const { + auto cloned = clone(); + if (auto concrete_sizes = sizes().concrete_sizes()) { + cloned->strides_ = VaryingShape(contiguousStridesOf(*concrete_sizes)); + } else { + cloned->strides_ = VaryingShape(sizes().size()); + } + return cloned; + } + ProfiledTensorTypePtr merge(ProfiledTensorTypePtr other) { auto scalar_type = merge_primitive(scalarType(), other->scalarType()); auto dev = merge_primitive(device(), other->device()); @@ -656,6 +621,13 @@ struct CAFFE2_API ProfiledTensorType : public TensorType { auto gr = merge_primitive(requiresGrad(), other->requiresGrad()); return ProfiledTensorType::create(scalar_type, dev, sz, srs, gr); } + // is all information about the type specified except for autograd? + // This replaces the notion of a 'CompleteTensorType' that used to exist + // in the type-hierarchy. Excluding require_grad and autogradZero allows + // this to match the old behavior. + bool isComplete() const { + return scalar_type_ && device_ && sizes_.isComplete() && strides_.isComplete(); + } static const TypeKind Kind = TypeKind::ProfiledTensorType; @@ -690,138 +662,6 @@ struct CAFFE2_API ProfiledTensorType : public TensorType { scalar_type_, device_, sizes_, strides_, requires_grad_)); } - c10::optional scalar_type_; - c10::optional device_; - VaryingShape sizes_; - VaryingStrides strides_; - c10::optional requires_grad_; -}; - -struct CompleteTensorType; -using CompleteTensorTypePtr = std::shared_ptr; -// This type represents a single Tensor with a specific size -struct CAFFE2_API CompleteTensorType : public DimensionedTensorType { - template - static CompleteTensorTypePtr create(T&&... all) { - return CompleteTensorTypePtr(new CompleteTensorType( - std::forward(all)...)); // NOLINT(modernize-make-shared) - } - - // overloaded create variadic template argument as it could not distinguish - // initializer list - static CompleteTensorTypePtr create( - at::ScalarType scalar_type, - at::Device device, - at::IntArrayRef sizes) { - return CompleteTensorTypePtr(new CompleteTensorType( - scalar_type, device, sizes)); // NOLINT(modernize-make-shared) - } - static CompleteTensorTypePtr create( - at::ScalarType scalar_type, - at::Device device, - at::IntArrayRef sizes, - at::IntArrayRef strides) { - return CompleteTensorTypePtr(new CompleteTensorType( - scalar_type, device, sizes, strides)); // NOLINT(modernize-make-shared) - } - - const std::vector& sizes() const { - return sizes_; - } - const std::vector& strides() const { - return strides_; - } - - TypePtr withSizesStrides(at::IntArrayRef sizes, at::IntArrayRef strides) - const { - return CompleteTensorType::create(scalar_type_, device_, sizes, strides); - } - - TypePtr withSizes(at::IntArrayRef sizes) const { - return withSizesStrides( - sizes, CompleteTensorType::contiguousStridesOf(sizes)); - } - - CompleteTensorTypePtr contiguous() const { - auto t = CompleteTensorType::create(*this); - t->strides_ = CompleteTensorType::contiguousStridesOf(sizes_); - return t; - } - - CompleteTensorTypePtr toScalarType(at::ScalarType type) { - auto t = CompleteTensorType::create(*this); - t->scalar_type_ = type; - return t; - } - - bool operator==(const Type& rhs) const override { - if (rhs.kind() != kind()) { - return false; - } - - auto rt = rhs.expect(); - return scalarType() == rt->scalarType() && sizes() == rt->sizes() && - strides() == rt->strides() && device() == rt->device(); - } - bool isSubtypeOf(const TypePtr rhs) const override { - if (rhs->kind() == TypeKind::DimensionedTensorType) - return *expect() == *rhs; - return rhs->kind() == TypeKind::TensorType || TensorType::isSubtypeOf(rhs); - } - bool isSubclass(const TypeKind kind) const override { - return kind == TypeKind::TensorType || - kind == TypeKind::DimensionedTensorType || - kind == TypeKind::CompleteTensorType; - } - std::string str() const override { - // str is used for user-facing error messages, where we - // don't want to reveal underlying size information. - return "Tensor"; - } - size_t numel() const { - size_t prod = 1; - for (auto s : sizes()) { - prod *= s; - } - return prod; - } - - static const TypeKind Kind = TypeKind::CompleteTensorType; - - static TypePtr fromNumberType(TypePtr typ); - static TypePtr fromBoolType(); - - private: - CompleteTensorType(const at::Tensor& tensor) - : DimensionedTensorType(tensor, TypeKind::CompleteTensorType), - sizes_(tensor.sizes().vec()), - strides_(tensor.strides().vec()) {} - CompleteTensorType( - at::ScalarType scalar_type, - at::Device device, - at::IntArrayRef sizes, - bool requires_grad = true) - : CompleteTensorType( - scalar_type, - device, - sizes, - CompleteTensorType::contiguousStridesOf(sizes), - requires_grad) {} - CompleteTensorType( - at::ScalarType scalar_type, - at::Device device, - at::IntArrayRef sizes, - at::IntArrayRef strides, - bool requires_grad = true) - : DimensionedTensorType( - scalar_type, - device, - sizes.size(), - requires_grad, - TypeKind::CompleteTensorType), - sizes_(sizes.vec()), - strides_(strides.vec()) {} - static std::vector contiguousStridesOf(at::IntArrayRef sizes) { std::vector strides(sizes.size()); if (sizes.empty()) // zero-dim case @@ -832,9 +672,12 @@ struct CAFFE2_API CompleteTensorType : public DimensionedTensorType { } return strides; } - - std::vector sizes_; - std::vector strides_; + + c10::optional scalar_type_; + c10::optional device_; + VaryingShape sizes_; + VaryingStrides strides_; + c10::optional requires_grad_; }; struct ListType; @@ -1383,19 +1226,18 @@ inline TypePtr unshapedType(const TypePtr& type) { return type->withContained(fmap(type->containedTypes(), unshapedType)); } -inline TypePtr CompleteTensorType::fromNumberType(TypePtr typ) { +inline TypePtr ProfiledTensorType::fromNumberType(TypePtr typ) { if (typ->isSubtypeOf(IntType::get())) { - return CompleteTensorType::create(at::kLong, at::kCPU, {}); + return ProfiledTensorType::createContiguous(at::kLong, at::kCPU, {}); } else if (typ->isSubtypeOf(FloatType::get())) { - return CompleteTensorType::create(at::kFloat, at::kCPU, {}); + return ProfiledTensorType::createContiguous(at::kFloat, at::kCPU, {}); } else if (typ->isSubtypeOf(BoolType::get())) { - return CompleteTensorType::create(at::kLong, at::kCPU, {}); + return ProfiledTensorType::createContiguous(at::kLong, at::kCPU, {}); } AT_ERROR("unknown number type", typ->str()); } - -inline TypePtr CompleteTensorType::fromBoolType() { - return CompleteTensorType::create(at::kLong, at::kCPU, {}); +inline TypePtr ProfiledTensorType::fromBoolType() { + return ProfiledTensorType::createContiguous(at::kLong, at::kCPU, {}); } inline at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) { diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index e9d2b29eee2..c7e63cae101 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -6,25 +6,7 @@ namespace c10 { std::ostream& operator<<(std::ostream & out, const Type & t) { - if(auto value = t.cast()) { - out << toString(value->scalarType()) << "("; - auto& sizes = value->sizes(); - auto& strides = value->strides(); - AT_ASSERT(sizes.size() == strides.size()); - for (size_t i = 0; i < sizes.size(); i++) { - if (i > 0) { - out << ", "; - } - // TODO: figure out a good way to output strides, or - // add a "debug" printing mode which adds the extra stuff - out << sizes[i]; // << "%" << strides[i]; - int64_t expected = i + 1 < sizes.size() ? sizes[i+1]*strides[i+1] : 1; - if (strides[i] != expected) { - out << "!"; //mark non-contiguous - } - } - out << ")"; - } else if (auto value = t.cast()) { + if (auto value = t.cast()) { if (value->scalarType().has_value()) { out << toString(*value->scalarType()); if (!value->sizes().size().has_value()) { @@ -151,7 +133,7 @@ ListTypePtr ListType::ofBools() { // the type, like in the tracer. TypePtr incompleteInferTypeFrom(const IValue& value) { if (value.isTensor()) { - return CompleteTensorType::create(value.toTensor()); + return ProfiledTensorType::create(value.toTensor()); } else if (value.isDouble()) { return FloatType::get(); } else if (value.isInt()) { diff --git a/test/common_methods_invocations.py b/test/common_methods_invocations.py index eabdc9977fd..a002161617b 100644 --- a/test/common_methods_invocations.py +++ b/test/common_methods_invocations.py @@ -217,7 +217,7 @@ def method_tests(): ('expand', (S, 1), (S, S, S), 'new_dim', (True,)), ('expand', (1,), (S, S, S), '1_element', (True,)), ('expand', (1, S), (1, 1, S), 'new_dim_front_old_front_1', (True,)), - ('expand', (), (dont_convert(()),), 'scalar_to_scalar', (True,)), + ('expand', (), (dont_convert(()),), 'scalar_to_scalar'), ('expand', (), (1, 3, 2), 'scalar_to_dims', (True,)), ('expand_as', (S, 1, 1), (torch.rand(S, S, S),), '', (True,)), ('exp', (S, S, S), NO_ARGS, '', (True,)), diff --git a/test/cpp/jit/test_autodiff.h b/test/cpp/jit/test_autodiff.h index 7862a9c6533..a7819795eb0 100644 --- a/test/cpp/jit/test_autodiff.h +++ b/test/cpp/jit/test_autodiff.h @@ -172,7 +172,7 @@ void testADFormulas() { void testDifferentiate() { auto graph = std::make_shared(); at::ScalarType s = at::ScalarType::Float; - auto type = CompleteTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1}); + auto type = ProfiledTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1}); // Build up a fake graph auto a = SymbolicVariable::asNewInput(*graph, type); diff --git a/test/cpp/jit/test_fuser.h b/test/cpp/jit/test_fuser.h index 3b69d5a2c06..c32c66bf3d2 100644 --- a/test/cpp/jit/test_fuser.h +++ b/test/cpp/jit/test_fuser.h @@ -177,7 +177,7 @@ void testRegisterFusionCachesKernel(std::ostream& out = std::cout) { auto createGraphWithNames = [](std::string cname, std::string dname) { auto graph = std::make_shared(); at::ScalarType s = at::ScalarType::Float; - auto type = CompleteTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1}); + auto type = ProfiledTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1}); auto a = SymbolicVariable::asNewInput(*graph, type); auto b = SymbolicVariable::asNewInput(*graph, type); auto c = a * b; diff --git a/test/cpp/jit/test_misc.h b/test/cpp/jit/test_misc.h index f1fe383f69c..6bb2977ca59 100644 --- a/test/cpp/jit/test_misc.h +++ b/test/cpp/jit/test_misc.h @@ -360,7 +360,7 @@ void testATenNativeBatchNorm() { void testCustomFusion() { auto graph = std::make_shared(); at::ScalarType s = at::ScalarType::Float; - auto type = CompleteTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1}); + auto type = ProfiledTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1}); auto a = SymbolicVariable::asNewInput(*graph, type); auto b = SymbolicVariable::asNewInput(*graph, type); auto c = a * b; @@ -394,7 +394,7 @@ void testCustomFusion() { void testCustomFusionNestedBlocks() { auto g = std::make_shared(); at::ScalarType s = at::ScalarType::Float; - auto type = CompleteTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1}); + auto type = ProfiledTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1}); // test CustomFusion in nested blocks; auto a = SymbolicVariable::asNewInput(*g, type); diff --git a/test/test_jit.py b/test/test_jit.py index 7c8d0e6c634..50250446363 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -2408,7 +2408,7 @@ graph(%Ra, %Rb): def test_onnx_transpose_incomplete_tensor_type(self): # Smoke test to get us into the state where we are attempting to export # a transpose op, where the input is a TensorType rather than a - # CompleteTensorType. This would previously not work, since we would + # ProfiledTensorType. This would previously not work, since we would # take the size of the input and use the length of its sizes as the # number of dimensions in the permutation. class Foo(torch.jit.ScriptModule): @@ -8270,9 +8270,9 @@ a") graph = _propagate_shapes(tensor_unifying.graph, (a, b, c), False) if_outputs = list(graph.findNode("prim::If").outputs()) - self.assertTrue(if_outputs[0].type().str() == "Float(2, 2)") - self.assertTrue(if_outputs[1].type().str() == "Tensor(2, *)") - self.assertTrue(if_outputs[2].type().str() == "Tensor(2, 4)") + self.assertTrue(if_outputs[0].type().str() == "Float(*, *)") + self.assertTrue(if_outputs[1].type().str() == "Tensor(*, *)") + self.assertTrue(if_outputs[2].type().str() == "Tensor(*, *)") def test_list_unify(self): # allowing a unififed int?[] would cause a runtime error b/c diff --git a/torch/csrc/jit/argument_spec.h b/torch/csrc/jit/argument_spec.h index 3d96324fc26..12cd876eb03 100644 --- a/torch/csrc/jit/argument_spec.h +++ b/torch/csrc/jit/argument_spec.h @@ -349,7 +349,7 @@ struct CompleteArgumentInfo { operator TypePtr() const { if (!defined()) return TensorType::get(); - return CompleteTensorType::create( + return ProfiledTensorType::create( type(), ConvertIntToCPUOrCUDA(device()), sizes(), strides()); } diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp index 591c9603812..6c24f9bf720 100644 --- a/torch/csrc/jit/autodiff.cpp +++ b/torch/csrc/jit/autodiff.cpp @@ -308,8 +308,10 @@ class GradientHelper { // reutrns them as a tuple auto sizes = node->namedInput(attr::self) ->type() - ->expect() - ->sizes(); + ->expect() + ->sizes() + .concrete_sizes() + .value(); return {grads.at(0).reshape(sizes), nullptr}; } else if ( diff --git a/torch/csrc/jit/docs/OVERVIEW.md b/torch/csrc/jit/docs/OVERVIEW.md index 1d14ffc8cbe..cdb9aabc4de 100644 --- a/torch/csrc/jit/docs/OVERVIEW.md +++ b/torch/csrc/jit/docs/OVERVIEW.md @@ -332,7 +332,6 @@ TorchScript, unlike Python, is statically typed, so every Value has a Type assoc * TensorType - the root type of all Tensors in the system. * ProfiledTensorType - a tensor with optionally refined information. It may optional know its device, type, requires_grad state, the number of dimensions. If it does know the number of dimensions it may optionally know the size of a particular dimension. -* CompleteTensorType - A subtype of TensorType that adds fixed sizes (e.g. a [3 x 4] cuda tensor). This only appears from tracing at the moment. * Tuples - e.g. Tuple[Tensor, Int]. Each member of the tuple is statically typed and the length of the tuple is statically known. * List[T] - e.g. List[Tensor]. Mutable lists of a particular type. * Optional[T] - e.g. Optional[Tensor], either the Tensor value or None. diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp index 3ef239facd9..97e9a2f39ff 100644 --- a/torch/csrc/jit/export.cpp +++ b/torch/csrc/jit/export.cpp @@ -221,11 +221,15 @@ void EncoderBase::EncodeValueInfo( const std::unordered_map>& dynamic_axes) { std::string name = n->debugName(); v->set_name(name); - if (CompleteTensorTypePtr node_type = n->type()->cast()) { + if (ProfiledTensorTypePtr node_type = n->type()->cast()) { + if (!node_type->isComplete()) { + return; + } onnx::TypeProto* t = v->mutable_type(); onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type(); onnx::TensorShapeProto* shape = tensor_type->mutable_shape(); - const std::vector& sizes = node_type->sizes(); + std::vector sizes = + node_type->sizes().concrete_sizes().value(); for (size_t i = 0; i < sizes.size(); i++) { shape->add_dim(); if ((dynamic_axes.find(name) != dynamic_axes.end()) && @@ -236,7 +240,8 @@ void EncoderBase::EncodeValueInfo( shape->mutable_dim(i)->set_dim_value(sizes[i]); } } - tensor_type->set_elem_type(ATenTypeToOnnxType(node_type->scalarType())); + tensor_type->set_elem_type( + ATenTypeToOnnxType(node_type->scalarType().value())); } else if (BoolTypePtr node_type = n->type()->cast()) { onnx::TypeProto* t = v->mutable_type(); onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type(); diff --git a/torch/csrc/jit/fuser/compiler.cpp b/torch/csrc/jit/fuser/compiler.cpp index e2d31a4c174..5ee0c90a496 100644 --- a/torch/csrc/jit/fuser/compiler.cpp +++ b/torch/csrc/jit/fuser/compiler.cpp @@ -256,7 +256,8 @@ std::shared_ptr compileKernel( auto scalar_type = ProfiledTensorType::create(o->type())->scalarType(); TORCH_INTERNAL_ASSERT(scalar_type); - auto type = CompleteTensorType::create(*scalar_type, device, sizes); + auto type = + ProfiledTensorType::createContiguous(*scalar_type, device, sizes); output_desc.emplace_back(type); const auto& desc = output_desc.back(); diff --git a/torch/csrc/jit/fuser/tensor_desc.h b/torch/csrc/jit/fuser/tensor_desc.h index 736a3396003..1037570b91e 100644 --- a/torch/csrc/jit/fuser/tensor_desc.h +++ b/torch/csrc/jit/fuser/tensor_desc.h @@ -41,8 +41,11 @@ struct TORCH_API TensorDesc { TensorDesc(const at::Tensor& t) : TensorDesc(t.scalar_type(), t.sizes(), t.strides()) {} - TensorDesc(const c10::CompleteTensorTypePtr& type) - : TensorDesc(type->scalarType(), type->sizes(), type->strides()) {} + TensorDesc(const c10::ProfiledTensorTypePtr& type) + : TensorDesc( + type->scalarType().value(), + type->sizes().concrete_sizes().value(), + type->strides().concrete_sizes().value()) {} // number of dimensions after contiguity compression size_t nDim() const { diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index 98afe5e2321..32634f387f2 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -314,10 +314,10 @@ static void checkSameDevice(const Node* node) { bool has_device = false; c10::optional device = c10::nullopt; auto checkValue = [&](const Value* v) { - if (CompleteTensorTypePtr type = v->type()->cast()) { - if (!has_device) { + if (ProfiledTensorTypePtr type = v->type()->cast()) { + if (type->device() && !has_device) { has_device = true; - device = type->device(); + device = *type->device(); } else { AT_ASSERT(device == type->device()); } @@ -669,13 +669,7 @@ void Graph::remapTypes(const std::function& type_map) { } void Value::inferTypeFrom(const at::Tensor& output) { - if (output.is_mkldnn()) { - // mkldnn tensor as opaque tensor doesn't have strides, so we can - // not create a CompleteTensorType - setType(ProfiledTensorType::create(output)); - return; - } - setType(CompleteTensorType::create(output)); + setType(ProfiledTensorType::create(output)); } bool Value::mustBeNone() const { @@ -1427,7 +1421,7 @@ Node* Graph::createDict( Node* Graph::createNumToTensor(Value* value) { auto typ = value->type(); Node* result = create(prim::NumToTensor, {value}); - result->output()->setType(CompleteTensorType::fromNumberType(std::move(typ))); + result->output()->setType(ProfiledTensorType::fromNumberType(std::move(typ))); return result; } diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index b887acf0bd5..23d6dd48dc3 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -168,7 +168,10 @@ struct Value { return type()->requires_grad(); } bool isCompleteTensor() const { - return type()->kind() == TypeKind::CompleteTensorType; + if (auto pt = type()->cast()) { + return pt->isComplete(); + } + return false; } TORCH_API bool mustBeNone() const; TORCH_API bool mustNotBeNone() const; diff --git a/torch/csrc/jit/passes/erase_number_types.cpp b/torch/csrc/jit/passes/erase_number_types.cpp index 70debca8612..c365533bd15 100644 --- a/torch/csrc/jit/passes/erase_number_types.cpp +++ b/torch/csrc/jit/passes/erase_number_types.cpp @@ -47,9 +47,9 @@ static void EraseNumberTypesOnBlock(Block* block) { default: { for (auto o : it->outputs()) { if (o->type()->isSubtypeOf(NumberType::get())) { - o->setType(CompleteTensorType::fromNumberType(o->type())); + o->setType(ProfiledTensorType::fromNumberType(o->type())); } else if (o->type()->isSubtypeOf(BoolType::get())) { - o->setType(CompleteTensorType::fromBoolType()); + o->setType(ProfiledTensorType::fromBoolType()); } } } break; diff --git a/torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp b/torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp index a50f4ed080c..29f867e2f5a 100644 --- a/torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp +++ b/torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp @@ -55,7 +55,7 @@ void FixupONNXLoops(Block* block) { cond->setType(BoolType::create()); Value* i = sub_block->inputs()[0]; - i->setType(CompleteTensorType::fromNumberType(IntType::get())); + i->setType(ProfiledTensorType::fromNumberType(IntType::get())); // add cast to condition input inside the loop. Value* next_cond_val = sub_block->outputs()[0]; diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp index 5e4ccca3a54..3dae2145053 100644 --- a/torch/csrc/jit/passes/onnx/peephole.cpp +++ b/torch/csrc/jit/passes/onnx/peephole.cpp @@ -129,9 +129,16 @@ void fuseBroadcast(Block* b) { // Not all broadcasts are supported by ONNX broadcast. c10::optional axis = fusibleExpandTo( unexpanded_input->type() - ->expect() - ->sizes(), // from - n->output()->type()->expect()->sizes()); // to + ->expect() + ->sizes() + .concrete_sizes() + .value(), // from + n->output() + ->type() + ->expect() + ->sizes() + .concrete_sizes() + .value()); // to if (axis == c10::nullopt) continue; @@ -289,15 +296,15 @@ void pushPackingPastRnn(Block* b) { // unhygenic way, Pytorch ends up propagating an incorrect type. // Until a long-term cleanup comes around, we can fix this by // resetting the size to the correct value. - CompleteTensorTypePtr oldType = - rnn->inputs().at(0)->type()->cast(); - if (oldType) { + ProfiledTensorTypePtr oldType = + rnn->inputs().at(0)->type()->cast(); + if (oldType && oldType->isComplete()) { std::vector new_sizes; - new_sizes.push_back(oldType->sizes().at(0)); - new_sizes.push_back(oldType->sizes().at(1)); + new_sizes.push_back(*oldType->sizes()[0]); + new_sizes.push_back(*oldType->sizes()[1]); new_sizes.push_back(rnn->i(attr::hidden_size)); - CompleteTensorTypePtr newType = CompleteTensorType::create( - oldType->scalarType(), oldType->device(), new_sizes); + ProfiledTensorTypePtr newType = ProfiledTensorType::createContiguous( + *oldType->scalarType(), *oldType->device(), new_sizes); next->outputs().at(0)->setType(newType); } diff --git a/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp b/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp index 10e07c5b866..7f2bf8d10f2 100644 --- a/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp @@ -28,7 +28,7 @@ static void PrepareDivisionForONNXOnBlock(Block* block) { it->replaceInput(0, floattensor_inputs[0]); it->replaceInput(1, floattensor_inputs[1]); it->output()->setType( - CompleteTensorType::fromNumberType(FloatType::get())); + ProfiledTensorType::fromNumberType(FloatType::get())); } } } diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp index 1feeb199b1b..54d1d4e1aa7 100644 --- a/torch/csrc/jit/passes/peephole.cpp +++ b/torch/csrc/jit/passes/peephole.cpp @@ -46,9 +46,11 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) { // x.expand(x.size()) == x if (auto input_type = node->namedInput(attr::self) ->type() - ->cast()) { + ->cast()) { auto expanded_sizes = node->get>(attr::size); - if (!expanded_sizes.has_value() || c10::impl::toVector(*expanded_sizes) == input_type->sizes()) { + auto input_type_sizes = input_type->sizes().concrete_sizes(); + if (expanded_sizes.has_value() && input_type_sizes && + c10::impl::toVector(*expanded_sizes) == *input_type_sizes) { GRAPH_UPDATE( *node, " (x.expand(x.size()) == x) is replaced with ", diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 115c9d2dff9..3692b311ed5 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -38,8 +38,11 @@ bool isValidArgumentForRunning(Value* v) { // allow constants if (toIValue(v)) return true; - if (CompleteTensorTypePtr tt = v->type()->cast()) { - return !at::isIntegralType(tt->scalarType(), /*includeBool=*/false); + if (ProfiledTensorTypePtr tt = v->type()->cast()) { + if (!tt->scalarType()) { + return false; + } + return !at::isIntegralType(*tt->scalarType(), /*includeBool=*/false); } return v->type()->isSubtypeOf(FloatType::get()); } @@ -155,14 +158,19 @@ class ShapePropagator { if (auto iv = toIValue(v)) { return *iv; } - if (CompleteTensorTypePtr type = type_->cast()) { - auto attype = type->device().is_cpu() ? at::CPU(type->scalarType()) - : at::CUDA(type->scalarType()); - at::DeviceGuard device_guard(type->device()); - auto t = - at::empty_strided(type->sizes(), type->strides(), attype.options()) - .zero_(); - return autograd::make_variable(t, /*requires_grad=*/false); + if (ProfiledTensorTypePtr type = type_->cast()) { + if (type->isComplete()) { + auto attype = type->device()->is_cpu() ? at::CPU(*type->scalarType()) + : at::CUDA(*type->scalarType()); + at::DeviceGuard device_guard(*type->device()); + auto t = at::empty_strided( + *type->sizes().concrete_sizes(), + *type->strides().concrete_sizes(), + attype.options()) + .zero_(); + return autograd::make_variable(t, /*requires_grad=*/false); + } + // fallthrough } else if (type_->isSubtypeOf(FloatType::get())) { return 0.f; } @@ -177,9 +185,10 @@ class ShapePropagator { // for each node in the schema with type Tensor, extract the T type // returns c10::nullopt if any Tensor in the schema does not have a known // shape ignores non-tensor in the list of inputs - template - c10::optional>> gatherTensorTypes(Node* node) { - std::vector> tensor_types; + c10::optional> gatherTensorTypes( + Node* node, + bool complete = false) { + std::vector tensor_types; auto& schema = node->schema(); auto& args = schema.arguments(); @@ -192,7 +201,10 @@ class ShapePropagator { if (args[i].type()->isSubtypeOf(ListType::ofTensors())) { return c10::nullopt; } else if (args[i].type()->isSubtypeOf(TensorType::get())) { - if (auto type = node->input(i)->type()->cast()) { + if (auto type = node->input(i)->type()->cast()) { + if (complete && !type->isComplete()) { + return c10::nullopt; + } tensor_types.push_back(type); } else { return c10::nullopt; @@ -224,13 +236,14 @@ class ShapePropagator { void broadcastBinary( Node* node, - std::vector& types, + std::vector& types, size_t idx1, size_t idx2) { - auto expected_size = - at::infer_size(types[idx1]->sizes(), types[idx2]->sizes()); + auto expected_size = at::infer_size( + *types[idx1]->sizes().concrete_sizes(), + *types[idx2]->sizes().concrete_sizes()); auto broadcast = [&](size_t input_idx) { - CompleteTensorTypePtr input_type = types.at(input_idx); + ProfiledTensorTypePtr input_type = types.at(input_idx); if (input_type->sizes() == expected_size) return; auto graph = node->owningGraph(); @@ -247,8 +260,8 @@ class ShapePropagator { }; broadcast(idx1); broadcast(idx2); - types[0] = node->inputs().at(idx1)->type()->expect(); - types[1] = node->inputs().at(idx2)->type()->expect(); + types[0] = node->inputs().at(idx1)->type()->expect(); + types[1] = node->inputs().at(idx2)->type()->expect(); } OperatorSet cannot_propagate_shape_by_running_it = { @@ -357,17 +370,19 @@ class ShapePropagator { static const auto propagate_complete = [this](Node* node, at::ArrayRef tensors) -> bool { auto input_types = fmap(tensors, [](Value* v) { - return v->type()->cast(); + return v->type()->cast(); }); if (!std::all_of( input_types.begin(), input_types.end(), - [](const CompleteTensorTypePtr& tp) { return tp != nullptr; })) { + [](const ProfiledTensorTypePtr& tp) { + return tp != nullptr && tp->isComplete(); + })) { return false; } if (!node->is_constant(attr::dim)) return false; - std::vector sizes = input_types[0]->sizes(); + std::vector sizes = *input_types[0]->sizes().concrete_sizes(); const int64_t dim = wrapDim(node->get(attr::dim).value(), sizes); const int64_t ndim = sizes.size(); @@ -376,7 +391,7 @@ class ShapePropagator { sizes[dim] = 0; for (auto& tp : input_types) { - auto& tp_sizes = tp->sizes(); + auto tp_sizes = tp->sizes().concrete_sizes().value(); if (sizes.size() != tp_sizes.size()) return false; for (int64_t i = 0; i < ndim; ++i) { @@ -621,7 +636,7 @@ class ShapePropagator { } if (auto maybe_complete_types = - gatherTensorTypes(node)) { + gatherTensorTypes(node, /*complete=*/true)) { if (PropagateCompleteShapeOnNode( node, insert_expands, std::move(*maybe_complete_types))) { return; @@ -840,8 +855,7 @@ class ShapePropagator { "aten::atan2(Tensor self, Tensor other) -> Tensor", }, [this](Node* node) -> type_vec_t { - if (auto maybe_tensor_types = - gatherTensorTypes(node)) { + if (auto maybe_tensor_types = gatherTensorTypes(node)) { AT_ASSERT(maybe_tensor_types->size() >= 2); auto first_scalar_type = (*maybe_tensor_types)[0]->scalarType(); auto second_scalar_type = (*maybe_tensor_types)[1]->scalarType(); @@ -865,8 +879,7 @@ class ShapePropagator { "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor", }, [this](Node* node) -> type_vec_t { - if (auto maybe_tensor_types = - gatherTensorTypes(node)) { + if (auto maybe_tensor_types = gatherTensorTypes(node)) { return {broadcast(*maybe_tensor_types, 0)}; } return {}; @@ -897,8 +910,7 @@ class ShapePropagator { "aten::__irshift__(Tensor self, Scalar other) -> Tensor", }, [this](Node* node) -> type_vec_t { - if (auto maybe_tensor_types = - gatherTensorTypes(node)) { + if (auto maybe_tensor_types = gatherTensorTypes(node)) { return {broadcast(*maybe_tensor_types, 0)}; } return {}; @@ -911,8 +923,7 @@ class ShapePropagator { "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor", }, [this](Node* node) -> type_vec_t { - if (auto maybe_tensor_types = - gatherTensorTypes(node)) { + if (auto maybe_tensor_types = gatherTensorTypes(node)) { return {broadcast(*maybe_tensor_types, 1)}; } return {}; @@ -971,9 +982,9 @@ class ShapePropagator { "aten::ne(Tensor self, Scalar other) -> Tensor", }, [this](Node* node) -> type_vec_t { - if (auto maybe_tensor_types = - gatherTensorTypes(node)) { - return {broadcast(*maybe_tensor_types, 0)->withScalarType(at::kBool)}; + if (auto maybe_tensor_types = gatherTensorTypes(node)) { + return { + broadcast(*maybe_tensor_types, 0)->withScalarType(at::kBool)}; } return {}; }}; @@ -1694,8 +1705,7 @@ class ShapePropagator { } return nullptr; }; - if (auto maybe_tensor_types = - gatherTensorTypes(node)) { + if (auto maybe_tensor_types = gatherTensorTypes(node)) { tensor_types = std::move(*maybe_tensor_types); } else { return false; @@ -1712,7 +1722,7 @@ class ShapePropagator { bool PropagateCompleteShapeOnNode( Node* node, bool insert_expands, - std::vector tensor_types) { + std::vector tensor_types) { // For expensive ops we can directly encode their shape propagation // here, otherwise we fallback to running a fake version of the op // to get a quick and dirty propagation. @@ -1761,17 +1771,19 @@ class ShapePropagator { } else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) { auto lhs_type = tensor_types.at(0); auto rhs_type = tensor_types.at(1); + auto lhs_sizes = lhs_type->sizes().concrete_sizes().value(); + auto rhs_sizes = rhs_type->sizes().concrete_sizes().value(); SHAPE_ASSERT( - lhs_type->sizes().size() == 2 && rhs_type->sizes().size() == 2); - node->output()->setType(CompleteTensorType::create( - lhs_type->scalarType(), - lhs_type->device(), - at::IntArrayRef{lhs_type->sizes().at(0), rhs_type->sizes().at(1)})); + *lhs_type->sizes().size() == 2 && *rhs_type->sizes().size() == 2); + node->output()->setType(ProfiledTensorType::createContiguous( + *lhs_type->scalarType(), + *lhs_type->device(), + at::IntArrayRef{lhs_sizes[0], rhs_sizes[1]})); return true; } else if (node->matches("aten::t(Tensor self) -> Tensor")) { auto tp = tensor_types.at(0); - auto sizes = tp->sizes(); - auto strides = tp->strides(); + auto sizes = tp->sizes().concrete_sizes().value(); + auto strides = tp->strides().concrete_sizes().value(); SHAPE_ASSERT(sizes.size() == 2); std::swap(sizes.at(0), sizes.at(1)); std::swap(strides.at(0), strides.at(1)); @@ -1782,12 +1794,13 @@ class ShapePropagator { "aten::narrow(Tensor self, int dim, int start, int length) -> Tensor", /*const_inputs=*/{attr::dim, attr::length})) { auto tp = tensor_types.at(0); - auto sizes = tp->sizes(); + auto sizes = tp->sizes().concrete_sizes().value(); int64_t dim = node->get(attr::dim).value(); int64_t length = node->get(attr::length).value(); SHAPE_ASSERT(dim >= 0 && static_cast(dim) < sizes.size()); sizes.at(dim) = length; - node->output()->setType(tp->withSizesStrides(sizes, tp->strides())); + node->output()->setType( + tp->withSizesStrides(sizes, tp->strides().concrete_sizes().value())); return true; } else if (node->matches("aten::sum(Tensor self, *, int? dtype) -> Tensor")) { node->output()->setType(tensor_types.at(0)->withSizes({})); @@ -1796,7 +1809,7 @@ class ShapePropagator { "aten::sum(Tensor self, int[] dim, bool keepdim, *, int? dtype) -> Tensor", /*const_inputs=*/{attr::dim, attr::keepdim})) { auto& tp = tensor_types.at(0); - auto sizes = tp->sizes(); + auto sizes = tp->sizes().concrete_sizes().value(); auto dims = node->get>(attr::dim).value(); bool keepdim = node->get(attr::keepdim).value(); std::reverse(dims.begin(), dims.end()); @@ -1814,8 +1827,8 @@ class ShapePropagator { "aten::squeeze(Tensor self, int dim) -> Tensor", /*const_inputs=*/attr::dim)) { auto& tp = tensor_types.at(0); - auto sizes = tp->sizes(); - auto strides = tp->strides(); + auto sizes = tp->sizes().concrete_sizes().value(); + auto strides = tp->strides().concrete_sizes().value(); int64_t dim = wrapDim(node->get(attr::dim).value(), sizes); SHAPE_ASSERT(dim >= 0 && static_cast(dim) < sizes.size()); if (sizes.at(dim) == 1) { @@ -1828,8 +1841,8 @@ class ShapePropagator { "aten::unsqueeze(Tensor self, int dim) -> Tensor", /*const_inputs=*/attr::dim)) { auto& tp = tensor_types.at(0); - auto sizes = tp->sizes(); - auto strides = tp->strides(); + auto sizes = tp->sizes().concrete_sizes().value(); + auto strides = tp->strides().concrete_sizes().value(); int64_t dim = wrapDim(node->get(attr::dim).value(), sizes); SHAPE_ASSERT(dim >= 0 && static_cast(dim) <= sizes.size()); int64_t new_stride = dim >= static_cast(sizes.size()) @@ -1860,7 +1873,7 @@ class ShapePropagator { if (inferred) { SHAPE_ASSERT(size_product != 0); size_t numel = 1; - for (int64_t s : tensor_types.at(0)->sizes()) + for (int64_t s : tensor_types.at(0)->sizes().concrete_sizes().value()) numel *= s; int64_t inferred_size = numel / size_product; sizes[inferred_idx] = inferred_size; @@ -1874,8 +1887,8 @@ class ShapePropagator { node->output()->setType(node->namedInput(attr::self)->type()); } else { // This will be a copy, so the result will be contiguous - node->output()->setType( - tensor_types.at(1)->withSizes(tensor_types.at(0)->sizes())); + node->output()->setType(tensor_types.at(1)->withSizes( + tensor_types.at(0)->sizes().concrete_sizes().value())); } return true; } else if ( @@ -1885,9 +1898,10 @@ class ShapePropagator { auto tp = tensor_types.at(0); std::vector sizes, strides; std::tie(sizes, strides) = at::inferExpandGeometry( - tp->sizes(), - tp->strides(), - c10::impl::toVector(node->get>(attr::size).value())); + tp->sizes().concrete_sizes().value(), + tp->strides().concrete_sizes().value(), + c10::impl::toVector( + node->get>(attr::size).value())); node->output()->setType(tp->withSizesStrides(sizes, strides)); return true; } else if ( @@ -1897,26 +1911,26 @@ class ShapePropagator { auto ten = tensor_types.at(0); auto index = tensor_types.at(1); int64_t dim = node->get(attr::dim).value(); - SHAPE_ASSERT(index->sizes().size() == 1); + SHAPE_ASSERT(*index->sizes().size() == 1); SHAPE_ASSERT(dim >= 0 && static_cast(dim) < ten->sizes().size()); - std::vector sizes = ten->sizes(); - sizes[dim] = index->sizes()[0]; + std::vector sizes = ten->sizes().concrete_sizes().value(); + sizes[dim] = index->sizes()[0].value(); node->output()->setType(ten->withSizes(sizes)); return true; } else if (node->matches( "aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]", /*const_inputs=*/{attr::chunks, attr::dim})) { auto input_type = tensor_types.at(0); - auto sizes = input_type->sizes(); - const auto& strides = input_type->strides(); + auto sizes = input_type->sizes().concrete_sizes().value(); + auto strides = input_type->strides().concrete_sizes().value(); int64_t dim = node->get(attr::dim).value(); int64_t chunks = node->get(attr::chunks).value(); sizes[dim] /= chunks; for (Value* output : node->outputs()) { output->setType(input_type->withSizesStrides(sizes, strides)); } - if (input_type->sizes().at(dim) % chunks != 0) { - sizes[dim] = input_type->sizes().at(dim) % chunks; + if (*input_type->sizes()[dim] % chunks != 0) { + sizes[dim] = *input_type->sizes()[dim] % chunks; node->outputs().back()->setType( input_type->withSizesStrides(sizes, strides)); } @@ -1924,10 +1938,10 @@ class ShapePropagator { } else if (node->kind() == ::c10::onnx::Shape) { SHAPE_ASSERT(node->inputs().size() == 1 && node->outputs().size() == 1); std::vector dim_vec = { - (int64_t)tensor_types.at(0)->sizes().size()}; + (int64_t)*tensor_types.at(0)->sizes().size()}; at::IntArrayRef dims(dim_vec); node->output()->setType( - CompleteTensorType::create(at::kLong, at::kCPU, dims)); + ProfiledTensorType::createContiguous(at::kLong, at::kCPU, dims)); return true; } else if (node->kind() == ::c10::onnx::Reshape) { setUnshapedType(node); diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h index 7da143b3f8a..b0798ac16d7 100644 --- a/torch/csrc/jit/pybind_utils.h +++ b/torch/csrc/jit/pybind_utils.h @@ -93,15 +93,7 @@ inline MatchTypeReturn tryToInferType(py::handle input) { // Try tensor types if (THPVariable_Check(input.ptr())) { auto tensor = py::cast(input); - if (tensor.is_mkldnn()) { - // mkldnn tensor as opaque tensor doesn't have strides, so we can - // not create a CompleteTensorType - return MatchTypeReturn(ProfiledTensorType::create(tensor)); - } - - // TODO: maybe unshape this type if this is used for script instead of - // tracing - return MatchTypeReturn(CompleteTensorType::create(tensor)); + return MatchTypeReturn(ProfiledTensorType::create(tensor)); } if (input.is(py::none())) { @@ -320,9 +312,7 @@ inline IValue toIValue( switch (type->kind()) { case TypeKind::TensorType: case TypeKind::AutogradZeroTensorType: - case TypeKind::ProfiledTensorType: - case TypeKind::DimensionedTensorType: - case TypeKind::CompleteTensorType: { + case TypeKind::ProfiledTensorType: { auto var = py::cast(obj); if (var.is_sparse()) { AT_WARN( diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index 3836a45e9fd..5c9df2f20a0 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -627,9 +627,7 @@ void initPythonIRBindings(PyObject* module_) { s << t; return s.str(); }) - .def("kind", [](const Type& t) { - return typeKindToString(t.kind()); - }) + .def("kind", [](const Type& t) { return typeKindToString(t.kind()); }) .def( "dim", [](Type& t) { @@ -640,15 +638,29 @@ void initPythonIRBindings(PyObject* module_) { }) .def( "sizes", - [](Type& t) { return t.expect()->sizes(); }) + [](Type& t) -> py::object { + if (auto ptt = t.expect()) { + if (auto cs = ptt->sizes().concrete_sizes()) { + return py::cast(*cs); + } + } + return py::none(); + }) .def( - "strides", - [](Type& t) { return t.expect()->strides(); }) + "sizes", + [](Type& t) -> py::object { + if (auto ptt = t.expect()) { + if (auto cs = ptt->strides().concrete_sizes()) { + return py::cast(*cs); + } + } + return py::none(); + }) .def( "contiguous", [](Type& t) { return std::static_pointer_cast( - t.expect()->contiguous()); + t.expect()->contiguous()); }) .def( "scalarType", diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 219067d96b3..4aa2bdf44b1 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -279,26 +279,19 @@ struct VISIBILITY_HIDDEN ModuleSelf : public Self { const py::object& pyModule_; }; -static TypePtr getTensorType(const at::Tensor& t, const TypeKind type_kind) { - switch (type_kind) { - case TypeKind::ProfiledTensorType: - return ProfiledTensorType::create(t); - case TypeKind::CompleteTensorType: { - auto scalar_type = t.scalar_type(); - auto sizes = t.sizes(); - return CompleteTensorType::create(scalar_type, at::kCPU, sizes); - } - default: - throw std::runtime_error( - "Attempted to call getTensorType for type kind other than ProfiledTensorType or CompleteTensorType."); +static TypePtr getTensorType(const at::Tensor& t, bool complete) { + auto r = ProfiledTensorType::create(t); + if (!complete) { + r = r->dimensionedOnly(); } + return r; } static TupleTypePtr getTupleTensorType( const Stack::const_iterator& s_iter, const Stack::const_iterator& s_iter_end, const TypePtr& tupleType, - const TypeKind type_kind) { + bool complete) { AT_ASSERT(tupleType->kind() == TupleType::Kind); AT_ASSERT(s_iter != s_iter_end); @@ -306,27 +299,24 @@ static TupleTypePtr getTupleTensorType( for (const auto& subType : tupleType->containedTypes()) { if (subType->kind() == TupleType::Kind) { types.push_back( - getTupleTensorType(s_iter + 1, s_iter_end, subType, type_kind)); + getTupleTensorType(s_iter + 1, s_iter_end, subType, complete)); } else { - types.push_back(getTensorType(s_iter->toTensor(), type_kind)); + types.push_back(getTensorType(s_iter->toTensor(), complete)); } } return TupleType::create(types); } -static void setInputTensorTypes( - Graph& g, - const Stack& stack, - const TypeKind type_kind = TypeKind::ProfiledTensorType) { +static void setInputTensorTypes(Graph& g, const Stack& stack, bool complete) { at::ArrayRef input_values = g.inputs(); auto s_iter = stack.begin(); for (auto v : input_values) { AT_ASSERT(s_iter != stack.end()); if (v->type()->kind() == TupleType::Kind) { AT_ASSERT(v->node()->kind() == prim::Param); - v->setType(getTupleTensorType(s_iter, stack.end(), v->type(), type_kind)); + v->setType(getTupleTensorType(s_iter, stack.end(), v->type(), complete)); } else { - v->setType(getTensorType(s_iter->toTensor(), type_kind)); + v->setType(getTensorType(s_iter->toTensor(), complete)); s_iter++; } } @@ -338,7 +328,7 @@ static std::shared_ptr _propagate_shapes( bool with_grad = false) { Stack stack(inputs.begin(), inputs.end()); auto retval = graph.copy(); - setInputTensorTypes(*retval, stack); + setInputTensorTypes(*retval, stack, /*complete=*/false); PropagateInputShapes(retval); return retval; } @@ -349,13 +339,10 @@ static std::shared_ptr _propagate_and_assign_input_shapes( bool with_grad = false, bool propagate = true) { auto retval = graph.copy(); + setInputTensorTypes(*retval, fmap(inputs), /*complete=*/true); if (propagate) { - setInputTensorTypes(*retval, fmap(inputs), TypeKind::ProfiledTensorType); PropagateInputShapes(retval); } - setInputTensorTypes( - *retval, fmap(inputs), TypeKind::CompleteTensorType); - return retval; } @@ -367,8 +354,8 @@ static std::shared_ptr _assign_output_shapes( for (size_t i = 0; i < outputs.size(); ++i) { auto scalar_type = outputs[i].scalar_type(); auto sizes = outputs[i].sizes(); - auto type = - torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes); + auto type = torch::jit::ProfiledTensorType::createContiguous( + scalar_type, at::kCPU, sizes); retval->outputs()[i]->setType(type); } return retval; diff --git a/torch/csrc/jit/script/schema_type_parser.cpp b/torch/csrc/jit/script/schema_type_parser.cpp index b1aca366f0c..6017642b128 100644 --- a/torch/csrc/jit/script/schema_type_parser.cpp +++ b/torch/csrc/jit/script/schema_type_parser.cpp @@ -9,7 +9,7 @@ using c10::AliasInfo; using c10::BoolType; -using c10::CompleteTensorType; +using c10::CapsuleType; using c10::DeviceObjType; using c10::DictType; using c10::FloatType; @@ -18,7 +18,6 @@ using c10::GeneratorType; using c10::IntType; using c10::ListType; using c10::NoneType; -using c10::CapsuleType; using c10::NumberType; using c10::OptionalType; using c10::StringType; @@ -167,8 +166,8 @@ TypePtr SchemaTypeParser::parseRefinedTensor() { dims.push_back(dim); }); at::IntArrayRef dims_ref(dims); - ptr = - CompleteTensorType::create(dtype, at::DeviceType::CPU, dims_ref, false); + ptr = at::ProfiledTensorType::create( + dtype, at::DeviceType::CPU, dims_ref, false); } return ptr; } diff --git a/torch/csrc/jit/symbolic_variable.h b/torch/csrc/jit/symbolic_variable.h index 50db692f663..ec35397f8e4 100644 --- a/torch/csrc/jit/symbolic_variable.h +++ b/torch/csrc/jit/symbolic_variable.h @@ -20,8 +20,12 @@ struct SymbolicVariable { static SymbolicVariable asNewInput(Graph& g, TypePtr type) { return g.addInput()->setType(std::move(type)); } - const std::vector& sizes() const { - return v->type()->expect()->sizes(); + std::vector sizes() const { + return v->type() + ->expect() + ->sizes() + .concrete_sizes() + .value(); } void addAsOutput() const { v->owningGraph()->registerOutput(v); @@ -313,7 +317,7 @@ struct SymbolicVariable { return v->owningGraph()->insertConstant(std::move(value)); } SymbolicVariable typeLike(SymbolicVariable other) const { - if (auto other_type = other.v->type()->cast()) + if (auto other_type = other.v->type()->cast()) v->setType(other_type->contiguous()); return *this; } @@ -336,8 +340,8 @@ struct SymbolicVariable { SymbolicVariable typeLikeWithScalarType( SymbolicVariable other, at::ScalarType type) const { - if (auto other_type = other.v->type()->cast()) { - auto new_type = other_type->toScalarType(type)->contiguous(); + if (auto other_type = other.v->type()->cast()) { + auto new_type = other_type->withScalarType(type)->contiguous(); v->setType(new_type); } return *this; @@ -345,11 +349,12 @@ struct SymbolicVariable { SymbolicVariable typeLikeWithRhsScalarType( SymbolicVariable other, SymbolicVariable rhs) const { - auto other_type = other.v->type()->cast(); - auto rhs_type = rhs.v->type()->cast(); - if (other_type && rhs_type) { + auto other_type = other.v->type()->cast(); + auto rhs_type = rhs.v->type()->cast(); + if (other_type && rhs_type && other_type->isComplete() && + rhs_type->isComplete()) { auto new_type = - other_type->toScalarType(rhs_type->scalarType())->contiguous(); + other_type->withScalarType(rhs_type->scalarType())->contiguous(); v->setType(new_type); } return *this; diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index a65314dcd73..472bbbea7ab 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -44,17 +44,11 @@ from functools import wraps # contained in TensorTyper. This adds a sizes() # method which can be used to retrieve the # concrete sizes. -# @deprecated -# CompleteTensorType <: TensorType - Denotes a Tensor for which we know the -# concrete sizes in addition to the information -# contained in TensorTyper. This adds a sizes() -# method which can be used to retrieve the -# concrete sizes. # # In general, we should prefer to rely on the least specific information possible. # For example, not relying on tensor properties at all is better than relying # on the number of dimensions which is better than relying on -# concrete shapes (CompleteTensorType). Doing so will make the export symbolics +# concrete shapes. Doing so will make the export symbolics # more robust to different graphs. # --------------------------------------------------------------------------------- diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 70c28a4b101..4d92fd7a594 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -330,7 +330,7 @@ def transpose(g, self, dim0, dim1): return self # NB: Transpose in ONNX is actually a Permute - if self.type().kind() == "CompleteTensorType": + if self.isCompleteTensor(): axes = list(range(self.type().dim())) axes[dim0], axes[dim1] = axes[dim1], axes[dim0] return g.op("Transpose", self, perm_i=axes) @@ -550,7 +550,7 @@ def get_pool_ceil_padding(input, kernel_size, stride, padding): def _max_pool(name, tuple_fn, ndims, return_indices): @parse_args('v', 'is', 'is', 'is', 'is', 'i') def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): - if ceil_mode and input.type().kind() != "CompleteTensorType": + if ceil_mode and not input.isCompleteTensor(): return _unimplemented(name, "input size not accessible") if set(tuple_fn(dilation)) != {1}: return _unimplemented(name, "dilation") @@ -608,7 +608,7 @@ max_pool3d_with_indices = _max_pool("max_pool3d_with_indices", _triple, 3, retur def _avg_pool(name, tuple_fn): @parse_args('v', 'is', 'is', 'is', 'i', 'i', 'none') def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override=None): - if ceil_mode and input.type().kind() != "CompleteTensorType": + if ceil_mode and not input.isCompleteTensor(): return _unimplemented(name, "input size not accessible") if divisor_override and divisor_override.node().kind() != 'prim::Constant': return _unimplemented(name, "divisor_override") @@ -650,11 +650,11 @@ def _adaptive_pool(name, type, tuple_fn, fn=None): # the same dimension, which makes it possible to export it to ONNX. # for MaxPool, GlobalMaxPool does not return indices, # so we try using max_poolxd_with_indices, and if it is not possible - # (input is not CompleteTensorType or output size not factor of input size) + # (input is not a complete tensor or output size not factor of input size) # then we call GlobalAveragePool and return None for the indices if output_size == [1] * len(output_size) and type == "AveragePool": return g.op("GlobalAveragePool", input) - if input.type().kind() != "CompleteTensorType": + if not input.isCompleteTensor(): if output_size == [1] * len(output_size): return g.op("GlobalMaxPool", input), None return _unimplemented(name, 'input size not accessible') @@ -1251,7 +1251,7 @@ def unsqueeze(g, self, dim): def sort(g, self, dim, decending, out=None): if out is not None: _unimplemented("Sort", "Out parameter is not supported for sort") - if self.type().kind() != "CompleteTensorType": + if not self.isCompleteTensor(): return _unimplemented("Sort", "input size not accessible") return g.op("TopK", self, k_i=self.type().sizes()[dim], axis_i=dim, outputs=2) @@ -1598,7 +1598,7 @@ def flatten(g, input, start_dim, end_dim): if start_dim == 0 and end_dim == dim - 2 : return g.op("Flatten", input, axis_i=end_dim + 1) # use Reshape for cases where the output shape is not 2D - if input.type().kind() != "CompleteTensorType": + if not input.isCompleteTensor(): return _unimplemented("flatten", "input size not accessible") input_dims = input.type().sizes() output_dims = [] @@ -1655,7 +1655,7 @@ def scatter(g, self, dim, index, src): @parse_args('v', 'i', 'v', 'v') def scatter_add(g, self, dim, index, src): - if self.type().kind() != "CompleteTensorType": + if not self.isCompleteTensor(): return _unimplemented("scatter_add", "input size not accessible") dtype = self.type().scalarType() dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype]) @@ -1689,7 +1689,7 @@ def gather(g, self, dim, index, sparse_grad=False): @parse_args('v', 'is', 'b', 'i') def _std(g, input, dim, unbiased, keepdim): - if input.type().kind() == "CompleteTensorType" or input.type().kind() == "DimensionedTensorType": + if input.isCompleteTensor(): sqrd = g.op("Mul", input, input) if dim is None: sqrdmean = g.op("ReduceMean", sqrd, keepdims_i=0) diff --git a/torch/utils/tensorboard/_pytorch_graph.py b/torch/utils/tensorboard/_pytorch_graph.py index 11a4c1349cb..b71893b5f7d 100644 --- a/torch/utils/tensorboard/_pytorch_graph.py +++ b/torch/utils/tensorboard/_pytorch_graph.py @@ -52,7 +52,7 @@ class NodePy(NodeBase): io_tensor_sizes = [] for n in list_of_node: io_unique_names.append(n.debugName()) - if n.type().kind() == 'CompleteTensorType': + if n.isCompleteTensor(): io_tensor_sizes.append(n.type().sizes()) else: io_tensor_sizes.append(None)