Add support for local wheel whitelisting and blacklisting

Also fix python version matching logic for wheels which do not require a specific python version.

PiperOrigin-RevId: 650383841
This commit is contained in:
Vadym Matsishevskyi 2024-07-08 14:55:57 -07:00 committed by TensorFlower Gardener
parent de2c871507
commit 5dd1758095
6 changed files with 192 additions and 18 deletions

View File

@ -7,7 +7,9 @@ def python_init_repositories(
requirements = {}, requirements = {},
local_wheel_workspaces = [], local_wheel_workspaces = [],
local_wheel_dist_folder = None, local_wheel_dist_folder = None,
default_python_version = None): default_python_version = None,
local_wheel_inclusion_list = ["*"],
local_wheel_exclusion_list = []):
python_repository( python_repository(
name = "python_version_repo", name = "python_version_repo",
requirements_versions = requirements.keys(), requirements_versions = requirements.keys(),
@ -15,5 +17,7 @@ def python_init_repositories(
local_wheel_workspaces = local_wheel_workspaces, local_wheel_workspaces = local_wheel_workspaces,
local_wheel_dist_folder = local_wheel_dist_folder, local_wheel_dist_folder = local_wheel_dist_folder,
default_python_version = default_python_version, default_python_version = default_python_version,
local_wheel_inclusion_list = local_wheel_inclusion_list,
local_wheel_exclusion_list = local_wheel_exclusion_list,
) )
py_repositories() py_repositories()

View File

@ -2,7 +2,6 @@
Repository rule to manage hermetic Python interpreter under Bazel. Repository rule to manage hermetic Python interpreter under Bazel.
Version can be set via build parameter "--repo_env=HERMETIC_PYTHON_VERSION=3.11" Version can be set via build parameter "--repo_env=HERMETIC_PYTHON_VERSION=3.11"
Defaults to 3.11.
To set wheel name, add "--repo_env=WHEEL_NAME=tensorflow_cpu" To set wheel name, add "--repo_env=WHEEL_NAME=tensorflow_cpu"
""" """
@ -123,6 +122,7 @@ def _get_injected_local_wheels(
local_wheels_dir): local_wheels_dir):
local_wheel_requirements = [] local_wheel_requirements = []
py_ver_marker = "-cp%s-" % py_version.replace(".", "") py_ver_marker = "-cp%s-" % py_version.replace(".", "")
py_major_ver_marker = "-py%s-" % py_version.split(".")[0]
wheels = {} wheels = {}
if local_wheel_workspaces: if local_wheel_workspaces:
@ -132,12 +132,26 @@ def _get_injected_local_wheels(
dist_folder_path = local_wheel_workspace_path.dirname.get_child(dist_folder) dist_folder_path = local_wheel_workspace_path.dirname.get_child(dist_folder)
if dist_folder_path.exists: if dist_folder_path.exists:
dist_wheels = dist_folder_path.readdir() dist_wheels = dist_folder_path.readdir()
_process_dist_wheels(dist_wheels, wheels, py_ver_marker) _process_dist_wheels(
dist_wheels,
wheels,
py_ver_marker,
py_major_ver_marker,
ctx.attr.local_wheel_inclusion_list,
ctx.attr.local_wheel_exclusion_list,
)
if local_wheels_dir: if local_wheels_dir:
dist_folder_path = ctx.path(local_wheels_dir) dist_folder_path = ctx.path(local_wheels_dir)
if dist_folder_path.exists: if dist_folder_path.exists:
dist_wheels = dist_folder_path.readdir() dist_wheels = dist_folder_path.readdir()
_process_dist_wheels(dist_wheels, wheels, py_ver_marker) _process_dist_wheels(
dist_wheels,
wheels,
py_ver_marker,
py_major_ver_marker,
ctx.attr.local_wheel_inclusion_list,
ctx.attr.local_wheel_exclusion_list,
)
for wheel_name, wheel_path in wheels.items(): for wheel_name, wheel_path in wheels.items():
local_wheel_requirements.append( local_wheel_requirements.append(
@ -172,6 +186,14 @@ python_repository = repository_rule(
mandatory = False, mandatory = False,
default = DEFAULT_VERSION, default = DEFAULT_VERSION,
), ),
"local_wheel_inclusion_list": attr.string_list(
mandatory = False,
default = ["*"],
),
"local_wheel_exclusion_list": attr.string_list(
mandatory = False,
default = [],
),
}, },
environ = [ environ = [
"TF_PYTHON_VERSION", "TF_PYTHON_VERSION",
@ -179,12 +201,23 @@ python_repository = repository_rule(
"WHEEL_NAME", "WHEEL_NAME",
"WHEEL_COLLAB", "WHEEL_COLLAB",
], ],
local = True,
) )
def _process_dist_wheels(dist_wheels, wheels, py_ver_marker): def _process_dist_wheels(
dist_wheels,
wheels,
py_ver_marker,
py_major_ver_marker,
local_wheel_inclusion_list,
local_wheel_exclusion_list):
for wheel in dist_wheels: for wheel in dist_wheels:
bn = wheel.basename bn = wheel.basename
if not bn.endswith(".whl") or bn.find(py_ver_marker) < 0: if not bn.endswith(".whl") or (bn.find(py_ver_marker) < 0 and bn.find(py_major_ver_marker) < 0):
continue
if not _basic_wildcard_match(bn, local_wheel_inclusion_list, True, False):
continue
if not _basic_wildcard_match(bn, local_wheel_exclusion_list, False, True):
continue continue
name_components = bn.split("-") name_components = bn.split("-")
@ -199,6 +232,27 @@ def _process_dist_wheels(dist_wheels, wheels, py_ver_marker):
if not latest_wheel or latest_wheel.basename < wheel.basename: if not latest_wheel or latest_wheel.basename < wheel.basename:
wheels[package_name] = wheel wheels[package_name] = wheel
def _basic_wildcard_match(name, patterns, expected_match_result, match_all):
match = False
for pattern in patterns:
match = False
if pattern.startswith("*") and pattern.endswith("*"):
match = name.find(pattern[1:-1]) >= 0
elif pattern.startswith("*"):
match = name.endswith(pattern[1:])
elif pattern.endswith("*"):
match = name.startswith(pattern[:-1])
else:
match = name == pattern
if match_all:
if match != expected_match_result:
return False
elif match == expected_match_result:
return True
return match == expected_match_result
def _custom_python_interpreter_impl(ctx): def _custom_python_interpreter_impl(ctx):
version = ctx.attr.version version = ctx.attr.version
strip_prefix = ctx.attr.strip_prefix.format(version = version) strip_prefix = ctx.attr.strip_prefix.format(version = version)

