summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to 'Swiften/TLS/Schannel/SchannelContext.cpp')
-rw-r--r--Swiften/TLS/Schannel/SchannelContext.cpp503
1 files changed, 503 insertions, 0 deletions
diff --git a/Swiften/TLS/Schannel/SchannelContext.cpp b/Swiften/TLS/Schannel/SchannelContext.cpp
new file mode 100644
index 0000000..6771d4a
--- /dev/null
+++ b/Swiften/TLS/Schannel/SchannelContext.cpp
@@ -0,0 +1,503 @@
+/*
+ * Copyright (c) 2011 Soren Dreijer
+ * Licensed under the simplified BSD license.
+ * See Documentation/Licenses/BSD-simplified.txt for more information.
+ */
+
+#include "Swiften/TLS/Schannel/SchannelContext.h"
+#include "Swiften/TLS/Schannel/SchannelCertificate.h"
+
+namespace Swift {
+
+//------------------------------------------------------------------------
+
+SchannelContext::SchannelContext()
+: m_state(Start)
+, m_secContext(0)
+, m_verificationError(CertificateVerificationError::UnknownError)
+{
+ m_ctxtFlags = ISC_REQ_ALLOCATE_MEMORY |
+ ISC_REQ_CONFIDENTIALITY |
+ ISC_REQ_EXTENDED_ERROR |
+ ISC_REQ_INTEGRITY |
+ ISC_REQ_REPLAY_DETECT |
+ ISC_REQ_SEQUENCE_DETECT |
+ ISC_REQ_USE_SUPPLIED_CREDS |
+ ISC_REQ_STREAM;
+
+ ZeroMemory(&m_streamSizes, sizeof(m_streamSizes));
+}
+
+//------------------------------------------------------------------------
+
+void SchannelContext::determineStreamSizes()
+{
+ QueryContextAttributes(m_ctxtHandle, SECPKG_ATTR_STREAM_SIZES, &m_streamSizes);
+}
+
+//------------------------------------------------------------------------
+
+void SchannelContext::connect()
+{
+ m_state = Connecting;
+
+ // We use an empty list for client certificates
+ PCCERT_CONTEXT clientCerts[1] = {0};
+
+ SCHANNEL_CRED sc = {0};
+ sc.dwVersion = SCHANNEL_CRED_VERSION;
+ sc.cCreds = 0; // Let Crypto API find the appropriate certificate for us
+ sc.paCred = clientCerts;
+ sc.grbitEnabledProtocols = SP_PROT_SSL3_CLIENT | SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_2_CLIENT;
+ sc.dwFlags = SCH_CRED_AUTO_CRED_VALIDATION | /*SCH_CRED_NO_DEFAULT_CREDS*/ SCH_CRED_USE_DEFAULT_CREDS | SCH_CRED_REVOCATION_CHECK_CHAIN;
+
+ // Swiften performs the server name check for us
+ sc.dwFlags |= SCH_CRED_NO_SERVERNAME_CHECK;
+
+ SECURITY_STATUS status = AcquireCredentialsHandle(
+ NULL,
+ UNISP_NAME,
+ SECPKG_CRED_OUTBOUND,
+ NULL,
+ &sc,
+ NULL,
+ NULL,
+ m_credHandle.Reset(),
+ NULL);
+
+ if (status != SEC_E_OK)
+ {
+ // We failed to obtain the credentials handle
+ indicateError();
+ return;
+ }
+
+ SecBuffer outBuffers[2];
+
+ // We let Schannel allocate the output buffer for us
+ outBuffers[0].pvBuffer = NULL;
+ outBuffers[0].cbBuffer = 0;
+ outBuffers[0].BufferType = SECBUFFER_TOKEN;
+
+ // Contains alert data if an alert is generated
+ outBuffers[1].pvBuffer = NULL;
+ outBuffers[1].cbBuffer = 0;
+ outBuffers[1].BufferType = SECBUFFER_ALERT;
+
+ // Make sure the output buffers are freed
+ ScopedSecBuffer scopedOutputData(&outBuffers[0]);
+ ScopedSecBuffer scopedOutputAlertData(&outBuffers[1]);
+
+ SecBufferDesc outBufferDesc = {0};
+ outBufferDesc.cBuffers = 2;
+ outBufferDesc.pBuffers = outBuffers;
+ outBufferDesc.ulVersion = SECBUFFER_VERSION;
+
+ // Create the initial security context
+ status = InitializeSecurityContext(
+ m_credHandle,
+ NULL,
+ NULL,
+ m_ctxtFlags,
+ 0,
+ 0,
+ NULL,
+ 0,
+ m_ctxtHandle.Reset(),
+ &outBufferDesc,
+ &m_secContext,
+ NULL);
+
+ if (status != SEC_E_OK && status != SEC_I_CONTINUE_NEEDED)
+ {
+ // We failed to initialize the security context
+ indicateError();
+ return;
+ }
+
+ // Start the handshake
+ sendDataOnNetwork(outBuffers[0].pvBuffer, outBuffers[0].cbBuffer);
+
+ if (status == SEC_E_OK)
+ {
+ m_state = Connected;
+ determineStreamSizes();
+
+ onConnected();
+ }
+}
+
+//------------------------------------------------------------------------
+
+void SchannelContext::appendNewData(const SafeByteArray& data)
+{
+ size_t originalSize = m_receivedData.size();
+ m_receivedData.resize( originalSize + data.size() );
+ memcpy( &m_receivedData[0] + originalSize, &data[0], data.size() );
+}
+
+//------------------------------------------------------------------------
+
+void SchannelContext::continueHandshake(const SafeByteArray& data)
+{
+ appendNewData(data);
+
+ while (!m_receivedData.empty())
+ {
+ SecBuffer inBuffers[2];
+
+ // Provide Schannel with the remote host's handshake data
+ inBuffers[0].pvBuffer = (char*)(&m_receivedData[0]);
+ inBuffers[0].cbBuffer = (unsigned long)m_receivedData.size();
+ inBuffers[0].BufferType = SECBUFFER_TOKEN;
+
+ inBuffers[1].pvBuffer = NULL;
+ inBuffers[1].cbBuffer = 0;
+ inBuffers[1].BufferType = SECBUFFER_EMPTY;
+
+ SecBufferDesc inBufferDesc = {0};
+ inBufferDesc.cBuffers = 2;
+ inBufferDesc.pBuffers = inBuffers;
+ inBufferDesc.ulVersion = SECBUFFER_VERSION;
+
+ SecBuffer outBuffers[2];
+
+ // We let Schannel allocate the output buffer for us
+ outBuffers[0].pvBuffer = NULL;
+ outBuffers[0].cbBuffer = 0;
+ outBuffers[0].BufferType = SECBUFFER_TOKEN;
+
+ // Contains alert data if an alert is generated
+ outBuffers[1].pvBuffer = NULL;
+ outBuffers[1].cbBuffer = 0;
+ outBuffers[1].BufferType = SECBUFFER_ALERT;
+
+ // Make sure the output buffers are freed
+ ScopedSecBuffer scopedOutputData(&outBuffers[0]);
+ ScopedSecBuffer scopedOutputAlertData(&outBuffers[1]);
+
+ SecBufferDesc outBufferDesc = {0};
+ outBufferDesc.cBuffers = 2;
+ outBufferDesc.pBuffers = outBuffers;
+ outBufferDesc.ulVersion = SECBUFFER_VERSION;
+
+ SECURITY_STATUS status = InitializeSecurityContext(
+ m_credHandle,
+ m_ctxtHandle,
+ NULL,
+ m_ctxtFlags,
+ 0,
+ 0,
+ &inBufferDesc,
+ 0,
+ NULL,
+ &outBufferDesc,
+ &m_secContext,
+ NULL);
+
+ if (status == SEC_E_INCOMPLETE_MESSAGE)
+ {
+ // Wait for more data to arrive
+ break;
+ }
+ else if (status == SEC_I_CONTINUE_NEEDED)
+ {
+ SecBuffer* pDataBuffer = &outBuffers[0];
+ SecBuffer* pExtraBuffer = &inBuffers[1];
+
+ if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL)
+ sendDataOnNetwork(pDataBuffer->pvBuffer, pDataBuffer->cbBuffer);
+
+ if (pExtraBuffer->BufferType == SECBUFFER_EXTRA)
+ m_receivedData.erase(m_receivedData.begin(), m_receivedData.end() - pExtraBuffer->cbBuffer);
+ else
+ m_receivedData.clear();
+
+ break;
+ }
+ else if (status == SEC_E_OK)
+ {
+ SecBuffer* pExtraBuffer = &inBuffers[1];
+
+ if (pExtraBuffer && pExtraBuffer->cbBuffer > 0)
+ m_receivedData.erase(m_receivedData.begin(), m_receivedData.end() - pExtraBuffer->cbBuffer);
+ else
+ m_receivedData.clear();
+
+ m_state = Connected;
+ determineStreamSizes();
+
+ onConnected();
+ }
+ else
+ {
+ // We failed to initialize the security context
+ indicateError();
+ return;
+ }
+ }
+}
+
+//------------------------------------------------------------------------
+
+void SchannelContext::sendDataOnNetwork(const void* pData, size_t dataSize)
+{
+ if (dataSize > 0 && pData)
+ {
+ SafeByteArray byteArray(dataSize);
+ memcpy(&byteArray[0], pData, dataSize);
+
+ onDataForNetwork(byteArray);
+ }
+}
+
+//------------------------------------------------------------------------
+
+void SchannelContext::forwardDataToApplication(const void* pData, size_t dataSize)
+{
+ SafeByteArray byteArray(dataSize);
+ memcpy(&byteArray[0], pData, dataSize);
+
+ onDataForApplication(byteArray);
+}
+
+//------------------------------------------------------------------------
+
+void SchannelContext::handleDataFromApplication(const SafeByteArray& data)
+{
+ // Don't attempt to send data until we're fully connected
+ if (m_state == Connecting)
+ return;
+
+ // Encrypt the data
+ encryptAndSendData(data);
+}
+
+//------------------------------------------------------------------------
+
+void SchannelContext::handleDataFromNetwork(const SafeByteArray& data)
+{
+ switch (m_state)
+ {
+ case Connecting:
+ {
+ // We're still establishing the connection, so continue the handshake
+ continueHandshake(data);
+ }
+ break;
+
+ case Connected:
+ {
+ // Decrypt the data
+ decryptAndProcessData(data);
+ }
+ break;
+
+ default:
+ return;
+ }
+}
+
+//------------------------------------------------------------------------
+
+void SchannelContext::indicateError()
+{
+ m_state = Error;
+ m_receivedData.clear();
+ onError();
+}
+
+//------------------------------------------------------------------------
+
+void SchannelContext::decryptAndProcessData(const SafeByteArray& data)
+{
+ SecBuffer inBuffers[4] = {0};
+
+ appendNewData(data);
+
+ while (!m_receivedData.empty())
+ {
+ //
+ // MSDN:
+ // When using the Schannel SSP with contexts that are not connection oriented, on input,
+ // the structure must contain four SecBuffer structures. Exactly one buffer must be of type
+ // SECBUFFER_DATA and contain an encrypted message, which is decrypted in place. The remaining
+ // buffers are used for output and must be of type SECBUFFER_EMPTY. For connection-oriented
+ // contexts, a SECBUFFER_DATA type buffer must be supplied, as noted for nonconnection-oriented
+ // contexts. Additionally, a second SECBUFFER_TOKEN type buffer that contains a security token
+ // must also be supplied.
+ //
+ inBuffers[0].pvBuffer = (char*)(&m_receivedData[0]);
+ inBuffers[0].cbBuffer = (unsigned long)m_receivedData.size();
+ inBuffers[0].BufferType = SECBUFFER_DATA;
+
+ inBuffers[1].BufferType = SECBUFFER_EMPTY;
+ inBuffers[2].BufferType = SECBUFFER_EMPTY;
+ inBuffers[3].BufferType = SECBUFFER_EMPTY;
+
+ SecBufferDesc inBufferDesc = {0};
+ inBufferDesc.cBuffers = 4;
+ inBufferDesc.pBuffers = inBuffers;
+ inBufferDesc.ulVersion = SECBUFFER_VERSION;
+
+ size_t inData = m_receivedData.size();
+ SECURITY_STATUS status = DecryptMessage(m_ctxtHandle, &inBufferDesc, 0, NULL);
+
+ if (status == SEC_E_INCOMPLETE_MESSAGE)
+ {
+ // Wait for more data to arrive
+ break;
+ }
+ else if (status == SEC_I_RENEGOTIATE)
+ {
+ // TODO: Handle renegotiation scenarios
+ indicateError();
+ break;
+ }
+ else if (status == SEC_I_CONTEXT_EXPIRED)
+ {
+ indicateError();
+ break;
+ }
+ else if (status != SEC_E_OK)
+ {
+ indicateError();
+ break;
+ }
+
+ SecBuffer* pDataBuffer = NULL;
+ SecBuffer* pExtraBuffer = NULL;
+ for (int i = 0; i < 4; ++i)
+ {
+ if (inBuffers[i].BufferType == SECBUFFER_DATA)
+ pDataBuffer = &inBuffers[i];
+
+ else if (inBuffers[i].BufferType == SECBUFFER_EXTRA)
+ pExtraBuffer = &inBuffers[i];
+ }
+
+ if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL)
+ forwardDataToApplication(pDataBuffer->pvBuffer, pDataBuffer->cbBuffer);
+
+ // If there is extra data left over from the decryption operation, we call DecryptMessage() again
+ if (pExtraBuffer)
+ {
+ m_receivedData.erase(m_receivedData.begin(), m_receivedData.end() - pExtraBuffer->cbBuffer);
+ }
+ else
+ {
+ // We're done
+ m_receivedData.erase(m_receivedData.begin(), m_receivedData.begin() + inData);
+ }
+ }
+}
+
+//------------------------------------------------------------------------
+
+void SchannelContext::encryptAndSendData(const SafeByteArray& data)
+{
+ SecBuffer outBuffers[4] = {0};
+
+ // Calculate the largest required size of the send buffer
+ size_t messageBufferSize = (data.size() > m_streamSizes.cbMaximumMessage)
+ ? m_streamSizes.cbMaximumMessage
+ : data.size();
+
+ // Allocate a packet for the encrypted data
+ SafeByteArray sendBuffer;
+ sendBuffer.resize(m_streamSizes.cbHeader + messageBufferSize + m_streamSizes.cbTrailer);
+
+ size_t bytesSent = 0;
+ do
+ {
+ size_t bytesLeftToSend = data.size() - bytesSent;
+
+ // Calculate how much of the send buffer we'll be using for this chunk
+ size_t bytesToSend = (bytesLeftToSend > m_streamSizes.cbMaximumMessage)
+ ? m_streamSizes.cbMaximumMessage
+ : bytesLeftToSend;
+
+ // Copy the plain text data into the send buffer
+ memcpy(&sendBuffer[0] + m_streamSizes.cbHeader, &data[0] + bytesSent, bytesToSend);
+
+ outBuffers[0].pvBuffer = &sendBuffer[0];
+ outBuffers[0].cbBuffer = m_streamSizes.cbHeader;
+ outBuffers[0].BufferType = SECBUFFER_STREAM_HEADER;
+
+ outBuffers[1].pvBuffer = &sendBuffer[0] + m_streamSizes.cbHeader;
+ outBuffers[1].cbBuffer = (unsigned long)bytesToSend;
+ outBuffers[1].BufferType = SECBUFFER_DATA;
+
+ outBuffers[2].pvBuffer = &sendBuffer[0] + m_streamSizes.cbHeader + bytesToSend;
+ outBuffers[2].cbBuffer = m_streamSizes.cbTrailer;
+ outBuffers[2].BufferType = SECBUFFER_STREAM_TRAILER;
+
+ outBuffers[3].pvBuffer = 0;
+ outBuffers[3].cbBuffer = 0;
+ outBuffers[3].BufferType = SECBUFFER_EMPTY;
+
+ SecBufferDesc outBufferDesc = {0};
+ outBufferDesc.cBuffers = 4;
+ outBufferDesc.pBuffers = outBuffers;
+ outBufferDesc.ulVersion = SECBUFFER_VERSION;
+
+ SECURITY_STATUS status = EncryptMessage(m_ctxtHandle, 0, &outBufferDesc, 0);
+ if (status != SEC_E_OK)
+ {
+ indicateError();
+ return;
+ }
+
+ sendDataOnNetwork(&sendBuffer[0], outBuffers[0].cbBuffer + outBuffers[1].cbBuffer + outBuffers[2].cbBuffer);
+ bytesSent += bytesToSend;
+
+ } while (bytesSent < data.size());
+}
+
+//------------------------------------------------------------------------
+
+bool SchannelContext::setClientCertificate(const PKCS12Certificate& certificate)
+{
+ return false;
+}
+
+//------------------------------------------------------------------------
+
+Certificate::ref SchannelContext::getPeerCertificate() const
+{
+ SchannelCertificate::ref pCertificate;
+
+ ScopedCertContext pServerCert;
+ SECURITY_STATUS status = QueryContextAttributes(m_ctxtHandle, SECPKG_ATTR_REMOTE_CERT_CONTEXT, pServerCert.Reset());
+ if (status != SEC_E_OK)
+ return pCertificate;
+
+ pCertificate.reset( new SchannelCertificate(pServerCert) );
+ return pCertificate;
+}
+
+//------------------------------------------------------------------------
+
+CertificateVerificationError::ref SchannelContext::getPeerCertificateVerificationError() const
+{
+ boost::shared_ptr<CertificateVerificationError> pCertError;
+
+ if (m_state == Error)
+ pCertError.reset( new CertificateVerificationError(m_verificationError) );
+
+ return pCertError;
+}
+
+//------------------------------------------------------------------------
+
+ByteArray SchannelContext::getFinishMessage() const
+{
+ // TODO: Implement
+
+ ByteArray emptyArray;
+ return emptyArray;
+}
+
+//------------------------------------------------------------------------
+
+}