mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
The problem was that `@python_<>_host` repository wasn't aware of the python type (controlled by the build flag `@rules_python//python/config_settings:py_freethreaded`) and selected the first candidate, e.g. GIL version of the python. PiperOrigin-RevId: 808825370
54 lines
2.2 KiB
Diff
54 lines
2.2 KiB
Diff
diff --git a/python/private/python_register_toolchains.bzl b/python/private/python_register_toolchains.bzl
|
|
index 2e0748de..42223d51 100644
|
|
--- a/python/private/python_register_toolchains.bzl
|
|
+++ b/python/private/python_register_toolchains.bzl
|
|
@@ -94,6 +94,7 @@ def python_register_toolchains(
|
|
minor_mapping = minor_mapping or MINOR_MAPPING
|
|
|
|
python_version = full_version(version = python_version, minor_mapping = minor_mapping)
|
|
+ python_version_kind = kwargs.pop("python_version_kind", "")
|
|
|
|
toolchain_repo_name = "{name}_toolchains".format(name = name)
|
|
|
|
@@ -189,6 +190,7 @@ def python_register_toolchains(
|
|
name = name + "_host",
|
|
platforms = loaded_platforms,
|
|
python_version = python_version,
|
|
+ python_version_kind = python_version_kind,
|
|
)
|
|
|
|
toolchains_repo(
|
|
diff --git a/python/private/toolchains_repo.bzl b/python/private/toolchains_repo.bzl
|
|
index 93bbb521..ad9cd7fd 100644
|
|
--- a/python/private/toolchains_repo.bzl
|
|
+++ b/python/private/toolchains_repo.bzl
|
|
@@ -214,7 +214,7 @@ def python_toolchain_build_file_content(
|
|
user_repository_name = "{}_{}".format(user_repository_name, platform),
|
|
python_version = python_version,
|
|
set_python_version_constraint = set_python_version_constraint,
|
|
- target_settings = [],
|
|
+ target_settings = meta.target_settings,
|
|
))
|
|
return "\n\n".join(entries)
|
|
|
|
@@ -445,6 +445,10 @@ Full python version, Major.Minor.Micro.
|
|
Only set in workspace calls.
|
|
""",
|
|
),
|
|
+ "python_version_kind": attr.string(
|
|
+ doc = "Python version kind, e.g. ft (free-threaded)",
|
|
+ default = ""
|
|
+ ),
|
|
"python_versions": attr.string_dict(
|
|
doc = """
|
|
If set, the Python version for the corresponding selected platform. Values in
|
|
@@ -603,6 +609,9 @@ def _get_host_impl_repo_name(*, rctx, logger, python_version, os_name, cpu_name,
|
|
else:
|
|
candidates = [preference]
|
|
|
|
+ if rctx.attr.python_version_kind == "ft":
|
|
+ candidates = [c for c in candidates if c[0].endswith("freethreaded")]
|
|
+
|
|
if candidates:
|
|
platform_name, meta = candidates[0]
|
|
suffix = meta.impl_repo_name
|