diff options
author | Tobias Markmann <tm@ayena.de> | 2018-02-19 14:22:19 (GMT) |
---|---|---|
committer | Tobias Markmann <tm@ayena.de> | 2018-02-21 13:46:30 (GMT) |
commit | 5eed7fcd3c8d42837a013855114deb6cdcaf47d0 (patch) | |
tree | 3fe4373a89be286541449fafd20d4069dff24866 | |
parent | cc1d97fc393c4d6fd3c9ecacd35b3683a10de356 (diff) | |
download | swift-5eed7fcd3c8d42837a013855114deb6cdcaf47d0.zip swift-5eed7fcd3c8d42837a013855114deb6cdcaf47d0.tar.bz2 |
Add support for Server Name Indication to OpenSSLContext
Test-Information:
Builds and unit tests pass on macOS 10.13.3 with OpenSSL TLS
backend.
Change-Id: Ie8f4578c867a2e4bf84484cde4a7cff048566ca4
-rw-r--r-- | Swiften/TLS/OpenSSL/OpenSSLContext.cpp | 38 | ||||
-rw-r--r-- | Swiften/TLS/OpenSSL/OpenSSLContext.h | 7 | ||||
-rw-r--r-- | Swiften/TLS/TLSContext.cpp | 11 | ||||
-rw-r--r-- | Swiften/TLS/TLSContext.h | 8 | ||||
-rw-r--r-- | Swiften/TLS/UnitTest/ClientServerTest.cpp | 85 |
5 files changed, 145 insertions, 4 deletions
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp index f90b4a8..47e7175 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp +++ b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp @@ -65,60 +65,65 @@ namespace { }; std::unique_ptr<SSL_CTX> createSSL_CTX(OpenSSLContext::Mode mode) { std::unique_ptr<SSL_CTX> sslCtx; switch (mode) { case OpenSSLContext::Mode::Client: sslCtx = std::unique_ptr<SSL_CTX>(SSL_CTX_new(SSLv23_client_method())); break; case OpenSSLContext::Mode::Server: sslCtx = std::unique_ptr<SSL_CTX>(SSL_CTX_new(SSLv23_server_method())); break; } return sslCtx; } std::string openSSLInternalErrorToString() { auto bio = std::shared_ptr<BIO>(BIO_new(BIO_s_mem()), BIO_free); ERR_print_errors(bio.get()); std::string errorString; errorString.resize(BIO_pending(bio.get())); BIO_read(bio.get(), (void*)errorString.data(), errorString.size()); return errorString; } } OpenSSLContext::OpenSSLContext(Mode mode) : mode_(mode), state_(State::Start) { ensureLibraryInitialized(); context_ = createSSL_CTX(mode_); SSL_CTX_set_options(context_.get(), SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3); + if (mode_ == Mode::Server) { + SSL_CTX_set_tlsext_servername_arg(context_.get(), this); + SSL_CTX_set_tlsext_servername_callback(context_.get(), OpenSSLContext::handleServerNameCallback); + } + // TODO: implement CRL checking // TODO: download CRL (HTTP transport) // TODO: cache CRL downloads for configurable time period // TODO: implement OCSP support // TODO: handle OCSP stapling see https://www.rfc-editor.org/rfc/rfc4366.txt // Load system certs #if defined(SWIFTEN_PLATFORM_WINDOWS) X509_STORE* store = SSL_CTX_get_cert_store(context_.get()); HCERTSTORE systemStore = CertOpenSystemStore(0, "ROOT"); if (systemStore) { PCCERT_CONTEXT certContext = NULL; while (true) { certContext = CertFindCertificateInStore(systemStore, X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, 0, CERT_FIND_ANY, NULL, certContext); if (!certContext) { break; } OpenSSLCertificate cert(createByteArray(certContext->pbCertEncoded, certContext->cbCertEncoded)); if (store && cert.getInternalX509()) { X509_STORE_add_cert(store, cert.getInternalX509().get()); } } } #elif !defined(SWIFTEN_PLATFORM_MACOSX) SSL_CTX_set_default_verify_paths(context_.get()); #elif defined(SWIFTEN_PLATFORM_MACOSX) && !defined(SWIFTEN_PLATFORM_IPHONE) // On Mac OS X 10.5 (OpenSSL < 0.9.8), OpenSSL does not automatically look in the system store. // On Mac OS X 10.6 (OpenSSL >= 0.9.8), OpenSSL *does* look in the system store to determine trust. // However, if there is a certificate error, it will always emit the "Invalid CA" error if we didn't add // the certificates first. See @@ -149,124 +154,153 @@ OpenSSLContext::OpenSSLContext(Mode mode) : mode_(mode), state_(State::Start) { OpenSSLContext::~OpenSSLContext() { } void OpenSSLContext::ensureLibraryInitialized() { static OpenSSLInitializerFinalizer openSSLInit; } void OpenSSLContext::initAndSetBIOs() { // Ownership of BIOs is transferred readBIO_ = BIO_new(BIO_s_mem()); writeBIO_ = BIO_new(BIO_s_mem()); SSL_set_bio(handle_.get(), readBIO_, writeBIO_); } void OpenSSLContext::accept() { assert(mode_ == Mode::Server); handle_ = std::unique_ptr<SSL>(SSL_new(context_.get())); if (!handle_) { state_ = State::Error; onError(std::make_shared<TLSError>()); return; } initAndSetBIOs(); state_ = State::Accepting; doAccept(); } void OpenSSLContext::connect() { + connect(std::string()); +} + +void OpenSSLContext::connect(const std::string& requestedServerName) { assert(mode_ == Mode::Client); handle_ = std::unique_ptr<SSL>(SSL_new(context_.get())); if (!handle_) { state_ = State::Error; onError(std::make_shared<TLSError>()); return; } + if (!requestedServerName.empty()) { + if (SSL_set_tlsext_host_name(handle_.get(), const_cast<char*>(requestedServerName.c_str())) != 1) { + SWIFT_LOG(error) << "Failed on SSL_set_tlsext_host_name()." << std::endl; + } + } + // Ownership of BIOs is transferred to the SSL_CTX instance in handle_. initAndSetBIOs(); state_ = State::Connecting; doConnect(); } void OpenSSLContext::doAccept() { auto acceptResult = SSL_accept(handle_.get()); auto error = SSL_get_error(handle_.get(), acceptResult); switch (error) { case SSL_ERROR_NONE: { state_ = State::Connected; //std::cout << x->name << std::endl; //const char* comp = SSL_get_current_compression(handle_.get()); //std::cout << "Compression: " << SSL_COMP_get_name(comp) << std::endl; onConnected(); // The following call is important so the client knowns the handshake is finished. sendPendingDataToNetwork(); break; } case SSL_ERROR_WANT_READ: sendPendingDataToNetwork(); break; case SSL_ERROR_WANT_WRITE: sendPendingDataToNetwork(); break; default: SWIFT_LOG(warning) << openSSLInternalErrorToString() << std::endl; state_ = State::Error; onError(std::make_shared<TLSError>()); + sendPendingDataToNetwork(); } } void OpenSSLContext::doConnect() { int connectResult = SSL_connect(handle_.get()); int error = SSL_get_error(handle_.get(), connectResult); switch (error) { case SSL_ERROR_NONE: { state_ = State::Connected; //std::cout << x->name << std::endl; //const char* comp = SSL_get_current_compression(handle_.get()); //std::cout << "Compression: " << SSL_COMP_get_name(comp) << std::endl; onConnected(); break; } case SSL_ERROR_WANT_READ: sendPendingDataToNetwork(); break; default: SWIFT_LOG(warning) << openSSLInternalErrorToString() << std::endl; state_ = State::Error; onError(std::make_shared<TLSError>()); } } +int OpenSSLContext::handleServerNameCallback(SSL* ssl, int*, void* arg) { + if (ssl == nullptr) + return SSL_TLSEXT_ERR_NOACK; + + const char* servername = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); + if (servername) { + auto serverNameString = std::string(servername); + auto context = reinterpret_cast<OpenSSLContext*>(arg); + context->onServerNameRequested(serverNameString); + + if (context->abortTLSHandshake_) { + context->abortTLSHandshake_ = false; + return SSL_TLSEXT_ERR_ALERT_FATAL; + } + } + return SSL_TLSEXT_ERR_OK; +} + void OpenSSLContext::sendPendingDataToNetwork() { int size = BIO_pending(writeBIO_); if (size > 0) { SafeByteArray data; data.resize(size); BIO_read(writeBIO_, vecptr(data), size); onDataForNetwork(data); } } void OpenSSLContext::handleDataFromNetwork(const SafeByteArray& data) { BIO_write(readBIO_, vecptr(data), data.size()); switch (state_) { case State::Accepting: doAccept(); break; case State::Connecting: doConnect(); break; case State::Connected: sendPendingDataToApplication(); break; case State::Start: assert(false); break; case State::Error: /*assert(false);*/ break; } } void OpenSSLContext::handleDataFromApplication(const SafeByteArray& data) { if (SSL_write(handle_.get(), vecptr(data), data.size()) >= 0) { sendPendingDataToNetwork(); @@ -358,60 +392,64 @@ bool OpenSSLContext::setPrivateKey(const PrivateKey::ref& privateKey) { BIO_write(bio.get(), vecptr(privateKey->getData()), int(privateKey->getData().size())); SafeByteArray safePassword; void* password = nullptr; if (privateKey->getPassword()) { safePassword = privateKey->getPassword().get(); safePassword.push_back(0); password = safePassword.data(); } auto resultKey = PEM_read_bio_PrivateKey(bio.get(), nullptr, empty_or_preset_password_cb, password); if (resultKey) { if (handle_) { auto result = SSL_use_PrivateKey(handle_.get(), resultKey);; if (result != 1) { return false; } } else { auto result = SSL_CTX_use_PrivateKey(context_.get(), resultKey); if (result != 1) { return false; } } } else { return false; } return true; } +void OpenSSLContext::setAbortTLSHandshake(bool abort) { + abortTLSHandshake_ = abort; +} + bool OpenSSLContext::setClientCertificate(CertificateWithKey::ref certificate) { std::shared_ptr<PKCS12Certificate> pkcs12Certificate = std::dynamic_pointer_cast<PKCS12Certificate>(certificate); if (!pkcs12Certificate || pkcs12Certificate->isNull()) { return false; } // Create a PKCS12 structure BIO* bio = BIO_new(BIO_s_mem()); BIO_write(bio, vecptr(pkcs12Certificate->getData()), pkcs12Certificate->getData().size()); std::shared_ptr<PKCS12> pkcs12(d2i_PKCS12_bio(bio, nullptr), PKCS12_free); BIO_free(bio); if (!pkcs12) { return false; } // Parse PKCS12 X509 *certPtr = nullptr; EVP_PKEY* privateKeyPtr = nullptr; STACK_OF(X509)* caCertsPtr = nullptr; SafeByteArray password(pkcs12Certificate->getPassword()); password.push_back(0); int result = PKCS12_parse(pkcs12.get(), reinterpret_cast<const char*>(vecptr(password)), &privateKeyPtr, &certPtr, &caCertsPtr); if (result != 1) { return false; } std::shared_ptr<X509> cert(certPtr, X509_free); std::shared_ptr<EVP_PKEY> privateKey(privateKeyPtr, EVP_PKEY_free); std::shared_ptr<STACK_OF(X509)> caCerts(caCertsPtr, freeX509Stack); // Use the key & certificates diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.h b/Swiften/TLS/OpenSSL/OpenSSLContext.h index 5f06811..4a94848 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLContext.h +++ b/Swiften/TLS/OpenSSL/OpenSSLContext.h @@ -16,65 +16,68 @@ #include <Swiften/Base/ByteArray.h> #include <Swiften/TLS/CertificateWithKey.h> #include <Swiften/TLS/TLSContext.h> namespace std { template<> class default_delete<SSL_CTX> { public: void operator()(SSL_CTX *ptr) { SSL_CTX_free(ptr); } }; template<> class default_delete<SSL> { public: void operator()(SSL *ptr) { SSL_free(ptr); } }; } namespace Swift { class OpenSSLContext : public TLSContext, boost::noncopyable { public: OpenSSLContext(Mode mode); virtual ~OpenSSLContext() override final; void accept() override final; void connect() override final; + void connect(const std::string& requestHostname) override final; bool setCertificateChain(const std::vector<Certificate::ref>& certificateChain) override final; bool setPrivateKey(const PrivateKey::ref& privateKey) override final; bool setClientCertificate(CertificateWithKey::ref cert) override final; + void setAbortTLSHandshake(bool abort) override final; void handleDataFromNetwork(const SafeByteArray&) override final; void handleDataFromApplication(const SafeByteArray&) override final; std::vector<Certificate::ref> getPeerCertificateChain() const override final; std::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const override final; virtual ByteArray getFinishMessage() const override final; private: static void ensureLibraryInitialized(); - + static int handleServerNameCallback(SSL *ssl, int *ad, void *arg); static CertificateVerificationError::Type getVerificationErrorTypeForResult(int); void initAndSetBIOs(); void doAccept(); void doConnect(); void sendPendingDataToNetwork(); void sendPendingDataToApplication(); private: enum class State { Start, Accepting, Connecting, Connected, Error }; - Mode mode_; + const Mode mode_; State state_; std::unique_ptr<SSL_CTX> context_; std::unique_ptr<SSL> handle_; BIO* readBIO_ = nullptr; BIO* writeBIO_ = nullptr; + bool abortTLSHandshake_ = false; }; } diff --git a/Swiften/TLS/TLSContext.cpp b/Swiften/TLS/TLSContext.cpp index 56814d2..39fb5c9 100644 --- a/Swiften/TLS/TLSContext.cpp +++ b/Swiften/TLS/TLSContext.cpp @@ -1,35 +1,44 @@ /* * Copyright (c) 2010-2018 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #include <Swiften/TLS/TLSContext.h> #include <cassert> namespace Swift { TLSContext::~TLSContext() { } void TLSContext::accept() { - assert(false); + assert(false); +} + +void TLSContext::connect(const std::string& /* serverName */) { + assert(false); } bool TLSContext::setCertificateChain(const std::vector<Certificate::ref>& /* certificateChain */) { assert(false); return false; } bool TLSContext::setPrivateKey(const PrivateKey::ref& /* privateKey */) { assert(false); return false; } +void TLSContext::setAbortTLSHandshake(bool /* abort */) { + assert(false); +} + + Certificate::ref TLSContext::getPeerCertificate() const { std::vector<Certificate::ref> chain = getPeerCertificateChain(); return chain.empty() ? Certificate::ref() : chain[0]; } } diff --git a/Swiften/TLS/TLSContext.h b/Swiften/TLS/TLSContext.h index 816f1c1..2655d4b 100644 --- a/Swiften/TLS/TLSContext.h +++ b/Swiften/TLS/TLSContext.h @@ -1,56 +1,64 @@ /* * Copyright (c) 2010-2018 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #pragma once #include <memory> #include <boost/signals2.hpp> #include <Swiften/Base/API.h> #include <Swiften/Base/SafeByteArray.h> #include <Swiften/TLS/Certificate.h> #include <Swiften/TLS/CertificateVerificationError.h> #include <Swiften/TLS/CertificateWithKey.h> #include <Swiften/TLS/PrivateKey.h> #include <Swiften/TLS/TLSError.h> namespace Swift { class SWIFTEN_API TLSContext { public: virtual ~TLSContext(); virtual void accept(); virtual void connect() = 0; + virtual void connect(const std::string& serverName); virtual bool setCertificateChain(const std::vector<Certificate::ref>& /* certificateChain */); virtual bool setPrivateKey(const PrivateKey::ref& /* privateKey */); virtual bool setClientCertificate(CertificateWithKey::ref cert) = 0; + /** + * This method can be used during the \ref onServerNameRequested signal, + * to report an error about an unknown host back to the requesting client. + */ + virtual void setAbortTLSHandshake(bool /* abort */); + virtual void handleDataFromNetwork(const SafeByteArray&) = 0; virtual void handleDataFromApplication(const SafeByteArray&) = 0; Certificate::ref getPeerCertificate() const; virtual std::vector<Certificate::ref> getPeerCertificateChain() const = 0; virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const = 0; virtual ByteArray getFinishMessage() const = 0; public: enum class Mode { Client, Server }; public: boost::signals2::signal<void (const SafeByteArray&)> onDataForNetwork; boost::signals2::signal<void (const SafeByteArray&)> onDataForApplication; boost::signals2::signal<void (std::shared_ptr<TLSError>)> onError; boost::signals2::signal<void ()> onConnected; + boost::signals2::signal<void (const std::string&)> onServerNameRequested; }; } diff --git a/Swiften/TLS/UnitTest/ClientServerTest.cpp b/Swiften/TLS/UnitTest/ClientServerTest.cpp index 9aa1762..692b3c0 100644 --- a/Swiften/TLS/UnitTest/ClientServerTest.cpp +++ b/Swiften/TLS/UnitTest/ClientServerTest.cpp @@ -266,102 +266,115 @@ class ClientServerConnector { connections_.push_back(clientContext_->onDataForNetwork.connect([&](const SafeByteArray& data) { serverContext_->handleDataFromNetwork(data); })); connections_.push_back(serverContext_->onDataForNetwork.connect([&](const SafeByteArray& data) { clientContext_->handleDataFromNetwork(data); })); } private: TLSContext* clientContext_; TLSContext* serverContext_; std::vector<boost::signals2::connection> connections_; }; struct TLSDataForNetwork { SafeByteArray data; }; struct TLSDataForApplication { SafeByteArray data; }; struct TLSFault { std::shared_ptr<Swift::TLSError> error; }; struct TLSConnected { std::vector<Certificate::ref> chain; }; -using TLSEvent = boost::variant<TLSDataForNetwork, TLSDataForApplication, TLSFault, TLSConnected>; +struct TLSServerNameRequested { + std::string name; +}; + +using TLSEvent = boost::variant<TLSDataForNetwork, TLSDataForApplication, TLSFault, TLSConnected, TLSServerNameRequested>; class TLSEventToSafeByteArrayVisitor : public boost::static_visitor<SafeByteArray> { public: SafeByteArray operator()(const TLSDataForNetwork& tlsData) const { return tlsData.data; } SafeByteArray operator()(const TLSDataForApplication& tlsData) const { return tlsData.data; } SafeByteArray operator()(const TLSFault&) const { return createSafeByteArray(""); } SafeByteArray operator()(const TLSConnected&) const { return createSafeByteArray(""); } + + SafeByteArray operator()(const TLSServerNameRequested&) const { + return createSafeByteArray(""); + } + }; class TLSEventToStringVisitor : public boost::static_visitor<std::string> { public: std::string operator()(const TLSDataForNetwork& event) const { return std::string("TLSDataForNetwork(") + "size: " + std::to_string(event.data.size()) + ")"; } std::string operator()(const TLSDataForApplication& event) const { return std::string("TLSDataForApplication(") + "size: " + std::to_string(event.data.size()) + ")"; } std::string operator()(const TLSFault&) const { return "TLSFault()"; } std::string operator()(const TLSConnected& event) const { std::string certificates; for (auto cert : event.chain) { certificates += "\t" + cert->getSubjectName() + "\n"; } return std::string("TLSConnected()") + "\n" + certificates; } + + std::string operator()(const TLSServerNameRequested& event) const { + return std::string("TLSServerNameRequested(") + "name: " + event.name + ")"; + } }; class TLSClientServerEventHistory { public: TLSClientServerEventHistory(TLSContext* client, TLSContext* server) { connectContext(std::string("client"), client); connectContext(std::string("server"), server); } __attribute__((unused)) void print() { auto count = 0; std::cout << "\n"; for (auto event : events) { if (event.first == "server") { std::cout << std::string(80, ' '); } std::cout << count << ". "; std::cout << event.first << " : " << boost::apply_visitor(TLSEventToStringVisitor(), event.second) << std::endl; count++; } } private: void connectContext(const std::string& name, TLSContext* context) { connections_.push_back(context->onDataForNetwork.connect([=](const SafeByteArray& data) { events.push_back(std::pair<std::string, TLSEvent>(name, TLSDataForNetwork{data})); })); connections_.push_back(context->onDataForApplication.connect([=](const SafeByteArray& data) { events.push_back(std::pair<std::string, TLSEvent>(name, TLSDataForApplication{data})); @@ -545,30 +558,100 @@ TEST(ClientServerTest, testSettingPrivateKeyWithWrongPassword) { TLSClientServerEventHistory events(clientContext.get(), serverContext.get()); ClientServerConnector connector(clientContext.get(), serverContext.get()); auto tlsFactories = std::make_shared<PlatformTLSFactories>(); ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["montague.example"])))); auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(montagueEncryptedPEM), createSafeByteArray("foo")); ASSERT_NE(nullptr, privateKey.get()); ASSERT_EQ(false, serverContext->setPrivateKey(privateKey)); } TEST(ClientServerTest, testSettingPrivateKeyWithoutRequiredPassword) { auto clientContext = createTLSContext(TLSContext::Mode::Client); auto serverContext = createTLSContext(TLSContext::Mode::Server); TLSClientServerEventHistory events(clientContext.get(), serverContext.get()); ClientServerConnector connector(clientContext.get(), serverContext.get()); auto tlsFactories = std::make_shared<PlatformTLSFactories>(); ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["montague.example"])))); auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(montagueEncryptedPEM)); ASSERT_NE(nullptr, privateKey.get()); ASSERT_EQ(false, serverContext->setPrivateKey(privateKey)); } + +TEST(ClientServerTest, testClientServerSNIRequestedHostAvailable) { + auto tlsFactories = std::make_shared<PlatformTLSFactories>(); + auto clientContext = createTLSContext(TLSContext::Mode::Client); + auto serverContext = createTLSContext(TLSContext::Mode::Server); + + serverContext->onServerNameRequested.connect([&](const std::string& requestedName) { + if (certificatePEM.find(requestedName) != certificatePEM.end() && privateKeyPEM.find(requestedName) != privateKeyPEM.end()) { + auto certChain = tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM[requestedName])); + ASSERT_EQ(true, serverContext->setCertificateChain(certChain)); + + auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(privateKeyPEM[requestedName])); + ASSERT_NE(nullptr, privateKey.get()); + ASSERT_EQ(true, serverContext->setPrivateKey(privateKey)); + } + }); + + TLSClientServerEventHistory events(clientContext.get(), serverContext.get()); + + ClientServerConnector connector(clientContext.get(), serverContext.get()); + + ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["capulet.example"])))); + + auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(privateKeyPEM["capulet.example"])); + ASSERT_NE(nullptr, privateKey.get()); + ASSERT_EQ(true, serverContext->setPrivateKey(privateKey)); + + serverContext->accept(); + clientContext->connect("montague.example"); + + clientContext->handleDataFromApplication(createSafeByteArray("This is a test message from the client.")); + serverContext->handleDataFromApplication(createSafeByteArray("This is a test message from the server.")); + ASSERT_EQ("This is a test message from the client.", safeByteArrayToString(boost::apply_visitor(TLSEventToSafeByteArrayVisitor(), std::find_if(events.events.begin(), events.events.end(), [](std::pair<std::string, TLSEvent>& event){ + return event.first == "server" && (event.second.type() == typeid(TLSDataForApplication)); + })->second))); + ASSERT_EQ("This is a test message from the server.", safeByteArrayToString(boost::apply_visitor(TLSEventToSafeByteArrayVisitor(), std::find_if(events.events.begin(), events.events.end(), [](std::pair<std::string, TLSEvent>& event){ + return event.first == "client" && (event.second.type() == typeid(TLSDataForApplication)); + })->second))); + + ASSERT_EQ("/CN=montague.example", boost::get<TLSConnected>(events.events[5].second).chain[0]->getSubjectName()); +} + +TEST(ClientServerTest, testClientServerSNIRequestedHostUnavailable) { + auto tlsFactories = std::make_shared<PlatformTLSFactories>(); + auto clientContext = createTLSContext(TLSContext::Mode::Client); + auto serverContext = createTLSContext(TLSContext::Mode::Server); + + serverContext->onServerNameRequested.connect([&](const std::string&) { + serverContext->setAbortTLSHandshake(true); + }); + + TLSClientServerEventHistory events(clientContext.get(), serverContext.get()); + + ClientServerConnector connector(clientContext.get(), serverContext.get()); + + ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["capulet.example"])))); + + auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(privateKeyPEM["capulet.example"])); + ASSERT_NE(nullptr, privateKey.get()); + ASSERT_EQ(true, serverContext->setPrivateKey(privateKey)); + + serverContext->accept(); + clientContext->connect("montague.example"); + + ASSERT_EQ("server", events.events[1].first); + ASSERT_EQ("TLSFault()", boost::apply_visitor(TLSEventToStringVisitor(), events.events[1].second)); + + ASSERT_EQ("client", events.events[3].first); + ASSERT_EQ("TLSFault()", boost::apply_visitor(TLSEventToStringVisitor(), events.events[3].second)); +} |