From 5eed7fcd3c8d42837a013855114deb6cdcaf47d0 Mon Sep 17 00:00:00 2001
From: Tobias Markmann <tm@ayena.de>
Date: Mon, 19 Feb 2018 15:22:19 +0100
Subject: Add support for Server Name Indication to OpenSSLContext

Test-Information:

Builds and unit tests pass on macOS 10.13.3 with OpenSSL TLS
backend.

Change-Id: Ie8f4578c867a2e4bf84484cde4a7cff048566ca4

diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
index f90b4a8..47e7175 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
+++ b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
@@ -92,6 +92,11 @@ OpenSSLContext::OpenSSLContext(Mode mode) : mode_(mode), state_(State::Start) {
     context_ = createSSL_CTX(mode_);
     SSL_CTX_set_options(context_.get(), SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
 
+    if (mode_ == Mode::Server) {
+        SSL_CTX_set_tlsext_servername_arg(context_.get(), this);
+        SSL_CTX_set_tlsext_servername_callback(context_.get(), OpenSSLContext::handleServerNameCallback);
+    }
+
     // TODO: implement CRL checking
     // TODO: download CRL (HTTP transport)
     // TODO: cache CRL downloads for configurable time period
@@ -176,6 +181,10 @@ void OpenSSLContext::accept() {
 }
 
 void OpenSSLContext::connect() {
+    connect(std::string());
+}
+
+void OpenSSLContext::connect(const std::string& requestedServerName) {
     assert(mode_ == Mode::Client);
     handle_ = std::unique_ptr<SSL>(SSL_new(context_.get()));
     if (!handle_) {
@@ -184,6 +193,12 @@ void OpenSSLContext::connect() {
         return;
     }
 
+    if (!requestedServerName.empty()) {
+        if (SSL_set_tlsext_host_name(handle_.get(), const_cast<char*>(requestedServerName.c_str())) != 1) {
+            SWIFT_LOG(error) << "Failed on SSL_set_tlsext_host_name()." << std::endl;
+        }
+    }
+
     // Ownership of BIOs is transferred to the SSL_CTX instance in handle_.
     initAndSetBIOs();
 
@@ -215,6 +230,7 @@ void OpenSSLContext::doAccept() {
             SWIFT_LOG(warning) << openSSLInternalErrorToString() << std::endl;
             state_ = State::Error;
             onError(std::make_shared<TLSError>());
+            sendPendingDataToNetwork();
     }
 }
 
@@ -240,6 +256,24 @@ void OpenSSLContext::doConnect() {
     }
 }
 
+int OpenSSLContext::handleServerNameCallback(SSL* ssl, int*, void* arg) {
+    if (ssl == nullptr)
+        return SSL_TLSEXT_ERR_NOACK;
+
+    const char* servername = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
+    if (servername) {
+        auto serverNameString = std::string(servername);
+        auto context = reinterpret_cast<OpenSSLContext*>(arg);
+        context->onServerNameRequested(serverNameString);
+
+        if (context->abortTLSHandshake_) {
+            context->abortTLSHandshake_ = false;
+            return SSL_TLSEXT_ERR_ALERT_FATAL;
+        }
+    }
+    return SSL_TLSEXT_ERR_OK;
+}
+
 void OpenSSLContext::sendPendingDataToNetwork() {
     int size = BIO_pending(writeBIO_);
     if (size > 0) {
@@ -385,6 +419,10 @@ bool OpenSSLContext::setPrivateKey(const PrivateKey::ref& privateKey) {
     return true;
 }
 
+void OpenSSLContext::setAbortTLSHandshake(bool abort) {
+    abortTLSHandshake_ = abort;
+}
+
 bool OpenSSLContext::setClientCertificate(CertificateWithKey::ref certificate) {
     std::shared_ptr<PKCS12Certificate> pkcs12Certificate = std::dynamic_pointer_cast<PKCS12Certificate>(certificate);
     if (!pkcs12Certificate || pkcs12Certificate->isNull()) {
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.h b/Swiften/TLS/OpenSSL/OpenSSLContext.h
index 5f06811..4a94848 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLContext.h
+++ b/Swiften/TLS/OpenSSL/OpenSSLContext.h
@@ -43,10 +43,12 @@ namespace Swift {
 
             void accept() override final;
             void connect() override final;
+            void connect(const std::string& requestHostname) override final;
 
             bool setCertificateChain(const std::vector<Certificate::ref>& certificateChain) override final;
             bool setPrivateKey(const PrivateKey::ref& privateKey) override final;
             bool setClientCertificate(CertificateWithKey::ref cert) override final;
+            void setAbortTLSHandshake(bool abort) override final;
 
             void handleDataFromNetwork(const SafeByteArray&) override final;
             void handleDataFromApplication(const SafeByteArray&) override final;
@@ -58,7 +60,7 @@ namespace Swift {
 
         private:
             static void ensureLibraryInitialized();
-
+            static int handleServerNameCallback(SSL *ssl, int *ad, void *arg);
             static CertificateVerificationError::Type getVerificationErrorTypeForResult(int);
 
             void initAndSetBIOs();
@@ -70,11 +72,12 @@ namespace Swift {
         private:
             enum class State { Start, Accepting, Connecting, Connected, Error };
 
-            Mode mode_;
+            const Mode mode_;
             State state_;
             std::unique_ptr<SSL_CTX> context_;
             std::unique_ptr<SSL> handle_;
             BIO* readBIO_ = nullptr;
             BIO* writeBIO_ = nullptr;
+            bool abortTLSHandshake_ = false;
     };
 }
diff --git a/Swiften/TLS/TLSContext.cpp b/Swiften/TLS/TLSContext.cpp
index 56814d2..39fb5c9 100644
--- a/Swiften/TLS/TLSContext.cpp
+++ b/Swiften/TLS/TLSContext.cpp
@@ -14,7 +14,11 @@ TLSContext::~TLSContext() {
 }
 
 void TLSContext::accept() {
-     assert(false);
+    assert(false);
+}
+
+void TLSContext::connect(const std::string& /* serverName */) {
+    assert(false);
 }
 
 bool TLSContext::setCertificateChain(const std::vector<Certificate::ref>& /* certificateChain */) {
@@ -27,6 +31,11 @@ bool TLSContext::setPrivateKey(const PrivateKey::ref& /* privateKey */) {
     return false;
 }
 
+void TLSContext::setAbortTLSHandshake(bool /* abort */) {
+    assert(false);
+}
+
+
 Certificate::ref TLSContext::getPeerCertificate() const {
     std::vector<Certificate::ref> chain = getPeerCertificateChain();
     return chain.empty() ? Certificate::ref() : chain[0];
diff --git a/Swiften/TLS/TLSContext.h b/Swiften/TLS/TLSContext.h
index 816f1c1..2655d4b 100644
--- a/Swiften/TLS/TLSContext.h
+++ b/Swiften/TLS/TLSContext.h
@@ -26,12 +26,19 @@ namespace Swift {
 
             virtual void accept();
             virtual void connect() = 0;
+            virtual void connect(const std::string& serverName);
 
             virtual bool setCertificateChain(const std::vector<Certificate::ref>& /* certificateChain */);
             virtual bool setPrivateKey(const PrivateKey::ref& /* privateKey */);
 
             virtual bool setClientCertificate(CertificateWithKey::ref cert) = 0;
 
+            /**
+             *  This method can be used during the \ref onServerNameRequested signal,
+             *  to report an error about an unknown host back to the requesting client.
+             */
+            virtual void setAbortTLSHandshake(bool /* abort */);
+
             virtual void handleDataFromNetwork(const SafeByteArray&) = 0;
             virtual void handleDataFromApplication(const SafeByteArray&) = 0;
 
@@ -52,5 +59,6 @@ namespace Swift {
             boost::signals2::signal<void (const SafeByteArray&)> onDataForApplication;
             boost::signals2::signal<void (std::shared_ptr<TLSError>)> onError;
             boost::signals2::signal<void ()> onConnected;
+            boost::signals2::signal<void (const std::string&)> onServerNameRequested;
     };
 }
diff --git a/Swiften/TLS/UnitTest/ClientServerTest.cpp b/Swiften/TLS/UnitTest/ClientServerTest.cpp
index 9aa1762..692b3c0 100644
--- a/Swiften/TLS/UnitTest/ClientServerTest.cpp
+++ b/Swiften/TLS/UnitTest/ClientServerTest.cpp
@@ -293,7 +293,11 @@ struct TLSConnected {
     std::vector<Certificate::ref> chain;
 };
 
-using TLSEvent = boost::variant<TLSDataForNetwork, TLSDataForApplication, TLSFault, TLSConnected>;
+struct TLSServerNameRequested {
+    std::string name;
+};
+
+using TLSEvent = boost::variant<TLSDataForNetwork, TLSDataForApplication, TLSFault, TLSConnected, TLSServerNameRequested>;
 
 class TLSEventToSafeByteArrayVisitor : public boost::static_visitor<SafeByteArray> {
     public:
@@ -312,6 +316,11 @@ class TLSEventToSafeByteArrayVisitor : public boost::static_visitor<SafeByteArra
         SafeByteArray operator()(const TLSConnected&) const {
             return createSafeByteArray("");
         }
+
+        SafeByteArray operator()(const TLSServerNameRequested&) const {
+            return createSafeByteArray("");
+        }
+
 };
 
 class TLSEventToStringVisitor : public boost::static_visitor<std::string> {
@@ -335,6 +344,10 @@ class TLSEventToStringVisitor : public boost::static_visitor<std::string> {
             }
             return std::string("TLSConnected()") + "\n" + certificates;
         }
+
+        std::string operator()(const TLSServerNameRequested& event) const {
+            return std::string("TLSServerNameRequested(") + "name: " + event.name + ")";
+        }
 };
 
 class TLSClientServerEventHistory {
@@ -572,3 +585,73 @@ TEST(ClientServerTest, testSettingPrivateKeyWithoutRequiredPassword) {
     ASSERT_NE(nullptr, privateKey.get());
     ASSERT_EQ(false, serverContext->setPrivateKey(privateKey));
 }
+
+TEST(ClientServerTest, testClientServerSNIRequestedHostAvailable) {
+    auto tlsFactories = std::make_shared<PlatformTLSFactories>();
+    auto clientContext = createTLSContext(TLSContext::Mode::Client);
+    auto serverContext = createTLSContext(TLSContext::Mode::Server);
+
+    serverContext->onServerNameRequested.connect([&](const std::string& requestedName) {
+        if (certificatePEM.find(requestedName) != certificatePEM.end() && privateKeyPEM.find(requestedName) != privateKeyPEM.end()) {
+            auto certChain = tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM[requestedName]));
+            ASSERT_EQ(true, serverContext->setCertificateChain(certChain));
+
+            auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(privateKeyPEM[requestedName]));
+            ASSERT_NE(nullptr, privateKey.get());
+            ASSERT_EQ(true, serverContext->setPrivateKey(privateKey));
+        }
+    });
+
+    TLSClientServerEventHistory events(clientContext.get(), serverContext.get());
+
+    ClientServerConnector connector(clientContext.get(), serverContext.get());
+
+    ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["capulet.example"]))));
+
+    auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(privateKeyPEM["capulet.example"]));
+    ASSERT_NE(nullptr, privateKey.get());
+    ASSERT_EQ(true, serverContext->setPrivateKey(privateKey));
+
+    serverContext->accept();
+    clientContext->connect("montague.example");
+
+    clientContext->handleDataFromApplication(createSafeByteArray("This is a test message from the client."));
+    serverContext->handleDataFromApplication(createSafeByteArray("This is a test message from the server."));
+    ASSERT_EQ("This is a test message from the client.", safeByteArrayToString(boost::apply_visitor(TLSEventToSafeByteArrayVisitor(), std::find_if(events.events.begin(), events.events.end(), [](std::pair<std::string, TLSEvent>& event){
+        return event.first == "server" && (event.second.type() == typeid(TLSDataForApplication));
+    })->second)));
+    ASSERT_EQ("This is a test message from the server.", safeByteArrayToString(boost::apply_visitor(TLSEventToSafeByteArrayVisitor(), std::find_if(events.events.begin(), events.events.end(), [](std::pair<std::string, TLSEvent>& event){
+        return event.first == "client" && (event.second.type() == typeid(TLSDataForApplication));
+    })->second)));
+
+    ASSERT_EQ("/CN=montague.example", boost::get<TLSConnected>(events.events[5].second).chain[0]->getSubjectName());
+}
+
+TEST(ClientServerTest, testClientServerSNIRequestedHostUnavailable) {
+    auto tlsFactories = std::make_shared<PlatformTLSFactories>();
+    auto clientContext = createTLSContext(TLSContext::Mode::Client);
+    auto serverContext = createTLSContext(TLSContext::Mode::Server);
+
+    serverContext->onServerNameRequested.connect([&](const std::string&) {
+        serverContext->setAbortTLSHandshake(true);
+    });
+
+    TLSClientServerEventHistory events(clientContext.get(), serverContext.get());
+
+    ClientServerConnector connector(clientContext.get(), serverContext.get());
+
+    ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["capulet.example"]))));
+
+    auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(privateKeyPEM["capulet.example"]));
+    ASSERT_NE(nullptr, privateKey.get());
+    ASSERT_EQ(true, serverContext->setPrivateKey(privateKey));
+
+    serverContext->accept();
+    clientContext->connect("montague.example");
+
+    ASSERT_EQ("server", events.events[1].first);
+    ASSERT_EQ("TLSFault()", boost::apply_visitor(TLSEventToStringVisitor(), events.events[1].second));
+
+    ASSERT_EQ("client", events.events[3].first);
+    ASSERT_EQ("TLSFault()", boost::apply_visitor(TLSEventToStringVisitor(), events.events[3].second));
+}
-- 
cgit v0.10.2-6-g49f6