mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Migrate conv_depthwise_test to use PjRt.
PiperOrigin-RevId: 825064898
This commit is contained in:
parent
7c6d13443d
commit
4aabddab2d
19
third_party/xla/xla/tests/BUILD
vendored
19
third_party/xla/xla/tests/BUILD
vendored
|
|
@ -525,17 +525,22 @@ xla_test(
|
|||
"conv_depthwise_test.cc",
|
||||
],
|
||||
shard_count = 30,
|
||||
tags = [
|
||||
"test_migrated_to_hlo_runner_pjrt",
|
||||
],
|
||||
deps = [
|
||||
":client_library_test_base",
|
||||
":conv_depthwise_common",
|
||||
":hlo_test_base",
|
||||
":xla_internal_test_main",
|
||||
"//xla:execution_options_util",
|
||||
"//xla:status_macros",
|
||||
"//xla/hlo/builder:xla_computation",
|
||||
":hlo_pjrt_interpreter_reference_mixin",
|
||||
":hlo_pjrt_test_base",
|
||||
":xla_internal_test_main", # fixdeps: keep
|
||||
"//xla:error_spec",
|
||||
"//xla/hlo/ir:hlo",
|
||||
"//xla/hlo/testlib:test",
|
||||
"//xla/hlo/transforms:despecializer",
|
||||
"//xla/hlo/transforms/simplifiers:float_normalization",
|
||||
"//xla/tsl/platform:errors",
|
||||
"//xla/tsl/platform:test",
|
||||
"@com_google_absl//absl/status",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -940,9 +945,7 @@ cc_library(
|
|||
testonly = True,
|
||||
srcs = ["conv_depthwise_common.cc"],
|
||||
hdrs = ["conv_depthwise_common.h"],
|
||||
tags = ["test_migrated_to_hlo_runner_pjrt"],
|
||||
deps = [
|
||||
":client_library_test_base",
|
||||
"//xla/hlo/testlib:test",
|
||||
"//xla/tsl/platform:test",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
|
|
|
|||
17
third_party/xla/xla/tests/conv_depthwise_test.cc
vendored
17
third_party/xla/xla/tests/conv_depthwise_test.cc
vendored
|
|
@ -13,23 +13,28 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "xla/execution_options_util.h"
|
||||
#include "xla/hlo/builder/xla_computation.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "xla/error_spec.h"
|
||||
#include "xla/hlo/ir/hlo_module.h"
|
||||
#include "xla/hlo/testlib/test.h"
|
||||
#include "xla/hlo/transforms/despecializer.h"
|
||||
#include "xla/hlo/transforms/simplifiers/float_normalization.h"
|
||||
#include "xla/status_macros.h"
|
||||
#include "xla/tests/client_library_test_base.h"
|
||||
#include "xla/tests/conv_depthwise_common.h"
|
||||
#include "xla/tests/hlo_test_base.h"
|
||||
#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
|
||||
#include "xla/tests/hlo_pjrt_test_base.h"
|
||||
#include "xla/tsl/platform/errors.h"
|
||||
#include "xla/tsl/platform/test.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class DepthwiseConvolution2DTest
|
||||
: public HloTestBase,
|
||||
: public HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>,
|
||||
public ::testing::WithParamInterface<
|
||||
::testing::tuple<DepthwiseConvolution2DSpec, bool>> {};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user