mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable accelerator to perform streaming backward (#153412)
Also see https://github.com/pytorch/pytorch/pull/142097 Pull Request resolved: https://github.com/pytorch/pytorch/pull/153412 Approved by: https://github.com/albanD ghstack dependencies: #151079
This commit is contained in:
parent
71c8231742
commit
d5d26ce436
|
|
@ -196,9 +196,7 @@ void InputBuffer::add(
|
|||
}
|
||||
const auto device = var.device();
|
||||
const auto device_type = device.type();
|
||||
// TODO: Use at::accelerator::isAccelerator(device->type()) instead
|
||||
bool is_accelerator =
|
||||
device.is_cuda() || device.is_mtia() || device.is_privateuseone();
|
||||
bool is_accelerator = at::accelerator::isAccelerator(device.type());
|
||||
//
|
||||
// Non-accelerator case
|
||||
//
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user