Revert D33834916: Set correct device id on efficientzerotensors

Test Plan: revert-hammer

Differential Revision:
D33834916 (a18cfb790d)

Original commit changeset: 11cec343e95e

Original Phabricator Diff: D33834916 (a18cfb790d)

fbshipit-source-id: 3d3f60b760b445383768161b1d21ea4dadbe5d7c
(cherry picked from commit eba41aa646)
This commit is contained in:
Anjali Chourdia 2022-01-30 19:45:25 -08:00 committed by PyTorch MergeBot
parent 6208c2800e
commit 1e4aefaa2f
6 changed files with 49 additions and 45 deletions

View File

@ -427,6 +427,27 @@ Tensor& eye_out_cpu(int64_t n, int64_t m, Tensor& result) {
namespace { namespace {
// The ZeroTensor allocator ignores whatever allocation is requested and always
// gives you nullptr
struct ZeroTensorAllocator final : public at::Allocator {
ZeroTensorAllocator(at::Device device) : device_(device) {};
~ZeroTensorAllocator() override = default;
static void deleter(void* const pointer) {
TORCH_INTERNAL_ASSERT(!pointer);
}
DataPtr allocate(const size_t nbytes) const override {
return {nullptr, nullptr, &deleter, device_};
}
DeleterFnPtr raw_deleter() const override {
return deleter;
}
at::Device device_;
};
at::Allocator* GetZeroTensorAllocator(ZeroTensorAllocator& zt) {
return &zt;
}
// Performs dtype inference for full // Performs dtype inference for full
TensorOptions infer_full_options( TensorOptions infer_full_options(
const Scalar& fill_value, const Scalar& fill_value,
@ -1055,11 +1076,11 @@ Tensor _efficientzerotensor(IntArrayRef size,
c10::optional<Device> device, c10::optional<Device> device,
c10::optional<bool> pin_memory) { c10::optional<bool> pin_memory) {
auto device_ = device_or_default(device); auto device_ = device_or_default(device);
auto allocator = at::native::ZeroTensorAllocator(device_); auto allocator = ZeroTensorAllocator(device_);
auto dtype_ = dtype_or_default(dtype); auto dtype_ = dtype_or_default(dtype);
auto zero_ks = at::DispatchKeySet(c10::DispatchKey::CPU) | at::DispatchKeySet(c10::DispatchKey::ZeroTensor); constexpr auto zero_ks = at::DispatchKeySet(at::DispatchKey::ZeroTensor);
auto out = at::detail::empty_generic(size, &allocator, zero_ks, dtype_, c10::nullopt); return at::detail::empty_generic(
return out; size, &allocator, zero_ks, dtype_, c10::nullopt);
} }
Tensor& zeros_out(IntArrayRef size, Tensor& result) { Tensor& zeros_out(IntArrayRef size, Tensor& result) {

View File

@ -87,23 +87,6 @@ inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tens
} }
} }
// The ZeroTensor allocator ignores whatever allocation is requested and always
// gives you nullptr
struct ZeroTensorAllocator final : public at::Allocator {
ZeroTensorAllocator(at::Device device) : device_(device) {};
~ZeroTensorAllocator() override = default;
static void deleter(void* const pointer) {
TORCH_INTERNAL_ASSERT(!pointer);
}
DataPtr allocate(const size_t nbytes) const override {
return {nullptr, nullptr, &deleter, device_};
}
DeleterFnPtr raw_deleter() const override {
return deleter;
}
at::Device device_;
};
using binary_fn = void (*)(TensorIterator&); using binary_fn = void (*)(TensorIterator&);
DECLARE_DISPATCH(binary_fn, complex_stub); DECLARE_DISPATCH(binary_fn, complex_stub);

View File

@ -40,23 +40,6 @@ Tensor empty_cuda(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::op
return at::detail::empty_cuda(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); return at::detail::empty_cuda(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
} }
Tensor _efficientzerotensor_cuda(IntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
auto device_ = device_or_default(device);
if (!device_.has_index()) {
device_.set_index(at::cuda::current_device());
}
auto allocator = at::native::ZeroTensorAllocator(device_);
auto dtype_ = dtype_or_default(dtype);
auto zero_ks = at::DispatchKeySet(c10::DispatchKey::CUDA) | at::DispatchKeySet(c10::DispatchKey::ZeroTensor);
auto out = at::detail::empty_generic(size, &allocator, zero_ks, dtype_, c10::nullopt);
return out;
}
Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) { Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
return at::detail::empty_strided_cuda(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt); return at::detail::empty_strided_cuda(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
} }

View File

@ -4807,8 +4807,7 @@
- func: _efficientzerotensor(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: _efficientzerotensor(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
dispatch: dispatch:
CPU: _efficientzerotensor CompositeExplicitAutograd: _efficientzerotensor
CUDA: _efficientzerotensor_cuda
- func: zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

View File

@ -5477,10 +5477,6 @@ class TestDevicePrecision(TestCase):
actual = x[..., :1].clamp(lb, ub) actual = x[..., :1].clamp(lb, ub)
self.assertEqual(expect, actual) self.assertEqual(expect, actual)
def test_cuda_device_idx(self, device):
x = torch.zeros(3, device=device)
y = torch._efficientzerotensor(3, device=device)
self.assertEqual(x.device, y.device)
# we implemented custom deallocation for subclasses, so it behooves # we implemented custom deallocation for subclasses, so it behooves
# us to make sure all of these bits work. We'll use __del__ to # us to make sure all of these bits work. We'll use __del__ to

View File

@ -9069,6 +9069,9 @@ op_db: List[OpInfo] = [
assert_autodiffed=True, assert_autodiffed=True,
rhs_make_tensor_kwargs=dict(exclude_zero=True), rhs_make_tensor_kwargs=dict(exclude_zero=True),
skips=( skips=(
# 69913: RuntimeError: CUDA error: an illegal memory access was encountered
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_fwgrad_bwgrad',
device_type='cuda', dtypes=[torch.double, torch.cdouble]),
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD', DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD',
device_type='cuda', dtypes=[torch.double, torch.cdouble]), device_type='cuda', dtypes=[torch.double, torch.cdouble]),
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_inplace_forward_mode_AD', DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_inplace_forward_mode_AD',
@ -9085,6 +9088,9 @@ op_db: List[OpInfo] = [
assert_autodiffed=True, assert_autodiffed=True,
rhs_make_tensor_kwargs=dict(exclude_zero=True), rhs_make_tensor_kwargs=dict(exclude_zero=True),
skips=( skips=(
# 69913: RuntimeError: CUDA error: an illegal memory access was encountered
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_fwgrad_bwgrad',
device_type='cuda', dtypes=[torch.double, torch.cdouble]),
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD', DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD',
device_type='cuda', dtypes=[torch.double, torch.cdouble]), device_type='cuda', dtypes=[torch.double, torch.cdouble]),
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_inplace_forward_mode_AD', DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_inplace_forward_mode_AD',
@ -9101,6 +9107,9 @@ op_db: List[OpInfo] = [
assert_autodiffed=True, assert_autodiffed=True,
rhs_make_tensor_kwargs=dict(exclude_zero=True), rhs_make_tensor_kwargs=dict(exclude_zero=True),
skips=( skips=(
# 69913: RuntimeError: CUDA error: an illegal memory access was encountered
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_fwgrad_bwgrad',
device_type='cuda', dtypes=[torch.double, torch.cdouble]),
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD', DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD',
device_type='cuda', dtypes=[torch.double, torch.cdouble]), device_type='cuda', dtypes=[torch.double, torch.cdouble]),
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_inplace_forward_mode_AD', DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_inplace_forward_mode_AD',
@ -9681,6 +9690,11 @@ op_db: List[OpInfo] = [
# RuntimeError: # RuntimeError:
# Arguments for call are not valid. # Arguments for call are not valid.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, torch.complex64)), # noqa: B950 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, torch.complex64)), # noqa: B950
# 69925: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', device_type='cuda'),
# (ROCm) Memory exception on virtual address 0x7f6f3deb7000, node id 4: Page not present
DecorateInfo(unittest.skip("Skipped! ROCm memory exception"), 'TestGradients', 'test_fn_fwgrad_bwgrad',
device_type='cuda', dtypes=[torch.float64, torch.complex128], active_if=TEST_WITH_ROCM),
), ),
supports_inplace_autograd=False, supports_inplace_autograd=False,
sample_inputs_func=sample_inputs_gradient), sample_inputs_func=sample_inputs_gradient),
@ -14018,7 +14032,15 @@ op_db: List[OpInfo] = [
supports_forward_ad=True, supports_forward_ad=True,
supports_fwgrad_bwgrad=True, supports_fwgrad_bwgrad=True,
supports_out=False, supports_out=False,
sample_inputs_func=sample_cumulative_trapezoid,), sample_inputs_func=sample_cumulative_trapezoid,
skips=(
# Two failures:
# 1. (CUDA) RuntimeError: Expected all tensors to be on the same device, but found at
# least two devices, cuda:0 and cpu!
# 2. (ROCm) Memory exception on virtual address 0x7f6a2216f000, node id 4: Page not present
DecorateInfo(unittest.skip("Skipped! ROCm memory exception"), 'TestGradients',
'test_fn_fwgrad_bwgrad', device_type='cuda'),
)),
OpInfo('unsqueeze', OpInfo('unsqueeze',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_out=False, supports_out=False,