View File

@ -7,7 +7,9 @@ def python_init_repositories(
requirements = {}, requirements = {},
local_wheel_workspaces = [], local_wheel_workspaces = [],
local_wheel_dist_folder = None, local_wheel_dist_folder = None,
default_python_version = None): default_python_version = None,
local_wheel_inclusion_list = ["*"],
local_wheel_exclusion_list = []):
python_repository( python_repository(
name = "python_version_repo", name = "python_version_repo",
requirements_versions = requirements.keys(), requirements_versions = requirements.keys(),
@ -15,5 +17,7 @@ def python_init_repositories(
local_wheel_workspaces = local_wheel_workspaces, local_wheel_workspaces = local_wheel_workspaces,
local_wheel_dist_folder = local_wheel_dist_folder, local_wheel_dist_folder = local_wheel_dist_folder,
default_python_version = default_python_version, default_python_version = default_python_version,
local_wheel_inclusion_list = local_wheel_inclusion_list,
local_wheel_exclusion_list = local_wheel_exclusion_list,
) )
py_repositories() py_repositories()

View File

@ -2,7 +2,6 @@
Repository rule to manage hermetic Python interpreter under Bazel. Repository rule to manage hermetic Python interpreter under Bazel.
Version can be set via build parameter "--repo_env=HERMETIC_PYTHON_VERSION=3.11" Version can be set via build parameter "--repo_env=HERMETIC_PYTHON_VERSION=3.11"
Defaults to 3.11.
To set wheel name, add "--repo_env=WHEEL_NAME=tensorflow_cpu" To set wheel name, add "--repo_env=WHEEL_NAME=tensorflow_cpu"
""" """
@ -123,6 +122,7 @@ def _get_injected_local_wheels(
local_wheels_dir): local_wheels_dir):
local_wheel_requirements = [] local_wheel_requirements = []
py_ver_marker = "-cp%s-" % py_version.replace(".", "") py_ver_marker = "-cp%s-" % py_version.replace(".", "")
py_major_ver_marker = "-py%s-" % py_version.split(".")[0]
wheels = {} wheels = {}
if local_wheel_workspaces: if local_wheel_workspaces:
@ -132,12 +132,26 @@ def _get_injected_local_wheels(
dist_folder_path = local_wheel_workspace_path.dirname.get_child(dist_folder) dist_folder_path = local_wheel_workspace_path.dirname.get_child(dist_folder)
if dist_folder_path.exists: if dist_folder_path.exists:
dist_wheels = dist_folder_path.readdir() dist_wheels = dist_folder_path.readdir()
_process_dist_wheels(dist_wheels, wheels, py_ver_marker) _process_dist_wheels(
dist_wheels,
wheels,
py_ver_marker,
py_major_ver_marker,
ctx.attr.local_wheel_inclusion_list,
ctx.attr.local_wheel_exclusion_list,
)
if local_wheels_dir: if local_wheels_dir:
dist_folder_path = ctx.path(local_wheels_dir) dist_folder_path = ctx.path(local_wheels_dir)
if dist_folder_path.exists: if dist_folder_path.exists:
dist_wheels = dist_folder_path.readdir() dist_wheels = dist_folder_path.readdir()
_process_dist_wheels(dist_wheels, wheels, py_ver_marker) _process_dist_wheels(
dist_wheels,
wheels,
py_ver_marker,
py_major_ver_marker,
ctx.attr.local_wheel_inclusion_list,
ctx.attr.local_wheel_exclusion_list,
)
for wheel_name, wheel_path in wheels.items(): for wheel_name, wheel_path in wheels.items():
local_wheel_requirements.append( local_wheel_requirements.append(
@ -172,6 +186,14 @@ python_repository = repository_rule(
mandatory = False, mandatory = False,
default = DEFAULT_VERSION, default = DEFAULT_VERSION,
), ),
"local_wheel_inclusion_list": attr.string_list(
mandatory = False,
default = ["*"],
),
"local_wheel_exclusion_list": attr.string_list(
mandatory = False,
default = [],
),
}, },
environ = [ environ = [
"TF_PYTHON_VERSION", "TF_PYTHON_VERSION",
@ -179,12 +201,23 @@ python_repository = repository_rule(
"WHEEL_NAME", "WHEEL_NAME",
"WHEEL_COLLAB", "WHEEL_COLLAB",
], ],
local = True,
) )
def _process_dist_wheels(dist_wheels, wheels, py_ver_marker): def _process_dist_wheels(
dist_wheels,
wheels,
py_ver_marker,
py_major_ver_marker,
local_wheel_inclusion_list,
local_wheel_exclusion_list):
for wheel in dist_wheels: for wheel in dist_wheels:
bn = wheel.basename bn = wheel.basename
if not bn.endswith(".whl") or bn.find(py_ver_marker) < 0: if not bn.endswith(".whl") or (bn.find(py_ver_marker) < 0 and bn.find(py_major_ver_marker) < 0):
continue
if not _basic_wildcard_match(bn, local_wheel_inclusion_list, True, False):
continue
if not _basic_wildcard_match(bn, local_wheel_exclusion_list, False, True):
continue continue
name_components = bn.split("-") name_components = bn.split("-")
@ -199,6 +232,27 @@ def _process_dist_wheels(dist_wheels, wheels, py_ver_marker):
if not latest_wheel or latest_wheel.basename < wheel.basename: if not latest_wheel or latest_wheel.basename < wheel.basename:
wheels[package_name] = wheel wheels[package_name] = wheel
def _basic_wildcard_match(name, patterns, expected_match_result, match_all):
match = False
for pattern in patterns:
match = False
if pattern.startswith("*") and pattern.endswith("*"):
match = name.find(pattern[1:-1]) >= 0
elif pattern.startswith("*"):
match = name.endswith(pattern[1:])
elif pattern.endswith("*"):
match = name.startswith(pattern[:-1])
else:
match = name == pattern
if match_all:
if match != expected_match_result:
return False
elif match == expected_match_result:
return True
return match == expected_match_result
def _custom_python_interpreter_impl(ctx): def _custom_python_interpreter_impl(ctx):
version = ctx.attr.version version = ctx.attr.version
strip_prefix = ctx.attr.strip_prefix.format(version = version) strip_prefix = ctx.attr.strip_prefix.format(version = version)

