Revert "nn.Linear: nD contiguous input + bias -- dispatch to addmm also when weight is sparse (#166071)"

This reverts commit 467c21ad9a.

Reverted https://github.com/pytorch/pytorch/pull/166071 on behalf of https://github.com/atalman due to Multiple CI breakages: test/profiler/test_profiler_tree.py::TestProfilerTree::test_profiler_experimental_tree_with_stack_and_modules [GH job link](https://github.com/pytorch/pytorch/actions/runs/18909087335/job/53976915830) [HUD commit link](467c21ad9a) ([comment](https://github.com/pytorch/pytorch/pull/166071#issuecomment-3462458968))
This commit is contained in:
PyTorch MergeBot 2025-10-29 16:05:27 +00:00
parent 14102fb1f3
commit c594950e86

View File

@ -50,35 +50,18 @@ static inline bool parseLinearFlatten3d() {
// `_flatten_nd_linear` flattens all but the last dimension of the input tensor // `_flatten_nd_linear` flattens all but the last dimension of the input tensor
// before passing it to linear operation // before passing it to linear operation
static inline Tensor _flatten_nd_linear(const Tensor& input, const Tensor& weight, const Tensor& bias) { static inline Tensor _flatten_nd_linear(const Tensor& input, const Tensor& weight, const Tensor& bias) {
const auto input_sizes = input.sym_sizes(); const auto input_sizes = input.sym_sizes();
// can't use -1 in reshape because it errors when a dimension is 0
const auto result_flattened = [&]() -> Tensor { c10::SymInt flattened_dim = 1;
const auto input_ncols = input_sizes.back(); for (int64_t i = 0, ndim = input_sizes.size(); i < ndim - 1; ++i) {
const auto input_flattened_nrows = [&]() -> c10::SymInt { flattened_dim = flattened_dim * input_sizes[i];
// can't use -1 in reshape because it errors when a dimension is 0
auto flattened_nrows = c10::SymInt{1};
for (const auto& size : input_sizes.slice(0, input_sizes.size() - 1)) {
flattened_nrows *= size;
}
return flattened_nrows;
}();
const auto input_flattened = input.view_symint({input_flattened_nrows, input_ncols});
if (weight.layout() == c10::kStrided) {
return at::addmm(bias, input_flattened, weight.t());
} else {
// weight is sparse, and addmm for sparse expects matmul lhs to be sparse,
// so we transpose the problem.
// NOTE: at::matmul handles (dense @ sparse) similarly.
const auto bias_t = (bias.dim() >= 2) ? bias.mT() : bias.unsqueeze(-1);
return at::addmm(bias_t, weight, input_flattened.t()).t();
} }
}(); auto inp_reshape = input.reshape_symint({flattened_dim, input_sizes.at(input_sizes.size() -1)});
const auto result = at::addmm(bias, inp_reshape, weight.t());
// Unflatten flattened row dims auto new_size = input_sizes.slice(0, input_sizes.size() - 1);
auto result_sizes = c10::SymDimVector{input_sizes.begin(), input_sizes.end()}; c10::SymDimVector sizes_vec(new_size.begin(), new_size.end());
result_sizes.back() = result_flattened.sym_size(1); sizes_vec.push_back(result.sym_size(1));
return result_flattened.view_symint(result_sizes); return result.view_symint(sizes_vec);
} }
@ -107,23 +90,15 @@ Tensor linear(const Tensor& input, const Tensor& weight, const std::optional<Ten
// Fused op is marginally faster. // Fused op is marginally faster.
return at::addmm(*bias, input, weight.t()); return at::addmm(*bias, input, weight.t());
} }
if (bias->defined() && !input.is_xla()) {
const auto is_bias_likely_fusable = ( // Also hit the fused path for contiguous 3D input, if not using xla
bias->defined() &&
// cuBLASLt: will fuse in the epilogue without copies
// when input/weight/bias are all strided.
// When weight is not strided, bias will not be fused,
// but we can still dispatch here to avoid at::matmul
// path which will probably use a very similar
// flattening optimization.
(bias->dim() == 1 && bias->is_contiguous_or_false())
);
if (is_bias_likely_fusable && !input.is_xla()) {
// Also hit the fused path for contiguous nD input, if not using xla
// backend. Reshaping/flattening has some performance implications on xla. // backend. Reshaping/flattening has some performance implications on xla.
if (input.is_contiguous_or_false()) { bool is_contiguous = input.is_contiguous_or_false();
if (is_contiguous && input_dim == 3) {
return _flatten_nd_linear(input, weight, *bias); return _flatten_nd_linear(input, weight, *bias);
} else if (parseLinearFlatten3d()) { } else if (is_contiguous && input.layout() == c10::kStrided && weight.layout() == c10::kStrided && bias->dim() == 1) {
return _flatten_nd_linear(input, weight, *bias);
} else if (parseLinearFlatten3d() && input_dim == 3) {
// If user forces flattening via env var // If user forces flattening via env var
const Tensor input_cont = input.contiguous(); const Tensor input_cont = input.contiguous();
return _flatten_nd_linear(input_cont, weight, *bias); return _flatten_nd_linear(input_cont, weight, *bias);