mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert D33834916: Set correct device id on efficientzerotensors
Test Plan: revert-hammer Differential Revision: D33834916 (a18cfb790d) Original commit changeset: 11cec343e95e Original Phabricator Diff: D33834916 (a18cfb790d) fbshipit-source-id: 3d3f60b760b445383768161b1d21ea4dadbe5d7c (cherry picked from commiteba41aa646)
This commit is contained in:
parent
6208c2800e
commit
1e4aefaa2f
|
|
@ -427,6 +427,27 @@ Tensor& eye_out_cpu(int64_t n, int64_t m, Tensor& result) {
|
|||
|
||||
namespace {
|
||||
|
||||
// The ZeroTensor allocator ignores whatever allocation is requested and always
|
||||
// gives you nullptr
|
||||
struct ZeroTensorAllocator final : public at::Allocator {
|
||||
ZeroTensorAllocator(at::Device device) : device_(device) {};
|
||||
~ZeroTensorAllocator() override = default;
|
||||
static void deleter(void* const pointer) {
|
||||
TORCH_INTERNAL_ASSERT(!pointer);
|
||||
}
|
||||
DataPtr allocate(const size_t nbytes) const override {
|
||||
return {nullptr, nullptr, &deleter, device_};
|
||||
}
|
||||
DeleterFnPtr raw_deleter() const override {
|
||||
return deleter;
|
||||
}
|
||||
at::Device device_;
|
||||
};
|
||||
|
||||
at::Allocator* GetZeroTensorAllocator(ZeroTensorAllocator& zt) {
|
||||
return &zt;
|
||||
}
|
||||
|
||||
// Performs dtype inference for full
|
||||
TensorOptions infer_full_options(
|
||||
const Scalar& fill_value,
|
||||
|
|
@ -1055,11 +1076,11 @@ Tensor _efficientzerotensor(IntArrayRef size,
|
|||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
auto device_ = device_or_default(device);
|
||||
auto allocator = at::native::ZeroTensorAllocator(device_);
|
||||
auto allocator = ZeroTensorAllocator(device_);
|
||||
auto dtype_ = dtype_or_default(dtype);
|
||||
auto zero_ks = at::DispatchKeySet(c10::DispatchKey::CPU) | at::DispatchKeySet(c10::DispatchKey::ZeroTensor);
|
||||
auto out = at::detail::empty_generic(size, &allocator, zero_ks, dtype_, c10::nullopt);
|
||||
return out;
|
||||
constexpr auto zero_ks = at::DispatchKeySet(at::DispatchKey::ZeroTensor);
|
||||
return at::detail::empty_generic(
|
||||
size, &allocator, zero_ks, dtype_, c10::nullopt);
|
||||
}
|
||||
|
||||
Tensor& zeros_out(IntArrayRef size, Tensor& result) {
|
||||
|
|
|
|||
|
|
@ -87,23 +87,6 @@ inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tens
|
|||
}
|
||||
}
|
||||
|
||||
// The ZeroTensor allocator ignores whatever allocation is requested and always
|
||||
// gives you nullptr
|
||||
struct ZeroTensorAllocator final : public at::Allocator {
|
||||
ZeroTensorAllocator(at::Device device) : device_(device) {};
|
||||
~ZeroTensorAllocator() override = default;
|
||||
static void deleter(void* const pointer) {
|
||||
TORCH_INTERNAL_ASSERT(!pointer);
|
||||
}
|
||||
DataPtr allocate(const size_t nbytes) const override {
|
||||
return {nullptr, nullptr, &deleter, device_};
|
||||
}
|
||||
DeleterFnPtr raw_deleter() const override {
|
||||
return deleter;
|
||||
}
|
||||
at::Device device_;
|
||||
};
|
||||
|
||||
using binary_fn = void (*)(TensorIterator&);
|
||||
|
||||
DECLARE_DISPATCH(binary_fn, complex_stub);
|
||||
|
|
|
|||
|
|
@ -40,23 +40,6 @@ Tensor empty_cuda(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::op
|
|||
return at::detail::empty_cuda(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
|
||||
}
|
||||
|
||||
Tensor _efficientzerotensor_cuda(IntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
auto device_ = device_or_default(device);
|
||||
if (!device_.has_index()) {
|
||||
device_.set_index(at::cuda::current_device());
|
||||
}
|
||||
auto allocator = at::native::ZeroTensorAllocator(device_);
|
||||
auto dtype_ = dtype_or_default(dtype);
|
||||
auto zero_ks = at::DispatchKeySet(c10::DispatchKey::CUDA) | at::DispatchKeySet(c10::DispatchKey::ZeroTensor);
|
||||
auto out = at::detail::empty_generic(size, &allocator, zero_ks, dtype_, c10::nullopt);
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
|
||||
return at::detail::empty_strided_cuda(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4807,8 +4807,7 @@
|
|||
|
||||
- func: _efficientzerotensor(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
dispatch:
|
||||
CPU: _efficientzerotensor
|
||||
CUDA: _efficientzerotensor_cuda
|
||||
CompositeExplicitAutograd: _efficientzerotensor
|
||||
|
||||
- func: zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
|
||||
|
|
|
|||
|
|
@ -5477,10 +5477,6 @@ class TestDevicePrecision(TestCase):
|
|||
actual = x[..., :1].clamp(lb, ub)
|
||||
self.assertEqual(expect, actual)
|
||||
|
||||
def test_cuda_device_idx(self, device):
|
||||
x = torch.zeros(3, device=device)
|
||||
y = torch._efficientzerotensor(3, device=device)
|
||||
self.assertEqual(x.device, y.device)
|
||||
|
||||
# we implemented custom deallocation for subclasses, so it behooves
|
||||
# us to make sure all of these bits work. We'll use __del__ to
|
||||
|
|
|
|||
|
|
@ -9069,6 +9069,9 @@ op_db: List[OpInfo] = [
|
|||
assert_autodiffed=True,
|
||||
rhs_make_tensor_kwargs=dict(exclude_zero=True),
|
||||
skips=(
|
||||
# 69913: RuntimeError: CUDA error: an illegal memory access was encountered
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_fwgrad_bwgrad',
|
||||
device_type='cuda', dtypes=[torch.double, torch.cdouble]),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD',
|
||||
device_type='cuda', dtypes=[torch.double, torch.cdouble]),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_inplace_forward_mode_AD',
|
||||
|
|
@ -9085,6 +9088,9 @@ op_db: List[OpInfo] = [
|
|||
assert_autodiffed=True,
|
||||
rhs_make_tensor_kwargs=dict(exclude_zero=True),
|
||||
skips=(
|
||||
# 69913: RuntimeError: CUDA error: an illegal memory access was encountered
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_fwgrad_bwgrad',
|
||||
device_type='cuda', dtypes=[torch.double, torch.cdouble]),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD',
|
||||
device_type='cuda', dtypes=[torch.double, torch.cdouble]),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_inplace_forward_mode_AD',
|
||||
|
|
@ -9101,6 +9107,9 @@ op_db: List[OpInfo] = [
|
|||
assert_autodiffed=True,
|
||||
rhs_make_tensor_kwargs=dict(exclude_zero=True),
|
||||
skips=(
|
||||
# 69913: RuntimeError: CUDA error: an illegal memory access was encountered
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_fwgrad_bwgrad',
|
||||
device_type='cuda', dtypes=[torch.double, torch.cdouble]),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD',
|
||||
device_type='cuda', dtypes=[torch.double, torch.cdouble]),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_inplace_forward_mode_AD',
|
||||
|
|
@ -9681,6 +9690,11 @@ op_db: List[OpInfo] = [
|
|||
# RuntimeError:
|
||||
# Arguments for call are not valid.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, torch.complex64)), # noqa: B950
|
||||
# 69925: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
|
||||
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', device_type='cuda'),
|
||||
# (ROCm) Memory exception on virtual address 0x7f6f3deb7000, node id 4: Page not present
|
||||
DecorateInfo(unittest.skip("Skipped! ROCm memory exception"), 'TestGradients', 'test_fn_fwgrad_bwgrad',
|
||||
device_type='cuda', dtypes=[torch.float64, torch.complex128], active_if=TEST_WITH_ROCM),
|
||||
),
|
||||
supports_inplace_autograd=False,
|
||||
sample_inputs_func=sample_inputs_gradient),
|
||||
|
|
@ -14018,7 +14032,15 @@ op_db: List[OpInfo] = [
|
|||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
supports_out=False,
|
||||
sample_inputs_func=sample_cumulative_trapezoid,),
|
||||
sample_inputs_func=sample_cumulative_trapezoid,
|
||||
skips=(
|
||||
# Two failures:
|
||||
# 1. (CUDA) RuntimeError: Expected all tensors to be on the same device, but found at
|
||||
# least two devices, cuda:0 and cpu!
|
||||
# 2. (ROCm) Memory exception on virtual address 0x7f6a2216f000, node id 4: Page not present
|
||||
DecorateInfo(unittest.skip("Skipped! ROCm memory exception"), 'TestGradients',
|
||||
'test_fn_fwgrad_bwgrad', device_type='cuda'),
|
||||
)),
|
||||
OpInfo('unsqueeze',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user