Add Gloo TCP_TLS transport (#56442)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56442

Test Plan: Imported from OSS

Reviewed By: malfet

Differential Revision: D27896285

Pulled By: pbelevich

fbshipit-source-id: 589af59ca4c7c9bab2329f079382c09b71cfcf9e
This commit is contained in:
Pavel Belevich 2021-05-07 13:34:43 -07:00 committed by Facebook GitHub Bot
parent 96fce78ac4
commit 96e1a83fb2
7 changed files with 169 additions and 1 deletions

View File

@ -204,6 +204,10 @@ if [[ "${BUILD_ENVIRONMENT}" == *xla* ]]; then
./xla/scripts/apply_patches.sh
fi
if [[ "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-py3.6-gcc7-build || "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-py3.6-gcc5.4-build ]]; then
export USE_GLOO_WITH_OPENSSL=ON
fi
if [[ "$BUILD_ENVIRONMENT" == *-bazel-* ]]; then
set -e

View File

@ -0,0 +1,96 @@
from datetime import datetime, timedelta
from tempfile import mkdtemp
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives import hashes
temp_dir = mkdtemp()
print(temp_dir)
def genrsa(path):
key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
)
with open(path, "wb") as f:
f.write(key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
))
return key
def create_cert(path, C, ST, L, O, key):
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, C),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, ST),
x509.NameAttribute(NameOID.LOCALITY_NAME, L),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, O),
])
cert = x509.CertificateBuilder().subject_name(
subject
).issuer_name(
issuer
).public_key(
key.public_key()
).serial_number(
x509.random_serial_number()
).not_valid_before(
datetime.utcnow()
).not_valid_after(
# Our certificate will be valid for 10 days
datetime.utcnow() + timedelta(days=10)
).add_extension(
x509.BasicConstraints(ca=True, path_length=None), critical=True,
).sign(key, hashes.SHA256())
# Write our certificate out to disk.
with open(path, "wb") as f:
f.write(cert.public_bytes(serialization.Encoding.PEM))
return cert
def create_req(path, C, ST, L, O, key):
csr = x509.CertificateSigningRequestBuilder().subject_name(x509.Name([
# Provide various details about who we are.
x509.NameAttribute(NameOID.COUNTRY_NAME, C),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, ST),
x509.NameAttribute(NameOID.LOCALITY_NAME, L),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, O),
])).sign(key, hashes.SHA256())
with open(path, "wb") as f:
f.write(csr.public_bytes(serialization.Encoding.PEM))
return csr
def sign_certificate_request(path, csr_cert, ca_cert, private_ca_key):
cert = x509.CertificateBuilder().subject_name(
csr_cert.subject
).issuer_name(
ca_cert.subject
).public_key(
csr_cert.public_key()
).serial_number(
x509.random_serial_number()
).not_valid_before(
datetime.utcnow()
).not_valid_after(
# Our certificate will be valid for 10 days
datetime.utcnow() + timedelta(days=10)
# Sign our certificate with our private key
).sign(private_ca_key, hashes.SHA256())
with open(path, "wb") as f:
f.write(cert.public_bytes(serialization.Encoding.PEM))
return cert
ca_key = genrsa(temp_dir + "/ca.key")
ca_cert = create_cert(temp_dir + "/ca.pem", u"US", u"New York", u"New York", u"Gloo Certificate Authority", ca_key)
pkey = genrsa(temp_dir + "/pkey.key")
csr = create_req(temp_dir + "/csr.csr", u"US", u"California", u"San Francisco", u"Gloo Testing Company", pkey)
cert = sign_certificate_request(temp_dir + "/cert.pem", csr, ca_cert, ca_key)

View File

@ -0,0 +1,18 @@
#!/bin/bash
CREATE_TEST_CERT="$(dirname "${BASH_SOURCE[0]}")/create_test_cert.py"
TMP_CERT_DIR=$(python "$CREATE_TEST_CERT")
openssl verify -CAfile "${TMP_CERT_DIR}/ca.pem" "${TMP_CERT_DIR}/cert.pem"
export GLOO_DEVICE_TRANSPORT=TCP_TLS
export GLOO_DEVICE_TRANSPORT_TCP_TLS_PKEY=${TMP_CERT_DIR}/pkey.key
export GLOO_DEVICE_TRANSPORT_TCP_TLS_CERT=${TMP_CERT_DIR}/cert.pem
export GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_FILE=${TMP_CERT_DIR}/ca.pem
time python test/run_test.py --include distributed/test_c10d_gloo --verbose --determine-from="$DETERMINE_FROM" -- ProcessGroupGlooTest
unset GLOO_DEVICE_TRANSPORT
unset GLOO_DEVICE_TRANSPORT_TCP_TLS_PKEY
unset GLOO_DEVICE_TRANSPORT_TCP_TLS_CERT
unset GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_FILE

