Replace flatten tensors with flatten loops. (#46539)

Summary:
This diff changes `TensorExprKernel::generateStmt` to use flatten loops instead of flatten tensors.

Checked all tests on CPU as well as CUDA.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/46539

Reviewed By: nickgg

Differential Revision: D24395956

Pulled By: navahgar

fbshipit-source-id: f3792903f2069bda37b571c9f0a840e6fb02f189
This commit is contained in:
Raghavan Raman 2020-10-20 12:13:56 -07:00 committed by Facebook GitHub Bot
parent 9c02e2112e
commit 2f51ddb81f
2 changed files with 20 additions and 66 deletions

View File

@ -1342,56 +1342,8 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
}
}
void TensorExprKernel::flattenTensors(BackendType backendType) {
if (backendType != BackendType::kCudaCodeGen &&
backendType != BackendType::kBlockCodeGen) {
// We only need to flatten for GPU, for other backends just use the same
// tensors.
flatTensorOutputs_ = tensorOutputs_;
return;
}
flatTensorOutputs_.resize(tensorOutputs_.size());
for (size_t tensorIdx = 0; tensorIdx < tensorOutputs_.size(); tensorIdx++) {
Tensor* tensor = tensorOutputs_[tensorIdx];
ExprHandle totalCount = ExprHandle(tensor->dim(0));
for (int i = 1; i < tensor->ndim(); i++) {
const IntImm* totalCountImm = totalCount.AsNode<IntImm>();
const IntImm* tensorDimImm = dynamic_cast<const IntImm*>(tensor->dim(i));
if (totalCountImm && tensorDimImm) {
// TODO: switch to real constant folding when it is available.
totalCount = ExprHandle(totalCountImm->value() * tensorDimImm->value());
} else {
totalCount = totalCount * ExprHandle(tensor->dim(i));
}
}
// Flatten the index for GPU kernels.
// TODO: move this to fusing axis when it is ready.
Tensor* newOut = Compute(
tensor->buf()->name_hint() + "_flat",
{totalCount},
[tensor](const VarHandle& index) -> ExprHandle {
std::vector<ExprHandle> dims;
ExprHandle value = index;
for (int i = tensor->ndim() - 1; i >= 0; i--) {
ExprHandle idx = value;
if (i > 0) {
idx = Mod::make(value, ExprHandle(tensor->dim(i)));
}
dims.push_back(idx);
value = value / ExprHandle(tensor->dim(i));
}
std::reverse(dims.begin(), dims.end());
return tensor->call(dims);
});
flatTensorOutputs_[tensorIdx] = newOut;
}
}
Stmt* TensorExprKernel::generateStmt(BackendType backendType) {
flattenTensors(backendType);
torch::jit::tensorexpr::LoopNest l(flatTensorOutputs_);
torch::jit::tensorexpr::LoopNest l(tensorOutputs_);
GRAPH_DEBUG("Original Stmt:\n", std::to_string(l.root_stmt()), "\n");
bool hasReduction = NodeFinder<ReduceOp>::find(l.root_stmt()).size() != 0;
@ -1404,12 +1356,15 @@ Stmt* TensorExprKernel::generateStmt(BackendType backendType) {
l.computeInline(p.second->buf());
}
if (backendType == kCudaCodeGen) {
for (size_t i = 0; i < flatTensorOutputs_.size(); i++) {
Tensor* tensor = flatTensorOutputs_[i];
for (auto tensor : tensorOutputs_) {
// For every output tensor we've created a flattened 1D tensor - let's
// mark the original output tensor with computeInline
l.computeInline(tensorOutputs_[i]->buf());
l.computeInline(tensor->buf());
std::vector<For*> loops = l.getLoopStmtsFor(tensor);
TORCH_INTERNAL_ASSERT(!loops.empty(), "loops should not be empty");
For* flattened;
LoopNest::flatten(loops, &flattened);
int loopLevels = getTECudaPointwiseLoopLevels();
const int kDefaultLoopLevels = 2;
@ -1424,8 +1379,7 @@ Stmt* TensorExprKernel::generateStmt(BackendType backendType) {
if (blockSize < 0) {
blockSize = kDefaultBlockSize;
}
std::vector<For*> loops = l.getLoopStmtsFor(tensor);
l.splitWithMask(loops[0], blockSize, &outer, &inner);
l.splitWithMask(flattened, blockSize, &outer, &inner);
l.setGPUBlockIndex(outer, 0);
l.setGPUThreadIndex(inner, 0);
} else if (loopLevels == 3) {
@ -1438,8 +1392,7 @@ Stmt* TensorExprKernel::generateStmt(BackendType backendType) {
const int kDefaultBlockSize = 256;
blockCount = (blockCount > 0) ? blockCount : kDefaultBlockCount;
blockSize = (blockSize > 0) ? blockSize : kDefaultBlockSize;
std::vector<For*> loops = l.getLoopStmtsFor(tensor);
l.splitWithMask(loops[0], blockCount * blockSize, &outer, &inner);
l.splitWithMask(flattened, blockCount * blockSize, &outer, &inner);
l.splitWithMask(inner, blockSize, &inner1, &inner2);
l.setGPUBlockIndex(inner1, 0);
l.setGPUThreadIndex(inner2, 0);
@ -1452,12 +1405,11 @@ Stmt* TensorExprKernel::generateStmt(BackendType backendType) {
if (backendType == kBlockCodeGen) {
auto block_analysis = std::make_unique<CreateBufferMap>();
for (size_t i = 0; i < flatTensorOutputs_.size(); i++) {
for (auto tensor : tensorOutputs_) {
const int default_fp16_blocksize = 16;
const int default_uint8_blocksize = 32;
int blockSize = default_fp16_blocksize;
// We only handle looplevels == 2 for now
Tensor* tensor = flatTensorOutputs_[i];
// Run Block analysis to get multi dim buffer info
auto root_stmt = l.root_stmt();
root_stmt->accept(block_analysis.get());
@ -1465,12 +1417,16 @@ Stmt* TensorExprKernel::generateStmt(BackendType backendType) {
if (tensor->buf()->dtype().scalar_type() == ScalarType::Byte) {
blockSize = default_uint8_blocksize;
}
l.computeInline(l.getLoopBodyFor(tensorOutputs_[i]));
l.computeInline(l.getLoopBodyFor(tensor));
std::vector<For*> loops = l.getLoopStmtsFor(tensor);
TORCH_INTERNAL_ASSERT(!loops.empty(), "loops should not be empty");
For* flattened;
LoopNest::flatten(loops, &flattened);
For* outer;
For* inner;
std::vector<For*> loops = l.getLoopStmtsFor(tensor);
TORCH_INTERNAL_ASSERT(loops.size() > 0, "loops should not be empty");
l.splitWithMask(loops[0], blockSize, &outer, &inner);
l.splitWithMask(flattened, blockSize, &outer, &inner);
l.setGPUBlockIndex(outer, 0);
l.setGPUThreadIndex(inner, 0);
l.setBufferMap(outer, block_analysis->getBufferMap());
@ -1518,7 +1474,7 @@ std::vector<CodeGen::BufferArg> TensorExprKernel::prepareBufferArgs() {
params.emplace_back(stride.var);
}
}
for (auto& o : flatTensorOutputs_) {
for (auto& o : tensorOutputs_) {
params.emplace_back(o);
}
return params;

View File

@ -125,7 +125,6 @@ class TORCH_API TensorExprKernel {
Tensor* computeValue(const torch::jit::Value* v);
void flattenTensors(BackendType backendType);
Stmt* generateStmt(BackendType backendType);
std::vector<CodeGen::BufferArg> prepareBufferArgs();
@ -191,7 +190,6 @@ class TORCH_API TensorExprKernel {
int64_t nInputs_ = 0;
std::vector<KernelArg> kernelArgs_;
std::vector<Tensor*> tensorOutputs_;
std::vector<Tensor*> flatTensorOutputs_;
std::unordered_map<int64_t, Tensor*> tensors_;
std::unordered_map<int64_t, VarHandle> scalars_;
std::unique_ptr<CodeGen> codegen_;