diff options
-rw-r--r-- | Swiften/Base/ByteArray.h | 7 | ||||
-rw-r--r-- | Swiften/StreamStack/TLSLayer.cpp | 8 | ||||
-rw-r--r-- | Swiften/StreamStack/TLSLayer.h | 13 | ||||
-rw-r--r-- | Swiften/TLS/Certificate.h | 57 | ||||
-rw-r--r-- | Swiften/TLS/CertificateVerificationError.h | 25 | ||||
-rw-r--r-- | Swiften/TLS/OpenSSL/OpenSSLContext.cpp | 74 | ||||
-rw-r--r-- | Swiften/TLS/OpenSSL/OpenSSLContext.h | 3 | ||||
-rw-r--r-- | Swiften/TLS/TLSContext.cpp | 3 | ||||
-rw-r--r-- | Swiften/TLS/TLSContext.h | 10 |
9 files changed, 193 insertions, 7 deletions
diff --git a/Swiften/Base/ByteArray.h b/Swiften/Base/ByteArray.h index 09698ad..b5cbfb0 100644 --- a/Swiften/Base/ByteArray.h +++ b/Swiften/Base/ByteArray.h @@ -36,6 +36,13 @@ namespace Swift { } } + ByteArray(const unsigned char* c, size_t n) { + if (n > 0) { + data_.resize(n); + memcpy(&data_[0], c, n); + } + } + const char* getData() const { return data_.empty() ? NULL : &data_[0]; } diff --git a/Swiften/StreamStack/TLSLayer.cpp b/Swiften/StreamStack/TLSLayer.cpp index 99154f6..dd6660f 100644 --- a/Swiften/StreamStack/TLSLayer.cpp +++ b/Swiften/StreamStack/TLSLayer.cpp @@ -38,4 +38,12 @@ bool TLSLayer::setClientCertificate(const PKCS12Certificate& certificate) { return context->setClientCertificate(certificate); } +Certificate::ref TLSLayer::getPeerCertificate() const { + return context->getPeerCertificate(); +} + +boost::optional<CertificateVerificationError> TLSLayer::getPeerCertificateVerificationError() const { + return context->getPeerCertificateVerificationError(); +} + } diff --git a/Swiften/StreamStack/TLSLayer.h b/Swiften/StreamStack/TLSLayer.h index f8cda41..6fb825f 100644 --- a/Swiften/StreamStack/TLSLayer.h +++ b/Swiften/StreamStack/TLSLayer.h @@ -8,6 +8,8 @@ #include "Swiften/Base/ByteArray.h" #include "Swiften/StreamStack/StreamLayer.h" +#include "Swiften/TLS/Certificate.h" +#include "Swiften/TLS/CertificateVerificationError.h" namespace Swift { class TLSContext; @@ -19,11 +21,14 @@ namespace Swift { TLSLayer(TLSContextFactory*); ~TLSLayer(); - virtual void connect(); - virtual bool setClientCertificate(const PKCS12Certificate&); + void connect(); + bool setClientCertificate(const PKCS12Certificate&); - virtual void writeData(const ByteArray& data); - virtual void handleDataRead(const ByteArray& data); + Certificate::ref getPeerCertificate() const; + boost::optional<CertificateVerificationError> getPeerCertificateVerificationError() const; + + void writeData(const ByteArray& data); + void handleDataRead(const ByteArray& data); public: boost::signal<void ()> onError; diff --git a/Swiften/TLS/Certificate.h b/Swiften/TLS/Certificate.h new file mode 100644 index 0000000..cdd8b57 --- /dev/null +++ b/Swiften/TLS/Certificate.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2010 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include <boost/shared_ptr.hpp> + +#include "Swiften/Base/String.h" + +namespace Swift { + class Certificate { + public: + typedef boost::shared_ptr<Certificate> ref; + + const String& getCommonName() const { + return commonName; + } + + void setCommonName(const String& commonName) { + this->commonName = commonName; + } + + const std::vector<String>& getSRVNames() const { + return srvNames; + } + + void addSRVName(const String& name) { + srvNames.push_back(name); + } + + const std::vector<String>& getDNSNames() const { + return dnsNames; + } + + void addDNSName(const String& name) { + dnsNames.push_back(name); + } + + const std::vector<String>& getXMPPAddresses() const { + return xmppAddresses; + } + + void addXMPPAddress(const String& addr) { + xmppAddresses.push_back(addr); + } + + private: + String commonName; + String srvName; + std::vector<String> dnsNames; + std::vector<String> xmppAddresses; + std::vector<String> srvNames; + }; +} diff --git a/Swiften/TLS/CertificateVerificationError.h b/Swiften/TLS/CertificateVerificationError.h new file mode 100644 index 0000000..71895ff --- /dev/null +++ b/Swiften/TLS/CertificateVerificationError.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2010 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +namespace Swift { + class CertificateVerificationError { + public: + enum Type { + UnknownError, + }; + + CertificateVerificationError(Type type = UnknownError) : type(type) {} + + Type getType() const { + return type; + } + + private: + Type type; + }; +} diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp index 80575ca..234c831 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp +++ b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp @@ -7,6 +7,7 @@ #include <vector> #include <openssl/err.h> #include <openssl/pkcs12.h> +#include <openssl/x509v3.h> #include "Swiften/TLS/OpenSSL/OpenSSLContext.h" #include "Swiften/TLS/PKCS12Certificate.h" @@ -56,14 +57,14 @@ void OpenSSLContext::doConnect() { int connectResult = SSL_connect(handle_); int error = SSL_get_error(handle_, connectResult); switch (error) { - case SSL_ERROR_NONE: + case SSL_ERROR_NONE: { state_ = Connected; - onConnected(); - //X509* x = SSL_get_peer_certificate(handle_); //std::cout << x->name << std::endl; //const char* comp = SSL_get_current_compression(handle_); //std::cout << "Compression: " << SSL_COMP_get_name(comp) << std::endl; + onConnected(); break; + } case SSL_ERROR_WANT_READ: sendPendingDataToNetwork(); break; @@ -162,4 +163,71 @@ bool OpenSSLContext::setClientCertificate(const PKCS12Certificate& certificate) return true; } +Certificate::ref OpenSSLContext::getPeerCertificate() const { + boost::shared_ptr<X509> x509Cert(SSL_get_peer_certificate(handle_), X509_free); + if (x509Cert) { + Certificate::ref certificate(new Certificate()); + + // Common name + X509_NAME* subjectName = X509_get_subject_name(x509Cert.get()); + if (subjectName) { + int cnLoc = X509_NAME_get_index_by_NID(subjectName, NID_commonName, -1); + if (cnLoc != -1) { + X509_NAME_ENTRY* cnEntry = X509_NAME_get_entry(subjectName, cnLoc); + ASN1_STRING* cnData = X509_NAME_ENTRY_get_data(cnEntry); + certificate->setCommonName(ByteArray(cnData->data, cnData->length).toString()); + } + } + + // subjectAltNames + int subjectAltNameLoc = X509_get_ext_by_NID(x509Cert.get(), NID_subject_alt_name, -1); + if(subjectAltNameLoc != -1) { + X509_EXTENSION* extension = X509_get_ext(x509Cert.get(), subjectAltNameLoc); + boost::shared_ptr<GENERAL_NAMES> generalNames(reinterpret_cast<GENERAL_NAMES*>(X509V3_EXT_d2i(extension)), GENERAL_NAMES_free); + boost::shared_ptr<ASN1_OBJECT> xmppAddrObject(OBJ_txt2obj(ID_ON_XMPPADDR_OID, 1), ASN1_OBJECT_free); + boost::shared_ptr<ASN1_OBJECT> dnsSRVObject(OBJ_txt2obj(ID_ON_DNSSRV_OID, 1), ASN1_OBJECT_free); + for (int i = 0; i < sk_GENERAL_NAME_num(generalNames.get()); ++i) { + GENERAL_NAME* generalName = sk_GENERAL_NAME_value(generalNames.get(), i); + if (generalName->type == GEN_OTHERNAME) { + OTHERNAME* otherName = generalName->d.otherName; + if (OBJ_cmp(otherName->type_id, xmppAddrObject.get()) == 0) { + // XmppAddr + if (otherName->value->type != V_ASN1_UTF8STRING) { + continue; + } + ASN1_UTF8STRING* xmppAddrValue = otherName->value->value.utf8string; + certificate->addXMPPAddress(ByteArray(ASN1_STRING_data(xmppAddrValue), ASN1_STRING_length(xmppAddrValue)).toString()); + } + else if (OBJ_cmp(otherName->type_id, dnsSRVObject.get()) == 0) { + // SRVName + if (otherName->value->type != V_ASN1_IA5STRING) { + continue; + } + ASN1_IA5STRING* srvNameValue = otherName->value->value.ia5string; + certificate->addSRVName(ByteArray(ASN1_STRING_data(srvNameValue), ASN1_STRING_length(srvNameValue)).toString()); + } + } + else if (generalName->type == GEN_DNS) { + // DNSName + certificate->addDNSName(ByteArray(ASN1_STRING_data(generalName->d.dNSName), ASN1_STRING_length(generalName->d.dNSName)).toString()); + } + } + } + return certificate; + } + else { + return Certificate::ref(); + } +} + +boost::optional<CertificateVerificationError> OpenSSLContext::getPeerCertificateVerificationError() const { + long verifyResult = SSL_get_verify_result(handle_); + if (verifyResult != X509_V_OK) { + return CertificateVerificationError(); + } + else { + return boost::optional<CertificateVerificationError>(); + } +} + } diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.h b/Swiften/TLS/OpenSSL/OpenSSLContext.h index a01e3e5..a0e73c4 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLContext.h +++ b/Swiften/TLS/OpenSSL/OpenSSLContext.h @@ -27,6 +27,9 @@ namespace Swift { void handleDataFromNetwork(const ByteArray&); void handleDataFromApplication(const ByteArray&); + Certificate::ref getPeerCertificate() const; + boost::optional<CertificateVerificationError> getPeerCertificateVerificationError() const; + private: static void ensureLibraryInitialized(); diff --git a/Swiften/TLS/TLSContext.cpp b/Swiften/TLS/TLSContext.cpp index 008bfc0..67fa903 100644 --- a/Swiften/TLS/TLSContext.cpp +++ b/Swiften/TLS/TLSContext.cpp @@ -8,6 +8,9 @@ namespace Swift { +const char* TLSContext::ID_ON_XMPPADDR_OID = "1.3.6.1.5.5.7.8.5"; +const char* TLSContext::ID_ON_DNSSRV_OID = "1.3.6.1.5.5.7.8.7"; + TLSContext::~TLSContext() { } diff --git a/Swiften/TLS/TLSContext.h b/Swiften/TLS/TLSContext.h index 9e911d4..2d05100 100644 --- a/Swiften/TLS/TLSContext.h +++ b/Swiften/TLS/TLSContext.h @@ -9,6 +9,8 @@ #include "Swiften/Base/boost_bsignals.h" #include "Swiften/Base/ByteArray.h" +#include "Swiften/TLS/Certificate.h" +#include "Swiften/TLS/CertificateVerificationError.h" namespace Swift { class PKCS12Certificate; @@ -18,11 +20,19 @@ namespace Swift { virtual ~TLSContext(); virtual void connect() = 0; + virtual bool setClientCertificate(const PKCS12Certificate& cert) = 0; virtual void handleDataFromNetwork(const ByteArray&) = 0; virtual void handleDataFromApplication(const ByteArray&) = 0; + virtual Certificate::ref getPeerCertificate() const = 0; + virtual boost::optional<CertificateVerificationError> getPeerCertificateVerificationError() const = 0; + + protected: + static const char* ID_ON_XMPPADDR_OID; + static const char* ID_ON_DNSSRV_OID; + public: boost::signal<void (const ByteArray&)> onDataForNetwork; boost::signal<void (const ByteArray&)> onDataForApplication; |