diff options
-rw-r--r-- | Swiften/TLS/OpenSSL/OpenSSLContext.cpp | 10 | ||||
-rw-r--r-- | Swiften/TLS/OpenSSL/OpenSSLContext.h | 1 | ||||
-rw-r--r-- | Swiften/TLS/TLSContext.cpp | 6 | ||||
-rw-r--r-- | Swiften/TLS/TLSContext.h | 2 | ||||
-rw-r--r-- | Swiften/TLS/UnitTest/ClientServerTest.cpp | 23 |
5 files changed, 40 insertions, 2 deletions
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp index 47e7175..6c27e22 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp +++ b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp @@ -490,11 +490,19 @@ std::shared_ptr<CertificateVerificationError> OpenSSLContext::getPeerCertificate ByteArray OpenSSLContext::getFinishMessage() const { ByteArray data; data.resize(MAX_FINISHED_SIZE); - size_t size = SSL_get_finished(handle_.get(), vecptr(data), data.size()); + auto size = SSL_get_finished(handle_.get(), vecptr(data), data.size()); data.resize(size); return data; } +ByteArray OpenSSLContext::getPeerFinishMessage() const { + ByteArray data; + data.resize(MAX_FINISHED_SIZE); + auto size = SSL_get_peer_finished(handle_.get(), vecptr(data), data.size()); + data.resize(size); + return data; + } + CertificateVerificationError::Type OpenSSLContext::getVerificationErrorTypeForResult(int result) { assert(result != 0); switch (result) { diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.h b/Swiften/TLS/OpenSSL/OpenSSLContext.h index 4a94848..bf897a7 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLContext.h +++ b/Swiften/TLS/OpenSSL/OpenSSLContext.h @@ -57,6 +57,7 @@ namespace Swift { std::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const override final; virtual ByteArray getFinishMessage() const override final; + virtual ByteArray getPeerFinishMessage() const override final; private: static void ensureLibraryInitialized(); diff --git a/Swiften/TLS/TLSContext.cpp b/Swiften/TLS/TLSContext.cpp index 39fb5c9..8246dde 100644 --- a/Swiften/TLS/TLSContext.cpp +++ b/Swiften/TLS/TLSContext.cpp @@ -35,10 +35,14 @@ void TLSContext::setAbortTLSHandshake(bool /* abort */) { assert(false); } - Certificate::ref TLSContext::getPeerCertificate() const { std::vector<Certificate::ref> chain = getPeerCertificateChain(); return chain.empty() ? Certificate::ref() : chain[0]; } +ByteArray TLSContext::getPeerFinishMessage() const { + assert(false); + return ByteArray(); +} + } diff --git a/Swiften/TLS/TLSContext.h b/Swiften/TLS/TLSContext.h index 2655d4b..653e8d2 100644 --- a/Swiften/TLS/TLSContext.h +++ b/Swiften/TLS/TLSContext.h @@ -47,6 +47,8 @@ namespace Swift { virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const = 0; virtual ByteArray getFinishMessage() const = 0; + virtual ByteArray getPeerFinishMessage() const; + public: enum class Mode { diff --git a/Swiften/TLS/UnitTest/ClientServerTest.cpp b/Swiften/TLS/UnitTest/ClientServerTest.cpp index 692b3c0..5777856 100644 --- a/Swiften/TLS/UnitTest/ClientServerTest.cpp +++ b/Swiften/TLS/UnitTest/ClientServerTest.cpp @@ -655,3 +655,26 @@ TEST(ClientServerTest, testClientServerSNIRequestedHostUnavailable) { ASSERT_EQ("client", events.events[3].first); ASSERT_EQ("TLSFault()", boost::apply_visitor(TLSEventToStringVisitor(), events.events[3].second)); } + +TEST(ClientServerTest, testClientServerEqualFinishedMessage) { + auto clientContext = createTLSContext(TLSContext::Mode::Client); + auto serverContext = createTLSContext(TLSContext::Mode::Server); + + TLSClientServerEventHistory events(clientContext.get(), serverContext.get()); + + ClientServerConnector connector(clientContext.get(), serverContext.get()); + + auto tlsFactories = std::make_shared<PlatformTLSFactories>(); + + ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["capulet.example"])))); + + auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(privateKeyPEM["capulet.example"])); + ASSERT_NE(nullptr, privateKey.get()); + ASSERT_EQ(true, serverContext->setPrivateKey(privateKey)); + + serverContext->accept(); + clientContext->connect(); + + ASSERT_EQ(serverContext->getPeerFinishMessage(), clientContext->getFinishMessage()); + ASSERT_EQ(clientContext->getPeerFinishMessage(), serverContext->getFinishMessage()); +} |