mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[NNC] Use index for stride mapping in kernel.cpp (#72266)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72266
Within the kernel, we may manipulate `Value *` in `OptimizeCat`, which would invalidate the input `Value *` -> Stride mapping.
Fix for https://github.com/pytorch/pytorch/issues/72173
Test Plan: Imported from OSS
Reviewed By: dagitses, davidberard98
Differential Revision: D33986306
Pulled By: eellison
fbshipit-source-id: dc33cd2b545e49e90d1e46b9fcf1e6dbb4b829db
(cherry picked from commit 5e4555968a)
This commit is contained in:
parent
9d4d782e42
commit
defde3bb04
|
|
@ -2100,6 +2100,14 @@ class TestTEFuser(JitTestCase):
|
|||
return F.hardsigmoid(x) * 1.01
|
||||
self._test_fwd_bwd(eager)
|
||||
|
||||
def test_cat_graph_opt(self):
|
||||
def foo(x, y, z):
|
||||
return torch.log(torch.cat([x, y, z]))
|
||||
|
||||
self.checkScript(foo, (torch.rand([5, 5]), torch.rand([2, 5]), torch.rand([1, 5])))
|
||||
# TODO: not sure why not updated graph isn't reflected in last_optimized_graph
|
||||
self.assertLastGraphAllFused()
|
||||
|
||||
def test_dynamic_cat(self):
|
||||
with inline_fusion_groups():
|
||||
@torch.jit.script
|
||||
|
|
|
|||
|
|
@ -918,6 +918,17 @@ ExprHandle TensorExprKernel::getStrideArg(
|
|||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<torch::jit::StrideInput>& TensorExprKernel::
|
||||
getSymbolicInputStrideDesc(const torch::jit::Value* value) {
|
||||
for (size_t i : c10::irange(graph_->inputs().size())) {
|
||||
if (value == graph_->inputs().at(i)) {
|
||||
TORCH_INTERNAL_ASSERT(sym_stride_inputs_.count(i));
|
||||
return sym_stride_inputs_[i];
|
||||
}
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(false);
|
||||
}
|
||||
|
||||
std::vector<ExprHandle> TensorExprKernel::getInputStrides(
|
||||
const torch::jit::Value* input,
|
||||
const std::vector<ExprHandle>& inputTensorDims) {
|
||||
|
|
@ -933,8 +944,7 @@ std::vector<ExprHandle> TensorExprKernel::getInputStrides(
|
|||
}
|
||||
|
||||
size_t rank = inputTensorDims.size();
|
||||
TORCH_INTERNAL_ASSERT(symbolic_strides_.count(input));
|
||||
std::vector<StrideInput>& stride_input = symbolic_strides_[input];
|
||||
std::vector<StrideInput>& stride_input = getSymbolicInputStrideDesc(input);
|
||||
if (stride_input.size() == 1 &&
|
||||
(stride_input[0] == StrideInput::TENSOR_CONT_CHANNELS_LAST ||
|
||||
stride_input[0] == StrideInput::TENSOR_CONT)) {
|
||||
|
|
@ -999,14 +1009,17 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) {
|
|||
auto tt = input->type()->cast<TensorType>();
|
||||
bool contiguous_concrete_tensor =
|
||||
(input->isCompleteTensor() && isContiguous(input));
|
||||
bool contiguous_strided_tensor = symbolic_strides_.count(input) &&
|
||||
symbolic_strides_[input].size() == 1 &&
|
||||
symbolic_strides_[input][0] == torch::jit::StrideInput::TENSOR_CONT;
|
||||
bool contiguous_sym_tensor = false;
|
||||
if (has_symbolic_shapes_) {
|
||||
auto desc = getSymbolicInputStrideDesc(input);
|
||||
contiguous_sym_tensor =
|
||||
desc.size() == 1 && desc[0] == torch::jit::StrideInput::TENSOR_CONT;
|
||||
}
|
||||
|
||||
// We don't need to copy the input if:
|
||||
// 1) it is not an output AND
|
||||
// 2) it is contiguous
|
||||
bool contiguous = contiguous_concrete_tensor || contiguous_strided_tensor;
|
||||
bool contiguous = contiguous_concrete_tensor || contiguous_sym_tensor;
|
||||
if (!outputs_set.count(input) && contiguous) {
|
||||
BufHandle inBuffer(
|
||||
"t" + input_name_map_[input],
|
||||
|
|
@ -1382,8 +1395,7 @@ BlockPtr TensorExprKernel::bindAllInputs() {
|
|||
if (!tt) {
|
||||
continue;
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(symbolic_strides_.count(input));
|
||||
auto symbolic_stride = symbolic_strides_[input];
|
||||
auto symbolic_stride = getSymbolicInputStrideDesc(input);
|
||||
for (size_t j = 0; j < symbolic_stride.size(); ++j) {
|
||||
if (symbolic_stride[j] == torch::jit::StrideInput::S_AS_ARG) {
|
||||
VarHandle v("v" + input_name_map_[input], kLong);
|
||||
|
|
@ -1492,7 +1504,8 @@ void TensorExprKernel::compile() {
|
|||
}
|
||||
|
||||
// Move output operands from `bufs_` to `bufOutputs_`
|
||||
for (auto& output : graph_->outputs()) {
|
||||
for (auto i : c10::irange(graph_->outputs().size())) {
|
||||
auto& output = graph_->outputs().at(i);
|
||||
if (!bufs_.count(output)) {
|
||||
throw malformed_input("cannot find output Tensor");
|
||||
}
|
||||
|
|
@ -1513,10 +1526,9 @@ void TensorExprKernel::compile() {
|
|||
if (has_symbolic_shapes_) {
|
||||
auto sizes = sizesFromSymbolicShape(tt->symbolic_sizes());
|
||||
tensorOutputSymbolicSizes_.push_back(sizes);
|
||||
TORCH_INTERNAL_ASSERT(symbolic_strides_.count(output));
|
||||
auto stride_desc = symbolic_strides_[output];
|
||||
TORCH_INTERNAL_ASSERT(stride_desc.size() == 1);
|
||||
tensorOutputStrideDesc_.push_back(stride_desc[0]);
|
||||
TORCH_INTERNAL_ASSERT(sym_stride_outputs_.count(i));
|
||||
auto stride_desc = sym_stride_outputs_[i];
|
||||
tensorOutputStrideDesc_.push_back(stride_desc);
|
||||
Tensor properly_strided_output =
|
||||
convertSymbolicOutputToCorrectStrides(output);
|
||||
if (properly_strided_output.stmt()) {
|
||||
|
|
@ -1595,8 +1607,22 @@ TensorExprKernel::TensorExprKernel(
|
|||
symbolic_shape_inputs_(std::move(symbolic_shape_inputs)),
|
||||
custom_lowerings_(std::move(custom_lowerings)),
|
||||
pre_alloc_(pre_alloc),
|
||||
kernel_func_name_(kernel_func_name),
|
||||
symbolic_strides_(std::move(symbolic_strides)) {
|
||||
kernel_func_name_(kernel_func_name) {
|
||||
// convert symbolic_stride to map by output and input index,
|
||||
// since we may manipulate output pointers in graph manipulation
|
||||
for (size_t i : c10::irange(graph_->inputs().size())) {
|
||||
if (symbolic_strides.count(graph_->inputs().at(i))) {
|
||||
sym_stride_inputs_[i] = symbolic_strides[graph_->inputs().at(i)];
|
||||
}
|
||||
}
|
||||
for (size_t i : c10::irange(graph_->outputs().size())) {
|
||||
if (symbolic_strides.count(graph_->outputs().at(i))) {
|
||||
auto& desc = symbolic_strides[graph_->outputs().at(i)];
|
||||
TORCH_INTERNAL_ASSERT(desc.size() == 1);
|
||||
sym_stride_outputs_[i] = desc[0];
|
||||
}
|
||||
}
|
||||
|
||||
allow_fallback_ = fallbackAllowed();
|
||||
|
||||
if (!allow_fallback_) {
|
||||
|
|
|
|||
|
|
@ -271,6 +271,8 @@ class TORCH_API TensorExprKernel {
|
|||
std::vector<ExprHandle> getInputStrides(
|
||||
const torch::jit::Value* input,
|
||||
const std::vector<ExprHandle>& inputTensorDims);
|
||||
std::vector<torch::jit::StrideInput>& getSymbolicInputStrideDesc(
|
||||
const torch::jit::Value* value);
|
||||
|
||||
int64_t nInputs_ = 0;
|
||||
int64_t nOutputs_ = 0;
|
||||
|
|
@ -319,10 +321,9 @@ class TORCH_API TensorExprKernel {
|
|||
// map from <input index, tensor dimension> to stride as arg VarHandle
|
||||
std::unordered_map<std::pair<size_t, size_t>, VarHandle, SmallSizeTPairHash>
|
||||
strideArgToVar_;
|
||||
std::unordered_map<
|
||||
const torch::jit::Value*,
|
||||
std::vector<torch::jit::StrideInput>>
|
||||
symbolic_strides_;
|
||||
std::unordered_map<size_t, std::vector<torch::jit::StrideInput>>
|
||||
sym_stride_inputs_;
|
||||
std::unordered_map<size_t, torch::jit::StrideInput> sym_stride_outputs_;
|
||||
};
|
||||
|
||||
TORCH_API int& getTECudaPointwiseLoopLevels();
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user