mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add Gloo TCP_TLS transport (#56442)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56442 Test Plan: Imported from OSS Reviewed By: malfet Differential Revision: D27896285 Pulled By: pbelevich fbshipit-source-id: 589af59ca4c7c9bab2329f079382c09b71cfcf9e
This commit is contained in:
parent
96fce78ac4
commit
96e1a83fb2
|
|
@ -204,6 +204,10 @@ if [[ "${BUILD_ENVIRONMENT}" == *xla* ]]; then
|
|||
./xla/scripts/apply_patches.sh
|
||||
fi
|
||||
|
||||
if [[ "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-py3.6-gcc7-build || "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-py3.6-gcc5.4-build ]]; then
|
||||
export USE_GLOO_WITH_OPENSSL=ON
|
||||
fi
|
||||
|
||||
if [[ "$BUILD_ENVIRONMENT" == *-bazel-* ]]; then
|
||||
set -e
|
||||
|
||||
|
|
|
|||
96
.jenkins/pytorch/create_test_cert.py
Normal file
96
.jenkins/pytorch/create_test_cert.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
from datetime import datetime, timedelta
|
||||
from tempfile import mkdtemp
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography import x509
|
||||
from cryptography.x509.oid import NameOID
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
|
||||
temp_dir = mkdtemp()
|
||||
print(temp_dir)
|
||||
|
||||
|
||||
def genrsa(path):
|
||||
key = rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=2048,
|
||||
)
|
||||
with open(path, "wb") as f:
|
||||
f.write(key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
))
|
||||
return key
|
||||
|
||||
|
||||
def create_cert(path, C, ST, L, O, key):
|
||||
subject = issuer = x509.Name([
|
||||
x509.NameAttribute(NameOID.COUNTRY_NAME, C),
|
||||
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, ST),
|
||||
x509.NameAttribute(NameOID.LOCALITY_NAME, L),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, O),
|
||||
])
|
||||
cert = x509.CertificateBuilder().subject_name(
|
||||
subject
|
||||
).issuer_name(
|
||||
issuer
|
||||
).public_key(
|
||||
key.public_key()
|
||||
).serial_number(
|
||||
x509.random_serial_number()
|
||||
).not_valid_before(
|
||||
datetime.utcnow()
|
||||
).not_valid_after(
|
||||
# Our certificate will be valid for 10 days
|
||||
datetime.utcnow() + timedelta(days=10)
|
||||
).add_extension(
|
||||
x509.BasicConstraints(ca=True, path_length=None), critical=True,
|
||||
).sign(key, hashes.SHA256())
|
||||
# Write our certificate out to disk.
|
||||
with open(path, "wb") as f:
|
||||
f.write(cert.public_bytes(serialization.Encoding.PEM))
|
||||
return cert
|
||||
|
||||
|
||||
def create_req(path, C, ST, L, O, key):
|
||||
csr = x509.CertificateSigningRequestBuilder().subject_name(x509.Name([
|
||||
# Provide various details about who we are.
|
||||
x509.NameAttribute(NameOID.COUNTRY_NAME, C),
|
||||
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, ST),
|
||||
x509.NameAttribute(NameOID.LOCALITY_NAME, L),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, O),
|
||||
])).sign(key, hashes.SHA256())
|
||||
with open(path, "wb") as f:
|
||||
f.write(csr.public_bytes(serialization.Encoding.PEM))
|
||||
return csr
|
||||
|
||||
|
||||
def sign_certificate_request(path, csr_cert, ca_cert, private_ca_key):
|
||||
cert = x509.CertificateBuilder().subject_name(
|
||||
csr_cert.subject
|
||||
).issuer_name(
|
||||
ca_cert.subject
|
||||
).public_key(
|
||||
csr_cert.public_key()
|
||||
).serial_number(
|
||||
x509.random_serial_number()
|
||||
).not_valid_before(
|
||||
datetime.utcnow()
|
||||
).not_valid_after(
|
||||
# Our certificate will be valid for 10 days
|
||||
datetime.utcnow() + timedelta(days=10)
|
||||
# Sign our certificate with our private key
|
||||
).sign(private_ca_key, hashes.SHA256())
|
||||
with open(path, "wb") as f:
|
||||
f.write(cert.public_bytes(serialization.Encoding.PEM))
|
||||
return cert
|
||||
|
||||
|
||||
ca_key = genrsa(temp_dir + "/ca.key")
|
||||
ca_cert = create_cert(temp_dir + "/ca.pem", u"US", u"New York", u"New York", u"Gloo Certificate Authority", ca_key)
|
||||
|
||||
pkey = genrsa(temp_dir + "/pkey.key")
|
||||
csr = create_req(temp_dir + "/csr.csr", u"US", u"California", u"San Francisco", u"Gloo Testing Company", pkey)
|
||||
|
||||
cert = sign_certificate_request(temp_dir + "/cert.pem", csr, ca_cert, ca_key)
|
||||
18
.jenkins/pytorch/run_glootls_test.sh
Executable file
18
.jenkins/pytorch/run_glootls_test.sh
Executable file
|
|
@ -0,0 +1,18 @@
|
|||
#!/bin/bash
|
||||
|
||||
CREATE_TEST_CERT="$(dirname "${BASH_SOURCE[0]}")/create_test_cert.py"
|
||||
TMP_CERT_DIR=$(python "$CREATE_TEST_CERT")
|
||||
|
||||
openssl verify -CAfile "${TMP_CERT_DIR}/ca.pem" "${TMP_CERT_DIR}/cert.pem"
|
||||
|
||||
export GLOO_DEVICE_TRANSPORT=TCP_TLS
|
||||
export GLOO_DEVICE_TRANSPORT_TCP_TLS_PKEY=${TMP_CERT_DIR}/pkey.key
|
||||
export GLOO_DEVICE_TRANSPORT_TCP_TLS_CERT=${TMP_CERT_DIR}/cert.pem
|
||||
export GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_FILE=${TMP_CERT_DIR}/ca.pem
|
||||
|
||||
time python test/run_test.py --include distributed/test_c10d_gloo --verbose --determine-from="$DETERMINE_FROM" -- ProcessGroupGlooTest
|
||||
|
||||
unset GLOO_DEVICE_TRANSPORT
|
||||
unset GLOO_DEVICE_TRANSPORT_TCP_TLS_PKEY
|
||||
unset GLOO_DEVICE_TRANSPORT_TCP_TLS_CERT
|
||||
unset GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_FILE
|
||||
|
|
@ -153,6 +153,11 @@ test_python() {
|
|||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
test_python_gloo_with_tls() {
|
||||
source "$(dirname "${BASH_SOURCE[0]}")/run_glootls_test.sh"
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
|
||||
test_aten() {
|
||||
# Test ATen
|
||||
|
|
@ -478,6 +483,9 @@ else
|
|||
test_distributed
|
||||
test_benchmarks
|
||||
test_rpc
|
||||
if [[ "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-py3.6-gcc7-test || "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-py3.6-gcc5.4-test ]]; then
|
||||
test_python_gloo_with_tls
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ "$BUILD_ENVIRONMENT" == *coverage* ]]; then
|
||||
|
|
|
|||
|
|
@ -277,6 +277,9 @@ cmake_dependent_option(
|
|||
cmake_dependent_option(
|
||||
USE_GLOO "Use Gloo. Only available if USE_DISTRIBUTED is on." ON
|
||||
"USE_DISTRIBUTED" OFF)
|
||||
cmake_dependent_option(
|
||||
USE_GLOO_WITH_OPENSSL "Use Gloo with OpenSSL. Only available if USE_GLOO is on." OFF
|
||||
"USE_GLOO AND LINUX AND NOT INTERN_BUILD_MOBILE" OFF)
|
||||
cmake_dependent_option(
|
||||
USE_TENSORPIPE "Use TensorPipe. Only available if USE_DISTRIBUTED is on." ON
|
||||
"USE_DISTRIBUTED" OFF)
|
||||
|
|
@ -327,6 +330,10 @@ if(WIN32)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
if(USE_GLOO_WITH_OPENSSL)
|
||||
set(USE_TCP_OPENSSL_LOAD ON CACHE STRING "")
|
||||
endif()
|
||||
|
||||
# Linux distributions do not want too many embedded sources, in that sense we
|
||||
# need to be able to build pytorch with an (almost) empty third_party
|
||||
# directory.
|
||||
|
|
|
|||
|
|
@ -631,12 +631,14 @@ class TestFile:
|
|||
self.test_suites[suite_name] = TestSuite(suite_name)
|
||||
if test_case.name in self.test_suites[suite_name].test_cases:
|
||||
# We expect duplicate tests for test_cpp_extensions_aot, distributed/test_distributed_fork,
|
||||
# and distributed/test_distributed_spawn. In these cases, we store the test case that took the longest,
|
||||
# and distributed/test_distributed_spawn and test_c10d_gloo.
|
||||
# In these cases, we store the test case that took the longest,
|
||||
# as in these jobs, the duplicate tests are run in parallel.
|
||||
# For other unexpected cases, we should raise a warning.
|
||||
if self.name == 'test_cpp_extensions_aot' or \
|
||||
self.name == 'distributed/test_distributed_fork' or \
|
||||
self.name == 'distributed/test_distributed_spawn' or \
|
||||
self.name == 'distributed/test_c10d_gloo' or \
|
||||
self.name == 'cpp': # The caffe2 cpp tests spawn duplicate test cases as well.
|
||||
time_difference = self.test_suites[suite_name].replace(test_case)
|
||||
self.total_time += time_difference
|
||||
|
|
|
|||
|
|
@ -8,6 +8,10 @@
|
|||
#include <gloo/transport/tcp/device.h>
|
||||
#endif
|
||||
|
||||
#if GLOO_HAVE_TRANSPORT_TCP_TLS
|
||||
#include <gloo/transport/tcp/tls/device.h>
|
||||
#endif
|
||||
|
||||
#if GLOO_HAVE_TRANSPORT_UV
|
||||
#include <gloo/transport/uv/device.h>
|
||||
#endif
|
||||
|
|
@ -59,6 +63,35 @@ C10_REGISTER_CREATOR(GlooDeviceRegistry, LINUX, makeTCPDevice);
|
|||
C10_REGISTER_CREATOR(GlooDeviceRegistry, TCP, makeTCPDevice);
|
||||
#endif
|
||||
|
||||
#if GLOO_HAVE_TRANSPORT_TCP_TLS
|
||||
static std::string cstr_to_std_string(const char* chars) {
|
||||
return std::string (chars != nullptr ? chars : "");
|
||||
}
|
||||
|
||||
static std::shared_ptr<::gloo::transport::Device> makeTCPTLSDevice(
|
||||
const std::string& interface,
|
||||
const std::string& hostname) {
|
||||
TORCH_CHECK(
|
||||
!interface.empty() || !hostname.empty(),
|
||||
"GlooDeviceFactory::makeTCPTLSDevice(): interface or hostname "
|
||||
"can't be empty");
|
||||
|
||||
::gloo::transport::tcp::attr attr;
|
||||
if (!interface.empty()) {
|
||||
attr.iface = interface;
|
||||
} else {
|
||||
attr.hostname = hostname;
|
||||
}
|
||||
const auto pkey = cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_PKEY"));
|
||||
const auto cert = cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CERT"));
|
||||
const auto caFile = cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_FILE"));
|
||||
const auto caPath = cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_PATH"));
|
||||
return ::gloo::transport::tcp::tls::CreateDevice(attr, pkey, cert, caFile, caPath);
|
||||
}
|
||||
|
||||
C10_REGISTER_CREATOR(GlooDeviceRegistry, TCP_TLS, makeTCPTLSDevice);
|
||||
#endif
|
||||
|
||||
#if GLOO_HAVE_TRANSPORT_UV
|
||||
static std::shared_ptr<::gloo::transport::Device> makeUVDevice(
|
||||
const std::string& interfaceName,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user