diff options
Diffstat (limited to 'Swiften/TLS/UnitTest/ClientServerTest.cpp')
-rw-r--r-- | Swiften/TLS/UnitTest/ClientServerTest.cpp | 85 |
1 files changed, 84 insertions, 1 deletions
diff --git a/Swiften/TLS/UnitTest/ClientServerTest.cpp b/Swiften/TLS/UnitTest/ClientServerTest.cpp index 9aa1762..692b3c0 100644 --- a/Swiften/TLS/UnitTest/ClientServerTest.cpp +++ b/Swiften/TLS/UnitTest/ClientServerTest.cpp @@ -266,102 +266,115 @@ class ClientServerConnector { connections_.push_back(clientContext_->onDataForNetwork.connect([&](const SafeByteArray& data) { serverContext_->handleDataFromNetwork(data); })); connections_.push_back(serverContext_->onDataForNetwork.connect([&](const SafeByteArray& data) { clientContext_->handleDataFromNetwork(data); })); } private: TLSContext* clientContext_; TLSContext* serverContext_; std::vector<boost::signals2::connection> connections_; }; struct TLSDataForNetwork { SafeByteArray data; }; struct TLSDataForApplication { SafeByteArray data; }; struct TLSFault { std::shared_ptr<Swift::TLSError> error; }; struct TLSConnected { std::vector<Certificate::ref> chain; }; -using TLSEvent = boost::variant<TLSDataForNetwork, TLSDataForApplication, TLSFault, TLSConnected>; +struct TLSServerNameRequested { + std::string name; +}; + +using TLSEvent = boost::variant<TLSDataForNetwork, TLSDataForApplication, TLSFault, TLSConnected, TLSServerNameRequested>; class TLSEventToSafeByteArrayVisitor : public boost::static_visitor<SafeByteArray> { public: SafeByteArray operator()(const TLSDataForNetwork& tlsData) const { return tlsData.data; } SafeByteArray operator()(const TLSDataForApplication& tlsData) const { return tlsData.data; } SafeByteArray operator()(const TLSFault&) const { return createSafeByteArray(""); } SafeByteArray operator()(const TLSConnected&) const { return createSafeByteArray(""); } + + SafeByteArray operator()(const TLSServerNameRequested&) const { + return createSafeByteArray(""); + } + }; class TLSEventToStringVisitor : public boost::static_visitor<std::string> { public: std::string operator()(const TLSDataForNetwork& event) const { return std::string("TLSDataForNetwork(") + "size: " + std::to_string(event.data.size()) + ")"; } std::string operator()(const TLSDataForApplication& event) const { return std::string("TLSDataForApplication(") + "size: " + std::to_string(event.data.size()) + ")"; } std::string operator()(const TLSFault&) const { return "TLSFault()"; } std::string operator()(const TLSConnected& event) const { std::string certificates; for (auto cert : event.chain) { certificates += "\t" + cert->getSubjectName() + "\n"; } return std::string("TLSConnected()") + "\n" + certificates; } + + std::string operator()(const TLSServerNameRequested& event) const { + return std::string("TLSServerNameRequested(") + "name: " + event.name + ")"; + } }; class TLSClientServerEventHistory { public: TLSClientServerEventHistory(TLSContext* client, TLSContext* server) { connectContext(std::string("client"), client); connectContext(std::string("server"), server); } __attribute__((unused)) void print() { auto count = 0; std::cout << "\n"; for (auto event : events) { if (event.first == "server") { std::cout << std::string(80, ' '); } std::cout << count << ". "; std::cout << event.first << " : " << boost::apply_visitor(TLSEventToStringVisitor(), event.second) << std::endl; count++; } } private: void connectContext(const std::string& name, TLSContext* context) { connections_.push_back(context->onDataForNetwork.connect([=](const SafeByteArray& data) { events.push_back(std::pair<std::string, TLSEvent>(name, TLSDataForNetwork{data})); })); connections_.push_back(context->onDataForApplication.connect([=](const SafeByteArray& data) { events.push_back(std::pair<std::string, TLSEvent>(name, TLSDataForApplication{data})); @@ -545,30 +558,100 @@ TEST(ClientServerTest, testSettingPrivateKeyWithWrongPassword) { TLSClientServerEventHistory events(clientContext.get(), serverContext.get()); ClientServerConnector connector(clientContext.get(), serverContext.get()); auto tlsFactories = std::make_shared<PlatformTLSFactories>(); ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["montague.example"])))); auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(montagueEncryptedPEM), createSafeByteArray("foo")); ASSERT_NE(nullptr, privateKey.get()); ASSERT_EQ(false, serverContext->setPrivateKey(privateKey)); } TEST(ClientServerTest, testSettingPrivateKeyWithoutRequiredPassword) { auto clientContext = createTLSContext(TLSContext::Mode::Client); auto serverContext = createTLSContext(TLSContext::Mode::Server); TLSClientServerEventHistory events(clientContext.get(), serverContext.get()); ClientServerConnector connector(clientContext.get(), serverContext.get()); auto tlsFactories = std::make_shared<PlatformTLSFactories>(); ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["montague.example"])))); auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(montagueEncryptedPEM)); ASSERT_NE(nullptr, privateKey.get()); ASSERT_EQ(false, serverContext->setPrivateKey(privateKey)); } + +TEST(ClientServerTest, testClientServerSNIRequestedHostAvailable) { + auto tlsFactories = std::make_shared<PlatformTLSFactories>(); + auto clientContext = createTLSContext(TLSContext::Mode::Client); + auto serverContext = createTLSContext(TLSContext::Mode::Server); + + serverContext->onServerNameRequested.connect([&](const std::string& requestedName) { + if (certificatePEM.find(requestedName) != certificatePEM.end() && privateKeyPEM.find(requestedName) != privateKeyPEM.end()) { + auto certChain = tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM[requestedName])); + ASSERT_EQ(true, serverContext->setCertificateChain(certChain)); + + auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(privateKeyPEM[requestedName])); + ASSERT_NE(nullptr, privateKey.get()); + ASSERT_EQ(true, serverContext->setPrivateKey(privateKey)); + } + }); + + TLSClientServerEventHistory events(clientContext.get(), serverContext.get()); + + ClientServerConnector connector(clientContext.get(), serverContext.get()); + + ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["capulet.example"])))); + + auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(privateKeyPEM["capulet.example"])); + ASSERT_NE(nullptr, privateKey.get()); + ASSERT_EQ(true, serverContext->setPrivateKey(privateKey)); + + serverContext->accept(); + clientContext->connect("montague.example"); + + clientContext->handleDataFromApplication(createSafeByteArray("This is a test message from the client.")); + serverContext->handleDataFromApplication(createSafeByteArray("This is a test message from the server.")); + ASSERT_EQ("This is a test message from the client.", safeByteArrayToString(boost::apply_visitor(TLSEventToSafeByteArrayVisitor(), std::find_if(events.events.begin(), events.events.end(), [](std::pair<std::string, TLSEvent>& event){ + return event.first == "server" && (event.second.type() == typeid(TLSDataForApplication)); + })->second))); + ASSERT_EQ("This is a test message from the server.", safeByteArrayToString(boost::apply_visitor(TLSEventToSafeByteArrayVisitor(), std::find_if(events.events.begin(), events.events.end(), [](std::pair<std::string, TLSEvent>& event){ + return event.first == "client" && (event.second.type() == typeid(TLSDataForApplication)); + })->second))); + + ASSERT_EQ("/CN=montague.example", boost::get<TLSConnected>(events.events[5].second).chain[0]->getSubjectName()); +} + +TEST(ClientServerTest, testClientServerSNIRequestedHostUnavailable) { + auto tlsFactories = std::make_shared<PlatformTLSFactories>(); + auto clientContext = createTLSContext(TLSContext::Mode::Client); + auto serverContext = createTLSContext(TLSContext::Mode::Server); + + serverContext->onServerNameRequested.connect([&](const std::string&) { + serverContext->setAbortTLSHandshake(true); + }); + + TLSClientServerEventHistory events(clientContext.get(), serverContext.get()); + + ClientServerConnector connector(clientContext.get(), serverContext.get()); + + ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["capulet.example"])))); + + auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(privateKeyPEM["capulet.example"])); + ASSERT_NE(nullptr, privateKey.get()); + ASSERT_EQ(true, serverContext->setPrivateKey(privateKey)); + + serverContext->accept(); + clientContext->connect("montague.example"); + + ASSERT_EQ("server", events.events[1].first); + ASSERT_EQ("TLSFault()", boost::apply_visitor(TLSEventToStringVisitor(), events.events[1].second)); + + ASSERT_EQ("client", events.events[3].first); + ASSERT_EQ("TLSFault()", boost::apply_visitor(TLSEventToStringVisitor(), events.events[3].second)); +} |