summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAlexey Melnikov <alexey.melnikov@isode.com>2012-02-13 17:54:23 (GMT)
committerKevin Smith <git@kismith.co.uk>2012-02-22 14:08:13 (GMT)
commit110eb87e848b85dd74a6f19413c775520a75ea35 (patch)
treeb10236387180fca676a29f24c747c9d0fd94d8dd
parent64fc103d0d5d1d523d00dcc5b231715160475f7e (diff)
downloadswift-contrib-110eb87e848b85dd74a6f19413c775520a75ea35.zip
swift-contrib-110eb87e848b85dd74a6f19413c775520a75ea35.tar.bz2
Initial implementation of using CAPI certificates with Schannel.
Introduced a new parent class for all certificates with keys (class CertificateWithKey is the new parent for PKCS12Certificate.) Switched to using "CertificateWithKey *" instead of "const CertificateWithKey&" Added calling of a Windows dialog for certificate selection when Schannel TLS implementation is used. This compiles, but is not tested. License: This patch is BSD-licensed, see Documentation/Licenses/BSD-simplified.txt for details.
-rw-r--r--Swift/QtUI/CAPICertificateSelector.cpp138
-rw-r--r--Swift/QtUI/CAPICertificateSelector.h13
-rw-r--r--Swift/QtUI/QtLoginWindow.cpp19
-rw-r--r--Swift/QtUI/SConscript3
-rw-r--r--Swiften/Client/CoreClient.cpp27
-rw-r--r--Swiften/Client/CoreClient.h2
-rw-r--r--Swiften/Session/SessionStream.cpp1
-rw-r--r--Swiften/Session/SessionStream.h12
-rw-r--r--Swiften/StreamStack/TLSLayer.cpp2
-rw-r--r--Swiften/StreamStack/TLSLayer.h4
-rw-r--r--Swiften/TLS/CAPICertificate.h196
-rw-r--r--Swiften/TLS/CertificateWithKey.h32
-rw-r--r--Swiften/TLS/OpenSSL/OpenSSLContext.cpp14
-rw-r--r--Swiften/TLS/OpenSSL/OpenSSLContext.h4
-rw-r--r--Swiften/TLS/PKCS12Certificate.h27
-rw-r--r--Swiften/TLS/Schannel/SchannelContext.cpp82
-rw-r--r--Swiften/TLS/Schannel/SchannelContext.h11
-rw-r--r--Swiften/TLS/TLSContext.h4
18 files changed, 559 insertions, 32 deletions
diff --git a/Swift/QtUI/CAPICertificateSelector.cpp b/Swift/QtUI/CAPICertificateSelector.cpp
new file mode 100644
index 0000000..44f5793
--- /dev/null
+++ b/Swift/QtUI/CAPICertificateSelector.cpp
@@ -0,0 +1,138 @@
+/*
+ * Copyright (c) 2012 Isode Limited, London, England.
+ * Licensed under the simplified BSD license.
+ * See Documentation/Licenses/BSD-simplified.txt for more information.
+ */
+
+#include <string>
+
+#define SECURITY_WIN32
+#include <Windows.h>
+#include <WinCrypt.h>
+#include <cryptuiapi.h>
+
+#include "CAPICertificateSelector.h"
+
+namespace Swift {
+
+#define cert_dlg_title L"TLS Client Certificate Selection"
+#define cert_dlg_prompt L"Select a certificate to use for authentication"
+/////Hmm, maybe we should not exlude the "location" column
+#define exclude_columns CRYPTUI_SELECT_LOCATION_COLUMN \
+ |CRYPTUI_SELECT_INTENDEDUSE_COLUMN
+
+
+
+static std::string getCertUri(PCCERT_CONTEXT cert, const char * cert_store_name) {
+ DWORD required_size;
+ char * comma;
+ char * p_in;
+ char * p_out;
+ char * subject_name;
+ std::string ret = std::string("certstore:") + cert_store_name + ":";
+
+ required_size = CertNameToStrA(cert->dwCertEncodingType,
+ &cert->pCertInfo->Subject,
+ /* Discard attribute names: */
+ CERT_SIMPLE_NAME_STR | CERT_NAME_STR_REVERSE_FLAG,
+ NULL,
+ 0);
+
+ subject_name = static_cast<char *>(malloc(required_size+1));
+
+ if (!CertNameToStrA(cert->dwCertEncodingType,
+ &cert->pCertInfo->Subject,
+ /* Discard attribute names: */
+ CERT_SIMPLE_NAME_STR | CERT_NAME_STR_REVERSE_FLAG,
+ subject_name,
+ required_size)) {
+ return "";
+ }
+
+ /* Now search for the "," (ignoring escapes)
+ and truncate the rest of the string */
+ if (subject_name[0] == '"') {
+ for (comma = subject_name + 1; comma[0]; comma++) {
+ if (comma[0] == '"') {
+ comma++;
+ if (comma[0] != '"') {
+ break;
+ }
+ }
+ }
+ } else {
+ comma = strchr(subject_name, ',');
+ }
+
+ if (comma != NULL) {
+ *comma = '\0';
+ }
+
+ /* We now need to unescape the returned RDN */
+ if (subject_name[0] == '"') {
+ for (p_in = subject_name + 1, p_out = subject_name; p_in[0]; p_in++, p_out++) {
+ if (p_in[0] == '"') {
+ p_in++;
+ }
+
+ p_out[0] = p_in[0];
+ }
+ p_out[0] = '\0';
+ }
+
+ ret += subject_name;
+ free(subject_name);
+
+ return ret;
+}
+
+std::string selectCAPICertificate() {
+
+ const char * cert_store_name = "MY";
+ PCCERT_CONTEXT cert;
+ DWORD store_flags;
+ HCERTSTORE hstore;
+ HWND hwnd;
+
+ store_flags = CERT_STORE_OPEN_EXISTING_FLAG |
+ CERT_STORE_READONLY_FLAG |
+ CERT_SYSTEM_STORE_CURRENT_USER;
+
+ hstore = CertOpenStore(CERT_STORE_PROV_SYSTEM_A, 0, 0, store_flags, cert_store_name);
+ if (!hstore) {
+ return "";
+ }
+
+
+////Does this handle need to be freed as well?
+ hwnd = GetForegroundWindow();
+ if (!hwnd) {
+ hwnd = GetActiveWindow();
+ }
+
+ /* Call Windows dialog to select a suitable certificate */
+ cert = CryptUIDlgSelectCertificateFromStore(hstore,
+ hwnd,
+ cert_dlg_title,
+ cert_dlg_prompt,
+ exclude_columns,
+ 0,
+ NULL);
+
+ if (hstore) {
+ CertCloseStore(hstore, 0);
+ }
+
+ if (cert) {
+ std::string ret = getCertUri(cert, cert_store_name);
+
+ CertFreeCertificateContext(cert);
+
+ return ret;
+ } else {
+ return "";
+ }
+}
+
+
+}
diff --git a/Swift/QtUI/CAPICertificateSelector.h b/Swift/QtUI/CAPICertificateSelector.h
new file mode 100644
index 0000000..9a0ee92
--- /dev/null
+++ b/Swift/QtUI/CAPICertificateSelector.h
@@ -0,0 +1,13 @@
+/*
+ * Copyright (c) 2012 Isode Limited, London, England.
+ * Licensed under the simplified BSD license.
+ * See Documentation/Licenses/BSD-simplified.txt for more information.
+ */
+
+#pragma once
+
+#include <string>
+
+namespace Swift {
+ std::string selectCAPICertificate();
+}
diff --git a/Swift/QtUI/QtLoginWindow.cpp b/Swift/QtUI/QtLoginWindow.cpp
index 1cd3206..6b9d389 100644
--- a/Swift/QtUI/QtLoginWindow.cpp
+++ b/Swift/QtUI/QtLoginWindow.cpp
@@ -9,70 +9,74 @@
#include <boost/bind.hpp>
#include <boost/smart_ptr/make_shared.hpp>
#include <algorithm>
#include <cassert>
#include <QApplication>
#include <QBoxLayout>
#include <QComboBox>
#include <QDesktopWidget>
#include <QFileDialog>
#include <QStatusBar>
#include <QToolButton>
#include <QLabel>
#include <QMenuBar>
#include <QHBoxLayout>
#include <qdebug.h>
#include <QCloseEvent>
#include <QCursor>
#include <QMessageBox>
#include <QKeyEvent>
#include <Swift/Controllers/UIEvents/UIEventStream.h>
#include <Swift/Controllers/UIEvents/RequestXMLConsoleUIEvent.h>
#include <Swift/Controllers/UIEvents/RequestFileTransferListUIEvent.h>
#include <Swift/Controllers/Settings/SettingsProvider.h>
#include <Swift/Controllers/SettingConstants.h>
#include <Swift/QtUI/QtUISettingConstants.h>
#include <Swiften/Base/Platform.h>
#include <Swiften/Base/Paths.h>
#include <QtAboutWidget.h>
#include <QtSwiftUtil.h>
#include <QtMainWindow.h>
#include <QtUtilities.h>
+#ifdef HAVE_SCHANNEL
+#include "CAPICertificateSelector.h"
+#endif
+
namespace Swift{
QtLoginWindow::QtLoginWindow(UIEventStream* uiEventStream, SettingsProvider* settings) : QMainWindow(), settings_(settings) {
uiEventStream_ = uiEventStream;
setWindowTitle("Swift");
#ifndef Q_WS_MAC
setWindowIcon(QIcon(":/logo-icon-16.png"));
#endif
QtUtilities::setX11Resource(this, "Main");
resize(200, 500);
setContentsMargins(0,0,0,0);
QWidget *centralWidget = new QWidget(this);
setCentralWidget(centralWidget);
QBoxLayout *topLayout = new QBoxLayout(QBoxLayout::TopToBottom, centralWidget);
stack_ = new QStackedWidget(centralWidget);
topLayout->addWidget(stack_);
topLayout->setMargin(0);
loginWidgetWrapper_ = new QWidget(this);
loginWidgetWrapper_->setSizePolicy(QSizePolicy(QSizePolicy::Expanding, QSizePolicy::Expanding));
QBoxLayout *layout = new QBoxLayout(QBoxLayout::TopToBottom, loginWidgetWrapper_);
layout->addStretch(2);
QLabel* logo = new QLabel(this);
logo->setPixmap(QPixmap(":/logo-shaded-text.256.png"));
logo->setScaledContents(true);
logo->setFixedSize(192,192);
QWidget *logoWidget = new QWidget(this);
QHBoxLayout *logoLayout = new QHBoxLayout();
logoLayout->setMargin(0);
logoLayout->addStretch(0);
logoLayout->addWidget(logo);
logoLayout->addStretch(0);
@@ -325,74 +329,81 @@ void QtLoginWindow::setIsLoggingIn(bool loggingIn) {
bool eagle = settings_->getSetting(SettingConstants::FORGET_PASSWORDS);
remember_->setEnabled(!eagle);
loginAutomatically_->setEnabled(!eagle);
}
void QtLoginWindow::loginClicked() {
if (username_->isEnabled()) {
std::string banner = settings_->getSetting(QtUISettingConstants::CLICKTHROUGH_BANNER);
if (!banner.empty()) {
QMessageBox msgBox;
msgBox.setWindowTitle(tr("Confirm terms of use"));
msgBox.setText("");
msgBox.setInformativeText(P2QSTRING(banner));
msgBox.setStandardButtons(QMessageBox::Yes | QMessageBox::No);
msgBox.setDefaultButton(QMessageBox::No);
if (msgBox.exec() != QMessageBox::Yes) {
return;
}
}
onLoginRequest(Q2PSTRING(username_->currentText()), Q2PSTRING(password_->text()), Q2PSTRING(certificateFile_), remember_->isChecked(), loginAutomatically_->isChecked());
if (settings_->getSetting(SettingConstants::FORGET_PASSWORDS)) { /* Mustn't remember logins */
username_->clearEditText();
password_->setText("");
}
} else {
onCancelLoginRequest();
}
}
void QtLoginWindow::setLoginAutomatically(bool loginAutomatically) {
loginAutomatically_->setChecked(loginAutomatically);
}
void QtLoginWindow::handleCertficateChecked(bool checked) {
if (checked) {
- certificateFile_ = QFileDialog::getOpenFileName(this, tr("Select an authentication certificate"), QString(), QString("*.cert;*.p12;*.pfx"));
- if (certificateFile_.isEmpty()) {
- certificateButton_->setChecked(false);
- }
+#ifdef HAVE_SCHANNEL
+ certificateFile_ = selectCAPICertificate();
+ if (certificateFile_.isEmpty()) {
+ certificateButton_->setChecked(false);
+ }
+#else
+ certificateFile_ = QFileDialog::getOpenFileName(this, tr("Select an authentication certificate"), QString(), QString("*.cert;*.p12;*.pfx"));
+ if (certificateFile_.isEmpty()) {
+ certificateButton_->setChecked(false);
+ }
+#endif
}
else {
certificateFile_ = "";
}
}
void QtLoginWindow::handleAbout() {
if (!aboutDialog_) {
aboutDialog_ = new QtAboutWidget();
aboutDialog_->show();
}
else {
aboutDialog_->show();
aboutDialog_->raise();
aboutDialog_->activateWindow();
}
}
void QtLoginWindow::handleShowXMLConsole() {
uiEventStream_->send(boost::shared_ptr<RequestXMLConsoleUIEvent>(new RequestXMLConsoleUIEvent()));
}
void QtLoginWindow::handleShowFileTransferOverview() {
uiEventStream_->send(boost::make_shared<RequestFileTransferListUIEvent>());
}
void QtLoginWindow::handleToggleSounds(bool enabled) {
settings_->storeSetting(SettingConstants::PLAY_SOUNDS, enabled);
}
void QtLoginWindow::handleToggleNotifications(bool enabled) {
settings_->storeSetting(SettingConstants::SHOW_NOTIFICATIONS, enabled);
}
void QtLoginWindow::handleQuit() {
diff --git a/Swift/QtUI/SConscript b/Swift/QtUI/SConscript
index d37958f..a8b8c78 100644
--- a/Swift/QtUI/SConscript
+++ b/Swift/QtUI/SConscript
@@ -23,70 +23,72 @@ myenv = env.Clone()
myenv["CXXFLAGS"] = filter(lambda x : x != "-Wfloat-equal", myenv["CXXFLAGS"])
myenv.UseFlags(env["SWIFT_CONTROLLERS_FLAGS"])
myenv.UseFlags(env["SWIFTOOLS_FLAGS"])
if myenv["HAVE_XSS"] :
myenv.UseFlags(env["XSS_FLAGS"])
if env["PLATFORM"] == "posix" :
myenv.Append(LIBS = ["X11"])
if myenv["HAVE_SPARKLE"] :
myenv.UseFlags(env["SPARKLE_FLAGS"])
myenv.UseFlags(env["SWIFTEN_FLAGS"])
myenv.UseFlags(env["SWIFTEN_DEP_FLAGS"])
if myenv.get("HAVE_GROWL", False) :
myenv.UseFlags(myenv["GROWL_FLAGS"])
myenv.Append(CPPDEFINES = ["HAVE_GROWL"])
if myenv["swift_mobile"] :
myenv.Append(CPPDEFINES = ["SWIFT_MOBILE"])
if myenv.get("HAVE_SNARL", False) :
myenv.UseFlags(myenv["SNARL_FLAGS"])
myenv.Append(CPPDEFINES = ["HAVE_SNARL"])
myenv.UseFlags(myenv["PLATFORM_FLAGS"])
myenv.Tool("qt4", toolpath = ["#/BuildTools/SCons/Tools"])
myenv.Tool("nsis", toolpath = ["#/BuildTools/SCons/Tools"])
myenv.Tool("wix", toolpath = ["#/BuildTools/SCons/Tools"])
qt4modules = ['QtCore', 'QtGui', 'QtWebKit']
if env["PLATFORM"] == "posix" :
qt4modules += ["QtDBus"]
myenv.EnableQt4Modules(qt4modules, debug = False)
myenv.Append(CPPPATH = ["."])
if env["PLATFORM"] == "win32" :
#myenv["LINKFLAGS"] = ["/SUBSYSTEM:CONSOLE"]
myenv.Append(LINKFLAGS = ["/SUBSYSTEM:WINDOWS"])
myenv.Append(LIBS = "qtmain")
+ if myenv.get("HAVE_SCHANNEL", 0) :
+ myenv.Append(LIBS = "Cryptui")
myenv.WriteVal("DefaultTheme.qrc", myenv.Value(generateDefaultTheme(myenv.Dir("#/Swift/resources/themes/Default"))))
sources = [
"main.cpp",
"QtAboutWidget.cpp",
"QtAvatarWidget.cpp",
"QtUIFactory.cpp",
"QtChatWindowFactory.cpp",
"QtChatWindow.cpp",
"QtClickableLabel.cpp",
"QtLoginWindow.cpp",
"QtMainWindow.cpp",
"QtProfileWindow.cpp",
"QtNameWidget.cpp",
"QtSettingsProvider.cpp",
"QtStatusWidget.cpp",
"QtScaledAvatarCache.cpp",
"QtSwift.cpp",
"QtURIHandler.cpp",
"QtChatView.cpp",
"QtChatTheme.cpp",
"QtChatTabs.cpp",
"QtSoundPlayer.cpp",
"QtSystemTray.cpp",
"QtCachedImageScaler.cpp",
"QtTabbable.cpp",
"QtTabWidget.cpp",
"QtTextEdit.cpp",
"QtXMLConsoleWidget.cpp",
"QtFileTransferListWidget.cpp",
"QtFileTransferListItemModel.cpp",
"QtAdHocCommandWindow.cpp",
"QtUtilities.cpp",
"QtBookmarkDetailWindow.cpp",
@@ -119,70 +121,71 @@ sources = [
"ChatList/ChatListModel.cpp",
"ChatList/ChatListDelegate.cpp",
"ChatList/ChatListMUCItem.cpp",
"ChatList/ChatListRecentItem.cpp",
"MUCSearch/QtMUCSearchWindow.cpp",
"MUCSearch/MUCSearchModel.cpp",
"MUCSearch/MUCSearchRoomItem.cpp",
"MUCSearch/MUCSearchEmptyItem.cpp",
"MUCSearch/MUCSearchDelegate.cpp",
"UserSearch/QtUserSearchFirstPage.cpp",
"UserSearch/QtUserSearchFieldsPage.cpp",
"UserSearch/QtUserSearchResultsPage.cpp",
"UserSearch/QtUserSearchDetailsPage.cpp",
"UserSearch/QtUserSearchWindow.cpp",
"UserSearch/UserSearchModel.cpp",
"UserSearch/UserSearchDelegate.cpp",
"QtSubscriptionRequestWindow.cpp",
"QtRosterHeader.cpp",
"QtWebView.cpp",
"qrc_DefaultTheme.cc",
"qrc_Swift.cc",
"QtFileTransferJSBridge.cpp",
"QtMUCConfigurationWindow.cpp",
"QtAffiliationEditor.cpp",
"QtUISettingConstants.cpp"
]
myenv["SWIFT_VERSION"] = Version.getBuildVersion(env.Dir("#").abspath, "swift")
if env["PLATFORM"] == "win32" :
res = myenv.RES("#/Swift/resources/Windows/Swift.rc")
# For some reason, SCons isn't picking up the dependency correctly
# Adding it explicitly until i figure out why
myenv.Depends(res, "../Controllers/BuildVersion.h")
sources += [
+ "CAPICertificateSelector.cpp",
"WindowsNotifier.cpp",
"#/Swift/resources/Windows/Swift.res"
]
if env["PLATFORM"] == "posix" :
sources += [
"FreeDesktopNotifier.cpp",
"QtDBUSURIHandler.cpp",
]
if env["PLATFORM"] == "darwin" or env["PLATFORM"] == "win32" :
swiftProgram = myenv.Program("Swift", sources)
else :
swiftProgram = myenv.Program("swift", sources)
if env["PLATFORM"] != "darwin" and env["PLATFORM"] != "win32" :
openURIProgram = myenv.Program("swift-open-uri", "swift-open-uri.cpp")
else :
openURIProgram = []
myenv.Uic4("MUCSearch/QtMUCSearchWindow.ui")
myenv.Uic4("UserSearch/QtUserSearchWizard.ui")
myenv.Uic4("UserSearch/QtUserSearchFirstPage.ui")
myenv.Uic4("UserSearch/QtUserSearchFieldsPage.ui")
myenv.Uic4("UserSearch/QtUserSearchResultsPage.ui")
myenv.Uic4("QtBookmarkDetailWindow.ui")
myenv.Uic4("QtAffiliationEditor.ui")
myenv.Uic4("QtJoinMUCWindow.ui")
myenv.Qrc("DefaultTheme.qrc")
myenv.Qrc("Swift.qrc")
# Resources
commonResources = {
"": ["#/Swift/resources/sounds"]
}
diff --git a/Swiften/Client/CoreClient.cpp b/Swiften/Client/CoreClient.cpp
index de12fb7..36bfe35 100644
--- a/Swiften/Client/CoreClient.cpp
+++ b/Swiften/Client/CoreClient.cpp
@@ -94,89 +94,114 @@ void CoreClient::connect(const std::string& host) {
networkFactories->getDomainNameResolver(),
host,
options.boshHTTPConnectProxyURL,
options.boshHTTPConnectProxyAuthID,
options.boshHTTPConnectProxyAuthPassword));
sessionStream_->onDataRead.connect(boost::bind(&CoreClient::handleDataRead, this, _1));
sessionStream_->onDataWritten.connect(boost::bind(&CoreClient::handleDataWritten, this, _1));
bindSessionToStream();
}
}
void CoreClient::bindSessionToStream() {
session_ = ClientSession::create(jid_, sessionStream_);
session_->setCertificateTrustChecker(certificateTrustChecker);
session_->setUseStreamCompression(options.useStreamCompression);
session_->setAllowPLAINOverNonTLS(options.allowPLAINWithoutTLS);
switch(options.useTLS) {
case ClientOptions::UseTLSWhenAvailable:
session_->setUseTLS(ClientSession::UseTLSWhenAvailable);
break;
case ClientOptions::NeverUseTLS:
session_->setUseTLS(ClientSession::NeverUseTLS);
break;
case ClientOptions::RequireTLS:
session_->setUseTLS(ClientSession::RequireTLS);
break;
}
session_->setUseAcks(options.useAcks);
stanzaChannel_->setSession(session_);
session_->onFinished.connect(boost::bind(&CoreClient::handleSessionFinished, this, _1));
session_->onNeedCredentials.connect(boost::bind(&CoreClient::handleNeedCredentials, this));
session_->start();
}
+bool CoreClient::isCAPIURI() {
+#ifdef HAVE_SCHANNEL
+ if (!boost::iequals(certificate_.substr(0, 10), "certstore:")) {
+ return false;
+ }
+
+ return true;
+
+#else
+ return false;
+#endif
+}
+
/**
* Only called for TCP sessions. BOSH is handled inside the BOSHSessionStream.
*/
void CoreClient::handleConnectorFinished(boost::shared_ptr<Connection> connection) {
resetConnector();
if (!connection) {
if (options.forgetPassword) {
purgePassword();
}
onDisconnected(disconnectRequested_ ? boost::optional<ClientError>() : boost::optional<ClientError>(ClientError::ConnectionError));
}
else {
assert(!connection_);
connection_ = connection;
assert(!sessionStream_);
sessionStream_ = boost::make_shared<BasicSessionStream>(ClientStreamType, connection_, getPayloadParserFactories(), getPayloadSerializers(), networkFactories->getTLSContextFactory(), networkFactories->getTimerFactory(), networkFactories->getXMLParserFactory());
if (!certificate_.empty()) {
- sessionStream_->setTLSCertificate(PKCS12Certificate(certificate_, password_));
+ CertificateWithKey* cert;
+
+#if defined(SWIFTEN_PLATFORM_WIN32)
+ if (isCAPIURI()) {
+ cert = new CAPICertificate(certificate_);
+ } else {
+ cert = new PKCS12Certificate(certificate_, password_);
+ }
+#else
+ cert = new PKCS12Certificate(certificate_, password_);
+#endif
+
+ sessionStream_->setTLSCertificate(cert);
}
sessionStream_->onDataRead.connect(boost::bind(&CoreClient::handleDataRead, this, _1));
sessionStream_->onDataWritten.connect(boost::bind(&CoreClient::handleDataWritten, this, _1));
bindSessionToStream();
}
}
void CoreClient::disconnect() {
// FIXME: We should be able to do without this boolean. We just have to make sure we can tell the difference between
// connector finishing without a connection due to an error or because of a disconnect.
disconnectRequested_ = true;
if (session_ && !session_->isFinished()) {
session_->finish();
}
else if (connector_) {
connector_->stop();
}
}
void CoreClient::setCertificate(const std::string& certificate) {
certificate_ = certificate;
}
void CoreClient::handleSessionFinished(boost::shared_ptr<Error> error) {
if (options.forgetPassword) {
purgePassword();
}
resetSession();
boost::optional<ClientError> actualError;
if (error) {
ClientError clientError;
if (boost::shared_ptr<ClientSession::Error> actualError = boost::dynamic_pointer_cast<ClientSession::Error>(error)) {
switch(actualError->type) {
diff --git a/Swiften/Client/CoreClient.h b/Swiften/Client/CoreClient.h
index c231fdc..6712e03 100644
--- a/Swiften/Client/CoreClient.h
+++ b/Swiften/Client/CoreClient.h
@@ -164,69 +164,71 @@ namespace Swift {
*/
boost::signal<void (const SafeByteArray&)> onDataWritten;
/**
* Emitted when a message is received.
*/
boost::signal<void (boost::shared_ptr<Message>)> onMessageReceived;
/**
* Emitted when a presence stanza is received.
*/
boost::signal<void (boost::shared_ptr<Presence>) > onPresenceReceived;
/**
* Emitted when the server acknowledges receipt of a
* stanza (if acknowledgements are available).
*
* \see getStreamManagementEnabled()
*/
boost::signal<void (boost::shared_ptr<Stanza>)> onStanzaAcked;
protected:
boost::shared_ptr<ClientSession> getSession() const {
return session_;
}
NetworkFactories* getNetworkFactories() const {
return networkFactories;
}
/**
* Called before onConnected signal is emmitted.
*/
virtual void handleConnected() {};
+ bool isCAPIURI();
+
private:
void handleConnectorFinished(boost::shared_ptr<Connection>);
void handleStanzaChannelAvailableChanged(bool available);
void handleSessionFinished(boost::shared_ptr<Error>);
void handleNeedCredentials();
void handleDataRead(const SafeByteArray&);
void handleDataWritten(const SafeByteArray&);
void handlePresenceReceived(boost::shared_ptr<Presence>);
void handleMessageReceived(boost::shared_ptr<Message>);
void handleStanzaAcked(boost::shared_ptr<Stanza>);
void purgePassword();
void bindSessionToStream();
void resetConnector();
void resetSession();
void forceReset();
private:
JID jid_;
SafeByteArray password_;
NetworkFactories* networkFactories;
ClientSessionStanzaChannel* stanzaChannel_;
IQRouter* iqRouter_;
ClientOptions options;
boost::shared_ptr<ChainedConnector> connector_;
std::vector<ConnectionFactory*> proxyConnectionFactories;
boost::shared_ptr<Connection> connection_;
boost::shared_ptr<SessionStream> sessionStream_;
boost::shared_ptr<ClientSession> session_;
std::string certificate_;
bool disconnectRequested_;
CertificateTrustChecker* certificateTrustChecker;
};
}
diff --git a/Swiften/Session/SessionStream.cpp b/Swiften/Session/SessionStream.cpp
index 0d73b63..487ad8b 100644
--- a/Swiften/Session/SessionStream.cpp
+++ b/Swiften/Session/SessionStream.cpp
@@ -1,14 +1,15 @@
/*
* Copyright (c) 2010 Remko Tronçon
* Licensed under the GNU General Public License v3.
* See Documentation/Licenses/GPLv3.txt for more information.
*/
#include <Swiften/Session/SessionStream.h>
namespace Swift {
SessionStream::~SessionStream() {
+ delete certificate;
}
};
diff --git a/Swiften/Session/SessionStream.h b/Swiften/Session/SessionStream.h
index 096f185..58015b3 100644
--- a/Swiften/Session/SessionStream.h
+++ b/Swiften/Session/SessionStream.h
@@ -1,87 +1,89 @@
/*
* Copyright (c) 2010 Remko Tronçon
* Licensed under the GNU General Public License v3.
* See Documentation/Licenses/GPLv3.txt for more information.
*/
#pragma once
#include <Swiften/Base/boost_bsignals.h>
#include <boost/shared_ptr.hpp>
#include <boost/optional.hpp>
#include <Swiften/Elements/ProtocolHeader.h>
#include <Swiften/Elements/Element.h>
#include <Swiften/Base/Error.h>
#include <Swiften/Base/SafeByteArray.h>
-#include <Swiften/TLS/PKCS12Certificate.h>
+#include <Swiften/TLS/CertificateWithKey.h>
#include <Swiften/TLS/Certificate.h>
#include <Swiften/TLS/CertificateVerificationError.h>
namespace Swift {
class SessionStream {
public:
class Error : public Swift::Error {
public:
enum Type {
ParseError,
TLSError,
InvalidTLSCertificateError,
ConnectionReadError,
ConnectionWriteError
};
Error(Type type) : type(type) {}
Type type;
};
+ SessionStream(): certificate(0) {}
+
virtual ~SessionStream();
virtual void close() = 0;
virtual bool isOpen() = 0;
virtual void writeHeader(const ProtocolHeader& header) = 0;
virtual void writeFooter() = 0;
virtual void writeElement(boost::shared_ptr<Element>) = 0;
virtual void writeData(const std::string& data) = 0;
virtual bool supportsZLibCompression() = 0;
virtual void addZLibCompression() = 0;
virtual bool supportsTLSEncryption() = 0;
virtual void addTLSEncryption() = 0;
virtual bool isTLSEncrypted() = 0;
virtual void setWhitespacePingEnabled(bool enabled) = 0;
virtual void resetXMPPParser() = 0;
- void setTLSCertificate(const PKCS12Certificate& cert) {
+ void setTLSCertificate(CertificateWithKey* cert) {
certificate = cert;
}
virtual bool hasTLSCertificate() {
- return !certificate.isNull();
+ return certificate && !certificate->isNull();
}
virtual Certificate::ref getPeerCertificate() const = 0;
virtual boost::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const = 0;
virtual ByteArray getTLSFinishMessage() const = 0;
boost::signal<void (const ProtocolHeader&)> onStreamStartReceived;
boost::signal<void (boost::shared_ptr<Element>)> onElementReceived;
boost::signal<void (boost::shared_ptr<Error>)> onClosed;
boost::signal<void ()> onTLSEncrypted;
boost::signal<void (const SafeByteArray&)> onDataRead;
boost::signal<void (const SafeByteArray&)> onDataWritten;
protected:
- const PKCS12Certificate& getTLSCertificate() const {
+ CertificateWithKey * getTLSCertificate() const {
return certificate;
}
private:
- PKCS12Certificate certificate;
+ CertificateWithKey * certificate;
};
}
diff --git a/Swiften/StreamStack/TLSLayer.cpp b/Swiften/StreamStack/TLSLayer.cpp
index 6f2223d..b7efbcb 100644
--- a/Swiften/StreamStack/TLSLayer.cpp
+++ b/Swiften/StreamStack/TLSLayer.cpp
@@ -5,48 +5,48 @@
*/
#include <Swiften/StreamStack/TLSLayer.h>
#include <boost/bind.hpp>
#include <Swiften/TLS/TLSContextFactory.h>
#include <Swiften/TLS/TLSContext.h>
namespace Swift {
TLSLayer::TLSLayer(TLSContextFactory* factory) {
context = factory->createTLSContext();
context->onDataForNetwork.connect(boost::bind(&TLSLayer::writeDataToChildLayer, this, _1));
context->onDataForApplication.connect(boost::bind(&TLSLayer::writeDataToParentLayer, this, _1));
context->onConnected.connect(onConnected);
context->onError.connect(onError);
}
TLSLayer::~TLSLayer() {
delete context;
}
void TLSLayer::connect() {
context->connect();
}
void TLSLayer::writeData(const SafeByteArray& data) {
context->handleDataFromApplication(data);
}
void TLSLayer::handleDataRead(const SafeByteArray& data) {
context->handleDataFromNetwork(data);
}
-bool TLSLayer::setClientCertificate(const PKCS12Certificate& certificate) {
+bool TLSLayer::setClientCertificate(CertificateWithKey * certificate) {
return context->setClientCertificate(certificate);
}
Certificate::ref TLSLayer::getPeerCertificate() const {
return context->getPeerCertificate();
}
boost::shared_ptr<CertificateVerificationError> TLSLayer::getPeerCertificateVerificationError() const {
return context->getPeerCertificateVerificationError();
}
}
diff --git a/Swiften/StreamStack/TLSLayer.h b/Swiften/StreamStack/TLSLayer.h
index a8693d5..6dc9135 100644
--- a/Swiften/StreamStack/TLSLayer.h
+++ b/Swiften/StreamStack/TLSLayer.h
@@ -1,44 +1,44 @@
/*
* Copyright (c) 2010 Remko Tronçon
* Licensed under the GNU General Public License v3.
* See Documentation/Licenses/GPLv3.txt for more information.
*/
#include <Swiften/Base/boost_bsignals.h>
#include <Swiften/Base/SafeByteArray.h>
#include <Swiften/StreamStack/StreamLayer.h>
#include <Swiften/TLS/Certificate.h>
#include <Swiften/TLS/CertificateVerificationError.h>
namespace Swift {
class TLSContext;
class TLSContextFactory;
- class PKCS12Certificate;
+ class CertificateWithKey;
class TLSLayer : public StreamLayer {
public:
TLSLayer(TLSContextFactory*);
~TLSLayer();
void connect();
- bool setClientCertificate(const PKCS12Certificate&);
+ bool setClientCertificate(CertificateWithKey * cert);
Certificate::ref getPeerCertificate() const;
boost::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const;
void writeData(const SafeByteArray& data);
void handleDataRead(const SafeByteArray& data);
TLSContext* getContext() const {
return context;
}
public:
boost::signal<void ()> onError;
boost::signal<void ()> onConnected;
private:
TLSContext* context;
};
}
diff --git a/Swiften/TLS/CAPICertificate.h b/Swiften/TLS/CAPICertificate.h
new file mode 100644
index 0000000..fcdb4c2
--- /dev/null
+++ b/Swiften/TLS/CAPICertificate.h
@@ -0,0 +1,196 @@
+/*
+ * Copyright (c) 2012 Isode Limited, London, England.
+ * Licensed under the simplified BSD license.
+ * See Documentation/Licenses/BSD-simplified.txt for more information.
+ */
+
+#pragma once
+
+#include <Swiften/Base/SafeByteArray.h>
+#include <Swiften/TLS/CertificateWithKey.h>
+
+#include <boost/algorithm/string/predicate.hpp>
+
+#define SECURITY_WIN32
+#include <WinCrypt.h>
+
+namespace Swift {
+ class CAPICertificate : public Swift::CertificateWithKey {
+ public:
+ CAPICertificate(const std::string& capiUri)
+ : valid_(false), uri_(capiUri), cert_store_handle_(0), cert_store_(NULL), cert_name_(NULL) {
+ setUri(capiUri);
+ }
+
+ virtual ~CAPICertificate() {
+ if (cert_store_handle_ != NULL)
+ {
+ CertCloseStore(cert_store_handle_, 0);
+ }
+ }
+
+ virtual bool isNull() const {
+ return uri_.empty() || !valid_;
+ }
+
+ virtual bool isPrivateKeyExportable() const {
+ /* We can check with CAPI, but for now the answer is "no" */
+ return false;
+ }
+
+ virtual const std::string& getCertStoreName() const {
+ return cert_store_;
+ }
+
+ virtual const std::string& getCertName() const {
+ return cert_name_;
+ }
+
+ const ByteArray& getData() const {
+////Might need to throw an exception here, or really generate PKCS12 blob from CAPI data?
+ assert(0);
+ }
+
+ void setData(const ByteArray& data) {
+ assert(0);
+ }
+
+ const SafeByteArray& getPassword() const {
+/////Can't pass NULL to createSafeByteArray!
+/////Should this throw an exception instead?
+ return createSafeByteArray("");
+ }
+
+ protected:
+ void setUri (const std::string& capiUri) {
+
+ valid_ = false;
+
+ /* Syntax: "certstore:" [<cert_store> ":"] <cert_id> */
+
+ if (!boost::iequals(capiUri.substr(0, 10), "certstore:")) {
+ return;
+ }
+
+ /* Substring of subject: uses "storename" */
+ std::string capi_identity = capiUri.substr(10);
+ std::string new_cert_store_name;
+ size_t pos = capi_identity.find_first_of (':');
+
+ if (pos == std::string::npos) {
+ /* Using the default certificate store */
+ new_cert_store_name = "MY";
+ cert_name_ = capi_identity;
+ } else {
+ new_cert_store_name = capi_identity.substr(0, pos);
+ cert_name_ = capi_identity.substr(pos + 1);
+ }
+
+ PCCERT_CONTEXT pCertContext = NULL;
+
+ if (cert_store_handle_ != NULL)
+ {
+ if (new_cert_store_name != cert_store_) {
+ CertCloseStore(cert_store_handle_, 0);
+ cert_store_handle_ = NULL;
+ }
+ }
+
+ if (cert_store_handle_ == NULL)
+ {
+ cert_store_handle_ = CertOpenSystemStore(0, cert_store_.c_str());
+ if (!cert_store_handle_)
+ {
+ return;
+ }
+ }
+
+ cert_store_ = new_cert_store_name;
+
+ /* NB: This might have to change, depending on how we locate certificates */
+
+ // Find client certificate. Note that this sample just searches for a
+ // certificate that contains the user name somewhere in the subject name.
+ pCertContext = CertFindCertificateInStore(cert_store_handle_,
+ X509_ASN_ENCODING,
+ 0, // dwFindFlags
+ CERT_FIND_SUBJECT_STR_A,
+ cert_name_.c_str(), // *pvFindPara
+ NULL ); // pPrevCertContext
+
+ if (pCertContext == NULL)
+ {
+ return;
+ }
+
+
+ /* Now verify that we can have access to the corresponding private key */
+
+ DWORD len;
+ CRYPT_KEY_PROV_INFO *pinfo;
+ HCRYPTPROV hprov;
+ HCRYPTKEY key;
+
+ if (!CertGetCertificateContextProperty(pCertContext,
+ CERT_KEY_PROV_INFO_PROP_ID,
+ NULL,
+ &len))
+ {
+ CertFreeCertificateContext(pCertContext);
+ return;
+ }
+
+ pinfo = static_cast<CRYPT_KEY_PROV_INFO *>(malloc(len));
+ if (!pinfo) {
+ CertFreeCertificateContext(pCertContext);
+ return;
+ }
+
+ if (!CertGetCertificateContextProperty(pCertContext,
+ CERT_KEY_PROV_INFO_PROP_ID,
+ pinfo,
+ &len))
+ {
+ CertFreeCertificateContext(pCertContext);
+ free(pinfo);
+ return;
+ }
+
+ CertFreeCertificateContext(pCertContext);
+
+ // Now verify if we have access to the private key
+ if (!CryptAcquireContextW(&hprov,
+ pinfo->pwszContainerName,
+ pinfo->pwszProvName,
+ pinfo->dwProvType,
+ 0))
+ {
+ free(pinfo);
+ return;
+ }
+
+ if (!CryptGetUserKey(hprov, pinfo->dwKeySpec, &key))
+ {
+ CryptReleaseContext(hprov, 0);
+ free(pinfo);
+ return;
+ }
+
+ CryptDestroyKey(key);
+ CryptReleaseContext(hprov, 0);
+ free(pinfo);
+
+ valid_ = true;
+ }
+
+ private:
+ bool valid_;
+ std::string uri_;
+
+ HCERTSTORE cert_store_handle_;
+
+ /* Parsed components of the uri_ */
+ std::string cert_store_;
+ std::string cert_name_;
+ };
+}
diff --git a/Swiften/TLS/CertificateWithKey.h b/Swiften/TLS/CertificateWithKey.h
new file mode 100644
index 0000000..6f6ea39
--- /dev/null
+++ b/Swiften/TLS/CertificateWithKey.h
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) 2010-2012 Remko Tronçon
+ * Licensed under the GNU General Public License v3.
+ * See Documentation/Licenses/GPLv3.txt for more information.
+ */
+
+#pragma once
+
+#include <Swiften/Base/SafeByteArray.h>
+
+namespace Swift {
+ class CertificateWithKey {
+ public:
+ CertificateWithKey() {}
+
+ virtual ~CertificateWithKey() {}
+
+ virtual bool isNull() const = 0;
+
+ virtual bool isPrivateKeyExportable() const = 0;
+
+ virtual const std::string& getCertStoreName() const = 0;
+
+ virtual const std::string& getCertName() const = 0;
+
+ virtual const ByteArray& getData() const = 0;
+
+ virtual void setData(const ByteArray& data) = 0;
+
+ virtual const SafeByteArray& getPassword() const = 0;
+ };
+}
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
index 220e7f9..dd3462f 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
+++ b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
@@ -1,59 +1,59 @@
/*
* Copyright (c) 2010 Remko Tronçon
* Licensed under the GNU General Public License v3.
* See Documentation/Licenses/GPLv3.txt for more information.
*/
#include <Swiften/Base/Platform.h>
#ifdef SWIFTEN_PLATFORM_WINDOWS
#include <windows.h>
#include <wincrypt.h>
#endif
#include <vector>
#include <openssl/err.h>
#include <openssl/pkcs12.h>
#include <boost/smart_ptr/make_shared.hpp>
#if defined(SWIFTEN_PLATFORM_MACOSX) && OPENSSL_VERSION_NUMBER < 0x00908000
#include <Security/Security.h>
#endif
#include <Swiften/TLS/OpenSSL/OpenSSLContext.h>
#include <Swiften/TLS/OpenSSL/OpenSSLCertificate.h>
-#include <Swiften/TLS/PKCS12Certificate.h>
+#include <Swiften/TLS/CertificateWithKey.h>
#pragma GCC diagnostic ignored "-Wold-style-cast"
namespace Swift {
static const int MAX_FINISHED_SIZE = 4096;
static const int SSL_READ_BUFFERSIZE = 8192;
void freeX509Stack(STACK_OF(X509)* stack) {
sk_X509_free(stack);
}
OpenSSLContext::OpenSSLContext() : state_(Start), context_(0), handle_(0), readBIO_(0), writeBIO_(0) {
ensureLibraryInitialized();
context_ = SSL_CTX_new(TLSv1_client_method());
// Load system certs
#if defined(SWIFTEN_PLATFORM_WINDOWS)
X509_STORE* store = SSL_CTX_get_cert_store(context_);
HCERTSTORE systemStore = CertOpenSystemStore(0, "ROOT");
if (systemStore) {
PCCERT_CONTEXT certContext = NULL;
while (true) {
certContext = CertFindCertificateInStore(systemStore, X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, 0, CERT_FIND_ANY, NULL, certContext);
if (!certContext) {
break;
}
OpenSSLCertificate cert(createByteArray(certContext->pbCertEncoded, certContext->cbCertEncoded));
if (store && cert.getInternalX509()) {
X509_STORE_add_cert(store, cert.getInternalX509().get());
}
}
}
#elif !defined(SWIFTEN_PLATFORM_MACOSX)
SSL_CTX_load_verify_locations(context_, NULL, "/etc/ssl/certs");
@@ -153,89 +153,93 @@ void OpenSSLContext::handleDataFromNetwork(const SafeByteArray& data) {
break;
case Connected:
sendPendingDataToApplication();
break;
case Start: assert(false); break;
case Error: /*assert(false);*/ break;
}
}
void OpenSSLContext::handleDataFromApplication(const SafeByteArray& data) {
if (SSL_write(handle_, vecptr(data), data.size()) >= 0) {
sendPendingDataToNetwork();
}
else {
state_ = Error;
onError();
}
}
void OpenSSLContext::sendPendingDataToApplication() {
SafeByteArray data;
data.resize(SSL_READ_BUFFERSIZE);
int ret = SSL_read(handle_, vecptr(data), data.size());
while (ret > 0) {
data.resize(ret);
onDataForApplication(data);
data.resize(SSL_READ_BUFFERSIZE);
ret = SSL_read(handle_, vecptr(data), data.size());
}
if (ret < 0 && SSL_get_error(handle_, ret) != SSL_ERROR_WANT_READ) {
state_ = Error;
onError();
}
}
-bool OpenSSLContext::setClientCertificate(const PKCS12Certificate& certificate) {
- if (certificate.isNull()) {
+bool OpenSSLContext::setClientCertificate(CertificateWithKey * certificate) {
+ if (!certificate || certificate->isNull()) {
+ return false;
+ }
+
+ if (!certificate->isPrivateKeyExportable()) {
return false;
}
// Create a PKCS12 structure
BIO* bio = BIO_new(BIO_s_mem());
- BIO_write(bio, vecptr(certificate.getData()), certificate.getData().size());
+ BIO_write(bio, vecptr(certificate->getData()), certificate->getData().size());
boost::shared_ptr<PKCS12> pkcs12(d2i_PKCS12_bio(bio, NULL), PKCS12_free);
BIO_free(bio);
if (!pkcs12) {
return false;
}
// Parse PKCS12
X509 *certPtr = 0;
EVP_PKEY* privateKeyPtr = 0;
STACK_OF(X509)* caCertsPtr = 0;
- int result = PKCS12_parse(pkcs12.get(), reinterpret_cast<const char*>(vecptr(certificate.getPassword())), &privateKeyPtr, &certPtr, &caCertsPtr);
+ int result = PKCS12_parse(pkcs12.get(), reinterpret_cast<const char*>(vecptr(certificate->getPassword())), &privateKeyPtr, &certPtr, &caCertsPtr);
if (result != 1) {
return false;
}
boost::shared_ptr<X509> cert(certPtr, X509_free);
boost::shared_ptr<EVP_PKEY> privateKey(privateKeyPtr, EVP_PKEY_free);
boost::shared_ptr<STACK_OF(X509)> caCerts(caCertsPtr, freeX509Stack);
// Use the key & certificates
if (SSL_CTX_use_certificate(context_, cert.get()) != 1) {
return false;
}
if (SSL_CTX_use_PrivateKey(context_, privateKey.get()) != 1) {
return false;
}
for (int i = 0; i < sk_X509_num(caCerts.get()); ++i) {
SSL_CTX_add_extra_chain_cert(context_, sk_X509_value(caCerts.get(), i));
}
return true;
}
Certificate::ref OpenSSLContext::getPeerCertificate() const {
boost::shared_ptr<X509> x509Cert(SSL_get_peer_certificate(handle_), X509_free);
if (x509Cert) {
return boost::make_shared<OpenSSLCertificate>(x509Cert);
}
else {
return Certificate::ref();
}
}
boost::shared_ptr<CertificateVerificationError> OpenSSLContext::getPeerCertificateVerificationError() const {
int verifyResult = SSL_get_verify_result(handle_);
if (verifyResult != X509_V_OK) {
return boost::make_shared<CertificateVerificationError>(getVerificationErrorTypeForResult(verifyResult));
}
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.h b/Swiften/TLS/OpenSSL/OpenSSLContext.h
index 04693a3..b53e715 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLContext.h
+++ b/Swiften/TLS/OpenSSL/OpenSSLContext.h
@@ -1,53 +1,53 @@
/*
* Copyright (c) 2010 Remko Tronçon
* Licensed under the GNU General Public License v3.
* See Documentation/Licenses/GPLv3.txt for more information.
*/
#pragma once
#include <openssl/ssl.h>
#include <Swiften/Base/boost_bsignals.h>
#include <boost/noncopyable.hpp>
#include <Swiften/TLS/TLSContext.h>
#include <Swiften/Base/ByteArray.h>
namespace Swift {
- class PKCS12Certificate;
+ class CertificateWithKey;
class OpenSSLContext : public TLSContext, boost::noncopyable {
public:
OpenSSLContext();
~OpenSSLContext();
void connect();
- bool setClientCertificate(const PKCS12Certificate& cert);
+ bool setClientCertificate(CertificateWithKey * cert);
void handleDataFromNetwork(const SafeByteArray&);
void handleDataFromApplication(const SafeByteArray&);
Certificate::ref getPeerCertificate() const;
boost::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const;
virtual ByteArray getFinishMessage() const;
private:
static void ensureLibraryInitialized();
static CertificateVerificationError::Type getVerificationErrorTypeForResult(int);
void doConnect();
void sendPendingDataToNetwork();
void sendPendingDataToApplication();
private:
enum State { Start, Connecting, Connected, Error };
State state_;
SSL_CTX* context_;
SSL* handle_;
BIO* readBIO_;
BIO* writeBIO_;
};
}
diff --git a/Swiften/TLS/PKCS12Certificate.h b/Swiften/TLS/PKCS12Certificate.h
index c0e01d0..2f70456 100644
--- a/Swiften/TLS/PKCS12Certificate.h
+++ b/Swiften/TLS/PKCS12Certificate.h
@@ -1,40 +1,59 @@
/*
* Copyright (c) 2010 Remko Tronçon
* Licensed under the GNU General Public License v3.
* See Documentation/Licenses/GPLv3.txt for more information.
*/
#pragma once
#include <Swiften/Base/SafeByteArray.h>
+#include <Swiften/TLS/CertificateWithKey.h>
namespace Swift {
- class PKCS12Certificate {
+ class PKCS12Certificate : public Swift::CertificateWithKey {
public:
PKCS12Certificate() {}
PKCS12Certificate(const std::string& filename, const SafeByteArray& password) : password_(password) {
readByteArrayFromFile(data_, filename);
}
- bool isNull() const {
+ virtual ~PKCS12Certificate() {}
+
+ virtual bool isNull() const {
return data_.empty();
}
- const ByteArray& getData() const {
+ virtual bool isPrivateKeyExportable() const {
+/////Hopefully a PKCS12 is never missing a private key
+ return true;
+ }
+
+ virtual const std::string& getCertStoreName() const {
+///// assert(0);
+ throw std::exception();
+ }
+
+ virtual const std::string& getCertName() const {
+ /* We can return the original filename instead, if we care */
+///// assert(0);
+ throw std::exception();
+ }
+
+ virtual const ByteArray& getData() const {
return data_;
}
void setData(const ByteArray& data) {
data_ = data;
}
- const SafeByteArray& getPassword() const {
+ virtual const SafeByteArray& getPassword() const {
return password_;
}
private:
ByteArray data_;
SafeByteArray password_;
};
}
diff --git a/Swiften/TLS/Schannel/SchannelContext.cpp b/Swiften/TLS/Schannel/SchannelContext.cpp
index 6771d4a..6f50b3a 100644
--- a/Swiften/TLS/Schannel/SchannelContext.cpp
+++ b/Swiften/TLS/Schannel/SchannelContext.cpp
@@ -1,102 +1,163 @@
/*
* Copyright (c) 2011 Soren Dreijer
* Licensed under the simplified BSD license.
* See Documentation/Licenses/BSD-simplified.txt for more information.
*/
#include "Swiften/TLS/Schannel/SchannelContext.h"
#include "Swiften/TLS/Schannel/SchannelCertificate.h"
namespace Swift {
//------------------------------------------------------------------------
SchannelContext::SchannelContext()
: m_state(Start)
, m_secContext(0)
, m_verificationError(CertificateVerificationError::UnknownError)
+, m_my_cert_store(NULL)
+, m_cert_store_name("MY")
+, m_cert_name(NULL)
{
m_ctxtFlags = ISC_REQ_ALLOCATE_MEMORY |
ISC_REQ_CONFIDENTIALITY |
ISC_REQ_EXTENDED_ERROR |
ISC_REQ_INTEGRITY |
ISC_REQ_REPLAY_DETECT |
ISC_REQ_SEQUENCE_DETECT |
ISC_REQ_USE_SUPPLIED_CREDS |
ISC_REQ_STREAM;
ZeroMemory(&m_streamSizes, sizeof(m_streamSizes));
}
//------------------------------------------------------------------------
+SchannelContext::~SchannelContext()
+{
+ if (m_my_cert_store) CertCloseStore(m_my_cert_store, 0);
+}
+
+//------------------------------------------------------------------------
+
void SchannelContext::determineStreamSizes()
{
QueryContextAttributes(m_ctxtHandle, SECPKG_ATTR_STREAM_SIZES, &m_streamSizes);
}
//------------------------------------------------------------------------
void SchannelContext::connect()
{
+ PCCERT_CONTEXT pCertContext = NULL;
+
m_state = Connecting;
+ // If a user name is specified, then attempt to find a client
+ // certificate. Otherwise, just create a NULL credential.
+ if (!m_cert_name.empty())
+ {
+ if (m_my_cert_store == NULL)
+ {
+ m_my_cert_store = CertOpenSystemStore(0, m_cert_store_name.c_str());
+ if (!m_my_cert_store)
+ {
+///// printf( "**** Error 0x%x returned by CertOpenSystemStore\n", GetLastError() );
+ indicateError();
+ return;
+ }
+ }
+
+ // Find client certificate. Note that this sample just searches for a
+ // certificate that contains the user name somewhere in the subject name.
+ pCertContext = CertFindCertificateInStore( m_my_cert_store,
+ X509_ASN_ENCODING,
+ 0, // dwFindFlags
+ CERT_FIND_SUBJECT_STR_A,
+ m_cert_name.c_str(), // *pvFindPara
+ NULL ); // pPrevCertContext
+
+ if (pCertContext == NULL)
+ {
+///// printf("**** Error 0x%x returned by CertFindCertificateInStore\n", GetLastError());
+ indicateError();
+ return;
+ }
+ }
+
// We use an empty list for client certificates
PCCERT_CONTEXT clientCerts[1] = {0};
SCHANNEL_CRED sc = {0};
sc.dwVersion = SCHANNEL_CRED_VERSION;
- sc.cCreds = 0; // Let Crypto API find the appropriate certificate for us
- sc.paCred = clientCerts;
+
+/////SSL3?
sc.grbitEnabledProtocols = SP_PROT_SSL3_CLIENT | SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_2_CLIENT;
- sc.dwFlags = SCH_CRED_AUTO_CRED_VALIDATION | /*SCH_CRED_NO_DEFAULT_CREDS*/ SCH_CRED_USE_DEFAULT_CREDS | SCH_CRED_REVOCATION_CHECK_CHAIN;
+/////Check SCH_CRED_REVOCATION_CHECK_CHAIN
+ sc.dwFlags = SCH_CRED_AUTO_CRED_VALIDATION | SCH_CRED_REVOCATION_CHECK_CHAIN;
+
+ if (pCertContext)
+ {
+ sc.cCreds = 1;
+ sc.paCred = &pCertContext;
+ sc.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS;
+ }
+ else
+ {
+ sc.cCreds = 0; // Let Crypto API find the appropriate certificate for us
+ sc.paCred = clientCerts;
+ sc.dwFlags |= SCH_CRED_USE_DEFAULT_CREDS;
+ }
// Swiften performs the server name check for us
sc.dwFlags |= SCH_CRED_NO_SERVERNAME_CHECK;
SECURITY_STATUS status = AcquireCredentialsHandle(
NULL,
UNISP_NAME,
SECPKG_CRED_OUTBOUND,
NULL,
&sc,
NULL,
NULL,
m_credHandle.Reset(),
NULL);
+ // cleanup: Free the certificate context. Schannel has already made its own copy.
+ if (pCertContext) CertFreeCertificateContext(pCertContext);
+
if (status != SEC_E_OK)
{
// We failed to obtain the credentials handle
indicateError();
return;
}
SecBuffer outBuffers[2];
// We let Schannel allocate the output buffer for us
outBuffers[0].pvBuffer = NULL;
outBuffers[0].cbBuffer = 0;
outBuffers[0].BufferType = SECBUFFER_TOKEN;
// Contains alert data if an alert is generated
outBuffers[1].pvBuffer = NULL;
outBuffers[1].cbBuffer = 0;
outBuffers[1].BufferType = SECBUFFER_ALERT;
// Make sure the output buffers are freed
ScopedSecBuffer scopedOutputData(&outBuffers[0]);
ScopedSecBuffer scopedOutputAlertData(&outBuffers[1]);
SecBufferDesc outBufferDesc = {0};
outBufferDesc.cBuffers = 2;
outBufferDesc.pBuffers = outBuffers;
outBufferDesc.ulVersion = SECBUFFER_VERSION;
// Create the initial security context
status = InitializeSecurityContext(
m_credHandle,
NULL,
NULL,
m_ctxtFlags,
0,
@@ -424,72 +485,85 @@ void SchannelContext::encryptAndSendData(const SafeByteArray& data)
outBuffers[0].cbBuffer = m_streamSizes.cbHeader;
outBuffers[0].BufferType = SECBUFFER_STREAM_HEADER;
outBuffers[1].pvBuffer = &sendBuffer[0] + m_streamSizes.cbHeader;
outBuffers[1].cbBuffer = (unsigned long)bytesToSend;
outBuffers[1].BufferType = SECBUFFER_DATA;
outBuffers[2].pvBuffer = &sendBuffer[0] + m_streamSizes.cbHeader + bytesToSend;
outBuffers[2].cbBuffer = m_streamSizes.cbTrailer;
outBuffers[2].BufferType = SECBUFFER_STREAM_TRAILER;
outBuffers[3].pvBuffer = 0;
outBuffers[3].cbBuffer = 0;
outBuffers[3].BufferType = SECBUFFER_EMPTY;
SecBufferDesc outBufferDesc = {0};
outBufferDesc.cBuffers = 4;
outBufferDesc.pBuffers = outBuffers;
outBufferDesc.ulVersion = SECBUFFER_VERSION;
SECURITY_STATUS status = EncryptMessage(m_ctxtHandle, 0, &outBufferDesc, 0);
if (status != SEC_E_OK)
{
indicateError();
return;
}
sendDataOnNetwork(&sendBuffer[0], outBuffers[0].cbBuffer + outBuffers[1].cbBuffer + outBuffers[2].cbBuffer);
bytesSent += bytesToSend;
} while (bytesSent < data.size());
}
//------------------------------------------------------------------------
-bool SchannelContext::setClientCertificate(const PKCS12Certificate& certificate)
+bool SchannelContext::setClientCertificate(CertificateWithKey * certificate)
{
+ if (!certificate || certificate->isNull()) {
+ return false;
+ }
+
+ if (!certificate->isPrivateKeyExportable()) {
+ // We assume that the Certificate Store Name/Certificate Name
+ // are valid at this point
+ m_cert_store_name = certificate->getCertStoreName();
+ m_cert_name = certificate->getCertName();
+
+ return true;
+ }
+
return false;
}
//------------------------------------------------------------------------
Certificate::ref SchannelContext::getPeerCertificate() const
{
SchannelCertificate::ref pCertificate;
ScopedCertContext pServerCert;
SECURITY_STATUS status = QueryContextAttributes(m_ctxtHandle, SECPKG_ATTR_REMOTE_CERT_CONTEXT, pServerCert.Reset());
if (status != SEC_E_OK)
return pCertificate;
pCertificate.reset( new SchannelCertificate(pServerCert) );
return pCertificate;
}
//------------------------------------------------------------------------
CertificateVerificationError::ref SchannelContext::getPeerCertificateVerificationError() const
{
boost::shared_ptr<CertificateVerificationError> pCertError;
if (m_state == Error)
pCertError.reset( new CertificateVerificationError(m_verificationError) );
return pCertError;
}
//------------------------------------------------------------------------
ByteArray SchannelContext::getFinishMessage() const
{
// TODO: Implement
diff --git a/Swiften/TLS/Schannel/SchannelContext.h b/Swiften/TLS/Schannel/SchannelContext.h
index 66467fe..0cdb3d7 100644
--- a/Swiften/TLS/Schannel/SchannelContext.h
+++ b/Swiften/TLS/Schannel/SchannelContext.h
@@ -1,81 +1,88 @@
/*
* Copyright (c) 2011 Soren Dreijer
* Licensed under the simplified BSD license.
* See Documentation/Licenses/BSD-simplified.txt for more information.
*/
#pragma once
#include "Swiften/Base/boost_bsignals.h"
#include "Swiften/TLS/TLSContext.h"
#include "Swiften/TLS/Schannel/SchannelUtil.h"
+#include <Swiften/TLS/CertificateWithKey.h>
#include "Swiften/Base/ByteArray.h"
#define SECURITY_WIN32
#include <Windows.h>
#include <Schannel.h>
#include <security.h>
#include <schnlsp.h>
#include <boost/noncopyable.hpp>
namespace Swift
{
class SchannelContext : public TLSContext, boost::noncopyable
{
public:
typedef boost::shared_ptr<SchannelContext> sp_t;
public:
- SchannelContext();
+ SchannelContext();
+
+ ~SchannelContext();
//
// TLSContext
//
virtual void connect();
- virtual bool setClientCertificate(const PKCS12Certificate&);
+ virtual bool setClientCertificate(CertificateWithKey * cert);
virtual void handleDataFromNetwork(const SafeByteArray& data);
virtual void handleDataFromApplication(const SafeByteArray& data);
virtual Certificate::ref getPeerCertificate() const;
virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const;
virtual ByteArray getFinishMessage() const;
private:
void determineStreamSizes();
void continueHandshake(const SafeByteArray& data);
void indicateError();
void sendDataOnNetwork(const void* pData, size_t dataSize);
void forwardDataToApplication(const void* pData, size_t dataSize);
void decryptAndProcessData(const SafeByteArray& data);
void encryptAndSendData(const SafeByteArray& data);
void appendNewData(const SafeByteArray& data);
private:
enum SchannelState
{
Start,
Connecting,
Connected,
Error
};
SchannelState m_state;
CertificateVerificationError m_verificationError;
ULONG m_secContext;
ScopedCredHandle m_credHandle;
ScopedCtxtHandle m_ctxtHandle;
DWORD m_ctxtFlags;
SecPkgContext_StreamSizes m_streamSizes;
std::vector<char> m_receivedData;
+
+ HCERTSTORE m_my_cert_store;
+ std::string m_cert_store_name;
+ std::string m_cert_name;
};
}
diff --git a/Swiften/TLS/TLSContext.h b/Swiften/TLS/TLSContext.h
index 1538863..ada813a 100644
--- a/Swiften/TLS/TLSContext.h
+++ b/Swiften/TLS/TLSContext.h
@@ -1,41 +1,41 @@
/*
* Copyright (c) 2010 Remko Tronçon
* Licensed under the GNU General Public License v3.
* See Documentation/Licenses/GPLv3.txt for more information.
*/
#pragma once
#include <Swiften/Base/boost_bsignals.h>
#include <boost/shared_ptr.hpp>
#include <Swiften/Base/SafeByteArray.h>
#include <Swiften/TLS/Certificate.h>
#include <Swiften/TLS/CertificateVerificationError.h>
namespace Swift {
- class PKCS12Certificate;
+ class CertificateWithKey;
class TLSContext {
public:
virtual ~TLSContext();
virtual void connect() = 0;
- virtual bool setClientCertificate(const PKCS12Certificate& cert) = 0;
+ virtual bool setClientCertificate(CertificateWithKey * cert) = 0;
virtual void handleDataFromNetwork(const SafeByteArray&) = 0;
virtual void handleDataFromApplication(const SafeByteArray&) = 0;
virtual Certificate::ref getPeerCertificate() const = 0;
virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const = 0;
virtual ByteArray getFinishMessage() const = 0;
public:
boost::signal<void (const SafeByteArray&)> onDataForNetwork;
boost::signal<void (const SafeByteArray&)> onDataForApplication;
boost::signal<void ()> onError;
boost::signal<void ()> onConnected;
};
}