diff options
Diffstat (limited to 'Swiften/TLS/Schannel')
-rw-r--r-- | Swiften/TLS/Schannel/SchannelCertificate.cpp | 197 | ||||
-rw-r--r-- | Swiften/TLS/Schannel/SchannelCertificate.h | 81 | ||||
-rw-r--r-- | Swiften/TLS/Schannel/SchannelCertificateFactory.h | 19 | ||||
-rw-r--r-- | Swiften/TLS/Schannel/SchannelContext.cpp | 503 | ||||
-rw-r--r-- | Swiften/TLS/Schannel/SchannelContext.h | 81 | ||||
-rw-r--r-- | Swiften/TLS/Schannel/SchannelContextFactory.cpp | 20 | ||||
-rw-r--r-- | Swiften/TLS/Schannel/SchannelContextFactory.h | 17 | ||||
-rw-r--r-- | Swiften/TLS/Schannel/SchannelUtil.h | 294 |
8 files changed, 1212 insertions, 0 deletions
diff --git a/Swiften/TLS/Schannel/SchannelCertificate.cpp b/Swiften/TLS/Schannel/SchannelCertificate.cpp new file mode 100644 index 0000000..8aaec00 --- /dev/null +++ b/Swiften/TLS/Schannel/SchannelCertificate.cpp @@ -0,0 +1,197 @@ +/* + * Copyright (c) 2011 Soren Dreijer + * Licensed under the simplified BSD license. + * See Documentation/Licenses/BSD-simplified.txt for more information. + */ + +#include "Swiften/TLS/Schannel/SchannelCertificate.h" +#include "Swiften/Base/ByteArray.h" + +#define SECURITY_WIN32 +#include <Windows.h> +#include <Schannel.h> +#include <security.h> +#include <schnlsp.h> +#include <Wincrypt.h> + +using std::vector; + +namespace Swift { + +//------------------------------------------------------------------------ + +SchannelCertificate::SchannelCertificate(const ScopedCertContext& certCtxt) +: m_cert(certCtxt) +{ + parse(); +} + +//------------------------------------------------------------------------ + +SchannelCertificate::SchannelCertificate(const ByteArray& der) +{ + if (!der.empty()) + { + // Convert the DER encoded certificate to a PCERT_CONTEXT + CERT_BLOB certBlob = {0}; + certBlob.cbData = der.size(); + certBlob.pbData = (BYTE*)&der[0]; + + if (!CryptQueryObject( + CERT_QUERY_OBJECT_BLOB, + &certBlob, + CERT_QUERY_CONTENT_FLAG_CERT, + CERT_QUERY_FORMAT_FLAG_ALL, + 0, + NULL, + NULL, + NULL, + NULL, + NULL, + (const void**)m_cert.Reset())) + { + // TODO: Because Swiften isn't exception safe, we have no way to indicate failure + } + } +} + +//------------------------------------------------------------------------ + +ByteArray SchannelCertificate::toDER() const +{ + ByteArray result; + + // Serialize the certificate. The CERT_CONTEXT is already DER encoded. + result.resize(m_cert->cbCertEncoded); + memcpy(&result[0], m_cert->pbCertEncoded, result.size()); + + return result; +} + +//------------------------------------------------------------------------ + +std::string SchannelCertificate::wstrToStr(const std::wstring& wstr) +{ + if (wstr.empty()) + return ""; + + // First request the size of the required UTF-8 buffer + int numRequiredBytes = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), wstr.size(), NULL, 0, NULL, NULL); + if (!numRequiredBytes) + return ""; + + // Allocate memory for the UTF-8 string + std::vector<char> utf8Str(numRequiredBytes); + + int numConverted = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), wstr.size(), &utf8Str[0], numRequiredBytes, NULL, NULL); + if (!numConverted) + return ""; + + std::string str(&utf8Str[0], numConverted); + return str; +} + +//------------------------------------------------------------------------ + +void SchannelCertificate::parse() +{ + // + // Subject name + // + DWORD requiredSize = CertNameToStr(X509_ASN_ENCODING, &m_cert->pCertInfo->Subject, CERT_OID_NAME_STR, NULL, 0); + if (requiredSize > 1) + { + vector<char> rawSubjectName(requiredSize); + CertNameToStr(X509_ASN_ENCODING, &m_cert->pCertInfo->Subject, CERT_OID_NAME_STR, &rawSubjectName[0], rawSubjectName.size()); + m_subjectName = std::string(&rawSubjectName[0]); + } + + // + // Common name + // + // Note: We only pull out one common name from the cert. + requiredSize = CertGetNameString(m_cert, CERT_NAME_ATTR_TYPE, 0, szOID_COMMON_NAME, NULL, 0); + if (requiredSize > 1) + { + vector<char> rawCommonName(requiredSize); + requiredSize = CertGetNameString(m_cert, CERT_NAME_ATTR_TYPE, 0, szOID_COMMON_NAME, &rawCommonName[0], rawCommonName.size()); + m_commonNames.push_back( std::string(&rawCommonName[0]) ); + } + + // + // Subject alternative names + // + PCERT_EXTENSION pExtensions = CertFindExtension(szOID_SUBJECT_ALT_NAME2, m_cert->pCertInfo->cExtension, m_cert->pCertInfo->rgExtension); + if (pExtensions) + { + CRYPT_DECODE_PARA decodePara = {0}; + decodePara.cbSize = sizeof(decodePara); + + CERT_ALT_NAME_INFO* pAltNameInfo = NULL; + DWORD altNameInfoSize = 0; + + BOOL status = CryptDecodeObjectEx( + X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, + szOID_SUBJECT_ALT_NAME2, + pExtensions->Value.pbData, + pExtensions->Value.cbData, + CRYPT_DECODE_ALLOC_FLAG | CRYPT_DECODE_NOCOPY_FLAG, + &decodePara, + &pAltNameInfo, + &altNameInfoSize); + + if (status && pAltNameInfo) + { + for (int i = 0; i < pAltNameInfo->cAltEntry; i++) + { + if (pAltNameInfo->rgAltEntry[i].dwAltNameChoice == CERT_ALT_NAME_DNS_NAME) + addDNSName( wstrToStr( pAltNameInfo->rgAltEntry[i].pwszDNSName ) ); + } + } + } + + // if (pExtensions) + // { + // vector<wchar_t> subjectAlt + // CryptDecodeObject(X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, szOID_SUBJECT_ALT_NAME, pExtensions->Value->pbData, pExtensions->Value->cbData, ) + // } + // + // // subjectAltNames + // int subjectAltNameLoc = X509_get_ext_by_NID(cert.get(), NID_subject_alt_name, -1); + // if(subjectAltNameLoc != -1) { + // X509_EXTENSION* extension = X509_get_ext(cert.get(), subjectAltNameLoc); + // boost::shared_ptr<GENERAL_NAMES> generalNames(reinterpret_cast<GENERAL_NAMES*>(X509V3_EXT_d2i(extension)), GENERAL_NAMES_free); + // boost::shared_ptr<ASN1_OBJECT> xmppAddrObject(OBJ_txt2obj(ID_ON_XMPPADDR_OID, 1), ASN1_OBJECT_free); + // boost::shared_ptr<ASN1_OBJECT> dnsSRVObject(OBJ_txt2obj(ID_ON_DNSSRV_OID, 1), ASN1_OBJECT_free); + // for (int i = 0; i < sk_GENERAL_NAME_num(generalNames.get()); ++i) { + // GENERAL_NAME* generalName = sk_GENERAL_NAME_value(generalNames.get(), i); + // if (generalName->type == GEN_OTHERNAME) { + // OTHERNAME* otherName = generalName->d.otherName; + // if (OBJ_cmp(otherName->type_id, xmppAddrObject.get()) == 0) { + // // XmppAddr + // if (otherName->value->type != V_ASN1_UTF8STRING) { + // continue; + // } + // ASN1_UTF8STRING* xmppAddrValue = otherName->value->value.utf8string; + // addXMPPAddress(ByteArray(ASN1_STRING_data(xmppAddrValue), ASN1_STRING_length(xmppAddrValue)).toString()); + // } + // else if (OBJ_cmp(otherName->type_id, dnsSRVObject.get()) == 0) { + // // SRVName + // if (otherName->value->type != V_ASN1_IA5STRING) { + // continue; + // } + // ASN1_IA5STRING* srvNameValue = otherName->value->value.ia5string; + // addSRVName(ByteArray(ASN1_STRING_data(srvNameValue), ASN1_STRING_length(srvNameValue)).toString()); + // } + // } + // else if (generalName->type == GEN_DNS) { + // // DNSName + // addDNSName(ByteArray(ASN1_STRING_data(generalName->d.dNSName), ASN1_STRING_length(generalName->d.dNSName)).toString()); + // } + // } + // } +} + +//------------------------------------------------------------------------ + +} diff --git a/Swiften/TLS/Schannel/SchannelCertificate.h b/Swiften/TLS/Schannel/SchannelCertificate.h new file mode 100644 index 0000000..f531cff --- /dev/null +++ b/Swiften/TLS/Schannel/SchannelCertificate.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2011 Soren Dreijer + * Licensed under the simplified BSD license. + * See Documentation/Licenses/BSD-simplified.txt for more information. + */ + +#pragma once + +#include <boost/shared_ptr.hpp> + +#include "Swiften/Base/String.h" +#include "Swiften/TLS/Certificate.h" +#include "Swiften/TLS/Schannel/SchannelUtil.h" + +namespace Swift +{ + class SchannelCertificate : public Certificate + { + public: + typedef boost::shared_ptr<SchannelCertificate> ref; + + public: + SchannelCertificate(const ScopedCertContext& certCtxt); + SchannelCertificate(const ByteArray& der); + + std::string getSubjectName() const + { + return m_subjectName; + } + + std::vector<std::string> getCommonNames() const + { + return m_commonNames; + } + + std::vector<std::string> getSRVNames() const + { + return m_srvNames; + } + + std::vector<std::string> getDNSNames() const + { + return m_dnsNames; + } + + std::vector<std::string> getXMPPAddresses() const + { + return m_xmppAddresses; + } + + ByteArray toDER() const; + + private: + void parse(); + std::string wstrToStr(const std::wstring& wstr); + + void addSRVName(const std::string& name) + { + m_srvNames.push_back(name); + } + + void addDNSName(const std::string& name) + { + m_dnsNames.push_back(name); + } + + void addXMPPAddress(const std::string& addr) + { + m_xmppAddresses.push_back(addr); + } + + private: + ScopedCertContext m_cert; + + std::string m_subjectName; + std::vector<std::string> m_commonNames; + std::vector<std::string> m_dnsNames; + std::vector<std::string> m_xmppAddresses; + std::vector<std::string> m_srvNames; + }; +} diff --git a/Swiften/TLS/Schannel/SchannelCertificateFactory.h b/Swiften/TLS/Schannel/SchannelCertificateFactory.h new file mode 100644 index 0000000..d09bb54 --- /dev/null +++ b/Swiften/TLS/Schannel/SchannelCertificateFactory.h @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2011 Soren Dreijer + * Licensed under the simplified BSD license. + * See Documentation/Licenses/BSD-simplified.txt for more information. + */ + +#pragma once + +#include <Swiften/TLS/CertificateFactory.h> +#include <Swiften/TLS/Schannel/SchannelCertificate.h> + +namespace Swift { + class SchannelCertificateFactory : public CertificateFactory { + public: + virtual Certificate::ref createCertificateFromDER(const ByteArray& der) { + return Certificate::ref(new SchannelCertificate(der)); + } + }; +} diff --git a/Swiften/TLS/Schannel/SchannelContext.cpp b/Swiften/TLS/Schannel/SchannelContext.cpp new file mode 100644 index 0000000..6771d4a --- /dev/null +++ b/Swiften/TLS/Schannel/SchannelContext.cpp @@ -0,0 +1,503 @@ +/* + * Copyright (c) 2011 Soren Dreijer + * Licensed under the simplified BSD license. + * See Documentation/Licenses/BSD-simplified.txt for more information. + */ + +#include "Swiften/TLS/Schannel/SchannelContext.h" +#include "Swiften/TLS/Schannel/SchannelCertificate.h" + +namespace Swift { + +//------------------------------------------------------------------------ + +SchannelContext::SchannelContext() +: m_state(Start) +, m_secContext(0) +, m_verificationError(CertificateVerificationError::UnknownError) +{ + m_ctxtFlags = ISC_REQ_ALLOCATE_MEMORY | + ISC_REQ_CONFIDENTIALITY | + ISC_REQ_EXTENDED_ERROR | + ISC_REQ_INTEGRITY | + ISC_REQ_REPLAY_DETECT | + ISC_REQ_SEQUENCE_DETECT | + ISC_REQ_USE_SUPPLIED_CREDS | + ISC_REQ_STREAM; + + ZeroMemory(&m_streamSizes, sizeof(m_streamSizes)); +} + +//------------------------------------------------------------------------ + +void SchannelContext::determineStreamSizes() +{ + QueryContextAttributes(m_ctxtHandle, SECPKG_ATTR_STREAM_SIZES, &m_streamSizes); +} + +//------------------------------------------------------------------------ + +void SchannelContext::connect() +{ + m_state = Connecting; + + // We use an empty list for client certificates + PCCERT_CONTEXT clientCerts[1] = {0}; + + SCHANNEL_CRED sc = {0}; + sc.dwVersion = SCHANNEL_CRED_VERSION; + sc.cCreds = 0; // Let Crypto API find the appropriate certificate for us + sc.paCred = clientCerts; + sc.grbitEnabledProtocols = SP_PROT_SSL3_CLIENT | SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_2_CLIENT; + sc.dwFlags = SCH_CRED_AUTO_CRED_VALIDATION | /*SCH_CRED_NO_DEFAULT_CREDS*/ SCH_CRED_USE_DEFAULT_CREDS | SCH_CRED_REVOCATION_CHECK_CHAIN; + + // Swiften performs the server name check for us + sc.dwFlags |= SCH_CRED_NO_SERVERNAME_CHECK; + + SECURITY_STATUS status = AcquireCredentialsHandle( + NULL, + UNISP_NAME, + SECPKG_CRED_OUTBOUND, + NULL, + &sc, + NULL, + NULL, + m_credHandle.Reset(), + NULL); + + if (status != SEC_E_OK) + { + // We failed to obtain the credentials handle + indicateError(); + return; + } + + SecBuffer outBuffers[2]; + + // We let Schannel allocate the output buffer for us + outBuffers[0].pvBuffer = NULL; + outBuffers[0].cbBuffer = 0; + outBuffers[0].BufferType = SECBUFFER_TOKEN; + + // Contains alert data if an alert is generated + outBuffers[1].pvBuffer = NULL; + outBuffers[1].cbBuffer = 0; + outBuffers[1].BufferType = SECBUFFER_ALERT; + + // Make sure the output buffers are freed + ScopedSecBuffer scopedOutputData(&outBuffers[0]); + ScopedSecBuffer scopedOutputAlertData(&outBuffers[1]); + + SecBufferDesc outBufferDesc = {0}; + outBufferDesc.cBuffers = 2; + outBufferDesc.pBuffers = outBuffers; + outBufferDesc.ulVersion = SECBUFFER_VERSION; + + // Create the initial security context + status = InitializeSecurityContext( + m_credHandle, + NULL, + NULL, + m_ctxtFlags, + 0, + 0, + NULL, + 0, + m_ctxtHandle.Reset(), + &outBufferDesc, + &m_secContext, + NULL); + + if (status != SEC_E_OK && status != SEC_I_CONTINUE_NEEDED) + { + // We failed to initialize the security context + indicateError(); + return; + } + + // Start the handshake + sendDataOnNetwork(outBuffers[0].pvBuffer, outBuffers[0].cbBuffer); + + if (status == SEC_E_OK) + { + m_state = Connected; + determineStreamSizes(); + + onConnected(); + } +} + +//------------------------------------------------------------------------ + +void SchannelContext::appendNewData(const SafeByteArray& data) +{ + size_t originalSize = m_receivedData.size(); + m_receivedData.resize( originalSize + data.size() ); + memcpy( &m_receivedData[0] + originalSize, &data[0], data.size() ); +} + +//------------------------------------------------------------------------ + +void SchannelContext::continueHandshake(const SafeByteArray& data) +{ + appendNewData(data); + + while (!m_receivedData.empty()) + { + SecBuffer inBuffers[2]; + + // Provide Schannel with the remote host's handshake data + inBuffers[0].pvBuffer = (char*)(&m_receivedData[0]); + inBuffers[0].cbBuffer = (unsigned long)m_receivedData.size(); + inBuffers[0].BufferType = SECBUFFER_TOKEN; + + inBuffers[1].pvBuffer = NULL; + inBuffers[1].cbBuffer = 0; + inBuffers[1].BufferType = SECBUFFER_EMPTY; + + SecBufferDesc inBufferDesc = {0}; + inBufferDesc.cBuffers = 2; + inBufferDesc.pBuffers = inBuffers; + inBufferDesc.ulVersion = SECBUFFER_VERSION; + + SecBuffer outBuffers[2]; + + // We let Schannel allocate the output buffer for us + outBuffers[0].pvBuffer = NULL; + outBuffers[0].cbBuffer = 0; + outBuffers[0].BufferType = SECBUFFER_TOKEN; + + // Contains alert data if an alert is generated + outBuffers[1].pvBuffer = NULL; + outBuffers[1].cbBuffer = 0; + outBuffers[1].BufferType = SECBUFFER_ALERT; + + // Make sure the output buffers are freed + ScopedSecBuffer scopedOutputData(&outBuffers[0]); + ScopedSecBuffer scopedOutputAlertData(&outBuffers[1]); + + SecBufferDesc outBufferDesc = {0}; + outBufferDesc.cBuffers = 2; + outBufferDesc.pBuffers = outBuffers; + outBufferDesc.ulVersion = SECBUFFER_VERSION; + + SECURITY_STATUS status = InitializeSecurityContext( + m_credHandle, + m_ctxtHandle, + NULL, + m_ctxtFlags, + 0, + 0, + &inBufferDesc, + 0, + NULL, + &outBufferDesc, + &m_secContext, + NULL); + + if (status == SEC_E_INCOMPLETE_MESSAGE) + { + // Wait for more data to arrive + break; + } + else if (status == SEC_I_CONTINUE_NEEDED) + { + SecBuffer* pDataBuffer = &outBuffers[0]; + SecBuffer* pExtraBuffer = &inBuffers[1]; + + if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL) + sendDataOnNetwork(pDataBuffer->pvBuffer, pDataBuffer->cbBuffer); + + if (pExtraBuffer->BufferType == SECBUFFER_EXTRA) + m_receivedData.erase(m_receivedData.begin(), m_receivedData.end() - pExtraBuffer->cbBuffer); + else + m_receivedData.clear(); + + break; + } + else if (status == SEC_E_OK) + { + SecBuffer* pExtraBuffer = &inBuffers[1]; + + if (pExtraBuffer && pExtraBuffer->cbBuffer > 0) + m_receivedData.erase(m_receivedData.begin(), m_receivedData.end() - pExtraBuffer->cbBuffer); + else + m_receivedData.clear(); + + m_state = Connected; + determineStreamSizes(); + + onConnected(); + } + else + { + // We failed to initialize the security context + indicateError(); + return; + } + } +} + +//------------------------------------------------------------------------ + +void SchannelContext::sendDataOnNetwork(const void* pData, size_t dataSize) +{ + if (dataSize > 0 && pData) + { + SafeByteArray byteArray(dataSize); + memcpy(&byteArray[0], pData, dataSize); + + onDataForNetwork(byteArray); + } +} + +//------------------------------------------------------------------------ + +void SchannelContext::forwardDataToApplication(const void* pData, size_t dataSize) +{ + SafeByteArray byteArray(dataSize); + memcpy(&byteArray[0], pData, dataSize); + + onDataForApplication(byteArray); +} + +//------------------------------------------------------------------------ + +void SchannelContext::handleDataFromApplication(const SafeByteArray& data) +{ + // Don't attempt to send data until we're fully connected + if (m_state == Connecting) + return; + + // Encrypt the data + encryptAndSendData(data); +} + +//------------------------------------------------------------------------ + +void SchannelContext::handleDataFromNetwork(const SafeByteArray& data) +{ + switch (m_state) + { + case Connecting: + { + // We're still establishing the connection, so continue the handshake + continueHandshake(data); + } + break; + + case Connected: + { + // Decrypt the data + decryptAndProcessData(data); + } + break; + + default: + return; + } +} + +//------------------------------------------------------------------------ + +void SchannelContext::indicateError() +{ + m_state = Error; + m_receivedData.clear(); + onError(); +} + +//------------------------------------------------------------------------ + +void SchannelContext::decryptAndProcessData(const SafeByteArray& data) +{ + SecBuffer inBuffers[4] = {0}; + + appendNewData(data); + + while (!m_receivedData.empty()) + { + // + // MSDN: + // When using the Schannel SSP with contexts that are not connection oriented, on input, + // the structure must contain four SecBuffer structures. Exactly one buffer must be of type + // SECBUFFER_DATA and contain an encrypted message, which is decrypted in place. The remaining + // buffers are used for output and must be of type SECBUFFER_EMPTY. For connection-oriented + // contexts, a SECBUFFER_DATA type buffer must be supplied, as noted for nonconnection-oriented + // contexts. Additionally, a second SECBUFFER_TOKEN type buffer that contains a security token + // must also be supplied. + // + inBuffers[0].pvBuffer = (char*)(&m_receivedData[0]); + inBuffers[0].cbBuffer = (unsigned long)m_receivedData.size(); + inBuffers[0].BufferType = SECBUFFER_DATA; + + inBuffers[1].BufferType = SECBUFFER_EMPTY; + inBuffers[2].BufferType = SECBUFFER_EMPTY; + inBuffers[3].BufferType = SECBUFFER_EMPTY; + + SecBufferDesc inBufferDesc = {0}; + inBufferDesc.cBuffers = 4; + inBufferDesc.pBuffers = inBuffers; + inBufferDesc.ulVersion = SECBUFFER_VERSION; + + size_t inData = m_receivedData.size(); + SECURITY_STATUS status = DecryptMessage(m_ctxtHandle, &inBufferDesc, 0, NULL); + + if (status == SEC_E_INCOMPLETE_MESSAGE) + { + // Wait for more data to arrive + break; + } + else if (status == SEC_I_RENEGOTIATE) + { + // TODO: Handle renegotiation scenarios + indicateError(); + break; + } + else if (status == SEC_I_CONTEXT_EXPIRED) + { + indicateError(); + break; + } + else if (status != SEC_E_OK) + { + indicateError(); + break; + } + + SecBuffer* pDataBuffer = NULL; + SecBuffer* pExtraBuffer = NULL; + for (int i = 0; i < 4; ++i) + { + if (inBuffers[i].BufferType == SECBUFFER_DATA) + pDataBuffer = &inBuffers[i]; + + else if (inBuffers[i].BufferType == SECBUFFER_EXTRA) + pExtraBuffer = &inBuffers[i]; + } + + if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL) + forwardDataToApplication(pDataBuffer->pvBuffer, pDataBuffer->cbBuffer); + + // If there is extra data left over from the decryption operation, we call DecryptMessage() again + if (pExtraBuffer) + { + m_receivedData.erase(m_receivedData.begin(), m_receivedData.end() - pExtraBuffer->cbBuffer); + } + else + { + // We're done + m_receivedData.erase(m_receivedData.begin(), m_receivedData.begin() + inData); + } + } +} + +//------------------------------------------------------------------------ + +void SchannelContext::encryptAndSendData(const SafeByteArray& data) +{ + SecBuffer outBuffers[4] = {0}; + + // Calculate the largest required size of the send buffer + size_t messageBufferSize = (data.size() > m_streamSizes.cbMaximumMessage) + ? m_streamSizes.cbMaximumMessage + : data.size(); + + // Allocate a packet for the encrypted data + SafeByteArray sendBuffer; + sendBuffer.resize(m_streamSizes.cbHeader + messageBufferSize + m_streamSizes.cbTrailer); + + size_t bytesSent = 0; + do + { + size_t bytesLeftToSend = data.size() - bytesSent; + + // Calculate how much of the send buffer we'll be using for this chunk + size_t bytesToSend = (bytesLeftToSend > m_streamSizes.cbMaximumMessage) + ? m_streamSizes.cbMaximumMessage + : bytesLeftToSend; + + // Copy the plain text data into the send buffer + memcpy(&sendBuffer[0] + m_streamSizes.cbHeader, &data[0] + bytesSent, bytesToSend); + + outBuffers[0].pvBuffer = &sendBuffer[0]; + outBuffers[0].cbBuffer = m_streamSizes.cbHeader; + outBuffers[0].BufferType = SECBUFFER_STREAM_HEADER; + + outBuffers[1].pvBuffer = &sendBuffer[0] + m_streamSizes.cbHeader; + outBuffers[1].cbBuffer = (unsigned long)bytesToSend; + outBuffers[1].BufferType = SECBUFFER_DATA; + + outBuffers[2].pvBuffer = &sendBuffer[0] + m_streamSizes.cbHeader + bytesToSend; + outBuffers[2].cbBuffer = m_streamSizes.cbTrailer; + outBuffers[2].BufferType = SECBUFFER_STREAM_TRAILER; + + outBuffers[3].pvBuffer = 0; + outBuffers[3].cbBuffer = 0; + outBuffers[3].BufferType = SECBUFFER_EMPTY; + + SecBufferDesc outBufferDesc = {0}; + outBufferDesc.cBuffers = 4; + outBufferDesc.pBuffers = outBuffers; + outBufferDesc.ulVersion = SECBUFFER_VERSION; + + SECURITY_STATUS status = EncryptMessage(m_ctxtHandle, 0, &outBufferDesc, 0); + if (status != SEC_E_OK) + { + indicateError(); + return; + } + + sendDataOnNetwork(&sendBuffer[0], outBuffers[0].cbBuffer + outBuffers[1].cbBuffer + outBuffers[2].cbBuffer); + bytesSent += bytesToSend; + + } while (bytesSent < data.size()); +} + +//------------------------------------------------------------------------ + +bool SchannelContext::setClientCertificate(const PKCS12Certificate& certificate) +{ + return false; +} + +//------------------------------------------------------------------------ + +Certificate::ref SchannelContext::getPeerCertificate() const +{ + SchannelCertificate::ref pCertificate; + + ScopedCertContext pServerCert; + SECURITY_STATUS status = QueryContextAttributes(m_ctxtHandle, SECPKG_ATTR_REMOTE_CERT_CONTEXT, pServerCert.Reset()); + if (status != SEC_E_OK) + return pCertificate; + + pCertificate.reset( new SchannelCertificate(pServerCert) ); + return pCertificate; +} + +//------------------------------------------------------------------------ + +CertificateVerificationError::ref SchannelContext::getPeerCertificateVerificationError() const +{ + boost::shared_ptr<CertificateVerificationError> pCertError; + + if (m_state == Error) + pCertError.reset( new CertificateVerificationError(m_verificationError) ); + + return pCertError; +} + +//------------------------------------------------------------------------ + +ByteArray SchannelContext::getFinishMessage() const +{ + // TODO: Implement + + ByteArray emptyArray; + return emptyArray; +} + +//------------------------------------------------------------------------ + +} diff --git a/Swiften/TLS/Schannel/SchannelContext.h b/Swiften/TLS/Schannel/SchannelContext.h new file mode 100644 index 0000000..66467fe --- /dev/null +++ b/Swiften/TLS/Schannel/SchannelContext.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2011 Soren Dreijer + * Licensed under the simplified BSD license. + * See Documentation/Licenses/BSD-simplified.txt for more information. + */ + +#pragma once + +#include "Swiften/Base/boost_bsignals.h" + +#include "Swiften/TLS/TLSContext.h" +#include "Swiften/TLS/Schannel/SchannelUtil.h" +#include "Swiften/Base/ByteArray.h" + +#define SECURITY_WIN32 +#include <Windows.h> +#include <Schannel.h> +#include <security.h> +#include <schnlsp.h> + +#include <boost/noncopyable.hpp> + +namespace Swift +{ + class SchannelContext : public TLSContext, boost::noncopyable + { + public: + typedef boost::shared_ptr<SchannelContext> sp_t; + + public: + SchannelContext(); + + // + // TLSContext + // + virtual void connect(); + virtual bool setClientCertificate(const PKCS12Certificate&); + + virtual void handleDataFromNetwork(const SafeByteArray& data); + virtual void handleDataFromApplication(const SafeByteArray& data); + + virtual Certificate::ref getPeerCertificate() const; + virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const; + + virtual ByteArray getFinishMessage() const; + + private: + void determineStreamSizes(); + void continueHandshake(const SafeByteArray& data); + void indicateError(); + + void sendDataOnNetwork(const void* pData, size_t dataSize); + void forwardDataToApplication(const void* pData, size_t dataSize); + + void decryptAndProcessData(const SafeByteArray& data); + void encryptAndSendData(const SafeByteArray& data); + + void appendNewData(const SafeByteArray& data); + + private: + enum SchannelState + { + Start, + Connecting, + Connected, + Error + + }; + + SchannelState m_state; + CertificateVerificationError m_verificationError; + + ULONG m_secContext; + ScopedCredHandle m_credHandle; + ScopedCtxtHandle m_ctxtHandle; + DWORD m_ctxtFlags; + SecPkgContext_StreamSizes m_streamSizes; + + std::vector<char> m_receivedData; + }; +} diff --git a/Swiften/TLS/Schannel/SchannelContextFactory.cpp b/Swiften/TLS/Schannel/SchannelContextFactory.cpp new file mode 100644 index 0000000..8ab7c6c --- /dev/null +++ b/Swiften/TLS/Schannel/SchannelContextFactory.cpp @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2011 Soren Dreijer + * Licensed under the simplified BSD license. + * See Documentation/Licenses/BSD-simplified.txt for more information. + */ + +#include "Swiften/TLS/Schannel/SchannelContextFactory.h" +#include "Swiften/TLS/Schannel/SchannelContext.h" + +namespace Swift { + +bool SchannelContextFactory::canCreate() const { + return true; +} + +TLSContext* SchannelContextFactory::createTLSContext() { + return new SchannelContext(); +} + +} diff --git a/Swiften/TLS/Schannel/SchannelContextFactory.h b/Swiften/TLS/Schannel/SchannelContextFactory.h new file mode 100644 index 0000000..43c39a9 --- /dev/null +++ b/Swiften/TLS/Schannel/SchannelContextFactory.h @@ -0,0 +1,17 @@ +/* + * Copyright (c) 2011 Soren Dreijer + * Licensed under the simplified BSD license. + * See Documentation/Licenses/BSD-simplified.txt for more information. + */ + +#pragma once + +#include "Swiften/TLS/TLSContextFactory.h" + +namespace Swift { + class SchannelContextFactory : public TLSContextFactory { + public: + bool canCreate() const; + virtual TLSContext* createTLSContext(); + }; +} diff --git a/Swiften/TLS/Schannel/SchannelUtil.h b/Swiften/TLS/Schannel/SchannelUtil.h new file mode 100644 index 0000000..0a54f16 --- /dev/null +++ b/Swiften/TLS/Schannel/SchannelUtil.h @@ -0,0 +1,294 @@ +/* + * Copyright (c) 2011 Soren Dreijer + * Licensed under the simplified BSD license. + * See Documentation/Licenses/BSD-simplified.txt for more information. + */ + +#pragma once + +#define SECURITY_WIN32 +#include <Windows.h> +#include <Schannel.h> +#include <security.h> +#include <schnlsp.h> + +#include <boost/noncopyable.hpp> + +namespace Swift +{ + // + // Convenience wrapper around the Schannel CredHandle struct. + // + class ScopedCredHandle + { + private: + struct HandleContext + { + HandleContext() + { + ZeroMemory(&m_h, sizeof(m_h)); + } + + HandleContext(const CredHandle& h) + { + memcpy(&m_h, &h, sizeof(m_h)); + } + + ~HandleContext() + { + ::FreeCredentialsHandle(&m_h); + } + + CredHandle m_h; + }; + + public: + ScopedCredHandle() + : m_pHandle( new HandleContext ) + { + } + + explicit ScopedCredHandle(const CredHandle& h) + : m_pHandle( new HandleContext(h) ) + { + } + + // Copy constructor + explicit ScopedCredHandle(const ScopedCredHandle& rhs) + { + m_pHandle = rhs.m_pHandle; + } + + ~ScopedCredHandle() + { + m_pHandle.reset(); + } + + PCredHandle Reset() + { + CloseHandle(); + return &m_pHandle->m_h; + } + + operator PCredHandle() const + { + return &m_pHandle->m_h; + } + + ScopedCredHandle& operator=(const ScopedCredHandle& sh) + { + // Only update the internal handle if it's different + if (&m_pHandle->m_h != &sh.m_pHandle->m_h) + { + m_pHandle = sh.m_pHandle; + } + + return *this; + } + + void CloseHandle() + { + m_pHandle.reset( new HandleContext ); + } + + private: + boost::shared_ptr<HandleContext> m_pHandle; + }; + + //------------------------------------------------------------------------ + + // + // Convenience wrapper around the Schannel CtxtHandle struct. + // + class ScopedCtxtHandle + { + private: + struct HandleContext + { + HandleContext() + { + ZeroMemory(&m_h, sizeof(m_h)); + } + + ~HandleContext() + { + ::DeleteSecurityContext(&m_h); + } + + CtxtHandle m_h; + }; + + public: + ScopedCtxtHandle() + : m_pHandle( new HandleContext ) + { + } + + explicit ScopedCtxtHandle(CredHandle h) + : m_pHandle( new HandleContext ) + { + } + + // Copy constructor + explicit ScopedCtxtHandle(const ScopedCtxtHandle& rhs) + { + m_pHandle = rhs.m_pHandle; + } + + ~ScopedCtxtHandle() + { + m_pHandle.reset(); + } + + PCredHandle Reset() + { + CloseHandle(); + return &m_pHandle->m_h; + } + + operator PCredHandle() const + { + return &m_pHandle->m_h; + } + + ScopedCtxtHandle& operator=(const ScopedCtxtHandle& sh) + { + // Only update the internal handle if it's different + if (&m_pHandle->m_h != &sh.m_pHandle->m_h) + { + m_pHandle = sh.m_pHandle; + } + + return *this; + } + + void CloseHandle() + { + m_pHandle.reset( new HandleContext ); + } + + private: + boost::shared_ptr<HandleContext> m_pHandle; + }; + + //------------------------------------------------------------------------ + + // + // Convenience wrapper around the Schannel ScopedSecBuffer struct. + // + class ScopedSecBuffer : boost::noncopyable + { + public: + ScopedSecBuffer(PSecBuffer pSecBuffer) + : m_pSecBuffer(pSecBuffer) + { + } + + ~ScopedSecBuffer() + { + // Loop through all the output buffers and make sure we free them + if (m_pSecBuffer->pvBuffer) + FreeContextBuffer(m_pSecBuffer->pvBuffer); + } + + PSecBuffer AsPtr() + { + return m_pSecBuffer; + } + + PSecBuffer operator->() + { + return m_pSecBuffer; + } + + private: + PSecBuffer m_pSecBuffer; + }; + + //------------------------------------------------------------------------ + + // + // Convenience wrapper around the Schannel PCCERT_CONTEXT. + // + class ScopedCertContext + { + private: + struct HandleContext + { + HandleContext() + : m_pCertCtxt(NULL) + { + } + + HandleContext(PCCERT_CONTEXT pCert) + : m_pCertCtxt(pCert) + { + } + + ~HandleContext() + { + if (m_pCertCtxt) + CertFreeCertificateContext(m_pCertCtxt); + } + + PCCERT_CONTEXT m_pCertCtxt; + }; + + public: + ScopedCertContext() + : m_pHandle( new HandleContext ) + { + } + + explicit ScopedCertContext(PCCERT_CONTEXT pCert) + : m_pHandle( new HandleContext(pCert) ) + { + } + + // Copy constructor + explicit ScopedCertContext(const ScopedCertContext& rhs) + { + m_pHandle = rhs.m_pHandle; + } + + ~ScopedCertContext() + { + m_pHandle.reset(); + } + + PCCERT_CONTEXT* Reset() + { + FreeContext(); + return &m_pHandle->m_pCertCtxt; + } + + operator PCCERT_CONTEXT() const + { + return m_pHandle->m_pCertCtxt; + } + + PCCERT_CONTEXT operator->() const + { + return m_pHandle->m_pCertCtxt; + } + + ScopedCertContext& operator=(const ScopedCertContext& sh) + { + // Only update the internal handle if it's different + if (&m_pHandle->m_pCertCtxt != &sh.m_pHandle->m_pCertCtxt) + { + m_pHandle = sh.m_pHandle; + } + + return *this; + } + + void FreeContext() + { + m_pHandle.reset( new HandleContext ); + } + + private: + boost::shared_ptr<HandleContext> m_pHandle; + }; +} |