diff options
Diffstat (limited to 'Swiften/TLS/Schannel/SchannelContext.cpp')
-rw-r--r-- | Swiften/TLS/Schannel/SchannelContext.cpp | 282 |
1 files changed, 116 insertions, 166 deletions
diff --git a/Swiften/TLS/Schannel/SchannelContext.cpp b/Swiften/TLS/Schannel/SchannelContext.cpp index 8e952ea..6169ad7 100644 --- a/Swiften/TLS/Schannel/SchannelContext.cpp +++ b/Swiften/TLS/Schannel/SchannelContext.cpp @@ -4,79 +4,69 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2012 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + #include <boost/bind.hpp> #include <Swiften/TLS/Schannel/SchannelContext.h> #include <Swiften/TLS/Schannel/SchannelCertificate.h> #include <Swiften/TLS/CAPICertificate.h> -#include <WinHTTP.h> // For SECURITY_FLAG_IGNORE_CERT_CN_INVALID +#include <WinHTTP.h> /* For SECURITY_FLAG_IGNORE_CERT_CN_INVALID */ namespace Swift { //------------------------------------------------------------------------ -SchannelContext::SchannelContext() -: m_state(Start) -, m_secContext(0) -, m_my_cert_store(NULL) -, m_cert_store_name("MY") -, m_cert_name() -, m_smartcard_reader() -{ - m_ctxtFlags = ISC_REQ_ALLOCATE_MEMORY | - ISC_REQ_CONFIDENTIALITY | - ISC_REQ_EXTENDED_ERROR | - ISC_REQ_INTEGRITY | - ISC_REQ_REPLAY_DETECT | - ISC_REQ_SEQUENCE_DETECT | - ISC_REQ_USE_SUPPLIED_CREDS | - ISC_REQ_STREAM; +SchannelContext::SchannelContext() : m_state(Start), m_secContext(0), m_my_cert_store(NULL), m_cert_store_name("MY"), m_cert_name(), m_smartcard_reader() { + m_ctxtFlags = ISC_REQ_ALLOCATE_MEMORY | + ISC_REQ_CONFIDENTIALITY | + ISC_REQ_EXTENDED_ERROR | + ISC_REQ_INTEGRITY | + ISC_REQ_REPLAY_DETECT | + ISC_REQ_SEQUENCE_DETECT | + ISC_REQ_USE_SUPPLIED_CREDS | + ISC_REQ_STREAM; ZeroMemory(&m_streamSizes, sizeof(m_streamSizes)); } //------------------------------------------------------------------------ -SchannelContext::~SchannelContext() -{ +SchannelContext::~SchannelContext() { if (m_my_cert_store) CertCloseStore(m_my_cert_store, 0); } //------------------------------------------------------------------------ -void SchannelContext::determineStreamSizes() -{ +void SchannelContext::determineStreamSizes() { QueryContextAttributes(m_ctxtHandle, SECPKG_ATTR_STREAM_SIZES, &m_streamSizes); } //------------------------------------------------------------------------ -void SchannelContext::connect() -{ +void SchannelContext::connect() { ScopedCertContext pCertContext; m_state = Connecting; // If a user name is specified, then attempt to find a client // certificate. Otherwise, just create a NULL credential. - if (!m_cert_name.empty()) - { - if (m_my_cert_store == NULL) - { + if (!m_cert_name.empty()) { + if (m_my_cert_store == NULL) { m_my_cert_store = CertOpenSystemStore(0, m_cert_store_name.c_str()); - if (!m_my_cert_store) - { -///// printf( "**** Error 0x%x returned by CertOpenSystemStore\n", GetLastError() ); - indicateError(); + if (!m_my_cert_store) { + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); return; } } pCertContext = findCertificateInStore( m_my_cert_store, m_cert_name ); - if (pCertContext == NULL) - { -///// printf("**** Error 0x%x returned by CertFindCertificateInStore\n", GetLastError()); - indicateError(); + if (pCertContext == NULL) { + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); return; } } @@ -91,17 +81,15 @@ void SchannelContext::connect() sc.grbitEnabledProtocols = SP_PROT_SSL3_CLIENT | SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_2_CLIENT; sc.dwFlags = SCH_CRED_MANUAL_CRED_VALIDATION; - if (pCertContext) - { + if (pCertContext) { sc.cCreds = 1; sc.paCred = pCertContext.GetPointer(); sc.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS; } - else - { - sc.cCreds = 0; // Let Crypto API find the appropriate certificate for us + else { + sc.cCreds = 0; sc.paCred = clientCerts; - sc.dwFlags |= SCH_CRED_USE_DEFAULT_CREDS; + sc.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS; } // Swiften performs the server name check for us @@ -118,10 +106,9 @@ void SchannelContext::connect() m_credHandle.Reset(), NULL); - if (status != SEC_E_OK) - { + if (status != SEC_E_OK) { // We failed to obtain the credentials handle - indicateError(); + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); return; } @@ -161,22 +148,21 @@ void SchannelContext::connect() &m_secContext, NULL); - if (status != SEC_E_OK && status != SEC_I_CONTINUE_NEEDED) - { + if (status != SEC_E_OK && status != SEC_I_CONTINUE_NEEDED) { // We failed to initialize the security context handleCertError(status); - indicateError(); + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); return; } // Start the handshake sendDataOnNetwork(outBuffers[0].pvBuffer, outBuffers[0].cbBuffer); - if (status == SEC_E_OK) - { + if (status == SEC_E_OK) { status = validateServerCertificate(); - if (status != SEC_E_OK) + if (status != SEC_E_OK) { handleCertError(status); + } m_state = Connected; determineStreamSizes(); @@ -187,11 +173,11 @@ void SchannelContext::connect() //------------------------------------------------------------------------ -SECURITY_STATUS SchannelContext::validateServerCertificate() -{ +SECURITY_STATUS SchannelContext::validateServerCertificate() { SchannelCertificate::ref pServerCert = boost::dynamic_pointer_cast<SchannelCertificate>( getPeerCertificate() ); - if (!pServerCert) + if (!pServerCert) { return SEC_E_WRONG_PRINCIPAL; + } const LPSTR usage[] = { @@ -220,8 +206,9 @@ SECURITY_STATUS SchannelContext::validateServerCertificate() NULL, pChainContext.Reset()); - if (!success) + if (!success) { return GetLastError(); + } SSL_EXTRA_CERT_CHAIN_POLICY_PARA sslChainPolicy = {0}; sslChainPolicy.cbSize = sizeof(sslChainPolicy); @@ -242,34 +229,31 @@ SECURITY_STATUS SchannelContext::validateServerCertificate() CERT_CHAIN_POLICY_SSL, pChainContext, &certChainPolicy, - &certChainPolicyStatus)) - { + &certChainPolicyStatus)) { return GetLastError(); } - if (certChainPolicyStatus.dwError != S_OK) + if (certChainPolicyStatus.dwError != S_OK) { return certChainPolicyStatus.dwError; + } return S_OK; } //------------------------------------------------------------------------ -void SchannelContext::appendNewData(const SafeByteArray& data) -{ +void SchannelContext::appendNewData(const SafeByteArray& data) { size_t originalSize = m_receivedData.size(); - m_receivedData.resize( originalSize + data.size() ); - memcpy( &m_receivedData[0] + originalSize, &data[0], data.size() ); + m_receivedData.resize(originalSize + data.size()); + memcpy(&m_receivedData[0] + originalSize, &data[0], data.size()); } //------------------------------------------------------------------------ -void SchannelContext::continueHandshake(const SafeByteArray& data) -{ +void SchannelContext::continueHandshake(const SafeByteArray& data) { appendNewData(data); - while (!m_receivedData.empty()) - { + while (!m_receivedData.empty()) { SecBuffer inBuffers[2]; // Provide Schannel with the remote host's handshake data @@ -321,49 +305,51 @@ void SchannelContext::continueHandshake(const SafeByteArray& data) &m_secContext, NULL); - if (status == SEC_E_INCOMPLETE_MESSAGE) - { + if (status == SEC_E_INCOMPLETE_MESSAGE) { // Wait for more data to arrive break; } - else if (status == SEC_I_CONTINUE_NEEDED) - { + else if (status == SEC_I_CONTINUE_NEEDED) { SecBuffer* pDataBuffer = &outBuffers[0]; SecBuffer* pExtraBuffer = &inBuffers[1]; - if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL) + if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL) { sendDataOnNetwork(pDataBuffer->pvBuffer, pDataBuffer->cbBuffer); + } - if (pExtraBuffer->BufferType == SECBUFFER_EXTRA) + if (pExtraBuffer->BufferType == SECBUFFER_EXTRA) { m_receivedData.erase(m_receivedData.begin(), m_receivedData.end() - pExtraBuffer->cbBuffer); - else + } + else { m_receivedData.clear(); + } break; } - else if (status == SEC_E_OK) - { + else if (status == SEC_E_OK) { status = validateServerCertificate(); - if (status != SEC_E_OK) + if (status != SEC_E_OK) { handleCertError(status); + } SecBuffer* pExtraBuffer = &inBuffers[1]; - if (pExtraBuffer && pExtraBuffer->cbBuffer > 0) + if (pExtraBuffer && pExtraBuffer->cbBuffer > 0) { m_receivedData.erase(m_receivedData.begin(), m_receivedData.end() - pExtraBuffer->cbBuffer); - else + } + else { m_receivedData.clear(); + } m_state = Connected; determineStreamSizes(); onConnected(); } - else - { + else { // We failed to initialize the security context handleCertError(status); - indicateError(); + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); return; } } @@ -377,45 +363,36 @@ void SchannelContext::handleCertError(SECURITY_STATUS status) status == CERT_E_UNTRUSTEDROOT || status == CRYPT_E_ISSUER_SERIALNUMBER || status == CRYPT_E_SIGNER_NOT_FOUND || - status == CRYPT_E_NO_TRUSTED_SIGNER) - { + status == CRYPT_E_NO_TRUSTED_SIGNER) { m_verificationError = CertificateVerificationError::Untrusted; } else if (status == SEC_E_CERT_EXPIRED || - status == CERT_E_EXPIRED) - { + status == CERT_E_EXPIRED) { m_verificationError = CertificateVerificationError::Expired; } - else if (status == CRYPT_E_SELF_SIGNED) - { + else if (status == CRYPT_E_SELF_SIGNED) { m_verificationError = CertificateVerificationError::SelfSigned; } else if (status == CRYPT_E_HASH_VALUE || - status == TRUST_E_CERT_SIGNATURE) - { + status == TRUST_E_CERT_SIGNATURE) { m_verificationError = CertificateVerificationError::InvalidSignature; } - else if (status == CRYPT_E_REVOKED) - { + else if (status == CRYPT_E_REVOKED) { m_verificationError = CertificateVerificationError::Revoked; } else if (status == CRYPT_E_NO_REVOCATION_CHECK || - status == CRYPT_E_REVOCATION_OFFLINE) - { + status == CRYPT_E_REVOCATION_OFFLINE) { m_verificationError = CertificateVerificationError::RevocationCheckFailed; } - else - { + else { m_verificationError = CertificateVerificationError::UnknownError; } } //------------------------------------------------------------------------ -void SchannelContext::sendDataOnNetwork(const void* pData, size_t dataSize) -{ - if (dataSize > 0 && pData) - { +void SchannelContext::sendDataOnNetwork(const void* pData, size_t dataSize) { + if (dataSize > 0 && pData) { SafeByteArray byteArray(dataSize); memcpy(&byteArray[0], pData, dataSize); @@ -425,8 +402,7 @@ void SchannelContext::sendDataOnNetwork(const void* pData, size_t dataSize) //------------------------------------------------------------------------ -void SchannelContext::forwardDataToApplication(const void* pData, size_t dataSize) -{ +void SchannelContext::forwardDataToApplication(const void* pData, size_t dataSize) { SafeByteArray byteArray(dataSize); memcpy(&byteArray[0], pData, dataSize); @@ -435,11 +411,11 @@ void SchannelContext::forwardDataToApplication(const void* pData, size_t dataSiz //------------------------------------------------------------------------ -void SchannelContext::handleDataFromApplication(const SafeByteArray& data) -{ +void SchannelContext::handleDataFromApplication(const SafeByteArray& data) { // Don't attempt to send data until we're fully connected - if (m_state == Connecting) + if (m_state == Connecting) { return; + } // Encrypt the data encryptAndSendData(data); @@ -447,10 +423,8 @@ void SchannelContext::handleDataFromApplication(const SafeByteArray& data) //------------------------------------------------------------------------ -void SchannelContext::handleDataFromNetwork(const SafeByteArray& data) -{ - switch (m_state) - { +void SchannelContext::handleDataFromNetwork(const SafeByteArray& data) { + switch (m_state) { case Connecting: { // We're still establishing the connection, so continue the handshake @@ -472,23 +446,20 @@ void SchannelContext::handleDataFromNetwork(const SafeByteArray& data) //------------------------------------------------------------------------ -void SchannelContext::indicateError() -{ +void SchannelContext::indicateError(boost::shared_ptr<TLSError> error) { m_state = Error; m_receivedData.clear(); - onError(boost::make_shared<TLSError>()); + onError(error); } //------------------------------------------------------------------------ -void SchannelContext::decryptAndProcessData(const SafeByteArray& data) -{ +void SchannelContext::decryptAndProcessData(const SafeByteArray& data) { SecBuffer inBuffers[4] = {0}; appendNewData(data); - while (!m_receivedData.empty()) - { + while (!m_receivedData.empty()) { // // MSDN: // When using the Schannel SSP with contexts that are not connection oriented, on input, @@ -515,49 +486,44 @@ void SchannelContext::decryptAndProcessData(const SafeByteArray& data) size_t inData = m_receivedData.size(); SECURITY_STATUS status = DecryptMessage(m_ctxtHandle, &inBufferDesc, 0, NULL); - if (status == SEC_E_INCOMPLETE_MESSAGE) - { + if (status == SEC_E_INCOMPLETE_MESSAGE) { // Wait for more data to arrive break; } - else if (status == SEC_I_RENEGOTIATE) - { + else if (status == SEC_I_RENEGOTIATE) { // TODO: Handle renegotiation scenarios - indicateError(); + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); break; } - else if (status == SEC_I_CONTEXT_EXPIRED) - { - indicateError(); + else if (status == SEC_I_CONTEXT_EXPIRED) { + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); break; } - else if (status != SEC_E_OK) - { - indicateError(); + else if (status != SEC_E_OK) { + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); break; } SecBuffer* pDataBuffer = NULL; SecBuffer* pExtraBuffer = NULL; - for (int i = 0; i < 4; ++i) - { - if (inBuffers[i].BufferType == SECBUFFER_DATA) + for (int i = 0; i < 4; ++i) { + if (inBuffers[i].BufferType == SECBUFFER_DATA) { pDataBuffer = &inBuffers[i]; - - else if (inBuffers[i].BufferType == SECBUFFER_EXTRA) + } + else if (inBuffers[i].BufferType == SECBUFFER_EXTRA) { pExtraBuffer = &inBuffers[i]; + } } - if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL) + if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL) { forwardDataToApplication(pDataBuffer->pvBuffer, pDataBuffer->cbBuffer); + } // If there is extra data left over from the decryption operation, we call DecryptMessage() again - if (pExtraBuffer) - { + if (pExtraBuffer) { m_receivedData.erase(m_receivedData.begin(), m_receivedData.end() - pExtraBuffer->cbBuffer); } - else - { + else { // We're done m_receivedData.erase(m_receivedData.begin(), m_receivedData.begin() + inData); } @@ -566,10 +532,10 @@ void SchannelContext::decryptAndProcessData(const SafeByteArray& data) //------------------------------------------------------------------------ -void SchannelContext::encryptAndSendData(const SafeByteArray& data) -{ - if (m_streamSizes.cbMaximumMessage == 0) +void SchannelContext::encryptAndSendData(const SafeByteArray& data) { + if (m_streamSizes.cbMaximumMessage == 0) { return; + } SecBuffer outBuffers[4] = {0}; @@ -583,8 +549,7 @@ void SchannelContext::encryptAndSendData(const SafeByteArray& data) sendBuffer.resize(m_streamSizes.cbHeader + messageBufferSize + m_streamSizes.cbTrailer); size_t bytesSent = 0; - do - { + do { size_t bytesLeftToSend = data.size() - bytesSent; // Calculate how much of the send buffer we'll be using for this chunk @@ -617,9 +582,8 @@ void SchannelContext::encryptAndSendData(const SafeByteArray& data) outBufferDesc.ulVersion = SECBUFFER_VERSION; SECURITY_STATUS status = EncryptMessage(m_ctxtHandle, 0, &outBufferDesc, 0); - if (status != SEC_E_OK) - { - indicateError(); + if (status != SEC_E_OK) { + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); return; } @@ -631,13 +595,14 @@ void SchannelContext::encryptAndSendData(const SafeByteArray& data) //------------------------------------------------------------------------ -bool SchannelContext::setClientCertificate(CertificateWithKey::ref certificate) -{ +bool SchannelContext::setClientCertificate(CertificateWithKey::ref certificate) { boost::shared_ptr<CAPICertificate> capiCertificate = boost::dynamic_pointer_cast<CAPICertificate>(certificate); if (!capiCertificate || capiCertificate->isNull()) { return false; } + userCertificate = capiCertificate; + // We assume that the Certificate Store Name/Certificate Name // are valid at this point m_cert_store_name = capiCertificate->getCertStoreName(); @@ -652,41 +617,26 @@ bool SchannelContext::setClientCertificate(CertificateWithKey::ref certificate) //------------------------------------------------------------------------ void SchannelContext::handleCertificateCardRemoved() { - //ToDo: Might want to log the reason ("certificate card ejected") - indicateError(); + indicateError(boost::make_shared<TLSError>(TLSError::CertificateCardRemoved)); } //------------------------------------------------------------------------ -Certificate::ref SchannelContext::getPeerCertificate() const -{ - SchannelCertificate::ref pCertificate; - +Certificate::ref SchannelContext::getPeerCertificate() const { ScopedCertContext pServerCert; SECURITY_STATUS status = QueryContextAttributes(m_ctxtHandle, SECPKG_ATTR_REMOTE_CERT_CONTEXT, pServerCert.Reset()); - if (status != SEC_E_OK) - return pCertificate; - - pCertificate.reset( new SchannelCertificate(pServerCert) ); - return pCertificate; + return status == SEC_E_OK ? boost::make_shared<SchannelCertificate>(pServerCert) : SchannelCertificate::ref(); } //------------------------------------------------------------------------ -CertificateVerificationError::ref SchannelContext::getPeerCertificateVerificationError() const -{ - boost::shared_ptr<CertificateVerificationError> pCertError; - - if (m_verificationError) - pCertError.reset( new CertificateVerificationError(*m_verificationError) ); - - return pCertError; +CertificateVerificationError::ref SchannelContext::getPeerCertificateVerificationError() const { + return m_verificationError ? boost::make_shared<CertificateVerificationError>(*m_verificationError) : CertificateVerificationError::ref(); } //------------------------------------------------------------------------ -ByteArray SchannelContext::getFinishMessage() const -{ +ByteArray SchannelContext::getFinishMessage() const { // TODO: Implement ByteArray emptyArray; |