mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
First step to introduce GpuComputeCapability custom class instead of std::variant
PiperOrigin-RevId: 820940828
This commit is contained in:
parent
4d358b2bac
commit
4a42fca868
|
|
@ -163,8 +163,10 @@ class GpuKernelToBlobPass
|
|||
target->Options.AllowFPOpFusion =
|
||||
llvm::FPOpFusion::FPOpFusionMode::Fast;
|
||||
};
|
||||
TF_ASSIGN_OR_RETURN(std::string ptx, xla::gpu::nvptx::CompileToPtx(
|
||||
llvm_module_copy.get(), cc,
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::string ptx,
|
||||
xla::gpu::nvptx::CompileToPtx(
|
||||
llvm_module_copy.get(), stream_executor::GpuComputeCapability(cc),
|
||||
options, enable_fusion));
|
||||
if (print_ptx_) {
|
||||
llvm::dbgs() << "Generated PTX code for module '"
|
||||
|
|
|
|||
|
|
@ -68,12 +68,12 @@ se::CudaComputeCapability GetComputeCapability() {
|
|||
}
|
||||
|
||||
bool IsRocm() {
|
||||
return std::holds_alternative<se::RocmComputeCapability>(
|
||||
se::GPUMachineManager()
|
||||
return se::GPUMachineManager()
|
||||
->ExecutorForDevice(0)
|
||||
.value()
|
||||
->GetDeviceDescription()
|
||||
.gpu_compute_capability());
|
||||
.gpu_compute_capability()
|
||||
.IsRocm();
|
||||
}
|
||||
|
||||
void ExpectErrorMessageSubstr(const Status& s, StringPiece substr) {
|
||||
|
|
|
|||
|
|
@ -516,7 +516,7 @@ struct LaunchFusedMatMulOp<GPUDevice, T> {
|
|||
|
||||
const auto& cc =
|
||||
stream->parent()->GetDeviceDescription().gpu_compute_capability();
|
||||
if (auto* procm = std::get_if<se::RocmComputeCapability>(&cc)) {
|
||||
if (auto* procm = cc.rocm_compute_capability()) {
|
||||
use_cudnn = !procm->gfx9_mi200_or_later();
|
||||
}
|
||||
BlasScratchAllocator scratch_allocator(context);
|
||||
|
|
|
|||
|
|
@ -604,7 +604,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
|||
|
||||
const auto& cc =
|
||||
stream->parent()->GetDeviceDescription().gpu_compute_capability();
|
||||
if (auto* procm = std::get_if<se::RocmComputeCapability>(&cc)) {
|
||||
if (auto* procm = cc.rocm_compute_capability()) {
|
||||
bCublasLtSupport = procm->gfx9_mi200_or_later();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2000,8 +2000,7 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
|
|||
mlir::MLIRContext& mlir_context, bool is_xla_fusion, bool emit_kernel) {
|
||||
const auto& gpu_cc = device_info.gpu_compute_capability();
|
||||
TF_RETURN_IF_ERROR(CheckAtLeastAmpere(gpu_cc));
|
||||
std::string arch_name =
|
||||
std::visit([](auto& cc) { return cc.ToString(); }, gpu_cc);
|
||||
std::string arch_name = gpu_cc.ToString();
|
||||
|
||||
const HloModuleConfig& hlo_config = hlo_module.config();
|
||||
|
||||
|
|
|
|||
|
|
@ -68,38 +68,25 @@ bool IsTritonSupportedDotOutputType(
|
|||
case F32:
|
||||
return true;
|
||||
case F8E5M2:
|
||||
return std::visit(
|
||||
absl::Overload(
|
||||
[](const se::CudaComputeCapability& cc) {
|
||||
return cc.IsAtLeastAmpere();
|
||||
},
|
||||
[](const se::RocmComputeCapability& cc) { return false; }),
|
||||
gpu_version);
|
||||
|
||||
if (auto ptr = gpu_version.cuda_compute_capability()) {
|
||||
return ptr->IsAtLeastAmpere();
|
||||
}
|
||||
return false;
|
||||
case F8E4M3FN:
|
||||
return std::visit(
|
||||
absl::Overload(
|
||||
[](const se::CudaComputeCapability& cc) {
|
||||
return cc.IsAtLeastHopper();
|
||||
},
|
||||
[](const se::RocmComputeCapability& cc) { return false; }),
|
||||
gpu_version);
|
||||
if (auto ptr = gpu_version.cuda_compute_capability()) {
|
||||
return ptr->IsAtLeastHopper();
|
||||
}
|
||||
return false;
|
||||
case BF16:
|
||||
return std::visit(
|
||||
absl::Overload(
|
||||
[](const se::CudaComputeCapability& cc) { return true; },
|
||||
[](const se::RocmComputeCapability& cc) {
|
||||
return cc.has_bf16_dtype_support();
|
||||
}),
|
||||
gpu_version);
|
||||
if (auto ptr = gpu_version.rocm_compute_capability()) {
|
||||
return ptr->has_bf16_dtype_support();
|
||||
}
|
||||
return true;
|
||||
case S32:
|
||||
return std::visit(
|
||||
absl::Overload(
|
||||
[](const se::CudaComputeCapability& cc) {
|
||||
return cc.IsAtLeastAmpere();
|
||||
},
|
||||
[](const se::RocmComputeCapability& cc) { return false; }),
|
||||
gpu_version);
|
||||
if (auto ptr = gpu_version.cuda_compute_capability()) {
|
||||
return ptr->IsAtLeastAmpere();
|
||||
}
|
||||
return false;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -100,19 +100,16 @@ CommandBufferConfig GetCommandBufferConfig(
|
|||
};
|
||||
|
||||
// Check if CUDA/ROCM driver supports required features.
|
||||
auto erase_cuda = [&](const se::CudaComputeCapability& cuda_comp) {
|
||||
if (device_info.gpu_compute_capability().IsCuda()) {
|
||||
if (std::min(device_info.runtime_version(), device_info.driver_version()) <
|
||||
se::SemanticVersion{12, 3, 0}) {
|
||||
erase(kRequireTracing); // cuStreamBeginCaptureToGraph
|
||||
erase(kRequireConditionals); // on-device control flow
|
||||
}
|
||||
};
|
||||
auto erase_rocm = [&](const se::RocmComputeCapability& rocm_comp) {
|
||||
}
|
||||
if (device_info.gpu_compute_capability().IsRocm()) {
|
||||
erase(kRequireConditionals); // on-device control flow
|
||||
};
|
||||
|
||||
std::visit(absl::Overload(erase_cuda, erase_rocm),
|
||||
device_info.gpu_compute_capability());
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
|
|
|
|||
1
third_party/xla/xla/service/gpu/BUILD
vendored
1
third_party/xla/xla/service/gpu/BUILD
vendored
|
|
@ -1388,6 +1388,7 @@ cc_library(
|
|||
"//xla/hlo/ir:hlo",
|
||||
"//xla/stream_executor:device_description",
|
||||
"//xla/stream_executor/cuda:cuda_compute_capability",
|
||||
"//xla/stream_executor/rocm:rocm_compute_capability",
|
||||
"@com_google_absl//absl/functional:overload",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -146,16 +146,14 @@ CompileModuleResults InitializeResults(const HloModule* hlo_module,
|
|||
}
|
||||
|
||||
std::string GetDumpName(const se::DeviceDescription& device_desc) {
|
||||
struct GetCcStr {
|
||||
std::string operator()(const se::CudaComputeCapability& cc) const {
|
||||
return absl::StrCat("sm_", cc.ToString());
|
||||
std::string prefix;
|
||||
if (auto* cc =
|
||||
device_desc.gpu_compute_capability().cuda_compute_capability()) {
|
||||
prefix = absl::StrCat("sm_", cc->ToString());
|
||||
} else if (auto* cc = device_desc.gpu_compute_capability()
|
||||
.rocm_compute_capability()) {
|
||||
prefix = cc->gfx_version();
|
||||
}
|
||||
std::string operator()(const se::RocmComputeCapability& cc) const {
|
||||
return cc.gfx_version();
|
||||
}
|
||||
};
|
||||
std::string prefix =
|
||||
std::visit(GetCcStr(), device_desc.gpu_compute_capability());
|
||||
return absl::StrCat(prefix, "_gpu_", kAfterOptimizationsDumpName);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||
#include "xla/shape.h"
|
||||
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
|
||||
#include "xla/stream_executor/device_description.h"
|
||||
#include "xla/stream_executor/rocm/rocm_compute_capability.h"
|
||||
#include "xla/util.h"
|
||||
|
||||
namespace xla {
|
||||
|
|
@ -33,26 +34,24 @@ namespace {
|
|||
|
||||
bool DimensionRequiresPadding(const int64_t size, const PrimitiveType data_type,
|
||||
const se::GpuComputeCapability& gpu_cc) {
|
||||
return std::visit(
|
||||
absl::Overload(
|
||||
[&](const se::CudaComputeCapability& cc) {
|
||||
if (const se::CudaComputeCapability* cc = gpu_cc.cuda_compute_capability()) {
|
||||
for (const auto& req : CublasPaddingRequirements) {
|
||||
if (cc.SupportsAllFeaturesOf(req.min_compute_capability) &&
|
||||
if (cc->SupportsAllFeaturesOf(req.min_compute_capability) &&
|
||||
data_type == req.data_type && size % req.multiple_of != 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
},
|
||||
[&](const se::RocmComputeCapability& cc) {
|
||||
}
|
||||
if (const se::RocmComputeCapability* cc = gpu_cc.rocm_compute_capability()) {
|
||||
for (const auto& req : HipblasPaddingRequirements) {
|
||||
if (data_type == req.data_type && size % req.multiple_of != 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}),
|
||||
gpu_cc);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ShapeRequiresPadding(const Shape& shape, int batch_dimensions_size,
|
||||
|
|
|
|||
|
|
@ -75,15 +75,9 @@ ENTRY e {
|
|||
}
|
||||
|
||||
TEST_F(FloatSupportTestWithTriton, MixedTypeDotWithBF16IsNotUpcasted) {
|
||||
bool skip_test =
|
||||
std::visit(absl::Overload(
|
||||
[](const se::CudaComputeCapability& cc) {
|
||||
return !cc.IsAtLeast(se::CudaComputeCapability::kAmpere);
|
||||
},
|
||||
[](const se::RocmComputeCapability&) { return true; }),
|
||||
GetGpuComputeCapability());
|
||||
|
||||
if (skip_test) {
|
||||
if (GetGpuComputeCapability().IsRocm() ||
|
||||
!GetGpuComputeCapability().cuda_compute_capability()->IsAtLeast(
|
||||
se::CudaComputeCapability::kAmpere)) {
|
||||
GTEST_SKIP() << "Not supported on this GPU architecture";
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -116,24 +116,21 @@ class AnalyticalLatencyHidingSchedulerTest : public GpuCodegenTest {
|
|||
|
||||
TEST_F(AnalyticalLatencyHidingSchedulerTest, TestAnalyticalLatencyEstimator) {
|
||||
auto gpu_compute_capability = GetGpuComputeCapability();
|
||||
auto visitor = [](const auto& c) {
|
||||
using cc = std::remove_const_t<std::remove_reference_t<decltype(c)>>;
|
||||
if constexpr (std::is_same_v<stream_executor::CudaComputeCapability, cc>) {
|
||||
if (!c.IsAtLeast(se::CudaComputeCapability::kPascal)) {
|
||||
if (gpu_compute_capability.IsRocm()) {
|
||||
GTEST_SKIP() << "This test is for Pascal+ GPUs.";
|
||||
}
|
||||
if (c.major == 12 && c.minor == 1) {
|
||||
|
||||
auto* c = gpu_compute_capability.cuda_compute_capability();
|
||||
if (!c->IsAtLeast(se::CudaComputeCapability::kPascal)) {
|
||||
GTEST_SKIP() << "This test is for Pascal+ GPUs.";
|
||||
}
|
||||
if (c->major == 12 && c->minor == 1) {
|
||||
// Skip this test for Spark. Because of the AllReduce, the test uses
|
||||
// gpu_collective_performance_model, which only makes sense in a
|
||||
// datacenter network setting.
|
||||
GTEST_SKIP() << "This test is for datacenter GPUs.";
|
||||
}
|
||||
} else if (!std::is_same_v<stream_executor::RocmComputeCapability, cc>) {
|
||||
GTEST_SKIP() << "This test is for Pascal+ GPUs.";
|
||||
}
|
||||
};
|
||||
|
||||
std::visit(visitor, gpu_compute_capability);
|
||||
const se::DeviceDescription dev_info =
|
||||
backend().default_stream_executor()->GetDeviceDescription();
|
||||
|
||||
|
|
|
|||
|
|
@ -440,11 +440,15 @@ GpuPerformanceWithCollectiveModel::ComputeAllreduceTime(
|
|||
const se::DeviceDescription& gpu_device_info) {
|
||||
// We use nccl group call to launch multiple allreduces so launch overhead
|
||||
// only occurs once.
|
||||
const auto visitor = [&](const auto& cc) {
|
||||
if (auto ptr =
|
||||
gpu_device_info.gpu_compute_capability().cuda_compute_capability()) {
|
||||
return ComputeAllreduceTimeImpl(instr, cost_analysis, gpu_device_info,
|
||||
CreateSettings(cc));
|
||||
};
|
||||
return std::visit(visitor, gpu_device_info.gpu_compute_capability());
|
||||
CreateSettings(*ptr));
|
||||
}
|
||||
return ComputeAllreduceTimeImpl(
|
||||
instr, cost_analysis, gpu_device_info,
|
||||
CreateSettings(
|
||||
*gpu_device_info.gpu_compute_capability().rocm_compute_capability()));
|
||||
}
|
||||
|
||||
/*static*/ absl::Duration
|
||||
|
|
|
|||
|
|
@ -378,6 +378,7 @@ cc_library(
|
|||
"//xla/stream_executor:device_description",
|
||||
"//xla/stream_executor:semantic_version",
|
||||
"//xla/stream_executor/cuda:cuda_compute_capability",
|
||||
"//xla/stream_executor/rocm:rocm_compute_capability",
|
||||
"//xla/tsl/platform:errors",
|
||||
"//xla/tsl/platform:statusor",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ limitations under the License.
|
|||
#include "xla/shape_util.h"
|
||||
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
|
||||
#include "xla/stream_executor/device_description.h"
|
||||
#include "xla/stream_executor/rocm/rocm_compute_capability.h"
|
||||
#include "xla/stream_executor/semantic_version.h"
|
||||
#include "xla/tsl/platform/errors.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
|
|
@ -843,20 +844,19 @@ absl::StatusOr<bool> CommandBufferScheduling::Run(
|
|||
};
|
||||
|
||||
// Check if CUDA/ROCM driver supports required features.
|
||||
auto erase_cuda = [&](const se::CudaComputeCapability& cuda_comp) {
|
||||
if (auto* cuda_comp = device_description_.gpu_compute_capability()
|
||||
.cuda_compute_capability()) {
|
||||
if (std::min(device_description_.runtime_version(),
|
||||
device_description_.driver_version()) <
|
||||
se::SemanticVersion{12, 3, 0}) {
|
||||
erase(kRequireTracing); // cuStreamBeginCaptureToGraph
|
||||
erase(kRequireConditionals); // on-device control flow
|
||||
}
|
||||
};
|
||||
auto erase_rocm = [&](const se::RocmComputeCapability& rocm_comp) {
|
||||
} else if (const se::RocmComputeCapability* rocm_comp =
|
||||
device_description_.gpu_compute_capability()
|
||||
.rocm_compute_capability()) {
|
||||
erase(kRequireConditionals); // on-device control flow
|
||||
};
|
||||
|
||||
std::visit(absl::Overload(erase_cuda, erase_rocm),
|
||||
device_description_.gpu_compute_capability());
|
||||
}
|
||||
|
||||
auto order = module->MakeComputationPostOrder();
|
||||
std::reverse(order.begin(), order.end());
|
||||
|
|
|
|||
|
|
@ -36,8 +36,35 @@ limitations under the License.
|
|||
|
||||
namespace stream_executor {
|
||||
|
||||
using GpuComputeCapability =
|
||||
std::variant<CudaComputeCapability, RocmComputeCapability>;
|
||||
class GpuComputeCapability
|
||||
: public std::variant<CudaComputeCapability, RocmComputeCapability> {
|
||||
public:
|
||||
using std::variant<CudaComputeCapability, RocmComputeCapability>::variant;
|
||||
using std::variant<CudaComputeCapability, RocmComputeCapability>::operator=;
|
||||
|
||||
bool IsCuda() const {
|
||||
return std::holds_alternative<CudaComputeCapability>(*this);
|
||||
}
|
||||
|
||||
bool IsRocm() const {
|
||||
return std::holds_alternative<RocmComputeCapability>(*this);
|
||||
}
|
||||
|
||||
const CudaComputeCapability* cuda_compute_capability() const {
|
||||
return std::get_if<CudaComputeCapability>(this);
|
||||
}
|
||||
|
||||
const RocmComputeCapability* rocm_compute_capability() const {
|
||||
return std::get_if<RocmComputeCapability>(this);
|
||||
}
|
||||
|
||||
std::string ToString() const {
|
||||
if (auto ptr = cuda_compute_capability()) {
|
||||
return ptr->ToString();
|
||||
}
|
||||
return rocm_compute_capability()->ToString();
|
||||
}
|
||||
};
|
||||
|
||||
// Data that describes the execution target of the StreamExecutor, in terms of
|
||||
// important logical parameters. These include dimensionality limits and
|
||||
|
|
@ -193,32 +220,24 @@ class DeviceDescription {
|
|||
// also we do not count what occupies cache, but rather claim that what is
|
||||
// much smaller than the cache size will likely stay in it.
|
||||
constexpr int64_t l1_cache_size_per_SM() const {
|
||||
return std::visit(
|
||||
[](const auto& capability) -> int64_t {
|
||||
if constexpr (std::is_same_v<std::decay_t<decltype(capability)>,
|
||||
RocmComputeCapability>) {
|
||||
if (auto* capability = gpu_compute_capability_.rocm_compute_capability()) {
|
||||
// MI100 and MI200 has 16KB L1 cache per CU.
|
||||
if (capability.gfx9_mi100() || capability.gfx9_mi200()) {
|
||||
if (capability->gfx9_mi100() || capability->gfx9_mi200()) {
|
||||
return 16 * 1024;
|
||||
}
|
||||
// MI300 has 32KB L1 cache per CU.
|
||||
if (capability.gfx9_mi300_series()) {
|
||||
if (capability->gfx9_mi300_series()) {
|
||||
return 32 * 1024;
|
||||
}
|
||||
}
|
||||
// Default return for other GPUs (e.g., RTX A6000).
|
||||
return 2 * 1024;
|
||||
},
|
||||
gpu_compute_capability_);
|
||||
}
|
||||
|
||||
constexpr int64_t dram_to_l2_transaction_size_bytes() const {
|
||||
return std::visit(
|
||||
[](const auto& capability) -> int {
|
||||
if constexpr (std::is_same_v<std::decay_t<decltype(capability)>,
|
||||
RocmComputeCapability>) {
|
||||
if (auto* capability = gpu_compute_capability_.rocm_compute_capability()) {
|
||||
// DRAM->L2 bus is 128 Byte width for MI300.
|
||||
if (capability.gfx9_mi300_series()) {
|
||||
if (capability->gfx9_mi300_series()) {
|
||||
return 128;
|
||||
}
|
||||
}
|
||||
|
|
@ -229,24 +248,17 @@ class DeviceDescription {
|
|||
// (page 10).
|
||||
// return 64 Bytes by default.
|
||||
return 64;
|
||||
},
|
||||
gpu_compute_capability_);
|
||||
}
|
||||
|
||||
constexpr int64_t memory_transactions_per_clock() const {
|
||||
return std::visit(
|
||||
[](const auto& capability) -> int {
|
||||
if constexpr (std::is_same_v<std::decay_t<decltype(capability)>,
|
||||
RocmComputeCapability>) {
|
||||
if (auto* capability = gpu_compute_capability_.rocm_compute_capability()) {
|
||||
// 16 works well on MI300.
|
||||
if (capability.gfx9_mi300_series()) {
|
||||
if (capability->gfx9_mi300_series()) {
|
||||
return 16;
|
||||
}
|
||||
}
|
||||
// Default return for other GPUs.
|
||||
return 32;
|
||||
},
|
||||
gpu_compute_capability_);
|
||||
}
|
||||
|
||||
GpuDeviceInfoProto ToGpuProto() const;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user