summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKevin Smith <git@kismith.co.uk>2012-03-23 16:00:24 (GMT)
committerKevin Smith <git@kismith.co.uk>2012-04-12 13:49:48 (GMT)
commit0bf6afc5c01b9eb3024a8cfd04bfd743890db4f6 (patch)
treeca480f6b8e27afa97ade97ca7a13b11502b21f31 /Swiften/TLS
parentd5f885dd9aa65d18145a99826a1c30aeb62aca8e (diff)
downloadswift-contrib-0bf6afc5c01b9eb3024a8cfd04bfd743890db4f6.zip
swift-contrib-0bf6afc5c01b9eb3024a8cfd04bfd743890db4f6.tar.bz2
Tidy up of assorted Schannel/CAPI stuffs.
Makes Swift disconnect if a smartcard used for auth is removed. Fixes compilation. Changes code style in a few places.
Diffstat (limited to 'Swiften/TLS')
-rw-r--r--Swiften/TLS/CAPICertificate.cpp278
-rw-r--r--Swiften/TLS/CAPICertificate.h6
-rw-r--r--Swiften/TLS/SConscript1
-rw-r--r--Swiften/TLS/Schannel/SchannelContext.cpp282
-rw-r--r--Swiften/TLS/Schannel/SchannelContext.h23
5 files changed, 259 insertions, 331 deletions
diff --git a/Swiften/TLS/CAPICertificate.cpp b/Swiften/TLS/CAPICertificate.cpp
index b33ebcf..0083b6f 100644
--- a/Swiften/TLS/CAPICertificate.cpp
+++ b/Swiften/TLS/CAPICertificate.cpp
@@ -3,31 +3,35 @@
* Licensed under the simplified BSD license.
* See Documentation/Licenses/BSD-simplified.txt for more information.
*/
+
#pragma once
#include <Swiften/Network/TimerFactory.h>
#include <Swiften/TLS/CAPICertificate.h>
#include <Swiften/StringCodecs/Hexify.h>
+#include <Swiften/Base/Log.h>
#include <boost/bind.hpp>
#include <boost/algorithm/string/predicate.hpp>
// Size of the SHA1 hash
-#define SHA1_HASH_LEN 20
+#define SHA1_HASH_LEN 20
namespace Swift {
-CAPICertificate::CAPICertificate(const std::string& capiUri, TimerFactory* timerFactory)
- : valid_(false),
- uri_(capiUri),
- certStoreHandle_(0),
- scardContext_(0),
- cardHandle_(0),
- certStore_(),
- certName_(),
- smartCardReaderName_(),
- timerFactory_(timerFactory) {
+CAPICertificate::CAPICertificate(const std::string& capiUri, TimerFactory* timerFactory) :
+ valid_(false),
+ uri_(capiUri),
+ certStoreHandle_(0),
+ scardContext_(0),
+ cardHandle_(0),
+ certStore_(),
+ certName_(),
+ smartCardReaderName_(),
+ timerFactory_(timerFactory),
+ lastPollingResult_(true) {
+ assert(timerFactory_);
setUri(capiUri);
}
@@ -69,19 +73,11 @@ const std::string& CAPICertificate::getSmartCardReaderName() const {
}
PCCERT_CONTEXT findCertificateInStore (HCERTSTORE certStoreHandle, const std::string &certName) {
- PCCERT_CONTEXT pCertContext = NULL;
-
if (!boost::iequals(certName.substr(0, 5), "sha1:")) {
// Find client certificate. Note that this sample just searches for a
// certificate that contains the user name somewhere in the subject name.
- pCertContext = CertFindCertificateInStore(certStoreHandle,
- X509_ASN_ENCODING,
- 0, // dwFindFlags
- CERT_FIND_SUBJECT_STR_A,
- certName.c_str(), // *pvFindPara
- NULL ); // pPrevCertContext
- return pCertContext;
+ return CertFindCertificateInStore(certStoreHandle, X509_ASN_ENCODING, /*dwFindFlags*/ 0, CERT_FIND_SUBJECT_STR_A, /* *pvFindPara*/certName.c_str(), /*pPrevCertContext*/ NULL);
}
@@ -97,19 +93,12 @@ PCCERT_CONTEXT findCertificateInStore (HCERTSTORE certStoreHandle, const std::st
// Find client certificate. Note that this sample just searches for a
// certificate that contains the user name somewhere in the subject name.
- pCertContext = CertFindCertificateInStore(certStoreHandle,
- X509_ASN_ENCODING,
- 0, // dwFindFlags
- CERT_FIND_HASH,
- &HashBlob,
- NULL ); // pPrevCertContext
-
- return pCertContext;
+ return CertFindCertificateInStore(certStoreHandle, X509_ASN_ENCODING, /* dwFindFlags */ 0, CERT_FIND_HASH, &HashBlob, /* pPrevCertContext */ NULL);
+
}
void CAPICertificate::setUri (const std::string& capiUri) {
-
valid_ = false;
/* Syntax: "certstore:" <cert_store> ":" <hash> ":" <hash_of_cert> */
@@ -119,41 +108,39 @@ void CAPICertificate::setUri (const std::string& capiUri) {
}
/* Substring of subject: uses "storename" */
- std::string capi_identity = capiUri.substr(10);
- std::string new_certStore_name;
- size_t pos = capi_identity.find_first_of (':');
+ std::string capiIdentity = capiUri.substr(10);
+ std::string newCertStoreName;
+ size_t pos = capiIdentity.find_first_of (':');
if (pos == std::string::npos) {
/* Using the default certificate store */
- new_certStore_name = "MY";
- certName_ = capi_identity;
+ newCertStoreName = "MY";
+ certName_ = capiIdentity;
}
else {
- new_certStore_name = capi_identity.substr(0, pos);
- certName_ = capi_identity.substr(pos + 1);
+ newCertStoreName = capiIdentity.substr(0, pos);
+ certName_ = capiIdentity.substr(pos + 1);
}
- PCCERT_CONTEXT pCertContext = NULL;
-
if (certStoreHandle_ != NULL) {
- if (new_certStore_name != certStore_) {
+ if (newCertStoreName != certStore_) {
CertCloseStore(certStoreHandle_, 0);
certStoreHandle_ = NULL;
}
}
if (certStoreHandle_ == NULL) {
- certStoreHandle_ = CertOpenSystemStore(0, new_certStore_name.c_str());
+ certStoreHandle_ = CertOpenSystemStore(0, newCertStoreName.c_str());
if (!certStoreHandle_) {
return;
}
}
- certStore_ = new_certStore_name;
+ certStore_ = newCertStoreName;
- pCertContext = findCertificateInStore (certStoreHandle_, certName_);
+ PCCERT_CONTEXT certContext = findCertificateInStore (certStoreHandle_, certName_);
- if (!pCertContext) {
+ if (!certContext) {
return;
}
@@ -165,61 +152,50 @@ void CAPICertificate::setUri (const std::string& capiUri) {
HCRYPTPROV hprov;
HCRYPTKEY key;
- if (!CertGetCertificateContextProperty(pCertContext,
+ if (!CertGetCertificateContextProperty(certContext,
CERT_KEY_PROV_INFO_PROP_ID,
NULL,
&len)) {
- CertFreeCertificateContext(pCertContext);
+ CertFreeCertificateContext(certContext);
return;
}
pinfo = static_cast<CRYPT_KEY_PROV_INFO *>(malloc(len));
if (!pinfo) {
- CertFreeCertificateContext(pCertContext);
+ CertFreeCertificateContext(certContext);
return;
}
- if (!CertGetCertificateContextProperty(pCertContext,
- CERT_KEY_PROV_INFO_PROP_ID,
- pinfo,
- &len)) {
- CertFreeCertificateContext(pCertContext);
+ if (!CertGetCertificateContextProperty(certContext, CERT_KEY_PROV_INFO_PROP_ID, pinfo, &len)) {
+ CertFreeCertificateContext(certContext);
free(pinfo);
return;
}
- CertFreeCertificateContext(pCertContext);
+ CertFreeCertificateContext(certContext);
// Now verify if we have access to the private key
- if (!CryptAcquireContextW(&hprov,
- pinfo->pwszContainerName,
- pinfo->pwszProvName,
- pinfo->dwProvType,
- 0)) {
+ if (!CryptAcquireContextW(&hprov, pinfo->pwszContainerName, pinfo->pwszProvName, pinfo->dwProvType, 0)) {
free(pinfo);
return;
}
- char smartcard_reader[1024];
- DWORD buflen;
-
- buflen = sizeof(smartcard_reader);
- if (!CryptGetProvParam(hprov, PP_SMARTCARD_READER, (BYTE *)&smartcard_reader, &buflen, 0)) {
- DWORD error;
-
- error = GetLastError();
+ char smartCardReader[1024];
+ DWORD bufferLength = sizeof(smartCardReader);
+ if (!CryptGetProvParam(hprov, PP_SMARTCARD_READER, (BYTE *)&smartCardReader, &bufferLength, 0)) {
+ DWORD error = GetLastError();
smartCardReaderName_ = "";
- } else {
- LONG lRet;
-
- smartCardReaderName_ = smartcard_reader;
+ }
+ else {
+ smartCardReaderName_ = smartCardReader;
- lRet = SCardEstablishContext(SCARD_SCOPE_USER, NULL, NULL, &scardContext_);
- if (SCARD_S_SUCCESS == lRet) {
+ LONG result = SCardEstablishContext(SCARD_SCOPE_USER, NULL, NULL, &scardContext_);
+ if (SCARD_S_SUCCESS == result) {
// Initiate monitoring for smartcard ejection
- smartCardTimer_ = timerFactory_->createTimer(SMARTCARD_EJECTION_CHECK_FREQ);
- } else {
+ smartCardTimer_ = timerFactory_->createTimer(SMARTCARD_EJECTION_CHECK_FREQUENCY_MILLISECONDS);
+ }
+ else {
///Need to handle an error here
}
}
@@ -242,125 +218,115 @@ void CAPICertificate::setUri (const std::string& capiUri) {
valid_ = true;
}
-static void smartcard_check_status (SCARDCONTEXT hContext,
- const char * pReader,
- SCARDHANDLE hCardHandle, // Can be 0 on the first call
- SCARDHANDLE * newCardHandle, // The handle returned
- DWORD * pdwState) {
- LONG lReturn;
- DWORD dwAP;
- char szReader[200];
- DWORD cch = sizeof(szReader);
- BYTE bAttr[32];
- DWORD cByte = 32;
-
+static void smartcard_check_status (SCARDCONTEXT hContext,
+ const char* pReader,
+ SCARDHANDLE hCardHandle, /* Can be 0 on the first call */
+ SCARDHANDLE* newCardHandle, /* The handle returned */
+ DWORD* pdwState) {
if (hCardHandle == 0) {
- lReturn = SCardConnect(hContext,
- pReader,
- SCARD_SHARE_SHARED,
- SCARD_PROTOCOL_T0 | SCARD_PROTOCOL_T1,
- &hCardHandle,
- &dwAP);
- if ( SCARD_S_SUCCESS != lReturn ) {
+ DWORD dwAP;
+ LONG result = SCardConnect(hContext, pReader, SCARD_SHARE_SHARED, SCARD_PROTOCOL_T0 | SCARD_PROTOCOL_T1, &hCardHandle, &dwAP);
+ if (SCARD_S_SUCCESS != result) {
hCardHandle = 0;
- if (SCARD_E_NO_SMARTCARD == lReturn || SCARD_W_REMOVED_CARD == lReturn) {
+ if (SCARD_E_NO_SMARTCARD == result || SCARD_W_REMOVED_CARD == result) {
*pdwState = SCARD_ABSENT;
- } else {
+ }
+ else {
*pdwState = SCARD_UNKNOWN;
}
- goto done;
+
+ if (newCardHandle == NULL) {
+ (void) SCardDisconnect(hCardHandle, SCARD_LEAVE_CARD);
+ hCardHandle = 0;
+ }
+ else {
+ *newCardHandle = hCardHandle;
+ }
}
}
- lReturn = SCardStatus(hCardHandle,
- szReader, // Unfortunately we can't use NULL here
- &cch,
- pdwState,
- NULL,
- (LPBYTE)&bAttr,
- &cByte);
+ char szReader[200];
+ DWORD cch = sizeof(szReader);
+ BYTE bAttr[32];
+ DWORD cByte = 32;
+ LONG result = SCardStatus(hCardHandle, /* Unfortunately we can't use NULL here */ szReader, &cch, pdwState, NULL, (LPBYTE)&bAttr, &cByte);
- if ( SCARD_S_SUCCESS != lReturn ) {
- if (SCARD_E_NO_SMARTCARD == lReturn || SCARD_W_REMOVED_CARD == lReturn) {
+ if (SCARD_S_SUCCESS != result) {
+ if (SCARD_E_NO_SMARTCARD == result || SCARD_W_REMOVED_CARD == result) {
*pdwState = SCARD_ABSENT;
- } else {
+ }
+ else {
*pdwState = SCARD_UNKNOWN;
}
}
-done:
if (newCardHandle == NULL) {
(void) SCardDisconnect(hCardHandle, SCARD_LEAVE_CARD);
hCardHandle = 0;
- } else {
+ }
+ else {
*newCardHandle = hCardHandle;
}
}
bool CAPICertificate::checkIfSmartCardPresent () {
-
- DWORD dwState;
-
if (!smartCardReaderName_.empty()) {
- smartcard_check_status (scardContext_,
- smartCardReaderName_.c_str(),
- cardHandle_,
- &cardHandle_,
- &dwState);
-////DEBUG
- switch ( dwState ) {
- case SCARD_ABSENT:
- printf("Card absent.\n");
- break;
- case SCARD_PRESENT:
- printf("Card present.\n");
- break;
- case SCARD_SWALLOWED:
- printf("Card swallowed.\n");
- break;
- case SCARD_POWERED:
- printf("Card has power.\n");
- break;
- case SCARD_NEGOTIABLE:
- printf("Card reset and waiting PTS negotiation.\n");
- break;
- case SCARD_SPECIFIC:
- printf("Card has specific communication protocols set.\n");
- break;
- default:
- printf("Unknown or unexpected card state.\n");
- break;
+ DWORD dwState;
+ smartcard_check_status(scardContext_, smartCardReaderName_.c_str(), cardHandle_, &cardHandle_, &dwState);
+
+ switch (dwState) {
+ case SCARD_ABSENT:
+ SWIFT_LOG("DEBUG") << "Card absent." << std::endl;
+ break;
+ case SCARD_PRESENT:
+ SWIFT_LOG("DEBUG") << "Card present." << std::endl;
+ break;
+ case SCARD_SWALLOWED:
+ SWIFT_LOG("DEBUG") << "Card swallowed." << std::endl;
+ break;
+ case SCARD_POWERED:
+ SWIFT_LOG("DEBUG") << "Card has power." << std::endl;
+ break;
+ case SCARD_NEGOTIABLE:
+ SWIFT_LOG("DEBUG") << "Card reset and waiting PTS negotiation." << std::endl;
+ break;
+ case SCARD_SPECIFIC:
+ SWIFT_LOG("DEBUG") << "Card has specific communication protocols set." << std::endl;
+ break;
+ default:
+ SWIFT_LOG("DEBUG") << "Unknown or unexpected card state." << std::endl;
+ break;
}
- switch ( dwState ) {
- case SCARD_ABSENT:
- return false;
+ switch (dwState) {
+ case SCARD_ABSENT:
+ return false;
- case SCARD_PRESENT:
- case SCARD_SWALLOWED:
- case SCARD_POWERED:
- case SCARD_NEGOTIABLE:
- case SCARD_SPECIFIC:
- return true;
+ case SCARD_PRESENT:
+ case SCARD_SWALLOWED:
+ case SCARD_POWERED:
+ case SCARD_NEGOTIABLE:
+ case SCARD_SPECIFIC:
+ return true;
- default:
- return false;
+ default:
+ return false;
}
- } else {
+ }
+ else {
return false;
}
}
void CAPICertificate::handleSmartCardTimerTick() {
-
- if (checkIfSmartCardPresent() == false) {
- smartCardTimer_->stop();
+ bool poll = checkIfSmartCardPresent();
+ if (lastPollingResult_ && !poll) {
onCertificateCardRemoved();
- } else {
- smartCardTimer_->start();
- }
+ }
+ lastPollingResult_ = poll;
+ smartCardTimer_->start();
}
}
diff --git a/Swiften/TLS/CAPICertificate.h b/Swiften/TLS/CAPICertificate.h
index c8c00fe..5f24b7e 100644
--- a/Swiften/TLS/CAPICertificate.h
+++ b/Swiften/TLS/CAPICertificate.h
@@ -16,15 +16,13 @@
#include <WinCrypt.h>
#include <Winscard.h>
-/* In ms */
-#define SMARTCARD_EJECTION_CHECK_FREQ 1000
+#define SMARTCARD_EJECTION_CHECK_FREQUENCY_MILLISECONDS 1000
namespace Swift {
class TimerFactory;
class CAPICertificate : public Swift::CertificateWithKey {
public:
-////Allow timerFactory to be NULL?
CAPICertificate(const std::string& capiUri, TimerFactory* timerFactory);
virtual ~CAPICertificate();
@@ -61,6 +59,8 @@ namespace Swift {
std::string smartCardReaderName_;
boost::shared_ptr<Timer> smartCardTimer_;
TimerFactory* timerFactory_;
+
+ bool lastPollingResult_;
};
PCCERT_CONTEXT findCertificateInStore (HCERTSTORE certStoreHandle, const std::string &certName);
diff --git a/Swiften/TLS/SConscript b/Swiften/TLS/SConscript
index 0e95b8b..fb327b9 100644
--- a/Swiften/TLS/SConscript
+++ b/Swiften/TLS/SConscript
@@ -19,6 +19,7 @@ if myenv.get("HAVE_OPENSSL", 0) :
])
myenv.Append(CPPDEFINES = "HAVE_OPENSSL")
elif myenv.get("HAVE_SCHANNEL", 0) :
+ swiften_env.Append(LIBS = ["Winscard"])
objects += myenv.StaticObject([
"CAPICertificate.cpp",
"Schannel/SchannelContext.cpp",
diff --git a/Swiften/TLS/Schannel/SchannelContext.cpp b/Swiften/TLS/Schannel/SchannelContext.cpp
index 8e952ea..6169ad7 100644
--- a/Swiften/TLS/Schannel/SchannelContext.cpp
+++ b/Swiften/TLS/Schannel/SchannelContext.cpp
@@ -4,79 +4,69 @@
* See Documentation/Licenses/BSD-simplified.txt for more information.
*/
+/*
+ * Copyright (c) 2012 Kevin Smith
+ * Licensed under the GNU General Public License v3.
+ * See Documentation/Licenses/GPLv3.txt for more information.
+ */
+
#include <boost/bind.hpp>
#include <Swiften/TLS/Schannel/SchannelContext.h>
#include <Swiften/TLS/Schannel/SchannelCertificate.h>
#include <Swiften/TLS/CAPICertificate.h>
-#include <WinHTTP.h> // For SECURITY_FLAG_IGNORE_CERT_CN_INVALID
+#include <WinHTTP.h> /* For SECURITY_FLAG_IGNORE_CERT_CN_INVALID */
namespace Swift {
//------------------------------------------------------------------------
-SchannelContext::SchannelContext()
-: m_state(Start)
-, m_secContext(0)
-, m_my_cert_store(NULL)
-, m_cert_store_name("MY")
-, m_cert_name()
-, m_smartcard_reader()
-{
- m_ctxtFlags = ISC_REQ_ALLOCATE_MEMORY |
- ISC_REQ_CONFIDENTIALITY |
- ISC_REQ_EXTENDED_ERROR |
- ISC_REQ_INTEGRITY |
- ISC_REQ_REPLAY_DETECT |
- ISC_REQ_SEQUENCE_DETECT |
- ISC_REQ_USE_SUPPLIED_CREDS |
- ISC_REQ_STREAM;
+SchannelContext::SchannelContext() : m_state(Start), m_secContext(0), m_my_cert_store(NULL), m_cert_store_name("MY"), m_cert_name(), m_smartcard_reader() {
+ m_ctxtFlags = ISC_REQ_ALLOCATE_MEMORY |
+ ISC_REQ_CONFIDENTIALITY |
+ ISC_REQ_EXTENDED_ERROR |
+ ISC_REQ_INTEGRITY |
+ ISC_REQ_REPLAY_DETECT |
+ ISC_REQ_SEQUENCE_DETECT |
+ ISC_REQ_USE_SUPPLIED_CREDS |
+ ISC_REQ_STREAM;
ZeroMemory(&m_streamSizes, sizeof(m_streamSizes));
}
//------------------------------------------------------------------------
-SchannelContext::~SchannelContext()
-{
+SchannelContext::~SchannelContext() {
if (m_my_cert_store) CertCloseStore(m_my_cert_store, 0);
}
//------------------------------------------------------------------------
-void SchannelContext::determineStreamSizes()
-{
+void SchannelContext::determineStreamSizes() {
QueryContextAttributes(m_ctxtHandle, SECPKG_ATTR_STREAM_SIZES, &m_streamSizes);
}
//------------------------------------------------------------------------
-void SchannelContext::connect()
-{
+void SchannelContext::connect() {
ScopedCertContext pCertContext;
m_state = Connecting;
// If a user name is specified, then attempt to find a client
// certificate. Otherwise, just create a NULL credential.
- if (!m_cert_name.empty())
- {
- if (m_my_cert_store == NULL)
- {
+ if (!m_cert_name.empty()) {
+ if (m_my_cert_store == NULL) {
m_my_cert_store = CertOpenSystemStore(0, m_cert_store_name.c_str());
- if (!m_my_cert_store)
- {
-///// printf( "**** Error 0x%x returned by CertOpenSystemStore\n", GetLastError() );
- indicateError();
+ if (!m_my_cert_store) {
+ indicateError(boost::make_shared<TLSError>(TLSError::UnknownError));
return;
}
}
pCertContext = findCertificateInStore( m_my_cert_store, m_cert_name );
- if (pCertContext == NULL)
- {
-///// printf("**** Error 0x%x returned by CertFindCertificateInStore\n", GetLastError());
- indicateError();
+ if (pCertContext == NULL) {
+ indicateError(boost::make_shared<TLSError>(TLSError::UnknownError));
return;
}
}
@@ -91,17 +81,15 @@ void SchannelContext::connect()
sc.grbitEnabledProtocols = SP_PROT_SSL3_CLIENT | SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_2_CLIENT;
sc.dwFlags = SCH_CRED_MANUAL_CRED_VALIDATION;
- if (pCertContext)
- {
+ if (pCertContext) {
sc.cCreds = 1;
sc.paCred = pCertContext.GetPointer();
sc.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS;
}
- else
- {
- sc.cCreds = 0; // Let Crypto API find the appropriate certificate for us
+ else {
+ sc.cCreds = 0;
sc.paCred = clientCerts;
- sc.dwFlags |= SCH_CRED_USE_DEFAULT_CREDS;
+ sc.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS;
}
// Swiften performs the server name check for us
@@ -118,10 +106,9 @@ void SchannelContext::connect()
m_credHandle.Reset(),
NULL);
- if (status != SEC_E_OK)
- {
+ if (status != SEC_E_OK) {
// We failed to obtain the credentials handle
- indicateError();
+ indicateError(boost::make_shared<TLSError>(TLSError::UnknownError));
return;
}
@@ -161,22 +148,21 @@ void SchannelContext::connect()
&m_secContext,
NULL);
- if (status != SEC_E_OK && status != SEC_I_CONTINUE_NEEDED)
- {
+ if (status != SEC_E_OK && status != SEC_I_CONTINUE_NEEDED) {
// We failed to initialize the security context
handleCertError(status);
- indicateError();
+ indicateError(boost::make_shared<TLSError>(TLSError::UnknownError));
return;
}
// Start the handshake
sendDataOnNetwork(outBuffers[0].pvBuffer, outBuffers[0].cbBuffer);
- if (status == SEC_E_OK)
- {
+ if (status == SEC_E_OK) {
status = validateServerCertificate();
- if (status != SEC_E_OK)
+ if (status != SEC_E_OK) {
handleCertError(status);
+ }
m_state = Connected;
determineStreamSizes();
@@ -187,11 +173,11 @@ void SchannelContext::connect()
//------------------------------------------------------------------------
-SECURITY_STATUS SchannelContext::validateServerCertificate()
-{
+SECURITY_STATUS SchannelContext::validateServerCertificate() {
SchannelCertificate::ref pServerCert = boost::dynamic_pointer_cast<SchannelCertificate>( getPeerCertificate() );
- if (!pServerCert)
+ if (!pServerCert) {
return SEC_E_WRONG_PRINCIPAL;
+ }
const LPSTR usage[] =
{
@@ -220,8 +206,9 @@ SECURITY_STATUS SchannelContext::validateServerCertificate()
NULL,
pChainContext.Reset());
- if (!success)
+ if (!success) {
return GetLastError();
+ }
SSL_EXTRA_CERT_CHAIN_POLICY_PARA sslChainPolicy = {0};
sslChainPolicy.cbSize = sizeof(sslChainPolicy);
@@ -242,34 +229,31 @@ SECURITY_STATUS SchannelContext::validateServerCertificate()
CERT_CHAIN_POLICY_SSL,
pChainContext,
&certChainPolicy,
- &certChainPolicyStatus))
- {
+ &certChainPolicyStatus)) {
return GetLastError();
}
- if (certChainPolicyStatus.dwError != S_OK)
+ if (certChainPolicyStatus.dwError != S_OK) {
return certChainPolicyStatus.dwError;
+ }
return S_OK;
}
//------------------------------------------------------------------------
-void SchannelContext::appendNewData(const SafeByteArray& data)
-{
+void SchannelContext::appendNewData(const SafeByteArray& data) {
size_t originalSize = m_receivedData.size();
- m_receivedData.resize( originalSize + data.size() );
- memcpy( &m_receivedData[0] + originalSize, &data[0], data.size() );
+ m_receivedData.resize(originalSize + data.size());
+ memcpy(&m_receivedData[0] + originalSize, &data[0], data.size());
}
//------------------------------------------------------------------------
-void SchannelContext::continueHandshake(const SafeByteArray& data)
-{
+void SchannelContext::continueHandshake(const SafeByteArray& data) {
appendNewData(data);
- while (!m_receivedData.empty())
- {
+ while (!m_receivedData.empty()) {
SecBuffer inBuffers[2];
// Provide Schannel with the remote host's handshake data
@@ -321,49 +305,51 @@ void SchannelContext::continueHandshake(const SafeByteArray& data)
&m_secContext,
NULL);
- if (status == SEC_E_INCOMPLETE_MESSAGE)
- {
+ if (status == SEC_E_INCOMPLETE_MESSAGE) {
// Wait for more data to arrive
break;
}
- else if (status == SEC_I_CONTINUE_NEEDED)
- {
+ else if (status == SEC_I_CONTINUE_NEEDED) {
SecBuffer* pDataBuffer = &outBuffers[0];
SecBuffer* pExtraBuffer = &inBuffers[1];
- if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL)
+ if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL) {
sendDataOnNetwork(pDataBuffer->pvBuffer, pDataBuffer->cbBuffer);
+ }
- if (pExtraBuffer->BufferType == SECBUFFER_EXTRA)
+ if (pExtraBuffer->BufferType == SECBUFFER_EXTRA) {
m_receivedData.erase(m_receivedData.begin(), m_receivedData.end() - pExtraBuffer->cbBuffer);
- else
+ }
+ else {
m_receivedData.clear();
+ }
break;
}
- else if (status == SEC_E_OK)
- {
+ else if (status == SEC_E_OK) {
status = validateServerCertificate();
- if (status != SEC_E_OK)
+ if (status != SEC_E_OK) {
handleCertError(status);
+ }
SecBuffer* pExtraBuffer = &inBuffers[1];
- if (pExtraBuffer && pExtraBuffer->cbBuffer > 0)
+ if (pExtraBuffer && pExtraBuffer->cbBuffer > 0) {
m_receivedData.erase(m_receivedData.begin(), m_receivedData.end() - pExtraBuffer->cbBuffer);
- else
+ }
+ else {
m_receivedData.clear();
+ }
m_state = Connected;
determineStreamSizes();
onConnected();
}
- else
- {
+ else {
// We failed to initialize the security context
handleCertError(status);
- indicateError();
+ indicateError(boost::make_shared<TLSError>(TLSError::UnknownError));
return;
}
}
@@ -377,45 +363,36 @@ void SchannelContext::handleCertError(SECURITY_STATUS status)
status == CERT_E_UNTRUSTEDROOT ||
status == CRYPT_E_ISSUER_SERIALNUMBER ||
status == CRYPT_E_SIGNER_NOT_FOUND ||
- status == CRYPT_E_NO_TRUSTED_SIGNER)
- {
+ status == CRYPT_E_NO_TRUSTED_SIGNER) {
m_verificationError = CertificateVerificationError::Untrusted;
}
else if (status == SEC_E_CERT_EXPIRED ||
- status == CERT_E_EXPIRED)
- {
+ status == CERT_E_EXPIRED) {
m_verificationError = CertificateVerificationError::Expired;
}
- else if (status == CRYPT_E_SELF_SIGNED)
- {
+ else if (status == CRYPT_E_SELF_SIGNED) {
m_verificationError = CertificateVerificationError::SelfSigned;
}
else if (status == CRYPT_E_HASH_VALUE ||
- status == TRUST_E_CERT_SIGNATURE)
- {
+ status == TRUST_E_CERT_SIGNATURE) {
m_verificationError = CertificateVerificationError::InvalidSignature;
}
- else if (status == CRYPT_E_REVOKED)
- {
+ else if (status == CRYPT_E_REVOKED) {
m_verificationError = CertificateVerificationError::Revoked;
}
else if (status == CRYPT_E_NO_REVOCATION_CHECK ||
- status == CRYPT_E_REVOCATION_OFFLINE)
- {
+ status == CRYPT_E_REVOCATION_OFFLINE) {
m_verificationError = CertificateVerificationError::RevocationCheckFailed;
}
- else
- {
+ else {
m_verificationError = CertificateVerificationError::UnknownError;
}
}
//------------------------------------------------------------------------
-void SchannelContext::sendDataOnNetwork(const void* pData, size_t dataSize)
-{
- if (dataSize > 0 && pData)
- {
+void SchannelContext::sendDataOnNetwork(const void* pData, size_t dataSize) {
+ if (dataSize > 0 && pData) {
SafeByteArray byteArray(dataSize);
memcpy(&byteArray[0], pData, dataSize);
@@ -425,8 +402,7 @@ void SchannelContext::sendDataOnNetwork(const void* pData, size_t dataSize)
//------------------------------------------------------------------------
-void SchannelContext::forwardDataToApplication(const void* pData, size_t dataSize)
-{
+void SchannelContext::forwardDataToApplication(const void* pData, size_t dataSize) {
SafeByteArray byteArray(dataSize);
memcpy(&byteArray[0], pData, dataSize);
@@ -435,11 +411,11 @@ void SchannelContext::forwardDataToApplication(const void* pData, size_t dataSiz
//------------------------------------------------------------------------
-void SchannelContext::handleDataFromApplication(const SafeByteArray& data)
-{
+void SchannelContext::handleDataFromApplication(const SafeByteArray& data) {
// Don't attempt to send data until we're fully connected
- if (m_state == Connecting)
+ if (m_state == Connecting) {
return;
+ }
// Encrypt the data
encryptAndSendData(data);
@@ -447,10 +423,8 @@ void SchannelContext::handleDataFromApplication(const SafeByteArray& data)
//------------------------------------------------------------------------
-void SchannelContext::handleDataFromNetwork(const SafeByteArray& data)
-{
- switch (m_state)
- {
+void SchannelContext::handleDataFromNetwork(const SafeByteArray& data) {
+ switch (m_state) {
case Connecting:
{
// We're still establishing the connection, so continue the handshake
@@ -472,23 +446,20 @@ void SchannelContext::handleDataFromNetwork(const SafeByteArray& data)
//------------------------------------------------------------------------
-void SchannelContext::indicateError()
-{
+void SchannelContext::indicateError(boost::shared_ptr<TLSError> error) {
m_state = Error;
m_receivedData.clear();
- onError(boost::make_shared<TLSError>());
+ onError(error);
}
//------------------------------------------------------------------------
-void SchannelContext::decryptAndProcessData(const SafeByteArray& data)
-{
+void SchannelContext::decryptAndProcessData(const SafeByteArray& data) {
SecBuffer inBuffers[4] = {0};
appendNewData(data);
- while (!m_receivedData.empty())
- {
+ while (!m_receivedData.empty()) {
//
// MSDN:
// When using the Schannel SSP with contexts that are not connection oriented, on input,
@@ -515,49 +486,44 @@ void SchannelContext::decryptAndProcessData(const SafeByteArray& data)
size_t inData = m_receivedData.size();
SECURITY_STATUS status = DecryptMessage(m_ctxtHandle, &inBufferDesc, 0, NULL);
- if (status == SEC_E_INCOMPLETE_MESSAGE)
- {
+ if (status == SEC_E_INCOMPLETE_MESSAGE) {
// Wait for more data to arrive
break;
}
- else if (status == SEC_I_RENEGOTIATE)
- {
+ else if (status == SEC_I_RENEGOTIATE) {
// TODO: Handle renegotiation scenarios
- indicateError();
+ indicateError(boost::make_shared<TLSError>(TLSError::UnknownError));
break;
}
- else if (status == SEC_I_CONTEXT_EXPIRED)
- {
- indicateError();
+ else if (status == SEC_I_CONTEXT_EXPIRED) {
+ indicateError(boost::make_shared<TLSError>(TLSError::UnknownError));
break;
}
- else if (status != SEC_E_OK)
- {
- indicateError();
+ else if (status != SEC_E_OK) {
+ indicateError(boost::make_shared<TLSError>(TLSError::UnknownError));
break;
}
SecBuffer* pDataBuffer = NULL;
SecBuffer* pExtraBuffer = NULL;
- for (int i = 0; i < 4; ++i)
- {
- if (inBuffers[i].BufferType == SECBUFFER_DATA)
+ for (int i = 0; i < 4; ++i) {
+ if (inBuffers[i].BufferType == SECBUFFER_DATA) {
pDataBuffer = &inBuffers[i];
-
- else if (inBuffers[i].BufferType == SECBUFFER_EXTRA)
+ }
+ else if (inBuffers[i].BufferType == SECBUFFER_EXTRA) {
pExtraBuffer = &inBuffers[i];
+ }
}
- if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL)
+ if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL) {
forwardDataToApplication(pDataBuffer->pvBuffer, pDataBuffer->cbBuffer);
+ }
// If there is extra data left over from the decryption operation, we call DecryptMessage() again
- if (pExtraBuffer)
- {
+ if (pExtraBuffer) {
m_receivedData.erase(m_receivedData.begin(), m_receivedData.end() - pExtraBuffer->cbBuffer);
}
- else
- {
+ else {
// We're done
m_receivedData.erase(m_receivedData.begin(), m_receivedData.begin() + inData);
}
@@ -566,10 +532,10 @@ void SchannelContext::decryptAndProcessData(const SafeByteArray& data)
//------------------------------------------------------------------------
-void SchannelContext::encryptAndSendData(const SafeByteArray& data)
-{
- if (m_streamSizes.cbMaximumMessage == 0)
+void SchannelContext::encryptAndSendData(const SafeByteArray& data) {
+ if (m_streamSizes.cbMaximumMessage == 0) {
return;
+ }
SecBuffer outBuffers[4] = {0};
@@ -583,8 +549,7 @@ void SchannelContext::encryptAndSendData(const SafeByteArray& data)
sendBuffer.resize(m_streamSizes.cbHeader + messageBufferSize + m_streamSizes.cbTrailer);
size_t bytesSent = 0;
- do
- {
+ do {
size_t bytesLeftToSend = data.size() - bytesSent;
// Calculate how much of the send buffer we'll be using for this chunk
@@ -617,9 +582,8 @@ void SchannelContext::encryptAndSendData(const SafeByteArray& data)
outBufferDesc.ulVersion = SECBUFFER_VERSION;
SECURITY_STATUS status = EncryptMessage(m_ctxtHandle, 0, &outBufferDesc, 0);
- if (status != SEC_E_OK)
- {
- indicateError();
+ if (status != SEC_E_OK) {
+ indicateError(boost::make_shared<TLSError>(TLSError::UnknownError));
return;
}
@@ -631,13 +595,14 @@ void SchannelContext::encryptAndSendData(const SafeByteArray& data)
//------------------------------------------------------------------------
-bool SchannelContext::setClientCertificate(CertificateWithKey::ref certificate)
-{
+bool SchannelContext::setClientCertificate(CertificateWithKey::ref certificate) {
boost::shared_ptr<CAPICertificate> capiCertificate = boost::dynamic_pointer_cast<CAPICertificate>(certificate);
if (!capiCertificate || capiCertificate->isNull()) {
return false;
}
+ userCertificate = capiCertificate;
+
// We assume that the Certificate Store Name/Certificate Name
// are valid at this point
m_cert_store_name = capiCertificate->getCertStoreName();
@@ -652,41 +617,26 @@ bool SchannelContext::setClientCertificate(CertificateWithKey::ref certificate)
//------------------------------------------------------------------------
void SchannelContext::handleCertificateCardRemoved() {
- //ToDo: Might want to log the reason ("certificate card ejected")
- indicateError();
+ indicateError(boost::make_shared<TLSError>(TLSError::CertificateCardRemoved));
}
//------------------------------------------------------------------------
-Certificate::ref SchannelContext::getPeerCertificate() const
-{
- SchannelCertificate::ref pCertificate;
-
+Certificate::ref SchannelContext::getPeerCertificate() const {
ScopedCertContext pServerCert;
SECURITY_STATUS status = QueryContextAttributes(m_ctxtHandle, SECPKG_ATTR_REMOTE_CERT_CONTEXT, pServerCert.Reset());
- if (status != SEC_E_OK)
- return pCertificate;
-
- pCertificate.reset( new SchannelCertificate(pServerCert) );
- return pCertificate;
+ return status == SEC_E_OK ? boost::make_shared<SchannelCertificate>(pServerCert) : SchannelCertificate::ref();
}
//------------------------------------------------------------------------
-CertificateVerificationError::ref SchannelContext::getPeerCertificateVerificationError() const
-{
- boost::shared_ptr<CertificateVerificationError> pCertError;
-
- if (m_verificationError)
- pCertError.reset( new CertificateVerificationError(*m_verificationError) );
-
- return pCertError;
+CertificateVerificationError::ref SchannelContext::getPeerCertificateVerificationError() const {
+ return m_verificationError ? boost::make_shared<CertificateVerificationError>(*m_verificationError) : CertificateVerificationError::ref();
}
//------------------------------------------------------------------------
-ByteArray SchannelContext::getFinishMessage() const
-{
+ByteArray SchannelContext::getFinishMessage() const {
// TODO: Implement
ByteArray emptyArray;
diff --git a/Swiften/TLS/Schannel/SchannelContext.h b/Swiften/TLS/Schannel/SchannelContext.h
index bce7415..7c2601b 100644
--- a/Swiften/TLS/Schannel/SchannelContext.h
+++ b/Swiften/TLS/Schannel/SchannelContext.h
@@ -4,14 +4,21 @@
* See Documentation/Licenses/BSD-simplified.txt for more information.
*/
+/*
+ * Copyright (c) 2012 Kevin Smith
+ * Licensed under the GNU General Public License v3.
+ * See Documentation/Licenses/GPLv3.txt for more information.
+ */
+
#pragma once
-#include "Swiften/Base/boost_bsignals.h"
+#include <Swiften/Base/boost_bsignals.h>
-#include "Swiften/TLS/TLSContext.h"
-#include "Swiften/TLS/Schannel/SchannelUtil.h"
-#include "Swiften/TLS/CertificateWithKey.h"
-#include "Swiften/Base/ByteArray.h"
+#include <Swiften/TLS/TLSContext.h>
+#include <Swiften/TLS/Schannel/SchannelUtil.h>
+#include <Swiften/TLS/CertificateWithKey.h>
+#include <Swiften/Base/ByteArray.h>
+#include <Swiften/TLS/TLSError.h>
#define SECURITY_WIN32
#include <Windows.h>
@@ -23,6 +30,7 @@
namespace Swift
{
+ class CAPICertificate;
class SchannelContext : public TLSContext, boost::noncopyable
{
public:
@@ -50,7 +58,9 @@ namespace Swift
private:
void determineStreamSizes();
void continueHandshake(const SafeByteArray& data);
- void indicateError();
+ void indicateError(boost::shared_ptr<TLSError> error);
+ //FIXME: Remove
+ void indicateError() {indicateError(boost::make_shared<TLSError>());}
void handleCertError(SECURITY_STATUS status) ;
void sendDataOnNetwork(const void* pData, size_t dataSize);
@@ -90,5 +100,6 @@ namespace Swift
std::string m_cert_name;
////Not needed, most likely
std::string m_smartcard_reader; //Can be empty string for non SmartCard certificates
+ boost::shared_ptr<CAPICertificate> userCertificate;
};
}