summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Swiften/TLS/OpenSSL/OpenSSLContext.cpp10
-rw-r--r--Swiften/TLS/OpenSSL/OpenSSLContext.h1
-rw-r--r--Swiften/TLS/TLSContext.cpp6
-rw-r--r--Swiften/TLS/TLSContext.h2
-rw-r--r--Swiften/TLS/UnitTest/ClientServerTest.cpp23
5 files changed, 40 insertions, 2 deletions
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
index 47e7175..6c27e22 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
+++ b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
@@ -463,65 +463,73 @@ bool OpenSSLContext::setClientCertificate(CertificateWithKey::ref certificate) {
SSL_CTX_add_extra_chain_cert(context_.get(), sk_X509_value(caCerts.get(), i));
}
return true;
}
std::vector<Certificate::ref> OpenSSLContext::getPeerCertificateChain() const {
std::vector<Certificate::ref> result;
STACK_OF(X509)* chain = SSL_get_peer_cert_chain(handle_.get());
for (int i = 0; i < sk_X509_num(chain); ++i) {
std::shared_ptr<X509> x509Cert(X509_dup(sk_X509_value(chain, i)), X509_free);
Certificate::ref cert = std::make_shared<OpenSSLCertificate>(x509Cert);
result.push_back(cert);
}
return result;
}
std::shared_ptr<CertificateVerificationError> OpenSSLContext::getPeerCertificateVerificationError() const {
int verifyResult = SSL_get_verify_result(handle_.get());
if (verifyResult != X509_V_OK) {
return std::make_shared<CertificateVerificationError>(getVerificationErrorTypeForResult(verifyResult));
}
else {
return std::shared_ptr<CertificateVerificationError>();
}
}
ByteArray OpenSSLContext::getFinishMessage() const {
ByteArray data;
data.resize(MAX_FINISHED_SIZE);
- size_t size = SSL_get_finished(handle_.get(), vecptr(data), data.size());
+ auto size = SSL_get_finished(handle_.get(), vecptr(data), data.size());
data.resize(size);
return data;
}
+ByteArray OpenSSLContext::getPeerFinishMessage() const {
+ ByteArray data;
+ data.resize(MAX_FINISHED_SIZE);
+ auto size = SSL_get_peer_finished(handle_.get(), vecptr(data), data.size());
+ data.resize(size);
+ return data;
+ }
+
CertificateVerificationError::Type OpenSSLContext::getVerificationErrorTypeForResult(int result) {
assert(result != 0);
switch (result) {
case X509_V_ERR_CERT_NOT_YET_VALID:
case X509_V_ERR_ERROR_IN_CERT_NOT_BEFORE_FIELD:
return CertificateVerificationError::NotYetValid;
case X509_V_ERR_CERT_HAS_EXPIRED:
case X509_V_ERR_ERROR_IN_CERT_NOT_AFTER_FIELD:
return CertificateVerificationError::Expired;
case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT:
case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN:
return CertificateVerificationError::SelfSigned;
case X509_V_ERR_CERT_UNTRUSTED:
return CertificateVerificationError::Untrusted;
case X509_V_ERR_CERT_REJECTED:
return CertificateVerificationError::Rejected;
case X509_V_ERR_INVALID_PURPOSE:
return CertificateVerificationError::InvalidPurpose;
case X509_V_ERR_PATH_LENGTH_EXCEEDED:
return CertificateVerificationError::PathLengthExceeded;
case X509_V_ERR_UNABLE_TO_DECRYPT_CERT_SIGNATURE:
case X509_V_ERR_CERT_SIGNATURE_FAILURE:
case X509_V_ERR_UNABLE_TO_VERIFY_LEAF_SIGNATURE:
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.h b/Swiften/TLS/OpenSSL/OpenSSLContext.h
index 4a94848..bf897a7 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLContext.h
+++ b/Swiften/TLS/OpenSSL/OpenSSLContext.h
@@ -30,54 +30,55 @@ namespace std {
class default_delete<SSL> {
public:
void operator()(SSL *ptr) {
SSL_free(ptr);
}
};
}
namespace Swift {
class OpenSSLContext : public TLSContext, boost::noncopyable {
public:
OpenSSLContext(Mode mode);
virtual ~OpenSSLContext() override final;
void accept() override final;
void connect() override final;
void connect(const std::string& requestHostname) override final;
bool setCertificateChain(const std::vector<Certificate::ref>& certificateChain) override final;
bool setPrivateKey(const PrivateKey::ref& privateKey) override final;
bool setClientCertificate(CertificateWithKey::ref cert) override final;
void setAbortTLSHandshake(bool abort) override final;
void handleDataFromNetwork(const SafeByteArray&) override final;
void handleDataFromApplication(const SafeByteArray&) override final;
std::vector<Certificate::ref> getPeerCertificateChain() const override final;
std::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const override final;
virtual ByteArray getFinishMessage() const override final;
+ virtual ByteArray getPeerFinishMessage() const override final;
private:
static void ensureLibraryInitialized();
static int handleServerNameCallback(SSL *ssl, int *ad, void *arg);
static CertificateVerificationError::Type getVerificationErrorTypeForResult(int);
void initAndSetBIOs();
void doAccept();
void doConnect();
void sendPendingDataToNetwork();
void sendPendingDataToApplication();
private:
enum class State { Start, Accepting, Connecting, Connected, Error };
const Mode mode_;
State state_;
std::unique_ptr<SSL_CTX> context_;
std::unique_ptr<SSL> handle_;
BIO* readBIO_ = nullptr;
BIO* writeBIO_ = nullptr;
bool abortTLSHandshake_ = false;
};
}
diff --git a/Swiften/TLS/TLSContext.cpp b/Swiften/TLS/TLSContext.cpp
index 39fb5c9..8246dde 100644
--- a/Swiften/TLS/TLSContext.cpp
+++ b/Swiften/TLS/TLSContext.cpp
@@ -8,37 +8,41 @@
#include <cassert>
namespace Swift {
TLSContext::~TLSContext() {
}
void TLSContext::accept() {
assert(false);
}
void TLSContext::connect(const std::string& /* serverName */) {
assert(false);
}
bool TLSContext::setCertificateChain(const std::vector<Certificate::ref>& /* certificateChain */) {
assert(false);
return false;
}
bool TLSContext::setPrivateKey(const PrivateKey::ref& /* privateKey */) {
assert(false);
return false;
}
void TLSContext::setAbortTLSHandshake(bool /* abort */) {
assert(false);
}
-
Certificate::ref TLSContext::getPeerCertificate() const {
std::vector<Certificate::ref> chain = getPeerCertificateChain();
return chain.empty() ? Certificate::ref() : chain[0];
}
+ByteArray TLSContext::getPeerFinishMessage() const {
+ assert(false);
+ return ByteArray();
+}
+
}
diff --git a/Swiften/TLS/TLSContext.h b/Swiften/TLS/TLSContext.h
index 2655d4b..653e8d2 100644
--- a/Swiften/TLS/TLSContext.h
+++ b/Swiften/TLS/TLSContext.h
@@ -20,45 +20,47 @@
namespace Swift {
class SWIFTEN_API TLSContext {
public:
virtual ~TLSContext();
virtual void accept();
virtual void connect() = 0;
virtual void connect(const std::string& serverName);
virtual bool setCertificateChain(const std::vector<Certificate::ref>& /* certificateChain */);
virtual bool setPrivateKey(const PrivateKey::ref& /* privateKey */);
virtual bool setClientCertificate(CertificateWithKey::ref cert) = 0;
/**
* This method can be used during the \ref onServerNameRequested signal,
* to report an error about an unknown host back to the requesting client.
*/
virtual void setAbortTLSHandshake(bool /* abort */);
virtual void handleDataFromNetwork(const SafeByteArray&) = 0;
virtual void handleDataFromApplication(const SafeByteArray&) = 0;
Certificate::ref getPeerCertificate() const;
virtual std::vector<Certificate::ref> getPeerCertificateChain() const = 0;
virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const = 0;
virtual ByteArray getFinishMessage() const = 0;
+ virtual ByteArray getPeerFinishMessage() const;
+
public:
enum class Mode {
Client,
Server
};
public:
boost::signals2::signal<void (const SafeByteArray&)> onDataForNetwork;
boost::signals2::signal<void (const SafeByteArray&)> onDataForApplication;
boost::signals2::signal<void (std::shared_ptr<TLSError>)> onError;
boost::signals2::signal<void ()> onConnected;
boost::signals2::signal<void (const std::string&)> onServerNameRequested;
};
}
diff --git a/Swiften/TLS/UnitTest/ClientServerTest.cpp b/Swiften/TLS/UnitTest/ClientServerTest.cpp
index 692b3c0..5777856 100644
--- a/Swiften/TLS/UnitTest/ClientServerTest.cpp
+++ b/Swiften/TLS/UnitTest/ClientServerTest.cpp
@@ -628,30 +628,53 @@ TEST(ClientServerTest, testClientServerSNIRequestedHostAvailable) {
}
TEST(ClientServerTest, testClientServerSNIRequestedHostUnavailable) {
auto tlsFactories = std::make_shared<PlatformTLSFactories>();
auto clientContext = createTLSContext(TLSContext::Mode::Client);
auto serverContext = createTLSContext(TLSContext::Mode::Server);
serverContext->onServerNameRequested.connect([&](const std::string&) {
serverContext->setAbortTLSHandshake(true);
});
TLSClientServerEventHistory events(clientContext.get(), serverContext.get());
ClientServerConnector connector(clientContext.get(), serverContext.get());
ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["capulet.example"]))));
auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(privateKeyPEM["capulet.example"]));
ASSERT_NE(nullptr, privateKey.get());
ASSERT_EQ(true, serverContext->setPrivateKey(privateKey));
serverContext->accept();
clientContext->connect("montague.example");
ASSERT_EQ("server", events.events[1].first);
ASSERT_EQ("TLSFault()", boost::apply_visitor(TLSEventToStringVisitor(), events.events[1].second));
ASSERT_EQ("client", events.events[3].first);
ASSERT_EQ("TLSFault()", boost::apply_visitor(TLSEventToStringVisitor(), events.events[3].second));
}
+
+TEST(ClientServerTest, testClientServerEqualFinishedMessage) {
+ auto clientContext = createTLSContext(TLSContext::Mode::Client);
+ auto serverContext = createTLSContext(TLSContext::Mode::Server);
+
+ TLSClientServerEventHistory events(clientContext.get(), serverContext.get());
+
+ ClientServerConnector connector(clientContext.get(), serverContext.get());
+
+ auto tlsFactories = std::make_shared<PlatformTLSFactories>();
+
+ ASSERT_EQ(true, serverContext->setCertificateChain(tlsFactories->getCertificateFactory()->createCertificateChain(createByteArray(certificatePEM["capulet.example"]))));
+
+ auto privateKey = tlsFactories->getCertificateFactory()->createPrivateKey(createSafeByteArray(privateKeyPEM["capulet.example"]));
+ ASSERT_NE(nullptr, privateKey.get());
+ ASSERT_EQ(true, serverContext->setPrivateKey(privateKey));
+
+ serverContext->accept();
+ clientContext->connect();
+
+ ASSERT_EQ(serverContext->getPeerFinishMessage(), clientContext->getFinishMessage());
+ ASSERT_EQ(clientContext->getPeerFinishMessage(), serverContext->getFinishMessage());
+}