From 110eb87e848b85dd74a6f19413c775520a75ea35 Mon Sep 17 00:00:00 2001
From: Alexey Melnikov <alexey.melnikov@isode.com>
Date: Mon, 13 Feb 2012 17:54:23 +0000
Subject: Initial implementation of using CAPI certificates with Schannel.

Introduced a new parent class for all certificates with keys
(class CertificateWithKey is the new parent for PKCS12Certificate.)
Switched to using "CertificateWithKey *" instead of "const CertificateWithKey&"
Added calling of a Windows dialog for certificate selection when Schannel
TLS implementation is used.

This compiles, but is not tested.

License: This patch is BSD-licensed, see Documentation/Licenses/BSD-simplified.txt for details.

diff --git a/Swift/QtUI/CAPICertificateSelector.cpp b/Swift/QtUI/CAPICertificateSelector.cpp
new file mode 100644
index 0000000..44f5793
--- /dev/null
+++ b/Swift/QtUI/CAPICertificateSelector.cpp
@@ -0,0 +1,138 @@
+/*
+ * Copyright (c) 2012 Isode Limited, London, England.
+ * Licensed under the simplified BSD license.
+ * See Documentation/Licenses/BSD-simplified.txt for more information.
+ */
+
+#include <string>
+
+#define SECURITY_WIN32
+#include <Windows.h>
+#include <WinCrypt.h>
+#include <cryptuiapi.h>
+
+#include "CAPICertificateSelector.h"
+
+namespace Swift {
+
+#define cert_dlg_title L"TLS Client Certificate Selection"
+#define cert_dlg_prompt L"Select a certificate to use for authentication"
+/////Hmm, maybe we should not exlude the "location" column
+#define exclude_columns	 CRYPTUI_SELECT_LOCATION_COLUMN \
+			|CRYPTUI_SELECT_INTENDEDUSE_COLUMN
+
+
+
+static std::string getCertUri(PCCERT_CONTEXT cert, const char * cert_store_name) {
+	DWORD required_size;
+	char * comma;
+	char * p_in;
+	char * p_out;
+	char * subject_name;
+	std::string ret = std::string("certstore:") + cert_store_name + ":";
+
+	required_size = CertNameToStrA(cert->dwCertEncodingType,
+				&cert->pCertInfo->Subject,
+				/* Discard attribute names: */
+				CERT_SIMPLE_NAME_STR | CERT_NAME_STR_REVERSE_FLAG,
+				NULL,
+				0);
+
+	subject_name = static_cast<char *>(malloc(required_size+1));
+
+	if (!CertNameToStrA(cert->dwCertEncodingType,
+			    &cert->pCertInfo->Subject,
+			    /* Discard attribute names: */
+			    CERT_SIMPLE_NAME_STR | CERT_NAME_STR_REVERSE_FLAG,
+			    subject_name,
+			    required_size)) {
+		return "";
+	}
+
+	/* Now search for the "," (ignoring escapes)
+	    and truncate the rest of the string */
+	if (subject_name[0] == '"') {
+		for (comma = subject_name + 1; comma[0]; comma++) {
+			if (comma[0] == '"') {
+				comma++;
+				if (comma[0] != '"') {
+					break;
+				}
+			}
+		}
+	} else {
+		comma = strchr(subject_name, ',');
+	}
+
+	if (comma != NULL) {
+		*comma = '\0';
+	}
+
+	/* We now need to unescape the returned RDN */
+	if (subject_name[0] == '"') {
+		for (p_in = subject_name + 1, p_out = subject_name; p_in[0]; p_in++, p_out++) {
+			if (p_in[0] == '"') {
+				p_in++;
+			}
+
+			p_out[0] = p_in[0];
+		}
+		p_out[0] = '\0';
+	}
+
+	ret += subject_name;
+	free(subject_name);
+
+	return ret;
+}
+
+std::string selectCAPICertificate() {
+
+	const char * cert_store_name = "MY";
+	PCCERT_CONTEXT cert;
+	DWORD store_flags;
+	HCERTSTORE hstore;
+	HWND hwnd;
+
+	store_flags = CERT_STORE_OPEN_EXISTING_FLAG |
+		      CERT_STORE_READONLY_FLAG |
+		      CERT_SYSTEM_STORE_CURRENT_USER;
+
+	hstore = CertOpenStore(CERT_STORE_PROV_SYSTEM_A, 0, 0, store_flags, cert_store_name);
+	if (!hstore) {
+		return "";
+	}
+
+
+////Does this handle need to be freed as well?
+	hwnd = GetForegroundWindow();
+	if (!hwnd) {
+		hwnd = GetActiveWindow();
+	}
+
+	/* Call Windows dialog to select a suitable certificate */
+	cert = CryptUIDlgSelectCertificateFromStore(hstore,
+						  hwnd,
+						  cert_dlg_title,
+						  cert_dlg_prompt,
+						  exclude_columns,
+						  0,
+						  NULL);
+
+	if (hstore) {
+		CertCloseStore(hstore, 0);
+	}
+
+	if (cert) {
+		std::string ret = getCertUri(cert, cert_store_name);
+
+		CertFreeCertificateContext(cert);
+
+		return ret;
+	} else {
+		return "";
+	}
+}
+
+
+}
diff --git a/Swift/QtUI/CAPICertificateSelector.h b/Swift/QtUI/CAPICertificateSelector.h
new file mode 100644
index 0000000..9a0ee92
--- /dev/null
+++ b/Swift/QtUI/CAPICertificateSelector.h
@@ -0,0 +1,13 @@
+/*
+ * Copyright (c) 2012 Isode Limited, London, England.
+ * Licensed under the simplified BSD license.
+ * See Documentation/Licenses/BSD-simplified.txt for more information.
+ */
+
+#pragma once
+
+#include <string>
+
+namespace Swift {
+	std::string selectCAPICertificate();
+}
diff --git a/Swift/QtUI/QtLoginWindow.cpp b/Swift/QtUI/QtLoginWindow.cpp
index 1cd3206..6b9d389 100644
--- a/Swift/QtUI/QtLoginWindow.cpp
+++ b/Swift/QtUI/QtLoginWindow.cpp
@@ -41,6 +41,10 @@
 #include <QtMainWindow.h>
 #include <QtUtilities.h>
 
