add torch.tensor requires grad (#19445)

Summary:
Add setting requires_grad = True within torchscript to torch.Tensor

Within constant propagation, we can't insert any constants that require grad.

Also added shape analysis and requires grad analysis to torch.tensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19445

Differential Revision: D15039713

Pulled By: eellison

fbshipit-source-id: 47f1931b6fc4a1137c13d80110cc404465bfdf06
This commit is contained in:
Elias Ellison 2019-04-22 17:56:51 -07:00 committed by Facebook Github Bot
parent 8be6d5ffd8
commit d2b03512da
7 changed files with 413 additions and 247 deletions

View File

@ -1199,6 +1199,20 @@ inline TypePtr CompleteTensorType::fromBoolType() {
return CompleteTensorType::create(at::kLong, at::kCPU, {});
}
inline at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) {
if (type == FloatType::get()) {
return at::ScalarType::Double;
} else if (type == IntType::get()) {
return at::ScalarType::Long;
} else if (type == BoolType::get()) {
return at::ScalarType::Byte;
}
AT_ASSERTM(
0,
"Add new condition, expected Float, Int, or Bool but got",
type->str());
}
// Attempt to find the correct supertype of t1 and t2. If none is found then
// nullopt will be returned. If t1 == t2, or t1 is a type refinement of t2,
// then t2 will be returned (and vice versa).

View File

@ -62,12 +62,13 @@ graph():
%1 : int = prim::Constant[value=1]()
%5 : int? = prim::Constant()
%7 : Device? = prim::Constant()
%15: bool = prim::Constant[value=0]()
%10 : int = prim::Constant[value=6]()
%3 : int[] = prim::ListConstruct(%1, %2)
%x : Tensor = aten::tensor(%3, %5, %7)
%y : Tensor = aten::tensor(%3, %10, %7)
%x : Tensor = aten::tensor(%3, %5, %7, %15)
%y : Tensor = aten::tensor(%3, %10, %7, %15)
%9 : int[] = prim::ListConstruct(%1, %2)
%z : Tensor = aten::tensor(%9, %10, %7)
%z : Tensor = aten::tensor(%9, %10, %7, %15)
%14 : (Tensor, Tensor) = prim::TupleConstruct(%x, %y)
return (%14)
)IR",

View File

