mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Migrate cpu_gpu_fusion_test to use PjRt.
PiperOrigin-RevId: 824627421
This commit is contained in:
parent
a1219cfa94
commit
56a660da4b
6
third_party/xla/xla/tests/BUILD
vendored
6
third_party/xla/xla/tests/BUILD
vendored
|
|
@ -3121,9 +3121,13 @@ xla_test(
|
||||||
"gpu",
|
"gpu",
|
||||||
"interpreter",
|
"interpreter",
|
||||||
],
|
],
|
||||||
|
tags = [
|
||||||
|
"test_migrated_to_hlo_runner_pjrt",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":client_library_test_runner_mixin",
|
":client_library_test_runner_mixin",
|
||||||
":hlo_test_base",
|
":hlo_pjrt_interpreter_reference_mixin",
|
||||||
|
":hlo_pjrt_test_base",
|
||||||
":literal_test_util",
|
":literal_test_util",
|
||||||
":xla_internal_test_main", # fixdeps: keep
|
":xla_internal_test_main", # fixdeps: keep
|
||||||
"//xla:array2d",
|
"//xla:array2d",
|
||||||
|
|
|
||||||
11
third_party/xla/xla/tests/cpu_gpu_fusion_test.cc
vendored
11
third_party/xla/xla/tests/cpu_gpu_fusion_test.cc
vendored
|
|
@ -52,7 +52,8 @@ limitations under the License.
|
||||||
#include "xla/stream_executor/platform.h"
|
#include "xla/stream_executor/platform.h"
|
||||||
#include "xla/stream_executor/stream_executor_memory_allocator.h"
|
#include "xla/stream_executor/stream_executor_memory_allocator.h"
|
||||||
#include "xla/tests/client_library_test_runner_mixin.h"
|
#include "xla/tests/client_library_test_runner_mixin.h"
|
||||||
#include "xla/tests/hlo_test_base.h"
|
#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
|
||||||
|
#include "xla/tests/hlo_pjrt_test_base.h"
|
||||||
#include "xla/tests/literal_test_util.h"
|
#include "xla/tests/literal_test_util.h"
|
||||||
#include "xla/tsl/platform/env.h"
|
#include "xla/tsl/platform/env.h"
|
||||||
#include "xla/tsl/platform/statusor.h"
|
#include "xla/tsl/platform/statusor.h"
|
||||||
|
|
@ -74,7 +75,8 @@ const float test_float_vals[3][test_width][test_height] = {
|
||||||
|
|
||||||
// Test whether fusion operations are emitted with no errors and compute
|
// Test whether fusion operations are emitted with no errors and compute
|
||||||
// accurate outputs.
|
// accurate outputs.
|
||||||
class CpuGpuFusionTest : public HloTestBase {
|
class CpuGpuFusionTest
|
||||||
|
: public HloPjRtInterpreterReferenceMixin<HloPjRtTestBase> {
|
||||||
protected:
|
protected:
|
||||||
template <typename T, int Arity>
|
template <typename T, int Arity>
|
||||||
void TestElementwise2D(
|
void TestElementwise2D(
|
||||||
|
|
@ -155,7 +157,7 @@ class CpuGpuFusionTest : public HloTestBase {
|
||||||
bool ComputeElementwiseAnswerCompare(ComparisonDirection direction,
|
bool ComputeElementwiseAnswerCompare(ComparisonDirection direction,
|
||||||
absl::Span<const float> xs);
|
absl::Span<const float> xs);
|
||||||
DebugOptions GetDebugOptionsForTest() const override {
|
DebugOptions GetDebugOptionsForTest() const override {
|
||||||
DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
|
DebugOptions debug_options = HloPjRtTestBase::GetDebugOptionsForTest();
|
||||||
debug_options.add_xla_disable_hlo_passes("layout-assignment");
|
debug_options.add_xla_disable_hlo_passes("layout-assignment");
|
||||||
return debug_options;
|
return debug_options;
|
||||||
}
|
}
|
||||||
|
|
@ -884,7 +886,8 @@ TEST_F(CpuGpuFusionTest, Clamp2D) {
|
||||||
}
|
}
|
||||||
|
|
||||||
class FusionClientLibraryTest
|
class FusionClientLibraryTest
|
||||||
: public ClientLibraryTestRunnerMixin<HloTestBase> {};
|
: public ClientLibraryTestRunnerMixin<
|
||||||
|
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>> {};
|
||||||
|
|
||||||
TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) {
|
TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) {
|
||||||
// On the GPU backend, it's possible to have too many transposes within one
|
// On the GPU backend, it's possible to have too many transposes within one
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user