Migrate conv_depthwise_test to use PjRt.

PiperOrigin-RevId: 825064898
This commit is contained in:
Niklas Vangerow 2025-10-28 09:03:25 -07:00 committed by TensorFlower Gardener
parent 7c6d13443d
commit 4aabddab2d
2 changed files with 22 additions and 14 deletions

View File

@ -525,17 +525,22 @@ xla_test(
"conv_depthwise_test.cc",
],
shard_count = 30,
tags = [
"test_migrated_to_hlo_runner_pjrt",
],
deps = [
":client_library_test_base",
":conv_depthwise_common",
":hlo_test_base",
":xla_internal_test_main",
"//xla:execution_options_util",
"//xla:status_macros",
"//xla/hlo/builder:xla_computation",
":hlo_pjrt_interpreter_reference_mixin",
":hlo_pjrt_test_base",
":xla_internal_test_main", # fixdeps: keep
"//xla:error_spec",
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:test",
"//xla/hlo/transforms:despecializer",
"//xla/hlo/transforms/simplifiers:float_normalization",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:test",
"@com_google_absl//absl/status",
],
)
@ -940,9 +945,7 @@ cc_library(
testonly = True,
srcs = ["conv_depthwise_common.cc"],
hdrs = ["conv_depthwise_common.h"],
tags = ["test_migrated_to_hlo_runner_pjrt"],
deps = [
":client_library_test_base",
"//xla/hlo/testlib:test",
"//xla/tsl/platform:test",
"@com_google_absl//absl/algorithm:container",

View File

@ -13,23 +13,28 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <optional>
#include <string>
#include <vector>
#include "xla/execution_options_util.h"
#include "xla/hlo/builder/xla_computation.h"
#include "absl/status/status.h"
#include "xla/error_spec.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/testlib/test.h"
#include "xla/hlo/transforms/despecializer.h"
#include "xla/hlo/transforms/simplifiers/float_normalization.h"
#include "xla/status_macros.h"
#include "xla/tests/client_library_test_base.h"
#include "xla/tests/conv_depthwise_common.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
#include "xla/tests/hlo_pjrt_test_base.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/test.h"
namespace xla {
namespace {
class DepthwiseConvolution2DTest
: public HloTestBase,
: public HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>,
public ::testing::WithParamInterface<
::testing::tuple<DepthwiseConvolution2DSpec, bool>> {};