From 4aabddab2d89e5956f9beeffe76b62d8d30a10fd Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Tue, 28 Oct 2025 09:03:25 -0700 Subject: [PATCH] Migrate conv_depthwise_test to use PjRt. PiperOrigin-RevId: 825064898 --- third_party/xla/xla/tests/BUILD | 19 +++++++++++-------- .../xla/xla/tests/conv_depthwise_test.cc | 17 +++++++++++------ 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 85036fbd350..4bae0a1699b 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -525,17 +525,22 @@ xla_test( "conv_depthwise_test.cc", ], shard_count = 30, + tags = [ + "test_migrated_to_hlo_runner_pjrt", + ], deps = [ - ":client_library_test_base", ":conv_depthwise_common", - ":hlo_test_base", - ":xla_internal_test_main", - "//xla:execution_options_util", - "//xla:status_macros", - "//xla/hlo/builder:xla_computation", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", + ":xla_internal_test_main", # fixdeps: keep + "//xla:error_spec", + "//xla/hlo/ir:hlo", "//xla/hlo/testlib:test", "//xla/hlo/transforms:despecializer", "//xla/hlo/transforms/simplifiers:float_normalization", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:test", + "@com_google_absl//absl/status", ], ) @@ -940,9 +945,7 @@ cc_library( testonly = True, srcs = ["conv_depthwise_common.cc"], hdrs = ["conv_depthwise_common.h"], - tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ - ":client_library_test_base", "//xla/hlo/testlib:test", "//xla/tsl/platform:test", "@com_google_absl//absl/algorithm:container", diff --git a/third_party/xla/xla/tests/conv_depthwise_test.cc b/third_party/xla/xla/tests/conv_depthwise_test.cc index f96eac7c272..5e4209a5a90 100644 --- a/third_party/xla/xla/tests/conv_depthwise_test.cc +++ b/third_party/xla/xla/tests/conv_depthwise_test.cc @@ -13,23 +13,28 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include +#include -#include "xla/execution_options_util.h" -#include "xla/hlo/builder/xla_computation.h" +#include "absl/status/status.h" +#include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/testlib/test.h" #include "xla/hlo/transforms/despecializer.h" #include "xla/hlo/transforms/simplifiers/float_normalization.h" -#include "xla/status_macros.h" -#include "xla/tests/client_library_test_base.h" #include "xla/tests/conv_depthwise_common.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace { class DepthwiseConvolution2DTest - : public HloTestBase, + : public HloPjRtInterpreterReferenceMixin, public ::testing::WithParamInterface< ::testing::tuple> {};