Enable accelerator to perform streaming backward (#153412)

Also see https://github.com/pytorch/pytorch/pull/142097
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153412
Approved by: https://github.com/albanD
ghstack dependencies: #151079
This commit is contained in:
soulitzer 2025-05-12 14:27:34 -07:00 committed by PyTorch MergeBot
parent 71c8231742
commit d5d26ce436

View File

@ -196,9 +196,7 @@ void InputBuffer::add(
}
const auto device = var.device();
const auto device_type = device.type();
// TODO: Use at::accelerator::isAccelerator(device->type()) instead
bool is_accelerator =
device.is_cuda() || device.is_mtia() || device.is_privateuseone();
bool is_accelerator = at::accelerator::isAccelerator(device.type());
//
// Non-accelerator case
//