summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authordreijer <dreijer@echobit.net>2012-03-22 14:17:38 (GMT)
committerdreijer <dreijer@echobit.net>2012-03-22 15:56:31 (GMT)
commit5f9e12d9d197195a859ad523a39fdb752f2c4cff (patch)
tree827b31bc062cfef1432eb4b984760ec48d9e32b0 /Swiften/TLS
parent2fa37f2976b933ca0bcf5f85dd1615805776d67d (diff)
downloadswift-contrib-dreijer/schannel.zip
swift-contrib-dreijer/schannel.tar.bz2
Manual certificate verification.dreijer/schannel
Added two additional TLS errors related to revocation. License: This patch is BSD-licensed, see http://www.opensource.org/licenses/bsd-license.php
Diffstat (limited to 'Swiften/TLS')
-rw-r--r--Swiften/TLS/CertificateVerificationError.h2
-rw-r--r--Swiften/TLS/Schannel/SchannelCertificate.h7
-rw-r--r--Swiften/TLS/Schannel/SchannelContext.cpp143
-rw-r--r--Swiften/TLS/Schannel/SchannelContext.h4
-rw-r--r--Swiften/TLS/Schannel/SchannelUtil.h135
5 files changed, 275 insertions, 16 deletions
diff --git a/Swiften/TLS/CertificateVerificationError.h b/Swiften/TLS/CertificateVerificationError.h
index 22e6eaf..b17f5df 100644
--- a/Swiften/TLS/CertificateVerificationError.h
+++ b/Swiften/TLS/CertificateVerificationError.h
@@ -1,40 +1,42 @@
/*
* Copyright (c) 2010-2012 Remko Tronçon
* Licensed under the GNU General Public License v3.
* See Documentation/Licenses/GPLv3.txt for more information.
*/
#pragma once
#include <boost/shared_ptr.hpp>
#include <Swiften/Base/Error.h>
namespace Swift {
class CertificateVerificationError : public Error {
public:
typedef boost::shared_ptr<CertificateVerificationError> ref;
enum Type {
UnknownError,
Expired,
NotYetValid,
SelfSigned,
Rejected,
Untrusted,
InvalidPurpose,
PathLengthExceeded,
InvalidSignature,
InvalidCA,
InvalidServerIdentity,
+ Revoked,
+ RevocationCheckFailed
};
CertificateVerificationError(Type type = UnknownError) : type(type) {}
Type getType() const {
return type;
}
private:
Type type;
};
}
diff --git a/Swiften/TLS/Schannel/SchannelCertificate.h b/Swiften/TLS/Schannel/SchannelCertificate.h
index f531cff..395d3ec 100644
--- a/Swiften/TLS/Schannel/SchannelCertificate.h
+++ b/Swiften/TLS/Schannel/SchannelCertificate.h
@@ -16,66 +16,71 @@ namespace Swift
{
class SchannelCertificate : public Certificate
{
public:
typedef boost::shared_ptr<SchannelCertificate> ref;
public:
SchannelCertificate(const ScopedCertContext& certCtxt);
SchannelCertificate(const ByteArray& der);
std::string getSubjectName() const
{
return m_subjectName;
}
std::vector<std::string> getCommonNames() const
{
return m_commonNames;
}
std::vector<std::string> getSRVNames() const
{
return m_srvNames;
}
std::vector<std::string> getDNSNames() const
{
return m_dnsNames;
}
std::vector<std::string> getXMPPAddresses() const
{
return m_xmppAddresses;
}
- ByteArray toDER() const;
+ ScopedCertContext getCertContext() const
+ {
+ return m_cert;
+ }
+ ByteArray toDER() const;
+
private:
void parse();
std::string wstrToStr(const std::wstring& wstr);
void addSRVName(const std::string& name)
{
m_srvNames.push_back(name);
}
void addDNSName(const std::string& name)
{
m_dnsNames.push_back(name);
}
void addXMPPAddress(const std::string& addr)
{
m_xmppAddresses.push_back(addr);
}
private:
ScopedCertContext m_cert;
std::string m_subjectName;
std::vector<std::string> m_commonNames;
std::vector<std::string> m_dnsNames;
std::vector<std::string> m_xmppAddresses;
std::vector<std::string> m_srvNames;
};
}
diff --git a/Swiften/TLS/Schannel/SchannelContext.cpp b/Swiften/TLS/Schannel/SchannelContext.cpp
index b2fea65..9be1ded 100644
--- a/Swiften/TLS/Schannel/SchannelContext.cpp
+++ b/Swiften/TLS/Schannel/SchannelContext.cpp
@@ -1,219 +1,289 @@
/*
* Copyright (c) 2011 Soren Dreijer
* Licensed under the simplified BSD license.
* See Documentation/Licenses/BSD-simplified.txt for more information.
*/
-#include <Swiften/TLS/Schannel/SchannelContext.h>
-#include <Swiften/TLS/Schannel/SchannelCertificate.h>
+#include "Swiften/TLS/Schannel/SchannelContext.h"
+#include "Swiften/TLS/Schannel/SchannelCertificate.h"
#include <Swiften/TLS/CAPICertificate.h>
+#include <WinHTTP.h> // For SECURITY_FLAG_IGNORE_CERT_CN_INVALID
namespace Swift {
//------------------------------------------------------------------------
SchannelContext::SchannelContext()
: m_state(Start)
, m_secContext(0)
-, m_verificationError(CertificateVerificationError::UnknownError)
, m_my_cert_store(NULL)
, m_cert_store_name("MY")
, m_cert_name()
{
m_ctxtFlags = ISC_REQ_ALLOCATE_MEMORY |
ISC_REQ_CONFIDENTIALITY |
ISC_REQ_EXTENDED_ERROR |
ISC_REQ_INTEGRITY |
ISC_REQ_REPLAY_DETECT |
ISC_REQ_SEQUENCE_DETECT |
ISC_REQ_USE_SUPPLIED_CREDS |
ISC_REQ_STREAM;
ZeroMemory(&m_streamSizes, sizeof(m_streamSizes));
}
//------------------------------------------------------------------------
SchannelContext::~SchannelContext()
{
if (m_my_cert_store) CertCloseStore(m_my_cert_store, 0);
}
//------------------------------------------------------------------------
void SchannelContext::determineStreamSizes()
{
QueryContextAttributes(m_ctxtHandle, SECPKG_ATTR_STREAM_SIZES, &m_streamSizes);
}
//------------------------------------------------------------------------
void SchannelContext::connect()
{
- PCCERT_CONTEXT pCertContext = NULL;
+ ScopedCertContext pCertContext;
m_state = Connecting;
// If a user name is specified, then attempt to find a client
// certificate. Otherwise, just create a NULL credential.
if (!m_cert_name.empty())
{
if (m_my_cert_store == NULL)
{
m_my_cert_store = CertOpenSystemStore(0, m_cert_store_name.c_str());
if (!m_my_cert_store)
{
///// printf( "**** Error 0x%x returned by CertOpenSystemStore\n", GetLastError() );
indicateError();
return;
}
}
pCertContext = findCertificateInStore( m_my_cert_store, m_cert_name );
if (pCertContext == NULL)
{
///// printf("**** Error 0x%x returned by CertFindCertificateInStore\n", GetLastError());
indicateError();
return;
}
}
// We use an empty list for client certificates
PCCERT_CONTEXT clientCerts[1] = {0};
SCHANNEL_CRED sc = {0};
sc.dwVersion = SCHANNEL_CRED_VERSION;
/////SSL3?
sc.grbitEnabledProtocols = SP_PROT_SSL3_CLIENT | SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_2_CLIENT;
- sc.dwFlags = SCH_CRED_AUTO_CRED_VALIDATION | SCH_CRED_REVOCATION_CHECK_CHAIN;
+ sc.dwFlags = SCH_CRED_MANUAL_CRED_VALIDATION;
if (pCertContext)
{
sc.cCreds = 1;
- sc.paCred = &pCertContext;
+ sc.paCred = pCertContext.GetPointer();
sc.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS;
}
else
{
sc.cCreds = 0; // Let Crypto API find the appropriate certificate for us
sc.paCred = clientCerts;
sc.dwFlags |= SCH_CRED_USE_DEFAULT_CREDS;
}
// Swiften performs the server name check for us
sc.dwFlags |= SCH_CRED_NO_SERVERNAME_CHECK;
SECURITY_STATUS status = AcquireCredentialsHandle(
NULL,
UNISP_NAME,
SECPKG_CRED_OUTBOUND,
NULL,
&sc,
NULL,
NULL,
m_credHandle.Reset(),
NULL);
-
- // cleanup: Free the certificate context. Schannel has already made its own copy.
- if (pCertContext) CertFreeCertificateContext(pCertContext);
-
+
if (status != SEC_E_OK)
{
// We failed to obtain the credentials handle
indicateError();
return;
}
SecBuffer outBuffers[2];
// We let Schannel allocate the output buffer for us
outBuffers[0].pvBuffer = NULL;
outBuffers[0].cbBuffer = 0;
outBuffers[0].BufferType = SECBUFFER_TOKEN;
// Contains alert data if an alert is generated
outBuffers[1].pvBuffer = NULL;
outBuffers[1].cbBuffer = 0;
outBuffers[1].BufferType = SECBUFFER_ALERT;
// Make sure the output buffers are freed
ScopedSecBuffer scopedOutputData(&outBuffers[0]);
ScopedSecBuffer scopedOutputAlertData(&outBuffers[1]);
SecBufferDesc outBufferDesc = {0};
outBufferDesc.cBuffers = 2;
outBufferDesc.pBuffers = outBuffers;
outBufferDesc.ulVersion = SECBUFFER_VERSION;
// Create the initial security context
status = InitializeSecurityContext(
m_credHandle,
NULL,
NULL,
m_ctxtFlags,
0,
0,
NULL,
0,
m_ctxtHandle.Reset(),
&outBufferDesc,
&m_secContext,
NULL);
if (status != SEC_E_OK && status != SEC_I_CONTINUE_NEEDED)
{
// We failed to initialize the security context
+ handleCertError(status);
indicateError();
return;
}
// Start the handshake
sendDataOnNetwork(outBuffers[0].pvBuffer, outBuffers[0].cbBuffer);
if (status == SEC_E_OK)
{
+ status = validateServerCertificate();
+ if (status != SEC_E_OK)
+ handleCertError(status);
+
m_state = Connected;
determineStreamSizes();
onConnected();
}
}
//------------------------------------------------------------------------
+SECURITY_STATUS SchannelContext::validateServerCertificate()
+{
+ SchannelCertificate::ref pServerCert = boost::dynamic_pointer_cast<SchannelCertificate>( getPeerCertificate() );
+ if (!pServerCert)
+ return SEC_E_WRONG_PRINCIPAL;
+
+ const LPSTR usage[] =
+ {
+ szOID_PKIX_KP_SERVER_AUTH,
+ szOID_SERVER_GATED_CRYPTO,
+ szOID_SGC_NETSCAPE
+ };
+
+ CERT_CHAIN_PARA chainParams = {0};
+ chainParams.cbSize = sizeof(chainParams);
+ chainParams.RequestedUsage.dwType = USAGE_MATCH_TYPE_OR;
+ chainParams.RequestedUsage.Usage.cUsageIdentifier = ARRAYSIZE(usage);
+ chainParams.RequestedUsage.Usage.rgpszUsageIdentifier = const_cast<LPSTR*>(usage);
+
+ DWORD chainFlags = CERT_CHAIN_CACHE_END_CERT | CERT_CHAIN_REVOCATION_CHECK_CHAIN_EXCLUDE_ROOT;
+
+ ScopedCertChainContext pChainContext;
+
+ BOOL success = CertGetCertificateChain(
+ NULL, // Use the chain engine for the current user (assumes a user is logged in)
+ pServerCert->getCertContext(),
+ NULL,
+ NULL,
+ &chainParams,
+ chainFlags,
+ NULL,
+ pChainContext.Reset());
+
+ if (!success)
+ return GetLastError();
+
+ SSL_EXTRA_CERT_CHAIN_POLICY_PARA sslChainPolicy = {0};
+ sslChainPolicy.cbSize = sizeof(sslChainPolicy);
+ sslChainPolicy.dwAuthType = AUTHTYPE_SERVER;
+ sslChainPolicy.fdwChecks = SECURITY_FLAG_IGNORE_CERT_CN_INVALID; // Swiften checks the server name for us. Is this the correct way to disable server name checking?
+ sslChainPolicy.pwszServerName = NULL;
+
+ CERT_CHAIN_POLICY_PARA certChainPolicy = {0};
+ certChainPolicy.cbSize = sizeof(certChainPolicy);
+ certChainPolicy.dwFlags = CERT_CHAIN_POLICY_IGNORE_INVALID_NAME_FLAG; // Swiften checks the server name for us. Is this the correct way to disable server name checking?
+ certChainPolicy.pvExtraPolicyPara = &sslChainPolicy;
+
+ CERT_CHAIN_POLICY_STATUS certChainPolicyStatus = {0};
+ certChainPolicyStatus.cbSize = sizeof(certChainPolicyStatus);
+
+ // Verify the chain
+ if (!CertVerifyCertificateChainPolicy(
+ CERT_CHAIN_POLICY_SSL,
+ pChainContext,
+ &certChainPolicy,
+ &certChainPolicyStatus))
+ {
+ return GetLastError();
+ }
+
+ if (certChainPolicyStatus.dwError != S_OK)
+ return certChainPolicyStatus.dwError;
+
+ return S_OK;
+}
+
+//------------------------------------------------------------------------
+
void SchannelContext::appendNewData(const SafeByteArray& data)
{
size_t originalSize = m_receivedData.size();
m_receivedData.resize( originalSize + data.size() );
memcpy( &m_receivedData[0] + originalSize, &data[0], data.size() );
}
//------------------------------------------------------------------------
void SchannelContext::continueHandshake(const SafeByteArray& data)
{
appendNewData(data);
while (!m_receivedData.empty())
{
SecBuffer inBuffers[2];
// Provide Schannel with the remote host's handshake data
inBuffers[0].pvBuffer = (char*)(&m_receivedData[0]);
inBuffers[0].cbBuffer = (unsigned long)m_receivedData.size();
inBuffers[0].BufferType = SECBUFFER_TOKEN;
inBuffers[1].pvBuffer = NULL;
inBuffers[1].cbBuffer = 0;
inBuffers[1].BufferType = SECBUFFER_EMPTY;
SecBufferDesc inBufferDesc = {0};
inBufferDesc.cBuffers = 2;
inBufferDesc.pBuffers = inBuffers;
inBufferDesc.ulVersion = SECBUFFER_VERSION;
SecBuffer outBuffers[2];
// We let Schannel allocate the output buffer for us
outBuffers[0].pvBuffer = NULL;
@@ -238,93 +308,139 @@ void SchannelContext::continueHandshake(const SafeByteArray& data)
m_credHandle,
m_ctxtHandle,
NULL,
m_ctxtFlags,
0,
0,
&inBufferDesc,
0,
NULL,
&outBufferDesc,
&m_secContext,
NULL);
if (status == SEC_E_INCOMPLETE_MESSAGE)
{
// Wait for more data to arrive
break;
}
else if (status == SEC_I_CONTINUE_NEEDED)
{
SecBuffer* pDataBuffer = &outBuffers[0];
SecBuffer* pExtraBuffer = &inBuffers[1];
if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL)
sendDataOnNetwork(pDataBuffer->pvBuffer, pDataBuffer->cbBuffer);
if (pExtraBuffer->BufferType == SECBUFFER_EXTRA)
m_receivedData.erase(m_receivedData.begin(), m_receivedData.end() - pExtraBuffer->cbBuffer);
else
m_receivedData.clear();
break;
}
else if (status == SEC_E_OK)
{
+ status = validateServerCertificate();
+ if (status != SEC_E_OK)
+ handleCertError(status);
+
SecBuffer* pExtraBuffer = &inBuffers[1];
if (pExtraBuffer && pExtraBuffer->cbBuffer > 0)
m_receivedData.erase(m_receivedData.begin(), m_receivedData.end() - pExtraBuffer->cbBuffer);
else
m_receivedData.clear();
m_state = Connected;
determineStreamSizes();
onConnected();
}
else
{
// We failed to initialize the security context
+ handleCertError(status);
indicateError();
return;
}
}
}
//------------------------------------------------------------------------
+void SchannelContext::handleCertError(SECURITY_STATUS status)
+{
+ if (status == SEC_E_UNTRUSTED_ROOT ||
+ status == CERT_E_UNTRUSTEDROOT ||
+ status == CRYPT_E_ISSUER_SERIALNUMBER ||
+ status == CRYPT_E_SIGNER_NOT_FOUND ||
+ status == CRYPT_E_NO_TRUSTED_SIGNER)
+ {
+ m_verificationError = CertificateVerificationError::Untrusted;
+ }
+ else if (status == SEC_E_CERT_EXPIRED ||
+ status == CERT_E_EXPIRED)
+ {
+ m_verificationError = CertificateVerificationError::Expired;
+ }
+ else if (status == CRYPT_E_SELF_SIGNED)
+ {
+ m_verificationError = CertificateVerificationError::SelfSigned;
+ }
+ else if (status == CRYPT_E_HASH_VALUE ||
+ status == TRUST_E_CERT_SIGNATURE)
+ {
+ m_verificationError = CertificateVerificationError::InvalidSignature;
+ }
+ else if (status == CRYPT_E_REVOKED)
+ {
+ m_verificationError = CertificateVerificationError::Revoked;
+ }
+ else if (status == CRYPT_E_NO_REVOCATION_CHECK ||
+ status == CRYPT_E_REVOCATION_OFFLINE)
+ {
+ m_verificationError = CertificateVerificationError::RevocationCheckFailed;
+ }
+ else
+ {
+ m_verificationError = CertificateVerificationError::UnknownError;
+ }
+}
+
+//------------------------------------------------------------------------
+
void SchannelContext::sendDataOnNetwork(const void* pData, size_t dataSize)
{
if (dataSize > 0 && pData)
{
SafeByteArray byteArray(dataSize);
memcpy(&byteArray[0], pData, dataSize);
onDataForNetwork(byteArray);
}
}
//------------------------------------------------------------------------
void SchannelContext::forwardDataToApplication(const void* pData, size_t dataSize)
{
SafeByteArray byteArray(dataSize);
memcpy(&byteArray[0], pData, dataSize);
onDataForApplication(byteArray);
}
//------------------------------------------------------------------------
void SchannelContext::handleDataFromApplication(const SafeByteArray& data)
{
// Don't attempt to send data until we're fully connected
if (m_state == Connecting)
return;
// Encrypt the data
encryptAndSendData(data);
}
//------------------------------------------------------------------------
@@ -417,70 +533,73 @@ void SchannelContext::decryptAndProcessData(const SafeByteArray& data)
indicateError();
break;
}
SecBuffer* pDataBuffer = NULL;
SecBuffer* pExtraBuffer = NULL;
for (int i = 0; i < 4; ++i)
{
if (inBuffers[i].BufferType == SECBUFFER_DATA)
pDataBuffer = &inBuffers[i];
else if (inBuffers[i].BufferType == SECBUFFER_EXTRA)
pExtraBuffer = &inBuffers[i];
}
if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL)
forwardDataToApplication(pDataBuffer->pvBuffer, pDataBuffer->cbBuffer);
// If there is extra data left over from the decryption operation, we call DecryptMessage() again
if (pExtraBuffer)
{
m_receivedData.erase(m_receivedData.begin(), m_receivedData.end() - pExtraBuffer->cbBuffer);
}
else
{
// We're done
m_receivedData.erase(m_receivedData.begin(), m_receivedData.begin() + inData);
}
}
}
//------------------------------------------------------------------------
void SchannelContext::encryptAndSendData(const SafeByteArray& data)
{
+ if (m_streamSizes.cbMaximumMessage == 0)
+ return;
+
SecBuffer outBuffers[4] = {0};
// Calculate the largest required size of the send buffer
size_t messageBufferSize = (data.size() > m_streamSizes.cbMaximumMessage)
? m_streamSizes.cbMaximumMessage
: data.size();
// Allocate a packet for the encrypted data
SafeByteArray sendBuffer;
sendBuffer.resize(m_streamSizes.cbHeader + messageBufferSize + m_streamSizes.cbTrailer);
size_t bytesSent = 0;
do
{
size_t bytesLeftToSend = data.size() - bytesSent;
// Calculate how much of the send buffer we'll be using for this chunk
size_t bytesToSend = (bytesLeftToSend > m_streamSizes.cbMaximumMessage)
? m_streamSizes.cbMaximumMessage
: bytesLeftToSend;
// Copy the plain text data into the send buffer
memcpy(&sendBuffer[0] + m_streamSizes.cbHeader, &data[0] + bytesSent, bytesToSend);
outBuffers[0].pvBuffer = &sendBuffer[0];
outBuffers[0].cbBuffer = m_streamSizes.cbHeader;
outBuffers[0].BufferType = SECBUFFER_STREAM_HEADER;
outBuffers[1].pvBuffer = &sendBuffer[0] + m_streamSizes.cbHeader;
outBuffers[1].cbBuffer = (unsigned long)bytesToSend;
outBuffers[1].BufferType = SECBUFFER_DATA;
outBuffers[2].pvBuffer = &sendBuffer[0] + m_streamSizes.cbHeader + bytesToSend;
outBuffers[2].cbBuffer = m_streamSizes.cbTrailer;
outBuffers[2].BufferType = SECBUFFER_STREAM_TRAILER;
@@ -512,54 +631,54 @@ void SchannelContext::encryptAndSendData(const SafeByteArray& data)
bool SchannelContext::setClientCertificate(CertificateWithKey::ref certificate)
{
boost::shared_ptr<CAPICertificate> capiCertificate = boost::dynamic_pointer_cast<CAPICertificate>(certificate);
if (!capiCertificate || capiCertificate->isNull()) {
return false;
}
// We assume that the Certificate Store Name/Certificate Name
// are valid at this point
m_cert_store_name = capiCertificate->getCertStoreName();
m_cert_name = capiCertificate->getCertName();
return true;
}
//------------------------------------------------------------------------
Certificate::ref SchannelContext::getPeerCertificate() const
{
SchannelCertificate::ref pCertificate;
ScopedCertContext pServerCert;
SECURITY_STATUS status = QueryContextAttributes(m_ctxtHandle, SECPKG_ATTR_REMOTE_CERT_CONTEXT, pServerCert.Reset());
if (status != SEC_E_OK)
return pCertificate;
pCertificate.reset( new SchannelCertificate(pServerCert) );
return pCertificate;
}
//------------------------------------------------------------------------
CertificateVerificationError::ref SchannelContext::getPeerCertificateVerificationError() const
{
boost::shared_ptr<CertificateVerificationError> pCertError;
- if (m_state == Error)
- pCertError.reset( new CertificateVerificationError(m_verificationError) );
+ if (m_verificationError)
+ pCertError.reset( new CertificateVerificationError(*m_verificationError) );
return pCertError;
}
//------------------------------------------------------------------------
ByteArray SchannelContext::getFinishMessage() const
{
// TODO: Implement
ByteArray emptyArray;
return emptyArray;
}
//------------------------------------------------------------------------
}
diff --git a/Swiften/TLS/Schannel/SchannelContext.h b/Swiften/TLS/Schannel/SchannelContext.h
index 7726c41..70b0694 100644
--- a/Swiften/TLS/Schannel/SchannelContext.h
+++ b/Swiften/TLS/Schannel/SchannelContext.h
@@ -19,70 +19,72 @@
#include <security.h>
#include <schnlsp.h>
#include <boost/noncopyable.hpp>
namespace Swift
{
class SchannelContext : public TLSContext, boost::noncopyable
{
public:
typedef boost::shared_ptr<SchannelContext> sp_t;
public:
SchannelContext();
~SchannelContext();
//
// TLSContext
//
virtual void connect();
virtual bool setClientCertificate(CertificateWithKey::ref cert);
virtual void handleDataFromNetwork(const SafeByteArray& data);
virtual void handleDataFromApplication(const SafeByteArray& data);
virtual Certificate::ref getPeerCertificate() const;
virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const;
virtual ByteArray getFinishMessage() const;
private:
void determineStreamSizes();
void continueHandshake(const SafeByteArray& data);
void indicateError();
+ void handleCertError(SECURITY_STATUS status) ;
void sendDataOnNetwork(const void* pData, size_t dataSize);
void forwardDataToApplication(const void* pData, size_t dataSize);
void decryptAndProcessData(const SafeByteArray& data);
void encryptAndSendData(const SafeByteArray& data);
void appendNewData(const SafeByteArray& data);
+ SECURITY_STATUS validateServerCertificate();
private:
enum SchannelState
{
Start,
Connecting,
Connected,
Error
};
SchannelState m_state;
- CertificateVerificationError m_verificationError;
+ boost::optional<CertificateVerificationError> m_verificationError;
ULONG m_secContext;
ScopedCredHandle m_credHandle;
ScopedCtxtHandle m_ctxtHandle;
DWORD m_ctxtFlags;
SecPkgContext_StreamSizes m_streamSizes;
std::vector<char> m_receivedData;
HCERTSTORE m_my_cert_store;
std::string m_cert_store_name;
std::string m_cert_name;
};
}
diff --git a/Swiften/TLS/Schannel/SchannelUtil.h b/Swiften/TLS/Schannel/SchannelUtil.h
index 0a54f16..4f73aac 100644
--- a/Swiften/TLS/Schannel/SchannelUtil.h
+++ b/Swiften/TLS/Schannel/SchannelUtil.h
@@ -214,81 +214,212 @@ namespace Swift
{
private:
struct HandleContext
{
HandleContext()
: m_pCertCtxt(NULL)
{
}
HandleContext(PCCERT_CONTEXT pCert)
: m_pCertCtxt(pCert)
{
}
~HandleContext()
{
if (m_pCertCtxt)
CertFreeCertificateContext(m_pCertCtxt);
}
PCCERT_CONTEXT m_pCertCtxt;
};
public:
ScopedCertContext()
: m_pHandle( new HandleContext )
{
}
explicit ScopedCertContext(PCCERT_CONTEXT pCert)
: m_pHandle( new HandleContext(pCert) )
{
}
// Copy constructor
- explicit ScopedCertContext(const ScopedCertContext& rhs)
+ ScopedCertContext(const ScopedCertContext& rhs)
{
m_pHandle = rhs.m_pHandle;
}
~ScopedCertContext()
{
m_pHandle.reset();
}
PCCERT_CONTEXT* Reset()
{
FreeContext();
return &m_pHandle->m_pCertCtxt;
}
operator PCCERT_CONTEXT() const
{
return m_pHandle->m_pCertCtxt;
}
+ PCCERT_CONTEXT* GetPointer() const
+ {
+ return &m_pHandle->m_pCertCtxt;
+ }
+
PCCERT_CONTEXT operator->() const
{
return m_pHandle->m_pCertCtxt;
}
ScopedCertContext& operator=(const ScopedCertContext& sh)
{
// Only update the internal handle if it's different
if (&m_pHandle->m_pCertCtxt != &sh.m_pHandle->m_pCertCtxt)
{
m_pHandle = sh.m_pHandle;
}
return *this;
}
+ ScopedCertContext& operator=(PCCERT_CONTEXT pCertCtxt)
+ {
+ // Only update the internal handle if it's different
+ if (m_pHandle && m_pHandle->m_pCertCtxt != pCertCtxt)
+ m_pHandle.reset( new HandleContext(pCertCtxt) );
+
+ return *this;
+ }
+
+ void FreeContext()
+ {
+ m_pHandle.reset( new HandleContext );
+ }
+
+ private:
+ boost::shared_ptr<HandleContext> m_pHandle;
+ };
+
+ //------------------------------------------------------------------------
+
+ //
+ // Convenience wrapper around the Schannel HCERTSTORE.
+ //
+ class ScopedCertStore : boost::noncopyable
+ {
+ public:
+ ScopedCertStore(HCERTSTORE hCertStore)
+ : m_hCertStore(hCertStore)
+ {
+ }
+
+ ~ScopedCertStore()
+ {
+ // Forcefully free all memory related to the store, i.e. we assume all CertContext's that have been opened via this
+ // cert store have been closed at this point.
+ if (m_hCertStore)
+ CertCloseStore(m_hCertStore, CERT_CLOSE_STORE_FORCE_FLAG);
+ }
+
+ operator HCERTSTORE() const
+ {
+ return m_hCertStore;
+ }
+
+ private:
+ HCERTSTORE m_hCertStore;
+ };
+
+ //------------------------------------------------------------------------
+
+ //
+ // Convenience wrapper around the Schannel CERT_CHAIN_CONTEXT.
+ //
+ class ScopedCertChainContext
+ {
+ private:
+ struct HandleContext
+ {
+ HandleContext()
+ : m_pCertChainCtxt(NULL)
+ {
+ }
+
+ HandleContext(PCCERT_CHAIN_CONTEXT pCert)
+ : m_pCertChainCtxt(pCert)
+ {
+ }
+
+ ~HandleContext()
+ {
+ if (m_pCertChainCtxt)
+ CertFreeCertificateChain(m_pCertChainCtxt);
+ }
+
+ PCCERT_CHAIN_CONTEXT m_pCertChainCtxt;
+ };
+
+ public:
+ ScopedCertChainContext()
+ : m_pHandle( new HandleContext )
+ {
+ }
+
+ explicit ScopedCertChainContext(PCCERT_CHAIN_CONTEXT pCert)
+ : m_pHandle( new HandleContext(pCert) )
+ {
+ }
+
+ // Copy constructor
+ ScopedCertChainContext(const ScopedCertChainContext& rhs)
+ {
+ m_pHandle = rhs.m_pHandle;
+ }
+
+ ~ScopedCertChainContext()
+ {
+ m_pHandle.reset();
+ }
+
+ PCCERT_CHAIN_CONTEXT* Reset()
+ {
+ FreeContext();
+ return &m_pHandle->m_pCertChainCtxt;
+ }
+
+ operator PCCERT_CHAIN_CONTEXT() const
+ {
+ return m_pHandle->m_pCertChainCtxt;
+ }
+
+ PCCERT_CHAIN_CONTEXT operator->() const
+ {
+ return m_pHandle->m_pCertChainCtxt;
+ }
+
+ ScopedCertChainContext& operator=(const ScopedCertChainContext& sh)
+ {
+ // Only update the internal handle if it's different
+ if (&m_pHandle->m_pCertChainCtxt != &sh.m_pHandle->m_pCertChainCtxt)
+ {
+ m_pHandle = sh.m_pHandle;
+ }
+
+ return *this;
+ }
+
void FreeContext()
{
m_pHandle.reset( new HandleContext );
}
private:
boost::shared_ptr<HandleContext> m_pHandle;
- };
+ };
}