mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Merge pull request #48495 from ROCmSoftwarePlatform/google_upstream_r25_port_pr_48336
[r2.5 port][ROCm] Port PR 48336 to r2.5
This commit is contained in:
commit
30afb3b823
2
.bazelrc
2
.bazelrc
|
|
@ -248,8 +248,6 @@ build:tensorrt --repo_env TF_NEED_TENSORRT=1
|
|||
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
|
||||
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
|
||||
build:rocm --repo_env TF_NEED_ROCM=1
|
||||
# Generated kernels are not yet supported on ROCm.
|
||||
build:rocm --//tensorflow/core/kernels/mlir_generated:enable_gpu=false
|
||||
|
||||
# Options extracted from configure script
|
||||
build:numa --define=with_numa_support=true
|
||||
|
|
|
|||
|
|
@ -271,6 +271,7 @@ def _gen_kernel_library(
|
|||
)
|
||||
|
||||
# We have to use a sh_test instead of build_test because it doesn't properly find the dependent targets.
|
||||
gpu_arch_option = "sm_70,compute_75" if cuda_gpu_architectures() else ",".join(rocm_gpu_architectures())
|
||||
native.sh_test(
|
||||
name = "{op}_{platform}_{type}_{output_type}_gen_test".format(
|
||||
op = op,
|
||||
|
|
@ -288,7 +289,7 @@ def _gen_kernel_library(
|
|||
type = type,
|
||||
output_type = output_type,
|
||||
),
|
||||
"--cpu_codegen=true" if enable_cpu else "--arch=sm_70,compute_75",
|
||||
"--cpu_codegen=true" if enable_cpu else "--arch={}".format(gpu_arch_option),
|
||||
],
|
||||
size = "medium",
|
||||
data = [
|
||||
|
|
|
|||
|
|
@ -58,12 +58,17 @@ GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
|||
|
||||
// Test only values in the function domain. The otherwise returned nan value
|
||||
// fails comparison for equality.
|
||||
#if defined(TENSORFLOW_USE_ROCM)
|
||||
auto acos_test_config = test::OpsTestConfig();
|
||||
#else
|
||||
auto acos_test_config = test::OpsTestConfig().ExpectStrictlyEqual();
|
||||
#endif
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Acos, DT_FLOAT, DT_FLOAT, test::DefaultInputBetweenZeroAndOne<float>(),
|
||||
std::acos, test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
std::acos, acos_test_config)
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Acos, DT_DOUBLE, DT_DOUBLE, test::DefaultInputBetweenZeroAndOne<double>(),
|
||||
std::acos, test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
std::acos, acos_test_config)
|
||||
|
||||
/// Test `tf.Acosh`.
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user