diff options
-rw-r--r-- | Swiften/TLS/OpenSSL/OpenSSLContext.cpp | 38 | ||||
-rw-r--r-- | Swiften/TLS/OpenSSL/OpenSSLContext.h | 7 | ||||
-rw-r--r-- | Swiften/TLS/TLSContext.cpp | 11 | ||||
-rw-r--r-- | Swiften/TLS/TLSContext.h | 8 | ||||
-rw-r--r-- | Swiften/TLS/UnitTest/ClientServerTest.cpp | 85 |
5 files changed, 145 insertions, 4 deletions
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp index f90b4a8..47e7175 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp +++ b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp @@ -92,6 +92,11 @@ OpenSSLContext::OpenSSLContext(Mode mode) : mode_(mode), state_(State::Start) { context_ = createSSL_CTX(mode_); SSL_CTX_set_options(context_.get(), SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3); + if (mode_ == Mode::Server) { + SSL_CTX_set_tlsext_servername_arg(context_.get(), this); + SSL_CTX_set_tlsext_servername_callback(context_.get(), OpenSSLContext::handleServerNameCallback); + } + // TODO: implement CRL checking // TODO: download CRL (HTTP transport) // TODO: cache CRL downloads for configurable time period @@ -176,6 +181,10 @@ void OpenSSLContext::accept() { } void OpenSSLContext::connect() { + connect(std::string()); +} + +void OpenSSLContext::connect(const std::string& requestedServerName) { assert(mode_ == Mode::Client); handle_ = std::unique_ptr<SSL>(SSL_new(context_.get())); if (!handle_) { @@ -184,6 +193,12 @@ void OpenSSLContext::connect() { return; } + if (!requestedServerName.empty()) { + if (SSL_set_tlsext_host_name(handle_.get(), const_cast<char*>(requestedServerName.c_str())) != 1) { + SWIFT_LOG(error) << "Failed on SSL_set_tlsext_host_name()." << std::endl; + } + } + // Ownership of BIOs is transferred to the SSL_CTX instance in handle_. initAndSetBIOs(); @@ -215,6 +230,7 @@ void OpenSSLContext::doAccept() { SWIFT_LOG(warning) << openSSLInternalErrorToString() << std::endl; state_ = State::Error; onError(std::make_shared<TLSError>()); + sendPendingDataToNetwork(); } } @@ -240,6 +256,24 @@ void OpenSSLContext::doConnect() { } } +int OpenSSLContext::handleServerNameCallback(SSL* ssl, int*, void* arg) { + if (ssl == nullptr) + return SSL_TLSEXT_ERR_NOACK; + + const char* servername = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); + if (servername) { + auto serverNameString = std::string(servername); + auto context = reinterpret_cast<OpenSSLContext*>(arg); + context->onServerNameRequested(serverNameString); + + if (context->abortTLSHandshake_) { + context->abortTLSHandshake_ = false; + return SSL_TLSEXT_ERR_ALERT_FATAL; + } + } + return SSL_TLSEXT_ERR_OK; +} + void OpenSSLContext::sendPendingDataToNetwork() { int size = BIO_pending(writeBIO_); if (size > 0) { @@ -385,6 +419,10 @@ bool OpenSSLContext::setPrivateKey(const PrivateKey::ref& privateKey) { return true; } +void OpenSSLContext::setAbortTLSHandshake(bool abort) { + abortTLSHandshake_ = abort; +} + bool OpenSSLContext::setClientCertificate(CertificateWithKey::ref certificate) { std::shared_ptr<PKCS12Certificate> pkcs12Certificate = std::dynamic_pointer_cast<PKCS12Certificate>(certificate); if (!pkcs12Certificate || pkcs12Certificate->isNull()) { diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.h b/Swiften/TLS/OpenSSL/OpenSSLContext.h index 5f06811..4a94848 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLContext.h +++ b/Swiften/TLS/OpenSSL/OpenSSLContext.h @@ -43,10 +43,12 @@ namespace Swift { void accept() override final; void connect() override final; + void connect(const std::string& requestHostname) override final; bool setCertificateChain(const std::vector<Certificate::ref>& certificateChain) override final; bool setPrivateKey(const PrivateKey::ref& privateKey) override final; bool setClientCertificate(CertificateWithKey::ref cert) override final; + void setAbortTLSHandshake(bool abort) override final; void handleDataFromNetwork(const SafeByteArray&) override final; void handleDataFromApplication(const SafeByteArray&) override final; @@ -58,7 +60,7 @@ namespace Swift { private: static void ensureLibraryInitialized(); - + static int handleServerNameCallback(SSL *ssl, int *ad, void *arg); static CertificateVerificationError::Type getVerificationErrorTypeForResult(int); void initAndSetBIOs(); @@ -70,11 +72,12 @@ namespace Swift { private: enum class State { Start, Accepting, Connecting, Connected, Error }; - Mode mode_; + const Mode mode_; State state_; std::unique_ptr<SSL_CTX> context_; std::unique_ptr<SSL> handle_; BIO* readBIO_ = nullptr; BIO* writeBIO_ = nullptr; + bool abortTLSHandshake_ = false; }; } diff --git a/Swiften/TLS/TLSContext.cpp b/Swiften/TLS/TLSContext.cpp index 56814d2..39fb5c9 100644 --- a/Swiften/TLS/TLSContext.cpp +++ b/Swiften/TLS/TLSContext.cpp @@ -14,7 +14,11 @@ TLSContext::~TLSContext() { } void TLSContext::accept() { - assert(false); + assert(false); +} + +void TLSContext::connect(const std::string& /* serverName */) { + assert(false); } bool TLSContext::setCertificateChain(const std::vector<Certificate::ref>& /* certificateChain */) { @@ -27,6 +31,11 @@ bool TLSContext::setPrivateKey(const PrivateKey::ref& /* privateKey */) { return false; } +void TLSContext::setAbortTLSHandshake(bool /* abort */) { + assert(false); +} + + Certificate::ref TLSContext::getPeerCertificate() const { std::vector<Certificate::ref> chain = getPeerCertificateChain(); return chain.empty() ? Certificate::ref() : chain[0]; diff --git a/Swiften/TLS/TLSContext.h b/Swiften/TLS/TLSContext.h index 816f1c1..2655d4b 100644 --- a/Swiften/TLS/TLSContext.h +++ b/Swiften/TLS/TLSContext.h @@ -26,12 +26,19 @@ namespace Swift { virtual void accept(); virtual void connect() = 0; + virtual void connect(const std::string& serverName); virtual bool setCertificateChain(const std::vector<Certificate::ref>& /* certificateChain */); virtual bool setPrivateKey(const PrivateKey::ref& /* privateKey */); virtual bool setClientCertificate(CertificateWithKey::ref cert) = 0; + /** + * This method can be used during the \ref onServerNameRequested signal, + * to report an error about an unknown host back to the requesting client. + */ + virtual void setAbortTLSHandshake(bool /* abort */); + virtual void handleDataFromNetwork(const SafeByteArray&) = 0; virtual void handleDataFromApplication(const SafeByteArray&) = 0; @@ -52,5 +59,6 @@ namespace Swift { boost::signals2::signal<void (const SafeByteArray&)> onDataForApplication; boost::signals2::signal<void (std::shared_ptr<TLSError>)> onError; boost::signals2::signal<void ()> onConnected; + boost::signals2::signal<void (const std::string&)> onServerNameRequested; }; } diff --git a/Swiften/TLS/UnitTest/ClientServerTest.cpp b/Swiften/TLS/UnitTest/ClientServerTest.cpp index 9aa1762..692b3c0 100644 --- a/Swiften/TLS/UnitTest/ClientServerTest.cpp +++ b/Swiften/TLS/UnitTest/ClientServerTest.cpp @@ -293,7 +293,11 @@ struct TLSConnected { std::vector<Certificate::ref> chain; }; -using TLSEvent = boost::variant<TLSDataForNetwork, TLSDataForApplication, TLSFault, TLSConnected>; +struct TLSServerNameRequested { + std::string name; +}; + +using TLSEvent = boost::variant<TLSDataForNetwork, TLSDataForApplication, TLSFault, TLSConnected, TLSServerNameRequested>; class TLSEventToSafeByteArrayVisitor : public boost::static_visitor<SafeByteArray> { public: @@ -312,6 +316,11 @@ class TLSEventToSafeByteArrayVisitor : public boost::static_visitor<SafeByteArra SafeByteArray operator()(const TLSConnected&) const { return createSafeByteArray(""); } + + SafeByteArray operator()(const TLSServerNameRequested&) const { + return createSafeByteArray(""); + } + }; class TLSEventToStringVisitor : public boost::static_visitor<std::string> { @@ -335,6 +344,10 @@ class TLSEventToStringVisitor : public boost::static_visitor<std::string> { } return std::string("TLSConnected()") + "\n" + certificates; } + + std::string operator()(const TLSServerNameRequested& event) const { + return std::string("TLSServerNameRequested(") + "name: " + event.name + ")"; + } }; class TLSClientServerEventHistory { @@ -572,3 +585,73 @@ TEST(ClientServerTest, testSettingPrivateKeyWithoutRequiredPassword) { ASSERT_NE(nullptr, privateKey.get()); ASSERT_EQ(false, serverContext->setPrivateKey(privateKey)); } + +TEST(ClientServerTest, testClientServerSNIRequestedHostAvailable) { + auto tlsFactories = std::make_shared<PlatformTLSFactories>(); + auto clientContext = createTLSContext(TLSContext::Mode::Client); + auto serverContext = createTLSContext(TLSContext::Mode::Server); + + serverContext->onServerNameRequested.connect([&](const std::string& requestedName) { + if (certificatePEM.find(requestedName) != certificatePEM.end() && privateKeyPEM.find(requestedName) != privateKeyPEM.end()) { + auto certChain = tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM[requestedName])); + ASSERT_EQ(true, serverContext->setCertificateChain(certChain)); + + auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(privateKeyPEM[requestedName])); + ASSERT_NE(nullptr, privateKey.get()); + ASSERT_EQ(true, serverContext->setPrivateKey(privateKey)); + } + }); + + TLSClientServerEventHistory events(clientContext.get(), serverContext.get()); + + ClientServerConnector connector(clientContext.get(), serverContext.get()); + + ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["capulet.example"])))); + + auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(privateKeyPEM["capulet.example"])); + ASSERT_NE(nullptr, privateKey.get()); + ASSERT_EQ(true, serverContext->setPrivateKey(privateKey)); + + serverContext->accept(); + clientContext->connect("montague.example"); + + clientContext->handleDataFromApplication(createSafeByteArray("This is a test message from the client.")); + serverContext->handleDataFromApplication(createSafeByteArray("This is a test message from the server.")); + ASSERT_EQ("This is a test message from the client.", safeByteArrayToString(boost::apply_visitor(TLSEventToSafeByteArrayVisitor(), std::find_if(events.events.begin(), events.events.end(), [](std::pair<std::string, TLSEvent>& event){ + return event.first == "server" && (event.second.type() == typeid(TLSDataForApplication)); + })->second))); + ASSERT_EQ("This is a test message from the server.", safeByteArrayToString(boost::apply_visitor(TLSEventToSafeByteArrayVisitor(), std::find_if(events.events.begin(), events.events.end(), [](std::pair<std::string, TLSEvent>& event){ + return event.first == "client" && (event.second.type() == typeid(TLSDataForApplication)); + })->second))); + + ASSERT_EQ("/CN=montague.example", boost::get<TLSConnected>(events.events[5].second).chain[0]->getSubjectName()); +} + +TEST(ClientServerTest, testClientServerSNIRequestedHostUnavailable) { + auto tlsFactories = std::make_shared<PlatformTLSFactories>(); + auto clientContext = createTLSContext(TLSContext::Mode::Client); + auto serverContext = createTLSContext(TLSContext::Mode::Server); + + serverContext->onServerNameRequested.connect([&](const std::string&) { + serverContext->setAbortTLSHandshake(true); + }); + + TLSClientServerEventHistory events(clientContext.get(), serverContext.get()); + + ClientServerConnector connector(clientContext.get(), serverContext.get()); + + ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["capulet.example"])))); + + auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(privateKeyPEM["capulet.example"])); + ASSERT_NE(nullptr, privateKey.get()); + ASSERT_EQ(true, serverContext->setPrivateKey(privateKey)); + + serverContext->accept(); + clientContext->connect("montague.example"); + + ASSERT_EQ("server", events.events[1].first); + ASSERT_EQ("TLSFault()", boost::apply_visitor(TLSEventToStringVisitor(), events.events[1].second)); + + ASSERT_EQ("client", events.events[3].first); + ASSERT_EQ("TLSFault()", boost::apply_visitor(TLSEventToStringVisitor(), events.events[3].second)); +} |