tensorflow/third_party/py/rules_python_freethreaded.patch
A. Unique TensorFlower 94e89d1103 Fix freethreaded builds by passing HERMETIC_PYTHON_VERSION_KIND in repository rules.
The problem was that `@python_<>_host` repository wasn't aware of the python type (controlled by the build flag `@rules_python//python/config_settings:py_freethreaded`) and selected the first candidate, e.g. GIL version of the python.

PiperOrigin-RevId: 808825370
2025-09-18 18:54:36 -07:00

54 lines
2.2 KiB
Diff

diff --git a/python/private/python_register_toolchains.bzl b/python/private/python_register_toolchains.bzl
index 2e0748de..42223d51 100644
--- a/python/private/python_register_toolchains.bzl
+++ b/python/private/python_register_toolchains.bzl
@@ -94,6 +94,7 @@ def python_register_toolchains(
minor_mapping = minor_mapping or MINOR_MAPPING
python_version = full_version(version = python_version, minor_mapping = minor_mapping)
+ python_version_kind = kwargs.pop("python_version_kind", "")
toolchain_repo_name = "{name}_toolchains".format(name = name)
@@ -189,6 +190,7 @@ def python_register_toolchains(
name = name + "_host",
platforms = loaded_platforms,
python_version = python_version,
+ python_version_kind = python_version_kind,
)
toolchains_repo(
diff --git a/python/private/toolchains_repo.bzl b/python/private/toolchains_repo.bzl
index 93bbb521..ad9cd7fd 100644
--- a/python/private/toolchains_repo.bzl
+++ b/python/private/toolchains_repo.bzl
@@ -214,7 +214,7 @@ def python_toolchain_build_file_content(
user_repository_name = "{}_{}".format(user_repository_name, platform),
python_version = python_version,
set_python_version_constraint = set_python_version_constraint,
- target_settings = [],
+ target_settings = meta.target_settings,
))
return "\n\n".join(entries)
@@ -445,6 +445,10 @@ Full python version, Major.Minor.Micro.
Only set in workspace calls.
""",
),
+ "python_version_kind": attr.string(
+ doc = "Python version kind, e.g. ft (free-threaded)",
+ default = ""
+ ),
"python_versions": attr.string_dict(
doc = """
If set, the Python version for the corresponding selected platform. Values in
@@ -603,6 +609,9 @@ def _get_host_impl_repo_name(*, rctx, logger, python_version, os_name, cpu_name,
else:
candidates = [preference]
+ if rctx.attr.python_version_kind == "ft":
+ candidates = [c for c in candidates if c[0].endswith("freethreaded")]
+
if candidates:
platform_name, meta = candidates[0]
suffix = meta.impl_repo_name