diff options
Diffstat (limited to 'Swiften/TLS/Schannel')
-rw-r--r-- | Swiften/TLS/Schannel/SchannelCertificate.cpp | 284 | ||||
-rw-r--r-- | Swiften/TLS/Schannel/SchannelCertificate.h | 140 | ||||
-rw-r--r-- | Swiften/TLS/Schannel/SchannelCertificateFactory.h | 12 | ||||
-rw-r--r-- | Swiften/TLS/Schannel/SchannelContext.cpp | 1080 | ||||
-rw-r--r-- | Swiften/TLS/Schannel/SchannelContext.h | 122 | ||||
-rw-r--r-- | Swiften/TLS/Schannel/SchannelContextFactory.cpp | 14 | ||||
-rw-r--r-- | Swiften/TLS/Schannel/SchannelContextFactory.h | 22 | ||||
-rw-r--r-- | Swiften/TLS/Schannel/SchannelUtil.h | 814 |
8 files changed, 1244 insertions, 1244 deletions
diff --git a/Swiften/TLS/Schannel/SchannelCertificate.cpp b/Swiften/TLS/Schannel/SchannelCertificate.cpp index 8aaec00..68dd0cf 100644 --- a/Swiften/TLS/Schannel/SchannelCertificate.cpp +++ b/Swiften/TLS/Schannel/SchannelCertificate.cpp @@ -20,176 +20,176 @@ namespace Swift { //------------------------------------------------------------------------ -SchannelCertificate::SchannelCertificate(const ScopedCertContext& certCtxt) -: m_cert(certCtxt) +SchannelCertificate::SchannelCertificate(const ScopedCertContext& certCtxt) +: m_cert(certCtxt) { - parse(); + parse(); } //------------------------------------------------------------------------ SchannelCertificate::SchannelCertificate(const ByteArray& der) { - if (!der.empty()) - { - // Convert the DER encoded certificate to a PCERT_CONTEXT - CERT_BLOB certBlob = {0}; - certBlob.cbData = der.size(); - certBlob.pbData = (BYTE*)&der[0]; - - if (!CryptQueryObject( - CERT_QUERY_OBJECT_BLOB, - &certBlob, - CERT_QUERY_CONTENT_FLAG_CERT, - CERT_QUERY_FORMAT_FLAG_ALL, - 0, - NULL, - NULL, - NULL, - NULL, - NULL, - (const void**)m_cert.Reset())) - { - // TODO: Because Swiften isn't exception safe, we have no way to indicate failure - } - } + if (!der.empty()) + { + // Convert the DER encoded certificate to a PCERT_CONTEXT + CERT_BLOB certBlob = {0}; + certBlob.cbData = der.size(); + certBlob.pbData = (BYTE*)&der[0]; + + if (!CryptQueryObject( + CERT_QUERY_OBJECT_BLOB, + &certBlob, + CERT_QUERY_CONTENT_FLAG_CERT, + CERT_QUERY_FORMAT_FLAG_ALL, + 0, + NULL, + NULL, + NULL, + NULL, + NULL, + (const void**)m_cert.Reset())) + { + // TODO: Because Swiften isn't exception safe, we have no way to indicate failure + } + } } //------------------------------------------------------------------------ -ByteArray SchannelCertificate::toDER() const +ByteArray SchannelCertificate::toDER() const { - ByteArray result; + ByteArray result; - // Serialize the certificate. The CERT_CONTEXT is already DER encoded. - result.resize(m_cert->cbCertEncoded); - memcpy(&result[0], m_cert->pbCertEncoded, result.size()); - - return result; + // Serialize the certificate. The CERT_CONTEXT is already DER encoded. + result.resize(m_cert->cbCertEncoded); + memcpy(&result[0], m_cert->pbCertEncoded, result.size()); + + return result; } //------------------------------------------------------------------------ std::string SchannelCertificate::wstrToStr(const std::wstring& wstr) { - if (wstr.empty()) - return ""; + if (wstr.empty()) + return ""; - // First request the size of the required UTF-8 buffer - int numRequiredBytes = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), wstr.size(), NULL, 0, NULL, NULL); - if (!numRequiredBytes) - return ""; + // First request the size of the required UTF-8 buffer + int numRequiredBytes = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), wstr.size(), NULL, 0, NULL, NULL); + if (!numRequiredBytes) + return ""; - // Allocate memory for the UTF-8 string - std::vector<char> utf8Str(numRequiredBytes); + // Allocate memory for the UTF-8 string + std::vector<char> utf8Str(numRequiredBytes); - int numConverted = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), wstr.size(), &utf8Str[0], numRequiredBytes, NULL, NULL); - if (!numConverted) - return ""; + int numConverted = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), wstr.size(), &utf8Str[0], numRequiredBytes, NULL, NULL); + if (!numConverted) + return ""; - std::string str(&utf8Str[0], numConverted); - return str; + std::string str(&utf8Str[0], numConverted); + return str; } //------------------------------------------------------------------------ -void SchannelCertificate::parse() +void SchannelCertificate::parse() { - // - // Subject name - // - DWORD requiredSize = CertNameToStr(X509_ASN_ENCODING, &m_cert->pCertInfo->Subject, CERT_OID_NAME_STR, NULL, 0); - if (requiredSize > 1) - { - vector<char> rawSubjectName(requiredSize); - CertNameToStr(X509_ASN_ENCODING, &m_cert->pCertInfo->Subject, CERT_OID_NAME_STR, &rawSubjectName[0], rawSubjectName.size()); - m_subjectName = std::string(&rawSubjectName[0]); - } - - // - // Common name - // - // Note: We only pull out one common name from the cert. - requiredSize = CertGetNameString(m_cert, CERT_NAME_ATTR_TYPE, 0, szOID_COMMON_NAME, NULL, 0); - if (requiredSize > 1) - { - vector<char> rawCommonName(requiredSize); - requiredSize = CertGetNameString(m_cert, CERT_NAME_ATTR_TYPE, 0, szOID_COMMON_NAME, &rawCommonName[0], rawCommonName.size()); - m_commonNames.push_back( std::string(&rawCommonName[0]) ); - } - - // - // Subject alternative names - // - PCERT_EXTENSION pExtensions = CertFindExtension(szOID_SUBJECT_ALT_NAME2, m_cert->pCertInfo->cExtension, m_cert->pCertInfo->rgExtension); - if (pExtensions) - { - CRYPT_DECODE_PARA decodePara = {0}; - decodePara.cbSize = sizeof(decodePara); - - CERT_ALT_NAME_INFO* pAltNameInfo = NULL; - DWORD altNameInfoSize = 0; - - BOOL status = CryptDecodeObjectEx( - X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, - szOID_SUBJECT_ALT_NAME2, - pExtensions->Value.pbData, - pExtensions->Value.cbData, - CRYPT_DECODE_ALLOC_FLAG | CRYPT_DECODE_NOCOPY_FLAG, - &decodePara, - &pAltNameInfo, - &altNameInfoSize); - - if (status && pAltNameInfo) - { - for (int i = 0; i < pAltNameInfo->cAltEntry; i++) - { - if (pAltNameInfo->rgAltEntry[i].dwAltNameChoice == CERT_ALT_NAME_DNS_NAME) - addDNSName( wstrToStr( pAltNameInfo->rgAltEntry[i].pwszDNSName ) ); - } - } - } - - // if (pExtensions) - // { - // vector<wchar_t> subjectAlt - // CryptDecodeObject(X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, szOID_SUBJECT_ALT_NAME, pExtensions->Value->pbData, pExtensions->Value->cbData, ) - // } - // - // // subjectAltNames - // int subjectAltNameLoc = X509_get_ext_by_NID(cert.get(), NID_subject_alt_name, -1); - // if(subjectAltNameLoc != -1) { - // X509_EXTENSION* extension = X509_get_ext(cert.get(), subjectAltNameLoc); - // boost::shared_ptr<GENERAL_NAMES> generalNames(reinterpret_cast<GENERAL_NAMES*>(X509V3_EXT_d2i(extension)), GENERAL_NAMES_free); - // boost::shared_ptr<ASN1_OBJECT> xmppAddrObject(OBJ_txt2obj(ID_ON_XMPPADDR_OID, 1), ASN1_OBJECT_free); - // boost::shared_ptr<ASN1_OBJECT> dnsSRVObject(OBJ_txt2obj(ID_ON_DNSSRV_OID, 1), ASN1_OBJECT_free); - // for (int i = 0; i < sk_GENERAL_NAME_num(generalNames.get()); ++i) { - // GENERAL_NAME* generalName = sk_GENERAL_NAME_value(generalNames.get(), i); - // if (generalName->type == GEN_OTHERNAME) { - // OTHERNAME* otherName = generalName->d.otherName; - // if (OBJ_cmp(otherName->type_id, xmppAddrObject.get()) == 0) { - // // XmppAddr - // if (otherName->value->type != V_ASN1_UTF8STRING) { - // continue; - // } - // ASN1_UTF8STRING* xmppAddrValue = otherName->value->value.utf8string; - // addXMPPAddress(ByteArray(ASN1_STRING_data(xmppAddrValue), ASN1_STRING_length(xmppAddrValue)).toString()); - // } - // else if (OBJ_cmp(otherName->type_id, dnsSRVObject.get()) == 0) { - // // SRVName - // if (otherName->value->type != V_ASN1_IA5STRING) { - // continue; - // } - // ASN1_IA5STRING* srvNameValue = otherName->value->value.ia5string; - // addSRVName(ByteArray(ASN1_STRING_data(srvNameValue), ASN1_STRING_length(srvNameValue)).toString()); - // } - // } - // else if (generalName->type == GEN_DNS) { - // // DNSName - // addDNSName(ByteArray(ASN1_STRING_data(generalName->d.dNSName), ASN1_STRING_length(generalName->d.dNSName)).toString()); - // } - // } - // } + // + // Subject name + // + DWORD requiredSize = CertNameToStr(X509_ASN_ENCODING, &m_cert->pCertInfo->Subject, CERT_OID_NAME_STR, NULL, 0); + if (requiredSize > 1) + { + vector<char> rawSubjectName(requiredSize); + CertNameToStr(X509_ASN_ENCODING, &m_cert->pCertInfo->Subject, CERT_OID_NAME_STR, &rawSubjectName[0], rawSubjectName.size()); + m_subjectName = std::string(&rawSubjectName[0]); + } + + // + // Common name + // + // Note: We only pull out one common name from the cert. + requiredSize = CertGetNameString(m_cert, CERT_NAME_ATTR_TYPE, 0, szOID_COMMON_NAME, NULL, 0); + if (requiredSize > 1) + { + vector<char> rawCommonName(requiredSize); + requiredSize = CertGetNameString(m_cert, CERT_NAME_ATTR_TYPE, 0, szOID_COMMON_NAME, &rawCommonName[0], rawCommonName.size()); + m_commonNames.push_back( std::string(&rawCommonName[0]) ); + } + + // + // Subject alternative names + // + PCERT_EXTENSION pExtensions = CertFindExtension(szOID_SUBJECT_ALT_NAME2, m_cert->pCertInfo->cExtension, m_cert->pCertInfo->rgExtension); + if (pExtensions) + { + CRYPT_DECODE_PARA decodePara = {0}; + decodePara.cbSize = sizeof(decodePara); + + CERT_ALT_NAME_INFO* pAltNameInfo = NULL; + DWORD altNameInfoSize = 0; + + BOOL status = CryptDecodeObjectEx( + X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, + szOID_SUBJECT_ALT_NAME2, + pExtensions->Value.pbData, + pExtensions->Value.cbData, + CRYPT_DECODE_ALLOC_FLAG | CRYPT_DECODE_NOCOPY_FLAG, + &decodePara, + &pAltNameInfo, + &altNameInfoSize); + + if (status && pAltNameInfo) + { + for (int i = 0; i < pAltNameInfo->cAltEntry; i++) + { + if (pAltNameInfo->rgAltEntry[i].dwAltNameChoice == CERT_ALT_NAME_DNS_NAME) + addDNSName( wstrToStr( pAltNameInfo->rgAltEntry[i].pwszDNSName ) ); + } + } + } + + // if (pExtensions) + // { + // vector<wchar_t> subjectAlt + // CryptDecodeObject(X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, szOID_SUBJECT_ALT_NAME, pExtensions->Value->pbData, pExtensions->Value->cbData, ) + // } + // + // // subjectAltNames + // int subjectAltNameLoc = X509_get_ext_by_NID(cert.get(), NID_subject_alt_name, -1); + // if(subjectAltNameLoc != -1) { + // X509_EXTENSION* extension = X509_get_ext(cert.get(), subjectAltNameLoc); + // boost::shared_ptr<GENERAL_NAMES> generalNames(reinterpret_cast<GENERAL_NAMES*>(X509V3_EXT_d2i(extension)), GENERAL_NAMES_free); + // boost::shared_ptr<ASN1_OBJECT> xmppAddrObject(OBJ_txt2obj(ID_ON_XMPPADDR_OID, 1), ASN1_OBJECT_free); + // boost::shared_ptr<ASN1_OBJECT> dnsSRVObject(OBJ_txt2obj(ID_ON_DNSSRV_OID, 1), ASN1_OBJECT_free); + // for (int i = 0; i < sk_GENERAL_NAME_num(generalNames.get()); ++i) { + // GENERAL_NAME* generalName = sk_GENERAL_NAME_value(generalNames.get(), i); + // if (generalName->type == GEN_OTHERNAME) { + // OTHERNAME* otherName = generalName->d.otherName; + // if (OBJ_cmp(otherName->type_id, xmppAddrObject.get()) == 0) { + // // XmppAddr + // if (otherName->value->type != V_ASN1_UTF8STRING) { + // continue; + // } + // ASN1_UTF8STRING* xmppAddrValue = otherName->value->value.utf8string; + // addXMPPAddress(ByteArray(ASN1_STRING_data(xmppAddrValue), ASN1_STRING_length(xmppAddrValue)).toString()); + // } + // else if (OBJ_cmp(otherName->type_id, dnsSRVObject.get()) == 0) { + // // SRVName + // if (otherName->value->type != V_ASN1_IA5STRING) { + // continue; + // } + // ASN1_IA5STRING* srvNameValue = otherName->value->value.ia5string; + // addSRVName(ByteArray(ASN1_STRING_data(srvNameValue), ASN1_STRING_length(srvNameValue)).toString()); + // } + // } + // else if (generalName->type == GEN_DNS) { + // // DNSName + // addDNSName(ByteArray(ASN1_STRING_data(generalName->d.dNSName), ASN1_STRING_length(generalName->d.dNSName)).toString()); + // } + // } + // } } //------------------------------------------------------------------------ diff --git a/Swiften/TLS/Schannel/SchannelCertificate.h b/Swiften/TLS/Schannel/SchannelCertificate.h index 0f4e9c1..814f344 100644 --- a/Swiften/TLS/Schannel/SchannelCertificate.h +++ b/Swiften/TLS/Schannel/SchannelCertificate.h @@ -18,75 +18,75 @@ #include <Swiften/TLS/Certificate.h> #include <Swiften/TLS/Schannel/SchannelUtil.h> -namespace Swift +namespace Swift { - class SchannelCertificate : public Certificate - { - public: - typedef boost::shared_ptr<SchannelCertificate> ref; - - public: - SchannelCertificate(const ScopedCertContext& certCtxt); - SchannelCertificate(const ByteArray& der); - - std::string getSubjectName() const - { - return m_subjectName; - } - - std::vector<std::string> getCommonNames() const - { - return m_commonNames; - } - - std::vector<std::string> getSRVNames() const - { - return m_srvNames; - } - - std::vector<std::string> getDNSNames() const - { - return m_dnsNames; - } - - std::vector<std::string> getXMPPAddresses() const - { - return m_xmppAddresses; - } - - ScopedCertContext getCertContext() const - { - return m_cert; - } - - ByteArray toDER() const; - - private: - void parse(); - std::string wstrToStr(const std::wstring& wstr); - - void addSRVName(const std::string& name) - { - m_srvNames.push_back(name); - } - - void addDNSName(const std::string& name) - { - m_dnsNames.push_back(name); - } - - void addXMPPAddress(const std::string& addr) - { - m_xmppAddresses.push_back(addr); - } - - private: - ScopedCertContext m_cert; - - std::string m_subjectName; - std::vector<std::string> m_commonNames; - std::vector<std::string> m_dnsNames; - std::vector<std::string> m_xmppAddresses; - std::vector<std::string> m_srvNames; - }; + class SchannelCertificate : public Certificate + { + public: + typedef boost::shared_ptr<SchannelCertificate> ref; + + public: + SchannelCertificate(const ScopedCertContext& certCtxt); + SchannelCertificate(const ByteArray& der); + + std::string getSubjectName() const + { + return m_subjectName; + } + + std::vector<std::string> getCommonNames() const + { + return m_commonNames; + } + + std::vector<std::string> getSRVNames() const + { + return m_srvNames; + } + + std::vector<std::string> getDNSNames() const + { + return m_dnsNames; + } + + std::vector<std::string> getXMPPAddresses() const + { + return m_xmppAddresses; + } + + ScopedCertContext getCertContext() const + { + return m_cert; + } + + ByteArray toDER() const; + + private: + void parse(); + std::string wstrToStr(const std::wstring& wstr); + + void addSRVName(const std::string& name) + { + m_srvNames.push_back(name); + } + + void addDNSName(const std::string& name) + { + m_dnsNames.push_back(name); + } + + void addXMPPAddress(const std::string& addr) + { + m_xmppAddresses.push_back(addr); + } + + private: + ScopedCertContext m_cert; + + std::string m_subjectName; + std::vector<std::string> m_commonNames; + std::vector<std::string> m_dnsNames; + std::vector<std::string> m_xmppAddresses; + std::vector<std::string> m_srvNames; + }; } diff --git a/Swiften/TLS/Schannel/SchannelCertificateFactory.h b/Swiften/TLS/Schannel/SchannelCertificateFactory.h index 5a2b208..be97c52 100644 --- a/Swiften/TLS/Schannel/SchannelCertificateFactory.h +++ b/Swiften/TLS/Schannel/SchannelCertificateFactory.h @@ -10,10 +10,10 @@ #include <Swiften/TLS/Schannel/SchannelCertificate.h> namespace Swift { - class SchannelCertificateFactory : public CertificateFactory { - public: - virtual Certificate* createCertificateFromDER(const ByteArray& der) { - return new SchannelCertificate(der); - } - }; + class SchannelCertificateFactory : public CertificateFactory { + public: + virtual Certificate* createCertificateFromDER(const ByteArray& der) { + return new SchannelCertificate(der); + } + }; } diff --git a/Swiften/TLS/Schannel/SchannelContext.cpp b/Swiften/TLS/Schannel/SchannelContext.cpp index 62aa137..7b67f4c 100644 --- a/Swiften/TLS/Schannel/SchannelContext.cpp +++ b/Swiften/TLS/Schannel/SchannelContext.cpp @@ -24,671 +24,671 @@ namespace Swift { //------------------------------------------------------------------------ SchannelContext::SchannelContext(bool tls1_0Workaround) : state_(Start), secContext_(0), myCertStore_(NULL), certStoreName_("MY"), certName_(), smartCardReader_(), checkCertificateRevocation_(true), tls1_0Workaround_(tls1_0Workaround), disconnectOnCardRemoval_(true) { - contextFlags_ = ISC_REQ_ALLOCATE_MEMORY | - ISC_REQ_CONFIDENTIALITY | - ISC_REQ_EXTENDED_ERROR | - ISC_REQ_INTEGRITY | - ISC_REQ_REPLAY_DETECT | - ISC_REQ_SEQUENCE_DETECT | - ISC_REQ_USE_SUPPLIED_CREDS | - ISC_REQ_STREAM; - - ZeroMemory(&streamSizes_, sizeof(streamSizes_)); + contextFlags_ = ISC_REQ_ALLOCATE_MEMORY | + ISC_REQ_CONFIDENTIALITY | + ISC_REQ_EXTENDED_ERROR | + ISC_REQ_INTEGRITY | + ISC_REQ_REPLAY_DETECT | + ISC_REQ_SEQUENCE_DETECT | + ISC_REQ_USE_SUPPLIED_CREDS | + ISC_REQ_STREAM; + + ZeroMemory(&streamSizes_, sizeof(streamSizes_)); } //------------------------------------------------------------------------ SchannelContext::~SchannelContext() { - if (myCertStore_) CertCloseStore(myCertStore_, 0); + if (myCertStore_) CertCloseStore(myCertStore_, 0); } //------------------------------------------------------------------------ void SchannelContext::determineStreamSizes() { - QueryContextAttributes(contextHandle_, SECPKG_ATTR_STREAM_SIZES, &streamSizes_); + QueryContextAttributes(contextHandle_, SECPKG_ATTR_STREAM_SIZES, &streamSizes_); } //------------------------------------------------------------------------ void SchannelContext::connect() { - ScopedCertContext pCertContext; - - state_ = Connecting; - - // If a user name is specified, then attempt to find a client - // certificate. Otherwise, just create a NULL credential. - if (!certName_.empty()) { - if (myCertStore_ == NULL) { - myCertStore_ = CertOpenSystemStore(0, certStoreName_.c_str()); - if (!myCertStore_) { - indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); - return; - } - } - - pCertContext = findCertificateInStore( myCertStore_, certName_ ); - if (pCertContext == NULL) { - indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); - return; - } - } - - // We use an empty list for client certificates - PCCERT_CONTEXT clientCerts[1] = {0}; - - SCHANNEL_CRED sc = {0}; - sc.dwVersion = SCHANNEL_CRED_VERSION; - - if (tls1_0Workaround_) { - sc.grbitEnabledProtocols = SP_PROT_TLS1_CLIENT; - } - else { - sc.grbitEnabledProtocols = /*SP_PROT_SSL3_CLIENT | */SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_2_CLIENT; - } - - sc.dwFlags = SCH_CRED_MANUAL_CRED_VALIDATION; - - if (pCertContext) { - sc.cCreds = 1; - sc.paCred = pCertContext.GetPointer(); - sc.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS; - } - else { - sc.cCreds = 0; - sc.paCred = clientCerts; - sc.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS; - } - - // Swiften performs the server name check for us - sc.dwFlags |= SCH_CRED_NO_SERVERNAME_CHECK; - - SECURITY_STATUS status = AcquireCredentialsHandle( - NULL, - UNISP_NAME, - SECPKG_CRED_OUTBOUND, - NULL, - &sc, - NULL, - NULL, - credHandle_.Reset(), - NULL); - - if (status != SEC_E_OK) { - // We failed to obtain the credentials handle - indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); - return; - } - - SecBuffer outBuffers[2]; - - // We let Schannel allocate the output buffer for us - outBuffers[0].pvBuffer = NULL; - outBuffers[0].cbBuffer = 0; - outBuffers[0].BufferType = SECBUFFER_TOKEN; - - // Contains alert data if an alert is generated - outBuffers[1].pvBuffer = NULL; - outBuffers[1].cbBuffer = 0; - outBuffers[1].BufferType = SECBUFFER_ALERT; - - // Make sure the output buffers are freed - ScopedSecBuffer scopedOutputData(&outBuffers[0]); - ScopedSecBuffer scopedOutputAlertData(&outBuffers[1]); - - SecBufferDesc outBufferDesc = {0}; - outBufferDesc.cBuffers = 2; - outBufferDesc.pBuffers = outBuffers; - outBufferDesc.ulVersion = SECBUFFER_VERSION; - - // Create the initial security context - status = InitializeSecurityContext( - credHandle_, - NULL, - NULL, - contextFlags_, - 0, - 0, - NULL, - 0, - contextHandle_.Reset(), - &outBufferDesc, - &secContext_, - NULL); - - if (status != SEC_E_OK && status != SEC_I_CONTINUE_NEEDED) { - // We failed to initialize the security context - handleCertError(status); - indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); - return; - } - - // Start the handshake - sendDataOnNetwork(outBuffers[0].pvBuffer, outBuffers[0].cbBuffer); - - if (status == SEC_E_OK) { - status = validateServerCertificate(); - if (status != SEC_E_OK) { - handleCertError(status); - } - - state_ = Connected; - determineStreamSizes(); - - onConnected(); - } + ScopedCertContext pCertContext; + + state_ = Connecting; + + // If a user name is specified, then attempt to find a client + // certificate. Otherwise, just create a NULL credential. + if (!certName_.empty()) { + if (myCertStore_ == NULL) { + myCertStore_ = CertOpenSystemStore(0, certStoreName_.c_str()); + if (!myCertStore_) { + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); + return; + } + } + + pCertContext = findCertificateInStore( myCertStore_, certName_ ); + if (pCertContext == NULL) { + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); + return; + } + } + + // We use an empty list for client certificates + PCCERT_CONTEXT clientCerts[1] = {0}; + + SCHANNEL_CRED sc = {0}; + sc.dwVersion = SCHANNEL_CRED_VERSION; + + if (tls1_0Workaround_) { + sc.grbitEnabledProtocols = SP_PROT_TLS1_CLIENT; + } + else { + sc.grbitEnabledProtocols = /*SP_PROT_SSL3_CLIENT | */SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_2_CLIENT; + } + + sc.dwFlags = SCH_CRED_MANUAL_CRED_VALIDATION; + + if (pCertContext) { + sc.cCreds = 1; + sc.paCred = pCertContext.GetPointer(); + sc.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS; + } + else { + sc.cCreds = 0; + sc.paCred = clientCerts; + sc.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS; + } + + // Swiften performs the server name check for us + sc.dwFlags |= SCH_CRED_NO_SERVERNAME_CHECK; + + SECURITY_STATUS status = AcquireCredentialsHandle( + NULL, + UNISP_NAME, + SECPKG_CRED_OUTBOUND, + NULL, + &sc, + NULL, + NULL, + credHandle_.Reset(), + NULL); + + if (status != SEC_E_OK) { + // We failed to obtain the credentials handle + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); + return; + } + + SecBuffer outBuffers[2]; + + // We let Schannel allocate the output buffer for us + outBuffers[0].pvBuffer = NULL; + outBuffers[0].cbBuffer = 0; + outBuffers[0].BufferType = SECBUFFER_TOKEN; + + // Contains alert data if an alert is generated + outBuffers[1].pvBuffer = NULL; + outBuffers[1].cbBuffer = 0; + outBuffers[1].BufferType = SECBUFFER_ALERT; + + // Make sure the output buffers are freed + ScopedSecBuffer scopedOutputData(&outBuffers[0]); + ScopedSecBuffer scopedOutputAlertData(&outBuffers[1]); + + SecBufferDesc outBufferDesc = {0}; + outBufferDesc.cBuffers = 2; + outBufferDesc.pBuffers = outBuffers; + outBufferDesc.ulVersion = SECBUFFER_VERSION; + + // Create the initial security context + status = InitializeSecurityContext( + credHandle_, + NULL, + NULL, + contextFlags_, + 0, + 0, + NULL, + 0, + contextHandle_.Reset(), + &outBufferDesc, + &secContext_, + NULL); + + if (status != SEC_E_OK && status != SEC_I_CONTINUE_NEEDED) { + // We failed to initialize the security context + handleCertError(status); + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); + return; + } + + // Start the handshake + sendDataOnNetwork(outBuffers[0].pvBuffer, outBuffers[0].cbBuffer); + + if (status == SEC_E_OK) { + status = validateServerCertificate(); + if (status != SEC_E_OK) { + handleCertError(status); + } + + state_ = Connected; + determineStreamSizes(); + + onConnected(); + } } //------------------------------------------------------------------------ SECURITY_STATUS SchannelContext::validateServerCertificate() { - SchannelCertificate::ref pServerCert = boost::dynamic_pointer_cast<SchannelCertificate>( getPeerCertificate() ); - if (!pServerCert) { - return SEC_E_WRONG_PRINCIPAL; - } - - const LPSTR usage[] = - { - szOID_PKIX_KP_SERVER_AUTH, - szOID_SERVER_GATED_CRYPTO, - szOID_SGC_NETSCAPE - }; - - CERT_CHAIN_PARA chainParams = {0}; - chainParams.cbSize = sizeof(chainParams); - chainParams.RequestedUsage.dwType = USAGE_MATCH_TYPE_OR; - chainParams.RequestedUsage.Usage.cUsageIdentifier = ARRAYSIZE(usage); - chainParams.RequestedUsage.Usage.rgpszUsageIdentifier = const_cast<LPSTR*>(usage); - - DWORD chainFlags = CERT_CHAIN_CACHE_END_CERT; - if (checkCertificateRevocation_) { - chainFlags |= CERT_CHAIN_REVOCATION_CHECK_CHAIN_EXCLUDE_ROOT; - } - - ScopedCertChainContext pChainContext; - - BOOL success = CertGetCertificateChain( - NULL, // Use the chain engine for the current user (assumes a user is logged in) - pServerCert->getCertContext(), - NULL, - pServerCert->getCertContext()->hCertStore, - &chainParams, - chainFlags, - NULL, - pChainContext.Reset()); - - if (!success) { - return GetLastError(); - } - - SSL_EXTRA_CERT_CHAIN_POLICY_PARA sslChainPolicy = {0}; - sslChainPolicy.cbSize = sizeof(sslChainPolicy); - sslChainPolicy.dwAuthType = AUTHTYPE_SERVER; - sslChainPolicy.fdwChecks = SECURITY_FLAG_IGNORE_CERT_CN_INVALID; // Swiften checks the server name for us. Is this the correct way to disable server name checking? - sslChainPolicy.pwszServerName = NULL; - - CERT_CHAIN_POLICY_PARA certChainPolicy = {0}; - certChainPolicy.cbSize = sizeof(certChainPolicy); - certChainPolicy.dwFlags = CERT_CHAIN_POLICY_IGNORE_INVALID_NAME_FLAG; // Swiften checks the server name for us. Is this the correct way to disable server name checking? - certChainPolicy.pvExtraPolicyPara = &sslChainPolicy; - - CERT_CHAIN_POLICY_STATUS certChainPolicyStatus = {0}; - certChainPolicyStatus.cbSize = sizeof(certChainPolicyStatus); - - // Verify the chain - if (!CertVerifyCertificateChainPolicy( - CERT_CHAIN_POLICY_SSL, - pChainContext, - &certChainPolicy, - &certChainPolicyStatus)) { - return GetLastError(); - } - - if (certChainPolicyStatus.dwError != S_OK) { - return certChainPolicyStatus.dwError; - } - - return S_OK; + SchannelCertificate::ref pServerCert = boost::dynamic_pointer_cast<SchannelCertificate>( getPeerCertificate() ); + if (!pServerCert) { + return SEC_E_WRONG_PRINCIPAL; + } + + const LPSTR usage[] = + { + szOID_PKIX_KP_SERVER_AUTH, + szOID_SERVER_GATED_CRYPTO, + szOID_SGC_NETSCAPE + }; + + CERT_CHAIN_PARA chainParams = {0}; + chainParams.cbSize = sizeof(chainParams); + chainParams.RequestedUsage.dwType = USAGE_MATCH_TYPE_OR; + chainParams.RequestedUsage.Usage.cUsageIdentifier = ARRAYSIZE(usage); + chainParams.RequestedUsage.Usage.rgpszUsageIdentifier = const_cast<LPSTR*>(usage); + + DWORD chainFlags = CERT_CHAIN_CACHE_END_CERT; + if (checkCertificateRevocation_) { + chainFlags |= CERT_CHAIN_REVOCATION_CHECK_CHAIN_EXCLUDE_ROOT; + } + + ScopedCertChainContext pChainContext; + + BOOL success = CertGetCertificateChain( + NULL, // Use the chain engine for the current user (assumes a user is logged in) + pServerCert->getCertContext(), + NULL, + pServerCert->getCertContext()->hCertStore, + &chainParams, + chainFlags, + NULL, + pChainContext.Reset()); + + if (!success) { + return GetLastError(); + } + + SSL_EXTRA_CERT_CHAIN_POLICY_PARA sslChainPolicy = {0}; + sslChainPolicy.cbSize = sizeof(sslChainPolicy); + sslChainPolicy.dwAuthType = AUTHTYPE_SERVER; + sslChainPolicy.fdwChecks = SECURITY_FLAG_IGNORE_CERT_CN_INVALID; // Swiften checks the server name for us. Is this the correct way to disable server name checking? + sslChainPolicy.pwszServerName = NULL; + + CERT_CHAIN_POLICY_PARA certChainPolicy = {0}; + certChainPolicy.cbSize = sizeof(certChainPolicy); + certChainPolicy.dwFlags = CERT_CHAIN_POLICY_IGNORE_INVALID_NAME_FLAG; // Swiften checks the server name for us. Is this the correct way to disable server name checking? + certChainPolicy.pvExtraPolicyPara = &sslChainPolicy; + + CERT_CHAIN_POLICY_STATUS certChainPolicyStatus = {0}; + certChainPolicyStatus.cbSize = sizeof(certChainPolicyStatus); + + // Verify the chain + if (!CertVerifyCertificateChainPolicy( + CERT_CHAIN_POLICY_SSL, + pChainContext, + &certChainPolicy, + &certChainPolicyStatus)) { + return GetLastError(); + } + + if (certChainPolicyStatus.dwError != S_OK) { + return certChainPolicyStatus.dwError; + } + + return S_OK; } //------------------------------------------------------------------------ void SchannelContext::appendNewData(const SafeByteArray& data) { - size_t originalSize = receivedData_.size(); - receivedData_.resize(originalSize + data.size()); - memcpy(&receivedData_[0] + originalSize, &data[0], data.size()); + size_t originalSize = receivedData_.size(); + receivedData_.resize(originalSize + data.size()); + memcpy(&receivedData_[0] + originalSize, &data[0], data.size()); } //------------------------------------------------------------------------ void SchannelContext::continueHandshake(const SafeByteArray& data) { - appendNewData(data); - - while (!receivedData_.empty()) { - SecBuffer inBuffers[2]; - - // Provide Schannel with the remote host's handshake data - inBuffers[0].pvBuffer = (char*)(&receivedData_[0]); - inBuffers[0].cbBuffer = (unsigned long)receivedData_.size(); - inBuffers[0].BufferType = SECBUFFER_TOKEN; - - inBuffers[1].pvBuffer = NULL; - inBuffers[1].cbBuffer = 0; - inBuffers[1].BufferType = SECBUFFER_EMPTY; - - SecBufferDesc inBufferDesc = {0}; - inBufferDesc.cBuffers = 2; - inBufferDesc.pBuffers = inBuffers; - inBufferDesc.ulVersion = SECBUFFER_VERSION; - - SecBuffer outBuffers[2]; - - // We let Schannel allocate the output buffer for us - outBuffers[0].pvBuffer = NULL; - outBuffers[0].cbBuffer = 0; - outBuffers[0].BufferType = SECBUFFER_TOKEN; - - // Contains alert data if an alert is generated - outBuffers[1].pvBuffer = NULL; - outBuffers[1].cbBuffer = 0; - outBuffers[1].BufferType = SECBUFFER_ALERT; - - // Make sure the output buffers are freed - ScopedSecBuffer scopedOutputData(&outBuffers[0]); - ScopedSecBuffer scopedOutputAlertData(&outBuffers[1]); - - SecBufferDesc outBufferDesc = {0}; - outBufferDesc.cBuffers = 2; - outBufferDesc.pBuffers = outBuffers; - outBufferDesc.ulVersion = SECBUFFER_VERSION; - - SECURITY_STATUS status = InitializeSecurityContext( - credHandle_, - contextHandle_, - NULL, - contextFlags_, - 0, - 0, - &inBufferDesc, - 0, - NULL, - &outBufferDesc, - &secContext_, - NULL); - - if (status == SEC_E_INCOMPLETE_MESSAGE) { - // Wait for more data to arrive - break; - } - else if (status == SEC_I_CONTINUE_NEEDED) { - SecBuffer* pDataBuffer = &outBuffers[0]; - SecBuffer* pExtraBuffer = &inBuffers[1]; - - if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL) { - sendDataOnNetwork(pDataBuffer->pvBuffer, pDataBuffer->cbBuffer); - } - - if (pExtraBuffer->BufferType == SECBUFFER_EXTRA) { - receivedData_.erase(receivedData_.begin(), receivedData_.end() - pExtraBuffer->cbBuffer); - } - else { - receivedData_.clear(); - } - - break; - } - else if (status == SEC_E_OK) { - status = validateServerCertificate(); - if (status != SEC_E_OK) { - handleCertError(status); - } - - SecBuffer* pExtraBuffer = &inBuffers[1]; - - if (pExtraBuffer && pExtraBuffer->cbBuffer > 0) { - receivedData_.erase(receivedData_.begin(), receivedData_.end() - pExtraBuffer->cbBuffer); - } - else { - receivedData_.clear(); - } - - state_ = Connected; - determineStreamSizes(); - - onConnected(); - } - else { - // We failed to initialize the security context - handleCertError(status); - indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); - return; - } - } + appendNewData(data); + + while (!receivedData_.empty()) { + SecBuffer inBuffers[2]; + + // Provide Schannel with the remote host's handshake data + inBuffers[0].pvBuffer = (char*)(&receivedData_[0]); + inBuffers[0].cbBuffer = (unsigned long)receivedData_.size(); + inBuffers[0].BufferType = SECBUFFER_TOKEN; + + inBuffers[1].pvBuffer = NULL; + inBuffers[1].cbBuffer = 0; + inBuffers[1].BufferType = SECBUFFER_EMPTY; + + SecBufferDesc inBufferDesc = {0}; + inBufferDesc.cBuffers = 2; + inBufferDesc.pBuffers = inBuffers; + inBufferDesc.ulVersion = SECBUFFER_VERSION; + + SecBuffer outBuffers[2]; + + // We let Schannel allocate the output buffer for us + outBuffers[0].pvBuffer = NULL; + outBuffers[0].cbBuffer = 0; + outBuffers[0].BufferType = SECBUFFER_TOKEN; + + // Contains alert data if an alert is generated + outBuffers[1].pvBuffer = NULL; + outBuffers[1].cbBuffer = 0; + outBuffers[1].BufferType = SECBUFFER_ALERT; + + // Make sure the output buffers are freed + ScopedSecBuffer scopedOutputData(&outBuffers[0]); + ScopedSecBuffer scopedOutputAlertData(&outBuffers[1]); + + SecBufferDesc outBufferDesc = {0}; + outBufferDesc.cBuffers = 2; + outBufferDesc.pBuffers = outBuffers; + outBufferDesc.ulVersion = SECBUFFER_VERSION; + + SECURITY_STATUS status = InitializeSecurityContext( + credHandle_, + contextHandle_, + NULL, + contextFlags_, + 0, + 0, + &inBufferDesc, + 0, + NULL, + &outBufferDesc, + &secContext_, + NULL); + + if (status == SEC_E_INCOMPLETE_MESSAGE) { + // Wait for more data to arrive + break; + } + else if (status == SEC_I_CONTINUE_NEEDED) { + SecBuffer* pDataBuffer = &outBuffers[0]; + SecBuffer* pExtraBuffer = &inBuffers[1]; + + if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL) { + sendDataOnNetwork(pDataBuffer->pvBuffer, pDataBuffer->cbBuffer); + } + + if (pExtraBuffer->BufferType == SECBUFFER_EXTRA) { + receivedData_.erase(receivedData_.begin(), receivedData_.end() - pExtraBuffer->cbBuffer); + } + else { + receivedData_.clear(); + } + + break; + } + else if (status == SEC_E_OK) { + status = validateServerCertificate(); + if (status != SEC_E_OK) { + handleCertError(status); + } + + SecBuffer* pExtraBuffer = &inBuffers[1]; + + if (pExtraBuffer && pExtraBuffer->cbBuffer > 0) { + receivedData_.erase(receivedData_.begin(), receivedData_.end() - pExtraBuffer->cbBuffer); + } + else { + receivedData_.clear(); + } + + state_ = Connected; + determineStreamSizes(); + + onConnected(); + } + else { + // We failed to initialize the security context + handleCertError(status); + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); + return; + } + } } //------------------------------------------------------------------------ void SchannelContext::handleCertError(SECURITY_STATUS status) { - if (status == SEC_E_UNTRUSTED_ROOT || - status == CERT_E_UNTRUSTEDROOT || - status == CRYPT_E_ISSUER_SERIALNUMBER || - status == CRYPT_E_SIGNER_NOT_FOUND || - status == CRYPT_E_NO_TRUSTED_SIGNER) { - verificationError_ = CertificateVerificationError::Untrusted; - } - else if (status == SEC_E_CERT_EXPIRED || - status == CERT_E_EXPIRED) { - verificationError_ = CertificateVerificationError::Expired; - } - else if (status == CRYPT_E_SELF_SIGNED) { - verificationError_ = CertificateVerificationError::SelfSigned; - } - else if (status == CRYPT_E_HASH_VALUE || - status == TRUST_E_CERT_SIGNATURE) { - verificationError_ = CertificateVerificationError::InvalidSignature; - } - else if (status == CRYPT_E_REVOKED) { - verificationError_ = CertificateVerificationError::Revoked; - } - else if (status == CRYPT_E_NO_REVOCATION_CHECK || - status == CRYPT_E_REVOCATION_OFFLINE) { - verificationError_ = CertificateVerificationError::RevocationCheckFailed; - } - else if (status == CERT_E_WRONG_USAGE) { - verificationError_ = CertificateVerificationError::InvalidPurpose; - } - else { - verificationError_ = CertificateVerificationError::UnknownError; - } + if (status == SEC_E_UNTRUSTED_ROOT || + status == CERT_E_UNTRUSTEDROOT || + status == CRYPT_E_ISSUER_SERIALNUMBER || + status == CRYPT_E_SIGNER_NOT_FOUND || + status == CRYPT_E_NO_TRUSTED_SIGNER) { + verificationError_ = CertificateVerificationError::Untrusted; + } + else if (status == SEC_E_CERT_EXPIRED || + status == CERT_E_EXPIRED) { + verificationError_ = CertificateVerificationError::Expired; + } + else if (status == CRYPT_E_SELF_SIGNED) { + verificationError_ = CertificateVerificationError::SelfSigned; + } + else if (status == CRYPT_E_HASH_VALUE || + status == TRUST_E_CERT_SIGNATURE) { + verificationError_ = CertificateVerificationError::InvalidSignature; + } + else if (status == CRYPT_E_REVOKED) { + verificationError_ = CertificateVerificationError::Revoked; + } + else if (status == CRYPT_E_NO_REVOCATION_CHECK || + status == CRYPT_E_REVOCATION_OFFLINE) { + verificationError_ = CertificateVerificationError::RevocationCheckFailed; + } + else if (status == CERT_E_WRONG_USAGE) { + verificationError_ = CertificateVerificationError::InvalidPurpose; + } + else { + verificationError_ = CertificateVerificationError::UnknownError; + } } //------------------------------------------------------------------------ void SchannelContext::sendDataOnNetwork(const void* pData, size_t dataSize) { - if (dataSize > 0 && pData) { - SafeByteArray byteArray(dataSize); - memcpy(&byteArray[0], pData, dataSize); + if (dataSize > 0 && pData) { + SafeByteArray byteArray(dataSize); + memcpy(&byteArray[0], pData, dataSize); - onDataForNetwork(byteArray); - } + onDataForNetwork(byteArray); + } } //------------------------------------------------------------------------ void SchannelContext::forwardDataToApplication(const void* pData, size_t dataSize) { - SafeByteArray byteArray(dataSize); - memcpy(&byteArray[0], pData, dataSize); + SafeByteArray byteArray(dataSize); + memcpy(&byteArray[0], pData, dataSize); - onDataForApplication(byteArray); + onDataForApplication(byteArray); } //------------------------------------------------------------------------ void SchannelContext::handleDataFromApplication(const SafeByteArray& data) { - // Don't attempt to send data until we're fully connected - if (state_ == Connecting) { - return; - } + // Don't attempt to send data until we're fully connected + if (state_ == Connecting) { + return; + } - // Encrypt the data - encryptAndSendData(data); + // Encrypt the data + encryptAndSendData(data); } //------------------------------------------------------------------------ void SchannelContext::handleDataFromNetwork(const SafeByteArray& data) { - switch (state_) { - case Connecting: - { - // We're still establishing the connection, so continue the handshake - continueHandshake(data); - } - break; - - case Connected: - { - // Decrypt the data - decryptAndProcessData(data); - } - break; - - default: - return; - } + switch (state_) { + case Connecting: + { + // We're still establishing the connection, so continue the handshake + continueHandshake(data); + } + break; + + case Connected: + { + // Decrypt the data + decryptAndProcessData(data); + } + break; + + default: + return; + } } //------------------------------------------------------------------------ void SchannelContext::indicateError(boost::shared_ptr<TLSError> error) { - state_ = Error; - receivedData_.clear(); - onError(error); + state_ = Error; + receivedData_.clear(); + onError(error); } //------------------------------------------------------------------------ void SchannelContext::decryptAndProcessData(const SafeByteArray& data) { - SecBuffer inBuffers[4] = {0}; - - appendNewData(data); - - while (!receivedData_.empty()) { - // - // MSDN: - // When using the Schannel SSP with contexts that are not connection oriented, on input, - // the structure must contain four SecBuffer structures. Exactly one buffer must be of type - // SECBUFFER_DATA and contain an encrypted message, which is decrypted in place. The remaining - // buffers are used for output and must be of type SECBUFFER_EMPTY. For connection-oriented - // contexts, a SECBUFFER_DATA type buffer must be supplied, as noted for nonconnection-oriented - // contexts. Additionally, a second SECBUFFER_TOKEN type buffer that contains a security token - // must also be supplied. - // - inBuffers[0].pvBuffer = (char*)(&receivedData_[0]); - inBuffers[0].cbBuffer = (unsigned long)receivedData_.size(); - inBuffers[0].BufferType = SECBUFFER_DATA; - - inBuffers[1].BufferType = SECBUFFER_EMPTY; - inBuffers[2].BufferType = SECBUFFER_EMPTY; - inBuffers[3].BufferType = SECBUFFER_EMPTY; - - SecBufferDesc inBufferDesc = {0}; - inBufferDesc.cBuffers = 4; - inBufferDesc.pBuffers = inBuffers; - inBufferDesc.ulVersion = SECBUFFER_VERSION; - - size_t inData = receivedData_.size(); - SECURITY_STATUS status = DecryptMessage(contextHandle_, &inBufferDesc, 0, NULL); - - if (status == SEC_E_INCOMPLETE_MESSAGE) { - // Wait for more data to arrive - break; - } - else if (status == SEC_I_RENEGOTIATE) { - // TODO: Handle renegotiation scenarios - indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); - break; - } - else if (status == SEC_I_CONTEXT_EXPIRED) { - indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); - break; - } - else if (status != SEC_E_OK) { - indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); - break; - } - - SecBuffer* pDataBuffer = NULL; - SecBuffer* pExtraBuffer = NULL; - for (int i = 0; i < 4; ++i) { - if (inBuffers[i].BufferType == SECBUFFER_DATA) { - pDataBuffer = &inBuffers[i]; - } - else if (inBuffers[i].BufferType == SECBUFFER_EXTRA) { - pExtraBuffer = &inBuffers[i]; - } - } - - if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL) { - forwardDataToApplication(pDataBuffer->pvBuffer, pDataBuffer->cbBuffer); - } - - // If there is extra data left over from the decryption operation, we call DecryptMessage() again - if (pExtraBuffer) { - receivedData_.erase(receivedData_.begin(), receivedData_.end() - pExtraBuffer->cbBuffer); - } - else { - // We're done - receivedData_.erase(receivedData_.begin(), receivedData_.begin() + inData); - } - } + SecBuffer inBuffers[4] = {0}; + + appendNewData(data); + + while (!receivedData_.empty()) { + // + // MSDN: + // When using the Schannel SSP with contexts that are not connection oriented, on input, + // the structure must contain four SecBuffer structures. Exactly one buffer must be of type + // SECBUFFER_DATA and contain an encrypted message, which is decrypted in place. The remaining + // buffers are used for output and must be of type SECBUFFER_EMPTY. For connection-oriented + // contexts, a SECBUFFER_DATA type buffer must be supplied, as noted for nonconnection-oriented + // contexts. Additionally, a second SECBUFFER_TOKEN type buffer that contains a security token + // must also be supplied. + // + inBuffers[0].pvBuffer = (char*)(&receivedData_[0]); + inBuffers[0].cbBuffer = (unsigned long)receivedData_.size(); + inBuffers[0].BufferType = SECBUFFER_DATA; + + inBuffers[1].BufferType = SECBUFFER_EMPTY; + inBuffers[2].BufferType = SECBUFFER_EMPTY; + inBuffers[3].BufferType = SECBUFFER_EMPTY; + + SecBufferDesc inBufferDesc = {0}; + inBufferDesc.cBuffers = 4; + inBufferDesc.pBuffers = inBuffers; + inBufferDesc.ulVersion = SECBUFFER_VERSION; + + size_t inData = receivedData_.size(); + SECURITY_STATUS status = DecryptMessage(contextHandle_, &inBufferDesc, 0, NULL); + + if (status == SEC_E_INCOMPLETE_MESSAGE) { + // Wait for more data to arrive + break; + } + else if (status == SEC_I_RENEGOTIATE) { + // TODO: Handle renegotiation scenarios + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); + break; + } + else if (status == SEC_I_CONTEXT_EXPIRED) { + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); + break; + } + else if (status != SEC_E_OK) { + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); + break; + } + + SecBuffer* pDataBuffer = NULL; + SecBuffer* pExtraBuffer = NULL; + for (int i = 0; i < 4; ++i) { + if (inBuffers[i].BufferType == SECBUFFER_DATA) { + pDataBuffer = &inBuffers[i]; + } + else if (inBuffers[i].BufferType == SECBUFFER_EXTRA) { + pExtraBuffer = &inBuffers[i]; + } + } + + if (pDataBuffer && pDataBuffer->cbBuffer > 0 && pDataBuffer->pvBuffer != NULL) { + forwardDataToApplication(pDataBuffer->pvBuffer, pDataBuffer->cbBuffer); + } + + // If there is extra data left over from the decryption operation, we call DecryptMessage() again + if (pExtraBuffer) { + receivedData_.erase(receivedData_.begin(), receivedData_.end() - pExtraBuffer->cbBuffer); + } + else { + // We're done + receivedData_.erase(receivedData_.begin(), receivedData_.begin() + inData); + } + } } //------------------------------------------------------------------------ void SchannelContext::encryptAndSendData(const SafeByteArray& data) { - if (streamSizes_.cbMaximumMessage == 0) { - return; - } + if (streamSizes_.cbMaximumMessage == 0) { + return; + } - SecBuffer outBuffers[4] = {0}; + SecBuffer outBuffers[4] = {0}; - // Calculate the largest required size of the send buffer - size_t messageBufferSize = (data.size() > streamSizes_.cbMaximumMessage) - ? streamSizes_.cbMaximumMessage - : data.size(); + // Calculate the largest required size of the send buffer + size_t messageBufferSize = (data.size() > streamSizes_.cbMaximumMessage) + ? streamSizes_.cbMaximumMessage + : data.size(); - // Allocate a packet for the encrypted data - SafeByteArray sendBuffer; - sendBuffer.resize(streamSizes_.cbHeader + messageBufferSize + streamSizes_.cbTrailer); + // Allocate a packet for the encrypted data + SafeByteArray sendBuffer; + sendBuffer.resize(streamSizes_.cbHeader + messageBufferSize + streamSizes_.cbTrailer); - size_t bytesSent = 0; - do { - size_t bytesLeftToSend = data.size() - bytesSent; + size_t bytesSent = 0; + do { + size_t bytesLeftToSend = data.size() - bytesSent; - // Calculate how much of the send buffer we'll be using for this chunk - size_t bytesToSend = (bytesLeftToSend > streamSizes_.cbMaximumMessage) - ? streamSizes_.cbMaximumMessage - : bytesLeftToSend; + // Calculate how much of the send buffer we'll be using for this chunk + size_t bytesToSend = (bytesLeftToSend > streamSizes_.cbMaximumMessage) + ? streamSizes_.cbMaximumMessage + : bytesLeftToSend; - // Copy the plain text data into the send buffer - memcpy(&sendBuffer[0] + streamSizes_.cbHeader, &data[0] + bytesSent, bytesToSend); + // Copy the plain text data into the send buffer + memcpy(&sendBuffer[0] + streamSizes_.cbHeader, &data[0] + bytesSent, bytesToSend); - outBuffers[0].pvBuffer = &sendBuffer[0]; - outBuffers[0].cbBuffer = streamSizes_.cbHeader; - outBuffers[0].BufferType = SECBUFFER_STREAM_HEADER; + outBuffers[0].pvBuffer = &sendBuffer[0]; + outBuffers[0].cbBuffer = streamSizes_.cbHeader; + outBuffers[0].BufferType = SECBUFFER_STREAM_HEADER; - outBuffers[1].pvBuffer = &sendBuffer[0] + streamSizes_.cbHeader; - outBuffers[1].cbBuffer = (unsigned long)bytesToSend; - outBuffers[1].BufferType = SECBUFFER_DATA; + outBuffers[1].pvBuffer = &sendBuffer[0] + streamSizes_.cbHeader; + outBuffers[1].cbBuffer = (unsigned long)bytesToSend; + outBuffers[1].BufferType = SECBUFFER_DATA; - outBuffers[2].pvBuffer = &sendBuffer[0] + streamSizes_.cbHeader + bytesToSend; - outBuffers[2].cbBuffer = streamSizes_.cbTrailer; - outBuffers[2].BufferType = SECBUFFER_STREAM_TRAILER; + outBuffers[2].pvBuffer = &sendBuffer[0] + streamSizes_.cbHeader + bytesToSend; + outBuffers[2].cbBuffer = streamSizes_.cbTrailer; + outBuffers[2].BufferType = SECBUFFER_STREAM_TRAILER; - outBuffers[3].pvBuffer = 0; - outBuffers[3].cbBuffer = 0; - outBuffers[3].BufferType = SECBUFFER_EMPTY; + outBuffers[3].pvBuffer = 0; + outBuffers[3].cbBuffer = 0; + outBuffers[3].BufferType = SECBUFFER_EMPTY; - SecBufferDesc outBufferDesc = {0}; - outBufferDesc.cBuffers = 4; - outBufferDesc.pBuffers = outBuffers; - outBufferDesc.ulVersion = SECBUFFER_VERSION; + SecBufferDesc outBufferDesc = {0}; + outBufferDesc.cBuffers = 4; + outBufferDesc.pBuffers = outBuffers; + outBufferDesc.ulVersion = SECBUFFER_VERSION; - SECURITY_STATUS status = EncryptMessage(contextHandle_, 0, &outBufferDesc, 0); - if (status != SEC_E_OK) { - indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); - return; - } + SECURITY_STATUS status = EncryptMessage(contextHandle_, 0, &outBufferDesc, 0); + if (status != SEC_E_OK) { + indicateError(boost::make_shared<TLSError>(TLSError::UnknownError)); + return; + } - sendDataOnNetwork(&sendBuffer[0], outBuffers[0].cbBuffer + outBuffers[1].cbBuffer + outBuffers[2].cbBuffer); - bytesSent += bytesToSend; + sendDataOnNetwork(&sendBuffer[0], outBuffers[0].cbBuffer + outBuffers[1].cbBuffer + outBuffers[2].cbBuffer); + bytesSent += bytesToSend; - } while (bytesSent < data.size()); + } while (bytesSent < data.size()); } //------------------------------------------------------------------------ bool SchannelContext::setClientCertificate(CertificateWithKey::ref certificate) { - boost::shared_ptr<CAPICertificate> capiCertificate = boost::dynamic_pointer_cast<CAPICertificate>(certificate); - if (!capiCertificate || capiCertificate->isNull()) { - return false; - } + boost::shared_ptr<CAPICertificate> capiCertificate = boost::dynamic_pointer_cast<CAPICertificate>(certificate); + if (!capiCertificate || capiCertificate->isNull()) { + return false; + } - userCertificate_ = capiCertificate; + userCertificate_ = capiCertificate; - // We assume that the Certificate Store Name/Certificate Name - // are valid at this point - certStoreName_ = capiCertificate->getCertStoreName(); - certName_ = capiCertificate->getCertName(); + // We assume that the Certificate Store Name/Certificate Name + // are valid at this point + certStoreName_ = capiCertificate->getCertStoreName(); + certName_ = capiCertificate->getCertName(); ////At the moment this is only useful for logging: - smartCardReader_ = capiCertificate->getSmartCardReaderName(); + smartCardReader_ = capiCertificate->getSmartCardReaderName(); - capiCertificate->onCertificateCardRemoved.connect(boost::bind(&SchannelContext::handleCertificateCardRemoved, this)); + capiCertificate->onCertificateCardRemoved.connect(boost::bind(&SchannelContext::handleCertificateCardRemoved, this)); - return true; + return true; } //------------------------------------------------------------------------ void SchannelContext::handleCertificateCardRemoved() { - if (disconnectOnCardRemoval_) { - indicateError(boost::make_shared<TLSError>(TLSError::CertificateCardRemoved)); - } + if (disconnectOnCardRemoval_) { + indicateError(boost::make_shared<TLSError>(TLSError::CertificateCardRemoved)); + } } //------------------------------------------------------------------------ std::vector<Certificate::ref> SchannelContext::getPeerCertificateChain() const { - std::vector<Certificate::ref> certificateChain; - ScopedCertContext pServerCert; - ScopedCertContext pIssuerCert; - ScopedCertContext pCurrentCert; - SECURITY_STATUS status = QueryContextAttributes(contextHandle_, SECPKG_ATTR_REMOTE_CERT_CONTEXT, pServerCert.Reset()); - - if (status != SEC_E_OK) { - return certificateChain; - } - certificateChain.push_back(boost::make_shared<SchannelCertificate>(pServerCert)); - - pCurrentCert = pServerCert; - while(pCurrentCert.GetPointer()) { - DWORD dwVerificationFlags = 0; - pIssuerCert = CertGetIssuerCertificateFromStore(pServerCert->hCertStore, pCurrentCert, NULL, &dwVerificationFlags ); - if (!(*pIssuerCert.GetPointer())) { - break; - } - certificateChain.push_back(boost::make_shared<SchannelCertificate>(pIssuerCert)); - - pCurrentCert = pIssuerCert; - pIssuerCert = NULL; - } - return certificateChain; + std::vector<Certificate::ref> certificateChain; + ScopedCertContext pServerCert; + ScopedCertContext pIssuerCert; + ScopedCertContext pCurrentCert; + SECURITY_STATUS status = QueryContextAttributes(contextHandle_, SECPKG_ATTR_REMOTE_CERT_CONTEXT, pServerCert.Reset()); + + if (status != SEC_E_OK) { + return certificateChain; + } + certificateChain.push_back(boost::make_shared<SchannelCertificate>(pServerCert)); + + pCurrentCert = pServerCert; + while(pCurrentCert.GetPointer()) { + DWORD dwVerificationFlags = 0; + pIssuerCert = CertGetIssuerCertificateFromStore(pServerCert->hCertStore, pCurrentCert, NULL, &dwVerificationFlags ); + if (!(*pIssuerCert.GetPointer())) { + break; + } + certificateChain.push_back(boost::make_shared<SchannelCertificate>(pIssuerCert)); + + pCurrentCert = pIssuerCert; + pIssuerCert = NULL; + } + return certificateChain; } //------------------------------------------------------------------------ CertificateVerificationError::ref SchannelContext::getPeerCertificateVerificationError() const { - return verificationError_ ? boost::make_shared<CertificateVerificationError>(*verificationError_) : CertificateVerificationError::ref(); + return verificationError_ ? boost::make_shared<CertificateVerificationError>(*verificationError_) : CertificateVerificationError::ref(); } //------------------------------------------------------------------------ ByteArray SchannelContext::getFinishMessage() const { - SecPkgContext_Bindings bindings; - int ret = QueryContextAttributes(contextHandle_, SECPKG_ATTR_UNIQUE_BINDINGS, &bindings); - if (ret == SEC_E_OK) { - return createByteArray(((unsigned char*) bindings.Bindings) + bindings.Bindings->dwApplicationDataOffset + 11 /* tls-unique:*/, bindings.Bindings->cbApplicationDataLength - 11); - } - return ByteArray(); + SecPkgContext_Bindings bindings; + int ret = QueryContextAttributes(contextHandle_, SECPKG_ATTR_UNIQUE_BINDINGS, &bindings); + if (ret == SEC_E_OK) { + return createByteArray(((unsigned char*) bindings.Bindings) + bindings.Bindings->dwApplicationDataOffset + 11 /* tls-unique:*/, bindings.Bindings->cbApplicationDataLength - 11); + } + return ByteArray(); } //------------------------------------------------------------------------ void SchannelContext::setCheckCertificateRevocation(bool b) { - checkCertificateRevocation_ = b; + checkCertificateRevocation_ = b; } void SchannelContext::setDisconnectOnCardRemoval(bool b) { - disconnectOnCardRemoval_ = b; + disconnectOnCardRemoval_ = b; } diff --git a/Swiften/TLS/Schannel/SchannelContext.h b/Swiften/TLS/Schannel/SchannelContext.h index be30a7c..2c6a3ff 100644 --- a/Swiften/TLS/Schannel/SchannelContext.h +++ b/Swiften/TLS/Schannel/SchannelContext.h @@ -28,85 +28,85 @@ #include <boost/noncopyable.hpp> -namespace Swift -{ - class CAPICertificate; - class SchannelContext : public TLSContext, boost::noncopyable - { - public: - typedef boost::shared_ptr<SchannelContext> sp_t; +namespace Swift +{ + class CAPICertificate; + class SchannelContext : public TLSContext, boost::noncopyable + { + public: + typedef boost::shared_ptr<SchannelContext> sp_t; - public: - SchannelContext(bool tls1_0Workaround); + public: + SchannelContext(bool tls1_0Workaround); - virtual ~SchannelContext(); + virtual ~SchannelContext(); - // - // TLSContext - // - virtual void connect(); - virtual bool setClientCertificate(CertificateWithKey::ref cert); + // + // TLSContext + // + virtual void connect(); + virtual bool setClientCertificate(CertificateWithKey::ref cert); - virtual void handleDataFromNetwork(const SafeByteArray& data); - virtual void handleDataFromApplication(const SafeByteArray& data); + virtual void handleDataFromNetwork(const SafeByteArray& data); + virtual void handleDataFromApplication(const SafeByteArray& data); - virtual std::vector<Certificate::ref> getPeerCertificateChain() const; - virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const; + virtual std::vector<Certificate::ref> getPeerCertificateChain() const; + virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const; - virtual ByteArray getFinishMessage() const; + virtual ByteArray getFinishMessage() const; - virtual void setCheckCertificateRevocation(bool b); + virtual void setCheckCertificateRevocation(bool b); - virtual void setDisconnectOnCardRemoval(bool b); + virtual void setDisconnectOnCardRemoval(bool b); - private: - void determineStreamSizes(); - void continueHandshake(const SafeByteArray& data); - void indicateError(boost::shared_ptr<TLSError> error); - //FIXME: Remove - void indicateError() {indicateError(boost::make_shared<TLSError>());} - void handleCertError(SECURITY_STATUS status) ; + private: + void determineStreamSizes(); + void continueHandshake(const SafeByteArray& data); + void indicateError(boost::shared_ptr<TLSError> error); + //FIXME: Remove + void indicateError() {indicateError(boost::make_shared<TLSError>());} + void handleCertError(SECURITY_STATUS status) ; - void sendDataOnNetwork(const void* pData, size_t dataSize); - void forwardDataToApplication(const void* pData, size_t dataSize); + void sendDataOnNetwork(const void* pData, size_t dataSize); + void forwardDataToApplication(const void* pData, size_t dataSize); - void decryptAndProcessData(const SafeByteArray& data); - void encryptAndSendData(const SafeByteArray& data); + void decryptAndProcessData(const SafeByteArray& data); + void encryptAndSendData(const SafeByteArray& data); - void appendNewData(const SafeByteArray& data); - SECURITY_STATUS validateServerCertificate(); + void appendNewData(const SafeByteArray& data); + SECURITY_STATUS validateServerCertificate(); - void handleCertificateCardRemoved(); + void handleCertificateCardRemoved(); - private: - enum SchannelState - { - Start, - Connecting, - Connected, - Error + private: + enum SchannelState + { + Start, + Connecting, + Connected, + Error - }; + }; - SchannelState state_; - boost::optional<CertificateVerificationError> verificationError_; + SchannelState state_; + boost::optional<CertificateVerificationError> verificationError_; - ULONG secContext_; - ScopedCredHandle credHandle_; - ScopedCtxtHandle contextHandle_; - DWORD contextFlags_; - SecPkgContext_StreamSizes streamSizes_; + ULONG secContext_; + ScopedCredHandle credHandle_; + ScopedCtxtHandle contextHandle_; + DWORD contextFlags_; + SecPkgContext_StreamSizes streamSizes_; - std::vector<char> receivedData_; + std::vector<char> receivedData_; - HCERTSTORE myCertStore_; - std::string certStoreName_; - std::string certName_; + HCERTSTORE myCertStore_; + std::string certStoreName_; + std::string certName_; ////Not needed, most likely - std::string smartCardReader_; //Can be empty string for non SmartCard certificates - boost::shared_ptr<CAPICertificate> userCertificate_; - bool checkCertificateRevocation_; - bool tls1_0Workaround_; - bool disconnectOnCardRemoval_; - }; + std::string smartCardReader_; //Can be empty string for non SmartCard certificates + boost::shared_ptr<CAPICertificate> userCertificate_; + bool checkCertificateRevocation_; + bool tls1_0Workaround_; + bool disconnectOnCardRemoval_; + }; } diff --git a/Swiften/TLS/Schannel/SchannelContextFactory.cpp b/Swiften/TLS/Schannel/SchannelContextFactory.cpp index dacf19e..f78d386 100644 --- a/Swiften/TLS/Schannel/SchannelContextFactory.cpp +++ b/Swiften/TLS/Schannel/SchannelContextFactory.cpp @@ -20,22 +20,22 @@ SchannelContextFactory::SchannelContextFactory() : checkCertificateRevocation(tr } bool SchannelContextFactory::canCreate() const { - return true; + return true; } TLSContext* SchannelContextFactory::createTLSContext(const TLSOptions& tlsOptions) { - SchannelContext* context = new SchannelContext(tlsOptions.schannelTLS1_0Workaround); - context->setCheckCertificateRevocation(checkCertificateRevocation); - context->setDisconnectOnCardRemoval(disconnectOnCardRemoval); - return context; + SchannelContext* context = new SchannelContext(tlsOptions.schannelTLS1_0Workaround); + context->setCheckCertificateRevocation(checkCertificateRevocation); + context->setDisconnectOnCardRemoval(disconnectOnCardRemoval); + return context; } void SchannelContextFactory::setCheckCertificateRevocation(bool b) { - checkCertificateRevocation = b; + checkCertificateRevocation = b; } void SchannelContextFactory::setDisconnectOnCardRemoval(bool b) { - disconnectOnCardRemoval = b; + disconnectOnCardRemoval = b; } } diff --git a/Swiften/TLS/Schannel/SchannelContextFactory.h b/Swiften/TLS/Schannel/SchannelContextFactory.h index 27b7dc9..142f193 100644 --- a/Swiften/TLS/Schannel/SchannelContextFactory.h +++ b/Swiften/TLS/Schannel/SchannelContextFactory.h @@ -15,19 +15,19 @@ #include <Swiften/TLS/TLSContextFactory.h> namespace Swift { - class SchannelContextFactory : public TLSContextFactory { - public: - SchannelContextFactory(); + class SchannelContextFactory : public TLSContextFactory { + public: + SchannelContextFactory(); - bool canCreate() const; - virtual TLSContext* createTLSContext(const TLSOptions& tlsOptions); + bool canCreate() const; + virtual TLSContext* createTLSContext(const TLSOptions& tlsOptions); - virtual void setCheckCertificateRevocation(bool b); + virtual void setCheckCertificateRevocation(bool b); - virtual void setDisconnectOnCardRemoval(bool b); + virtual void setDisconnectOnCardRemoval(bool b); - public: - bool checkCertificateRevocation; - bool disconnectOnCardRemoval; - }; + public: + bool checkCertificateRevocation; + bool disconnectOnCardRemoval; + }; } diff --git a/Swiften/TLS/Schannel/SchannelUtil.h b/Swiften/TLS/Schannel/SchannelUtil.h index 4f73aac..194ec35 100644 --- a/Swiften/TLS/Schannel/SchannelUtil.h +++ b/Swiften/TLS/Schannel/SchannelUtil.h @@ -14,412 +14,412 @@ #include <boost/noncopyable.hpp> -namespace Swift +namespace Swift { - // - // Convenience wrapper around the Schannel CredHandle struct. - // - class ScopedCredHandle - { - private: - struct HandleContext - { - HandleContext() - { - ZeroMemory(&m_h, sizeof(m_h)); - } - - HandleContext(const CredHandle& h) - { - memcpy(&m_h, &h, sizeof(m_h)); - } - - ~HandleContext() - { - ::FreeCredentialsHandle(&m_h); - } - - CredHandle m_h; - }; - - public: - ScopedCredHandle() - : m_pHandle( new HandleContext ) - { - } - - explicit ScopedCredHandle(const CredHandle& h) - : m_pHandle( new HandleContext(h) ) - { - } - - // Copy constructor - explicit ScopedCredHandle(const ScopedCredHandle& rhs) - { - m_pHandle = rhs.m_pHandle; - } - - ~ScopedCredHandle() - { - m_pHandle.reset(); - } - - PCredHandle Reset() - { - CloseHandle(); - return &m_pHandle->m_h; - } - - operator PCredHandle() const - { - return &m_pHandle->m_h; - } - - ScopedCredHandle& operator=(const ScopedCredHandle& sh) - { - // Only update the internal handle if it's different - if (&m_pHandle->m_h != &sh.m_pHandle->m_h) - { - m_pHandle = sh.m_pHandle; - } - - return *this; - } - - void CloseHandle() - { - m_pHandle.reset( new HandleContext ); - } - - private: - boost::shared_ptr<HandleContext> m_pHandle; - }; - - //------------------------------------------------------------------------ - - // - // Convenience wrapper around the Schannel CtxtHandle struct. - // - class ScopedCtxtHandle - { - private: - struct HandleContext - { - HandleContext() - { - ZeroMemory(&m_h, sizeof(m_h)); - } - - ~HandleContext() - { - ::DeleteSecurityContext(&m_h); - } - - CtxtHandle m_h; - }; - - public: - ScopedCtxtHandle() - : m_pHandle( new HandleContext ) - { - } - - explicit ScopedCtxtHandle(CredHandle h) - : m_pHandle( new HandleContext ) - { - } - - // Copy constructor - explicit ScopedCtxtHandle(const ScopedCtxtHandle& rhs) - { - m_pHandle = rhs.m_pHandle; - } - - ~ScopedCtxtHandle() - { - m_pHandle.reset(); - } - - PCredHandle Reset() - { - CloseHandle(); - return &m_pHandle->m_h; - } - - operator PCredHandle() const - { - return &m_pHandle->m_h; - } - - ScopedCtxtHandle& operator=(const ScopedCtxtHandle& sh) - { - // Only update the internal handle if it's different - if (&m_pHandle->m_h != &sh.m_pHandle->m_h) - { - m_pHandle = sh.m_pHandle; - } - - return *this; - } - - void CloseHandle() - { - m_pHandle.reset( new HandleContext ); - } - - private: - boost::shared_ptr<HandleContext> m_pHandle; - }; - - //------------------------------------------------------------------------ - - // - // Convenience wrapper around the Schannel ScopedSecBuffer struct. - // - class ScopedSecBuffer : boost::noncopyable - { - public: - ScopedSecBuffer(PSecBuffer pSecBuffer) - : m_pSecBuffer(pSecBuffer) - { - } - - ~ScopedSecBuffer() - { - // Loop through all the output buffers and make sure we free them - if (m_pSecBuffer->pvBuffer) - FreeContextBuffer(m_pSecBuffer->pvBuffer); - } - - PSecBuffer AsPtr() - { - return m_pSecBuffer; - } - - PSecBuffer operator->() - { - return m_pSecBuffer; - } - - private: - PSecBuffer m_pSecBuffer; - }; - - //------------------------------------------------------------------------ - - // - // Convenience wrapper around the Schannel PCCERT_CONTEXT. - // - class ScopedCertContext - { - private: - struct HandleContext - { - HandleContext() - : m_pCertCtxt(NULL) - { - } - - HandleContext(PCCERT_CONTEXT pCert) - : m_pCertCtxt(pCert) - { - } - - ~HandleContext() - { - if (m_pCertCtxt) - CertFreeCertificateContext(m_pCertCtxt); - } - - PCCERT_CONTEXT m_pCertCtxt; - }; - - public: - ScopedCertContext() - : m_pHandle( new HandleContext ) - { - } - - explicit ScopedCertContext(PCCERT_CONTEXT pCert) - : m_pHandle( new HandleContext(pCert) ) - { - } - - // Copy constructor - ScopedCertContext(const ScopedCertContext& rhs) - { - m_pHandle = rhs.m_pHandle; - } - - ~ScopedCertContext() - { - m_pHandle.reset(); - } - - PCCERT_CONTEXT* Reset() - { - FreeContext(); - return &m_pHandle->m_pCertCtxt; - } - - operator PCCERT_CONTEXT() const - { - return m_pHandle->m_pCertCtxt; - } - - PCCERT_CONTEXT* GetPointer() const - { - return &m_pHandle->m_pCertCtxt; - } - - PCCERT_CONTEXT operator->() const - { - return m_pHandle->m_pCertCtxt; - } - - ScopedCertContext& operator=(const ScopedCertContext& sh) - { - // Only update the internal handle if it's different - if (&m_pHandle->m_pCertCtxt != &sh.m_pHandle->m_pCertCtxt) - { - m_pHandle = sh.m_pHandle; - } - - return *this; - } - - ScopedCertContext& operator=(PCCERT_CONTEXT pCertCtxt) - { - // Only update the internal handle if it's different - if (m_pHandle && m_pHandle->m_pCertCtxt != pCertCtxt) - m_pHandle.reset( new HandleContext(pCertCtxt) ); - - return *this; - } - - void FreeContext() - { - m_pHandle.reset( new HandleContext ); - } - - private: - boost::shared_ptr<HandleContext> m_pHandle; - }; - - //------------------------------------------------------------------------ - - // - // Convenience wrapper around the Schannel HCERTSTORE. - // - class ScopedCertStore : boost::noncopyable - { - public: - ScopedCertStore(HCERTSTORE hCertStore) - : m_hCertStore(hCertStore) - { - } - - ~ScopedCertStore() - { - // Forcefully free all memory related to the store, i.e. we assume all CertContext's that have been opened via this - // cert store have been closed at this point. - if (m_hCertStore) - CertCloseStore(m_hCertStore, CERT_CLOSE_STORE_FORCE_FLAG); - } - - operator HCERTSTORE() const - { - return m_hCertStore; - } - - private: - HCERTSTORE m_hCertStore; - }; - - //------------------------------------------------------------------------ - - // - // Convenience wrapper around the Schannel CERT_CHAIN_CONTEXT. - // - class ScopedCertChainContext - { - private: - struct HandleContext - { - HandleContext() - : m_pCertChainCtxt(NULL) - { - } - - HandleContext(PCCERT_CHAIN_CONTEXT pCert) - : m_pCertChainCtxt(pCert) - { - } - - ~HandleContext() - { - if (m_pCertChainCtxt) - CertFreeCertificateChain(m_pCertChainCtxt); - } - - PCCERT_CHAIN_CONTEXT m_pCertChainCtxt; - }; - - public: - ScopedCertChainContext() - : m_pHandle( new HandleContext ) - { - } - - explicit ScopedCertChainContext(PCCERT_CHAIN_CONTEXT pCert) - : m_pHandle( new HandleContext(pCert) ) - { - } - - // Copy constructor - ScopedCertChainContext(const ScopedCertChainContext& rhs) - { - m_pHandle = rhs.m_pHandle; - } - - ~ScopedCertChainContext() - { - m_pHandle.reset(); - } - - PCCERT_CHAIN_CONTEXT* Reset() - { - FreeContext(); - return &m_pHandle->m_pCertChainCtxt; - } - - operator PCCERT_CHAIN_CONTEXT() const - { - return m_pHandle->m_pCertChainCtxt; - } - - PCCERT_CHAIN_CONTEXT operator->() const - { - return m_pHandle->m_pCertChainCtxt; - } - - ScopedCertChainContext& operator=(const ScopedCertChainContext& sh) - { - // Only update the internal handle if it's different - if (&m_pHandle->m_pCertChainCtxt != &sh.m_pHandle->m_pCertChainCtxt) - { - m_pHandle = sh.m_pHandle; - } - - return *this; - } - - void FreeContext() - { - m_pHandle.reset( new HandleContext ); - } - - private: - boost::shared_ptr<HandleContext> m_pHandle; - }; + // + // Convenience wrapper around the Schannel CredHandle struct. + // + class ScopedCredHandle + { + private: + struct HandleContext + { + HandleContext() + { + ZeroMemory(&m_h, sizeof(m_h)); + } + + HandleContext(const CredHandle& h) + { + memcpy(&m_h, &h, sizeof(m_h)); + } + + ~HandleContext() + { + ::FreeCredentialsHandle(&m_h); + } + + CredHandle m_h; + }; + + public: + ScopedCredHandle() + : m_pHandle( new HandleContext ) + { + } + + explicit ScopedCredHandle(const CredHandle& h) + : m_pHandle( new HandleContext(h) ) + { + } + + // Copy constructor + explicit ScopedCredHandle(const ScopedCredHandle& rhs) + { + m_pHandle = rhs.m_pHandle; + } + + ~ScopedCredHandle() + { + m_pHandle.reset(); + } + + PCredHandle Reset() + { + CloseHandle(); + return &m_pHandle->m_h; + } + + operator PCredHandle() const + { + return &m_pHandle->m_h; + } + + ScopedCredHandle& operator=(const ScopedCredHandle& sh) + { + // Only update the internal handle if it's different + if (&m_pHandle->m_h != &sh.m_pHandle->m_h) + { + m_pHandle = sh.m_pHandle; + } + + return *this; + } + + void CloseHandle() + { + m_pHandle.reset( new HandleContext ); + } + + private: + boost::shared_ptr<HandleContext> m_pHandle; + }; + + //------------------------------------------------------------------------ + + // + // Convenience wrapper around the Schannel CtxtHandle struct. + // + class ScopedCtxtHandle + { + private: + struct HandleContext + { + HandleContext() + { + ZeroMemory(&m_h, sizeof(m_h)); + } + + ~HandleContext() + { + ::DeleteSecurityContext(&m_h); + } + + CtxtHandle m_h; + }; + + public: + ScopedCtxtHandle() + : m_pHandle( new HandleContext ) + { + } + + explicit ScopedCtxtHandle(CredHandle h) + : m_pHandle( new HandleContext ) + { + } + + // Copy constructor + explicit ScopedCtxtHandle(const ScopedCtxtHandle& rhs) + { + m_pHandle = rhs.m_pHandle; + } + + ~ScopedCtxtHandle() + { + m_pHandle.reset(); + } + + PCredHandle Reset() + { + CloseHandle(); + return &m_pHandle->m_h; + } + + operator PCredHandle() const + { + return &m_pHandle->m_h; + } + + ScopedCtxtHandle& operator=(const ScopedCtxtHandle& sh) + { + // Only update the internal handle if it's different + if (&m_pHandle->m_h != &sh.m_pHandle->m_h) + { + m_pHandle = sh.m_pHandle; + } + + return *this; + } + + void CloseHandle() + { + m_pHandle.reset( new HandleContext ); + } + + private: + boost::shared_ptr<HandleContext> m_pHandle; + }; + + //------------------------------------------------------------------------ + + // + // Convenience wrapper around the Schannel ScopedSecBuffer struct. + // + class ScopedSecBuffer : boost::noncopyable + { + public: + ScopedSecBuffer(PSecBuffer pSecBuffer) + : m_pSecBuffer(pSecBuffer) + { + } + + ~ScopedSecBuffer() + { + // Loop through all the output buffers and make sure we free them + if (m_pSecBuffer->pvBuffer) + FreeContextBuffer(m_pSecBuffer->pvBuffer); + } + + PSecBuffer AsPtr() + { + return m_pSecBuffer; + } + + PSecBuffer operator->() + { + return m_pSecBuffer; + } + + private: + PSecBuffer m_pSecBuffer; + }; + + //------------------------------------------------------------------------ + + // + // Convenience wrapper around the Schannel PCCERT_CONTEXT. + // + class ScopedCertContext + { + private: + struct HandleContext + { + HandleContext() + : m_pCertCtxt(NULL) + { + } + + HandleContext(PCCERT_CONTEXT pCert) + : m_pCertCtxt(pCert) + { + } + + ~HandleContext() + { + if (m_pCertCtxt) + CertFreeCertificateContext(m_pCertCtxt); + } + + PCCERT_CONTEXT m_pCertCtxt; + }; + + public: + ScopedCertContext() + : m_pHandle( new HandleContext ) + { + } + + explicit ScopedCertContext(PCCERT_CONTEXT pCert) + : m_pHandle( new HandleContext(pCert) ) + { + } + + // Copy constructor + ScopedCertContext(const ScopedCertContext& rhs) + { + m_pHandle = rhs.m_pHandle; + } + + ~ScopedCertContext() + { + m_pHandle.reset(); + } + + PCCERT_CONTEXT* Reset() + { + FreeContext(); + return &m_pHandle->m_pCertCtxt; + } + + operator PCCERT_CONTEXT() const + { + return m_pHandle->m_pCertCtxt; + } + + PCCERT_CONTEXT* GetPointer() const + { + return &m_pHandle->m_pCertCtxt; + } + + PCCERT_CONTEXT operator->() const + { + return m_pHandle->m_pCertCtxt; + } + + ScopedCertContext& operator=(const ScopedCertContext& sh) + { + // Only update the internal handle if it's different + if (&m_pHandle->m_pCertCtxt != &sh.m_pHandle->m_pCertCtxt) + { + m_pHandle = sh.m_pHandle; + } + + return *this; + } + + ScopedCertContext& operator=(PCCERT_CONTEXT pCertCtxt) + { + // Only update the internal handle if it's different + if (m_pHandle && m_pHandle->m_pCertCtxt != pCertCtxt) + m_pHandle.reset( new HandleContext(pCertCtxt) ); + + return *this; + } + + void FreeContext() + { + m_pHandle.reset( new HandleContext ); + } + + private: + boost::shared_ptr<HandleContext> m_pHandle; + }; + + //------------------------------------------------------------------------ + + // + // Convenience wrapper around the Schannel HCERTSTORE. + // + class ScopedCertStore : boost::noncopyable + { + public: + ScopedCertStore(HCERTSTORE hCertStore) + : m_hCertStore(hCertStore) + { + } + + ~ScopedCertStore() + { + // Forcefully free all memory related to the store, i.e. we assume all CertContext's that have been opened via this + // cert store have been closed at this point. + if (m_hCertStore) + CertCloseStore(m_hCertStore, CERT_CLOSE_STORE_FORCE_FLAG); + } + + operator HCERTSTORE() const + { + return m_hCertStore; + } + + private: + HCERTSTORE m_hCertStore; + }; + + //------------------------------------------------------------------------ + + // + // Convenience wrapper around the Schannel CERT_CHAIN_CONTEXT. + // + class ScopedCertChainContext + { + private: + struct HandleContext + { + HandleContext() + : m_pCertChainCtxt(NULL) + { + } + + HandleContext(PCCERT_CHAIN_CONTEXT pCert) + : m_pCertChainCtxt(pCert) + { + } + + ~HandleContext() + { + if (m_pCertChainCtxt) + CertFreeCertificateChain(m_pCertChainCtxt); + } + + PCCERT_CHAIN_CONTEXT m_pCertChainCtxt; + }; + + public: + ScopedCertChainContext() + : m_pHandle( new HandleContext ) + { + } + + explicit ScopedCertChainContext(PCCERT_CHAIN_CONTEXT pCert) + : m_pHandle( new HandleContext(pCert) ) + { + } + + // Copy constructor + ScopedCertChainContext(const ScopedCertChainContext& rhs) + { + m_pHandle = rhs.m_pHandle; + } + + ~ScopedCertChainContext() + { + m_pHandle.reset(); + } + + PCCERT_CHAIN_CONTEXT* Reset() + { + FreeContext(); + return &m_pHandle->m_pCertChainCtxt; + } + + operator PCCERT_CHAIN_CONTEXT() const + { + return m_pHandle->m_pCertChainCtxt; + } + + PCCERT_CHAIN_CONTEXT operator->() const + { + return m_pHandle->m_pCertChainCtxt; + } + + ScopedCertChainContext& operator=(const ScopedCertChainContext& sh) + { + // Only update the internal handle if it's different + if (&m_pHandle->m_pCertChainCtxt != &sh.m_pHandle->m_pCertChainCtxt) + { + m_pHandle = sh.m_pHandle; + } + + return *this; + } + + void FreeContext() + { + m_pHandle.reset( new HandleContext ); + } + + private: + boost::shared_ptr<HandleContext> m_pHandle; + }; } |