@ -5834,6 +5834,66 @@ a")
self.checkScript(func, ())
def test_tensor_shape_prop(self):
template = dedent('''
def func():
li = {list_create}
return torch.tensor(li)
''')
list_input = ["[1]", "[False]", "[2.5]", "0.5", "1", "False", "[[1]]"]
expected_shape = ["Long(*)", ("Byte(*)"), "Double(*)", "Double()", "Long()", "Byte()", "Long(*, *)"]
for list_i, expect in zip(list_input, expected_shape):
code = template.format(list_create=list_i)
scope = {}
exec(code, globals(), scope)
cu = torch.jit.CompilationUnit(code)
g = cu.func
torch._C._jit_pass_complete_shape_analysis(g.graph, (), False)
FileCheck().check(expect).check("aten::tensor").run(g.graph)
@torch.jit.script
def test_dtype(inp_dtype):
# type: (int) -> Tuple[Tensor, Tensor]
a = torch.tensor(1.0, dtype=torch.float, requires_grad=True)
return a, torch.tensor(1.0, dtype=inp_dtype)
test_dtype(5)
g = test_dtype.graph_for(5)
# first should have type set second should not
FileCheck().check("Float() = aten::tensor").check("Tensor = aten::tensor").run(g)
def test_tensor_requires_grad(self):
@torch.jit.script
def test(b):
# type: (bool) -> Tuple[Tensor, Tensor, Tensor]
a = torch.tensor(1., requires_grad=b)
b = torch.tensor(1., requires_grad=True)
c = torch.tensor(1., requires_grad=False)
return a, b, c
g = test.graph_for(True)
out = next(g.outputs())
out_inp = list(out.node().inputs())
self.assertTrue(out_inp[0].requires_grad())
self.assertTrue(out_inp[1].requires_grad())
self.assertFalse(out_inp[2].requires_grad())
def test_grad_from_script(self):
def test():
a = torch.tensor(2.5, requires_grad=True)
b = a * 2
return a, b
a, b = test()
b.backward()
a_script, b_script = torch.jit.script(test)()
b_script.backward()
self.assertEqual(a.grad, a_script.grad)
def test_torch_tensor(self):
template = dedent('''
def func():

View File

@ -1,6 +1,7 @@
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <ATen/core/functional.h>
#include <ATen/core/ivalue.h>
#include <c10/util/Exception.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/constants.h>
#include <torch/csrc/jit/interpreter.h>
@ -36,6 +37,9 @@ std::vector<IValue> runNode(Node* n) {
if (v.isTensor()) {
auto t = std::move(v).toTensor();
if (t.defined()) {
if (t.requires_grad()) {
throw c10::Error("Can't insert requires grad as constant", "");
}
return IValue(autograd::as_variable_ref(t).data());
} else {
return t;

View File

@ -1,5 +1,6 @@
#include <ATen/core/jit_type.h>
#include <torch/csrc/jit/argument_spec.h>
#include <torch/csrc/jit/constants.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/operator.h>
@ -64,6 +65,17 @@ void PropagateRequiresGradSimpleNode(Node* node) {
return setRequiresGrad(node->output(), node->input(0)->requires_grad());
} else if (node->matches("aten::detach(Tensor self) -> Tensor")) {
return setRequiresGrad(node->output(), false);
} else if (node->kind() == aten::tensor) {
if (auto grad_index =
node->schema().argumentIndexWithName("requires_grad")) {
if (auto const_arg = constant_as<bool>(node->inputs().at(*grad_index))) {
return setRequiresGrad(node->output(), *const_arg);
}
}
if (auto type = node->output()->type()->cast<DimensionedTensorType>()) {
setRequiresGrad(node->output(), at::isFloatingType(type->scalarType()));
}
return;
}
auto inputs = node->inputs();

View File

@ -418,6 +418,41 @@ class ShapePropagator {
setUnshapedType(cat_node);
}
void propagateTorchTensorShape(Node* node) {
auto input_type = node->inputs().at(0)->type();
size_t dims = 0;
auto input_base_type = input_type;
auto list_type = input_type->cast<ListType>();
while (list_type) {
dims++;
input_base_type = list_type->getElementType();
list_type = input_base_type->cast<ListType>();
}
at::ScalarType default_type = scalarTypeFromJitType(input_base_type);
if (auto grad_index = node->schema().argumentIndexWithName("dtype")) {
auto inp = toIValue(node->inputs().at(*grad_index));
if (inp == c10::nullopt) {
return;
} else if (!inp->isNone()) {
default_type = inp->toScalarType();
}
}
at::Device default_device = at::kCPU;
if (auto device_index = node->schema().argumentIndexWithName("device")) {
auto inp = toIValue(node->inputs().at(*device_index));
if (inp == c10::nullopt) {
return;
} else if (!inp->isNone()) {
default_device = inp->toDevice();
}
}
node->output()->setType(
DimensionedTensorType::create(default_type, default_device, dims));
}
bool mayAliasResizedSet(at::ArrayRef<Value*> vs) {
bool in_resize = false;
for (auto v : vs) {
@ -493,6 +528,9 @@ class ShapePropagator {
}
return;
}
case aten::tensor: {
return propagateTorchTensorShape(node);
}
case prim::TupleConstruct: {
// We refresh the tuple type, because the input types could have been
// refined.

View File

@ -1,11 +1,14 @@
#include <aten/src/ATen/Context.h>
#include <ATen/core/jit_type.h>
#include <aten/src/ATen/ExpandUtils.h>
#include <torch/csrc/api/include/torch/utils.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/jit/custom_operator.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/api/include/torch/utils.h>
#include <aten/src/ATen/ExpandUtils.h>
#include <c10/core/ScalarType.h>
#include <aten/src/ATen/InitialTensorOptions.h>
#include <c10/core/ScalarType.h>
#include <torch/csrc/jit/script/error_report.h>
#include <regex>
@ -16,45 +19,33 @@ namespace jit {
namespace {
void checkListInputType(const c10::TypePtr& elem_type, const Node* node) {
if (!elem_type->isSubtypeOf(NumberType::get()) && elem_type != BoolType::get()) {
if (!elem_type->isSubtypeOf(NumberType::get()) &&
elem_type != BoolType::get()) {
auto error = script::ErrorReport(node->getSourceLocation());
error << "Input list to torch.tensor must be of ints, floats, or bools, " <<
"got " << elem_type->str();
error << "Input list to torch.tensor must be of ints, floats, or bools, "
<< "got " << elem_type->str();
// special case empty list torch.tensor([])
if (elem_type->isSubtypeOf(TensorType::get())) {
auto input = node->inputs().at(0);
if (input->node()->kind() == prim::ListConstruct && input->node()->inputs().size() == 0) {
if (input->node()->kind() == prim::ListConstruct &&
input->node()->inputs().size() == 0) {
error << "\n(Note: empty lists are constructed as Tensor[]; \n"
<< "if you want an empty list of a different type, \n"
<< "use `torch.jit.annotate(List[T], [])`, \n"
<< "where `T` is the type of elements in the list)";
<< "if you want an empty list of a different type, \n"
<< "use `torch.jit.annotate(List[T], [])`, \n"
<< "where `T` is the type of elements in the list)";
}
}
throw error;
}
}
at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) {
if (type == FloatType::get()) {
return at::ScalarType::Double;
} else if (type == IntType::get()) {
return at::ScalarType::Long;
} else if (type == BoolType::get()) {
return at::ScalarType::Byte;
}
AT_ASSERTM(0, "Add new condition, expected Float, Int, or Bool but got",
type->str());
}
int64_t list_size(const IValue& list) {
if (list.isGenericList()) {
return list.toGenericListRef().size();
} else if (list.isIntList()) {
return list.toIntListRef().size();
} else if (list.isDoubleList()){
} else if (list.isDoubleList()) {
return list.toDoubleListRef().size();
} else if (list.isBoolList()) {
return list.toBoolListRef().size();
@ -80,13 +71,25 @@ std::vector<int64_t> compute_sizes(const IValue& seq) {
void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) {
if (seq_size != n) {
AT_ERROR("Expected sequence of length ", n, " at dim ", dim, " (got ", seq_size, ")");
AT_ERROR(
"Expected sequence of length ",
n,
" at dim ",
dim,
" (got ",
seq_size,
")");
}
}
template <typename DTYPE>
void storeLastDimension(char* data, const std::vector<int64_t>& sizes, const c10::ArrayRef<int64_t>& strides, int64_t dim,
int elementSize, const std::vector<DTYPE>& obj) {
void storeLastDimension(
char* data,
const std::vector<int64_t>& sizes,
const c10::ArrayRef<int64_t>& strides,
int64_t dim,
int elementSize,
const std::vector<DTYPE>& obj) {
auto n = sizes[dim];
auto seq_size = obj.size();
checkSequenceSize(n, dim, seq_size);
@ -97,9 +100,14 @@ void storeLastDimension(char* data, const std::vector<int64_t>& sizes, const c10
}
// bool vector needs to be cast to uint8_t
template<>
void storeLastDimension<bool>(char* data, const std::vector<int64_t>& sizes, const c10::ArrayRef<int64_t>& strides, int64_t dim,
int elementSize, const std::vector<bool>& obj) {
template <>
void storeLastDimension<bool>(
char* data,
const std::vector<int64_t>& sizes,
const c10::ArrayRef<int64_t>& strides,
int64_t dim,
int elementSize,
const std::vector<bool>& obj) {
auto n = sizes[dim];
auto seq_size = obj.size();
checkSequenceSize(n, dim, seq_size);
@ -111,9 +119,13 @@ void storeLastDimension<bool>(char* data, const std::vector<int64_t>& sizes, con
// refernce python implementation recursive_store in tensor_new.cpp
void recursiveStore(char* data, const std::vector<int64_t>& sizes, const c10::ArrayRef<int64_t>& strides, int64_t dim,
int elementSize, const IValue& obj) {
void recursiveStore(
char* data,
const std::vector<int64_t>& sizes,
const c10::ArrayRef<int64_t>& strides,
int64_t dim,
int elementSize,
const IValue& obj) {
auto ndim = sizes.size();
auto n = sizes[dim];
auto seq_size = list_size(obj);
@ -127,240 +139,265 @@ void recursiveStore(char* data, const std::vector<int64_t>& sizes, const c10::Ar
} else {
AT_ASSERT(obj.isIntList() || obj.isDoubleList() || obj.isBoolList());
if (obj.isIntList()) {
storeLastDimension<int64_t>(data, sizes, strides, dim, elementSize, obj.toIntListRef());
} else if (obj.isDoubleList()){
storeLastDimension<double>(data, sizes, strides, dim, elementSize, obj.toDoubleListRef());
storeLastDimension<int64_t>(
data, sizes, strides, dim, elementSize, obj.toIntListRef());
} else if (obj.isDoubleList()) {
storeLastDimension<double>(
data, sizes, strides, dim, elementSize, obj.toDoubleListRef());
} else {
storeLastDimension<bool>(data, sizes, strides, dim, elementSize, obj.toBoolListRef());
storeLastDimension<bool>(
data, sizes, strides, dim, elementSize, obj.toBoolListRef());
}
}
}
RegisterOperators reg({
Operator(
"aten::split(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]",
[](Stack& stack) {
RECORD_FUNCTION("split_with_sizes", last(stack, 3));
RegisterOperators reg(
{Operator(
"aten::split(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]",
[](Stack& stack) {
RECORD_FUNCTION("split_with_sizes", last(stack, 3));
auto result = at::split_with_sizes(
(std::move(peek(stack, 0, 3))).toTensor(),
(std::move(peek(stack, 1, 3))).toIntList()->elements(),
(std::move(peek(stack, 2, 3))).toInt());
drop(stack, 3);
pack(stack, std::move(result));
return 0;
}),
Operator(
"aten::Size(int[] sizes) -> int[]",
[](Stack& stack) { return 0; }),
Operator(
"aten::size(Tensor self) -> int[]",
[](Stack& stack) {
RECORD_FUNCTION("size", last(stack, 1));
auto result = at::split_with_sizes(
(std::move(peek(stack, 0, 3))).toTensor(),
(std::move(peek(stack, 1, 3))).toIntList()->elements(),
(std::move(peek(stack, 2, 3))).toInt());
drop(stack, 3);
pack(stack, std::move(result));
return 0;
}),
Operator(
"aten::Size(int[] sizes) -> int[]",
[](Stack& stack) { return 0; }),
Operator(
"aten::size(Tensor self) -> int[]",
[](Stack& stack) {
RECORD_FUNCTION("size", last(stack, 1));
auto t = std::move(pop(stack)).toTensor();
pack(stack, t.sizes().vec());
return 0;
}),
Operator(
"aten::list_with_default(int[] list, int[] defaults) -> int[]",
[](Stack& stack) {
RECORD_FUNCTION("sizes", last(stack, 2));
auto t = std::move(pop(stack)).toTensor();
pack(stack, t.sizes().vec());
return 0;
}),
Operator(
"aten::list_with_default(int[] list, int[] defaults) -> int[]",
[](Stack& stack) {
RECORD_FUNCTION("sizes", last(stack, 2));
auto list = peek(stack, 0, 2).toIntListRef();
auto defaults = peek(stack, 1, 2).toIntListRef();
drop(stack, 2);
auto list = peek(stack, 0, 2).toIntListRef();
auto defaults = peek(stack, 1, 2).toIntListRef();
drop(stack, 2);
AT_ASSERT(defaults.size() > list.size());
AT_ASSERT(defaults.size() > list.size());
// TODO: allow list of optionals to be filled in with defaults
// i.e. list_with_default([1, 2, None], [1, 2, 3]) -> [1, 2, 3]
// TODO: allow list of optionals to be filled in with defaults
// i.e. list_with_default([1, 2, None], [1, 2, 3]) -> [1, 2, 3]
push(stack, list);
return 0;
}),
Operator(
"aten::_infer_size(int[] a, int[] b) -> int[]",
[](const Node* node) {
return [](Stack& stack) {
auto a = pop(stack).toIntList()->elements();
auto b = pop(stack).toIntList()->elements();
push(stack, at::infer_size(a, b));
return 0;
};
}),
Operator(
"aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor",
[](const Node* node) {
return [](Stack& stack) {
at::Tensor weight;
at::Tensor input;
double max_norm;
double norm_type;
pop(stack, weight, input, max_norm, norm_type);
push(stack, list);
return 0;
}),
Operator(
"aten::_infer_size(int[] a, int[] b) -> int[]",
[](const Node* node) {
return [](Stack& stack) {
auto a = pop(stack).toIntList()->elements();
auto b = pop(stack).toIntList()->elements();
push(stack, at::infer_size(a, b));
return 0;
};
}),
Operator(
"aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor",
[](const Node* node) {
return [](Stack& stack) {
at::Tensor weight;
at::Tensor input;
double max_norm;
double norm_type;
pop(stack, weight, input, max_norm, norm_type);
// TODO: remove when script supports setting grad mode
torch::NoGradGuard no_grad;
// TODO: remove when script supports setting grad mode
torch::NoGradGuard no_grad;
at::Tensor result = at::embedding_renorm_(weight, input, max_norm, norm_type);
push(stack, result);
at::Tensor result =
at::embedding_renorm_(weight, input, max_norm, norm_type);
push(stack, result);
return 0;
};
return 0;
};
}),
Operator(
"aten::format(str self, ...) -> str",
[](const Node* node) {
size_t num_inputs = node->inputs().size();
std::regex unsupported_options("\\{(.*)\\}");
return [num_inputs, unsupported_options](Stack& stack) {
auto format = peek(stack, 0, num_inputs).toStringRef();
if (std::regex_search(format, unsupported_options)) {
AT_WARN("Format options are not supported.");
}
auto args = last(stack, num_inputs - 1);
std::stringstream ss;
for (size_t begin = 0, used_args = 0; true; ++used_args) {
size_t loc = format.find("{}", begin);
if (loc == std::string::npos) {
ss << format.substr(begin);
break;
}
ss << format.substr(begin, loc - begin);
if (used_args >= args.size()) {
AT_ERROR("Too few arguments for format string: ", format);
}
ss << args[used_args];
begin = loc + 2;
}
drop(stack, num_inputs);
push(stack, ss.str());
return 0;
};
}),
#define DEFINE_TORCH_TENSOR_OP(operator_type, c_type, tensor_creation_op) \
Operator( \
"aten::tensor(" #operator_type \
" t, *, ScalarType? dtype=None, Device? device=None" \
", bool requires_grad=False) -> Tensor", \
[](const Node* node) { \
auto initial_scalar_type = \
scalarTypeFromJitType(node->inputs().at(0)->type()); \
return [initial_scalar_type](Stack& stack) { \
c_type scalar_val; \
IValue dtype; \
IValue device; \
bool requires_grad; \
pop(stack, scalar_val, dtype, device, requires_grad); \
auto tensor = autograd::make_variable(tensor_creation_op); \
at::ScalarType scalar_type = \
dtype.isNone() ? tensor.scalar_type() : dtype.toScalarType(); \
c10::Device dev = \
device.isNone() ? tensor.device() : device.toDevice(); \
if (scalar_type != initial_scalar_type || dev != tensor.device()) { \
tensor = tensor.to(dev, scalar_type); \
} \
push(stack, tensor); \
tensor.set_requires_grad(requires_grad); \
return 0; \
}; \
}),
Operator(
"aten::format(str self, ...) -> str",
[](const Node* node) {
size_t num_inputs = node->inputs().size();
std::regex unsupported_options("\\{(.*)\\}");
return [num_inputs, unsupported_options](Stack& stack) {
auto format = peek(stack, 0, num_inputs).toStringRef();
if (std::regex_search(format, unsupported_options)) {
AT_WARN("Format options are not supported.");
}
DEFINE_TORCH_TENSOR_OP(float, double, at::scalar_to_tensor(scalar_val))
DEFINE_TORCH_TENSOR_OP(int, int64_t, at::scalar_to_tensor(scalar_val))
DEFINE_TORCH_TENSOR_OP(
bool,
bool,
at::empty({}, at::CPU(at::kByte).options()).fill_(scalar_val))
auto args = last(stack, num_inputs - 1);
std::stringstream ss;
for (size_t begin = 0, used_args = 0; true; ++used_args) {
size_t loc = format.find("{}", begin);
if (loc == std::string::npos) {
ss << format.substr(begin);
break;
}
ss << format.substr(begin, loc - begin);
if (used_args >= args.size()) {
AT_ERROR("Too few arguments for format string: ", format);
}
ss << args[used_args];
begin = loc + 2;
}
// reference python implementation: internal_new_from_data in
// tensor_new.cpp
Operator(
"aten::_infer_size(int[] a, int[] b) -> int[]",
[](const Node* node) {
return [](Stack& stack) {
auto a = pop(stack).toIntList()->elements();
auto b = pop(stack).toIntList()->elements();
push(stack, at::infer_size(a, b));
return 0;
};
}),
Operator(
"aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor",
[](const Node* node) {
return [](Stack& stack) {
at::Tensor weight;
at::Tensor input;
double max_norm;
double norm_type;
pop(stack, weight, input, max_norm, norm_type);
drop(stack, num_inputs);
push(stack, ss.str());
return 0;
};
}),
// TODO: remove when script supports setting grad mode
torch::NoGradGuard no_grad;
#define DEFINE_TORCH_TENSOR_OP(operator_type, c_type, tensor_creation_op) \
Operator( \
"aten::tensor(" #operator_type " t, *, ScalarType? dtype=None, Device? device=None"\
") -> Tensor", \
[](const Node* node) { \
auto initial_scalar_type = scalarTypeFromJitType(node->inputs().at(0)->type()); \
return [initial_scalar_type](Stack& stack) { \
c_type scalar_val; \
IValue dtype; \
IValue device; \
pop(stack, scalar_val, dtype, device); \
auto tensor = autograd::make_variable(tensor_creation_op); \
at::ScalarType scalar_type = dtype.isNone() ? \
tensor.scalar_type() : dtype.toScalarType(); \
c10::Device dev = device.isNone() ? tensor.device() : device.toDevice(); \
if (scalar_type != initial_scalar_type || dev != tensor.device()) { \
tensor = tensor.to(dev, scalar_type); \
} \
push(stack, tensor); \
return 0; \
}; \
}),
at::Tensor result =
at::embedding_renorm_(weight, input, max_norm, norm_type);
push(stack, result);
DEFINE_TORCH_TENSOR_OP(float, double, at::scalar_to_tensor(scalar_val))
DEFINE_TORCH_TENSOR_OP(int, int64_t, at::scalar_to_tensor(scalar_val))
DEFINE_TORCH_TENSOR_OP(bool, bool, at::empty({}, at::CPU(at::kByte).options()).fill_(scalar_val))
return 0;
};
}),
Operator(
"aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor",
[](const Node* node) {
auto input = node->inputs().at(0);
auto elem_type = input->type();
while (auto list_type = elem_type->cast<ListType>()) {
elem_type = list_type->getElementType();
}
checkListInputType(elem_type, node);
at::ScalarType initial_scalar_type =
scalarTypeFromJitType(elem_type);
return [initial_scalar_type, elem_type](Stack& stack) {
bool requires_grad;
IValue data;
IValue dtype;
IValue device;
pop(stack, data, dtype, device, requires_grad);
auto sizes = compute_sizes(data);
auto tensor = autograd::make_variable(at::empty(
sizes, at::initialTensorOptions().dtype(initial_scalar_type)));
recursiveStore(
(char*)tensor.data_ptr(),
sizes,
tensor.strides(),
0,
tensor.element_size(),
data);
// reference python implementation: internal_new_from_data in tensor_new.cpp
Operator(
"aten::_infer_size(int[] a, int[] b) -> int[]",
[](const Node* node) {
return [](Stack& stack) {
auto a = pop(stack).toIntList()->elements();
auto b = pop(stack).toIntList()->elements();
push(stack, at::infer_size(a, b));
return 0;
};
}),
Operator(
"aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor",
[](const Node* node) {
return [](Stack& stack) {
at::Tensor weight;
at::Tensor input;
double max_norm;
double norm_type;
pop(stack, weight, input, max_norm, norm_type);
at::ScalarType scalar_type =
dtype.isNone() ? tensor.scalar_type() : dtype.toScalarType();
c10::Device dev =
device.isNone() ? tensor.device() : device.toDevice();
if (scalar_type != initial_scalar_type || dev != tensor.device()) {
tensor = tensor.to(dev, scalar_type);
}
// TODO: remove when script supports setting grad mode
torch::NoGradGuard no_grad;
auto default_type =
at::typeMetaToScalarType(at::get_default_dtype());
at::Tensor result =
at::embedding_renorm_(weight, input, max_norm, norm_type);
push(stack, result);
if (dtype.isNone() && tensor.scalar_type() != default_type &&
tensor.numel() == 0) {
AT_WARN(
"Creating a tensor from an empty ",
elem_type->str(),
"list will create a tensor of default floating point type (currently ",
default_type,
") in python but a tensor of type ",
elem_type->str(),
" in torchscript.\n",
"Pass in a dtype argument to ensure consistent behavior");
}
tensor.set_requires_grad(requires_grad);
push(stack, tensor);
return 0;
};
}),
Operator(
"aten::_assert_int_or_pair(int[] vals, str name, str message) -> Tensor",
[](const Node* node) {
return [](Stack& stack) {
// Everything is a list at the point this is used, so don't do
// anything
drop(stack, 3);
return 0;
};
}),
Operator(
"aten::_pack_sequence(Tensor output, Tensor batch_sizes, Tensor? sorted_indices, "
"Tensor? unsorted_indices) -> (Tensor, Tensor, Tensor?, Tensor?)",
[](Stack& stack) { return 0; })
return 0;
};
}),
Operator(
"aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None) -> Tensor",
[](const Node* node) {
auto input = node->inputs().at(0);
auto elem_type = input->type();
while (auto list_type = elem_type->cast<ListType>()) {
elem_type = list_type->getElementType();
}
checkListInputType(elem_type, node);
at::ScalarType initial_scalar_type = scalarTypeFromJitType(elem_type);
return [initial_scalar_type, elem_type](Stack& stack) {
IValue data;
IValue dtype;
IValue device;
pop(stack, data, dtype, device);
auto sizes = compute_sizes(data);
auto tensor = autograd::make_variable(
at::empty(sizes, at::initialTensorOptions().dtype(initial_scalar_type)));
recursiveStore((char*)tensor.data_ptr(), sizes, tensor.strides(), 0,
tensor.element_size(), data);
at::ScalarType scalar_type = dtype.isNone() ? tensor.scalar_type() : dtype.toScalarType();
c10::Device dev = device.isNone() ? tensor.device() : device.toDevice();
if (scalar_type != initial_scalar_type || dev != tensor.device()) {
tensor = tensor.to(dev, scalar_type);
}
auto default_type = at::typeMetaToScalarType(at::get_default_dtype());
if (dtype.isNone() && tensor.scalar_type() != default_type &&
tensor.numel() == 0) {
AT_WARN("Creating a tensor from an empty ", elem_type->str(),
"list will create a tensor of default floating point type (currently ", default_type,
") in python but a tensor of type ", elem_type->str(), " in torchscript.\n",
"Pass in a dtype argument to ensure consistent behavior");
}
push(stack, tensor);
return 0;
};
}),
Operator(
"aten::_assert_int_or_pair(int[] vals, str name, str message) -> Tensor",
[](const Node* node) {
return [](Stack& stack) {
// Everything is a list at the point this is used, so don't do
// anything
drop(stack, 3);
return 0;
};
}),
Operator(
"aten::_pack_sequence(Tensor output, Tensor batch_sizes, Tensor? sorted_indices, "
"Tensor? unsorted_indices) -> (Tensor, Tensor, Tensor?, Tensor?)",
[](Stack& stack) {
return 0;
})
});
}
});
} // namespace
} // namespace jit
} // namespace torch