mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54804 Improve the implementation of the utility functions to handle more edge cases and also have a new set of unit tests to cover their usage. Test Plan: Run the existing and newly introduced unit tests. Reviewed By: kiukchung Differential Revision: D27327898 fbshipit-source-id: 96b6fe2d910e3de69f44947a0e8a9f687ab50633
101 lines
2.9 KiB
Python
101 lines
2.9 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import re
|
|
from typing import Dict, Optional, Tuple
|
|
|
|
|
|
def _parse_rendezvous_config(config_str: str) -> Dict[str, str]:
|
|
"""Extracts key-value pairs from a rendezvous configuration string.
|
|
|
|
Args:
|
|
config_str:
|
|
A string in format <key1>=<value1>,...,<keyN>=<valueN>.
|
|
"""
|
|
config: Dict[str, str] = {}
|
|
|
|
config_str = config_str.strip()
|
|
if not config_str:
|
|
return config
|
|
|
|
key_values = config_str.split(",")
|
|
for kv in key_values:
|
|
key, *values = kv.split("=", 1)
|
|
|
|
key = key.strip()
|
|
if not key:
|
|
raise ValueError(
|
|
"The rendezvous configuration string must be in format "
|
|
"<key1>=<value1>,...,<keyN>=<valueN>."
|
|
)
|
|
|
|
value: Optional[str]
|
|
if values:
|
|
value = values[0].strip()
|
|
else:
|
|
value = None
|
|
if not value:
|
|
raise ValueError(
|
|
f"The rendezvous configuration option '{key}' must have a value specified."
|
|
)
|
|
|
|
config[key] = value
|
|
return config
|
|
|
|
|
|
def _try_parse_port(port_str: str) -> Optional[int]:
|
|
"""Tries to extract the port number from `port_str`."""
|
|
if port_str and re.match(r"^[0-9]{1,5}$", port_str):
|
|
return int(port_str)
|
|
return None
|
|
|
|
|
|
def _parse_rendezvous_endpoint(endpoint: Optional[str], default_port: int) -> Tuple[str, int]:
|
|
"""Extracts the hostname and the port number from a rendezvous endpoint.
|
|
|
|
Args:
|
|
endpoint:
|
|
A string in format <hostname>[:<port>].
|
|
default_port:
|
|
The port number to use if the endpoint does not include one.
|
|
|
|
Returns:
|
|
A tuple of hostname and port number.
|
|
"""
|
|
if endpoint is not None:
|
|
endpoint = endpoint.strip()
|
|
|
|
if not endpoint:
|
|
return ("localhost", default_port)
|
|
|
|
# An endpoint that starts and ends with brackets represents an IPv6 address.
|
|
if endpoint[0] == "[" and endpoint[-1] == "]":
|
|
host, *rest = endpoint, *[]
|
|
else:
|
|
host, *rest = endpoint.rsplit(":", 1)
|
|
|
|
# Sanitize the IPv6 address.
|
|
if len(host) > 1 and host[0] == "[" and host[-1] == "]":
|
|
host = host[1:-1]
|
|
|
|
if len(rest) == 1:
|
|
port = _try_parse_port(rest[0])
|
|
if port is None or port >= 2 ** 16:
|
|
raise ValueError(
|
|
f"The port number of the rendezvous endpoint '{endpoint}' must be an integer "
|
|
"between 0 and 65536."
|
|
)
|
|
else:
|
|
port = default_port
|
|
|
|
if not re.match(r"^[\w\.:-]+$", host):
|
|
raise ValueError(
|
|
f"The hostname of the rendezvous endpoint '{endpoint}' must be a dot-separated list of "
|
|
"labels, an IPv4 address, or an IPv6 address."
|
|
)
|
|
|
|
return host, port
|