mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Improved JIT support for torch.einsum (#59265)
Summary: Added JIT support for the vararg version of `torch.einsum`. Note that JIT does not support the Python's Ellipsis object (`...`) Pull Request resolved: https://github.com/pytorch/pytorch/pull/59265 Reviewed By: VitalyFedyunin Differential Revision: D29328469 Pulled By: heitorschueroff fbshipit-source-id: 5e4b177fda93255251f45d735b00c08220f0f124
This commit is contained in:
parent
d46eb77b04
commit
8f658d537d
|
|
@ -1743,20 +1743,18 @@ graph(%Ra, %Rb):
|
|||
def equation_format(x, y):
|
||||
return torch.einsum('i,j->ij', (x, y))
|
||||
|
||||
def equation_format_varargs(x, y):
|
||||
return torch.einsum('i,j->ij', x, y)
|
||||
|
||||
def sublist_format(x, y):
|
||||
return torch.einsum(x, [0], y, [1], [0, 1])
|
||||
|
||||
# Sublist format cannot be scripted because it is
|
||||
# a NumPy API only feature
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.jit.script(sublist_format)
|
||||
|
||||
x = make_tensor((5,), 'cpu', torch.float32)
|
||||
y = make_tensor((10,), 'cpu', torch.float32)
|
||||
|
||||
check(equation_format, torch.jit.script(equation_format), x, y)
|
||||
check(equation_format, torch.jit.trace(equation_format, (x, y)), x, y)
|
||||
check(sublist_format, torch.jit.trace(sublist_format, (x, y)), x, y)
|
||||
for fn in [equation_format, equation_format_varargs, sublist_format]:
|
||||
check(fn, torch.jit.script(fn), x, y)
|
||||
check(fn, torch.jit.trace(fn, (x, y)), x, y)
|
||||
|
||||
def test_python_ivalue(self):
|
||||
# Test if pure python object can be hold as IValue and conversion
|
||||
|
|
|
|||
|
|
@ -442,6 +442,14 @@ RegisterOperators reg(
|
|||
format(*stack, num_inputs);
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGenerator(
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
"aten::einsum.sublist(Tensor a, ...) -> Tensor"),
|
||||
[](Stack* stack) {
|
||||
size_t num_inputs = pop(stack).toInt();
|
||||
einsum(*stack, num_inputs);
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGenerator(
|
||||
TORCH_SELECTIVE_SCHEMA("prim::NumToTensor.Scalar(Scalar a) -> Tensor"),
|
||||
[](Stack* stack) {
|
||||
|
|
|
|||
|
|
@ -133,6 +133,84 @@ void format(Stack& stack, size_t num_inputs) {
|
|||
push(stack, ss.str());
|
||||
}
|
||||
|
||||
void einsum(Stack& stack, size_t num_inputs) {
|
||||
TORCH_CHECK(
|
||||
num_inputs >= 2,
|
||||
"einsum(): must specify the equation string and at least one operand, ",
|
||||
"or at least one operand and its subscripts list");
|
||||
|
||||
const auto args = last(stack, num_inputs);
|
||||
|
||||
// Convert the subscript list format which is an interleaving of operand and
|
||||
// its subscripts list with an optional output subscripts list at the end
|
||||
// (see documentation for more details on this) to the equation string
|
||||
// format by creating the equation string from the subscripts list and
|
||||
// grouping the input operands into a tensorlist (List[Tensor]).
|
||||
std::stringstream ss;
|
||||
|
||||
auto parse_sublist = [&ss](const c10::List<int64_t>& l, size_t arg_num) {
|
||||
for (const auto i : c10::irange(l.size())) {
|
||||
TORCH_CHECK(
|
||||
l[i] >= 0 && l[i] < 52,
|
||||
"einsum(): expected subscript ",
|
||||
i,
|
||||
" in argument ",
|
||||
arg_num,
|
||||
" to be within the range [0, 52), but got ",
|
||||
l[i]);
|
||||
if (l[i] < 26) {
|
||||
ss << static_cast<char>(l[i] + 'A');
|
||||
} else {
|
||||
ss << static_cast<char>(l[i] - 26 + 'a');
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Parse subscripts for input operands
|
||||
for (auto i = decltype(num_inputs){1}; i < num_inputs; i += 2) {
|
||||
TORCH_CHECK(
|
||||
args[i].isIntList(),
|
||||
"einsum(): expected List[int] in argument ",
|
||||
i,
|
||||
", but got ",
|
||||
args[i].type()->repr_str());
|
||||
parse_sublist(args[i].toIntList(), i);
|
||||
if (i + 2 < num_inputs) {
|
||||
ss << ',';
|
||||
}
|
||||
}
|
||||
|
||||
// Parse optional output subscripts (provided if #args is odd)
|
||||
if (num_inputs % 2 == 1) {
|
||||
TORCH_CHECK(
|
||||
args.back().isIntList(),
|
||||
"einsum(): expected List[int] in argument ",
|
||||
num_inputs - 1,
|
||||
", but got ",
|
||||
args.back().type()->repr_str());
|
||||
ss << "->";
|
||||
parse_sublist(args.back().toIntList(), num_inputs - 1);
|
||||
}
|
||||
|
||||
const auto equation = ss.str();
|
||||
std::vector<at::Tensor> operands;
|
||||
|
||||
// Parse input operands
|
||||
const auto end = num_inputs % 2 == 1 ? num_inputs - 1 : num_inputs;
|
||||
for (auto i = decltype(num_inputs){0}; i < end; i += 2) {
|
||||
TORCH_CHECK(
|
||||
args[i].isTensor(),
|
||||
"einsum(): expected Tensor in argument ",
|
||||
i,
|
||||
", but got ",
|
||||
args[i].type()->repr_str());
|
||||
operands.emplace_back(args[i].toTensor());
|
||||
}
|
||||
|
||||
drop(stack, num_inputs);
|
||||
push(stack, at::einsum(equation, operands));
|
||||
}
|
||||
|
||||
void percentFormat(Stack& stack, size_t num_inputs) {
|
||||
auto format_str = peek(stack, 0, num_inputs).toStringRef();
|
||||
auto args = last(stack, num_inputs - 1)[0];
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ void tupleUnpack(Stack& stack);
|
|||
|
||||
void format(Stack& stack, size_t num_inputs);
|
||||
|
||||
void einsum(Stack& stack, size_t num_inputs);
|
||||
|
||||
void percentFormat(Stack& stack, size_t num_inputs);
|
||||
|
||||
void listUnpack(Stack& stack, size_t num_outputs);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user