Merge pull request #48495 from ROCmSoftwarePlatform/google_upstream_r25_port_pr_48336

[r2.5 port][ROCm] Port PR 48336 to r2.5
This commit is contained in:
Mihai Maruseac 2021-04-22 15:27:55 -07:00 committed by GitHub
commit 30afb3b823
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 5 deletions

View File

@ -248,8 +248,6 @@ build:tensorrt --repo_env TF_NEED_TENSORRT=1
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
build:rocm --repo_env TF_NEED_ROCM=1
# Generated kernels are not yet supported on ROCm.
build:rocm --//tensorflow/core/kernels/mlir_generated:enable_gpu=false
# Options extracted from configure script
build:numa --define=with_numa_support=true

View File

@ -271,6 +271,7 @@ def _gen_kernel_library(
)
# We have to use a sh_test instead of build_test because it doesn't properly find the dependent targets.
gpu_arch_option = "sm_70,compute_75" if cuda_gpu_architectures() else ",".join(rocm_gpu_architectures())
native.sh_test(
name = "{op}_{platform}_{type}_{output_type}_gen_test".format(
op = op,
@ -288,7 +289,7 @@ def _gen_kernel_library(
type = type,
output_type = output_type,
),
"--cpu_codegen=true" if enable_cpu else "--arch=sm_70,compute_75",
"--cpu_codegen=true" if enable_cpu else "--arch={}".format(gpu_arch_option),
],
size = "medium",
data = [

View File

@ -58,12 +58,17 @@ GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
// Test only values in the function domain. The otherwise returned nan value
// fails comparison for equality.
#if defined(TENSORFLOW_USE_ROCM)
auto acos_test_config = test::OpsTestConfig();
#else
auto acos_test_config = test::OpsTestConfig().ExpectStrictlyEqual();
#endif
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
Acos, DT_FLOAT, DT_FLOAT, test::DefaultInputBetweenZeroAndOne<float>(),
std::acos, test::OpsTestConfig().ExpectStrictlyEqual())
std::acos, acos_test_config)
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
Acos, DT_DOUBLE, DT_DOUBLE, test::DefaultInputBetweenZeroAndOne<double>(),
std::acos, test::OpsTestConfig().ExpectStrictlyEqual())
std::acos, acos_test_config)
/// Test `tf.Acosh`.