summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Swift/Controllers/Storages/CertificateStorageTrustChecker.h2
-rw-r--r--Swiften/Client/ClientSession.cpp11
-rw-r--r--Swiften/Client/ClientSession.h2
-rw-r--r--Swiften/TLS/BlindCertificateTrustChecker.h2
-rw-r--r--Swiften/TLS/CertificateTrustChecker.h8
-rw-r--r--Swiften/TLS/OpenSSL/OpenSSLContext.cpp10
-rw-r--r--Swiften/TLS/OpenSSL/OpenSSLContext.h1
-rw-r--r--Swiften/TLS/Schannel/SchannelContext.cpp8
-rw-r--r--Swiften/TLS/Schannel/SchannelContext.h1
-rw-r--r--Swiften/TLS/TLSContext.cpp5
-rw-r--r--Swiften/TLS/TLSContext.h2
11 files changed, 18 insertions, 34 deletions
diff --git a/Swift/Controllers/Storages/CertificateStorageTrustChecker.h b/Swift/Controllers/Storages/CertificateStorageTrustChecker.h
index a73590a..df15575 100644
--- a/Swift/Controllers/Storages/CertificateStorageTrustChecker.h
+++ b/Swift/Controllers/Storages/CertificateStorageTrustChecker.h
@@ -12,19 +12,19 @@
namespace Swift {
/**
* A certificate trust checker that trusts certificates in a certificate storage.
*/
class CertificateStorageTrustChecker : public CertificateTrustChecker {
public:
CertificateStorageTrustChecker(CertificateStorage* storage) : storage(storage) {
}
- virtual bool isCertificateTrusted(Certificate::ref, const std::vector<Certificate::ref>& certificateChain) {
+ virtual bool isCertificateTrusted(const std::vector<Certificate::ref>& certificateChain) {
lastCertificateChain = std::vector<Certificate::ref>(certificateChain.begin(), certificateChain.end());
return certificateChain.empty() ? false : storage->hasCertificate(certificateChain[0]);
}
const std::vector<Certificate::ref>& getLastCertificateChain() const {
return lastCertificateChain;
}
private:
diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp
index c2dc3ae..7e1f517 100644
--- a/Swiften/Client/ClientSession.cpp
+++ b/Swiften/Client/ClientSession.cpp
@@ -364,37 +364,36 @@ void ClientSession::sendCredentials(const SafeByteArray& password) {
assert(WaitingForCredentials);
state = Authenticating;
authenticator->setCredentials(localJID.getNode(), password);
stream->writeElement(boost::make_shared<AuthRequest>(authenticator->getName(), authenticator->getResponse()));
}
void ClientSession::handleTLSEncrypted() {
checkState(Encrypting);
- Certificate::ref certificate = stream->getPeerCertificate();
std::vector<Certificate::ref> certificateChain = stream->getPeerCertificateChain();
boost::shared_ptr<CertificateVerificationError> verificationError = stream->getPeerCertificateVerificationError();
if (verificationError) {
- checkTrustOrFinish(certificate, certificateChain, verificationError);
+ checkTrustOrFinish(certificateChain, verificationError);
}
else {
ServerIdentityVerifier identityVerifier(localJID);
- if (identityVerifier.certificateVerifies(certificate)) {
+ if (!certificateChain.empty() && identityVerifier.certificateVerifies(certificateChain[0])) {
continueAfterTLSEncrypted();
}
else {
- checkTrustOrFinish(certificate, certificateChain, boost::make_shared<CertificateVerificationError>(CertificateVerificationError::InvalidServerIdentity));
+ checkTrustOrFinish(certificateChain, boost::make_shared<CertificateVerificationError>(CertificateVerificationError::InvalidServerIdentity));
}
}
}
-void ClientSession::checkTrustOrFinish(Certificate::ref certificate, const std::vector<Certificate::ref>& certificateChain, boost::shared_ptr<CertificateVerificationError> error) {
- if (certificateTrustChecker && certificateTrustChecker->isCertificateTrusted(certificate, certificateChain)) {
+void ClientSession::checkTrustOrFinish(const std::vector<Certificate::ref>& certificateChain, boost::shared_ptr<CertificateVerificationError> error) {
+ if (certificateTrustChecker && certificateTrustChecker->isCertificateTrusted(certificateChain)) {
continueAfterTLSEncrypted();
}
else {
finishSession(error);
}
}
void ClientSession::continueAfterTLSEncrypted() {
state = WaitingForStreamStart;
diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h
index 9c4b980..66a90ed 100644
--- a/Swiften/Client/ClientSession.h
+++ b/Swiften/Client/ClientSession.h
@@ -148,19 +148,19 @@ namespace Swift {
void handleTLSEncrypted();
bool checkState(State);
void continueSessionInitialization();
void requestAck();
void handleStanzaAcked(boost::shared_ptr<Stanza> stanza);
void ack(unsigned int handledStanzasCount);
void continueAfterTLSEncrypted();
- void checkTrustOrFinish(Certificate::ref certificate, const std::vector<Certificate::ref>& certificateChain, boost::shared_ptr<CertificateVerificationError> error);
+ void checkTrustOrFinish(const std::vector<Certificate::ref>& certificateChain, boost::shared_ptr<CertificateVerificationError> error);
private:
JID localJID;
State state;
boost::shared_ptr<SessionStream> stream;
bool allowPLAINOverNonTLS;
bool useStreamCompression;
UseTLS useTLS;
bool useAcks;
diff --git a/Swiften/TLS/BlindCertificateTrustChecker.h b/Swiften/TLS/BlindCertificateTrustChecker.h
index 9ed7ff2..d91ec25 100644
--- a/Swiften/TLS/BlindCertificateTrustChecker.h
+++ b/Swiften/TLS/BlindCertificateTrustChecker.h
@@ -13,14 +13,14 @@ namespace Swift {
* A certificate trust checker that trusts any ceritficate.
*
* This can be used to ignore any TLS certificate errors occurring
* during connection.
*
* \see Client::setAlwaysTrustCertificates()
*/
class BlindCertificateTrustChecker : public CertificateTrustChecker {
public:
- virtual bool isCertificateTrusted(Certificate::ref, const std::vector<Certificate::ref>&) {
+ virtual bool isCertificateTrusted(const std::vector<Certificate::ref>&) {
return true;
}
};
}
diff --git a/Swiften/TLS/CertificateTrustChecker.h b/Swiften/TLS/CertificateTrustChecker.h
index 91cc530..2ba6b40 100644
--- a/Swiften/TLS/CertificateTrustChecker.h
+++ b/Swiften/TLS/CertificateTrustChecker.h
@@ -15,19 +15,19 @@
namespace Swift {
/**
* A class to implement a check for certificate trust.
*/
class CertificateTrustChecker {
public:
virtual ~CertificateTrustChecker();
/**
- * This method is called to find out whether a certificate is
+ * This method is called to find out whether a certificate (chain) is
* trusted. This usually happens when a certificate's validation
* fails, to check whether to proceed with the connection or not.
*
- * certificateChain contains the chain of certificates, if available.
- * This chain includes certificate.
+ * certificateChain contains the chain of certificates. The first certificate
+ * is the subject certificate.
*/
- virtual bool isCertificateTrusted(Certificate::ref certificate, const std::vector<Certificate::ref>& certificateChain) = 0;
+ virtual bool isCertificateTrusted(const std::vector<Certificate::ref>& certificateChain) = 0;
};
}
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
index 58a8d05..2364c2e 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
+++ b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
@@ -226,28 +226,18 @@ bool OpenSSLContext::setClientCertificate(CertificateWithKey::ref certificate) {
if (SSL_CTX_use_PrivateKey(context_, privateKey.get()) != 1) {
return false;
}
for (int i = 0; i < sk_X509_num(caCerts.get()); ++i) {
SSL_CTX_add_extra_chain_cert(context_, sk_X509_value(caCerts.get(), i));
}
return true;
}
-Certificate::ref OpenSSLContext::getPeerCertificate() const {
- boost::shared_ptr<X509> x509Cert(SSL_get_peer_certificate(handle_), X509_free);
- if (x509Cert) {
- return boost::make_shared<OpenSSLCertificate>(x509Cert);
- }
- else {
- return Certificate::ref();
- }
-}
-
std::vector<Certificate::ref> OpenSSLContext::getPeerCertificateChain() const {
std::vector<Certificate::ref> result;
STACK_OF(X509)* chain = SSL_get_peer_cert_chain(handle_);
for (int i = 0; i < sk_X509_num(chain); ++i) {
boost::shared_ptr<X509> x509Cert(X509_dup(sk_X509_value(chain, i)), X509_free);
Certificate::ref cert = boost::make_shared<OpenSSLCertificate>(x509Cert);
result.push_back(cert);
}
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.h b/Swiften/TLS/OpenSSL/OpenSSLContext.h
index cee4f79..d4327ca 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLContext.h
+++ b/Swiften/TLS/OpenSSL/OpenSSLContext.h
@@ -21,19 +21,18 @@ namespace Swift {
OpenSSLContext();
~OpenSSLContext();
void connect();
bool setClientCertificate(CertificateWithKey::ref cert);
void handleDataFromNetwork(const SafeByteArray&);
void handleDataFromApplication(const SafeByteArray&);
- Certificate::ref getPeerCertificate() const;
std::vector<Certificate::ref> getPeerCertificateChain() const;
boost::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const;
virtual ByteArray getFinishMessage() const;
private:
static void ensureLibraryInitialized();
static CertificateVerificationError::Type getVerificationErrorTypeForResult(int);
diff --git a/Swiften/TLS/Schannel/SchannelContext.cpp b/Swiften/TLS/Schannel/SchannelContext.cpp
index 997d760..b4b2843 100644
--- a/Swiften/TLS/Schannel/SchannelContext.cpp
+++ b/Swiften/TLS/Schannel/SchannelContext.cpp
@@ -619,26 +619,18 @@ bool SchannelContext::setClientCertificate(CertificateWithKey::ref certificate)
}
//------------------------------------------------------------------------
void SchannelContext::handleCertificateCardRemoved() {
indicateError(boost::make_shared<TLSError>(TLSError::CertificateCardRemoved));
}
//------------------------------------------------------------------------
-Certificate::ref SchannelContext::getPeerCertificate() const {
- ScopedCertContext pServerCert;
- SECURITY_STATUS status = QueryContextAttributes(m_ctxtHandle, SECPKG_ATTR_REMOTE_CERT_CONTEXT, pServerCert.Reset());
- return status == SEC_E_OK ? boost::make_shared<SchannelCertificate>(pServerCert) : SchannelCertificate::ref();
-}
-
-//------------------------------------------------------------------------
-
std::vector<Certificate::ref> SchannelContext::getPeerCertificateChain() const {
std::vector<Certificate::ref> certificateChain;
ScopedCertContext pServerCert;
ScopedCertContext pIssuerCert;
ScopedCertContext pCurrentCert;
SECURITY_STATUS status = QueryContextAttributes(m_ctxtHandle, SECPKG_ATTR_REMOTE_CERT_CONTEXT, pServerCert.Reset());
if (status != SEC_E_OK) {
return certificateChain;
diff --git a/Swiften/TLS/Schannel/SchannelContext.h b/Swiften/TLS/Schannel/SchannelContext.h
index 2d65a8a..8603498 100644
--- a/Swiften/TLS/Schannel/SchannelContext.h
+++ b/Swiften/TLS/Schannel/SchannelContext.h
@@ -44,19 +44,18 @@ namespace Swift
//
// TLSContext
//
virtual void connect();
virtual bool setClientCertificate(CertificateWithKey::ref cert);
virtual void handleDataFromNetwork(const SafeByteArray& data);
virtual void handleDataFromApplication(const SafeByteArray& data);
- virtual Certificate::ref getPeerCertificate() const;
virtual std::vector<Certificate::ref> getPeerCertificateChain() const;
virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const;
virtual ByteArray getFinishMessage() const;
virtual void setCheckCertificateRevocation(bool b);
private:
void determineStreamSizes();
diff --git a/Swiften/TLS/TLSContext.cpp b/Swiften/TLS/TLSContext.cpp
index 026ae70..d461d91 100644
--- a/Swiften/TLS/TLSContext.cpp
+++ b/Swiften/TLS/TLSContext.cpp
@@ -5,10 +5,15 @@
*/
#include <Swiften/TLS/TLSContext.h>
namespace Swift {
TLSContext::~TLSContext() {
}
+Certificate::ref TLSContext::getPeerCertificate() const {
+ std::vector<Certificate::ref> chain = getPeerCertificateChain();
+ return chain.empty() ? Certificate::ref() : chain[0];
+}
+
}
diff --git a/Swiften/TLS/TLSContext.h b/Swiften/TLS/TLSContext.h
index 388f8ee..5fee021 100644
--- a/Swiften/TLS/TLSContext.h
+++ b/Swiften/TLS/TLSContext.h
@@ -22,19 +22,19 @@ namespace Swift {
virtual ~TLSContext();
virtual void connect() = 0;
virtual bool setClientCertificate(CertificateWithKey::ref cert) = 0;
virtual void handleDataFromNetwork(const SafeByteArray&) = 0;
virtual void handleDataFromApplication(const SafeByteArray&) = 0;
- virtual Certificate::ref getPeerCertificate() const = 0;
+ Certificate::ref getPeerCertificate() const;
virtual std::vector<Certificate::ref> getPeerCertificateChain() const = 0;
virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const = 0;
virtual ByteArray getFinishMessage() const = 0;
public:
boost::signal<void (const SafeByteArray&)> onDataForNetwork;
boost::signal<void (const SafeByteArray&)> onDataForApplication;
boost::signal<void (boost::shared_ptr<TLSError>)> onError;