From a594eb3fef7e047d1eca7959d7734d4d10fd1eb7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Remko=20Tron=C3=A7on?= <git@el-tramo.be>
Date: Sun, 7 Nov 2010 21:07:06 +0100
Subject: Refactoring certificates & certificate checking.


diff --git a/Swift/Controllers/MainController.cpp b/Swift/Controllers/MainController.cpp
index 093f987..cc51ec3 100644
--- a/Swift/Controllers/MainController.cpp
+++ b/Swift/Controllers/MainController.cpp
@@ -356,6 +356,7 @@ void MainController::performLoginFromCachedCredentials() {
 	if (!client_) {
 		storages_ = storagesFactory_->createStorages(jid_);
 		client_ = new Swift::Client(eventLoop_, jid_, password_, storages_);
+		client_->setAlwaysTrustCertificates();
 		client_->onDataRead.connect(boost::bind(&XMLConsoleController::handleDataRead, xmlConsoleController_, _1));
 		client_->onDataWritten.connect(boost::bind(&XMLConsoleController::handleDataWritten, xmlConsoleController_, _1));
 		client_->onDisconnected.connect(boost::bind(&MainController::handleDisconnected, this, _1));
@@ -416,6 +417,19 @@ void MainController::handleDisconnected(const boost::optional<ClientError>& erro
 			case ClientError::TLSError: message = "Encryption error"; break;
 			case ClientError::ClientCertificateLoadError: message = "Error loading certificate (Invalid password?)"; break;
 			case ClientError::ClientCertificateError: message = "Certificate not authorized"; break;
+
+			case ClientError::UnknownCertificateError:
+			case ClientError::CertificateExpiredError:
+			case ClientError::CertificateNotYetValidError:
+			case ClientError::CertificateSelfSignedError:
+			case ClientError::CertificateRejectedError:
+			case ClientError::CertificateUntrustedError:
+			case ClientError::InvalidCertificatePurposeError:
+			case ClientError::CertificatePathLengthExceededError:
+			case ClientError::InvalidCertificateSignatureError:
+			case ClientError::InvalidCAError:
+				// TODO
+				message = "Certificate error"; break;
 		}
 		if (!rosterController_) { //hasn't been logged in yet
 			signOut();
diff --git a/Swiften/Client/Client.cpp b/Swiften/Client/Client.cpp
index 18cff9a..3b2c102 100644
--- a/Swiften/Client/Client.cpp
+++ b/Swiften/Client/Client.cpp
@@ -23,6 +23,7 @@
 #include "Swiften/Disco/ClientDiscoManager.h"
 #include "Swiften/Client/NickResolver.h"
 #include "Swiften/Presence/SubscriptionManager.h"
+#include "Swiften/TLS/BlindCertificateTrustChecker.h"
 
 namespace Swift {
 
@@ -53,9 +54,13 @@ Client::Client(EventLoop* eventLoop, const JID& jid, const String& password, Sto
 	entityCapsManager = new EntityCapsManager(capsManager, getStanzaChannel());
 
 	nickResolver = new NickResolver(jid.toBare(), roster, vcardManager, mucRegistry);
+
+	blindCertificateTrustChecker = new BlindCertificateTrustChecker();
 }
 
 Client::~Client() {
+	delete blindCertificateTrustChecker;
+
 	delete nickResolver;
 
 	delete entityCapsManager;
@@ -116,4 +121,8 @@ EntityCapsProvider* Client::getEntityCapsProvider() const {
 	return entityCapsManager;
 }
 
+void Client::setAlwaysTrustCertificates() {
+	setCertificateTrustChecker(blindCertificateTrustChecker);
+}
+
 }
diff --git a/Swiften/Client/Client.h b/Swiften/Client/Client.h
index a6cf059..1a6700e 100644
--- a/Swiften/Client/Client.h
+++ b/Swiften/Client/Client.h
@@ -10,6 +10,7 @@
 
 namespace Swift {
 	class SoftwareVersionResponder;
+	class BlindCertificateTrustChecker;
 	class XMPPRoster;
 	class XMPPRosterImpl;
 	class MUCManager;
@@ -126,6 +127,8 @@ namespace Swift {
 				return discoManager;
 			}
 
+			void setAlwaysTrustCertificates();
+		
 		public:
 			/**
 			 * This signal is emitted when a JID changes presence.
@@ -156,5 +159,6 @@ namespace Swift {
 			SubscriptionManager* subscriptionManager;
 			MUCManager* mucManager;
 			ClientDiscoManager* discoManager;
+			BlindCertificateTrustChecker* blindCertificateTrustChecker;
 	};
 }
diff --git a/Swiften/Client/ClientError.h b/Swiften/Client/ClientError.h
index 6ac8a6d..1c775e4 100644
--- a/Swiften/Client/ClientError.h
+++ b/Swiften/Client/ClientError.h
@@ -25,7 +25,19 @@ namespace Swift {
 				SessionStartError,
 				TLSError,
 				ClientCertificateLoadError,
-				ClientCertificateError
+				ClientCertificateError,
+
+				// Certificate verification errors
+				UnknownCertificateError,
+				CertificateExpiredError,
+				CertificateNotYetValidError,
+				CertificateSelfSignedError,
+				CertificateRejectedError,
+				CertificateUntrustedError,
+				InvalidCertificatePurposeError,
+				CertificatePathLengthExceededError,
+				InvalidCertificateSignatureError,
+				InvalidCAError,
 			};
 
 			ClientError(Type type = UnknownError) : type_(type) {}
diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp
index 7170a20..a199a84 100644
--- a/Swiften/Client/ClientSession.cpp
+++ b/Swiften/Client/ClientSession.cpp
@@ -11,7 +11,6 @@
 #include <boost/uuid/uuid_io.hpp>
 #include <boost/uuid/uuid_generators.hpp>
 
-#include "Swiften/TLS/SecurityError.h"
 #include "Swiften/Elements/ProtocolHeader.h"
 #include "Swiften/Elements/StreamFeatures.h"
 #include "Swiften/Elements/StartTLSRequest.h"
@@ -37,6 +36,7 @@
 #include "Swiften/SASL/SCRAMSHA1ClientAuthenticator.h"
 #include "Swiften/SASL/DIGESTMD5ClientAuthenticator.h"
 #include "Swiften/Session/SessionStream.h"
+#include "Swiften/TLS/CertificateTrustChecker.h"
 
 namespace Swift {
 
@@ -50,7 +50,8 @@ ClientSession::ClientSession(
 			needSessionStart(false),
 			needResourceBind(false),
 			needAcking(false),
-			authenticator(NULL) {
+			authenticator(NULL),
+			certificateTrustChecker(NULL) {
 }
 
 ClientSession::~ClientSession() {
@@ -323,19 +324,18 @@ void ClientSession::sendCredentials(const String& password) {
 	stream->writeElement(boost::shared_ptr<AuthRequest>(new AuthRequest(authenticator->getName(), authenticator->getResponse())));
 }
 
-void ClientSession::continueAfterSecurityError() {
-	checkState(WaitingForContinueAfterSecurityError);
-	continueAfterTLSEncrypted();
-}
-
 void ClientSession::handleTLSEncrypted() {
 	checkState(Encrypting);
 
 	Certificate::ref certificate = stream->getPeerCertificate();
-	boost::optional<CertificateVerificationError> verificationError = stream->getPeerCertificateVerificationError();
+	boost::shared_ptr<CertificateVerificationError> verificationError = stream->getPeerCertificateVerificationError();
 	if (verificationError) {
-		state = WaitingForContinueAfterSecurityError;
-		onSecurityError(SecurityError(*verificationError));
+		if (certificateTrustChecker && certificateTrustChecker->isCertificateTrusted(certificate, localJID.getDomain())) {
+			continueAfterTLSEncrypted();
+		}
+		else {
+			finishSession(verificationError);
+		}
 	}
 	else {
 		continueAfterTLSEncrypted();
diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h
index b14a6ec..6acd9a3 100644
--- a/Swiften/Client/ClientSession.h
+++ b/Swiften/Client/ClientSession.h
@@ -20,7 +20,7 @@
 
 namespace Swift {
 	class ClientAuthenticator;
-	class SecurityError;
+	class CertificateTrustChecker;
 
 	class ClientSession : public boost::enable_shared_from_this<ClientSession> {
 		public:
@@ -31,7 +31,6 @@ namespace Swift {
 				Compressing,
 				WaitingForEncrypt,
 				Encrypting,
-				WaitingForContinueAfterSecurityError,
 				WaitingForCredentials,
 				Authenticating,
 				EnablingSessionManagement,
@@ -83,11 +82,13 @@ namespace Swift {
 
 			void sendCredentials(const String& password);
 			void sendStanza(boost::shared_ptr<Stanza>);
-			void continueAfterSecurityError();
+
+			void setCertificateTrustChecker(CertificateTrustChecker* checker) {
+				certificateTrustChecker = checker;
+			}
 
 		public:
 			boost::signal<void ()> onNeedCredentials;
-			boost::signal<void (const SecurityError&)> onSecurityError;
 			boost::signal<void ()> onInitialized;
 			boost::signal<void (boost::shared_ptr<Swift::Error>)> onFinished;
 			boost::signal<void (boost::shared_ptr<Stanza>)> onStanzaReceived;
@@ -132,5 +133,6 @@ namespace Swift {
 			ClientAuthenticator* authenticator;
 			boost::shared_ptr<StanzaAckRequester> stanzaAckRequester_;
 			boost::shared_ptr<StanzaAckResponder> stanzaAckResponder_;
+			CertificateTrustChecker* certificateTrustChecker;
 	};
 }
diff --git a/Swiften/Client/CoreClient.cpp b/Swiften/Client/CoreClient.cpp
index 4202483..7bde017 100644
--- a/Swiften/Client/CoreClient.cpp
+++ b/Swiften/Client/CoreClient.cpp
@@ -12,6 +12,7 @@
 #include "Swiften/Network/BoostIOServiceThread.h"
 #include "Swiften/Client/ClientSession.h"
 #include "Swiften/TLS/PlatformTLSContextFactory.h"
+#include "Swiften/TLS/CertificateVerificationError.h"
 #include "Swiften/Network/Connector.h"
 #include "Swiften/Network/BoostConnectionFactory.h"
 #include "Swiften/Network/BoostTimerFactory.h"
@@ -23,7 +24,7 @@
 
 namespace Swift {
 
-CoreClient::CoreClient(EventLoop* eventLoop, const JID& jid, const String& password) : resolver_(eventLoop), jid_(jid), password_(password), eventLoop(eventLoop), disconnectRequested_(false), ignoreSecurityErrors(true) {
+CoreClient::CoreClient(EventLoop* eventLoop, const JID& jid, const String& password) : resolver_(eventLoop), jid_(jid), password_(password), eventLoop(eventLoop), disconnectRequested_(false), certificateTrustChecker(NULL) {
 	stanzaChannel_ = new ClientSessionStanzaChannel();
 	stanzaChannel_->onMessageReceived.connect(boost::ref(onMessageReceived));
 	stanzaChannel_->onPresenceReceived.connect(boost::ref(onPresenceReceived));
@@ -90,10 +91,10 @@ void CoreClient::handleConnectorFinished(boost::shared_ptr<Connection> connectio
 		sessionStream_->initialize();
 
 		session_ = ClientSession::create(jid_, sessionStream_);
+		session_->setCertificateTrustChecker(certificateTrustChecker);
 		stanzaChannel_->setSession(session_);
 		session_->onFinished.connect(boost::bind(&CoreClient::handleSessionFinished, this, _1));
 		session_->onNeedCredentials.connect(boost::bind(&CoreClient::handleNeedCredentials, this));
-		session_->onSecurityError.connect(boost::bind(&CoreClient::handleSecurityError, this, _1));
 		session_->start();
 	}
 }
@@ -115,7 +116,6 @@ void CoreClient::setCertificate(const String& certificate) {
 }
 
 void CoreClient::handleSessionFinished(boost::shared_ptr<Error> error) {
-	session_->onSecurityError.disconnect(boost::bind(&CoreClient::handleSecurityError, this, _1));
 	session_->onFinished.disconnect(boost::bind(&CoreClient::handleSessionFinished, this, _1));
 	session_->onNeedCredentials.disconnect(boost::bind(&CoreClient::handleNeedCredentials, this));
 	session_.reset();
@@ -180,6 +180,30 @@ void CoreClient::handleSessionFinished(boost::shared_ptr<Error> error) {
 					break;
 			}
 		}
+		else if (boost::shared_ptr<CertificateVerificationError> verificationError = boost::dynamic_pointer_cast<CertificateVerificationError>(error)) {
+			switch(verificationError->getType()) {
+				case CertificateVerificationError::UnknownError: 
+					clientError = ClientError(ClientError::UnknownCertificateError);
+				case CertificateVerificationError::Expired: 
+					clientError = ClientError(ClientError::CertificateExpiredError);
+				case CertificateVerificationError::NotYetValid: 
+					clientError = ClientError(ClientError::CertificateNotYetValidError);
+				case CertificateVerificationError::SelfSigned: 
+					clientError = ClientError(ClientError::CertificateSelfSignedError);
+				case CertificateVerificationError::Rejected: 
+					clientError = ClientError(ClientError::CertificateRejectedError);
+				case CertificateVerificationError::Untrusted: 
+					clientError = ClientError(ClientError::CertificateUntrustedError);
+				case CertificateVerificationError::InvalidPurpose: 
+					clientError = ClientError(ClientError::InvalidCertificatePurposeError);
+				case CertificateVerificationError::PathLengthExceeded: 
+					clientError = ClientError(ClientError::CertificatePathLengthExceededError);
+				case CertificateVerificationError::InvalidSignature: 
+					clientError = ClientError(ClientError::InvalidCertificateSignatureError);
+				case CertificateVerificationError::InvalidCA: 
+					clientError = ClientError(ClientError::InvalidCAError);
+			}
+		}
 		actualError = boost::optional<ClientError>(clientError);
 	}
 	onDisconnected(actualError);
@@ -216,17 +240,8 @@ bool CoreClient::isActive() const {
 	return session_ || connector_;
 }
 
-void CoreClient::handleSecurityError(const SecurityError& error) {
-	if (ignoreSecurityErrors) {
-		session_->continueAfterSecurityError();
-	}
-	else {
-		onSecurityError(error);
-	}
-}
-
-void CoreClient::continueAfterSecurityError() {
-	session_->continueAfterSecurityError();
+void CoreClient::setCertificateTrustChecker(CertificateTrustChecker* checker) {
+	certificateTrustChecker = checker;
 }
 
 }
diff --git a/Swiften/Client/CoreClient.h b/Swiften/Client/CoreClient.h
index 3176a51..780201d 100644
--- a/Swiften/Client/CoreClient.h
+++ b/Swiften/Client/CoreClient.h
@@ -32,7 +32,7 @@ namespace Swift {
 	class ClientSession;
 	class BasicSessionStream;
 	class EventLoop;
-	class SecurityError;
+	class CertificateTrustChecker;
 
 	/** 
 	 * The central class for communicating with an XMPP server.
@@ -72,14 +72,6 @@ namespace Swift {
 			void connect(const String& host);
 			
 			/**
-			 * Instructs the client to continue initializing the session
-			 * after a security error has occurred (and as such ignoring the error)
-			 *
-			 * \see onSecurityError
-			 */
-			void continueAfterSecurityError();
-
-			/**
 			 * Sends a message.
 			 */
 			void sendMessage(Message::ref);
@@ -140,29 +132,10 @@ namespace Swift {
 				return stanzaChannel_;
 			}
 
-			/**
-			 * Sets whether security errors should be ignored or not.
-			 *
-			 * If this is set to 'true', onSecurityError will not be called when a security
-			 * error occurs, and connecting will continue.
-			 *
-			 * Defaults to true.
-			 */
-			void setIgnoreSecurityErrors(bool b) {
-				ignoreSecurityErrors = b;
-			}
+			void setCertificateTrustChecker(CertificateTrustChecker*);
 
 		public:
 			/**
-			 * Emitted when an error occurred while establishing a secure connection.
-			 *
-			 * If the error is to be ignored, call continueAfterSecurityError(), otherwise call
-			 * finish().
-			 * This signal is not emitted when setIgnoreSecurityErrors() is set to true.
-			 */
-			boost::signal<void (const SecurityError&)> onSecurityError;
-
-			/**
 			 * Emitted when the client was disconnected from the network.
 			 *
 			 * If the connection was due to a non-recoverable error, the type
@@ -217,7 +190,6 @@ namespace Swift {
 			void handleNeedCredentials();
 			void handleDataRead(const String&);
 			void handleDataWritten(const String&);
-			void handleSecurityError(const SecurityError& securityError);
 
 		private:
 			PlatformDomainNameResolver resolver_;
@@ -237,6 +209,6 @@ namespace Swift {
 			boost::shared_ptr<ClientSession> session_;
 			String certificate_;
 			bool disconnectRequested_;
-			bool ignoreSecurityErrors;
+			CertificateTrustChecker* certificateTrustChecker;
 	};
 }
diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp
index 43a8bf3..11e4992 100644
--- a/Swiften/Client/UnitTest/ClientSessionTest.cpp
+++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp
@@ -287,8 +287,8 @@ class ClientSessionTest : public CppUnit::TestFixture {
 					return Certificate::ref();
 				}
 
-				virtual boost::optional<CertificateVerificationError> getPeerCertificateVerificationError() const {
-					return boost::optional<CertificateVerificationError>();
+				virtual boost::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const {
+					return boost::shared_ptr<CertificateVerificationError>();
 				}
 
 				virtual void addZLibCompression() {
diff --git a/Swiften/Component/UnitTest/ComponentSessionTest.cpp b/Swiften/Component/UnitTest/ComponentSessionTest.cpp
index b6b57dd..4fe8e87 100644
--- a/Swiften/Component/UnitTest/ComponentSessionTest.cpp
+++ b/Swiften/Component/UnitTest/ComponentSessionTest.cpp
@@ -127,8 +127,8 @@ class ComponentSessionTest : public CppUnit::TestFixture {
 					return Certificate::ref();
 				}
 
-				virtual boost::optional<CertificateVerificationError> getPeerCertificateVerificationError() const {
-					return boost::optional<CertificateVerificationError>();
+				virtual boost::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const {
+					return boost::shared_ptr<CertificateVerificationError>();
 				}
 
 				virtual void addZLibCompression() {
diff --git a/Swiften/Examples/SendMessage/SendMessage.cpp b/Swiften/Examples/SendMessage/SendMessage.cpp
index fe020aa..567a351 100644
--- a/Swiften/Examples/SendMessage/SendMessage.cpp
+++ b/Swiften/Examples/SendMessage/SendMessage.cpp
@@ -58,6 +58,7 @@ int main(int argc, char* argv[]) {
 	}
 
 	client = new Swift::Client(&eventLoop, JID(jid), String(argv[argi++]));
+	client->setAlwaysTrustCertificates();
 
 	recipient = JID(argv[argi++]);
 	messageBody = std::string(argv[argi++]);
diff --git a/Swiften/QA/ClientTest/ClientTest.cpp b/Swiften/QA/ClientTest/ClientTest.cpp
index 4e48339..dd63056 100644
--- a/Swiften/QA/ClientTest/ClientTest.cpp
+++ b/Swiften/QA/ClientTest/ClientTest.cpp
@@ -58,6 +58,7 @@ int main(int, char**) {
 	client = new Swift::Client(&eventLoop, JID(jid), String(pass));
 	ClientXMLTracer* tracer = new ClientXMLTracer(client);
 	client->onConnected.connect(&handleConnected);
+	client->setAlwaysTrustCertificates();
 	client->connect();
 
 	{
diff --git a/Swiften/Session/BasicSessionStream.cpp b/Swiften/Session/BasicSessionStream.cpp
index 65a241c..32424bc 100644
--- a/Swiften/Session/BasicSessionStream.cpp
+++ b/Swiften/Session/BasicSessionStream.cpp
@@ -89,7 +89,7 @@ Certificate::ref BasicSessionStream::getPeerCertificate() const {
 	return tlsLayer->getPeerCertificate();
 }
 
-boost::optional<CertificateVerificationError> BasicSessionStream::getPeerCertificateVerificationError() const {
+boost::shared_ptr<CertificateVerificationError> BasicSessionStream::getPeerCertificateVerificationError() const {
 	return tlsLayer->getPeerCertificateVerificationError();
 }
 
diff --git a/Swiften/Session/BasicSessionStream.h b/Swiften/Session/BasicSessionStream.h
index 8addeb6..fbaa937 100644
--- a/Swiften/Session/BasicSessionStream.h
+++ b/Swiften/Session/BasicSessionStream.h
@@ -53,7 +53,7 @@ namespace Swift {
 			virtual void addTLSEncryption();
 			virtual bool isTLSEncrypted();
 			virtual Certificate::ref getPeerCertificate() const;
-			virtual boost::optional<CertificateVerificationError> getPeerCertificateVerificationError() const;
+			virtual boost::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const;
 
 			virtual void setWhitespacePingEnabled(bool);
 
diff --git a/Swiften/Session/SessionStream.h b/Swiften/Session/SessionStream.h
index 1bf9090..d648f91 100644
--- a/Swiften/Session/SessionStream.h
+++ b/Swiften/Session/SessionStream.h
@@ -61,7 +61,7 @@ namespace Swift {
 			}
 
 			virtual Certificate::ref getPeerCertificate() const = 0;
-			virtual boost::optional<CertificateVerificationError> getPeerCertificateVerificationError() const = 0;
+			virtual boost::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const = 0;
 
 			boost::signal<void (const ProtocolHeader&)> onStreamStartReceived;
 			boost::signal<void (boost::shared_ptr<Element>)> onElementReceived;
diff --git a/Swiften/StreamStack/TLSLayer.cpp b/Swiften/StreamStack/TLSLayer.cpp
index dd6660f..8cb06fc 100644
--- a/Swiften/StreamStack/TLSLayer.cpp
+++ b/Swiften/StreamStack/TLSLayer.cpp
@@ -42,7 +42,7 @@ Certificate::ref TLSLayer::getPeerCertificate() const {
 	return context->getPeerCertificate();
 }
 
-boost::optional<CertificateVerificationError> TLSLayer::getPeerCertificateVerificationError() const {
+boost::shared_ptr<CertificateVerificationError> TLSLayer::getPeerCertificateVerificationError() const {
 	return context->getPeerCertificateVerificationError();
 }
 
diff --git a/Swiften/StreamStack/TLSLayer.h b/Swiften/StreamStack/TLSLayer.h
index 6fb825f..a69f789 100644
--- a/Swiften/StreamStack/TLSLayer.h
+++ b/Swiften/StreamStack/TLSLayer.h
@@ -25,7 +25,7 @@ namespace Swift {
 			bool setClientCertificate(const PKCS12Certificate&);
 
 			Certificate::ref getPeerCertificate() const;
-			boost::optional<CertificateVerificationError> getPeerCertificateVerificationError() const;
+			boost::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const;
 
 			void writeData(const ByteArray& data);
 			void handleDataRead(const ByteArray& data);
diff --git a/Swiften/TLS/BlindCertificateTrustChecker.h b/Swiften/TLS/BlindCertificateTrustChecker.h
new file mode 100644
index 0000000..26a7f94
--- /dev/null
+++ b/Swiften/TLS/BlindCertificateTrustChecker.h
@@ -0,0 +1,18 @@
+/*
+ * Copyright (c) 2010 Remko Tronçon
+ * Licensed under the GNU General Public License v3.
+ * See Documentation/Licenses/GPLv3.txt for more information.
+ */
+
+#pragma once
+
+#include "Swiften/TLS/CertificateTrustChecker.h"
+
+namespace Swift {
+	class BlindCertificateTrustChecker : public CertificateTrustChecker {
+		public:
+			virtual bool isCertificateTrusted(Certificate::ref, const String&) {
+				return true;
+			}
+	};
+}
diff --git a/Swiften/TLS/Certificate.cpp b/Swiften/TLS/Certificate.cpp
new file mode 100644
index 0000000..7d61b22
--- /dev/null
+++ b/Swiften/TLS/Certificate.cpp
@@ -0,0 +1,17 @@
+/*
+ * Copyright (c) 2010 Remko Tronçon
+ * Licensed under the GNU General Public License v3.
+ * See Documentation/Licenses/GPLv3.txt for more information.
+ */
+
+#include "Swiften/TLS/Certificate.h"
+
+namespace Swift {
+
+const char* Certificate::ID_ON_XMPPADDR_OID = "1.3.6.1.5.5.7.8.5";
+const char* Certificate::ID_ON_DNSSRV_OID = "1.3.6.1.5.5.7.8.7";
+
+Certificate::~Certificate() {
+}
+
+}
diff --git a/Swiften/TLS/Certificate.h b/Swiften/TLS/Certificate.h
index 21ea0bd..3a812a8 100644
--- a/Swiften/TLS/Certificate.h
+++ b/Swiften/TLS/Certificate.h
@@ -15,42 +15,16 @@ namespace Swift {
 		public:
 			typedef boost::shared_ptr<Certificate> ref;
 
-			const String& getCommonName() const {
-				return commonName;
-			}
-
-			void setCommonName(const String& commonName) {
-				this->commonName = commonName;
-			}
-
-			const std::vector<String>& getSRVNames() const {
-				return srvNames;
-			}
-
-			void addSRVName(const String& name) {
-				srvNames.push_back(name);
-			}
-
-			const std::vector<String>& getDNSNames() const {
-				return dnsNames;
-			}
-
-			void addDNSName(const String& name) {
-				dnsNames.push_back(name);
-			}
-
-			const std::vector<String>& getXMPPAddresses() const {
-				return xmppAddresses;
-			}
-
-			void addXMPPAddress(const String& addr) {
-				xmppAddresses.push_back(addr);
-			}
-
-		private:
-			String commonName;
-			std::vector<String> dnsNames;
-			std::vector<String> xmppAddresses;
-			std::vector<String> srvNames;
+			virtual ~Certificate();
+
+			virtual String getCommonName() const = 0;
+			virtual std::vector<String> getSRVNames() const = 0;
+			virtual std::vector<String> getDNSNames() const = 0;
+			virtual std::vector<String> getXMPPAddresses() const = 0;
+
+		protected:
+			static const char* ID_ON_XMPPADDR_OID;
+			static const char* ID_ON_DNSSRV_OID;
+
 	};
 }
diff --git a/Swiften/TLS/CertificateTrustChecker.cpp b/Swiften/TLS/CertificateTrustChecker.cpp
new file mode 100644
index 0000000..f4f921d
--- /dev/null
+++ b/Swiften/TLS/CertificateTrustChecker.cpp
@@ -0,0 +1,14 @@
+/*
+ * Copyright (c) 2010 Remko Tronçon
+ * Licensed under the GNU General Public License v3.
+ * See Documentation/Licenses/GPLv3.txt for more information.
+ */
+
+#include "Swiften/TLS/CertificateTrustChecker.h"
+
+namespace Swift {
+
+CertificateTrustChecker::~CertificateTrustChecker() {
+}
+
+}
diff --git a/Swiften/TLS/CertificateTrustChecker.h b/Swiften/TLS/CertificateTrustChecker.h
new file mode 100644
index 0000000..070c4bb
--- /dev/null
+++ b/Swiften/TLS/CertificateTrustChecker.h
@@ -0,0 +1,21 @@
+/*
+ * Copyright (c) 2010 Remko Tronçon
+ * Licensed under the GNU General Public License v3.
+ * See Documentation/Licenses/GPLv3.txt for more information.
+ */
+
+#pragma once
+
+#include <boost/shared_ptr.hpp>
+
+#include "Swiften/Base/String.h"
+#include "Swiften/TLS/Certificate.h"
+
+namespace Swift {
+	class CertificateTrustChecker {
+		public:
+			virtual ~CertificateTrustChecker();
+
+			virtual bool isCertificateTrusted(Certificate::ref certificate, const String& domain) = 0;
+	};
+}
diff --git a/Swiften/TLS/CertificateVerificationError.h b/Swiften/TLS/CertificateVerificationError.h
index 76b4aff..f1bd091 100644
--- a/Swiften/TLS/CertificateVerificationError.h
+++ b/Swiften/TLS/CertificateVerificationError.h
@@ -6,8 +6,10 @@
 
 #pragma once
 
+#include "Swiften/Base/Error.h"
+
 namespace Swift {
-	class CertificateVerificationError {
+	class CertificateVerificationError : public Error {
 		public:
 			enum Type {
 				UnknownError,
diff --git a/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp b/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp
new file mode 100644
index 0000000..5d9aac2
--- /dev/null
+++ b/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp
@@ -0,0 +1,65 @@
+/*
+ * Copyright (c) 2010 Remko Tronçon
+ * Licensed under the GNU General Public License v3.
+ * See Documentation/Licenses/GPLv3.txt for more information.
+ */
+
+#include "Swiften/TLS/OpenSSL/OpenSSLCertificate.h"
+
+#include <openssl/x509v3.h>
+
+#include "Swiften/Base/ByteArray.h"
+
+#pragma GCC diagnostic ignored "-Wold-style-cast"
+
+namespace Swift {
+
+OpenSSLCertificate::OpenSSLCertificate(boost::shared_ptr<X509> cert) : cert(cert) {
+	// Common name
+	X509_NAME* subjectName = X509_get_subject_name(cert.get());
+	if (subjectName) {
+		int cnLoc = X509_NAME_get_index_by_NID(subjectName, NID_commonName, -1);
+		if (cnLoc != -1) {
+			X509_NAME_ENTRY* cnEntry = X509_NAME_get_entry(subjectName, cnLoc);
+			ASN1_STRING* cnData = X509_NAME_ENTRY_get_data(cnEntry);
+			setCommonName(ByteArray(cnData->data, cnData->length).toString());
+		}
+	}
+
+	// subjectAltNames
+	int subjectAltNameLoc = X509_get_ext_by_NID(cert.get(), NID_subject_alt_name, -1);
+	if(subjectAltNameLoc != -1) {
+		X509_EXTENSION* extension = X509_get_ext(cert.get(), subjectAltNameLoc);
+		boost::shared_ptr<GENERAL_NAMES> generalNames(reinterpret_cast<GENERAL_NAMES*>(X509V3_EXT_d2i(extension)), GENERAL_NAMES_free);
+		boost::shared_ptr<ASN1_OBJECT> xmppAddrObject(OBJ_txt2obj(ID_ON_XMPPADDR_OID, 1), ASN1_OBJECT_free);
+		boost::shared_ptr<ASN1_OBJECT> dnsSRVObject(OBJ_txt2obj(ID_ON_DNSSRV_OID, 1), ASN1_OBJECT_free);
+		for (int i = 0; i < sk_GENERAL_NAME_num(generalNames.get()); ++i) {
+			GENERAL_NAME* generalName = sk_GENERAL_NAME_value(generalNames.get(), i);
+			if (generalName->type == GEN_OTHERNAME) {
+				OTHERNAME* otherName = generalName->d.otherName;
+				if (OBJ_cmp(otherName->type_id, xmppAddrObject.get()) == 0) {
+					// XmppAddr
+					if (otherName->value->type != V_ASN1_UTF8STRING) {
+						continue;
+					}
+					ASN1_UTF8STRING* xmppAddrValue = otherName->value->value.utf8string;
+					addXMPPAddress(ByteArray(ASN1_STRING_data(xmppAddrValue), ASN1_STRING_length(xmppAddrValue)).toString());
+				}
+				else if (OBJ_cmp(otherName->type_id, dnsSRVObject.get()) == 0) {
+					// SRVName
+					if (otherName->value->type != V_ASN1_IA5STRING) {
+						continue;
+					}
+					ASN1_IA5STRING* srvNameValue = otherName->value->value.ia5string;
+					addSRVName(ByteArray(ASN1_STRING_data(srvNameValue), ASN1_STRING_length(srvNameValue)).toString());
+				}
+			}
+			else if (generalName->type == GEN_DNS) {
+				// DNSName
+				addDNSName(ByteArray(ASN1_STRING_data(generalName->d.dNSName), ASN1_STRING_length(generalName->d.dNSName)).toString());
+			}
+		}
+	}
+}
+
+}
diff --git a/Swiften/TLS/OpenSSL/OpenSSLCertificate.h b/Swiften/TLS/OpenSSL/OpenSSLCertificate.h
new file mode 100644
index 0000000..4708120
--- /dev/null
+++ b/Swiften/TLS/OpenSSL/OpenSSLCertificate.h
@@ -0,0 +1,60 @@
+/*
+ * Copyright (c) 2010 Remko Tronçon
+ * Licensed under the GNU General Public License v3.
+ * See Documentation/Licenses/GPLv3.txt for more information.
+ */
+
+#pragma once
+
+#include <boost/shared_ptr.hpp>
+#include <openssl/ssl.h>
+
+#include "Swiften/Base/String.h"
+#include "Swiften/TLS/Certificate.h"
+
+namespace Swift {
+	class OpenSSLCertificate : public Certificate {
+		public:
+			OpenSSLCertificate(boost::shared_ptr<X509>);
+
+			String getCommonName() const {
+				return commonName;
+			}
+
+			std::vector<String> getSRVNames() const {
+				return srvNames;
+			}
+
+			std::vector<String> getDNSNames() const {
+				return dnsNames;
+			}
+
+			std::vector<String> getXMPPAddresses() const {
+				return xmppAddresses;
+			}
+
+		private:
+			void addSRVName(const String& name) {
+				srvNames.push_back(name);
+			}
+
+			void addDNSName(const String& name) {
+				dnsNames.push_back(name);
+			}
+
+			void addXMPPAddress(const String& addr) {
+				xmppAddresses.push_back(addr);
+			}
+
+			void setCommonName(const String& commonName) {
+				this->commonName = commonName;
+			}
+
+		private:
+			boost::shared_ptr<X509> cert;
+			String commonName;
+			std::vector<String> dnsNames;
+			std::vector<String> xmppAddresses;
+			std::vector<String> srvNames;
+	};
+}
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
index c78d5a1..41c98c1 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
+++ b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
@@ -7,9 +7,9 @@
 #include <vector>
 #include <openssl/err.h>
 #include <openssl/pkcs12.h>
-#include <openssl/x509v3.h>
 
 #include "Swiften/TLS/OpenSSL/OpenSSLContext.h"
+#include "Swiften/TLS/OpenSSL/OpenSSLCertificate.h"
 #include "Swiften/TLS/PKCS12Certificate.h"
 
 #pragma GCC diagnostic ignored "-Wold-style-cast"
@@ -166,67 +166,20 @@ bool OpenSSLContext::setClientCertificate(const PKCS12Certificate& certificate)
 Certificate::ref OpenSSLContext::getPeerCertificate() const {
 	boost::shared_ptr<X509> x509Cert(SSL_get_peer_certificate(handle_), X509_free);
 	if (x509Cert) {
-		Certificate::ref certificate(new Certificate());
-
-		// Common name
-		X509_NAME* subjectName = X509_get_subject_name(x509Cert.get());
-		if (subjectName) {
-			int cnLoc = X509_NAME_get_index_by_NID(subjectName, NID_commonName, -1);
-			if (cnLoc != -1) {
-				X509_NAME_ENTRY* cnEntry = X509_NAME_get_entry(subjectName, cnLoc);
-				ASN1_STRING* cnData = X509_NAME_ENTRY_get_data(cnEntry);
-				certificate->setCommonName(ByteArray(cnData->data, cnData->length).toString());
-			}
-		}
-
-		// subjectAltNames
-		int subjectAltNameLoc = X509_get_ext_by_NID(x509Cert.get(), NID_subject_alt_name, -1);
-		if(subjectAltNameLoc != -1) {
-			X509_EXTENSION* extension = X509_get_ext(x509Cert.get(), subjectAltNameLoc);
-			boost::shared_ptr<GENERAL_NAMES> generalNames(reinterpret_cast<GENERAL_NAMES*>(X509V3_EXT_d2i(extension)), GENERAL_NAMES_free);
-			boost::shared_ptr<ASN1_OBJECT> xmppAddrObject(OBJ_txt2obj(ID_ON_XMPPADDR_OID, 1), ASN1_OBJECT_free);
-			boost::shared_ptr<ASN1_OBJECT> dnsSRVObject(OBJ_txt2obj(ID_ON_DNSSRV_OID, 1), ASN1_OBJECT_free);
-			for (int i = 0; i < sk_GENERAL_NAME_num(generalNames.get()); ++i) {
-				GENERAL_NAME* generalName = sk_GENERAL_NAME_value(generalNames.get(), i);
-				if (generalName->type == GEN_OTHERNAME) {
-					OTHERNAME* otherName = generalName->d.otherName;
-					if (OBJ_cmp(otherName->type_id, xmppAddrObject.get()) == 0) {
-						// XmppAddr
-						if (otherName->value->type != V_ASN1_UTF8STRING) {
-							continue;
-						}
-						ASN1_UTF8STRING* xmppAddrValue = otherName->value->value.utf8string;
-						certificate->addXMPPAddress(ByteArray(ASN1_STRING_data(xmppAddrValue), ASN1_STRING_length(xmppAddrValue)).toString());
-					}
-					else if (OBJ_cmp(otherName->type_id, dnsSRVObject.get()) == 0) {
-						// SRVName
-						if (otherName->value->type != V_ASN1_IA5STRING) {
-							continue;
-						}
-						ASN1_IA5STRING* srvNameValue = otherName->value->value.ia5string;
-						certificate->addSRVName(ByteArray(ASN1_STRING_data(srvNameValue), ASN1_STRING_length(srvNameValue)).toString());
-					}
-				}
-				else if (generalName->type == GEN_DNS) {
-					// DNSName
-					certificate->addDNSName(ByteArray(ASN1_STRING_data(generalName->d.dNSName), ASN1_STRING_length(generalName->d.dNSName)).toString());
-				}
-			}
-		}
-		return certificate;
+		return Certificate::ref(new OpenSSLCertificate(x509Cert));
 	}
 	else {
 		return Certificate::ref();
 	}
 }
 
-boost::optional<CertificateVerificationError> OpenSSLContext::getPeerCertificateVerificationError() const {
+boost::shared_ptr<CertificateVerificationError> OpenSSLContext::getPeerCertificateVerificationError() const {
 	int verifyResult = SSL_get_verify_result(handle_);
 	if (verifyResult != X509_V_OK) {
-		return CertificateVerificationError(getVerificationErrorTypeForResult(verifyResult));
+		return boost::shared_ptr<CertificateVerificationError>(new CertificateVerificationError(getVerificationErrorTypeForResult(verifyResult)));
 	}
 	else {
-		return boost::optional<CertificateVerificationError>();
+		return boost::shared_ptr<CertificateVerificationError>();
 	}
 }
 
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.h b/Swiften/TLS/OpenSSL/OpenSSLContext.h
index 31141a5..9cb287d 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLContext.h
+++ b/Swiften/TLS/OpenSSL/OpenSSLContext.h
@@ -28,7 +28,7 @@ namespace Swift {
 			void handleDataFromApplication(const ByteArray&);
 
 			Certificate::ref getPeerCertificate() const;
-			boost::optional<CertificateVerificationError> getPeerCertificateVerificationError() const;
+			boost::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const;
 
 		private:
 			static void ensureLibraryInitialized();	
diff --git a/Swiften/TLS/SConscript b/Swiften/TLS/SConscript
index b84dbc0..bb33239 100644
--- a/Swiften/TLS/SConscript
+++ b/Swiften/TLS/SConscript
@@ -1,14 +1,16 @@
 Import("swiften_env")
 
 objects = swiften_env.StaticObject([
+			"Certificate.cpp",
+			"CertificateTrustChecker.cpp",
 			"TLSContext.cpp",
 			"TLSContextFactory.cpp",
-			"SecurityError.cpp",
 		])
 		
 if swiften_env.get("HAVE_OPENSSL", 0) :
 	objects += swiften_env.StaticObject([
 			"OpenSSL/OpenSSLContext.cpp",
+			"OpenSSL/OpenSSLCertificate.cpp",
 			"OpenSSL/OpenSSLContextFactory.cpp",
 		])
 		
diff --git a/Swiften/TLS/SecurityError.cpp b/Swiften/TLS/SecurityError.cpp
deleted file mode 100644
index 03aadf0..0000000
--- a/Swiften/TLS/SecurityError.cpp
+++ /dev/null
@@ -1,31 +0,0 @@
-/*
- * Copyright (c) 2010 Remko Tronçon
- * Licensed under the GNU General Public License v3.
- * See Documentation/Licenses/GPLv3.txt for more information.
- */
-
-#include "Swiften/TLS/SecurityError.h"
-#include "Swiften/TLS/CertificateVerificationError.h"
-
-namespace Swift {
-
-SecurityError::SecurityError(Type type) : type(type) {
-}
-
-SecurityError::SecurityError(const CertificateVerificationError& verificationError) {
-	type = UnknownError;
-	switch(verificationError.getType()) {
-		case CertificateVerificationError::UnknownError: type = UnknownError;
-		case CertificateVerificationError::Expired: type = Expired;
-		case CertificateVerificationError::NotYetValid: type = NotYetValid;
-		case CertificateVerificationError::SelfSigned: type = SelfSigned;
-		case CertificateVerificationError::Rejected: type = Rejected;
-		case CertificateVerificationError::Untrusted: type = Untrusted;
-		case CertificateVerificationError::InvalidPurpose: type = InvalidPurpose;
-		case CertificateVerificationError::PathLengthExceeded: type = PathLengthExceeded;
-		case CertificateVerificationError::InvalidSignature: type = InvalidSignature;
-		case CertificateVerificationError::InvalidCA: type = InvalidCA;
-	}
-}
-
-}
diff --git a/Swiften/TLS/SecurityError.h b/Swiften/TLS/SecurityError.h
deleted file mode 100644
index 55ac7d5..0000000
--- a/Swiften/TLS/SecurityError.h
+++ /dev/null
@@ -1,42 +0,0 @@
-/*
- * Copyright (c) 2010 Remko Tronçon
- * Licensed under the GNU General Public License v3.
- * See Documentation/Licenses/GPLv3.txt for more information.
- */
-
-#pragma once
-
-namespace Swift {
-	class CertificateVerificationError;
-
-	class SecurityError {
-		public:
-			enum Type {
-				// From CertificateVerificationError
-				UnknownError,
-				Expired,
-				NotYetValid,
-				SelfSigned,
-				Rejected,
-				Untrusted,
-				InvalidPurpose,
-				PathLengthExceeded,
-				InvalidSignature,
-				InvalidCA,
-
-				// Identity verification
-				InvalidIdentity,
-			};
-
-			SecurityError(Type type);
-			SecurityError(const CertificateVerificationError& verificationError);
-
-
-			Type getType() const { 
-				return type; 
-			}
-
-		private:
-			Type type;
-	};
-}
diff --git a/Swiften/TLS/TLSContext.cpp b/Swiften/TLS/TLSContext.cpp
index 67fa903..008bfc0 100644
--- a/Swiften/TLS/TLSContext.cpp
+++ b/Swiften/TLS/TLSContext.cpp
@@ -8,9 +8,6 @@
 
 namespace Swift {
 
-const char* TLSContext::ID_ON_XMPPADDR_OID = "1.3.6.1.5.5.7.8.5";
-const char* TLSContext::ID_ON_DNSSRV_OID = "1.3.6.1.5.5.7.8.7";
-
 TLSContext::~TLSContext() {
 }
 
diff --git a/Swiften/TLS/TLSContext.h b/Swiften/TLS/TLSContext.h
index 2d05100..47f2697 100644
--- a/Swiften/TLS/TLSContext.h
+++ b/Swiften/TLS/TLSContext.h
@@ -7,6 +7,7 @@
 #pragma once
 
 #include "Swiften/Base/boost_bsignals.h"
+#include <boost/shared_ptr.hpp>
 
 #include "Swiften/Base/ByteArray.h"
 #include "Swiften/TLS/Certificate.h"
@@ -27,11 +28,7 @@ namespace Swift {
 			virtual void handleDataFromApplication(const ByteArray&) = 0;
 
 			virtual Certificate::ref getPeerCertificate() const = 0;
-			virtual boost::optional<CertificateVerificationError> getPeerCertificateVerificationError() const = 0;
-
-		protected:
-			static const char* ID_ON_XMPPADDR_OID;
-			static const char* ID_ON_DNSSRV_OID;
+			virtual boost::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const = 0;
 
 		public:
 			boost::signal<void (const ByteArray&)> onDataForNetwork;
-- 
cgit v0.10.2-6-g49f6