diff options
-rw-r--r-- | BuildTools/SCons/SConscript.boot | 4 | ||||
-rw-r--r-- | Swift/SConscript | 2 | ||||
-rw-r--r-- | Swiften/Client/ClientError.h | 2 | ||||
-rw-r--r-- | Swiften/Client/CoreClient.cpp | 6 | ||||
-rw-r--r-- | Swiften/TLS/CertificateVerificationError.h | 2 | ||||
-rw-r--r-- | Swiften/TLS/Schannel/SchannelCertificate.h | 7 | ||||
-rw-r--r-- | Swiften/TLS/Schannel/SchannelContext.cpp | 143 | ||||
-rw-r--r-- | Swiften/TLS/Schannel/SchannelContext.h | 4 | ||||
-rw-r--r-- | Swiften/TLS/Schannel/SchannelUtil.h | 135 |
9 files changed, 287 insertions, 18 deletions
diff --git a/BuildTools/SCons/SConscript.boot b/BuildTools/SCons/SConscript.boot index 188184c..dc8a8a5 100644 --- a/BuildTools/SCons/SConscript.boot +++ b/BuildTools/SCons/SConscript.boot @@ -23,70 +23,71 @@ vars.Add(BoolVariable("max_jobs", "Build with maximum number of parallel jobs", vars.Add(EnumVariable("target", "Choose a target platform for compilation", "native", ["native", "iphone-simulator", "iphone-device", "xcode"])) vars.Add(BoolVariable("swift_mobile", "Build mobile Swift", "no")) if os.name != "nt" : vars.Add(BoolVariable("coverage", "Compile with coverage information", "no")) if os.name == "posix" : vars.Add(BoolVariable("valgrind", "Run tests with valgrind", "no")) if os.name == "mac" or (os.name == "posix" and os.uname()[0] == "Darwin"): vars.Add(BoolVariable("universal", "Create universal binaries", "no")) vars.Add(BoolVariable("mac105", "Link against the 10.5 frameworks", "no")) if os.name == "nt" : vars.Add(PathVariable("vcredist", "MSVC redistributable dir", "", PathVariable.PathAccept)) if os.name == "nt" : vars.Add(PathVariable("wix_bindir", "Path to WiX binaries", "", PathVariable.PathAccept)) if os.name == "nt" : vars.Add(PackageVariable("bonjour", "Bonjour SDK location", "yes")) vars.Add(PackageVariable("openssl", "OpenSSL location", "yes")) vars.Add(PathVariable("boost_includedir", "Boost headers location", None, PathVariable.PathAccept)) vars.Add(PathVariable("boost_libdir", "Boost library location", None, PathVariable.PathAccept)) vars.Add(PathVariable("expat_includedir", "Expat headers location", None, PathVariable.PathAccept)) vars.Add(PathVariable("expat_libdir", "Expat library location", None, PathVariable.PathAccept)) vars.Add("expat_libname", "Expat library name", "libexpat" if os.name == "nt" else "expat") vars.Add(PathVariable("libidn_includedir", "LibIDN headers location", None, PathVariable.PathAccept)) vars.Add(PathVariable("libidn_libdir", "LibIDN library location", None, PathVariable.PathAccept)) vars.Add("libidn_libname", "LibIDN library name", "libidn" if os.name == "nt" else "idn") vars.Add(PathVariable("sqlite_includedir", "SQLite headers location", None, PathVariable.PathAccept)) vars.Add(PathVariable("sqlite_libdir", "SQLite library location", None, PathVariable.PathAccept)) vars.Add("sqlite_libname", "SQLite library name", "libsqlite3" if os.name == "nt" else "sqlite3") vars.Add(PathVariable("avahi_includedir", "Avahi headers location", None, PathVariable.PathAccept)) vars.Add(PathVariable("avahi_libdir", "Avahi library location", None, PathVariable.PathAccept)) vars.Add(PathVariable("qt", "Qt location", "", PathVariable.PathAccept)) vars.Add(PathVariable("docbook_xml", "DocBook XML", None, PathVariable.PathAccept)) vars.Add(PathVariable("docbook_xsl", "DocBook XSL", None, PathVariable.PathAccept)) vars.Add(BoolVariable("build_examples", "Build example programs", "yes")) vars.Add(BoolVariable("enable_variants", "Build in a separate dir under build/, depending on compile flags", "no")) vars.Add(BoolVariable("experimental", "Build experimental features", "no")) +vars.Add(BoolVariable("set_iterator_debug_level", "Set _ITERATOR_DEBUG_LEVEL=0", "yes")) ################################################################################ # Set up default build & configure environment ################################################################################ env = Environment(CPPPATH = ["#"], ENV = { 'PATH' : os.environ['PATH'], 'LD_LIBRARY_PATH' : os.environ.get("LD_LIBRARY_PATH", ""), }, variables = vars) Help(vars.GenerateHelpText(env)) # Default environment variables env["PLATFORM_FLAGS"] = {} # Default custom tools env.Tool("Test", toolpath = ["#/BuildTools/SCons/Tools"]) env.Tool("WriteVal", toolpath = ["#/BuildTools/SCons/Tools"]) env.Tool("BuildVersion", toolpath = ["#/BuildTools/SCons/Tools"]) env.Tool("Flags", toolpath = ["#/BuildTools/SCons/Tools"]) if env["PLATFORM"] == "darwin" : env.Tool("Nib", toolpath = ["#/BuildTools/SCons/Tools"]) env.Tool("AppBundle", toolpath = ["#/BuildTools/SCons/Tools"]) if env["PLATFORM"] == "win32" : env.Tool("WindowsBundle", toolpath = ["#/BuildTools/SCons/Tools"]) #So we don't need to escalate with UAC if "TMP" in os.environ.keys() : env['ENV']['TMP'] = os.environ['TMP'] env.Tool("SLOCCount", toolpath = ["#/BuildTools/SCons/Tools"]) # Max out the number of jobs if env["max_jobs"] : try : import multiprocessing SetOption("num_jobs", multiprocessing.cpu_count()) @@ -102,71 +103,72 @@ if env.get("distcc", False) : if "distcc_hosts" in env : env["ENV"]["DISTCC_HOSTS"] = env["distcc_hosts"] env["CC"] = "distcc gcc" env["CXX"] = "distcc g++" if "cc" in env : env["CC"] = env["cc"] if "cxx" in env : env["CXX"] = env["cxx"] ccflags = env.get("ccflags", []) if isinstance(ccflags, str) : # FIXME: Make the splitting more robust env["CCFLAGS"] = ccflags.split(" ") else : env["CCFLAGS"] = ccflags if "link" in env : env["SHLINK"] = env["link"] env["LINK"] = env["link"] env["LINKFLAGS"] = env.get("linkflags", []) # This isn't a real flag (yet) AFAIK. Be sure to append it to the CXXFLAGS # where you need it env["OBJCCFLAGS"] = [] if env["optimize"] : if env["PLATFORM"] == "win32" : env.Append(CCFLAGS = ["/O2", "/GL"]) env.Append(LINKFLAGS = ["/INCREMENTAL:NO", "/LTCG"]) else : env.Append(CCFLAGS = ["-O2"]) if env["target"] == "xcode" and os.environ["CONFIGURATION"] == "Release" : env.Append(CCFLAGS = ["-Os"]) if env["debug"] : if env["PLATFORM"] == "win32" : env.Append(CCFLAGS = ["/Zi", "/MDd"]) env.Append(LINKFLAGS = ["/DEBUG"]) - env.Append(CPPDEFINES = ["_ITERATOR_DEBUG_LEVEL=0"]) + if env["set_iterator_debug_level"] : + env.Append(CPPDEFINES = ["_ITERATOR_DEBUG_LEVEL=0"]) else : env.Append(CCFLAGS = ["-g"]) elif env["PLATFORM"] == "win32" : env.Append(CCFLAGS = ["/MD"]) if env.get("universal", 0) : assert(env["PLATFORM"] == "darwin") env.Append(CCFLAGS = [ "-isysroot", "/Developer/SDKs/MacOSX10.4u.sdk", "-arch", "i386", "-arch", "ppc"]) env.Append(LINKFLAGS = [ "-mmacosx-version-min=10.4", "-isysroot", "/Developer/SDKs/MacOSX10.4u.sdk", "-arch", "i386", "-arch", "ppc"]) if env.get("mac105", 0) : assert(env["PLATFORM"] == "darwin") env.Append(CCFLAGS = [ "-isysroot", "/Developer/SDKs/MacOSX10.5.sdk", "-arch", "i386"]) env.Append(LINKFLAGS = [ "-mmacosx-version-min=10.5", "-isysroot", "/Developer/SDKs/MacOSX10.5.sdk", "-arch", "i386"]) env.Append(FRAMEWORKS = ["Security"]) if not env["assertions"] : env.Append(CPPDEFINES = ["NDEBUG"]) if env["experimental"] : env.Append(CPPDEFINES = ["SWIFT_EXPERIMENTAL_FT"]) # If we build shared libs on AMD64, we need -fPIC. diff --git a/Swift/SConscript b/Swift/SConscript index b66058b..49aa985 100644 --- a/Swift/SConscript +++ b/Swift/SConscript @@ -1,16 +1,16 @@ import datetime Import("env") SConscript("Controllers/SConscript") if env["SCONS_STAGE"] == "build" : if not GetOption("help") and not env.get("HAVE_OPENSSL", 0) and not env.get("HAVE_SCHANNEL", 0) : print "Error: Swift requires OpenSSL support, and OpenSSL was not found." if "Swift" in env["PROJECTS"] : env["PROJECTS"].remove("Swift") elif not GetOption("help") and not env.get("HAVE_QT", 0) : print "Error: Swift requires Qt. Not building Swift." - env["PROJECTS"].remove("Swift") +# env["PROJECTS"].remove("Swift") elif env["target"] == "native": SConscript("QtUI/SConscript") diff --git a/Swiften/Client/ClientError.h b/Swiften/Client/ClientError.h index baf1b0a..2f2d2af 100644 --- a/Swiften/Client/ClientError.h +++ b/Swiften/Client/ClientError.h @@ -8,45 +8,47 @@ namespace Swift { class ClientError { public: enum Type { UnknownError, DomainNameResolveError, ConnectionError, ConnectionReadError, ConnectionWriteError, XMLError, AuthenticationFailedError, CompressionFailedError, ServerVerificationFailedError, NoSupportedAuthMechanismsError, UnexpectedElementError, ResourceBindError, SessionStartError, StreamError, TLSError, ClientCertificateLoadError, ClientCertificateError, // Certificate verification errors UnknownCertificateError, CertificateExpiredError, CertificateNotYetValidError, CertificateSelfSignedError, CertificateRejectedError, CertificateUntrustedError, InvalidCertificatePurposeError, CertificatePathLengthExceededError, InvalidCertificateSignatureError, InvalidCAError, InvalidServerIdentityError, + RevokedError, + RevocationCheckFailedError }; ClientError(Type type = UnknownError) : type_(type) {} Type getType() const { return type_; } private: Type type_; }; } diff --git a/Swiften/Client/CoreClient.cpp b/Swiften/Client/CoreClient.cpp index f7e3b21..14481c6 100644 --- a/Swiften/Client/CoreClient.cpp +++ b/Swiften/Client/CoreClient.cpp @@ -239,70 +239,76 @@ void CoreClient::handleSessionFinished(boost::shared_ptr<Error> error) { else if (boost::shared_ptr<CertificateVerificationError> verificationError = boost::dynamic_pointer_cast<CertificateVerificationError>(error)) { switch(verificationError->getType()) { case CertificateVerificationError::UnknownError: clientError = ClientError(ClientError::UnknownCertificateError); break; case CertificateVerificationError::Expired: clientError = ClientError(ClientError::CertificateExpiredError); break; case CertificateVerificationError::NotYetValid: clientError = ClientError(ClientError::CertificateNotYetValidError); break; case CertificateVerificationError::SelfSigned: clientError = ClientError(ClientError::CertificateSelfSignedError); break; case CertificateVerificationError::Rejected: clientError = ClientError(ClientError::CertificateRejectedError); break; case CertificateVerificationError::Untrusted: clientError = ClientError(ClientError::CertificateUntrustedError); break; case CertificateVerificationError::InvalidPurpose: clientError = ClientError(ClientError::InvalidCertificatePurposeError); break; case CertificateVerificationError::PathLengthExceeded: clientError = ClientError(ClientError::CertificatePathLengthExceededError); break; case CertificateVerificationError::InvalidSignature: clientError = ClientError(ClientError::InvalidCertificateSignatureError); break; case CertificateVerificationError::InvalidCA: clientError = ClientError(ClientError::InvalidCAError); break; case CertificateVerificationError::InvalidServerIdentity: clientError = ClientError(ClientError::InvalidServerIdentityError); break; + case CertificateVerificationError::Revoked: + clientError = ClientError(ClientError::RevokedError); + break; + case CertificateVerificationError::RevocationCheckFailed: + clientError = ClientError(ClientError::RevocationCheckFailedError); + break; } } actualError = boost::optional<ClientError>(clientError); } onDisconnected(actualError); } void CoreClient::handleNeedCredentials() { assert(session_); session_->sendCredentials(password_); if (options.forgetPassword) { purgePassword(); } } void CoreClient::handleDataRead(const SafeByteArray& data) { onDataRead(data); } void CoreClient::handleDataWritten(const SafeByteArray& data) { onDataWritten(data); } void CoreClient::handleStanzaChannelAvailableChanged(bool available) { if (available) { iqRouter_->setJID(session_->getLocalJID()); handleConnected(); onConnected(); } } void CoreClient::sendMessage(boost::shared_ptr<Message> message) { stanzaChannel_->sendMessage(message); } diff --git a/Swiften/TLS/CertificateVerificationError.h b/Swiften/TLS/CertificateVerificationError.h index 22e6eaf..b17f5df 100644 --- a/Swiften/TLS/CertificateVerificationError.h +++ b/Swiften/TLS/CertificateVerificationError.h @@ -1,40 +1,42 @@ /* * Copyright (c) 2010-2012 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once #include <boost/shared_ptr.hpp> #include <Swiften/Base/Error.h> namespace Swift { class CertificateVerificationError : public Error { public: typedef boost::shared_ptr<CertificateVerificationError> ref; enum Type { UnknownError, Expired, NotYetValid, SelfSigned, Rejected, Untrusted, InvalidPurpose, PathLengthExceeded, InvalidSignature, InvalidCA, InvalidServerIdentity, + Revoked, + RevocationCheckFailed }; CertificateVerificationError(Type type = UnknownError) : type(type) {} Type getType() const { return type; } private: Type type; }; } diff --git a/Swiften/TLS/Schannel/SchannelCertificate.h b/Swiften/TLS/Schannel/SchannelCertificate.h index f531cff..395d3ec 100644 --- a/Swiften/TLS/Schannel/SchannelCertificate.h +++ b/Swiften/TLS/Schannel/SchannelCertificate.h @@ -16,66 +16,71 @@ namespace Swift { class SchannelCertificate : public Certificate { public: typedef boost::shared_ptr<SchannelCertificate> ref; public: SchannelCertificate(const ScopedCertContext& certCtxt); SchannelCertificate(const ByteArray& der); std::string getSubjectName() const { return m_subjectName; } std::vector<std::string> getCommonNames() const { return m_commonNames; } std::vector<std::string> getSRVNames() const { return m_srvNames; } std::vector<std::string> getDNSNames() const { return m_dnsNames; } std::vector<std::string> getXMPPAddresses() const { return m_xmppAddresses; } - ByteArray toDER() const; + ScopedCertContext getCertContext() const + { + return m_cert; + } + ByteArray toDER() const; + private: void parse(); std::string wstrToStr(const std::wstring& wstr); void addSRVName(const std::string& name) { m_srvNames.push_back(name); } void addDNSName(const std::string& name) { m_dnsNames.push_back(name); } void addXMPPAddress(const std::string& addr) { m_xmppAddresses.push_back(addr); } private: ScopedCertContext m_cert; std::string m_subjectName; std::vector<std::string> m_commonNames; std::vector<std::string> m_dnsNames; std::vector<std::string> m_xmppAddresses; std::vector<std::string> m_srvNames; }; } diff --git a/Swiften/TLS/Schannel/SchannelContext.cpp b/Swiften/TLS/Schannel/SchannelContext.cpp index b2fea65..9be1ded 100644 --- a/Swiften/TLS/Schannel/SchannelContext.cpp +++ b/Swiften/TLS/Schannel/SchannelContext.cpp @@ -1,219 +1,289 @@ /* * Copyright (c) 2011 Soren Dreijer * Licensed under the simplified BSD license. * See Documentation/Licenses/BSD-simplified.txt for more information. */ -#include <Swiften/TLS/Schannel/SchannelContext.h> -#include <Swiften/TLS/Schannel/SchannelCertificate.h> +#include "Swiften/TLS/Schannel/SchannelContext.h" +#include "Swiften/TLS/Schannel/SchannelCertificate.h" #include <Swiften/TLS/CAPICertificate.h> +#include <WinHTTP.h> // For SECURITY_FLAG_IGNORE_CERT_CN_INVALID namespace Swift { //------------------------------------------------------------------------ SchannelContext::SchannelContext() : m_state(Start) , m_secContext(0) -, m_verificationError(CertificateVerificationError::UnknownError) , m_my_cert_store(NULL) , m_cert_store_name("MY") , m_cert_name() { m_ctxtFlags = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | ISC_REQ_EXTENDED_ERROR | ISC_REQ_INTEGRITY | ISC_REQ_REPLAY_DETECT | ISC_REQ_SEQUENCE_DETECT | ISC_REQ_USE_SUPPLIED_CREDS | ISC_REQ_STREAM; ZeroMemory(&m_streamSizes, sizeof(m_streamSizes)); } //------------------------------------------------------------------------ SchannelContext::~SchannelContext() { if (m_my_cert_store) CertCloseStore(m_my_cert_store, 0); } //------------------------------------------------------------------------ void SchannelContext::determineStreamSizes() { QueryContextAttributes(m_ctxtHandle, SECPKG_ATTR_STREAM_SIZES, &m_streamSizes); } //------------------------------------------------------------------------ void SchannelContext::connect() { - PCCERT_CONTEXT pCertContext = NULL; + ScopedCertContext pCertContext; m_state = Connecting; // If a user name is specified, then attempt to find a client // certificate. Otherwise, just create a NULL credential. if (!m_cert_name.empty()) { if (m_my_cert_store == NULL) { m_my_cert_store = CertOpenSystemStore(0, m_cert_store_name.c_str()); if (!m_my_cert_store) { ///// printf( "**** Error 0x%x returned by CertOpenSystemStore\n", GetLastError() ); indicateError(); return; } } pCertContext = findCertificateInStore( m_my_cert_store, m_cert_name ); if (pCertContext == NULL) { ///// printf("**** Error 0x%x returned by CertFindCertificateInStore\n", GetLastError()); indicateError(); return; } } // We use an empty list for client certificates PCCERT_CONTEXT clientCerts[1] = {0}; SCHANNEL_CRED sc = {0}; sc.dwVersion = SCHANNEL_CRED_VERSION; /////SSL3? sc.grbitEnabledProtocols = SP_PROT_SSL3_CLIENT | SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_2_CLIENT; - sc.dwFlags = SCH_CRED_AUTO_CRED_VALIDATION | SCH_CRED_REVOCATION_CHECK_CHAIN; + sc.dwFlags = SCH_CRED_MANUAL_CRED_VALIDATION; if (pCertContext) { sc.cCreds = 1; - sc.paCred = &pCertContext; + sc.paCred = pCertContext.GetPointer(); sc.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS; } else { sc.cCreds = 0; // Let Crypto API find the appropriate certificate for us sc.paCred = clientCerts; sc.dwFlags |= SCH_CRED_USE_DEFAULT_CREDS; } // Swiften performs the server name check for us sc.dwFlags |= SCH_CRED_NO_SERVERNAME_CHECK; SECURITY_STATUS status = AcquireCredentialsHandle( NULL, UNISP_NAME, SECPKG_CRED_OUTBOUND, NULL, &sc, NULL, NULL, m_credHandle.Reset(), NULL); - - // cleanup: Free the certificate context. Schannel has already made its own copy. - if (pCertContext) CertFreeCertificateContext(pCertContext); - + if (status != SEC_E_OK) { // We failed to obtain the credentials handle indicateError(); return; } SecBuffer outBuffers[2]; // We let Schannel allocate the output buffer for us outBuffers[0].pvBuffer = NULL; outBuffers[0].cbBuffer = 0; outBuffers[0].BufferType = SECBUFFER_TOKEN; // Contains alert data if an alert is generated outBuffers[1].pvBuffer = NULL; outBuffers[1].cbBuffer = 0; outBuffers[1].BufferType = SECBUFFER_ALERT; // Make sure the output buffers are freed ScopedSecBuffer scopedOutputData(&outBuffers[0]); ScopedSecBuffer scopedOutputAlertData(&outBuffers[1]); SecBufferDesc outBufferDesc = {0}; outBufferDesc.cBuffers = 2; outBufferDesc.pBuffers = outBuffers; outBufferDesc.ulVersion = SECBUFFER_VERSION; // Create the initial security context status = InitializeSecurityContext( m_credHandle, NULL, NULL, m_ctxtFlags, 0, 0, NULL, 0, m_ctxtHandle.Reset(), &outBufferDesc, &m_secContext, NULL); if (status != SEC_E_OK && status != SEC_I_CONTINUE_NEEDED) { // We failed to initialize the security context + handleCertError(status); indicateError(); return; } // Start the handshake sendDataOnNetwork(outBuffers[0].pvBuffer, outBuffers[0].cbBuffer); if (status == SEC_E_OK) { + status = validateServerCertificate(); + if (status != SEC_E_OK) + handleCertError(status); + m_state = Connected; determineStreamSizes(); onConnected(); } } //------------------------------------------------------------------------ +SECURITY_STATUS SchannelContext::validateServerCertificate() +{ + SchannelCertificate::ref pServerCert = boost::dynamic_pointer_cast<SchannelCertificate>( getPeerCertificate() ); + if (!pServerCert) + return SEC_E_WRONG_PRINCIPAL; + + const LPSTR usage[] = + { + szOID_PKIX_KP_SERVER_AUTH, + szOID_SERVER_GATED_CRYPTO, + szOID_SGC_NETSCAPE + }; + + CERT_CHAIN_PARA chainParams = {0}; + chainParams.cbSize = sizeof(chainParams); + chainParams.RequestedUsage.dwType = USAGE_MATCH_TYPE_OR; + chainParams.RequestedUsage.Usage.cUsageIdentifier = ARRAYSIZE(usage); + chainParams.RequestedUsage.Usage.rgpszUsageIdentifier = const_cast<LPSTR*>(usage); + + DWORD chainFlags = CERT_CHAIN_CACHE_END_CERT | CERT_CHAIN_REVOCATION_CHECK_CHAIN_EXCLUDE_ROOT; + + ScopedCertChainContext pChainContext; + + BOOL success = CertGetCertificateChain( + NULL, // Use the chain engine for the current user (assumes a user is logged in) + pServerCert->getCertContext(), + NULL, + NULL, + &chainParams, + chainFlags, + NULL, + pChainContext.Reset()); + + if (!success) + return GetLastError(); + + SSL_EXTRA_CERT_CHAIN_POLICY_PARA sslChainPolicy = {0}; + sslChainPolicy.cbSize = sizeof(sslChainPolicy); + sslChainPolicy.dwAuthType = AUTHTYPE_SERVER; + sslChainPolicy.fdwChecks = SECURITY_FLAG_IGNORE_CERT_CN_INVALID; // Swiften checks the server name for us. Is this the correct way to disable server name checking? + sslChainPolicy.pwszServerName = NULL; + + CERT_CHAIN_POLICY_PARA certChainPolicy = {0}; + certChainPolicy.cbSize = sizeof(certChainPolicy); + certChainPolicy.dwFlags = CERT_CHAIN_POLICY_IGNORE_INVALID_NAME_FLAG; // Swiften checks the server name for us. Is this the correct way to disable server name checking? + certChainPolicy.pvExtraPolicyPara = &sslChainPolicy; + + CERT_CHAIN_POLICY_STATUS certChainPolicyStatus = {0}; + certChainPolicyStatus.cbSize = sizeof(certChainPolicyStatus); + + // Verify the chain + if (!CertVerifyCertificateChainPolicy( + CERT_CHAIN_POLICY_SSL, + pChainContext, + &certChainPolicy, + &certChainPolicyStatus)) + { + return GetLastError(); + } + + if (certChainPolicyStatus.dwError != S_OK) + return certChainPolicyStatus.dwError; + + return S_OK; +} + +//------------------------------------------------------------------------ + void SchannelContext::appendNewData(const SafeByteArray& data) { size_t originalSize = m_receivedData.size(); m_receivedData.resize( originalSize + data.size() ); memcpy( &m_receivedData[0] + originalSize, &data[0], data.size() ); } //------------------------------------------------------------------------ void SchannelContext::continueHandshake(const SafeByteArray& data) { appendNewData(data); while (!m_receivedData.empty()) { SecBuffer inBuffers[2]; // Provide Schannel with the remote host's handshake data inBuffers[0].pvBuffer = (char*)(&m_receivedData[0]); inBuffers[0].cbBuffer = (unsigned long)m_receivedData.size(); inBuffers[0].BufferType = SECBUFFER_TOKEN; inBuffers[1].pvBuffer = NULL; inBuffers[1].cbBuffer = 0; inBuffers[1].BufferType = SECBUFFER_EMPTY; SecBufferDesc inBufferDesc = {0}; inBufferDesc.cBuffers = 2; inBufferDesc.pBuffers = inBuffers; inBufferDesc.ulVersion = SECBUFFER_VERSION; SecBuffer outBuffers[2]; // We let Schannel allocate the output buffer for us outBuffers[0].pvBuffer = NULL; @@ -238,93 +308,139 @@ void SchannelContext::continueHandshake(const SafeByteArray& data) m_credHandle, m_ctxtHandle, NULL, m_ctxtFlags, 0, 0, &inBufferDesc, 0, NULL, &outBufferDesc, &m_secContext, NULL); if (status == SEC_E_INCOMPLETE_MESSAGE) { // Wait for more data to arrive break; } else if (status == SEC_I_CONTINUE_NEEDED) { SecBuffer* pDataBuffer = &outBuffers[0]; SecBuffer* pExtraBuffer = &inBuffers[1]; if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL) sendDataOnNetwork(pDataBuffer->pvBuffer, pDataBuffer->cbBuffer); if (pExtraBuffer->BufferType == SECBUFFER_EXTRA) m_receivedData.erase(m_receivedData.begin(), m_receivedData.end() - pExtraBuffer->cbBuffer); else m_receivedData.clear(); break; } else if (status == SEC_E_OK) { + status = validateServerCertificate(); + if (status != SEC_E_OK) + handleCertError(status); + SecBuffer* pExtraBuffer = &inBuffers[1]; if (pExtraBuffer && pExtraBuffer->cbBuffer > 0) m_receivedData.erase(m_receivedData.begin(), m_receivedData.end() - pExtraBuffer->cbBuffer); else m_receivedData.clear(); m_state = Connected; determineStreamSizes(); onConnected(); } else { // We failed to initialize the security context + handleCertError(status); indicateError(); return; } } } //------------------------------------------------------------------------ +void SchannelContext::handleCertError(SECURITY_STATUS status) +{ + if (status == SEC_E_UNTRUSTED_ROOT || + status == CERT_E_UNTRUSTEDROOT || + status == CRYPT_E_ISSUER_SERIALNUMBER || + status == CRYPT_E_SIGNER_NOT_FOUND || + status == CRYPT_E_NO_TRUSTED_SIGNER) + { + m_verificationError = CertificateVerificationError::Untrusted; + } + else if (status == SEC_E_CERT_EXPIRED || + status == CERT_E_EXPIRED) + { + m_verificationError = CertificateVerificationError::Expired; + } + else if (status == CRYPT_E_SELF_SIGNED) + { + m_verificationError = CertificateVerificationError::SelfSigned; + } + else if (status == CRYPT_E_HASH_VALUE || + status == TRUST_E_CERT_SIGNATURE) + { + m_verificationError = CertificateVerificationError::InvalidSignature; + } + else if (status == CRYPT_E_REVOKED) + { + m_verificationError = CertificateVerificationError::Revoked; + } + else if (status == CRYPT_E_NO_REVOCATION_CHECK || + status == CRYPT_E_REVOCATION_OFFLINE) + { + m_verificationError = CertificateVerificationError::RevocationCheckFailed; + } + else + { + m_verificationError = CertificateVerificationError::UnknownError; + } +} + +//------------------------------------------------------------------------ + void SchannelContext::sendDataOnNetwork(const void* pData, size_t dataSize) { if (dataSize > 0 && pData) { SafeByteArray byteArray(dataSize); memcpy(&byteArray[0], pData, dataSize); onDataForNetwork(byteArray); } } //------------------------------------------------------------------------ void SchannelContext::forwardDataToApplication(const void* pData, size_t dataSize) { SafeByteArray byteArray(dataSize); memcpy(&byteArray[0], pData, dataSize); onDataForApplication(byteArray); } //------------------------------------------------------------------------ void SchannelContext::handleDataFromApplication(const SafeByteArray& data) { // Don't attempt to send data until we're fully connected if (m_state == Connecting) return; // Encrypt the data encryptAndSendData(data); } //------------------------------------------------------------------------ @@ -417,70 +533,73 @@ void SchannelContext::decryptAndProcessData(const SafeByteArray& data) indicateError(); break; } SecBuffer* pDataBuffer = NULL; SecBuffer* pExtraBuffer = NULL; for (int i = 0; i < 4; ++i) { if (inBuffers[i].BufferType == SECBUFFER_DATA) pDataBuffer = &inBuffers[i]; else if (inBuffers[i].BufferType == SECBUFFER_EXTRA) pExtraBuffer = &inBuffers[i]; } if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL) forwardDataToApplication(pDataBuffer->pvBuffer, pDataBuffer->cbBuffer); // If there is extra data left over from the decryption operation, we call DecryptMessage() again if (pExtraBuffer) { m_receivedData.erase(m_receivedData.begin(), m_receivedData.end() - pExtraBuffer->cbBuffer); } else { // We're done m_receivedData.erase(m_receivedData.begin(), m_receivedData.begin() + inData); } } } //------------------------------------------------------------------------ void SchannelContext::encryptAndSendData(const SafeByteArray& data) { + if (m_streamSizes.cbMaximumMessage == 0) + return; + SecBuffer outBuffers[4] = {0}; // Calculate the largest required size of the send buffer size_t messageBufferSize = (data.size() > m_streamSizes.cbMaximumMessage) ? m_streamSizes.cbMaximumMessage : data.size(); // Allocate a packet for the encrypted data SafeByteArray sendBuffer; sendBuffer.resize(m_streamSizes.cbHeader + messageBufferSize + m_streamSizes.cbTrailer); size_t bytesSent = 0; do { size_t bytesLeftToSend = data.size() - bytesSent; // Calculate how much of the send buffer we'll be using for this chunk size_t bytesToSend = (bytesLeftToSend > m_streamSizes.cbMaximumMessage) ? m_streamSizes.cbMaximumMessage : bytesLeftToSend; // Copy the plain text data into the send buffer memcpy(&sendBuffer[0] + m_streamSizes.cbHeader, &data[0] + bytesSent, bytesToSend); outBuffers[0].pvBuffer = &sendBuffer[0]; outBuffers[0].cbBuffer = m_streamSizes.cbHeader; outBuffers[0].BufferType = SECBUFFER_STREAM_HEADER; outBuffers[1].pvBuffer = &sendBuffer[0] + m_streamSizes.cbHeader; outBuffers[1].cbBuffer = (unsigned long)bytesToSend; outBuffers[1].BufferType = SECBUFFER_DATA; outBuffers[2].pvBuffer = &sendBuffer[0] + m_streamSizes.cbHeader + bytesToSend; outBuffers[2].cbBuffer = m_streamSizes.cbTrailer; outBuffers[2].BufferType = SECBUFFER_STREAM_TRAILER; @@ -512,54 +631,54 @@ void SchannelContext::encryptAndSendData(const SafeByteArray& data) bool SchannelContext::setClientCertificate(CertificateWithKey::ref certificate) { boost::shared_ptr<CAPICertificate> capiCertificate = boost::dynamic_pointer_cast<CAPICertificate>(certificate); if (!capiCertificate || capiCertificate->isNull()) { return false; } // We assume that the Certificate Store Name/Certificate Name // are valid at this point m_cert_store_name = capiCertificate->getCertStoreName(); m_cert_name = capiCertificate->getCertName(); return true; } //------------------------------------------------------------------------ Certificate::ref SchannelContext::getPeerCertificate() const { SchannelCertificate::ref pCertificate; ScopedCertContext pServerCert; SECURITY_STATUS status = QueryContextAttributes(m_ctxtHandle, SECPKG_ATTR_REMOTE_CERT_CONTEXT, pServerCert.Reset()); if (status != SEC_E_OK) return pCertificate; pCertificate.reset( new SchannelCertificate(pServerCert) ); return pCertificate; } //------------------------------------------------------------------------ CertificateVerificationError::ref SchannelContext::getPeerCertificateVerificationError() const { boost::shared_ptr<CertificateVerificationError> pCertError; - if (m_state == Error) - pCertError.reset( new CertificateVerificationError(m_verificationError) ); + if (m_verificationError) + pCertError.reset( new CertificateVerificationError(*m_verificationError) ); return pCertError; } //------------------------------------------------------------------------ ByteArray SchannelContext::getFinishMessage() const { // TODO: Implement ByteArray emptyArray; return emptyArray; } //------------------------------------------------------------------------ } diff --git a/Swiften/TLS/Schannel/SchannelContext.h b/Swiften/TLS/Schannel/SchannelContext.h index 7726c41..70b0694 100644 --- a/Swiften/TLS/Schannel/SchannelContext.h +++ b/Swiften/TLS/Schannel/SchannelContext.h @@ -19,70 +19,72 @@ #include <security.h> #include <schnlsp.h> #include <boost/noncopyable.hpp> namespace Swift { class SchannelContext : public TLSContext, boost::noncopyable { public: typedef boost::shared_ptr<SchannelContext> sp_t; public: SchannelContext(); ~SchannelContext(); // // TLSContext // virtual void connect(); virtual bool setClientCertificate(CertificateWithKey::ref cert); virtual void handleDataFromNetwork(const SafeByteArray& data); virtual void handleDataFromApplication(const SafeByteArray& data); virtual Certificate::ref getPeerCertificate() const; virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const; virtual ByteArray getFinishMessage() const; private: void determineStreamSizes(); void continueHandshake(const SafeByteArray& data); void indicateError(); + void handleCertError(SECURITY_STATUS status) ; void sendDataOnNetwork(const void* pData, size_t dataSize); void forwardDataToApplication(const void* pData, size_t dataSize); void decryptAndProcessData(const SafeByteArray& data); void encryptAndSendData(const SafeByteArray& data); void appendNewData(const SafeByteArray& data); + SECURITY_STATUS validateServerCertificate(); private: enum SchannelState { Start, Connecting, Connected, Error }; SchannelState m_state; - CertificateVerificationError m_verificationError; + boost::optional<CertificateVerificationError> m_verificationError; ULONG m_secContext; ScopedCredHandle m_credHandle; ScopedCtxtHandle m_ctxtHandle; DWORD m_ctxtFlags; SecPkgContext_StreamSizes m_streamSizes; std::vector<char> m_receivedData; HCERTSTORE m_my_cert_store; std::string m_cert_store_name; std::string m_cert_name; }; } diff --git a/Swiften/TLS/Schannel/SchannelUtil.h b/Swiften/TLS/Schannel/SchannelUtil.h index 0a54f16..4f73aac 100644 --- a/Swiften/TLS/Schannel/SchannelUtil.h +++ b/Swiften/TLS/Schannel/SchannelUtil.h @@ -214,81 +214,212 @@ namespace Swift { private: struct HandleContext { HandleContext() : m_pCertCtxt(NULL) { } HandleContext(PCCERT_CONTEXT pCert) : m_pCertCtxt(pCert) { } ~HandleContext() { if (m_pCertCtxt) CertFreeCertificateContext(m_pCertCtxt); } PCCERT_CONTEXT m_pCertCtxt; }; public: ScopedCertContext() : m_pHandle( new HandleContext ) { } explicit ScopedCertContext(PCCERT_CONTEXT pCert) : m_pHandle( new HandleContext(pCert) ) { } // Copy constructor - explicit ScopedCertContext(const ScopedCertContext& rhs) + ScopedCertContext(const ScopedCertContext& rhs) { m_pHandle = rhs.m_pHandle; } ~ScopedCertContext() { m_pHandle.reset(); } PCCERT_CONTEXT* Reset() { FreeContext(); return &m_pHandle->m_pCertCtxt; } operator PCCERT_CONTEXT() const { return m_pHandle->m_pCertCtxt; } + PCCERT_CONTEXT* GetPointer() const + { + return &m_pHandle->m_pCertCtxt; + } + PCCERT_CONTEXT operator->() const { return m_pHandle->m_pCertCtxt; } ScopedCertContext& operator=(const ScopedCertContext& sh) { // Only update the internal handle if it's different if (&m_pHandle->m_pCertCtxt != &sh.m_pHandle->m_pCertCtxt) { m_pHandle = sh.m_pHandle; } return *this; } + ScopedCertContext& operator=(PCCERT_CONTEXT pCertCtxt) + { + // Only update the internal handle if it's different + if (m_pHandle && m_pHandle->m_pCertCtxt != pCertCtxt) + m_pHandle.reset( new HandleContext(pCertCtxt) ); + + return *this; + } + + void FreeContext() + { + m_pHandle.reset( new HandleContext ); + } + + private: + boost::shared_ptr<HandleContext> m_pHandle; + }; + + //------------------------------------------------------------------------ + + // + // Convenience wrapper around the Schannel HCERTSTORE. + // + class ScopedCertStore : boost::noncopyable + { + public: + ScopedCertStore(HCERTSTORE hCertStore) + : m_hCertStore(hCertStore) + { + } + + ~ScopedCertStore() + { + // Forcefully free all memory related to the store, i.e. we assume all CertContext's that have been opened via this + // cert store have been closed at this point. + if (m_hCertStore) + CertCloseStore(m_hCertStore, CERT_CLOSE_STORE_FORCE_FLAG); + } + + operator HCERTSTORE() const + { + return m_hCertStore; + } + + private: + HCERTSTORE m_hCertStore; + }; + + //------------------------------------------------------------------------ + + // + // Convenience wrapper around the Schannel CERT_CHAIN_CONTEXT. + // + class ScopedCertChainContext + { + private: + struct HandleContext + { + HandleContext() + : m_pCertChainCtxt(NULL) + { + } + + HandleContext(PCCERT_CHAIN_CONTEXT pCert) + : m_pCertChainCtxt(pCert) + { + } + + ~HandleContext() + { + if (m_pCertChainCtxt) + CertFreeCertificateChain(m_pCertChainCtxt); + } + + PCCERT_CHAIN_CONTEXT m_pCertChainCtxt; + }; + + public: + ScopedCertChainContext() + : m_pHandle( new HandleContext ) + { + } + + explicit ScopedCertChainContext(PCCERT_CHAIN_CONTEXT pCert) + : m_pHandle( new HandleContext(pCert) ) + { + } + + // Copy constructor + ScopedCertChainContext(const ScopedCertChainContext& rhs) + { + m_pHandle = rhs.m_pHandle; + } + + ~ScopedCertChainContext() + { + m_pHandle.reset(); + } + + PCCERT_CHAIN_CONTEXT* Reset() + { + FreeContext(); + return &m_pHandle->m_pCertChainCtxt; + } + + operator PCCERT_CHAIN_CONTEXT() const + { + return m_pHandle->m_pCertChainCtxt; + } + + PCCERT_CHAIN_CONTEXT operator->() const + { + return m_pHandle->m_pCertChainCtxt; + } + + ScopedCertChainContext& operator=(const ScopedCertChainContext& sh) + { + // Only update the internal handle if it's different + if (&m_pHandle->m_pCertChainCtxt != &sh.m_pHandle->m_pCertChainCtxt) + { + m_pHandle = sh.m_pHandle; + } + + return *this; + } + void FreeContext() { m_pHandle.reset( new HandleContext ); } private: boost::shared_ptr<HandleContext> m_pHandle; - }; + }; } |