diff options
Diffstat (limited to 'Swiften/TLS')
45 files changed, 3188 insertions, 3137 deletions
diff --git a/Swiften/TLS/BlindCertificateTrustChecker.h b/Swiften/TLS/BlindCertificateTrustChecker.h index b21f7a6..76a7a02 100644 --- a/Swiften/TLS/BlindCertificateTrustChecker.h +++ b/Swiften/TLS/BlindCertificateTrustChecker.h @@ -10,18 +10,18 @@ #include <Swiften/TLS/CertificateTrustChecker.h> namespace Swift { - /** - * A certificate trust checker that trusts any ceritficate. - * - * This can be used to ignore any TLS certificate errors occurring - * during connection. - * - * \see Client::setAlwaysTrustCertificates() - */ - class SWIFTEN_API BlindCertificateTrustChecker : public CertificateTrustChecker { - public: - virtual bool isCertificateTrusted(const std::vector<Certificate::ref>&) { - return true; - } - }; + /** + * A certificate trust checker that trusts any ceritficate. + * + * This can be used to ignore any TLS certificate errors occurring + * during connection. + * + * \see Client::setAlwaysTrustCertificates() + */ + class SWIFTEN_API BlindCertificateTrustChecker : public CertificateTrustChecker { + public: + virtual bool isCertificateTrusted(const std::vector<Certificate::ref>&) { + return true; + } + }; } diff --git a/Swiften/TLS/CAPICertificate.cpp b/Swiften/TLS/CAPICertificate.cpp index f492c50..a46b9f6 100644 --- a/Swiften/TLS/CAPICertificate.cpp +++ b/Swiften/TLS/CAPICertificate.cpp @@ -1,350 +1,351 @@ /* - * Copyright (c) 2012-2015 Isode Limited. + * Copyright (c) 2012-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #pragma once -#include <Swiften/Network/TimerFactory.h> #include <Swiften/TLS/CAPICertificate.h> -#include <Swiften/StringCodecs/Hexify.h> -#include <Swiften/Base/Log.h> -#include <boost/bind.hpp> #include <boost/algorithm/string/predicate.hpp> +#include <boost/bind.hpp> + +#include <Swiften/Base/Log.h> +#include <Swiften/Network/TimerFactory.h> +#include <Swiften/StringCodecs/Hexify.h> // Size of the SHA1 hash #define SHA1_HASH_LEN 20 #define DEBUG_SCARD_STATUS(function, status) \ { \ - boost::shared_ptr<boost::system::error_code> errorCode = boost::make_shared<boost::system::error_code>(status, boost::system::system_category()); \ - SWIFT_LOG(debug) << std::hex << function << ": status: 0x" << status << ": " << errorCode->message() << std::endl; \ + std::shared_ptr<boost::system::error_code> errorCode = std::make_shared<boost::system::error_code>(status, boost::system::system_category()); \ + SWIFT_LOG(debug) << std::hex << function << ": status: 0x" << status << ": " << errorCode->message() << std::endl; \ } namespace Swift { -CAPICertificate::CAPICertificate(const std::string& capiUri, TimerFactory* timerFactory) : - valid_(false), - uri_(capiUri), - certStoreHandle_(0), - scardContext_(0), - cardHandle_(0), - certStore_(), - certName_(), - smartCardReaderName_(), - timerFactory_(timerFactory), - lastPollingResult_(true) { - assert(timerFactory_); - - setUri(capiUri); +CAPICertificate::CAPICertificate(const std::string& capiUri, TimerFactory* timerFactory) : + valid_(false), + uri_(capiUri), + certStoreHandle_(0), + scardContext_(0), + cardHandle_(0), + certStore_(), + certName_(), + smartCardReaderName_(), + timerFactory_(timerFactory), + lastPollingResult_(true) { + assert(timerFactory_); + + setUri(capiUri); } CAPICertificate::~CAPICertificate() { - if (smartCardTimer_) { - smartCardTimer_->stop(); - smartCardTimer_->onTick.disconnect(boost::bind(&CAPICertificate::handleSmartCardTimerTick, this)); - smartCardTimer_.reset(); - } - - if (certStoreHandle_) { - CertCloseStore(certStoreHandle_, 0); - } - - if (cardHandle_) { - LONG result = SCardDisconnect(cardHandle_, SCARD_LEAVE_CARD); - DEBUG_SCARD_STATUS("SCardDisconnect", result); - } - - if (scardContext_) { - LONG result = SCardReleaseContext(scardContext_); - DEBUG_SCARD_STATUS("SCardReleaseContext", result); - } + if (smartCardTimer_) { + smartCardTimer_->stop(); + smartCardTimer_->onTick.disconnect(boost::bind(&CAPICertificate::handleSmartCardTimerTick, this)); + smartCardTimer_.reset(); + } + + if (certStoreHandle_) { + CertCloseStore(certStoreHandle_, 0); + } + + if (cardHandle_) { + LONG result = SCardDisconnect(cardHandle_, SCARD_LEAVE_CARD); + DEBUG_SCARD_STATUS("SCardDisconnect", result); + } + + if (scardContext_) { + LONG result = SCardReleaseContext(scardContext_); + DEBUG_SCARD_STATUS("SCardReleaseContext", result); + } } bool CAPICertificate::isNull() const { - return uri_.empty() || !valid_; + return uri_.empty() || !valid_; } const std::string& CAPICertificate::getCertStoreName() const { - return certStore_; + return certStore_; } const std::string& CAPICertificate::getCertName() const { - return certName_; + return certName_; } const std::string& CAPICertificate::getSmartCardReaderName() const { - return smartCardReaderName_; + return smartCardReaderName_; } PCCERT_CONTEXT findCertificateInStore (HCERTSTORE certStoreHandle, const std::string &certName) { - if (!boost::iequals(certName.substr(0, 5), "sha1:")) { + if (!boost::iequals(certName.substr(0, 5), "sha1:")) { - // Find client certificate. Note that this sample just searches for a - // certificate that contains the user name somewhere in the subject name. - return CertFindCertificateInStore(certStoreHandle, X509_ASN_ENCODING, /*dwFindFlags*/ 0, CERT_FIND_SUBJECT_STR_A, /* *pvFindPara*/certName.c_str(), /*pPrevCertContext*/ NULL); - } + // Find client certificate. Note that this sample just searches for a + // certificate that contains the user name somewhere in the subject name. + return CertFindCertificateInStore(certStoreHandle, X509_ASN_ENCODING, /*dwFindFlags*/ 0, CERT_FIND_SUBJECT_STR_A, /* *pvFindPara*/certName.c_str(), /*pPrevCertContext*/ NULL); + } - std::string hexstring = certName.substr(5); - ByteArray byteArray = Hexify::unhexify(hexstring); - CRYPT_HASH_BLOB HashBlob; + std::string hexstring = certName.substr(5); + ByteArray byteArray = Hexify::unhexify(hexstring); + CRYPT_HASH_BLOB HashBlob; - if (byteArray.size() != SHA1_HASH_LEN) { - return NULL; - } - HashBlob.cbData = SHA1_HASH_LEN; - HashBlob.pbData = static_cast<BYTE *>(vecptr(byteArray)); + if (byteArray.size() != SHA1_HASH_LEN) { + return NULL; + } + HashBlob.cbData = SHA1_HASH_LEN; + HashBlob.pbData = static_cast<BYTE *>(vecptr(byteArray)); - // Find client certificate. Note that this sample just searches for a - // certificate that contains the user name somewhere in the subject name. - return CertFindCertificateInStore(certStoreHandle, X509_ASN_ENCODING, /* dwFindFlags */ 0, CERT_FIND_HASH, &HashBlob, /* pPrevCertContext */ NULL); + // Find client certificate. Note that this sample just searches for a + // certificate that contains the user name somewhere in the subject name. + return CertFindCertificateInStore(certStoreHandle, X509_ASN_ENCODING, /* dwFindFlags */ 0, CERT_FIND_HASH, &HashBlob, /* pPrevCertContext */ NULL); } void CAPICertificate::setUri (const std::string& capiUri) { - valid_ = false; - - /* Syntax: "certstore:" <cert_store> ":" <hash> ":" <hash_of_cert> */ - - if (!boost::iequals(capiUri.substr(0, 10), "certstore:")) { - return; - } - - /* Substring of subject: uses "storename" */ - std::string capiIdentity = capiUri.substr(10); - std::string newCertStoreName; - size_t pos = capiIdentity.find_first_of (':'); - - if (pos == std::string::npos) { - /* Using the default certificate store */ - newCertStoreName = "MY"; - certName_ = capiIdentity; - } - else { - newCertStoreName = capiIdentity.substr(0, pos); - certName_ = capiIdentity.substr(pos + 1); - } - - if (certStoreHandle_ != NULL) { - if (newCertStoreName != certStore_) { - CertCloseStore(certStoreHandle_, 0); - certStoreHandle_ = NULL; - } - } - - if (certStoreHandle_ == NULL) { - certStoreHandle_ = CertOpenSystemStore(0, newCertStoreName.c_str()); - if (!certStoreHandle_) { - return; - } - } - - certStore_ = newCertStoreName; - - PCCERT_CONTEXT certContext = findCertificateInStore (certStoreHandle_, certName_); - - if (!certContext) { - return; - } - - - /* Now verify that we can have access to the corresponding private key */ - - DWORD len; - CRYPT_KEY_PROV_INFO *pinfo; - HCRYPTPROV hprov; - HCRYPTKEY key; - - if (!CertGetCertificateContextProperty(certContext, - CERT_KEY_PROV_INFO_PROP_ID, - NULL, - &len)) { - CertFreeCertificateContext(certContext); - return; - } - - pinfo = static_cast<CRYPT_KEY_PROV_INFO *>(malloc(len)); - if (!pinfo) { - CertFreeCertificateContext(certContext); - return; - } - - if (!CertGetCertificateContextProperty(certContext, CERT_KEY_PROV_INFO_PROP_ID, pinfo, &len)) { - CertFreeCertificateContext(certContext); - free(pinfo); - return; - } - - CertFreeCertificateContext(certContext); - - // Now verify if we have access to the private key - if (!CryptAcquireContextW(&hprov, pinfo->pwszContainerName, pinfo->pwszProvName, pinfo->dwProvType, 0)) { - free(pinfo); - return; - } - - - char smartCardReader[1024]; - DWORD bufferLength = sizeof(smartCardReader); - if (!CryptGetProvParam(hprov, PP_SMARTCARD_READER, (BYTE *)&smartCardReader, &bufferLength, 0)) { - DWORD error = GetLastError(); - smartCardReaderName_ = ""; - } - else { - smartCardReaderName_ = smartCardReader; - - LONG result = SCardEstablishContext(SCARD_SCOPE_USER, NULL, NULL, &scardContext_); - DEBUG_SCARD_STATUS("SCardEstablishContext", result); - if (SCARD_S_SUCCESS == result) { - // Initiate monitoring for smartcard ejection - smartCardTimer_ = timerFactory_->createTimer(SMARTCARD_EJECTION_CHECK_FREQUENCY_MILLISECONDS); - } - else { - ///Need to handle an error here - } - } - - if (!CryptGetUserKey(hprov, pinfo->dwKeySpec, &key)) { - CryptReleaseContext(hprov, 0); - free(pinfo); - return; - } - - CryptDestroyKey(key); - CryptReleaseContext(hprov, 0); - free(pinfo); - - if (smartCardTimer_) { - smartCardTimer_->onTick.connect(boost::bind(&CAPICertificate::handleSmartCardTimerTick, this)); - smartCardTimer_->start(); - } - - valid_ = true; + valid_ = false; + + /* Syntax: "certstore:" <cert_store> ":" <hash> ":" <hash_of_cert> */ + + if (!boost::iequals(capiUri.substr(0, 10), "certstore:")) { + return; + } + + /* Substring of subject: uses "storename" */ + std::string capiIdentity = capiUri.substr(10); + std::string newCertStoreName; + size_t pos = capiIdentity.find_first_of (':'); + + if (pos == std::string::npos) { + /* Using the default certificate store */ + newCertStoreName = "MY"; + certName_ = capiIdentity; + } + else { + newCertStoreName = capiIdentity.substr(0, pos); + certName_ = capiIdentity.substr(pos + 1); + } + + if (certStoreHandle_ != NULL) { + if (newCertStoreName != certStore_) { + CertCloseStore(certStoreHandle_, 0); + certStoreHandle_ = NULL; + } + } + + if (certStoreHandle_ == NULL) { + certStoreHandle_ = CertOpenSystemStore(0, newCertStoreName.c_str()); + if (!certStoreHandle_) { + return; + } + } + + certStore_ = newCertStoreName; + + PCCERT_CONTEXT certContext = findCertificateInStore (certStoreHandle_, certName_); + + if (!certContext) { + return; + } + + + /* Now verify that we can have access to the corresponding private key */ + + DWORD len; + CRYPT_KEY_PROV_INFO *pinfo; + HCRYPTPROV hprov; + HCRYPTKEY key; + + if (!CertGetCertificateContextProperty(certContext, + CERT_KEY_PROV_INFO_PROP_ID, + NULL, + &len)) { + CertFreeCertificateContext(certContext); + return; + } + + pinfo = static_cast<CRYPT_KEY_PROV_INFO *>(malloc(len)); + if (!pinfo) { + CertFreeCertificateContext(certContext); + return; + } + + if (!CertGetCertificateContextProperty(certContext, CERT_KEY_PROV_INFO_PROP_ID, pinfo, &len)) { + CertFreeCertificateContext(certContext); + free(pinfo); + return; + } + + CertFreeCertificateContext(certContext); + + // Now verify if we have access to the private key + if (!CryptAcquireContextW(&hprov, pinfo->pwszContainerName, pinfo->pwszProvName, pinfo->dwProvType, 0)) { + free(pinfo); + return; + } + + + char smartCardReader[1024]; + DWORD bufferLength = sizeof(smartCardReader); + if (!CryptGetProvParam(hprov, PP_SMARTCARD_READER, (BYTE *)&smartCardReader, &bufferLength, 0)) { + DWORD error = GetLastError(); + smartCardReaderName_ = ""; + } + else { + smartCardReaderName_ = smartCardReader; + + LONG result = SCardEstablishContext(SCARD_SCOPE_USER, NULL, NULL, &scardContext_); + DEBUG_SCARD_STATUS("SCardEstablishContext", result); + if (SCARD_S_SUCCESS == result) { + // Initiate monitoring for smartcard ejection + smartCardTimer_ = timerFactory_->createTimer(SMARTCARD_EJECTION_CHECK_FREQUENCY_MILLISECONDS); + } + else { + ///Need to handle an error here + } + } + + if (!CryptGetUserKey(hprov, pinfo->dwKeySpec, &key)) { + CryptReleaseContext(hprov, 0); + free(pinfo); + return; + } + + CryptDestroyKey(key); + CryptReleaseContext(hprov, 0); + free(pinfo); + + if (smartCardTimer_) { + smartCardTimer_->onTick.connect(boost::bind(&CAPICertificate::handleSmartCardTimerTick, this)); + smartCardTimer_->start(); + } + + valid_ = true; } static void smartcard_check_status (SCARDCONTEXT hContext, - const char* pReader, - SCARDHANDLE hCardHandle, /* Can be 0 on the first call */ - SCARDHANDLE* newCardHandle, /* The handle returned */ - DWORD* pdwState) { - DWORD shareMode = SCARD_SHARE_SHARED; - DWORD preferredProtocols = SCARD_PROTOCOL_T0 | SCARD_PROTOCOL_T1; - DWORD dwAP; - LONG result; - - if (hCardHandle == 0) { - result = SCardConnect(hContext, pReader, shareMode, preferredProtocols, &hCardHandle, &dwAP); - DEBUG_SCARD_STATUS("SCardConnect", result); - if (SCARD_S_SUCCESS != result) { - hCardHandle = 0; - } - } - - char szReader[200]; - DWORD cch = sizeof(szReader); - BYTE bAttr[32]; - DWORD cByte = 32; - size_t countStatusAttempts = 0; - - while (hCardHandle && (countStatusAttempts < 2)) { - *pdwState = SCARD_UNKNOWN; - - result = SCardStatus(hCardHandle, /* Unfortunately we can't use NULL here */ szReader, &cch, pdwState, NULL, (LPBYTE)&bAttr, &cByte); - DEBUG_SCARD_STATUS("SCardStatus", result); - countStatusAttempts++; - - if ((SCARD_W_RESET_CARD == result) && (countStatusAttempts < 2)) { - result = SCardReconnect(hCardHandle, shareMode, preferredProtocols, SCARD_RESET_CARD, &dwAP); - DEBUG_SCARD_STATUS("SCardReconnect", result); - if (SCARD_S_SUCCESS != result) { - break; - } - } - else { - break; - } - } - - if (SCARD_S_SUCCESS != result) { - if (SCARD_E_NO_SMARTCARD == result || SCARD_W_REMOVED_CARD == result) { - *pdwState = SCARD_ABSENT; - } - else { - *pdwState = SCARD_UNKNOWN; - } - } - - if (newCardHandle == NULL) { - result = SCardDisconnect(hCardHandle, SCARD_LEAVE_CARD); - DEBUG_SCARD_STATUS("SCardDisconnect", result); - } - else { - *newCardHandle = hCardHandle; - } + const char* pReader, + SCARDHANDLE hCardHandle, /* Can be 0 on the first call */ + SCARDHANDLE* newCardHandle, /* The handle returned */ + DWORD* pdwState) { + DWORD shareMode = SCARD_SHARE_SHARED; + DWORD preferredProtocols = SCARD_PROTOCOL_T0 | SCARD_PROTOCOL_T1; + DWORD dwAP; + LONG result; + + if (hCardHandle == 0) { + result = SCardConnect(hContext, pReader, shareMode, preferredProtocols, &hCardHandle, &dwAP); + DEBUG_SCARD_STATUS("SCardConnect", result); + if (SCARD_S_SUCCESS != result) { + hCardHandle = 0; + } + } + + char szReader[200]; + DWORD cch = sizeof(szReader); + BYTE bAttr[32]; + DWORD cByte = 32; + size_t countStatusAttempts = 0; + + while (hCardHandle && (countStatusAttempts < 2)) { + *pdwState = SCARD_UNKNOWN; + + result = SCardStatus(hCardHandle, /* Unfortunately we can't use NULL here */ szReader, &cch, pdwState, NULL, (LPBYTE)&bAttr, &cByte); + DEBUG_SCARD_STATUS("SCardStatus", result); + countStatusAttempts++; + + if ((SCARD_W_RESET_CARD == result) && (countStatusAttempts < 2)) { + result = SCardReconnect(hCardHandle, shareMode, preferredProtocols, SCARD_RESET_CARD, &dwAP); + DEBUG_SCARD_STATUS("SCardReconnect", result); + if (SCARD_S_SUCCESS != result) { + break; + } + } + else { + break; + } + } + + if (SCARD_S_SUCCESS != result) { + if (SCARD_E_NO_SMARTCARD == result || SCARD_W_REMOVED_CARD == result) { + *pdwState = SCARD_ABSENT; + } + else { + *pdwState = SCARD_UNKNOWN; + } + } + + if (newCardHandle == NULL) { + result = SCardDisconnect(hCardHandle, SCARD_LEAVE_CARD); + DEBUG_SCARD_STATUS("SCardDisconnect", result); + } + else { + *newCardHandle = hCardHandle; + } } bool CAPICertificate::checkIfSmartCardPresent () { - if (!smartCardReaderName_.empty()) { - DWORD dwState; - smartcard_check_status(scardContext_, smartCardReaderName_.c_str(), cardHandle_, &cardHandle_, &dwState); - - switch (dwState) { - case SCARD_ABSENT: - SWIFT_LOG(debug) << "Card absent." << std::endl; - break; - case SCARD_PRESENT: - SWIFT_LOG(debug) << "Card present." << std::endl; - break; - case SCARD_SWALLOWED: - SWIFT_LOG(debug) << "Card swallowed." << std::endl; - break; - case SCARD_POWERED: - SWIFT_LOG(debug) << "Card has power." << std::endl; - break; - case SCARD_NEGOTIABLE: - SWIFT_LOG(debug) << "Card reset and waiting PTS negotiation." << std::endl; - break; - case SCARD_SPECIFIC: - SWIFT_LOG(debug) << "Card has specific communication protocols set." << std::endl; - break; - default: - SWIFT_LOG(debug) << "Unknown or unexpected card state." << std::endl; - break; - } - - - - switch (dwState) { - case SCARD_ABSENT: - return false; - - case SCARD_PRESENT: - case SCARD_SWALLOWED: - case SCARD_POWERED: - case SCARD_NEGOTIABLE: - case SCARD_SPECIFIC: - return true; - - default: - return false; - } - } - else { - return false; - } + if (!smartCardReaderName_.empty()) { + DWORD dwState; + smartcard_check_status(scardContext_, smartCardReaderName_.c_str(), cardHandle_, &cardHandle_, &dwState); + + switch (dwState) { + case SCARD_ABSENT: + SWIFT_LOG(debug) << "Card absent." << std::endl; + break; + case SCARD_PRESENT: + SWIFT_LOG(debug) << "Card present." << std::endl; + break; + case SCARD_SWALLOWED: + SWIFT_LOG(debug) << "Card swallowed." << std::endl; + break; + case SCARD_POWERED: + SWIFT_LOG(debug) << "Card has power." << std::endl; + break; + case SCARD_NEGOTIABLE: + SWIFT_LOG(debug) << "Card reset and waiting PTS negotiation." << std::endl; + break; + case SCARD_SPECIFIC: + SWIFT_LOG(debug) << "Card has specific communication protocols set." << std::endl; + break; + default: + SWIFT_LOG(debug) << "Unknown or unexpected card state." << std::endl; + break; + } + + + + switch (dwState) { + case SCARD_ABSENT: + return false; + + case SCARD_PRESENT: + case SCARD_SWALLOWED: + case SCARD_POWERED: + case SCARD_NEGOTIABLE: + case SCARD_SPECIFIC: + return true; + + default: + return false; + } + } + else { + return false; + } } void CAPICertificate::handleSmartCardTimerTick() { - bool poll = checkIfSmartCardPresent(); - if (lastPollingResult_ && !poll) { - onCertificateCardRemoved(); - } - lastPollingResult_ = poll; - smartCardTimer_->start(); + bool poll = checkIfSmartCardPresent(); + if (lastPollingResult_ && !poll) { + onCertificateCardRemoved(); + } + lastPollingResult_ = poll; + smartCardTimer_->start(); } } diff --git a/Swiften/TLS/CAPICertificate.h b/Swiften/TLS/CAPICertificate.h index aebfb41..0259db5 100644 --- a/Swiften/TLS/CAPICertificate.h +++ b/Swiften/TLS/CAPICertificate.h @@ -4,10 +4,16 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #pragma once #include <Swiften/Base/API.h> -#include <Swiften/Base/boost_bsignals.h> +#include <boost/signals2.hpp> #include <Swiften/Base/SafeByteArray.h> #include <Swiften/TLS/CertificateWithKey.h> #include <Swiften/Network/Timer.h> @@ -17,52 +23,52 @@ #include <WinCrypt.h> #include <Winscard.h> -#define SMARTCARD_EJECTION_CHECK_FREQUENCY_MILLISECONDS 1000 +#define SMARTCARD_EJECTION_CHECK_FREQUENCY_MILLISECONDS 1000 namespace Swift { - class TimerFactory; + class TimerFactory; - class SWIFTEN_API CAPICertificate : public Swift::CertificateWithKey { - public: - CAPICertificate(const std::string& capiUri, TimerFactory* timerFactory); + class SWIFTEN_API CAPICertificate : public Swift::CertificateWithKey { + public: + CAPICertificate(const std::string& capiUri, TimerFactory* timerFactory); - virtual ~CAPICertificate(); + virtual ~CAPICertificate(); - virtual bool isNull() const; + virtual bool isNull() const; - const std::string& getCertStoreName() const; + const std::string& getCertStoreName() const; - const std::string& getCertName() const; + const std::string& getCertName() const; - const std::string& getSmartCardReaderName() const; + const std::string& getSmartCardReaderName() const; - public: - boost::signal<void ()> onCertificateCardRemoved; + public: + boost::signals2::signal<void ()> onCertificateCardRemoved; - private: - void setUri (const std::string& capiUri); + private: + void setUri (const std::string& capiUri); - void handleSmartCardTimerTick(); + void handleSmartCardTimerTick(); - bool checkIfSmartCardPresent(); + bool checkIfSmartCardPresent(); - private: - bool valid_; - std::string uri_; + private: + bool valid_; + std::string uri_; - HCERTSTORE certStoreHandle_; - SCARDCONTEXT scardContext_; - SCARDHANDLE cardHandle_; + HCERTSTORE certStoreHandle_; + SCARDCONTEXT scardContext_; + SCARDHANDLE cardHandle_; - /* Parsed components of the uri_ */ - std::string certStore_; - std::string certName_; - std::string smartCardReaderName_; - boost::shared_ptr<Timer> smartCardTimer_; - TimerFactory* timerFactory_; + /* Parsed components of the uri_ */ + std::string certStore_; + std::string certName_; + std::string smartCardReaderName_; + std::shared_ptr<Timer> smartCardTimer_; + TimerFactory* timerFactory_; - bool lastPollingResult_; - }; + bool lastPollingResult_; + }; PCCERT_CONTEXT findCertificateInStore (HCERTSTORE certStoreHandle, const std::string &certName); diff --git a/Swiften/TLS/Certificate.cpp b/Swiften/TLS/Certificate.cpp index fe84a74..c7d48b2 100644 --- a/Swiften/TLS/Certificate.cpp +++ b/Swiften/TLS/Certificate.cpp @@ -20,15 +20,15 @@ Certificate::~Certificate() { } std::string Certificate::getSHA1Fingerprint(Certificate::ref certificate, CryptoProvider* crypto) { - ByteArray hash = crypto->getSHA1Hash(certificate->toDER()); - std::ostringstream s; - for (size_t i = 0; i < hash.size(); ++i) { - if (i > 0) { - s << ":"; - } - s << Hexify::hexify(hash[i]); - } - return std::string(s.str()); + ByteArray hash = crypto->getSHA1Hash(certificate->toDER()); + std::ostringstream s; + for (size_t i = 0; i < hash.size(); ++i) { + if (i > 0) { + s << ":"; + } + s << Hexify::hexify(hash[i]); + } + return std::string(s.str()); } } diff --git a/Swiften/TLS/Certificate.h b/Swiften/TLS/Certificate.h index 00d618e..dbc61ad 100644 --- a/Swiften/TLS/Certificate.h +++ b/Swiften/TLS/Certificate.h @@ -1,46 +1,45 @@ /* - * Copyright (c) 2010-2015 Isode Limited. + * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #pragma once +#include <memory> #include <string> #include <vector> -#include <boost/shared_ptr.hpp> - #include <Swiften/Base/API.h> #include <Swiften/Base/ByteArray.h> namespace Swift { - class CryptoProvider; + class CryptoProvider; - class SWIFTEN_API Certificate { - public: - typedef boost::shared_ptr<Certificate> ref; + class SWIFTEN_API Certificate { + public: + typedef std::shared_ptr<Certificate> ref; - virtual ~Certificate(); + virtual ~Certificate(); - /** - * Returns the textual representation of the full Subject - * name. - */ - virtual std::string getSubjectName() const = 0; + /** + * Returns the textual representation of the full Subject + * name. + */ + virtual std::string getSubjectName() const = 0; - virtual std::vector<std::string> getCommonNames() const = 0; - virtual std::vector<std::string> getSRVNames() const = 0; - virtual std::vector<std::string> getDNSNames() const = 0; - virtual std::vector<std::string> getXMPPAddresses() const = 0; + virtual std::vector<std::string> getCommonNames() const = 0; + virtual std::vector<std::string> getSRVNames() const = 0; + virtual std::vector<std::string> getDNSNames() const = 0; + virtual std::vector<std::string> getXMPPAddresses() const = 0; - virtual ByteArray toDER() const = 0; + virtual ByteArray toDER() const = 0; - static std::string getSHA1Fingerprint(Certificate::ref, CryptoProvider* crypto); + static std::string getSHA1Fingerprint(Certificate::ref, CryptoProvider* crypto); - protected: - static const char* ID_ON_XMPPADDR_OID; - static const char* ID_ON_DNSSRV_OID; + protected: + static const char* ID_ON_XMPPADDR_OID; + static const char* ID_ON_DNSSRV_OID; - }; + }; } diff --git a/Swiften/TLS/CertificateFactory.h b/Swiften/TLS/CertificateFactory.h index 28d39bb..522a6e6 100644 --- a/Swiften/TLS/CertificateFactory.h +++ b/Swiften/TLS/CertificateFactory.h @@ -10,10 +10,10 @@ #include <Swiften/TLS/Certificate.h> namespace Swift { - class SWIFTEN_API CertificateFactory { - public: - virtual ~CertificateFactory(); + class SWIFTEN_API CertificateFactory { + public: + virtual ~CertificateFactory(); - virtual Certificate* createCertificateFromDER(const ByteArray& der) = 0; - }; + virtual Certificate* createCertificateFromDER(const ByteArray& der) = 0; + }; } diff --git a/Swiften/TLS/CertificateTrustChecker.h b/Swiften/TLS/CertificateTrustChecker.h index 744634f..dd2b3ec 100644 --- a/Swiften/TLS/CertificateTrustChecker.h +++ b/Swiften/TLS/CertificateTrustChecker.h @@ -1,12 +1,12 @@ /* - * Copyright (c) 2010 Isode Limited. + * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #pragma once -#include <boost/shared_ptr.hpp> +#include <memory> #include <string> #include <vector> @@ -14,21 +14,21 @@ #include <Swiften/TLS/Certificate.h> namespace Swift { - /** - * A class to implement a check for certificate trust. - */ - class SWIFTEN_API CertificateTrustChecker { - public: - virtual ~CertificateTrustChecker(); + /** + * A class to implement a check for certificate trust. + */ + class SWIFTEN_API CertificateTrustChecker { + public: + virtual ~CertificateTrustChecker(); - /** - * This method is called to find out whether a certificate (chain) is - * trusted. This usually happens when a certificate's validation - * fails, to check whether to proceed with the connection or not. - * - * certificateChain contains the chain of certificates. The first certificate - * is the subject certificate. - */ - virtual bool isCertificateTrusted(const std::vector<Certificate::ref>& certificateChain) = 0; - }; + /** + * This method is called to find out whether a certificate (chain) is + * trusted. This usually happens when a certificate's validation + * fails, to check whether to proceed with the connection or not. + * + * certificateChain contains the chain of certificates. The first certificate + * is the subject certificate. + */ + virtual bool isCertificateTrusted(const std::vector<Certificate::ref>& certificateChain) = 0; + }; } diff --git a/Swiften/TLS/CertificateVerificationError.h b/Swiften/TLS/CertificateVerificationError.h index f1596dc..02b4cca 100644 --- a/Swiften/TLS/CertificateVerificationError.h +++ b/Swiften/TLS/CertificateVerificationError.h @@ -1,43 +1,44 @@ /* - * Copyright (c) 2010-2015 Isode Limited. + * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #pragma once -#include <boost/shared_ptr.hpp> +#include <memory> + #include <Swiften/Base/API.h> #include <Swiften/Base/Error.h> namespace Swift { - class SWIFTEN_API CertificateVerificationError : public Error { - public: - typedef boost::shared_ptr<CertificateVerificationError> ref; - - enum Type { - UnknownError, - Expired, - NotYetValid, - SelfSigned, - Rejected, - Untrusted, - InvalidPurpose, - PathLengthExceeded, - InvalidSignature, - InvalidCA, - InvalidServerIdentity, - Revoked, - RevocationCheckFailed - }; - - CertificateVerificationError(Type type = UnknownError) : type(type) {} - - Type getType() const { - return type; - } - - private: - Type type; - }; + class SWIFTEN_API CertificateVerificationError : public Error { + public: + typedef std::shared_ptr<CertificateVerificationError> ref; + + enum Type { + UnknownError, + Expired, + NotYetValid, + SelfSigned, + Rejected, + Untrusted, + InvalidPurpose, + PathLengthExceeded, + InvalidSignature, + InvalidCA, + InvalidServerIdentity, + Revoked, + RevocationCheckFailed + }; + + CertificateVerificationError(Type type = UnknownError) : type(type) {} + + Type getType() const { + return type; + } + + private: + Type type; + }; } diff --git a/Swiften/TLS/CertificateWithKey.h b/Swiften/TLS/CertificateWithKey.h index 687118a..8414938 100644 --- a/Swiften/TLS/CertificateWithKey.h +++ b/Swiften/TLS/CertificateWithKey.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2010-2015 Isode Limited. + * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ @@ -10,14 +10,14 @@ #include <Swiften/Base/SafeByteArray.h> namespace Swift { - class SWIFTEN_API CertificateWithKey { - public: - typedef boost::shared_ptr<CertificateWithKey> ref; - CertificateWithKey() {} + class SWIFTEN_API CertificateWithKey { + public: + typedef std::shared_ptr<CertificateWithKey> ref; + CertificateWithKey() {} - virtual ~CertificateWithKey() {} + virtual ~CertificateWithKey() {} - virtual bool isNull() const = 0; + virtual bool isNull() const = 0; - }; + }; } diff --git a/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp b/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp index fc8dce5..17ac8cc 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp +++ b/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2010-2013 Isode Limited. + * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ @@ -19,92 +19,92 @@ namespace Swift { -OpenSSLCertificate::OpenSSLCertificate(boost::shared_ptr<X509> cert) : cert(cert) { - parse(); +OpenSSLCertificate::OpenSSLCertificate(std::shared_ptr<X509> cert) : cert(cert) { + parse(); } OpenSSLCertificate::OpenSSLCertificate(const ByteArray& der) { #if OPENSSL_VERSION_NUMBER <= 0x009070cfL - unsigned char* p = const_cast<unsigned char*>(vecptr(der)); + unsigned char* p = const_cast<unsigned char*>(vecptr(der)); #else - const unsigned char* p = vecptr(der); + const unsigned char* p = vecptr(der); #endif - cert = boost::shared_ptr<X509>(d2i_X509(NULL, &p, der.size()), X509_free); - if (!cert) { - SWIFT_LOG(warning) << "Error creating certificate from DER data" << std::endl; - } - parse(); + cert = std::shared_ptr<X509>(d2i_X509(NULL, &p, der.size()), X509_free); + if (!cert) { + SWIFT_LOG(warning) << "Error creating certificate from DER data" << std::endl; + } + parse(); } ByteArray OpenSSLCertificate::toDER() const { - ByteArray result; - if (!cert) { - return result; - } - result.resize(i2d_X509(cert.get(), NULL)); - unsigned char* p = vecptr(result); - i2d_X509(cert.get(), &p); - return result; + ByteArray result; + if (!cert) { + return result; + } + result.resize(i2d_X509(cert.get(), NULL)); + unsigned char* p = vecptr(result); + i2d_X509(cert.get(), &p); + return result; } void OpenSSLCertificate::parse() { - if (!cert) { - return; - } - // Subject name - X509_NAME* subjectName = X509_get_subject_name(cert.get()); - if (subjectName) { - // Subject name - ByteArray subjectNameData; - subjectNameData.resize(256); - X509_NAME_oneline(X509_get_subject_name(cert.get()), reinterpret_cast<char*>(vecptr(subjectNameData)), static_cast<unsigned int>(subjectNameData.size())); - this->subjectName = byteArrayToString(subjectNameData); + if (!cert) { + return; + } + // Subject name + X509_NAME* subjectName = X509_get_subject_name(cert.get()); + if (subjectName) { + // Subject name + ByteArray subjectNameData; + subjectNameData.resize(256); + X509_NAME_oneline(X509_get_subject_name(cert.get()), reinterpret_cast<char*>(vecptr(subjectNameData)), static_cast<unsigned int>(subjectNameData.size())); + this->subjectName = byteArrayToString(subjectNameData); - // Common name - int cnLoc = X509_NAME_get_index_by_NID(subjectName, NID_commonName, -1); - while (cnLoc != -1) { - X509_NAME_ENTRY* cnEntry = X509_NAME_get_entry(subjectName, cnLoc); - ASN1_STRING* cnData = X509_NAME_ENTRY_get_data(cnEntry); - commonNames.push_back(byteArrayToString(createByteArray(reinterpret_cast<const char*>(cnData->data), cnData->length))); - cnLoc = X509_NAME_get_index_by_NID(subjectName, NID_commonName, cnLoc); - } - } + // Common name + int cnLoc = X509_NAME_get_index_by_NID(subjectName, NID_commonName, -1); + while (cnLoc != -1) { + X509_NAME_ENTRY* cnEntry = X509_NAME_get_entry(subjectName, cnLoc); + ASN1_STRING* cnData = X509_NAME_ENTRY_get_data(cnEntry); + commonNames.push_back(byteArrayToString(createByteArray(reinterpret_cast<const char*>(cnData->data), cnData->length))); + cnLoc = X509_NAME_get_index_by_NID(subjectName, NID_commonName, cnLoc); + } + } - // subjectAltNames - int subjectAltNameLoc = X509_get_ext_by_NID(cert.get(), NID_subject_alt_name, -1); - if(subjectAltNameLoc != -1) { - X509_EXTENSION* extension = X509_get_ext(cert.get(), subjectAltNameLoc); - boost::shared_ptr<GENERAL_NAMES> generalNames(reinterpret_cast<GENERAL_NAMES*>(X509V3_EXT_d2i(extension)), GENERAL_NAMES_free); - boost::shared_ptr<ASN1_OBJECT> xmppAddrObject(OBJ_txt2obj(ID_ON_XMPPADDR_OID, 1), ASN1_OBJECT_free); - boost::shared_ptr<ASN1_OBJECT> dnsSRVObject(OBJ_txt2obj(ID_ON_DNSSRV_OID, 1), ASN1_OBJECT_free); - for (int i = 0; i < sk_GENERAL_NAME_num(generalNames.get()); ++i) { - GENERAL_NAME* generalName = sk_GENERAL_NAME_value(generalNames.get(), i); - if (generalName->type == GEN_OTHERNAME) { - OTHERNAME* otherName = generalName->d.otherName; - if (OBJ_cmp(otherName->type_id, xmppAddrObject.get()) == 0) { - // XmppAddr - if (otherName->value->type != V_ASN1_UTF8STRING) { - continue; - } - ASN1_UTF8STRING* xmppAddrValue = otherName->value->value.utf8string; - addXMPPAddress(byteArrayToString(createByteArray(reinterpret_cast<const char*>(ASN1_STRING_data(xmppAddrValue)), ASN1_STRING_length(xmppAddrValue)))); - } - else if (OBJ_cmp(otherName->type_id, dnsSRVObject.get()) == 0) { - // SRVName - if (otherName->value->type != V_ASN1_IA5STRING) { - continue; - } - ASN1_IA5STRING* srvNameValue = otherName->value->value.ia5string; - addSRVName(byteArrayToString(createByteArray(reinterpret_cast<const char*>(ASN1_STRING_data(srvNameValue)), ASN1_STRING_length(srvNameValue)))); - } - } - else if (generalName->type == GEN_DNS) { - // DNSName - addDNSName(byteArrayToString(createByteArray(ASN1_STRING_data(generalName->d.dNSName), ASN1_STRING_length(generalName->d.dNSName)))); - } - } - } + // subjectAltNames + int subjectAltNameLoc = X509_get_ext_by_NID(cert.get(), NID_subject_alt_name, -1); + if(subjectAltNameLoc != -1) { + X509_EXTENSION* extension = X509_get_ext(cert.get(), subjectAltNameLoc); + std::shared_ptr<GENERAL_NAMES> generalNames(reinterpret_cast<GENERAL_NAMES*>(X509V3_EXT_d2i(extension)), GENERAL_NAMES_free); + std::shared_ptr<ASN1_OBJECT> xmppAddrObject(OBJ_txt2obj(ID_ON_XMPPADDR_OID, 1), ASN1_OBJECT_free); + std::shared_ptr<ASN1_OBJECT> dnsSRVObject(OBJ_txt2obj(ID_ON_DNSSRV_OID, 1), ASN1_OBJECT_free); + for (int i = 0; i < sk_GENERAL_NAME_num(generalNames.get()); ++i) { + GENERAL_NAME* generalName = sk_GENERAL_NAME_value(generalNames.get(), i); + if (generalName->type == GEN_OTHERNAME) { + OTHERNAME* otherName = generalName->d.otherName; + if (OBJ_cmp(otherName->type_id, xmppAddrObject.get()) == 0) { + // XmppAddr + if (otherName->value->type != V_ASN1_UTF8STRING) { + continue; + } + ASN1_UTF8STRING* xmppAddrValue = otherName->value->value.utf8string; + addXMPPAddress(byteArrayToString(createByteArray(reinterpret_cast<const char*>(ASN1_STRING_data(xmppAddrValue)), ASN1_STRING_length(xmppAddrValue)))); + } + else if (OBJ_cmp(otherName->type_id, dnsSRVObject.get()) == 0) { + // SRVName + if (otherName->value->type != V_ASN1_IA5STRING) { + continue; + } + ASN1_IA5STRING* srvNameValue = otherName->value->value.ia5string; + addSRVName(byteArrayToString(createByteArray(reinterpret_cast<const char*>(ASN1_STRING_data(srvNameValue)), ASN1_STRING_length(srvNameValue)))); + } + } + else if (generalName->type == GEN_DNS) { + // DNSName + addDNSName(byteArrayToString(createByteArray(ASN1_STRING_data(generalName->d.dNSName), ASN1_STRING_length(generalName->d.dNSName)))); + } + } + } } } diff --git a/Swiften/TLS/OpenSSL/OpenSSLCertificate.h b/Swiften/TLS/OpenSSL/OpenSSLCertificate.h index 2cc047a..186caea 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLCertificate.h +++ b/Swiften/TLS/OpenSSL/OpenSSLCertificate.h @@ -1,70 +1,71 @@ /* - * Copyright (c) 2010 Isode Limited. + * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #pragma once -#include <boost/shared_ptr.hpp> +#include <memory> +#include <string> + #include <openssl/ssl.h> -#include <string> #include <Swiften/TLS/Certificate.h> namespace Swift { - class OpenSSLCertificate : public Certificate { - public: - OpenSSLCertificate(boost::shared_ptr<X509>); - OpenSSLCertificate(const ByteArray& der); - - std::string getSubjectName() const { - return subjectName; - } - - std::vector<std::string> getCommonNames() const { - return commonNames; - } - - std::vector<std::string> getSRVNames() const { - return srvNames; - } - - std::vector<std::string> getDNSNames() const { - return dnsNames; - } - - std::vector<std::string> getXMPPAddresses() const { - return xmppAddresses; - } - - ByteArray toDER() const; - - boost::shared_ptr<X509> getInternalX509() const { - return cert; - } - - private: - void parse(); - - void addSRVName(const std::string& name) { - srvNames.push_back(name); - } - - void addDNSName(const std::string& name) { - dnsNames.push_back(name); - } - - void addXMPPAddress(const std::string& addr) { - xmppAddresses.push_back(addr); - } - - private: - boost::shared_ptr<X509> cert; - std::string subjectName; - std::vector<std::string> commonNames; - std::vector<std::string> dnsNames; - std::vector<std::string> xmppAddresses; - std::vector<std::string> srvNames; - }; + class OpenSSLCertificate : public Certificate { + public: + OpenSSLCertificate(std::shared_ptr<X509>); + OpenSSLCertificate(const ByteArray& der); + + std::string getSubjectName() const { + return subjectName; + } + + std::vector<std::string> getCommonNames() const { + return commonNames; + } + + std::vector<std::string> getSRVNames() const { + return srvNames; + } + + std::vector<std::string> getDNSNames() const { + return dnsNames; + } + + std::vector<std::string> getXMPPAddresses() const { + return xmppAddresses; + } + + ByteArray toDER() const; + + std::shared_ptr<X509> getInternalX509() const { + return cert; + } + + private: + void parse(); + + void addSRVName(const std::string& name) { + srvNames.push_back(name); + } + + void addDNSName(const std::string& name) { + dnsNames.push_back(name); + } + + void addXMPPAddress(const std::string& addr) { + xmppAddresses.push_back(addr); + } + + private: + std::shared_ptr<X509> cert; + std::string subjectName; + std::vector<std::string> commonNames; + std::vector<std::string> dnsNames; + std::vector<std::string> xmppAddresses; + std::vector<std::string> srvNames; + }; } diff --git a/Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.h b/Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.h index bb8780e..c996cd5 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.h +++ b/Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.h @@ -10,10 +10,10 @@ #include <Swiften/TLS/OpenSSL/OpenSSLCertificate.h> namespace Swift { - class OpenSSLCertificateFactory : public CertificateFactory { - public: - virtual Certificate* createCertificateFromDER(const ByteArray& der) { - return new OpenSSLCertificate(der); - } - }; + class OpenSSLCertificateFactory : public CertificateFactory { + public: + virtual Certificate* createCertificateFromDER(const ByteArray& der) { + return new OpenSSLCertificate(der); + } + }; } diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp index a3e0e1d..cd6b6bc 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp +++ b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2010-2013 Isode Limited. + * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ @@ -13,7 +13,7 @@ #include <vector> #include <openssl/err.h> #include <openssl/pkcs12.h> -#include <boost/smart_ptr/make_shared.hpp> +#include <memory> #if defined(SWIFTEN_PLATFORM_MACOSX) #include <Security/Security.h> @@ -36,302 +36,302 @@ static const int MAX_FINISHED_SIZE = 4096; static const int SSL_READ_BUFFERSIZE = 8192; static void freeX509Stack(STACK_OF(X509)* stack) { - sk_X509_free(stack); + sk_X509_free(stack); } OpenSSLContext::OpenSSLContext() : state_(Start), context_(0), handle_(0), readBIO_(0), writeBIO_(0) { - ensureLibraryInitialized(); - context_ = SSL_CTX_new(SSLv23_client_method()); - SSL_CTX_set_options(context_, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3); + ensureLibraryInitialized(); + context_ = SSL_CTX_new(SSLv23_client_method()); + SSL_CTX_set_options(context_, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3); - // TODO: implement CRL checking - // TODO: download CRL (HTTP transport) - // TODO: cache CRL downloads for configurable time period + // TODO: implement CRL checking + // TODO: download CRL (HTTP transport) + // TODO: cache CRL downloads for configurable time period - // TODO: implement OCSP support - // TODO: handle OCSP stapling see https://www.rfc-editor.org/rfc/rfc4366.txt - // Load system certs + // TODO: implement OCSP support + // TODO: handle OCSP stapling see https://www.rfc-editor.org/rfc/rfc4366.txt + // Load system certs #if defined(SWIFTEN_PLATFORM_WINDOWS) - X509_STORE* store = SSL_CTX_get_cert_store(context_); - HCERTSTORE systemStore = CertOpenSystemStore(0, "ROOT"); - if (systemStore) { - PCCERT_CONTEXT certContext = NULL; - while (true) { - certContext = CertFindCertificateInStore(systemStore, X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, 0, CERT_FIND_ANY, NULL, certContext); - if (!certContext) { - break; - } - OpenSSLCertificate cert(createByteArray(certContext->pbCertEncoded, certContext->cbCertEncoded)); - if (store && cert.getInternalX509()) { - X509_STORE_add_cert(store, cert.getInternalX509().get()); - } - } - } + X509_STORE* store = SSL_CTX_get_cert_store(context_); + HCERTSTORE systemStore = CertOpenSystemStore(0, "ROOT"); + if (systemStore) { + PCCERT_CONTEXT certContext = NULL; + while (true) { + certContext = CertFindCertificateInStore(systemStore, X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, 0, CERT_FIND_ANY, NULL, certContext); + if (!certContext) { + break; + } + OpenSSLCertificate cert(createByteArray(certContext->pbCertEncoded, certContext->cbCertEncoded)); + if (store && cert.getInternalX509()) { + X509_STORE_add_cert(store, cert.getInternalX509().get()); + } + } + } #elif !defined(SWIFTEN_PLATFORM_MACOSX) - SSL_CTX_load_verify_locations(context_, NULL, "/etc/ssl/certs"); + SSL_CTX_set_default_verify_paths(context_); #elif defined(SWIFTEN_PLATFORM_MACOSX) && !defined(SWIFTEN_PLATFORM_IPHONE) - // On Mac OS X 10.5 (OpenSSL < 0.9.8), OpenSSL does not automatically look in the system store. - // On Mac OS X 10.6 (OpenSSL >= 0.9.8), OpenSSL *does* look in the system store to determine trust. - // However, if there is a certificate error, it will always emit the "Invalid CA" error if we didn't add - // the certificates first. See - // http://opensource.apple.com/source/OpenSSL098/OpenSSL098-27/src/crypto/x509/x509_vfy_apple.c - // to understand why. We therefore add all certs from the system store ourselves. - X509_STORE* store = SSL_CTX_get_cert_store(context_); - CFArrayRef anchorCertificates; - if (SecTrustCopyAnchorCertificates(&anchorCertificates) == 0) { - for (int i = 0; i < CFArrayGetCount(anchorCertificates); ++i) { - SecCertificateRef cert = reinterpret_cast<SecCertificateRef>(const_cast<void*>(CFArrayGetValueAtIndex(anchorCertificates, i))); - CSSM_DATA certCSSMData; - if (SecCertificateGetData(cert, &certCSSMData) != 0 || certCSSMData.Length == 0) { - continue; - } - std::vector<unsigned char> certData; - certData.resize(certCSSMData.Length); - memcpy(&certData[0], certCSSMData.Data, certCSSMData.Length); - OpenSSLCertificate certificate(certData); - if (store && certificate.getInternalX509()) { - X509_STORE_add_cert(store, certificate.getInternalX509().get()); - } - } - CFRelease(anchorCertificates); - } + // On Mac OS X 10.5 (OpenSSL < 0.9.8), OpenSSL does not automatically look in the system store. + // On Mac OS X 10.6 (OpenSSL >= 0.9.8), OpenSSL *does* look in the system store to determine trust. + // However, if there is a certificate error, it will always emit the "Invalid CA" error if we didn't add + // the certificates first. See + // http://opensource.apple.com/source/OpenSSL098/OpenSSL098-27/src/crypto/x509/x509_vfy_apple.c + // to understand why. We therefore add all certs from the system store ourselves. + X509_STORE* store = SSL_CTX_get_cert_store(context_); + CFArrayRef anchorCertificates; + if (SecTrustCopyAnchorCertificates(&anchorCertificates) == 0) { + for (int i = 0; i < CFArrayGetCount(anchorCertificates); ++i) { + SecCertificateRef cert = reinterpret_cast<SecCertificateRef>(const_cast<void*>(CFArrayGetValueAtIndex(anchorCertificates, i))); + CSSM_DATA certCSSMData; + if (SecCertificateGetData(cert, &certCSSMData) != 0 || certCSSMData.Length == 0) { + continue; + } + std::vector<unsigned char> certData; + certData.resize(certCSSMData.Length); + memcpy(&certData[0], certCSSMData.Data, certCSSMData.Length); + OpenSSLCertificate certificate(certData); + if (store && certificate.getInternalX509()) { + X509_STORE_add_cert(store, certificate.getInternalX509().get()); + } + } + CFRelease(anchorCertificates); + } #endif } OpenSSLContext::~OpenSSLContext() { - SSL_free(handle_); - SSL_CTX_free(context_); + SSL_free(handle_); + SSL_CTX_free(context_); } void OpenSSLContext::ensureLibraryInitialized() { - static bool isLibraryInitialized = false; - if (!isLibraryInitialized) { - SSL_load_error_strings(); - SSL_library_init(); - OpenSSL_add_all_algorithms(); - - // Disable compression - /* - STACK_OF(SSL_COMP)* compressionMethods = SSL_COMP_get_compression_methods(); - sk_SSL_COMP_zero(compressionMethods);*/ - - isLibraryInitialized = true; - } + static bool isLibraryInitialized = false; + if (!isLibraryInitialized) { + SSL_load_error_strings(); + SSL_library_init(); + OpenSSL_add_all_algorithms(); + + // Disable compression + /* + STACK_OF(SSL_COMP)* compressionMethods = SSL_COMP_get_compression_methods(); + sk_SSL_COMP_zero(compressionMethods);*/ + + isLibraryInitialized = true; + } } void OpenSSLContext::connect() { - handle_ = SSL_new(context_); - // Ownership of BIOs is ransferred - readBIO_ = BIO_new(BIO_s_mem()); - writeBIO_ = BIO_new(BIO_s_mem()); - SSL_set_bio(handle_, readBIO_, writeBIO_); - - state_ = Connecting; - doConnect(); + handle_ = SSL_new(context_); + // Ownership of BIOs is ransferred + readBIO_ = BIO_new(BIO_s_mem()); + writeBIO_ = BIO_new(BIO_s_mem()); + SSL_set_bio(handle_, readBIO_, writeBIO_); + + state_ = Connecting; + doConnect(); } void OpenSSLContext::doConnect() { - int connectResult = SSL_connect(handle_); - int error = SSL_get_error(handle_, connectResult); - switch (error) { - case SSL_ERROR_NONE: { - state_ = Connected; - //std::cout << x->name << std::endl; - //const char* comp = SSL_get_current_compression(handle_); - //std::cout << "Compression: " << SSL_COMP_get_name(comp) << std::endl; - onConnected(); - break; - } - case SSL_ERROR_WANT_READ: - sendPendingDataToNetwork(); - break; - default: - state_ = Error; - onError(boost::make_shared<TLSError>()); - } + int connectResult = SSL_connect(handle_); + int error = SSL_get_error(handle_, connectResult); + switch (error) { + case SSL_ERROR_NONE: { + state_ = Connected; + //std::cout << x->name << std::endl; + //const char* comp = SSL_get_current_compression(handle_); + //std::cout << "Compression: " << SSL_COMP_get_name(comp) << std::endl; + onConnected(); + break; + } + case SSL_ERROR_WANT_READ: + sendPendingDataToNetwork(); + break; + default: + state_ = Error; + onError(std::make_shared<TLSError>()); + } } void OpenSSLContext::sendPendingDataToNetwork() { - int size = BIO_pending(writeBIO_); - if (size > 0) { - SafeByteArray data; - data.resize(size); - BIO_read(writeBIO_, vecptr(data), size); - onDataForNetwork(data); - } + int size = BIO_pending(writeBIO_); + if (size > 0) { + SafeByteArray data; + data.resize(size); + BIO_read(writeBIO_, vecptr(data), size); + onDataForNetwork(data); + } } void OpenSSLContext::handleDataFromNetwork(const SafeByteArray& data) { - BIO_write(readBIO_, vecptr(data), data.size()); - switch (state_) { - case Connecting: - doConnect(); - break; - case Connected: - sendPendingDataToApplication(); - break; - case Start: assert(false); break; - case Error: /*assert(false);*/ break; - } + BIO_write(readBIO_, vecptr(data), data.size()); + switch (state_) { + case Connecting: + doConnect(); + break; + case Connected: + sendPendingDataToApplication(); + break; + case Start: assert(false); break; + case Error: /*assert(false);*/ break; + } } void OpenSSLContext::handleDataFromApplication(const SafeByteArray& data) { - if (SSL_write(handle_, vecptr(data), data.size()) >= 0) { - sendPendingDataToNetwork(); - } - else { - state_ = Error; - onError(boost::make_shared<TLSError>()); - } + if (SSL_write(handle_, vecptr(data), data.size()) >= 0) { + sendPendingDataToNetwork(); + } + else { + state_ = Error; + onError(std::make_shared<TLSError>()); + } } void OpenSSLContext::sendPendingDataToApplication() { - SafeByteArray data; - data.resize(SSL_READ_BUFFERSIZE); - int ret = SSL_read(handle_, vecptr(data), data.size()); - while (ret > 0) { - data.resize(ret); - onDataForApplication(data); - data.resize(SSL_READ_BUFFERSIZE); - ret = SSL_read(handle_, vecptr(data), data.size()); - } - if (ret < 0 && SSL_get_error(handle_, ret) != SSL_ERROR_WANT_READ) { - state_ = Error; - onError(boost::make_shared<TLSError>()); - } + SafeByteArray data; + data.resize(SSL_READ_BUFFERSIZE); + int ret = SSL_read(handle_, vecptr(data), data.size()); + while (ret > 0) { + data.resize(ret); + onDataForApplication(data); + data.resize(SSL_READ_BUFFERSIZE); + ret = SSL_read(handle_, vecptr(data), data.size()); + } + if (ret < 0 && SSL_get_error(handle_, ret) != SSL_ERROR_WANT_READ) { + state_ = Error; + onError(std::make_shared<TLSError>()); + } } bool OpenSSLContext::setClientCertificate(CertificateWithKey::ref certificate) { - boost::shared_ptr<PKCS12Certificate> pkcs12Certificate = boost::dynamic_pointer_cast<PKCS12Certificate>(certificate); - if (!pkcs12Certificate || pkcs12Certificate->isNull()) { - return false; - } - - // Create a PKCS12 structure - BIO* bio = BIO_new(BIO_s_mem()); - BIO_write(bio, vecptr(pkcs12Certificate->getData()), pkcs12Certificate->getData().size()); - boost::shared_ptr<PKCS12> pkcs12(d2i_PKCS12_bio(bio, NULL), PKCS12_free); - BIO_free(bio); - if (!pkcs12) { - return false; - } - - // Parse PKCS12 - X509 *certPtr = 0; - EVP_PKEY* privateKeyPtr = 0; - STACK_OF(X509)* caCertsPtr = 0; - SafeByteArray password(pkcs12Certificate->getPassword()); - password.push_back(0); - int result = PKCS12_parse(pkcs12.get(), reinterpret_cast<const char*>(vecptr(password)), &privateKeyPtr, &certPtr, &caCertsPtr); - if (result != 1) { - return false; - } - boost::shared_ptr<X509> cert(certPtr, X509_free); - boost::shared_ptr<EVP_PKEY> privateKey(privateKeyPtr, EVP_PKEY_free); - boost::shared_ptr<STACK_OF(X509)> caCerts(caCertsPtr, freeX509Stack); - - // Use the key & certificates - if (SSL_CTX_use_certificate(context_, cert.get()) != 1) { - return false; - } - if (SSL_CTX_use_PrivateKey(context_, privateKey.get()) != 1) { - return false; - } - for (int i = 0; i < sk_X509_num(caCerts.get()); ++i) { - SSL_CTX_add_extra_chain_cert(context_, sk_X509_value(caCerts.get(), i)); - } - return true; + std::shared_ptr<PKCS12Certificate> pkcs12Certificate = std::dynamic_pointer_cast<PKCS12Certificate>(certificate); + if (!pkcs12Certificate || pkcs12Certificate->isNull()) { + return false; + } + + // Create a PKCS12 structure + BIO* bio = BIO_new(BIO_s_mem()); + BIO_write(bio, vecptr(pkcs12Certificate->getData()), pkcs12Certificate->getData().size()); + std::shared_ptr<PKCS12> pkcs12(d2i_PKCS12_bio(bio, NULL), PKCS12_free); + BIO_free(bio); + if (!pkcs12) { + return false; + } + + // Parse PKCS12 + X509 *certPtr = 0; + EVP_PKEY* privateKeyPtr = 0; + STACK_OF(X509)* caCertsPtr = 0; + SafeByteArray password(pkcs12Certificate->getPassword()); + password.push_back(0); + int result = PKCS12_parse(pkcs12.get(), reinterpret_cast<const char*>(vecptr(password)), &privateKeyPtr, &certPtr, &caCertsPtr); + if (result != 1) { + return false; + } + std::shared_ptr<X509> cert(certPtr, X509_free); + std::shared_ptr<EVP_PKEY> privateKey(privateKeyPtr, EVP_PKEY_free); + std::shared_ptr<STACK_OF(X509)> caCerts(caCertsPtr, freeX509Stack); + + // Use the key & certificates + if (SSL_CTX_use_certificate(context_, cert.get()) != 1) { + return false; + } + if (SSL_CTX_use_PrivateKey(context_, privateKey.get()) != 1) { + return false; + } + for (int i = 0; i < sk_X509_num(caCerts.get()); ++i) { + SSL_CTX_add_extra_chain_cert(context_, sk_X509_value(caCerts.get(), i)); + } + return true; } std::vector<Certificate::ref> OpenSSLContext::getPeerCertificateChain() const { - std::vector<Certificate::ref> result; - STACK_OF(X509)* chain = SSL_get_peer_cert_chain(handle_); - for (int i = 0; i < sk_X509_num(chain); ++i) { - boost::shared_ptr<X509> x509Cert(X509_dup(sk_X509_value(chain, i)), X509_free); - - Certificate::ref cert = boost::make_shared<OpenSSLCertificate>(x509Cert); - result.push_back(cert); - } - return result; + std::vector<Certificate::ref> result; + STACK_OF(X509)* chain = SSL_get_peer_cert_chain(handle_); + for (int i = 0; i < sk_X509_num(chain); ++i) { + std::shared_ptr<X509> x509Cert(X509_dup(sk_X509_value(chain, i)), X509_free); + + Certificate::ref cert = std::make_shared<OpenSSLCertificate>(x509Cert); + result.push_back(cert); + } + return result; } -boost::shared_ptr<CertificateVerificationError> OpenSSLContext::getPeerCertificateVerificationError() const { - int verifyResult = SSL_get_verify_result(handle_); - if (verifyResult != X509_V_OK) { - return boost::make_shared<CertificateVerificationError>(getVerificationErrorTypeForResult(verifyResult)); - } - else { - return boost::shared_ptr<CertificateVerificationError>(); - } +std::shared_ptr<CertificateVerificationError> OpenSSLContext::getPeerCertificateVerificationError() const { + int verifyResult = SSL_get_verify_result(handle_); + if (verifyResult != X509_V_OK) { + return std::make_shared<CertificateVerificationError>(getVerificationErrorTypeForResult(verifyResult)); + } + else { + return std::shared_ptr<CertificateVerificationError>(); + } } ByteArray OpenSSLContext::getFinishMessage() const { - ByteArray data; - data.resize(MAX_FINISHED_SIZE); - size_t size = SSL_get_finished(handle_, vecptr(data), data.size()); - data.resize(size); - return data; + ByteArray data; + data.resize(MAX_FINISHED_SIZE); + size_t size = SSL_get_finished(handle_, vecptr(data), data.size()); + data.resize(size); + return data; } CertificateVerificationError::Type OpenSSLContext::getVerificationErrorTypeForResult(int result) { - assert(result != 0); - switch (result) { - case X509_V_ERR_CERT_NOT_YET_VALID: - case X509_V_ERR_ERROR_IN_CERT_NOT_BEFORE_FIELD: - return CertificateVerificationError::NotYetValid; - - case X509_V_ERR_CERT_HAS_EXPIRED: - case X509_V_ERR_ERROR_IN_CERT_NOT_AFTER_FIELD: - return CertificateVerificationError::Expired; - - case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT: - case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN: - return CertificateVerificationError::SelfSigned; - - case X509_V_ERR_CERT_UNTRUSTED: - return CertificateVerificationError::Untrusted; - - case X509_V_ERR_CERT_REJECTED: - return CertificateVerificationError::Rejected; - - case X509_V_ERR_INVALID_PURPOSE: - return CertificateVerificationError::InvalidPurpose; - - case X509_V_ERR_PATH_LENGTH_EXCEEDED: - return CertificateVerificationError::PathLengthExceeded; - - case X509_V_ERR_UNABLE_TO_DECRYPT_CERT_SIGNATURE: - case X509_V_ERR_CERT_SIGNATURE_FAILURE: - case X509_V_ERR_UNABLE_TO_VERIFY_LEAF_SIGNATURE: - return CertificateVerificationError::InvalidSignature; - - case X509_V_ERR_INVALID_CA: - case X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT: - case X509_V_ERR_UNABLE_TO_DECODE_ISSUER_PUBLIC_KEY: - case X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY: - return CertificateVerificationError::InvalidCA; - - case X509_V_ERR_SUBJECT_ISSUER_MISMATCH: - case X509_V_ERR_AKID_SKID_MISMATCH: - case X509_V_ERR_AKID_ISSUER_SERIAL_MISMATCH: - case X509_V_ERR_KEYUSAGE_NO_CERTSIGN: - return CertificateVerificationError::UnknownError; - - // Unused / should not happen - case X509_V_ERR_CERT_REVOKED: - case X509_V_ERR_OUT_OF_MEM: - case X509_V_ERR_UNABLE_TO_GET_CRL: - case X509_V_ERR_UNABLE_TO_DECRYPT_CRL_SIGNATURE: - case X509_V_ERR_CRL_SIGNATURE_FAILURE: - case X509_V_ERR_CRL_NOT_YET_VALID: - case X509_V_ERR_CRL_HAS_EXPIRED: - case X509_V_ERR_ERROR_IN_CRL_LAST_UPDATE_FIELD: - case X509_V_ERR_ERROR_IN_CRL_NEXT_UPDATE_FIELD: - case X509_V_ERR_CERT_CHAIN_TOO_LONG: - case X509_V_ERR_APPLICATION_VERIFICATION: - default: - return CertificateVerificationError::UnknownError; - } + assert(result != 0); + switch (result) { + case X509_V_ERR_CERT_NOT_YET_VALID: + case X509_V_ERR_ERROR_IN_CERT_NOT_BEFORE_FIELD: + return CertificateVerificationError::NotYetValid; + + case X509_V_ERR_CERT_HAS_EXPIRED: + case X509_V_ERR_ERROR_IN_CERT_NOT_AFTER_FIELD: + return CertificateVerificationError::Expired; + + case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT: + case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN: + return CertificateVerificationError::SelfSigned; + + case X509_V_ERR_CERT_UNTRUSTED: + return CertificateVerificationError::Untrusted; + + case X509_V_ERR_CERT_REJECTED: + return CertificateVerificationError::Rejected; + + case X509_V_ERR_INVALID_PURPOSE: + return CertificateVerificationError::InvalidPurpose; + + case X509_V_ERR_PATH_LENGTH_EXCEEDED: + return CertificateVerificationError::PathLengthExceeded; + + case X509_V_ERR_UNABLE_TO_DECRYPT_CERT_SIGNATURE: + case X509_V_ERR_CERT_SIGNATURE_FAILURE: + case X509_V_ERR_UNABLE_TO_VERIFY_LEAF_SIGNATURE: + return CertificateVerificationError::InvalidSignature; + + case X509_V_ERR_INVALID_CA: + case X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT: + case X509_V_ERR_UNABLE_TO_DECODE_ISSUER_PUBLIC_KEY: + case X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY: + return CertificateVerificationError::InvalidCA; + + case X509_V_ERR_SUBJECT_ISSUER_MISMATCH: + case X509_V_ERR_AKID_SKID_MISMATCH: + case X509_V_ERR_AKID_ISSUER_SERIAL_MISMATCH: + case X509_V_ERR_KEYUSAGE_NO_CERTSIGN: + return CertificateVerificationError::UnknownError; + + // Unused / should not happen + case X509_V_ERR_CERT_REVOKED: + case X509_V_ERR_OUT_OF_MEM: + case X509_V_ERR_UNABLE_TO_GET_CRL: + case X509_V_ERR_UNABLE_TO_DECRYPT_CRL_SIGNATURE: + case X509_V_ERR_CRL_SIGNATURE_FAILURE: + case X509_V_ERR_CRL_NOT_YET_VALID: + case X509_V_ERR_CRL_HAS_EXPIRED: + case X509_V_ERR_ERROR_IN_CRL_LAST_UPDATE_FIELD: + case X509_V_ERR_ERROR_IN_CRL_NEXT_UPDATE_FIELD: + case X509_V_ERR_CERT_CHAIN_TOO_LONG: + case X509_V_ERR_APPLICATION_VERIFICATION: + default: + return CertificateVerificationError::UnknownError; + } } } diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.h b/Swiften/TLS/OpenSSL/OpenSSLContext.h index 73fe75c..e75b3c9 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLContext.h +++ b/Swiften/TLS/OpenSSL/OpenSSLContext.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2010-2015 Isode Limited. + * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ @@ -7,48 +7,48 @@ #pragma once #include <boost/noncopyable.hpp> +#include <boost/signals2.hpp> #include <openssl/ssl.h> #include <Swiften/Base/ByteArray.h> -#include <Swiften/Base/boost_bsignals.h> #include <Swiften/TLS/CertificateWithKey.h> #include <Swiften/TLS/TLSContext.h> namespace Swift { - class OpenSSLContext : public TLSContext, boost::noncopyable { - public: - OpenSSLContext(); - virtual ~OpenSSLContext(); + class OpenSSLContext : public TLSContext, boost::noncopyable { + public: + OpenSSLContext(); + virtual ~OpenSSLContext(); - void connect(); - bool setClientCertificate(CertificateWithKey::ref cert); + void connect(); + bool setClientCertificate(CertificateWithKey::ref cert); - void handleDataFromNetwork(const SafeByteArray&); - void handleDataFromApplication(const SafeByteArray&); + void handleDataFromNetwork(const SafeByteArray&); + void handleDataFromApplication(const SafeByteArray&); - std::vector<Certificate::ref> getPeerCertificateChain() const; - boost::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const; + std::vector<Certificate::ref> getPeerCertificateChain() const; + std::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const; - virtual ByteArray getFinishMessage() const; + virtual ByteArray getFinishMessage() const; - private: - static void ensureLibraryInitialized(); + private: + static void ensureLibraryInitialized(); - static CertificateVerificationError::Type getVerificationErrorTypeForResult(int); + static CertificateVerificationError::Type getVerificationErrorTypeForResult(int); - void doConnect(); - void sendPendingDataToNetwork(); - void sendPendingDataToApplication(); + void doConnect(); + void sendPendingDataToNetwork(); + void sendPendingDataToApplication(); - private: - enum State { Start, Connecting, Connected, Error }; + private: + enum State { Start, Connecting, Connected, Error }; - State state_; - SSL_CTX* context_; - SSL* handle_; - BIO* readBIO_; - BIO* writeBIO_; - }; + State state_; + SSL_CTX* context_; + SSL* handle_; + BIO* readBIO_; + BIO* writeBIO_; + }; } diff --git a/Swiften/TLS/OpenSSL/OpenSSLContextFactory.cpp b/Swiften/TLS/OpenSSL/OpenSSLContextFactory.cpp index 4981170..9f7b2aa 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLContextFactory.cpp +++ b/Swiften/TLS/OpenSSL/OpenSSLContextFactory.cpp @@ -1,34 +1,35 @@ /* - * Copyright (c) 2010-2015 Isode Limited. + * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #include <Swiften/TLS/OpenSSL/OpenSSLContextFactory.h> -#include <Swiften/TLS/OpenSSL/OpenSSLContext.h> + #include <Swiften/Base/Log.h> +#include <Swiften/TLS/OpenSSL/OpenSSLContext.h> namespace Swift { bool OpenSSLContextFactory::canCreate() const { - return true; + return true; } TLSContext* OpenSSLContextFactory::createTLSContext(const TLSOptions&) { - return new OpenSSLContext(); + return new OpenSSLContext(); } void OpenSSLContextFactory::setCheckCertificateRevocation(bool check) { - if (check) { - SWIFT_LOG(warning) << "CRL Checking not supported for OpenSSL" << std::endl; - assert(false); - } + if (check) { + SWIFT_LOG(warning) << "CRL Checking not supported for OpenSSL" << std::endl; + assert(false); + } } void OpenSSLContextFactory::setDisconnectOnCardRemoval(bool check) { - if (check) { - SWIFT_LOG(warning) << "Smart cards not supported for OpenSSL" << std::endl; - } + if (check) { + SWIFT_LOG(warning) << "Smart cards not supported for OpenSSL" << std::endl; + } } diff --git a/Swiften/TLS/OpenSSL/OpenSSLContextFactory.h b/Swiften/TLS/OpenSSL/OpenSSLContextFactory.h index 89033ad..e121a1a 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLContextFactory.h +++ b/Swiften/TLS/OpenSSL/OpenSSLContextFactory.h @@ -1,23 +1,23 @@ /* - * Copyright (c) 2010-2015 Isode Limited. + * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #pragma once -#include <Swiften/TLS/TLSContextFactory.h> - #include <cassert> +#include <Swiften/TLS/TLSContextFactory.h> + namespace Swift { - class OpenSSLContextFactory : public TLSContextFactory { - public: - bool canCreate() const; - virtual TLSContext* createTLSContext(const TLSOptions& tlsOptions); + class OpenSSLContextFactory : public TLSContextFactory { + public: + bool canCreate() const; + virtual TLSContext* createTLSContext(const TLSOptions& tlsOptions); - // Not supported - virtual void setCheckCertificateRevocation(bool b); - virtual void setDisconnectOnCardRemoval(bool b); - }; + // Not supported + virtual void setCheckCertificateRevocation(bool b); + virtual void setDisconnectOnCardRemoval(bool b); + }; } diff --git a/Swiften/TLS/PKCS12Certificate.h b/Swiften/TLS/PKCS12Certificate.h index 0fd3f56..4ed5040 100644 --- a/Swiften/TLS/PKCS12Certificate.h +++ b/Swiften/TLS/PKCS12Certificate.h @@ -1,61 +1,62 @@ /* - * Copyright (c) 2010-2015 Isode Limited. + * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #pragma once +#include <boost/filesystem/path.hpp> + #include <Swiften/Base/API.h> #include <Swiften/Base/SafeByteArray.h> #include <Swiften/TLS/CertificateWithKey.h> -#include <boost/filesystem/path.hpp> namespace Swift { - class SWIFTEN_API PKCS12Certificate : public Swift::CertificateWithKey { - public: - PKCS12Certificate() {} + class SWIFTEN_API PKCS12Certificate : public Swift::CertificateWithKey { + public: + PKCS12Certificate() {} - PKCS12Certificate(const boost::filesystem::path& filename, const SafeByteArray& password) : password_(password) { - readByteArrayFromFile(data_, filename); - } + PKCS12Certificate(const boost::filesystem::path& filename, const SafeByteArray& password) : password_(password) { + readByteArrayFromFile(data_, filename); + } - virtual ~PKCS12Certificate() {} + virtual ~PKCS12Certificate() {} - virtual bool isNull() const { - return data_.empty(); - } + virtual bool isNull() const { + return data_.empty(); + } - virtual bool isPrivateKeyExportable() const { + virtual bool isPrivateKeyExportable() const { /////Hopefully a PKCS12 is never missing a private key - return true; - } - - virtual const std::string& getCertStoreName() const { -///// assert(0); - throw std::exception(); - } - - virtual const std::string& getCertName() const { - /* We can return the original filename instead, if we care */ -///// assert(0); - throw std::exception(); - } - - virtual const ByteArray& getData() const { - return data_; - } - - void setData(const ByteArray& data) { - data_ = data; - } - - virtual const SafeByteArray& getPassword() const { - return password_; - } - - private: - ByteArray data_; - SafeByteArray password_; - }; + return true; + } + + virtual const std::string& getCertStoreName() const { +///// assert(0); + throw std::exception(); + } + + virtual const std::string& getCertName() const { + /* We can return the original filename instead, if we care */ +///// assert(0); + throw std::exception(); + } + + virtual const ByteArray& getData() const { + return data_; + } + + void setData(const ByteArray& data) { + data_ = data; + } + + virtual const SafeByteArray& getPassword() const { + return password_; + } + + private: + ByteArray data_; + SafeByteArray password_; + }; } diff --git a/Swiften/TLS/PlatformTLSFactories.cpp b/Swiften/TLS/PlatformTLSFactories.cpp index 588e0e1..81f560b 100644 --- a/Swiften/TLS/PlatformTLSFactories.cpp +++ b/Swiften/TLS/PlatformTLSFactories.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2010-2015 Isode Limited. + * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ @@ -9,48 +9,47 @@ #include <Swiften/Base/Platform.h> #include <Swiften/TLS/CertificateFactory.h> #include <Swiften/TLS/TLSContextFactory.h> - #ifdef HAVE_OPENSSL - #include <Swiften/TLS/OpenSSL/OpenSSLContextFactory.h> - #include <Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.h> + #include <Swiften/TLS/OpenSSL/OpenSSLContextFactory.h> + #include <Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.h> #endif #ifdef HAVE_SCHANNEL - #include <Swiften/TLS/Schannel/SchannelContextFactory.h> - #include <Swiften/TLS/Schannel/SchannelCertificateFactory.h> + #include <Swiften/TLS/Schannel/SchannelContextFactory.h> + #include <Swiften/TLS/Schannel/SchannelCertificateFactory.h> #endif #ifdef HAVE_SECURETRANSPORT - #include <Swiften/TLS/SecureTransport/SecureTransportContextFactory.h> - #include <Swiften/TLS/SecureTransport/SecureTransportCertificateFactory.h> + #include <Swiften/TLS/SecureTransport/SecureTransportContextFactory.h> + #include <Swiften/TLS/SecureTransport/SecureTransportCertificateFactory.h> #endif namespace Swift { -PlatformTLSFactories::PlatformTLSFactories() : contextFactory(NULL), certificateFactory(NULL) { +PlatformTLSFactories::PlatformTLSFactories() : contextFactory(nullptr), certificateFactory(nullptr) { #ifdef HAVE_OPENSSL - contextFactory = new OpenSSLContextFactory(); - certificateFactory = new OpenSSLCertificateFactory(); + contextFactory = new OpenSSLContextFactory(); + certificateFactory = new OpenSSLCertificateFactory(); #endif #ifdef HAVE_SCHANNEL - contextFactory = new SchannelContextFactory(); - certificateFactory = new SchannelCertificateFactory(); + contextFactory = new SchannelContextFactory(); + certificateFactory = new SchannelCertificateFactory(); #endif #ifdef HAVE_SECURETRANSPORT - contextFactory = new SecureTransportContextFactory(); - certificateFactory = new SecureTransportCertificateFactory(); -#endif + contextFactory = new SecureTransportContextFactory(); + certificateFactory = new SecureTransportCertificateFactory(); +#endif } PlatformTLSFactories::~PlatformTLSFactories() { - delete contextFactory; - delete certificateFactory; + delete contextFactory; + delete certificateFactory; } TLSContextFactory* PlatformTLSFactories::getTLSContextFactory() const { - return contextFactory; + return contextFactory; } CertificateFactory* PlatformTLSFactories::getCertificateFactory() const { - return certificateFactory; + return certificateFactory; } } diff --git a/Swiften/TLS/PlatformTLSFactories.h b/Swiften/TLS/PlatformTLSFactories.h index df23b32..3821521 100644 --- a/Swiften/TLS/PlatformTLSFactories.h +++ b/Swiften/TLS/PlatformTLSFactories.h @@ -9,19 +9,19 @@ #include <Swiften/Base/API.h> namespace Swift { - class TLSContextFactory; - class CertificateFactory; + class TLSContextFactory; + class CertificateFactory; - class SWIFTEN_API PlatformTLSFactories { - public: - PlatformTLSFactories(); - ~PlatformTLSFactories(); + class SWIFTEN_API PlatformTLSFactories { + public: + PlatformTLSFactories(); + ~PlatformTLSFactories(); - TLSContextFactory* getTLSContextFactory() const; - CertificateFactory* getCertificateFactory() const; + TLSContextFactory* getTLSContextFactory() const; + CertificateFactory* getCertificateFactory() const; - private: - TLSContextFactory* contextFactory; - CertificateFactory* certificateFactory; - }; + private: + TLSContextFactory* contextFactory; + CertificateFactory* certificateFactory; + }; } diff --git a/Swiften/TLS/SConscript b/Swiften/TLS/SConscript index f5eb053..68bf50b 100644 --- a/Swiften/TLS/SConscript +++ b/Swiften/TLS/SConscript @@ -1,43 +1,41 @@ Import("swiften_env") objects = swiften_env.SwiftenObject([ - "Certificate.cpp", - "CertificateFactory.cpp", - "CertificateTrustChecker.cpp", - "ServerIdentityVerifier.cpp", - "TLSContext.cpp", - "TLSContextFactory.cpp", - ]) - + "Certificate.cpp", + "CertificateFactory.cpp", + "CertificateTrustChecker.cpp", + "ServerIdentityVerifier.cpp", + "TLSContext.cpp", + "TLSContextFactory.cpp", + ]) + myenv = swiften_env.Clone() if myenv.get("HAVE_OPENSSL", 0) : - myenv.MergeFlags(myenv["OPENSSL_FLAGS"]) - objects += myenv.SwiftenObject([ - "OpenSSL/OpenSSLContext.cpp", - "OpenSSL/OpenSSLCertificate.cpp", - "OpenSSL/OpenSSLContextFactory.cpp", - ]) - myenv.Append(CPPDEFINES = "HAVE_OPENSSL") + myenv.MergeFlags(myenv["OPENSSL_FLAGS"]) + objects += myenv.SwiftenObject([ + "OpenSSL/OpenSSLContext.cpp", + "OpenSSL/OpenSSLCertificate.cpp", + "OpenSSL/OpenSSLContextFactory.cpp", + ]) + myenv.Append(CPPDEFINES = "HAVE_OPENSSL") elif myenv.get("HAVE_SCHANNEL", 0) : - swiften_env.Append(LIBS = ["Winscard"]) - objects += myenv.StaticObject([ - "CAPICertificate.cpp", - "Schannel/SchannelContext.cpp", - "Schannel/SchannelCertificate.cpp", - "Schannel/SchannelContextFactory.cpp", - ]) - myenv.Append(CPPDEFINES = "HAVE_SCHANNEL") + swiften_env.Append(LIBS = ["Winscard"]) + objects += myenv.SwiftenObject([ + "CAPICertificate.cpp", + "Schannel/SchannelContext.cpp", + "Schannel/SchannelCertificate.cpp", + "Schannel/SchannelContextFactory.cpp", + ]) + myenv.Append(CPPDEFINES = "HAVE_SCHANNEL") elif myenv.get("HAVE_SECURETRANSPORT", 0) : - #swiften_env.Append(LIBS = ["Winscard"]) - objects += myenv.StaticObject([ - "SecureTransport/SecureTransportContext.mm", - "SecureTransport/SecureTransportCertificate.mm", - "SecureTransport/SecureTransportContextFactory.cpp", - ]) - myenv.Append(CPPDEFINES = "HAVE_SECURETRANSPORT") + #swiften_env.Append(LIBS = ["Winscard"]) + objects += myenv.SwiftenObject([ + "SecureTransport/SecureTransportContext.mm", + "SecureTransport/SecureTransportCertificate.mm", + "SecureTransport/SecureTransportContextFactory.cpp", + ]) + myenv.Append(CPPDEFINES = "HAVE_SECURETRANSPORT") objects += myenv.SwiftenObject(["PlatformTLSFactories.cpp"]) - - swiften_env.Append(SWIFTEN_OBJECTS = [objects]) diff --git a/Swiften/TLS/Schannel/SchannelCertificate.cpp b/Swiften/TLS/Schannel/SchannelCertificate.cpp index 8aaec00..23c2479 100644 --- a/Swiften/TLS/Schannel/SchannelCertificate.cpp +++ b/Swiften/TLS/Schannel/SchannelCertificate.cpp @@ -4,6 +4,12 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #include "Swiften/TLS/Schannel/SchannelCertificate.h" #include "Swiften/Base/ByteArray.h" @@ -20,176 +26,176 @@ namespace Swift { //------------------------------------------------------------------------ -SchannelCertificate::SchannelCertificate(const ScopedCertContext& certCtxt) -: m_cert(certCtxt) +SchannelCertificate::SchannelCertificate(const ScopedCertContext& certCtxt) +: m_cert(certCtxt) { - parse(); + parse(); } //------------------------------------------------------------------------ SchannelCertificate::SchannelCertificate(const ByteArray& der) { - if (!der.empty()) - { - // Convert the DER encoded certificate to a PCERT_CONTEXT - CERT_BLOB certBlob = {0}; - certBlob.cbData = der.size(); - certBlob.pbData = (BYTE*)&der[0]; - - if (!CryptQueryObject( - CERT_QUERY_OBJECT_BLOB, - &certBlob, - CERT_QUERY_CONTENT_FLAG_CERT, - CERT_QUERY_FORMAT_FLAG_ALL, - 0, - NULL, - NULL, - NULL, - NULL, - NULL, - (const void**)m_cert.Reset())) - { - // TODO: Because Swiften isn't exception safe, we have no way to indicate failure - } - } + if (!der.empty()) + { + // Convert the DER encoded certificate to a PCERT_CONTEXT + CERT_BLOB certBlob = {0}; + certBlob.cbData = der.size(); + certBlob.pbData = (BYTE*)&der[0]; + + if (!CryptQueryObject( + CERT_QUERY_OBJECT_BLOB, + &certBlob, + CERT_QUERY_CONTENT_FLAG_CERT, + CERT_QUERY_FORMAT_FLAG_ALL, + 0, + NULL, + NULL, + NULL, + NULL, + NULL, + (const void**)m_cert.Reset())) + { + // TODO: Because Swiften isn't exception safe, we have no way to indicate failure + } + } } //------------------------------------------------------------------------ -ByteArray SchannelCertificate::toDER() const +ByteArray SchannelCertificate::toDER() const { - ByteArray result; + ByteArray result; + + // Serialize the certificate. The CERT_CONTEXT is already DER encoded. + result.resize(m_cert->cbCertEncoded); + memcpy(&result[0], m_cert->pbCertEncoded, result.size()); - // Serialize the certificate. The CERT_CONTEXT is already DER encoded. - result.resize(m_cert->cbCertEncoded); - memcpy(&result[0], m_cert->pbCertEncoded, result.size()); - - return result; + return result; } //------------------------------------------------------------------------ std::string SchannelCertificate::wstrToStr(const std::wstring& wstr) { - if (wstr.empty()) - return ""; + if (wstr.empty()) + return ""; - // First request the size of the required UTF-8 buffer - int numRequiredBytes = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), wstr.size(), NULL, 0, NULL, NULL); - if (!numRequiredBytes) - return ""; + // First request the size of the required UTF-8 buffer + int numRequiredBytes = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), wstr.size(), NULL, 0, NULL, NULL); + if (!numRequiredBytes) + return ""; - // Allocate memory for the UTF-8 string - std::vector<char> utf8Str(numRequiredBytes); + // Allocate memory for the UTF-8 string + std::vector<char> utf8Str(numRequiredBytes); - int numConverted = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), wstr.size(), &utf8Str[0], numRequiredBytes, NULL, NULL); - if (!numConverted) - return ""; + int numConverted = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), wstr.size(), &utf8Str[0], numRequiredBytes, NULL, NULL); + if (!numConverted) + return ""; - std::string str(&utf8Str[0], numConverted); - return str; + std::string str(&utf8Str[0], numConverted); + return str; } //------------------------------------------------------------------------ -void SchannelCertificate::parse() +void SchannelCertificate::parse() { - // - // Subject name - // - DWORD requiredSize = CertNameToStr(X509_ASN_ENCODING, &m_cert->pCertInfo->Subject, CERT_OID_NAME_STR, NULL, 0); - if (requiredSize > 1) - { - vector<char> rawSubjectName(requiredSize); - CertNameToStr(X509_ASN_ENCODING, &m_cert->pCertInfo->Subject, CERT_OID_NAME_STR, &rawSubjectName[0], rawSubjectName.size()); - m_subjectName = std::string(&rawSubjectName[0]); - } - - // - // Common name - // - // Note: We only pull out one common name from the cert. - requiredSize = CertGetNameString(m_cert, CERT_NAME_ATTR_TYPE, 0, szOID_COMMON_NAME, NULL, 0); - if (requiredSize > 1) - { - vector<char> rawCommonName(requiredSize); - requiredSize = CertGetNameString(m_cert, CERT_NAME_ATTR_TYPE, 0, szOID_COMMON_NAME, &rawCommonName[0], rawCommonName.size()); - m_commonNames.push_back( std::string(&rawCommonName[0]) ); - } - - // - // Subject alternative names - // - PCERT_EXTENSION pExtensions = CertFindExtension(szOID_SUBJECT_ALT_NAME2, m_cert->pCertInfo->cExtension, m_cert->pCertInfo->rgExtension); - if (pExtensions) - { - CRYPT_DECODE_PARA decodePara = {0}; - decodePara.cbSize = sizeof(decodePara); - - CERT_ALT_NAME_INFO* pAltNameInfo = NULL; - DWORD altNameInfoSize = 0; - - BOOL status = CryptDecodeObjectEx( - X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, - szOID_SUBJECT_ALT_NAME2, - pExtensions->Value.pbData, - pExtensions->Value.cbData, - CRYPT_DECODE_ALLOC_FLAG | CRYPT_DECODE_NOCOPY_FLAG, - &decodePara, - &pAltNameInfo, - &altNameInfoSize); - - if (status && pAltNameInfo) - { - for (int i = 0; i < pAltNameInfo->cAltEntry; i++) - { - if (pAltNameInfo->rgAltEntry[i].dwAltNameChoice == CERT_ALT_NAME_DNS_NAME) - addDNSName( wstrToStr( pAltNameInfo->rgAltEntry[i].pwszDNSName ) ); - } - } - } - - // if (pExtensions) - // { - // vector<wchar_t> subjectAlt - // CryptDecodeObject(X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, szOID_SUBJECT_ALT_NAME, pExtensions->Value->pbData, pExtensions->Value->cbData, ) - // } - // - // // subjectAltNames - // int subjectAltNameLoc = X509_get_ext_by_NID(cert.get(), NID_subject_alt_name, -1); - // if(subjectAltNameLoc != -1) { - // X509_EXTENSION* extension = X509_get_ext(cert.get(), subjectAltNameLoc); - // boost::shared_ptr<GENERAL_NAMES> generalNames(reinterpret_cast<GENERAL_NAMES*>(X509V3_EXT_d2i(extension)), GENERAL_NAMES_free); - // boost::shared_ptr<ASN1_OBJECT> xmppAddrObject(OBJ_txt2obj(ID_ON_XMPPADDR_OID, 1), ASN1_OBJECT_free); - // boost::shared_ptr<ASN1_OBJECT> dnsSRVObject(OBJ_txt2obj(ID_ON_DNSSRV_OID, 1), ASN1_OBJECT_free); - // for (int i = 0; i < sk_GENERAL_NAME_num(generalNames.get()); ++i) { - // GENERAL_NAME* generalName = sk_GENERAL_NAME_value(generalNames.get(), i); - // if (generalName->type == GEN_OTHERNAME) { - // OTHERNAME* otherName = generalName->d.otherName; - // if (OBJ_cmp(otherName->type_id, xmppAddrObject.get()) == 0) { - // // XmppAddr - // if (otherName->value->type != V_ASN1_UTF8STRING) { - // continue; - // } - // ASN1_UTF8STRING* xmppAddrValue = otherName->value->value.utf8string; - // addXMPPAddress(ByteArray(ASN1_STRING_data(xmppAddrValue), ASN1_STRING_length(xmppAddrValue)).toString()); - // } - // else if (OBJ_cmp(otherName->type_id, dnsSRVObject.get()) == 0) { - // // SRVName - // if (otherName->value->type != V_ASN1_IA5STRING) { - // continue; - // } - // ASN1_IA5STRING* srvNameValue = otherName->value->value.ia5string; - // addSRVName(ByteArray(ASN1_STRING_data(srvNameValue), ASN1_STRING_length(srvNameValue)).toString()); - // } - // } - // else if (generalName->type == GEN_DNS) { - // // DNSName - // addDNSName(ByteArray(ASN1_STRING_data(generalName->d.dNSName), ASN1_STRING_length(generalName->d.dNSName)).toString()); - // } - // } - // } + // + // Subject name + // + DWORD requiredSize = CertNameToStr(X509_ASN_ENCODING, &m_cert->pCertInfo->Subject, CERT_OID_NAME_STR, NULL, 0); + if (requiredSize > 1) + { + vector<char> rawSubjectName(requiredSize); + CertNameToStr(X509_ASN_ENCODING, &m_cert->pCertInfo->Subject, CERT_OID_NAME_STR, &rawSubjectName[0], rawSubjectName.size()); + m_subjectName = std::string(&rawSubjectName[0]); + } + + // + // Common name + // + // Note: We only pull out one common name from the cert. + requiredSize = CertGetNameString(m_cert, CERT_NAME_ATTR_TYPE, 0, szOID_COMMON_NAME, NULL, 0); + if (requiredSize > 1) + { + vector<char> rawCommonName(requiredSize); + requiredSize = CertGetNameString(m_cert, CERT_NAME_ATTR_TYPE, 0, szOID_COMMON_NAME, &rawCommonName[0], rawCommonName.size()); + m_commonNames.push_back( std::string(&rawCommonName[0]) ); + } + + // + // Subject alternative names + // + PCERT_EXTENSION pExtensions = CertFindExtension(szOID_SUBJECT_ALT_NAME2, m_cert->pCertInfo->cExtension, m_cert->pCertInfo->rgExtension); + if (pExtensions) + { + CRYPT_DECODE_PARA decodePara = {0}; + decodePara.cbSize = sizeof(decodePara); + + CERT_ALT_NAME_INFO* pAltNameInfo = NULL; + DWORD altNameInfoSize = 0; + + BOOL status = CryptDecodeObjectEx( + X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, + szOID_SUBJECT_ALT_NAME2, + pExtensions->Value.pbData, + pExtensions->Value.cbData, + CRYPT_DECODE_ALLOC_FLAG | CRYPT_DECODE_NOCOPY_FLAG, + &decodePara, + &pAltNameInfo, + &altNameInfoSize); + + if (status && pAltNameInfo) + { + for (int i = 0; i < pAltNameInfo->cAltEntry; i++) + { + if (pAltNameInfo->rgAltEntry[i].dwAltNameChoice == CERT_ALT_NAME_DNS_NAME) + addDNSName( wstrToStr( pAltNameInfo->rgAltEntry[i].pwszDNSName ) ); + } + } + } + + // if (pExtensions) + // { + // vector<wchar_t> subjectAlt + // CryptDecodeObject(X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, szOID_SUBJECT_ALT_NAME, pExtensions->Value->pbData, pExtensions->Value->cbData, ) + // } + // + // // subjectAltNames + // int subjectAltNameLoc = X509_get_ext_by_NID(cert.get(), NID_subject_alt_name, -1); + // if(subjectAltNameLoc != -1) { + // X509_EXTENSION* extension = X509_get_ext(cert.get(), subjectAltNameLoc); + // std::shared_ptr<GENERAL_NAMES> generalNames(reinterpret_cast<GENERAL_NAMES*>(X509V3_EXT_d2i(extension)), GENERAL_NAMES_free); + // std::shared_ptr<ASN1_OBJECT> xmppAddrObject(OBJ_txt2obj(ID_ON_XMPPADDR_OID, 1), ASN1_OBJECT_free); + // std::shared_ptr<ASN1_OBJECT> dnsSRVObject(OBJ_txt2obj(ID_ON_DNSSRV_OID, 1), ASN1_OBJECT_free); + // for (int i = 0; i < sk_GENERAL_NAME_num(generalNames.get()); ++i) { + // GENERAL_NAME* generalName = sk_GENERAL_NAME_value(generalNames.get(), i); + // if (generalName->type == GEN_OTHERNAME) { + // OTHERNAME* otherName = generalName->d.otherName; + // if (OBJ_cmp(otherName->type_id, xmppAddrObject.get()) == 0) { + // // XmppAddr + // if (otherName->value->type != V_ASN1_UTF8STRING) { + // continue; + // } + // ASN1_UTF8STRING* xmppAddrValue = otherName->value->value.utf8string; + // addXMPPAddress(ByteArray(ASN1_STRING_data(xmppAddrValue), ASN1_STRING_length(xmppAddrValue)).toString()); + // } + // else if (OBJ_cmp(otherName->type_id, dnsSRVObject.get()) == 0) { + // // SRVName + // if (otherName->value->type != V_ASN1_IA5STRING) { + // continue; + // } + // ASN1_IA5STRING* srvNameValue = otherName->value->value.ia5string; + // addSRVName(ByteArray(ASN1_STRING_data(srvNameValue), ASN1_STRING_length(srvNameValue)).toString()); + // } + // } + // else if (generalName->type == GEN_DNS) { + // // DNSName + // addDNSName(ByteArray(ASN1_STRING_data(generalName->d.dNSName), ASN1_STRING_length(generalName->d.dNSName)).toString()); + // } + // } + // } } //------------------------------------------------------------------------ diff --git a/Swiften/TLS/Schannel/SchannelCertificate.h b/Swiften/TLS/Schannel/SchannelCertificate.h index 395d3ec..d3bd66c 100644 --- a/Swiften/TLS/Schannel/SchannelCertificate.h +++ b/Swiften/TLS/Schannel/SchannelCertificate.h @@ -4,83 +4,89 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #pragma once -#include <boost/shared_ptr.hpp> +#include <memory> -#include "Swiften/Base/String.h" -#include "Swiften/TLS/Certificate.h" -#include "Swiften/TLS/Schannel/SchannelUtil.h" +#include <Swiften/Base/String.h> +#include <Swiften/TLS/Certificate.h> +#include <Swiften/TLS/Schannel/SchannelUtil.h> -namespace Swift +namespace Swift { - class SchannelCertificate : public Certificate - { - public: - typedef boost::shared_ptr<SchannelCertificate> ref; - - public: - SchannelCertificate(const ScopedCertContext& certCtxt); - SchannelCertificate(const ByteArray& der); - - std::string getSubjectName() const - { - return m_subjectName; - } - - std::vector<std::string> getCommonNames() const - { - return m_commonNames; - } - - std::vector<std::string> getSRVNames() const - { - return m_srvNames; - } - - std::vector<std::string> getDNSNames() const - { - return m_dnsNames; - } - - std::vector<std::string> getXMPPAddresses() const - { - return m_xmppAddresses; - } - - ScopedCertContext getCertContext() const - { - return m_cert; - } - - ByteArray toDER() const; - - private: - void parse(); - std::string wstrToStr(const std::wstring& wstr); - - void addSRVName(const std::string& name) - { - m_srvNames.push_back(name); - } - - void addDNSName(const std::string& name) - { - m_dnsNames.push_back(name); - } - - void addXMPPAddress(const std::string& addr) - { - m_xmppAddresses.push_back(addr); - } - - private: - ScopedCertContext m_cert; - - std::string m_subjectName; - std::vector<std::string> m_commonNames; - std::vector<std::string> m_dnsNames; - std::vector<std::string> m_xmppAddresses; - std::vector<std::string> m_srvNames; - }; + class SchannelCertificate : public Certificate + { + public: + typedef std::shared_ptr<SchannelCertificate> ref; + + public: + SchannelCertificate(const ScopedCertContext& certCtxt); + SchannelCertificate(const ByteArray& der); + + std::string getSubjectName() const + { + return m_subjectName; + } + + std::vector<std::string> getCommonNames() const + { + return m_commonNames; + } + + std::vector<std::string> getSRVNames() const + { + return m_srvNames; + } + + std::vector<std::string> getDNSNames() const + { + return m_dnsNames; + } + + std::vector<std::string> getXMPPAddresses() const + { + return m_xmppAddresses; + } + + ScopedCertContext getCertContext() const + { + return m_cert; + } + + ByteArray toDER() const; + + private: + void parse(); + std::string wstrToStr(const std::wstring& wstr); + + void addSRVName(const std::string& name) + { + m_srvNames.push_back(name); + } + + void addDNSName(const std::string& name) + { + m_dnsNames.push_back(name); + } + + void addXMPPAddress(const std::string& addr) + { + m_xmppAddresses.push_back(addr); + } + + private: + ScopedCertContext m_cert; + + std::string m_subjectName; + std::vector<std::string> m_commonNames; + std::vector<std::string> m_dnsNames; + std::vector<std::string> m_xmppAddresses; + std::vector<std::string> m_srvNames; + }; } diff --git a/Swiften/TLS/Schannel/SchannelCertificateFactory.h b/Swiften/TLS/Schannel/SchannelCertificateFactory.h index 5a2b208..be97c52 100644 --- a/Swiften/TLS/Schannel/SchannelCertificateFactory.h +++ b/Swiften/TLS/Schannel/SchannelCertificateFactory.h @@ -10,10 +10,10 @@ #include <Swiften/TLS/Schannel/SchannelCertificate.h> namespace Swift { - class SchannelCertificateFactory : public CertificateFactory { - public: - virtual Certificate* createCertificateFromDER(const ByteArray& der) { - return new SchannelCertificate(der); - } - }; + class SchannelCertificateFactory : public CertificateFactory { + public: + virtual Certificate* createCertificateFromDER(const ByteArray& der) { + return new SchannelCertificate(der); + } + }; } diff --git a/Swiften/TLS/Schannel/SchannelContext.cpp b/Swiften/TLS/Schannel/SchannelContext.cpp index 62aa137..5799157 100644 --- a/Swiften/TLS/Schannel/SchannelContext.cpp +++ b/Swiften/TLS/Schannel/SchannelContext.cpp @@ -24,671 +24,671 @@ namespace Swift { //------------------------------------------------------------------------ SchannelContext::SchannelContext(bool tls1_0Workaround) : state_(Start), secContext_(0), myCertStore_(NULL), certStoreName_("MY"), certName_(), smartCardReader_(), checkCertificateRevocation_(true), tls1_0Workaround_(tls1_0Workaround), disconnectOnCardRemoval_(true) { - contextFlags_ = ISC_REQ_ALLOCATE_MEMORY | - ISC_REQ_CONFIDENTIALITY | - ISC_REQ_EXTENDED_ERROR | - ISC_REQ_INTEGRITY | - ISC_REQ_REPLAY_DETECT | - ISC_REQ_SEQUENCE_DETECT | - ISC_REQ_USE_SUPPLIED_CREDS | - ISC_REQ_STREAM; - - ZeroMemory(&streamSizes_, sizeof(streamSizes_)); + contextFlags_ = ISC_REQ_ALLOCATE_MEMORY | + ISC_REQ_CONFIDENTIALITY | + ISC_REQ_EXTENDED_ERROR | + ISC_REQ_INTEGRITY | + ISC_REQ_REPLAY_DETECT | + ISC_REQ_SEQUENCE_DETECT | + ISC_REQ_USE_SUPPLIED_CREDS | + ISC_REQ_STREAM; + + ZeroMemory(&streamSizes_, sizeof(streamSizes_)); } //------------------------------------------------------------------------ SchannelContext::~SchannelContext() { - if (myCertStore_) CertCloseStore(myCertStore_, 0); + if (myCertStore_) CertCloseStore(myCertStore_, 0); } //------------------------------------------------------------------------ void SchannelContext::determineStreamSizes() { - QueryContextAttributes(contextHandle_, SECPKG_ATTR_STREAM_SIZES, &streamSizes_); + QueryContextAttributes(contextHandle_, SECPKG_ATTR_STREAM_SIZES, &streamSizes_); } //------------------------------------------------------------------------ void SchannelContext::connect() { - ScopedCertContext pCertContext; - - state_ = Connecting; - - // If a user name is specified, then attempt to find a client - // certificate. Otherwise, just create a NULL credential. - if (!certName_.empty()) { - if (myCertStore_ == NULL) { - myCertStore_ = CertOpenSystemStore(0, certStoreName_.c_str()); - if (!myCertStore_) { - indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); - return; - } - } - - pCertContext = findCertificateInStore( myCertStore_, certName_ ); - if (pCertContext == NULL) { - indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); - return; - } - } - - // We use an empty list for client certificates - PCCERT_CONTEXT clientCerts[1] = {0}; - - SCHANNEL_CRED sc = {0}; - sc.dwVersion = SCHANNEL_CRED_VERSION; - - if (tls1_0Workaround_) { - sc.grbitEnabledProtocols = SP_PROT_TLS1_CLIENT; - } - else { - sc.grbitEnabledProtocols = /*SP_PROT_SSL3_CLIENT | */SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_2_CLIENT; - } - - sc.dwFlags = SCH_CRED_MANUAL_CRED_VALIDATION; - - if (pCertContext) { - sc.cCreds = 1; - sc.paCred = pCertContext.GetPointer(); - sc.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS; - } - else { - sc.cCreds = 0; - sc.paCred = clientCerts; - sc.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS; - } - - // Swiften performs the server name check for us - sc.dwFlags |= SCH_CRED_NO_SERVERNAME_CHECK; - - SECURITY_STATUS status = AcquireCredentialsHandle( - NULL, - UNISP_NAME, - SECPKG_CRED_OUTBOUND, - NULL, - &sc, - NULL, - NULL, - credHandle_.Reset(), - NULL); - - if (status != SEC_E_OK) { - // We failed to obtain the credentials handle - indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); - return; - } - - SecBuffer outBuffers[2]; - - // We let Schannel allocate the output buffer for us - outBuffers[0].pvBuffer = NULL; - outBuffers[0].cbBuffer = 0; - outBuffers[0].BufferType = SECBUFFER_TOKEN; - - // Contains alert data if an alert is generated - outBuffers[1].pvBuffer = NULL; - outBuffers[1].cbBuffer = 0; - outBuffers[1].BufferType = SECBUFFER_ALERT; - - // Make sure the output buffers are freed - ScopedSecBuffer scopedOutputData(&outBuffers[0]); - ScopedSecBuffer scopedOutputAlertData(&outBuffers[1]); - - SecBufferDesc outBufferDesc = {0}; - outBufferDesc.cBuffers = 2; - outBufferDesc.pBuffers = outBuffers; - outBufferDesc.ulVersion = SECBUFFER_VERSION; - - // Create the initial security context - status = InitializeSecurityContext( - credHandle_, - NULL, - NULL, - contextFlags_, - 0, - 0, - NULL, - 0, - contextHandle_.Reset(), - &outBufferDesc, - &secContext_, - NULL); - - if (status != SEC_E_OK && status != SEC_I_CONTINUE_NEEDED) { - // We failed to initialize the security context - handleCertError(status); - indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); - return; - } - - // Start the handshake - sendDataOnNetwork(outBuffers[0].pvBuffer, outBuffers[0].cbBuffer); - - if (status == SEC_E_OK) { - status = validateServerCertificate(); - if (status != SEC_E_OK) { - handleCertError(status); - } - - state_ = Connected; - determineStreamSizes(); - - onConnected(); - } + ScopedCertContext pCertContext; + + state_ = Connecting; + + // If a user name is specified, then attempt to find a client + // certificate. Otherwise, just create a NULL credential. + if (!certName_.empty()) { + if (myCertStore_ == NULL) { + myCertStore_ = CertOpenSystemStore(0, certStoreName_.c_str()); + if (!myCertStore_) { + indicateError(std::make_shared<TLSError>(TLSError::UnknownError)); + return; + } + } + + pCertContext = findCertificateInStore( myCertStore_, certName_ ); + if (pCertContext == NULL) { + indicateError(std::make_shared<TLSError>(TLSError::UnknownError)); + return; + } + } + + // We use an empty list for client certificates + PCCERT_CONTEXT clientCerts[1] = {0}; + + SCHANNEL_CRED sc = {0}; + sc.dwVersion = SCHANNEL_CRED_VERSION; + + if (tls1_0Workaround_) { + sc.grbitEnabledProtocols = SP_PROT_TLS1_CLIENT; + } + else { + sc.grbitEnabledProtocols = /*SP_PROT_SSL3_CLIENT | */SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_2_CLIENT; + } + + sc.dwFlags = SCH_CRED_MANUAL_CRED_VALIDATION; + + if (pCertContext) { + sc.cCreds = 1; + sc.paCred = pCertContext.GetPointer(); + sc.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS; + } + else { + sc.cCreds = 0; + sc.paCred = clientCerts; + sc.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS; + } + + // Swiften performs the server name check for us + sc.dwFlags |= SCH_CRED_NO_SERVERNAME_CHECK; + + SECURITY_STATUS status = AcquireCredentialsHandle( + NULL, + UNISP_NAME, + SECPKG_CRED_OUTBOUND, + NULL, + &sc, + NULL, + NULL, + credHandle_.Reset(), + NULL); + + if (status != SEC_E_OK) { + // We failed to obtain the credentials handle + indicateError(std::make_shared<TLSError>(TLSError::UnknownError)); + return; + } + + SecBuffer outBuffers[2]; + + // We let Schannel allocate the output buffer for us + outBuffers[0].pvBuffer = NULL; + outBuffers[0].cbBuffer = 0; + outBuffers[0].BufferType = SECBUFFER_TOKEN; + + // Contains alert data if an alert is generated + outBuffers[1].pvBuffer = NULL; + outBuffers[1].cbBuffer = 0; + outBuffers[1].BufferType = SECBUFFER_ALERT; + + // Make sure the output buffers are freed + ScopedSecBuffer scopedOutputData(&outBuffers[0]); + ScopedSecBuffer scopedOutputAlertData(&outBuffers[1]); + + SecBufferDesc outBufferDesc = {0}; + outBufferDesc.cBuffers = 2; + outBufferDesc.pBuffers = outBuffers; + outBufferDesc.ulVersion = SECBUFFER_VERSION; + + // Create the initial security context + status = InitializeSecurityContext( + credHandle_, + NULL, + NULL, + contextFlags_, + 0, + 0, + NULL, + 0, + contextHandle_.Reset(), + &outBufferDesc, + &secContext_, + NULL); + + if (status != SEC_E_OK && status != SEC_I_CONTINUE_NEEDED) { + // We failed to initialize the security context + handleCertError(status); + indicateError(std::make_shared<TLSError>(TLSError::UnknownError)); + return; + } + + // Start the handshake + sendDataOnNetwork(outBuffers[0].pvBuffer, outBuffers[0].cbBuffer); + + if (status == SEC_E_OK) { + status = validateServerCertificate(); + if (status != SEC_E_OK) { + handleCertError(status); + } + + state_ = Connected; + determineStreamSizes(); + + onConnected(); + } } //------------------------------------------------------------------------ SECURITY_STATUS SchannelContext::validateServerCertificate() { - SchannelCertificate::ref pServerCert = boost::dynamic_pointer_cast<SchannelCertificate>( getPeerCertificate() ); - if (!pServerCert) { - return SEC_E_WRONG_PRINCIPAL; - } - - const LPSTR usage[] = - { - szOID_PKIX_KP_SERVER_AUTH, - szOID_SERVER_GATED_CRYPTO, - szOID_SGC_NETSCAPE - }; - - CERT_CHAIN_PARA chainParams = {0}; - chainParams.cbSize = sizeof(chainParams); - chainParams.RequestedUsage.dwType = USAGE_MATCH_TYPE_OR; - chainParams.RequestedUsage.Usage.cUsageIdentifier = ARRAYSIZE(usage); - chainParams.RequestedUsage.Usage.rgpszUsageIdentifier = const_cast<LPSTR*>(usage); - - DWORD chainFlags = CERT_CHAIN_CACHE_END_CERT; - if (checkCertificateRevocation_) { - chainFlags |= CERT_CHAIN_REVOCATION_CHECK_CHAIN_EXCLUDE_ROOT; - } - - ScopedCertChainContext pChainContext; - - BOOL success = CertGetCertificateChain( - NULL, // Use the chain engine for the current user (assumes a user is logged in) - pServerCert->getCertContext(), - NULL, - pServerCert->getCertContext()->hCertStore, - &chainParams, - chainFlags, - NULL, - pChainContext.Reset()); - - if (!success) { - return GetLastError(); - } - - SSL_EXTRA_CERT_CHAIN_POLICY_PARA sslChainPolicy = {0}; - sslChainPolicy.cbSize = sizeof(sslChainPolicy); - sslChainPolicy.dwAuthType = AUTHTYPE_SERVER; - sslChainPolicy.fdwChecks = SECURITY_FLAG_IGNORE_CERT_CN_INVALID; // Swiften checks the server name for us. Is this the correct way to disable server name checking? - sslChainPolicy.pwszServerName = NULL; - - CERT_CHAIN_POLICY_PARA certChainPolicy = {0}; - certChainPolicy.cbSize = sizeof(certChainPolicy); - certChainPolicy.dwFlags = CERT_CHAIN_POLICY_IGNORE_INVALID_NAME_FLAG; // Swiften checks the server name for us. Is this the correct way to disable server name checking? - certChainPolicy.pvExtraPolicyPara = &sslChainPolicy; - - CERT_CHAIN_POLICY_STATUS certChainPolicyStatus = {0}; - certChainPolicyStatus.cbSize = sizeof(certChainPolicyStatus); - - // Verify the chain - if (!CertVerifyCertificateChainPolicy( - CERT_CHAIN_POLICY_SSL, - pChainContext, - &certChainPolicy, - &certChainPolicyStatus)) { - return GetLastError(); - } - - if (certChainPolicyStatus.dwError != S_OK) { - return certChainPolicyStatus.dwError; - } - - return S_OK; + SchannelCertificate::ref pServerCert = std::dynamic_pointer_cast<SchannelCertificate>( getPeerCertificate() ); + if (!pServerCert) { + return SEC_E_WRONG_PRINCIPAL; + } + + const LPSTR usage[] = + { + szOID_PKIX_KP_SERVER_AUTH, + szOID_SERVER_GATED_CRYPTO, + szOID_SGC_NETSCAPE + }; + + CERT_CHAIN_PARA chainParams = {0}; + chainParams.cbSize = sizeof(chainParams); + chainParams.RequestedUsage.dwType = USAGE_MATCH_TYPE_OR; + chainParams.RequestedUsage.Usage.cUsageIdentifier = ARRAYSIZE(usage); + chainParams.RequestedUsage.Usage.rgpszUsageIdentifier = const_cast<LPSTR*>(usage); + + DWORD chainFlags = CERT_CHAIN_CACHE_END_CERT; + if (checkCertificateRevocation_) { + chainFlags |= CERT_CHAIN_REVOCATION_CHECK_CHAIN_EXCLUDE_ROOT; + } + + ScopedCertChainContext pChainContext; + + BOOL success = CertGetCertificateChain( + NULL, // Use the chain engine for the current user (assumes a user is logged in) + pServerCert->getCertContext(), + NULL, + pServerCert->getCertContext()->hCertStore, + &chainParams, + chainFlags, + NULL, + pChainContext.Reset()); + + if (!success) { + return GetLastError(); + } + + SSL_EXTRA_CERT_CHAIN_POLICY_PARA sslChainPolicy = {0}; + sslChainPolicy.cbSize = sizeof(sslChainPolicy); + sslChainPolicy.dwAuthType = AUTHTYPE_SERVER; + sslChainPolicy.fdwChecks = SECURITY_FLAG_IGNORE_CERT_CN_INVALID; // Swiften checks the server name for us. Is this the correct way to disable server name checking? + sslChainPolicy.pwszServerName = NULL; + + CERT_CHAIN_POLICY_PARA certChainPolicy = {0}; + certChainPolicy.cbSize = sizeof(certChainPolicy); + certChainPolicy.dwFlags = CERT_CHAIN_POLICY_IGNORE_INVALID_NAME_FLAG; // Swiften checks the server name for us. Is this the correct way to disable server name checking? + certChainPolicy.pvExtraPolicyPara = &sslChainPolicy; + + CERT_CHAIN_POLICY_STATUS certChainPolicyStatus = {0}; + certChainPolicyStatus.cbSize = sizeof(certChainPolicyStatus); + + // Verify the chain + if (!CertVerifyCertificateChainPolicy( + CERT_CHAIN_POLICY_SSL, + pChainContext, + &certChainPolicy, + &certChainPolicyStatus)) { + return GetLastError(); + } + + if (certChainPolicyStatus.dwError != S_OK) { + return certChainPolicyStatus.dwError; + } + + return S_OK; } //------------------------------------------------------------------------ void SchannelContext::appendNewData(const SafeByteArray& data) { - size_t originalSize = receivedData_.size(); - receivedData_.resize(originalSize + data.size()); - memcpy(&receivedData_[0] + originalSize, &data[0], data.size()); + size_t originalSize = receivedData_.size(); + receivedData_.resize(originalSize + data.size()); + memcpy(&receivedData_[0] + originalSize, &data[0], data.size()); } //------------------------------------------------------------------------ void SchannelContext::continueHandshake(const SafeByteArray& data) { - appendNewData(data); - - while (!receivedData_.empty()) { - SecBuffer inBuffers[2]; - - // Provide Schannel with the remote host's handshake data - inBuffers[0].pvBuffer = (char*)(&receivedData_[0]); - inBuffers[0].cbBuffer = (unsigned long)receivedData_.size(); - inBuffers[0].BufferType = SECBUFFER_TOKEN; - - inBuffers[1].pvBuffer = NULL; - inBuffers[1].cbBuffer = 0; - inBuffers[1].BufferType = SECBUFFER_EMPTY; - - SecBufferDesc inBufferDesc = {0}; - inBufferDesc.cBuffers = 2; - inBufferDesc.pBuffers = inBuffers; - inBufferDesc.ulVersion = SECBUFFER_VERSION; - - SecBuffer outBuffers[2]; - - // We let Schannel allocate the output buffer for us - outBuffers[0].pvBuffer = NULL; - outBuffers[0].cbBuffer = 0; - outBuffers[0].BufferType = SECBUFFER_TOKEN; - - // Contains alert data if an alert is generated - outBuffers[1].pvBuffer = NULL; - outBuffers[1].cbBuffer = 0; - outBuffers[1].BufferType = SECBUFFER_ALERT; - - // Make sure the output buffers are freed - ScopedSecBuffer scopedOutputData(&outBuffers[0]); - ScopedSecBuffer scopedOutputAlertData(&outBuffers[1]); - - SecBufferDesc outBufferDesc = {0}; - outBufferDesc.cBuffers = 2; - outBufferDesc.pBuffers = outBuffers; - outBufferDesc.ulVersion = SECBUFFER_VERSION; - - SECURITY_STATUS status = InitializeSecurityContext( - credHandle_, - contextHandle_, - NULL, - contextFlags_, - 0, - 0, - &inBufferDesc, - 0, - NULL, - &outBufferDesc, - &secContext_, - NULL); - - if (status == SEC_E_INCOMPLETE_MESSAGE) { - // Wait for more data to arrive - break; - } - else if (status == SEC_I_CONTINUE_NEEDED) { - SecBuffer* pDataBuffer = &outBuffers[0]; - SecBuffer* pExtraBuffer = &inBuffers[1]; - - if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL) { - sendDataOnNetwork(pDataBuffer->pvBuffer, pDataBuffer->cbBuffer); - } - - if (pExtraBuffer->BufferType == SECBUFFER_EXTRA) { - receivedData_.erase(receivedData_.begin(), receivedData_.end() - pExtraBuffer->cbBuffer); - } - else { - receivedData_.clear(); - } - - break; - } - else if (status == SEC_E_OK) { - status = validateServerCertificate(); - if (status != SEC_E_OK) { - handleCertError(status); - } - - SecBuffer* pExtraBuffer = &inBuffers[1]; - - if (pExtraBuffer && pExtraBuffer->cbBuffer > 0) { - receivedData_.erase(receivedData_.begin(), receivedData_.end() - pExtraBuffer->cbBuffer); - } - else { - receivedData_.clear(); - } - - state_ = Connected; - determineStreamSizes(); - - onConnected(); - } - else { - // We failed to initialize the security context - handleCertError(status); - indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); - return; - } - } + appendNewData(data); + + while (!receivedData_.empty()) { + SecBuffer inBuffers[2]; + + // Provide Schannel with the remote host's handshake data + inBuffers[0].pvBuffer = (char*)(&receivedData_[0]); + inBuffers[0].cbBuffer = (unsigned long)receivedData_.size(); + inBuffers[0].BufferType = SECBUFFER_TOKEN; + + inBuffers[1].pvBuffer = NULL; + inBuffers[1].cbBuffer = 0; + inBuffers[1].BufferType = SECBUFFER_EMPTY; + + SecBufferDesc inBufferDesc = {0}; + inBufferDesc.cBuffers = 2; + inBufferDesc.pBuffers = inBuffers; + inBufferDesc.ulVersion = SECBUFFER_VERSION; + + SecBuffer outBuffers[2]; + + // We let Schannel allocate the output buffer for us + outBuffers[0].pvBuffer = NULL; + outBuffers[0].cbBuffer = 0; + outBuffers[0].BufferType = SECBUFFER_TOKEN; + + // Contains alert data if an alert is generated + outBuffers[1].pvBuffer = NULL; + outBuffers[1].cbBuffer = 0; + outBuffers[1].BufferType = SECBUFFER_ALERT; + + // Make sure the output buffers are freed + ScopedSecBuffer scopedOutputData(&outBuffers[0]); + ScopedSecBuffer scopedOutputAlertData(&outBuffers[1]); + + SecBufferDesc outBufferDesc = {0}; + outBufferDesc.cBuffers = 2; + outBufferDesc.pBuffers = outBuffers; + outBufferDesc.ulVersion = SECBUFFER_VERSION; + + SECURITY_STATUS status = InitializeSecurityContext( + credHandle_, + contextHandle_, + NULL, + contextFlags_, + 0, + 0, + &inBufferDesc, + 0, + NULL, + &outBufferDesc, + &secContext_, + NULL); + + if (status == SEC_E_INCOMPLETE_MESSAGE) { + // Wait for more data to arrive + break; + } + else if (status == SEC_I_CONTINUE_NEEDED) { + SecBuffer* pDataBuffer = &outBuffers[0]; + SecBuffer* pExtraBuffer = &inBuffers[1]; + + if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL) { + sendDataOnNetwork(pDataBuffer->pvBuffer, pDataBuffer->cbBuffer); + } + + if (pExtraBuffer->BufferType == SECBUFFER_EXTRA) { + receivedData_.erase(receivedData_.begin(), receivedData_.end() - pExtraBuffer->cbBuffer); + } + else { + receivedData_.clear(); + } + + break; + } + else if (status == SEC_E_OK) { + status = validateServerCertificate(); + if (status != SEC_E_OK) { + handleCertError(status); + } + + SecBuffer* pExtraBuffer = &inBuffers[1]; + + if (pExtraBuffer && pExtraBuffer->cbBuffer > 0) { + receivedData_.erase(receivedData_.begin(), receivedData_.end() - pExtraBuffer->cbBuffer); + } + else { + receivedData_.clear(); + } + + state_ = Connected; + determineStreamSizes(); + + onConnected(); + } + else { + // We failed to initialize the security context + handleCertError(status); + indicateError(std::make_shared<TLSError>(TLSError::UnknownError)); + return; + } + } } //------------------------------------------------------------------------ void SchannelContext::handleCertError(SECURITY_STATUS status) { - if (status == SEC_E_UNTRUSTED_ROOT || - status == CERT_E_UNTRUSTEDROOT || - status == CRYPT_E_ISSUER_SERIALNUMBER || - status == CRYPT_E_SIGNER_NOT_FOUND || - status == CRYPT_E_NO_TRUSTED_SIGNER) { - verificationError_ = CertificateVerificationError::Untrusted; - } - else if (status == SEC_E_CERT_EXPIRED || - status == CERT_E_EXPIRED) { - verificationError_ = CertificateVerificationError::Expired; - } - else if (status == CRYPT_E_SELF_SIGNED) { - verificationError_ = CertificateVerificationError::SelfSigned; - } - else if (status == CRYPT_E_HASH_VALUE || - status == TRUST_E_CERT_SIGNATURE) { - verificationError_ = CertificateVerificationError::InvalidSignature; - } - else if (status == CRYPT_E_REVOKED) { - verificationError_ = CertificateVerificationError::Revoked; - } - else if (status == CRYPT_E_NO_REVOCATION_CHECK || - status == CRYPT_E_REVOCATION_OFFLINE) { - verificationError_ = CertificateVerificationError::RevocationCheckFailed; - } - else if (status == CERT_E_WRONG_USAGE) { - verificationError_ = CertificateVerificationError::InvalidPurpose; - } - else { - verificationError_ = CertificateVerificationError::UnknownError; - } + if (status == SEC_E_UNTRUSTED_ROOT || + status == CERT_E_UNTRUSTEDROOT || + status == CRYPT_E_ISSUER_SERIALNUMBER || + status == CRYPT_E_SIGNER_NOT_FOUND || + status == CRYPT_E_NO_TRUSTED_SIGNER) { + verificationError_ = CertificateVerificationError::Untrusted; + } + else if (status == SEC_E_CERT_EXPIRED || + status == CERT_E_EXPIRED) { + verificationError_ = CertificateVerificationError::Expired; + } + else if (status == CRYPT_E_SELF_SIGNED) { + verificationError_ = CertificateVerificationError::SelfSigned; + } + else if (status == CRYPT_E_HASH_VALUE || + status == TRUST_E_CERT_SIGNATURE) { + verificationError_ = CertificateVerificationError::InvalidSignature; + } + else if (status == CRYPT_E_REVOKED) { + verificationError_ = CertificateVerificationError::Revoked; + } + else if (status == CRYPT_E_NO_REVOCATION_CHECK || + status == CRYPT_E_REVOCATION_OFFLINE) { + verificationError_ = CertificateVerificationError::RevocationCheckFailed; + } + else if (status == CERT_E_WRONG_USAGE) { + verificationError_ = CertificateVerificationError::InvalidPurpose; + } + else { + verificationError_ = CertificateVerificationError::UnknownError; + } } //------------------------------------------------------------------------ void SchannelContext::sendDataOnNetwork(const void* pData, size_t dataSize) { - if (dataSize > 0 && pData) { - SafeByteArray byteArray(dataSize); - memcpy(&byteArray[0], pData, dataSize); + if (dataSize > 0 && pData) { + SafeByteArray byteArray(dataSize); + memcpy(&byteArray[0], pData, dataSize); - onDataForNetwork(byteArray); - } + onDataForNetwork(byteArray); + } } //------------------------------------------------------------------------ void SchannelContext::forwardDataToApplication(const void* pData, size_t dataSize) { - SafeByteArray byteArray(dataSize); - memcpy(&byteArray[0], pData, dataSize); + SafeByteArray byteArray(dataSize); + memcpy(&byteArray[0], pData, dataSize); - onDataForApplication(byteArray); + onDataForApplication(byteArray); } //------------------------------------------------------------------------ void SchannelContext::handleDataFromApplication(const SafeByteArray& data) { - // Don't attempt to send data until we're fully connected - if (state_ == Connecting) { - return; - } + // Don't attempt to send data until we're fully connected + if (state_ == Connecting) { + return; + } - // Encrypt the data - encryptAndSendData(data); + // Encrypt the data + encryptAndSendData(data); } //------------------------------------------------------------------------ void SchannelContext::handleDataFromNetwork(const SafeByteArray& data) { - switch (state_) { - case Connecting: - { - // We're still establishing the connection, so continue the handshake - continueHandshake(data); - } - break; - - case Connected: - { - // Decrypt the data - decryptAndProcessData(data); - } - break; - - default: - return; - } + switch (state_) { + case Connecting: + { + // We're still establishing the connection, so continue the handshake + continueHandshake(data); + } + break; + + case Connected: + { + // Decrypt the data + decryptAndProcessData(data); + } + break; + + default: + return; + } } //------------------------------------------------------------------------ -void SchannelContext::indicateError(boost::shared_ptr<TLSError> error) { - state_ = Error; - receivedData_.clear(); - onError(error); +void SchannelContext::indicateError(std::shared_ptr<TLSError> error) { + state_ = Error; + receivedData_.clear(); + onError(error); } //------------------------------------------------------------------------ void SchannelContext::decryptAndProcessData(const SafeByteArray& data) { - SecBuffer inBuffers[4] = {0}; - - appendNewData(data); - - while (!receivedData_.empty()) { - // - // MSDN: - // When using the Schannel SSP with contexts that are not connection oriented, on input, - // the structure must contain four SecBuffer structures. Exactly one buffer must be of type - // SECBUFFER_DATA and contain an encrypted message, which is decrypted in place. The remaining - // buffers are used for output and must be of type SECBUFFER_EMPTY. For connection-oriented - // contexts, a SECBUFFER_DATA type buffer must be supplied, as noted for nonconnection-oriented - // contexts. Additionally, a second SECBUFFER_TOKEN type buffer that contains a security token - // must also be supplied. - // - inBuffers[0].pvBuffer = (char*)(&receivedData_[0]); - inBuffers[0].cbBuffer = (unsigned long)receivedData_.size(); - inBuffers[0].BufferType = SECBUFFER_DATA; - - inBuffers[1].BufferType = SECBUFFER_EMPTY; - inBuffers[2].BufferType = SECBUFFER_EMPTY; - inBuffers[3].BufferType = SECBUFFER_EMPTY; - - SecBufferDesc inBufferDesc = {0}; - inBufferDesc.cBuffers = 4; - inBufferDesc.pBuffers = inBuffers; - inBufferDesc.ulVersion = SECBUFFER_VERSION; - - size_t inData = receivedData_.size(); - SECURITY_STATUS status = DecryptMessage(contextHandle_, &inBufferDesc, 0, NULL); - - if (status == SEC_E_INCOMPLETE_MESSAGE) { - // Wait for more data to arrive - break; - } - else if (status == SEC_I_RENEGOTIATE) { - // TODO: Handle renegotiation scenarios - indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); - break; - } - else if (status == SEC_I_CONTEXT_EXPIRED) { - indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); - break; - } - else if (status != SEC_E_OK) { - indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); - break; - } - - SecBuffer* pDataBuffer = NULL; - SecBuffer* pExtraBuffer = NULL; - for (int i = 0; i < 4; ++i) { - if (inBuffers[i].BufferType == SECBUFFER_DATA) { - pDataBuffer = &inBuffers[i]; - } - else if (inBuffers[i].BufferType == SECBUFFER_EXTRA) { - pExtraBuffer = &inBuffers[i]; - } - } - - if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL) { - forwardDataToApplication(pDataBuffer->pvBuffer, pDataBuffer->cbBuffer); - } - - // If there is extra data left over from the decryption operation, we call DecryptMessage() again - if (pExtraBuffer) { - receivedData_.erase(receivedData_.begin(), receivedData_.end() - pExtraBuffer->cbBuffer); - } - else { - // We're done - receivedData_.erase(receivedData_.begin(), receivedData_.begin() + inData); - } - } + SecBuffer inBuffers[4] = {0}; + + appendNewData(data); + + while (!receivedData_.empty()) { + // + // MSDN: + // When using the Schannel SSP with contexts that are not connection oriented, on input, + // the structure must contain four SecBuffer structures. Exactly one buffer must be of type + // SECBUFFER_DATA and contain an encrypted message, which is decrypted in place. The remaining + // buffers are used for output and must be of type SECBUFFER_EMPTY. For connection-oriented + // contexts, a SECBUFFER_DATA type buffer must be supplied, as noted for nonconnection-oriented + // contexts. Additionally, a second SECBUFFER_TOKEN type buffer that contains a security token + // must also be supplied. + // + inBuffers[0].pvBuffer = (char*)(&receivedData_[0]); + inBuffers[0].cbBuffer = (unsigned long)receivedData_.size(); + inBuffers[0].BufferType = SECBUFFER_DATA; + + inBuffers[1].BufferType = SECBUFFER_EMPTY; + inBuffers[2].BufferType = SECBUFFER_EMPTY; + inBuffers[3].BufferType = SECBUFFER_EMPTY; + + SecBufferDesc inBufferDesc = {0}; + inBufferDesc.cBuffers = 4; + inBufferDesc.pBuffers = inBuffers; + inBufferDesc.ulVersion = SECBUFFER_VERSION; + + size_t inData = receivedData_.size(); + SECURITY_STATUS status = DecryptMessage(contextHandle_, &inBufferDesc, 0, NULL); + + if (status == SEC_E_INCOMPLETE_MESSAGE) { + // Wait for more data to arrive + break; + } + else if (status == SEC_I_RENEGOTIATE) { + // TODO: Handle renegotiation scenarios + indicateError(std::make_shared<TLSError>(TLSError::UnknownError)); + break; + } + else if (status == SEC_I_CONTEXT_EXPIRED) { + indicateError(std::make_shared<TLSError>(TLSError::UnknownError)); + break; + } + else if (status != SEC_E_OK) { + indicateError(std::make_shared<TLSError>(TLSError::UnknownError)); + break; + } + + SecBuffer* pDataBuffer = NULL; + SecBuffer* pExtraBuffer = NULL; + for (int i = 0; i < 4; ++i) { + if (inBuffers[i].BufferType == SECBUFFER_DATA) { + pDataBuffer = &inBuffers[i]; + } + else if (inBuffers[i].BufferType == SECBUFFER_EXTRA) { + pExtraBuffer = &inBuffers[i]; + } + } + + if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL) { + forwardDataToApplication(pDataBuffer->pvBuffer, pDataBuffer->cbBuffer); + } + + // If there is extra data left over from the decryption operation, we call DecryptMessage() again + if (pExtraBuffer) { + receivedData_.erase(receivedData_.begin(), receivedData_.end() - pExtraBuffer->cbBuffer); + } + else { + // We're done + receivedData_.erase(receivedData_.begin(), receivedData_.begin() + inData); + } + } } //------------------------------------------------------------------------ void SchannelContext::encryptAndSendData(const SafeByteArray& data) { - if (streamSizes_.cbMaximumMessage == 0) { - return; - } + if (streamSizes_.cbMaximumMessage == 0) { + return; + } - SecBuffer outBuffers[4] = {0}; + SecBuffer outBuffers[4] = {0}; - // Calculate the largest required size of the send buffer - size_t messageBufferSize = (data.size() > streamSizes_.cbMaximumMessage) - ? streamSizes_.cbMaximumMessage - : data.size(); + // Calculate the largest required size of the send buffer + size_t messageBufferSize = (data.size() > streamSizes_.cbMaximumMessage) + ? streamSizes_.cbMaximumMessage + : data.size(); - // Allocate a packet for the encrypted data - SafeByteArray sendBuffer; - sendBuffer.resize(streamSizes_.cbHeader + messageBufferSize + streamSizes_.cbTrailer); + // Allocate a packet for the encrypted data + SafeByteArray sendBuffer; + sendBuffer.resize(streamSizes_.cbHeader + messageBufferSize + streamSizes_.cbTrailer); - size_t bytesSent = 0; - do { - size_t bytesLeftToSend = data.size() - bytesSent; + size_t bytesSent = 0; + do { + size_t bytesLeftToSend = data.size() - bytesSent; - // Calculate how much of the send buffer we'll be using for this chunk - size_t bytesToSend = (bytesLeftToSend > streamSizes_.cbMaximumMessage) - ? streamSizes_.cbMaximumMessage - : bytesLeftToSend; + // Calculate how much of the send buffer we'll be using for this chunk + size_t bytesToSend = (bytesLeftToSend > streamSizes_.cbMaximumMessage) + ? streamSizes_.cbMaximumMessage + : bytesLeftToSend; - // Copy the plain text data into the send buffer - memcpy(&sendBuffer[0] + streamSizes_.cbHeader, &data[0] + bytesSent, bytesToSend); + // Copy the plain text data into the send buffer + memcpy(&sendBuffer[0] + streamSizes_.cbHeader, &data[0] + bytesSent, bytesToSend); - outBuffers[0].pvBuffer = &sendBuffer[0]; - outBuffers[0].cbBuffer = streamSizes_.cbHeader; - outBuffers[0].BufferType = SECBUFFER_STREAM_HEADER; + outBuffers[0].pvBuffer = &sendBuffer[0]; + outBuffers[0].cbBuffer = streamSizes_.cbHeader; + outBuffers[0].BufferType = SECBUFFER_STREAM_HEADER; - outBuffers[1].pvBuffer = &sendBuffer[0] + streamSizes_.cbHeader; - outBuffers[1].cbBuffer = (unsigned long)bytesToSend; - outBuffers[1].BufferType = SECBUFFER_DATA; + outBuffers[1].pvBuffer = &sendBuffer[0] + streamSizes_.cbHeader; + outBuffers[1].cbBuffer = (unsigned long)bytesToSend; + outBuffers[1].BufferType = SECBUFFER_DATA; - outBuffers[2].pvBuffer = &sendBuffer[0] + streamSizes_.cbHeader + bytesToSend; - outBuffers[2].cbBuffer = streamSizes_.cbTrailer; - outBuffers[2].BufferType = SECBUFFER_STREAM_TRAILER; + outBuffers[2].pvBuffer = &sendBuffer[0] + streamSizes_.cbHeader + bytesToSend; + outBuffers[2].cbBuffer = streamSizes_.cbTrailer; + outBuffers[2].BufferType = SECBUFFER_STREAM_TRAILER; - outBuffers[3].pvBuffer = 0; - outBuffers[3].cbBuffer = 0; - outBuffers[3].BufferType = SECBUFFER_EMPTY; + outBuffers[3].pvBuffer = 0; + outBuffers[3].cbBuffer = 0; + outBuffers[3].BufferType = SECBUFFER_EMPTY; - SecBufferDesc outBufferDesc = {0}; - outBufferDesc.cBuffers = 4; - outBufferDesc.pBuffers = outBuffers; - outBufferDesc.ulVersion = SECBUFFER_VERSION; + SecBufferDesc outBufferDesc = {0}; + outBufferDesc.cBuffers = 4; + outBufferDesc.pBuffers = outBuffers; + outBufferDesc.ulVersion = SECBUFFER_VERSION; - SECURITY_STATUS status = EncryptMessage(contextHandle_, 0, &outBufferDesc, 0); - if (status != SEC_E_OK) { - indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); - return; - } + SECURITY_STATUS status = EncryptMessage(contextHandle_, 0, &outBufferDesc, 0); + if (status != SEC_E_OK) { + indicateError(std::make_shared<TLSError>(TLSError::UnknownError)); + return; + } - sendDataOnNetwork(&sendBuffer[0], outBuffers[0].cbBuffer + outBuffers[1].cbBuffer + outBuffers[2].cbBuffer); - bytesSent += bytesToSend; + sendDataOnNetwork(&sendBuffer[0], outBuffers[0].cbBuffer + outBuffers[1].cbBuffer + outBuffers[2].cbBuffer); + bytesSent += bytesToSend; - } while (bytesSent < data.size()); + } while (bytesSent < data.size()); } //------------------------------------------------------------------------ bool SchannelContext::setClientCertificate(CertificateWithKey::ref certificate) { - boost::shared_ptr<CAPICertificate> capiCertificate = boost::dynamic_pointer_cast<CAPICertificate>(certificate); - if (!capiCertificate || capiCertificate->isNull()) { - return false; - } + std::shared_ptr<CAPICertificate> capiCertificate = std::dynamic_pointer_cast<CAPICertificate>(certificate); + if (!capiCertificate || capiCertificate->isNull()) { + return false; + } - userCertificate_ = capiCertificate; + userCertificate_ = capiCertificate; - // We assume that the Certificate Store Name/Certificate Name - // are valid at this point - certStoreName_ = capiCertificate->getCertStoreName(); - certName_ = capiCertificate->getCertName(); + // We assume that the Certificate Store Name/Certificate Name + // are valid at this point + certStoreName_ = capiCertificate->getCertStoreName(); + certName_ = capiCertificate->getCertName(); ////At the moment this is only useful for logging: - smartCardReader_ = capiCertificate->getSmartCardReaderName(); + smartCardReader_ = capiCertificate->getSmartCardReaderName(); - capiCertificate->onCertificateCardRemoved.connect(boost::bind(&SchannelContext::handleCertificateCardRemoved, this)); + capiCertificate->onCertificateCardRemoved.connect(boost::bind(&SchannelContext::handleCertificateCardRemoved, this)); - return true; + return true; } //------------------------------------------------------------------------ void SchannelContext::handleCertificateCardRemoved() { - if (disconnectOnCardRemoval_) { - indicateError(boost::make_shared<TLSError>(TLSError::CertificateCardRemoved)); - } + if (disconnectOnCardRemoval_) { + indicateError(std::make_shared<TLSError>(TLSError::CertificateCardRemoved)); + } } //------------------------------------------------------------------------ std::vector<Certificate::ref> SchannelContext::getPeerCertificateChain() const { - std::vector<Certificate::ref> certificateChain; - ScopedCertContext pServerCert; - ScopedCertContext pIssuerCert; - ScopedCertContext pCurrentCert; - SECURITY_STATUS status = QueryContextAttributes(contextHandle_, SECPKG_ATTR_REMOTE_CERT_CONTEXT, pServerCert.Reset()); - - if (status != SEC_E_OK) { - return certificateChain; - } - certificateChain.push_back(boost::make_shared<SchannelCertificate>(pServerCert)); - - pCurrentCert = pServerCert; - while(pCurrentCert.GetPointer()) { - DWORD dwVerificationFlags = 0; - pIssuerCert = CertGetIssuerCertificateFromStore(pServerCert->hCertStore, pCurrentCert, NULL, &dwVerificationFlags ); - if (!(*pIssuerCert.GetPointer())) { - break; - } - certificateChain.push_back(boost::make_shared<SchannelCertificate>(pIssuerCert)); - - pCurrentCert = pIssuerCert; - pIssuerCert = NULL; - } - return certificateChain; + std::vector<Certificate::ref> certificateChain; + ScopedCertContext pServerCert; + ScopedCertContext pIssuerCert; + ScopedCertContext pCurrentCert; + SECURITY_STATUS status = QueryContextAttributes(contextHandle_, SECPKG_ATTR_REMOTE_CERT_CONTEXT, pServerCert.Reset()); + + if (status != SEC_E_OK) { + return certificateChain; + } + certificateChain.push_back(std::make_shared<SchannelCertificate>(pServerCert)); + + pCurrentCert = pServerCert; + while(pCurrentCert.GetPointer()) { + DWORD dwVerificationFlags = 0; + pIssuerCert = CertGetIssuerCertificateFromStore(pServerCert->hCertStore, pCurrentCert, NULL, &dwVerificationFlags ); + if (!(*pIssuerCert.GetPointer())) { + break; + } + certificateChain.push_back(std::make_shared<SchannelCertificate>(pIssuerCert)); + + pCurrentCert = pIssuerCert; + pIssuerCert = NULL; + } + return certificateChain; } //------------------------------------------------------------------------ CertificateVerificationError::ref SchannelContext::getPeerCertificateVerificationError() const { - return verificationError_ ? boost::make_shared<CertificateVerificationError>(*verificationError_) : CertificateVerificationError::ref(); + return verificationError_ ? std::make_shared<CertificateVerificationError>(*verificationError_) : CertificateVerificationError::ref(); } //------------------------------------------------------------------------ ByteArray SchannelContext::getFinishMessage() const { - SecPkgContext_Bindings bindings; - int ret = QueryContextAttributes(contextHandle_, SECPKG_ATTR_UNIQUE_BINDINGS, &bindings); - if (ret == SEC_E_OK) { - return createByteArray(((unsigned char*) bindings.Bindings) + bindings.Bindings->dwApplicationDataOffset + 11 /* tls-unique:*/, bindings.Bindings->cbApplicationDataLength - 11); - } - return ByteArray(); + SecPkgContext_Bindings bindings; + int ret = QueryContextAttributes(contextHandle_, SECPKG_ATTR_UNIQUE_BINDINGS, &bindings); + if (ret == SEC_E_OK) { + return createByteArray(((unsigned char*) bindings.Bindings) + bindings.Bindings->dwApplicationDataOffset + 11 /* tls-unique:*/, bindings.Bindings->cbApplicationDataLength - 11); + } + return ByteArray(); } //------------------------------------------------------------------------ void SchannelContext::setCheckCertificateRevocation(bool b) { - checkCertificateRevocation_ = b; + checkCertificateRevocation_ = b; } void SchannelContext::setDisconnectOnCardRemoval(bool b) { - disconnectOnCardRemoval_ = b; + disconnectOnCardRemoval_ = b; } diff --git a/Swiften/TLS/Schannel/SchannelContext.h b/Swiften/TLS/Schannel/SchannelContext.h index be30a7c..a3748fe 100644 --- a/Swiften/TLS/Schannel/SchannelContext.h +++ b/Swiften/TLS/Schannel/SchannelContext.h @@ -5,14 +5,14 @@ */ /* - * Copyright (c) 2012-2015 Isode Limited. + * Copyright (c) 2012-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #pragma once -#include <Swiften/Base/boost_bsignals.h> +#include <boost/signals2.hpp> #include <Swiften/TLS/TLSContext.h> #include <Swiften/TLS/Schannel/SchannelUtil.h> @@ -28,85 +28,85 @@ #include <boost/noncopyable.hpp> -namespace Swift -{ - class CAPICertificate; - class SchannelContext : public TLSContext, boost::noncopyable - { - public: - typedef boost::shared_ptr<SchannelContext> sp_t; +namespace Swift +{ + class CAPICertificate; + class SchannelContext : public TLSContext, boost::noncopyable + { + public: + typedef std::shared_ptr<SchannelContext> sp_t; - public: - SchannelContext(bool tls1_0Workaround); + public: + SchannelContext(bool tls1_0Workaround); - virtual ~SchannelContext(); + virtual ~SchannelContext(); - // - // TLSContext - // - virtual void connect(); - virtual bool setClientCertificate(CertificateWithKey::ref cert); + // + // TLSContext + // + virtual void connect(); + virtual bool setClientCertificate(CertificateWithKey::ref cert); - virtual void handleDataFromNetwork(const SafeByteArray& data); - virtual void handleDataFromApplication(const SafeByteArray& data); + virtual void handleDataFromNetwork(const SafeByteArray& data); + virtual void handleDataFromApplication(const SafeByteArray& data); - virtual std::vector<Certificate::ref> getPeerCertificateChain() const; - virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const; + virtual std::vector<Certificate::ref> getPeerCertificateChain() const; + virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const; - virtual ByteArray getFinishMessage() const; + virtual ByteArray getFinishMessage() const; - virtual void setCheckCertificateRevocation(bool b); + virtual void setCheckCertificateRevocation(bool b); - virtual void setDisconnectOnCardRemoval(bool b); + virtual void setDisconnectOnCardRemoval(bool b); - private: - void determineStreamSizes(); - void continueHandshake(const SafeByteArray& data); - void indicateError(boost::shared_ptr<TLSError> error); - //FIXME: Remove - void indicateError() {indicateError(boost::make_shared<TLSError>());} - void handleCertError(SECURITY_STATUS status) ; + private: + void determineStreamSizes(); + void continueHandshake(const SafeByteArray& data); + void indicateError(std::shared_ptr<TLSError> error); + //FIXME: Remove + void indicateError() {indicateError(std::make_shared<TLSError>());} + void handleCertError(SECURITY_STATUS status) ; - void sendDataOnNetwork(const void* pData, size_t dataSize); - void forwardDataToApplication(const void* pData, size_t dataSize); + void sendDataOnNetwork(const void* pData, size_t dataSize); + void forwardDataToApplication(const void* pData, size_t dataSize); - void decryptAndProcessData(const SafeByteArray& data); - void encryptAndSendData(const SafeByteArray& data); + void decryptAndProcessData(const SafeByteArray& data); + void encryptAndSendData(const SafeByteArray& data); - void appendNewData(const SafeByteArray& data); - SECURITY_STATUS validateServerCertificate(); + void appendNewData(const SafeByteArray& data); + SECURITY_STATUS validateServerCertificate(); - void handleCertificateCardRemoved(); + void handleCertificateCardRemoved(); - private: - enum SchannelState - { - Start, - Connecting, - Connected, - Error + private: + enum SchannelState + { + Start, + Connecting, + Connected, + Error - }; + }; - SchannelState state_; - boost::optional<CertificateVerificationError> verificationError_; + SchannelState state_; + boost::optional<CertificateVerificationError> verificationError_; - ULONG secContext_; - ScopedCredHandle credHandle_; - ScopedCtxtHandle contextHandle_; - DWORD contextFlags_; - SecPkgContext_StreamSizes streamSizes_; + ULONG secContext_; + ScopedCredHandle credHandle_; + ScopedCtxtHandle contextHandle_; + DWORD contextFlags_; + SecPkgContext_StreamSizes streamSizes_; - std::vector<char> receivedData_; + std::vector<char> receivedData_; - HCERTSTORE myCertStore_; - std::string certStoreName_; - std::string certName_; + HCERTSTORE myCertStore_; + std::string certStoreName_; + std::string certName_; ////Not needed, most likely - std::string smartCardReader_; //Can be empty string for non SmartCard certificates - boost::shared_ptr<CAPICertificate> userCertificate_; - bool checkCertificateRevocation_; - bool tls1_0Workaround_; - bool disconnectOnCardRemoval_; - }; + std::string smartCardReader_; //Can be empty string for non SmartCard certificates + std::shared_ptr<CAPICertificate> userCertificate_; + bool checkCertificateRevocation_; + bool tls1_0Workaround_; + bool disconnectOnCardRemoval_; + }; } diff --git a/Swiften/TLS/Schannel/SchannelContextFactory.cpp b/Swiften/TLS/Schannel/SchannelContextFactory.cpp index c2587c5..f78d386 100644 --- a/Swiften/TLS/Schannel/SchannelContextFactory.cpp +++ b/Swiften/TLS/Schannel/SchannelContextFactory.cpp @@ -5,13 +5,14 @@ */ /* - * Copyright (c) 2015 Isode Limited. + * Copyright (c) 2015-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ -#include "Swiften/TLS/Schannel/SchannelContextFactory.h" -#include "Swiften/TLS/Schannel/SchannelContext.h" +#include <Swiften/TLS/Schannel/SchannelContextFactory.h> + +#include <Swiften/TLS/Schannel/SchannelContext.h> namespace Swift { @@ -19,22 +20,22 @@ SchannelContextFactory::SchannelContextFactory() : checkCertificateRevocation(tr } bool SchannelContextFactory::canCreate() const { - return true; + return true; } TLSContext* SchannelContextFactory::createTLSContext(const TLSOptions& tlsOptions) { - SchannelContext* context = new SchannelContext(tlsOptions.schannelTLS1_0Workaround); - context->setCheckCertificateRevocation(checkCertificateRevocation); - context->setDisconnectOnCardRemoval(disconnectOnCardRemoval); - return context; + SchannelContext* context = new SchannelContext(tlsOptions.schannelTLS1_0Workaround); + context->setCheckCertificateRevocation(checkCertificateRevocation); + context->setDisconnectOnCardRemoval(disconnectOnCardRemoval); + return context; } void SchannelContextFactory::setCheckCertificateRevocation(bool b) { - checkCertificateRevocation = b; + checkCertificateRevocation = b; } void SchannelContextFactory::setDisconnectOnCardRemoval(bool b) { - disconnectOnCardRemoval = b; + disconnectOnCardRemoval = b; } } diff --git a/Swiften/TLS/Schannel/SchannelContextFactory.h b/Swiften/TLS/Schannel/SchannelContextFactory.h index 27b7dc9..142f193 100644 --- a/Swiften/TLS/Schannel/SchannelContextFactory.h +++ b/Swiften/TLS/Schannel/SchannelContextFactory.h @@ -15,19 +15,19 @@ #include <Swiften/TLS/TLSContextFactory.h> namespace Swift { - class SchannelContextFactory : public TLSContextFactory { - public: - SchannelContextFactory(); + class SchannelContextFactory : public TLSContextFactory { + public: + SchannelContextFactory(); - bool canCreate() const; - virtual TLSContext* createTLSContext(const TLSOptions& tlsOptions); + bool canCreate() const; + virtual TLSContext* createTLSContext(const TLSOptions& tlsOptions); - virtual void setCheckCertificateRevocation(bool b); + virtual void setCheckCertificateRevocation(bool b); - virtual void setDisconnectOnCardRemoval(bool b); + virtual void setDisconnectOnCardRemoval(bool b); - public: - bool checkCertificateRevocation; - bool disconnectOnCardRemoval; - }; + public: + bool checkCertificateRevocation; + bool disconnectOnCardRemoval; + }; } diff --git a/Swiften/TLS/Schannel/SchannelUtil.h b/Swiften/TLS/Schannel/SchannelUtil.h index 4f73aac..ec71d9d 100644 --- a/Swiften/TLS/Schannel/SchannelUtil.h +++ b/Swiften/TLS/Schannel/SchannelUtil.h @@ -4,6 +4,12 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #pragma once #define SECURITY_WIN32 @@ -14,412 +20,412 @@ #include <boost/noncopyable.hpp> -namespace Swift +namespace Swift { - // - // Convenience wrapper around the Schannel CredHandle struct. - // - class ScopedCredHandle - { - private: - struct HandleContext - { - HandleContext() - { - ZeroMemory(&m_h, sizeof(m_h)); - } - - HandleContext(const CredHandle& h) - { - memcpy(&m_h, &h, sizeof(m_h)); - } - - ~HandleContext() - { - ::FreeCredentialsHandle(&m_h); - } - - CredHandle m_h; - }; - - public: - ScopedCredHandle() - : m_pHandle( new HandleContext ) - { - } - - explicit ScopedCredHandle(const CredHandle& h) - : m_pHandle( new HandleContext(h) ) - { - } - - // Copy constructor - explicit ScopedCredHandle(const ScopedCredHandle& rhs) - { - m_pHandle = rhs.m_pHandle; - } - - ~ScopedCredHandle() - { - m_pHandle.reset(); - } - - PCredHandle Reset() - { - CloseHandle(); - return &m_pHandle->m_h; - } - - operator PCredHandle() const - { - return &m_pHandle->m_h; - } - - ScopedCredHandle& operator=(const ScopedCredHandle& sh) - { - // Only update the internal handle if it's different - if (&m_pHandle->m_h != &sh.m_pHandle->m_h) - { - m_pHandle = sh.m_pHandle; - } - - return *this; - } - - void CloseHandle() - { - m_pHandle.reset( new HandleContext ); - } - - private: - boost::shared_ptr<HandleContext> m_pHandle; - }; - - //------------------------------------------------------------------------ - - // - // Convenience wrapper around the Schannel CtxtHandle struct. - // - class ScopedCtxtHandle - { - private: - struct HandleContext - { - HandleContext() - { - ZeroMemory(&m_h, sizeof(m_h)); - } - - ~HandleContext() - { - ::DeleteSecurityContext(&m_h); - } - - CtxtHandle m_h; - }; - - public: - ScopedCtxtHandle() - : m_pHandle( new HandleContext ) - { - } - - explicit ScopedCtxtHandle(CredHandle h) - : m_pHandle( new HandleContext ) - { - } - - // Copy constructor - explicit ScopedCtxtHandle(const ScopedCtxtHandle& rhs) - { - m_pHandle = rhs.m_pHandle; - } - - ~ScopedCtxtHandle() - { - m_pHandle.reset(); - } - - PCredHandle Reset() - { - CloseHandle(); - return &m_pHandle->m_h; - } - - operator PCredHandle() const - { - return &m_pHandle->m_h; - } - - ScopedCtxtHandle& operator=(const ScopedCtxtHandle& sh) - { - // Only update the internal handle if it's different - if (&m_pHandle->m_h != &sh.m_pHandle->m_h) - { - m_pHandle = sh.m_pHandle; - } - - return *this; - } - - void CloseHandle() - { - m_pHandle.reset( new HandleContext ); - } - - private: - boost::shared_ptr<HandleContext> m_pHandle; - }; - - //------------------------------------------------------------------------ - - // - // Convenience wrapper around the Schannel ScopedSecBuffer struct. - // - class ScopedSecBuffer : boost::noncopyable - { - public: - ScopedSecBuffer(PSecBuffer pSecBuffer) - : m_pSecBuffer(pSecBuffer) - { - } - - ~ScopedSecBuffer() - { - // Loop through all the output buffers and make sure we free them - if (m_pSecBuffer->pvBuffer) - FreeContextBuffer(m_pSecBuffer->pvBuffer); - } - - PSecBuffer AsPtr() - { - return m_pSecBuffer; - } - - PSecBuffer operator->() - { - return m_pSecBuffer; - } - - private: - PSecBuffer m_pSecBuffer; - }; - - //------------------------------------------------------------------------ - - // - // Convenience wrapper around the Schannel PCCERT_CONTEXT. - // - class ScopedCertContext - { - private: - struct HandleContext - { - HandleContext() - : m_pCertCtxt(NULL) - { - } - - HandleContext(PCCERT_CONTEXT pCert) - : m_pCertCtxt(pCert) - { - } - - ~HandleContext() - { - if (m_pCertCtxt) - CertFreeCertificateContext(m_pCertCtxt); - } - - PCCERT_CONTEXT m_pCertCtxt; - }; - - public: - ScopedCertContext() - : m_pHandle( new HandleContext ) - { - } - - explicit ScopedCertContext(PCCERT_CONTEXT pCert) - : m_pHandle( new HandleContext(pCert) ) - { - } - - // Copy constructor - ScopedCertContext(const ScopedCertContext& rhs) - { - m_pHandle = rhs.m_pHandle; - } - - ~ScopedCertContext() - { - m_pHandle.reset(); - } - - PCCERT_CONTEXT* Reset() - { - FreeContext(); - return &m_pHandle->m_pCertCtxt; - } - - operator PCCERT_CONTEXT() const - { - return m_pHandle->m_pCertCtxt; - } - - PCCERT_CONTEXT* GetPointer() const - { - return &m_pHandle->m_pCertCtxt; - } - - PCCERT_CONTEXT operator->() const - { - return m_pHandle->m_pCertCtxt; - } - - ScopedCertContext& operator=(const ScopedCertContext& sh) - { - // Only update the internal handle if it's different - if (&m_pHandle->m_pCertCtxt != &sh.m_pHandle->m_pCertCtxt) - { - m_pHandle = sh.m_pHandle; - } - - return *this; - } - - ScopedCertContext& operator=(PCCERT_CONTEXT pCertCtxt) - { - // Only update the internal handle if it's different - if (m_pHandle && m_pHandle->m_pCertCtxt != pCertCtxt) - m_pHandle.reset( new HandleContext(pCertCtxt) ); - - return *this; - } - - void FreeContext() - { - m_pHandle.reset( new HandleContext ); - } - - private: - boost::shared_ptr<HandleContext> m_pHandle; - }; - - //------------------------------------------------------------------------ - - // - // Convenience wrapper around the Schannel HCERTSTORE. - // - class ScopedCertStore : boost::noncopyable - { - public: - ScopedCertStore(HCERTSTORE hCertStore) - : m_hCertStore(hCertStore) - { - } - - ~ScopedCertStore() - { - // Forcefully free all memory related to the store, i.e. we assume all CertContext's that have been opened via this - // cert store have been closed at this point. - if (m_hCertStore) - CertCloseStore(m_hCertStore, CERT_CLOSE_STORE_FORCE_FLAG); - } - - operator HCERTSTORE() const - { - return m_hCertStore; - } - - private: - HCERTSTORE m_hCertStore; - }; - - //------------------------------------------------------------------------ - - // - // Convenience wrapper around the Schannel CERT_CHAIN_CONTEXT. - // - class ScopedCertChainContext - { - private: - struct HandleContext - { - HandleContext() - : m_pCertChainCtxt(NULL) - { - } - - HandleContext(PCCERT_CHAIN_CONTEXT pCert) - : m_pCertChainCtxt(pCert) - { - } - - ~HandleContext() - { - if (m_pCertChainCtxt) - CertFreeCertificateChain(m_pCertChainCtxt); - } - - PCCERT_CHAIN_CONTEXT m_pCertChainCtxt; - }; - - public: - ScopedCertChainContext() - : m_pHandle( new HandleContext ) - { - } - - explicit ScopedCertChainContext(PCCERT_CHAIN_CONTEXT pCert) - : m_pHandle( new HandleContext(pCert) ) - { - } - - // Copy constructor - ScopedCertChainContext(const ScopedCertChainContext& rhs) - { - m_pHandle = rhs.m_pHandle; - } - - ~ScopedCertChainContext() - { - m_pHandle.reset(); - } - - PCCERT_CHAIN_CONTEXT* Reset() - { - FreeContext(); - return &m_pHandle->m_pCertChainCtxt; - } - - operator PCCERT_CHAIN_CONTEXT() const - { - return m_pHandle->m_pCertChainCtxt; - } - - PCCERT_CHAIN_CONTEXT operator->() const - { - return m_pHandle->m_pCertChainCtxt; - } - - ScopedCertChainContext& operator=(const ScopedCertChainContext& sh) - { - // Only update the internal handle if it's different - if (&m_pHandle->m_pCertChainCtxt != &sh.m_pHandle->m_pCertChainCtxt) - { - m_pHandle = sh.m_pHandle; - } - - return *this; - } - - void FreeContext() - { - m_pHandle.reset( new HandleContext ); - } - - private: - boost::shared_ptr<HandleContext> m_pHandle; - }; + // + // Convenience wrapper around the Schannel CredHandle struct. + // + class ScopedCredHandle + { + private: + struct HandleContext + { + HandleContext() + { + ZeroMemory(&m_h, sizeof(m_h)); + } + + HandleContext(const CredHandle& h) + { + memcpy(&m_h, &h, sizeof(m_h)); + } + + ~HandleContext() + { + ::FreeCredentialsHandle(&m_h); + } + + CredHandle m_h; + }; + + public: + ScopedCredHandle() + : m_pHandle( new HandleContext ) + { + } + + explicit ScopedCredHandle(const CredHandle& h) + : m_pHandle( new HandleContext(h) ) + { + } + + // Copy constructor + explicit ScopedCredHandle(const ScopedCredHandle& rhs) + { + m_pHandle = rhs.m_pHandle; + } + + ~ScopedCredHandle() + { + m_pHandle.reset(); + } + + PCredHandle Reset() + { + CloseHandle(); + return &m_pHandle->m_h; + } + + operator PCredHandle() const + { + return &m_pHandle->m_h; + } + + ScopedCredHandle& operator=(const ScopedCredHandle& sh) + { + // Only update the internal handle if it's different + if (&m_pHandle->m_h != &sh.m_pHandle->m_h) + { + m_pHandle = sh.m_pHandle; + } + + return *this; + } + + void CloseHandle() + { + m_pHandle.reset( new HandleContext ); + } + + private: + std::shared_ptr<HandleContext> m_pHandle; + }; + + //------------------------------------------------------------------------ + + // + // Convenience wrapper around the Schannel CtxtHandle struct. + // + class ScopedCtxtHandle + { + private: + struct HandleContext + { + HandleContext() + { + ZeroMemory(&m_h, sizeof(m_h)); + } + + ~HandleContext() + { + ::DeleteSecurityContext(&m_h); + } + + CtxtHandle m_h; + }; + + public: + ScopedCtxtHandle() + : m_pHandle( new HandleContext ) + { + } + + explicit ScopedCtxtHandle(CredHandle h) + : m_pHandle( new HandleContext ) + { + } + + // Copy constructor + explicit ScopedCtxtHandle(const ScopedCtxtHandle& rhs) + { + m_pHandle = rhs.m_pHandle; + } + + ~ScopedCtxtHandle() + { + m_pHandle.reset(); + } + + PCredHandle Reset() + { + CloseHandle(); + return &m_pHandle->m_h; + } + + operator PCredHandle() const + { + return &m_pHandle->m_h; + } + + ScopedCtxtHandle& operator=(const ScopedCtxtHandle& sh) + { + // Only update the internal handle if it's different + if (&m_pHandle->m_h != &sh.m_pHandle->m_h) + { + m_pHandle = sh.m_pHandle; + } + + return *this; + } + + void CloseHandle() + { + m_pHandle.reset( new HandleContext ); + } + + private: + std::shared_ptr<HandleContext> m_pHandle; + }; + + //------------------------------------------------------------------------ + + // + // Convenience wrapper around the Schannel ScopedSecBuffer struct. + // + class ScopedSecBuffer : boost::noncopyable + { + public: + ScopedSecBuffer(PSecBuffer pSecBuffer) + : m_pSecBuffer(pSecBuffer) + { + } + + ~ScopedSecBuffer() + { + // Loop through all the output buffers and make sure we free them + if (m_pSecBuffer->pvBuffer) + FreeContextBuffer(m_pSecBuffer->pvBuffer); + } + + PSecBuffer AsPtr() + { + return m_pSecBuffer; + } + + PSecBuffer operator->() + { + return m_pSecBuffer; + } + + private: + PSecBuffer m_pSecBuffer; + }; + + //------------------------------------------------------------------------ + + // + // Convenience wrapper around the Schannel PCCERT_CONTEXT. + // + class ScopedCertContext + { + private: + struct HandleContext + { + HandleContext() + : m_pCertCtxt(NULL) + { + } + + HandleContext(PCCERT_CONTEXT pCert) + : m_pCertCtxt(pCert) + { + } + + ~HandleContext() + { + if (m_pCertCtxt) + CertFreeCertificateContext(m_pCertCtxt); + } + + PCCERT_CONTEXT m_pCertCtxt; + }; + + public: + ScopedCertContext() + : m_pHandle( new HandleContext ) + { + } + + explicit ScopedCertContext(PCCERT_CONTEXT pCert) + : m_pHandle( new HandleContext(pCert) ) + { + } + + // Copy constructor + ScopedCertContext(const ScopedCertContext& rhs) + { + m_pHandle = rhs.m_pHandle; + } + + ~ScopedCertContext() + { + m_pHandle.reset(); + } + + PCCERT_CONTEXT* Reset() + { + FreeContext(); + return &m_pHandle->m_pCertCtxt; + } + + operator PCCERT_CONTEXT() const + { + return m_pHandle->m_pCertCtxt; + } + + PCCERT_CONTEXT* GetPointer() const + { + return &m_pHandle->m_pCertCtxt; + } + + PCCERT_CONTEXT operator->() const + { + return m_pHandle->m_pCertCtxt; + } + + ScopedCertContext& operator=(const ScopedCertContext& sh) + { + // Only update the internal handle if it's different + if (&m_pHandle->m_pCertCtxt != &sh.m_pHandle->m_pCertCtxt) + { + m_pHandle = sh.m_pHandle; + } + + return *this; + } + + ScopedCertContext& operator=(PCCERT_CONTEXT pCertCtxt) + { + // Only update the internal handle if it's different + if (m_pHandle && m_pHandle->m_pCertCtxt != pCertCtxt) + m_pHandle.reset( new HandleContext(pCertCtxt) ); + + return *this; + } + + void FreeContext() + { + m_pHandle.reset( new HandleContext ); + } + + private: + std::shared_ptr<HandleContext> m_pHandle; + }; + + //------------------------------------------------------------------------ + + // + // Convenience wrapper around the Schannel HCERTSTORE. + // + class ScopedCertStore : boost::noncopyable + { + public: + ScopedCertStore(HCERTSTORE hCertStore) + : m_hCertStore(hCertStore) + { + } + + ~ScopedCertStore() + { + // Forcefully free all memory related to the store, i.e. we assume all CertContext's that have been opened via this + // cert store have been closed at this point. + if (m_hCertStore) + CertCloseStore(m_hCertStore, CERT_CLOSE_STORE_FORCE_FLAG); + } + + operator HCERTSTORE() const + { + return m_hCertStore; + } + + private: + HCERTSTORE m_hCertStore; + }; + + //------------------------------------------------------------------------ + + // + // Convenience wrapper around the Schannel CERT_CHAIN_CONTEXT. + // + class ScopedCertChainContext + { + private: + struct HandleContext + { + HandleContext() + : m_pCertChainCtxt(NULL) + { + } + + HandleContext(PCCERT_CHAIN_CONTEXT pCert) + : m_pCertChainCtxt(pCert) + { + } + + ~HandleContext() + { + if (m_pCertChainCtxt) + CertFreeCertificateChain(m_pCertChainCtxt); + } + + PCCERT_CHAIN_CONTEXT m_pCertChainCtxt; + }; + + public: + ScopedCertChainContext() + : m_pHandle( new HandleContext ) + { + } + + explicit ScopedCertChainContext(PCCERT_CHAIN_CONTEXT pCert) + : m_pHandle( new HandleContext(pCert) ) + { + } + + // Copy constructor + ScopedCertChainContext(const ScopedCertChainContext& rhs) + { + m_pHandle = rhs.m_pHandle; + } + + ~ScopedCertChainContext() + { + m_pHandle.reset(); + } + + PCCERT_CHAIN_CONTEXT* Reset() + { + FreeContext(); + return &m_pHandle->m_pCertChainCtxt; + } + + operator PCCERT_CHAIN_CONTEXT() const + { + return m_pHandle->m_pCertChainCtxt; + } + + PCCERT_CHAIN_CONTEXT operator->() const + { + return m_pHandle->m_pCertChainCtxt; + } + + ScopedCertChainContext& operator=(const ScopedCertChainContext& sh) + { + // Only update the internal handle if it's different + if (&m_pHandle->m_pCertChainCtxt != &sh.m_pHandle->m_pCertChainCtxt) + { + m_pHandle = sh.m_pHandle; + } + + return *this; + } + + void FreeContext() + { + m_pHandle.reset( new HandleContext ); + } + + private: + std::shared_ptr<HandleContext> m_pHandle; + }; } diff --git a/Swiften/TLS/SecureTransport/SecureTransportCertificate.h b/Swiften/TLS/SecureTransport/SecureTransportCertificate.h index b8d3728..7faf3be 100644 --- a/Swiften/TLS/SecureTransport/SecureTransportCertificate.h +++ b/Swiften/TLS/SecureTransport/SecureTransportCertificate.h @@ -1,12 +1,13 @@ /* - * Copyright (c) 2015 Isode Limited. + * Copyright (c) 2015-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #pragma once -#include <boost/shared_ptr.hpp> +#include <memory> + #include <boost/type_traits.hpp> #include <Security/SecCertificate.h> @@ -17,29 +18,29 @@ namespace Swift { class SecureTransportCertificate : public Certificate { public: - SecureTransportCertificate(SecCertificateRef certificate); - SecureTransportCertificate(const ByteArray& der); - virtual ~SecureTransportCertificate(); + SecureTransportCertificate(SecCertificateRef certificate); + SecureTransportCertificate(const ByteArray& der); + virtual ~SecureTransportCertificate(); - virtual std::string getSubjectName() const; - virtual std::vector<std::string> getCommonNames() const; - virtual std::vector<std::string> getSRVNames() const; - virtual std::vector<std::string> getDNSNames() const; - virtual std::vector<std::string> getXMPPAddresses() const; + virtual std::string getSubjectName() const; + virtual std::vector<std::string> getCommonNames() const; + virtual std::vector<std::string> getSRVNames() const; + virtual std::vector<std::string> getDNSNames() const; + virtual std::vector<std::string> getXMPPAddresses() const; - virtual ByteArray toDER() const; + virtual ByteArray toDER() const; private: - void parse(); - typedef boost::remove_pointer<SecCertificateRef>::type SecCertificate; + void parse(); + typedef boost::remove_pointer<SecCertificateRef>::type SecCertificate; private: - boost::shared_ptr<SecCertificate> certificateHandle_; - std::string subjectName_; - std::vector<std::string> commonNames_; - std::vector<std::string> srvNames_; - std::vector<std::string> dnsNames_; - std::vector<std::string> xmppAddresses_; + std::shared_ptr<SecCertificate> certificateHandle_; + std::string subjectName_; + std::vector<std::string> commonNames_; + std::vector<std::string> srvNames_; + std::vector<std::string> dnsNames_; + std::vector<std::string> xmppAddresses_; }; } diff --git a/Swiften/TLS/SecureTransport/SecureTransportCertificate.mm b/Swiften/TLS/SecureTransport/SecureTransportCertificate.mm index ed409bd..db0af89 100644 --- a/Swiften/TLS/SecureTransport/SecureTransportCertificate.mm +++ b/Swiften/TLS/SecureTransport/SecureTransportCertificate.mm @@ -19,123 +19,133 @@ template <typename T, typename S> T bridge_cast(S source) { #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wold-style-cast" - return (__bridge T)(source); + return (__bridge T)(source); #pragma clang diagnostic pop } } +namespace { + +inline std::string ns2StdString(NSString* _Nullable nsString); +inline std::string ns2StdString(NSString* _Nullable nsString) { + std::string stdString; + if (nsString != nil) { + stdString = std::string([nsString cStringUsingEncoding:NSUTF8StringEncoding]); + } + return stdString; +} + +} + namespace Swift { SecureTransportCertificate::SecureTransportCertificate(SecCertificateRef certificate) { - assert(certificate); - CFRetain(certificate); - certificateHandle_ = boost::shared_ptr<SecCertificate>(certificate, CFRelease); - parse(); + assert(certificate); + CFRetain(certificate); + certificateHandle_ = std::shared_ptr<SecCertificate>(certificate, CFRelease); + parse(); } SecureTransportCertificate::SecureTransportCertificate(const ByteArray& der) { - CFDataRef derData = CFDataCreateWithBytesNoCopy(NULL, der.data(), static_cast<CFIndex>(der.size()), NULL); - // certificate will take ownership of derData and free it on its release. - SecCertificateRef certificate = SecCertificateCreateWithData(NULL, derData); - if (certificate) { - certificateHandle_ = boost::shared_ptr<SecCertificate>(certificate, CFRelease); - parse(); - } + CFDataRef derData = CFDataCreateWithBytesNoCopy(nullptr, der.data(), static_cast<CFIndex>(der.size()), nullptr); + // certificate will take ownership of derData and free it on its release. + SecCertificateRef certificate = SecCertificateCreateWithData(nullptr, derData); + if (certificate) { + certificateHandle_ = std::shared_ptr<SecCertificate>(certificate, CFRelease); + parse(); + } } SecureTransportCertificate::~SecureTransportCertificate() { } -#define NS2STDSTRING(a) (a == nil ? std::string() : std::string([a cStringUsingEncoding:NSUTF8StringEncoding])) - - void SecureTransportCertificate::parse() { - assert(certificateHandle_); - CFErrorRef error = NULL; - - // The SecCertificateCopyValues function is not part of the iOS Secure Transport API. - CFDictionaryRef valueDict = SecCertificateCopyValues(certificateHandle_.get(), 0, &error); - if (valueDict) { - // Handle subject. - CFStringRef subject = SecCertificateCopySubjectSummary(certificateHandle_.get()); - if (subject) { - NSString* subjectStr = bridge_cast<NSString*>(subject); - subjectName_ = NS2STDSTRING(subjectStr); - CFRelease(subject); - } - - // Handle a single Common Name. - CFStringRef commonName = NULL; - OSStatus error = SecCertificateCopyCommonName(certificateHandle_.get(), &commonName); - if (!error && commonName) { - NSString* commonNameStr = bridge_cast<NSString*>(commonName); - commonNames_.push_back(NS2STDSTRING(commonNameStr)); - } - if (commonName) { - CFRelease(commonName); - } - - // Handle Subject Alternative Names - NSDictionary* certDict = bridge_cast<NSDictionary*>(valueDict); - NSDictionary* subjectAltNamesDict = certDict[@"2.5.29.17"][@"value"]; - - for (NSDictionary* entry in subjectAltNamesDict) { - if ([entry[@"label"] isEqualToString:[NSString stringWithUTF8String:ID_ON_XMPPADDR_OID]]) { - xmppAddresses_.push_back(NS2STDSTRING(entry[@"value"])); - } - else if ([entry[@"label"] isEqualToString:[NSString stringWithUTF8String:ID_ON_DNSSRV_OID]]) { - srvNames_.push_back(NS2STDSTRING(entry[@"value"])); - } - else if ([entry[@"label"] isEqualToString:@"DNS Name"]) { - dnsNames_.push_back(NS2STDSTRING(entry[@"value"])); - } - } - CFRelease(valueDict); - } - - if (error) { - CFRelease(error); - } + assert(certificateHandle_); + CFErrorRef error = nullptr; + + // The SecCertificateCopyValues function is not part of the iOS Secure Transport API. + CFDictionaryRef valueDict = SecCertificateCopyValues(certificateHandle_.get(), nullptr, &error); + if (valueDict) { + // Handle subject. + CFStringRef subject = SecCertificateCopySubjectSummary(certificateHandle_.get()); + if (subject) { + NSString* subjectStr = bridge_cast<NSString*>(subject); + subjectName_ = ns2StdString(subjectStr); + CFRelease(subject); + } + + // Handle a single Common Name. + CFStringRef commonName = nullptr; + OSStatus error = SecCertificateCopyCommonName(certificateHandle_.get(), &commonName); + if (!error && commonName) { + NSString* commonNameStr = bridge_cast<NSString*>(commonName); + commonNames_.push_back(ns2StdString(commonNameStr)); + } + if (commonName) { + CFRelease(commonName); + } + + // Handle Subject Alternative Names + NSDictionary* certDict = bridge_cast<NSDictionary*>(valueDict); + NSDictionary* subjectAltNamesDict = certDict[@"2.5.29.17"][@"value"]; + + for (NSDictionary* entry in subjectAltNamesDict) { + if ([entry[@"label"] isEqualToString:static_cast<NSString * _Nonnull>([NSString stringWithUTF8String:ID_ON_XMPPADDR_OID])]) { + xmppAddresses_.push_back(ns2StdString(entry[@"value"])); + } + else if ([entry[@"label"] isEqualToString:static_cast<NSString * _Nonnull>([NSString stringWithUTF8String:ID_ON_DNSSRV_OID])]) { + srvNames_.push_back(ns2StdString(entry[@"value"])); + } + else if ([entry[@"label"] isEqualToString:@"DNS Name"]) { + dnsNames_.push_back(ns2StdString(entry[@"value"])); + } + } + CFRelease(valueDict); + } + + if (error) { + CFRelease(error); + } } std::string SecureTransportCertificate::getSubjectName() const { - return subjectName_; + return subjectName_; } std::vector<std::string> SecureTransportCertificate::getCommonNames() const { - return commonNames_; + return commonNames_; } std::vector<std::string> SecureTransportCertificate::getSRVNames() const { - return srvNames_; + return srvNames_; } std::vector<std::string> SecureTransportCertificate::getDNSNames() const { - return dnsNames_; + return dnsNames_; } std::vector<std::string> SecureTransportCertificate::getXMPPAddresses() const { - return xmppAddresses_; + return xmppAddresses_; } ByteArray SecureTransportCertificate::toDER() const { - ByteArray der; - if (certificateHandle_) { - CFDataRef derData = SecCertificateCopyData(certificateHandle_.get()); - if (derData) { - try { - size_t dataSize = boost::numeric_cast<size_t>(CFDataGetLength(derData)); - der.resize(dataSize); - CFDataGetBytes(derData, CFRangeMake(0,CFDataGetLength(derData)), der.data()); - } catch (...) { - } - CFRelease(derData); - } - } - return der; + ByteArray der; + if (certificateHandle_) { + CFDataRef derData = SecCertificateCopyData(certificateHandle_.get()); + if (derData) { + try { + size_t dataSize = boost::numeric_cast<size_t>(CFDataGetLength(derData)); + der.resize(dataSize); + CFDataGetBytes(derData, CFRangeMake(0,CFDataGetLength(derData)), der.data()); + } catch (...) { + } + CFRelease(derData); + } + } + return der; } } diff --git a/Swiften/TLS/SecureTransport/SecureTransportCertificateFactory.h b/Swiften/TLS/SecureTransport/SecureTransportCertificateFactory.h index 1f86541..3ea469d 100644 --- a/Swiften/TLS/SecureTransport/SecureTransportCertificateFactory.h +++ b/Swiften/TLS/SecureTransport/SecureTransportCertificateFactory.h @@ -10,11 +10,11 @@ #include <Swiften/TLS/SecureTransport/SecureTransportCertificate.h> namespace Swift { - + class SecureTransportCertificateFactory : public CertificateFactory { - public: - virtual Certificate* createCertificateFromDER(const ByteArray& der) { - return new SecureTransportCertificate(der); - } - }; + public: + virtual Certificate* createCertificateFromDER(const ByteArray& der) { + return new SecureTransportCertificate(der); + } + }; } diff --git a/Swiften/TLS/SecureTransport/SecureTransportContext.h b/Swiften/TLS/SecureTransport/SecureTransportContext.h index aa17c66..3942904 100644 --- a/Swiften/TLS/SecureTransport/SecureTransportContext.h +++ b/Swiften/TLS/SecureTransport/SecureTransportContext.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015 Isode Limited. + * Copyright (c) 2015-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ @@ -13,46 +13,46 @@ namespace Swift { class SecureTransportContext : public TLSContext { - public: - SecureTransportContext(bool checkCertificateRevocation); - virtual ~SecureTransportContext(); + public: + SecureTransportContext(bool checkCertificateRevocation); + virtual ~SecureTransportContext(); - virtual void connect(); + virtual void connect(); - virtual bool setClientCertificate(CertificateWithKey::ref cert); + virtual bool setClientCertificate(CertificateWithKey::ref cert); - virtual void handleDataFromNetwork(const SafeByteArray&); - virtual void handleDataFromApplication(const SafeByteArray&); + virtual void handleDataFromNetwork(const SafeByteArray&); + virtual void handleDataFromApplication(const SafeByteArray&); - virtual std::vector<Certificate::ref> getPeerCertificateChain() const; - virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const; + virtual std::vector<Certificate::ref> getPeerCertificateChain() const; + virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const; - virtual ByteArray getFinishMessage() const; - - private: - static OSStatus SSLSocketReadCallback(SSLConnectionRef connection, void *data, size_t *dataLength); - static OSStatus SSLSocketWriteCallback(SSLConnectionRef connection, const void *data, size_t *dataLength); + virtual ByteArray getFinishMessage() const; - private: - enum State { None, Handshake, HandshakeDone, Error}; - static std::string stateToString(State state); - void setState(State newState); + private: + static OSStatus SSLSocketReadCallback(SSLConnectionRef connection, void *data, size_t *dataLength); + static OSStatus SSLSocketWriteCallback(SSLConnectionRef connection, const void *data, size_t *dataLength); - static boost::shared_ptr<TLSError> nativeToTLSError(OSStatus error); - boost::shared_ptr<CertificateVerificationError> CSSMErrorToVerificationError(OSStatus resultCode); + private: + enum State { None, Handshake, HandshakeDone, Error}; + static std::string stateToString(State state); + void setState(State newState); - void processHandshake(); - void verifyServerCertificate(); + static std::shared_ptr<TLSError> nativeToTLSError(OSStatus error); + std::shared_ptr<CertificateVerificationError> CSSMErrorToVerificationError(OSStatus resultCode); - void fatalError(boost::shared_ptr<TLSError> error, boost::shared_ptr<CertificateVerificationError> certificateError); + void processHandshake(); + void verifyServerCertificate(); - private: - boost::shared_ptr<SSLContext> sslContext_; - SafeByteArray readingBuffer_; - State state_; - CertificateVerificationError::ref verificationError_; - CertificateWithKey::ref clientCertificate_; - bool checkCertificateRevocation_; + void fatalError(std::shared_ptr<TLSError> error, std::shared_ptr<CertificateVerificationError> certificateError); + + private: + std::shared_ptr<SSLContext> sslContext_; + SafeByteArray readingBuffer_; + State state_; + CertificateVerificationError::ref verificationError_; + CertificateWithKey::ref clientCertificate_; + bool checkCertificateRevocation_; }; } diff --git a/Swiften/TLS/SecureTransport/SecureTransportContext.mm b/Swiften/TLS/SecureTransport/SecureTransportContext.mm index ca6c5bb..1ed636b 100644 --- a/Swiften/TLS/SecureTransport/SecureTransportContext.mm +++ b/Swiften/TLS/SecureTransport/SecureTransportContext.mm @@ -21,15 +21,15 @@ #import <Security/SecImportExport.h> namespace { - typedef boost::remove_pointer<CFArrayRef>::type CFArray; - typedef boost::remove_pointer<SecTrustRef>::type SecTrust; + typedef boost::remove_pointer<CFArrayRef>::type CFArray; + typedef boost::remove_pointer<SecTrustRef>::type SecTrust; } template <typename T, typename S> T bridge_cast(S source) { #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wold-style-cast" - return (__bridge T)(source); + return (__bridge T)(source); #pragma clang diagnostic pop } @@ -37,162 +37,162 @@ namespace Swift { namespace { - + CFArrayRef CreateClientCertificateChainAsCFArrayRef(CertificateWithKey::ref key) { - boost::shared_ptr<PKCS12Certificate> pkcs12 = boost::dynamic_pointer_cast<PKCS12Certificate>(key); - if (!key) { - return NULL; - } - - SafeByteArray safePassword = pkcs12->getPassword(); - CFIndex passwordSize = 0; - try { - passwordSize = boost::numeric_cast<CFIndex>(safePassword.size()); - } catch (...) { - return NULL; - } - - CFMutableArrayRef certChain = CFArrayCreateMutable(NULL, 0, 0); - - OSStatus securityError = errSecSuccess; - CFStringRef password = CFStringCreateWithBytes(kCFAllocatorDefault, safePassword.data(), passwordSize, kCFStringEncodingUTF8, false); - const void* keys[] = { kSecImportExportPassphrase }; - const void* values[] = { password }; - - CFDictionaryRef options = CFDictionaryCreate(NULL, keys, values, 1, NULL, NULL); - - CFArrayRef items = NULL; - CFDataRef pkcs12Data = bridge_cast<CFDataRef>([NSData dataWithBytes: static_cast<const void *>(pkcs12->getData().data()) length:pkcs12->getData().size()]); - securityError = SecPKCS12Import(pkcs12Data, options, &items); - CFRelease(options); - NSArray* nsItems = bridge_cast<NSArray*>(items); - - switch(securityError) { - case errSecSuccess: - break; - case errSecAuthFailed: - // Password did not work for decoding the certificate. - SWIFT_LOG(warning) << "Invalid password." << std::endl; - break; - case errSecDecode: - // Other decoding error. - SWIFT_LOG(warning) << "PKCS12 decoding error." << std::endl; - break; - default: - SWIFT_LOG(warning) << "Unknown error." << std::endl; - } - - if (securityError != errSecSuccess) { - if (items) { - CFRelease(items); - items = NULL; - } - CFRelease(certChain); - certChain = NULL; - } - - if (certChain) { - CFArrayAppendValue(certChain, nsItems[0][@"identity"]); - - for (CFIndex index = 0; index < CFArrayGetCount(bridge_cast<CFArrayRef>(nsItems[0][@"chain"])); index++) { - CFArrayAppendValue(certChain, CFArrayGetValueAtIndex(bridge_cast<CFArrayRef>(nsItems[0][@"chain"]), index)); - } - } - return certChain; + std::shared_ptr<PKCS12Certificate> pkcs12 = std::dynamic_pointer_cast<PKCS12Certificate>(key); + if (!key) { + return nullptr; + } + + SafeByteArray safePassword = pkcs12->getPassword(); + CFIndex passwordSize = 0; + try { + passwordSize = boost::numeric_cast<CFIndex>(safePassword.size()); + } catch (...) { + return nullptr; + } + + CFMutableArrayRef certChain = CFArrayCreateMutable(nullptr, 0, nullptr); + + OSStatus securityError = errSecSuccess; + CFStringRef password = CFStringCreateWithBytes(kCFAllocatorDefault, safePassword.data(), passwordSize, kCFStringEncodingUTF8, false); + const void* keys[] = { kSecImportExportPassphrase }; + const void* values[] = { password }; + + CFDictionaryRef options = CFDictionaryCreate(nullptr, keys, values, 1, nullptr, nullptr); + + CFArrayRef items = nullptr; + CFDataRef pkcs12Data = bridge_cast<CFDataRef>([NSData dataWithBytes: static_cast<const void *>(pkcs12->getData().data()) length:pkcs12->getData().size()]); + securityError = SecPKCS12Import(pkcs12Data, options, &items); + CFRelease(options); + NSArray* nsItems = bridge_cast<NSArray*>(items); + + switch(securityError) { + case errSecSuccess: + break; + case errSecAuthFailed: + // Password did not work for decoding the certificate. + SWIFT_LOG(warning) << "Invalid password." << std::endl; + break; + case errSecDecode: + // Other decoding error. + SWIFT_LOG(warning) << "PKCS12 decoding error." << std::endl; + break; + default: + SWIFT_LOG(warning) << "Unknown error." << std::endl; + } + + if (securityError != errSecSuccess) { + if (items) { + CFRelease(items); + items = nullptr; + } + CFRelease(certChain); + certChain = nullptr; + } + + if (certChain) { + CFArrayAppendValue(certChain, nsItems[0][@"identity"]); + + for (CFIndex index = 0; index < CFArrayGetCount(bridge_cast<CFArrayRef>(nsItems[0][@"chain"])); index++) { + CFArrayAppendValue(certChain, CFArrayGetValueAtIndex(bridge_cast<CFArrayRef>(nsItems[0][@"chain"]), index)); + } + } + return certChain; } } SecureTransportContext::SecureTransportContext(bool checkCertificateRevocation) : state_(None), checkCertificateRevocation_(checkCertificateRevocation) { - sslContext_ = boost::shared_ptr<SSLContext>(SSLCreateContext(NULL, kSSLClientSide, kSSLStreamType), CFRelease); - - OSStatus error = noErr; - // set IO callbacks - error = SSLSetIOFuncs(sslContext_.get(), &SecureTransportContext::SSLSocketReadCallback, &SecureTransportContext::SSLSocketWriteCallback); - if (error != noErr) { - SWIFT_LOG(error) << "Unable to set IO functions to SSL context." << std::endl; - sslContext_.reset(); - } - - error = SSLSetConnection(sslContext_.get(), this); - if (error != noErr) { - SWIFT_LOG(error) << "Unable to set connection to SSL context." << std::endl; - sslContext_.reset(); - } - - - error = SSLSetSessionOption(sslContext_.get(), kSSLSessionOptionBreakOnServerAuth, true); - if (error != noErr) { - SWIFT_LOG(error) << "Unable to set kSSLSessionOptionBreakOnServerAuth on session." << std::endl; - sslContext_.reset(); - } + sslContext_ = std::shared_ptr<SSLContext>(SSLCreateContext(nullptr, kSSLClientSide, kSSLStreamType), CFRelease); + + OSStatus error = noErr; + // set IO callbacks + error = SSLSetIOFuncs(sslContext_.get(), &SecureTransportContext::SSLSocketReadCallback, &SecureTransportContext::SSLSocketWriteCallback); + if (error != noErr) { + SWIFT_LOG(error) << "Unable to set IO functions to SSL context." << std::endl; + sslContext_.reset(); + } + + error = SSLSetConnection(sslContext_.get(), this); + if (error != noErr) { + SWIFT_LOG(error) << "Unable to set connection to SSL context." << std::endl; + sslContext_.reset(); + } + + + error = SSLSetSessionOption(sslContext_.get(), kSSLSessionOptionBreakOnServerAuth, true); + if (error != noErr) { + SWIFT_LOG(error) << "Unable to set kSSLSessionOptionBreakOnServerAuth on session." << std::endl; + sslContext_.reset(); + } } SecureTransportContext::~SecureTransportContext() { - if (sslContext_) { - SSLClose(sslContext_.get()); - } + if (sslContext_) { + SSLClose(sslContext_.get()); + } } std::string SecureTransportContext::stateToString(State state) { - std::string returnValue; - switch(state) { - case Handshake: - returnValue = "Handshake"; - break; - case HandshakeDone: - returnValue = "HandshakeDone"; - break; - case None: - returnValue = "None"; - break; - case Error: - returnValue = "Error"; - break; - } - return returnValue; + std::string returnValue; + switch(state) { + case Handshake: + returnValue = "Handshake"; + break; + case HandshakeDone: + returnValue = "HandshakeDone"; + break; + case None: + returnValue = "None"; + break; + case Error: + returnValue = "Error"; + break; + } + return returnValue; } void SecureTransportContext::setState(State newState) { - SWIFT_LOG(debug) << "Switch state from " << stateToString(state_) << " to " << stateToString(newState) << "." << std::endl; - state_ = newState; + SWIFT_LOG(debug) << "Switch state from " << stateToString(state_) << " to " << stateToString(newState) << "." << std::endl; + state_ = newState; } void SecureTransportContext::connect() { - SWIFT_LOG_ASSERT(state_ == None, error) << "current state '" << stateToString(state_) << " invalid." << std::endl; - if (clientCertificate_) { - CFArrayRef certs = CreateClientCertificateChainAsCFArrayRef(clientCertificate_); - if (certs) { - boost::shared_ptr<CFArray> certRefs(certs, CFRelease); - OSStatus result = SSLSetCertificate(sslContext_.get(), certRefs.get()); - if (result != noErr) { - SWIFT_LOG(error) << "SSLSetCertificate failed with error " << result << "." << std::endl; - } - } - } - processHandshake(); + SWIFT_LOG_ASSERT(state_ == None, error) << "current state '" << stateToString(state_) << " invalid." << std::endl; + if (clientCertificate_) { + CFArrayRef certs = CreateClientCertificateChainAsCFArrayRef(clientCertificate_); + if (certs) { + std::shared_ptr<CFArray> certRefs(certs, CFRelease); + OSStatus result = SSLSetCertificate(sslContext_.get(), certRefs.get()); + if (result != noErr) { + SWIFT_LOG(error) << "SSLSetCertificate failed with error " << result << "." << std::endl; + } + } + } + processHandshake(); } void SecureTransportContext::processHandshake() { - SWIFT_LOG_ASSERT(state_ == None || state_ == Handshake, error) << "current state '" << stateToString(state_) << " invalid." << std::endl; - OSStatus error = SSLHandshake(sslContext_.get()); - if (error == errSSLWouldBlock) { - setState(Handshake); - } - else if (error == noErr) { - SWIFT_LOG(debug) << "TLS handshake successful." << std::endl; - setState(HandshakeDone); - onConnected(); - } - else if (error == errSSLPeerAuthCompleted) { - SWIFT_LOG(debug) << "Received server certificate. Start verification." << std::endl; - setState(Handshake); - verifyServerCertificate(); - } - else { - SWIFT_LOG(debug) << "Error returned from SSLHandshake call is " << error << "." << std::endl; - fatalError(nativeToTLSError(error), boost::make_shared<CertificateVerificationError>()); - } + SWIFT_LOG_ASSERT(state_ == None || state_ == Handshake, error) << "current state '" << stateToString(state_) << " invalid." << std::endl; + OSStatus error = SSLHandshake(sslContext_.get()); + if (error == errSSLWouldBlock) { + setState(Handshake); + } + else if (error == noErr) { + SWIFT_LOG(debug) << "TLS handshake successful." << std::endl; + setState(HandshakeDone); + onConnected(); + } + else if (error == errSSLPeerAuthCompleted) { + SWIFT_LOG(debug) << "Received server certificate. Start verification." << std::endl; + setState(Handshake); + verifyServerCertificate(); + } + else { + SWIFT_LOG(debug) << "Error returned from SSLHandshake call is " << error << "." << std::endl; + fatalError(nativeToTLSError(error), std::make_shared<CertificateVerificationError>()); + } } @@ -200,296 +200,308 @@ void SecureTransportContext::processHandshake() { #pragma clang diagnostic ignored "-Wdeprecated-declarations" void SecureTransportContext::verifyServerCertificate() { - SecTrustRef trust = NULL; - OSStatus error = SSLCopyPeerTrust(sslContext_.get(), &trust); - if (error != noErr) { - fatalError(boost::make_shared<TLSError>(), boost::make_shared<CertificateVerificationError>()); - return; - } - boost::shared_ptr<SecTrust> trustRef = boost::shared_ptr<SecTrust>(trust, CFRelease); - - if (checkCertificateRevocation_) { - error = SecTrustSetOptions(trust, kSecTrustOptionRequireRevPerCert | kSecTrustOptionFetchIssuerFromNet); - if (error != noErr) { - fatalError(boost::make_shared<TLSError>(), boost::make_shared<CertificateVerificationError>()); - return; - } - } - - SecTrustResultType trustResult; - error = SecTrustEvaluate(trust, &trustResult); - if (error != errSecSuccess) { - fatalError(boost::make_shared<TLSError>(), boost::make_shared<CertificateVerificationError>()); - return; - } - - OSStatus cssmResult = 0; - switch(trustResult) { - case kSecTrustResultUnspecified: - SWIFT_LOG(warning) << "Successful implicit validation. Result unspecified." << std::endl; - break; - case kSecTrustResultProceed: - SWIFT_LOG(warning) << "Validation resulted in explicitly trusted." << std::endl; - break; - case kSecTrustResultRecoverableTrustFailure: - SWIFT_LOG(warning) << "recoverable trust failure" << std::endl; - error = SecTrustGetCssmResultCode(trust, &cssmResult); - if (error == errSecSuccess) { - verificationError_ = CSSMErrorToVerificationError(cssmResult); - if (cssmResult == CSSMERR_TP_VERIFY_ACTION_FAILED || cssmResult == CSSMERR_APPLETP_INCOMPLETE_REVOCATION_CHECK ) { - // Find out the reason why the verification failed. - CFArrayRef certChain; - CSSM_TP_APPLE_EVIDENCE_INFO* statusChain; - error = SecTrustGetResult(trustRef.get(), &trustResult, &certChain, &statusChain); - if (error == errSecSuccess) { - boost::shared_ptr<CFArray> certChainRef = boost::shared_ptr<CFArray>(certChain, CFRelease); - for (CFIndex index = 0; index < CFArrayGetCount(certChainRef.get()); index++) { - for (CFIndex n = 0; n < statusChain[index].NumStatusCodes; n++) { - // Even though Secure Transport reported CSSMERR_APPLETP_INCOMPLETE_REVOCATION_CHECK on the whole certificate - // chain, the actual cause can be that a revocation check for a specific cert returned CSSMERR_TP_CERT_REVOKED. - if (!verificationError_ || verificationError_->getType() == CertificateVerificationError::RevocationCheckFailed) { - verificationError_ = CSSMErrorToVerificationError(statusChain[index].StatusCodes[n]); - } - } - } - } - else { - - } - } - } - else { - verificationError_ = boost::make_shared<CertificateVerificationError>(CertificateVerificationError::UnknownError); - } - break; - case kSecTrustResultOtherError: - verificationError_ = boost::make_shared<CertificateVerificationError>(CertificateVerificationError::UnknownError); - break; - default: - SWIFT_LOG(warning) << "Unhandled trust result " << trustResult << "." << std::endl; - break; - } - - // We proceed with the TLS handshake here to give the application an opportunity - // to apply custom validation and trust management. The application is responsible - // to call \ref getPeerCertificateVerificationError directly after the \ref onConnected - // signal is called and before any application data is send to the context. - processHandshake(); + SecTrustRef trust = nullptr; + OSStatus error = SSLCopyPeerTrust(sslContext_.get(), &trust); + if (error != noErr) { + fatalError(std::make_shared<TLSError>(), std::make_shared<CertificateVerificationError>()); + return; + } + std::shared_ptr<SecTrust> trustRef = std::shared_ptr<SecTrust>(trust, CFRelease); + + if (checkCertificateRevocation_) { + error = SecTrustSetOptions(trust, kSecTrustOptionRequireRevPerCert | kSecTrustOptionFetchIssuerFromNet); + if (error != noErr) { + fatalError(std::make_shared<TLSError>(), std::make_shared<CertificateVerificationError>()); + return; + } + } + + SecTrustResultType trustResult; + error = SecTrustEvaluate(trust, &trustResult); + if (error != errSecSuccess) { + fatalError(std::make_shared<TLSError>(), std::make_shared<CertificateVerificationError>()); + return; + } + + OSStatus cssmResult = 0; + switch(trustResult) { + case kSecTrustResultUnspecified: + SWIFT_LOG(warning) << "Successful implicit validation. Result unspecified." << std::endl; + break; + case kSecTrustResultProceed: + SWIFT_LOG(warning) << "Validation resulted in explicitly trusted." << std::endl; + break; + case kSecTrustResultRecoverableTrustFailure: + SWIFT_LOG(warning) << "recoverable trust failure" << std::endl; + error = SecTrustGetCssmResultCode(trust, &cssmResult); + if (error == errSecSuccess) { + verificationError_ = CSSMErrorToVerificationError(cssmResult); + if (cssmResult == CSSMERR_TP_VERIFY_ACTION_FAILED || cssmResult == CSSMERR_APPLETP_INCOMPLETE_REVOCATION_CHECK ) { + // Find out the reason why the verification failed. + CFArrayRef certChain; + CSSM_TP_APPLE_EVIDENCE_INFO* statusChain; + error = SecTrustGetResult(trustRef.get(), &trustResult, &certChain, &statusChain); + if (error == errSecSuccess) { + std::shared_ptr<CFArray> certChainRef = std::shared_ptr<CFArray>(certChain, CFRelease); + for (CFIndex index = 0; index < CFArrayGetCount(certChainRef.get()); index++) { + for (CFIndex n = 0; n < statusChain[index].NumStatusCodes; n++) { + // Even though Secure Transport reported CSSMERR_APPLETP_INCOMPLETE_REVOCATION_CHECK on the whole certificate + // chain, the actual cause can be that a revocation check for a specific cert returned CSSMERR_TP_CERT_REVOKED. + if (!verificationError_ || verificationError_->getType() == CertificateVerificationError::RevocationCheckFailed) { + verificationError_ = CSSMErrorToVerificationError(statusChain[index].StatusCodes[n]); + } + } + } + } + else { + + } + } + } + else { + verificationError_ = std::make_shared<CertificateVerificationError>(CertificateVerificationError::UnknownError); + } + break; + case kSecTrustResultInvalid: + verificationError_ = std::make_shared<CertificateVerificationError>(CertificateVerificationError::UnknownError); + break; + case kSecTrustResultConfirm: + // TODO: Confirmation from the user is required before proceeding. + verificationError_ = std::make_shared<CertificateVerificationError>(CertificateVerificationError::UnknownError); + break; + case kSecTrustResultDeny: + // The user specified that the certificate should not be trusted. + verificationError_ = std::make_shared<CertificateVerificationError>(CertificateVerificationError::Untrusted); + break; + case kSecTrustResultFatalTrustFailure: + // Trust denied; no simple fix is available. + verificationError_ = std::make_shared<CertificateVerificationError>(CertificateVerificationError::UnknownError); + break; + case kSecTrustResultOtherError: + verificationError_ = std::make_shared<CertificateVerificationError>(CertificateVerificationError::UnknownError); + break; + } + + // We proceed with the TLS handshake here to give the application an opportunity + // to apply custom validation and trust management. The application is responsible + // to call \ref getPeerCertificateVerificationError directly after the \ref onConnected + // signal is called and before any application data is send to the context. + processHandshake(); } #pragma clang diagnostic pop bool SecureTransportContext::setClientCertificate(CertificateWithKey::ref cert) { - CFArrayRef nativeClientChain = CreateClientCertificateChainAsCFArrayRef(cert); - if (nativeClientChain) { - clientCertificate_ = cert; - CFRelease(nativeClientChain); - return true; - } - else { - return false; - } + CFArrayRef nativeClientChain = CreateClientCertificateChainAsCFArrayRef(cert); + if (nativeClientChain) { + clientCertificate_ = cert; + CFRelease(nativeClientChain); + return true; + } + else { + return false; + } } void SecureTransportContext::handleDataFromNetwork(const SafeByteArray& data) { - SWIFT_LOG(debug) << std::endl; - SWIFT_LOG_ASSERT(state_ == HandshakeDone || state_ == Handshake, error) << "current state '" << stateToString(state_) << " invalid." << std::endl; - - append(readingBuffer_, data); - - size_t bytesRead = 0; - OSStatus error = noErr; - SafeByteArray applicationData; - - switch(state_) { - case None: - assert(false && "Invalid state 'None'."); - break; - case Handshake: - processHandshake(); - break; - case HandshakeDone: - while (error == noErr) { - applicationData.resize(readingBuffer_.size()); - error = SSLRead(sslContext_.get(), applicationData.data(), applicationData.size(), &bytesRead); - if (error == noErr) { - // Read successful. - } - else if (error == errSSLWouldBlock) { - // Secure Transport does not want more data. - break; - } - else { - SWIFT_LOG(error) << "SSLRead failed with error " << error << ", read bytes: " << bytesRead << "." << std::endl; - fatalError(boost::make_shared<TLSError>(), boost::make_shared<CertificateVerificationError>()); - return; - } - - if (bytesRead > 0) { - applicationData.resize(bytesRead); - onDataForApplication(applicationData); - } - else { - break; - } - } - break; - case Error: - SWIFT_LOG(debug) << "Igoring received data in error state." << std::endl; - break; - } + SWIFT_LOG(debug) << std::endl; + SWIFT_LOG_ASSERT(state_ == HandshakeDone || state_ == Handshake, error) << "current state '" << stateToString(state_) << " invalid." << std::endl; + + append(readingBuffer_, data); + + size_t bytesRead = 0; + OSStatus error = noErr; + SafeByteArray applicationData; + + switch(state_) { + case None: + assert(false && "Invalid state 'None'."); + break; + case Handshake: + processHandshake(); + break; + case HandshakeDone: + while (error == noErr) { + applicationData.resize(readingBuffer_.size()); + error = SSLRead(sslContext_.get(), applicationData.data(), applicationData.size(), &bytesRead); + if (error == noErr) { + // Read successful. + } + else if (error == errSSLWouldBlock) { + // Secure Transport does not want more data. + break; + } + else { + SWIFT_LOG(error) << "SSLRead failed with error " << error << ", read bytes: " << bytesRead << "." << std::endl; + fatalError(std::make_shared<TLSError>(), std::make_shared<CertificateVerificationError>()); + return; + } + + if (bytesRead > 0) { + applicationData.resize(bytesRead); + onDataForApplication(applicationData); + } + else { + break; + } + } + break; + case Error: + SWIFT_LOG(debug) << "Igoring received data in error state." << std::endl; + break; + } } void SecureTransportContext::handleDataFromApplication(const SafeByteArray& data) { - size_t processedBytes = 0; - OSStatus error = SSLWrite(sslContext_.get(), data.data(), data.size(), &processedBytes); - switch(error) { - case errSSLWouldBlock: - SWIFT_LOG(warning) << "Unexpected because the write callback does not block." << std::endl; - return; - case errSSLClosedGraceful: - case noErr: - return; - default: - SWIFT_LOG(warning) << "SSLWrite returned error code: " << error << ", processed bytes: " << processedBytes << std::endl; - fatalError(boost::make_shared<TLSError>(), boost::shared_ptr<CertificateVerificationError>()); - } + size_t processedBytes = 0; + OSStatus error = SSLWrite(sslContext_.get(), data.data(), data.size(), &processedBytes); + switch(error) { + case errSSLWouldBlock: + SWIFT_LOG(warning) << "Unexpected because the write callback does not block." << std::endl; + return; + case errSSLClosedGraceful: + case noErr: + return; + default: + SWIFT_LOG(warning) << "SSLWrite returned error code: " << error << ", processed bytes: " << processedBytes << std::endl; + fatalError(std::make_shared<TLSError>(), std::shared_ptr<CertificateVerificationError>()); + } } std::vector<Certificate::ref> SecureTransportContext::getPeerCertificateChain() const { - std::vector<Certificate::ref> peerCertificateChain; - - if (sslContext_) { - typedef boost::remove_pointer<SecTrustRef>::type SecTrust; - boost::shared_ptr<SecTrust> securityTrust; - - SecTrustRef secTrust = NULL;; - OSStatus error = SSLCopyPeerTrust(sslContext_.get(), &secTrust); - if (error == noErr) { - securityTrust = boost::shared_ptr<SecTrust>(secTrust, CFRelease); - - CFIndex chainSize = SecTrustGetCertificateCount(securityTrust.get()); - for (CFIndex n = 0; n < chainSize; n++) { - SecCertificateRef certificate = SecTrustGetCertificateAtIndex(securityTrust.get(), n); - if (certificate) { - peerCertificateChain.push_back(boost::make_shared<SecureTransportCertificate>(certificate)); - } - } - } - else { - SWIFT_LOG(warning) << "Failed to obtain peer trust structure; error = " << error << "." << std::endl; - } - } - - return peerCertificateChain; + std::vector<Certificate::ref> peerCertificateChain; + + if (sslContext_) { + typedef boost::remove_pointer<SecTrustRef>::type SecTrust; + std::shared_ptr<SecTrust> securityTrust; + + SecTrustRef secTrust = nullptr;; + OSStatus error = SSLCopyPeerTrust(sslContext_.get(), &secTrust); + if (error == noErr) { + securityTrust = std::shared_ptr<SecTrust>(secTrust, CFRelease); + + CFIndex chainSize = SecTrustGetCertificateCount(securityTrust.get()); + for (CFIndex n = 0; n < chainSize; n++) { + SecCertificateRef certificate = SecTrustGetCertificateAtIndex(securityTrust.get(), n); + if (certificate) { + peerCertificateChain.push_back(std::make_shared<SecureTransportCertificate>(certificate)); + } + } + } + else { + SWIFT_LOG(warning) << "Failed to obtain peer trust structure; error = " << error << "." << std::endl; + } + } + + return peerCertificateChain; } CertificateVerificationError::ref SecureTransportContext::getPeerCertificateVerificationError() const { - return verificationError_; + return verificationError_; } ByteArray SecureTransportContext::getFinishMessage() const { - SWIFT_LOG(warning) << "Access to TLS handshake finish message is not part of OS X Secure Transport APIs." << std::endl; - return ByteArray(); + SWIFT_LOG(warning) << "Access to TLS handshake finish message is not part of OS X Secure Transport APIs." << std::endl; + return ByteArray(); } /** - * This I/O callback simulates an asynchronous read to the read buffer of the context. If it is empty, it returns errSSLWouldBlock; else + * This I/O callback simulates an asynchronous read to the read buffer of the context. If it is empty, it returns errSSLWouldBlock; else * the data within the buffer is returned. */ OSStatus SecureTransportContext::SSLSocketReadCallback(SSLConnectionRef connection, void *data, size_t *dataLength) { - SecureTransportContext* context = const_cast<SecureTransportContext*>(static_cast<const SecureTransportContext*>(connection)); - OSStatus retValue = noErr; - - if (context->readingBuffer_.size() < *dataLength) { - // Would block because Secure Transport is trying to read more data than there currently is available in the buffer. - *dataLength = 0; - retValue = errSSLWouldBlock; - } - else { - size_t bufferLen = *dataLength; - size_t copyToBuffer = bufferLen < context->readingBuffer_.size() ? bufferLen : context->readingBuffer_.size(); - - memcpy(data, context->readingBuffer_.data(), copyToBuffer); - - context->readingBuffer_ = SafeByteArray(context->readingBuffer_.data() + copyToBuffer, context->readingBuffer_.data() + context->readingBuffer_.size()); - *dataLength = copyToBuffer; - } - return retValue; + SecureTransportContext* context = const_cast<SecureTransportContext*>(static_cast<const SecureTransportContext*>(connection)); + OSStatus retValue = noErr; + + if (context->readingBuffer_.size() < *dataLength) { + // Would block because Secure Transport is trying to read more data than there currently is available in the buffer. + *dataLength = 0; + retValue = errSSLWouldBlock; + } + else { + size_t bufferLen = *dataLength; + size_t copyToBuffer = bufferLen < context->readingBuffer_.size() ? bufferLen : context->readingBuffer_.size(); + + memcpy(data, context->readingBuffer_.data(), copyToBuffer); + + context->readingBuffer_ = SafeByteArray(context->readingBuffer_.data() + copyToBuffer, context->readingBuffer_.data() + context->readingBuffer_.size()); + *dataLength = copyToBuffer; + } + return retValue; } OSStatus SecureTransportContext::SSLSocketWriteCallback(SSLConnectionRef connection, const void *data, size_t *dataLength) { - SecureTransportContext* context = const_cast<SecureTransportContext*>(static_cast<const SecureTransportContext*>(connection)); - OSStatus retValue = noErr; - - SafeByteArray safeData; - safeData.resize(*dataLength); - memcpy(safeData.data(), data, safeData.size()); - - context->onDataForNetwork(safeData); - return retValue; + SecureTransportContext* context = const_cast<SecureTransportContext*>(static_cast<const SecureTransportContext*>(connection)); + OSStatus retValue = noErr; + + SafeByteArray safeData; + safeData.resize(*dataLength); + memcpy(safeData.data(), data, safeData.size()); + + context->onDataForNetwork(safeData); + return retValue; } -boost::shared_ptr<TLSError> SecureTransportContext::nativeToTLSError(OSStatus /* error */) { - boost::shared_ptr<TLSError> swiftenError; - swiftenError = boost::make_shared<TLSError>(); - return swiftenError; +std::shared_ptr<TLSError> SecureTransportContext::nativeToTLSError(OSStatus /* error */) { + std::shared_ptr<TLSError> swiftenError; + swiftenError = std::make_shared<TLSError>(); + return swiftenError; } -boost::shared_ptr<CertificateVerificationError> SecureTransportContext::CSSMErrorToVerificationError(OSStatus resultCode) { - boost::shared_ptr<CertificateVerificationError> error; - switch(resultCode) { - case CSSMERR_TP_NOT_TRUSTED: - SWIFT_LOG(debug) << "CSSM result code: CSSMERR_TP_NOT_TRUSTED" << std::endl; - error = boost::make_shared<CertificateVerificationError>(CertificateVerificationError::Untrusted); - break; - case CSSMERR_TP_CERT_NOT_VALID_YET: - SWIFT_LOG(debug) << "CSSM result code: CSSMERR_TP_CERT_NOT_VALID_YET" << std::endl; - error = boost::make_shared<CertificateVerificationError>(CertificateVerificationError::NotYetValid); - break; - case CSSMERR_TP_CERT_EXPIRED: - SWIFT_LOG(debug) << "CSSM result code: CSSMERR_TP_CERT_EXPIRED" << std::endl; - error = boost::make_shared<CertificateVerificationError>(CertificateVerificationError::Expired); - break; - case CSSMERR_TP_CERT_REVOKED: - SWIFT_LOG(debug) << "CSSM result code: CSSMERR_TP_CERT_REVOKED" << std::endl; - error = boost::make_shared<CertificateVerificationError>(CertificateVerificationError::Revoked); - break; - case CSSMERR_TP_VERIFY_ACTION_FAILED: - SWIFT_LOG(debug) << "CSSM result code: CSSMERR_TP_VERIFY_ACTION_FAILED" << std::endl; - break; - case CSSMERR_APPLETP_INCOMPLETE_REVOCATION_CHECK: - SWIFT_LOG(debug) << "CSSM result code: CSSMERR_APPLETP_INCOMPLETE_REVOCATION_CHECK" << std::endl; - if (checkCertificateRevocation_) { - error = boost::make_shared<CertificateVerificationError>(CertificateVerificationError::RevocationCheckFailed); - } - break; - case CSSMERR_APPLETP_OCSP_UNAVAILABLE: - SWIFT_LOG(debug) << "CSSM result code: CSSMERR_APPLETP_OCSP_UNAVAILABLE" << std::endl; - if (checkCertificateRevocation_) { - error = boost::make_shared<CertificateVerificationError>(CertificateVerificationError::RevocationCheckFailed); - } - break; - case CSSMERR_APPLETP_SSL_BAD_EXT_KEY_USE: - SWIFT_LOG(debug) << "CSSM result code: CSSMERR_APPLETP_SSL_BAD_EXT_KEY_USE" << std::endl; - error = boost::make_shared<CertificateVerificationError>(CertificateVerificationError::InvalidPurpose); - break; - default: - SWIFT_LOG(warning) << "unhandled CSSM error: " << resultCode << ", CSSM_TP_BASE_TP_ERROR: " << CSSM_TP_BASE_TP_ERROR << std::endl; - error = boost::make_shared<CertificateVerificationError>(CertificateVerificationError::UnknownError); - break; - } - return error; +std::shared_ptr<CertificateVerificationError> SecureTransportContext::CSSMErrorToVerificationError(OSStatus resultCode) { + std::shared_ptr<CertificateVerificationError> error; + switch(resultCode) { + case CSSMERR_TP_NOT_TRUSTED: + SWIFT_LOG(debug) << "CSSM result code: CSSMERR_TP_NOT_TRUSTED" << std::endl; + error = std::make_shared<CertificateVerificationError>(CertificateVerificationError::Untrusted); + break; + case CSSMERR_TP_CERT_NOT_VALID_YET: + SWIFT_LOG(debug) << "CSSM result code: CSSMERR_TP_CERT_NOT_VALID_YET" << std::endl; + error = std::make_shared<CertificateVerificationError>(CertificateVerificationError::NotYetValid); + break; + case CSSMERR_TP_CERT_EXPIRED: + SWIFT_LOG(debug) << "CSSM result code: CSSMERR_TP_CERT_EXPIRED" << std::endl; + error = std::make_shared<CertificateVerificationError>(CertificateVerificationError::Expired); + break; + case CSSMERR_TP_CERT_REVOKED: + SWIFT_LOG(debug) << "CSSM result code: CSSMERR_TP_CERT_REVOKED" << std::endl; + error = std::make_shared<CertificateVerificationError>(CertificateVerificationError::Revoked); + break; + case CSSMERR_TP_VERIFY_ACTION_FAILED: + SWIFT_LOG(debug) << "CSSM result code: CSSMERR_TP_VERIFY_ACTION_FAILED" << std::endl; + break; + case CSSMERR_APPLETP_INCOMPLETE_REVOCATION_CHECK: + SWIFT_LOG(debug) << "CSSM result code: CSSMERR_APPLETP_INCOMPLETE_REVOCATION_CHECK" << std::endl; + if (checkCertificateRevocation_) { + error = std::make_shared<CertificateVerificationError>(CertificateVerificationError::RevocationCheckFailed); + } + break; + case CSSMERR_APPLETP_OCSP_UNAVAILABLE: + SWIFT_LOG(debug) << "CSSM result code: CSSMERR_APPLETP_OCSP_UNAVAILABLE" << std::endl; + if (checkCertificateRevocation_) { + error = std::make_shared<CertificateVerificationError>(CertificateVerificationError::RevocationCheckFailed); + } + break; + case CSSMERR_APPLETP_SSL_BAD_EXT_KEY_USE: + SWIFT_LOG(debug) << "CSSM result code: CSSMERR_APPLETP_SSL_BAD_EXT_KEY_USE" << std::endl; + error = std::make_shared<CertificateVerificationError>(CertificateVerificationError::InvalidPurpose); + break; + default: + SWIFT_LOG(warning) << "unhandled CSSM error: " << resultCode << ", CSSM_TP_BASE_TP_ERROR: " << CSSM_TP_BASE_TP_ERROR << std::endl; + error = std::make_shared<CertificateVerificationError>(CertificateVerificationError::UnknownError); + break; + } + return error; } -void SecureTransportContext::fatalError(boost::shared_ptr<TLSError> error, boost::shared_ptr<CertificateVerificationError> certificateError) { - setState(Error); - if (sslContext_) { - SSLClose(sslContext_.get()); - } - verificationError_ = certificateError; - onError(error); +void SecureTransportContext::fatalError(std::shared_ptr<TLSError> error, std::shared_ptr<CertificateVerificationError> certificateError) { + setState(Error); + if (sslContext_) { + SSLClose(sslContext_.get()); + } + verificationError_ = certificateError; + onError(error); } } diff --git a/Swiften/TLS/SecureTransport/SecureTransportContextFactory.cpp b/Swiften/TLS/SecureTransport/SecureTransportContextFactory.cpp index ce19839..1fac1fb 100644 --- a/Swiften/TLS/SecureTransport/SecureTransportContextFactory.cpp +++ b/Swiften/TLS/SecureTransport/SecureTransportContextFactory.cpp @@ -23,22 +23,22 @@ SecureTransportContextFactory::~SecureTransportContextFactory() { } bool SecureTransportContextFactory::canCreate() const { - return true; + return true; } TLSContext* SecureTransportContextFactory::createTLSContext(const TLSOptions& /* tlsOptions */) { - return new SecureTransportContext(checkCertificateRevocation_); + return new SecureTransportContext(checkCertificateRevocation_); } void SecureTransportContextFactory::setCheckCertificateRevocation(bool b) { - checkCertificateRevocation_ = b; + checkCertificateRevocation_ = b; } void SecureTransportContextFactory::setDisconnectOnCardRemoval(bool b) { - disconnectOnCardRemoval_ = b; - if (disconnectOnCardRemoval_) { - SWIFT_LOG(warning) << "Smart cards have not been tested yet" << std::endl; - } + disconnectOnCardRemoval_ = b; + if (disconnectOnCardRemoval_) { + SWIFT_LOG(warning) << "Smart cards have not been tested yet" << std::endl; + } } } diff --git a/Swiften/TLS/SecureTransport/SecureTransportContextFactory.h b/Swiften/TLS/SecureTransport/SecureTransportContextFactory.h index f490768..74c598f 100644 --- a/Swiften/TLS/SecureTransport/SecureTransportContextFactory.h +++ b/Swiften/TLS/SecureTransport/SecureTransportContextFactory.h @@ -11,19 +11,19 @@ namespace Swift { class SecureTransportContextFactory : public TLSContextFactory { - public: - SecureTransportContextFactory(); - virtual ~SecureTransportContextFactory(); + public: + SecureTransportContextFactory(); + virtual ~SecureTransportContextFactory(); - virtual bool canCreate() const; + virtual bool canCreate() const; - virtual TLSContext* createTLSContext(const TLSOptions& tlsOptions); - virtual void setCheckCertificateRevocation(bool b); - virtual void setDisconnectOnCardRemoval(bool b); + virtual TLSContext* createTLSContext(const TLSOptions& tlsOptions); + virtual void setCheckCertificateRevocation(bool b); + virtual void setDisconnectOnCardRemoval(bool b); - private: - bool checkCertificateRevocation_; - bool disconnectOnCardRemoval_; + private: + bool checkCertificateRevocation_; + bool disconnectOnCardRemoval_; }; } diff --git a/Swiften/TLS/ServerIdentityVerifier.cpp b/Swiften/TLS/ServerIdentityVerifier.cpp index 19d7489..226e94b 100644 --- a/Swiften/TLS/ServerIdentityVerifier.cpp +++ b/Swiften/TLS/ServerIdentityVerifier.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2010 Isode Limited. + * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ @@ -8,88 +8,87 @@ #include <boost/algorithm/string.hpp> -#include <Swiften/Base/foreach.h> #include <Swiften/IDN/IDNConverter.h> namespace Swift { ServerIdentityVerifier::ServerIdentityVerifier(const JID& jid, IDNConverter* idnConverter) : domainValid(false) { - domain = jid.getDomain(); - boost::optional<std::string> domainResult = idnConverter->getIDNAEncoded(domain); - if (!!domainResult) { - encodedDomain = *domainResult; - domainValid = true; - } + domain = jid.getDomain(); + boost::optional<std::string> domainResult = idnConverter->getIDNAEncoded(domain); + if (!!domainResult) { + encodedDomain = *domainResult; + domainValid = true; + } } bool ServerIdentityVerifier::certificateVerifies(Certificate::ref certificate) { - bool hasSAN = false; + bool hasSAN = false; - if (certificate == NULL) { - return false; - } - // DNS names - std::vector<std::string> dnsNames = certificate->getDNSNames(); - foreach (const std::string& dnsName, dnsNames) { - if (matchesDomain(dnsName)) { - return true; - } - } - hasSAN |= !dnsNames.empty(); + if (certificate == nullptr) { + return false; + } + // DNS names + std::vector<std::string> dnsNames = certificate->getDNSNames(); + for (const auto& dnsName : dnsNames) { + if (matchesDomain(dnsName)) { + return true; + } + } + hasSAN |= !dnsNames.empty(); - // SRV names - std::vector<std::string> srvNames = certificate->getSRVNames(); - foreach (const std::string& srvName, srvNames) { - // Only match SRV names that begin with the service; this isn't required per - // spec, but we're being purist about this. - if (boost::starts_with(srvName, "_xmpp-client.") && matchesDomain(srvName.substr(std::string("_xmpp-client.").size(), srvName.npos))) { - return true; - } - } - hasSAN |= !srvNames.empty(); + // SRV names + std::vector<std::string> srvNames = certificate->getSRVNames(); + for (const auto& srvName : srvNames) { + // Only match SRV names that begin with the service; this isn't required per + // spec, but we're being purist about this. + if (boost::starts_with(srvName, "_xmpp-client.") && matchesDomain(srvName.substr(std::string("_xmpp-client.").size(), srvName.npos))) { + return true; + } + } + hasSAN |= !srvNames.empty(); - // XmppAddr - std::vector<std::string> xmppAddresses = certificate->getXMPPAddresses(); - foreach (const std::string& xmppAddress, xmppAddresses) { - if (matchesAddress(xmppAddress)) { - return true; - } - } - hasSAN |= !xmppAddresses.empty(); + // XmppAddr + std::vector<std::string> xmppAddresses = certificate->getXMPPAddresses(); + for (const auto& xmppAddress : xmppAddresses) { + if (matchesAddress(xmppAddress)) { + return true; + } + } + hasSAN |= !xmppAddresses.empty(); - // CommonNames. Only check this if there was no SAN (according to spec). - if (!hasSAN) { - std::vector<std::string> commonNames = certificate->getCommonNames(); - foreach (const std::string& commonName, commonNames) { - if (matchesDomain(commonName)) { - return true; - } - } - } + // CommonNames. Only check this if there was no SAN (according to spec). + if (!hasSAN) { + std::vector<std::string> commonNames = certificate->getCommonNames(); + for (const auto& commonName : commonNames) { + if (matchesDomain(commonName)) { + return true; + } + } + } - return false; + return false; } bool ServerIdentityVerifier::matchesDomain(const std::string& s) const { - if (!domainValid) { - return false; - } - if (boost::starts_with(s, "*.")) { - std::string matchString(s.substr(2, s.npos)); - std::string matchDomain = encodedDomain; - size_t dotIndex = matchDomain.find('.'); - if (dotIndex != matchDomain.npos) { - matchDomain = matchDomain.substr(dotIndex + 1, matchDomain.npos); - } - return matchString == matchDomain; - } - else { - return s == encodedDomain; - } + if (!domainValid) { + return false; + } + if (boost::starts_with(s, "*.")) { + std::string matchString(s.substr(2, s.npos)); + std::string matchDomain = encodedDomain; + size_t dotIndex = matchDomain.find('.'); + if (dotIndex != matchDomain.npos) { + matchDomain = matchDomain.substr(dotIndex + 1, matchDomain.npos); + } + return matchString == matchDomain; + } + else { + return s == encodedDomain; + } } bool ServerIdentityVerifier::matchesAddress(const std::string& s) const { - return s == domain; + return s == domain; } } diff --git a/Swiften/TLS/ServerIdentityVerifier.h b/Swiften/TLS/ServerIdentityVerifier.h index 79a3c17..f40c683 100644 --- a/Swiften/TLS/ServerIdentityVerifier.h +++ b/Swiften/TLS/ServerIdentityVerifier.h @@ -1,34 +1,34 @@ /* - * Copyright (c) 2010 Isode Limited. + * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #pragma once -#include <boost/shared_ptr.hpp> - +#include <memory> #include <string> + #include <Swiften/Base/API.h> #include <Swiften/JID/JID.h> #include <Swiften/TLS/Certificate.h> namespace Swift { - class IDNConverter; + class IDNConverter; - class SWIFTEN_API ServerIdentityVerifier { - public: - ServerIdentityVerifier(const JID& jid, IDNConverter* idnConverter); + class SWIFTEN_API ServerIdentityVerifier { + public: + ServerIdentityVerifier(const JID& jid, IDNConverter* idnConverter); - bool certificateVerifies(Certificate::ref); + bool certificateVerifies(Certificate::ref); - private: - bool matchesDomain(const std::string&) const ; - bool matchesAddress(const std::string&) const; + private: + bool matchesDomain(const std::string&) const ; + bool matchesAddress(const std::string&) const; - private: - std::string domain; - std::string encodedDomain; - bool domainValid; - }; + private: + std::string domain; + std::string encodedDomain; + bool domainValid; + }; } diff --git a/Swiften/TLS/SimpleCertificate.h b/Swiften/TLS/SimpleCertificate.h index 88688c0..08cf1e3 100644 --- a/Swiften/TLS/SimpleCertificate.h +++ b/Swiften/TLS/SimpleCertificate.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2010-2015 Isode Limited. + * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ @@ -7,71 +7,72 @@ #pragma once #include <string> + #include <Swiften/Base/API.h> #include <Swiften/TLS/Certificate.h> namespace Swift { - class SWIFTEN_API SimpleCertificate : public Certificate { - public: - typedef boost::shared_ptr<SimpleCertificate> ref; - - void setSubjectName(const std::string& name) { - subjectName = name; - } - - std::string getSubjectName() const { - return subjectName; - } - - std::vector<std::string> getCommonNames() const { - return commonNames; - } - - void addCommonName(const std::string& name) { - commonNames.push_back(name); - } - - void addSRVName(const std::string& name) { - srvNames.push_back(name); - } - - void addDNSName(const std::string& name) { - dnsNames.push_back(name); - } - - void addXMPPAddress(const std::string& addr) { - xmppAddresses.push_back(addr); - } - - std::vector<std::string> getSRVNames() const { - return srvNames; - } - - std::vector<std::string> getDNSNames() const { - return dnsNames; - } - - std::vector<std::string> getXMPPAddresses() const { - return xmppAddresses; - } - - ByteArray toDER() const { - return der; - } - - void setDER(const ByteArray& der) { - this->der = der; - } - - private: - void parse(); - - private: - std::string subjectName; - ByteArray der; - std::vector<std::string> commonNames; - std::vector<std::string> dnsNames; - std::vector<std::string> xmppAddresses; - std::vector<std::string> srvNames; - }; + class SWIFTEN_API SimpleCertificate : public Certificate { + public: + typedef std::shared_ptr<SimpleCertificate> ref; + + void setSubjectName(const std::string& name) { + subjectName = name; + } + + std::string getSubjectName() const { + return subjectName; + } + + std::vector<std::string> getCommonNames() const { + return commonNames; + } + + void addCommonName(const std::string& name) { + commonNames.push_back(name); + } + + void addSRVName(const std::string& name) { + srvNames.push_back(name); + } + + void addDNSName(const std::string& name) { + dnsNames.push_back(name); + } + + void addXMPPAddress(const std::string& addr) { + xmppAddresses.push_back(addr); + } + + std::vector<std::string> getSRVNames() const { + return srvNames; + } + + std::vector<std::string> getDNSNames() const { + return dnsNames; + } + + std::vector<std::string> getXMPPAddresses() const { + return xmppAddresses; + } + + ByteArray toDER() const { + return der; + } + + void setDER(const ByteArray& der) { + this->der = der; + } + + private: + void parse(); + + private: + std::string subjectName; + ByteArray der; + std::vector<std::string> commonNames; + std::vector<std::string> dnsNames; + std::vector<std::string> xmppAddresses; + std::vector<std::string> srvNames; + }; } diff --git a/Swiften/TLS/TLSContext.cpp b/Swiften/TLS/TLSContext.cpp index 9d8b166..2763547 100644 --- a/Swiften/TLS/TLSContext.cpp +++ b/Swiften/TLS/TLSContext.cpp @@ -12,8 +12,8 @@ TLSContext::~TLSContext() { } Certificate::ref TLSContext::getPeerCertificate() const { - std::vector<Certificate::ref> chain = getPeerCertificateChain(); - return chain.empty() ? Certificate::ref() : chain[0]; + std::vector<Certificate::ref> chain = getPeerCertificateChain(); + return chain.empty() ? Certificate::ref() : chain[0]; } } diff --git a/Swiften/TLS/TLSContext.h b/Swiften/TLS/TLSContext.h index c5703e7..79e3485 100644 --- a/Swiften/TLS/TLSContext.h +++ b/Swiften/TLS/TLSContext.h @@ -1,44 +1,45 @@ /* - * Copyright (c) 2010-2015 Isode Limited. + * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #pragma once -#include <Swiften/Base/boost_bsignals.h> -#include <boost/shared_ptr.hpp> +#include <memory> + +#include <boost/signals2.hpp> #include <Swiften/Base/API.h> #include <Swiften/Base/SafeByteArray.h> #include <Swiften/TLS/Certificate.h> -#include <Swiften/TLS/CertificateWithKey.h> #include <Swiften/TLS/CertificateVerificationError.h> +#include <Swiften/TLS/CertificateWithKey.h> #include <Swiften/TLS/TLSError.h> namespace Swift { - class SWIFTEN_API TLSContext { - public: - virtual ~TLSContext(); + class SWIFTEN_API TLSContext { + public: + virtual ~TLSContext(); - virtual void connect() = 0; + virtual void connect() = 0; - virtual bool setClientCertificate(CertificateWithKey::ref cert) = 0; + virtual bool setClientCertificate(CertificateWithKey::ref cert) = 0; - virtual void handleDataFromNetwork(const SafeByteArray&) = 0; - virtual void handleDataFromApplication(const SafeByteArray&) = 0; + virtual void handleDataFromNetwork(const SafeByteArray&) = 0; + virtual void handleDataFromApplication(const SafeByteArray&) = 0; - Certificate::ref getPeerCertificate() const; - virtual std::vector<Certificate::ref> getPeerCertificateChain() const = 0; - virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const = 0; + Certificate::ref getPeerCertificate() const; + virtual std::vector<Certificate::ref> getPeerCertificateChain() const = 0; + virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const = 0; - virtual ByteArray getFinishMessage() const = 0; + virtual ByteArray getFinishMessage() const = 0; - public: - boost::signal<void (const SafeByteArray&)> onDataForNetwork; - boost::signal<void (const SafeByteArray&)> onDataForApplication; - boost::signal<void (boost::shared_ptr<TLSError>)> onError; - boost::signal<void ()> onConnected; - }; + public: + boost::signals2::signal<void (const SafeByteArray&)> onDataForNetwork; + boost::signals2::signal<void (const SafeByteArray&)> onDataForApplication; + boost::signals2::signal<void (std::shared_ptr<TLSError>)> onError; + boost::signals2::signal<void ()> onConnected; + }; } diff --git a/Swiften/TLS/TLSContextFactory.h b/Swiften/TLS/TLSContextFactory.h index b67c34f..d2ffe15 100644 --- a/Swiften/TLS/TLSContextFactory.h +++ b/Swiften/TLS/TLSContextFactory.h @@ -10,16 +10,16 @@ #include <Swiften/TLS/TLSOptions.h> namespace Swift { - class TLSContext; + class TLSContext; - class SWIFTEN_API TLSContextFactory { - public: - virtual ~TLSContextFactory(); + class SWIFTEN_API TLSContextFactory { + public: + virtual ~TLSContextFactory(); - virtual bool canCreate() const = 0; + virtual bool canCreate() const = 0; - virtual TLSContext* createTLSContext(const TLSOptions& tlsOptions) = 0; - virtual void setCheckCertificateRevocation(bool b) = 0; - virtual void setDisconnectOnCardRemoval(bool b) = 0; - }; + virtual TLSContext* createTLSContext(const TLSOptions& tlsOptions) = 0; + virtual void setCheckCertificateRevocation(bool b) = 0; + virtual void setDisconnectOnCardRemoval(bool b) = 0; + }; } diff --git a/Swiften/TLS/TLSError.h b/Swiften/TLS/TLSError.h index 27e4b03..ae775e6 100644 --- a/Swiften/TLS/TLSError.h +++ b/Swiften/TLS/TLSError.h @@ -1,32 +1,33 @@ /* - * Copyright (c) 2012-2015 Isode Limited. + * Copyright (c) 2012-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #pragma once +#include <memory> + #include <Swiften/Base/API.h> -#include <boost/shared_ptr.hpp> #include <Swiften/Base/Error.h> namespace Swift { - class SWIFTEN_API TLSError : public Error { - public: - typedef boost::shared_ptr<TLSError> ref; + class SWIFTEN_API TLSError : public Error { + public: + typedef std::shared_ptr<TLSError> ref; - enum Type { - UnknownError, - CertificateCardRemoved - }; + enum Type { + UnknownError, + CertificateCardRemoved + }; - TLSError(Type type = UnknownError) : type(type) {} + TLSError(Type type = UnknownError) : type(type) {} - Type getType() const { - return type; - } + Type getType() const { + return type; + } - private: - Type type; - }; + private: + Type type; + }; } diff --git a/Swiften/TLS/TLSOptions.h b/Swiften/TLS/TLSOptions.h index ca84829..dd7e920 100644 --- a/Swiften/TLS/TLSOptions.h +++ b/Swiften/TLS/TLSOptions.h @@ -8,18 +8,18 @@ namespace Swift { - struct TLSOptions { - TLSOptions() : schannelTLS1_0Workaround(false) { + struct TLSOptions { + TLSOptions() : schannelTLS1_0Workaround(false) { - } + } - /** - * A bug in the Windows SChannel TLS stack, combined with - * overly-restrictive server stacks means it's sometimes necessary to - * not use TLS>1.0. This option has no effect unless compiled on - * Windows against SChannel (OpenSSL users are unaffected). - */ - bool schannelTLS1_0Workaround; + /** + * A bug in the Windows SChannel TLS stack, combined with + * overly-restrictive server stacks means it's sometimes necessary to + * not use TLS>1.0. This option has no effect unless compiled on + * Windows against SChannel (OpenSSL users are unaffected). + */ + bool schannelTLS1_0Workaround; - }; + }; } diff --git a/Swiften/TLS/UnitTest/CertificateTest.cpp b/Swiften/TLS/UnitTest/CertificateTest.cpp index 8e9c205..2483dae 100644 --- a/Swiften/TLS/UnitTest/CertificateTest.cpp +++ b/Swiften/TLS/UnitTest/CertificateTest.cpp @@ -1,34 +1,34 @@ /* - * Copyright (c) 2010-2013 Isode Limited. + * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ -#include <Swiften/Base/ByteArray.h> +#include <memory> #include <cppunit/extensions/HelperMacros.h> #include <cppunit/extensions/TestFactoryRegistry.h> -#include <boost/smart_ptr/make_shared.hpp> -#include <Swiften/TLS/Certificate.h> -#include <Swiften/TLS/SimpleCertificate.h> +#include <Swiften/Base/ByteArray.h> #include <Swiften/Crypto/CryptoProvider.h> #include <Swiften/Crypto/PlatformCryptoProvider.h> +#include <Swiften/TLS/Certificate.h> +#include <Swiften/TLS/SimpleCertificate.h> using namespace Swift; class CertificateTest : public CppUnit::TestFixture { - CPPUNIT_TEST_SUITE(CertificateTest); - CPPUNIT_TEST(testGetSHA1Fingerprint); - CPPUNIT_TEST_SUITE_END(); + CPPUNIT_TEST_SUITE(CertificateTest); + CPPUNIT_TEST(testGetSHA1Fingerprint); + CPPUNIT_TEST_SUITE_END(); - public: - void testGetSHA1Fingerprint() { - SimpleCertificate::ref testling = boost::make_shared<SimpleCertificate>(); - testling->setDER(createByteArray("abcdefg")); + public: + void testGetSHA1Fingerprint() { + SimpleCertificate::ref testling = std::make_shared<SimpleCertificate>(); + testling->setDER(createByteArray("abcdefg")); - CPPUNIT_ASSERT_EQUAL(std::string("2f:b5:e1:34:19:fc:89:24:68:65:e7:a3:24:f4:76:ec:62:4e:87:40"), Certificate::getSHA1Fingerprint(testling, boost::shared_ptr<CryptoProvider>(PlatformCryptoProvider::create()).get())); - } + CPPUNIT_ASSERT_EQUAL(std::string("2f:b5:e1:34:19:fc:89:24:68:65:e7:a3:24:f4:76:ec:62:4e:87:40"), Certificate::getSHA1Fingerprint(testling, std::shared_ptr<CryptoProvider>(PlatformCryptoProvider::create()).get())); + } }; CPPUNIT_TEST_SUITE_REGISTRATION(CertificateTest); diff --git a/Swiften/TLS/UnitTest/ServerIdentityVerifierTest.cpp b/Swiften/TLS/UnitTest/ServerIdentityVerifierTest.cpp index a82cd2e..30fe423 100644 --- a/Swiften/TLS/UnitTest/ServerIdentityVerifierTest.cpp +++ b/Swiften/TLS/UnitTest/ServerIdentityVerifierTest.cpp @@ -1,178 +1,178 @@ /* - * Copyright (c) 2010 Isode Limited. + * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ -#include <Swiften/Base/ByteArray.h> +#include <vector> #include <cppunit/extensions/HelperMacros.h> #include <cppunit/extensions/TestFactoryRegistry.h> -#include <vector> -#include <Swiften/TLS/ServerIdentityVerifier.h> -#include <Swiften/TLS/SimpleCertificate.h> +#include <Swiften/Base/ByteArray.h> #include <Swiften/IDN/IDNConverter.h> #include <Swiften/IDN/PlatformIDNConverter.h> +#include <Swiften/TLS/ServerIdentityVerifier.h> +#include <Swiften/TLS/SimpleCertificate.h> using namespace Swift; class ServerIdentityVerifierTest : public CppUnit::TestFixture { - CPPUNIT_TEST_SUITE(ServerIdentityVerifierTest); - CPPUNIT_TEST(testCertificateVerifies_WithoutMatchingDNSName); - CPPUNIT_TEST(testCertificateVerifies_WithMatchingDNSName); - CPPUNIT_TEST(testCertificateVerifies_WithSecondMatchingDNSName); - CPPUNIT_TEST(testCertificateVerifies_WithMatchingInternationalDNSName); - CPPUNIT_TEST(testCertificateVerifies_WithMatchingDNSNameWithWildcard); - CPPUNIT_TEST(testCertificateVerifies_WithMatchingDNSNameWithWildcardMatchingNoComponents); - CPPUNIT_TEST(testCertificateVerifies_WithDNSNameWithWildcardMatchingTwoComponents); - CPPUNIT_TEST(testCertificateVerifies_WithMatchingSRVNameWithoutService); - CPPUNIT_TEST(testCertificateVerifies_WithMatchingSRVNameWithService); - CPPUNIT_TEST(testCertificateVerifies_WithMatchingSRVNameWithServiceAndWildcard); - CPPUNIT_TEST(testCertificateVerifies_WithMatchingSRVNameWithDifferentService); - CPPUNIT_TEST(testCertificateVerifies_WithMatchingXmppAddr); - CPPUNIT_TEST(testCertificateVerifies_WithMatchingXmppAddrWithWildcard); - CPPUNIT_TEST(testCertificateVerifies_WithMatchingInternationalXmppAddr); - CPPUNIT_TEST(testCertificateVerifies_WithMatchingCNWithoutSAN); - CPPUNIT_TEST(testCertificateVerifies_WithMatchingCNWithSAN); - CPPUNIT_TEST_SUITE_END(); - - public: - void setUp() { - idnConverter = boost::shared_ptr<IDNConverter>(PlatformIDNConverter::create()); - } - - void testCertificateVerifies_WithoutMatchingDNSName() { - ServerIdentityVerifier testling(JID("foo@bar.com/baz"), idnConverter.get()); - SimpleCertificate::ref certificate(new SimpleCertificate()); - certificate->addDNSName("foo.com"); - - CPPUNIT_ASSERT(!testling.certificateVerifies(certificate)); - } - - void testCertificateVerifies_WithMatchingDNSName() { - ServerIdentityVerifier testling(JID("foo@bar.com/baz"), idnConverter.get()); - SimpleCertificate::ref certificate(new SimpleCertificate()); - certificate->addDNSName("bar.com"); - - CPPUNIT_ASSERT(testling.certificateVerifies(certificate)); - } - - void testCertificateVerifies_WithSecondMatchingDNSName() { - ServerIdentityVerifier testling(JID("foo@bar.com/baz"), idnConverter.get()); - SimpleCertificate::ref certificate(new SimpleCertificate()); - certificate->addDNSName("foo.com"); - certificate->addDNSName("bar.com"); - - CPPUNIT_ASSERT(testling.certificateVerifies(certificate)); - } - - void testCertificateVerifies_WithMatchingInternationalDNSName() { - ServerIdentityVerifier testling(JID("foo@tron\xc3\xa7on.com/baz"), idnConverter.get()); - SimpleCertificate::ref certificate(new SimpleCertificate()); - certificate->addDNSName("xn--tronon-zua.com"); - - CPPUNIT_ASSERT(testling.certificateVerifies(certificate)); - } - - void testCertificateVerifies_WithMatchingDNSNameWithWildcard() { - ServerIdentityVerifier testling(JID("foo@im.bar.com/baz"), idnConverter.get()); - SimpleCertificate::ref certificate(new SimpleCertificate()); - certificate->addDNSName("*.bar.com"); - - CPPUNIT_ASSERT(testling.certificateVerifies(certificate)); - } - - void testCertificateVerifies_WithMatchingDNSNameWithWildcardMatchingNoComponents() { - ServerIdentityVerifier testling(JID("foo@bar.com/baz"), idnConverter.get()); - SimpleCertificate::ref certificate(new SimpleCertificate()); - certificate->addDNSName("*.bar.com"); - - CPPUNIT_ASSERT(!testling.certificateVerifies(certificate)); - } - - void testCertificateVerifies_WithDNSNameWithWildcardMatchingTwoComponents() { - ServerIdentityVerifier testling(JID("foo@xmpp.im.bar.com/baz"), idnConverter.get()); - SimpleCertificate::ref certificate(new SimpleCertificate()); - certificate->addDNSName("*.bar.com"); - - CPPUNIT_ASSERT(!testling.certificateVerifies(certificate)); - } - - void testCertificateVerifies_WithMatchingSRVNameWithoutService() { - ServerIdentityVerifier testling(JID("foo@bar.com/baz"), idnConverter.get()); - SimpleCertificate::ref certificate(new SimpleCertificate()); - certificate->addSRVName("bar.com"); - - CPPUNIT_ASSERT(!testling.certificateVerifies(certificate)); - } - - void testCertificateVerifies_WithMatchingSRVNameWithService() { - ServerIdentityVerifier testling(JID("foo@bar.com/baz"), idnConverter.get()); - SimpleCertificate::ref certificate(new SimpleCertificate()); - certificate->addSRVName("_xmpp-client.bar.com"); - - CPPUNIT_ASSERT(testling.certificateVerifies(certificate)); - } - - void testCertificateVerifies_WithMatchingSRVNameWithServiceAndWildcard() { - ServerIdentityVerifier testling(JID("foo@im.bar.com/baz"), idnConverter.get()); - SimpleCertificate::ref certificate(new SimpleCertificate()); - certificate->addSRVName("_xmpp-client.*.bar.com"); - - CPPUNIT_ASSERT(testling.certificateVerifies(certificate)); - } - - void testCertificateVerifies_WithMatchingSRVNameWithDifferentService() { - ServerIdentityVerifier testling(JID("foo@bar.com/baz"), idnConverter.get()); - SimpleCertificate::ref certificate(new SimpleCertificate()); - certificate->addSRVName("_xmpp-server.bar.com"); - - CPPUNIT_ASSERT(!testling.certificateVerifies(certificate)); - } - - void testCertificateVerifies_WithMatchingXmppAddr() { - ServerIdentityVerifier testling(JID("foo@bar.com/baz"), idnConverter.get()); - SimpleCertificate::ref certificate(new SimpleCertificate()); - certificate->addXMPPAddress("bar.com"); + CPPUNIT_TEST_SUITE(ServerIdentityVerifierTest); + CPPUNIT_TEST(testCertificateVerifies_WithoutMatchingDNSName); + CPPUNIT_TEST(testCertificateVerifies_WithMatchingDNSName); + CPPUNIT_TEST(testCertificateVerifies_WithSecondMatchingDNSName); + CPPUNIT_TEST(testCertificateVerifies_WithMatchingInternationalDNSName); + CPPUNIT_TEST(testCertificateVerifies_WithMatchingDNSNameWithWildcard); + CPPUNIT_TEST(testCertificateVerifies_WithMatchingDNSNameWithWildcardMatchingNoComponents); + CPPUNIT_TEST(testCertificateVerifies_WithDNSNameWithWildcardMatchingTwoComponents); + CPPUNIT_TEST(testCertificateVerifies_WithMatchingSRVNameWithoutService); + CPPUNIT_TEST(testCertificateVerifies_WithMatchingSRVNameWithService); + CPPUNIT_TEST(testCertificateVerifies_WithMatchingSRVNameWithServiceAndWildcard); + CPPUNIT_TEST(testCertificateVerifies_WithMatchingSRVNameWithDifferentService); + CPPUNIT_TEST(testCertificateVerifies_WithMatchingXmppAddr); + CPPUNIT_TEST(testCertificateVerifies_WithMatchingXmppAddrWithWildcard); + CPPUNIT_TEST(testCertificateVerifies_WithMatchingInternationalXmppAddr); + CPPUNIT_TEST(testCertificateVerifies_WithMatchingCNWithoutSAN); + CPPUNIT_TEST(testCertificateVerifies_WithMatchingCNWithSAN); + CPPUNIT_TEST_SUITE_END(); + + public: + void setUp() { + idnConverter = std::shared_ptr<IDNConverter>(PlatformIDNConverter::create()); + } + + void testCertificateVerifies_WithoutMatchingDNSName() { + ServerIdentityVerifier testling(JID("foo@bar.com/baz"), idnConverter.get()); + SimpleCertificate::ref certificate(new SimpleCertificate()); + certificate->addDNSName("foo.com"); + + CPPUNIT_ASSERT(!testling.certificateVerifies(certificate)); + } + + void testCertificateVerifies_WithMatchingDNSName() { + ServerIdentityVerifier testling(JID("foo@bar.com/baz"), idnConverter.get()); + SimpleCertificate::ref certificate(new SimpleCertificate()); + certificate->addDNSName("bar.com"); + + CPPUNIT_ASSERT(testling.certificateVerifies(certificate)); + } + + void testCertificateVerifies_WithSecondMatchingDNSName() { + ServerIdentityVerifier testling(JID("foo@bar.com/baz"), idnConverter.get()); + SimpleCertificate::ref certificate(new SimpleCertificate()); + certificate->addDNSName("foo.com"); + certificate->addDNSName("bar.com"); + + CPPUNIT_ASSERT(testling.certificateVerifies(certificate)); + } + + void testCertificateVerifies_WithMatchingInternationalDNSName() { + ServerIdentityVerifier testling(JID("foo@tron\xc3\xa7on.com/baz"), idnConverter.get()); + SimpleCertificate::ref certificate(new SimpleCertificate()); + certificate->addDNSName("xn--tronon-zua.com"); + + CPPUNIT_ASSERT(testling.certificateVerifies(certificate)); + } + + void testCertificateVerifies_WithMatchingDNSNameWithWildcard() { + ServerIdentityVerifier testling(JID("foo@im.bar.com/baz"), idnConverter.get()); + SimpleCertificate::ref certificate(new SimpleCertificate()); + certificate->addDNSName("*.bar.com"); + + CPPUNIT_ASSERT(testling.certificateVerifies(certificate)); + } + + void testCertificateVerifies_WithMatchingDNSNameWithWildcardMatchingNoComponents() { + ServerIdentityVerifier testling(JID("foo@bar.com/baz"), idnConverter.get()); + SimpleCertificate::ref certificate(new SimpleCertificate()); + certificate->addDNSName("*.bar.com"); + + CPPUNIT_ASSERT(!testling.certificateVerifies(certificate)); + } + + void testCertificateVerifies_WithDNSNameWithWildcardMatchingTwoComponents() { + ServerIdentityVerifier testling(JID("foo@xmpp.im.bar.com/baz"), idnConverter.get()); + SimpleCertificate::ref certificate(new SimpleCertificate()); + certificate->addDNSName("*.bar.com"); + + CPPUNIT_ASSERT(!testling.certificateVerifies(certificate)); + } + + void testCertificateVerifies_WithMatchingSRVNameWithoutService() { + ServerIdentityVerifier testling(JID("foo@bar.com/baz"), idnConverter.get()); + SimpleCertificate::ref certificate(new SimpleCertificate()); + certificate->addSRVName("bar.com"); + + CPPUNIT_ASSERT(!testling.certificateVerifies(certificate)); + } + + void testCertificateVerifies_WithMatchingSRVNameWithService() { + ServerIdentityVerifier testling(JID("foo@bar.com/baz"), idnConverter.get()); + SimpleCertificate::ref certificate(new SimpleCertificate()); + certificate->addSRVName("_xmpp-client.bar.com"); + + CPPUNIT_ASSERT(testling.certificateVerifies(certificate)); + } + + void testCertificateVerifies_WithMatchingSRVNameWithServiceAndWildcard() { + ServerIdentityVerifier testling(JID("foo@im.bar.com/baz"), idnConverter.get()); + SimpleCertificate::ref certificate(new SimpleCertificate()); + certificate->addSRVName("_xmpp-client.*.bar.com"); + + CPPUNIT_ASSERT(testling.certificateVerifies(certificate)); + } + + void testCertificateVerifies_WithMatchingSRVNameWithDifferentService() { + ServerIdentityVerifier testling(JID("foo@bar.com/baz"), idnConverter.get()); + SimpleCertificate::ref certificate(new SimpleCertificate()); + certificate->addSRVName("_xmpp-server.bar.com"); + + CPPUNIT_ASSERT(!testling.certificateVerifies(certificate)); + } + + void testCertificateVerifies_WithMatchingXmppAddr() { + ServerIdentityVerifier testling(JID("foo@bar.com/baz"), idnConverter.get()); + SimpleCertificate::ref certificate(new SimpleCertificate()); + certificate->addXMPPAddress("bar.com"); - CPPUNIT_ASSERT(testling.certificateVerifies(certificate)); - } + CPPUNIT_ASSERT(testling.certificateVerifies(certificate)); + } - void testCertificateVerifies_WithMatchingXmppAddrWithWildcard() { - ServerIdentityVerifier testling(JID("foo@im.bar.com/baz"), idnConverter.get()); - SimpleCertificate::ref certificate(new SimpleCertificate()); - certificate->addXMPPAddress("*.bar.com"); + void testCertificateVerifies_WithMatchingXmppAddrWithWildcard() { + ServerIdentityVerifier testling(JID("foo@im.bar.com/baz"), idnConverter.get()); + SimpleCertificate::ref certificate(new SimpleCertificate()); + certificate->addXMPPAddress("*.bar.com"); - CPPUNIT_ASSERT(!testling.certificateVerifies(certificate)); - } + CPPUNIT_ASSERT(!testling.certificateVerifies(certificate)); + } - void testCertificateVerifies_WithMatchingInternationalXmppAddr() { - ServerIdentityVerifier testling(JID("foo@tron\xc3\xa7.com/baz"), idnConverter.get()); - SimpleCertificate::ref certificate(new SimpleCertificate()); - certificate->addXMPPAddress("tron\xc3\xa7.com"); + void testCertificateVerifies_WithMatchingInternationalXmppAddr() { + ServerIdentityVerifier testling(JID("foo@tron\xc3\xa7.com/baz"), idnConverter.get()); + SimpleCertificate::ref certificate(new SimpleCertificate()); + certificate->addXMPPAddress("tron\xc3\xa7.com"); - CPPUNIT_ASSERT(testling.certificateVerifies(certificate)); - } + CPPUNIT_ASSERT(testling.certificateVerifies(certificate)); + } - void testCertificateVerifies_WithMatchingCNWithoutSAN() { - ServerIdentityVerifier testling(JID("foo@bar.com/baz"), idnConverter.get()); - SimpleCertificate::ref certificate(new SimpleCertificate()); - certificate->addCommonName("bar.com"); + void testCertificateVerifies_WithMatchingCNWithoutSAN() { + ServerIdentityVerifier testling(JID("foo@bar.com/baz"), idnConverter.get()); + SimpleCertificate::ref certificate(new SimpleCertificate()); + certificate->addCommonName("bar.com"); - CPPUNIT_ASSERT(testling.certificateVerifies(certificate)); - } + CPPUNIT_ASSERT(testling.certificateVerifies(certificate)); + } - void testCertificateVerifies_WithMatchingCNWithSAN() { - ServerIdentityVerifier testling(JID("foo@bar.com/baz"), idnConverter.get()); - SimpleCertificate::ref certificate(new SimpleCertificate()); - certificate->addSRVName("foo.com"); - certificate->addCommonName("bar.com"); + void testCertificateVerifies_WithMatchingCNWithSAN() { + ServerIdentityVerifier testling(JID("foo@bar.com/baz"), idnConverter.get()); + SimpleCertificate::ref certificate(new SimpleCertificate()); + certificate->addSRVName("foo.com"); + certificate->addCommonName("bar.com"); - CPPUNIT_ASSERT(!testling.certificateVerifies(certificate)); - } + CPPUNIT_ASSERT(!testling.certificateVerifies(certificate)); + } - boost::shared_ptr<IDNConverter> idnConverter; + std::shared_ptr<IDNConverter> idnConverter; }; CPPUNIT_TEST_SUITE_REGISTRATION(ServerIdentityVerifierTest); |