diff options
-rw-r--r-- | Swift/Controllers/MainController.cpp | 2 | ||||
-rw-r--r-- | Swift/QtUI/CAPICertificateSelector.cpp | 101 | ||||
-rw-r--r-- | Swift/QtUI/SConscript | 1 | ||||
-rw-r--r-- | Swiften/Client/CoreClient.cpp | 2 | ||||
-rw-r--r-- | Swiften/SConscript | 4 | ||||
-rw-r--r-- | Swiften/TLS/CAPICertificate.cpp | 278 | ||||
-rw-r--r-- | Swiften/TLS/CAPICertificate.h | 6 | ||||
-rw-r--r-- | Swiften/TLS/SConscript | 1 | ||||
-rw-r--r-- | Swiften/TLS/Schannel/SchannelContext.cpp | 282 | ||||
-rw-r--r-- | Swiften/TLS/Schannel/SchannelContext.h | 23 |
10 files changed, 323 insertions, 377 deletions
diff --git a/Swift/Controllers/MainController.cpp b/Swift/Controllers/MainController.cpp index b0a1778..28fdb2b 100644 --- a/Swift/Controllers/MainController.cpp +++ b/Swift/Controllers/MainController.cpp @@ -544,18 +544,20 @@ void MainController::handleDisconnected(const boost::optional<ClientError>& erro case ClientError::CertificateNotYetValidError: certificateErrorMessage = QT_TRANSLATE_NOOP("", "Certificate is not yet valid"); break; case ClientError::CertificateSelfSignedError: certificateErrorMessage = QT_TRANSLATE_NOOP("", "Certificate is self-signed"); break; case ClientError::CertificateRejectedError: certificateErrorMessage = QT_TRANSLATE_NOOP("", "Certificate has been rejected"); break; case ClientError::CertificateUntrustedError: certificateErrorMessage = QT_TRANSLATE_NOOP("", "Certificate is not trusted"); break; case ClientError::InvalidCertificatePurposeError: certificateErrorMessage = QT_TRANSLATE_NOOP("", "Certificate cannot be used for encrypting your connection"); break; case ClientError::CertificatePathLengthExceededError: certificateErrorMessage = QT_TRANSLATE_NOOP("", "Certificate path length constraint exceeded"); break; case ClientError::InvalidCertificateSignatureError: certificateErrorMessage = QT_TRANSLATE_NOOP("", "Invalid certificate signature"); break; case ClientError::InvalidCAError: certificateErrorMessage = QT_TRANSLATE_NOOP("", "Invalid Certificate Authority"); break; case ClientError::InvalidServerIdentityError: certificateErrorMessage = QT_TRANSLATE_NOOP("", "Certificate does not match the host identity"); break; + case ClientError::RevokedError: certificateErrorMessage = QT_TRANSLATE_NOOP("", "Certificate has been revoked"); break; + case ClientError::RevocationCheckFailedError: certificateErrorMessage = QT_TRANSLATE_NOOP("", "Unable to determine certificate revocation state"); break; } bool forceReconnectAfterCertificateTrust = false; if (!certificateErrorMessage.empty()) { Certificate::ref certificate = certificateTrustChecker_->getLastCertificate(); if (loginWindow_->askUserToTrustCertificatePermanently(certificateErrorMessage, certificate)) { certificateStorage_->addCertificate(certificate); forceReconnectAfterCertificateTrust = true; } else { diff --git a/Swift/QtUI/CAPICertificateSelector.cpp b/Swift/QtUI/CAPICertificateSelector.cpp index 0d4768c..cc69956 100644 --- a/Swift/QtUI/CAPICertificateSelector.cpp +++ b/Swift/QtUI/CAPICertificateSelector.cpp @@ -1,103 +1,116 @@ /* * Copyright (c) 2012 Isode Limited, London, England. * Licensed under the simplified BSD license. * See Documentation/Licenses/BSD-simplified.txt for more information. */ #include <string> -#include "CAPICertificateSelector.h" +#include <Swift/QtUI/CAPICertificateSelector.h> #define SECURITY_WIN32 #include <Windows.h> #include <WinCrypt.h> #include <cryptuiapi.h> + #include <Swiften/StringCodecs/Hexify.h> #include <boost/algorithm/string.hpp> +#include <Swift/Controllers/Intl.h> +#include <Swift/QtUI/QtSwiftUtil.h> +#include <Swiften/Base/Log.h> namespace Swift { -#define cert_dlg_title L"TLS Client Certificate Selection" -#define cert_dlg_prompt L"Select a certificate to use for authentication" /////Hmm, maybe we should not exlude the "location" column -#define exclude_columns CRYPTUI_SELECT_LOCATION_COLUMN \ - |CRYPTUI_SELECT_INTENDEDUSE_COLUMN +#define exclude_columns CRYPTUI_SELECT_LOCATION_COLUMN | CRYPTUI_SELECT_INTENDEDUSE_COLUMN -// Size of the SHA1 hash -#define SHA1_HASH_LEN 20 +#define SHA1_HASH_LENGTH 20 static std::string getCertUri(PCCERT_CONTEXT cert, const char * cert_store_name) { - DWORD cbHash = SHA1_HASH_LEN; - BYTE aHash[SHA1_HASH_LEN]; - std::string ret("certstore:"); + DWORD cbHash = SHA1_HASH_LENGTH; + BYTE aHash[SHA1_HASH_LENGTH]; + std::string result("certstore:"); - ret += cert_store_name; - ret += ":sha1:"; + result += cert_store_name; + result += ":sha1:"; - if (CertGetCertificateContextProperty(cert, - CERT_HASH_PROP_ID, - aHash, - &cbHash) == FALSE ) { + if (CertGetCertificateContextProperty(cert, CERT_HASH_PROP_ID, aHash, &cbHash) == FALSE ) { return ""; } ByteArray byteArray = createByteArray((char *)(&aHash[0]), cbHash); - ret += Hexify::hexify(byteArray); + result += Hexify::hexify(byteArray); - return ret; + return result; } std::string selectCAPICertificate() { + const char* certStoreName = "MY"; - const char * cert_store_name = "MY"; - PCCERT_CONTEXT cert; - DWORD store_flags; - HCERTSTORE hstore; - HWND hwnd; - - store_flags = CERT_STORE_OPEN_EXISTING_FLAG | - CERT_STORE_READONLY_FLAG | - CERT_SYSTEM_STORE_CURRENT_USER; + DWORD storeFlags = CERT_STORE_OPEN_EXISTING_FLAG | CERT_STORE_READONLY_FLAG | CERT_SYSTEM_STORE_CURRENT_USER; - hstore = CertOpenStore(CERT_STORE_PROV_SYSTEM_A, 0, 0, store_flags, cert_store_name); + HCERTSTORE hstore = CertOpenStore(CERT_STORE_PROV_SYSTEM_A, 0, 0, storeFlags, certStoreName); if (!hstore) { return ""; } - -////Does this handle need to be freed as well? - hwnd = GetForegroundWindow(); + HWND hwnd = GetForegroundWindow(); if (!hwnd) { hwnd = GetActiveWindow(); } + std::string certificateDialogTitle = QT_TRANSLATE_NOOP("", "TLS Client Certificate Selection"); + std::string certificateDialogPrompt = QT_TRANSLATE_NOOP("", "Select a certificate to use for authentication"); + + int titleLength = MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, certificateDialogTitle.c_str(), -1, NULL, 0); + int promptLength = MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, certificateDialogPrompt.c_str(), -1, NULL, 0); + + wchar_t* titleChars = new wchar_t[titleLength]; + wchar_t* promptChars = new wchar_t[promptLength]; + + //titleChars[titleLength] = '\0'; + //promptChars[promptLength] = '\0'; + + titleLength = MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, certificateDialogTitle.c_str(), -1, titleChars, titleLength); + promptLength = MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, certificateDialogPrompt.c_str(), -1, promptChars, promptLength); + + if (titleLength == 0 || promptLength == 0) { + int error = GetLastError(); + switch (error) { + case ERROR_INSUFFICIENT_BUFFER: SWIFT_LOG("error") << "Insufficient buffer for rendering cert dialog" << std::endl;break; + case ERROR_INVALID_FLAGS: SWIFT_LOG("error") << "Invalid flags for rendering cert dialog" << std::endl;break; + case ERROR_INVALID_PARAMETER: SWIFT_LOG("error") << "Invalid parameter for rendering cert dialog" << std::endl;break; + case ERROR_NO_UNICODE_TRANSLATION: SWIFT_LOG("error") << "Invalid unicode for rendering cert dialog" << std::endl;break; + default: SWIFT_LOG("error") << "Unexpected multibyte conversion errorcode" << std::endl; + + } + } + + + /* Call Windows dialog to select a suitable certificate */ - cert = CryptUIDlgSelectCertificateFromStore(hstore, - hwnd, - cert_dlg_title, - cert_dlg_prompt, - exclude_columns, - 0, - NULL); + PCCERT_CONTEXT cert = CryptUIDlgSelectCertificateFromStore(hstore, hwnd, titleChars, promptChars, exclude_columns, 0, NULL); + + delete[] titleChars; + delete[] promptChars; if (hstore) { CertCloseStore(hstore, 0); } - if (cert) { - std::string ret = getCertUri(cert, cert_store_name); + std::string result; + if (cert) { + result = getCertUri(cert, certStoreName); CertFreeCertificateContext(cert); - - return ret; - } else { - return ""; } + + return result; } bool isCAPIURI(std::string uri) { return (boost::iequals(uri.substr(0, 10), "certstore:")); } } diff --git a/Swift/QtUI/SConscript b/Swift/QtUI/SConscript index 0622cc6..0971577 100644 --- a/Swift/QtUI/SConscript +++ b/Swift/QtUI/SConscript @@ -35,19 +35,18 @@ if myenv.get("HAVE_GROWL", False) : myenv.UseFlags(myenv["GROWL_FLAGS"]) myenv.Append(CPPDEFINES = ["HAVE_GROWL"]) if myenv["swift_mobile"] : myenv.Append(CPPDEFINES = ["SWIFT_MOBILE"]) if myenv.get("HAVE_SNARL", False) : myenv.UseFlags(myenv["SNARL_FLAGS"]) myenv.Append(CPPDEFINES = ["HAVE_SNARL"]) if env["PLATFORM"] == "win32" : myenv.Append(LIBS = ["cryptui"]) - myenv.Append(LIBS = ["Winscard"]) myenv.UseFlags(myenv["PLATFORM_FLAGS"]) myenv.Tool("qt4", toolpath = ["#/BuildTools/SCons/Tools"]) myenv.Tool("nsis", toolpath = ["#/BuildTools/SCons/Tools"]) myenv.Tool("wix", toolpath = ["#/BuildTools/SCons/Tools"]) qt4modules = ['QtCore', 'QtGui', 'QtWebKit'] if env["PLATFORM"] == "posix" : qt4modules += ["QtDBus"] myenv.EnableQt4Modules(qt4modules, debug = False) diff --git a/Swiften/Client/CoreClient.cpp b/Swiften/Client/CoreClient.cpp index 45d80aa..8a922ba 100644 --- a/Swiften/Client/CoreClient.cpp +++ b/Swiften/Client/CoreClient.cpp @@ -217,19 +217,19 @@ void CoreClient::handleSessionFinished(boost::shared_ptr<Error> error) { clientError = ClientError(ClientError::StreamError); break; } } else if (boost::shared_ptr<TLSError> actualError = boost::dynamic_pointer_cast<TLSError>(error)) { switch(actualError->getType()) { case TLSError::CertificateCardRemoved: clientError = ClientError(ClientError::CertificateCardRemoved); break; - default: + case TLSError::UnknownError: clientError = ClientError(ClientError::TLSError); break; } } else if (boost::shared_ptr<SessionStream::SessionStreamError> actualError = boost::dynamic_pointer_cast<SessionStream::SessionStreamError>(error)) { switch(actualError->type) { case SessionStream::SessionStreamError::ParseError: clientError = ClientError(ClientError::XMLError); break; diff --git a/Swiften/SConscript b/Swiften/SConscript index 41ec947..6308a80 100644 --- a/Swiften/SConscript +++ b/Swiften/SConscript @@ -37,18 +37,22 @@ if env["SCONS_STAGE"] == "flags" : swiften_env["LIBS"] = [swiften_env["SWIFTEN_LIBRARY"]] dep_env = env.Clone() for module in swiften_dep_modules : if env.get(module + "_BUNDLED", False) : swiften_env.UseFlags(env.get(module + "_FLAGS", {})) else : dep_env.UseFlags(env.get(module + "_FLAGS", {})) dep_env.UseFlags(dep_env["PLATFORM_FLAGS"]) + if env.get("HAVE_SCHANNEL", 0) : + dep_env.Append(LIBS = ["Winscard"]) + + for var, e in [("SWIFTEN_FLAGS", swiften_env), ("SWIFTEN_DEP_FLAGS", dep_env)] : env[var] = { "CPPDEFINES": e.get("CPPDEFINES", []), "CPPPATH": e.get("CPPPATH", []), "CPPFLAGS": e.get("CPPFLAGS", []), "LIBPATH": e.get("LIBPATH", []), "LIBS": e.get("LIBS", []), "FRAMEWORKS": e.get("FRAMEWORKS", []), } diff --git a/Swiften/TLS/CAPICertificate.cpp b/Swiften/TLS/CAPICertificate.cpp index b33ebcf..0083b6f 100644 --- a/Swiften/TLS/CAPICertificate.cpp +++ b/Swiften/TLS/CAPICertificate.cpp @@ -1,39 +1,43 @@ /* * Copyright (c) 2012 Isode Limited, London, England. * Licensed under the simplified BSD license. * See Documentation/Licenses/BSD-simplified.txt for more information. */ + #pragma once #include <Swiften/Network/TimerFactory.h> #include <Swiften/TLS/CAPICertificate.h> #include <Swiften/StringCodecs/Hexify.h> +#include <Swiften/Base/Log.h> #include <boost/bind.hpp> #include <boost/algorithm/string/predicate.hpp> // Size of the SHA1 hash -#define SHA1_HASH_LEN 20 +#define SHA1_HASH_LEN 20 namespace Swift { -CAPICertificate::CAPICertificate(const std::string& capiUri, TimerFactory* timerFactory) - : valid_(false), - uri_(capiUri), - certStoreHandle_(0), - scardContext_(0), - cardHandle_(0), - certStore_(), - certName_(), - smartCardReaderName_(), - timerFactory_(timerFactory) { +CAPICertificate::CAPICertificate(const std::string& capiUri, TimerFactory* timerFactory) : + valid_(false), + uri_(capiUri), + certStoreHandle_(0), + scardContext_(0), + cardHandle_(0), + certStore_(), + certName_(), + smartCardReaderName_(), + timerFactory_(timerFactory), + lastPollingResult_(true) { + assert(timerFactory_); setUri(capiUri); } CAPICertificate::~CAPICertificate() { if (smartCardTimer_) { smartCardTimer_->stop(); smartCardTimer_->onTick.disconnect(boost::bind(&CAPICertificate::handleSmartCardTimerTick, this)); smartCardTimer_.reset(); @@ -63,169 +67,141 @@ const std::string& CAPICertificate::getCertStoreName() const { const std::string& CAPICertificate::getCertName() const { return certName_; } const std::string& CAPICertificate::getSmartCardReaderName() const { return smartCardReaderName_; } PCCERT_CONTEXT findCertificateInStore (HCERTSTORE certStoreHandle, const std::string &certName) { - PCCERT_CONTEXT pCertContext = NULL; - if (!boost::iequals(certName.substr(0, 5), "sha1:")) { // Find client certificate. Note that this sample just searches for a // certificate that contains the user name somewhere in the subject name. - pCertContext = CertFindCertificateInStore(certStoreHandle, - X509_ASN_ENCODING, - 0, // dwFindFlags - CERT_FIND_SUBJECT_STR_A, - certName.c_str(), // *pvFindPara - NULL ); // pPrevCertContext - return pCertContext; + return CertFindCertificateInStore(certStoreHandle, X509_ASN_ENCODING, /*dwFindFlags*/ 0, CERT_FIND_SUBJECT_STR_A, /* *pvFindPara*/certName.c_str(), /*pPrevCertContext*/ NULL); } std::string hexstring = certName.substr(5); ByteArray byteArray = Hexify::unhexify(hexstring); CRYPT_HASH_BLOB HashBlob; if (byteArray.size() != SHA1_HASH_LEN) { return NULL; } HashBlob.cbData = SHA1_HASH_LEN; HashBlob.pbData = static_cast<BYTE *>(vecptr(byteArray)); // Find client certificate. Note that this sample just searches for a // certificate that contains the user name somewhere in the subject name. - pCertContext = CertFindCertificateInStore(certStoreHandle, - X509_ASN_ENCODING, - 0, // dwFindFlags - CERT_FIND_HASH, - &HashBlob, - NULL ); // pPrevCertContext - - return pCertContext; + return CertFindCertificateInStore(certStoreHandle, X509_ASN_ENCODING, /* dwFindFlags */ 0, CERT_FIND_HASH, &HashBlob, /* pPrevCertContext */ NULL); + } void CAPICertificate::setUri (const std::string& capiUri) { - valid_ = false; /* Syntax: "certstore:" <cert_store> ":" <hash> ":" <hash_of_cert> */ if (!boost::iequals(capiUri.substr(0, 10), "certstore:")) { return; } /* Substring of subject: uses "storename" */ - std::string capi_identity = capiUri.substr(10); - std::string new_certStore_name; - size_t pos = capi_identity.find_first_of (':'); + std::string capiIdentity = capiUri.substr(10); + std::string newCertStoreName; + size_t pos = capiIdentity.find_first_of (':'); if (pos == std::string::npos) { /* Using the default certificate store */ - new_certStore_name = "MY"; - certName_ = capi_identity; + newCertStoreName = "MY"; + certName_ = capiIdentity; } else { - new_certStore_name = capi_identity.substr(0, pos); - certName_ = capi_identity.substr(pos + 1); + newCertStoreName = capiIdentity.substr(0, pos); + certName_ = capiIdentity.substr(pos + 1); } - PCCERT_CONTEXT pCertContext = NULL; - if (certStoreHandle_ != NULL) { - if (new_certStore_name != certStore_) { + if (newCertStoreName != certStore_) { CertCloseStore(certStoreHandle_, 0); certStoreHandle_ = NULL; } } if (certStoreHandle_ == NULL) { - certStoreHandle_ = CertOpenSystemStore(0, new_certStore_name.c_str()); + certStoreHandle_ = CertOpenSystemStore(0, newCertStoreName.c_str()); if (!certStoreHandle_) { return; } } - certStore_ = new_certStore_name; + certStore_ = newCertStoreName; - pCertContext = findCertificateInStore (certStoreHandle_, certName_); + PCCERT_CONTEXT certContext = findCertificateInStore (certStoreHandle_, certName_); - if (!pCertContext) { + if (!certContext) { return; } /* Now verify that we can have access to the corresponding private key */ DWORD len; CRYPT_KEY_PROV_INFO *pinfo; HCRYPTPROV hprov; HCRYPTKEY key; - if (!CertGetCertificateContextProperty(pCertContext, + if (!CertGetCertificateContextProperty(certContext, CERT_KEY_PROV_INFO_PROP_ID, NULL, &len)) { - CertFreeCertificateContext(pCertContext); + CertFreeCertificateContext(certContext); return; } pinfo = static_cast<CRYPT_KEY_PROV_INFO *>(malloc(len)); if (!pinfo) { - CertFreeCertificateContext(pCertContext); + CertFreeCertificateContext(certContext); return; } - if (!CertGetCertificateContextProperty(pCertContext, - CERT_KEY_PROV_INFO_PROP_ID, - pinfo, - &len)) { - CertFreeCertificateContext(pCertContext); + if (!CertGetCertificateContextProperty(certContext, CERT_KEY_PROV_INFO_PROP_ID, pinfo, &len)) { + CertFreeCertificateContext(certContext); free(pinfo); return; } - CertFreeCertificateContext(pCertContext); + CertFreeCertificateContext(certContext); // Now verify if we have access to the private key - if (!CryptAcquireContextW(&hprov, - pinfo->pwszContainerName, - pinfo->pwszProvName, - pinfo->dwProvType, - 0)) { + if (!CryptAcquireContextW(&hprov, pinfo->pwszContainerName, pinfo->pwszProvName, pinfo->dwProvType, 0)) { free(pinfo); return; } - char smartcard_reader[1024]; - DWORD buflen; - - buflen = sizeof(smartcard_reader); - if (!CryptGetProvParam(hprov, PP_SMARTCARD_READER, (BYTE *)&smartcard_reader, &buflen, 0)) { - DWORD error; - - error = GetLastError(); + char smartCardReader[1024]; + DWORD bufferLength = sizeof(smartCardReader); + if (!CryptGetProvParam(hprov, PP_SMARTCARD_READER, (BYTE *)&smartCardReader, &bufferLength, 0)) { + DWORD error = GetLastError(); smartCardReaderName_ = ""; - } else { - LONG lRet; - - smartCardReaderName_ = smartcard_reader; + } + else { + smartCardReaderName_ = smartCardReader; - lRet = SCardEstablishContext(SCARD_SCOPE_USER, NULL, NULL, &scardContext_); - if (SCARD_S_SUCCESS == lRet) { + LONG result = SCardEstablishContext(SCARD_SCOPE_USER, NULL, NULL, &scardContext_); + if (SCARD_S_SUCCESS == result) { // Initiate monitoring for smartcard ejection - smartCardTimer_ = timerFactory_->createTimer(SMARTCARD_EJECTION_CHECK_FREQ); - } else { + smartCardTimer_ = timerFactory_->createTimer(SMARTCARD_EJECTION_CHECK_FREQUENCY_MILLISECONDS); + } + else { ///Need to handle an error here } } if (!CryptGetUserKey(hprov, pinfo->dwKeySpec, &key)) { CryptReleaseContext(hprov, 0); free(pinfo); return; } @@ -236,131 +212,121 @@ void CAPICertificate::setUri (const std::string& capiUri) { if (smartCardTimer_) { smartCardTimer_->onTick.connect(boost::bind(&CAPICertificate::handleSmartCardTimerTick, this)); smartCardTimer_->start(); } valid_ = true; } -static void smartcard_check_status (SCARDCONTEXT hContext, - const char * pReader, - SCARDHANDLE hCardHandle, // Can be 0 on the first call - SCARDHANDLE * newCardHandle, // The handle returned - DWORD * pdwState) { - LONG lReturn; - DWORD dwAP; - char szReader[200]; - DWORD cch = sizeof(szReader); - BYTE bAttr[32]; - DWORD cByte = 32; - +static void smartcard_check_status (SCARDCONTEXT hContext, + const char* pReader, + SCARDHANDLE hCardHandle, /* Can be 0 on the first call */ + SCARDHANDLE* newCardHandle, /* The handle returned */ + DWORD* pdwState) { if (hCardHandle == 0) { - lReturn = SCardConnect(hContext, - pReader, - SCARD_SHARE_SHARED, - SCARD_PROTOCOL_T0 | SCARD_PROTOCOL_T1, - &hCardHandle, - &dwAP); - if ( SCARD_S_SUCCESS != lReturn ) { + DWORD dwAP; + LONG result = SCardConnect(hContext, pReader, SCARD_SHARE_SHARED, SCARD_PROTOCOL_T0 | SCARD_PROTOCOL_T1, &hCardHandle, &dwAP); + if (SCARD_S_SUCCESS != result) { hCardHandle = 0; - if (SCARD_E_NO_SMARTCARD == lReturn || SCARD_W_REMOVED_CARD == lReturn) { + if (SCARD_E_NO_SMARTCARD == result || SCARD_W_REMOVED_CARD == result) { *pdwState = SCARD_ABSENT; - } else { + } + else { *pdwState = SCARD_UNKNOWN; } - goto done; + + if (newCardHandle == NULL) { + (void) SCardDisconnect(hCardHandle, SCARD_LEAVE_CARD); + hCardHandle = 0; + } + else { + *newCardHandle = hCardHandle; + } } } - lReturn = SCardStatus(hCardHandle, - szReader, // Unfortunately we can't use NULL here - &cch, - pdwState, - NULL, - (LPBYTE)&bAttr, - &cByte); + char szReader[200]; + DWORD cch = sizeof(szReader); + BYTE bAttr[32]; + DWORD cByte = 32; + LONG result = SCardStatus(hCardHandle, /* Unfortunately we can't use NULL here */ szReader, &cch, pdwState, NULL, (LPBYTE)&bAttr, &cByte); - if ( SCARD_S_SUCCESS != lReturn ) { - if (SCARD_E_NO_SMARTCARD == lReturn || SCARD_W_REMOVED_CARD == lReturn) { + if (SCARD_S_SUCCESS != result) { + if (SCARD_E_NO_SMARTCARD == result || SCARD_W_REMOVED_CARD == result) { *pdwState = SCARD_ABSENT; - } else { + } + else { *pdwState = SCARD_UNKNOWN; } } -done: if (newCardHandle == NULL) { (void) SCardDisconnect(hCardHandle, SCARD_LEAVE_CARD); hCardHandle = 0; - } else { + } + else { *newCardHandle = hCardHandle; } } bool CAPICertificate::checkIfSmartCardPresent () { - - DWORD dwState; - if (!smartCardReaderName_.empty()) { - smartcard_check_status (scardContext_, - smartCardReaderName_.c_str(), - cardHandle_, - &cardHandle_, - &dwState); -////DEBUG - switch ( dwState ) { - case SCARD_ABSENT: - printf("Card absent.\n"); - break; - case SCARD_PRESENT: - printf("Card present.\n"); - break; - case SCARD_SWALLOWED: - printf("Card swallowed.\n"); - break; - case SCARD_POWERED: - printf("Card has power.\n"); - break; - case SCARD_NEGOTIABLE: - printf("Card reset and waiting PTS negotiation.\n"); - break; - case SCARD_SPECIFIC: - printf("Card has specific communication protocols set.\n"); - break; - default: - printf("Unknown or unexpected card state.\n"); - break; + DWORD dwState; + smartcard_check_status(scardContext_, smartCardReaderName_.c_str(), cardHandle_, &cardHandle_, &dwState); + + switch (dwState) { + case SCARD_ABSENT: + SWIFT_LOG("DEBUG") << "Card absent." << std::endl; + break; + case SCARD_PRESENT: + SWIFT_LOG("DEBUG") << "Card present." << std::endl; + break; + case SCARD_SWALLOWED: + SWIFT_LOG("DEBUG") << "Card swallowed." << std::endl; + break; + case SCARD_POWERED: + SWIFT_LOG("DEBUG") << "Card has power." << std::endl; + break; + case SCARD_NEGOTIABLE: + SWIFT_LOG("DEBUG") << "Card reset and waiting PTS negotiation." << std::endl; + break; + case SCARD_SPECIFIC: + SWIFT_LOG("DEBUG") << "Card has specific communication protocols set." << std::endl; + break; + default: + SWIFT_LOG("DEBUG") << "Unknown or unexpected card state." << std::endl; + break; } - switch ( dwState ) { - case SCARD_ABSENT: - return false; + switch (dwState) { + case SCARD_ABSENT: + return false; - case SCARD_PRESENT: - case SCARD_SWALLOWED: - case SCARD_POWERED: - case SCARD_NEGOTIABLE: - case SCARD_SPECIFIC: - return true; + case SCARD_PRESENT: + case SCARD_SWALLOWED: + case SCARD_POWERED: + case SCARD_NEGOTIABLE: + case SCARD_SPECIFIC: + return true; - default: - return false; + default: + return false; } - } else { + } + else { return false; } } void CAPICertificate::handleSmartCardTimerTick() { - - if (checkIfSmartCardPresent() == false) { - smartCardTimer_->stop(); + bool poll = checkIfSmartCardPresent(); + if (lastPollingResult_ && !poll) { onCertificateCardRemoved(); - } else { - smartCardTimer_->start(); - } + } + lastPollingResult_ = poll; + smartCardTimer_->start(); } } diff --git a/Swiften/TLS/CAPICertificate.h b/Swiften/TLS/CAPICertificate.h index c8c00fe..5f24b7e 100644 --- a/Swiften/TLS/CAPICertificate.h +++ b/Swiften/TLS/CAPICertificate.h @@ -10,27 +10,25 @@ #include <Swiften/Base/SafeByteArray.h> #include <Swiften/TLS/CertificateWithKey.h> #include <Swiften/Network/Timer.h> #define SECURITY_WIN32 #include <Windows.h> #include <WinCrypt.h> #include <Winscard.h> -/* In ms */ -#define SMARTCARD_EJECTION_CHECK_FREQ 1000 +#define SMARTCARD_EJECTION_CHECK_FREQUENCY_MILLISECONDS 1000 namespace Swift { class TimerFactory; class CAPICertificate : public Swift::CertificateWithKey { public: -////Allow timerFactory to be NULL? CAPICertificate(const std::string& capiUri, TimerFactory* timerFactory); virtual ~CAPICertificate(); virtual bool isNull() const; const std::string& getCertStoreName() const; const std::string& getCertName() const; @@ -55,14 +53,16 @@ namespace Swift { SCARDCONTEXT scardContext_; SCARDHANDLE cardHandle_; /* Parsed components of the uri_ */ std::string certStore_; std::string certName_; std::string smartCardReaderName_; boost::shared_ptr<Timer> smartCardTimer_; TimerFactory* timerFactory_; + + bool lastPollingResult_; }; PCCERT_CONTEXT findCertificateInStore (HCERTSTORE certStoreHandle, const std::string &certName); } diff --git a/Swiften/TLS/SConscript b/Swiften/TLS/SConscript index 0e95b8b..fb327b9 100644 --- a/Swiften/TLS/SConscript +++ b/Swiften/TLS/SConscript @@ -13,18 +13,19 @@ myenv = swiften_env.Clone() if myenv.get("HAVE_OPENSSL", 0) : myenv.MergeFlags(myenv["OPENSSL_FLAGS"]) objects += myenv.SwiftenObject([ "OpenSSL/OpenSSLContext.cpp", "OpenSSL/OpenSSLCertificate.cpp", "OpenSSL/OpenSSLContextFactory.cpp", ]) myenv.Append(CPPDEFINES = "HAVE_OPENSSL") elif myenv.get("HAVE_SCHANNEL", 0) : + swiften_env.Append(LIBS = ["Winscard"]) objects += myenv.StaticObject([ "CAPICertificate.cpp", "Schannel/SchannelContext.cpp", "Schannel/SchannelCertificate.cpp", "Schannel/SchannelContextFactory.cpp", ]) myenv.Append(CPPDEFINES = "HAVE_SCHANNEL") objects += myenv.SwiftenObject(["PlatformTLSFactories.cpp"]) diff --git a/Swiften/TLS/Schannel/SchannelContext.cpp b/Swiften/TLS/Schannel/SchannelContext.cpp index 8e952ea..6169ad7 100644 --- a/Swiften/TLS/Schannel/SchannelContext.cpp +++ b/Swiften/TLS/Schannel/SchannelContext.cpp @@ -1,133 +1,120 @@ /* * Copyright (c) 2011 Soren Dreijer * Licensed under the simplified BSD license. * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2012 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + #include <boost/bind.hpp> #include <Swiften/TLS/Schannel/SchannelContext.h> #include <Swiften/TLS/Schannel/SchannelCertificate.h> #include <Swiften/TLS/CAPICertificate.h> -#include <WinHTTP.h> // For SECURITY_FLAG_IGNORE_CERT_CN_INVALID +#include <WinHTTP.h> /* For SECURITY_FLAG_IGNORE_CERT_CN_INVALID */ namespace Swift { //------------------------------------------------------------------------ -SchannelContext::SchannelContext() -: m_state(Start) -, m_secContext(0) -, m_my_cert_store(NULL) -, m_cert_store_name("MY") -, m_cert_name() -, m_smartcard_reader() -{ - m_ctxtFlags = ISC_REQ_ALLOCATE_MEMORY | - ISC_REQ_CONFIDENTIALITY | - ISC_REQ_EXTENDED_ERROR | - ISC_REQ_INTEGRITY | - ISC_REQ_REPLAY_DETECT | - ISC_REQ_SEQUENCE_DETECT | - ISC_REQ_USE_SUPPLIED_CREDS | - ISC_REQ_STREAM; +SchannelContext::SchannelContext() : m_state(Start), m_secContext(0), m_my_cert_store(NULL), m_cert_store_name("MY"), m_cert_name(), m_smartcard_reader() { + m_ctxtFlags = ISC_REQ_ALLOCATE_MEMORY | + ISC_REQ_CONFIDENTIALITY | + ISC_REQ_EXTENDED_ERROR | + ISC_REQ_INTEGRITY | + ISC_REQ_REPLAY_DETECT | + ISC_REQ_SEQUENCE_DETECT | + ISC_REQ_USE_SUPPLIED_CREDS | + ISC_REQ_STREAM; ZeroMemory(&m_streamSizes, sizeof(m_streamSizes)); } //------------------------------------------------------------------------ -SchannelContext::~SchannelContext() -{ +SchannelContext::~SchannelContext() { if (m_my_cert_store) CertCloseStore(m_my_cert_store, 0); } //------------------------------------------------------------------------ -void SchannelContext::determineStreamSizes() -{ +void SchannelContext::determineStreamSizes() { QueryContextAttributes(m_ctxtHandle, SECPKG_ATTR_STREAM_SIZES, &m_streamSizes); } //------------------------------------------------------------------------ -void SchannelContext::connect() -{ +void SchannelContext::connect() { ScopedCertContext pCertContext; m_state = Connecting; // If a user name is specified, then attempt to find a client // certificate. Otherwise, just create a NULL credential. - if (!m_cert_name.empty()) - { - if (m_my_cert_store == NULL) - { + if (!m_cert_name.empty()) { + if (m_my_cert_store == NULL) { m_my_cert_store = CertOpenSystemStore(0, m_cert_store_name.c_str()); - if (!m_my_cert_store) - { -///// printf( "**** Error 0x%x returned by CertOpenSystemStore\n", GetLastError() ); - indicateError(); + if (!m_my_cert_store) { + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); return; } } pCertContext = findCertificateInStore( m_my_cert_store, m_cert_name ); - if (pCertContext == NULL) - { -///// printf("**** Error 0x%x returned by CertFindCertificateInStore\n", GetLastError()); - indicateError(); + if (pCertContext == NULL) { + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); return; } } // We use an empty list for client certificates PCCERT_CONTEXT clientCerts[1] = {0}; SCHANNEL_CRED sc = {0}; sc.dwVersion = SCHANNEL_CRED_VERSION; /////SSL3? sc.grbitEnabledProtocols = SP_PROT_SSL3_CLIENT | SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_2_CLIENT; sc.dwFlags = SCH_CRED_MANUAL_CRED_VALIDATION; - if (pCertContext) - { + if (pCertContext) { sc.cCreds = 1; sc.paCred = pCertContext.GetPointer(); sc.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS; } - else - { - sc.cCreds = 0; // Let Crypto API find the appropriate certificate for us + else { + sc.cCreds = 0; sc.paCred = clientCerts; - sc.dwFlags |= SCH_CRED_USE_DEFAULT_CREDS; + sc.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS; } // Swiften performs the server name check for us sc.dwFlags |= SCH_CRED_NO_SERVERNAME_CHECK; SECURITY_STATUS status = AcquireCredentialsHandle( NULL, UNISP_NAME, SECPKG_CRED_OUTBOUND, NULL, &sc, NULL, NULL, m_credHandle.Reset(), NULL); - if (status != SEC_E_OK) - { + if (status != SEC_E_OK) { // We failed to obtain the credentials handle - indicateError(); + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); return; } SecBuffer outBuffers[2]; // We let Schannel allocate the output buffer for us outBuffers[0].pvBuffer = NULL; outBuffers[0].cbBuffer = 0; outBuffers[0].BufferType = SECBUFFER_TOKEN; @@ -155,49 +142,48 @@ void SchannelContext::connect() 0, 0, NULL, 0, m_ctxtHandle.Reset(), &outBufferDesc, &m_secContext, NULL); - if (status != SEC_E_OK && status != SEC_I_CONTINUE_NEEDED) - { + if (status != SEC_E_OK && status != SEC_I_CONTINUE_NEEDED) { // We failed to initialize the security context handleCertError(status); - indicateError(); + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); return; } // Start the handshake sendDataOnNetwork(outBuffers[0].pvBuffer, outBuffers[0].cbBuffer); - if (status == SEC_E_OK) - { + if (status == SEC_E_OK) { status = validateServerCertificate(); - if (status != SEC_E_OK) + if (status != SEC_E_OK) { handleCertError(status); + } m_state = Connected; determineStreamSizes(); onConnected(); } } //------------------------------------------------------------------------ -SECURITY_STATUS SchannelContext::validateServerCertificate() -{ +SECURITY_STATUS SchannelContext::validateServerCertificate() { SchannelCertificate::ref pServerCert = boost::dynamic_pointer_cast<SchannelCertificate>( getPeerCertificate() ); - if (!pServerCert) + if (!pServerCert) { return SEC_E_WRONG_PRINCIPAL; + } const LPSTR usage[] = { szOID_PKIX_KP_SERVER_AUTH, szOID_SERVER_GATED_CRYPTO, szOID_SGC_NETSCAPE }; CERT_CHAIN_PARA chainParams = {0}; @@ -214,20 +200,21 @@ SECURITY_STATUS SchannelContext::validateServerCertificate() NULL, // Use the chain engine for the current user (assumes a user is logged in) pServerCert->getCertContext(), NULL, NULL, &chainParams, chainFlags, NULL, pChainContext.Reset()); - if (!success) + if (!success) { return GetLastError(); + } SSL_EXTRA_CERT_CHAIN_POLICY_PARA sslChainPolicy = {0}; sslChainPolicy.cbSize = sizeof(sslChainPolicy); sslChainPolicy.dwAuthType = AUTHTYPE_SERVER; sslChainPolicy.fdwChecks = SECURITY_FLAG_IGNORE_CERT_CN_INVALID; // Swiften checks the server name for us. Is this the correct way to disable server name checking? sslChainPolicy.pwszServerName = NULL; CERT_CHAIN_POLICY_PARA certChainPolicy = {0}; certChainPolicy.cbSize = sizeof(certChainPolicy); @@ -236,46 +223,43 @@ SECURITY_STATUS SchannelContext::validateServerCertificate() CERT_CHAIN_POLICY_STATUS certChainPolicyStatus = {0}; certChainPolicyStatus.cbSize = sizeof(certChainPolicyStatus); // Verify the chain if (!CertVerifyCertificateChainPolicy( CERT_CHAIN_POLICY_SSL, pChainContext, &certChainPolicy, - &certChainPolicyStatus)) - { + &certChainPolicyStatus)) { return GetLastError(); } - if (certChainPolicyStatus.dwError != S_OK) + if (certChainPolicyStatus.dwError != S_OK) { return certChainPolicyStatus.dwError; + } return S_OK; } //------------------------------------------------------------------------ -void SchannelContext::appendNewData(const SafeByteArray& data) -{ +void SchannelContext::appendNewData(const SafeByteArray& data) { size_t originalSize = m_receivedData.size(); - m_receivedData.resize( originalSize + data.size() ); - memcpy( &m_receivedData[0] + originalSize, &data[0], data.size() ); + m_receivedData.resize(originalSize + data.size()); + memcpy(&m_receivedData[0] + originalSize, &data[0], data.size()); } //------------------------------------------------------------------------ -void SchannelContext::continueHandshake(const SafeByteArray& data) -{ +void SchannelContext::continueHandshake(const SafeByteArray& data) { appendNewData(data); - while (!m_receivedData.empty()) - { + while (!m_receivedData.empty()) { SecBuffer inBuffers[2]; // Provide Schannel with the remote host's handshake data inBuffers[0].pvBuffer = (char*)(&m_receivedData[0]); inBuffers[0].cbBuffer = (unsigned long)m_receivedData.size(); inBuffers[0].BufferType = SECBUFFER_TOKEN; inBuffers[1].pvBuffer = NULL; inBuffers[1].cbBuffer = 0; @@ -315,148 +299,138 @@ void SchannelContext::continueHandshake(const SafeByteArray& data) 0, 0, &inBufferDesc, 0, NULL, &outBufferDesc, &m_secContext, NULL); - if (status == SEC_E_INCOMPLETE_MESSAGE) - { + if (status == SEC_E_INCOMPLETE_MESSAGE) { // Wait for more data to arrive break; } - else if (status == SEC_I_CONTINUE_NEEDED) - { + else if (status == SEC_I_CONTINUE_NEEDED) { SecBuffer* pDataBuffer = &outBuffers[0]; SecBuffer* pExtraBuffer = &inBuffers[1]; - if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL) + if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL) { sendDataOnNetwork(pDataBuffer->pvBuffer, pDataBuffer->cbBuffer); + } - if (pExtraBuffer->BufferType == SECBUFFER_EXTRA) + if (pExtraBuffer->BufferType == SECBUFFER_EXTRA) { m_receivedData.erase(m_receivedData.begin(), m_receivedData.end() - pExtraBuffer->cbBuffer); - else + } + else { m_receivedData.clear(); + } break; } - else if (status == SEC_E_OK) - { + else if (status == SEC_E_OK) { status = validateServerCertificate(); - if (status != SEC_E_OK) + if (status != SEC_E_OK) { handleCertError(status); + } SecBuffer* pExtraBuffer = &inBuffers[1]; - if (pExtraBuffer && pExtraBuffer->cbBuffer > 0) + if (pExtraBuffer && pExtraBuffer->cbBuffer > 0) { m_receivedData.erase(m_receivedData.begin(), m_receivedData.end() - pExtraBuffer->cbBuffer); - else + } + else { m_receivedData.clear(); + } m_state = Connected; determineStreamSizes(); onConnected(); } - else - { + else { // We failed to initialize the security context handleCertError(status); - indicateError(); + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); return; } } } //------------------------------------------------------------------------ void SchannelContext::handleCertError(SECURITY_STATUS status) { if (status == SEC_E_UNTRUSTED_ROOT || status == CERT_E_UNTRUSTEDROOT || status == CRYPT_E_ISSUER_SERIALNUMBER || status == CRYPT_E_SIGNER_NOT_FOUND || - status == CRYPT_E_NO_TRUSTED_SIGNER) - { + status == CRYPT_E_NO_TRUSTED_SIGNER) { m_verificationError = CertificateVerificationError::Untrusted; } else if (status == SEC_E_CERT_EXPIRED || - status == CERT_E_EXPIRED) - { + status == CERT_E_EXPIRED) { m_verificationError = CertificateVerificationError::Expired; } - else if (status == CRYPT_E_SELF_SIGNED) - { + else if (status == CRYPT_E_SELF_SIGNED) { m_verificationError = CertificateVerificationError::SelfSigned; } else if (status == CRYPT_E_HASH_VALUE || - status == TRUST_E_CERT_SIGNATURE) - { + status == TRUST_E_CERT_SIGNATURE) { m_verificationError = CertificateVerificationError::InvalidSignature; } - else if (status == CRYPT_E_REVOKED) - { + else if (status == CRYPT_E_REVOKED) { m_verificationError = CertificateVerificationError::Revoked; } else if (status == CRYPT_E_NO_REVOCATION_CHECK || - status == CRYPT_E_REVOCATION_OFFLINE) - { + status == CRYPT_E_REVOCATION_OFFLINE) { m_verificationError = CertificateVerificationError::RevocationCheckFailed; } - else - { + else { m_verificationError = CertificateVerificationError::UnknownError; } } //------------------------------------------------------------------------ -void SchannelContext::sendDataOnNetwork(const void* pData, size_t dataSize) -{ - if (dataSize > 0 && pData) - { +void SchannelContext::sendDataOnNetwork(const void* pData, size_t dataSize) { + if (dataSize > 0 && pData) { SafeByteArray byteArray(dataSize); memcpy(&byteArray[0], pData, dataSize); onDataForNetwork(byteArray); } } //------------------------------------------------------------------------ -void SchannelContext::forwardDataToApplication(const void* pData, size_t dataSize) -{ +void SchannelContext::forwardDataToApplication(const void* pData, size_t dataSize) { SafeByteArray byteArray(dataSize); memcpy(&byteArray[0], pData, dataSize); onDataForApplication(byteArray); } //------------------------------------------------------------------------ -void SchannelContext::handleDataFromApplication(const SafeByteArray& data) -{ +void SchannelContext::handleDataFromApplication(const SafeByteArray& data) { // Don't attempt to send data until we're fully connected - if (m_state == Connecting) + if (m_state == Connecting) { return; + } // Encrypt the data encryptAndSendData(data); } //------------------------------------------------------------------------ -void SchannelContext::handleDataFromNetwork(const SafeByteArray& data) -{ - switch (m_state) - { +void SchannelContext::handleDataFromNetwork(const SafeByteArray& data) { + switch (m_state) { case Connecting: { // We're still establishing the connection, so continue the handshake continueHandshake(data); } break; case Connected: { @@ -466,35 +440,32 @@ void SchannelContext::handleDataFromNetwork(const SafeByteArray& data) break; default: return; } } //------------------------------------------------------------------------ -void SchannelContext::indicateError() -{ +void SchannelContext::indicateError(boost::shared_ptr<TLSError> error) { m_state = Error; m_receivedData.clear(); - onError(boost::make_shared<TLSError>()); + onError(error); } //------------------------------------------------------------------------ -void SchannelContext::decryptAndProcessData(const SafeByteArray& data) -{ +void SchannelContext::decryptAndProcessData(const SafeByteArray& data) { SecBuffer inBuffers[4] = {0}; appendNewData(data); - while (!m_receivedData.empty()) - { + while (!m_receivedData.empty()) { // // MSDN: // When using the Schannel SSP with contexts that are not connection oriented, on input, // the structure must contain four SecBuffer structures. Exactly one buffer must be of type // SECBUFFER_DATA and contain an encrypted message, which is decrypted in place. The remaining // buffers are used for output and must be of type SECBUFFER_EMPTY. For connection-oriented // contexts, a SECBUFFER_DATA type buffer must be supplied, as noted for nonconnection-oriented // contexts. Additionally, a second SECBUFFER_TOKEN type buffer that contains a security token // must also be supplied. @@ -509,88 +480,82 @@ void SchannelContext::decryptAndProcessData(const SafeByteArray& data) SecBufferDesc inBufferDesc = {0}; inBufferDesc.cBuffers = 4; inBufferDesc.pBuffers = inBuffers; inBufferDesc.ulVersion = SECBUFFER_VERSION; size_t inData = m_receivedData.size(); SECURITY_STATUS status = DecryptMessage(m_ctxtHandle, &inBufferDesc, 0, NULL); - if (status == SEC_E_INCOMPLETE_MESSAGE) - { + if (status == SEC_E_INCOMPLETE_MESSAGE) { // Wait for more data to arrive break; } - else if (status == SEC_I_RENEGOTIATE) - { + else if (status == SEC_I_RENEGOTIATE) { // TODO: Handle renegotiation scenarios - indicateError(); + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); break; } - else if (status == SEC_I_CONTEXT_EXPIRED) - { - indicateError(); + else if (status == SEC_I_CONTEXT_EXPIRED) { + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); break; } - else if (status != SEC_E_OK) - { - indicateError(); + else if (status != SEC_E_OK) { + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); break; } SecBuffer* pDataBuffer = NULL; SecBuffer* pExtraBuffer = NULL; - for (int i = 0; i < 4; ++i) - { - if (inBuffers[i].BufferType == SECBUFFER_DATA) + for (int i = 0; i < 4; ++i) { + if (inBuffers[i].BufferType == SECBUFFER_DATA) { pDataBuffer = &inBuffers[i]; - - else if (inBuffers[i].BufferType == SECBUFFER_EXTRA) + } + else if (inBuffers[i].BufferType == SECBUFFER_EXTRA) { pExtraBuffer = &inBuffers[i]; + } } - if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL) + if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL) { forwardDataToApplication(pDataBuffer->pvBuffer, pDataBuffer->cbBuffer); + } // If there is extra data left over from the decryption operation, we call DecryptMessage() again - if (pExtraBuffer) - { + if (pExtraBuffer) { m_receivedData.erase(m_receivedData.begin(), m_receivedData.end() - pExtraBuffer->cbBuffer); } - else - { + else { // We're done m_receivedData.erase(m_receivedData.begin(), m_receivedData.begin() + inData); } } } //------------------------------------------------------------------------ -void SchannelContext::encryptAndSendData(const SafeByteArray& data) -{ - if (m_streamSizes.cbMaximumMessage == 0) +void SchannelContext::encryptAndSendData(const SafeByteArray& data) { + if (m_streamSizes.cbMaximumMessage == 0) { return; + } SecBuffer outBuffers[4] = {0}; // Calculate the largest required size of the send buffer size_t messageBufferSize = (data.size() > m_streamSizes.cbMaximumMessage) ? m_streamSizes.cbMaximumMessage : data.size(); // Allocate a packet for the encrypted data SafeByteArray sendBuffer; sendBuffer.resize(m_streamSizes.cbHeader + messageBufferSize + m_streamSizes.cbTrailer); size_t bytesSent = 0; - do - { + do { size_t bytesLeftToSend = data.size() - bytesSent; // Calculate how much of the send buffer we'll be using for this chunk size_t bytesToSend = (bytesLeftToSend > m_streamSizes.cbMaximumMessage) ? m_streamSizes.cbMaximumMessage : bytesLeftToSend; // Copy the plain text data into the send buffer memcpy(&sendBuffer[0] + m_streamSizes.cbHeader, &data[0] + bytesSent, bytesToSend); @@ -611,88 +576,73 @@ void SchannelContext::encryptAndSendData(const SafeByteArray& data) outBuffers[3].cbBuffer = 0; outBuffers[3].BufferType = SECBUFFER_EMPTY; SecBufferDesc outBufferDesc = {0}; outBufferDesc.cBuffers = 4; outBufferDesc.pBuffers = outBuffers; outBufferDesc.ulVersion = SECBUFFER_VERSION; SECURITY_STATUS status = EncryptMessage(m_ctxtHandle, 0, &outBufferDesc, 0); - if (status != SEC_E_OK) - { - indicateError(); + if (status != SEC_E_OK) { + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); return; } sendDataOnNetwork(&sendBuffer[0], outBuffers[0].cbBuffer + outBuffers[1].cbBuffer + outBuffers[2].cbBuffer); bytesSent += bytesToSend; } while (bytesSent < data.size()); } //------------------------------------------------------------------------ -bool SchannelContext::setClientCertificate(CertificateWithKey::ref certificate) -{ +bool SchannelContext::setClientCertificate(CertificateWithKey::ref certificate) { boost::shared_ptr<CAPICertificate> capiCertificate = boost::dynamic_pointer_cast<CAPICertificate>(certificate); if (!capiCertificate || capiCertificate->isNull()) { return false; } + userCertificate = capiCertificate; + // We assume that the Certificate Store Name/Certificate Name // are valid at this point m_cert_store_name = capiCertificate->getCertStoreName(); m_cert_name = capiCertificate->getCertName(); ////At the moment this is only useful for logging: m_smartcard_reader = capiCertificate->getSmartCardReaderName(); capiCertificate->onCertificateCardRemoved.connect(boost::bind(&SchannelContext::handleCertificateCardRemoved, this)); return true; } //------------------------------------------------------------------------ void SchannelContext::handleCertificateCardRemoved() { - //ToDo: Might want to log the reason ("certificate card ejected") - indicateError(); + indicateError(boost::make_shared<TLSError>(TLSError::CertificateCardRemoved)); } //------------------------------------------------------------------------ -Certificate::ref SchannelContext::getPeerCertificate() const -{ - SchannelCertificate::ref pCertificate; - +Certificate::ref SchannelContext::getPeerCertificate() const { ScopedCertContext pServerCert; SECURITY_STATUS status = QueryContextAttributes(m_ctxtHandle, SECPKG_ATTR_REMOTE_CERT_CONTEXT, pServerCert.Reset()); - if (status != SEC_E_OK) - return pCertificate; - - pCertificate.reset( new SchannelCertificate(pServerCert) ); - return pCertificate; + return status == SEC_E_OK ? boost::make_shared<SchannelCertificate>(pServerCert) : SchannelCertificate::ref(); } //------------------------------------------------------------------------ -CertificateVerificationError::ref SchannelContext::getPeerCertificateVerificationError() const -{ - boost::shared_ptr<CertificateVerificationError> pCertError; - - if (m_verificationError) - pCertError.reset( new CertificateVerificationError(*m_verificationError) ); - - return pCertError; +CertificateVerificationError::ref SchannelContext::getPeerCertificateVerificationError() const { + return m_verificationError ? boost::make_shared<CertificateVerificationError>(*m_verificationError) : CertificateVerificationError::ref(); } //------------------------------------------------------------------------ -ByteArray SchannelContext::getFinishMessage() const -{ +ByteArray SchannelContext::getFinishMessage() const { // TODO: Implement ByteArray emptyArray; return emptyArray; } //------------------------------------------------------------------------ } diff --git a/Swiften/TLS/Schannel/SchannelContext.h b/Swiften/TLS/Schannel/SchannelContext.h index bce7415..7c2601b 100644 --- a/Swiften/TLS/Schannel/SchannelContext.h +++ b/Swiften/TLS/Schannel/SchannelContext.h @@ -1,34 +1,42 @@ /* * Copyright (c) 2011 Soren Dreijer * Licensed under the simplified BSD license. * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2012 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + #pragma once -#include "Swiften/Base/boost_bsignals.h" +#include <Swiften/Base/boost_bsignals.h> -#include "Swiften/TLS/TLSContext.h" -#include "Swiften/TLS/Schannel/SchannelUtil.h" -#include "Swiften/TLS/CertificateWithKey.h" -#include "Swiften/Base/ByteArray.h" +#include <Swiften/TLS/TLSContext.h> +#include <Swiften/TLS/Schannel/SchannelUtil.h> +#include <Swiften/TLS/CertificateWithKey.h> +#include <Swiften/Base/ByteArray.h> +#include <Swiften/TLS/TLSError.h> #define SECURITY_WIN32 #include <Windows.h> #include <Schannel.h> #include <security.h> #include <schnlsp.h> #include <boost/noncopyable.hpp> namespace Swift { + class CAPICertificate; class SchannelContext : public TLSContext, boost::noncopyable { public: typedef boost::shared_ptr<SchannelContext> sp_t; public: SchannelContext(); ~SchannelContext(); @@ -44,19 +52,21 @@ namespace Swift virtual Certificate::ref getPeerCertificate() const; virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const; virtual ByteArray getFinishMessage() const; private: void determineStreamSizes(); void continueHandshake(const SafeByteArray& data); - void indicateError(); + void indicateError(boost::shared_ptr<TLSError> error); + //FIXME: Remove + void indicateError() {indicateError(boost::make_shared<TLSError>());} void handleCertError(SECURITY_STATUS status) ; void sendDataOnNetwork(const void* pData, size_t dataSize); void forwardDataToApplication(const void* pData, size_t dataSize); void decryptAndProcessData(const SafeByteArray& data); void encryptAndSendData(const SafeByteArray& data); void appendNewData(const SafeByteArray& data); @@ -84,11 +94,12 @@ namespace Swift SecPkgContext_StreamSizes m_streamSizes; std::vector<char> m_receivedData; HCERTSTORE m_my_cert_store; std::string m_cert_store_name; std::string m_cert_name; ////Not needed, most likely std::string m_smartcard_reader; //Can be empty string for non SmartCard certificates + boost::shared_ptr<CAPICertificate> userCertificate; }; } |