mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Add support for log_softmax (#47409)
Summary: This diff adds support for `log_softmax` op in NNC. Pull Request resolved: https://github.com/pytorch/pytorch/pull/47409 Reviewed By: ejguan Differential Revision: D24750203 Pulled By: navahgar fbshipit-source-id: c4dacc7f62f9df65ae467f0d578ea03d3698273d
This commit is contained in:
parent
582e852fba
commit
8eb228a7f3
|
|
@ -635,7 +635,7 @@ void testKernelSoftmax2D() {
|
|||
graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)):
|
||||
%1 : int = prim::Constant[value=${dim}]()
|
||||
%2 : int = prim::Constant[value=7]()
|
||||
%3 : Tensor = aten::softmax(%0, %1, %2)
|
||||
%3 : Tensor = aten::${op}(%0, %1, %2)
|
||||
return (%3))IR";
|
||||
|
||||
auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
|
|
@ -652,39 +652,43 @@ void testKernelSoftmax2D() {
|
|||
# CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3
|
||||
# CHECK-NEXT: aten_softmax)IR";
|
||||
|
||||
for (int softmax_dim = 0; softmax_dim < a.dim(); ++softmax_dim) {
|
||||
auto softmax_dim_size = a.sizes()[softmax_dim];
|
||||
auto other_dim = (softmax_dim + 1) % a.dim();
|
||||
for (auto log_softmax : {false, true}) {
|
||||
for (int softmax_dim = 0; softmax_dim < a.dim(); ++softmax_dim) {
|
||||
auto softmax_dim_size = a.sizes()[softmax_dim];
|
||||
auto other_dim = (softmax_dim + 1) % a.dim();
|
||||
|
||||
KernelScope kernel_scope;
|
||||
TemplateEnv env;
|
||||
env.d("dim", softmax_dim);
|
||||
const auto graph_string = format(graph_template, env);
|
||||
KernelScope kernel_scope;
|
||||
TemplateEnv env;
|
||||
env.d("dim", softmax_dim);
|
||||
env.s("op", log_softmax ? "log_softmax" : "softmax");
|
||||
const auto graph_string = format(graph_template, env);
|
||||
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(graph_string, &*graph);
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(graph_string, &*graph);
|
||||
|
||||
TensorExprKernel k(graph);
|
||||
std::vector<at::Tensor> inputs = {a};
|
||||
Stmt* s = k.getCodeGenStmt();
|
||||
TensorExprKernel k(graph);
|
||||
std::vector<at::Tensor> inputs = {a};
|
||||
Stmt* s = k.getCodeGenStmt();
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *s;
|
||||
std::ostringstream oss;
|
||||
oss << *s;
|
||||
|
||||
TemplateEnv ver_env;
|
||||
ver_env.d("other_dim", other_dim);
|
||||
ver_env.d("other_dim_size", a.sizes()[other_dim]);
|
||||
ver_env.d("softmax_dim", softmax_dim);
|
||||
ver_env.d("softmax_dim_size", softmax_dim_size);
|
||||
const auto verification_pattern = format(verification_template, ver_env);
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
TemplateEnv ver_env;
|
||||
ver_env.d("other_dim", other_dim);
|
||||
ver_env.d("other_dim_size", a.sizes()[other_dim]);
|
||||
ver_env.d("softmax_dim", softmax_dim);
|
||||
ver_env.d("softmax_dim_size", softmax_dim_size);
|
||||
const auto verification_pattern = format(verification_template, ver_env);
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
k.run(stack);
|
||||
auto output = stack[0].toTensor();
|
||||
auto ref = a.softmax(softmax_dim);
|
||||
ASSERT_EQ(output.sizes(), ref.sizes());
|
||||
ASSERT_TRUE(at::allclose(output, ref));
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
k.run(stack);
|
||||
auto output = stack[0].toTensor();
|
||||
auto ref =
|
||||
log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim);
|
||||
ASSERT_EQ(output.sizes(), ref.sizes());
|
||||
ASSERT_TRUE(at::allclose(output, ref));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -693,7 +697,7 @@ void testKernelSoftmax3D() {
|
|||
graph(%0 : Float(3, 4, 5, strides=[20, 5, 1], device=cpu)):
|
||||
%1 : int = prim::Constant[value=${dim}]()
|
||||
%2 : int = prim::Constant[value=7]()
|
||||
%3 : Tensor = aten::softmax(%0, %1, %2)
|
||||
%3 : Tensor = aten::${op}(%0, %1, %2)
|
||||
return (%3))IR";
|
||||
|
||||
auto a = at::rand({3, 4, 5}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
|
|
@ -713,47 +717,51 @@ void testKernelSoftmax3D() {
|
|||
# CHECK-NEXT: for (int i2_2 = 0; i2_2 < 5
|
||||
# CHECK-NEXT: aten_softmax)IR";
|
||||
|
||||
for (int softmax_dim = 0; softmax_dim < a.dim(); ++softmax_dim) {
|
||||
auto softmax_dim_size = a.sizes()[softmax_dim];
|
||||
std::vector<int> other_dims;
|
||||
for (int i = 0; i < a.dim(); ++i) {
|
||||
if (i != softmax_dim) {
|
||||
other_dims.push_back(i);
|
||||
for (auto log_softmax : {false, true}) {
|
||||
for (int softmax_dim = 0; softmax_dim < a.dim(); ++softmax_dim) {
|
||||
auto softmax_dim_size = a.sizes()[softmax_dim];
|
||||
std::vector<int> other_dims;
|
||||
for (int i = 0; i < a.dim(); ++i) {
|
||||
if (i != softmax_dim) {
|
||||
other_dims.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
KernelScope kernel_scope;
|
||||
TemplateEnv env;
|
||||
env.d("dim", softmax_dim);
|
||||
env.s("op", log_softmax ? "log_softmax" : "softmax");
|
||||
const auto graph_string = format(graph_template, env);
|
||||
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(graph_string, &*graph);
|
||||
|
||||
TensorExprKernel k(graph);
|
||||
std::vector<at::Tensor> inputs = {a};
|
||||
Stmt* s = k.getCodeGenStmt();
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *s;
|
||||
|
||||
TemplateEnv ver_env;
|
||||
ver_env.d("dim1", other_dims[0]);
|
||||
ver_env.d("dim1_size", a.sizes()[other_dims[0]]);
|
||||
ver_env.d("dim2", other_dims[1]);
|
||||
ver_env.d("dim2_size", a.sizes()[other_dims[1]]);
|
||||
ver_env.d("softmax_dim", softmax_dim);
|
||||
ver_env.d("softmax_dim_size", softmax_dim_size);
|
||||
const auto verification_pattern = format(verification_template, ver_env);
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
k.run(stack);
|
||||
auto output = stack[0].toTensor();
|
||||
|
||||
auto ref =
|
||||
log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim);
|
||||
ASSERT_EQ(output.sizes(), ref.sizes());
|
||||
ASSERT_TRUE(at::allclose(output, ref));
|
||||
}
|
||||
|
||||
KernelScope kernel_scope;
|
||||
TemplateEnv env;
|
||||
env.d("dim", softmax_dim);
|
||||
const auto graph_string = format(graph_template, env);
|
||||
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(graph_string, &*graph);
|
||||
|
||||
TensorExprKernel k(graph);
|
||||
std::vector<at::Tensor> inputs = {a};
|
||||
Stmt* s = k.getCodeGenStmt();
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *s;
|
||||
|
||||
TemplateEnv ver_env;
|
||||
ver_env.d("dim1", other_dims[0]);
|
||||
ver_env.d("dim1_size", a.sizes()[other_dims[0]]);
|
||||
ver_env.d("dim2", other_dims[1]);
|
||||
ver_env.d("dim2_size", a.sizes()[other_dims[1]]);
|
||||
ver_env.d("softmax_dim", softmax_dim);
|
||||
ver_env.d("softmax_dim_size", softmax_dim_size);
|
||||
const auto verification_pattern = format(verification_template, ver_env);
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
k.run(stack);
|
||||
auto output = stack[0].toTensor();
|
||||
|
||||
auto ref = a.softmax(softmax_dim);
|
||||
ASSERT_EQ(output.sizes(), ref.sizes());
|
||||
ASSERT_TRUE(at::allclose(output, ref));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -762,7 +770,7 @@ void testKernelSoftmax4D() {
|
|||
graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], device=cpu)):
|
||||
%1 : int = prim::Constant[value=${dim}]()
|
||||
%2 : int = prim::Constant[value=7]()
|
||||
%3 : Tensor = aten::softmax(%0, %1, %2)
|
||||
%3 : Tensor = aten::${op}(%0, %1, %2)
|
||||
return (%3))IR";
|
||||
|
||||
auto a = at::rand({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
|
|
@ -785,48 +793,52 @@ void testKernelSoftmax4D() {
|
|||
# CHECK-NEXT: for (int i3_2 = 0; i3_2 < 3
|
||||
# CHECK-NEXT: aten_softmax)IR";
|
||||
|
||||
for (int softmax_dim = 0; softmax_dim < a.dim(); ++softmax_dim) {
|
||||
auto softmax_dim_size = a.sizes()[softmax_dim];
|
||||
std::vector<int> other_dims;
|
||||
for (int i = 0; i < a.dim(); ++i) {
|
||||
if (i != softmax_dim) {
|
||||
other_dims.push_back(i);
|
||||
for (auto log_softmax : {false, true}) {
|
||||
for (int softmax_dim = 0; softmax_dim < a.dim(); ++softmax_dim) {
|
||||
auto softmax_dim_size = a.sizes()[softmax_dim];
|
||||
std::vector<int> other_dims;
|
||||
for (int i = 0; i < a.dim(); ++i) {
|
||||
if (i != softmax_dim) {
|
||||
other_dims.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
KernelScope kernel_scope;
|
||||
TemplateEnv env;
|
||||
env.d("dim", softmax_dim);
|
||||
env.s("op", log_softmax ? "log_softmax" : "softmax");
|
||||
const auto graph_string = format(graph_template, env);
|
||||
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(graph_string, &*graph);
|
||||
|
||||
TensorExprKernel k(graph);
|
||||
std::vector<at::Tensor> inputs = {a};
|
||||
Stmt* s = k.getCodeGenStmt();
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *s;
|
||||
|
||||
TemplateEnv ver_env;
|
||||
ver_env.d("dim1", other_dims[0]);
|
||||
ver_env.d("dim1_size", a.sizes()[other_dims[0]]);
|
||||
ver_env.d("dim2", other_dims[1]);
|
||||
ver_env.d("dim2_size", a.sizes()[other_dims[1]]);
|
||||
ver_env.d("dim3", other_dims[2]);
|
||||
ver_env.d("dim3_size", a.sizes()[other_dims[2]]);
|
||||
ver_env.d("softmax_dim", softmax_dim);
|
||||
ver_env.d("softmax_dim_size", softmax_dim_size);
|
||||
const auto verification_pattern = format(verification_template, ver_env);
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
k.run(stack);
|
||||
auto output = stack[0].toTensor();
|
||||
auto ref =
|
||||
log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim);
|
||||
ASSERT_EQ(output.sizes(), ref.sizes());
|
||||
ASSERT_TRUE(at::allclose(output, ref));
|
||||
}
|
||||
|
||||
KernelScope kernel_scope;
|
||||
TemplateEnv env;
|
||||
env.d("dim", softmax_dim);
|
||||
const auto graph_string = format(graph_template, env);
|
||||
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(graph_string, &*graph);
|
||||
|
||||
TensorExprKernel k(graph);
|
||||
std::vector<at::Tensor> inputs = {a};
|
||||
Stmt* s = k.getCodeGenStmt();
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *s;
|
||||
|
||||
TemplateEnv ver_env;
|
||||
ver_env.d("dim1", other_dims[0]);
|
||||
ver_env.d("dim1_size", a.sizes()[other_dims[0]]);
|
||||
ver_env.d("dim2", other_dims[1]);
|
||||
ver_env.d("dim2_size", a.sizes()[other_dims[1]]);
|
||||
ver_env.d("dim3", other_dims[2]);
|
||||
ver_env.d("dim3_size", a.sizes()[other_dims[2]]);
|
||||
ver_env.d("softmax_dim", softmax_dim);
|
||||
ver_env.d("softmax_dim_size", softmax_dim_size);
|
||||
const auto verification_pattern = format(verification_template, ver_env);
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
k.run(stack);
|
||||
auto output = stack[0].toTensor();
|
||||
auto ref = a.softmax(softmax_dim);
|
||||
ASSERT_EQ(output.sizes(), ref.sizes());
|
||||
ASSERT_TRUE(at::allclose(output, ref));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1226,21 +1226,29 @@ class TestTensorExprFuser(BaseTestClass):
|
|||
assert llvm.elapsed_value() == 1 or interp.elapsed_value() > 1
|
||||
|
||||
def _test_softmax(self, device):
|
||||
def test(x, y):
|
||||
def test_softmax(x, y):
|
||||
a = F.softmax(x, dim=0, dtype=torch.float32)
|
||||
b = F.softmax(y, dim=0, dtype=torch.float32)
|
||||
c = F.softmax(x, dim=1, dtype=torch.float32)
|
||||
d = F.softmax(y, dim=1, dtype=torch.float32)
|
||||
return a + b + c + d
|
||||
|
||||
old = torch._C._jit_set_texpr_reductions_enabled(True)
|
||||
traced = torch.jit.trace(test, (torch.randn(2, 3, device=device), torch.randn(2, 3, device=device)))
|
||||
inp = torch.randn(2, 3, device=device)
|
||||
res = traced(inp, inp)
|
||||
# Use eager mode as reference.
|
||||
ref = test(inp, inp)
|
||||
np.testing.assert_allclose(ref, res.cpu().numpy(), rtol=1e-06, atol=1e-06)
|
||||
torch._C._jit_set_texpr_reductions_enabled(old)
|
||||
def test_log_softmax(x, y):
|
||||
a = F.log_softmax(x, dim=0, dtype=torch.float32)
|
||||
b = F.log_softmax(y, dim=0, dtype=torch.float32)
|
||||
c = F.log_softmax(x, dim=1, dtype=torch.float32)
|
||||
d = F.log_softmax(y, dim=1, dtype=torch.float32)
|
||||
return a + b + c + d
|
||||
|
||||
for test in (test_softmax, test_log_softmax):
|
||||
old = torch._C._jit_set_texpr_reductions_enabled(True)
|
||||
traced = torch.jit.trace(test, (torch.randn(2, 3, device=device), torch.randn(2, 3, device=device)))
|
||||
inp = torch.randn(2, 3, device=device)
|
||||
res = traced(inp, inp)
|
||||
# Use eager mode as reference.
|
||||
ref = test(inp, inp)
|
||||
np.testing.assert_allclose(ref, res.cpu().numpy(), rtol=1e-06, atol=1e-06)
|
||||
torch._C._jit_set_texpr_reductions_enabled(old)
|
||||
|
||||
def test_softmax_cpu(self):
|
||||
llvm = LLVMCodeGenExecuted()
|
||||
|
|
|
|||
|
|
@ -153,6 +153,7 @@ bool isSupported(Node* node) {
|
|||
"aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor",
|
||||
"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
|
||||
"aten::softmax.int(Tensor self, int dim , ScalarType? dtype=None) -> Tensor",
|
||||
"aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor",
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
|
|
|
|||
|
|
@ -371,7 +371,8 @@ std::vector<ExprHandle> TensorExprKernel::inferSizesForValue(
|
|||
}
|
||||
|
||||
case aten::softmax:
|
||||
// Output of Softmax has the same shape as input 0.
|
||||
case aten::log_softmax:
|
||||
// Output of softmax / log_softmax has the same shape as input 0.
|
||||
return sizesForValue(v->node()->input(0));
|
||||
|
||||
case aten::slice:
|
||||
|
|
@ -1353,7 +1354,11 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
|
|||
}
|
||||
|
||||
case aten::softmax: {
|
||||
return computeSoftmax(v);
|
||||
return computeSoftmax(v, false);
|
||||
}
|
||||
|
||||
case aten::log_softmax: {
|
||||
return computeSoftmax(v, true);
|
||||
}
|
||||
|
||||
default: {
|
||||
|
|
@ -1626,7 +1631,9 @@ Tensor* TensorExprKernel::computeSum(const torch::jit::Value* v) {
|
|||
reduction_info.reductionDims);
|
||||
}
|
||||
|
||||
Tensor* TensorExprKernel::computeSoftmax(const torch::jit::Value* v) {
|
||||
Tensor* TensorExprKernel::computeSoftmax(
|
||||
const torch::jit::Value* v,
|
||||
bool log_softmax) {
|
||||
// Softmax is computed as follows:
|
||||
// softmax(vi) = exp(vi) / sum(exp(vi))
|
||||
//
|
||||
|
|
@ -1641,6 +1648,21 @@ Tensor* TensorExprKernel::computeSoftmax(const torch::jit::Value* v) {
|
|||
// - Third loop computes the sum over the softmax dim.
|
||||
// - Final loop computes softmax for every element in v.
|
||||
|
||||
// LogSoftmax is computed as follows:
|
||||
// log_softmax(vi) = log(softmax(vi))
|
||||
// = vi - log(sum(exp(vi)))
|
||||
//
|
||||
// Using the same max trick as above:
|
||||
// log_softmax(vi) = vi - max(vi) - log(sum(exp(vi - max(vi))))
|
||||
//
|
||||
// This is implemented as 5 loopnests:
|
||||
// - First loop computes the max over the softmax dim.
|
||||
// - Second loop computes exp for every element in v after subtracting
|
||||
// the max of the softmax dim it belongs to.
|
||||
// - Third loop computes the sum over the softmax dim.
|
||||
// - Fourth loop computes log for every element in the sum.
|
||||
// - Final loop computes the log_softmax for every element in v.
|
||||
|
||||
TORCH_INTERNAL_ASSERT(v->node()->inputs().size() == 3);
|
||||
auto output_dims = dimsFromSizes(sizesForValue(v));
|
||||
|
||||
|
|
@ -1724,10 +1746,23 @@ Tensor* TensorExprKernel::computeSoftmax(const torch::jit::Value* v) {
|
|||
return e->call(move_softmax_dim_index_to_pos(indices));
|
||||
},
|
||||
{output_dims[softmax_dim]});
|
||||
auto res = Compute("aten_softmax", output_dims, [&](ParameterList& indices) {
|
||||
return e->call(indices) / sum->call(remove_softmax_dim_index(indices));
|
||||
if (!log_softmax) {
|
||||
return Compute("aten_softmax", output_dims, [&](ParameterList& indices) {
|
||||
return e->call(indices) / sum->call(remove_softmax_dim_index(indices));
|
||||
});
|
||||
}
|
||||
|
||||
auto log_sum = Compute(
|
||||
"aten_softmax_log_sum", non_softmax_dims, [&](ParameterList& indices) {
|
||||
return log(sum->call(indices));
|
||||
});
|
||||
return Compute("aten_log_softmax", output_dims, [&](ParameterList& indices) {
|
||||
auto inp = tensorOrConstant(
|
||||
v->node()->inputs()[0], convert_indices_to_expr_handle(indices));
|
||||
auto non_softmax_indices = remove_softmax_dim_index(indices);
|
||||
return inp - max->call(non_softmax_indices) -
|
||||
log_sum->call(non_softmax_indices);
|
||||
});
|
||||
return res;
|
||||
}
|
||||
|
||||
TensorExprKernel::ReductionInfo TensorExprKernel::getReductionInfo(
|
||||
|
|
|
|||
|
|
@ -121,7 +121,7 @@ class TORCH_API TensorExprKernel {
|
|||
|
||||
Tensor* computeSum(const torch::jit::Value* v);
|
||||
|
||||
Tensor* computeSoftmax(const torch::jit::Value* v);
|
||||
Tensor* computeSoftmax(const torch::jit::Value* v, bool log_softmax);
|
||||
|
||||
Tensor* computeValue(const torch::jit::Value* v);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user