Add option to tag PJRT wheels with nightly timestamp

PiperOrigin-RevId: 825706994
This commit is contained in:
Jake Harmon 2025-10-29 14:21:56 -07:00 committed by TensorFlower Gardener
parent cecce70fb2
commit 83051de423
4 changed files with 47 additions and 1 deletions

View File

@ -205,3 +205,6 @@ use_repo(rocm_configure, "local_config_rocm")
tensorrt_configure = use_extension("//third_party/extensions:tensorrt_configure.bzl", "tensorrt_configure_ext")
use_repo(tensorrt_configure, "local_config_tensorrt")
pjrt_nightly_timestamp = use_extension("//build_tools/pjrt_wheels:nightly.bzl", "nightly_timestamp_repo_bzlmod")
use_repo(pjrt_nightly_timestamp, "nightly_timestamp")

View File

@ -152,3 +152,8 @@ load(
nvshmem_redist_init_repository(
nvshmem_redistributions = NVSHMEM_REDISTRIBUTIONS,
)
# This is used for building nightly PJRT wheels.
load("//build_tools/pjrt_wheels:nightly.bzl", "nightly_timestamp_repo")
nightly_timestamp_repo(name = "nightly_timestamp")

View File

@ -1,10 +1,13 @@
load("@cuda_cudart//:version.bzl", cuda_major_version = "VERSION")
load("@nightly_timestamp//:timestamp.bzl", "XLA_NIGHTLY_TIMESTAMP")
load("@rules_python//python:packaging.bzl", "py_wheel")
# This ensures we can only build plugins for selected CUDA versions.
cuda_label = "cuda" + cuda_major_version if cuda_major_version else "null"
wheel_version = "0.0.0.dev0"
# If we're building a nightly, append .devYYYYMMDD to the version
# If we're not, the timestamp is the empty string
wheel_version = "0.0.0" + XLA_NIGHTLY_TIMESTAMP
wheel_platform = select({
"//conditions:default": "manylinux_2_27_x86_64",

View File

@ -0,0 +1,35 @@
"""If we're building a nightly, we use this to pass a timestamp for the wheel version."""
def _nightly_timestamp_impl(rctx):
timestamp_val = rctx.getenv("XLA_NIGHTLY_TIMESTAMP", "") # Default to ""
# Smuggle the value via a new .bzl file
if timestamp_val:
rctx.file(
"timestamp.bzl",
content = 'XLA_NIGHTLY_TIMESTAMP = ".dev{}"'.format(timestamp_val),
)
else:
rctx.file(
"timestamp.bzl",
content = 'XLA_NIGHTLY_TIMESTAMP = ""',
)
# Create a BUILD file to make timestamp.bzl addressable
rctx.file("BUILD.bazel", content = "")
nightly_timestamp_repo = repository_rule(
implementation = _nightly_timestamp_impl,
environ = ["XLA_NIGHTLY_TIMESTAMP"],
)
# bzlmod implementation
def _nightly_timestamp_ext_impl(mctx): # @unused
nightly_timestamp_repo(
name = "nightly_timestamp",
)
nightly_timestamp_repo_bzlmod = module_extension(
implementation = _nightly_timestamp_ext_impl,
environ = ["XLA_NIGHTLY_TIMESTAMP"],
)