summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to 'Swiften/TLS/UnitTest/ClientServerTest.cpp')
-rw-r--r--Swiften/TLS/UnitTest/ClientServerTest.cpp85
1 files changed, 84 insertions, 1 deletions
diff --git a/Swiften/TLS/UnitTest/ClientServerTest.cpp b/Swiften/TLS/UnitTest/ClientServerTest.cpp
index 9aa1762..692b3c0 100644
--- a/Swiften/TLS/UnitTest/ClientServerTest.cpp
+++ b/Swiften/TLS/UnitTest/ClientServerTest.cpp
@@ -266,102 +266,115 @@ class ClientServerConnector {
connections_.push_back(clientContext_->onDataForNetwork.connect([&](const SafeByteArray& data) {
serverContext_->handleDataFromNetwork(data);
}));
connections_.push_back(serverContext_->onDataForNetwork.connect([&](const SafeByteArray& data) {
clientContext_->handleDataFromNetwork(data);
}));
}
private:
TLSContext* clientContext_;
TLSContext* serverContext_;
std::vector<boost::signals2::connection> connections_;
};
struct TLSDataForNetwork {
SafeByteArray data;
};
struct TLSDataForApplication {
SafeByteArray data;
};
struct TLSFault {
std::shared_ptr<Swift::TLSError> error;
};
struct TLSConnected {
std::vector<Certificate::ref> chain;
};
-using TLSEvent = boost::variant<TLSDataForNetwork, TLSDataForApplication, TLSFault, TLSConnected>;
+struct TLSServerNameRequested {
+ std::string name;
+};
+
+using TLSEvent = boost::variant<TLSDataForNetwork, TLSDataForApplication, TLSFault, TLSConnected, TLSServerNameRequested>;
class TLSEventToSafeByteArrayVisitor : public boost::static_visitor<SafeByteArray> {
public:
SafeByteArray operator()(const TLSDataForNetwork& tlsData) const {
return tlsData.data;
}
SafeByteArray operator()(const TLSDataForApplication& tlsData) const {
return tlsData.data;
}
SafeByteArray operator()(const TLSFault&) const {
return createSafeByteArray("");
}
SafeByteArray operator()(const TLSConnected&) const {
return createSafeByteArray("");
}
+
+ SafeByteArray operator()(const TLSServerNameRequested&) const {
+ return createSafeByteArray("");
+ }
+
};
class TLSEventToStringVisitor : public boost::static_visitor<std::string> {
public:
std::string operator()(const TLSDataForNetwork& event) const {
return std::string("TLSDataForNetwork(") + "size: " + std::to_string(event.data.size()) + ")";
}
std::string operator()(const TLSDataForApplication& event) const {
return std::string("TLSDataForApplication(") + "size: " + std::to_string(event.data.size()) + ")";
}
std::string operator()(const TLSFault&) const {
return "TLSFault()";
}
std::string operator()(const TLSConnected& event) const {
std::string certificates;
for (auto cert : event.chain) {
certificates += "\t" + cert->getSubjectName() + "\n";
}
return std::string("TLSConnected()") + "\n" + certificates;
}
+
+ std::string operator()(const TLSServerNameRequested& event) const {
+ return std::string("TLSServerNameRequested(") + "name: " + event.name + ")";
+ }
};
class TLSClientServerEventHistory {
public:
TLSClientServerEventHistory(TLSContext* client, TLSContext* server) {
connectContext(std::string("client"), client);
connectContext(std::string("server"), server);
}
__attribute__((unused))
void print() {
auto count = 0;
std::cout << "\n";
for (auto event : events) {
if (event.first == "server") {
std::cout << std::string(80, ' ');
}
std::cout << count << ". ";
std::cout << event.first << " : " << boost::apply_visitor(TLSEventToStringVisitor(), event.second) << std::endl;
count++;
}
}
private:
void connectContext(const std::string& name, TLSContext* context) {
connections_.push_back(context->onDataForNetwork.connect([=](const SafeByteArray& data) {
events.push_back(std::pair<std::string, TLSEvent>(name, TLSDataForNetwork{data}));
}));
connections_.push_back(context->onDataForApplication.connect([=](const SafeByteArray& data) {
events.push_back(std::pair<std::string, TLSEvent>(name, TLSDataForApplication{data}));
@@ -545,30 +558,100 @@ TEST(ClientServerTest, testSettingPrivateKeyWithWrongPassword) {
TLSClientServerEventHistory events(clientContext.get(), serverContext.get());
ClientServerConnector connector(clientContext.get(), serverContext.get());
auto tlsFactories = std::make_shared<PlatformTLSFactories>();
ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["montague.example"]))));
auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(montagueEncryptedPEM), createSafeByteArray("foo"));
ASSERT_NE(nullptr, privateKey.get());
ASSERT_EQ(false, serverContext->setPrivateKey(privateKey));
}
TEST(ClientServerTest, testSettingPrivateKeyWithoutRequiredPassword) {
auto clientContext = createTLSContext(TLSContext::Mode::Client);
auto serverContext = createTLSContext(TLSContext::Mode::Server);
TLSClientServerEventHistory events(clientContext.get(), serverContext.get());
ClientServerConnector connector(clientContext.get(), serverContext.get());
auto tlsFactories = std::make_shared<PlatformTLSFactories>();
ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["montague.example"]))));
auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(montagueEncryptedPEM));
ASSERT_NE(nullptr, privateKey.get());
ASSERT_EQ(false, serverContext->setPrivateKey(privateKey));
}
+
+TEST(ClientServerTest, testClientServerSNIRequestedHostAvailable) {
+ auto tlsFactories = std::make_shared<PlatformTLSFactories>();
+ auto clientContext = createTLSContext(TLSContext::Mode::Client);
+ auto serverContext = createTLSContext(TLSContext::Mode::Server);
+
+ serverContext->onServerNameRequested.connect([&](const std::string& requestedName) {
+ if (certificatePEM.find(requestedName) != certificatePEM.end() && privateKeyPEM.find(requestedName) != privateKeyPEM.end()) {
+ auto certChain = tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM[requestedName]));
+ ASSERT_EQ(true, serverContext->setCertificateChain(certChain));
+
+ auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(privateKeyPEM[requestedName]));
+ ASSERT_NE(nullptr, privateKey.get());
+ ASSERT_EQ(true, serverContext->setPrivateKey(privateKey));
+ }
+ });
+
+ TLSClientServerEventHistory events(clientContext.get(), serverContext.get());
+
+ ClientServerConnector connector(clientContext.get(), serverContext.get());
+
+ ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["capulet.example"]))));
+
+ auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(privateKeyPEM["capulet.example"]));
+ ASSERT_NE(nullptr, privateKey.get());
+ ASSERT_EQ(true, serverContext->setPrivateKey(privateKey));
+
+ serverContext->accept();
+ clientContext->connect("montague.example");
+
+ clientContext->handleDataFromApplication(createSafeByteArray("This is a test message from the client."));
+ serverContext->handleDataFromApplication(createSafeByteArray("This is a test message from the server."));
+ ASSERT_EQ("This is a test message from the client.", safeByteArrayToString(boost::apply_visitor(TLSEventToSafeByteArrayVisitor(), std::find_if(events.events.begin(), events.events.end(), [](std::pair<std::string, TLSEvent>& event){
+ return event.first == "server" && (event.second.type() == typeid(TLSDataForApplication));
+ })->second)));
+ ASSERT_EQ("This is a test message from the server.", safeByteArrayToString(boost::apply_visitor(TLSEventToSafeByteArrayVisitor(), std::find_if(events.events.begin(), events.events.end(), [](std::pair<std::string, TLSEvent>& event){
+ return event.first == "client" && (event.second.type() == typeid(TLSDataForApplication));
+ })->second)));
+
+ ASSERT_EQ("/CN=montague.example", boost::get<TLSConnected>(events.events[5].second).chain[0]->getSubjectName());
+}
+
+TEST(ClientServerTest, testClientServerSNIRequestedHostUnavailable) {
+ auto tlsFactories = std::make_shared<PlatformTLSFactories>();
+ auto clientContext = createTLSContext(TLSContext::Mode::Client);
+ auto serverContext = createTLSContext(TLSContext::Mode::Server);
+
+ serverContext->onServerNameRequested.connect([&](const std::string&) {
+ serverContext->setAbortTLSHandshake(true);
+ });
+
+ TLSClientServerEventHistory events(clientContext.get(), serverContext.get());
+
+ ClientServerConnector connector(clientContext.get(), serverContext.get());
+
+ ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["capulet.example"]))));
+
+ auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(privateKeyPEM["capulet.example"]));
+ ASSERT_NE(nullptr, privateKey.get());
+ ASSERT_EQ(true, serverContext->setPrivateKey(privateKey));
+
+ serverContext->accept();
+ clientContext->connect("montague.example");
+
+ ASSERT_EQ("server", events.events[1].first);
+ ASSERT_EQ("TLSFault()", boost::apply_visitor(TLSEventToStringVisitor(), events.events[1].second));
+
+ ASSERT_EQ("client", events.events[3].first);
+ ASSERT_EQ("TLSFault()", boost::apply_visitor(TLSEventToStringVisitor(), events.events[3].second));
+}