mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ Code changes includes: - codegen improvements: 1. double support in expression evaluator - bug fixes: 1. dropout fix - rework RNG to support broadcasted dropout (Fixes #82784) 2. expand fix - Patch expand+reduction, expand+view, rework view analysis and guard - scheduler: 1. manual transpose schedule example 2. WIP transpose scheduler Commits that's in this PR from the devel branch: ``` b7435afcd22c917713c2f41a7237bc26e1183f14 Transpose scheduler, step 1 (#1854) 8a45dbf72034684eb8e18b1835b533e90b68f184 Add an example on how to manually schedule transpose (#1889) 83dbf56a9554b2efbd5416461d938fff477b0b27 Patch dropout fix (#1898) 69d3519a532250719b1aa8341b50e067b181b42d Expand+Reduction, Expand+View support, rework View analysis and guards (#1883) 15091c488e96343bdc49e3990acbf238a3b3da51 Rework RNG to correctly support broadcasted dropout (#1888) aafe2d048aaac596e503596a41303423619f3954 Make ExpressionEvaluator support Double (#1885) ``` RUN_TORCHBENCH: nvfuser Differential Revision: [D38657074](https://our.internmc.facebook.com/intern/diff/D38657074) Pull Request resolved: https://github.com/pytorch/pytorch/pull/83239 Approved by: https://github.com/davidberard98
288 lines
8.6 KiB
C++
288 lines
8.6 KiB
C++
#include <torch/csrc/jit/codegen/cuda/partition.h>
|
|
|
|
#include <ATen/core/jit_type.h>
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <c10/util/irange.h>
|
|
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
|
#include <torch/csrc/jit/codegen/cuda/parser.h>
|
|
#include <torch/csrc/jit/codegen/cuda/utils.h>
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
namespace cuda {
|
|
|
|
const c10::DeviceIndex INVALID_INDEX = -2;
|
|
|
|
namespace {
|
|
|
|
bool hasNonElementWiseOperation(const Node* node) {
|
|
if (node->kind() == prim::CudaFusionGroup) {
|
|
for (auto n : node->g(attr::Subgraph)->nodes()) {
|
|
if (hasNonElementWiseOperation(n)) {
|
|
return true;
|
|
}
|
|
}
|
|
} else {
|
|
// prim::Constant is not parsible, but it is also not nonElementWise
|
|
if (node->kind() != prim::Constant && !isElementWiseNode(node)) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
// Check all outputs are:
|
|
// 1. TensorType
|
|
// 2. on the same device;
|
|
// TODO: update this when codegen can output scalar
|
|
static c10::optional<c10::Device> getDevice(const Value* value) {
|
|
if (!value->type()->isSubtypeOf(*TensorType::get())) {
|
|
// not tensor type, return false as the op is not outputing scalar.
|
|
return c10::nullopt;
|
|
}
|
|
auto tensor_type = value->type()->expectRef<TensorType>();
|
|
// special case for scalar tensor: return c10::nullopt instead of cpu device.
|
|
// this allows us to fuse scalar cpu tensor with cuda tensor, while avoid
|
|
// merging ops with pure scalar cpu tensors.
|
|
if (is_cpu_scalar(tensor_type)) {
|
|
return c10::nullopt;
|
|
}
|
|
return tensor_type.device();
|
|
}
|
|
|
|
static bool hasBfloat(const Node* node) {
|
|
auto has_bfloat = [](const Value* value) {
|
|
if (!value->type()->isSubtypeOf(*TensorType::get())) {
|
|
return false;
|
|
}
|
|
auto opt_scalar_type = value->type()->expectRef<TensorType>().scalarType();
|
|
if (opt_scalar_type.has_value() &&
|
|
opt_scalar_type.value() == at::ScalarType::BFloat16) {
|
|
return true;
|
|
}
|
|
return false;
|
|
};
|
|
|
|
if (std::any_of(node->inputs().begin(), node->inputs().end(), has_bfloat) ||
|
|
std::any_of(node->outputs().begin(), node->outputs().end(), has_bfloat)) {
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
static c10::optional<c10::Device> getDevice(const Node* node) {
|
|
c10::optional<c10::Device> ret = c10::nullopt;
|
|
auto merge_devices = [&ret](const c10::optional<c10::Device>& device) {
|
|
if (device.has_value()) {
|
|
if (ret.has_value()) {
|
|
if (ret.value() != device.value()) {
|
|
// invalidate device to reflect conflicts
|
|
ret->set_index(INVALID_INDEX);
|
|
// return false to indicate early termination
|
|
return false;
|
|
} else {
|
|
// same device, do nothing
|
|
return true;
|
|
}
|
|
} else {
|
|
// initialize return device
|
|
ret = device.value();
|
|
return true;
|
|
}
|
|
}
|
|
// no device information, do nothing
|
|
return true;
|
|
};
|
|
for (auto val : node->inputs()) {
|
|
if (!merge_devices(getDevice(val))) {
|
|
return ret;
|
|
}
|
|
}
|
|
for (auto val : node->outputs()) {
|
|
if (!merge_devices(getDevice(val))) {
|
|
return ret;
|
|
}
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
static bool isDeviceCompatible(const Node* node, const c10::Device& device) {
|
|
// only fuses cuda device
|
|
if (!device.is_cuda()) {
|
|
GRAPH_UPDATE("rejecting node (non-cuda device): ", *node);
|
|
return false;
|
|
}
|
|
const auto major = at::cuda::getDeviceProperties(device.index())->major;
|
|
// disable non-elementwise fusion on pre-volta devices
|
|
if (major < 7 && hasNonElementWiseOperation(node)) {
|
|
GRAPH_UPDATE(
|
|
"rejecting node (non element-wise op not supported on SM < 7X): ",
|
|
*node);
|
|
return false;
|
|
}
|
|
// disable bfloat fusion on pre-ampere devices
|
|
if (major < 8 && hasBfloat(node)) {
|
|
GRAPH_UPDATE("rejecting node (bfloat not supported on SM < 8X): ", *node);
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
static bool isFusibleDevice(const Node* node, const c10::Device& device) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
device.index() != INVALID_INDEX, "fusible device needs to be validate");
|
|
auto opt_device = getDevice(node);
|
|
// we can be more relaxed here as we known that this function tries to merge
|
|
// node into an existing `device`
|
|
if (opt_device.has_value() &&
|
|
(opt_device->index() == INVALID_INDEX || opt_device != device)) {
|
|
GRAPH_UPDATE(
|
|
"rejecting node from fusion (outputs device not matching fusion): ",
|
|
*node);
|
|
return false;
|
|
}
|
|
if (!isDeviceCompatible(node, device)) {
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// TODO: we need to check input type when we handle `to()`
|
|
static bool isFusibleDevice(const Node* node) {
|
|
auto device = getDevice(node);
|
|
// be conservative and only fuse cuda operations, this avoids us initializing
|
|
// operations that produces cpu scalar outputs
|
|
if (!device.has_value() || device->index() == INVALID_INDEX) {
|
|
return false;
|
|
}
|
|
|
|
if (!isDeviceCompatible(node, device.value())) {
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool compatibleType(const torch::jit::Value* val) {
|
|
if (auto tensor_type = val->type()->cast<c10::TensorType>()) {
|
|
if (tensor_type->scalarType().has_value()) {
|
|
if (aten_to_data_type(tensor_type->scalarType().value()) ==
|
|
DataType::Null) {
|
|
return false;
|
|
}
|
|
if (!isOptionEnabled(EnableOption::Complex)) {
|
|
// Complex is disabled by default until its support is completely added
|
|
// TODO: remove this logic
|
|
if (isComplexType(
|
|
aten_to_data_type(tensor_type->scalarType().value()))) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
// magic number 8 here since our kernel argument only supports rank <= 8
|
|
if (tensor_type->dim().has_value() && (tensor_type->dim().value() > 8)) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool checkInputTensorTypes(const Node* node) {
|
|
for (const auto i : c10::irange(node->inputs().size())) {
|
|
const auto& val = node->inputs()[i];
|
|
if (!compatibleType(val)) {
|
|
// special case on aten::_batch_norm_impl_index_backward, the 11th output
|
|
// is going to be discarded, so no need to check data type there.
|
|
if (node->kind() ==
|
|
c10::Symbol::fromQualString(
|
|
"aten::_batch_norm_impl_index_backward") &&
|
|
i == 11) {
|
|
continue;
|
|
}
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool checkOutputTensorTypes(const Node* node) {
|
|
for (const auto i : c10::irange(node->outputs().size())) {
|
|
const auto& val = node->outputs()[i];
|
|
if (!compatibleType(val)) {
|
|
// special case on aten::_batch_norm_impl_index, the 4th output
|
|
// is going to be discarded, so no need to check data type there.
|
|
if (node->kind() ==
|
|
c10::Symbol::fromQualString("aten::_batch_norm_impl_index") &&
|
|
i == 3) {
|
|
continue;
|
|
}
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
inline bool isFusibleNode(const Node* node) {
|
|
// Check if already part of a fusion group
|
|
if (node->kind() == prim::CudaFusionGroup)
|
|
return true;
|
|
// Check we have a parsing rule
|
|
if (!isNodeParsible(node)) {
|
|
// ignoring profile nodes & constant nodes to avoid noise from debugging
|
|
if (node->kind() != prim::Constant &&
|
|
node->kind() != prim::profile_ivalue && node->kind() != prim::profile &&
|
|
node->kind() != prim::Param) {
|
|
GRAPH_UPDATE("rejecting node from fusion (node not parsible): ", *node);
|
|
}
|
|
return false;
|
|
}
|
|
// Check if we have a tensor type it's one we support
|
|
if (!checkInputTensorTypes(node)) {
|
|
GRAPH_UPDATE(
|
|
"rejecting node from fusion (input scalar type not supported): ",
|
|
*node);
|
|
return false;
|
|
}
|
|
if (!checkOutputTensorTypes(node)) {
|
|
GRAPH_UPDATE(
|
|
"rejecting node from fusion (output scalar type not supported): ",
|
|
*node);
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
bool isFusibleCudaFusionGroup(const Node* node) {
|
|
FUSER_PERF_SCOPE("isFusibleCudaFusionGroup");
|
|
|
|
if (isFusibleNode(node)) {
|
|
auto ret = isFusibleDevice(node);
|
|
return ret;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool isFusibleCudaFusionGroup(const Node* fusion, const Node* node) {
|
|
FUSER_PERF_SCOPE("isFusibleCudaFusionGroup");
|
|
bool fused = false;
|
|
// TODO: lift the restriction of not fusing producer containing reduction when
|
|
// we have proper scheduling.
|
|
if (isFusibleNode(node)) {
|
|
// ensure if the node has a designated device, it's on the same device with
|
|
// fusion.
|
|
// TODO: is there a danger of us fusing operations that's supposed to be on
|
|
// separate GPUs? And is that necessarily bad?
|
|
auto device = getDevice(fusion);
|
|
fused = (!device.has_value() || isFusibleDevice(node, device.value()));
|
|
}
|
|
return fused;
|
|
}
|
|
|
|
} // namespace cuda
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|