mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add ZT fastpath for torch.{dot, vdot} (#71129)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71129
cc ezyang albanD zou3519 gqchen pearu nikitaved soulitzer Lezcano Varal7
Test Plan: Imported from OSS
Reviewed By: jbschlosser
Differential Revision: D34012577
Pulled By: anjali411
fbshipit-source-id: 02d2f2d761f7c9332e2f3cc529e8f1c6b60d7da2
(cherry picked from commit 87318a2e0d)
This commit is contained in:
parent
4e98a4b6e3
commit
9d8f0c7842
|
|
@ -93,6 +93,8 @@ namespace at {
|
|||
m.impl("add.Scalar", torch::CppFunction::makeFallthrough());
|
||||
m.impl("copy_", torch::CppFunction::makeFallthrough());
|
||||
m.impl("clone", torch::CppFunction::makeFallthrough());
|
||||
m.impl("dot", torch::CppFunction::makeFallthrough());
|
||||
m.impl("vdot", torch::CppFunction::makeFallthrough());
|
||||
// The functions in the list below have a specific registeration in native_functions.yaml and
|
||||
// do not use the fallback.
|
||||
// m.impl("mul.Tensor", torch::CppFunction::makeFallthrough());
|
||||
|
|
|
|||
|
|
@ -154,6 +154,10 @@ Tensor dot(const Tensor &self, const Tensor &other){
|
|||
at::NoNamesGuard guard;
|
||||
dot_check(self, other);
|
||||
|
||||
if (self._is_zerotensor() || other._is_zerotensor()) {
|
||||
return at::_efficientzerotensor({}, self.options());
|
||||
}
|
||||
|
||||
if (use_mkldnn_bf16_matmul(self, other, /*result=*/Tensor())){
|
||||
// mkldnn matmul expect result have sizes info to create ideep tensor
|
||||
auto r = at::empty({1, 1}, self.options());
|
||||
|
|
@ -188,6 +192,10 @@ Tensor vdot(const Tensor &self, const Tensor &other){
|
|||
// For complex dtypes.
|
||||
dot_check(self, other);
|
||||
|
||||
if (self._is_zerotensor() || other._is_zerotensor()) {
|
||||
return at::_efficientzerotensor({}, self.options());
|
||||
}
|
||||
|
||||
return AT_DISPATCH_COMPLEX_TYPES(self.scalar_type(), "vdot", [&] {
|
||||
Tensor result = at::empty({}, self.options());
|
||||
result.fill_(vdot_impl<scalar_t>(self.numel(), self.data_ptr<scalar_t>(), self.stride(0), other.data_ptr<scalar_t>(), other.stride(0)));
|
||||
|
|
|
|||
|
|
@ -357,6 +357,10 @@ Tensor dot_cuda(const Tensor& self, const Tensor& other) {
|
|||
incy = 1;
|
||||
}
|
||||
|
||||
if (self._is_zerotensor() || other._is_zerotensor()) {
|
||||
return at::_efficientzerotensor({}, self.options());
|
||||
}
|
||||
|
||||
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
ScalarType::Half, ScalarType::BFloat16,
|
||||
self.scalar_type(), "dot",
|
||||
|
|
@ -396,6 +400,10 @@ Tensor vdot_cuda(const Tensor& self, const Tensor& other) {
|
|||
at::NoNamesGuard guard;
|
||||
dot_check(self, other);
|
||||
|
||||
if (self._is_zerotensor() || other._is_zerotensor()) {
|
||||
return at::_efficientzerotensor({}, self.options());
|
||||
}
|
||||
|
||||
const int n = static_cast<int>(self.numel());
|
||||
int incx = static_cast<int>(self.stride(0));
|
||||
int incy = static_cast<int>(other.stride(0));
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user