From 70d19e3b5d3757310caf32e1732cac2cd4ae0a63 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Remko=20Tron=C3=A7on?= <git@el-tramo.be>
Date: Sun, 7 Nov 2010 13:31:41 +0100
Subject: Added certificate verification API to TLS context.


diff --git a/Swiften/Base/ByteArray.h b/Swiften/Base/ByteArray.h
index 09698ad..b5cbfb0 100644
--- a/Swiften/Base/ByteArray.h
+++ b/Swiften/Base/ByteArray.h
@@ -36,6 +36,13 @@ namespace Swift {
 				}
 			}
 
+			ByteArray(const unsigned char* c, size_t n) {
+				if (n > 0) {
+					data_.resize(n);
+					memcpy(&data_[0], c, n);
+				}
+			}
+
 			const char* getData() const {
 				return data_.empty() ? NULL : &data_[0];
 			}
diff --git a/Swiften/StreamStack/TLSLayer.cpp b/Swiften/StreamStack/TLSLayer.cpp
index 99154f6..dd6660f 100644
--- a/Swiften/StreamStack/TLSLayer.cpp
+++ b/Swiften/StreamStack/TLSLayer.cpp
@@ -38,4 +38,12 @@ bool TLSLayer::setClientCertificate(const PKCS12Certificate& certificate) {
 	return context->setClientCertificate(certificate);
 }
 
+Certificate::ref TLSLayer::getPeerCertificate() const {
+	return context->getPeerCertificate();
+}
+
+boost::optional<CertificateVerificationError> TLSLayer::getPeerCertificateVerificationError() const {
+	return context->getPeerCertificateVerificationError();
+}
+
 }
diff --git a/Swiften/StreamStack/TLSLayer.h b/Swiften/StreamStack/TLSLayer.h
index f8cda41..6fb825f 100644
--- a/Swiften/StreamStack/TLSLayer.h
+++ b/Swiften/StreamStack/TLSLayer.h
@@ -8,6 +8,8 @@
 
 #include "Swiften/Base/ByteArray.h"
 #include "Swiften/StreamStack/StreamLayer.h"
