Make the fuser raise NotImplementedError when unknown device is hit (#54709)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54709

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: eellison

Differential Revision: D27338815

Pulled By: ezyang

fbshipit-source-id: 5cbaf3c19b9b85cc3f171f3b405d0cd98f832e65
This commit is contained in:
Edward Yang 2021-03-27 11:50:04 -07:00 committed by Facebook GitHub Bot
parent 6445c9a1cb
commit c782949e17
2 changed files with 4 additions and 2 deletions

View File

@ -187,8 +187,9 @@ struct GraphFuser {
return canFuseOnGPU();
} else if ((*device).is_xpu()) {
return false;
} else {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Unknown device for graph fuser");
}
throw std::runtime_error("Unknown device");
}
// Default fusability check - used when the user doesn't pass in

View File

@ -888,8 +888,9 @@ class TensorExprFuser {
return canFuseOnGPU();
} else if (device->is_xpu()) {
return false;
} else {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Unknown device for tensorexpr fuser")
}
throw std::runtime_error("Unknown device");
}
bool isFusableOnDevice(Node* node) {