mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[JIT] Implement Tensor.tolist() (#33472)
Summary: **Summary** This commit adds an implementation of `Tensor.tolist()` to the JIT interpreter. **Testing** This commit adds several unit tests that test that this function works correctly for 0D, 1D, 2D and 3D tensors of type `float`, `int` and `bool`. ``` (base) meghanl-mbp:pytorch meghanl$ python test/test_jit.py TestList.test_to_list -v Fail to import hypothesis in common_utils, tests are not derandomized test_to_list (jit.test_list_dict.TestList) Unit tests for Tensor.tolist() function. ... ok ---------------------------------------------------------------------- Ran 1 test in 0.329s OK ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/33472 Differential Revision: D20109738 Pulled By: SplitInfinity fbshipit-source-id: a6e3fee5e3201d5e1f0c4ca45048488ae2bf5e33
This commit is contained in:
parent
5029ff001b
commit
cb8d9f99aa
|
|
@ -97,6 +97,7 @@ namespace c10 {
|
|||
_(prim, range) \
|
||||
_(prim, rangelist) \
|
||||
_(prim, isinstance) \
|
||||
_(prim, tolist) \
|
||||
_(prim, unchecked_cast) \
|
||||
_(aten, _grad_sum_to_size) \
|
||||
_(aten, _size_if_not_equal) \
|
||||
|
|
|
|||
|
|
@ -892,6 +892,181 @@ class TestList(JitTestCase):
|
|||
check_list(min_floatlist, float_li)
|
||||
check_list(max_floatlist, float_li)
|
||||
|
||||
def test_to_list(self):
|
||||
"""Unit tests for Tensor.tolist() function."""
|
||||
|
||||
"""
|
||||
Boolean dtype unit tests.
|
||||
"""
|
||||
def to_list_bool_0D(x):
|
||||
# type: (torch.Tensor) -> bool
|
||||
li = torch.jit.annotate(bool, x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_bool_1D(x):
|
||||
# type: (torch.Tensor) -> List[bool]
|
||||
li = torch.jit.annotate(List[bool], x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_bool_2D(x):
|
||||
# type: (torch.Tensor) -> List[List[bool]]
|
||||
li = torch.jit.annotate(List[List[bool]], x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_bool_3D(x):
|
||||
# type: (torch.Tensor) -> List[List[List[bool]]]
|
||||
li = torch.jit.annotate(List[List[List[bool]]], x.tolist())
|
||||
return li
|
||||
|
||||
self.checkScript(to_list_bool_0D, (torch.tensor(False, dtype=torch.bool),))
|
||||
bool_input_1D = torch.tensor([True, False, True, False], dtype=torch.bool)
|
||||
self.checkScript(to_list_bool_1D, (bool_input_1D,))
|
||||
bool_input_2D = torch.tensor(
|
||||
[[True, True, False], [False, True, False]], dtype=torch.bool
|
||||
)
|
||||
self.checkScript(to_list_bool_2D, (bool_input_2D,))
|
||||
bool_input_3D = torch.tensor(
|
||||
[[[True, False], [False, True]], [[True, False], [False, False]]],
|
||||
dtype=torch.bool,
|
||||
)
|
||||
self.checkScript(to_list_bool_3D, (bool_input_3D,))
|
||||
bool_input_noncontiguous = torch.tensor(
|
||||
[[[True, False], [False, True]], [[True, False], [False, False]]],
|
||||
dtype=torch.bool,
|
||||
).transpose(0, 1)
|
||||
self.checkScript(to_list_bool_3D, (bool_input_noncontiguous,))
|
||||
|
||||
"""
|
||||
Int dtype unit tests.
|
||||
"""
|
||||
def to_list_int_0D(x):
|
||||
# type: (torch.Tensor) -> int
|
||||
li = torch.jit.annotate(int, x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_int_1D(x):
|
||||
# type: (torch.Tensor) -> List[int]
|
||||
li = torch.jit.annotate(List[int], x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_int_2D(x):
|
||||
# type: (torch.Tensor) -> List[List[int]]
|
||||
li = torch.jit.annotate(List[List[int]], x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_int_3D(x):
|
||||
# type: (torch.Tensor) -> List[List[List[int]]]
|
||||
li = torch.jit.annotate(List[List[List[int]]], x.tolist())
|
||||
return li
|
||||
|
||||
self.checkScript(to_list_int_0D, (torch.tensor(1, dtype=torch.long),))
|
||||
int_input_1D = torch.tensor([1, 2, 3, 4], dtype=torch.long)
|
||||
self.checkScript(to_list_int_1D, (int_input_1D,))
|
||||
int_input_2D = torch.tensor([[1, 2, 3], [3, 4, 5]], dtype=torch.long)
|
||||
self.checkScript(to_list_int_2D, (int_input_2D,))
|
||||
int_input_3D = torch.tensor(
|
||||
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=torch.long
|
||||
)
|
||||
self.checkScript(to_list_int_3D, (int_input_3D,))
|
||||
int_input_noncontiguous = torch.tensor(
|
||||
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=torch.long
|
||||
).transpose(0, 1)
|
||||
self.checkScript(to_list_int_3D, (int_input_noncontiguous,))
|
||||
|
||||
"""
|
||||
Float dtype unit tests.
|
||||
"""
|
||||
def to_list_float_0D(x):
|
||||
# type: (torch.Tensor) -> float
|
||||
li = torch.jit.annotate(float, x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_float_1D(x):
|
||||
# type: (torch.Tensor) -> List[float]
|
||||
li = torch.jit.annotate(List[float], x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_float_2D(x):
|
||||
# type: (torch.Tensor) -> List[List[float]]
|
||||
li = torch.jit.annotate(List[List[float]], x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_float_3D(x):
|
||||
# type: (torch.Tensor) -> List[List[List[float]]]
|
||||
li = torch.jit.annotate(List[List[List[float]]], x.tolist())
|
||||
return li
|
||||
|
||||
self.checkScript(to_list_float_0D, (torch.randn(5, dtype=torch.double)[0],))
|
||||
self.checkScript(to_list_float_1D, (torch.randn(5, dtype=torch.double),))
|
||||
self.checkScript(to_list_float_2D, (torch.randn(5, 6, dtype=torch.double),))
|
||||
self.checkScript(to_list_float_3D, (torch.randn(5, 6, 7, dtype=torch.double),))
|
||||
self.checkScript(to_list_float_3D, (torch.randn(5, 6, 7, dtype=torch.double).transpose(0, 1),))
|
||||
|
||||
"""
|
||||
Non-happy path tests:
|
||||
- missing type annotation
|
||||
- mismatch between type annotation and input
|
||||
- type annotation with unsupported type
|
||||
- type annotation with the wrong dimension
|
||||
- type annotation with scalar type that doesn't match the input scalar type
|
||||
"""
|
||||
def to_list_missing_type_annotation(x):
|
||||
# type: (torch.Tensor) -> List[float]
|
||||
li = x.tolist()
|
||||
return li
|
||||
|
||||
def to_list_incorrect_type_annotation(x):
|
||||
# type: (torch.Tensor) -> List[float]
|
||||
li = torch.jit.annotate(float, x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_unsupported_type_annotation(x):
|
||||
# type: (torch.Tensor) -> List[float]
|
||||
li = torch.jit.annotate(List[str], x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_type_annotation_wrong_dim(x):
|
||||
# type: (torch.Tensor) -> List[List[float]]
|
||||
li = torch.jit.annotate(List[List[float]], x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_type_annotation_incorrect_scalar_type(x):
|
||||
# type: (torch.Tensor) -> List[float]
|
||||
li = torch.jit.annotate(List[float], x.tolist())
|
||||
return li
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r"Expected type hint for result of tolist()"
|
||||
):
|
||||
self.checkScript(to_list_missing_type_annotation, (torch.randn(5),))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Return value was annotated as having type List\[float\] but is actually of type float",
|
||||
):
|
||||
self.checkScript(to_list_incorrect_type_annotation, (torch.randn(5),))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r"str is not one of the supported element types for tolist"
|
||||
):
|
||||
self.checkScript(to_list_unsupported_type_annotation, (torch.randn(5),))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Output annotation list dimension and runtime tensor dimension must match",
|
||||
):
|
||||
self.checkScript(to_list_type_annotation_wrong_dim, (torch.randn(5, dtype=torch.double),))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Output annotation element type and runtime tensor element type must match",
|
||||
):
|
||||
self.checkScript(
|
||||
to_list_type_annotation_incorrect_scalar_type,
|
||||
(torch.ones(5, dtype=torch.long),),
|
||||
)
|
||||
|
||||
|
||||
class TestDict(JitTestCase):
|
||||
def dict(self):
|
||||
return {u'a': torch.ones(1), u'b': torch.ones(1) + 1, u'c': torch.ones(1) + 2}
|
||||
|
|
|
|||
|
|
@ -2387,11 +2387,14 @@ struct to_ir {
|
|||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<SugaredValue> emitApplyExpr(Apply& apply, size_t n_binders) {
|
||||
std::shared_ptr<SugaredValue> emitApplyExpr(
|
||||
Apply& apply,
|
||||
size_t n_binders,
|
||||
const TypePtr& type_hint = nullptr) {
|
||||
auto sv = emitSugaredExpr(apply.callee(), 1);
|
||||
auto loc = apply.callee().range();
|
||||
if (auto special_form = dynamic_cast<SpecialFormValue*>(sv.get())) {
|
||||
return emitApplySpecialForm(special_form->form(), apply);
|
||||
return emitApplySpecialForm(special_form->form(), apply, type_hint);
|
||||
}
|
||||
auto inputs = getNamedValues(apply.inputs(), true);
|
||||
auto attributes = emitAttributes(apply.attributes());
|
||||
|
|
@ -2405,7 +2408,8 @@ struct to_ir {
|
|||
// evaluation order.
|
||||
std::shared_ptr<SugaredValue> emitApplySpecialForm(
|
||||
Symbol form,
|
||||
Apply& apply) {
|
||||
Apply& apply,
|
||||
const TypePtr& type_hint = nullptr) {
|
||||
switch (form) {
|
||||
case prim::fork: {
|
||||
auto& trees = apply.inputs().tree()->trees();
|
||||
|
|
@ -2495,6 +2499,19 @@ struct to_ir {
|
|||
auto result = emitIsInstance(apply.inputs()[0], apply.inputs()[1]);
|
||||
return std::make_shared<SimpleValue>(result.value());
|
||||
}
|
||||
case prim::tolist: {
|
||||
auto select = Select(apply.callee());
|
||||
auto value = select.value();
|
||||
auto operand = emitSugaredExpr(value, 1);
|
||||
|
||||
if (!type_hint) {
|
||||
throw ErrorReport(apply)
|
||||
<< "Expected type hint for result of tolist()";
|
||||
}
|
||||
|
||||
return std::make_shared<SimpleValue>(graph->insertToList(
|
||||
operand->asValue(value.range(), method), type_hint));
|
||||
}
|
||||
case prim::HasAttr: {
|
||||
checkApplyNumInputs(apply, 2);
|
||||
const auto result = emitHasAttr(apply.inputs()[0], apply.inputs()[1]);
|
||||
|
|
@ -2638,7 +2655,7 @@ struct to_ir {
|
|||
}
|
||||
case TK_APPLY: {
|
||||
auto apply = Apply(tree);
|
||||
return emitApplyExpr(apply, n_binders);
|
||||
return emitApplyExpr(apply, n_binders, type_hint);
|
||||
} break;
|
||||
default:
|
||||
return std::make_shared<SimpleValue>(emitSimpleExpr(tree, type_hint));
|
||||
|
|
|
|||
|
|
@ -147,6 +147,11 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
|
|||
return builtin;
|
||||
}
|
||||
|
||||
// Handle calling tolist() on a Tensor.
|
||||
if (value_->type()->isSubtypeOf(TensorType::get()) && field == "tolist") {
|
||||
return SpecialFormValue::create(prim::tolist);
|
||||
}
|
||||
|
||||
ErrorReport report(loc);
|
||||
report << "Tried to access nonexistent attribute or method '" << field
|
||||
<< "' of type '" << value_->type()->python_str() << "'.";
|
||||
|
|
|
|||
|
|
@ -336,6 +336,7 @@ void AliasDb::analyzeImpl(Node* node) {
|
|||
case prim::ChunkSizes:
|
||||
case prim::Function:
|
||||
case prim::CreateObject:
|
||||
case prim::tolist:
|
||||
return analyzeCreator(node);
|
||||
case prim::TupleConstruct:
|
||||
case prim::DictConstruct:
|
||||
|
|
|
|||
|
|
@ -1686,6 +1686,40 @@ Value* Graph::insertUncheckedCast(Value* v, TypePtr type) {
|
|||
return n->output();
|
||||
}
|
||||
|
||||
Value* Graph::insertToList(Value* v, TypePtr type) {
|
||||
int dim = 0;
|
||||
TypePtr ptr = type;
|
||||
|
||||
// Unwrap the type to determine the number of dimensions.
|
||||
while (auto list_type = ptr->cast<ListType>()) {
|
||||
ptr = list_type->getElementType();
|
||||
++dim;
|
||||
}
|
||||
|
||||
// Encode the base element type as an integer.
|
||||
int elem_ty = 0;
|
||||
if (ptr == IntType::get()) {
|
||||
elem_ty = 0;
|
||||
} else if (ptr == FloatType::get()) {
|
||||
elem_ty = 1;
|
||||
} else if (ptr == BoolType::get()) {
|
||||
elem_ty = 2;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
ptr->python_str(),
|
||||
" is not one of the supported element types for tolist: int, float, bool");
|
||||
}
|
||||
|
||||
// Pass in the number of dimensions and base element type as arguments
|
||||
// to the op.
|
||||
Value* dim_val = insertConstant(IValue(dim));
|
||||
Value* elem_ty_val = insertConstant(IValue(elem_ty));
|
||||
Node* n = insertNode(create(prim::tolist, {v, dim_val, elem_ty_val}));
|
||||
n->output()->setType(std::move(type));
|
||||
return n->output();
|
||||
}
|
||||
|
||||
Value* Graph::insertFunctionCall(
|
||||
Function* callee,
|
||||
const script::MatchedSchema& matched) {
|
||||
|
|
|
|||
|
|
@ -1128,6 +1128,10 @@ struct Graph {
|
|||
|
||||
TORCH_API Value* insertUncheckedCast(Value* v, TypePtr type);
|
||||
|
||||
// Insert a ToList operator with argument \p v and output type \p type.
|
||||
// \returns the output of the operation.
|
||||
TORCH_API Value* insertToList(Value* v, TypePtr type);
|
||||
|
||||
TORCH_API Value* insertFunctionCall(
|
||||
Function* callee,
|
||||
const script::MatchedSchema& matched);
|
||||
|
|
|
|||
|
|
@ -157,6 +157,7 @@ bool printerHasSpecialCaseFor(Symbol sym) {
|
|||
prim::CallFunction,
|
||||
prim::isinstance,
|
||||
prim::unchecked_cast,
|
||||
prim::tolist,
|
||||
};
|
||||
|
||||
// WARNING: by adding a value to this set, you are asserting that your
|
||||
|
|
@ -239,6 +240,7 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
|
|||
aten::wait,
|
||||
prim::isinstance,
|
||||
prim::unchecked_cast,
|
||||
prim::tolist,
|
||||
};
|
||||
|
||||
// Operators that should not be used by alias analysis
|
||||
|
|
|
|||
|
|
@ -103,6 +103,73 @@ void checkImplicitTensorToNum(at::Tensor t, bool toInt) {
|
|||
}
|
||||
}
|
||||
|
||||
// Convert the tensor pointed to by \p data to a nested list. \p dim is the
|
||||
// number of dimensions in the tensor and \p cur_dim is the dimension being
|
||||
// processed by the current invocation. \p ty is the expected output IR type of
|
||||
// the operation. \p sizes and \p strides are the sizes and strides of the
|
||||
// tensor operand and \p element_size is the size in bytes of one tensor
|
||||
// element.
|
||||
IValue tensorToListRecursive(
|
||||
char* data,
|
||||
int64_t cur_dim,
|
||||
int64_t num_tensor_dims,
|
||||
TypePtr ty,
|
||||
at::IntArrayRef sizes,
|
||||
at::IntArrayRef strides,
|
||||
size_t element_size) {
|
||||
// If ty is a ListType, get the element type.
|
||||
if (auto list_type = ty->cast<ListType>()) {
|
||||
ty = list_type->getElementType();
|
||||
} else {
|
||||
// If the output type is a scalar, read and push one scalar of
|
||||
// the right type onto the stack.
|
||||
if (ty == IntType::get()) {
|
||||
int64_t scalar = *(int64_t*)data;
|
||||
return IValue(scalar);
|
||||
} else if (ty == FloatType::get()) {
|
||||
double scalar = *(double*)data;
|
||||
return IValue(scalar);
|
||||
} else if (ty == BoolType::get()) {
|
||||
bool scalar = *(bool*)data;
|
||||
return IValue(scalar);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
ty->python_str(),
|
||||
" is not one of the supported types for tolist: int, float, bool");
|
||||
}
|
||||
}
|
||||
|
||||
// Make the result list consisting of elements of type ty. Since this
|
||||
// invocation is processing dimension cur_dim, there will be sizes[cur_dim]
|
||||
// output elements.
|
||||
auto result = c10::impl::GenericList(ty);
|
||||
result.reserve(sizes[cur_dim]);
|
||||
|
||||
// Since ty was a list type, tensorToListRecursive needs to be called
|
||||
// recursively on each slice of the tensor in the current dimension.
|
||||
for (int64_t i = 0, e = sizes[cur_dim]; i < e; ++i) {
|
||||
auto inner_result = tensorToListRecursive(
|
||||
data, cur_dim + 1, num_tensor_dims, ty, sizes, strides, element_size);
|
||||
|
||||
if (inner_result.isList()) {
|
||||
result.emplace_back(inner_result.toList());
|
||||
} else if (inner_result.isDouble()) {
|
||||
result.emplace_back(inner_result.toDouble());
|
||||
} else if (inner_result.isInt()) {
|
||||
result.emplace_back(inner_result.toInt());
|
||||
} else if (inner_result.isBool()) {
|
||||
result.emplace_back(inner_result.toBool());
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT("Unknown return type for tensorToListRecursive");
|
||||
}
|
||||
|
||||
data += strides[cur_dim] * element_size;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static int64_t floordiv(int64_t a, int64_t b) {
|
||||
if (b == 0) {
|
||||
throw std::runtime_error("division by 0");
|
||||
|
|
@ -967,6 +1034,56 @@ RegisterOperators reg(
|
|||
return 0;
|
||||
},
|
||||
aliasAnalysisSpecialCase()),
|
||||
Operator(
|
||||
prim::tolist,
|
||||
// This operator has to be unschematized because the return type depends on the type hint and input.
|
||||
// The implementation of this operator below is intended to be as close to the Python implementation in
|
||||
// torch/csrc/utils/tensor_list.cpp as possible.
|
||||
[](const Node* node) -> Operation {
|
||||
return [](Stack& stack) {
|
||||
int elem_ty_val;
|
||||
int dim_val;
|
||||
at::Tensor t;
|
||||
|
||||
pop(stack, elem_ty_val);
|
||||
pop(stack, dim_val);
|
||||
pop(stack, t);
|
||||
|
||||
// Rebuild the output type using elem_ty_val and dim_val. Start
|
||||
// with the element type corresponding to elem_ty_val.
|
||||
TypePtr out_ty;
|
||||
if (elem_ty_val == 0) {
|
||||
out_ty = IntType::get();
|
||||
} else if (elem_ty_val == 1) {
|
||||
out_ty = FloatType::get();
|
||||
} else if (elem_ty_val == 2) {
|
||||
out_ty = BoolType::get();
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported element type for tolist; only int, float and bool are supported");
|
||||
}
|
||||
|
||||
// Check that type of the Tensor matches that of the annotation.
|
||||
TORCH_CHECK(tryScalarTypeFromJitType(out_ty) == t.scalar_type(), "Output annotation element type and runtime tensor element type must match for tolist()")
|
||||
|
||||
// Check that the dimension of the Tensor matches that of the annotation.
|
||||
TORCH_CHECK(dim_val == t.dim(), "Output annotation list dimension and runtime tensor dimension must match for tolist()")
|
||||
|
||||
// Wrap out_ty in a ListType dim times.
|
||||
for (int i = 0; i < dim_val; ++i) {
|
||||
out_ty = ListType::create(out_ty);
|
||||
}
|
||||
|
||||
int64_t dim = t.dim();
|
||||
auto sizes = t.sizes();
|
||||
auto strides = t.strides();
|
||||
size_t element_size = t.element_size();
|
||||
char* data = (char*)t.data_ptr();
|
||||
auto result = tensorToListRecursive(data, 0, dim, out_ty, sizes, strides, element_size);
|
||||
push(stack, std::move(result));
|
||||
return 0;
|
||||
};
|
||||
},
|
||||
aliasAnalysisSpecialCase()),
|
||||
Operator(
|
||||
prim::ConstantChunk,
|
||||
[](const Node* node) -> Operation {
|
||||
|
|
|
|||
|
|
@ -1043,6 +1043,11 @@ struct PythonPrintImpl {
|
|||
}
|
||||
stmt << ")";
|
||||
} break;
|
||||
case prim::tolist: {
|
||||
stmt << "annotate(" << node->output()->type()->python_str() << ", ";
|
||||
stmt << useOf(node->input(0)) << ".tolist()"
|
||||
<< ")";
|
||||
} break;
|
||||
default: {
|
||||
printOpName(stmt, node->kind());
|
||||
const FunctionSchema& schema = node->schema();
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user