+#ifdef HAVE_SCHANNEL
+#include "CAPICertificateSelector.h"
+#endif
+
 namespace Swift{
 
 QtLoginWindow::QtLoginWindow(UIEventStream* uiEventStream, SettingsProvider* settings) : QMainWindow(), settings_(settings) {
@@ -357,10 +361,17 @@ void QtLoginWindow::setLoginAutomatically(bool loginAutomatically) {
 
 void QtLoginWindow::handleCertficateChecked(bool checked) {
 	if (checked) {
-		 certificateFile_ = QFileDialog::getOpenFileName(this, tr("Select an authentication certificate"), QString(), QString("*.cert;*.p12;*.pfx"));
-		 if (certificateFile_.isEmpty()) {
-			 certificateButton_->setChecked(false);
-		 }
+#ifdef HAVE_SCHANNEL
+		certificateFile_ = selectCAPICertificate();
+		if (certificateFile_.isEmpty()) {
+			certificateButton_->setChecked(false);
+		}
+#else
+		certificateFile_ = QFileDialog::getOpenFileName(this, tr("Select an authentication certificate"), QString(), QString("*.cert;*.p12;*.pfx"));
+		if (certificateFile_.isEmpty()) {
+			certificateButton_->setChecked(false);
+		}
+#endif
 	}
 	else {
 		certificateFile_ = "";
diff --git a/Swift/QtUI/SConscript b/Swift/QtUI/SConscript
index d37958f..a8b8c78 100644
--- a/Swift/QtUI/SConscript
+++ b/Swift/QtUI/SConscript
@@ -55,6 +55,8 @@ if env["PLATFORM"] == "win32" :
   #myenv["LINKFLAGS"] = ["/SUBSYSTEM:CONSOLE"]
   myenv.Append(LINKFLAGS = ["/SUBSYSTEM:WINDOWS"])
   myenv.Append(LIBS = "qtmain")
+  if myenv.get("HAVE_SCHANNEL", 0) :
+    myenv.Append(LIBS = "Cryptui")
 
 myenv.WriteVal("DefaultTheme.qrc", myenv.Value(generateDefaultTheme(myenv.Dir("#/Swift/resources/themes/Default"))))
 
@@ -151,6 +153,7 @@ if env["PLATFORM"] == "win32" :
 	# Adding it explicitly until i figure out why
   myenv.Depends(res, "../Controllers/BuildVersion.h")
   sources += [
+			"CAPICertificateSelector.cpp",
 			"WindowsNotifier.cpp",
 			"#/Swift/resources/Windows/Swift.res"
 		]
diff --git a/Swiften/Client/CoreClient.cpp b/Swiften/Client/CoreClient.cpp
index de12fb7..36bfe35 100644
--- a/Swiften/Client/CoreClient.cpp
+++ b/Swiften/Client/CoreClient.cpp
@@ -126,6 +126,19 @@ void CoreClient::bindSessionToStream() {
 	session_->start();
 }
 
+bool CoreClient::isCAPIURI() {
+#ifdef HAVE_SCHANNEL
+	if (!boost::iequals(certificate_.substr(0, 10), "certstore:")) {
+		return false;
+	}
+
+	return true;
+
+#else
+	return false;
+#endif
+}
+
 /**
  * Only called for TCP sessions. BOSH is handled inside the BOSHSessionStream.
  */
@@ -144,7 +157,19 @@ void CoreClient::handleConnectorFinished(boost::shared_ptr<Connection> connectio
 		assert(!sessionStream_);
 		sessionStream_ = boost::make_shared<BasicSessionStream>(ClientStreamType, connection_, getPayloadParserFactories(), getPayloadSerializers(), networkFactories->getTLSContextFactory(), networkFactories->getTimerFactory(), networkFactories->getXMLParserFactory());
 		if (!certificate_.empty()) {
-			sessionStream_->setTLSCertificate(PKCS12Certificate(certificate_, password_));
+			CertificateWithKey* cert;
+
+#if defined(SWIFTEN_PLATFORM_WIN32)
+			if (isCAPIURI()) {
+				cert = new CAPICertificate(certificate_);
+			} else {
+				cert = new PKCS12Certificate(certificate_, password_);
+			}
+#else
+			cert = new PKCS12Certificate(certificate_, password_);
+#endif
+
+			sessionStream_->setTLSCertificate(cert);
 		}
 		sessionStream_->onDataRead.connect(boost::bind(&CoreClient::handleDataRead, this, _1));
 		sessionStream_->onDataWritten.connect(boost::bind(&CoreClient::handleDataWritten, this, _1));
diff --git a/Swiften/Client/CoreClient.h b/Swiften/Client/CoreClient.h
index c231fdc..6712e03 100644
--- a/Swiften/Client/CoreClient.h
+++ b/Swiften/Client/CoreClient.h
@@ -196,6 +196,8 @@ namespace Swift {
 			 */
 			virtual void handleConnected() {};
 
+			bool isCAPIURI();
+
 		private:
 			void handleConnectorFinished(boost::shared_ptr<Connection>);
 			void handleStanzaChannelAvailableChanged(bool available);
diff --git a/Swiften/Session/SessionStream.cpp b/Swiften/Session/SessionStream.cpp
index 0d73b63..487ad8b 100644
--- a/Swiften/Session/SessionStream.cpp
+++ b/Swiften/Session/SessionStream.cpp
@@ -9,6 +9,7 @@
 namespace Swift {
 
 SessionStream::~SessionStream() {
+	delete certificate;
 }
 
 };
diff --git a/Swiften/Session/SessionStream.h b/Swiften/Session/SessionStream.h
index 096f185..58015b3 100644
--- a/Swiften/Session/SessionStream.h
+++ b/Swiften/Session/SessionStream.h
@@ -14,7 +14,7 @@
 #include <Swiften/Elements/Element.h>
 #include <Swiften/Base/Error.h>
 #include <Swiften/Base/SafeByteArray.h>
-#include <Swiften/TLS/PKCS12Certificate.h>
+#include <Swiften/TLS/CertificateWithKey.h>
 #include <Swiften/TLS/Certificate.h>
 #include <Swiften/TLS/CertificateVerificationError.h>
 
@@ -36,6 +36,8 @@ namespace Swift {
 					Type type;
 			};
 
+			SessionStream(): certificate(0) {}
+
 			virtual ~SessionStream();
 
 			virtual void close() = 0;
@@ -56,12 +58,12 @@ namespace Swift {
 
 			virtual void resetXMPPParser() = 0;
 
-			void setTLSCertificate(const PKCS12Certificate& cert) {
+			void setTLSCertificate(CertificateWithKey* cert) {
 				certificate = cert;
 			}
 
 			virtual bool hasTLSCertificate() {
-				return !certificate.isNull();
+				return certificate && !certificate->isNull();
 			}
 
 			virtual Certificate::ref getPeerCertificate() const = 0;
@@ -77,11 +79,11 @@ namespace Swift {
 			boost::signal<void (const SafeByteArray&)> onDataWritten;
 
 		protected:
-			const PKCS12Certificate& getTLSCertificate() const {
+			CertificateWithKey * getTLSCertificate() const {
 				return certificate;
 			}
 
 		private:
-			PKCS12Certificate certificate;
+			CertificateWithKey * certificate;
 	};
 }
diff --git a/Swiften/StreamStack/TLSLayer.cpp b/Swiften/StreamStack/TLSLayer.cpp
index 6f2223d..b7efbcb 100644
--- a/Swiften/StreamStack/TLSLayer.cpp
+++ b/Swiften/StreamStack/TLSLayer.cpp
@@ -37,7 +37,7 @@ void TLSLayer::handleDataRead(const SafeByteArray& data) {
 	context->handleDataFromNetwork(data);
 }
 
-bool TLSLayer::setClientCertificate(const PKCS12Certificate& certificate) {
+bool TLSLayer::setClientCertificate(CertificateWithKey * certificate) {
 	return context->setClientCertificate(certificate);
 }
 
diff --git a/Swiften/StreamStack/TLSLayer.h b/Swiften/StreamStack/TLSLayer.h
index a8693d5..6dc9135 100644
--- a/Swiften/StreamStack/TLSLayer.h
+++ b/Swiften/StreamStack/TLSLayer.h
@@ -14,7 +14,7 @@
 namespace Swift {
 	class TLSContext;
 	class TLSContextFactory;
-	class PKCS12Certificate;
+	class CertificateWithKey;
 
 	class TLSLayer : public StreamLayer {
 		public:
@@ -22,7 +22,7 @@ namespace Swift {
 			~TLSLayer();
 
 			void connect();
-			bool setClientCertificate(const PKCS12Certificate&);
+			bool setClientCertificate(CertificateWithKey * cert);
 
 			Certificate::ref getPeerCertificate() const;
 			boost::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const;
diff --git a/Swiften/TLS/CAPICertificate.h b/Swiften/TLS/CAPICertificate.h
new file mode 100644
index 0000000..fcdb4c2
--- /dev/null
+++ b/Swiften/TLS/CAPICertificate.h
@@ -0,0 +1,196 @@
+/*
+ * Copyright (c) 2012 Isode Limited, London, England.
+ * Licensed under the simplified BSD license.
+ * See Documentation/Licenses/BSD-simplified.txt for more information.
+ */
+
+#pragma once
+
+#include <Swiften/Base/SafeByteArray.h>
+#include <Swiften/TLS/CertificateWithKey.h>
+
+#include <boost/algorithm/string/predicate.hpp>
+
+#define SECURITY_WIN32
+#include <WinCrypt.h>
+
+namespace Swift {
+	class CAPICertificate : public Swift::CertificateWithKey {
+		public:
+			CAPICertificate(const std::string& capiUri)
+			    : valid_(false), uri_(capiUri), cert_store_handle_(0), cert_store_(NULL), cert_name_(NULL) {
+				setUri(capiUri);
+			}
+
+			virtual ~CAPICertificate() {
+				if (cert_store_handle_ != NULL)
+				{
+					CertCloseStore(cert_store_handle_, 0);
+				}
+			}
+
+			virtual bool isNull() const {
+				return uri_.empty() || !valid_;
+			}
+
+			virtual bool isPrivateKeyExportable() const {
+				/* We can check with CAPI, but for now the answer is "no" */
+				return false;
+			}
+
+			virtual const std::string& getCertStoreName() const {
+			    return cert_store_;
+			}
+
+			virtual const std::string& getCertName() const {
+			    return cert_name_;
+			}
+
+			const ByteArray& getData() const {
+////Might need to throw an exception here, or really generate PKCS12 blob from CAPI data?
+				assert(0);
+			}
+
+			void setData(const ByteArray& data) {
+				assert(0);
+			}
+
+			const SafeByteArray& getPassword() const {
+/////Can't pass NULL to createSafeByteArray!
+/////Should this throw an exception instead?
+				return createSafeByteArray("");
+			}
+
+		protected:
+			void setUri (const std::string& capiUri) {
+
+				valid_ = false;
+
+				/* Syntax: "certstore:" [<cert_store> ":"] <cert_id> */
+
+				if (!boost::iequals(capiUri.substr(0, 10), "certstore:")) {
+					return;
+				}
+
+				/* Substring of subject: uses "storename" */
+				std::string capi_identity = capiUri.substr(10);
+				std::string new_cert_store_name;
+				size_t pos = capi_identity.find_first_of (':');
+
+				if (pos == std::string::npos) {
+					/* Using the default certificate store */
+					new_cert_store_name = "MY";
+					cert_name_ = capi_identity;
+				} else {
+					new_cert_store_name = capi_identity.substr(0, pos);
+					cert_name_ = capi_identity.substr(pos + 1);
+				}
+
+				PCCERT_CONTEXT pCertContext = NULL;
+
+				if (cert_store_handle_ != NULL)
+				{
+					if (new_cert_store_name != cert_store_) {
+						CertCloseStore(cert_store_handle_, 0);
+						cert_store_handle_ = NULL;
+					}
+				}
+
+				if (cert_store_handle_ == NULL)
+				{
+					cert_store_handle_ = CertOpenSystemStore(0, cert_store_.c_str());
+					if (!cert_store_handle_)
+					{
+						return;
+					}
+				}
+
+				cert_store_ = new_cert_store_name;
+
+				/* NB: This might have to change, depending on how we locate certificates */
+
+				// Find client certificate. Note that this sample just searches for a 
+				// certificate that contains the user name somewhere in the subject name.
+				pCertContext = CertFindCertificateInStore(cert_store_handle_,
+					X509_ASN_ENCODING,
+					0,				// dwFindFlags
+					CERT_FIND_SUBJECT_STR_A,
+					cert_name_.c_str(),		// *pvFindPara
+					NULL );				// pPrevCertContext
+
+				if (pCertContext == NULL)
+				{
+					return;
+				}
+
+
+				/* Now verify that we can have access to the corresponding private key */
+
+				DWORD len;
+				CRYPT_KEY_PROV_INFO *pinfo;
+				HCRYPTPROV hprov;
+				HCRYPTKEY key;
+
+				if (!CertGetCertificateContextProperty(pCertContext,
+								       CERT_KEY_PROV_INFO_PROP_ID,
+								       NULL,
+								       &len))
+				{
+					CertFreeCertificateContext(pCertContext);
+					return;
+				}
+
+				pinfo = static_cast<CRYPT_KEY_PROV_INFO *>(malloc(len));
+				if (!pinfo) {
+					CertFreeCertificateContext(pCertContext);
+					return;
+				}
+
+				if (!CertGetCertificateContextProperty(pCertContext,
+								       CERT_KEY_PROV_INFO_PROP_ID,
+								       pinfo,
+								       &len))
+				{
+					CertFreeCertificateContext(pCertContext);
+					free(pinfo);
+					return;
+				}
+
+				CertFreeCertificateContext(pCertContext);
+
+				// Now verify if we have access to the private key
+				if (!CryptAcquireContextW(&hprov,
+							  pinfo->pwszContainerName,
+							  pinfo->pwszProvName,
+							  pinfo->dwProvType,
+							  0))
+				{
+					free(pinfo);
+					return;
+				}
+
+				if (!CryptGetUserKey(hprov, pinfo->dwKeySpec, &key))
+				{
+					CryptReleaseContext(hprov, 0);
+					free(pinfo);
+					return;
+				}
+
+				CryptDestroyKey(key);
+				CryptReleaseContext(hprov, 0);
+				free(pinfo);
+
+				valid_ = true;
+			}
+
+		private:
+			bool valid_;
+			std::string uri_;
+
+			HCERTSTORE cert_store_handle_;
+
+			/* Parsed components of the uri_ */
+			std::string cert_store_;
+			std::string cert_name_;
+	};
+}
diff --git a/Swiften/TLS/CertificateWithKey.h b/Swiften/TLS/CertificateWithKey.h
new file mode 100644
index 0000000..6f6ea39
--- /dev/null
+++ b/Swiften/TLS/CertificateWithKey.h
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) 2010-2012 Remko Tronçon
+ * Licensed under the GNU General Public License v3.
+ * See Documentation/Licenses/GPLv3.txt for more information.
+ */
+
+#pragma once
+
+#include <Swiften/Base/SafeByteArray.h>
+
+namespace Swift {
+	class CertificateWithKey {
+		public:
+			CertificateWithKey() {}
+
+			virtual ~CertificateWithKey() {}
+
+			virtual bool isNull() const = 0;
+
+			virtual bool isPrivateKeyExportable() const = 0;
+
+			virtual const std::string& getCertStoreName() const = 0;
+
+			virtual const std::string& getCertName() const = 0;
+
+			virtual const ByteArray& getData() const = 0;
+
+			virtual void setData(const ByteArray& data) = 0;
+
+			virtual const SafeByteArray& getPassword() const = 0;
+	};
+}
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
index 220e7f9..dd3462f 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
+++ b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
@@ -21,7 +21,7 @@
 
 #include <Swiften/TLS/OpenSSL/OpenSSLContext.h>
 #include <Swiften/TLS/OpenSSL/OpenSSLCertificate.h>
-#include <Swiften/TLS/PKCS12Certificate.h>
+#include <Swiften/TLS/CertificateWithKey.h>
 
 #pragma GCC diagnostic ignored "-Wold-style-cast"
 
@@ -185,14 +185,18 @@ void OpenSSLContext::sendPendingDataToApplication() {
 	}
 }
 
-bool OpenSSLContext::setClientCertificate(const PKCS12Certificate& certificate) {
-	if (certificate.isNull()) {
+bool OpenSSLContext::setClientCertificate(CertificateWithKey * certificate) {
+	if (!certificate || certificate->isNull()) {
+		return false;
+	}
+
+	if (!certificate->isPrivateKeyExportable()) {
 		return false;
 	}
 
 	// Create a PKCS12 structure
 	BIO* bio = BIO_new(BIO_s_mem());
-	BIO_write(bio, vecptr(certificate.getData()), certificate.getData().size());
+	BIO_write(bio, vecptr(certificate->getData()), certificate->getData().size());
 	boost::shared_ptr<PKCS12> pkcs12(d2i_PKCS12_bio(bio, NULL), PKCS12_free);
 	BIO_free(bio);
 	if (!pkcs12) {
@@ -203,7 +207,7 @@ bool OpenSSLContext::setClientCertificate(const PKCS12Certificate& certificate)
 	X509 *certPtr = 0;
 	EVP_PKEY* privateKeyPtr = 0;
 	STACK_OF(X509)* caCertsPtr = 0;
-	int result = PKCS12_parse(pkcs12.get(), reinterpret_cast<const char*>(vecptr(certificate.getPassword())), &privateKeyPtr, &certPtr, &caCertsPtr);
+	int result = PKCS12_parse(pkcs12.get(), reinterpret_cast<const char*>(vecptr(certificate->getPassword())), &privateKeyPtr, &certPtr, &caCertsPtr);
 	if (result != 1) { 
 		return false;
 	}
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.h b/Swiften/TLS/OpenSSL/OpenSSLContext.h
index 04693a3..b53e715 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLContext.h
+++ b/Swiften/TLS/OpenSSL/OpenSSLContext.h
@@ -14,7 +14,7 @@
 #include <Swiften/Base/ByteArray.h>
 
 namespace Swift {
-	class PKCS12Certificate;
+	class CertificateWithKey;
 
 	class OpenSSLContext : public TLSContext, boost::noncopyable {
 		public:
@@ -22,7 +22,7 @@ namespace Swift {
 			~OpenSSLContext();
 
 			void connect();
-			bool setClientCertificate(const PKCS12Certificate& cert);
+			bool setClientCertificate(CertificateWithKey * cert);
 
 			void handleDataFromNetwork(const SafeByteArray&);
 			void handleDataFromApplication(const SafeByteArray&);
diff --git a/Swiften/TLS/PKCS12Certificate.h b/Swiften/TLS/PKCS12Certificate.h
index c0e01d0..2f70456 100644
--- a/Swiften/TLS/PKCS12Certificate.h
+++ b/Swiften/TLS/PKCS12Certificate.h
@@ -7,9 +7,10 @@
 #pragma once
 
 #include <Swiften/Base/SafeByteArray.h>
+#include <Swiften/TLS/CertificateWithKey.h>
 
 namespace Swift {
-	class PKCS12Certificate {
+	class PKCS12Certificate : public Swift::CertificateWithKey {
 		public:
 			PKCS12Certificate() {}
 
@@ -17,11 +18,29 @@ namespace Swift {
 				readByteArrayFromFile(data_, filename);
 			}
 
-			bool isNull() const {
+			virtual ~PKCS12Certificate() {}
+
+			virtual bool isNull() const {
 				return data_.empty();
 			}
 
-			const ByteArray& getData() const {
+			virtual bool isPrivateKeyExportable() const {
+/////Hopefully a PKCS12 is never missing a private key
+				return true;
+			}
+
+			virtual const std::string& getCertStoreName() const {
+/////				assert(0);
+				throw std::exception();
+			}
+
+			virtual const std::string& getCertName() const {
+				/* We can return the original filename instead, if we care */
+/////				assert(0);
+				throw std::exception();
+			}
+
+			virtual const ByteArray& getData() const {
 				return data_;
 			}
 
@@ -29,7 +48,7 @@ namespace Swift {
 				data_ = data;
 			}
 
-			const SafeByteArray& getPassword() const {
+			virtual const SafeByteArray& getPassword() const {
 				return password_;
 			}
 
diff --git a/Swiften/TLS/Schannel/SchannelContext.cpp b/Swiften/TLS/Schannel/SchannelContext.cpp
index 6771d4a..6f50b3a 100644
--- a/Swiften/TLS/Schannel/SchannelContext.cpp
+++ b/Swiften/TLS/Schannel/SchannelContext.cpp
@@ -15,6 +15,9 @@ SchannelContext::SchannelContext()
 : m_state(Start)
 , m_secContext(0)
 , m_verificationError(CertificateVerificationError::UnknownError)
+, m_my_cert_store(NULL)
+, m_cert_store_name("MY")
+, m_cert_name(NULL)
 {
 	m_ctxtFlags = ISC_REQ_ALLOCATE_MEMORY | 
 				  ISC_REQ_CONFIDENTIALITY |
@@ -30,6 +33,13 @@ SchannelContext::SchannelContext()
 
 //------------------------------------------------------------------------
 
+SchannelContext::~SchannelContext()
+{
+	if (m_my_cert_store) CertCloseStore(m_my_cert_store, 0);
+}
+
+//------------------------------------------------------------------------
+
 void SchannelContext::determineStreamSizes()
 {
 	QueryContextAttributes(m_ctxtHandle, SECPKG_ATTR_STREAM_SIZES, &m_streamSizes);
@@ -39,17 +49,65 @@ void SchannelContext::determineStreamSizes()
 
 void SchannelContext::connect() 
 {
+	PCCERT_CONTEXT   pCertContext = NULL;
+
 	m_state = Connecting;
 
+	// If a user name is specified, then attempt to find a client
+	// certificate. Otherwise, just create a NULL credential.
+	if (!m_cert_name.empty())
+	{
+		if (m_my_cert_store == NULL)
+		{
+			m_my_cert_store = CertOpenSystemStore(0, m_cert_store_name.c_str());
+			if (!m_my_cert_store)
+			{
+/////			printf( "**** Error 0x%x returned by CertOpenSystemStore\n", GetLastError() );
+				indicateError();
+				return;
+			}
+		}
+
+		// Find client certificate. Note that this sample just searches for a 
+		// certificate that contains the user name somewhere in the subject name.
+		pCertContext = CertFindCertificateInStore( m_my_cert_store,
+			X509_ASN_ENCODING,
+			0,				// dwFindFlags
+			CERT_FIND_SUBJECT_STR_A,
+			m_cert_name.c_str(),		// *pvFindPara
+			NULL );				// pPrevCertContext
+
+		if (pCertContext == NULL)
+		{
+/////		printf("**** Error 0x%x returned by CertFindCertificateInStore\n", GetLastError());
+			indicateError();
+			return;
+		}
+	}
+
 	// We use an empty list for client certificates
 	PCCERT_CONTEXT clientCerts[1] = {0};
 
 	SCHANNEL_CRED sc = {0};
 	sc.dwVersion = SCHANNEL_CRED_VERSION;
-	sc.cCreds = 0; // Let Crypto API find the appropriate certificate for us
-	sc.paCred = clientCerts;
+
+/////SSL3?
 	sc.grbitEnabledProtocols = SP_PROT_SSL3_CLIENT | SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_2_CLIENT;
-	sc.dwFlags = SCH_CRED_AUTO_CRED_VALIDATION | /*SCH_CRED_NO_DEFAULT_CREDS*/ SCH_CRED_USE_DEFAULT_CREDS | SCH_CRED_REVOCATION_CHECK_CHAIN;
+/////Check SCH_CRED_REVOCATION_CHECK_CHAIN
+	sc.dwFlags = SCH_CRED_AUTO_CRED_VALIDATION | SCH_CRED_REVOCATION_CHECK_CHAIN;
+
+	if (pCertContext)
+	{
+		sc.cCreds = 1;
+		sc.paCred = &pCertContext;
+		sc.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS;
+	}
+	else
+	{
+		sc.cCreds = 0; // Let Crypto API find the appropriate certificate for us
+		sc.paCred = clientCerts;
+		sc.dwFlags |= SCH_CRED_USE_DEFAULT_CREDS;
+	}
 
 	// Swiften performs the server name check for us
 	sc.dwFlags |= SCH_CRED_NO_SERVERNAME_CHECK;
@@ -65,6 +123,9 @@ void SchannelContext::connect()
 		m_credHandle.Reset(),
 		NULL);
 
+	// cleanup: Free the certificate context. Schannel has already made its own copy.
+	if (pCertContext) CertFreeCertificateContext(pCertContext);
+
 	if (status != SEC_E_OK) 
 	{
 		// We failed to obtain the credentials handle
@@ -456,8 +517,21 @@ void SchannelContext::encryptAndSendData(const SafeByteArray& data)
 
 //------------------------------------------------------------------------
 
-bool SchannelContext::setClientCertificate(const PKCS12Certificate& certificate) 
+bool SchannelContext::setClientCertificate(CertificateWithKey * certificate)
 {
+	if (!certificate || certificate->isNull()) {
+		return false;
+	}
+
+	if (!certificate->isPrivateKeyExportable()) {
+		// We assume that the Certificate Store Name/Certificate Name
+		// are valid at this point
+		m_cert_store_name = certificate->getCertStoreName();
+		m_cert_name = certificate->getCertName();
+
+		return true;
+	}
+
 	return false;
 }
 
diff --git a/Swiften/TLS/Schannel/SchannelContext.h b/Swiften/TLS/Schannel/SchannelContext.h
index 66467fe..0cdb3d7 100644
--- a/Swiften/TLS/Schannel/SchannelContext.h
+++ b/Swiften/TLS/Schannel/SchannelContext.h
@@ -10,6 +10,7 @@
 
 #include "Swiften/TLS/TLSContext.h"
 #include "Swiften/TLS/Schannel/SchannelUtil.h"
+#include <Swiften/TLS/CertificateWithKey.h>
 #include "Swiften/Base/ByteArray.h"
 
 #define SECURITY_WIN32
@@ -28,13 +29,15 @@ namespace Swift
 		typedef boost::shared_ptr<SchannelContext> sp_t;
 
 	public:
-						SchannelContext();
+		SchannelContext();
+
+		~SchannelContext();
 
 		//
 		// TLSContext
 		//
 		virtual void	connect();
-		virtual bool	setClientCertificate(const PKCS12Certificate&);
+		virtual bool	setClientCertificate(CertificateWithKey * cert);
 
 		virtual void	handleDataFromNetwork(const SafeByteArray& data);
 		virtual void	handleDataFromApplication(const SafeByteArray& data);
@@ -77,5 +80,9 @@ namespace Swift
 		SecPkgContext_StreamSizes m_streamSizes;
 
 		std::vector<char>	m_receivedData;
+
+		HCERTSTORE		m_my_cert_store;
+		std::string		m_cert_store_name;
+		std::string		m_cert_name;
 	};
 }
diff --git a/Swiften/TLS/TLSContext.h b/Swiften/TLS/TLSContext.h
index 1538863..ada813a 100644
--- a/Swiften/TLS/TLSContext.h
+++ b/Swiften/TLS/TLSContext.h
@@ -14,7 +14,7 @@
 #include <Swiften/TLS/CertificateVerificationError.h>
 
 namespace Swift {
-	class PKCS12Certificate;
+	class CertificateWithKey;
 
 	class TLSContext {
 		public:
@@ -22,7 +22,7 @@ namespace Swift {
 
 			virtual void connect() = 0;
 
-			virtual bool setClientCertificate(const PKCS12Certificate& cert) = 0;
+			virtual bool setClientCertificate(CertificateWithKey * cert) = 0;
 
 			virtual void handleDataFromNetwork(const SafeByteArray&) = 0;
 			virtual void handleDataFromApplication(const SafeByteArray&) = 0;
-- 
cgit v0.10.2-6-g49f6