From 5f9e12d9d197195a859ad523a39fdb752f2c4cff Mon Sep 17 00:00:00 2001
From: dreijer <dreijer@echobit.net>
Date: Thu, 22 Mar 2012 09:17:38 -0500
Subject: Manual certificate verification. Added two additional TLS errors
 related to revocation.

License: This patch is BSD-licensed, see http://www.opensource.org/licenses/bsd-license.php

diff --git a/BuildTools/SCons/SConscript.boot b/BuildTools/SCons/SConscript.boot
index 188184c..dc8a8a5 100644
--- a/BuildTools/SCons/SConscript.boot
+++ b/BuildTools/SCons/SConscript.boot
@@ -55,6 +55,7 @@ vars.Add(PathVariable("docbook_xsl", "DocBook XSL", None, PathVariable.PathAccep
 vars.Add(BoolVariable("build_examples", "Build example programs", "yes"))
 vars.Add(BoolVariable("enable_variants", "Build in a separate dir under build/, depending on compile flags", "no"))
 vars.Add(BoolVariable("experimental", "Build experimental features", "no"))
+vars.Add(BoolVariable("set_iterator_debug_level", "Set _ITERATOR_DEBUG_LEVEL=0", "yes"))
 
 ################################################################################
 # Set up default build & configure environment
@@ -134,7 +135,8 @@ if env["debug"] :
 	if env["PLATFORM"] == "win32" :
 		env.Append(CCFLAGS = ["/Zi", "/MDd"])
 		env.Append(LINKFLAGS = ["/DEBUG"])
-		env.Append(CPPDEFINES = ["_ITERATOR_DEBUG_LEVEL=0"])
+		if env["set_iterator_debug_level"] :
+			env.Append(CPPDEFINES = ["_ITERATOR_DEBUG_LEVEL=0"])
 	else :
 		env.Append(CCFLAGS = ["-g"])
 elif env["PLATFORM"] == "win32" :
diff --git a/Swift/SConscript b/Swift/SConscript
index b66058b..49aa985 100644
--- a/Swift/SConscript
+++ b/Swift/SConscript
@@ -11,6 +11,6 @@ if env["SCONS_STAGE"] == "build" :
 			env["PROJECTS"].remove("Swift")
 	elif not GetOption("help") and not env.get("HAVE_QT", 0) :
 		print "Error: Swift requires Qt. Not building Swift."
-		env["PROJECTS"].remove("Swift")
+#		env["PROJECTS"].remove("Swift")
 	elif env["target"] == "native":
 		 SConscript("QtUI/SConscript")
diff --git a/Swiften/Client/ClientError.h b/Swiften/Client/ClientError.h
index baf1b0a..2f2d2af 100644
--- a/Swiften/Client/ClientError.h
+++ b/Swiften/Client/ClientError.h
@@ -40,6 +40,8 @@ namespace Swift {
 				InvalidCertificateSignatureError,
 				InvalidCAError,
 				InvalidServerIdentityError,
+				RevokedError,
+				RevocationCheckFailedError
 			};
 
 			ClientError(Type type = UnknownError) : type_(type) {}
diff --git a/Swiften/Client/CoreClient.cpp b/Swiften/Client/CoreClient.cpp
index f7e3b21..14481c6 100644
--- a/Swiften/Client/CoreClient.cpp
+++ b/Swiften/Client/CoreClient.cpp
@@ -271,6 +271,12 @@ void CoreClient::handleSessionFinished(boost::shared_ptr<Error> error) {
 				case CertificateVerificationError::InvalidServerIdentity:
 					clientError = ClientError(ClientError::InvalidServerIdentityError);
 					break;
+				case CertificateVerificationError::Revoked:
+					clientError = ClientError(ClientError::RevokedError);
+					break;
+				case CertificateVerificationError::RevocationCheckFailed:
+					clientError = ClientError(ClientError::RevocationCheckFailedError);
+					break;
 			}
 		}
 		actualError = boost::optional<ClientError>(clientError);
diff --git a/Swiften/TLS/CertificateVerificationError.h b/Swiften/TLS/CertificateVerificationError.h
index 22e6eaf..b17f5df 100644
--- a/Swiften/TLS/CertificateVerificationError.h
+++ b/Swiften/TLS/CertificateVerificationError.h
@@ -26,6 +26,8 @@ namespace Swift {
 				InvalidSignature,
 				InvalidCA,
 				InvalidServerIdentity,
+				Revoked,
+				RevocationCheckFailed
 			};
 
 			CertificateVerificationError(Type type = UnknownError) : type(type) {}
diff --git a/Swiften/TLS/Schannel/SchannelCertificate.h b/Swiften/TLS/Schannel/SchannelCertificate.h
index f531cff..395d3ec 100644
--- a/Swiften/TLS/Schannel/SchannelCertificate.h
+++ b/Swiften/TLS/Schannel/SchannelCertificate.h
@@ -48,8 +48,13 @@ namespace Swift
 			return m_xmppAddresses;
 		}
 
