summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Swiften/TLS/OpenSSL/OpenSSLContext.cpp38
-rw-r--r--Swiften/TLS/OpenSSL/OpenSSLContext.h7
-rw-r--r--Swiften/TLS/TLSContext.cpp11
-rw-r--r--Swiften/TLS/TLSContext.h8
-rw-r--r--Swiften/TLS/UnitTest/ClientServerTest.cpp85
5 files changed, 145 insertions, 4 deletions
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
index f90b4a8..47e7175 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
+++ b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
@@ -65,60 +65,65 @@ namespace {
};
std::unique_ptr<SSL_CTX> createSSL_CTX(OpenSSLContext::Mode mode) {
std::unique_ptr<SSL_CTX> sslCtx;
switch (mode) {
case OpenSSLContext::Mode::Client:
sslCtx = std::unique_ptr<SSL_CTX>(SSL_CTX_new(SSLv23_client_method()));
break;
case OpenSSLContext::Mode::Server:
sslCtx = std::unique_ptr<SSL_CTX>(SSL_CTX_new(SSLv23_server_method()));
break;
}
return sslCtx;
}
std::string openSSLInternalErrorToString() {
auto bio = std::shared_ptr<BIO>(BIO_new(BIO_s_mem()), BIO_free);
ERR_print_errors(bio.get());
std::string errorString;
errorString.resize(BIO_pending(bio.get()));
BIO_read(bio.get(), (void*)errorString.data(), errorString.size());
return errorString;
}
}
OpenSSLContext::OpenSSLContext(Mode mode) : mode_(mode), state_(State::Start) {
ensureLibraryInitialized();
context_ = createSSL_CTX(mode_);
SSL_CTX_set_options(context_.get(), SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
+ if (mode_ == Mode::Server) {
+ SSL_CTX_set_tlsext_servername_arg(context_.get(), this);
+ SSL_CTX_set_tlsext_servername_callback(context_.get(), OpenSSLContext::handleServerNameCallback);
+ }
+
// TODO: implement CRL checking
// TODO: download CRL (HTTP transport)
// TODO: cache CRL downloads for configurable time period
// TODO: implement OCSP support
// TODO: handle OCSP stapling see https://www.rfc-editor.org/rfc/rfc4366.txt
// Load system certs
#if defined(SWIFTEN_PLATFORM_WINDOWS)
X509_STORE* store = SSL_CTX_get_cert_store(context_.get());
HCERTSTORE systemStore = CertOpenSystemStore(0, "ROOT");
if (systemStore) {
PCCERT_CONTEXT certContext = NULL;
while (true) {
certContext = CertFindCertificateInStore(systemStore, X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, 0, CERT_FIND_ANY, NULL, certContext);
if (!certContext) {
break;
}
OpenSSLCertificate cert(createByteArray(certContext->pbCertEncoded, certContext->cbCertEncoded));
if (store && cert.getInternalX509()) {
X509_STORE_add_cert(store, cert.getInternalX509().get());
}
}
}
#elif !defined(SWIFTEN_PLATFORM_MACOSX)
SSL_CTX_set_default_verify_paths(context_.get());
#elif defined(SWIFTEN_PLATFORM_MACOSX) && !defined(SWIFTEN_PLATFORM_IPHONE)
// On Mac OS X 10.5 (OpenSSL < 0.9.8), OpenSSL does not automatically look in the system store.
// On Mac OS X 10.6 (OpenSSL >= 0.9.8), OpenSSL *does* look in the system store to determine trust.
// However, if there is a certificate error, it will always emit the "Invalid CA" error if we didn't add
// the certificates first. See
@@ -149,124 +154,153 @@ OpenSSLContext::OpenSSLContext(Mode mode) : mode_(mode), state_(State::Start) {
OpenSSLContext::~OpenSSLContext() {
}
void OpenSSLContext::ensureLibraryInitialized() {
static OpenSSLInitializerFinalizer openSSLInit;
}
void OpenSSLContext::initAndSetBIOs() {
// Ownership of BIOs is transferred
readBIO_ = BIO_new(BIO_s_mem());
writeBIO_ = BIO_new(BIO_s_mem());
SSL_set_bio(handle_.get(), readBIO_, writeBIO_);
}
void OpenSSLContext::accept() {
assert(mode_ == Mode::Server);
handle_ = std::unique_ptr<SSL>(SSL_new(context_.get()));
if (!handle_) {
state_ = State::Error;
onError(std::make_shared<TLSError>());
return;
}
initAndSetBIOs();
state_ = State::Accepting;
doAccept();
}
void OpenSSLContext::connect() {
+ connect(std::string());
+}
+
+void OpenSSLContext::connect(const std::string& requestedServerName) {
assert(mode_ == Mode::Client);
handle_ = std::unique_ptr<SSL>(SSL_new(context_.get()));
if (!handle_) {
state_ = State::Error;
onError(std::make_shared<TLSError>());
return;
}
+ if (!requestedServerName.empty()) {
+ if (SSL_set_tlsext_host_name(handle_.get(), const_cast<char*>(requestedServerName.c_str())) != 1) {
+ SWIFT_LOG(error) << "Failed on SSL_set_tlsext_host_name()." << std::endl;
+ }
+ }
+
// Ownership of BIOs is transferred to the SSL_CTX instance in handle_.
initAndSetBIOs();
state_ = State::Connecting;
doConnect();
}
void OpenSSLContext::doAccept() {
auto acceptResult = SSL_accept(handle_.get());
auto error = SSL_get_error(handle_.get(), acceptResult);
switch (error) {
case SSL_ERROR_NONE: {
state_ = State::Connected;
//std::cout << x->name << std::endl;
//const char* comp = SSL_get_current_compression(handle_.get());
//std::cout << "Compression: " << SSL_COMP_get_name(comp) << std::endl;
onConnected();
// The following call is important so the client knowns the handshake is finished.
sendPendingDataToNetwork();
break;
}
case SSL_ERROR_WANT_READ:
sendPendingDataToNetwork();
break;
case SSL_ERROR_WANT_WRITE:
sendPendingDataToNetwork();
break;
default:
SWIFT_LOG(warning) << openSSLInternalErrorToString() << std::endl;
state_ = State::Error;
onError(std::make_shared<TLSError>());
+ sendPendingDataToNetwork();
}
}
void OpenSSLContext::doConnect() {
int connectResult = SSL_connect(handle_.get());
int error = SSL_get_error(handle_.get(), connectResult);
switch (error) {
case SSL_ERROR_NONE: {
state_ = State::Connected;
//std::cout << x->name << std::endl;
//const char* comp = SSL_get_current_compression(handle_.get());
//std::cout << "Compression: " << SSL_COMP_get_name(comp) << std::endl;
onConnected();
break;
}
case SSL_ERROR_WANT_READ:
sendPendingDataToNetwork();
break;
default:
SWIFT_LOG(warning) << openSSLInternalErrorToString() << std::endl;
state_ = State::Error;
onError(std::make_shared<TLSError>());
}
}
+int OpenSSLContext::handleServerNameCallback(SSL* ssl, int*, void* arg) {
+ if (ssl == nullptr)
+ return SSL_TLSEXT_ERR_NOACK;
+
+ const char* servername = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
+ if (servername) {
+ auto serverNameString = std::string(servername);
+ auto context = reinterpret_cast<OpenSSLContext*>(arg);
+ context->onServerNameRequested(serverNameString);
+
+ if (context->abortTLSHandshake_) {
+ context->abortTLSHandshake_ = false;
+ return SSL_TLSEXT_ERR_ALERT_FATAL;
+ }
+ }
+ return SSL_TLSEXT_ERR_OK;
+}
+
void OpenSSLContext::sendPendingDataToNetwork() {
int size = BIO_pending(writeBIO_);
if (size > 0) {
SafeByteArray data;
data.resize(size);
BIO_read(writeBIO_, vecptr(data), size);
onDataForNetwork(data);
}
}
void OpenSSLContext::handleDataFromNetwork(const SafeByteArray& data) {
BIO_write(readBIO_, vecptr(data), data.size());
switch (state_) {
case State::Accepting:
doAccept();
break;
case State::Connecting:
doConnect();
break;
case State::Connected:
sendPendingDataToApplication();
break;
case State::Start: assert(false); break;
case State::Error: /*assert(false);*/ break;
}
}
void OpenSSLContext::handleDataFromApplication(const SafeByteArray& data) {
if (SSL_write(handle_.get(), vecptr(data), data.size()) >= 0) {
sendPendingDataToNetwork();
@@ -358,60 +392,64 @@ bool OpenSSLContext::setPrivateKey(const PrivateKey::ref& privateKey) {
BIO_write(bio.get(), vecptr(privateKey->getData()), int(privateKey->getData().size()));
SafeByteArray safePassword;
void* password = nullptr;
if (privateKey->getPassword()) {
safePassword = privateKey->getPassword().get();
safePassword.push_back(0);
password = safePassword.data();
}
auto resultKey = PEM_read_bio_PrivateKey(bio.get(), nullptr, empty_or_preset_password_cb, password);
if (resultKey) {
if (handle_) {
auto result = SSL_use_PrivateKey(handle_.get(), resultKey);;
if (result != 1) {
return false;
}
}
else {
auto result = SSL_CTX_use_PrivateKey(context_.get(), resultKey);
if (result != 1) {
return false;
}
}
}
else {
return false;
}
return true;
}
+void OpenSSLContext::setAbortTLSHandshake(bool abort) {
+ abortTLSHandshake_ = abort;
+}
+
bool OpenSSLContext::setClientCertificate(CertificateWithKey::ref certificate) {
std::shared_ptr<PKCS12Certificate> pkcs12Certificate = std::dynamic_pointer_cast<PKCS12Certificate>(certificate);
if (!pkcs12Certificate || pkcs12Certificate->isNull()) {
return false;
}
// Create a PKCS12 structure
BIO* bio = BIO_new(BIO_s_mem());
BIO_write(bio, vecptr(pkcs12Certificate->getData()), pkcs12Certificate->getData().size());
std::shared_ptr<PKCS12> pkcs12(d2i_PKCS12_bio(bio, nullptr), PKCS12_free);
BIO_free(bio);
if (!pkcs12) {
return false;
}
// Parse PKCS12
X509 *certPtr = nullptr;
EVP_PKEY* privateKeyPtr = nullptr;
STACK_OF(X509)* caCertsPtr = nullptr;
SafeByteArray password(pkcs12Certificate->getPassword());
password.push_back(0);
int result = PKCS12_parse(pkcs12.get(), reinterpret_cast<const char*>(vecptr(password)), &privateKeyPtr, &certPtr, &caCertsPtr);
if (result != 1) {
return false;
}
std::shared_ptr<X509> cert(certPtr, X509_free);
std::shared_ptr<EVP_PKEY> privateKey(privateKeyPtr, EVP_PKEY_free);
std::shared_ptr<STACK_OF(X509)> caCerts(caCertsPtr, freeX509Stack);
// Use the key & certificates
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.h b/Swiften/TLS/OpenSSL/OpenSSLContext.h
index 5f06811..4a94848 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLContext.h
+++ b/Swiften/TLS/OpenSSL/OpenSSLContext.h
@@ -16,65 +16,68 @@
#include <Swiften/Base/ByteArray.h>
#include <Swiften/TLS/CertificateWithKey.h>
#include <Swiften/TLS/TLSContext.h>
namespace std {
template<>
class default_delete<SSL_CTX> {
public:
void operator()(SSL_CTX *ptr) {
SSL_CTX_free(ptr);
}
};
template<>
class default_delete<SSL> {
public:
void operator()(SSL *ptr) {
SSL_free(ptr);
}
};
}
namespace Swift {
class OpenSSLContext : public TLSContext, boost::noncopyable {
public:
OpenSSLContext(Mode mode);
virtual ~OpenSSLContext() override final;
void accept() override final;
void connect() override final;
+ void connect(const std::string& requestHostname) override final;
bool setCertificateChain(const std::vector<Certificate::ref>& certificateChain) override final;
bool setPrivateKey(const PrivateKey::ref& privateKey) override final;
bool setClientCertificate(CertificateWithKey::ref cert) override final;
+ void setAbortTLSHandshake(bool abort) override final;
void handleDataFromNetwork(const SafeByteArray&) override final;
void handleDataFromApplication(const SafeByteArray&) override final;
std::vector<Certificate::ref> getPeerCertificateChain() const override final;
std::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const override final;
virtual ByteArray getFinishMessage() const override final;
private:
static void ensureLibraryInitialized();
-
+ static int handleServerNameCallback(SSL *ssl, int *ad, void *arg);
static CertificateVerificationError::Type getVerificationErrorTypeForResult(int);
void initAndSetBIOs();
void doAccept();
void doConnect();
void sendPendingDataToNetwork();
void sendPendingDataToApplication();
private:
enum class State { Start, Accepting, Connecting, Connected, Error };
- Mode mode_;
+ const Mode mode_;
State state_;
std::unique_ptr<SSL_CTX> context_;
std::unique_ptr<SSL> handle_;
BIO* readBIO_ = nullptr;
BIO* writeBIO_ = nullptr;
+ bool abortTLSHandshake_ = false;
};
}
diff --git a/Swiften/TLS/TLSContext.cpp b/Swiften/TLS/TLSContext.cpp
index 56814d2..39fb5c9 100644
--- a/Swiften/TLS/TLSContext.cpp
+++ b/Swiften/TLS/TLSContext.cpp
@@ -1,35 +1,44 @@
/*
* Copyright (c) 2010-2018 Isode Limited.
* All rights reserved.
* See the COPYING file for more information.
*/
#include <Swiften/TLS/TLSContext.h>
#include <cassert>
namespace Swift {
TLSContext::~TLSContext() {
}
void TLSContext::accept() {
- assert(false);
+ assert(false);
+}
+
+void TLSContext::connect(const std::string& /* serverName */) {
+ assert(false);
}
bool TLSContext::setCertificateChain(const std::vector<Certificate::ref>& /* certificateChain */) {
assert(false);
return false;
}
bool TLSContext::setPrivateKey(const PrivateKey::ref& /* privateKey */) {
assert(false);
return false;
}
+void TLSContext::setAbortTLSHandshake(bool /* abort */) {
+ assert(false);
+}
+
+
Certificate::ref TLSContext::getPeerCertificate() const {
std::vector<Certificate::ref> chain = getPeerCertificateChain();
return chain.empty() ? Certificate::ref() : chain[0];
}
}
diff --git a/Swiften/TLS/TLSContext.h b/Swiften/TLS/TLSContext.h
index 816f1c1..2655d4b 100644
--- a/Swiften/TLS/TLSContext.h
+++ b/Swiften/TLS/TLSContext.h
@@ -1,56 +1,64 @@
/*
* Copyright (c) 2010-2018 Isode Limited.
* All rights reserved.
* See the COPYING file for more information.
*/
#pragma once
#include <memory>
#include <boost/signals2.hpp>
#include <Swiften/Base/API.h>
#include <Swiften/Base/SafeByteArray.h>
#include <Swiften/TLS/Certificate.h>
#include <Swiften/TLS/CertificateVerificationError.h>
#include <Swiften/TLS/CertificateWithKey.h>
#include <Swiften/TLS/PrivateKey.h>
#include <Swiften/TLS/TLSError.h>
namespace Swift {
class SWIFTEN_API TLSContext {
public:
virtual ~TLSContext();
virtual void accept();
virtual void connect() = 0;
+ virtual void connect(const std::string& serverName);
virtual bool setCertificateChain(const std::vector<Certificate::ref>& /* certificateChain */);
virtual bool setPrivateKey(const PrivateKey::ref& /* privateKey */);
virtual bool setClientCertificate(CertificateWithKey::ref cert) = 0;
+ /**
+ * This method can be used during the \ref onServerNameRequested signal,
+ * to report an error about an unknown host back to the requesting client.
+ */
+ virtual void setAbortTLSHandshake(bool /* abort */);
+
virtual void handleDataFromNetwork(const SafeByteArray&) = 0;
virtual void handleDataFromApplication(const SafeByteArray&) = 0;
Certificate::ref getPeerCertificate() const;
virtual std::vector<Certificate::ref> getPeerCertificateChain() const = 0;
virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const = 0;
virtual ByteArray getFinishMessage() const = 0;
public:
enum class Mode {
Client,
Server
};
public:
boost::signals2::signal<void (const SafeByteArray&)> onDataForNetwork;
boost::signals2::signal<void (const SafeByteArray&)> onDataForApplication;
boost::signals2::signal<void (std::shared_ptr<TLSError>)> onError;
boost::signals2::signal<void ()> onConnected;
+ boost::signals2::signal<void (const std::string&)> onServerNameRequested;
};
}
diff --git a/Swiften/TLS/UnitTest/ClientServerTest.cpp b/Swiften/TLS/UnitTest/ClientServerTest.cpp
index 9aa1762..692b3c0 100644
--- a/Swiften/TLS/UnitTest/ClientServerTest.cpp
+++ b/Swiften/TLS/UnitTest/ClientServerTest.cpp
@@ -266,102 +266,115 @@ class ClientServerConnector {
connections_.push_back(clientContext_->onDataForNetwork.connect([&](const SafeByteArray& data) {
serverContext_->handleDataFromNetwork(data);
}));
connections_.push_back(serverContext_->onDataForNetwork.connect([&](const SafeByteArray& data) {
clientContext_->handleDataFromNetwork(data);
}));
}
private:
TLSContext* clientContext_;
TLSContext* serverContext_;
std::vector<boost::signals2::connection> connections_;
};
struct TLSDataForNetwork {
SafeByteArray data;
};
struct TLSDataForApplication {
SafeByteArray data;
};
struct TLSFault {
std::shared_ptr<Swift::TLSError> error;
};
struct TLSConnected {
std::vector<Certificate::ref> chain;
};
-using TLSEvent = boost::variant<TLSDataForNetwork, TLSDataForApplication, TLSFault, TLSConnected>;
+struct TLSServerNameRequested {
+ std::string name;
+};
+
+using TLSEvent = boost::variant<TLSDataForNetwork, TLSDataForApplication, TLSFault, TLSConnected, TLSServerNameRequested>;
class TLSEventToSafeByteArrayVisitor : public boost::static_visitor<SafeByteArray> {
public:
SafeByteArray operator()(const TLSDataForNetwork& tlsData) const {
return tlsData.data;
}
SafeByteArray operator()(const TLSDataForApplication& tlsData) const {
return tlsData.data;
}
SafeByteArray operator()(const TLSFault&) const {
return createSafeByteArray("");
}
SafeByteArray operator()(const TLSConnected&) const {
return createSafeByteArray("");
}
+
+ SafeByteArray operator()(const TLSServerNameRequested&) const {
+ return createSafeByteArray("");
+ }
+
};
class TLSEventToStringVisitor : public boost::static_visitor<std::string> {
public:
std::string operator()(const TLSDataForNetwork& event) const {
return std::string("TLSDataForNetwork(") + "size: " + std::to_string(event.data.size()) + ")";
}
std::string operator()(const TLSDataForApplication& event) const {
return std::string("TLSDataForApplication(") + "size: " + std::to_string(event.data.size()) + ")";
}
std::string operator()(const TLSFault&) const {
return "TLSFault()";
}
std::string operator()(const TLSConnected& event) const {
std::string certificates;
for (auto cert : event.chain) {
certificates += "\t" + cert->getSubjectName() + "\n";
}
return std::string("TLSConnected()") + "\n" + certificates;
}
+
+ std::string operator()(const TLSServerNameRequested& event) const {
+ return std::string("TLSServerNameRequested(") + "name: " + event.name + ")";
+ }
};
class TLSClientServerEventHistory {
public:
TLSClientServerEventHistory(TLSContext* client, TLSContext* server) {
connectContext(std::string("client"), client);
connectContext(std::string("server"), server);
}
__attribute__((unused))
void print() {
auto count = 0;
std::cout << "\n";
for (auto event : events) {
if (event.first == "server") {
std::cout << std::string(80, ' ');
}
std::cout << count << ". ";
std::cout << event.first << " : " << boost::apply_visitor(TLSEventToStringVisitor(), event.second) << std::endl;
count++;
}
}
private:
void connectContext(const std::string& name, TLSContext* context) {
connections_.push_back(context->onDataForNetwork.connect([=](const SafeByteArray& data) {
events.push_back(std::pair<std::string, TLSEvent>(name, TLSDataForNetwork{data}));
}));
connections_.push_back(context->onDataForApplication.connect([=](const SafeByteArray& data) {
events.push_back(std::pair<std::string, TLSEvent>(name, TLSDataForApplication{data}));
@@ -545,30 +558,100 @@ TEST(ClientServerTest, testSettingPrivateKeyWithWrongPassword) {
TLSClientServerEventHistory events(clientContext.get(), serverContext.get());
ClientServerConnector connector(clientContext.get(), serverContext.get());
auto tlsFactories = std::make_shared<PlatformTLSFactories>();
ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["montague.example"]))));
auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(montagueEncryptedPEM), createSafeByteArray("foo"));
ASSERT_NE(nullptr, privateKey.get());
ASSERT_EQ(false, serverContext->setPrivateKey(privateKey));
}
TEST(ClientServerTest, testSettingPrivateKeyWithoutRequiredPassword) {
auto clientContext = createTLSContext(TLSContext::Mode::Client);
auto serverContext = createTLSContext(TLSContext::Mode::Server);
TLSClientServerEventHistory events(clientContext.get(), serverContext.get());
ClientServerConnector connector(clientContext.get(), serverContext.get());
auto tlsFactories = std::make_shared<PlatformTLSFactories>();
ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["montague.example"]))));
auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(montagueEncryptedPEM));
ASSERT_NE(nullptr, privateKey.get());
ASSERT_EQ(false, serverContext->setPrivateKey(privateKey));
}
+
+TEST(ClientServerTest, testClientServerSNIRequestedHostAvailable) {
+ auto tlsFactories = std::make_shared<PlatformTLSFactories>();
+ auto clientContext = createTLSContext(TLSContext::Mode::Client);
+ auto serverContext = createTLSContext(TLSContext::Mode::Server);
+
+ serverContext->onServerNameRequested.connect([&](const std::string& requestedName) {
+ if (certificatePEM.find(requestedName) != certificatePEM.end() && privateKeyPEM.find(requestedName) != privateKeyPEM.end()) {
+ auto certChain = tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM[requestedName]));
+ ASSERT_EQ(true, serverContext->setCertificateChain(certChain));
+
+ auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(privateKeyPEM[requestedName]));
+ ASSERT_NE(nullptr, privateKey.get());
+ ASSERT_EQ(true, serverContext->setPrivateKey(privateKey));
+ }
+ });
+
+ TLSClientServerEventHistory events(clientContext.get(), serverContext.get());
+
+ ClientServerConnector connector(clientContext.get(), serverContext.get());
+
+ ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["capulet.example"]))));
+
+ auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(privateKeyPEM["capulet.example"]));
+ ASSERT_NE(nullptr, privateKey.get());
+ ASSERT_EQ(true, serverContext->setPrivateKey(privateKey));
+
+ serverContext->accept();
+ clientContext->connect("montague.example");
+
+ clientContext->handleDataFromApplication(createSafeByteArray("This is a test message from the client."));
+ serverContext->handleDataFromApplication(createSafeByteArray("This is a test message from the server."));
+ ASSERT_EQ("This is a test message from the client.", safeByteArrayToString(boost::apply_visitor(TLSEventToSafeByteArrayVisitor(), std::find_if(events.events.begin(), events.events.end(), [](std::pair<std::string, TLSEvent>& event){
+ return event.first == "server" && (event.second.type() == typeid(TLSDataForApplication));
+ })->second)));
+ ASSERT_EQ("This is a test message from the server.", safeByteArrayToString(boost::apply_visitor(TLSEventToSafeByteArrayVisitor(), std::find_if(events.events.begin(), events.events.end(), [](std::pair<std::string, TLSEvent>& event){
+ return event.first == "client" && (event.second.type() == typeid(TLSDataForApplication));
+ })->second)));
+
+ ASSERT_EQ("/CN=montague.example", boost::get<TLSConnected>(events.events[5].second).chain[0]->getSubjectName());
+}
+
+TEST(ClientServerTest, testClientServerSNIRequestedHostUnavailable) {
+ auto tlsFactories = std::make_shared<PlatformTLSFactories>();
+ auto clientContext = createTLSContext(TLSContext::Mode::Client);
+ auto serverContext = createTLSContext(TLSContext::Mode::Server);
+
+ serverContext->onServerNameRequested.connect([&](const std::string&) {
+ serverContext->setAbortTLSHandshake(true);
+ });
+
+ TLSClientServerEventHistory events(clientContext.get(), serverContext.get());
+
+ ClientServerConnector connector(clientContext.get(), serverContext.get());
+
+ ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["capulet.example"]))));
+
+ auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(privateKeyPEM["capulet.example"]));
+ ASSERT_NE(nullptr, privateKey.get());
+ ASSERT_EQ(true, serverContext->setPrivateKey(privateKey));
+
+ serverContext->accept();
+ clientContext->connect("montague.example");
+
+ ASSERT_EQ("server", events.events[1].first);
+ ASSERT_EQ("TLSFault()", boost::apply_visitor(TLSEventToStringVisitor(), events.events[1].second));
+
+ ASSERT_EQ("client", events.events[3].first);
+ ASSERT_EQ("TLSFault()", boost::apply_visitor(TLSEventToStringVisitor(), events.events[3].second));
+}