Add support for log_softmax (#47409)

Summary:
This diff adds support for `log_softmax` op in NNC.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/47409

Reviewed By: ejguan

Differential Revision: D24750203

Pulled By: navahgar

fbshipit-source-id: c4dacc7f62f9df65ae467f0d578ea03d3698273d
This commit is contained in:
Raghavan Raman 2020-11-06 13:25:01 -08:00 committed by Facebook GitHub Bot
parent 582e852fba
commit 8eb228a7f3
5 changed files with 181 additions and 125 deletions

View File

@ -635,7 +635,7 @@ void testKernelSoftmax2D() {
graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)):
%1 : int = prim::Constant[value=${dim}]()
%2 : int = prim::Constant[value=7]()
%3 : Tensor = aten::softmax(%0, %1, %2)
%3 : Tensor = aten::${op}(%0, %1, %2)
return (%3))IR";
auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
@ -652,39 +652,43 @@ void testKernelSoftmax2D() {
# CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3
# CHECK-NEXT: aten_softmax)IR";
for (int softmax_dim = 0; softmax_dim < a.dim(); ++softmax_dim) {
auto softmax_dim_size = a.sizes()[softmax_dim];
auto other_dim = (softmax_dim + 1) % a.dim();
for (auto log_softmax : {false, true}) {
for (int softmax_dim = 0; softmax_dim < a.dim(); ++softmax_dim) {
auto softmax_dim_size = a.sizes()[softmax_dim];
auto other_dim = (softmax_dim + 1) % a.dim();
KernelScope kernel_scope;
TemplateEnv env;
env.d("dim", softmax_dim);
const auto graph_string = format(graph_template, env);
KernelScope kernel_scope;
TemplateEnv env;
env.d("dim", softmax_dim);
env.s("op", log_softmax ? "log_softmax" : "softmax");
const auto graph_string = format(graph_template, env);
auto graph = std::make_shared<Graph>();
parseIR(graph_string, &*graph);
auto graph = std::make_shared<Graph>();
parseIR(graph_string, &*graph);
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {a};
Stmt* s = k.getCodeGenStmt();
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {a};
Stmt* s = k.getCodeGenStmt();
std::ostringstream oss;
oss << *s;
std::ostringstream oss;
oss << *s;
TemplateEnv ver_env;
ver_env.d("other_dim", other_dim);
ver_env.d("other_dim_size", a.sizes()[other_dim]);
ver_env.d("softmax_dim", softmax_dim);
ver_env.d("softmax_dim_size", softmax_dim_size);
const auto verification_pattern = format(verification_template, ver_env);
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
TemplateEnv ver_env;
ver_env.d("other_dim", other_dim);
ver_env.d("other_dim_size", a.sizes()[other_dim]);
ver_env.d("softmax_dim", softmax_dim);
ver_env.d("softmax_dim_size", softmax_dim_size);
const auto verification_pattern = format(verification_template, ver_env);
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
std::vector<IValue> stack = fmap<IValue>(inputs);
k.run(stack);
auto output = stack[0].toTensor();
auto ref = a.softmax(softmax_dim);
ASSERT_EQ(output.sizes(), ref.sizes());
ASSERT_TRUE(at::allclose(output, ref));
std::vector<IValue> stack = fmap<IValue>(inputs);
k.run(stack);
auto output = stack[0].toTensor();
auto ref =
log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim);
ASSERT_EQ(output.sizes(), ref.sizes());
ASSERT_TRUE(at::allclose(output, ref));
}
}
}
@ -693,7 +697,7 @@ void testKernelSoftmax3D() {
graph(%0 : Float(3, 4, 5, strides=[20, 5, 1], device=cpu)):
%1 : int = prim::Constant[value=${dim}]()
%2 : int = prim::Constant[value=7]()
%3 : Tensor = aten::softmax(%0, %1, %2)
%3 : Tensor = aten::${op}(%0, %1, %2)
return (%3))IR";
auto a = at::rand({3, 4, 5}, TensorOptions(kCPU).dtype(at::kFloat));
@ -713,47 +717,51 @@ void testKernelSoftmax3D() {
# CHECK-NEXT: for (int i2_2 = 0; i2_2 < 5
# CHECK-NEXT: aten_softmax)IR";
for (int softmax_dim = 0; softmax_dim < a.dim(); ++softmax_dim) {
auto softmax_dim_size = a.sizes()[softmax_dim];
std::vector<int> other_dims;
for (int i = 0; i < a.dim(); ++i) {
if (i != softmax_dim) {
other_dims.push_back(i);
for (auto log_softmax : {false, true}) {
for (int softmax_dim = 0; softmax_dim < a.dim(); ++softmax_dim) {
auto softmax_dim_size = a.sizes()[softmax_dim];
std::vector<int> other_dims;
for (int i = 0; i < a.dim(); ++i) {
if (i != softmax_dim) {
other_dims.push_back(i);
}
}
KernelScope kernel_scope;
TemplateEnv env;
env.d("dim", softmax_dim);
env.s("op", log_softmax ? "log_softmax" : "softmax");
const auto graph_string = format(graph_template, env);
auto graph = std::make_shared<Graph>();
parseIR(graph_string, &*graph);
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {a};
Stmt* s = k.getCodeGenStmt();
std::ostringstream oss;
oss << *s;
TemplateEnv ver_env;
ver_env.d("dim1", other_dims[0]);
ver_env.d("dim1_size", a.sizes()[other_dims[0]]);
ver_env.d("dim2", other_dims[1]);
ver_env.d("dim2_size", a.sizes()[other_dims[1]]);
ver_env.d("softmax_dim", softmax_dim);
ver_env.d("softmax_dim_size", softmax_dim_size);
const auto verification_pattern = format(verification_template, ver_env);
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
std::vector<IValue> stack = fmap<IValue>(inputs);
k.run(stack);
auto output = stack[0].toTensor();
auto ref =
log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim);
ASSERT_EQ(output.sizes(), ref.sizes());
ASSERT_TRUE(at::allclose(output, ref));
}
KernelScope kernel_scope;
TemplateEnv env;
env.d("dim", softmax_dim);
const auto graph_string = format(graph_template, env);
auto graph = std::make_shared<Graph>();
parseIR(graph_string, &*graph);
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {a};
Stmt* s = k.getCodeGenStmt();
std::ostringstream oss;
oss << *s;
TemplateEnv ver_env;
ver_env.d("dim1", other_dims[0]);
ver_env.d("dim1_size", a.sizes()[other_dims[0]]);
ver_env.d("dim2", other_dims[1]);
ver_env.d("dim2_size", a.sizes()[other_dims[1]]);
ver_env.d("softmax_dim", softmax_dim);
ver_env.d("softmax_dim_size", softmax_dim_size);
const auto verification_pattern = format(verification_template, ver_env);
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
std::vector<IValue> stack = fmap<IValue>(inputs);
k.run(stack);
auto output = stack[0].toTensor();
auto ref = a.softmax(softmax_dim);
ASSERT_EQ(output.sizes(), ref.sizes());
ASSERT_TRUE(at::allclose(output, ref));
}
}
@ -762,7 +770,7 @@ void testKernelSoftmax4D() {
graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], device=cpu)):
%1 : int = prim::Constant[value=${dim}]()
%2 : int = prim::Constant[value=7]()
%3 : Tensor = aten::softmax(%0, %1, %2)
%3 : Tensor = aten::${op}(%0, %1, %2)
return (%3))IR";
auto a = at::rand({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat));
@ -785,48 +793,52 @@ void testKernelSoftmax4D() {
# CHECK-NEXT: for (int i3_2 = 0; i3_2 < 3
# CHECK-NEXT: aten_softmax)IR";
for (int softmax_dim = 0; softmax_dim < a.dim(); ++softmax_dim) {
auto softmax_dim_size = a.sizes()[softmax_dim];
std::vector<int> other_dims;
for (int i = 0; i < a.dim(); ++i) {
if (i != softmax_dim) {
other_dims.push_back(i);
for (auto log_softmax : {false, true}) {
for (int softmax_dim = 0; softmax_dim < a.dim(); ++softmax_dim) {
auto softmax_dim_size = a.sizes()[softmax_dim];
std::vector<int> other_dims;
for (int i = 0; i < a.dim(); ++i) {
if (i != softmax_dim) {
other_dims.push_back(i);
}
}
KernelScope kernel_scope;
TemplateEnv env;
env.d("dim", softmax_dim);
env.s("op", log_softmax ? "log_softmax" : "softmax");
const auto graph_string = format(graph_template, env);
auto graph = std::make_shared<Graph>();
parseIR(graph_string, &*graph);
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {a};
Stmt* s = k.getCodeGenStmt();
std::ostringstream oss;
oss << *s;
TemplateEnv ver_env;
ver_env.d("dim1", other_dims[0]);
ver_env.d("dim1_size", a.sizes()[other_dims[0]]);
ver_env.d("dim2", other_dims[1]);
ver_env.d("dim2_size", a.sizes()[other_dims[1]]);
ver_env.d("dim3", other_dims[2]);
ver_env.d("dim3_size", a.sizes()[other_dims[2]]);
ver_env.d("softmax_dim", softmax_dim);
ver_env.d("softmax_dim_size", softmax_dim_size);
const auto verification_pattern = format(verification_template, ver_env);
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
std::vector<IValue> stack = fmap<IValue>(inputs);
k.run(stack);
auto output = stack[0].toTensor();
auto ref =
log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim);
ASSERT_EQ(output.sizes(), ref.sizes());
ASSERT_TRUE(at::allclose(output, ref));
}
KernelScope kernel_scope;
TemplateEnv env;
env.d("dim", softmax_dim);
const auto graph_string = format(graph_template, env);
auto graph = std::make_shared<Graph>();
parseIR(graph_string, &*graph);
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {a};
Stmt* s = k.getCodeGenStmt();
std::ostringstream oss;
oss << *s;
TemplateEnv ver_env;
ver_env.d("dim1", other_dims[0]);
ver_env.d("dim1_size", a.sizes()[other_dims[0]]);
ver_env.d("dim2", other_dims[1]);
ver_env.d("dim2_size", a.sizes()[other_dims[1]]);
ver_env.d("dim3", other_dims[2]);
ver_env.d("dim3_size", a.sizes()[other_dims[2]]);
ver_env.d("softmax_dim", softmax_dim);
ver_env.d("softmax_dim_size", softmax_dim_size);
const auto verification_pattern = format(verification_template, ver_env);
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
std::vector<IValue> stack = fmap<IValue>(inputs);
k.run(stack);
auto output = stack[0].toTensor();
auto ref = a.softmax(softmax_dim);
ASSERT_EQ(output.sizes(), ref.sizes());
ASSERT_TRUE(at::allclose(output, ref));
}
}

View File

@ -1226,21 +1226,29 @@ class TestTensorExprFuser(BaseTestClass):
assert llvm.elapsed_value() == 1 or interp.elapsed_value() > 1
def _test_softmax(self, device):
def test(x, y):
def test_softmax(x, y):
a = F.softmax(x, dim=0, dtype=torch.float32)
b = F.softmax(y, dim=0, dtype=torch.float32)
c = F.softmax(x, dim=1, dtype=torch.float32)
d = F.softmax(y, dim=1, dtype=torch.float32)
return a + b + c + d
old = torch._C._jit_set_texpr_reductions_enabled(True)
traced = torch.jit.trace(test, (torch.randn(2, 3, device=device), torch.randn(2, 3, device=device)))
inp = torch.randn(2, 3, device=device)
res = traced(inp, inp)
# Use eager mode as reference.
ref = test(inp, inp)
np.testing.assert_allclose(ref, res.cpu().numpy(), rtol=1e-06, atol=1e-06)
torch._C._jit_set_texpr_reductions_enabled(old)
def test_log_softmax(x, y):
a = F.log_softmax(x, dim=0, dtype=torch.float32)
b = F.log_softmax(y, dim=0, dtype=torch.float32)
c = F.log_softmax(x, dim=1, dtype=torch.float32)
d = F.log_softmax(y, dim=1, dtype=torch.float32)
return a + b + c + d
for test in (test_softmax, test_log_softmax):
old = torch._C._jit_set_texpr_reductions_enabled(True)
traced = torch.jit.trace(test, (torch.randn(2, 3, device=device), torch.randn(2, 3, device=device)))
inp = torch.randn(2, 3, device=device)
res = traced(inp, inp)
# Use eager mode as reference.
ref = test(inp, inp)
np.testing.assert_allclose(ref, res.cpu().numpy(), rtol=1e-06, atol=1e-06)
torch._C._jit_set_texpr_reductions_enabled(old)
def test_softmax_cpu(self):
llvm = LLVMCodeGenExecuted()

View File

@ -153,6 +153,7 @@ bool isSupported(Node* node) {
"aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor",
"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
"aten::softmax.int(Tensor self, int dim , ScalarType? dtype=None) -> Tensor",
"aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor",
};
// clang-format on

View File

@ -371,7 +371,8 @@ std::vector<ExprHandle> TensorExprKernel::inferSizesForValue(
}
case aten::softmax:
// Output of Softmax has the same shape as input 0.
case aten::log_softmax:
// Output of softmax / log_softmax has the same shape as input 0.
return sizesForValue(v->node()->input(0));
case aten::slice:
@ -1353,7 +1354,11 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
}
case aten::softmax: {
return computeSoftmax(v);
return computeSoftmax(v, false);
}
case aten::log_softmax: {
return computeSoftmax(v, true);
}
default: {
@ -1626,7 +1631,9 @@ Tensor* TensorExprKernel::computeSum(const torch::jit::Value* v) {
reduction_info.reductionDims);
}
Tensor* TensorExprKernel::computeSoftmax(const torch::jit::Value* v) {
Tensor* TensorExprKernel::computeSoftmax(
const torch::jit::Value* v,
bool log_softmax) {
// Softmax is computed as follows:
// softmax(vi) = exp(vi) / sum(exp(vi))
//
@ -1641,6 +1648,21 @@ Tensor* TensorExprKernel::computeSoftmax(const torch::jit::Value* v) {
// - Third loop computes the sum over the softmax dim.
// - Final loop computes softmax for every element in v.
// LogSoftmax is computed as follows:
// log_softmax(vi) = log(softmax(vi))
// = vi - log(sum(exp(vi)))
//
// Using the same max trick as above:
// log_softmax(vi) = vi - max(vi) - log(sum(exp(vi - max(vi))))
//
// This is implemented as 5 loopnests:
// - First loop computes the max over the softmax dim.
// - Second loop computes exp for every element in v after subtracting
// the max of the softmax dim it belongs to.
// - Third loop computes the sum over the softmax dim.
// - Fourth loop computes log for every element in the sum.
// - Final loop computes the log_softmax for every element in v.
TORCH_INTERNAL_ASSERT(v->node()->inputs().size() == 3);
auto output_dims = dimsFromSizes(sizesForValue(v));
@ -1724,10 +1746,23 @@ Tensor* TensorExprKernel::computeSoftmax(const torch::jit::Value* v) {
return e->call(move_softmax_dim_index_to_pos(indices));
},
{output_dims[softmax_dim]});
auto res = Compute("aten_softmax", output_dims, [&](ParameterList& indices) {
return e->call(indices) / sum->call(remove_softmax_dim_index(indices));
if (!log_softmax) {
return Compute("aten_softmax", output_dims, [&](ParameterList& indices) {
return e->call(indices) / sum->call(remove_softmax_dim_index(indices));
});
}
auto log_sum = Compute(
"aten_softmax_log_sum", non_softmax_dims, [&](ParameterList& indices) {
return log(sum->call(indices));
});
return Compute("aten_log_softmax", output_dims, [&](ParameterList& indices) {
auto inp = tensorOrConstant(
v->node()->inputs()[0], convert_indices_to_expr_handle(indices));
auto non_softmax_indices = remove_softmax_dim_index(indices);
return inp - max->call(non_softmax_indices) -
log_sum->call(non_softmax_indices);
});
return res;
}
TensorExprKernel::ReductionInfo TensorExprKernel::getReductionInfo(

View File

@ -121,7 +121,7 @@ class TORCH_API TensorExprKernel {
Tensor* computeSum(const torch::jit::Value* v);
Tensor* computeSoftmax(const torch::jit::Value* v);
Tensor* computeSoftmax(const torch::jit::Value* v, bool log_softmax);
Tensor* computeValue(const torch::jit::Value* v);