-		ByteArray toDER() const;
+		ScopedCertContext getCertContext() const
+		{
+			return m_cert;
+		}
 
+		ByteArray toDER() const;
+		
 	private:
 		void parse();
 		std::string wstrToStr(const std::wstring& wstr);
diff --git a/Swiften/TLS/Schannel/SchannelContext.cpp b/Swiften/TLS/Schannel/SchannelContext.cpp
index b2fea65..9be1ded 100644
--- a/Swiften/TLS/Schannel/SchannelContext.cpp
+++ b/Swiften/TLS/Schannel/SchannelContext.cpp
@@ -4,9 +4,10 @@
  * See Documentation/Licenses/BSD-simplified.txt for more information.
  */
 
-#include <Swiften/TLS/Schannel/SchannelContext.h>
-#include <Swiften/TLS/Schannel/SchannelCertificate.h>
+#include "Swiften/TLS/Schannel/SchannelContext.h"
+#include "Swiften/TLS/Schannel/SchannelCertificate.h"
 #include <Swiften/TLS/CAPICertificate.h>
+#include <WinHTTP.h> // For SECURITY_FLAG_IGNORE_CERT_CN_INVALID
 
 namespace Swift {
 
@@ -15,7 +16,6 @@ namespace Swift {
 SchannelContext::SchannelContext() 
 : m_state(Start)
 , m_secContext(0)
-, m_verificationError(CertificateVerificationError::UnknownError)
 , m_my_cert_store(NULL)
 , m_cert_store_name("MY")
 , m_cert_name()
@@ -50,7 +50,7 @@ void SchannelContext::determineStreamSizes()
 
 void SchannelContext::connect() 
 {
-	PCCERT_CONTEXT   pCertContext = NULL;
+	ScopedCertContext pCertContext;
 
 	m_state = Connecting;
 
@@ -86,12 +86,12 @@ void SchannelContext::connect()
 
 /////SSL3?
 	sc.grbitEnabledProtocols = SP_PROT_SSL3_CLIENT | SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_2_CLIENT;
-	sc.dwFlags = SCH_CRED_AUTO_CRED_VALIDATION | SCH_CRED_REVOCATION_CHECK_CHAIN;
+	sc.dwFlags = SCH_CRED_MANUAL_CRED_VALIDATION;
 
 	if (pCertContext)
 	{
 		sc.cCreds = 1;
-		sc.paCred = &pCertContext;
+		sc.paCred = pCertContext.GetPointer();
 		sc.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS;
 	}
 	else
@@ -114,10 +114,7 @@ void SchannelContext::connect()
 		NULL,
 		m_credHandle.Reset(),
 		NULL);
-
-	// cleanup: Free the certificate context. Schannel has already made its own copy.
-	if (pCertContext) CertFreeCertificateContext(pCertContext);
-
+	
 	if (status != SEC_E_OK) 
 	{
 		// We failed to obtain the credentials handle
@@ -164,6 +161,7 @@ void SchannelContext::connect()
 	if (status != SEC_E_OK && status != SEC_I_CONTINUE_NEEDED) 
 	{
 		// We failed to initialize the security context
+		handleCertError(status);
 		indicateError();
 		return;
 	}
@@ -173,6 +171,10 @@ void SchannelContext::connect()
 
 	if (status == SEC_E_OK) 
 	{
+		status = validateServerCertificate();
+		if (status != SEC_E_OK)
+			handleCertError(status);
+
 		m_state = Connected;
 		determineStreamSizes();
 
@@ -182,6 +184,74 @@ void SchannelContext::connect()
 
 //------------------------------------------------------------------------
 
+SECURITY_STATUS SchannelContext::validateServerCertificate()
+{
+	SchannelCertificate::ref pServerCert = boost::dynamic_pointer_cast<SchannelCertificate>( getPeerCertificate() );
+	if (!pServerCert)
+		return SEC_E_WRONG_PRINCIPAL;
+
+	const LPSTR usage[] = 
+	{
+		szOID_PKIX_KP_SERVER_AUTH,
+		szOID_SERVER_GATED_CRYPTO,
+		szOID_SGC_NETSCAPE
+	};
+
+	CERT_CHAIN_PARA chainParams = {0};
+	chainParams.cbSize = sizeof(chainParams);
+	chainParams.RequestedUsage.dwType = USAGE_MATCH_TYPE_OR;
+	chainParams.RequestedUsage.Usage.cUsageIdentifier = ARRAYSIZE(usage);
+	chainParams.RequestedUsage.Usage.rgpszUsageIdentifier = const_cast<LPSTR*>(usage);
+
+	DWORD chainFlags = CERT_CHAIN_CACHE_END_CERT | CERT_CHAIN_REVOCATION_CHECK_CHAIN_EXCLUDE_ROOT;
+
+	ScopedCertChainContext pChainContext;
+
+	BOOL success = CertGetCertificateChain(
+		NULL, // Use the chain engine for the current user (assumes a user is logged in)
+		pServerCert->getCertContext(),
+		NULL,
+		NULL,
+		&chainParams,
+		chainFlags,
+		NULL,
+		pChainContext.Reset());
+
+	if (!success)
+		return GetLastError();
+
+	SSL_EXTRA_CERT_CHAIN_POLICY_PARA sslChainPolicy = {0};
+	sslChainPolicy.cbSize = sizeof(sslChainPolicy);
+	sslChainPolicy.dwAuthType = AUTHTYPE_SERVER;
+	sslChainPolicy.fdwChecks = SECURITY_FLAG_IGNORE_CERT_CN_INVALID; // Swiften checks the server name for us. Is this the correct way to disable server name checking?
+	sslChainPolicy.pwszServerName = NULL;
+
+	CERT_CHAIN_POLICY_PARA certChainPolicy = {0};
+	certChainPolicy.cbSize = sizeof(certChainPolicy);
+	certChainPolicy.dwFlags = CERT_CHAIN_POLICY_IGNORE_INVALID_NAME_FLAG; // Swiften checks the server name for us. Is this the correct way to disable server name checking?
+	certChainPolicy.pvExtraPolicyPara = &sslChainPolicy;
+
+	CERT_CHAIN_POLICY_STATUS certChainPolicyStatus = {0};
+	certChainPolicyStatus.cbSize = sizeof(certChainPolicyStatus);
+
+	// Verify the chain
+	if (!CertVerifyCertificateChainPolicy(
+		CERT_CHAIN_POLICY_SSL,
+		pChainContext,
+		&certChainPolicy,
+		&certChainPolicyStatus))
+	{
+		return GetLastError();
+	}
+
+	if (certChainPolicyStatus.dwError != S_OK)
+		return certChainPolicyStatus.dwError;
+
+	return S_OK;
+}
+
+//------------------------------------------------------------------------
+
 void SchannelContext::appendNewData(const SafeByteArray& data)
 {
 	size_t originalSize = m_receivedData.size();
@@ -270,6 +340,10 @@ void SchannelContext::continueHandshake(const SafeByteArray& data)
 		}
 		else if (status == SEC_E_OK) 
 		{
+			status = validateServerCertificate();
+			if (status != SEC_E_OK)
+				handleCertError(status);
+
 			SecBuffer* pExtraBuffer = &inBuffers[1];
 			
 			if (pExtraBuffer && pExtraBuffer->cbBuffer > 0)
@@ -285,6 +359,7 @@ void SchannelContext::continueHandshake(const SafeByteArray& data)
 		else 
 		{
 			// We failed to initialize the security context
+			handleCertError(status);
 			indicateError();
 			return;
 		}
@@ -293,6 +368,47 @@ void SchannelContext::continueHandshake(const SafeByteArray& data)
 
 //------------------------------------------------------------------------
 
+void SchannelContext::handleCertError(SECURITY_STATUS status) 
+{
+	if (status == SEC_E_UNTRUSTED_ROOT			|| 
+		status == CERT_E_UNTRUSTEDROOT			||
+		status == CRYPT_E_ISSUER_SERIALNUMBER	|| 
+		status == CRYPT_E_SIGNER_NOT_FOUND		||
+		status == CRYPT_E_NO_TRUSTED_SIGNER)
+	{
+		m_verificationError = CertificateVerificationError::Untrusted;
+	}
+	else if (status == SEC_E_CERT_EXPIRED || 
+			 status == CERT_E_EXPIRED)
+	{
+		m_verificationError = CertificateVerificationError::Expired;
+	}
+	else if (status == CRYPT_E_SELF_SIGNED)
+	{
+		m_verificationError = CertificateVerificationError::SelfSigned;
+	}
+	else if (status == CRYPT_E_HASH_VALUE	||
+			 status == TRUST_E_CERT_SIGNATURE)
+	{
+		m_verificationError = CertificateVerificationError::InvalidSignature;
+	}
+	else if (status == CRYPT_E_REVOKED)
+	{
+		m_verificationError = CertificateVerificationError::Revoked;
+	}
+	else if (status == CRYPT_E_NO_REVOCATION_CHECK ||
+			 status == CRYPT_E_REVOCATION_OFFLINE)
+	{
+		m_verificationError = CertificateVerificationError::RevocationCheckFailed;
+	}
+	else
+	{
+		m_verificationError = CertificateVerificationError::UnknownError;
+	}
+}
+
+//------------------------------------------------------------------------
+
 void SchannelContext::sendDataOnNetwork(const void* pData, size_t dataSize) 
 {
 	if (dataSize > 0 && pData) 
@@ -449,6 +565,9 @@ void SchannelContext::decryptAndProcessData(const SafeByteArray& data)
 
 void SchannelContext::encryptAndSendData(const SafeByteArray& data) 
 {	
+	if (m_streamSizes.cbMaximumMessage == 0)
+		return;
+
 	SecBuffer outBuffers[4]	= {0};
 
 	// Calculate the largest required size of the send buffer
@@ -544,8 +663,8 @@ CertificateVerificationError::ref SchannelContext::getPeerCertificateVerificatio
 {
 	boost::shared_ptr<CertificateVerificationError> pCertError;
 
-	if (m_state == Error)
-		pCertError.reset( new CertificateVerificationError(m_verificationError) );
+	if (m_verificationError)
+		pCertError.reset( new CertificateVerificationError(*m_verificationError) );
 	
 	return pCertError;
 }
diff --git a/Swiften/TLS/Schannel/SchannelContext.h b/Swiften/TLS/Schannel/SchannelContext.h
index 7726c41..70b0694 100644
--- a/Swiften/TLS/Schannel/SchannelContext.h
+++ b/Swiften/TLS/Schannel/SchannelContext.h
@@ -51,6 +51,7 @@ namespace Swift
 		void			determineStreamSizes();
 		void			continueHandshake(const SafeByteArray& data);
 		void			indicateError();
+		void			handleCertError(SECURITY_STATUS status) ;
 
 		void			sendDataOnNetwork(const void* pData, size_t dataSize);
 		void			forwardDataToApplication(const void* pData, size_t dataSize);
@@ -59,6 +60,7 @@ namespace Swift
 		void			encryptAndSendData(const SafeByteArray& data);
 
 		void			appendNewData(const SafeByteArray& data);
+		SECURITY_STATUS validateServerCertificate();
 
 	private:
 		enum SchannelState
@@ -71,7 +73,7 @@ namespace Swift
 		};
 
 		SchannelState		m_state;
-		CertificateVerificationError m_verificationError;
+		boost::optional<CertificateVerificationError> m_verificationError;
 
 		ULONG				m_secContext;
 		ScopedCredHandle	m_credHandle;
diff --git a/Swiften/TLS/Schannel/SchannelUtil.h b/Swiften/TLS/Schannel/SchannelUtil.h
index 0a54f16..4f73aac 100644
--- a/Swiften/TLS/Schannel/SchannelUtil.h
+++ b/Swiften/TLS/Schannel/SchannelUtil.h
@@ -246,7 +246,7 @@ namespace Swift
 		}
 
 		// Copy constructor
-		explicit ScopedCertContext(const ScopedCertContext& rhs)
+		ScopedCertContext(const ScopedCertContext& rhs)
 		{
 			m_pHandle = rhs.m_pHandle;			
 		}
@@ -267,6 +267,11 @@ namespace Swift
 			return m_pHandle->m_pCertCtxt;
 		}
 
+		PCCERT_CONTEXT* GetPointer() const
+		{
+			return &m_pHandle->m_pCertCtxt;
+		}
+
 		PCCERT_CONTEXT operator->() const
 		{
 			return m_pHandle->m_pCertCtxt;
@@ -283,6 +288,132 @@ namespace Swift
 			return *this;
 		}
 
+		ScopedCertContext& operator=(PCCERT_CONTEXT pCertCtxt)
+		{
+			// Only update the internal handle if it's different
+			if (m_pHandle && m_pHandle->m_pCertCtxt != pCertCtxt)
+				m_pHandle.reset( new HandleContext(pCertCtxt) );
+
+			return *this;
+		}
+
+		void FreeContext()
+		{
+			m_pHandle.reset( new HandleContext );
+		}
+
+	private:
+		boost::shared_ptr<HandleContext> m_pHandle;		
+	};
+
+	//------------------------------------------------------------------------
+
+	//
+	// Convenience wrapper around the Schannel HCERTSTORE.
+	//
+	class ScopedCertStore : boost::noncopyable
+	{
+	public:
+		ScopedCertStore(HCERTSTORE hCertStore)
+		: m_hCertStore(hCertStore)
+		{
+		}
+		
+		~ScopedCertStore()
+		{
+			// Forcefully free all memory related to the store, i.e. we assume all CertContext's that have been opened via this
+			// cert store have been closed at this point.
+			if (m_hCertStore)
+				CertCloseStore(m_hCertStore, CERT_CLOSE_STORE_FORCE_FLAG);
+		}
+
+		operator HCERTSTORE() const
+		{
+			return m_hCertStore;
+		}
+
+	private:
+		HCERTSTORE m_hCertStore;
+	};
+
+	//------------------------------------------------------------------------
+
+	//
+	// Convenience wrapper around the Schannel CERT_CHAIN_CONTEXT.
+	//
+	class ScopedCertChainContext
+	{
+	private:
+		struct HandleContext
+		{
+			HandleContext()
+			: m_pCertChainCtxt(NULL)
+			{
+			}
+
+			HandleContext(PCCERT_CHAIN_CONTEXT pCert)
+			: m_pCertChainCtxt(pCert)
+			{
+			}
+
+			~HandleContext()
+			{
+				if (m_pCertChainCtxt)
+					CertFreeCertificateChain(m_pCertChainCtxt);
+			}
+
+			PCCERT_CHAIN_CONTEXT m_pCertChainCtxt;
+		};
+
+	public:
+		ScopedCertChainContext()
+		: m_pHandle( new HandleContext )
+		{
+		}
+
+		explicit ScopedCertChainContext(PCCERT_CHAIN_CONTEXT pCert)
+		: m_pHandle( new HandleContext(pCert) )
+		{
+		}
+
+		// Copy constructor
+		ScopedCertChainContext(const ScopedCertChainContext& rhs)
+		{
+			m_pHandle = rhs.m_pHandle;			
+		}
+
+		~ScopedCertChainContext()
+		{
+			m_pHandle.reset();
+		}
+
+		PCCERT_CHAIN_CONTEXT* Reset()
+		{
+			FreeContext();
+			return &m_pHandle->m_pCertChainCtxt;
+		}
+
+		operator PCCERT_CHAIN_CONTEXT() const
+		{
+			return m_pHandle->m_pCertChainCtxt;
+		}
+
+		PCCERT_CHAIN_CONTEXT operator->() const
+		{
+			return m_pHandle->m_pCertChainCtxt;
+		}
+
+		ScopedCertChainContext& operator=(const ScopedCertChainContext& sh)
+		{
+			// Only update the internal handle if it's different
+			if (&m_pHandle->m_pCertChainCtxt != &sh.m_pHandle->m_pCertChainCtxt)
+			{
+				m_pHandle = sh.m_pHandle;
+			}
+
+			return *this;
+		}
+
 		void FreeContext()
 		{
 			m_pHandle.reset( new HandleContext );
@@ -290,5 +421,5 @@ namespace Swift
 
 	private:
 		boost::shared_ptr<HandleContext> m_pHandle;		
-	};	 
+	};
 }
-- 
cgit v0.10.2-6-g49f6