diff options
| -rw-r--r-- | Swiften/TLS/OpenSSL/OpenSSLContext.cpp | 4 | ||||
| -rw-r--r-- | Swiften/TLS/UnitTest/ClientServerTest.cpp | 83 |
2 files changed, 54 insertions, 33 deletions
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp index e9889bc..5692e74 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp +++ b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp @@ -223,18 +223,19 @@ void OpenSSLContext::doAccept() { switch (error) { case SSL_ERROR_NONE: { state_ = State::Connected; //std::cout << x->name << std::endl; //const char* comp = SSL_get_current_compression(handle_.get()); //std::cout << "Compression: " << SSL_COMP_get_name(comp) << std::endl; onConnected(); // The following call is important so the client knowns the handshake is finished. sendPendingDataToNetwork(); + sendPendingDataToApplication(); break; } case SSL_ERROR_WANT_READ: sendPendingDataToNetwork(); break; case SSL_ERROR_WANT_WRITE: sendPendingDataToNetwork(); break; default: @@ -248,18 +249,21 @@ void OpenSSLContext::doConnect() { int connectResult = SSL_connect(handle_.get()); int error = SSL_get_error(handle_.get(), connectResult); switch (error) { case SSL_ERROR_NONE: { state_ = State::Connected; //std::cout << x->name << std::endl; //const char* comp = SSL_get_current_compression(handle_.get()); //std::cout << "Compression: " << SSL_COMP_get_name(comp) << std::endl; onConnected(); + // The following is needed since OpenSSL 1.1.1 for the server to be able to calculate the + // TLS finish message. + sendPendingDataToNetwork(); break; } case SSL_ERROR_WANT_READ: sendPendingDataToNetwork(); break; default: state_ = State::Error; onError(std::make_shared<TLSError>()); onError(std::make_shared<TLSError>(TLSError::ConnectFailed, openSSLInternalErrorToString())); diff --git a/Swiften/TLS/UnitTest/ClientServerTest.cpp b/Swiften/TLS/UnitTest/ClientServerTest.cpp index 24bd7c5..a356dcf 100644 --- a/Swiften/TLS/UnitTest/ClientServerTest.cpp +++ b/Swiften/TLS/UnitTest/ClientServerTest.cpp @@ -1,11 +1,11 @@ /* - * Copyright (c) 2010-2018 Isode Limited. + * Copyright (c) 2010-2019 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #include <map> #include <memory> #include <utility> #include <vector> @@ -402,19 +402,18 @@ class TLSEventToSafeByteArrayVisitor : public boost::static_visitor<SafeByteArra } SafeByteArray operator()(const TLSConnected&) const { return createSafeByteArray(""); } SafeByteArray operator()(const TLSServerNameRequested&) const { return createSafeByteArray(""); } - }; class TLSEventToStringVisitor : public boost::static_visitor<std::string> { public: std::string operator()(const TLSDataForNetwork& event) const { return std::string("TLSDataForNetwork(") + "size: " + std::to_string(event.data.size()) + ")"; } std::string operator()(const TLSDataForApplication& event) const { @@ -453,18 +452,35 @@ class TLSClientServerEventHistory { if (event.first == "server") { std::cout << std::string(80, ' '); } std::cout << count << ". "; std::cout << event.first << " : " << boost::apply_visitor(TLSEventToStringVisitor(), event.second) << std::endl; count++; } } + template<class TLSEventType> + boost::optional<TLSEventType> getEvent(const std::string& peer, size_t number = 0) { + for (const auto& pair : events) { + if (pair.first == peer) { + if (pair.second.type() == typeid(TLSEventType)) { + if (number == 0) { + return boost::optional<TLSEventType>(boost::get<TLSEventType>(pair.second)); + } + else { + number--; + } + } + } + } + return {}; + } + private: void connectContext(const std::string& name, TLSContext* context) { connections_.push_back(context->onDataForNetwork.connect([=](const SafeByteArray& data) { events.push_back(std::pair<std::string, TLSEvent>(name, TLSDataForNetwork{data})); })); connections_.push_back(context->onDataForApplication.connect([=](const SafeByteArray& data) { events.push_back(std::pair<std::string, TLSEvent>(name, TLSDataForApplication{data})); })); connections_.push_back(context->onError.connect([=](std::shared_ptr<Swift::TLSError> error) { @@ -596,24 +612,24 @@ TEST(ClientServerTest, testClientServerBasicCommunication) { ASSERT_NE(nullptr, privateKey.get()); ASSERT_EQ(true, serverContext->setPrivateKey(privateKey)); serverContext->accept(); clientContext->connect(); clientContext->handleDataFromApplication(createSafeByteArray("This is a test message from the client.")); serverContext->handleDataFromApplication(createSafeByteArray("This is a test message from the server.")); - ASSERT_EQ(safeByteArrayToString(createSafeByteArray("This is a test message from the client.")), safeByteArrayToString(boost::apply_visitor(TLSEventToSafeByteArrayVisitor(), std::find_if(events.events.begin(), events.events.end(), [](std::pair<std::string, TLSEvent>& event){ - return event.first == "server" && (event.second.type() == typeid(TLSDataForApplication)); - })->second))); - ASSERT_EQ(safeByteArrayToString(createSafeByteArray("This is a test message from the server.")), safeByteArrayToString(boost::apply_visitor(TLSEventToSafeByteArrayVisitor(), std::find_if(events.events.begin(), events.events.end(), [](std::pair<std::string, TLSEvent>& event){ - return event.first == "client" && (event.second.type() == typeid(TLSDataForApplication)); - })->second))); + auto firstMessageFromClient = events.getEvent<TLSDataForApplication>("server"); + ASSERT_EQ(true, firstMessageFromClient.is_initialized()); + ASSERT_EQ(safeByteArrayToString(createSafeByteArray("This is a test message from the client.")), safeByteArrayToString(firstMessageFromClient->data)); + auto firstMessageFromServer = events.getEvent<TLSDataForApplication>("client"); + ASSERT_EQ(true, firstMessageFromServer.is_initialized()); + ASSERT_EQ(safeByteArrayToString(createSafeByteArray("This is a test message from the server.")), safeByteArrayToString(firstMessageFromServer->data)); } TEST(ClientServerTest, testClientServerBasicCommunicationEncryptedPrivateKeyRightPassword) { auto clientContext = createTLSContext(TLSContext::Mode::Client); auto serverContext = createTLSContext(TLSContext::Mode::Server); TLSClientServerEventHistory events(clientContext.get(), serverContext.get()); ClientServerConnector connector(clientContext.get(), serverContext.get()); @@ -626,24 +642,24 @@ TEST(ClientServerTest, testClientServerBasicCommunicationEncryptedPrivateKeyRigh ASSERT_NE(nullptr, privateKey.get()); ASSERT_EQ(true, serverContext->setPrivateKey(privateKey)); serverContext->accept(); clientContext->connect(); clientContext->handleDataFromApplication(createSafeByteArray("This is a test message from the client.")); serverContext->handleDataFromApplication(createSafeByteArray("This is a test message from the server.")); - ASSERT_EQ(safeByteArrayToString(createSafeByteArray("This is a test message from the client.")), safeByteArrayToString(boost::apply_visitor(TLSEventToSafeByteArrayVisitor(), std::find_if(events.events.begin(), events.events.end(), [](std::pair<std::string, TLSEvent>& event){ - return event.first == "server" && (event.second.type() == typeid(TLSDataForApplication)); - })->second))); - ASSERT_EQ(safeByteArrayToString(createSafeByteArray("This is a test message from the server.")), safeByteArrayToString(boost::apply_visitor(TLSEventToSafeByteArrayVisitor(), std::find_if(events.events.begin(), events.events.end(), [](std::pair<std::string, TLSEvent>& event){ - return event.first == "client" && (event.second.type() == typeid(TLSDataForApplication)); - })->second))); + auto firstMessageFromClient = events.getEvent<TLSDataForApplication>("server"); + ASSERT_EQ(true, firstMessageFromClient.is_initialized()); + ASSERT_EQ(safeByteArrayToString(createSafeByteArray("This is a test message from the client.")), safeByteArrayToString(firstMessageFromClient->data)); + auto firstMessageFromServer = events.getEvent<TLSDataForApplication>("client"); + ASSERT_EQ(true, firstMessageFromServer.is_initialized()); + ASSERT_EQ(safeByteArrayToString(createSafeByteArray("This is a test message from the server.")), safeByteArrayToString(firstMessageFromServer->data)); } TEST(ClientServerTest, testClientServerBasicCommunicationWithChainedCert) { auto clientContext = createTLSContext(TLSContext::Mode::Client); auto serverContext = createTLSContext(TLSContext::Mode::Server); TLSClientServerEventHistory events(clientContext.get(), serverContext.get()); ClientServerConnector connector(clientContext.get(), serverContext.get()); @@ -733,26 +749,27 @@ TEST(ClientServerTest, testClientServerSNIRequestedHostAvailable) { auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(privateKeyPEM["capulet.example"])); ASSERT_NE(nullptr, privateKey.get()); ASSERT_EQ(true, serverContext->setPrivateKey(privateKey)); serverContext->accept(); clientContext->connect("montague.example"); clientContext->handleDataFromApplication(createSafeByteArray("This is a test message from the client.")); serverContext->handleDataFromApplication(createSafeByteArray("This is a test message from the server.")); - ASSERT_EQ("This is a test message from the client.", safeByteArrayToString(boost::apply_visitor(TLSEventToSafeByteArrayVisitor(), std::find_if(events.events.begin(), events.events.end(), [](std::pair<std::string, TLSEvent>& event){ - return event.first == "server" && (event.second.type() == typeid(TLSDataForApplication)); - })->second))); - ASSERT_EQ("This is a test message from the server.", safeByteArrayToString(boost::apply_visitor(TLSEventToSafeByteArrayVisitor(), std::find_if(events.events.begin(), events.events.end(), [](std::pair<std::string, TLSEvent>& event){ - return event.first == "client" && (event.second.type() == typeid(TLSDataForApplication)); - })->second))); - ASSERT_EQ("/CN=montague.example", boost::get<TLSConnected>(events.events[5].second).chain[0]->getSubjectName()); + auto firstMessageFromClient = events.getEvent<TLSDataForApplication>("server"); + ASSERT_EQ(true, firstMessageFromClient.is_initialized()); + ASSERT_EQ(safeByteArrayToString(createSafeByteArray("This is a test message from the client.")), safeByteArrayToString(firstMessageFromClient->data)); + auto firstMessageFromServer = events.getEvent<TLSDataForApplication>("client"); + ASSERT_EQ(true, firstMessageFromServer.is_initialized()); + ASSERT_EQ(safeByteArrayToString(createSafeByteArray("This is a test message from the server.")), safeByteArrayToString(firstMessageFromServer->data)); + + ASSERT_EQ("/CN=montague.example", events.getEvent<TLSConnected>("client")->chain[0]->getSubjectName()); } TEST(ClientServerTest, testClientServerSNIRequestedHostUnavailable) { auto tlsFactories = std::make_shared<PlatformTLSFactories>(); auto clientContext = createTLSContext(TLSContext::Mode::Client); auto serverContext = createTLSContext(TLSContext::Mode::Server); serverContext->onServerNameRequested.connect([&](const std::string&) { serverContext->setAbortTLSHandshake(true); @@ -819,24 +836,24 @@ TEST(ClientServerTest, testClientServerBasicCommunicationWith2048BitDHParams) { ASSERT_EQ(true, serverContext->setDiffieHellmanParameters(tlsFactories->getTLSContextFactory()->convertDHParametersFromPEMToDER(dhParamsOpenSslDer2048))); serverContext->accept(); clientContext->connect(); clientContext->handleDataFromApplication(createSafeByteArray("This is a test message from the client.")); serverContext->handleDataFromApplication(createSafeByteArray("This is a test message from the server.")); - ASSERT_EQ(safeByteArrayToString(createSafeByteArray("This is a test message from the client.")), safeByteArrayToString(boost::apply_visitor(TLSEventToSafeByteArrayVisitor(), std::find_if(events.events.begin(), events.events.end(), [](std::pair<std::string, TLSEvent>& event){ - return event.first == "server" && (event.second.type() == typeid(TLSDataForApplication)); - })->second))); - ASSERT_EQ(safeByteArrayToString(createSafeByteArray("This is a test message from the server.")), safeByteArrayToString(boost::apply_visitor(TLSEventToSafeByteArrayVisitor(), std::find_if(events.events.begin(), events.events.end(), [](std::pair<std::string, TLSEvent>& event){ - return event.first == "client" && (event.second.type() == typeid(TLSDataForApplication)); - })->second))); + auto firstMessageFromClient = events.getEvent<TLSDataForApplication>("server"); + ASSERT_EQ(true, firstMessageFromClient.is_initialized()); + ASSERT_EQ(safeByteArrayToString(createSafeByteArray("This is a test message from the client.")), safeByteArrayToString(firstMessageFromClient->data)); + auto firstMessageFromServer = events.getEvent<TLSDataForApplication>("client"); + ASSERT_EQ(true, firstMessageFromServer.is_initialized()); + ASSERT_EQ(safeByteArrayToString(createSafeByteArray("This is a test message from the server.")), safeByteArrayToString(firstMessageFromServer->data)); } TEST(ClientServerTest, testClientServerBasicCommunicationWith1024BitDHParams) { auto clientContext = createTLSContext(TLSContext::Mode::Client); auto serverContext = createTLSContext(TLSContext::Mode::Server); TLSClientServerEventHistory events(clientContext.get(), serverContext.get()); ClientServerConnector connector(clientContext.get(), serverContext.get()); @@ -851,16 +868,16 @@ TEST(ClientServerTest, testClientServerBasicCommunicationWith1024BitDHParams) { ASSERT_EQ(true, serverContext->setDiffieHellmanParameters(tlsFactories->getTLSContextFactory()->convertDHParametersFromPEMToDER(dhParamsOpenSslDer1024))); serverContext->accept(); clientContext->connect(); clientContext->handleDataFromApplication(createSafeByteArray("This is a test message from the client.")); serverContext->handleDataFromApplication(createSafeByteArray("This is a test message from the server.")); - ASSERT_EQ(safeByteArrayToString(createSafeByteArray("This is a test message from the client.")), safeByteArrayToString(boost::apply_visitor(TLSEventToSafeByteArrayVisitor(), std::find_if(events.events.begin(), events.events.end(), [](std::pair<std::string, TLSEvent>& event){ - return event.first == "server" && (event.second.type() == typeid(TLSDataForApplication)); - })->second))); - ASSERT_EQ(safeByteArrayToString(createSafeByteArray("This is a test message from the server.")), safeByteArrayToString(boost::apply_visitor(TLSEventToSafeByteArrayVisitor(), std::find_if(events.events.begin(), events.events.end(), [](std::pair<std::string, TLSEvent>& event){ - return event.first == "client" && (event.second.type() == typeid(TLSDataForApplication)); - })->second))); + auto firstMessageFromClient = events.getEvent<TLSDataForApplication>("server"); + ASSERT_EQ(true, firstMessageFromClient.is_initialized()); + ASSERT_EQ(safeByteArrayToString(createSafeByteArray("This is a test message from the client.")), safeByteArrayToString(firstMessageFromClient->data)); + auto firstMessageFromServer = events.getEvent<TLSDataForApplication>("client"); + ASSERT_EQ(true, firstMessageFromServer.is_initialized()); + ASSERT_EQ(safeByteArrayToString(createSafeByteArray("This is a test message from the server.")), safeByteArrayToString(firstMessageFromServer->data)); } |
Swift