Improved JIT support for torch.einsum (#59265)

Summary:
Added JIT support for the vararg version of `torch.einsum`. Note that JIT does not support the Python's Ellipsis object (`...`)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59265

Reviewed By: VitalyFedyunin

Differential Revision: D29328469

Pulled By: heitorschueroff

fbshipit-source-id: 5e4b177fda93255251f45d735b00c08220f0f124
This commit is contained in:
Heitor Schueroff 2021-06-29 13:59:02 -07:00 committed by Facebook GitHub Bot
parent d46eb77b04
commit 8f658d537d
4 changed files with 94 additions and 8 deletions

View File

@ -1743,20 +1743,18 @@ graph(%Ra, %Rb):
def equation_format(x, y):
return torch.einsum('i,j->ij', (x, y))
def equation_format_varargs(x, y):
return torch.einsum('i,j->ij', x, y)
def sublist_format(x, y):
return torch.einsum(x, [0], y, [1], [0, 1])
# Sublist format cannot be scripted because it is
# a NumPy API only feature
with self.assertRaises(RuntimeError):
torch.jit.script(sublist_format)
x = make_tensor((5,), 'cpu', torch.float32)
y = make_tensor((10,), 'cpu', torch.float32)
check(equation_format, torch.jit.script(equation_format), x, y)
check(equation_format, torch.jit.trace(equation_format, (x, y)), x, y)
check(sublist_format, torch.jit.trace(sublist_format, (x, y)), x, y)
for fn in [equation_format, equation_format_varargs, sublist_format]:
check(fn, torch.jit.script(fn), x, y)
check(fn, torch.jit.trace(fn, (x, y)), x, y)
def test_python_ivalue(self):
# Test if pure python object can be hold as IValue and conversion

View File

@ -442,6 +442,14 @@ RegisterOperators reg(
format(*stack, num_inputs);
},
aliasAnalysisFromSchema()),
OperatorGenerator(
TORCH_SELECTIVE_SCHEMA(
"aten::einsum.sublist(Tensor a, ...) -> Tensor"),
[](Stack* stack) {
size_t num_inputs = pop(stack).toInt();
einsum(*stack, num_inputs);
},
aliasAnalysisFromSchema()),
OperatorGenerator(
TORCH_SELECTIVE_SCHEMA("prim::NumToTensor.Scalar(Scalar a) -> Tensor"),
[](Stack* stack) {

View File

@ -133,6 +133,84 @@ void format(Stack& stack, size_t num_inputs) {
push(stack, ss.str());
}
void einsum(Stack& stack, size_t num_inputs) {
TORCH_CHECK(
num_inputs >= 2,
"einsum(): must specify the equation string and at least one operand, ",
"or at least one operand and its subscripts list");
const auto args = last(stack, num_inputs);
// Convert the subscript list format which is an interleaving of operand and
// its subscripts list with an optional output subscripts list at the end
// (see documentation for more details on this) to the equation string
// format by creating the equation string from the subscripts list and
// grouping the input operands into a tensorlist (List[Tensor]).
std::stringstream ss;
auto parse_sublist = [&ss](const c10::List<int64_t>& l, size_t arg_num) {
for (const auto i : c10::irange(l.size())) {
TORCH_CHECK(
l[i] >= 0 && l[i] < 52,
"einsum(): expected subscript ",
i,
" in argument ",
arg_num,
" to be within the range [0, 52), but got ",
l[i]);
if (l[i] < 26) {
ss << static_cast<char>(l[i] + 'A');
} else {
ss << static_cast<char>(l[i] - 26 + 'a');
}
}
};
// Parse subscripts for input operands
for (auto i = decltype(num_inputs){1}; i < num_inputs; i += 2) {
TORCH_CHECK(
args[i].isIntList(),
"einsum(): expected List[int] in argument ",
i,
", but got ",
args[i].type()->repr_str());
parse_sublist(args[i].toIntList(), i);
if (i + 2 < num_inputs) {
ss << ',';
}
}
// Parse optional output subscripts (provided if #args is odd)
if (num_inputs % 2 == 1) {
TORCH_CHECK(
args.back().isIntList(),
"einsum(): expected List[int] in argument ",
num_inputs - 1,
", but got ",
args.back().type()->repr_str());
ss << "->";
parse_sublist(args.back().toIntList(), num_inputs - 1);
}
const auto equation = ss.str();
std::vector<at::Tensor> operands;
// Parse input operands
const auto end = num_inputs % 2 == 1 ? num_inputs - 1 : num_inputs;
for (auto i = decltype(num_inputs){0}; i < end; i += 2) {
TORCH_CHECK(
args[i].isTensor(),
"einsum(): expected Tensor in argument ",
i,
", but got ",
args[i].type()->repr_str());
operands.emplace_back(args[i].toTensor());
}
drop(stack, num_inputs);
push(stack, at::einsum(equation, operands));
}
void percentFormat(Stack& stack, size_t num_inputs) {
auto format_str = peek(stack, 0, num_inputs).toStringRef();
auto args = last(stack, num_inputs - 1)[0];

View File

@ -12,6 +12,8 @@ void tupleUnpack(Stack& stack);
void format(Stack& stack, size_t num_inputs);
void einsum(Stack& stack, size_t num_inputs);
void percentFormat(Stack& stack, size_t num_inputs);
void listUnpack(Stack& stack, size_t num_outputs);