View File

@ -153,6 +153,11 @@ test_python() {
assert_git_not_dirty
}
test_python_gloo_with_tls() {
source "$(dirname "${BASH_SOURCE[0]}")/run_glootls_test.sh"
assert_git_not_dirty
}
test_aten() {
# Test ATen
@ -478,6 +483,9 @@ else
test_distributed
test_benchmarks
test_rpc
if [[ "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-py3.6-gcc7-test || "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-py3.6-gcc5.4-test ]]; then
test_python_gloo_with_tls
fi
fi
if [[ "$BUILD_ENVIRONMENT" == *coverage* ]]; then

View File

@ -277,6 +277,9 @@ cmake_dependent_option(
cmake_dependent_option(
USE_GLOO "Use Gloo. Only available if USE_DISTRIBUTED is on." ON
"USE_DISTRIBUTED" OFF)
cmake_dependent_option(
USE_GLOO_WITH_OPENSSL "Use Gloo with OpenSSL. Only available if USE_GLOO is on." OFF
"USE_GLOO AND LINUX AND NOT INTERN_BUILD_MOBILE" OFF)
cmake_dependent_option(
USE_TENSORPIPE "Use TensorPipe. Only available if USE_DISTRIBUTED is on." ON
"USE_DISTRIBUTED" OFF)
@ -327,6 +330,10 @@ if(WIN32)
endif()
endif()
if(USE_GLOO_WITH_OPENSSL)
set(USE_TCP_OPENSSL_LOAD ON CACHE STRING "")
endif()
# Linux distributions do not want too many embedded sources, in that sense we
# need to be able to build pytorch with an (almost) empty third_party
# directory.

View File

@ -631,12 +631,14 @@ class TestFile:
self.test_suites[suite_name] = TestSuite(suite_name)
if test_case.name in self.test_suites[suite_name].test_cases:
# We expect duplicate tests for test_cpp_extensions_aot, distributed/test_distributed_fork,
# and distributed/test_distributed_spawn. In these cases, we store the test case that took the longest,
# and distributed/test_distributed_spawn and test_c10d_gloo.
# In these cases, we store the test case that took the longest,
# as in these jobs, the duplicate tests are run in parallel.
# For other unexpected cases, we should raise a warning.
if self.name == 'test_cpp_extensions_aot' or \
self.name == 'distributed/test_distributed_fork' or \
self.name == 'distributed/test_distributed_spawn' or \
self.name == 'distributed/test_c10d_gloo' or \
self.name == 'cpp': # The caffe2 cpp tests spawn duplicate test cases as well.
time_difference = self.test_suites[suite_name].replace(test_case)
self.total_time += time_difference

View File

@ -8,6 +8,10 @@
#include <gloo/transport/tcp/device.h>
#endif
#if GLOO_HAVE_TRANSPORT_TCP_TLS
#include <gloo/transport/tcp/tls/device.h>
#endif
#if GLOO_HAVE_TRANSPORT_UV
#include <gloo/transport/uv/device.h>
#endif
@ -59,6 +63,35 @@ C10_REGISTER_CREATOR(GlooDeviceRegistry, LINUX, makeTCPDevice);
C10_REGISTER_CREATOR(GlooDeviceRegistry, TCP, makeTCPDevice);
#endif
#if GLOO_HAVE_TRANSPORT_TCP_TLS
static std::string cstr_to_std_string(const char* chars) {
return std::string (chars != nullptr ? chars : "");
}
static std::shared_ptr<::gloo::transport::Device> makeTCPTLSDevice(
const std::string& interface,
const std::string& hostname) {
TORCH_CHECK(
!interface.empty() || !hostname.empty(),
"GlooDeviceFactory::makeTCPTLSDevice(): interface or hostname "
"can't be empty");
::gloo::transport::tcp::attr attr;
if (!interface.empty()) {
attr.iface = interface;
} else {
attr.hostname = hostname;
}
const auto pkey = cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_PKEY"));
const auto cert = cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CERT"));
const auto caFile = cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_FILE"));
const auto caPath = cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_PATH"));
return ::gloo::transport::tcp::tls::CreateDevice(attr, pkey, cert, caFile, caPath);
}
C10_REGISTER_CREATOR(GlooDeviceRegistry, TCP_TLS, makeTCPTLSDevice);
#endif
#if GLOO_HAVE_TRANSPORT_UV
static std::shared_ptr<::gloo::transport::Device> makeUVDevice(
const std::string& interfaceName,