diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index cfe7c5e6131..7bb3a25f92a 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -43,7 +43,6 @@ using OptNameList = c10::optional>; _(BoolType) \ _(OptionalType) \ _(VarType) \ - _(ProfiledTensorType) \ _(DeviceObjType) \ _(FunctionType) \ _(ClassType) \ @@ -251,34 +250,6 @@ struct CAFFE2_API OptionalType OptionalType(TypePtr elem) : SingleElementType(elem) {} }; -struct TensorType; -using TensorTypePtr = std::shared_ptr; -// This type represents a single Tensor, with an unknown shape. -// Subtype hierarchy for Tensor Types (TensorType as the base type): -// ProfiledTensorType <: TensorType -struct CAFFE2_API TensorType : public Type { - static TensorTypePtr create() { - return TensorTypePtr(new TensorType()); // NOLINT(modernize-make-shared) - } - - bool requires_grad() const override { - return true; - } - - bool operator==(const Type& rhs) const override { - return rhs.kind() == kind(); - } - std::string str() const override { - return "Tensor"; - } - static const TypeKind Kind = TypeKind::TensorType; - // global singleton - static TensorTypePtr get(); - - protected: - TensorType(TypeKind kind = TypeKind::TensorType) : Type(kind) {} -}; - template inline c10::optional merge_primitive( const c10::optional& a, @@ -369,43 +340,31 @@ struct CAFFE2_API VaryingShape { using VaryingStrides = VaryingShape; -struct ProfiledTensorType; -using ProfiledTensorTypePtr = std::shared_ptr; +struct TensorType; +using TensorTypePtr = std::shared_ptr; // This type represents a single Tensor with a specific size -struct CAFFE2_API ProfiledTensorType : public TensorType { - static ProfiledTensorTypePtr create(const at::Tensor& t) { - return ProfiledTensorTypePtr(new ProfiledTensorType(t)); +struct CAFFE2_API TensorType : public Type { + static TensorTypePtr create(const at::Tensor& t) { + return TensorTypePtr(new TensorType(t)); } - static ProfiledTensorTypePtr create(const TypePtr& tptr) { - if (auto ptt = tptr->cast()) { - return ptt; - } - - if (tptr->isSubtypeOf(TensorType::get())) { - return ProfiledTensorType::get(); - } - - TORCH_INTERNAL_ASSERT(false, "Expected a tensor type"); - } - - static ProfiledTensorTypePtr create( + static TensorTypePtr create( c10::optional scalar_type, c10::optional device, const VaryingShape& sizes, const VaryingStrides& strides, c10::optional requires_grad, c10::optional autograd_zero=c10::nullopt) { - return ProfiledTensorTypePtr(new ProfiledTensorType( + return TensorTypePtr(new TensorType( scalar_type, device, sizes, strides, requires_grad)); } - static ProfiledTensorTypePtr create( + static TensorTypePtr create( c10::optional scalar_type, c10::optional device, c10::optional dim, c10::optional requires_grad) { - return ProfiledTensorType::create( + return TensorType::create( scalar_type, device, VaryingShape(dim), @@ -415,7 +374,7 @@ struct CAFFE2_API ProfiledTensorType : public TensorType { // overloaded create variadic template argument as it could not distinguish // initializer list - static ProfiledTensorTypePtr createContiguous( + static TensorTypePtr createContiguous( at::ScalarType scalar_type, at::Device device, at::IntArrayRef sizes) { @@ -426,7 +385,7 @@ struct CAFFE2_API ProfiledTensorType : public TensorType { VaryingShape(contiguousStridesOf(sizes)), c10::nullopt); } - static ProfiledTensorTypePtr create( + static TensorTypePtr create( at::ScalarType scalar_type, at::Device device, at::IntArrayRef sizes, @@ -470,21 +429,13 @@ struct CAFFE2_API ProfiledTensorType : public TensorType { return false; } - auto rt = rhs.expect(); + auto rt = rhs.expect(); return scalar_type_ == rt->scalarType() && sizes() == rt->sizes() && strides() == rt->strides() && device() == rt->device() && requiresGrad() == rt->requiresGrad() && autogradZero() == rt->autogradZero(); } - bool isSubtypeOf(const TypePtr rhs) const override { - if (auto rhs_p = rhs->cast()) { - // if we have the same pointer, avoid computing the merge - if (this == rhs_p.get()) { - return true; - } - return *merge(rhs_p) == *rhs_p; - } - return rhs->kind() == TypeKind::TensorType || TensorType::isSubtypeOf(rhs); - } + bool isSubtypeOf(const TypePtr rhs) const override; + std::string str() const override; c10::optional numel() const { @@ -500,27 +451,27 @@ struct CAFFE2_API ProfiledTensorType : public TensorType { return prod; } - ProfiledTensorTypePtr withRequiresGrad(c10::optional s) { + TensorTypePtr withRequiresGrad(c10::optional s) { auto copy = clone(); copy->requires_grad_ = s; return copy; } - ProfiledTensorTypePtr withScalarType(c10::optional st) { + TensorTypePtr withScalarType(c10::optional st) { auto copy = clone(); copy->scalar_type_ = st; return copy; } - ProfiledTensorTypePtr withDim(c10::optional d) { + TensorTypePtr withDim(c10::optional d) { auto copy = clone(); copy->sizes_ = VaryingShape(d); copy->strides_ = VaryingShape(d); return copy; } - ProfiledTensorTypePtr withSizesStrides( + TensorTypePtr withSizesStrides( at::IntArrayRef sizes, at::IntArrayRef strides) const { auto cloned = clone(); @@ -529,19 +480,19 @@ struct CAFFE2_API ProfiledTensorType : public TensorType { return cloned; } - ProfiledTensorTypePtr withSizes(at::IntArrayRef sizes) const { + TensorTypePtr withSizes(at::IntArrayRef sizes) const { return withSizesStrides( sizes, contiguousStridesOf(sizes)); } - ProfiledTensorTypePtr dimensionedOnly() const { + TensorTypePtr dimensionedOnly() const { auto copy = clone(); copy->sizes_ = VaryingShape(sizes().size()); copy->strides_ = VaryingShape(sizes().size()); return copy; } - ProfiledTensorTypePtr contiguous() const { + TensorTypePtr contiguous() const { auto cloned = clone(); if (auto concrete_sizes = sizes().concrete_sizes()) { cloned->strides_ = VaryingShape(contiguousStridesOf(*concrete_sizes)); @@ -551,14 +502,14 @@ struct CAFFE2_API ProfiledTensorType : public TensorType { return cloned; } - ProfiledTensorTypePtr merge(ProfiledTensorTypePtr other) const { + TensorTypePtr merge(TensorTypePtr other) const { auto scalar_type = merge_primitive(scalarType(), other->scalarType()); auto dev = merge_primitive(device(), other->device()); auto sz = sizes().merge(other->sizes()); auto srs = strides().merge(other->strides()); auto gr = merge_primitive(requiresGrad(), other->requiresGrad()); auto zero = merge_primitive(autogradZero(), other->autogradZero()); - return ProfiledTensorType::create(scalar_type, dev, sz, srs, gr, zero); + return TensorType::create(scalar_type, dev, sz, srs, gr, zero); } // is all information about the type specified except for autograd? // This replaces the notion of a 'CompleteTensorType' that used to exist @@ -568,7 +519,7 @@ struct CAFFE2_API ProfiledTensorType : public TensorType { return scalar_type_ && device_ && sizes_.isComplete() && strides_.isComplete(); } - ProfiledTensorTypePtr withAutogradZero() { + TensorTypePtr withAutogradZero() { auto r = clone(); r->autograd_zero_ = true; return r; @@ -578,13 +529,13 @@ struct CAFFE2_API ProfiledTensorType : public TensorType { return autograd_zero_; } - static ProfiledTensorTypePtr get(); + static TensorTypePtr get(); - static const TypeKind Kind = TypeKind::ProfiledTensorType; + static const TypeKind Kind = TypeKind::TensorType; private: - ProfiledTensorType(const at::Tensor& tensor) - : TensorType(TypeKind::ProfiledTensorType), + TensorType(const at::Tensor& tensor) + : Type(TypeKind::TensorType), scalar_type_(tensor.scalar_type()), device_(tensor.device()), sizes_(tensor.sizes().size()), @@ -595,14 +546,14 @@ struct CAFFE2_API ProfiledTensorType : public TensorType { strides_ = tensor.strides().vec(); } } - ProfiledTensorType( + TensorType( c10::optional scalar_type, c10::optional device, const VaryingShape& sizes, const VaryingStrides& strides, c10::optional requires_grad, c10::optional autograd_zero=c10::nullopt) - : TensorType(TypeKind::ProfiledTensorType), + : Type(TypeKind::TensorType), scalar_type_(scalar_type), device_(device), sizes_(sizes), @@ -610,8 +561,8 @@ struct CAFFE2_API ProfiledTensorType : public TensorType { requires_grad_(requires_grad), autograd_zero_(autograd_zero) {} - ProfiledTensorTypePtr clone() const { - return ProfiledTensorTypePtr(new ProfiledTensorType( + TensorTypePtr clone() const { + return TensorTypePtr(new TensorType( scalar_type_, device_, sizes_, strides_, requires_grad_, autograd_zero_)); } @@ -1170,18 +1121,18 @@ inline TypePtr unshapedType(const TypePtr& type) { return type->withContained(fmap(type->containedTypes(), unshapedType)); } -inline TypePtr ProfiledTensorType::fromNumberType(TypePtr typ) { +inline TypePtr TensorType::fromNumberType(TypePtr typ) { if (typ->isSubtypeOf(IntType::get())) { - return ProfiledTensorType::createContiguous(at::kLong, at::kCPU, {}); + return TensorType::createContiguous(at::kLong, at::kCPU, {}); } else if (typ->isSubtypeOf(FloatType::get())) { - return ProfiledTensorType::createContiguous(at::kFloat, at::kCPU, {}); + return TensorType::createContiguous(at::kFloat, at::kCPU, {}); } else if (typ->isSubtypeOf(BoolType::get())) { - return ProfiledTensorType::createContiguous(at::kLong, at::kCPU, {}); + return TensorType::createContiguous(at::kLong, at::kCPU, {}); } AT_ERROR("unknown number type", typ->str()); } -inline TypePtr ProfiledTensorType::fromBoolType() { - return ProfiledTensorType::createContiguous(at::kLong, at::kCPU, {}); +inline TypePtr TensorType::fromBoolType() { + return TensorType::createContiguous(at::kLong, at::kCPU, {}); } inline at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) { diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index 2960536f5bb..8b013521240 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -6,7 +6,7 @@ namespace c10 { std::ostream& operator<<(std::ostream & out, const Type & t) { - if (auto value = t.cast()) { + if (auto value = t.cast()) { if (value->scalarType().has_value()) { out << toString(*value->scalarType()); if (!value->sizes().size().has_value()) { @@ -29,6 +29,9 @@ std::ostream& operator<<(std::ostream & out, const Type & t) { } out << ")"; } + if (value->autogradZero() && *value->autogradZero()) { + out << "[AutogradZero]"; + } } else if(t.kind() == TypeKind::ListType) { auto prim = t.cast()->getElementType(); out << *prim << "[]"; @@ -61,12 +64,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) { } TensorTypePtr TensorType::get() { - static auto value = TensorType::create(); - return value; -} - -ProfiledTensorTypePtr ProfiledTensorType::get() { - static auto value = ProfiledTensorType::create( + static auto value = TensorType::create( {}, {}, VaryingShape{c10::optional()}, @@ -140,7 +138,7 @@ ListTypePtr ListType::ofBools() { // the type, like in the tracer. TypePtr incompleteInferTypeFrom(const IValue& value) { if (value.isTensor()) { - return ProfiledTensorType::create(value.toTensor()); + return TensorType::create(value.toTensor()); } else if (value.isDouble()) { return FloatType::get(); } else if (value.isInt()) { @@ -262,8 +260,8 @@ c10::optional unifyTypes(const TypePtr& t1, const TypePtr& t2) { // NB: we do not return NumberType because there is not currently enough // operator support for it - if (t1->kind() == ProfiledTensorType::Kind && t2->kind() == ProfiledTensorType::Kind) { - return t1->expect()->merge(t2->expect()); + if (t1->kind() == TensorType::Kind && t2->kind() == TensorType::Kind) { + return t1->expect()->merge(t2->expect()); } if (t1->isSubtypeOf(TensorType::get()) && t2->isSubtypeOf(TensorType::get())) { @@ -479,7 +477,7 @@ bool Type::isSubtypeOf(const TypePtr rhs) const { return false; } -std::string ProfiledTensorType::str() const { +std::string TensorType::str() const { return "Tensor"; } @@ -625,4 +623,15 @@ std::string TupleType::python_str() const { return ss.str(); } +bool TensorType::isSubtypeOf(const TypePtr rhs) const { + if (auto rhs_p = rhs->cast()) { + // if we have the same pointer, avoid computing the merge + if (this == rhs_p.get()) { + return true; + } + return *merge(rhs_p) == *rhs_p; + } + return Type::isSubtypeOf(rhs); +} + } // namespace c10 diff --git a/test/cpp/jit/test_argument_spec.cpp b/test/cpp/jit/test_argument_spec.cpp index 3a5cadaea53..0dfe92cc48b 100644 --- a/test/cpp/jit/test_argument_spec.cpp +++ b/test/cpp/jit/test_argument_spec.cpp @@ -91,40 +91,37 @@ void testCompleteArgumentSpec() { ASSERT_EQ(with_const.at(2).sizes().size(), 2); } -size_t hashCode(const ProfiledTensorTypePtr& ptr) { - return std::hash()(*ptr.get()); +size_t hashCode(const TensorTypePtr& ptr) { + return std::hash()(*ptr.get()); } void testProfiledTensorTypeHashing() { c10::VaryingShape vs(c10::optional{}); - auto ptt_empty1 = ProfiledTensorType::create({}, {}, vs, vs, false); - auto ptt_empty2 = ProfiledTensorType::create({}, {}, vs, vs, false); + auto ptt_empty1 = TensorType::create({}, {}, vs, vs, false); + auto ptt_empty2 = TensorType::create({}, {}, vs, vs, false); ASSERT_EQ(hashCode(ptt_empty1), hashCode(ptt_empty2)); c10::VaryingShape vs22(std::vector{2, 2}); - auto ptt_vs22_1 = ProfiledTensorType::create({}, {}, vs22, vs, false); - auto ptt_vs22_2 = ProfiledTensorType::create({}, {}, vs22, vs, false); + auto ptt_vs22_1 = TensorType::create({}, {}, vs22, vs, false); + auto ptt_vs22_2 = TensorType::create({}, {}, vs22, vs, false); ASSERT_EQ(hashCode(ptt_vs22_1), hashCode(ptt_vs22_2)); c10::VaryingShape vs23(std::vector{2, 3}); - auto ptt_vs23_1 = ProfiledTensorType::create({}, {}, vs23, vs, false); + auto ptt_vs23_1 = TensorType::create({}, {}, vs23, vs, false); ASSERT_NE(hashCode(ptt_vs22_1), hashCode(ptt_vs23_1)); - auto ptt_vs22_vs22_1 = ProfiledTensorType::create({}, {}, vs22, vs22, false); - auto ptt_vs22_vs22_2 = ProfiledTensorType::create({}, {}, vs22, vs22, false); + auto ptt_vs22_vs22_1 = TensorType::create({}, {}, vs22, vs22, false); + auto ptt_vs22_vs22_2 = TensorType::create({}, {}, vs22, vs22, false); ASSERT_EQ(hashCode(ptt_vs22_vs22_1), hashCode(ptt_vs22_vs22_2)); - auto ptt_vs22_vs23_2 = ProfiledTensorType::create({}, {}, vs22, vs23, false); + auto ptt_vs22_vs23_2 = TensorType::create({}, {}, vs22, vs23, false); ASSERT_NE(hashCode(ptt_vs22_vs22_1), hashCode(ptt_vs22_vs23_2)); - auto ptt_vs22_vs22_1_true = - ProfiledTensorType::create({}, {}, vs22, vs22, true); - auto ptt_vs22_vs22_2_true = - ProfiledTensorType::create({}, {}, vs22, vs22, true); + auto ptt_vs22_vs22_1_true = TensorType::create({}, {}, vs22, vs22, true); + auto ptt_vs22_vs22_2_true = TensorType::create({}, {}, vs22, vs22, true); ASSERT_EQ(hashCode(ptt_vs22_vs22_1_true), hashCode(ptt_vs22_vs22_2_true)); - auto ptt_vs22_vs22_1_false = - ProfiledTensorType::create({}, {}, vs22, vs22, false); + auto ptt_vs22_vs22_1_false = TensorType::create({}, {}, vs22, vs22, false); ASSERT_NE(hashCode(ptt_vs22_vs22_1_true), hashCode(ptt_vs22_vs22_1_false)); } diff --git a/test/cpp/jit/test_autodiff.cpp b/test/cpp/jit/test_autodiff.cpp index 2820ea0fa24..6bde65e1942 100644 --- a/test/cpp/jit/test_autodiff.cpp +++ b/test/cpp/jit/test_autodiff.cpp @@ -169,7 +169,7 @@ void testADFormulas() { void testDifferentiate() { // Note: can't use IRParser for this test due to issue #23989 auto graph = std::make_shared(); - const auto type = ProfiledTensorType::create(at::ScalarType::Float, at::kCPU, {2, 3, 4}, {12, 4, 1}); + const auto type = TensorType::create(at::ScalarType::Float, at::kCPU, {2, 3, 4}, {12, 4, 1}); // Builds graph a * b * a + b auto* a = graph->addInput()->setType(type); diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index 731d6317455..7b92aa7697b 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -1011,7 +1011,7 @@ static void checkShape( bool prev = true) { auto profile = (prev) ? n->inputs().at(0)->node() : n; auto tp = profile->output()->type(); - auto ptp = tp->expect(); + auto ptp = tp->expect(); ASSERT_EQ(ptp->sizes().concrete_sizes().value(), expected); } @@ -1044,7 +1044,9 @@ void testInsertAndEliminateRedundantGuards() { return n->kind() == prim::Guard; }); ASSERT_NE(guard, nodes.end()); - ASSERT_EQ(guard->input()->type()->cast(), nullptr); + ASSERT_EQ( + guard->input()->type()->expect()->sizes().size(), + c10::nullopt); checkShape(*guard, {2, 3}, false); auto is_guard = [](Node* n) { return n->kind() == prim::Guard; }; int num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard); diff --git a/test/test_jit.py b/test/test_jit.py index c376934914f..9bcc6179113 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -2150,8 +2150,8 @@ graph(%Ra, %Rb): graph = f.graph_for(t, "hi") input_types = list(next(graph.inputs()).type().elements()) w = input_types[0] - self.assertEqual(input_types[0].kind(), 'ProfiledTensorType') - self.assertEqual(input_types[1].elements()[1].kind(), 'ProfiledTensorType') + self.assertEqual(input_types[0].kind(), 'TensorType') + self.assertEqual(input_types[1].elements()[1].kind(), 'TensorType') def test_constant_prop_simple(self): @torch.jit.script @@ -2412,8 +2412,8 @@ graph(%Ra, %Rb): def test_onnx_transpose_incomplete_tensor_type(self): # Smoke test to get us into the state where we are attempting to export - # a transpose op, where the input is a TensorType rather than a - # ProfiledTensorType. This would previously not work, since we would + # a transpose op, where the input is a TensorType without size information. + # This would previously not work, since we would # take the size of the input and use the length of its sizes as the # number of dimensions in the permutation. class Foo(torch.jit.ScriptModule): @@ -4561,8 +4561,7 @@ a") x = torch.randn(3, 1, 5, requires_grad=True) fn = torch.jit.script(fn) graph = _propagate_shapes(fn.graph, (x,), False) - a = next(graph.outputs()).type().kind() - self.assertTrue(next(graph.outputs()).type().kind() != 'TensorType') + self.assertTrue(next(graph.outputs()).type().scalarType() == 'Double') def test_shape_prop_promotion(self): @torch.jit.script @@ -5168,7 +5167,7 @@ a") res = fn(t, 1) self.assertEqual(res, 0) g = torch.jit.last_executed_optimized_graph() - self.assertEqual(next(g.inputs()).type().kind(), 'ProfiledTensorType') + self.assertEqual(next(g.inputs()).type().kind(), 'TensorType') @torch.jit.script def fn(x, y, b): diff --git a/torch/csrc/jit/argument_spec.cpp b/torch/csrc/jit/argument_spec.cpp index 9343ad08d98..19aed9619a6 100644 --- a/torch/csrc/jit/argument_spec.cpp +++ b/torch/csrc/jit/argument_spec.cpp @@ -216,7 +216,7 @@ void ArgumentSpecCreator::specializeTypes( auto& arg = spec.tensorAt(tensor_arg_spec_offset++); if (!arg.defined()) { result_stack.back().emplace_back( - ProfiledTensorType::get()->withAutogradZero()); + TensorType::get()->withAutogradZero()); } else { result_stack.back().emplace_back(arg.toType()); } diff --git a/torch/csrc/jit/argument_spec.h b/torch/csrc/jit/argument_spec.h index b1c5ca92c09..52e9158ee81 100644 --- a/torch/csrc/jit/argument_spec.h +++ b/torch/csrc/jit/argument_spec.h @@ -44,11 +44,12 @@ struct ArgumentInfo { TypePtr toType() const { if (!defined()) return TensorType::get(); - return ProfiledTensorType::create(type(), - ConvertIntToCPUOrCUDA(device()), - c10::VaryingShape(dim()), - c10::VaryingShape(dim()), - requires_grad()); + return TensorType::create( + type(), + ConvertIntToCPUOrCUDA(device()), + c10::VaryingShape(dim()), + c10::VaryingShape(dim()), + requires_grad()); } operator TypePtr() const { return toType(); @@ -349,7 +350,7 @@ struct CompleteArgumentInfo { operator TypePtr() const { if (!defined()) return TensorType::get(); - return ProfiledTensorType::create( + return TensorType::create( type(), ConvertIntToCPUOrCUDA(device()), sizes(), strides()); } @@ -447,8 +448,8 @@ struct hash { }; template <> -struct hash { - size_t operator()(const c10::ProfiledTensorType& ptt) const { +struct hash { + size_t operator()(const c10::TensorType& ptt) const { return torch::get_hash< c10::optional, c10::VaryingShape, diff --git a/torch/csrc/jit/docs/OVERVIEW.md b/torch/csrc/jit/docs/OVERVIEW.md index cdb9aabc4de..4f54f21f510 100644 --- a/torch/csrc/jit/docs/OVERVIEW.md +++ b/torch/csrc/jit/docs/OVERVIEW.md @@ -329,8 +329,7 @@ Values are abstract representation of data in the program. When executing, the a TorchScript, unlike Python, is statically typed, so every Value has a Type associated with it, and every FunctionSchema has a list of argument types and a return type for a function. Type is the base class of a hierarchy of C++ objects that represent the built-in types of TorchScript. Types provide methods such as `Type::isSubtypeOf` that describe the typing relationships. Common type are: -* TensorType - the root type of all Tensors in the system. -* ProfiledTensorType - a tensor with optionally refined information. It may optional know its device, type, requires_grad state, the number of dimensions. +* TensorType - a tensor with optionally refined information. It may optional know its device, type, requires_grad state, the number of dimensions. If it does know the number of dimensions it may optionally know the size of a particular dimension. * Tuples - e.g. Tuple[Tensor, Int]. Each member of the tuple is statically typed and the length of the tuple is statically known. * List[T] - e.g. List[Tensor]. Mutable lists of a particular type. @@ -811,7 +810,7 @@ Execution starts in `GraphExecutor::run`, which takes takes a Stack of inputs. The ArgumentSpec object is used as a key into a cache that holds pre-optimized Code objects (held in an ExecutionPlan object). On a cache hit, an InterpreterState is created and the Code in the cache is run. -*Pre-derivative Optimization* On a code cache miss, we generate a new optimized Graph on the fly (`compileSpec`). It starts by creating a copy of the initial Graph and setting the input types to the specialized Tensor types observed in this specialization. TensorType inputs to the Graph will get replaced with ProfiledTensorTypes that know the device, number of dimensions, and requires grad state. +*Pre-derivative Optimization* On a code cache miss, we generate a new optimized Graph on the fly (`compileSpec`). It starts by creating a copy of the initial Graph and setting the input types to the specialized Tensor types observed in this specialization. TensorType inputs to the Graph will get refined with types that know the device, number of dimensions, and requires grad state. ``` # post specialization, inputs are now specialized types @@ -1148,7 +1147,7 @@ TODO: fusion, operators # Profiling Programs -`prim::profile` nodes are inserted on every **use** of a value by `ProfilingRecord::instrumentBlock`. Every `prim::profile` node runs a lambda that uses a captured, initial type value and the type of an incoming tensor and merges the two into `ProfiledTensorType` +`prim::profile` nodes are inserted on every **use** of a value by `ProfilingRecord::instrumentBlock`. Every `prim::profile` node runs a lambda that uses a captured, initial type value and the type of an incoming tensor and merges the two into a refined `TensorType` `prim::profile` nodes are replaced with `prim::Guard` nodes by `InsertGuards`. `prim::Guard` nodes are inserted to guarantee that beyond the guard a guarded tensor will always be of the profiled shape. This guarantee will enable optimizations and codegens to generate more efficient code. diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp index 97e9a2f39ff..05a73f31744 100644 --- a/torch/csrc/jit/export.cpp +++ b/torch/csrc/jit/export.cpp @@ -221,7 +221,7 @@ void EncoderBase::EncodeValueInfo( const std::unordered_map>& dynamic_axes) { std::string name = n->debugName(); v->set_name(name); - if (ProfiledTensorTypePtr node_type = n->type()->cast()) { + if (TensorTypePtr node_type = n->type()->cast()) { if (!node_type->isComplete()) { return; } diff --git a/torch/csrc/jit/fuser/codegen.cpp b/torch/csrc/jit/fuser/codegen.cpp index a0579dc9fbb..850aef4f72a 100644 --- a/torch/csrc/jit/fuser/codegen.cpp +++ b/torch/csrc/jit/fuser/codegen.cpp @@ -91,7 +91,7 @@ static std::string variableType(const std::shared_ptr& t) { return "double"; } else if (t->kind() == TypeKind::BoolType) { return "bool"; - } else if (auto scalar_type = ProfiledTensorType::create(t)->scalarType()) { + } else if (auto scalar_type = t->expect()->scalarType()) { return calcScalarTypeName(*scalar_type); } // something went wrong with the type analysis during shape propagation @@ -115,7 +115,7 @@ static std::string typeCastedValueName( // cast here, which may end up being a no-op if the tensor's scalar type // is `double`. return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")"; - } else if (auto scalar_type = ProfiledTensorType::create(t)->scalarType()) { + } else if (auto scalar_type = t->expect()->scalarType()) { if (*scalar_type != outtype) { return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")"; } @@ -260,8 +260,7 @@ static std::string encodeRHS(const Node* n) { } else { size_t i = 0; - auto outtype = - ProfiledTensorType::create(n->output()->type())->scalarType(); + auto outtype = n->output()->type()->expect()->scalarType(); TORCH_INTERNAL_ASSERT(outtype); for (auto in : n->inputs()) { diff --git a/torch/csrc/jit/fuser/compiler.cpp b/torch/csrc/jit/fuser/compiler.cpp index 5ee0c90a496..93e9b96c11c 100644 --- a/torch/csrc/jit/fuser/compiler.cpp +++ b/torch/csrc/jit/fuser/compiler.cpp @@ -204,10 +204,10 @@ std::shared_ptr compileKernel( for (size_t i = 0; i < input_desc.size(); i++) { const auto& desc = input_desc[i]; - // TODO: can't get rid of this use of ProfiledTensorType + // TODO: can't get rid of this use of TensorType // until we switch to ProfilingGraphExecutor, so we don't have to // run PropagateInputShapes below - graph->inputs()[i]->setType(ProfiledTensorType::create( + graph->inputs()[i]->setType(TensorType::create( desc.scalar_type, device, c10::VaryingShape(desc.nDim()), @@ -254,10 +254,9 @@ std::shared_ptr compileKernel( sizes.at(o->node()->i(attr::dim)) *= o->node()->inputs().size(); } - auto scalar_type = ProfiledTensorType::create(o->type())->scalarType(); + auto scalar_type = o->type()->expect()->scalarType(); TORCH_INTERNAL_ASSERT(scalar_type); - auto type = - ProfiledTensorType::createContiguous(*scalar_type, device, sizes); + auto type = TensorType::createContiguous(*scalar_type, device, sizes); output_desc.emplace_back(type); const auto& desc = output_desc.back(); diff --git a/torch/csrc/jit/fuser/tensor_desc.h b/torch/csrc/jit/fuser/tensor_desc.h index 1037570b91e..4c0a5df3de1 100644 --- a/torch/csrc/jit/fuser/tensor_desc.h +++ b/torch/csrc/jit/fuser/tensor_desc.h @@ -41,7 +41,7 @@ struct TORCH_API TensorDesc { TensorDesc(const at::Tensor& t) : TensorDesc(t.scalar_type(), t.sizes(), t.strides()) {} - TensorDesc(const c10::ProfiledTensorTypePtr& type) + TensorDesc(const c10::TensorTypePtr& type) : TensorDesc( type->scalarType().value(), type->sizes().concrete_sizes().value(), diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp index 33679a332b8..cacd404d70b 100644 --- a/torch/csrc/jit/interpreter.cpp +++ b/torch/csrc/jit/interpreter.cpp @@ -1001,7 +1001,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { ++af.pc; } break; case GUARD: { - auto actual = ProfiledTensorType::create(stack.back().toTensor()); + auto actual = TensorType::create(stack.back().toTensor()); const TypePtr& expected = af.types[inst.X]; push(stack, *expected == *actual); ++af.pc; diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index 5bc3f95aa27..213fa2edc30 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -315,7 +315,7 @@ static void checkSameDevice(const Node* node) { bool has_device = false; c10::optional device = c10::nullopt; auto checkValue = [&](const Value* v) { - if (ProfiledTensorTypePtr type = v->type()->cast()) { + if (TensorTypePtr type = v->type()->cast()) { if (type->device() && !has_device) { has_device = true; device = *type->device(); @@ -670,7 +670,7 @@ void Graph::remapTypes(const std::function& type_map) { } void Value::inferTypeFrom(const at::Tensor& output) { - setType(ProfiledTensorType::create(output)); + setType(TensorType::create(output)); } bool Value::mustBeNone() const { @@ -1439,7 +1439,7 @@ Node* Graph::createDict( Node* Graph::createNumToTensor(Value* value) { auto typ = value->type(); Node* result = create(prim::NumToTensor, {value}); - result->output()->setType(ProfiledTensorType::fromNumberType(std::move(typ))); + result->output()->setType(TensorType::fromNumberType(std::move(typ))); return result; } diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index 206ccce2ee2..6c6e63820b1 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -169,7 +169,7 @@ struct Value { return type()->requires_grad(); } bool isCompleteTensor() const { - if (auto pt = type()->cast()) { + if (auto pt = type()->cast()) { return pt->isComplete(); } return false; diff --git a/torch/csrc/jit/passes/bailout_graph.cpp b/torch/csrc/jit/passes/bailout_graph.cpp index aced90f8bc8..13a2462e2a7 100644 --- a/torch/csrc/jit/passes/bailout_graph.cpp +++ b/torch/csrc/jit/passes/bailout_graph.cpp @@ -172,7 +172,7 @@ struct BailOutInserter { for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) { if (it->kind() == prim::Guard) { // this will need to be profiled again - it->input()->setType(TensorType::create()); + it->input()->setType(TensorType::get()); // destroy the guard it->output()->replaceAllUsesWith(it->input()); it.destroyCurrent(); @@ -274,7 +274,7 @@ static void removeBailouts(Block* b) { for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) { if (it->kind() == prim::BailOut) { // clear profiling information - it->inputs().at(0)->setType(TensorType::create()); + it->inputs().at(0)->setType(TensorType::get()); it->output()->replaceAllUsesWith(it->inputs().at(0)); it.destroyCurrent(); } else { diff --git a/torch/csrc/jit/passes/bailout_graph.h b/torch/csrc/jit/passes/bailout_graph.h index 08a65c96122..9ad7fe9adcd 100644 --- a/torch/csrc/jit/passes/bailout_graph.h +++ b/torch/csrc/jit/passes/bailout_graph.h @@ -13,8 +13,6 @@ namespace torch { namespace jit { -using ::c10::ProfiledTensorTypePtr; - // Replaces prim::Guard nodes with prim::BailOut nodes and // computes sets of inputs needed to resume execution at // bailout points diff --git a/torch/csrc/jit/passes/decompose_ops.cpp b/torch/csrc/jit/passes/decompose_ops.cpp index c9d40e1f953..cacc37adfeb 100644 --- a/torch/csrc/jit/passes/decompose_ops.cpp +++ b/torch/csrc/jit/passes/decompose_ops.cpp @@ -40,7 +40,7 @@ bool isDecomposableNorm(Node* normalize_op) { if (!input->type()->isSubtypeOf(TensorType::get())) { return false; } - auto device = ProfiledTensorType::create(input->type())->device(); + auto device = input->type()->expect()->device(); // As of now, we do the decomposition for batchnorm/layernorm on GPU device // only if (!device || (*device).is_cpu()) { diff --git a/torch/csrc/jit/passes/erase_number_types.cpp b/torch/csrc/jit/passes/erase_number_types.cpp index c365533bd15..397ff790f32 100644 --- a/torch/csrc/jit/passes/erase_number_types.cpp +++ b/torch/csrc/jit/passes/erase_number_types.cpp @@ -47,9 +47,9 @@ static void EraseNumberTypesOnBlock(Block* block) { default: { for (auto o : it->outputs()) { if (o->type()->isSubtypeOf(NumberType::get())) { - o->setType(ProfiledTensorType::fromNumberType(o->type())); + o->setType(TensorType::fromNumberType(o->type())); } else if (o->type()->isSubtypeOf(BoolType::get())) { - o->setType(ProfiledTensorType::fromBoolType()); + o->setType(TensorType::fromBoolType()); } } } break; diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index cf4058c0127..0d6727e8cb8 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -178,7 +178,7 @@ struct GraphFuser { if (!v->type()->isSubtypeOf(TensorType::get())) { return true; } - auto device = ProfiledTensorType::create(v->type())->device(); + auto device = v->type()->expect()->device(); if (!device) { return true; } diff --git a/torch/csrc/jit/passes/guard_elimination.cpp b/torch/csrc/jit/passes/guard_elimination.cpp index bc7639ad039..cb61f2edeed 100644 --- a/torch/csrc/jit/passes/guard_elimination.cpp +++ b/torch/csrc/jit/passes/guard_elimination.cpp @@ -134,7 +134,7 @@ struct GuardElimination { input->node()->kind() == prim::Constant) { AT_ASSERT( input->node()->kind() != prim::Guard || - input->type()->expect()); + input->type()->expect()); } else { all_inputs_guarded = false; break; diff --git a/torch/csrc/jit/passes/guard_elimination.h b/torch/csrc/jit/passes/guard_elimination.h index fd522d5d099..539aaa586af 100644 --- a/torch/csrc/jit/passes/guard_elimination.h +++ b/torch/csrc/jit/passes/guard_elimination.h @@ -13,8 +13,6 @@ namespace torch { namespace jit { -using ::c10::ProfiledTensorTypePtr; - TORCH_API void EliminateRedundantGuards(std::shared_ptr graph); } // namespace jit diff --git a/torch/csrc/jit/passes/insert_guards.cpp b/torch/csrc/jit/passes/insert_guards.cpp index b506d0fd2d5..c935d9f8451 100644 --- a/torch/csrc/jit/passes/insert_guards.cpp +++ b/torch/csrc/jit/passes/insert_guards.cpp @@ -30,15 +30,11 @@ struct GuardInserter { for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) { auto n = *it; if (n->kind() == prim::profile && n->outputs().size() == 1) { - auto pttp = n->output()->type()->cast(); + auto pttp = n->output()->type()->cast(); if (pttp) { - // make a *copy* of ProfilingTensorType, in case we'd like - // to make changes to it independently from the one being - // profiled auto guard = graph_->create(prim::Guard, {n->input()}, 1); auto go = guard->output(); - auto copy = ProfiledTensorType::create(pttp); - go->setType(copy); + go->setType(pttp); guard->insertBefore(n); n->output()->replaceAllUsesWith(go); } else { diff --git a/torch/csrc/jit/passes/insert_guards.h b/torch/csrc/jit/passes/insert_guards.h index e6df760be92..02011b6e2f3 100644 --- a/torch/csrc/jit/passes/insert_guards.h +++ b/torch/csrc/jit/passes/insert_guards.h @@ -13,8 +13,6 @@ namespace torch { namespace jit { -using ::c10::ProfiledTensorTypePtr; - TORCH_API void InsertGuards(std::shared_ptr graph); } // namespace jit diff --git a/torch/csrc/jit/passes/liveness.h b/torch/csrc/jit/passes/liveness.h index 456448fcf4f..f526d0f0b24 100644 --- a/torch/csrc/jit/passes/liveness.h +++ b/torch/csrc/jit/passes/liveness.h @@ -13,7 +13,6 @@ namespace torch { namespace jit { -using ::c10::ProfiledTensorTypePtr; using SparseBitVector = ::c10::SparseBitVector<256>; // BuildLivenessSets computes "bailout" liveness which is equivalent to diff --git a/torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp b/torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp index 1735a237320..47a7bd998ef 100644 --- a/torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp +++ b/torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp @@ -29,8 +29,8 @@ Node* InsertCastForCond(Value* cond_val, Graph* graph, Node* consumer_node) { bool IsCondCastRequired(Value* cond_val) { const auto& type = cond_val->type(); - if (type->isSubtypeOf(TensorType::get())) { - if (auto scalar_type = ProfiledTensorType::create(type)->scalarType()) { + if (auto tt = type->cast()) { + if (auto scalar_type = tt->scalarType()) { return *scalar_type != c10::kBool; } } @@ -55,7 +55,7 @@ void FixupONNXLoops(Block* block) { cond->setType(BoolType::create()); Value* i = sub_block->inputs()[0]; - i->setType(ProfiledTensorType::fromNumberType(IntType::get())); + i->setType(TensorType::fromNumberType(IntType::get())); // add cast to condition input inside the loop. Value* next_cond_val = sub_block->outputs()[0]; diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp index 3dae2145053..0ac61547309 100644 --- a/torch/csrc/jit/passes/onnx/peephole.cpp +++ b/torch/csrc/jit/passes/onnx/peephole.cpp @@ -129,13 +129,13 @@ void fuseBroadcast(Block* b) { // Not all broadcasts are supported by ONNX broadcast. c10::optional axis = fusibleExpandTo( unexpanded_input->type() - ->expect() + ->expect() ->sizes() .concrete_sizes() .value(), // from n->output() ->type() - ->expect() + ->expect() ->sizes() .concrete_sizes() .value()); // to @@ -296,14 +296,13 @@ void pushPackingPastRnn(Block* b) { // unhygenic way, Pytorch ends up propagating an incorrect type. // Until a long-term cleanup comes around, we can fix this by // resetting the size to the correct value. - ProfiledTensorTypePtr oldType = - rnn->inputs().at(0)->type()->cast(); + TensorTypePtr oldType = rnn->inputs().at(0)->type()->cast(); if (oldType && oldType->isComplete()) { std::vector new_sizes; new_sizes.push_back(*oldType->sizes()[0]); new_sizes.push_back(*oldType->sizes()[1]); new_sizes.push_back(rnn->i(attr::hidden_size)); - ProfiledTensorTypePtr newType = ProfiledTensorType::createContiguous( + TensorTypePtr newType = TensorType::createContiguous( *oldType->scalarType(), *oldType->device(), new_sizes); next->outputs().at(0)->setType(newType); } diff --git a/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp b/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp index 7f2bf8d10f2..7fb595027bb 100644 --- a/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp @@ -27,8 +27,7 @@ static void PrepareDivisionForONNXOnBlock(Block* block) { it->replaceInput(0, floattensor_inputs[0]); it->replaceInput(1, floattensor_inputs[1]); - it->output()->setType( - ProfiledTensorType::fromNumberType(FloatType::get())); + it->output()->setType(TensorType::fromNumberType(FloatType::get())); } } } diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp index 54d1d4e1aa7..dfbf021a42c 100644 --- a/torch/csrc/jit/passes/peephole.cpp +++ b/torch/csrc/jit/passes/peephole.cpp @@ -44,9 +44,8 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) { "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor", /*const_inputs=*/attr::size)) { // x.expand(x.size()) == x - if (auto input_type = node->namedInput(attr::self) - ->type() - ->cast()) { + if (auto input_type = + node->namedInput(attr::self)->type()->cast()) { auto expanded_sizes = node->get>(attr::size); auto input_type_sizes = input_type->sizes().concrete_sizes(); if (expanded_sizes.has_value() && input_type_sizes && @@ -71,8 +70,8 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) { } else if (node->matches( "aten::type_as(Tensor self, Tensor other) -> Tensor")) { // x.type_as(y) == x iff x.type() == y.type() - auto self_type = ProfiledTensorType::create(node->input(0)->type()); - auto other_type = ProfiledTensorType::create(node->input(1)->type()); + auto self_type = node->input(0)->type()->expect(); + auto other_type = node->input(1)->type()->expect(); if (mustBeEqual(self_type->scalarType(), other_type->scalarType()) && mustBeEqual(self_type->device(), other_type->device())) { GRAPH_UPDATE( @@ -108,7 +107,7 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) { // type_as conditional on the tensor shape being a scalar, but that // might add overhead, and make analysis harder. auto add_mat_type = - ProfiledTensorType::create(node->input(1 - mm_side)->type()); + node->input(1 - mm_side)->type()->expect(); // if we don't have the rank, we can't tell if the bias is a scalar if (!add_mat_type->sizes().size()) { continue; @@ -123,10 +122,10 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) { SymbolicVariable mat1(mm_node->input(0)); SymbolicVariable mat2(mm_node->input(1)); - auto mat1_type = ProfiledTensorType::create(mat1.value()->type()); + auto mat1_type = mat1.value()->type()->expect(); auto mat_scalar_type = mat1_type->scalarType(); if (!mat_scalar_type) { - auto mat2_type = ProfiledTensorType::create(mat2.value()->type()); + auto mat2_type = mat2.value()->type()->expect(); mat_scalar_type = mat2_type->scalarType(); } @@ -282,7 +281,7 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) { node->output()->replaceAllUsesWith(node->input()); } } else if (node->matches("prim::dtype(Tensor a) -> int")) { - auto ptt = ProfiledTensorType::create(node->input()->type()); + auto ptt = node->input()->type()->expect(); if (ptt->scalarType()) { WithInsertPoint guard(node); auto output = node->owningGraph()->insertConstant( @@ -292,7 +291,7 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) { node->output()->replaceAllUsesWith(output); } } else if (node->matches("prim::device(Tensor a) -> Device")) { - auto ptt = ProfiledTensorType::create(node->input()->type()); + auto ptt = node->input()->type()->expect(); if (ptt->device()) { WithInsertPoint guard(node); auto output = node->owningGraph()->insertConstant(*ptt->device()); @@ -304,7 +303,7 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) { node->output()->replaceAllUsesWith(output); } } else if (node->matches("aten::dim(Tensor self) -> int")) { - auto ptt = ProfiledTensorType::create(node->input()->type()); + auto ptt = node->input()->type()->expect(); if (auto dim = ptt->sizes().size()) { WithInsertPoint guard(node); auto output = @@ -317,7 +316,7 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) { node->output()->replaceAllUsesWith(output); } } else if (node->matches("prim::is_cuda(Tensor a) -> bool")) { - auto ptt = ProfiledTensorType::create(node->input()->type()); + auto ptt = node->input()->type()->expect(); if (ptt->device()) { WithInsertPoint guard(node); auto output = diff --git a/torch/csrc/jit/passes/requires_grad_analysis.cpp b/torch/csrc/jit/passes/requires_grad_analysis.cpp index eda1adf8dc8..04aac5b8d19 100644 --- a/torch/csrc/jit/passes/requires_grad_analysis.cpp +++ b/torch/csrc/jit/passes/requires_grad_analysis.cpp @@ -16,7 +16,7 @@ bool getRequiresGrad(Value* value) { } void setRequiresGrad(Value* value, bool req_value) { - if (auto type = value->type()->cast()) { + if (auto type = value->type()->cast()) { value->setType(type->withRequiresGrad(req_value)); } } @@ -72,7 +72,7 @@ void PropagateRequiresGradSimpleNode(Node* node) { return setRequiresGrad(node->output(), *const_arg); } } - if (auto type = node->output()->type()->cast()) { + if (auto type = node->output()->type()->cast()) { if (type->scalarType()) { setRequiresGrad(node->output(), at::isFloatingType(*type->scalarType())); } @@ -85,7 +85,7 @@ void PropagateRequiresGradSimpleNode(Node* node) { bool should_require = std::any_of(inputs.begin(), inputs.end(), getRequiresGrad); for (Value* output : outputs) { - if (auto type = output->type()->cast()) { + if (auto type = output->type()->cast()) { if (type->scalarType()) { setRequiresGrad( output, should_require && at::isFloatingType(*type->scalarType())); diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 3692b311ed5..19be42eea93 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -38,7 +38,7 @@ bool isValidArgumentForRunning(Value* v) { // allow constants if (toIValue(v)) return true; - if (ProfiledTensorTypePtr tt = v->type()->cast()) { + if (TensorTypePtr tt = v->type()->cast()) { if (!tt->scalarType()) { return false; } @@ -158,7 +158,7 @@ class ShapePropagator { if (auto iv = toIValue(v)) { return *iv; } - if (ProfiledTensorTypePtr type = type_->cast()) { + if (TensorTypePtr type = type_->cast()) { if (type->isComplete()) { auto attype = type->device()->is_cpu() ? at::CPU(*type->scalarType()) : at::CUDA(*type->scalarType()); @@ -185,10 +185,10 @@ class ShapePropagator { // for each node in the schema with type Tensor, extract the T type // returns c10::nullopt if any Tensor in the schema does not have a known // shape ignores non-tensor in the list of inputs - c10::optional> gatherTensorTypes( + c10::optional> gatherTensorTypes( Node* node, bool complete = false) { - std::vector tensor_types; + std::vector tensor_types; auto& schema = node->schema(); auto& args = schema.arguments(); @@ -201,7 +201,7 @@ class ShapePropagator { if (args[i].type()->isSubtypeOf(ListType::ofTensors())) { return c10::nullopt; } else if (args[i].type()->isSubtypeOf(TensorType::get())) { - if (auto type = node->input(i)->type()->cast()) { + if (auto type = node->input(i)->type()->cast()) { if (complete && !type->isComplete()) { return c10::nullopt; } @@ -236,14 +236,14 @@ class ShapePropagator { void broadcastBinary( Node* node, - std::vector& types, + std::vector& types, size_t idx1, size_t idx2) { auto expected_size = at::infer_size( *types[idx1]->sizes().concrete_sizes(), *types[idx2]->sizes().concrete_sizes()); auto broadcast = [&](size_t input_idx) { - ProfiledTensorTypePtr input_type = types.at(input_idx); + TensorTypePtr input_type = types.at(input_idx); if (input_type->sizes() == expected_size) return; auto graph = node->owningGraph(); @@ -260,8 +260,8 @@ class ShapePropagator { }; broadcast(idx1); broadcast(idx2); - types[0] = node->inputs().at(idx1)->type()->expect(); - types[1] = node->inputs().at(idx2)->type()->expect(); + types[0] = node->inputs().at(idx1)->type()->expect(); + types[1] = node->inputs().at(idx2)->type()->expect(); } OperatorSet cannot_propagate_shape_by_running_it = { @@ -369,13 +369,12 @@ class ShapePropagator { void PropagateCatShape(Node* cat_node) { static const auto propagate_complete = [this](Node* node, at::ArrayRef tensors) -> bool { - auto input_types = fmap(tensors, [](Value* v) { - return v->type()->cast(); - }); + auto input_types = + fmap(tensors, [](Value* v) { return v->type()->cast(); }); if (!std::all_of( input_types.begin(), input_types.end(), - [](const ProfiledTensorTypePtr& tp) { + [](const TensorTypePtr& tp) { return tp != nullptr && tp->isComplete(); })) { return false; @@ -407,7 +406,7 @@ class ShapePropagator { static const auto propagate = [](Node* node, at::ArrayRef tensors) -> bool { for (Value* v : tensors) { - if (auto type = v->type()->cast()) { + if (auto type = v->type()->cast()) { node->output()->setType(type->dimensionedOnly()); return true; } @@ -463,7 +462,7 @@ class ShapePropagator { default_device = inp->toDevice(); } } - node->output()->setType(ProfiledTensorType::create( + node->output()->setType(TensorType::create( default_type, default_device, dims, /*requires_grad=*/c10::nullopt)); } @@ -535,10 +534,10 @@ class ShapePropagator { TypePtr typ = node->input()->type(); if (typ->isSubtypeOf(IntType::get()) || typ->isSubtypeOf(BoolType::get())) { - node->output()->setType(ProfiledTensorType::create( + node->output()->setType(TensorType::create( at::kLong, at::kCPU, 0, /*requires_grad=*/c10::nullopt)); } else if (node->input()->type()->isSubtypeOf(FloatType::get())) { - node->output()->setType(ProfiledTensorType::create( + node->output()->setType(TensorType::create( at::kDouble, at::kCPU, 0, /*requires_grad=*/c10::nullopt)); } return; @@ -591,7 +590,7 @@ class ShapePropagator { } case prim::ConstantChunk: { Value* tensor = node->input(); - if (auto type = tensor->type()->cast()) { + if (auto type = tensor->type()->cast()) { type = type->dimensionedOnly(); for (Value* output : node->outputs()) { output->setType(type); @@ -683,9 +682,8 @@ class ShapePropagator { // primitive/tensor outputs. bool PropagateTensorShapeOnNode(Node* node, bool insert_expands) { - static const auto broadcast = - [](std::vector& tensor_types, - size_t arg_for_type) -> ProfiledTensorTypePtr { + static const auto broadcast = [](std::vector& tensor_types, + size_t arg_for_type) -> TensorTypePtr { if (tensor_types.size() == 1) { return tensor_types[0]->dimensionedOnly(); } @@ -699,14 +697,14 @@ class ShapePropagator { max_dims = std::max(*max_dims, *type->dim()); } } - return ProfiledTensorType::create( + return TensorType::create( any_type->scalarType(), any_type->device(), max_dims, /*requires_grad=*/c10::nullopt); }; - using type_vec_t = std::vector; + using type_vec_t = std::vector; // Formula is expected to return a vector of length equal to the number of // tensor outputs of the node, or an empty vector which implies that it // failed to propagate. @@ -815,9 +813,9 @@ class ShapePropagator { "aten::zeros_like(Tensor self) -> Tensor", }, [](Node* node) -> type_vec_t { - auto input_type = - node->input(0)->type()->cast(); - return input_type ? type_vec_t{input_type->dimensionedOnly()} : type_vec_t{}; + auto input_type = node->input(0)->type()->cast(); + return input_type ? type_vec_t{input_type->dimensionedOnly()} + : type_vec_t{}; }}; // Requirements: @@ -929,10 +927,9 @@ class ShapePropagator { return {}; }}; - static const auto any_tensor_type = - [](Node* node) -> ProfiledTensorTypePtr { + static const auto any_tensor_type = [](Node* node) -> TensorTypePtr { for (Value* input : node->inputs()) { - if (auto type = input->type()->cast()) { + if (auto type = input->type()->cast()) { if (type->dim().has_value()) { return type; } @@ -1035,8 +1032,7 @@ class ShapePropagator { "aten::prelu(Tensor self, Tensor weight) -> Tensor", }, [](Node* node) -> type_vec_t { - if (auto type = - node->input(0)->type()->cast()) { + if (auto type = node->input(0)->type()->cast()) { return {type->dimensionedOnly()}; } return {}; @@ -1065,8 +1061,7 @@ class ShapePropagator { "aten::any(Tensor self) -> Tensor", }, [](Node* node) -> type_vec_t { - if (auto type = - node->input(0)->type()->cast()) { + if (auto type = node->input(0)->type()->cast()) { return {type->withDim(0)}; } return {}; @@ -1086,7 +1081,7 @@ class ShapePropagator { }, [](Node* node) -> type_vec_t { at::optional maybe_dtype_option = node->get(attr::dtype); - if (auto type = node->input(0)->type()->cast()) { + if (auto type = node->input(0)->type()->cast()) { auto ret = type->withDim(0); if(maybe_dtype_option && !maybe_dtype_option->isNone()) { return {ret->withScalarType(maybe_dtype_option->toScalarType())}; @@ -1105,51 +1100,53 @@ class ShapePropagator { // tensor outputs : 1 // Additionally: // - First input should be the only tensor input - static const register_formula_for all_reduce_ops_with_integer_upcast_and_dtype{ - { - "aten::sum(Tensor self, *, int? dtype) -> Tensor", - "aten::prod(Tensor self, *, int? dtype) -> Tensor", - }, - [](Node* node) -> type_vec_t { - if (auto type = - node->input(0)->type()->cast()) { - type = type->withDim(0); - at::optional maybe_dtype_option = node->get(attr::dtype); - if( maybe_dtype_option && ! maybe_dtype_option->isNone() ) { - return {type->withScalarType(maybe_dtype_option->toScalarType())}; - } - if (type->scalarType()) { - return { at::isFloatingType(*type->scalarType()) - ? type - : type->withScalarType(at::kLong)}; - } else { - return { type }; - } - } - return {}; - }}; - - static const auto reduce_op_handler = - [](Node* node, - int64_t num_reduced_dim = 0, - bool upcast_integer = false, - c10::optional opt_dtype = c10::nullopt) -> type_vec_t { - if(auto type = node->input(0)->type()->cast()) { - if (!type->scalarType() || !type->dim()) { + static const register_formula_for + all_reduce_ops_with_integer_upcast_and_dtype{ + { + "aten::sum(Tensor self, *, int? dtype) -> Tensor", + "aten::prod(Tensor self, *, int? dtype) -> Tensor", + }, + [](Node* node) -> type_vec_t { + if (auto type = node->input(0)->type()->cast()) { + type = type->withDim(0); + at::optional maybe_dtype_option = + node->get(attr::dtype); + if (maybe_dtype_option && !maybe_dtype_option->isNone()) { + return { + type->withScalarType(maybe_dtype_option->toScalarType())}; + } + if (type->scalarType()) { + return {at::isFloatingType(*type->scalarType()) + ? type + : type->withScalarType(at::kLong)}; + } else { + return {type}; + } + } return {}; - } - if( opt_dtype && ! opt_dtype->isNone() ) { - type = type->withScalarType(opt_dtype->toScalarType()); - } else if(upcast_integer && !at::isFloatingType(*type->scalarType())) { - type = type->withScalarType(at::kLong); - } - if (*type->dim() >= num_reduced_dim && num_reduced_dim > 0) { - return {type->withDim(*type->dim() - num_reduced_dim)}; - } else { - return {type}; - } - } + }}; + + static const auto reduce_op_handler = [](Node* node, + int64_t num_reduced_dim = 0, + bool upcast_integer = false, + c10::optional opt_dtype = + c10::nullopt) -> type_vec_t { + if (auto type = node->input(0)->type()->cast()) { + if (!type->scalarType() || !type->dim()) { return {}; + } + if (opt_dtype && !opt_dtype->isNone()) { + type = type->withScalarType(opt_dtype->toScalarType()); + } else if (upcast_integer && !at::isFloatingType(*type->scalarType())) { + type = type->withScalarType(at::kLong); + } + if (*type->dim() >= num_reduced_dim && num_reduced_dim > 0) { + return {type->withDim(*type->dim() - num_reduced_dim)}; + } else { + return {type}; + } + } + return {}; }; static const auto multidim_reduce_with_keepdim = @@ -1175,8 +1172,7 @@ class ShapePropagator { "aten::argmin(Tensor self, int? dim, bool keepdim) -> Tensor", }, [](Node* node) -> type_vec_t { - if (auto type = - node->input(0)->type()->cast()) { + if (auto type = node->input(0)->type()->cast()) { if (node->input(1)->type()->kind() == c10::TypeKind::NoneType) { return {type->withDim(0)}; } else { @@ -1300,7 +1296,7 @@ class ShapePropagator { (maybe_dtype_option->isNone() ? at::kDouble : maybe_dtype_option->toScalarType()); - return {ProfiledTensorType::create( + return {TensorType::create( dtype, device, dim, /*requires_grad=*/c10::nullopt)}; }; @@ -1324,9 +1320,8 @@ class ShapePropagator { "aten::zeros_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor", }, [](Node* node) -> type_vec_t { - if (auto type = node->namedInput(attr::self) - ->type() - ->cast()) { + if (auto type = + node->namedInput(attr::self)->type()->cast()) { if (type->dim()) { return factory_with_ndim(node, *type->dim()); } @@ -1398,9 +1393,8 @@ class ShapePropagator { "aten::_cast_Short(Tensor self, bool non_blocking) -> Tensor", }, [](Node* node) -> type_vec_t { - if (auto type = node->namedInput(attr::self) - ->type() - ->cast()) { + if (auto type = + node->namedInput(attr::self)->type()->cast()) { return {type->withScalarType(get_cast_scalar_type(node))}; } return {}; @@ -1428,7 +1422,7 @@ class ShapePropagator { // This section implements shape prop for an assorted set of nodes that only // need partial information about their input types. const auto input_type = [node](size_t index) { - auto result = node->input(index)->type()->cast(); + auto result = node->input(index)->type()->cast(); if (result) { result = result->dimensionedOnly(); } @@ -1562,12 +1556,11 @@ class ShapePropagator { // The code below implements formulas that need type information for all // their tensor inputs, and have exactly one output. - std::vector tensor_types; + std::vector tensor_types; static const auto reshape_prop = [](Node* node, Symbol shape_input, - const std::vector& tensor_types) - -> ProfiledTensorTypePtr { + const std::vector& tensor_types) -> TensorTypePtr { if (auto list_size = determineListSize(node->namedInput(shape_input))) { return tensor_types.at(0)->withDim(*list_size); } @@ -1593,7 +1586,7 @@ class ShapePropagator { return reshape_prop(node, attr::size, tensor_types); } else if (node->matches("aten::as_tensor(Tensor data, *, ScalarType? dtype, Device? device) -> Tensor")) { TypePtr input_type = node->inputs().at(0)->type(); - if (auto type = input_type->cast()) { + if (auto type = input_type->cast()) { if (type->scalarType() && type->device()) { at::ScalarType default_type = *type->scalarType(); c10::Device default_device = *type->device(); @@ -1615,7 +1608,7 @@ class ShapePropagator { default_device = inp->toDevice(); } } - node->output()->setType(ProfiledTensorType::create( + node->output()->setType(TensorType::create( default_type, default_device, type->dim(), @@ -1722,7 +1715,7 @@ class ShapePropagator { bool PropagateCompleteShapeOnNode( Node* node, bool insert_expands, - std::vector tensor_types) { + std::vector tensor_types) { // For expensive ops we can directly encode their shape propagation // here, otherwise we fallback to running a fake version of the op // to get a quick and dirty propagation. @@ -1775,7 +1768,7 @@ class ShapePropagator { auto rhs_sizes = rhs_type->sizes().concrete_sizes().value(); SHAPE_ASSERT( *lhs_type->sizes().size() == 2 && *rhs_type->sizes().size() == 2); - node->output()->setType(ProfiledTensorType::createContiguous( + node->output()->setType(TensorType::createContiguous( *lhs_type->scalarType(), *lhs_type->device(), at::IntArrayRef{lhs_sizes[0], rhs_sizes[1]})); @@ -1941,7 +1934,7 @@ class ShapePropagator { (int64_t)*tensor_types.at(0)->sizes().size()}; at::IntArrayRef dims(dim_vec); node->output()->setType( - ProfiledTensorType::createContiguous(at::kLong, at::kCPU, dims)); + TensorType::createContiguous(at::kLong, at::kCPU, dims)); return true; } else if (node->kind() == ::c10::onnx::Reshape) { setUnshapedType(node); diff --git a/torch/csrc/jit/passes/specialize_autogradzero.cpp b/torch/csrc/jit/passes/specialize_autogradzero.cpp index 10d3c36a2e9..3adc8dc50a9 100644 --- a/torch/csrc/jit/passes/specialize_autogradzero.cpp +++ b/torch/csrc/jit/passes/specialize_autogradzero.cpp @@ -16,7 +16,7 @@ void specializeAutogradZero(Graph& g) { for (Value* input : g.inputs()) { const auto& tp = input->type(); - if (auto tt = tp->cast()) { + if (auto tt = tp->cast()) { if (tt->autogradZero() && *tt->autogradZero()) { state[input] = State::Zero; } else { diff --git a/torch/csrc/jit/profiling_record.cpp b/torch/csrc/jit/profiling_record.cpp index f8b89a27418..07f0b203b18 100644 --- a/torch/csrc/jit/profiling_record.cpp +++ b/torch/csrc/jit/profiling_record.cpp @@ -29,22 +29,26 @@ void ProfilingRecord::instrumentBlock(Block* block) { auto pn = createProfileNode(nullptr, {i}); auto pno = pn->addOutput(); + bool first = true; pno->setType(TensorType::get()); - std::function shape_profiler = [this, pno](Stack& stack) { - IValue t; - pop(stack, t); - if (t.isTensor()) { - auto pttp = ProfiledTensorType::create(t.toTensor()); - std::lock_guard lock(this->mutex_); - if (auto type = pno->type()->cast()) { - pno->setType(type->merge(pttp)); - } else { - pno->setType(pttp); - } - } - // passing t through - push(stack, t); - }; + std::function shape_profiler = + [this, pno, first](Stack& stack) mutable { + IValue t; + pop(stack, t); + if (t.isTensor()) { + auto pttp = TensorType::create(t.toTensor()); + std::lock_guard lock(this->mutex_); + if (auto type = pno->type()->cast()) { + if (!first) { + pttp = pttp->merge(type); + } + pno->setType(pttp); + first = false; + } + } + // passing t through + push(stack, t); + }; pn->setCallback(shape_profiler); pn->insertBefore(n); @@ -77,11 +81,10 @@ std::unique_ptr ProfilingRecord::instrumentGraph( return pr; } -ProfiledTensorTypePtr ProfilingRecord::toProfiledTensorTypePtr( - const IValue& ival) { +TensorTypePtr ProfilingRecord::toTensorTypePtr(const IValue& ival) { if (ival.isTensor()) { auto tensor = ival.toTensor(); - return ProfiledTensorType::create(tensor); + return TensorType::create(tensor); } return {nullptr}; diff --git a/torch/csrc/jit/profiling_record.h b/torch/csrc/jit/profiling_record.h index 49acfeb5f5e..8a15ff4c55d 100644 --- a/torch/csrc/jit/profiling_record.h +++ b/torch/csrc/jit/profiling_record.h @@ -13,7 +13,7 @@ namespace torch { namespace jit { -using ::c10::ProfiledTensorTypePtr; +using ::c10::TensorTypePtr; struct ProfilingRecord { // N.B. ProfilingRecord's copy and move c-tor are disabled, so we won't @@ -21,7 +21,7 @@ struct ProfilingRecord { // are captured in callbacks_ ProfilingRecord(const ProfilingRecord&) = delete; ProfilingRecord(ProfilingRecord&&) noexcept = delete; - static ProfiledTensorTypePtr toProfiledTensorTypePtr(const IValue& ival); + static TensorTypePtr toTensorTypePtr(const IValue& ival); TORCH_API static std::unique_ptr instrumentGraph( const std::shared_ptr& graph); diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h index 66f3ea913f7..45ae7e524e5 100644 --- a/torch/csrc/jit/pybind_utils.h +++ b/torch/csrc/jit/pybind_utils.h @@ -93,7 +93,7 @@ inline MatchTypeReturn tryToInferType(py::handle input) { // Try tensor types if (THPVariable_Check(input.ptr())) { auto tensor = py::cast(input); - return MatchTypeReturn(ProfiledTensorType::create(tensor)); + return MatchTypeReturn(TensorType::create(tensor)); } if (input.is(py::none())) { @@ -310,8 +310,7 @@ inline IValue toIValue( const TypePtr& type, c10::optional N) { switch (type->kind()) { - case TypeKind::TensorType: - case TypeKind::ProfiledTensorType: { + case TypeKind::TensorType: { auto var = py::cast(obj); if (var.is_sparse()) { AT_WARN( @@ -391,7 +390,6 @@ inline IValue toIValue( } return repeated; } - case TypeKind::ProfiledTensorType: case TypeKind::TensorType: return c10::impl::toList(py::cast>(obj)); default: diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index 5c9df2f20a0..3f20c939460 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -631,15 +631,14 @@ void initPythonIRBindings(PyObject* module_) { .def( "dim", [](Type& t) { - auto vshape = - ProfiledTensorType::create(t.shared_from_this())->sizes(); + auto vshape = t.shared_from_this()->expect()->sizes(); return vshape.size() ? py::cast(*vshape.size()) : py::cast(Py_None); }) .def( "sizes", [](Type& t) -> py::object { - if (auto ptt = t.expect()) { + if (auto ptt = t.expect()) { if (auto cs = ptt->sizes().concrete_sizes()) { return py::cast(*cs); } @@ -649,7 +648,7 @@ void initPythonIRBindings(PyObject* module_) { .def( "sizes", [](Type& t) -> py::object { - if (auto ptt = t.expect()) { + if (auto ptt = t.expect()) { if (auto cs = ptt->strides().concrete_sizes()) { return py::cast(*cs); } @@ -660,13 +659,13 @@ void initPythonIRBindings(PyObject* module_) { "contiguous", [](Type& t) { return std::static_pointer_cast( - t.expect()->contiguous()); + t.expect()->contiguous()); }) .def( "scalarType", [](Type& t) { auto scalar_type = - ProfiledTensorType::create(t.shared_from_this())->scalarType(); + t.shared_from_this()->expect()->scalarType(); return (scalar_type) ? toString(*scalar_type) : nullptr; }) .def( diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 2cca9774fef..82d0bebfb53 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -280,7 +280,7 @@ struct VISIBILITY_HIDDEN ModuleSelf : public Self { }; static TypePtr getTensorType(const at::Tensor& t, bool complete) { - auto r = ProfiledTensorType::create(t); + auto r = TensorType::create(t); if (!complete) { r = r->dimensionedOnly(); } @@ -354,8 +354,8 @@ static std::shared_ptr _assign_output_shapes( for (size_t i = 0; i < outputs.size(); ++i) { auto scalar_type = outputs[i].scalar_type(); auto sizes = outputs[i].sizes(); - auto type = torch::jit::ProfiledTensorType::createContiguous( - scalar_type, at::kCPU, sizes); + auto type = + torch::jit::TensorType::createContiguous(scalar_type, at::kCPU, sizes); retval->outputs()[i]->setType(type); } return retval; diff --git a/torch/csrc/jit/script/schema_type_parser.cpp b/torch/csrc/jit/script/schema_type_parser.cpp index 6017642b128..36adb3505c0 100644 --- a/torch/csrc/jit/script/schema_type_parser.cpp +++ b/torch/csrc/jit/script/schema_type_parser.cpp @@ -147,7 +147,7 @@ TypePtr SchemaTypeParser::parseRefinedTensor() { L.expect('*'); num_dims++; }); - ptr = at::ProfiledTensorType::create( + ptr = at::TensorType::create( dtype, at::DeviceType::CPU, c10::VaryingShape(num_dims), @@ -166,8 +166,7 @@ TypePtr SchemaTypeParser::parseRefinedTensor() { dims.push_back(dim); }); at::IntArrayRef dims_ref(dims); - ptr = at::ProfiledTensorType::create( - dtype, at::DeviceType::CPU, dims_ref, false); + ptr = at::TensorType::create(dtype, at::DeviceType::CPU, dims_ref, false); } return ptr; } diff --git a/torch/csrc/jit/symbolic_variable.h b/torch/csrc/jit/symbolic_variable.h index ec35397f8e4..de6f7fd9999 100644 --- a/torch/csrc/jit/symbolic_variable.h +++ b/torch/csrc/jit/symbolic_variable.h @@ -21,11 +21,7 @@ struct SymbolicVariable { return g.addInput()->setType(std::move(type)); } std::vector sizes() const { - return v->type() - ->expect() - ->sizes() - .concrete_sizes() - .value(); + return v->type()->expect()->sizes().concrete_sizes().value(); } void addAsOutput() const { v->owningGraph()->registerOutput(v); @@ -317,7 +313,7 @@ struct SymbolicVariable { return v->owningGraph()->insertConstant(std::move(value)); } SymbolicVariable typeLike(SymbolicVariable other) const { - if (auto other_type = other.v->type()->cast()) + if (auto other_type = other.v->type()->cast()) v->setType(other_type->contiguous()); return *this; } @@ -340,7 +336,7 @@ struct SymbolicVariable { SymbolicVariable typeLikeWithScalarType( SymbolicVariable other, at::ScalarType type) const { - if (auto other_type = other.v->type()->cast()) { + if (auto other_type = other.v->type()->cast()) { auto new_type = other_type->withScalarType(type)->contiguous(); v->setType(new_type); } @@ -349,8 +345,8 @@ struct SymbolicVariable { SymbolicVariable typeLikeWithRhsScalarType( SymbolicVariable other, SymbolicVariable rhs) const { - auto other_type = other.v->type()->cast(); - auto rhs_type = rhs.v->type()->cast(); + auto other_type = other.v->type()->cast(); + auto rhs_type = rhs.v->type()->cast(); if (other_type && rhs_type && other_type->isComplete() && rhs_type->isComplete()) { auto new_type = diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 472bbbea7ab..36d53b3b87e 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -32,18 +32,8 @@ from functools import wraps # # In general, we should avoid depending on the type of Tensor Values contained # within the trace graph. However, this is sometimes unavoidable (due to ONNX -# spec requirements, etc). If you are implementing a symbolic and need Tensor -# type information, note that there are several levels of Tensor types, defined -# in aten/src/ATen/core/jit_type.h: -# -# TensorType - This is a Tensor, but we don't know anything about its -# properties (e.g. scalar type, # dims, shapes). -# Appears as `Tensor` in graph print-outs. -# ProfiledTensorType <: TensorType - Denotes a Tensor for which we know the -# concrete sizes in addition to the information -# contained in TensorTyper. This adds a sizes() -# method which can be used to retrieve the -# concrete sizes. +# spec requirements, etc). The TensorType object has accessors for these properties +# that return the property if it is statically known and return nullopt otherwise. # # In general, we should prefer to rely on the least specific information possible. # For example, not relying on tensor properties at all is better than relying