mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[BLAS] Avoid downcasts for fp16fp16->fp32 BLAS (#161999)"
This reverts commit 02c83f1334.
Reverted https://github.com/pytorch/pytorch/pull/161999 on behalf of https://github.com/jeanschmidt due to break a few internal tests ([comment](https://github.com/pytorch/pytorch/pull/161999#issuecomment-3255381925))
This commit is contained in:
parent
afa6e5604d
commit
c3d54dea9f
|
|
@ -496,18 +496,18 @@ void gemm(
|
|||
// for the fallback path, first compute gemm with beta = 0,
|
||||
// and then add c in full precision.
|
||||
int64_t c_size = n * m;
|
||||
std::vector<float> float_c(c_size, 0.f);
|
||||
gemm_no_downcast_stub(
|
||||
std::vector<at::Half> float16_c(c_size, 0.f);
|
||||
gemm_stub(
|
||||
at::kCPU, at::kHalf,
|
||||
transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float_c.data(), m);
|
||||
transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float16_c.data(), m);
|
||||
for (const auto j : c10::irange(n)) {
|
||||
for (const auto i : c10::irange(m)) {
|
||||
auto offset = j * ldc + i;
|
||||
// beta == 0 won't propagate NaN from C
|
||||
if (beta == 0.f) {
|
||||
c[offset] = float_c[j * m + i];
|
||||
c[offset] = c10::convert<float>(float16_c[j * m + i]);
|
||||
} else {
|
||||
c[offset] = beta * c[offset] + float_c[j * m + i];
|
||||
c[offset] = beta * c[offset] + c10::convert<float>(float16_c[j * m + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user