+#include "Swiften/TLS/Certificate.h"
+#include "Swiften/TLS/CertificateVerificationError.h"
 
 namespace Swift {
 	class TLSContext;
@@ -19,11 +21,14 @@ namespace Swift {
 			TLSLayer(TLSContextFactory*);
 			~TLSLayer();
 
-			virtual void connect();
-			virtual bool setClientCertificate(const PKCS12Certificate&);
+			void connect();
+			bool setClientCertificate(const PKCS12Certificate&);
 
-			virtual void writeData(const ByteArray& data);
-			virtual void handleDataRead(const ByteArray& data);
+			Certificate::ref getPeerCertificate() const;
+			boost::optional<CertificateVerificationError> getPeerCertificateVerificationError() const;
+
+			void writeData(const ByteArray& data);
+			void handleDataRead(const ByteArray& data);
 
 		public:
 			boost::signal<void ()> onError;
diff --git a/Swiften/TLS/Certificate.h b/Swiften/TLS/Certificate.h
new file mode 100644
index 0000000..cdd8b57
--- /dev/null
+++ b/Swiften/TLS/Certificate.h
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) 2010 Remko Tronçon
+ * Licensed under the GNU General Public License v3.
+ * See Documentation/Licenses/GPLv3.txt for more information.
+ */
+
+#pragma once
+
+#include <boost/shared_ptr.hpp>
+
+#include "Swiften/Base/String.h"
+
+namespace Swift {
+	class Certificate {
+		public:
+			typedef boost::shared_ptr<Certificate> ref;
+
+			const String& getCommonName() const {
+				return commonName;
+			}
+
+			void setCommonName(const String& commonName) {
+				this->commonName = commonName;
+			}
+
+			const std::vector<String>& getSRVNames() const {
+				return srvNames;
+			}
+
+			void addSRVName(const String& name) {
+				srvNames.push_back(name);
+			}
+
+			const std::vector<String>& getDNSNames() const {
+				return dnsNames;
+			}
+
+			void addDNSName(const String& name) {
+				dnsNames.push_back(name);
+			}
+
+			const std::vector<String>& getXMPPAddresses() const {
+				return xmppAddresses;
+			}
+
+			void addXMPPAddress(const String& addr) {
+				xmppAddresses.push_back(addr);
+			}
+
+		private:
+			String commonName;
+			String srvName;
+			std::vector<String> dnsNames;
+			std::vector<String> xmppAddresses;
+			std::vector<String> srvNames;
+	};
+}
diff --git a/Swiften/TLS/CertificateVerificationError.h b/Swiften/TLS/CertificateVerificationError.h
new file mode 100644
index 0000000..71895ff
--- /dev/null
+++ b/Swiften/TLS/CertificateVerificationError.h
@@ -0,0 +1,25 @@
+/*
+ * Copyright (c) 2010 Remko Tronçon
+ * Licensed under the GNU General Public License v3.
+ * See Documentation/Licenses/GPLv3.txt for more information.
+ */
+
+#pragma once
+
+namespace Swift {
+	class CertificateVerificationError {
+		public:
+			enum Type {
+				UnknownError,
+			};
+
+			CertificateVerificationError(Type type = UnknownError) : type(type) {}
+
+			Type getType() const { 
+				return type; 
+			}
+
+		private:
+			Type type;
+	};
+}
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
index 80575ca..234c831 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
+++ b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
@@ -7,6 +7,7 @@
 #include <vector>
 #include <openssl/err.h>
 #include <openssl/pkcs12.h>
+#include <openssl/x509v3.h>
 
 #include "Swiften/TLS/OpenSSL/OpenSSLContext.h"
 #include "Swiften/TLS/PKCS12Certificate.h"
@@ -56,14 +57,14 @@ void OpenSSLContext::doConnect() {
 	int connectResult = SSL_connect(handle_);
 	int error = SSL_get_error(handle_, connectResult);
 	switch (error) {
-		case SSL_ERROR_NONE:
+		case SSL_ERROR_NONE: {
 			state_ = Connected;
-			onConnected();
-			//X509* x = SSL_get_peer_certificate(handle_);
 			//std::cout << x->name << std::endl;
 			//const char* comp = SSL_get_current_compression(handle_);
 			//std::cout << "Compression: " << SSL_COMP_get_name(comp) << std::endl;
+			onConnected();
 			break;
+		}
 		case SSL_ERROR_WANT_READ:
 			sendPendingDataToNetwork();
 			break;
@@ -162,4 +163,71 @@ bool OpenSSLContext::setClientCertificate(const PKCS12Certificate& certificate)
 	return true;
 }
 
+Certificate::ref OpenSSLContext::getPeerCertificate() const {
+	boost::shared_ptr<X509> x509Cert(SSL_get_peer_certificate(handle_), X509_free);
+	if (x509Cert) {
+		Certificate::ref certificate(new Certificate());
+
+		// Common name
+		X509_NAME* subjectName = X509_get_subject_name(x509Cert.get());
+		if (subjectName) {
+			int cnLoc = X509_NAME_get_index_by_NID(subjectName, NID_commonName, -1);
+			if (cnLoc != -1) {
+				X509_NAME_ENTRY* cnEntry = X509_NAME_get_entry(subjectName, cnLoc);
+				ASN1_STRING* cnData = X509_NAME_ENTRY_get_data(cnEntry);
+				certificate->setCommonName(ByteArray(cnData->data, cnData->length).toString());
+			}
+		}
+
+		// subjectAltNames
+		int subjectAltNameLoc = X509_get_ext_by_NID(x509Cert.get(), NID_subject_alt_name, -1);
+		if(subjectAltNameLoc != -1) {
+			X509_EXTENSION* extension = X509_get_ext(x509Cert.get(), subjectAltNameLoc);
+			boost::shared_ptr<GENERAL_NAMES> generalNames(reinterpret_cast<GENERAL_NAMES*>(X509V3_EXT_d2i(extension)), GENERAL_NAMES_free);
+			boost::shared_ptr<ASN1_OBJECT> xmppAddrObject(OBJ_txt2obj(ID_ON_XMPPADDR_OID, 1), ASN1_OBJECT_free);
+			boost::shared_ptr<ASN1_OBJECT> dnsSRVObject(OBJ_txt2obj(ID_ON_DNSSRV_OID, 1), ASN1_OBJECT_free);
+			for (int i = 0; i < sk_GENERAL_NAME_num(generalNames.get()); ++i) {
+				GENERAL_NAME* generalName = sk_GENERAL_NAME_value(generalNames.get(), i);
+				if (generalName->type == GEN_OTHERNAME) {
+					OTHERNAME* otherName = generalName->d.otherName;
+					if (OBJ_cmp(otherName->type_id, xmppAddrObject.get()) == 0) {
+						// XmppAddr
+						if (otherName->value->type != V_ASN1_UTF8STRING) {
+							continue;
+						}
+						ASN1_UTF8STRING* xmppAddrValue = otherName->value->value.utf8string;
+						certificate->addXMPPAddress(ByteArray(ASN1_STRING_data(xmppAddrValue), ASN1_STRING_length(xmppAddrValue)).toString());
+					}
+					else if (OBJ_cmp(otherName->type_id, dnsSRVObject.get()) == 0) {
+						// SRVName
+						if (otherName->value->type != V_ASN1_IA5STRING) {
+							continue;
+						}
+						ASN1_IA5STRING* srvNameValue = otherName->value->value.ia5string;
+						certificate->addSRVName(ByteArray(ASN1_STRING_data(srvNameValue), ASN1_STRING_length(srvNameValue)).toString());
+					}
+				}
+				else if (generalName->type == GEN_DNS) {
+					// DNSName
+					certificate->addDNSName(ByteArray(ASN1_STRING_data(generalName->d.dNSName), ASN1_STRING_length(generalName->d.dNSName)).toString());
+				}
+			}
+		}
+		return certificate;
+	}
+	else {
+		return Certificate::ref();
+	}
+}
+
+boost::optional<CertificateVerificationError> OpenSSLContext::getPeerCertificateVerificationError() const {
+	long verifyResult = SSL_get_verify_result(handle_);
+	if (verifyResult != X509_V_OK) {
+		return CertificateVerificationError();
+	}
+	else {
+		return boost::optional<CertificateVerificationError>();
+	}
+}
+
 }
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.h b/Swiften/TLS/OpenSSL/OpenSSLContext.h
index a01e3e5..a0e73c4 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLContext.h
+++ b/Swiften/TLS/OpenSSL/OpenSSLContext.h
@@ -27,6 +27,9 @@ namespace Swift {
 			void handleDataFromNetwork(const ByteArray&);
 			void handleDataFromApplication(const ByteArray&);
 
+			Certificate::ref getPeerCertificate() const;
+			boost::optional<CertificateVerificationError> getPeerCertificateVerificationError() const;
+
 		private:
 			static void ensureLibraryInitialized();	
 
diff --git a/Swiften/TLS/TLSContext.cpp b/Swiften/TLS/TLSContext.cpp
index 008bfc0..67fa903 100644
--- a/Swiften/TLS/TLSContext.cpp
+++ b/Swiften/TLS/TLSContext.cpp
@@ -8,6 +8,9 @@
 
 namespace Swift {
 
+const char* TLSContext::ID_ON_XMPPADDR_OID = "1.3.6.1.5.5.7.8.5";
+const char* TLSContext::ID_ON_DNSSRV_OID = "1.3.6.1.5.5.7.8.7";
+
 TLSContext::~TLSContext() {
 }
 
diff --git a/Swiften/TLS/TLSContext.h b/Swiften/TLS/TLSContext.h
index 9e911d4..2d05100 100644
--- a/Swiften/TLS/TLSContext.h
+++ b/Swiften/TLS/TLSContext.h
@@ -9,6 +9,8 @@
 #include "Swiften/Base/boost_bsignals.h"
 
 #include "Swiften/Base/ByteArray.h"
+#include "Swiften/TLS/Certificate.h"
+#include "Swiften/TLS/CertificateVerificationError.h"
 
 namespace Swift {
 	class PKCS12Certificate;
@@ -18,11 +20,19 @@ namespace Swift {
 			virtual ~TLSContext();
 
 			virtual void connect() = 0;
+
 			virtual bool setClientCertificate(const PKCS12Certificate& cert) = 0;
 
 			virtual void handleDataFromNetwork(const ByteArray&) = 0;
 			virtual void handleDataFromApplication(const ByteArray&) = 0;
 
+			virtual Certificate::ref getPeerCertificate() const = 0;
+			virtual boost::optional<CertificateVerificationError> getPeerCertificateVerificationError() const = 0;
+
+		protected:
+			static const char* ID_ON_XMPPADDR_OID;
+			static const char* ID_ON_DNSSRV_OID;
+
 		public:
 			boost::signal<void (const ByteArray&)> onDataForNetwork;
 			boost::signal<void (const ByteArray&)> onDataForApplication;
-- 
cgit v0.10.2-6-g49f6