summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to 'Swiften')
-rw-r--r--Swiften/Base/ByteArray.h2
-rw-r--r--Swiften/EventLoop/SConscript1
-rw-r--r--Swiften/EventLoop/SingleThreadedEventLoop.cpp65
-rw-r--r--Swiften/EventLoop/SingleThreadedEventLoop.h58
-rw-r--r--Swiften/IDN/StringPrep.cpp6
-rw-r--r--Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp8
-rw-r--r--Swiften/TLS/PlatformTLSFactories.cpp11
-rw-r--r--Swiften/TLS/SConscript6
-rw-r--r--Swiften/TLS/Schannel/SchannelCertificate.cpp197
-rw-r--r--Swiften/TLS/Schannel/SchannelCertificate.h81
-rw-r--r--Swiften/TLS/Schannel/SchannelCertificateFactory.h19
-rw-r--r--Swiften/TLS/Schannel/SchannelContext.cpp503
-rw-r--r--Swiften/TLS/Schannel/SchannelContext.h81
-rw-r--r--Swiften/TLS/Schannel/SchannelContextFactory.cpp20
-rw-r--r--Swiften/TLS/Schannel/SchannelContextFactory.h17
-rw-r--r--Swiften/TLS/Schannel/SchannelUtil.h294
16 files changed, 1357 insertions, 12 deletions
diff --git a/Swiften/Base/ByteArray.h b/Swiften/Base/ByteArray.h
index b368ef8..01cd5d0 100644
--- a/Swiften/Base/ByteArray.h
+++ b/Swiften/Base/ByteArray.h
@@ -36,7 +36,7 @@ namespace Swift {
static T* vecptr(std::vector<T, A>& v) {
return v.empty() ? NULL : &v[0];
}
-
+
std::string byteArrayToString(const ByteArray& b);
void readByteArrayFromFile(ByteArray&, const std::string& file);
diff --git a/Swiften/EventLoop/SConscript b/Swiften/EventLoop/SConscript
index e448f43..b405f6b 100644
--- a/Swiften/EventLoop/SConscript
+++ b/Swiften/EventLoop/SConscript
@@ -6,6 +6,7 @@ sources = [
"Event.cpp",
"SimpleEventLoop.cpp",
"DummyEventLoop.cpp",
+ "SingleThreadedEventLoop.cpp",
]
objects = swiften_env.SwiftenObject(sources)
diff --git a/Swiften/EventLoop/SingleThreadedEventLoop.cpp b/Swiften/EventLoop/SingleThreadedEventLoop.cpp
new file mode 100644
index 0000000..4c5e209
--- /dev/null
+++ b/Swiften/EventLoop/SingleThreadedEventLoop.cpp
@@ -0,0 +1,65 @@
+/*
+ * Copyright (c) 2010 Soren Dreijer
+ * Licensed under the simplified BSD license.
+ * See Documentation/Licenses/BSD-simplified.txt for more information.
+ */
+
+#include "Swiften/EventLoop/SingleThreadedEventLoop.h"
+
+#include <boost/bind.hpp>
+#include <iostream>
+
+#include "Swiften/Base/foreach.h"
+
+
+namespace Swift {
+
+SingleThreadedEventLoop::SingleThreadedEventLoop()
+: shouldShutDown_(false)
+{
+}
+
+SingleThreadedEventLoop::~SingleThreadedEventLoop() {
+ if (!events_.empty()) {
+ std::cerr << "Warning: Pending events in SingleThreadedEventLoop at destruction time." << std::endl;
+ }
+}
+
+void SingleThreadedEventLoop::waitForEvents() {
+ boost::unique_lock<boost::mutex> lock(eventsMutex_);
+ while (events_.size() == 0 && !shouldShutDown_) {
+ eventsAvailable_.wait(lock);
+ }
+
+ if (shouldShutDown_)
+ throw EventLoopCanceledException();
+}
+
+void SingleThreadedEventLoop::handleEvents() {
+ // Make a copy of the list of events so we don't block any threads that post
+ // events while we process them.
+ std::vector<Event> events;
+ {
+ boost::unique_lock<boost::mutex> lock(eventsMutex_);
+ events.swap(events_);
+ }
+
+ // Loop through all the events and handle them
+ foreach(const Event& event, events) {
+ handleEvent(event);
+ }
+}
+
+void SingleThreadedEventLoop::stop() {
+ boost::unique_lock<boost::mutex> lock(eventsMutex_);
+ shouldShutDown_ = true;
+ eventsAvailable_.notify_one();
+}
+
+void SingleThreadedEventLoop::post(const Event& event) {
+ boost::lock_guard<boost::mutex> lock(eventsMutex_);
+ events_.push_back(event);
+ eventsAvailable_.notify_one();
+}
+
+} // namespace Swift
diff --git a/Swiften/EventLoop/SingleThreadedEventLoop.h b/Swiften/EventLoop/SingleThreadedEventLoop.h
new file mode 100644
index 0000000..75ffad0
--- /dev/null
+++ b/Swiften/EventLoop/SingleThreadedEventLoop.h
@@ -0,0 +1,58 @@
+/*
+ * Copyright (c) 2010 Soren Dreijer
+ * Licensed under the simplified BSD license.
+ * See Documentation/Licenses/BSD-simplified.txt for more information.
+ */
+
+#pragma once
+
+#include <vector>
+#include <boost/thread/mutex.hpp>
+#include <boost/thread/condition_variable.hpp>
+
+#include "Swiften/EventLoop/EventLoop.h"
+
+// DESCRIPTION:
+//
+// All interaction with Swiften should happen on the same thread, such as the main GUI thread,
+// since the library isn't thread-safe.
+// For applications that don't have a main loop, such as WPF and MFC applications, we need a
+// different approach to process events from Swiften.
+//
+// The SingleThreadedEventLoop class implements an event loop that can be used from such applications.
+//
+// USAGE:
+//
+// Spawn a new thread in the desired framework and call SingleThreadedEventLoop::waitForEvents(). The method
+// blocks until a new event has arrived at which time it'll return, or until the wait is canceled
+// at which time it throws EventLoopCanceledException.
+//
+// When a new event has arrived and SingleThreadedEventLoop::waitForEvents() returns, the caller should then
+// call SingleThreadedEventLoop::handleEvents() on the main GUI thread. For WPF applications, for instance,
+// the Dispatcher class can be used to execute the call on the GUI thread.
+//
+
+namespace Swift {
+ class SingleThreadedEventLoop : public EventLoop {
+ public:
+ class EventLoopCanceledException : public std::exception { };
+
+ public:
+ SingleThreadedEventLoop();
+ ~SingleThreadedEventLoop();
+
+ // Blocks while waiting for new events and returns when new events are available.
+ // Throws EventLoopCanceledException when the wait is canceled.
+ void waitForEvents();
+ void handleEvents();
+ void stop();
+
+ virtual void post(const Event& event);
+
+ private:
+ bool shouldShutDown_;
+ std::vector<Event> events_;
+ boost::mutex eventsMutex_;
+ boost::condition_variable eventsAvailable_;
+ };
+}
diff --git a/Swiften/IDN/StringPrep.cpp b/Swiften/IDN/StringPrep.cpp
index db09523..9085569 100644
--- a/Swiften/IDN/StringPrep.cpp
+++ b/Swiften/IDN/StringPrep.cpp
@@ -6,7 +6,11 @@
#include <Swiften/IDN/StringPrep.h>
-#include <stringprep.h>
+extern "C"
+{
+ #include <stringprep.h>
+};
+
#include <vector>
#include <cassert>
#include <Swiften/Base/SafeAllocator.h>
diff --git a/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp b/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp
index 06ce360..ac36f4f 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp
+++ b/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp
@@ -35,10 +35,6 @@ OpenSSLCertificate::OpenSSLCertificate(const ByteArray& der) {
}
ByteArray OpenSSLCertificate::toDER() const {
- if (!cert) {
- return ByteArray();
- }
-
ByteArray result;
result.resize(i2d_X509(cert.get(), NULL));
unsigned char* p = vecptr(result);
@@ -47,10 +43,6 @@ ByteArray OpenSSLCertificate::toDER() const {
}
void OpenSSLCertificate::parse() {
- if (!cert) {
- return;
- }
-
// Subject name
X509_NAME* subjectName = X509_get_subject_name(cert.get());
if (subjectName) {
diff --git a/Swiften/TLS/PlatformTLSFactories.cpp b/Swiften/TLS/PlatformTLSFactories.cpp
index dec8788..5f57793 100644
--- a/Swiften/TLS/PlatformTLSFactories.cpp
+++ b/Swiften/TLS/PlatformTLSFactories.cpp
@@ -4,14 +4,18 @@
* See Documentation/Licenses/GPLv3.txt for more information.
*/
+#include <Swiften/Base/Platform.h>
#include <Swiften/TLS/PlatformTLSFactories.h>
#include <cstring>
#include <cassert>
#ifdef HAVE_OPENSSL
-#include <Swiften/TLS/OpenSSL/OpenSSLContextFactory.h>
-#include <Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.h>
+ #include "Swiften/TLS/OpenSSL/OpenSSLContextFactory.h"
+ #include "Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.h"
+#elif defined SWIFTEN_PLATFORM_WINDOWS
+ #include "Swiften/TLS/Schannel/SchannelContextFactory.h"
+#include "Swiften/TLS/Schannel/SchannelCertificateFactory.h"
#endif
namespace Swift {
@@ -20,6 +24,9 @@ PlatformTLSFactories::PlatformTLSFactories() : contextFactory(NULL), certificate
#ifdef HAVE_OPENSSL
contextFactory = new OpenSSLContextFactory();
certificateFactory = new OpenSSLCertificateFactory();
+#elif defined SWIFTEN_PLATFORM_WINDOWS
+ contextFactory = new SchannelContextFactory();
+ certificateFactory = new SchannelCertificateFactory();
#endif
}
diff --git a/Swiften/TLS/SConscript b/Swiften/TLS/SConscript
index b5829d6..225aa0a 100644
--- a/Swiften/TLS/SConscript
+++ b/Swiften/TLS/SConscript
@@ -18,6 +18,12 @@ if myenv.get("HAVE_OPENSSL", 0) :
"OpenSSL/OpenSSLContextFactory.cpp",
])
myenv.Append(CPPDEFINES = "HAVE_OPENSSL")
+elif myenv["PLATFORM"] == "win32" :
+ objects += myenv.StaticObject([
+ "Schannel/SchannelContext.cpp",
+ "Schannel/SchannelCertificate.cpp",
+ "Schannel/SchannelContextFactory.cpp",
+ ])
objects += myenv.SwiftenObject(["PlatformTLSFactories.cpp"])
diff --git a/Swiften/TLS/Schannel/SchannelCertificate.cpp b/Swiften/TLS/Schannel/SchannelCertificate.cpp
new file mode 100644
index 0000000..8aaec00
--- /dev/null
+++ b/Swiften/TLS/Schannel/SchannelCertificate.cpp
@@ -0,0 +1,197 @@
+/*
+ * Copyright (c) 2011 Soren Dreijer
+ * Licensed under the simplified BSD license.
+ * See Documentation/Licenses/BSD-simplified.txt for more information.
+ */
+
+#include "Swiften/TLS/Schannel/SchannelCertificate.h"
+#include "Swiften/Base/ByteArray.h"
+
+#define SECURITY_WIN32
+#include <Windows.h>
+#include <Schannel.h>
+#include <security.h>
+#include <schnlsp.h>
+#include <Wincrypt.h>
+
+using std::vector;
+
+namespace Swift {
+
+//------------------------------------------------------------------------
+
+SchannelCertificate::SchannelCertificate(const ScopedCertContext& certCtxt)
+: m_cert(certCtxt)
+{
+ parse();
+}
+
+//------------------------------------------------------------------------
+
+SchannelCertificate::SchannelCertificate(const ByteArray& der)
+{
+ if (!der.empty())
+ {
+ // Convert the DER encoded certificate to a PCERT_CONTEXT
+ CERT_BLOB certBlob = {0};
+ certBlob.cbData = der.size();
+ certBlob.pbData = (BYTE*)&der[0];
+
+ if (!CryptQueryObject(
+ CERT_QUERY_OBJECT_BLOB,
+ &certBlob,
+ CERT_QUERY_CONTENT_FLAG_CERT,
+ CERT_QUERY_FORMAT_FLAG_ALL,
+ 0,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ (const void**)m_cert.Reset()))
+ {
+ // TODO: Because Swiften isn't exception safe, we have no way to indicate failure
+ }
+ }
+}
+
+//------------------------------------------------------------------------
+
+ByteArray SchannelCertificate::toDER() const
+{
+ ByteArray result;
+
+ // Serialize the certificate. The CERT_CONTEXT is already DER encoded.
+ result.resize(m_cert->cbCertEncoded);
+ memcpy(&result[0], m_cert->pbCertEncoded, result.size());
+
+ return result;
+}
+
+//------------------------------------------------------------------------
+
+std::string SchannelCertificate::wstrToStr(const std::wstring& wstr)
+{
+ if (wstr.empty())
+ return "";
+
+ // First request the size of the required UTF-8 buffer
+ int numRequiredBytes = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), wstr.size(), NULL, 0, NULL, NULL);
+ if (!numRequiredBytes)
+ return "";
+
+ // Allocate memory for the UTF-8 string
+ std::vector<char> utf8Str(numRequiredBytes);
+
+ int numConverted = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), wstr.size(), &utf8Str[0], numRequiredBytes, NULL, NULL);
+ if (!numConverted)
+ return "";
+
+ std::string str(&utf8Str[0], numConverted);
+ return str;
+}
+
+//------------------------------------------------------------------------
+
+void SchannelCertificate::parse()
+{
+ //
+ // Subject name
+ //
+ DWORD requiredSize = CertNameToStr(X509_ASN_ENCODING, &m_cert->pCertInfo->Subject, CERT_OID_NAME_STR, NULL, 0);
+ if (requiredSize > 1)
+ {
+ vector<char> rawSubjectName(requiredSize);
+ CertNameToStr(X509_ASN_ENCODING, &m_cert->pCertInfo->Subject, CERT_OID_NAME_STR, &rawSubjectName[0], rawSubjectName.size());
+ m_subjectName = std::string(&rawSubjectName[0]);
+ }
+
+ //
+ // Common name
+ //
+ // Note: We only pull out one common name from the cert.
+ requiredSize = CertGetNameString(m_cert, CERT_NAME_ATTR_TYPE, 0, szOID_COMMON_NAME, NULL, 0);
+ if (requiredSize > 1)
+ {
+ vector<char> rawCommonName(requiredSize);
+ requiredSize = CertGetNameString(m_cert, CERT_NAME_ATTR_TYPE, 0, szOID_COMMON_NAME, &rawCommonName[0], rawCommonName.size());
+ m_commonNames.push_back( std::string(&rawCommonName[0]) );
+ }
+
+ //
+ // Subject alternative names
+ //
+ PCERT_EXTENSION pExtensions = CertFindExtension(szOID_SUBJECT_ALT_NAME2, m_cert->pCertInfo->cExtension, m_cert->pCertInfo->rgExtension);
+ if (pExtensions)
+ {
+ CRYPT_DECODE_PARA decodePara = {0};
+ decodePara.cbSize = sizeof(decodePara);
+
+ CERT_ALT_NAME_INFO* pAltNameInfo = NULL;
+ DWORD altNameInfoSize = 0;
+
+ BOOL status = CryptDecodeObjectEx(
+ X509_ASN_ENCODING | PKCS_7_ASN_ENCODING,
+ szOID_SUBJECT_ALT_NAME2,
+ pExtensions->Value.pbData,
+ pExtensions->Value.cbData,
+ CRYPT_DECODE_ALLOC_FLAG | CRYPT_DECODE_NOCOPY_FLAG,
+ &decodePara,
+ &pAltNameInfo,
+ &altNameInfoSize);
+
+ if (status && pAltNameInfo)
+ {
+ for (int i = 0; i < pAltNameInfo->cAltEntry; i++)
+ {
+ if (pAltNameInfo->rgAltEntry[i].dwAltNameChoice == CERT_ALT_NAME_DNS_NAME)
+ addDNSName( wstrToStr( pAltNameInfo->rgAltEntry[i].pwszDNSName ) );
+ }
+ }
+ }
+
+ // if (pExtensions)
+ // {
+ // vector<wchar_t> subjectAlt
+ // CryptDecodeObject(X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, szOID_SUBJECT_ALT_NAME, pExtensions->Value->pbData, pExtensions->Value->cbData, )
+ // }
+ //
+ // // subjectAltNames
+ // int subjectAltNameLoc = X509_get_ext_by_NID(cert.get(), NID_subject_alt_name, -1);
+ // if(subjectAltNameLoc != -1) {
+ // X509_EXTENSION* extension = X509_get_ext(cert.get(), subjectAltNameLoc);
+ // boost::shared_ptr<GENERAL_NAMES> generalNames(reinterpret_cast<GENERAL_NAMES*>(X509V3_EXT_d2i(extension)), GENERAL_NAMES_free);
+ // boost::shared_ptr<ASN1_OBJECT> xmppAddrObject(OBJ_txt2obj(ID_ON_XMPPADDR_OID, 1), ASN1_OBJECT_free);
+ // boost::shared_ptr<ASN1_OBJECT> dnsSRVObject(OBJ_txt2obj(ID_ON_DNSSRV_OID, 1), ASN1_OBJECT_free);
+ // for (int i = 0; i < sk_GENERAL_NAME_num(generalNames.get()); ++i) {
+ // GENERAL_NAME* generalName = sk_GENERAL_NAME_value(generalNames.get(), i);
+ // if (generalName->type == GEN_OTHERNAME) {
+ // OTHERNAME* otherName = generalName->d.otherName;
+ // if (OBJ_cmp(otherName->type_id, xmppAddrObject.get()) == 0) {
+ // // XmppAddr
+ // if (otherName->value->type != V_ASN1_UTF8STRING) {
+ // continue;
+ // }
+ // ASN1_UTF8STRING* xmppAddrValue = otherName->value->value.utf8string;
+ // addXMPPAddress(ByteArray(ASN1_STRING_data(xmppAddrValue), ASN1_STRING_length(xmppAddrValue)).toString());
+ // }
+ // else if (OBJ_cmp(otherName->type_id, dnsSRVObject.get()) == 0) {
+ // // SRVName
+ // if (otherName->value->type != V_ASN1_IA5STRING) {
+ // continue;
+ // }
+ // ASN1_IA5STRING* srvNameValue = otherName->value->value.ia5string;
+ // addSRVName(ByteArray(ASN1_STRING_data(srvNameValue), ASN1_STRING_length(srvNameValue)).toString());
+ // }
+ // }
+ // else if (generalName->type == GEN_DNS) {
+ // // DNSName
+ // addDNSName(ByteArray(ASN1_STRING_data(generalName->d.dNSName), ASN1_STRING_length(generalName->d.dNSName)).toString());
+ // }
+ // }
+ // }
+}
+
+//------------------------------------------------------------------------
+
+}
diff --git a/Swiften/TLS/Schannel/SchannelCertificate.h b/Swiften/TLS/Schannel/SchannelCertificate.h
new file mode 100644
index 0000000..f531cff
--- /dev/null
+++ b/Swiften/TLS/Schannel/SchannelCertificate.h
@@ -0,0 +1,81 @@
+/*
+ * Copyright (c) 2011 Soren Dreijer
+ * Licensed under the simplified BSD license.
+ * See Documentation/Licenses/BSD-simplified.txt for more information.
+ */
+
+#pragma once
+
+#include <boost/shared_ptr.hpp>
+
+#include "Swiften/Base/String.h"
+#include "Swiften/TLS/Certificate.h"
+#include "Swiften/TLS/Schannel/SchannelUtil.h"
+
+namespace Swift
+{
+ class SchannelCertificate : public Certificate
+ {
+ public:
+ typedef boost::shared_ptr<SchannelCertificate> ref;
+
+ public:
+ SchannelCertificate(const ScopedCertContext& certCtxt);
+ SchannelCertificate(const ByteArray& der);
+
+ std::string getSubjectName() const
+ {
+ return m_subjectName;
+ }
+
+ std::vector<std::string> getCommonNames() const
+ {
+ return m_commonNames;
+ }
+
+ std::vector<std::string> getSRVNames() const
+ {
+ return m_srvNames;
+ }
+
+ std::vector<std::string> getDNSNames() const
+ {
+ return m_dnsNames;
+ }
+
+ std::vector<std::string> getXMPPAddresses() const
+ {
+ return m_xmppAddresses;
+ }
+
+ ByteArray toDER() const;
+
+ private:
+ void parse();
+ std::string wstrToStr(const std::wstring& wstr);
+
+ void addSRVName(const std::string& name)
+ {
+ m_srvNames.push_back(name);
+ }
+
+ void addDNSName(const std::string& name)
+ {
+ m_dnsNames.push_back(name);
+ }
+
+ void addXMPPAddress(const std::string& addr)
+ {
+ m_xmppAddresses.push_back(addr);
+ }
+
+ private:
+ ScopedCertContext m_cert;
+
+ std::string m_subjectName;
+ std::vector<std::string> m_commonNames;
+ std::vector<std::string> m_dnsNames;
+ std::vector<std::string> m_xmppAddresses;
+ std::vector<std::string> m_srvNames;
+ };
+}
diff --git a/Swiften/TLS/Schannel/SchannelCertificateFactory.h b/Swiften/TLS/Schannel/SchannelCertificateFactory.h
new file mode 100644
index 0000000..d09bb54
--- /dev/null
+++ b/Swiften/TLS/Schannel/SchannelCertificateFactory.h
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2011 Soren Dreijer
+ * Licensed under the simplified BSD license.
+ * See Documentation/Licenses/BSD-simplified.txt for more information.
+ */
+
+#pragma once
+
+#include <Swiften/TLS/CertificateFactory.h>
+#include <Swiften/TLS/Schannel/SchannelCertificate.h>
+
+namespace Swift {
+ class SchannelCertificateFactory : public CertificateFactory {
+ public:
+ virtual Certificate::ref createCertificateFromDER(const ByteArray& der) {
+ return Certificate::ref(new SchannelCertificate(der));
+ }
+ };
+}
diff --git a/Swiften/TLS/Schannel/SchannelContext.cpp b/Swiften/TLS/Schannel/SchannelContext.cpp
new file mode 100644
index 0000000..6771d4a
--- /dev/null
+++ b/Swiften/TLS/Schannel/SchannelContext.cpp
@@ -0,0 +1,503 @@
+/*
+ * Copyright (c) 2011 Soren Dreijer
+ * Licensed under the simplified BSD license.
+ * See Documentation/Licenses/BSD-simplified.txt for more information.
+ */
+
+#include "Swiften/TLS/Schannel/SchannelContext.h"
+#include "Swiften/TLS/Schannel/SchannelCertificate.h"
+
+namespace Swift {
+
+//------------------------------------------------------------------------
+
+SchannelContext::SchannelContext()
+: m_state(Start)
+, m_secContext(0)
+, m_verificationError(CertificateVerificationError::UnknownError)
+{
+ m_ctxtFlags = ISC_REQ_ALLOCATE_MEMORY |
+ ISC_REQ_CONFIDENTIALITY |
+ ISC_REQ_EXTENDED_ERROR |
+ ISC_REQ_INTEGRITY |
+ ISC_REQ_REPLAY_DETECT |
+ ISC_REQ_SEQUENCE_DETECT |
+ ISC_REQ_USE_SUPPLIED_CREDS |
+ ISC_REQ_STREAM;
+
+ ZeroMemory(&m_streamSizes, sizeof(m_streamSizes));
+}
+
+//------------------------------------------------------------------------
+
+void SchannelContext::determineStreamSizes()
+{
+ QueryContextAttributes(m_ctxtHandle, SECPKG_ATTR_STREAM_SIZES, &m_streamSizes);
+}
+
+//------------------------------------------------------------------------
+
+void SchannelContext::connect()
+{
+ m_state = Connecting;
+
+ // We use an empty list for client certificates
+ PCCERT_CONTEXT clientCerts[1] = {0};
+
+ SCHANNEL_CRED sc = {0};
+ sc.dwVersion = SCHANNEL_CRED_VERSION;
+ sc.cCreds = 0; // Let Crypto API find the appropriate certificate for us
+ sc.paCred = clientCerts;
+ sc.grbitEnabledProtocols = SP_PROT_SSL3_CLIENT | SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_2_CLIENT;
+ sc.dwFlags = SCH_CRED_AUTO_CRED_VALIDATION | /*SCH_CRED_NO_DEFAULT_CREDS*/ SCH_CRED_USE_DEFAULT_CREDS | SCH_CRED_REVOCATION_CHECK_CHAIN;
+
+ // Swiften performs the server name check for us
+ sc.dwFlags |= SCH_CRED_NO_SERVERNAME_CHECK;
+
+ SECURITY_STATUS status = AcquireCredentialsHandle(
+ NULL,
+ UNISP_NAME,
+ SECPKG_CRED_OUTBOUND,
+ NULL,
+ &sc,
+ NULL,
+ NULL,
+ m_credHandle.Reset(),
+ NULL);
+
+ if (status != SEC_E_OK)
+ {
+ // We failed to obtain the credentials handle
+ indicateError();
+ return;
+ }
+
+ SecBuffer outBuffers[2];
+
+ // We let Schannel allocate the output buffer for us
+ outBuffers[0].pvBuffer = NULL;
+ outBuffers[0].cbBuffer = 0;
+ outBuffers[0].BufferType = SECBUFFER_TOKEN;
+
+ // Contains alert data if an alert is generated
+ outBuffers[1].pvBuffer = NULL;
+ outBuffers[1].cbBuffer = 0;
+ outBuffers[1].BufferType = SECBUFFER_ALERT;
+
+ // Make sure the output buffers are freed
+ ScopedSecBuffer scopedOutputData(&outBuffers[0]);
+ ScopedSecBuffer scopedOutputAlertData(&outBuffers[1]);
+
+ SecBufferDesc outBufferDesc = {0};
+ outBufferDesc.cBuffers = 2;
+ outBufferDesc.pBuffers = outBuffers;
+ outBufferDesc.ulVersion = SECBUFFER_VERSION;
+
+ // Create the initial security context
+ status = InitializeSecurityContext(
+ m_credHandle,
+ NULL,
+ NULL,
+ m_ctxtFlags,
+ 0,
+ 0,
+ NULL,
+ 0,
+ m_ctxtHandle.Reset(),
+ &outBufferDesc,
+ &m_secContext,
+ NULL);
+
+ if (status != SEC_E_OK && status != SEC_I_CONTINUE_NEEDED)
+ {
+ // We failed to initialize the security context
+ indicateError();
+ return;
+ }
+
+ // Start the handshake
+ sendDataOnNetwork(outBuffers[0].pvBuffer, outBuffers[0].cbBuffer);
+
+ if (status == SEC_E_OK)
+ {
+ m_state = Connected;
+ determineStreamSizes();
+
+ onConnected();
+ }
+}
+
+//------------------------------------------------------------------------
+
+void SchannelContext::appendNewData(const SafeByteArray& data)
+{
+ size_t originalSize = m_receivedData.size();
+ m_receivedData.resize( originalSize + data.size() );
+ memcpy( &m_receivedData[0] + originalSize, &data[0], data.size() );
+}
+
+//------------------------------------------------------------------------
+
+void SchannelContext::continueHandshake(const SafeByteArray& data)
+{
+ appendNewData(data);
+
+ while (!m_receivedData.empty())
+ {
+ SecBuffer inBuffers[2];
+
+ // Provide Schannel with the remote host's handshake data
+ inBuffers[0].pvBuffer = (char*)(&m_receivedData[0]);
+ inBuffers[0].cbBuffer = (unsigned long)m_receivedData.size();
+ inBuffers[0].BufferType = SECBUFFER_TOKEN;
+
+ inBuffers[1].pvBuffer = NULL;
+ inBuffers[1].cbBuffer = 0;
+ inBuffers[1].BufferType = SECBUFFER_EMPTY;
+
+ SecBufferDesc inBufferDesc = {0};
+ inBufferDesc.cBuffers = 2;
+ inBufferDesc.pBuffers = inBuffers;
+ inBufferDesc.ulVersion = SECBUFFER_VERSION;
+
+ SecBuffer outBuffers[2];
+
+ // We let Schannel allocate the output buffer for us
+ outBuffers[0].pvBuffer = NULL;
+ outBuffers[0].cbBuffer = 0;
+ outBuffers[0].BufferType = SECBUFFER_TOKEN;
+
+ // Contains alert data if an alert is generated
+ outBuffers[1].pvBuffer = NULL;
+ outBuffers[1].cbBuffer = 0;
+ outBuffers[1].BufferType = SECBUFFER_ALERT;
+
+ // Make sure the output buffers are freed
+ ScopedSecBuffer scopedOutputData(&outBuffers[0]);
+ ScopedSecBuffer scopedOutputAlertData(&outBuffers[1]);
+
+ SecBufferDesc outBufferDesc = {0};
+ outBufferDesc.cBuffers = 2;
+ outBufferDesc.pBuffers = outBuffers;
+ outBufferDesc.ulVersion = SECBUFFER_VERSION;
+
+ SECURITY_STATUS status = InitializeSecurityContext(
+ m_credHandle,
+ m_ctxtHandle,
+ NULL,
+ m_ctxtFlags,
+ 0,
+ 0,
+ &inBufferDesc,
+ 0,
+ NULL,
+ &outBufferDesc,
+ &m_secContext,
+ NULL);
+
+ if (status == SEC_E_INCOMPLETE_MESSAGE)
+ {
+ // Wait for more data to arrive
+ break;
+ }
+ else if (status == SEC_I_CONTINUE_NEEDED)
+ {
+ SecBuffer* pDataBuffer = &outBuffers[0];
+ SecBuffer* pExtraBuffer = &inBuffers[1];
+
+ if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL)
+ sendDataOnNetwork(pDataBuffer->pvBuffer, pDataBuffer->cbBuffer);
+
+ if (pExtraBuffer->BufferType == SECBUFFER_EXTRA)
+ m_receivedData.erase(m_receivedData.begin(), m_receivedData.end() - pExtraBuffer->cbBuffer);
+ else
+ m_receivedData.clear();
+
+ break;
+ }
+ else if (status == SEC_E_OK)
+ {
+ SecBuffer* pExtraBuffer = &inBuffers[1];
+
+ if (pExtraBuffer && pExtraBuffer->cbBuffer > 0)
+ m_receivedData.erase(m_receivedData.begin(), m_receivedData.end() - pExtraBuffer->cbBuffer);
+ else
+ m_receivedData.clear();
+
+ m_state = Connected;
+ determineStreamSizes();
+
+ onConnected();
+ }
+ else
+ {
+ // We failed to initialize the security context
+ indicateError();
+ return;
+ }
+ }
+}
+
+//------------------------------------------------------------------------
+
+void SchannelContext::sendDataOnNetwork(const void* pData, size_t dataSize)
+{
+ if (dataSize > 0 && pData)
+ {
+ SafeByteArray byteArray(dataSize);
+ memcpy(&byteArray[0], pData, dataSize);
+
+ onDataForNetwork(byteArray);
+ }
+}
+
+//------------------------------------------------------------------------
+
+void SchannelContext::forwardDataToApplication(const void* pData, size_t dataSize)
+{
+ SafeByteArray byteArray(dataSize);
+ memcpy(&byteArray[0], pData, dataSize);
+
+ onDataForApplication(byteArray);
+}
+
+//------------------------------------------------------------------------
+
+void SchannelContext::handleDataFromApplication(const SafeByteArray& data)
+{
+ // Don't attempt to send data until we're fully connected
+ if (m_state == Connecting)
+ return;
+
+ // Encrypt the data
+ encryptAndSendData(data);
+}
+
+//------------------------------------------------------------------------
+
+void SchannelContext::handleDataFromNetwork(const SafeByteArray& data)
+{
+ switch (m_state)
+ {
+ case Connecting:
+ {
+ // We're still establishing the connection, so continue the handshake
+ continueHandshake(data);
+ }
+ break;
+
+ case Connected:
+ {
+ // Decrypt the data
+ decryptAndProcessData(data);
+ }
+ break;
+
+ default:
+ return;
+ }
+}
+
+//------------------------------------------------------------------------
+
+void SchannelContext::indicateError()
+{
+ m_state = Error;
+ m_receivedData.clear();
+ onError();
+}
+
+//------------------------------------------------------------------------
+
+void SchannelContext::decryptAndProcessData(const SafeByteArray& data)
+{
+ SecBuffer inBuffers[4] = {0};
+
+ appendNewData(data);
+
+ while (!m_receivedData.empty())
+ {
+ //
+ // MSDN:
+ // When using the Schannel SSP with contexts that are not connection oriented, on input,
+ // the structure must contain four SecBuffer structures. Exactly one buffer must be of type
+ // SECBUFFER_DATA and contain an encrypted message, which is decrypted in place. The remaining
+ // buffers are used for output and must be of type SECBUFFER_EMPTY. For connection-oriented
+ // contexts, a SECBUFFER_DATA type buffer must be supplied, as noted for nonconnection-oriented
+ // contexts. Additionally, a second SECBUFFER_TOKEN type buffer that contains a security token
+ // must also be supplied.
+ //
+ inBuffers[0].pvBuffer = (char*)(&m_receivedData[0]);
+ inBuffers[0].cbBuffer = (unsigned long)m_receivedData.size();
+ inBuffers[0].BufferType = SECBUFFER_DATA;
+
+ inBuffers[1].BufferType = SECBUFFER_EMPTY;
+ inBuffers[2].BufferType = SECBUFFER_EMPTY;
+ inBuffers[3].BufferType = SECBUFFER_EMPTY;
+
+ SecBufferDesc inBufferDesc = {0};
+ inBufferDesc.cBuffers = 4;
+ inBufferDesc.pBuffers = inBuffers;
+ inBufferDesc.ulVersion = SECBUFFER_VERSION;
+
+ size_t inData = m_receivedData.size();
+ SECURITY_STATUS status = DecryptMessage(m_ctxtHandle, &inBufferDesc, 0, NULL);
+
+ if (status == SEC_E_INCOMPLETE_MESSAGE)
+ {
+ // Wait for more data to arrive
+ break;
+ }
+ else if (status == SEC_I_RENEGOTIATE)
+ {
+ // TODO: Handle renegotiation scenarios
+ indicateError();
+ break;
+ }
+ else if (status == SEC_I_CONTEXT_EXPIRED)
+ {
+ indicateError();
+ break;
+ }
+ else if (status != SEC_E_OK)
+ {
+ indicateError();
+ break;
+ }
+
+ SecBuffer* pDataBuffer = NULL;
+ SecBuffer* pExtraBuffer = NULL;
+ for (int i = 0; i < 4; ++i)
+ {
+ if (inBuffers[i].BufferType == SECBUFFER_DATA)
+ pDataBuffer = &inBuffers[i];
+
+ else if (inBuffers[i].BufferType == SECBUFFER_EXTRA)
+ pExtraBuffer = &inBuffers[i];
+ }
+
+ if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL)
+ forwardDataToApplication(pDataBuffer->pvBuffer, pDataBuffer->cbBuffer);
+
+ // If there is extra data left over from the decryption operation, we call DecryptMessage() again
+ if (pExtraBuffer)
+ {
+ m_receivedData.erase(m_receivedData.begin(), m_receivedData.end() - pExtraBuffer->cbBuffer);
+ }
+ else
+ {
+ // We're done
+ m_receivedData.erase(m_receivedData.begin(), m_receivedData.begin() + inData);
+ }
+ }
+}
+
+//------------------------------------------------------------------------
+
+void SchannelContext::encryptAndSendData(const SafeByteArray& data)
+{
+ SecBuffer outBuffers[4] = {0};
+
+ // Calculate the largest required size of the send buffer
+ size_t messageBufferSize = (data.size() > m_streamSizes.cbMaximumMessage)
+ ? m_streamSizes.cbMaximumMessage
+ : data.size();
+
+ // Allocate a packet for the encrypted data
+ SafeByteArray sendBuffer;
+ sendBuffer.resize(m_streamSizes.cbHeader + messageBufferSize + m_streamSizes.cbTrailer);
+
+ size_t bytesSent = 0;
+ do
+ {
+ size_t bytesLeftToSend = data.size() - bytesSent;
+
+ // Calculate how much of the send buffer we'll be using for this chunk
+ size_t bytesToSend = (bytesLeftToSend > m_streamSizes.cbMaximumMessage)
+ ? m_streamSizes.cbMaximumMessage
+ : bytesLeftToSend;
+
+ // Copy the plain text data into the send buffer
+ memcpy(&sendBuffer[0] + m_streamSizes.cbHeader, &data[0] + bytesSent, bytesToSend);
+
+ outBuffers[0].pvBuffer = &sendBuffer[0];
+ outBuffers[0].cbBuffer = m_streamSizes.cbHeader;
+ outBuffers[0].BufferType = SECBUFFER_STREAM_HEADER;
+
+ outBuffers[1].pvBuffer = &sendBuffer[0] + m_streamSizes.cbHeader;
+ outBuffers[1].cbBuffer = (unsigned long)bytesToSend;
+ outBuffers[1].BufferType = SECBUFFER_DATA;
+
+ outBuffers[2].pvBuffer = &sendBuffer[0] + m_streamSizes.cbHeader + bytesToSend;
+ outBuffers[2].cbBuffer = m_streamSizes.cbTrailer;
+ outBuffers[2].BufferType = SECBUFFER_STREAM_TRAILER;
+
+ outBuffers[3].pvBuffer = 0;
+ outBuffers[3].cbBuffer = 0;
+ outBuffers[3].BufferType = SECBUFFER_EMPTY;
+
+ SecBufferDesc outBufferDesc = {0};
+ outBufferDesc.cBuffers = 4;
+ outBufferDesc.pBuffers = outBuffers;
+ outBufferDesc.ulVersion = SECBUFFER_VERSION;
+
+ SECURITY_STATUS status = EncryptMessage(m_ctxtHandle, 0, &outBufferDesc, 0);
+ if (status != SEC_E_OK)
+ {
+ indicateError();
+ return;
+ }
+
+ sendDataOnNetwork(&sendBuffer[0], outBuffers[0].cbBuffer + outBuffers[1].cbBuffer + outBuffers[2].cbBuffer);
+ bytesSent += bytesToSend;
+
+ } while (bytesSent < data.size());
+}
+
+//------------------------------------------------------------------------
+
+bool SchannelContext::setClientCertificate(const PKCS12Certificate& certificate)
+{
+ return false;
+}
+
+//------------------------------------------------------------------------
+
+Certificate::ref SchannelContext::getPeerCertificate() const
+{
+ SchannelCertificate::ref pCertificate;
+
+ ScopedCertContext pServerCert;
+ SECURITY_STATUS status = QueryContextAttributes(m_ctxtHandle, SECPKG_ATTR_REMOTE_CERT_CONTEXT, pServerCert.Reset());
+ if (status != SEC_E_OK)
+ return pCertificate;
+
+ pCertificate.reset( new SchannelCertificate(pServerCert) );
+ return pCertificate;
+}
+
+//------------------------------------------------------------------------
+
+CertificateVerificationError::ref SchannelContext::getPeerCertificateVerificationError() const
+{
+ boost::shared_ptr<CertificateVerificationError> pCertError;
+
+ if (m_state == Error)
+ pCertError.reset( new CertificateVerificationError(m_verificationError) );
+
+ return pCertError;
+}
+
+//------------------------------------------------------------------------
+
+ByteArray SchannelContext::getFinishMessage() const
+{
+ // TODO: Implement
+
+ ByteArray emptyArray;
+ return emptyArray;
+}
+
+//------------------------------------------------------------------------
+
+}
diff --git a/Swiften/TLS/Schannel/SchannelContext.h b/Swiften/TLS/Schannel/SchannelContext.h
new file mode 100644
index 0000000..66467fe
--- /dev/null
+++ b/Swiften/TLS/Schannel/SchannelContext.h
@@ -0,0 +1,81 @@
+/*
+ * Copyright (c) 2011 Soren Dreijer
+ * Licensed under the simplified BSD license.
+ * See Documentation/Licenses/BSD-simplified.txt for more information.
+ */
+
+#pragma once
+
+#include "Swiften/Base/boost_bsignals.h"
+
+#include "Swiften/TLS/TLSContext.h"
+#include "Swiften/TLS/Schannel/SchannelUtil.h"
+#include "Swiften/Base/ByteArray.h"
+
+#define SECURITY_WIN32
+#include <Windows.h>
+#include <Schannel.h>
+#include <security.h>
+#include <schnlsp.h>
+
+#include <boost/noncopyable.hpp>
+
+namespace Swift
+{
+ class SchannelContext : public TLSContext, boost::noncopyable
+ {
+ public:
+ typedef boost::shared_ptr<SchannelContext> sp_t;
+
+ public:
+ SchannelContext();
+
+ //
+ // TLSContext
+ //
+ virtual void connect();
+ virtual bool setClientCertificate(const PKCS12Certificate&);
+
+ virtual void handleDataFromNetwork(const SafeByteArray& data);
+ virtual void handleDataFromApplication(const SafeByteArray& data);
+
+ virtual Certificate::ref getPeerCertificate() const;
+ virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const;
+
+ virtual ByteArray getFinishMessage() const;
+
+ private:
+ void determineStreamSizes();
+ void continueHandshake(const SafeByteArray& data);
+ void indicateError();
+
+ void sendDataOnNetwork(const void* pData, size_t dataSize);
+ void forwardDataToApplication(const void* pData, size_t dataSize);
+
+ void decryptAndProcessData(const SafeByteArray& data);
+ void encryptAndSendData(const SafeByteArray& data);
+
+ void appendNewData(const SafeByteArray& data);
+
+ private:
+ enum SchannelState
+ {
+ Start,
+ Connecting,
+ Connected,
+ Error
+
+ };
+
+ SchannelState m_state;
+ CertificateVerificationError m_verificationError;
+
+ ULONG m_secContext;
+ ScopedCredHandle m_credHandle;
+ ScopedCtxtHandle m_ctxtHandle;
+ DWORD m_ctxtFlags;
+ SecPkgContext_StreamSizes m_streamSizes;
+
+ std::vector<char> m_receivedData;
+ };
+}
diff --git a/Swiften/TLS/Schannel/SchannelContextFactory.cpp b/Swiften/TLS/Schannel/SchannelContextFactory.cpp
new file mode 100644
index 0000000..8ab7c6c
--- /dev/null
+++ b/Swiften/TLS/Schannel/SchannelContextFactory.cpp
@@ -0,0 +1,20 @@
+/*
+ * Copyright (c) 2011 Soren Dreijer
+ * Licensed under the simplified BSD license.
+ * See Documentation/Licenses/BSD-simplified.txt for more information.
+ */
+
+#include "Swiften/TLS/Schannel/SchannelContextFactory.h"
+#include "Swiften/TLS/Schannel/SchannelContext.h"
+
+namespace Swift {
+
+bool SchannelContextFactory::canCreate() const {
+ return true;
+}
+
+TLSContext* SchannelContextFactory::createTLSContext() {
+ return new SchannelContext();
+}
+
+}
diff --git a/Swiften/TLS/Schannel/SchannelContextFactory.h b/Swiften/TLS/Schannel/SchannelContextFactory.h
new file mode 100644
index 0000000..43c39a9
--- /dev/null
+++ b/Swiften/TLS/Schannel/SchannelContextFactory.h
@@ -0,0 +1,17 @@
+/*
+ * Copyright (c) 2011 Soren Dreijer
+ * Licensed under the simplified BSD license.
+ * See Documentation/Licenses/BSD-simplified.txt for more information.
+ */
+
+#pragma once
+
+#include "Swiften/TLS/TLSContextFactory.h"
+
+namespace Swift {
+ class SchannelContextFactory : public TLSContextFactory {
+ public:
+ bool canCreate() const;
+ virtual TLSContext* createTLSContext();
+ };
+}
diff --git a/Swiften/TLS/Schannel/SchannelUtil.h b/Swiften/TLS/Schannel/SchannelUtil.h
new file mode 100644
index 0000000..0a54f16
--- /dev/null
+++ b/Swiften/TLS/Schannel/SchannelUtil.h
@@ -0,0 +1,294 @@
+/*
+ * Copyright (c) 2011 Soren Dreijer
+ * Licensed under the simplified BSD license.
+ * See Documentation/Licenses/BSD-simplified.txt for more information.
+ */
+
+#pragma once
+
+#define SECURITY_WIN32
+#include <Windows.h>
+#include <Schannel.h>
+#include <security.h>
+#include <schnlsp.h>
+
+#include <boost/noncopyable.hpp>
+
+namespace Swift
+{
+ //
+ // Convenience wrapper around the Schannel CredHandle struct.
+ //
+ class ScopedCredHandle
+ {
+ private:
+ struct HandleContext
+ {
+ HandleContext()
+ {
+ ZeroMemory(&m_h, sizeof(m_h));
+ }
+
+ HandleContext(const CredHandle& h)
+ {
+ memcpy(&m_h, &h, sizeof(m_h));
+ }
+
+ ~HandleContext()
+ {
+ ::FreeCredentialsHandle(&m_h);
+ }
+
+ CredHandle m_h;
+ };
+
+ public:
+ ScopedCredHandle()
+ : m_pHandle( new HandleContext )
+ {
+ }
+
+ explicit ScopedCredHandle(const CredHandle& h)
+ : m_pHandle( new HandleContext(h) )
+ {
+ }
+
+ // Copy constructor
+ explicit ScopedCredHandle(const ScopedCredHandle& rhs)
+ {
+ m_pHandle = rhs.m_pHandle;
+ }
+
+ ~ScopedCredHandle()
+ {
+ m_pHandle.reset();
+ }
+
+ PCredHandle Reset()
+ {
+ CloseHandle();
+ return &m_pHandle->m_h;
+ }
+
+ operator PCredHandle() const
+ {
+ return &m_pHandle->m_h;
+ }
+
+ ScopedCredHandle& operator=(const ScopedCredHandle& sh)
+ {
+ // Only update the internal handle if it's different
+ if (&m_pHandle->m_h != &sh.m_pHandle->m_h)
+ {
+ m_pHandle = sh.m_pHandle;
+ }
+
+ return *this;
+ }
+
+ void CloseHandle()
+ {
+ m_pHandle.reset( new HandleContext );
+ }
+
+ private:
+ boost::shared_ptr<HandleContext> m_pHandle;
+ };
+
+ //------------------------------------------------------------------------
+
+ //
+ // Convenience wrapper around the Schannel CtxtHandle struct.
+ //
+ class ScopedCtxtHandle
+ {
+ private:
+ struct HandleContext
+ {
+ HandleContext()
+ {
+ ZeroMemory(&m_h, sizeof(m_h));
+ }
+
+ ~HandleContext()
+ {
+ ::DeleteSecurityContext(&m_h);
+ }
+
+ CtxtHandle m_h;
+ };
+
+ public:
+ ScopedCtxtHandle()
+ : m_pHandle( new HandleContext )
+ {
+ }
+
+ explicit ScopedCtxtHandle(CredHandle h)
+ : m_pHandle( new HandleContext )
+ {
+ }
+
+ // Copy constructor
+ explicit ScopedCtxtHandle(const ScopedCtxtHandle& rhs)
+ {
+ m_pHandle = rhs.m_pHandle;
+ }
+
+ ~ScopedCtxtHandle()
+ {
+ m_pHandle.reset();
+ }
+
+ PCredHandle Reset()
+ {
+ CloseHandle();
+ return &m_pHandle->m_h;
+ }
+
+ operator PCredHandle() const
+ {
+ return &m_pHandle->m_h;
+ }
+
+ ScopedCtxtHandle& operator=(const ScopedCtxtHandle& sh)
+ {
+ // Only update the internal handle if it's different
+ if (&m_pHandle->m_h != &sh.m_pHandle->m_h)
+ {
+ m_pHandle = sh.m_pHandle;
+ }
+
+ return *this;
+ }
+
+ void CloseHandle()
+ {
+ m_pHandle.reset( new HandleContext );
+ }
+
+ private:
+ boost::shared_ptr<HandleContext> m_pHandle;
+ };
+
+ //------------------------------------------------------------------------
+
+ //
+ // Convenience wrapper around the Schannel ScopedSecBuffer struct.
+ //
+ class ScopedSecBuffer : boost::noncopyable
+ {
+ public:
+ ScopedSecBuffer(PSecBuffer pSecBuffer)
+ : m_pSecBuffer(pSecBuffer)
+ {
+ }
+
+ ~ScopedSecBuffer()
+ {
+ // Loop through all the output buffers and make sure we free them
+ if (m_pSecBuffer->pvBuffer)
+ FreeContextBuffer(m_pSecBuffer->pvBuffer);
+ }
+
+ PSecBuffer AsPtr()
+ {
+ return m_pSecBuffer;
+ }
+
+ PSecBuffer operator->()
+ {
+ return m_pSecBuffer;
+ }
+
+ private:
+ PSecBuffer m_pSecBuffer;
+ };
+
+ //------------------------------------------------------------------------
+
+ //
+ // Convenience wrapper around the Schannel PCCERT_CONTEXT.
+ //
+ class ScopedCertContext
+ {
+ private:
+ struct HandleContext
+ {
+ HandleContext()
+ : m_pCertCtxt(NULL)
+ {
+ }
+
+ HandleContext(PCCERT_CONTEXT pCert)
+ : m_pCertCtxt(pCert)
+ {
+ }
+
+ ~HandleContext()
+ {
+ if (m_pCertCtxt)
+ CertFreeCertificateContext(m_pCertCtxt);
+ }
+
+ PCCERT_CONTEXT m_pCertCtxt;
+ };
+
+ public:
+ ScopedCertContext()
+ : m_pHandle( new HandleContext )
+ {
+ }
+
+ explicit ScopedCertContext(PCCERT_CONTEXT pCert)
+ : m_pHandle( new HandleContext(pCert) )
+ {
+ }
+
+ // Copy constructor
+ explicit ScopedCertContext(const ScopedCertContext& rhs)
+ {
+ m_pHandle = rhs.m_pHandle;
+ }
+
+ ~ScopedCertContext()
+ {
+ m_pHandle.reset();
+ }
+
+ PCCERT_CONTEXT* Reset()
+ {
+ FreeContext();
+ return &m_pHandle->m_pCertCtxt;
+ }
+
+ operator PCCERT_CONTEXT() const
+ {
+ return m_pHandle->m_pCertCtxt;
+ }
+
+ PCCERT_CONTEXT operator->() const
+ {
+ return m_pHandle->m_pCertCtxt;
+ }
+
+ ScopedCertContext& operator=(const ScopedCertContext& sh)
+ {
+ // Only update the internal handle if it's different
+ if (&m_pHandle->m_pCertCtxt != &sh.m_pHandle->m_pCertCtxt)
+ {
+ m_pHandle = sh.m_pHandle;
+ }
+
+ return *this;
+ }
+
+ void FreeContext()
+ {
+ m_pHandle.reset( new HandleContext );
+ }
+
+ private:
+ boost::shared_ptr<HandleContext> m_pHandle;
+ };
+}