View File

@ -7,7 +7,9 @@ def python_init_repositories(
requirements = {}, requirements = {},
local_wheel_workspaces = [], local_wheel_workspaces = [],
local_wheel_dist_folder = None, local_wheel_dist_folder = None,
default_python_version = None): default_python_version = None,
local_wheel_inclusion_list = ["*"],
local_wheel_exclusion_list = []):
python_repository( python_repository(
name = "python_version_repo", name = "python_version_repo",
requirements_versions = requirements.keys(), requirements_versions = requirements.keys(),
@ -15,5 +17,7 @@ def python_init_repositories(
local_wheel_workspaces = local_wheel_workspaces, local_wheel_workspaces = local_wheel_workspaces,
local_wheel_dist_folder = local_wheel_dist_folder, local_wheel_dist_folder = local_wheel_dist_folder,
default_python_version = default_python_version, default_python_version = default_python_version,
local_wheel_inclusion_list = local_wheel_inclusion_list,
local_wheel_exclusion_list = local_wheel_exclusion_list,
) )
py_repositories() py_repositories()

View File

@ -2,7 +2,6 @@
Repository rule to manage hermetic Python interpreter under Bazel. Repository rule to manage hermetic Python interpreter under Bazel.
Version can be set via build parameter "--repo_env=HERMETIC_PYTHON_VERSION=3.11" Version can be set via build parameter "--repo_env=HERMETIC_PYTHON_VERSION=3.11"
Defaults to 3.11.
To set wheel name, add "--repo_env=WHEEL_NAME=tensorflow_cpu" To set wheel name, add "--repo_env=WHEEL_NAME=tensorflow_cpu"
""" """
@ -123,6 +122,7 @@ def _get_injected_local_wheels(
local_wheels_dir): local_wheels_dir):
local_wheel_requirements = [] local_wheel_requirements = []
py_ver_marker = "-cp%s-" % py_version.replace(".", "") py_ver_marker = "-cp%s-" % py_version.replace(".", "")
py_major_ver_marker = "-py%s-" % py_version.split(".")[0]
wheels = {} wheels = {}
if local_wheel_workspaces: if local_wheel_workspaces:
@ -132,12 +132,26 @@ def _get_injected_local_wheels(
dist_folder_path = local_wheel_workspace_path.dirname.get_child(dist_folder) dist_folder_path = local_wheel_workspace_path.dirname.get_child(dist_folder)
if dist_folder_path.exists: if dist_folder_path.exists:
dist_wheels = dist_folder_path.readdir() dist_wheels = dist_folder_path.readdir()
_process_dist_wheels(dist_wheels, wheels, py_ver_marker) _process_dist_wheels(
dist_wheels,
wheels,
py_ver_marker,
py_major_ver_marker,
ctx.attr.local_wheel_inclusion_list,
ctx.attr.local_wheel_exclusion_list,
)
if local_wheels_dir: if local_wheels_dir:
dist_folder_path = ctx.path(local_wheels_dir) dist_folder_path = ctx.path(local_wheels_dir)
if dist_folder_path.exists: if dist_folder_path.exists:
dist_wheels = dist_folder_path.readdir() dist_wheels = dist_folder_path.readdir()
_process_dist_wheels(dist_wheels, wheels, py_ver_marker) _process_dist_wheels(
dist_wheels,
wheels,
py_ver_marker,
py_major_ver_marker,
ctx.attr.local_wheel_inclusion_list,
ctx.attr.local_wheel_exclusion_list,
)
for wheel_name, wheel_path in wheels.items(): for wheel_name, wheel_path in wheels.items():
local_wheel_requirements.append( local_wheel_requirements.append(
@ -172,6 +186,14 @@ python_repository = repository_rule(
mandatory = False, mandatory = False,
default = DEFAULT_VERSION, default = DEFAULT_VERSION,
), ),
"local_wheel_inclusion_list": attr.string_list(
mandatory = False,
default = ["*"],
),
"local_wheel_exclusion_list": attr.string_list(
mandatory = False,
default = [],
),
}, },
environ = [ environ = [
"TF_PYTHON_VERSION", "TF_PYTHON_VERSION",
@ -179,12 +201,23 @@ python_repository = repository_rule(
"WHEEL_NAME", "WHEEL_NAME",
"WHEEL_COLLAB", "WHEEL_COLLAB",
], ],
local = True,
) )
def _process_dist_wheels(dist_wheels, wheels, py_ver_marker): def _process_dist_wheels(
dist_wheels,
wheels,
py_ver_marker,
py_major_ver_marker,
local_wheel_inclusion_list,
local_wheel_exclusion_list):
for wheel in dist_wheels: for wheel in dist_wheels:
bn = wheel.basename bn = wheel.basename
if not bn.endswith(".whl") or bn.find(py_ver_marker) < 0: if not bn.endswith(".whl") or (bn.find(py_ver_marker) < 0 and bn.find(py_major_ver_marker) < 0):
continue
if not _basic_wildcard_match(bn, local_wheel_inclusion_list, True, False):
continue
if not _basic_wildcard_match(bn, local_wheel_exclusion_list, False, True):
continue continue
name_components = bn.split("-") name_components = bn.split("-")
@ -199,6 +232,27 @@ def _process_dist_wheels(dist_wheels, wheels, py_ver_marker):
if not latest_wheel or latest_wheel.basename < wheel.basename: if not latest_wheel or latest_wheel.basename < wheel.basename:
wheels[package_name] = wheel wheels[package_name] = wheel
def _basic_wildcard_match(name, patterns, expected_match_result, match_all):
match = False
for pattern in patterns:
match = False
if pattern.startswith("*") and pattern.endswith("*"):
match = name.find(pattern[1:-1]) >= 0
elif pattern.startswith("*"):
match = name.endswith(pattern[1:])
elif pattern.endswith("*"):
match = name.startswith(pattern[:-1])
else:
match = name == pattern
if match_all:
if match != expected_match_result:
return False
elif match == expected_match_result:
return True
return match == expected_match_result
def _custom_python_interpreter_impl(ctx): def _custom_python_interpreter_impl(ctx):
version = ctx.attr.version version = ctx.attr.version
strip_prefix = ctx.attr.strip_prefix.format(version = version) strip_prefix = ctx.attr.strip_prefix.format(version = version)