mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Migrate cpu_gpu_fusion_test to use PjRt.
PiperOrigin-RevId: 824627421
This commit is contained in:
parent
a1219cfa94
commit
56a660da4b
6
third_party/xla/xla/tests/BUILD
vendored
6
third_party/xla/xla/tests/BUILD
vendored
|
|
@ -3121,9 +3121,13 @@ xla_test(
|
|||
"gpu",
|
||||
"interpreter",
|
||||
],
|
||||
tags = [
|
||||
"test_migrated_to_hlo_runner_pjrt",
|
||||
],
|
||||
deps = [
|
||||
":client_library_test_runner_mixin",
|
||||
":hlo_test_base",
|
||||
":hlo_pjrt_interpreter_reference_mixin",
|
||||
":hlo_pjrt_test_base",
|
||||
":literal_test_util",
|
||||
":xla_internal_test_main", # fixdeps: keep
|
||||
"//xla:array2d",
|
||||
|
|
|
|||
11
third_party/xla/xla/tests/cpu_gpu_fusion_test.cc
vendored
11
third_party/xla/xla/tests/cpu_gpu_fusion_test.cc
vendored
|
|
@ -52,7 +52,8 @@ limitations under the License.
|
|||
#include "xla/stream_executor/platform.h"
|
||||
#include "xla/stream_executor/stream_executor_memory_allocator.h"
|
||||
#include "xla/tests/client_library_test_runner_mixin.h"
|
||||
#include "xla/tests/hlo_test_base.h"
|
||||
#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
|
||||
#include "xla/tests/hlo_pjrt_test_base.h"
|
||||
#include "xla/tests/literal_test_util.h"
|
||||
#include "xla/tsl/platform/env.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
|
|
@ -74,7 +75,8 @@ const float test_float_vals[3][test_width][test_height] = {
|
|||
|
||||
// Test whether fusion operations are emitted with no errors and compute
|
||||
// accurate outputs.
|
||||
class CpuGpuFusionTest : public HloTestBase {
|
||||
class CpuGpuFusionTest
|
||||
: public HloPjRtInterpreterReferenceMixin<HloPjRtTestBase> {
|
||||
protected:
|
||||
template <typename T, int Arity>
|
||||
void TestElementwise2D(
|
||||
|
|
@ -155,7 +157,7 @@ class CpuGpuFusionTest : public HloTestBase {
|
|||
bool ComputeElementwiseAnswerCompare(ComparisonDirection direction,
|
||||
absl::Span<const float> xs);
|
||||
DebugOptions GetDebugOptionsForTest() const override {
|
||||
DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
|
||||
DebugOptions debug_options = HloPjRtTestBase::GetDebugOptionsForTest();
|
||||
debug_options.add_xla_disable_hlo_passes("layout-assignment");
|
||||
return debug_options;
|
||||
}
|
||||
|
|
@ -884,7 +886,8 @@ TEST_F(CpuGpuFusionTest, Clamp2D) {
|
|||
}
|
||||
|
||||
class FusionClientLibraryTest
|
||||
: public ClientLibraryTestRunnerMixin<HloTestBase> {};
|
||||
: public ClientLibraryTestRunnerMixin<
|
||||
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>> {};
|
||||
|
||||
TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) {
|
||||
// On the GPU backend, it's possible to have too many transposes within one
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user