mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Correctly set dnn_version in device_description when parsing from proto.
Removing the setting from the other 2 places as it is no longer necessary. PiperOrigin-RevId: 822533070
This commit is contained in:
parent
dfea7bb9a7
commit
3cc86433e3
5
third_party/xla/xla/service/compiler.cc
vendored
5
third_party/xla/xla/service/compiler.cc
vendored
|
|
@ -68,6 +68,11 @@ absl::StatusOr<Compiler::TargetConfig> Compiler::TargetConfig::FromProto(
|
|||
proto.runtime_version().minor(),
|
||||
proto.runtime_version().patch());
|
||||
target_config.device_description.set_runtime_version(runtime_version);
|
||||
se::SemanticVersion dnn_version(
|
||||
static_cast<unsigned>(proto.dnn_version_info().major()),
|
||||
static_cast<unsigned>(proto.dnn_version_info().minor()),
|
||||
static_cast<unsigned>(proto.dnn_version_info().patch()));
|
||||
target_config.device_description.set_dnn_version(dnn_version);
|
||||
return target_config;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -102,17 +102,11 @@ absl::StatusOr<std::unique_ptr<AutotunerPass>> AutotunerPass::Create(
|
|||
GetProfileOptions(debug_options), allocator);
|
||||
}
|
||||
|
||||
se::DeviceDescription device_description = target_config->device_description;
|
||||
device_description.set_dnn_version(
|
||||
{static_cast<unsigned>(target_config->dnn_version_info.major_version()),
|
||||
static_cast<unsigned>(target_config->dnn_version_info.minor_version()),
|
||||
static_cast<unsigned>(target_config->dnn_version_info.patch())});
|
||||
|
||||
std::unique_ptr<AutotunerCacheInterface> cache =
|
||||
std::make_unique<LegacyCache>(
|
||||
debug_options.xla_gpu_experimental_autotuner_cache_dir(),
|
||||
debug_options.xla_gpu_experimental_autotune_cache_mode(),
|
||||
device_description);
|
||||
target_config->device_description);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<Autotuner> autotuner,
|
||||
|
|
|
|||
11
third_party/xla/xla/service/gpu/gpu_compiler.cc
vendored
11
third_party/xla/xla/service/gpu/gpu_compiler.cc
vendored
|
|
@ -366,15 +366,8 @@ DeviceOrDevicelessConfig GetDeviceConfig(
|
|||
return DeviceOrDevicelessConfig{
|
||||
DeviceConfig{stream_exec, options.device_allocator}};
|
||||
}
|
||||
se::DeviceDescription device_description =
|
||||
gpu_target_config.device_description;
|
||||
device_description.set_dnn_version(
|
||||
{static_cast<unsigned>(
|
||||
gpu_target_config.dnn_version_info.major_version()),
|
||||
static_cast<unsigned>(
|
||||
gpu_target_config.dnn_version_info.minor_version()),
|
||||
static_cast<unsigned>(gpu_target_config.dnn_version_info.patch())});
|
||||
return DeviceOrDevicelessConfig{DevicelessConfig{device_description}};
|
||||
return DeviceOrDevicelessConfig{
|
||||
DevicelessConfig{gpu_target_config.device_description}};
|
||||
}
|
||||
|
||||
se::GpuComputeCapability GetGpuVersion(const se::StreamExecutor* stream_exec) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user