[NNC] Use index for stride mapping in kernel.cpp (#72266)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72266

Within the kernel, we may manipulate `Value *` in `OptimizeCat`, which would invalidate the input `Value *` -> Stride mapping.

Fix for https://github.com/pytorch/pytorch/issues/72173

Test Plan: Imported from OSS

Reviewed By: dagitses, davidberard98

Differential Revision: D33986306

Pulled By: eellison

fbshipit-source-id: dc33cd2b545e49e90d1e46b9fcf1e6dbb4b829db
(cherry picked from commit 5e4555968a)
This commit is contained in:
Elias Ellison 2022-02-03 16:02:17 -08:00 committed by PyTorch MergeBot
parent 9d4d782e42
commit defde3bb04
3 changed files with 54 additions and 19 deletions

View File

@ -2100,6 +2100,14 @@ class TestTEFuser(JitTestCase):
return F.hardsigmoid(x) * 1.01
self._test_fwd_bwd(eager)
def test_cat_graph_opt(self):
def foo(x, y, z):
return torch.log(torch.cat([x, y, z]))
self.checkScript(foo, (torch.rand([5, 5]), torch.rand([2, 5]), torch.rand([1, 5])))
# TODO: not sure why not updated graph isn't reflected in last_optimized_graph
self.assertLastGraphAllFused()
def test_dynamic_cat(self):
with inline_fusion_groups():
@torch.jit.script

View File

@ -918,6 +918,17 @@ ExprHandle TensorExprKernel::getStrideArg(
return it->second;
}
std::vector<torch::jit::StrideInput>& TensorExprKernel::
getSymbolicInputStrideDesc(const torch::jit::Value* value) {
for (size_t i : c10::irange(graph_->inputs().size())) {
if (value == graph_->inputs().at(i)) {
TORCH_INTERNAL_ASSERT(sym_stride_inputs_.count(i));
return sym_stride_inputs_[i];
}
}
TORCH_INTERNAL_ASSERT(false);
}
std::vector<ExprHandle> TensorExprKernel::getInputStrides(
const torch::jit::Value* input,
const std::vector<ExprHandle>& inputTensorDims) {
@ -933,8 +944,7 @@ std::vector<ExprHandle> TensorExprKernel::getInputStrides(
}
size_t rank = inputTensorDims.size();
TORCH_INTERNAL_ASSERT(symbolic_strides_.count(input));
std::vector<StrideInput>& stride_input = symbolic_strides_[input];
std::vector<StrideInput>& stride_input = getSymbolicInputStrideDesc(input);
if (stride_input.size() == 1 &&
(stride_input[0] == StrideInput::TENSOR_CONT_CHANNELS_LAST ||
stride_input[0] == StrideInput::TENSOR_CONT)) {
@ -999,14 +1009,17 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) {
auto tt = input->type()->cast<TensorType>();
bool contiguous_concrete_tensor =
(input->isCompleteTensor() && isContiguous(input));
bool contiguous_strided_tensor = symbolic_strides_.count(input) &&
symbolic_strides_[input].size() == 1 &&
symbolic_strides_[input][0] == torch::jit::StrideInput::TENSOR_CONT;
bool contiguous_sym_tensor = false;
if (has_symbolic_shapes_) {
auto desc = getSymbolicInputStrideDesc(input);
contiguous_sym_tensor =
desc.size() == 1 && desc[0] == torch::jit::StrideInput::TENSOR_CONT;
}
// We don't need to copy the input if:
// 1) it is not an output AND
// 2) it is contiguous
bool contiguous = contiguous_concrete_tensor || contiguous_strided_tensor;
bool contiguous = contiguous_concrete_tensor || contiguous_sym_tensor;
if (!outputs_set.count(input) && contiguous) {
BufHandle inBuffer(
"t" + input_name_map_[input],
@ -1382,8 +1395,7 @@ BlockPtr TensorExprKernel::bindAllInputs() {
if (!tt) {
continue;
}
TORCH_INTERNAL_ASSERT(symbolic_strides_.count(input));
auto symbolic_stride = symbolic_strides_[input];
auto symbolic_stride = getSymbolicInputStrideDesc(input);
for (size_t j = 0; j < symbolic_stride.size(); ++j) {
if (symbolic_stride[j] == torch::jit::StrideInput::S_AS_ARG) {
VarHandle v("v" + input_name_map_[input], kLong);
@ -1492,7 +1504,8 @@ void TensorExprKernel::compile() {
}
// Move output operands from `bufs_` to `bufOutputs_`
for (auto& output : graph_->outputs()) {
for (auto i : c10::irange(graph_->outputs().size())) {
auto& output = graph_->outputs().at(i);
if (!bufs_.count(output)) {
throw malformed_input("cannot find output Tensor");
}
@ -1513,10 +1526,9 @@ void TensorExprKernel::compile() {
if (has_symbolic_shapes_) {
auto sizes = sizesFromSymbolicShape(tt->symbolic_sizes());
tensorOutputSymbolicSizes_.push_back(sizes);
TORCH_INTERNAL_ASSERT(symbolic_strides_.count(output));
auto stride_desc = symbolic_strides_[output];
TORCH_INTERNAL_ASSERT(stride_desc.size() == 1);
tensorOutputStrideDesc_.push_back(stride_desc[0]);
TORCH_INTERNAL_ASSERT(sym_stride_outputs_.count(i));
auto stride_desc = sym_stride_outputs_[i];
tensorOutputStrideDesc_.push_back(stride_desc);
Tensor properly_strided_output =
convertSymbolicOutputToCorrectStrides(output);
if (properly_strided_output.stmt()) {
@ -1595,8 +1607,22 @@ TensorExprKernel::TensorExprKernel(
symbolic_shape_inputs_(std::move(symbolic_shape_inputs)),
custom_lowerings_(std::move(custom_lowerings)),
pre_alloc_(pre_alloc),
kernel_func_name_(kernel_func_name),
symbolic_strides_(std::move(symbolic_strides)) {
kernel_func_name_(kernel_func_name) {
// convert symbolic_stride to map by output and input index,
// since we may manipulate output pointers in graph manipulation
for (size_t i : c10::irange(graph_->inputs().size())) {
if (symbolic_strides.count(graph_->inputs().at(i))) {
sym_stride_inputs_[i] = symbolic_strides[graph_->inputs().at(i)];
}
}
for (size_t i : c10::irange(graph_->outputs().size())) {
if (symbolic_strides.count(graph_->outputs().at(i))) {
auto& desc = symbolic_strides[graph_->outputs().at(i)];
TORCH_INTERNAL_ASSERT(desc.size() == 1);
sym_stride_outputs_[i] = desc[0];
}
}
allow_fallback_ = fallbackAllowed();
if (!allow_fallback_) {

View File

@ -271,6 +271,8 @@ class TORCH_API TensorExprKernel {
std::vector<ExprHandle> getInputStrides(
const torch::jit::Value* input,
const std::vector<ExprHandle>& inputTensorDims);
std::vector<torch::jit::StrideInput>& getSymbolicInputStrideDesc(
const torch::jit::Value* value);
int64_t nInputs_ = 0;
int64_t nOutputs_ = 0;
@ -319,10 +321,9 @@ class TORCH_API TensorExprKernel {
// map from <input index, tensor dimension> to stride as arg VarHandle
std::unordered_map<std::pair<size_t, size_t>, VarHandle, SmallSizeTPairHash>
strideArgToVar_;
std::unordered_map<
const torch::jit::Value*,
std::vector<torch::jit::StrideInput>>
symbolic_strides_;
std::unordered_map<size_t, std::vector<torch::jit::StrideInput>>
sym_stride_inputs_;
std::unordered_map<size_t, torch::jit::StrideInput> sym_stride_outputs_;
};
TORCH_API int& getTECudaPointwiseLoopLevels();