diff options
-rw-r--r-- | Swift/QtUI/CAPICertificateSelector.cpp | 138 | ||||
-rw-r--r-- | Swift/QtUI/CAPICertificateSelector.h | 13 | ||||
-rw-r--r-- | Swift/QtUI/QtLoginWindow.cpp | 19 | ||||
-rw-r--r-- | Swift/QtUI/SConscript | 3 | ||||
-rw-r--r-- | Swiften/Client/CoreClient.cpp | 27 | ||||
-rw-r--r-- | Swiften/Client/CoreClient.h | 2 | ||||
-rw-r--r-- | Swiften/Session/SessionStream.cpp | 1 | ||||
-rw-r--r-- | Swiften/Session/SessionStream.h | 12 | ||||
-rw-r--r-- | Swiften/StreamStack/TLSLayer.cpp | 2 | ||||
-rw-r--r-- | Swiften/StreamStack/TLSLayer.h | 4 | ||||
-rw-r--r-- | Swiften/TLS/CAPICertificate.h | 196 | ||||
-rw-r--r-- | Swiften/TLS/CertificateWithKey.h | 32 | ||||
-rw-r--r-- | Swiften/TLS/OpenSSL/OpenSSLContext.cpp | 14 | ||||
-rw-r--r-- | Swiften/TLS/OpenSSL/OpenSSLContext.h | 4 | ||||
-rw-r--r-- | Swiften/TLS/PKCS12Certificate.h | 27 | ||||
-rw-r--r-- | Swiften/TLS/Schannel/SchannelContext.cpp | 82 | ||||
-rw-r--r-- | Swiften/TLS/Schannel/SchannelContext.h | 11 | ||||
-rw-r--r-- | Swiften/TLS/TLSContext.h | 4 |
18 files changed, 559 insertions, 32 deletions
diff --git a/Swift/QtUI/CAPICertificateSelector.cpp b/Swift/QtUI/CAPICertificateSelector.cpp new file mode 100644 index 0000000..44f5793 --- /dev/null +++ b/Swift/QtUI/CAPICertificateSelector.cpp @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2012 Isode Limited, London, England. + * Licensed under the simplified BSD license. + * See Documentation/Licenses/BSD-simplified.txt for more information. + */ + +#include <string> + +#define SECURITY_WIN32 +#include <Windows.h> +#include <WinCrypt.h> +#include <cryptuiapi.h> + +#include "CAPICertificateSelector.h" + +namespace Swift { + +#define cert_dlg_title L"TLS Client Certificate Selection" +#define cert_dlg_prompt L"Select a certificate to use for authentication" +/////Hmm, maybe we should not exlude the "location" column +#define exclude_columns CRYPTUI_SELECT_LOCATION_COLUMN \ + |CRYPTUI_SELECT_INTENDEDUSE_COLUMN + + + +static std::string getCertUri(PCCERT_CONTEXT cert, const char * cert_store_name) { + DWORD required_size; + char * comma; + char * p_in; + char * p_out; + char * subject_name; + std::string ret = std::string("certstore:") + cert_store_name + ":"; + + required_size = CertNameToStrA(cert->dwCertEncodingType, + &cert->pCertInfo->Subject, + /* Discard attribute names: */ + CERT_SIMPLE_NAME_STR | CERT_NAME_STR_REVERSE_FLAG, + NULL, + 0); + + subject_name = static_cast<char *>(malloc(required_size+1)); + + if (!CertNameToStrA(cert->dwCertEncodingType, + &cert->pCertInfo->Subject, + /* Discard attribute names: */ + CERT_SIMPLE_NAME_STR | CERT_NAME_STR_REVERSE_FLAG, + subject_name, + required_size)) { + return ""; + } + + /* Now search for the "," (ignoring escapes) + and truncate the rest of the string */ + if (subject_name[0] == '"') { + for (comma = subject_name + 1; comma[0]; comma++) { + if (comma[0] == '"') { + comma++; + if (comma[0] != '"') { + break; + } + } + } + } else { + comma = strchr(subject_name, ','); + } + + if (comma != NULL) { + *comma = '\0'; + } + + /* We now need to unescape the returned RDN */ + if (subject_name[0] == '"') { + for (p_in = subject_name + 1, p_out = subject_name; p_in[0]; p_in++, p_out++) { + if (p_in[0] == '"') { + p_in++; + } + + p_out[0] = p_in[0]; + } + p_out[0] = '\0'; + } + + ret += subject_name; + free(subject_name); + + return ret; +} + +std::string selectCAPICertificate() { + + const char * cert_store_name = "MY"; + PCCERT_CONTEXT cert; + DWORD store_flags; + HCERTSTORE hstore; + HWND hwnd; + + store_flags = CERT_STORE_OPEN_EXISTING_FLAG | + CERT_STORE_READONLY_FLAG | + CERT_SYSTEM_STORE_CURRENT_USER; + + hstore = CertOpenStore(CERT_STORE_PROV_SYSTEM_A, 0, 0, store_flags, cert_store_name); + if (!hstore) { + return ""; + } + + +////Does this handle need to be freed as well? + hwnd = GetForegroundWindow(); + if (!hwnd) { + hwnd = GetActiveWindow(); + } + + /* Call Windows dialog to select a suitable certificate */ + cert = CryptUIDlgSelectCertificateFromStore(hstore, + hwnd, + cert_dlg_title, + cert_dlg_prompt, + exclude_columns, + 0, + NULL); + + if (hstore) { + CertCloseStore(hstore, 0); + } + + if (cert) { + std::string ret = getCertUri(cert, cert_store_name); + + CertFreeCertificateContext(cert); + + return ret; + } else { + return ""; + } +} + + +} diff --git a/Swift/QtUI/CAPICertificateSelector.h b/Swift/QtUI/CAPICertificateSelector.h new file mode 100644 index 0000000..9a0ee92 --- /dev/null +++ b/Swift/QtUI/CAPICertificateSelector.h @@ -0,0 +1,13 @@ +/* + * Copyright (c) 2012 Isode Limited, London, England. + * Licensed under the simplified BSD license. + * See Documentation/Licenses/BSD-simplified.txt for more information. + */ + +#pragma once + +#include <string> + +namespace Swift { + std::string selectCAPICertificate(); +} diff --git a/Swift/QtUI/QtLoginWindow.cpp b/Swift/QtUI/QtLoginWindow.cpp index 1cd3206..6b9d389 100644 --- a/Swift/QtUI/QtLoginWindow.cpp +++ b/Swift/QtUI/QtLoginWindow.cpp @@ -9,70 +9,74 @@ #include <boost/bind.hpp> #include <boost/smart_ptr/make_shared.hpp> #include <algorithm> #include <cassert> #include <QApplication> #include <QBoxLayout> #include <QComboBox> #include <QDesktopWidget> #include <QFileDialog> #include <QStatusBar> #include <QToolButton> #include <QLabel> #include <QMenuBar> #include <QHBoxLayout> #include <qdebug.h> #include <QCloseEvent> #include <QCursor> #include <QMessageBox> #include <QKeyEvent> #include <Swift/Controllers/UIEvents/UIEventStream.h> #include <Swift/Controllers/UIEvents/RequestXMLConsoleUIEvent.h> #include <Swift/Controllers/UIEvents/RequestFileTransferListUIEvent.h> #include <Swift/Controllers/Settings/SettingsProvider.h> #include <Swift/Controllers/SettingConstants.h> #include <Swift/QtUI/QtUISettingConstants.h> #include <Swiften/Base/Platform.h> #include <Swiften/Base/Paths.h> #include <QtAboutWidget.h> #include <QtSwiftUtil.h> #include <QtMainWindow.h> #include <QtUtilities.h> +#ifdef HAVE_SCHANNEL +#include "CAPICertificateSelector.h" +#endif + namespace Swift{ QtLoginWindow::QtLoginWindow(UIEventStream* uiEventStream, SettingsProvider* settings) : QMainWindow(), settings_(settings) { uiEventStream_ = uiEventStream; setWindowTitle("Swift"); #ifndef Q_WS_MAC setWindowIcon(QIcon(":/logo-icon-16.png")); #endif QtUtilities::setX11Resource(this, "Main"); resize(200, 500); setContentsMargins(0,0,0,0); QWidget *centralWidget = new QWidget(this); setCentralWidget(centralWidget); QBoxLayout *topLayout = new QBoxLayout(QBoxLayout::TopToBottom, centralWidget); stack_ = new QStackedWidget(centralWidget); topLayout->addWidget(stack_); topLayout->setMargin(0); loginWidgetWrapper_ = new QWidget(this); loginWidgetWrapper_->setSizePolicy(QSizePolicy(QSizePolicy::Expanding, QSizePolicy::Expanding)); QBoxLayout *layout = new QBoxLayout(QBoxLayout::TopToBottom, loginWidgetWrapper_); layout->addStretch(2); QLabel* logo = new QLabel(this); logo->setPixmap(QPixmap(":/logo-shaded-text.256.png")); logo->setScaledContents(true); logo->setFixedSize(192,192); QWidget *logoWidget = new QWidget(this); QHBoxLayout *logoLayout = new QHBoxLayout(); logoLayout->setMargin(0); logoLayout->addStretch(0); logoLayout->addWidget(logo); logoLayout->addStretch(0); @@ -325,74 +329,81 @@ void QtLoginWindow::setIsLoggingIn(bool loggingIn) { bool eagle = settings_->getSetting(SettingConstants::FORGET_PASSWORDS); remember_->setEnabled(!eagle); loginAutomatically_->setEnabled(!eagle); } void QtLoginWindow::loginClicked() { if (username_->isEnabled()) { std::string banner = settings_->getSetting(QtUISettingConstants::CLICKTHROUGH_BANNER); if (!banner.empty()) { QMessageBox msgBox; msgBox.setWindowTitle(tr("Confirm terms of use")); msgBox.setText(""); msgBox.setInformativeText(P2QSTRING(banner)); msgBox.setStandardButtons(QMessageBox::Yes | QMessageBox::No); msgBox.setDefaultButton(QMessageBox::No); if (msgBox.exec() != QMessageBox::Yes) { return; } } onLoginRequest(Q2PSTRING(username_->currentText()), Q2PSTRING(password_->text()), Q2PSTRING(certificateFile_), remember_->isChecked(), loginAutomatically_->isChecked()); if (settings_->getSetting(SettingConstants::FORGET_PASSWORDS)) { /* Mustn't remember logins */ username_->clearEditText(); password_->setText(""); } } else { onCancelLoginRequest(); } } void QtLoginWindow::setLoginAutomatically(bool loginAutomatically) { loginAutomatically_->setChecked(loginAutomatically); } void QtLoginWindow::handleCertficateChecked(bool checked) { if (checked) { - certificateFile_ = QFileDialog::getOpenFileName(this, tr("Select an authentication certificate"), QString(), QString("*.cert;*.p12;*.pfx")); - if (certificateFile_.isEmpty()) { - certificateButton_->setChecked(false); - } +#ifdef HAVE_SCHANNEL + certificateFile_ = selectCAPICertificate(); + if (certificateFile_.isEmpty()) { + certificateButton_->setChecked(false); + } +#else + certificateFile_ = QFileDialog::getOpenFileName(this, tr("Select an authentication certificate"), QString(), QString("*.cert;*.p12;*.pfx")); + if (certificateFile_.isEmpty()) { + certificateButton_->setChecked(false); + } +#endif } else { certificateFile_ = ""; } } void QtLoginWindow::handleAbout() { if (!aboutDialog_) { aboutDialog_ = new QtAboutWidget(); aboutDialog_->show(); } else { aboutDialog_->show(); aboutDialog_->raise(); aboutDialog_->activateWindow(); } } void QtLoginWindow::handleShowXMLConsole() { uiEventStream_->send(boost::shared_ptr<RequestXMLConsoleUIEvent>(new RequestXMLConsoleUIEvent())); } void QtLoginWindow::handleShowFileTransferOverview() { uiEventStream_->send(boost::make_shared<RequestFileTransferListUIEvent>()); } void QtLoginWindow::handleToggleSounds(bool enabled) { settings_->storeSetting(SettingConstants::PLAY_SOUNDS, enabled); } void QtLoginWindow::handleToggleNotifications(bool enabled) { settings_->storeSetting(SettingConstants::SHOW_NOTIFICATIONS, enabled); } void QtLoginWindow::handleQuit() { diff --git a/Swift/QtUI/SConscript b/Swift/QtUI/SConscript index d37958f..a8b8c78 100644 --- a/Swift/QtUI/SConscript +++ b/Swift/QtUI/SConscript @@ -23,70 +23,72 @@ myenv = env.Clone() myenv["CXXFLAGS"] = filter(lambda x : x != "-Wfloat-equal", myenv["CXXFLAGS"]) myenv.UseFlags(env["SWIFT_CONTROLLERS_FLAGS"]) myenv.UseFlags(env["SWIFTOOLS_FLAGS"]) if myenv["HAVE_XSS"] : myenv.UseFlags(env["XSS_FLAGS"]) if env["PLATFORM"] == "posix" : myenv.Append(LIBS = ["X11"]) if myenv["HAVE_SPARKLE"] : myenv.UseFlags(env["SPARKLE_FLAGS"]) myenv.UseFlags(env["SWIFTEN_FLAGS"]) myenv.UseFlags(env["SWIFTEN_DEP_FLAGS"]) if myenv.get("HAVE_GROWL", False) : myenv.UseFlags(myenv["GROWL_FLAGS"]) myenv.Append(CPPDEFINES = ["HAVE_GROWL"]) if myenv["swift_mobile"] : myenv.Append(CPPDEFINES = ["SWIFT_MOBILE"]) if myenv.get("HAVE_SNARL", False) : myenv.UseFlags(myenv["SNARL_FLAGS"]) myenv.Append(CPPDEFINES = ["HAVE_SNARL"]) myenv.UseFlags(myenv["PLATFORM_FLAGS"]) myenv.Tool("qt4", toolpath = ["#/BuildTools/SCons/Tools"]) myenv.Tool("nsis", toolpath = ["#/BuildTools/SCons/Tools"]) myenv.Tool("wix", toolpath = ["#/BuildTools/SCons/Tools"]) qt4modules = ['QtCore', 'QtGui', 'QtWebKit'] if env["PLATFORM"] == "posix" : qt4modules += ["QtDBus"] myenv.EnableQt4Modules(qt4modules, debug = False) myenv.Append(CPPPATH = ["."]) if env["PLATFORM"] == "win32" : #myenv["LINKFLAGS"] = ["/SUBSYSTEM:CONSOLE"] myenv.Append(LINKFLAGS = ["/SUBSYSTEM:WINDOWS"]) myenv.Append(LIBS = "qtmain") + if myenv.get("HAVE_SCHANNEL", 0) : + myenv.Append(LIBS = "Cryptui") myenv.WriteVal("DefaultTheme.qrc", myenv.Value(generateDefaultTheme(myenv.Dir("#/Swift/resources/themes/Default")))) sources = [ "main.cpp", "QtAboutWidget.cpp", "QtAvatarWidget.cpp", "QtUIFactory.cpp", "QtChatWindowFactory.cpp", "QtChatWindow.cpp", "QtClickableLabel.cpp", "QtLoginWindow.cpp", "QtMainWindow.cpp", "QtProfileWindow.cpp", "QtNameWidget.cpp", "QtSettingsProvider.cpp", "QtStatusWidget.cpp", "QtScaledAvatarCache.cpp", "QtSwift.cpp", "QtURIHandler.cpp", "QtChatView.cpp", "QtChatTheme.cpp", "QtChatTabs.cpp", "QtSoundPlayer.cpp", "QtSystemTray.cpp", "QtCachedImageScaler.cpp", "QtTabbable.cpp", "QtTabWidget.cpp", "QtTextEdit.cpp", "QtXMLConsoleWidget.cpp", "QtFileTransferListWidget.cpp", "QtFileTransferListItemModel.cpp", "QtAdHocCommandWindow.cpp", "QtUtilities.cpp", "QtBookmarkDetailWindow.cpp", @@ -119,70 +121,71 @@ sources = [ "ChatList/ChatListModel.cpp", "ChatList/ChatListDelegate.cpp", "ChatList/ChatListMUCItem.cpp", "ChatList/ChatListRecentItem.cpp", "MUCSearch/QtMUCSearchWindow.cpp", "MUCSearch/MUCSearchModel.cpp", "MUCSearch/MUCSearchRoomItem.cpp", "MUCSearch/MUCSearchEmptyItem.cpp", "MUCSearch/MUCSearchDelegate.cpp", "UserSearch/QtUserSearchFirstPage.cpp", "UserSearch/QtUserSearchFieldsPage.cpp", "UserSearch/QtUserSearchResultsPage.cpp", "UserSearch/QtUserSearchDetailsPage.cpp", "UserSearch/QtUserSearchWindow.cpp", "UserSearch/UserSearchModel.cpp", "UserSearch/UserSearchDelegate.cpp", "QtSubscriptionRequestWindow.cpp", "QtRosterHeader.cpp", "QtWebView.cpp", "qrc_DefaultTheme.cc", "qrc_Swift.cc", "QtFileTransferJSBridge.cpp", "QtMUCConfigurationWindow.cpp", "QtAffiliationEditor.cpp", "QtUISettingConstants.cpp" ] myenv["SWIFT_VERSION"] = Version.getBuildVersion(env.Dir("#").abspath, "swift") if env["PLATFORM"] == "win32" : res = myenv.RES("#/Swift/resources/Windows/Swift.rc") # For some reason, SCons isn't picking up the dependency correctly # Adding it explicitly until i figure out why myenv.Depends(res, "../Controllers/BuildVersion.h") sources += [ + "CAPICertificateSelector.cpp", "WindowsNotifier.cpp", "#/Swift/resources/Windows/Swift.res" ] if env["PLATFORM"] == "posix" : sources += [ "FreeDesktopNotifier.cpp", "QtDBUSURIHandler.cpp", ] if env["PLATFORM"] == "darwin" or env["PLATFORM"] == "win32" : swiftProgram = myenv.Program("Swift", sources) else : swiftProgram = myenv.Program("swift", sources) if env["PLATFORM"] != "darwin" and env["PLATFORM"] != "win32" : openURIProgram = myenv.Program("swift-open-uri", "swift-open-uri.cpp") else : openURIProgram = [] myenv.Uic4("MUCSearch/QtMUCSearchWindow.ui") myenv.Uic4("UserSearch/QtUserSearchWizard.ui") myenv.Uic4("UserSearch/QtUserSearchFirstPage.ui") myenv.Uic4("UserSearch/QtUserSearchFieldsPage.ui") myenv.Uic4("UserSearch/QtUserSearchResultsPage.ui") myenv.Uic4("QtBookmarkDetailWindow.ui") myenv.Uic4("QtAffiliationEditor.ui") myenv.Uic4("QtJoinMUCWindow.ui") myenv.Qrc("DefaultTheme.qrc") myenv.Qrc("Swift.qrc") # Resources commonResources = { "": ["#/Swift/resources/sounds"] } diff --git a/Swiften/Client/CoreClient.cpp b/Swiften/Client/CoreClient.cpp index de12fb7..36bfe35 100644 --- a/Swiften/Client/CoreClient.cpp +++ b/Swiften/Client/CoreClient.cpp @@ -94,89 +94,114 @@ void CoreClient::connect(const std::string& host) { networkFactories->getDomainNameResolver(), host, options.boshHTTPConnectProxyURL, options.boshHTTPConnectProxyAuthID, options.boshHTTPConnectProxyAuthPassword)); sessionStream_->onDataRead.connect(boost::bind(&CoreClient::handleDataRead, this, _1)); sessionStream_->onDataWritten.connect(boost::bind(&CoreClient::handleDataWritten, this, _1)); bindSessionToStream(); } } void CoreClient::bindSessionToStream() { session_ = ClientSession::create(jid_, sessionStream_); session_->setCertificateTrustChecker(certificateTrustChecker); session_->setUseStreamCompression(options.useStreamCompression); session_->setAllowPLAINOverNonTLS(options.allowPLAINWithoutTLS); switch(options.useTLS) { case ClientOptions::UseTLSWhenAvailable: session_->setUseTLS(ClientSession::UseTLSWhenAvailable); break; case ClientOptions::NeverUseTLS: session_->setUseTLS(ClientSession::NeverUseTLS); break; case ClientOptions::RequireTLS: session_->setUseTLS(ClientSession::RequireTLS); break; } session_->setUseAcks(options.useAcks); stanzaChannel_->setSession(session_); session_->onFinished.connect(boost::bind(&CoreClient::handleSessionFinished, this, _1)); session_->onNeedCredentials.connect(boost::bind(&CoreClient::handleNeedCredentials, this)); session_->start(); } +bool CoreClient::isCAPIURI() { +#ifdef HAVE_SCHANNEL + if (!boost::iequals(certificate_.substr(0, 10), "certstore:")) { + return false; + } + + return true; + +#else + return false; +#endif +} + /** * Only called for TCP sessions. BOSH is handled inside the BOSHSessionStream. */ void CoreClient::handleConnectorFinished(boost::shared_ptr<Connection> connection) { resetConnector(); if (!connection) { if (options.forgetPassword) { purgePassword(); } onDisconnected(disconnectRequested_ ? boost::optional<ClientError>() : boost::optional<ClientError>(ClientError::ConnectionError)); } else { assert(!connection_); connection_ = connection; assert(!sessionStream_); sessionStream_ = boost::make_shared<BasicSessionStream>(ClientStreamType, connection_, getPayloadParserFactories(), getPayloadSerializers(), networkFactories->getTLSContextFactory(), networkFactories->getTimerFactory(), networkFactories->getXMLParserFactory()); if (!certificate_.empty()) { - sessionStream_->setTLSCertificate(PKCS12Certificate(certificate_, password_)); + CertificateWithKey* cert; + +#if defined(SWIFTEN_PLATFORM_WIN32) + if (isCAPIURI()) { + cert = new CAPICertificate(certificate_); + } else { + cert = new PKCS12Certificate(certificate_, password_); + } +#else + cert = new PKCS12Certificate(certificate_, password_); +#endif + + sessionStream_->setTLSCertificate(cert); } sessionStream_->onDataRead.connect(boost::bind(&CoreClient::handleDataRead, this, _1)); sessionStream_->onDataWritten.connect(boost::bind(&CoreClient::handleDataWritten, this, _1)); bindSessionToStream(); } } void CoreClient::disconnect() { // FIXME: We should be able to do without this boolean. We just have to make sure we can tell the difference between // connector finishing without a connection due to an error or because of a disconnect. disconnectRequested_ = true; if (session_ && !session_->isFinished()) { session_->finish(); } else if (connector_) { connector_->stop(); } } void CoreClient::setCertificate(const std::string& certificate) { certificate_ = certificate; } void CoreClient::handleSessionFinished(boost::shared_ptr<Error> error) { if (options.forgetPassword) { purgePassword(); } resetSession(); boost::optional<ClientError> actualError; if (error) { ClientError clientError; if (boost::shared_ptr<ClientSession::Error> actualError = boost::dynamic_pointer_cast<ClientSession::Error>(error)) { switch(actualError->type) { diff --git a/Swiften/Client/CoreClient.h b/Swiften/Client/CoreClient.h index c231fdc..6712e03 100644 --- a/Swiften/Client/CoreClient.h +++ b/Swiften/Client/CoreClient.h @@ -164,69 +164,71 @@ namespace Swift { */ boost::signal<void (const SafeByteArray&)> onDataWritten; /** * Emitted when a message is received. */ boost::signal<void (boost::shared_ptr<Message>)> onMessageReceived; /** * Emitted when a presence stanza is received. */ boost::signal<void (boost::shared_ptr<Presence>) > onPresenceReceived; /** * Emitted when the server acknowledges receipt of a * stanza (if acknowledgements are available). * * \see getStreamManagementEnabled() */ boost::signal<void (boost::shared_ptr<Stanza>)> onStanzaAcked; protected: boost::shared_ptr<ClientSession> getSession() const { return session_; } NetworkFactories* getNetworkFactories() const { return networkFactories; } /** * Called before onConnected signal is emmitted. */ virtual void handleConnected() {}; + bool isCAPIURI(); + private: void handleConnectorFinished(boost::shared_ptr<Connection>); void handleStanzaChannelAvailableChanged(bool available); void handleSessionFinished(boost::shared_ptr<Error>); void handleNeedCredentials(); void handleDataRead(const SafeByteArray&); void handleDataWritten(const SafeByteArray&); void handlePresenceReceived(boost::shared_ptr<Presence>); void handleMessageReceived(boost::shared_ptr<Message>); void handleStanzaAcked(boost::shared_ptr<Stanza>); void purgePassword(); void bindSessionToStream(); void resetConnector(); void resetSession(); void forceReset(); private: JID jid_; SafeByteArray password_; NetworkFactories* networkFactories; ClientSessionStanzaChannel* stanzaChannel_; IQRouter* iqRouter_; ClientOptions options; boost::shared_ptr<ChainedConnector> connector_; std::vector<ConnectionFactory*> proxyConnectionFactories; boost::shared_ptr<Connection> connection_; boost::shared_ptr<SessionStream> sessionStream_; boost::shared_ptr<ClientSession> session_; std::string certificate_; bool disconnectRequested_; CertificateTrustChecker* certificateTrustChecker; }; } diff --git a/Swiften/Session/SessionStream.cpp b/Swiften/Session/SessionStream.cpp index 0d73b63..487ad8b 100644 --- a/Swiften/Session/SessionStream.cpp +++ b/Swiften/Session/SessionStream.cpp @@ -1,14 +1,15 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <Swiften/Session/SessionStream.h> namespace Swift { SessionStream::~SessionStream() { + delete certificate; } }; diff --git a/Swiften/Session/SessionStream.h b/Swiften/Session/SessionStream.h index 096f185..58015b3 100644 --- a/Swiften/Session/SessionStream.h +++ b/Swiften/Session/SessionStream.h @@ -1,87 +1,89 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once #include <Swiften/Base/boost_bsignals.h> #include <boost/shared_ptr.hpp> #include <boost/optional.hpp> #include <Swiften/Elements/ProtocolHeader.h> #include <Swiften/Elements/Element.h> #include <Swiften/Base/Error.h> #include <Swiften/Base/SafeByteArray.h> -#include <Swiften/TLS/PKCS12Certificate.h> +#include <Swiften/TLS/CertificateWithKey.h> #include <Swiften/TLS/Certificate.h> #include <Swiften/TLS/CertificateVerificationError.h> namespace Swift { class SessionStream { public: class Error : public Swift::Error { public: enum Type { ParseError, TLSError, InvalidTLSCertificateError, ConnectionReadError, ConnectionWriteError }; Error(Type type) : type(type) {} Type type; }; + SessionStream(): certificate(0) {} + virtual ~SessionStream(); virtual void close() = 0; virtual bool isOpen() = 0; virtual void writeHeader(const ProtocolHeader& header) = 0; virtual void writeFooter() = 0; virtual void writeElement(boost::shared_ptr<Element>) = 0; virtual void writeData(const std::string& data) = 0; virtual bool supportsZLibCompression() = 0; virtual void addZLibCompression() = 0; virtual bool supportsTLSEncryption() = 0; virtual void addTLSEncryption() = 0; virtual bool isTLSEncrypted() = 0; virtual void setWhitespacePingEnabled(bool enabled) = 0; virtual void resetXMPPParser() = 0; - void setTLSCertificate(const PKCS12Certificate& cert) { + void setTLSCertificate(CertificateWithKey* cert) { certificate = cert; } virtual bool hasTLSCertificate() { - return !certificate.isNull(); + return certificate && !certificate->isNull(); } virtual Certificate::ref getPeerCertificate() const = 0; virtual boost::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const = 0; virtual ByteArray getTLSFinishMessage() const = 0; boost::signal<void (const ProtocolHeader&)> onStreamStartReceived; boost::signal<void (boost::shared_ptr<Element>)> onElementReceived; boost::signal<void (boost::shared_ptr<Error>)> onClosed; boost::signal<void ()> onTLSEncrypted; boost::signal<void (const SafeByteArray&)> onDataRead; boost::signal<void (const SafeByteArray&)> onDataWritten; protected: - const PKCS12Certificate& getTLSCertificate() const { + CertificateWithKey * getTLSCertificate() const { return certificate; } private: - PKCS12Certificate certificate; + CertificateWithKey * certificate; }; } diff --git a/Swiften/StreamStack/TLSLayer.cpp b/Swiften/StreamStack/TLSLayer.cpp index 6f2223d..b7efbcb 100644 --- a/Swiften/StreamStack/TLSLayer.cpp +++ b/Swiften/StreamStack/TLSLayer.cpp @@ -5,48 +5,48 @@ */ #include <Swiften/StreamStack/TLSLayer.h> #include <boost/bind.hpp> #include <Swiften/TLS/TLSContextFactory.h> #include <Swiften/TLS/TLSContext.h> namespace Swift { TLSLayer::TLSLayer(TLSContextFactory* factory) { context = factory->createTLSContext(); context->onDataForNetwork.connect(boost::bind(&TLSLayer::writeDataToChildLayer, this, _1)); context->onDataForApplication.connect(boost::bind(&TLSLayer::writeDataToParentLayer, this, _1)); context->onConnected.connect(onConnected); context->onError.connect(onError); } TLSLayer::~TLSLayer() { delete context; } void TLSLayer::connect() { context->connect(); } void TLSLayer::writeData(const SafeByteArray& data) { context->handleDataFromApplication(data); } void TLSLayer::handleDataRead(const SafeByteArray& data) { context->handleDataFromNetwork(data); } -bool TLSLayer::setClientCertificate(const PKCS12Certificate& certificate) { +bool TLSLayer::setClientCertificate(CertificateWithKey * certificate) { return context->setClientCertificate(certificate); } Certificate::ref TLSLayer::getPeerCertificate() const { return context->getPeerCertificate(); } boost::shared_ptr<CertificateVerificationError> TLSLayer::getPeerCertificateVerificationError() const { return context->getPeerCertificateVerificationError(); } } diff --git a/Swiften/StreamStack/TLSLayer.h b/Swiften/StreamStack/TLSLayer.h index a8693d5..6dc9135 100644 --- a/Swiften/StreamStack/TLSLayer.h +++ b/Swiften/StreamStack/TLSLayer.h @@ -1,44 +1,44 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <Swiften/Base/boost_bsignals.h> #include <Swiften/Base/SafeByteArray.h> #include <Swiften/StreamStack/StreamLayer.h> #include <Swiften/TLS/Certificate.h> #include <Swiften/TLS/CertificateVerificationError.h> namespace Swift { class TLSContext; class TLSContextFactory; - class PKCS12Certificate; + class CertificateWithKey; class TLSLayer : public StreamLayer { public: TLSLayer(TLSContextFactory*); ~TLSLayer(); void connect(); - bool setClientCertificate(const PKCS12Certificate&); + bool setClientCertificate(CertificateWithKey * cert); Certificate::ref getPeerCertificate() const; boost::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const; void writeData(const SafeByteArray& data); void handleDataRead(const SafeByteArray& data); TLSContext* getContext() const { return context; } public: boost::signal<void ()> onError; boost::signal<void ()> onConnected; private: TLSContext* context; }; } diff --git a/Swiften/TLS/CAPICertificate.h b/Swiften/TLS/CAPICertificate.h new file mode 100644 index 0000000..fcdb4c2 --- /dev/null +++ b/Swiften/TLS/CAPICertificate.h @@ -0,0 +1,196 @@ +/* + * Copyright (c) 2012 Isode Limited, London, England. + * Licensed under the simplified BSD license. + * See Documentation/Licenses/BSD-simplified.txt for more information. + */ + +#pragma once + +#include <Swiften/Base/SafeByteArray.h> +#include <Swiften/TLS/CertificateWithKey.h> + +#include <boost/algorithm/string/predicate.hpp> + +#define SECURITY_WIN32 +#include <WinCrypt.h> + +namespace Swift { + class CAPICertificate : public Swift::CertificateWithKey { + public: + CAPICertificate(const std::string& capiUri) + : valid_(false), uri_(capiUri), cert_store_handle_(0), cert_store_(NULL), cert_name_(NULL) { + setUri(capiUri); + } + + virtual ~CAPICertificate() { + if (cert_store_handle_ != NULL) + { + CertCloseStore(cert_store_handle_, 0); + } + } + + virtual bool isNull() const { + return uri_.empty() || !valid_; + } + + virtual bool isPrivateKeyExportable() const { + /* We can check with CAPI, but for now the answer is "no" */ + return false; + } + + virtual const std::string& getCertStoreName() const { + return cert_store_; + } + + virtual const std::string& getCertName() const { + return cert_name_; + } + + const ByteArray& getData() const { +////Might need to throw an exception here, or really generate PKCS12 blob from CAPI data? + assert(0); + } + + void setData(const ByteArray& data) { + assert(0); + } + + const SafeByteArray& getPassword() const { +/////Can't pass NULL to createSafeByteArray! +/////Should this throw an exception instead? + return createSafeByteArray(""); + } + + protected: + void setUri (const std::string& capiUri) { + + valid_ = false; + + /* Syntax: "certstore:" [<cert_store> ":"] <cert_id> */ + + if (!boost::iequals(capiUri.substr(0, 10), "certstore:")) { + return; + } + + /* Substring of subject: uses "storename" */ + std::string capi_identity = capiUri.substr(10); + std::string new_cert_store_name; + size_t pos = capi_identity.find_first_of (':'); + + if (pos == std::string::npos) { + /* Using the default certificate store */ + new_cert_store_name = "MY"; + cert_name_ = capi_identity; + } else { + new_cert_store_name = capi_identity.substr(0, pos); + cert_name_ = capi_identity.substr(pos + 1); + } + + PCCERT_CONTEXT pCertContext = NULL; + + if (cert_store_handle_ != NULL) + { + if (new_cert_store_name != cert_store_) { + CertCloseStore(cert_store_handle_, 0); + cert_store_handle_ = NULL; + } + } + + if (cert_store_handle_ == NULL) + { + cert_store_handle_ = CertOpenSystemStore(0, cert_store_.c_str()); + if (!cert_store_handle_) + { + return; + } + } + + cert_store_ = new_cert_store_name; + + /* NB: This might have to change, depending on how we locate certificates */ + + // Find client certificate. Note that this sample just searches for a + // certificate that contains the user name somewhere in the subject name. + pCertContext = CertFindCertificateInStore(cert_store_handle_, + X509_ASN_ENCODING, + 0, // dwFindFlags + CERT_FIND_SUBJECT_STR_A, + cert_name_.c_str(), // *pvFindPara + NULL ); // pPrevCertContext + + if (pCertContext == NULL) + { + return; + } + + + /* Now verify that we can have access to the corresponding private key */ + + DWORD len; + CRYPT_KEY_PROV_INFO *pinfo; + HCRYPTPROV hprov; + HCRYPTKEY key; + + if (!CertGetCertificateContextProperty(pCertContext, + CERT_KEY_PROV_INFO_PROP_ID, + NULL, + &len)) + { + CertFreeCertificateContext(pCertContext); + return; + } + + pinfo = static_cast<CRYPT_KEY_PROV_INFO *>(malloc(len)); + if (!pinfo) { + CertFreeCertificateContext(pCertContext); + return; + } + + if (!CertGetCertificateContextProperty(pCertContext, + CERT_KEY_PROV_INFO_PROP_ID, + pinfo, + &len)) + { + CertFreeCertificateContext(pCertContext); + free(pinfo); + return; + } + + CertFreeCertificateContext(pCertContext); + + // Now verify if we have access to the private key + if (!CryptAcquireContextW(&hprov, + pinfo->pwszContainerName, + pinfo->pwszProvName, + pinfo->dwProvType, + 0)) + { + free(pinfo); + return; + } + + if (!CryptGetUserKey(hprov, pinfo->dwKeySpec, &key)) + { + CryptReleaseContext(hprov, 0); + free(pinfo); + return; + } + + CryptDestroyKey(key); + CryptReleaseContext(hprov, 0); + free(pinfo); + + valid_ = true; + } + + private: + bool valid_; + std::string uri_; + + HCERTSTORE cert_store_handle_; + + /* Parsed components of the uri_ */ + std::string cert_store_; + std::string cert_name_; + }; +} diff --git a/Swiften/TLS/CertificateWithKey.h b/Swiften/TLS/CertificateWithKey.h new file mode 100644 index 0000000..6f6ea39 --- /dev/null +++ b/Swiften/TLS/CertificateWithKey.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2010-2012 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include <Swiften/Base/SafeByteArray.h> + +namespace Swift { + class CertificateWithKey { + public: + CertificateWithKey() {} + + virtual ~CertificateWithKey() {} + + virtual bool isNull() const = 0; + + virtual bool isPrivateKeyExportable() const = 0; + + virtual const std::string& getCertStoreName() const = 0; + + virtual const std::string& getCertName() const = 0; + + virtual const ByteArray& getData() const = 0; + + virtual void setData(const ByteArray& data) = 0; + + virtual const SafeByteArray& getPassword() const = 0; + }; +} diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp index 220e7f9..dd3462f 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp +++ b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp @@ -1,59 +1,59 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <Swiften/Base/Platform.h> #ifdef SWIFTEN_PLATFORM_WINDOWS #include <windows.h> #include <wincrypt.h> #endif #include <vector> #include <openssl/err.h> #include <openssl/pkcs12.h> #include <boost/smart_ptr/make_shared.hpp> #if defined(SWIFTEN_PLATFORM_MACOSX) && OPENSSL_VERSION_NUMBER < 0x00908000 #include <Security/Security.h> #endif #include <Swiften/TLS/OpenSSL/OpenSSLContext.h> #include <Swiften/TLS/OpenSSL/OpenSSLCertificate.h> -#include <Swiften/TLS/PKCS12Certificate.h> +#include <Swiften/TLS/CertificateWithKey.h> #pragma GCC diagnostic ignored "-Wold-style-cast" namespace Swift { static const int MAX_FINISHED_SIZE = 4096; static const int SSL_READ_BUFFERSIZE = 8192; void freeX509Stack(STACK_OF(X509)* stack) { sk_X509_free(stack); } OpenSSLContext::OpenSSLContext() : state_(Start), context_(0), handle_(0), readBIO_(0), writeBIO_(0) { ensureLibraryInitialized(); context_ = SSL_CTX_new(TLSv1_client_method()); // Load system certs #if defined(SWIFTEN_PLATFORM_WINDOWS) X509_STORE* store = SSL_CTX_get_cert_store(context_); HCERTSTORE systemStore = CertOpenSystemStore(0, "ROOT"); if (systemStore) { PCCERT_CONTEXT certContext = NULL; while (true) { certContext = CertFindCertificateInStore(systemStore, X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, 0, CERT_FIND_ANY, NULL, certContext); if (!certContext) { break; } OpenSSLCertificate cert(createByteArray(certContext->pbCertEncoded, certContext->cbCertEncoded)); if (store && cert.getInternalX509()) { X509_STORE_add_cert(store, cert.getInternalX509().get()); } } } #elif !defined(SWIFTEN_PLATFORM_MACOSX) SSL_CTX_load_verify_locations(context_, NULL, "/etc/ssl/certs"); @@ -153,89 +153,93 @@ void OpenSSLContext::handleDataFromNetwork(const SafeByteArray& data) { break; case Connected: sendPendingDataToApplication(); break; case Start: assert(false); break; case Error: /*assert(false);*/ break; } } void OpenSSLContext::handleDataFromApplication(const SafeByteArray& data) { if (SSL_write(handle_, vecptr(data), data.size()) >= 0) { sendPendingDataToNetwork(); } else { state_ = Error; onError(); } } void OpenSSLContext::sendPendingDataToApplication() { SafeByteArray data; data.resize(SSL_READ_BUFFERSIZE); int ret = SSL_read(handle_, vecptr(data), data.size()); while (ret > 0) { data.resize(ret); onDataForApplication(data); data.resize(SSL_READ_BUFFERSIZE); ret = SSL_read(handle_, vecptr(data), data.size()); } if (ret < 0 && SSL_get_error(handle_, ret) != SSL_ERROR_WANT_READ) { state_ = Error; onError(); } } -bool OpenSSLContext::setClientCertificate(const PKCS12Certificate& certificate) { - if (certificate.isNull()) { +bool OpenSSLContext::setClientCertificate(CertificateWithKey * certificate) { + if (!certificate || certificate->isNull()) { + return false; + } + + if (!certificate->isPrivateKeyExportable()) { return false; } // Create a PKCS12 structure BIO* bio = BIO_new(BIO_s_mem()); - BIO_write(bio, vecptr(certificate.getData()), certificate.getData().size()); + BIO_write(bio, vecptr(certificate->getData()), certificate->getData().size()); boost::shared_ptr<PKCS12> pkcs12(d2i_PKCS12_bio(bio, NULL), PKCS12_free); BIO_free(bio); if (!pkcs12) { return false; } // Parse PKCS12 X509 *certPtr = 0; EVP_PKEY* privateKeyPtr = 0; STACK_OF(X509)* caCertsPtr = 0; - int result = PKCS12_parse(pkcs12.get(), reinterpret_cast<const char*>(vecptr(certificate.getPassword())), &privateKeyPtr, &certPtr, &caCertsPtr); + int result = PKCS12_parse(pkcs12.get(), reinterpret_cast<const char*>(vecptr(certificate->getPassword())), &privateKeyPtr, &certPtr, &caCertsPtr); if (result != 1) { return false; } boost::shared_ptr<X509> cert(certPtr, X509_free); boost::shared_ptr<EVP_PKEY> privateKey(privateKeyPtr, EVP_PKEY_free); boost::shared_ptr<STACK_OF(X509)> caCerts(caCertsPtr, freeX509Stack); // Use the key & certificates if (SSL_CTX_use_certificate(context_, cert.get()) != 1) { return false; } if (SSL_CTX_use_PrivateKey(context_, privateKey.get()) != 1) { return false; } for (int i = 0; i < sk_X509_num(caCerts.get()); ++i) { SSL_CTX_add_extra_chain_cert(context_, sk_X509_value(caCerts.get(), i)); } return true; } Certificate::ref OpenSSLContext::getPeerCertificate() const { boost::shared_ptr<X509> x509Cert(SSL_get_peer_certificate(handle_), X509_free); if (x509Cert) { return boost::make_shared<OpenSSLCertificate>(x509Cert); } else { return Certificate::ref(); } } boost::shared_ptr<CertificateVerificationError> OpenSSLContext::getPeerCertificateVerificationError() const { int verifyResult = SSL_get_verify_result(handle_); if (verifyResult != X509_V_OK) { return boost::make_shared<CertificateVerificationError>(getVerificationErrorTypeForResult(verifyResult)); } diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.h b/Swiften/TLS/OpenSSL/OpenSSLContext.h index 04693a3..b53e715 100644 --- a/Swiften/TLS/OpenSSL/OpenSSLContext.h +++ b/Swiften/TLS/OpenSSL/OpenSSLContext.h @@ -1,53 +1,53 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once #include <openssl/ssl.h> #include <Swiften/Base/boost_bsignals.h> #include <boost/noncopyable.hpp> #include <Swiften/TLS/TLSContext.h> #include <Swiften/Base/ByteArray.h> namespace Swift { - class PKCS12Certificate; + class CertificateWithKey; class OpenSSLContext : public TLSContext, boost::noncopyable { public: OpenSSLContext(); ~OpenSSLContext(); void connect(); - bool setClientCertificate(const PKCS12Certificate& cert); + bool setClientCertificate(CertificateWithKey * cert); void handleDataFromNetwork(const SafeByteArray&); void handleDataFromApplication(const SafeByteArray&); Certificate::ref getPeerCertificate() const; boost::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const; virtual ByteArray getFinishMessage() const; private: static void ensureLibraryInitialized(); static CertificateVerificationError::Type getVerificationErrorTypeForResult(int); void doConnect(); void sendPendingDataToNetwork(); void sendPendingDataToApplication(); private: enum State { Start, Connecting, Connected, Error }; State state_; SSL_CTX* context_; SSL* handle_; BIO* readBIO_; BIO* writeBIO_; }; } diff --git a/Swiften/TLS/PKCS12Certificate.h b/Swiften/TLS/PKCS12Certificate.h index c0e01d0..2f70456 100644 --- a/Swiften/TLS/PKCS12Certificate.h +++ b/Swiften/TLS/PKCS12Certificate.h @@ -1,40 +1,59 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once #include <Swiften/Base/SafeByteArray.h> +#include <Swiften/TLS/CertificateWithKey.h> namespace Swift { - class PKCS12Certificate { + class PKCS12Certificate : public Swift::CertificateWithKey { public: PKCS12Certificate() {} PKCS12Certificate(const std::string& filename, const SafeByteArray& password) : password_(password) { readByteArrayFromFile(data_, filename); } - bool isNull() const { + virtual ~PKCS12Certificate() {} + + virtual bool isNull() const { return data_.empty(); } - const ByteArray& getData() const { + virtual bool isPrivateKeyExportable() const { +/////Hopefully a PKCS12 is never missing a private key + return true; + } + + virtual const std::string& getCertStoreName() const { +///// assert(0); + throw std::exception(); + } + + virtual const std::string& getCertName() const { + /* We can return the original filename instead, if we care */ +///// assert(0); + throw std::exception(); + } + + virtual const ByteArray& getData() const { return data_; } void setData(const ByteArray& data) { data_ = data; } - const SafeByteArray& getPassword() const { + virtual const SafeByteArray& getPassword() const { return password_; } private: ByteArray data_; SafeByteArray password_; }; } diff --git a/Swiften/TLS/Schannel/SchannelContext.cpp b/Swiften/TLS/Schannel/SchannelContext.cpp index 6771d4a..6f50b3a 100644 --- a/Swiften/TLS/Schannel/SchannelContext.cpp +++ b/Swiften/TLS/Schannel/SchannelContext.cpp @@ -1,102 +1,163 @@ /* * Copyright (c) 2011 Soren Dreijer * Licensed under the simplified BSD license. * See Documentation/Licenses/BSD-simplified.txt for more information. */ #include "Swiften/TLS/Schannel/SchannelContext.h" #include "Swiften/TLS/Schannel/SchannelCertificate.h" namespace Swift { //------------------------------------------------------------------------ SchannelContext::SchannelContext() : m_state(Start) , m_secContext(0) , m_verificationError(CertificateVerificationError::UnknownError) +, m_my_cert_store(NULL) +, m_cert_store_name("MY") +, m_cert_name(NULL) { m_ctxtFlags = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | ISC_REQ_EXTENDED_ERROR | ISC_REQ_INTEGRITY | ISC_REQ_REPLAY_DETECT | ISC_REQ_SEQUENCE_DETECT | ISC_REQ_USE_SUPPLIED_CREDS | ISC_REQ_STREAM; ZeroMemory(&m_streamSizes, sizeof(m_streamSizes)); } //------------------------------------------------------------------------ +SchannelContext::~SchannelContext() +{ + if (m_my_cert_store) CertCloseStore(m_my_cert_store, 0); +} + +//------------------------------------------------------------------------ + void SchannelContext::determineStreamSizes() { QueryContextAttributes(m_ctxtHandle, SECPKG_ATTR_STREAM_SIZES, &m_streamSizes); } //------------------------------------------------------------------------ void SchannelContext::connect() { + PCCERT_CONTEXT pCertContext = NULL; + m_state = Connecting; + // If a user name is specified, then attempt to find a client + // certificate. Otherwise, just create a NULL credential. + if (!m_cert_name.empty()) + { + if (m_my_cert_store == NULL) + { + m_my_cert_store = CertOpenSystemStore(0, m_cert_store_name.c_str()); + if (!m_my_cert_store) + { +///// printf( "**** Error 0x%x returned by CertOpenSystemStore\n", GetLastError() ); + indicateError(); + return; + } + } + + // Find client certificate. Note that this sample just searches for a + // certificate that contains the user name somewhere in the subject name. + pCertContext = CertFindCertificateInStore( m_my_cert_store, + X509_ASN_ENCODING, + 0, // dwFindFlags + CERT_FIND_SUBJECT_STR_A, + m_cert_name.c_str(), // *pvFindPara + NULL ); // pPrevCertContext + + if (pCertContext == NULL) + { +///// printf("**** Error 0x%x returned by CertFindCertificateInStore\n", GetLastError()); + indicateError(); + return; + } + } + // We use an empty list for client certificates PCCERT_CONTEXT clientCerts[1] = {0}; SCHANNEL_CRED sc = {0}; sc.dwVersion = SCHANNEL_CRED_VERSION; - sc.cCreds = 0; // Let Crypto API find the appropriate certificate for us - sc.paCred = clientCerts; + +/////SSL3? sc.grbitEnabledProtocols = SP_PROT_SSL3_CLIENT | SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_2_CLIENT; - sc.dwFlags = SCH_CRED_AUTO_CRED_VALIDATION | /*SCH_CRED_NO_DEFAULT_CREDS*/ SCH_CRED_USE_DEFAULT_CREDS | SCH_CRED_REVOCATION_CHECK_CHAIN; +/////Check SCH_CRED_REVOCATION_CHECK_CHAIN + sc.dwFlags = SCH_CRED_AUTO_CRED_VALIDATION | SCH_CRED_REVOCATION_CHECK_CHAIN; + + if (pCertContext) + { + sc.cCreds = 1; + sc.paCred = &pCertContext; + sc.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS; + } + else + { + sc.cCreds = 0; // Let Crypto API find the appropriate certificate for us + sc.paCred = clientCerts; + sc.dwFlags |= SCH_CRED_USE_DEFAULT_CREDS; + } // Swiften performs the server name check for us sc.dwFlags |= SCH_CRED_NO_SERVERNAME_CHECK; SECURITY_STATUS status = AcquireCredentialsHandle( NULL, UNISP_NAME, SECPKG_CRED_OUTBOUND, NULL, &sc, NULL, NULL, m_credHandle.Reset(), NULL); + // cleanup: Free the certificate context. Schannel has already made its own copy. + if (pCertContext) CertFreeCertificateContext(pCertContext); + if (status != SEC_E_OK) { // We failed to obtain the credentials handle indicateError(); return; } SecBuffer outBuffers[2]; // We let Schannel allocate the output buffer for us outBuffers[0].pvBuffer = NULL; outBuffers[0].cbBuffer = 0; outBuffers[0].BufferType = SECBUFFER_TOKEN; // Contains alert data if an alert is generated outBuffers[1].pvBuffer = NULL; outBuffers[1].cbBuffer = 0; outBuffers[1].BufferType = SECBUFFER_ALERT; // Make sure the output buffers are freed ScopedSecBuffer scopedOutputData(&outBuffers[0]); ScopedSecBuffer scopedOutputAlertData(&outBuffers[1]); SecBufferDesc outBufferDesc = {0}; outBufferDesc.cBuffers = 2; outBufferDesc.pBuffers = outBuffers; outBufferDesc.ulVersion = SECBUFFER_VERSION; // Create the initial security context status = InitializeSecurityContext( m_credHandle, NULL, NULL, m_ctxtFlags, 0, @@ -424,72 +485,85 @@ void SchannelContext::encryptAndSendData(const SafeByteArray& data) outBuffers[0].cbBuffer = m_streamSizes.cbHeader; outBuffers[0].BufferType = SECBUFFER_STREAM_HEADER; outBuffers[1].pvBuffer = &sendBuffer[0] + m_streamSizes.cbHeader; outBuffers[1].cbBuffer = (unsigned long)bytesToSend; outBuffers[1].BufferType = SECBUFFER_DATA; outBuffers[2].pvBuffer = &sendBuffer[0] + m_streamSizes.cbHeader + bytesToSend; outBuffers[2].cbBuffer = m_streamSizes.cbTrailer; outBuffers[2].BufferType = SECBUFFER_STREAM_TRAILER; outBuffers[3].pvBuffer = 0; outBuffers[3].cbBuffer = 0; outBuffers[3].BufferType = SECBUFFER_EMPTY; SecBufferDesc outBufferDesc = {0}; outBufferDesc.cBuffers = 4; outBufferDesc.pBuffers = outBuffers; outBufferDesc.ulVersion = SECBUFFER_VERSION; SECURITY_STATUS status = EncryptMessage(m_ctxtHandle, 0, &outBufferDesc, 0); if (status != SEC_E_OK) { indicateError(); return; } sendDataOnNetwork(&sendBuffer[0], outBuffers[0].cbBuffer + outBuffers[1].cbBuffer + outBuffers[2].cbBuffer); bytesSent += bytesToSend; } while (bytesSent < data.size()); } //------------------------------------------------------------------------ -bool SchannelContext::setClientCertificate(const PKCS12Certificate& certificate) +bool SchannelContext::setClientCertificate(CertificateWithKey * certificate) { + if (!certificate || certificate->isNull()) { + return false; + } + + if (!certificate->isPrivateKeyExportable()) { + // We assume that the Certificate Store Name/Certificate Name + // are valid at this point + m_cert_store_name = certificate->getCertStoreName(); + m_cert_name = certificate->getCertName(); + + return true; + } + return false; } //------------------------------------------------------------------------ Certificate::ref SchannelContext::getPeerCertificate() const { SchannelCertificate::ref pCertificate; ScopedCertContext pServerCert; SECURITY_STATUS status = QueryContextAttributes(m_ctxtHandle, SECPKG_ATTR_REMOTE_CERT_CONTEXT, pServerCert.Reset()); if (status != SEC_E_OK) return pCertificate; pCertificate.reset( new SchannelCertificate(pServerCert) ); return pCertificate; } //------------------------------------------------------------------------ CertificateVerificationError::ref SchannelContext::getPeerCertificateVerificationError() const { boost::shared_ptr<CertificateVerificationError> pCertError; if (m_state == Error) pCertError.reset( new CertificateVerificationError(m_verificationError) ); return pCertError; } //------------------------------------------------------------------------ ByteArray SchannelContext::getFinishMessage() const { // TODO: Implement diff --git a/Swiften/TLS/Schannel/SchannelContext.h b/Swiften/TLS/Schannel/SchannelContext.h index 66467fe..0cdb3d7 100644 --- a/Swiften/TLS/Schannel/SchannelContext.h +++ b/Swiften/TLS/Schannel/SchannelContext.h @@ -1,81 +1,88 @@ /* * Copyright (c) 2011 Soren Dreijer * Licensed under the simplified BSD license. * See Documentation/Licenses/BSD-simplified.txt for more information. */ #pragma once #include "Swiften/Base/boost_bsignals.h" #include "Swiften/TLS/TLSContext.h" #include "Swiften/TLS/Schannel/SchannelUtil.h" +#include <Swiften/TLS/CertificateWithKey.h> #include "Swiften/Base/ByteArray.h" #define SECURITY_WIN32 #include <Windows.h> #include <Schannel.h> #include <security.h> #include <schnlsp.h> #include <boost/noncopyable.hpp> namespace Swift { class SchannelContext : public TLSContext, boost::noncopyable { public: typedef boost::shared_ptr<SchannelContext> sp_t; public: - SchannelContext(); + SchannelContext(); + + ~SchannelContext(); // // TLSContext // virtual void connect(); - virtual bool setClientCertificate(const PKCS12Certificate&); + virtual bool setClientCertificate(CertificateWithKey * cert); virtual void handleDataFromNetwork(const SafeByteArray& data); virtual void handleDataFromApplication(const SafeByteArray& data); virtual Certificate::ref getPeerCertificate() const; virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const; virtual ByteArray getFinishMessage() const; private: void determineStreamSizes(); void continueHandshake(const SafeByteArray& data); void indicateError(); void sendDataOnNetwork(const void* pData, size_t dataSize); void forwardDataToApplication(const void* pData, size_t dataSize); void decryptAndProcessData(const SafeByteArray& data); void encryptAndSendData(const SafeByteArray& data); void appendNewData(const SafeByteArray& data); private: enum SchannelState { Start, Connecting, Connected, Error }; SchannelState m_state; CertificateVerificationError m_verificationError; ULONG m_secContext; ScopedCredHandle m_credHandle; ScopedCtxtHandle m_ctxtHandle; DWORD m_ctxtFlags; SecPkgContext_StreamSizes m_streamSizes; std::vector<char> m_receivedData; + + HCERTSTORE m_my_cert_store; + std::string m_cert_store_name; + std::string m_cert_name; }; } diff --git a/Swiften/TLS/TLSContext.h b/Swiften/TLS/TLSContext.h index 1538863..ada813a 100644 --- a/Swiften/TLS/TLSContext.h +++ b/Swiften/TLS/TLSContext.h @@ -1,41 +1,41 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once #include <Swiften/Base/boost_bsignals.h> #include <boost/shared_ptr.hpp> #include <Swiften/Base/SafeByteArray.h> #include <Swiften/TLS/Certificate.h> #include <Swiften/TLS/CertificateVerificationError.h> namespace Swift { - class PKCS12Certificate; + class CertificateWithKey; class TLSContext { public: virtual ~TLSContext(); virtual void connect() = 0; - virtual bool setClientCertificate(const PKCS12Certificate& cert) = 0; + virtual bool setClientCertificate(CertificateWithKey * cert) = 0; virtual void handleDataFromNetwork(const SafeByteArray&) = 0; virtual void handleDataFromApplication(const SafeByteArray&) = 0; virtual Certificate::ref getPeerCertificate() const = 0; virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const = 0; virtual ByteArray getFinishMessage() const = 0; public: boost::signal<void (const SafeByteArray&)> onDataForNetwork; boost::signal<void (const SafeByteArray&)> onDataForApplication; boost::signal<void ()> onError; boost::signal<void ()> onConnected; }; } |