mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Improved JIT support for torch.einsum (#59265)
Summary: Added JIT support for the vararg version of `torch.einsum`. Note that JIT does not support the Python's Ellipsis object (`...`) Pull Request resolved: https://github.com/pytorch/pytorch/pull/59265 Reviewed By: VitalyFedyunin Differential Revision: D29328469 Pulled By: heitorschueroff fbshipit-source-id: 5e4b177fda93255251f45d735b00c08220f0f124
This commit is contained in:
parent
d46eb77b04
commit
8f658d537d
|
|
@ -1743,20 +1743,18 @@ graph(%Ra, %Rb):
|
||||||
def equation_format(x, y):
|
def equation_format(x, y):
|
||||||
return torch.einsum('i,j->ij', (x, y))
|
return torch.einsum('i,j->ij', (x, y))
|
||||||
|
|
||||||
|
def equation_format_varargs(x, y):
|
||||||
|
return torch.einsum('i,j->ij', x, y)
|
||||||
|
|
||||||
def sublist_format(x, y):
|
def sublist_format(x, y):
|
||||||
return torch.einsum(x, [0], y, [1], [0, 1])
|
return torch.einsum(x, [0], y, [1], [0, 1])
|
||||||
|
|
||||||
# Sublist format cannot be scripted because it is
|
|
||||||
# a NumPy API only feature
|
|
||||||
with self.assertRaises(RuntimeError):
|
|
||||||
torch.jit.script(sublist_format)
|
|
||||||
|
|
||||||
x = make_tensor((5,), 'cpu', torch.float32)
|
x = make_tensor((5,), 'cpu', torch.float32)
|
||||||
y = make_tensor((10,), 'cpu', torch.float32)
|
y = make_tensor((10,), 'cpu', torch.float32)
|
||||||
|
|
||||||
check(equation_format, torch.jit.script(equation_format), x, y)
|
for fn in [equation_format, equation_format_varargs, sublist_format]:
|
||||||
check(equation_format, torch.jit.trace(equation_format, (x, y)), x, y)
|
check(fn, torch.jit.script(fn), x, y)
|
||||||
check(sublist_format, torch.jit.trace(sublist_format, (x, y)), x, y)
|
check(fn, torch.jit.trace(fn, (x, y)), x, y)
|
||||||
|
|
||||||
def test_python_ivalue(self):
|
def test_python_ivalue(self):
|
||||||
# Test if pure python object can be hold as IValue and conversion
|
# Test if pure python object can be hold as IValue and conversion
|
||||||
|
|
|
||||||
|
|
@ -442,6 +442,14 @@ RegisterOperators reg(
|
||||||
format(*stack, num_inputs);
|
format(*stack, num_inputs);
|
||||||
},
|
},
|
||||||
aliasAnalysisFromSchema()),
|
aliasAnalysisFromSchema()),
|
||||||
|
OperatorGenerator(
|
||||||
|
TORCH_SELECTIVE_SCHEMA(
|
||||||
|
"aten::einsum.sublist(Tensor a, ...) -> Tensor"),
|
||||||
|
[](Stack* stack) {
|
||||||
|
size_t num_inputs = pop(stack).toInt();
|
||||||
|
einsum(*stack, num_inputs);
|
||||||
|
},
|
||||||
|
aliasAnalysisFromSchema()),
|
||||||
OperatorGenerator(
|
OperatorGenerator(
|
||||||
TORCH_SELECTIVE_SCHEMA("prim::NumToTensor.Scalar(Scalar a) -> Tensor"),
|
TORCH_SELECTIVE_SCHEMA("prim::NumToTensor.Scalar(Scalar a) -> Tensor"),
|
||||||
[](Stack* stack) {
|
[](Stack* stack) {
|
||||||
|
|
|
||||||
|
|
@ -133,6 +133,84 @@ void format(Stack& stack, size_t num_inputs) {
|
||||||
push(stack, ss.str());
|
push(stack, ss.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void einsum(Stack& stack, size_t num_inputs) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
num_inputs >= 2,
|
||||||
|
"einsum(): must specify the equation string and at least one operand, ",
|
||||||
|
"or at least one operand and its subscripts list");
|
||||||
|
|
||||||
|
const auto args = last(stack, num_inputs);
|
||||||
|
|
||||||
|
// Convert the subscript list format which is an interleaving of operand and
|
||||||
|
// its subscripts list with an optional output subscripts list at the end
|
||||||
|
// (see documentation for more details on this) to the equation string
|
||||||
|
// format by creating the equation string from the subscripts list and
|
||||||
|
// grouping the input operands into a tensorlist (List[Tensor]).
|
||||||
|
std::stringstream ss;
|
||||||
|
|
||||||
|
auto parse_sublist = [&ss](const c10::List<int64_t>& l, size_t arg_num) {
|
||||||
|
for (const auto i : c10::irange(l.size())) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
l[i] >= 0 && l[i] < 52,
|
||||||
|
"einsum(): expected subscript ",
|
||||||
|
i,
|
||||||
|
" in argument ",
|
||||||
|
arg_num,
|
||||||
|
" to be within the range [0, 52), but got ",
|
||||||
|
l[i]);
|
||||||
|
if (l[i] < 26) {
|
||||||
|
ss << static_cast<char>(l[i] + 'A');
|
||||||
|
} else {
|
||||||
|
ss << static_cast<char>(l[i] - 26 + 'a');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse subscripts for input operands
|
||||||
|
for (auto i = decltype(num_inputs){1}; i < num_inputs; i += 2) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
args[i].isIntList(),
|
||||||
|
"einsum(): expected List[int] in argument ",
|
||||||
|
i,
|
||||||
|
", but got ",
|
||||||
|
args[i].type()->repr_str());
|
||||||
|
parse_sublist(args[i].toIntList(), i);
|
||||||
|
if (i + 2 < num_inputs) {
|
||||||
|
ss << ',';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse optional output subscripts (provided if #args is odd)
|
||||||
|
if (num_inputs % 2 == 1) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
args.back().isIntList(),
|
||||||
|
"einsum(): expected List[int] in argument ",
|
||||||
|
num_inputs - 1,
|
||||||
|
", but got ",
|
||||||
|
args.back().type()->repr_str());
|
||||||
|
ss << "->";
|
||||||
|
parse_sublist(args.back().toIntList(), num_inputs - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto equation = ss.str();
|
||||||
|
std::vector<at::Tensor> operands;
|
||||||
|
|
||||||
|
// Parse input operands
|
||||||
|
const auto end = num_inputs % 2 == 1 ? num_inputs - 1 : num_inputs;
|
||||||
|
for (auto i = decltype(num_inputs){0}; i < end; i += 2) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
args[i].isTensor(),
|
||||||
|
"einsum(): expected Tensor in argument ",
|
||||||
|
i,
|
||||||
|
", but got ",
|
||||||
|
args[i].type()->repr_str());
|
||||||
|
operands.emplace_back(args[i].toTensor());
|
||||||
|
}
|
||||||
|
|
||||||
|
drop(stack, num_inputs);
|
||||||
|
push(stack, at::einsum(equation, operands));
|
||||||
|
}
|
||||||
|
|
||||||
void percentFormat(Stack& stack, size_t num_inputs) {
|
void percentFormat(Stack& stack, size_t num_inputs) {
|
||||||
auto format_str = peek(stack, 0, num_inputs).toStringRef();
|
auto format_str = peek(stack, 0, num_inputs).toStringRef();
|
||||||
auto args = last(stack, num_inputs - 1)[0];
|
auto args = last(stack, num_inputs - 1)[0];
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,8 @@ void tupleUnpack(Stack& stack);
|
||||||
|
|
||||||
void format(Stack& stack, size_t num_inputs);
|
void format(Stack& stack, size_t num_inputs);
|
||||||
|
|
||||||
|
void einsum(Stack& stack, size_t num_inputs);
|
||||||
|
|
||||||
void percentFormat(Stack& stack, size_t num_inputs);
|
void percentFormat(Stack& stack, size_t num_inputs);
|
||||||
|
|
||||||
void listUnpack(Stack& stack, size_t num_outputs);
|
void listUnpack(Stack& stack, size_t num_outputs);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user