mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[JIT] Implement Tensor.tolist() (#33472)
Summary: **Summary** This commit adds an implementation of `Tensor.tolist()` to the JIT interpreter. **Testing** This commit adds several unit tests that test that this function works correctly for 0D, 1D, 2D and 3D tensors of type `float`, `int` and `bool`. ``` (base) meghanl-mbp:pytorch meghanl$ python test/test_jit.py TestList.test_to_list -v Fail to import hypothesis in common_utils, tests are not derandomized test_to_list (jit.test_list_dict.TestList) Unit tests for Tensor.tolist() function. ... ok ---------------------------------------------------------------------- Ran 1 test in 0.329s OK ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/33472 Differential Revision: D20109738 Pulled By: SplitInfinity fbshipit-source-id: a6e3fee5e3201d5e1f0c4ca45048488ae2bf5e33
This commit is contained in:
parent
5029ff001b
commit
cb8d9f99aa
|
|
@ -97,6 +97,7 @@ namespace c10 {
|
||||||
_(prim, range) \
|
_(prim, range) \
|
||||||
_(prim, rangelist) \
|
_(prim, rangelist) \
|
||||||
_(prim, isinstance) \
|
_(prim, isinstance) \
|
||||||
|
_(prim, tolist) \
|
||||||
_(prim, unchecked_cast) \
|
_(prim, unchecked_cast) \
|
||||||
_(aten, _grad_sum_to_size) \
|
_(aten, _grad_sum_to_size) \
|
||||||
_(aten, _size_if_not_equal) \
|
_(aten, _size_if_not_equal) \
|
||||||
|
|
|
||||||
|
|
@ -892,6 +892,181 @@ class TestList(JitTestCase):
|
||||||
check_list(min_floatlist, float_li)
|
check_list(min_floatlist, float_li)
|
||||||
check_list(max_floatlist, float_li)
|
check_list(max_floatlist, float_li)
|
||||||
|
|
||||||
|
def test_to_list(self):
|
||||||
|
"""Unit tests for Tensor.tolist() function."""
|
||||||
|
|
||||||
|
"""
|
||||||
|
Boolean dtype unit tests.
|
||||||
|
"""
|
||||||
|
def to_list_bool_0D(x):
|
||||||
|
# type: (torch.Tensor) -> bool
|
||||||
|
li = torch.jit.annotate(bool, x.tolist())
|
||||||
|
return li
|
||||||
|
|
||||||
|
def to_list_bool_1D(x):
|
||||||
|
# type: (torch.Tensor) -> List[bool]
|
||||||
|
li = torch.jit.annotate(List[bool], x.tolist())
|
||||||
|
return li
|
||||||
|
|
||||||
|
def to_list_bool_2D(x):
|
||||||
|
# type: (torch.Tensor) -> List[List[bool]]
|
||||||
|
li = torch.jit.annotate(List[List[bool]], x.tolist())
|
||||||
|
return li
|
||||||
|
|
||||||
|
def to_list_bool_3D(x):
|
||||||
|
# type: (torch.Tensor) -> List[List[List[bool]]]
|
||||||
|
li = torch.jit.annotate(List[List[List[bool]]], x.tolist())
|
||||||
|
return li
|
||||||
|
|
||||||
|
self.checkScript(to_list_bool_0D, (torch.tensor(False, dtype=torch.bool),))
|
||||||
|
bool_input_1D = torch.tensor([True, False, True, False], dtype=torch.bool)
|
||||||
|
self.checkScript(to_list_bool_1D, (bool_input_1D,))
|
||||||
|
bool_input_2D = torch.tensor(
|
||||||
|
[[True, True, False], [False, True, False]], dtype=torch.bool
|
||||||
|
)
|
||||||
|
self.checkScript(to_list_bool_2D, (bool_input_2D,))
|
||||||
|
bool_input_3D = torch.tensor(
|
||||||
|
[[[True, False], [False, True]], [[True, False], [False, False]]],
|
||||||
|
dtype=torch.bool,
|
||||||
|
)
|
||||||
|
self.checkScript(to_list_bool_3D, (bool_input_3D,))
|
||||||
|
bool_input_noncontiguous = torch.tensor(
|
||||||
|
[[[True, False], [False, True]], [[True, False], [False, False]]],
|
||||||
|
dtype=torch.bool,
|
||||||
|
).transpose(0, 1)
|
||||||
|
self.checkScript(to_list_bool_3D, (bool_input_noncontiguous,))
|
||||||
|
|
||||||
|
"""
|
||||||
|
Int dtype unit tests.
|
||||||
|
"""
|
||||||
|
def to_list_int_0D(x):
|
||||||
|
# type: (torch.Tensor) -> int
|
||||||
|
li = torch.jit.annotate(int, x.tolist())
|
||||||
|
return li
|
||||||
|
|
||||||
|
def to_list_int_1D(x):
|
||||||
|
# type: (torch.Tensor) -> List[int]
|
||||||
|
li = torch.jit.annotate(List[int], x.tolist())
|
||||||
|
return li
|
||||||
|
|
||||||
|
def to_list_int_2D(x):
|
||||||
|
# type: (torch.Tensor) -> List[List[int]]
|
||||||
|
li = torch.jit.annotate(List[List[int]], x.tolist())
|
||||||
|
return li
|
||||||
|
|
||||||
|
def to_list_int_3D(x):
|
||||||
|
# type: (torch.Tensor) -> List[List[List[int]]]
|
||||||
|
li = torch.jit.annotate(List[List[List[int]]], x.tolist())
|
||||||
|
return li
|
||||||
|
|
||||||
|
self.checkScript(to_list_int_0D, (torch.tensor(1, dtype=torch.long),))
|
||||||
|
int_input_1D = torch.tensor([1, 2, 3, 4], dtype=torch.long)
|
||||||
|
self.checkScript(to_list_int_1D, (int_input_1D,))
|
||||||
|
int_input_2D = torch.tensor([[1, 2, 3], [3, 4, 5]], dtype=torch.long)
|
||||||
|
self.checkScript(to_list_int_2D, (int_input_2D,))
|
||||||
|
int_input_3D = torch.tensor(
|
||||||
|
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=torch.long
|
||||||
|
)
|
||||||
|
self.checkScript(to_list_int_3D, (int_input_3D,))
|
||||||
|
int_input_noncontiguous = torch.tensor(
|
||||||
|
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=torch.long
|
||||||
|
).transpose(0, 1)
|
||||||
|
self.checkScript(to_list_int_3D, (int_input_noncontiguous,))
|
||||||
|
|
||||||
|
"""
|
||||||
|
Float dtype unit tests.
|
||||||
|
"""
|
||||||
|
def to_list_float_0D(x):
|
||||||
|
# type: (torch.Tensor) -> float
|
||||||
|
li = torch.jit.annotate(float, x.tolist())
|
||||||
|
return li
|
||||||
|
|
||||||
|
def to_list_float_1D(x):
|
||||||
|
# type: (torch.Tensor) -> List[float]
|
||||||
|
li = torch.jit.annotate(List[float], x.tolist())
|
||||||
|
return li
|
||||||
|
|
||||||
|
def to_list_float_2D(x):
|
||||||
|
# type: (torch.Tensor) -> List[List[float]]
|
||||||
|
li = torch.jit.annotate(List[List[float]], x.tolist())
|
||||||
|
return li
|
||||||
|
|
||||||
|
def to_list_float_3D(x):
|
||||||
|
# type: (torch.Tensor) -> List[List[List[float]]]
|
||||||
|
li = torch.jit.annotate(List[List[List[float]]], x.tolist())
|
||||||
|
return li
|
||||||
|
|
||||||
|
self.checkScript(to_list_float_0D, (torch.randn(5, dtype=torch.double)[0],))
|
||||||
|
self.checkScript(to_list_float_1D, (torch.randn(5, dtype=torch.double),))
|
||||||
|
self.checkScript(to_list_float_2D, (torch.randn(5, 6, dtype=torch.double),))
|
||||||
|
self.checkScript(to_list_float_3D, (torch.randn(5, 6, 7, dtype=torch.double),))
|
||||||
|
self.checkScript(to_list_float_3D, (torch.randn(5, 6, 7, dtype=torch.double).transpose(0, 1),))
|
||||||
|
|
||||||
|
"""
|
||||||
|
Non-happy path tests:
|
||||||
|
- missing type annotation
|
||||||
|
- mismatch between type annotation and input
|
||||||
|
- type annotation with unsupported type
|
||||||
|
- type annotation with the wrong dimension
|
||||||
|
- type annotation with scalar type that doesn't match the input scalar type
|
||||||
|
"""
|
||||||
|
def to_list_missing_type_annotation(x):
|
||||||
|
# type: (torch.Tensor) -> List[float]
|
||||||
|
li = x.tolist()
|
||||||
|
return li
|
||||||
|
|
||||||
|
def to_list_incorrect_type_annotation(x):
|
||||||
|
# type: (torch.Tensor) -> List[float]
|
||||||
|
li = torch.jit.annotate(float, x.tolist())
|
||||||
|
return li
|
||||||
|
|
||||||
|
def to_list_unsupported_type_annotation(x):
|
||||||
|
# type: (torch.Tensor) -> List[float]
|
||||||
|
li = torch.jit.annotate(List[str], x.tolist())
|
||||||
|
return li
|
||||||
|
|
||||||
|
def to_list_type_annotation_wrong_dim(x):
|
||||||
|
# type: (torch.Tensor) -> List[List[float]]
|
||||||
|
li = torch.jit.annotate(List[List[float]], x.tolist())
|
||||||
|
return li
|
||||||
|
|
||||||
|
def to_list_type_annotation_incorrect_scalar_type(x):
|
||||||
|
# type: (torch.Tensor) -> List[float]
|
||||||
|
li = torch.jit.annotate(List[float], x.tolist())
|
||||||
|
return li
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, r"Expected type hint for result of tolist()"
|
||||||
|
):
|
||||||
|
self.checkScript(to_list_missing_type_annotation, (torch.randn(5),))
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
r"Return value was annotated as having type List\[float\] but is actually of type float",
|
||||||
|
):
|
||||||
|
self.checkScript(to_list_incorrect_type_annotation, (torch.randn(5),))
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, r"str is not one of the supported element types for tolist"
|
||||||
|
):
|
||||||
|
self.checkScript(to_list_unsupported_type_annotation, (torch.randn(5),))
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
r"Output annotation list dimension and runtime tensor dimension must match",
|
||||||
|
):
|
||||||
|
self.checkScript(to_list_type_annotation_wrong_dim, (torch.randn(5, dtype=torch.double),))
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
r"Output annotation element type and runtime tensor element type must match",
|
||||||
|
):
|
||||||
|
self.checkScript(
|
||||||
|
to_list_type_annotation_incorrect_scalar_type,
|
||||||
|
(torch.ones(5, dtype=torch.long),),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestDict(JitTestCase):
|
class TestDict(JitTestCase):
|
||||||
def dict(self):
|
def dict(self):
|
||||||
return {u'a': torch.ones(1), u'b': torch.ones(1) + 1, u'c': torch.ones(1) + 2}
|
return {u'a': torch.ones(1), u'b': torch.ones(1) + 1, u'c': torch.ones(1) + 2}
|
||||||
|
|
|
||||||
|
|
@ -2387,11 +2387,14 @@ struct to_ir {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<SugaredValue> emitApplyExpr(Apply& apply, size_t n_binders) {
|
std::shared_ptr<SugaredValue> emitApplyExpr(
|
||||||
|
Apply& apply,
|
||||||
|
size_t n_binders,
|
||||||
|
const TypePtr& type_hint = nullptr) {
|
||||||
auto sv = emitSugaredExpr(apply.callee(), 1);
|
auto sv = emitSugaredExpr(apply.callee(), 1);
|
||||||
auto loc = apply.callee().range();
|
auto loc = apply.callee().range();
|
||||||
if (auto special_form = dynamic_cast<SpecialFormValue*>(sv.get())) {
|
if (auto special_form = dynamic_cast<SpecialFormValue*>(sv.get())) {
|
||||||
return emitApplySpecialForm(special_form->form(), apply);
|
return emitApplySpecialForm(special_form->form(), apply, type_hint);
|
||||||
}
|
}
|
||||||
auto inputs = getNamedValues(apply.inputs(), true);
|
auto inputs = getNamedValues(apply.inputs(), true);
|
||||||
auto attributes = emitAttributes(apply.attributes());
|
auto attributes = emitAttributes(apply.attributes());
|
||||||
|
|
@ -2405,7 +2408,8 @@ struct to_ir {
|
||||||
// evaluation order.
|
// evaluation order.
|
||||||
std::shared_ptr<SugaredValue> emitApplySpecialForm(
|
std::shared_ptr<SugaredValue> emitApplySpecialForm(
|
||||||
Symbol form,
|
Symbol form,
|
||||||
Apply& apply) {
|
Apply& apply,
|
||||||
|
const TypePtr& type_hint = nullptr) {
|
||||||
switch (form) {
|
switch (form) {
|
||||||
case prim::fork: {
|
case prim::fork: {
|
||||||
auto& trees = apply.inputs().tree()->trees();
|
auto& trees = apply.inputs().tree()->trees();
|
||||||
|
|
@ -2495,6 +2499,19 @@ struct to_ir {
|
||||||
auto result = emitIsInstance(apply.inputs()[0], apply.inputs()[1]);
|
auto result = emitIsInstance(apply.inputs()[0], apply.inputs()[1]);
|
||||||
return std::make_shared<SimpleValue>(result.value());
|
return std::make_shared<SimpleValue>(result.value());
|
||||||
}
|
}
|
||||||
|
case prim::tolist: {
|
||||||
|
auto select = Select(apply.callee());
|
||||||
|
auto value = select.value();
|
||||||
|
auto operand = emitSugaredExpr(value, 1);
|
||||||
|
|
||||||
|
if (!type_hint) {
|
||||||
|
throw ErrorReport(apply)
|
||||||
|
<< "Expected type hint for result of tolist()";
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_shared<SimpleValue>(graph->insertToList(
|
||||||
|
operand->asValue(value.range(), method), type_hint));
|
||||||
|
}
|
||||||
case prim::HasAttr: {
|
case prim::HasAttr: {
|
||||||
checkApplyNumInputs(apply, 2);
|
checkApplyNumInputs(apply, 2);
|
||||||
const auto result = emitHasAttr(apply.inputs()[0], apply.inputs()[1]);
|
const auto result = emitHasAttr(apply.inputs()[0], apply.inputs()[1]);
|
||||||
|
|
@ -2638,7 +2655,7 @@ struct to_ir {
|
||||||
}
|
}
|
||||||
case TK_APPLY: {
|
case TK_APPLY: {
|
||||||
auto apply = Apply(tree);
|
auto apply = Apply(tree);
|
||||||
return emitApplyExpr(apply, n_binders);
|
return emitApplyExpr(apply, n_binders, type_hint);
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
return std::make_shared<SimpleValue>(emitSimpleExpr(tree, type_hint));
|
return std::make_shared<SimpleValue>(emitSimpleExpr(tree, type_hint));
|
||||||
|
|
|
||||||
|
|
@ -147,6 +147,11 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
|
||||||
return builtin;
|
return builtin;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle calling tolist() on a Tensor.
|
||||||
|
if (value_->type()->isSubtypeOf(TensorType::get()) && field == "tolist") {
|
||||||
|
return SpecialFormValue::create(prim::tolist);
|
||||||
|
}
|
||||||
|
|
||||||
ErrorReport report(loc);
|
ErrorReport report(loc);
|
||||||
report << "Tried to access nonexistent attribute or method '" << field
|
report << "Tried to access nonexistent attribute or method '" << field
|
||||||
<< "' of type '" << value_->type()->python_str() << "'.";
|
<< "' of type '" << value_->type()->python_str() << "'.";
|
||||||
|
|
|
||||||
|
|
@ -336,6 +336,7 @@ void AliasDb::analyzeImpl(Node* node) {
|
||||||
case prim::ChunkSizes:
|
case prim::ChunkSizes:
|
||||||
case prim::Function:
|
case prim::Function:
|
||||||
case prim::CreateObject:
|
case prim::CreateObject:
|
||||||
|
case prim::tolist:
|
||||||
return analyzeCreator(node);
|
return analyzeCreator(node);
|
||||||
case prim::TupleConstruct:
|
case prim::TupleConstruct:
|
||||||
case prim::DictConstruct:
|
case prim::DictConstruct:
|
||||||
|
|
|
||||||
|
|
@ -1686,6 +1686,40 @@ Value* Graph::insertUncheckedCast(Value* v, TypePtr type) {
|
||||||
return n->output();
|
return n->output();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value* Graph::insertToList(Value* v, TypePtr type) {
|
||||||
|
int dim = 0;
|
||||||
|
TypePtr ptr = type;
|
||||||
|
|
||||||
|
// Unwrap the type to determine the number of dimensions.
|
||||||
|
while (auto list_type = ptr->cast<ListType>()) {
|
||||||
|
ptr = list_type->getElementType();
|
||||||
|
++dim;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode the base element type as an integer.
|
||||||
|
int elem_ty = 0;
|
||||||
|
if (ptr == IntType::get()) {
|
||||||
|
elem_ty = 0;
|
||||||
|
} else if (ptr == FloatType::get()) {
|
||||||
|
elem_ty = 1;
|
||||||
|
} else if (ptr == BoolType::get()) {
|
||||||
|
elem_ty = 2;
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(
|
||||||
|
false,
|
||||||
|
ptr->python_str(),
|
||||||
|
" is not one of the supported element types for tolist: int, float, bool");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass in the number of dimensions and base element type as arguments
|
||||||
|
// to the op.
|
||||||
|
Value* dim_val = insertConstant(IValue(dim));
|
||||||
|
Value* elem_ty_val = insertConstant(IValue(elem_ty));
|
||||||
|
Node* n = insertNode(create(prim::tolist, {v, dim_val, elem_ty_val}));
|
||||||
|
n->output()->setType(std::move(type));
|
||||||
|
return n->output();
|
||||||
|
}
|
||||||
|
|
||||||
Value* Graph::insertFunctionCall(
|
Value* Graph::insertFunctionCall(
|
||||||
Function* callee,
|
Function* callee,
|
||||||
const script::MatchedSchema& matched) {
|
const script::MatchedSchema& matched) {
|
||||||
|
|
|
||||||
|
|
@ -1128,6 +1128,10 @@ struct Graph {
|
||||||
|
|
||||||
TORCH_API Value* insertUncheckedCast(Value* v, TypePtr type);
|
TORCH_API Value* insertUncheckedCast(Value* v, TypePtr type);
|
||||||
|
|
||||||
|
// Insert a ToList operator with argument \p v and output type \p type.
|
||||||
|
// \returns the output of the operation.
|
||||||
|
TORCH_API Value* insertToList(Value* v, TypePtr type);
|
||||||
|
|
||||||
TORCH_API Value* insertFunctionCall(
|
TORCH_API Value* insertFunctionCall(
|
||||||
Function* callee,
|
Function* callee,
|
||||||
const script::MatchedSchema& matched);
|
const script::MatchedSchema& matched);
|
||||||
|
|
|
||||||
|
|
@ -157,6 +157,7 @@ bool printerHasSpecialCaseFor(Symbol sym) {
|
||||||
prim::CallFunction,
|
prim::CallFunction,
|
||||||
prim::isinstance,
|
prim::isinstance,
|
||||||
prim::unchecked_cast,
|
prim::unchecked_cast,
|
||||||
|
prim::tolist,
|
||||||
};
|
};
|
||||||
|
|
||||||
// WARNING: by adding a value to this set, you are asserting that your
|
// WARNING: by adding a value to this set, you are asserting that your
|
||||||
|
|
@ -239,6 +240,7 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
|
||||||
aten::wait,
|
aten::wait,
|
||||||
prim::isinstance,
|
prim::isinstance,
|
||||||
prim::unchecked_cast,
|
prim::unchecked_cast,
|
||||||
|
prim::tolist,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Operators that should not be used by alias analysis
|
// Operators that should not be used by alias analysis
|
||||||
|
|
|
||||||
|
|
@ -103,6 +103,73 @@ void checkImplicitTensorToNum(at::Tensor t, bool toInt) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert the tensor pointed to by \p data to a nested list. \p dim is the
|
||||||
|
// number of dimensions in the tensor and \p cur_dim is the dimension being
|
||||||
|
// processed by the current invocation. \p ty is the expected output IR type of
|
||||||
|
// the operation. \p sizes and \p strides are the sizes and strides of the
|
||||||
|
// tensor operand and \p element_size is the size in bytes of one tensor
|
||||||
|
// element.
|
||||||
|
IValue tensorToListRecursive(
|
||||||
|
char* data,
|
||||||
|
int64_t cur_dim,
|
||||||
|
int64_t num_tensor_dims,
|
||||||
|
TypePtr ty,
|
||||||
|
at::IntArrayRef sizes,
|
||||||
|
at::IntArrayRef strides,
|
||||||
|
size_t element_size) {
|
||||||
|
// If ty is a ListType, get the element type.
|
||||||
|
if (auto list_type = ty->cast<ListType>()) {
|
||||||
|
ty = list_type->getElementType();
|
||||||
|
} else {
|
||||||
|
// If the output type is a scalar, read and push one scalar of
|
||||||
|
// the right type onto the stack.
|
||||||
|
if (ty == IntType::get()) {
|
||||||
|
int64_t scalar = *(int64_t*)data;
|
||||||
|
return IValue(scalar);
|
||||||
|
} else if (ty == FloatType::get()) {
|
||||||
|
double scalar = *(double*)data;
|
||||||
|
return IValue(scalar);
|
||||||
|
} else if (ty == BoolType::get()) {
|
||||||
|
bool scalar = *(bool*)data;
|
||||||
|
return IValue(scalar);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(
|
||||||
|
false,
|
||||||
|
ty->python_str(),
|
||||||
|
" is not one of the supported types for tolist: int, float, bool");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make the result list consisting of elements of type ty. Since this
|
||||||
|
// invocation is processing dimension cur_dim, there will be sizes[cur_dim]
|
||||||
|
// output elements.
|
||||||
|
auto result = c10::impl::GenericList(ty);
|
||||||
|
result.reserve(sizes[cur_dim]);
|
||||||
|
|
||||||
|
// Since ty was a list type, tensorToListRecursive needs to be called
|
||||||
|
// recursively on each slice of the tensor in the current dimension.
|
||||||
|
for (int64_t i = 0, e = sizes[cur_dim]; i < e; ++i) {
|
||||||
|
auto inner_result = tensorToListRecursive(
|
||||||
|
data, cur_dim + 1, num_tensor_dims, ty, sizes, strides, element_size);
|
||||||
|
|
||||||
|
if (inner_result.isList()) {
|
||||||
|
result.emplace_back(inner_result.toList());
|
||||||
|
} else if (inner_result.isDouble()) {
|
||||||
|
result.emplace_back(inner_result.toDouble());
|
||||||
|
} else if (inner_result.isInt()) {
|
||||||
|
result.emplace_back(inner_result.toInt());
|
||||||
|
} else if (inner_result.isBool()) {
|
||||||
|
result.emplace_back(inner_result.toBool());
|
||||||
|
} else {
|
||||||
|
TORCH_INTERNAL_ASSERT("Unknown return type for tensorToListRecursive");
|
||||||
|
}
|
||||||
|
|
||||||
|
data += strides[cur_dim] * element_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
static int64_t floordiv(int64_t a, int64_t b) {
|
static int64_t floordiv(int64_t a, int64_t b) {
|
||||||
if (b == 0) {
|
if (b == 0) {
|
||||||
throw std::runtime_error("division by 0");
|
throw std::runtime_error("division by 0");
|
||||||
|
|
@ -967,6 +1034,56 @@ RegisterOperators reg(
|
||||||
return 0;
|
return 0;
|
||||||
},
|
},
|
||||||
aliasAnalysisSpecialCase()),
|
aliasAnalysisSpecialCase()),
|
||||||
|
Operator(
|
||||||
|
prim::tolist,
|
||||||
|
// This operator has to be unschematized because the return type depends on the type hint and input.
|
||||||
|
// The implementation of this operator below is intended to be as close to the Python implementation in
|
||||||
|
// torch/csrc/utils/tensor_list.cpp as possible.
|
||||||
|
[](const Node* node) -> Operation {
|
||||||
|
return [](Stack& stack) {
|
||||||
|
int elem_ty_val;
|
||||||
|
int dim_val;
|
||||||
|
at::Tensor t;
|
||||||
|
|
||||||
|
pop(stack, elem_ty_val);
|
||||||
|
pop(stack, dim_val);
|
||||||
|
pop(stack, t);
|
||||||
|
|
||||||
|
// Rebuild the output type using elem_ty_val and dim_val. Start
|
||||||
|
// with the element type corresponding to elem_ty_val.
|
||||||
|
TypePtr out_ty;
|
||||||
|
if (elem_ty_val == 0) {
|
||||||
|
out_ty = IntType::get();
|
||||||
|
} else if (elem_ty_val == 1) {
|
||||||
|
out_ty = FloatType::get();
|
||||||
|
} else if (elem_ty_val == 2) {
|
||||||
|
out_ty = BoolType::get();
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported element type for tolist; only int, float and bool are supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that type of the Tensor matches that of the annotation.
|
||||||
|
TORCH_CHECK(tryScalarTypeFromJitType(out_ty) == t.scalar_type(), "Output annotation element type and runtime tensor element type must match for tolist()")
|
||||||
|
|
||||||
|
// Check that the dimension of the Tensor matches that of the annotation.
|
||||||
|
TORCH_CHECK(dim_val == t.dim(), "Output annotation list dimension and runtime tensor dimension must match for tolist()")
|
||||||
|
|
||||||
|
// Wrap out_ty in a ListType dim times.
|
||||||
|
for (int i = 0; i < dim_val; ++i) {
|
||||||
|
out_ty = ListType::create(out_ty);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t dim = t.dim();
|
||||||
|
auto sizes = t.sizes();
|
||||||
|
auto strides = t.strides();
|
||||||
|
size_t element_size = t.element_size();
|
||||||
|
char* data = (char*)t.data_ptr();
|
||||||
|
auto result = tensorToListRecursive(data, 0, dim, out_ty, sizes, strides, element_size);
|
||||||
|
push(stack, std::move(result));
|
||||||
|
return 0;
|
||||||
|
};
|
||||||
|
},
|
||||||
|
aliasAnalysisSpecialCase()),
|
||||||
Operator(
|
Operator(
|
||||||
prim::ConstantChunk,
|
prim::ConstantChunk,
|
||||||
[](const Node* node) -> Operation {
|
[](const Node* node) -> Operation {
|
||||||
|
|
|
||||||
|
|
@ -1043,6 +1043,11 @@ struct PythonPrintImpl {
|
||||||
}
|
}
|
||||||
stmt << ")";
|
stmt << ")";
|
||||||
} break;
|
} break;
|
||||||
|
case prim::tolist: {
|
||||||
|
stmt << "annotate(" << node->output()->type()->python_str() << ", ";
|
||||||
|
stmt << useOf(node->input(0)) << ".tolist()"
|
||||||
|
<< ")";
|
||||||
|
} break;
|
||||||
default: {
|
default: {
|
||||||
printOpName(stmt, node->kind());
|
printOpName(stmt, node->kind());
|
||||||
const FunctionSchema& schema = node->schema();
|
const FunctionSchema& schema = node->schema();
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user