mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
add torch.tensor requires grad (#19445)
Summary: Add setting requires_grad = True within torchscript to torch.Tensor Within constant propagation, we can't insert any constants that require grad. Also added shape analysis and requires grad analysis to torch.tensor Pull Request resolved: https://github.com/pytorch/pytorch/pull/19445 Differential Revision: D15039713 Pulled By: eellison fbshipit-source-id: 47f1931b6fc4a1137c13d80110cc404465bfdf06
This commit is contained in:
parent
8be6d5ffd8
commit
d2b03512da
|
|
@ -1199,6 +1199,20 @@ inline TypePtr CompleteTensorType::fromBoolType() {
|
|||
return CompleteTensorType::create(at::kLong, at::kCPU, {});
|
||||
}
|
||||
|
||||
inline at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) {
|
||||
if (type == FloatType::get()) {
|
||||
return at::ScalarType::Double;
|
||||
} else if (type == IntType::get()) {
|
||||
return at::ScalarType::Long;
|
||||
} else if (type == BoolType::get()) {
|
||||
return at::ScalarType::Byte;
|
||||
}
|
||||
AT_ASSERTM(
|
||||
0,
|
||||
"Add new condition, expected Float, Int, or Bool but got",
|
||||
type->str());
|
||||
}
|
||||
|
||||
// Attempt to find the correct supertype of t1 and t2. If none is found then
|
||||
// nullopt will be returned. If t1 == t2, or t1 is a type refinement of t2,
|
||||
// then t2 will be returned (and vice versa).
|
||||
|
|
|
|||
|
|
@ -62,12 +62,13 @@ graph():
|
|||
%1 : int = prim::Constant[value=1]()
|
||||
%5 : int? = prim::Constant()
|
||||
%7 : Device? = prim::Constant()
|
||||
%15: bool = prim::Constant[value=0]()
|
||||
%10 : int = prim::Constant[value=6]()
|
||||
%3 : int[] = prim::ListConstruct(%1, %2)
|
||||
%x : Tensor = aten::tensor(%3, %5, %7)
|
||||
%y : Tensor = aten::tensor(%3, %10, %7)
|
||||
%x : Tensor = aten::tensor(%3, %5, %7, %15)
|
||||
%y : Tensor = aten::tensor(%3, %10, %7, %15)
|
||||
%9 : int[] = prim::ListConstruct(%1, %2)
|
||||
%z : Tensor = aten::tensor(%9, %10, %7)
|
||||
%z : Tensor = aten::tensor(%9, %10, %7, %15)
|
||||
%14 : (Tensor, Tensor) = prim::TupleConstruct(%x, %y)
|
||||
return (%14)
|
||||
)IR",
|
||||
|
|
|
|||
|
|
@ -5834,6 +5834,66 @@ a")
|
|||
|
||||
self.checkScript(func, ())
|
||||
|
||||
def test_tensor_shape_prop(self):
|
||||
template = dedent('''
|
||||
def func():
|
||||
li = {list_create}
|
||||
return torch.tensor(li)
|
||||
''')
|
||||
|
||||
list_input = ["[1]", "[False]", "[2.5]", "0.5", "1", "False", "[[1]]"]
|
||||
expected_shape = ["Long(*)", ("Byte(*)"), "Double(*)", "Double()", "Long()", "Byte()", "Long(*, *)"]
|
||||
|
||||
for list_i, expect in zip(list_input, expected_shape):
|
||||
code = template.format(list_create=list_i)
|
||||
scope = {}
|
||||
exec(code, globals(), scope)
|
||||
cu = torch.jit.CompilationUnit(code)
|
||||
g = cu.func
|
||||
torch._C._jit_pass_complete_shape_analysis(g.graph, (), False)
|
||||
FileCheck().check(expect).check("aten::tensor").run(g.graph)
|
||||
|
||||
@torch.jit.script
|
||||
def test_dtype(inp_dtype):
|
||||
# type: (int) -> Tuple[Tensor, Tensor]
|
||||
a = torch.tensor(1.0, dtype=torch.float, requires_grad=True)
|
||||
return a, torch.tensor(1.0, dtype=inp_dtype)
|
||||
|
||||
test_dtype(5)
|
||||
g = test_dtype.graph_for(5)
|
||||
# first should have type set second should not
|
||||
FileCheck().check("Float() = aten::tensor").check("Tensor = aten::tensor").run(g)
|
||||
|
||||
def test_tensor_requires_grad(self):
|
||||
@torch.jit.script
|
||||
def test(b):
|
||||
# type: (bool) -> Tuple[Tensor, Tensor, Tensor]
|
||||
a = torch.tensor(1., requires_grad=b)
|
||||
b = torch.tensor(1., requires_grad=True)
|
||||
c = torch.tensor(1., requires_grad=False)
|
||||
return a, b, c
|
||||
|
||||
g = test.graph_for(True)
|
||||
out = next(g.outputs())
|
||||
out_inp = list(out.node().inputs())
|
||||
|
||||
self.assertTrue(out_inp[0].requires_grad())
|
||||
self.assertTrue(out_inp[1].requires_grad())
|
||||
self.assertFalse(out_inp[2].requires_grad())
|
||||
|
||||
def test_grad_from_script(self):
|
||||
def test():
|
||||
a = torch.tensor(2.5, requires_grad=True)
|
||||
b = a * 2
|
||||
return a, b
|
||||
|
||||
a, b = test()
|
||||
b.backward()
|
||||
|
||||
a_script, b_script = torch.jit.script(test)()
|
||||
b_script.backward()
|
||||
self.assertEqual(a.grad, a_script.grad)
|
||||
|
||||
def test_torch_tensor(self):
|
||||
template = dedent('''
|
||||
def func():
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#include <torch/csrc/jit/passes/constant_propagation.h>
|
||||
#include <ATen/core/functional.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
#include <torch/csrc/jit/constants.h>
|
||||
#include <torch/csrc/jit/interpreter.h>
|
||||
|
|
@ -36,6 +37,9 @@ std::vector<IValue> runNode(Node* n) {
|
|||
if (v.isTensor()) {
|
||||
auto t = std::move(v).toTensor();
|
||||
if (t.defined()) {
|
||||
if (t.requires_grad()) {
|
||||
throw c10::Error("Can't insert requires grad as constant", "");
|
||||
}
|
||||
return IValue(autograd::as_variable_ref(t).data());
|
||||
} else {
|
||||
return t;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#include <ATen/core/jit_type.h>
|
||||
#include <torch/csrc/jit/argument_spec.h>
|
||||
#include <torch/csrc/jit/constants.h>
|
||||
#include <torch/csrc/jit/ir.h>
|
||||
#include <torch/csrc/jit/operator.h>
|
||||
|
||||
|
|
@ -64,6 +65,17 @@ void PropagateRequiresGradSimpleNode(Node* node) {
|
|||
return setRequiresGrad(node->output(), node->input(0)->requires_grad());
|
||||
} else if (node->matches("aten::detach(Tensor self) -> Tensor")) {
|
||||
return setRequiresGrad(node->output(), false);
|
||||
} else if (node->kind() == aten::tensor) {
|
||||
if (auto grad_index =
|
||||
node->schema().argumentIndexWithName("requires_grad")) {
|
||||
if (auto const_arg = constant_as<bool>(node->inputs().at(*grad_index))) {
|
||||
return setRequiresGrad(node->output(), *const_arg);
|
||||
}
|
||||
}
|
||||
if (auto type = node->output()->type()->cast<DimensionedTensorType>()) {
|
||||
setRequiresGrad(node->output(), at::isFloatingType(type->scalarType()));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
auto inputs = node->inputs();
|
||||
|
|
|
|||
|
|
@ -418,6 +418,41 @@ class ShapePropagator {
|
|||
setUnshapedType(cat_node);
|
||||
}
|
||||
|
||||
void propagateTorchTensorShape(Node* node) {
|
||||
auto input_type = node->inputs().at(0)->type();
|
||||
|
||||
size_t dims = 0;
|
||||
auto input_base_type = input_type;
|
||||
auto list_type = input_type->cast<ListType>();
|
||||
while (list_type) {
|
||||
dims++;
|
||||
input_base_type = list_type->getElementType();
|
||||
list_type = input_base_type->cast<ListType>();
|
||||
}
|
||||
|
||||
at::ScalarType default_type = scalarTypeFromJitType(input_base_type);
|
||||
if (auto grad_index = node->schema().argumentIndexWithName("dtype")) {
|
||||
auto inp = toIValue(node->inputs().at(*grad_index));
|
||||
if (inp == c10::nullopt) {
|
||||
return;
|
||||
} else if (!inp->isNone()) {
|
||||
default_type = inp->toScalarType();
|
||||
}
|
||||
}
|
||||
|
||||
at::Device default_device = at::kCPU;
|
||||
if (auto device_index = node->schema().argumentIndexWithName("device")) {
|
||||
auto inp = toIValue(node->inputs().at(*device_index));
|
||||
if (inp == c10::nullopt) {
|
||||
return;
|
||||
} else if (!inp->isNone()) {
|
||||
default_device = inp->toDevice();
|
||||
}
|
||||
}
|
||||
node->output()->setType(
|
||||
DimensionedTensorType::create(default_type, default_device, dims));
|
||||
}
|
||||
|
||||
bool mayAliasResizedSet(at::ArrayRef<Value*> vs) {
|
||||
bool in_resize = false;
|
||||
for (auto v : vs) {
|
||||
|
|
@ -493,6 +528,9 @@ class ShapePropagator {
|
|||
}
|
||||
return;
|
||||
}
|
||||
case aten::tensor: {
|
||||
return propagateTorchTensorShape(node);
|
||||
}
|
||||
case prim::TupleConstruct: {
|
||||
// We refresh the tuple type, because the input types could have been
|
||||
// refined.
|
||||
|
|
|
|||
|
|
@ -1,11 +1,14 @@
|
|||
#include <aten/src/ATen/Context.h>
|
||||
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <aten/src/ATen/ExpandUtils.h>
|
||||
#include <torch/csrc/api/include/torch/utils.h>
|
||||
#include <torch/csrc/autograd/profiler.h>
|
||||
#include <torch/csrc/jit/custom_operator.h>
|
||||
#include <torch/csrc/jit/operator.h>
|
||||
#include <torch/csrc/api/include/torch/utils.h>
|
||||
#include <aten/src/ATen/ExpandUtils.h>
|
||||
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <aten/src/ATen/InitialTensorOptions.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <torch/csrc/jit/script/error_report.h>
|
||||
|
||||
#include <regex>
|
||||
|
|
@ -16,45 +19,33 @@ namespace jit {
|
|||
|
||||
namespace {
|
||||
|
||||
|
||||
void checkListInputType(const c10::TypePtr& elem_type, const Node* node) {
|
||||
if (!elem_type->isSubtypeOf(NumberType::get()) && elem_type != BoolType::get()) {
|
||||
if (!elem_type->isSubtypeOf(NumberType::get()) &&
|
||||
elem_type != BoolType::get()) {
|
||||
auto error = script::ErrorReport(node->getSourceLocation());
|
||||
error << "Input list to torch.tensor must be of ints, floats, or bools, " <<
|
||||
"got " << elem_type->str();
|
||||
error << "Input list to torch.tensor must be of ints, floats, or bools, "
|
||||
<< "got " << elem_type->str();
|
||||
// special case empty list torch.tensor([])
|
||||
if (elem_type->isSubtypeOf(TensorType::get())) {
|
||||
auto input = node->inputs().at(0);
|
||||
if (input->node()->kind() == prim::ListConstruct && input->node()->inputs().size() == 0) {
|
||||
if (input->node()->kind() == prim::ListConstruct &&
|
||||
input->node()->inputs().size() == 0) {
|
||||
error << "\n(Note: empty lists are constructed as Tensor[]; \n"
|
||||
<< "if you want an empty list of a different type, \n"
|
||||
<< "use `torch.jit.annotate(List[T], [])`, \n"
|
||||
<< "where `T` is the type of elements in the list)";
|
||||
<< "if you want an empty list of a different type, \n"
|
||||
<< "use `torch.jit.annotate(List[T], [])`, \n"
|
||||
<< "where `T` is the type of elements in the list)";
|
||||
}
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) {
|
||||
if (type == FloatType::get()) {
|
||||
return at::ScalarType::Double;
|
||||
} else if (type == IntType::get()) {
|
||||
return at::ScalarType::Long;
|
||||
} else if (type == BoolType::get()) {
|
||||
return at::ScalarType::Byte;
|
||||
}
|
||||
AT_ASSERTM(0, "Add new condition, expected Float, Int, or Bool but got",
|
||||
type->str());
|
||||
}
|
||||
|
||||
|
||||
int64_t list_size(const IValue& list) {
|
||||
if (list.isGenericList()) {
|
||||
return list.toGenericListRef().size();
|
||||
} else if (list.isIntList()) {
|
||||
return list.toIntListRef().size();
|
||||
} else if (list.isDoubleList()){
|
||||
} else if (list.isDoubleList()) {
|
||||
return list.toDoubleListRef().size();
|
||||
} else if (list.isBoolList()) {
|
||||
return list.toBoolListRef().size();
|
||||
|
|
@ -80,13 +71,25 @@ std::vector<int64_t> compute_sizes(const IValue& seq) {
|
|||
|
||||
void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) {
|
||||
if (seq_size != n) {
|
||||
AT_ERROR("Expected sequence of length ", n, " at dim ", dim, " (got ", seq_size, ")");
|
||||
AT_ERROR(
|
||||
"Expected sequence of length ",
|
||||
n,
|
||||
" at dim ",
|
||||
dim,
|
||||
" (got ",
|
||||
seq_size,
|
||||
")");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DTYPE>
|
||||
void storeLastDimension(char* data, const std::vector<int64_t>& sizes, const c10::ArrayRef<int64_t>& strides, int64_t dim,
|
||||
int elementSize, const std::vector<DTYPE>& obj) {
|
||||
void storeLastDimension(
|
||||
char* data,
|
||||
const std::vector<int64_t>& sizes,
|
||||
const c10::ArrayRef<int64_t>& strides,
|
||||
int64_t dim,
|
||||
int elementSize,
|
||||
const std::vector<DTYPE>& obj) {
|
||||
auto n = sizes[dim];
|
||||
auto seq_size = obj.size();
|
||||
checkSequenceSize(n, dim, seq_size);
|
||||
|
|
@ -97,9 +100,14 @@ void storeLastDimension(char* data, const std::vector<int64_t>& sizes, const c10
|
|||
}
|
||||
|
||||
// bool vector needs to be cast to uint8_t
|
||||
template<>
|
||||
void storeLastDimension<bool>(char* data, const std::vector<int64_t>& sizes, const c10::ArrayRef<int64_t>& strides, int64_t dim,
|
||||
int elementSize, const std::vector<bool>& obj) {
|
||||
template <>
|
||||
void storeLastDimension<bool>(
|
||||
char* data,
|
||||
const std::vector<int64_t>& sizes,
|
||||
const c10::ArrayRef<int64_t>& strides,
|
||||
int64_t dim,
|
||||
int elementSize,
|
||||
const std::vector<bool>& obj) {
|
||||
auto n = sizes[dim];
|
||||
auto seq_size = obj.size();
|
||||
checkSequenceSize(n, dim, seq_size);
|
||||
|
|
@ -111,9 +119,13 @@ void storeLastDimension<bool>(char* data, const std::vector<int64_t>& sizes, con
|
|||
|
||||
// refernce python implementation recursive_store in tensor_new.cpp
|
||||
|
||||
void recursiveStore(char* data, const std::vector<int64_t>& sizes, const c10::ArrayRef<int64_t>& strides, int64_t dim,
|
||||
int elementSize, const IValue& obj) {
|
||||
|
||||
void recursiveStore(
|
||||
char* data,
|
||||
const std::vector<int64_t>& sizes,
|
||||
const c10::ArrayRef<int64_t>& strides,
|
||||
int64_t dim,
|
||||
int elementSize,
|
||||
const IValue& obj) {
|
||||
auto ndim = sizes.size();
|
||||
auto n = sizes[dim];
|
||||
auto seq_size = list_size(obj);
|
||||
|
|
@ -127,240 +139,265 @@ void recursiveStore(char* data, const std::vector<int64_t>& sizes, const c10::Ar
|
|||
} else {
|
||||
AT_ASSERT(obj.isIntList() || obj.isDoubleList() || obj.isBoolList());
|
||||
if (obj.isIntList()) {
|
||||
storeLastDimension<int64_t>(data, sizes, strides, dim, elementSize, obj.toIntListRef());
|
||||
} else if (obj.isDoubleList()){
|
||||
storeLastDimension<double>(data, sizes, strides, dim, elementSize, obj.toDoubleListRef());
|
||||
storeLastDimension<int64_t>(
|
||||
data, sizes, strides, dim, elementSize, obj.toIntListRef());
|
||||
} else if (obj.isDoubleList()) {
|
||||
storeLastDimension<double>(
|
||||
data, sizes, strides, dim, elementSize, obj.toDoubleListRef());
|
||||
} else {
|
||||
storeLastDimension<bool>(data, sizes, strides, dim, elementSize, obj.toBoolListRef());
|
||||
storeLastDimension<bool>(
|
||||
data, sizes, strides, dim, elementSize, obj.toBoolListRef());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RegisterOperators reg({
|
||||
Operator(
|
||||
"aten::split(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]",
|
||||
[](Stack& stack) {
|
||||
RECORD_FUNCTION("split_with_sizes", last(stack, 3));
|
||||
RegisterOperators reg(
|
||||
{Operator(
|
||||
"aten::split(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]",
|
||||
[](Stack& stack) {
|
||||
RECORD_FUNCTION("split_with_sizes", last(stack, 3));
|
||||
|
||||
auto result = at::split_with_sizes(
|
||||
(std::move(peek(stack, 0, 3))).toTensor(),
|
||||
(std::move(peek(stack, 1, 3))).toIntList()->elements(),
|
||||
(std::move(peek(stack, 2, 3))).toInt());
|
||||
drop(stack, 3);
|
||||
pack(stack, std::move(result));
|
||||
return 0;
|
||||
}),
|
||||
Operator(
|
||||
"aten::Size(int[] sizes) -> int[]",
|
||||
[](Stack& stack) { return 0; }),
|
||||
Operator(
|
||||
"aten::size(Tensor self) -> int[]",
|
||||
[](Stack& stack) {
|
||||
RECORD_FUNCTION("size", last(stack, 1));
|
||||
auto result = at::split_with_sizes(
|
||||
(std::move(peek(stack, 0, 3))).toTensor(),
|
||||
(std::move(peek(stack, 1, 3))).toIntList()->elements(),
|
||||
(std::move(peek(stack, 2, 3))).toInt());
|
||||
drop(stack, 3);
|
||||
pack(stack, std::move(result));
|
||||
return 0;
|
||||
}),
|
||||
Operator(
|
||||
"aten::Size(int[] sizes) -> int[]",
|
||||
[](Stack& stack) { return 0; }),
|
||||
Operator(
|
||||
"aten::size(Tensor self) -> int[]",
|
||||
[](Stack& stack) {
|
||||
RECORD_FUNCTION("size", last(stack, 1));
|
||||
|
||||
auto t = std::move(pop(stack)).toTensor();
|
||||
pack(stack, t.sizes().vec());
|
||||
return 0;
|
||||
}),
|
||||
Operator(
|
||||
"aten::list_with_default(int[] list, int[] defaults) -> int[]",
|
||||
[](Stack& stack) {
|
||||
RECORD_FUNCTION("sizes", last(stack, 2));
|
||||
auto t = std::move(pop(stack)).toTensor();
|
||||
pack(stack, t.sizes().vec());
|
||||
return 0;
|
||||
}),
|
||||
Operator(
|
||||
"aten::list_with_default(int[] list, int[] defaults) -> int[]",
|
||||
[](Stack& stack) {
|
||||
RECORD_FUNCTION("sizes", last(stack, 2));
|
||||
|
||||
auto list = peek(stack, 0, 2).toIntListRef();
|
||||
auto defaults = peek(stack, 1, 2).toIntListRef();
|
||||
drop(stack, 2);
|
||||
auto list = peek(stack, 0, 2).toIntListRef();
|
||||
auto defaults = peek(stack, 1, 2).toIntListRef();
|
||||
drop(stack, 2);
|
||||
|
||||
AT_ASSERT(defaults.size() > list.size());
|
||||
AT_ASSERT(defaults.size() > list.size());
|
||||
|
||||
// TODO: allow list of optionals to be filled in with defaults
|
||||
// i.e. list_with_default([1, 2, None], [1, 2, 3]) -> [1, 2, 3]
|
||||
// TODO: allow list of optionals to be filled in with defaults
|
||||
// i.e. list_with_default([1, 2, None], [1, 2, 3]) -> [1, 2, 3]
|
||||
|
||||
push(stack, list);
|
||||
return 0;
|
||||
}),
|
||||
Operator(
|
||||
"aten::_infer_size(int[] a, int[] b) -> int[]",
|
||||
[](const Node* node) {
|
||||
return [](Stack& stack) {
|
||||
auto a = pop(stack).toIntList()->elements();
|
||||
auto b = pop(stack).toIntList()->elements();
|
||||
push(stack, at::infer_size(a, b));
|
||||
return 0;
|
||||
};
|
||||
}),
|
||||
Operator(
|
||||
"aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor",
|
||||
[](const Node* node) {
|
||||
return [](Stack& stack) {
|
||||
at::Tensor weight;
|
||||
at::Tensor input;
|
||||
double max_norm;
|
||||
double norm_type;
|
||||
pop(stack, weight, input, max_norm, norm_type);
|
||||
push(stack, list);
|
||||
return 0;
|
||||
}),
|
||||
Operator(
|
||||
"aten::_infer_size(int[] a, int[] b) -> int[]",
|
||||
[](const Node* node) {
|
||||
return [](Stack& stack) {
|
||||
auto a = pop(stack).toIntList()->elements();
|
||||
auto b = pop(stack).toIntList()->elements();
|
||||
push(stack, at::infer_size(a, b));
|
||||
return 0;
|
||||
};
|
||||
}),
|
||||
Operator(
|
||||
"aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor",
|
||||
[](const Node* node) {
|
||||
return [](Stack& stack) {
|
||||
at::Tensor weight;
|
||||
at::Tensor input;
|
||||
double max_norm;
|
||||
double norm_type;
|
||||
pop(stack, weight, input, max_norm, norm_type);
|
||||
|
||||
// TODO: remove when script supports setting grad mode
|
||||
torch::NoGradGuard no_grad;
|
||||
// TODO: remove when script supports setting grad mode
|
||||
torch::NoGradGuard no_grad;
|
||||
|
||||
at::Tensor result = at::embedding_renorm_(weight, input, max_norm, norm_type);
|
||||
push(stack, result);
|
||||
at::Tensor result =
|
||||
at::embedding_renorm_(weight, input, max_norm, norm_type);
|
||||
push(stack, result);
|
||||
|
||||
return 0;
|
||||
};
|
||||
return 0;
|
||||
};
|
||||
}),
|
||||
Operator(
|
||||
"aten::format(str self, ...) -> str",
|
||||
[](const Node* node) {
|
||||
size_t num_inputs = node->inputs().size();
|
||||
std::regex unsupported_options("\\{(.*)\\}");
|
||||
return [num_inputs, unsupported_options](Stack& stack) {
|
||||
auto format = peek(stack, 0, num_inputs).toStringRef();
|
||||
|
||||
if (std::regex_search(format, unsupported_options)) {
|
||||
AT_WARN("Format options are not supported.");
|
||||
}
|
||||
|
||||
auto args = last(stack, num_inputs - 1);
|
||||
std::stringstream ss;
|
||||
for (size_t begin = 0, used_args = 0; true; ++used_args) {
|
||||
size_t loc = format.find("{}", begin);
|
||||
if (loc == std::string::npos) {
|
||||
ss << format.substr(begin);
|
||||
break;
|
||||
}
|
||||
ss << format.substr(begin, loc - begin);
|
||||
if (used_args >= args.size()) {
|
||||
AT_ERROR("Too few arguments for format string: ", format);
|
||||
}
|
||||
ss << args[used_args];
|
||||
begin = loc + 2;
|
||||
}
|
||||
|
||||
drop(stack, num_inputs);
|
||||
push(stack, ss.str());
|
||||
return 0;
|
||||
};
|
||||
}),
|
||||
|
||||
#define DEFINE_TORCH_TENSOR_OP(operator_type, c_type, tensor_creation_op) \
|
||||
Operator( \
|
||||
"aten::tensor(" #operator_type \
|
||||
" t, *, ScalarType? dtype=None, Device? device=None" \
|
||||
", bool requires_grad=False) -> Tensor", \
|
||||
[](const Node* node) { \
|
||||
auto initial_scalar_type = \
|
||||
scalarTypeFromJitType(node->inputs().at(0)->type()); \
|
||||
return [initial_scalar_type](Stack& stack) { \
|
||||
c_type scalar_val; \
|
||||
IValue dtype; \
|
||||
IValue device; \
|
||||
bool requires_grad; \
|
||||
pop(stack, scalar_val, dtype, device, requires_grad); \
|
||||
auto tensor = autograd::make_variable(tensor_creation_op); \
|
||||
at::ScalarType scalar_type = \
|
||||
dtype.isNone() ? tensor.scalar_type() : dtype.toScalarType(); \
|
||||
c10::Device dev = \
|
||||
device.isNone() ? tensor.device() : device.toDevice(); \
|
||||
if (scalar_type != initial_scalar_type || dev != tensor.device()) { \
|
||||
tensor = tensor.to(dev, scalar_type); \
|
||||
} \
|
||||
push(stack, tensor); \
|
||||
tensor.set_requires_grad(requires_grad); \
|
||||
return 0; \
|
||||
}; \
|
||||
}),
|
||||
Operator(
|
||||
"aten::format(str self, ...) -> str",
|
||||
[](const Node* node) {
|
||||
size_t num_inputs = node->inputs().size();
|
||||
std::regex unsupported_options("\\{(.*)\\}");
|
||||
return [num_inputs, unsupported_options](Stack& stack) {
|
||||
auto format = peek(stack, 0, num_inputs).toStringRef();
|
||||
|
||||
if (std::regex_search(format, unsupported_options)) {
|
||||
AT_WARN("Format options are not supported.");
|
||||
}
|
||||
DEFINE_TORCH_TENSOR_OP(float, double, at::scalar_to_tensor(scalar_val))
|
||||
DEFINE_TORCH_TENSOR_OP(int, int64_t, at::scalar_to_tensor(scalar_val))
|
||||
DEFINE_TORCH_TENSOR_OP(
|
||||
bool,
|
||||
bool,
|
||||
at::empty({}, at::CPU(at::kByte).options()).fill_(scalar_val))
|
||||
|
||||
auto args = last(stack, num_inputs - 1);
|
||||
std::stringstream ss;
|
||||
for (size_t begin = 0, used_args = 0; true; ++used_args) {
|
||||
size_t loc = format.find("{}", begin);
|
||||
if (loc == std::string::npos) {
|
||||
ss << format.substr(begin);
|
||||
break;
|
||||
}
|
||||
ss << format.substr(begin, loc - begin);
|
||||
if (used_args >= args.size()) {
|
||||
AT_ERROR("Too few arguments for format string: ", format);
|
||||
}
|
||||
ss << args[used_args];
|
||||
begin = loc + 2;
|
||||
}
|
||||
// reference python implementation: internal_new_from_data in
|
||||
// tensor_new.cpp
|
||||
Operator(
|
||||
"aten::_infer_size(int[] a, int[] b) -> int[]",
|
||||
[](const Node* node) {
|
||||
return [](Stack& stack) {
|
||||
auto a = pop(stack).toIntList()->elements();
|
||||
auto b = pop(stack).toIntList()->elements();
|
||||
push(stack, at::infer_size(a, b));
|
||||
return 0;
|
||||
};
|
||||
}),
|
||||
Operator(
|
||||
"aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor",
|
||||
[](const Node* node) {
|
||||
return [](Stack& stack) {
|
||||
at::Tensor weight;
|
||||
at::Tensor input;
|
||||
double max_norm;
|
||||
double norm_type;
|
||||
pop(stack, weight, input, max_norm, norm_type);
|
||||
|
||||
drop(stack, num_inputs);
|
||||
push(stack, ss.str());
|
||||
return 0;
|
||||
};
|
||||
}),
|
||||
// TODO: remove when script supports setting grad mode
|
||||
torch::NoGradGuard no_grad;
|
||||
|
||||
#define DEFINE_TORCH_TENSOR_OP(operator_type, c_type, tensor_creation_op) \
|
||||
Operator( \
|
||||
"aten::tensor(" #operator_type " t, *, ScalarType? dtype=None, Device? device=None"\
|
||||
") -> Tensor", \
|
||||
[](const Node* node) { \
|
||||
auto initial_scalar_type = scalarTypeFromJitType(node->inputs().at(0)->type()); \
|
||||
return [initial_scalar_type](Stack& stack) { \
|
||||
c_type scalar_val; \
|
||||
IValue dtype; \
|
||||
IValue device; \
|
||||
pop(stack, scalar_val, dtype, device); \
|
||||
auto tensor = autograd::make_variable(tensor_creation_op); \
|
||||
at::ScalarType scalar_type = dtype.isNone() ? \
|
||||
tensor.scalar_type() : dtype.toScalarType(); \
|
||||
c10::Device dev = device.isNone() ? tensor.device() : device.toDevice(); \
|
||||
if (scalar_type != initial_scalar_type || dev != tensor.device()) { \
|
||||
tensor = tensor.to(dev, scalar_type); \
|
||||
} \
|
||||
push(stack, tensor); \
|
||||
return 0; \
|
||||
}; \
|
||||
}),
|
||||
at::Tensor result =
|
||||
at::embedding_renorm_(weight, input, max_norm, norm_type);
|
||||
push(stack, result);
|
||||
|
||||
DEFINE_TORCH_TENSOR_OP(float, double, at::scalar_to_tensor(scalar_val))
|
||||
DEFINE_TORCH_TENSOR_OP(int, int64_t, at::scalar_to_tensor(scalar_val))
|
||||
DEFINE_TORCH_TENSOR_OP(bool, bool, at::empty({}, at::CPU(at::kByte).options()).fill_(scalar_val))
|
||||
return 0;
|
||||
};
|
||||
}),
|
||||
Operator(
|
||||
"aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor",
|
||||
[](const Node* node) {
|
||||
auto input = node->inputs().at(0);
|
||||
auto elem_type = input->type();
|
||||
while (auto list_type = elem_type->cast<ListType>()) {
|
||||
elem_type = list_type->getElementType();
|
||||
}
|
||||
checkListInputType(elem_type, node);
|
||||
at::ScalarType initial_scalar_type =
|
||||
scalarTypeFromJitType(elem_type);
|
||||
return [initial_scalar_type, elem_type](Stack& stack) {
|
||||
bool requires_grad;
|
||||
IValue data;
|
||||
IValue dtype;
|
||||
IValue device;
|
||||
pop(stack, data, dtype, device, requires_grad);
|
||||
auto sizes = compute_sizes(data);
|
||||
auto tensor = autograd::make_variable(at::empty(
|
||||
sizes, at::initialTensorOptions().dtype(initial_scalar_type)));
|
||||
|
||||
recursiveStore(
|
||||
(char*)tensor.data_ptr(),
|
||||
sizes,
|
||||
tensor.strides(),
|
||||
0,
|
||||
tensor.element_size(),
|
||||
data);
|
||||
|
||||
// reference python implementation: internal_new_from_data in tensor_new.cpp
|
||||
Operator(
|
||||
"aten::_infer_size(int[] a, int[] b) -> int[]",
|
||||
[](const Node* node) {
|
||||
return [](Stack& stack) {
|
||||
auto a = pop(stack).toIntList()->elements();
|
||||
auto b = pop(stack).toIntList()->elements();
|
||||
push(stack, at::infer_size(a, b));
|
||||
return 0;
|
||||
};
|
||||
}),
|
||||
Operator(
|
||||
"aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor",
|
||||
[](const Node* node) {
|
||||
return [](Stack& stack) {
|
||||
at::Tensor weight;
|
||||
at::Tensor input;
|
||||
double max_norm;
|
||||
double norm_type;
|
||||
pop(stack, weight, input, max_norm, norm_type);
|
||||
at::ScalarType scalar_type =
|
||||
dtype.isNone() ? tensor.scalar_type() : dtype.toScalarType();
|
||||
c10::Device dev =
|
||||
device.isNone() ? tensor.device() : device.toDevice();
|
||||
if (scalar_type != initial_scalar_type || dev != tensor.device()) {
|
||||
tensor = tensor.to(dev, scalar_type);
|
||||
}
|
||||
|
||||
// TODO: remove when script supports setting grad mode
|
||||
torch::NoGradGuard no_grad;
|
||||
auto default_type =
|
||||
at::typeMetaToScalarType(at::get_default_dtype());
|
||||
|
||||
at::Tensor result =
|
||||
at::embedding_renorm_(weight, input, max_norm, norm_type);
|
||||
push(stack, result);
|
||||
if (dtype.isNone() && tensor.scalar_type() != default_type &&
|
||||
tensor.numel() == 0) {
|
||||
AT_WARN(
|
||||
"Creating a tensor from an empty ",
|
||||
elem_type->str(),
|
||||
"list will create a tensor of default floating point type (currently ",
|
||||
default_type,
|
||||
") in python but a tensor of type ",
|
||||
elem_type->str(),
|
||||
" in torchscript.\n",
|
||||
"Pass in a dtype argument to ensure consistent behavior");
|
||||
}
|
||||
tensor.set_requires_grad(requires_grad);
|
||||
push(stack, tensor);
|
||||
return 0;
|
||||
};
|
||||
}),
|
||||
Operator(
|
||||
"aten::_assert_int_or_pair(int[] vals, str name, str message) -> Tensor",
|
||||
[](const Node* node) {
|
||||
return [](Stack& stack) {
|
||||
// Everything is a list at the point this is used, so don't do
|
||||
// anything
|
||||
drop(stack, 3);
|
||||
return 0;
|
||||
};
|
||||
}),
|
||||
Operator(
|
||||
"aten::_pack_sequence(Tensor output, Tensor batch_sizes, Tensor? sorted_indices, "
|
||||
"Tensor? unsorted_indices) -> (Tensor, Tensor, Tensor?, Tensor?)",
|
||||
[](Stack& stack) { return 0; })
|
||||
|
||||
return 0;
|
||||
};
|
||||
}),
|
||||
Operator(
|
||||
"aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None) -> Tensor",
|
||||
[](const Node* node) {
|
||||
auto input = node->inputs().at(0);
|
||||
auto elem_type = input->type();
|
||||
while (auto list_type = elem_type->cast<ListType>()) {
|
||||
elem_type = list_type->getElementType();
|
||||
}
|
||||
checkListInputType(elem_type, node);
|
||||
at::ScalarType initial_scalar_type = scalarTypeFromJitType(elem_type);
|
||||
return [initial_scalar_type, elem_type](Stack& stack) {
|
||||
IValue data;
|
||||
IValue dtype;
|
||||
IValue device;
|
||||
pop(stack, data, dtype, device);
|
||||
auto sizes = compute_sizes(data);
|
||||
auto tensor = autograd::make_variable(
|
||||
at::empty(sizes, at::initialTensorOptions().dtype(initial_scalar_type)));
|
||||
|
||||
recursiveStore((char*)tensor.data_ptr(), sizes, tensor.strides(), 0,
|
||||
tensor.element_size(), data);
|
||||
|
||||
at::ScalarType scalar_type = dtype.isNone() ? tensor.scalar_type() : dtype.toScalarType();
|
||||
c10::Device dev = device.isNone() ? tensor.device() : device.toDevice();
|
||||
if (scalar_type != initial_scalar_type || dev != tensor.device()) {
|
||||
tensor = tensor.to(dev, scalar_type);
|
||||
}
|
||||
|
||||
auto default_type = at::typeMetaToScalarType(at::get_default_dtype());
|
||||
|
||||
if (dtype.isNone() && tensor.scalar_type() != default_type &&
|
||||
tensor.numel() == 0) {
|
||||
AT_WARN("Creating a tensor from an empty ", elem_type->str(),
|
||||
"list will create a tensor of default floating point type (currently ", default_type,
|
||||
") in python but a tensor of type ", elem_type->str(), " in torchscript.\n",
|
||||
"Pass in a dtype argument to ensure consistent behavior");
|
||||
}
|
||||
|
||||
push(stack, tensor);
|
||||
return 0;
|
||||
};
|
||||
}),
|
||||
Operator(
|
||||
"aten::_assert_int_or_pair(int[] vals, str name, str message) -> Tensor",
|
||||
[](const Node* node) {
|
||||
return [](Stack& stack) {
|
||||
// Everything is a list at the point this is used, so don't do
|
||||
// anything
|
||||
drop(stack, 3);
|
||||
return 0;
|
||||
};
|
||||
}),
|
||||
Operator(
|
||||
"aten::_pack_sequence(Tensor output, Tensor batch_sizes, Tensor? sorted_indices, "
|
||||
"Tensor? unsorted_indices) -> (Tensor, Tensor, Tensor?, Tensor?)",
|
||||
[](Stack& stack) {
|
||||
return 0;
|
||||
})
|
||||
|
||||
});
|
||||
}
|
||||
});
|
||||
} // namespace
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user