diff --git a/third_party/xla/MODULE.bazel b/third_party/xla/MODULE.bazel index d037084670e..02456f332d0 100644 --- a/third_party/xla/MODULE.bazel +++ b/third_party/xla/MODULE.bazel @@ -205,3 +205,6 @@ use_repo(rocm_configure, "local_config_rocm") tensorrt_configure = use_extension("//third_party/extensions:tensorrt_configure.bzl", "tensorrt_configure_ext") use_repo(tensorrt_configure, "local_config_tensorrt") + +pjrt_nightly_timestamp = use_extension("//build_tools/pjrt_wheels:nightly.bzl", "nightly_timestamp_repo_bzlmod") +use_repo(pjrt_nightly_timestamp, "nightly_timestamp") diff --git a/third_party/xla/WORKSPACE b/third_party/xla/WORKSPACE index 31e7e3b6c80..82aa675c8bb 100644 --- a/third_party/xla/WORKSPACE +++ b/third_party/xla/WORKSPACE @@ -152,3 +152,8 @@ load( nvshmem_redist_init_repository( nvshmem_redistributions = NVSHMEM_REDISTRIBUTIONS, ) + +# This is used for building nightly PJRT wheels. +load("//build_tools/pjrt_wheels:nightly.bzl", "nightly_timestamp_repo") + +nightly_timestamp_repo(name = "nightly_timestamp") diff --git a/third_party/xla/build_tools/pjrt_wheels/BUILD.bazel b/third_party/xla/build_tools/pjrt_wheels/BUILD.bazel index d04ded8e2e2..6f0e08fcb29 100644 --- a/third_party/xla/build_tools/pjrt_wheels/BUILD.bazel +++ b/third_party/xla/build_tools/pjrt_wheels/BUILD.bazel @@ -1,10 +1,13 @@ load("@cuda_cudart//:version.bzl", cuda_major_version = "VERSION") +load("@nightly_timestamp//:timestamp.bzl", "XLA_NIGHTLY_TIMESTAMP") load("@rules_python//python:packaging.bzl", "py_wheel") # This ensures we can only build plugins for selected CUDA versions. cuda_label = "cuda" + cuda_major_version if cuda_major_version else "null" -wheel_version = "0.0.0.dev0" +# If we're building a nightly, append .devYYYYMMDD to the version +# If we're not, the timestamp is the empty string +wheel_version = "0.0.0" + XLA_NIGHTLY_TIMESTAMP wheel_platform = select({ "//conditions:default": "manylinux_2_27_x86_64", diff --git a/third_party/xla/build_tools/pjrt_wheels/nightly.bzl b/third_party/xla/build_tools/pjrt_wheels/nightly.bzl new file mode 100644 index 00000000000..f479196ac2f --- /dev/null +++ b/third_party/xla/build_tools/pjrt_wheels/nightly.bzl @@ -0,0 +1,35 @@ +"""If we're building a nightly, we use this to pass a timestamp for the wheel version.""" + +def _nightly_timestamp_impl(rctx): + timestamp_val = rctx.getenv("XLA_NIGHTLY_TIMESTAMP", "") # Default to "" + + # Smuggle the value via a new .bzl file + if timestamp_val: + rctx.file( + "timestamp.bzl", + content = 'XLA_NIGHTLY_TIMESTAMP = ".dev{}"'.format(timestamp_val), + ) + else: + rctx.file( + "timestamp.bzl", + content = 'XLA_NIGHTLY_TIMESTAMP = ""', + ) + + # Create a BUILD file to make timestamp.bzl addressable + rctx.file("BUILD.bazel", content = "") + +nightly_timestamp_repo = repository_rule( + implementation = _nightly_timestamp_impl, + environ = ["XLA_NIGHTLY_TIMESTAMP"], +) + +# bzlmod implementation +def _nightly_timestamp_ext_impl(mctx): # @unused + nightly_timestamp_repo( + name = "nightly_timestamp", + ) + +nightly_timestamp_repo_bzlmod = module_extension( + implementation = _nightly_timestamp_ext_impl, + environ = ["XLA_NIGHTLY_TIMESTAMP"], +)