Add ZT fastpath for torch.{dot, vdot} (#71129)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71129

cc ezyang albanD zou3519 gqchen pearu nikitaved soulitzer Lezcano Varal7

Test Plan: Imported from OSS

Reviewed By: jbschlosser

Differential Revision: D34012577

Pulled By: anjali411

fbshipit-source-id: 02d2f2d761f7c9332e2f3cc529e8f1c6b60d7da2
(cherry picked from commit 87318a2e0d)
This commit is contained in:
anjali411 2022-02-07 08:33:21 -08:00 committed by PyTorch MergeBot
parent 4e98a4b6e3
commit 9d8f0c7842
3 changed files with 18 additions and 0 deletions

View File

@ -93,6 +93,8 @@ namespace at {
m.impl("add.Scalar", torch::CppFunction::makeFallthrough());
m.impl("copy_", torch::CppFunction::makeFallthrough());
m.impl("clone", torch::CppFunction::makeFallthrough());
m.impl("dot", torch::CppFunction::makeFallthrough());
m.impl("vdot", torch::CppFunction::makeFallthrough());
// The functions in the list below have a specific registeration in native_functions.yaml and
// do not use the fallback.
// m.impl("mul.Tensor", torch::CppFunction::makeFallthrough());

View File

@ -154,6 +154,10 @@ Tensor dot(const Tensor &self, const Tensor &other){
at::NoNamesGuard guard;
dot_check(self, other);
if (self._is_zerotensor() || other._is_zerotensor()) {
return at::_efficientzerotensor({}, self.options());
}
if (use_mkldnn_bf16_matmul(self, other, /*result=*/Tensor())){
// mkldnn matmul expect result have sizes info to create ideep tensor
auto r = at::empty({1, 1}, self.options());
@ -188,6 +192,10 @@ Tensor vdot(const Tensor &self, const Tensor &other){
// For complex dtypes.
dot_check(self, other);
if (self._is_zerotensor() || other._is_zerotensor()) {
return at::_efficientzerotensor({}, self.options());
}
return AT_DISPATCH_COMPLEX_TYPES(self.scalar_type(), "vdot", [&] {
Tensor result = at::empty({}, self.options());
result.fill_(vdot_impl<scalar_t>(self.numel(), self.data_ptr<scalar_t>(), self.stride(0), other.data_ptr<scalar_t>(), other.stride(0)));

View File

@ -357,6 +357,10 @@ Tensor dot_cuda(const Tensor& self, const Tensor& other) {
incy = 1;
}
if (self._is_zerotensor() || other._is_zerotensor()) {
return at::_efficientzerotensor({}, self.options());
}
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
self.scalar_type(), "dot",
@ -396,6 +400,10 @@ Tensor vdot_cuda(const Tensor& self, const Tensor& other) {
at::NoNamesGuard guard;
dot_check(self, other);
if (self._is_zerotensor() || other._is_zerotensor()) {
return at::_efficientzerotensor({}, self.options());
}
const int n = static_cast<int>(self.numel());
int incx = static_cast<int>(self.stride(0));
int incy = static_cast<int>(other.stride(0));