From cb8d9f99aaf66947e371ae750eccd23b508fee91 Mon Sep 17 00:00:00 2001 From: Meghan Lele Date: Thu, 27 Feb 2020 21:43:17 -0800 Subject: [PATCH] [JIT] Implement Tensor.tolist() (#33472) Summary: **Summary** This commit adds an implementation of `Tensor.tolist()` to the JIT interpreter. **Testing** This commit adds several unit tests that test that this function works correctly for 0D, 1D, 2D and 3D tensors of type `float`, `int` and `bool`. ``` (base) meghanl-mbp:pytorch meghanl$ python test/test_jit.py TestList.test_to_list -v Fail to import hypothesis in common_utils, tests are not derandomized test_to_list (jit.test_list_dict.TestList) Unit tests for Tensor.tolist() function. ... ok ---------------------------------------------------------------------- Ran 1 test in 0.329s OK ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/33472 Differential Revision: D20109738 Pulled By: SplitInfinity fbshipit-source-id: a6e3fee5e3201d5e1f0c4ca45048488ae2bf5e33 --- aten/src/ATen/core/interned_strings.h | 1 + test/jit/test_list_dict.py | 175 ++++++++++++++++++ torch/csrc/jit/frontend/ir_emitter.cpp | 25 ++- torch/csrc/jit/frontend/sugared_value.cpp | 5 + torch/csrc/jit/ir/alias_analysis.cpp | 1 + torch/csrc/jit/ir/ir.cpp | 34 ++++ torch/csrc/jit/ir/ir.h | 4 + torch/csrc/jit/runtime/operator.cpp | 2 + torch/csrc/jit/runtime/register_prim_ops.cpp | 117 ++++++++++++ torch/csrc/jit/serialization/python_print.cpp | 5 + 10 files changed, 365 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index b5997f4cd1a..10cb7ee1f72 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -97,6 +97,7 @@ namespace c10 { _(prim, range) \ _(prim, rangelist) \ _(prim, isinstance) \ + _(prim, tolist) \ _(prim, unchecked_cast) \ _(aten, _grad_sum_to_size) \ _(aten, _size_if_not_equal) \ diff --git a/test/jit/test_list_dict.py b/test/jit/test_list_dict.py index e90780ed048..04a8a789d0b 100644 --- a/test/jit/test_list_dict.py +++ b/test/jit/test_list_dict.py @@ -892,6 +892,181 @@ class TestList(JitTestCase): check_list(min_floatlist, float_li) check_list(max_floatlist, float_li) + def test_to_list(self): + """Unit tests for Tensor.tolist() function.""" + + """ + Boolean dtype unit tests. + """ + def to_list_bool_0D(x): + # type: (torch.Tensor) -> bool + li = torch.jit.annotate(bool, x.tolist()) + return li + + def to_list_bool_1D(x): + # type: (torch.Tensor) -> List[bool] + li = torch.jit.annotate(List[bool], x.tolist()) + return li + + def to_list_bool_2D(x): + # type: (torch.Tensor) -> List[List[bool]] + li = torch.jit.annotate(List[List[bool]], x.tolist()) + return li + + def to_list_bool_3D(x): + # type: (torch.Tensor) -> List[List[List[bool]]] + li = torch.jit.annotate(List[List[List[bool]]], x.tolist()) + return li + + self.checkScript(to_list_bool_0D, (torch.tensor(False, dtype=torch.bool),)) + bool_input_1D = torch.tensor([True, False, True, False], dtype=torch.bool) + self.checkScript(to_list_bool_1D, (bool_input_1D,)) + bool_input_2D = torch.tensor( + [[True, True, False], [False, True, False]], dtype=torch.bool + ) + self.checkScript(to_list_bool_2D, (bool_input_2D,)) + bool_input_3D = torch.tensor( + [[[True, False], [False, True]], [[True, False], [False, False]]], + dtype=torch.bool, + ) + self.checkScript(to_list_bool_3D, (bool_input_3D,)) + bool_input_noncontiguous = torch.tensor( + [[[True, False], [False, True]], [[True, False], [False, False]]], + dtype=torch.bool, + ).transpose(0, 1) + self.checkScript(to_list_bool_3D, (bool_input_noncontiguous,)) + + """ + Int dtype unit tests. + """ + def to_list_int_0D(x): + # type: (torch.Tensor) -> int + li = torch.jit.annotate(int, x.tolist()) + return li + + def to_list_int_1D(x): + # type: (torch.Tensor) -> List[int] + li = torch.jit.annotate(List[int], x.tolist()) + return li + + def to_list_int_2D(x): + # type: (torch.Tensor) -> List[List[int]] + li = torch.jit.annotate(List[List[int]], x.tolist()) + return li + + def to_list_int_3D(x): + # type: (torch.Tensor) -> List[List[List[int]]] + li = torch.jit.annotate(List[List[List[int]]], x.tolist()) + return li + + self.checkScript(to_list_int_0D, (torch.tensor(1, dtype=torch.long),)) + int_input_1D = torch.tensor([1, 2, 3, 4], dtype=torch.long) + self.checkScript(to_list_int_1D, (int_input_1D,)) + int_input_2D = torch.tensor([[1, 2, 3], [3, 4, 5]], dtype=torch.long) + self.checkScript(to_list_int_2D, (int_input_2D,)) + int_input_3D = torch.tensor( + [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=torch.long + ) + self.checkScript(to_list_int_3D, (int_input_3D,)) + int_input_noncontiguous = torch.tensor( + [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=torch.long + ).transpose(0, 1) + self.checkScript(to_list_int_3D, (int_input_noncontiguous,)) + + """ + Float dtype unit tests. + """ + def to_list_float_0D(x): + # type: (torch.Tensor) -> float + li = torch.jit.annotate(float, x.tolist()) + return li + + def to_list_float_1D(x): + # type: (torch.Tensor) -> List[float] + li = torch.jit.annotate(List[float], x.tolist()) + return li + + def to_list_float_2D(x): + # type: (torch.Tensor) -> List[List[float]] + li = torch.jit.annotate(List[List[float]], x.tolist()) + return li + + def to_list_float_3D(x): + # type: (torch.Tensor) -> List[List[List[float]]] + li = torch.jit.annotate(List[List[List[float]]], x.tolist()) + return li + + self.checkScript(to_list_float_0D, (torch.randn(5, dtype=torch.double)[0],)) + self.checkScript(to_list_float_1D, (torch.randn(5, dtype=torch.double),)) + self.checkScript(to_list_float_2D, (torch.randn(5, 6, dtype=torch.double),)) + self.checkScript(to_list_float_3D, (torch.randn(5, 6, 7, dtype=torch.double),)) + self.checkScript(to_list_float_3D, (torch.randn(5, 6, 7, dtype=torch.double).transpose(0, 1),)) + + """ + Non-happy path tests: + - missing type annotation + - mismatch between type annotation and input + - type annotation with unsupported type + - type annotation with the wrong dimension + - type annotation with scalar type that doesn't match the input scalar type + """ + def to_list_missing_type_annotation(x): + # type: (torch.Tensor) -> List[float] + li = x.tolist() + return li + + def to_list_incorrect_type_annotation(x): + # type: (torch.Tensor) -> List[float] + li = torch.jit.annotate(float, x.tolist()) + return li + + def to_list_unsupported_type_annotation(x): + # type: (torch.Tensor) -> List[float] + li = torch.jit.annotate(List[str], x.tolist()) + return li + + def to_list_type_annotation_wrong_dim(x): + # type: (torch.Tensor) -> List[List[float]] + li = torch.jit.annotate(List[List[float]], x.tolist()) + return li + + def to_list_type_annotation_incorrect_scalar_type(x): + # type: (torch.Tensor) -> List[float] + li = torch.jit.annotate(List[float], x.tolist()) + return li + + with self.assertRaisesRegex( + RuntimeError, r"Expected type hint for result of tolist()" + ): + self.checkScript(to_list_missing_type_annotation, (torch.randn(5),)) + + with self.assertRaisesRegex( + RuntimeError, + r"Return value was annotated as having type List\[float\] but is actually of type float", + ): + self.checkScript(to_list_incorrect_type_annotation, (torch.randn(5),)) + + with self.assertRaisesRegex( + RuntimeError, r"str is not one of the supported element types for tolist" + ): + self.checkScript(to_list_unsupported_type_annotation, (torch.randn(5),)) + + with self.assertRaisesRegex( + RuntimeError, + r"Output annotation list dimension and runtime tensor dimension must match", + ): + self.checkScript(to_list_type_annotation_wrong_dim, (torch.randn(5, dtype=torch.double),)) + + with self.assertRaisesRegex( + RuntimeError, + r"Output annotation element type and runtime tensor element type must match", + ): + self.checkScript( + to_list_type_annotation_incorrect_scalar_type, + (torch.ones(5, dtype=torch.long),), + ) + + class TestDict(JitTestCase): def dict(self): return {u'a': torch.ones(1), u'b': torch.ones(1) + 1, u'c': torch.ones(1) + 2} diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index 642b96776ad..8ede70cc4f8 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -2387,11 +2387,14 @@ struct to_ir { } } - std::shared_ptr emitApplyExpr(Apply& apply, size_t n_binders) { + std::shared_ptr emitApplyExpr( + Apply& apply, + size_t n_binders, + const TypePtr& type_hint = nullptr) { auto sv = emitSugaredExpr(apply.callee(), 1); auto loc = apply.callee().range(); if (auto special_form = dynamic_cast(sv.get())) { - return emitApplySpecialForm(special_form->form(), apply); + return emitApplySpecialForm(special_form->form(), apply, type_hint); } auto inputs = getNamedValues(apply.inputs(), true); auto attributes = emitAttributes(apply.attributes()); @@ -2405,7 +2408,8 @@ struct to_ir { // evaluation order. std::shared_ptr emitApplySpecialForm( Symbol form, - Apply& apply) { + Apply& apply, + const TypePtr& type_hint = nullptr) { switch (form) { case prim::fork: { auto& trees = apply.inputs().tree()->trees(); @@ -2495,6 +2499,19 @@ struct to_ir { auto result = emitIsInstance(apply.inputs()[0], apply.inputs()[1]); return std::make_shared(result.value()); } + case prim::tolist: { + auto select = Select(apply.callee()); + auto value = select.value(); + auto operand = emitSugaredExpr(value, 1); + + if (!type_hint) { + throw ErrorReport(apply) + << "Expected type hint for result of tolist()"; + } + + return std::make_shared(graph->insertToList( + operand->asValue(value.range(), method), type_hint)); + } case prim::HasAttr: { checkApplyNumInputs(apply, 2); const auto result = emitHasAttr(apply.inputs()[0], apply.inputs()[1]); @@ -2638,7 +2655,7 @@ struct to_ir { } case TK_APPLY: { auto apply = Apply(tree); - return emitApplyExpr(apply, n_binders); + return emitApplyExpr(apply, n_binders, type_hint); } break; default: return std::make_shared(emitSimpleExpr(tree, type_hint)); diff --git a/torch/csrc/jit/frontend/sugared_value.cpp b/torch/csrc/jit/frontend/sugared_value.cpp index 77b05936be2..2c31e87973d 100644 --- a/torch/csrc/jit/frontend/sugared_value.cpp +++ b/torch/csrc/jit/frontend/sugared_value.cpp @@ -147,6 +147,11 @@ std::shared_ptr SimpleValue::attr( return builtin; } + // Handle calling tolist() on a Tensor. + if (value_->type()->isSubtypeOf(TensorType::get()) && field == "tolist") { + return SpecialFormValue::create(prim::tolist); + } + ErrorReport report(loc); report << "Tried to access nonexistent attribute or method '" << field << "' of type '" << value_->type()->python_str() << "'."; diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index d46fe3602b6..047f53a39b0 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -336,6 +336,7 @@ void AliasDb::analyzeImpl(Node* node) { case prim::ChunkSizes: case prim::Function: case prim::CreateObject: + case prim::tolist: return analyzeCreator(node); case prim::TupleConstruct: case prim::DictConstruct: diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index fc2dbbc51eb..284ac06776d 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -1686,6 +1686,40 @@ Value* Graph::insertUncheckedCast(Value* v, TypePtr type) { return n->output(); } +Value* Graph::insertToList(Value* v, TypePtr type) { + int dim = 0; + TypePtr ptr = type; + + // Unwrap the type to determine the number of dimensions. + while (auto list_type = ptr->cast()) { + ptr = list_type->getElementType(); + ++dim; + } + + // Encode the base element type as an integer. + int elem_ty = 0; + if (ptr == IntType::get()) { + elem_ty = 0; + } else if (ptr == FloatType::get()) { + elem_ty = 1; + } else if (ptr == BoolType::get()) { + elem_ty = 2; + } else { + TORCH_CHECK( + false, + ptr->python_str(), + " is not one of the supported element types for tolist: int, float, bool"); + } + + // Pass in the number of dimensions and base element type as arguments + // to the op. + Value* dim_val = insertConstant(IValue(dim)); + Value* elem_ty_val = insertConstant(IValue(elem_ty)); + Node* n = insertNode(create(prim::tolist, {v, dim_val, elem_ty_val})); + n->output()->setType(std::move(type)); + return n->output(); +} + Value* Graph::insertFunctionCall( Function* callee, const script::MatchedSchema& matched) { diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h index c2fb3b32ad2..476484632c9 100644 --- a/torch/csrc/jit/ir/ir.h +++ b/torch/csrc/jit/ir/ir.h @@ -1128,6 +1128,10 @@ struct Graph { TORCH_API Value* insertUncheckedCast(Value* v, TypePtr type); + // Insert a ToList operator with argument \p v and output type \p type. + // \returns the output of the operation. + TORCH_API Value* insertToList(Value* v, TypePtr type); + TORCH_API Value* insertFunctionCall( Function* callee, const script::MatchedSchema& matched); diff --git a/torch/csrc/jit/runtime/operator.cpp b/torch/csrc/jit/runtime/operator.cpp index 188370be30d..2798d27d1ef 100644 --- a/torch/csrc/jit/runtime/operator.cpp +++ b/torch/csrc/jit/runtime/operator.cpp @@ -157,6 +157,7 @@ bool printerHasSpecialCaseFor(Symbol sym) { prim::CallFunction, prim::isinstance, prim::unchecked_cast, + prim::tolist, }; // WARNING: by adding a value to this set, you are asserting that your @@ -239,6 +240,7 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) { aten::wait, prim::isinstance, prim::unchecked_cast, + prim::tolist, }; // Operators that should not be used by alias analysis diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index 483eaf88aeb..a49753659bc 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -103,6 +103,73 @@ void checkImplicitTensorToNum(at::Tensor t, bool toInt) { } } +// Convert the tensor pointed to by \p data to a nested list. \p dim is the +// number of dimensions in the tensor and \p cur_dim is the dimension being +// processed by the current invocation. \p ty is the expected output IR type of +// the operation. \p sizes and \p strides are the sizes and strides of the +// tensor operand and \p element_size is the size in bytes of one tensor +// element. +IValue tensorToListRecursive( + char* data, + int64_t cur_dim, + int64_t num_tensor_dims, + TypePtr ty, + at::IntArrayRef sizes, + at::IntArrayRef strides, + size_t element_size) { + // If ty is a ListType, get the element type. + if (auto list_type = ty->cast()) { + ty = list_type->getElementType(); + } else { + // If the output type is a scalar, read and push one scalar of + // the right type onto the stack. + if (ty == IntType::get()) { + int64_t scalar = *(int64_t*)data; + return IValue(scalar); + } else if (ty == FloatType::get()) { + double scalar = *(double*)data; + return IValue(scalar); + } else if (ty == BoolType::get()) { + bool scalar = *(bool*)data; + return IValue(scalar); + } else { + TORCH_CHECK( + false, + ty->python_str(), + " is not one of the supported types for tolist: int, float, bool"); + } + } + + // Make the result list consisting of elements of type ty. Since this + // invocation is processing dimension cur_dim, there will be sizes[cur_dim] + // output elements. + auto result = c10::impl::GenericList(ty); + result.reserve(sizes[cur_dim]); + + // Since ty was a list type, tensorToListRecursive needs to be called + // recursively on each slice of the tensor in the current dimension. + for (int64_t i = 0, e = sizes[cur_dim]; i < e; ++i) { + auto inner_result = tensorToListRecursive( + data, cur_dim + 1, num_tensor_dims, ty, sizes, strides, element_size); + + if (inner_result.isList()) { + result.emplace_back(inner_result.toList()); + } else if (inner_result.isDouble()) { + result.emplace_back(inner_result.toDouble()); + } else if (inner_result.isInt()) { + result.emplace_back(inner_result.toInt()); + } else if (inner_result.isBool()) { + result.emplace_back(inner_result.toBool()); + } else { + TORCH_INTERNAL_ASSERT("Unknown return type for tensorToListRecursive"); + } + + data += strides[cur_dim] * element_size; + } + + return result; +} + static int64_t floordiv(int64_t a, int64_t b) { if (b == 0) { throw std::runtime_error("division by 0"); @@ -967,6 +1034,56 @@ RegisterOperators reg( return 0; }, aliasAnalysisSpecialCase()), + Operator( + prim::tolist, + // This operator has to be unschematized because the return type depends on the type hint and input. + // The implementation of this operator below is intended to be as close to the Python implementation in + // torch/csrc/utils/tensor_list.cpp as possible. + [](const Node* node) -> Operation { + return [](Stack& stack) { + int elem_ty_val; + int dim_val; + at::Tensor t; + + pop(stack, elem_ty_val); + pop(stack, dim_val); + pop(stack, t); + + // Rebuild the output type using elem_ty_val and dim_val. Start + // with the element type corresponding to elem_ty_val. + TypePtr out_ty; + if (elem_ty_val == 0) { + out_ty = IntType::get(); + } else if (elem_ty_val == 1) { + out_ty = FloatType::get(); + } else if (elem_ty_val == 2) { + out_ty = BoolType::get(); + } else { + TORCH_CHECK(false, "Unsupported element type for tolist; only int, float and bool are supported"); + } + + // Check that type of the Tensor matches that of the annotation. + TORCH_CHECK(tryScalarTypeFromJitType(out_ty) == t.scalar_type(), "Output annotation element type and runtime tensor element type must match for tolist()") + + // Check that the dimension of the Tensor matches that of the annotation. + TORCH_CHECK(dim_val == t.dim(), "Output annotation list dimension and runtime tensor dimension must match for tolist()") + + // Wrap out_ty in a ListType dim times. + for (int i = 0; i < dim_val; ++i) { + out_ty = ListType::create(out_ty); + } + + int64_t dim = t.dim(); + auto sizes = t.sizes(); + auto strides = t.strides(); + size_t element_size = t.element_size(); + char* data = (char*)t.data_ptr(); + auto result = tensorToListRecursive(data, 0, dim, out_ty, sizes, strides, element_size); + push(stack, std::move(result)); + return 0; + }; + }, + aliasAnalysisSpecialCase()), Operator( prim::ConstantChunk, [](const Node* node) -> Operation { diff --git a/torch/csrc/jit/serialization/python_print.cpp b/torch/csrc/jit/serialization/python_print.cpp index 709416437a1..6887195ed52 100644 --- a/torch/csrc/jit/serialization/python_print.cpp +++ b/torch/csrc/jit/serialization/python_print.cpp @@ -1043,6 +1043,11 @@ struct PythonPrintImpl { } stmt << ")"; } break; + case prim::tolist: { + stmt << "annotate(" << node->output()->type()->python_str() << ", "; + stmt << useOf(node->input(0)) << ".tolist()" + << ")"; + } break; default: { printOpName(stmt, node->kind()); const FunctionSchema& schema = node->schema();