mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Replace flatten tensors with flatten loops. (#46539)
Summary: This diff changes `TensorExprKernel::generateStmt` to use flatten loops instead of flatten tensors. Checked all tests on CPU as well as CUDA. Pull Request resolved: https://github.com/pytorch/pytorch/pull/46539 Reviewed By: nickgg Differential Revision: D24395956 Pulled By: navahgar fbshipit-source-id: f3792903f2069bda37b571c9f0a840e6fb02f189
This commit is contained in:
parent
9c02e2112e
commit
2f51ddb81f
|
|
@ -1342,56 +1342,8 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
|
|||
}
|
||||
}
|
||||
|
||||
void TensorExprKernel::flattenTensors(BackendType backendType) {
|
||||
if (backendType != BackendType::kCudaCodeGen &&
|
||||
backendType != BackendType::kBlockCodeGen) {
|
||||
// We only need to flatten for GPU, for other backends just use the same
|
||||
// tensors.
|
||||
flatTensorOutputs_ = tensorOutputs_;
|
||||
return;
|
||||
}
|
||||
|
||||
flatTensorOutputs_.resize(tensorOutputs_.size());
|
||||
for (size_t tensorIdx = 0; tensorIdx < tensorOutputs_.size(); tensorIdx++) {
|
||||
Tensor* tensor = tensorOutputs_[tensorIdx];
|
||||
ExprHandle totalCount = ExprHandle(tensor->dim(0));
|
||||
for (int i = 1; i < tensor->ndim(); i++) {
|
||||
const IntImm* totalCountImm = totalCount.AsNode<IntImm>();
|
||||
const IntImm* tensorDimImm = dynamic_cast<const IntImm*>(tensor->dim(i));
|
||||
if (totalCountImm && tensorDimImm) {
|
||||
// TODO: switch to real constant folding when it is available.
|
||||
totalCount = ExprHandle(totalCountImm->value() * tensorDimImm->value());
|
||||
} else {
|
||||
totalCount = totalCount * ExprHandle(tensor->dim(i));
|
||||
}
|
||||
}
|
||||
// Flatten the index for GPU kernels.
|
||||
// TODO: move this to fusing axis when it is ready.
|
||||
Tensor* newOut = Compute(
|
||||
tensor->buf()->name_hint() + "_flat",
|
||||
{totalCount},
|
||||
[tensor](const VarHandle& index) -> ExprHandle {
|
||||
std::vector<ExprHandle> dims;
|
||||
ExprHandle value = index;
|
||||
for (int i = tensor->ndim() - 1; i >= 0; i--) {
|
||||
ExprHandle idx = value;
|
||||
if (i > 0) {
|
||||
idx = Mod::make(value, ExprHandle(tensor->dim(i)));
|
||||
}
|
||||
dims.push_back(idx);
|
||||
value = value / ExprHandle(tensor->dim(i));
|
||||
}
|
||||
std::reverse(dims.begin(), dims.end());
|
||||
return tensor->call(dims);
|
||||
});
|
||||
flatTensorOutputs_[tensorIdx] = newOut;
|
||||
}
|
||||
}
|
||||
|
||||
Stmt* TensorExprKernel::generateStmt(BackendType backendType) {
|
||||
flattenTensors(backendType);
|
||||
|
||||
torch::jit::tensorexpr::LoopNest l(flatTensorOutputs_);
|
||||
torch::jit::tensorexpr::LoopNest l(tensorOutputs_);
|
||||
GRAPH_DEBUG("Original Stmt:\n", std::to_string(l.root_stmt()), "\n");
|
||||
|
||||
bool hasReduction = NodeFinder<ReduceOp>::find(l.root_stmt()).size() != 0;
|
||||
|
|
@ -1404,12 +1356,15 @@ Stmt* TensorExprKernel::generateStmt(BackendType backendType) {
|
|||
l.computeInline(p.second->buf());
|
||||
}
|
||||
if (backendType == kCudaCodeGen) {
|
||||
for (size_t i = 0; i < flatTensorOutputs_.size(); i++) {
|
||||
Tensor* tensor = flatTensorOutputs_[i];
|
||||
|
||||
for (auto tensor : tensorOutputs_) {
|
||||
// For every output tensor we've created a flattened 1D tensor - let's
|
||||
// mark the original output tensor with computeInline
|
||||
l.computeInline(tensorOutputs_[i]->buf());
|
||||
l.computeInline(tensor->buf());
|
||||
|
||||
std::vector<For*> loops = l.getLoopStmtsFor(tensor);
|
||||
TORCH_INTERNAL_ASSERT(!loops.empty(), "loops should not be empty");
|
||||
For* flattened;
|
||||
LoopNest::flatten(loops, &flattened);
|
||||
|
||||
int loopLevels = getTECudaPointwiseLoopLevels();
|
||||
const int kDefaultLoopLevels = 2;
|
||||
|
|
@ -1424,8 +1379,7 @@ Stmt* TensorExprKernel::generateStmt(BackendType backendType) {
|
|||
if (blockSize < 0) {
|
||||
blockSize = kDefaultBlockSize;
|
||||
}
|
||||
std::vector<For*> loops = l.getLoopStmtsFor(tensor);
|
||||
l.splitWithMask(loops[0], blockSize, &outer, &inner);
|
||||
l.splitWithMask(flattened, blockSize, &outer, &inner);
|
||||
l.setGPUBlockIndex(outer, 0);
|
||||
l.setGPUThreadIndex(inner, 0);
|
||||
} else if (loopLevels == 3) {
|
||||
|
|
@ -1438,8 +1392,7 @@ Stmt* TensorExprKernel::generateStmt(BackendType backendType) {
|
|||
const int kDefaultBlockSize = 256;
|
||||
blockCount = (blockCount > 0) ? blockCount : kDefaultBlockCount;
|
||||
blockSize = (blockSize > 0) ? blockSize : kDefaultBlockSize;
|
||||
std::vector<For*> loops = l.getLoopStmtsFor(tensor);
|
||||
l.splitWithMask(loops[0], blockCount * blockSize, &outer, &inner);
|
||||
l.splitWithMask(flattened, blockCount * blockSize, &outer, &inner);
|
||||
l.splitWithMask(inner, blockSize, &inner1, &inner2);
|
||||
l.setGPUBlockIndex(inner1, 0);
|
||||
l.setGPUThreadIndex(inner2, 0);
|
||||
|
|
@ -1452,12 +1405,11 @@ Stmt* TensorExprKernel::generateStmt(BackendType backendType) {
|
|||
|
||||
if (backendType == kBlockCodeGen) {
|
||||
auto block_analysis = std::make_unique<CreateBufferMap>();
|
||||
for (size_t i = 0; i < flatTensorOutputs_.size(); i++) {
|
||||
for (auto tensor : tensorOutputs_) {
|
||||
const int default_fp16_blocksize = 16;
|
||||
const int default_uint8_blocksize = 32;
|
||||
int blockSize = default_fp16_blocksize;
|
||||
// We only handle looplevels == 2 for now
|
||||
Tensor* tensor = flatTensorOutputs_[i];
|
||||
// Run Block analysis to get multi dim buffer info
|
||||
auto root_stmt = l.root_stmt();
|
||||
root_stmt->accept(block_analysis.get());
|
||||
|
|
@ -1465,12 +1417,16 @@ Stmt* TensorExprKernel::generateStmt(BackendType backendType) {
|
|||
if (tensor->buf()->dtype().scalar_type() == ScalarType::Byte) {
|
||||
blockSize = default_uint8_blocksize;
|
||||
}
|
||||
l.computeInline(l.getLoopBodyFor(tensorOutputs_[i]));
|
||||
l.computeInline(l.getLoopBodyFor(tensor));
|
||||
|
||||
std::vector<For*> loops = l.getLoopStmtsFor(tensor);
|
||||
TORCH_INTERNAL_ASSERT(!loops.empty(), "loops should not be empty");
|
||||
For* flattened;
|
||||
LoopNest::flatten(loops, &flattened);
|
||||
|
||||
For* outer;
|
||||
For* inner;
|
||||
std::vector<For*> loops = l.getLoopStmtsFor(tensor);
|
||||
TORCH_INTERNAL_ASSERT(loops.size() > 0, "loops should not be empty");
|
||||
l.splitWithMask(loops[0], blockSize, &outer, &inner);
|
||||
l.splitWithMask(flattened, blockSize, &outer, &inner);
|
||||
l.setGPUBlockIndex(outer, 0);
|
||||
l.setGPUThreadIndex(inner, 0);
|
||||
l.setBufferMap(outer, block_analysis->getBufferMap());
|
||||
|
|
@ -1518,7 +1474,7 @@ std::vector<CodeGen::BufferArg> TensorExprKernel::prepareBufferArgs() {
|
|||
params.emplace_back(stride.var);
|
||||
}
|
||||
}
|
||||
for (auto& o : flatTensorOutputs_) {
|
||||
for (auto& o : tensorOutputs_) {
|
||||
params.emplace_back(o);
|
||||
}
|
||||
return params;
|
||||
|
|
|
|||
|
|
@ -125,7 +125,6 @@ class TORCH_API TensorExprKernel {
|
|||
|
||||
Tensor* computeValue(const torch::jit::Value* v);
|
||||
|
||||
void flattenTensors(BackendType backendType);
|
||||
Stmt* generateStmt(BackendType backendType);
|
||||
std::vector<CodeGen::BufferArg> prepareBufferArgs();
|
||||
|
||||
|
|
@ -191,7 +190,6 @@ class TORCH_API TensorExprKernel {
|
|||
int64_t nInputs_ = 0;
|
||||
std::vector<KernelArg> kernelArgs_;
|
||||
std::vector<Tensor*> tensorOutputs_;
|
||||
std::vector<Tensor*> flatTensorOutputs_;
|
||||
std::unordered_map<int64_t, Tensor*> tensors_;
|
||||
std::unordered_map<int64_t, VarHandle> scalars_;
|
||||
std::unique_ptr<CodeGen> codegen_;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user