diff options
-rw-r--r-- | Swiften/TLS/OpenSSL/OpenSSLContext.cpp | 10 | ||||
-rw-r--r-- | Swiften/TLS/OpenSSL/OpenSSLContext.h | 1 | ||||
-rw-r--r-- | Swiften/TLS/TLSContext.cpp | 6 | ||||
-rw-r--r-- | Swiften/TLS/TLSContext.h | 2 | ||||
-rw-r--r-- | Swiften/TLS/UnitTest/ClientServerTest.cpp | 23 |
5 files changed, 40 insertions, 2 deletions
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp index 47e7175..6c27e22 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp +++ b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp @@ -463,65 +463,73 @@ bool OpenSSLContext::setClientCertificate(CertificateWithKey::ref certificate) { SSL_CTX_add_extra_chain_cert(context_.get(), sk_X509_value(caCerts.get(), i)); } return true; } std::vector<Certificate::ref> OpenSSLContext::getPeerCertificateChain() const { std::vector<Certificate::ref> result; STACK_OF(X509)* chain = SSL_get_peer_cert_chain(handle_.get()); for (int i = 0; i < sk_X509_num(chain); ++i) { std::shared_ptr<X509> x509Cert(X509_dup(sk_X509_value(chain, i)), X509_free); Certificate::ref cert = std::make_shared<OpenSSLCertificate>(x509Cert); result.push_back(cert); } return result; } std::shared_ptr<CertificateVerificationError> OpenSSLContext::getPeerCertificateVerificationError() const { int verifyResult = SSL_get_verify_result(handle_.get()); if (verifyResult != X509_V_OK) { return std::make_shared<CertificateVerificationError>(getVerificationErrorTypeForResult(verifyResult)); } else { return std::shared_ptr<CertificateVerificationError>(); } } ByteArray OpenSSLContext::getFinishMessage() const { ByteArray data; data.resize(MAX_FINISHED_SIZE); - size_t size = SSL_get_finished(handle_.get(), vecptr(data), data.size()); + auto size = SSL_get_finished(handle_.get(), vecptr(data), data.size()); data.resize(size); return data; } +ByteArray OpenSSLContext::getPeerFinishMessage() const { + ByteArray data; + data.resize(MAX_FINISHED_SIZE); + auto size = SSL_get_peer_finished(handle_.get(), vecptr(data), data.size()); + data.resize(size); + return data; + } + CertificateVerificationError::Type OpenSSLContext::getVerificationErrorTypeForResult(int result) { assert(result != 0); switch (result) { case X509_V_ERR_CERT_NOT_YET_VALID: case X509_V_ERR_ERROR_IN_CERT_NOT_BEFORE_FIELD: return CertificateVerificationError::NotYetValid; case X509_V_ERR_CERT_HAS_EXPIRED: case X509_V_ERR_ERROR_IN_CERT_NOT_AFTER_FIELD: return CertificateVerificationError::Expired; case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT: case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN: return CertificateVerificationError::SelfSigned; case X509_V_ERR_CERT_UNTRUSTED: return CertificateVerificationError::Untrusted; case X509_V_ERR_CERT_REJECTED: return CertificateVerificationError::Rejected; case X509_V_ERR_INVALID_PURPOSE: return CertificateVerificationError::InvalidPurpose; case X509_V_ERR_PATH_LENGTH_EXCEEDED: return CertificateVerificationError::PathLengthExceeded; case X509_V_ERR_UNABLE_TO_DECRYPT_CERT_SIGNATURE: case X509_V_ERR_CERT_SIGNATURE_FAILURE: case X509_V_ERR_UNABLE_TO_VERIFY_LEAF_SIGNATURE: diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.h b/Swiften/TLS/OpenSSL/OpenSSLContext.h index 4a94848..bf897a7 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLContext.h +++ b/Swiften/TLS/OpenSSL/OpenSSLContext.h @@ -30,54 +30,55 @@ namespace std { class default_delete<SSL> { public: void operator()(SSL *ptr) { SSL_free(ptr); } }; } namespace Swift { class OpenSSLContext : public TLSContext, boost::noncopyable { public: OpenSSLContext(Mode mode); virtual ~OpenSSLContext() override final; void accept() override final; void connect() override final; void connect(const std::string& requestHostname) override final; bool setCertificateChain(const std::vector<Certificate::ref>& certificateChain) override final; bool setPrivateKey(const PrivateKey::ref& privateKey) override final; bool setClientCertificate(CertificateWithKey::ref cert) override final; void setAbortTLSHandshake(bool abort) override final; void handleDataFromNetwork(const SafeByteArray&) override final; void handleDataFromApplication(const SafeByteArray&) override final; std::vector<Certificate::ref> getPeerCertificateChain() const override final; std::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const override final; virtual ByteArray getFinishMessage() const override final; + virtual ByteArray getPeerFinishMessage() const override final; private: static void ensureLibraryInitialized(); static int handleServerNameCallback(SSL *ssl, int *ad, void *arg); static CertificateVerificationError::Type getVerificationErrorTypeForResult(int); void initAndSetBIOs(); void doAccept(); void doConnect(); void sendPendingDataToNetwork(); void sendPendingDataToApplication(); private: enum class State { Start, Accepting, Connecting, Connected, Error }; const Mode mode_; State state_; std::unique_ptr<SSL_CTX> context_; std::unique_ptr<SSL> handle_; BIO* readBIO_ = nullptr; BIO* writeBIO_ = nullptr; bool abortTLSHandshake_ = false; }; } diff --git a/Swiften/TLS/TLSContext.cpp b/Swiften/TLS/TLSContext.cpp index 39fb5c9..8246dde 100644 --- a/Swiften/TLS/TLSContext.cpp +++ b/Swiften/TLS/TLSContext.cpp @@ -8,37 +8,41 @@ #include <cassert> namespace Swift { TLSContext::~TLSContext() { } void TLSContext::accept() { assert(false); } void TLSContext::connect(const std::string& /* serverName */) { assert(false); } bool TLSContext::setCertificateChain(const std::vector<Certificate::ref>& /* certificateChain */) { assert(false); return false; } bool TLSContext::setPrivateKey(const PrivateKey::ref& /* privateKey */) { assert(false); return false; } void TLSContext::setAbortTLSHandshake(bool /* abort */) { assert(false); } - Certificate::ref TLSContext::getPeerCertificate() const { std::vector<Certificate::ref> chain = getPeerCertificateChain(); return chain.empty() ? Certificate::ref() : chain[0]; } +ByteArray TLSContext::getPeerFinishMessage() const { + assert(false); + return ByteArray(); +} + } diff --git a/Swiften/TLS/TLSContext.h b/Swiften/TLS/TLSContext.h index 2655d4b..653e8d2 100644 --- a/Swiften/TLS/TLSContext.h +++ b/Swiften/TLS/TLSContext.h @@ -20,45 +20,47 @@ namespace Swift { class SWIFTEN_API TLSContext { public: virtual ~TLSContext(); virtual void accept(); virtual void connect() = 0; virtual void connect(const std::string& serverName); virtual bool setCertificateChain(const std::vector<Certificate::ref>& /* certificateChain */); virtual bool setPrivateKey(const PrivateKey::ref& /* privateKey */); virtual bool setClientCertificate(CertificateWithKey::ref cert) = 0; /** * This method can be used during the \ref onServerNameRequested signal, * to report an error about an unknown host back to the requesting client. */ virtual void setAbortTLSHandshake(bool /* abort */); virtual void handleDataFromNetwork(const SafeByteArray&) = 0; virtual void handleDataFromApplication(const SafeByteArray&) = 0; Certificate::ref getPeerCertificate() const; virtual std::vector<Certificate::ref> getPeerCertificateChain() const = 0; virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const = 0; virtual ByteArray getFinishMessage() const = 0; + virtual ByteArray getPeerFinishMessage() const; + public: enum class Mode { Client, Server }; public: boost::signals2::signal<void (const SafeByteArray&)> onDataForNetwork; boost::signals2::signal<void (const SafeByteArray&)> onDataForApplication; boost::signals2::signal<void (std::shared_ptr<TLSError>)> onError; boost::signals2::signal<void ()> onConnected; boost::signals2::signal<void (const std::string&)> onServerNameRequested; }; } diff --git a/Swiften/TLS/UnitTest/ClientServerTest.cpp b/Swiften/TLS/UnitTest/ClientServerTest.cpp index 692b3c0..5777856 100644 --- a/Swiften/TLS/UnitTest/ClientServerTest.cpp +++ b/Swiften/TLS/UnitTest/ClientServerTest.cpp @@ -628,30 +628,53 @@ TEST(ClientServerTest, testClientServerSNIRequestedHostAvailable) { } TEST(ClientServerTest, testClientServerSNIRequestedHostUnavailable) { auto tlsFactories = std::make_shared<PlatformTLSFactories>(); auto clientContext = createTLSContext(TLSContext::Mode::Client); auto serverContext = createTLSContext(TLSContext::Mode::Server); serverContext->onServerNameRequested.connect([&](const std::string&) { serverContext->setAbortTLSHandshake(true); }); TLSClientServerEventHistory events(clientContext.get(), serverContext.get()); ClientServerConnector connector(clientContext.get(), serverContext.get()); ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["capulet.example"])))); auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(privateKeyPEM["capulet.example"])); ASSERT_NE(nullptr, privateKey.get()); ASSERT_EQ(true, serverContext->setPrivateKey(privateKey)); serverContext->accept(); clientContext->connect("montague.example"); ASSERT_EQ("server", events.events[1].first); ASSERT_EQ("TLSFault()", boost::apply_visitor(TLSEventToStringVisitor(), events.events[1].second)); ASSERT_EQ("client", events.events[3].first); ASSERT_EQ("TLSFault()", boost::apply_visitor(TLSEventToStringVisitor(), events.events[3].second)); } + +TEST(ClientServerTest, testClientServerEqualFinishedMessage) { + auto clientContext = createTLSContext(TLSContext::Mode::Client); + auto serverContext = createTLSContext(TLSContext::Mode::Server); + + TLSClientServerEventHistory events(clientContext.get(), serverContext.get()); + + ClientServerConnector connector(clientContext.get(), serverContext.get()); + + auto tlsFactories = std::make_shared<PlatformTLSFactories>(); + + ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["capulet.example"])))); + + auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(privateKeyPEM["capulet.example"])); + ASSERT_NE(nullptr, privateKey.get()); + ASSERT_EQ(true, serverContext->setPrivateKey(privateKey)); + + serverContext->accept(); + clientContext->connect(); + + ASSERT_EQ(serverContext->getPeerFinishMessage(), clientContext->getFinishMessage()); + ASSERT_EQ(clientContext->getPeerFinishMessage(), serverContext->getFinishMessage()); +} |