From 23fa0f462ddd0c686c677bfe5d4d743621432b7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Remko=20Tron=C3=A7on?= Date: Wed, 18 May 2011 15:45:41 +0200 Subject: Introduce safe containers for storing passwords. diff --git a/QA/Checker/IO.cpp b/QA/Checker/IO.cpp index 4945791..2e14589 100644 --- a/QA/Checker/IO.cpp +++ b/QA/Checker/IO.cpp @@ -21,3 +21,17 @@ std::ostream& operator<<(std::ostream& os, const Swift::ByteArray& s) { os.flags(oldFlags); return os; } + +std::ostream& operator<<(std::ostream& os, const Swift::SafeByteArray& s) { + std::ios::fmtflags oldFlags = os.flags(); + os << std::hex; + for (Swift::SafeByteArray::const_iterator i = s.begin(); i != s.end(); ++i) { + os << "0x" << static_cast(static_cast(*i)); + if (i + 1 < s.end()) { + os << " "; + } + } + os << std::endl; + os.flags(oldFlags); + return os; +} diff --git a/QA/Checker/IO.h b/QA/Checker/IO.h index 5eb61d8..a369b56 100644 --- a/QA/Checker/IO.h +++ b/QA/Checker/IO.h @@ -7,5 +7,7 @@ #pragma once #include +#include std::ostream& operator<<(std::ostream& os, const Swift::ByteArray& s); +std::ostream& operator<<(std::ostream& os, const Swift::SafeByteArray& s); diff --git a/Slimber/Server.h b/Slimber/Server.h index 365fedf..386365b 100644 --- a/Slimber/Server.h +++ b/Slimber/Server.h @@ -87,7 +87,7 @@ namespace Swift { public: DummyUserRegistry() {} - virtual bool isValidUserPassword(const JID&, const std::string&) const { + virtual bool isValidUserPassword(const JID&, const SafeByteArray&) const { return true; } }; diff --git a/Swift/Controllers/MainController.cpp b/Swift/Controllers/MainController.cpp index f85fcd8..6cc862e 100644 --- a/Swift/Controllers/MainController.cpp +++ b/Swift/Controllers/MainController.cpp @@ -404,7 +404,7 @@ void MainController::performLoginFromCachedCredentials() { certificateStorage_ = certificateStorageFactory_->createCertificateStorage(jid_.toBare()); certificateTrustChecker_ = new CertificateStorageTrustChecker(certificateStorage_); - client_ = boost::make_shared(clientJID, password_, networkFactories_, storages_); + client_ = boost::make_shared(clientJID, password_.c_str(), networkFactories_, storages_); clientInitialized_ = true; client_->setCertificateTrustChecker(certificateTrustChecker_); client_->onDataRead.connect(boost::bind(&XMLConsoleController::handleDataRead, xmlConsoleController_, _1)); diff --git a/Swiften/Base/Algorithm.h b/Swiften/Base/Algorithm.h index 4d7f1de..4694823 100644 --- a/Swiften/Base/Algorithm.h +++ b/Swiften/Base/Algorithm.h @@ -88,8 +88,8 @@ namespace Swift { Detail::eraseIfImpl(container, predicate, typename Detail::ContainerTraits::Category()); } - template - void append(C& target, const C& source) { + template + void append(Target& target, const Source& source) { target.insert(target.end(), source.begin(), source.end()); } @@ -104,6 +104,11 @@ namespace Swift { } } + template + void nullify(Container& c) { + std::fill(c.begin(), c.end(), 0); + } + /* * Functors */ diff --git a/Swiften/Base/ByteArray.cpp b/Swiften/Base/ByteArray.cpp index 10da395..6be96aa 100644 --- a/Swiften/Base/ByteArray.cpp +++ b/Swiften/Base/ByteArray.cpp @@ -36,18 +36,6 @@ std::vector createByteArray(const char* c) { return data; } -std::vector createByteArray(const char* c, size_t n) { - std::vector data(n); - std::copy(c, c + n, data.begin()); - return data; -} - -std::vector createByteArray(const unsigned char* c, size_t n) { - std::vector data(n); - std::copy(c, c + n, data.begin()); - return data; -} - std::string byteArrayToString(const ByteArray& b) { size_t i; for (i = b.size(); i > 0; --i) { diff --git a/Swiften/Base/ByteArray.h b/Swiften/Base/ByteArray.h index 8ef8dd6..b368ef8 100644 --- a/Swiften/Base/ByteArray.h +++ b/Swiften/Base/ByteArray.h @@ -12,16 +12,21 @@ namespace Swift { typedef std::vector ByteArray; - ByteArray createByteArray(const unsigned char* c, size_t n); ByteArray createByteArray(const std::string& s); ByteArray createByteArray(const char* c); - ByteArray createByteArray(const char* c, size_t n); + + inline ByteArray createByteArray(const unsigned char* c, size_t n) { + return ByteArray(c, c + n); + } + + inline ByteArray createByteArray(const char* c, size_t n) { + return ByteArray(c, c + n); + } inline ByteArray createByteArray(char c) { return std::vector(1, c); } - template static const T* vecptr(const std::vector& v) { return v.empty() ? NULL : &v[0]; diff --git a/Swiften/Base/SConscript b/Swiften/Base/SConscript index 01252e5..ab78639 100644 --- a/Swiften/Base/SConscript +++ b/Swiften/Base/SConscript @@ -2,10 +2,12 @@ Import("swiften_env") objects = swiften_env.SwiftenObject([ "ByteArray.cpp", + "SafeByteArray.cpp", "Error.cpp", "Log.cpp", "Paths.cpp", "String.cpp", + "SafeString.cpp", "IDGenerator.cpp", "sleep.cpp", ]) diff --git a/Swiften/Base/SafeAllocator.h b/Swiften/Base/SafeAllocator.h new file mode 100644 index 0000000..fc74234 --- /dev/null +++ b/Swiften/Base/SafeAllocator.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2011 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include +#include + +namespace Swift { + template + class SafeAllocator : public std::allocator { + public: + template struct rebind { + typedef SafeAllocator other; + }; + + SafeAllocator() throw() {} + SafeAllocator(const SafeAllocator&) throw() : std::allocator() {} + template SafeAllocator(const SafeAllocator&) throw() {} + ~SafeAllocator() throw() {} + + void deallocate (T* p, size_t num) { + std::fill(p, p + num, 0); + std::allocator::deallocate(p, num); + } + }; +}; diff --git a/Swiften/Base/SafeByteArray.cpp b/Swiften/Base/SafeByteArray.cpp new file mode 100644 index 0000000..e09a285 --- /dev/null +++ b/Swiften/Base/SafeByteArray.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2011 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include +#include + +using namespace Swift; + +namespace Swift { + +SafeByteArray createSafeByteArray(const char* c) { + SafeByteArray data; + while (*c) { + data.push_back(static_cast(*c)); + ++c; + } + return data; +} + +SafeByteArray createSafeByteArray(const SafeString& s) { + return SafeByteArray(s.begin(), s.end()); +} + +} diff --git a/Swiften/Base/SafeByteArray.h b/Swiften/Base/SafeByteArray.h new file mode 100644 index 0000000..c80a2c0 --- /dev/null +++ b/Swiften/Base/SafeByteArray.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2010 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include + +#include +#include + +namespace Swift { + class SafeString; + + typedef std::vector > SafeByteArray; + + inline SafeByteArray createSafeByteArray(const ByteArray& a) { + return SafeByteArray(a.begin(), a.end()); + } + + SafeByteArray createSafeByteArray(const char* c); + + inline SafeByteArray createSafeByteArray(const std::string& s) { + return SafeByteArray(s.begin(), s.end()); + } + + inline SafeByteArray createSafeByteArray(char c) { + return SafeByteArray(1, c); + } + + inline SafeByteArray createSafeByteArray(const char* c, size_t n) { + return SafeByteArray(c, c + n); + } + + SafeByteArray createSafeByteArray(const SafeString& s); +} + diff --git a/Swiften/Base/SafeString.cpp b/Swiften/Base/SafeString.cpp new file mode 100644 index 0000000..2abcdb0 --- /dev/null +++ b/Swiften/Base/SafeString.cpp @@ -0,0 +1,17 @@ +/* + * Copyright (c) 2011 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include + +#include + +using namespace Swift; + +SafeString::SafeString(const char* rawData) { + for (const char* c = rawData; *c; ++c) { + data.push_back(*c); + } +} diff --git a/Swiften/Base/SafeString.h b/Swiften/Base/SafeString.h new file mode 100644 index 0000000..0bd898d --- /dev/null +++ b/Swiften/Base/SafeString.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2011 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include +#include + +#include + +namespace Swift { + class SafeString { + public: + typedef std::vector > Data; + typedef Data::iterator Iterator; + typedef Data::const_iterator ConstIterator; + + SafeString() { + } + + SafeString(const std::string& s) : data(s.begin(), s.end()) { + } + + SafeString(const char*); + + + std::string toString() const { + return data.empty() ? std::string() : std::string(&data[0], data.size()); + } + + void resize(size_t n) { + data.resize(n); + } + + char& operator[](size_t n) { + return data[n]; + } + + Iterator begin() { + return data.begin(); + } + + Iterator end() { + return data.end(); + } + + ConstIterator begin() const { + return data.begin(); + } + + ConstIterator end() const { + return data.end(); + } + + size_t size() const { + return data.size(); + } + + private: + std::vector > data; + }; +}; diff --git a/Swiften/Client/Client.cpp b/Swiften/Client/Client.cpp index 13071ac..c53bcaf 100644 --- a/Swiften/Client/Client.cpp +++ b/Swiften/Client/Client.cpp @@ -29,7 +29,7 @@ namespace Swift { -Client::Client(const JID& jid, const std::string& password, NetworkFactories* networkFactories, Storages* storages) : CoreClient(jid, password, networkFactories), storages(storages) { +Client::Client(const JID& jid, const SafeString& password, NetworkFactories* networkFactories, Storages* storages) : CoreClient(jid, password, networkFactories), storages(storages) { memoryStorages = new MemoryStorages(); softwareVersionResponder = new SoftwareVersionResponder(getIQRouter()); diff --git a/Swiften/Client/Client.h b/Swiften/Client/Client.h index 05c1e6e..bee9d5c 100644 --- a/Swiften/Client/Client.h +++ b/Swiften/Client/Client.h @@ -47,7 +47,7 @@ namespace Swift { * this is NULL, * all data will be stored in memory (and be lost on shutdown) */ - Client(const JID& jid, const std::string& password, NetworkFactories* networkFactories, Storages* storages = NULL); + Client(const JID& jid, const SafeString& password, NetworkFactories* networkFactories, Storages* storages = NULL); ~Client(); diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp index 846a5e7..57d9c12 100644 --- a/Swiften/Client/ClientSession.cpp +++ b/Swiften/Client/ClientSession.cpp @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -185,7 +186,7 @@ void ClientSession::handleElement(boost::shared_ptr element) { if (stream->hasTLSCertificate()) { if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { state = Authenticating; - stream->writeElement(boost::make_shared("EXTERNAL", createByteArray(""))); + stream->writeElement(boost::make_shared("EXTERNAL", createSafeByteArray(""))); } else { finishSession(Error::TLSClientCertificateError); @@ -193,7 +194,7 @@ void ClientSession::handleElement(boost::shared_ptr element) { } else if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { state = Authenticating; - stream->writeElement(boost::make_shared("EXTERNAL", createByteArray(""))); + stream->writeElement(boost::make_shared("EXTERNAL", createSafeByteArray(""))); } else if (streamFeatures->hasAuthenticationMechanism("SCRAM-SHA-1") || streamFeatures->hasAuthenticationMechanism("SCRAM-SHA-1-PLUS")) { std::ostringstream s; @@ -275,6 +276,8 @@ void ClientSession::handleElement(boost::shared_ptr element) { else if (AuthSuccess* authSuccess = dynamic_cast(element.get())) { checkState(Authenticating); if (authenticator && !authenticator->setChallenge(authSuccess->getValue())) { + delete authenticator; + authenticator = NULL; finishSession(Error::ServerVerificationFailedError); } else { @@ -336,7 +339,7 @@ bool ClientSession::checkState(State state) { return true; } -void ClientSession::sendCredentials(const std::string& password) { +void ClientSession::sendCredentials(const SafeString& password) { assert(WaitingForCredentials); state = Authenticating; authenticator->setCredentials(localJID.getNode(), password); diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h index 9022b16..246388a 100644 --- a/Swiften/Client/ClientSession.h +++ b/Swiften/Client/ClientSession.h @@ -21,6 +21,7 @@ namespace Swift { class ClientAuthenticator; class CertificateTrustChecker; + class SafeString; class ClientSession : public boost::enable_shared_from_this { public: @@ -104,7 +105,7 @@ namespace Swift { return getState() == Finished; } - void sendCredentials(const std::string& password); + void sendCredentials(const SafeString& password); void sendStanza(boost::shared_ptr); void setCertificateTrustChecker(CertificateTrustChecker* checker) { diff --git a/Swiften/Client/CoreClient.cpp b/Swiften/Client/CoreClient.cpp index a17696f..9521bd9 100644 --- a/Swiften/Client/CoreClient.cpp +++ b/Swiften/Client/CoreClient.cpp @@ -27,7 +27,7 @@ namespace Swift { -CoreClient::CoreClient(const JID& jid, const std::string& password, NetworkFactories* networkFactories) : jid_(jid), password_(password), networkFactories(networkFactories), disconnectRequested_(false), certificateTrustChecker(NULL) { +CoreClient::CoreClient(const JID& jid, const SafeString& password, NetworkFactories* networkFactories) : jid_(jid), password_(password), networkFactories(networkFactories), disconnectRequested_(false), certificateTrustChecker(NULL) { stanzaChannel_ = new ClientSessionStanzaChannel(); stanzaChannel_->onMessageReceived.connect(boost::bind(&CoreClient::handleMessageReceived, this, _1)); stanzaChannel_->onPresenceReceived.connect(boost::bind(&CoreClient::handlePresenceReceived, this, _1)); @@ -97,7 +97,7 @@ void CoreClient::handleConnectorFinished(boost::shared_ptr connectio assert(!sessionStream_); sessionStream_ = boost::make_shared(ClientStreamType, connection_, getPayloadParserFactories(), getPayloadSerializers(), tlsFactories->getTLSContextFactory(), networkFactories->getTimerFactory()); if (!certificate_.empty()) { - sessionStream_->setTLSCertificate(PKCS12Certificate(certificate_, password_)); + sessionStream_->setTLSCertificate(PKCS12Certificate(certificate_, password_.toString())); } sessionStream_->onDataRead.connect(boost::bind(&CoreClient::handleDataRead, this, _1)); sessionStream_->onDataWritten.connect(boost::bind(&CoreClient::handleDataWritten, this, _1)); diff --git a/Swiften/Client/CoreClient.h b/Swiften/Client/CoreClient.h index 7c46fe7..6dc8392 100644 --- a/Swiften/Client/CoreClient.h +++ b/Swiften/Client/CoreClient.h @@ -14,6 +14,7 @@ #include #include #include +#include namespace Swift { class ChainedConnector; @@ -33,6 +34,7 @@ namespace Swift { class CertificateTrustChecker; class NetworkFactories; class ClientSessionStanzaChannel; + class SafeString; /** * The central class for communicating with an XMPP server. @@ -50,7 +52,7 @@ namespace Swift { * Constructs a client for the given JID with the given password. * The given eventLoop will be used to post events to. */ - CoreClient(const JID& jid, const std::string& password, NetworkFactories* networkFactories); + CoreClient(const JID& jid, const SafeString& password, NetworkFactories* networkFactories); ~CoreClient(); void setCertificate(const std::string& certificate); @@ -200,7 +202,7 @@ namespace Swift { private: JID jid_; - std::string password_; + SafeString password_; NetworkFactories* networkFactories; ClientSessionStanzaChannel* stanzaChannel_; IQRouter* iqRouter_; diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp index 25476c0..6918be8 100644 --- a/Swiften/Client/UnitTest/ClientSessionTest.cpp +++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include #include diff --git a/Swiften/Elements/AuthRequest.h b/Swiften/Elements/AuthRequest.h index 5e4e4ab..bfc86c2 100644 --- a/Swiften/Elements/AuthRequest.h +++ b/Swiften/Elements/AuthRequest.h @@ -11,6 +11,7 @@ #include #include +#include namespace Swift { class AuthRequest : public Element { @@ -18,20 +19,20 @@ namespace Swift { AuthRequest(const std::string& mechanism = "") : mechanism_(mechanism) { } - AuthRequest(const std::string& mechanism, const std::vector& message) : + AuthRequest(const std::string& mechanism, const SafeByteArray& message) : mechanism_(mechanism), message_(message) { } - AuthRequest(const std::string& mechanism, const boost::optional >& message) : + AuthRequest(const std::string& mechanism, const boost::optional& message) : mechanism_(mechanism), message_(message) { } - const boost::optional >& getMessage() const { + const boost::optional& getMessage() const { return message_; } - void setMessage(const std::vector& message) { - message_ = boost::optional >(message); + void setMessage(const SafeByteArray& message) { + message_ = boost::optional(message); } const std::string& getMechanism() const { @@ -44,6 +45,6 @@ namespace Swift { private: std::string mechanism_; - boost::optional > message_; + boost::optional message_; }; } diff --git a/Swiften/Elements/AuthResponse.h b/Swiften/Elements/AuthResponse.h index a616005..db2dcea 100644 --- a/Swiften/Elements/AuthResponse.h +++ b/Swiften/Elements/AuthResponse.h @@ -10,6 +10,7 @@ #include #include +#include namespace Swift { class AuthResponse : public Element { @@ -17,21 +18,21 @@ namespace Swift { AuthResponse() { } - AuthResponse(const std::vector& value) : value(value) { + AuthResponse(const SafeByteArray& value) : value(value) { } - AuthResponse(const boost::optional >& value) : value(value) { + AuthResponse(const boost::optional& value) : value(value) { } - const boost::optional >& getValue() const { + const boost::optional& getValue() const { return value; } - void setValue(const std::vector& value) { - this->value = boost::optional >(value); + void setValue(const SafeByteArray& value) { + this->value = boost::optional(value); } private: - boost::optional > value; + boost::optional value; }; } diff --git a/Swiften/IDN/StringPrep.cpp b/Swiften/IDN/StringPrep.cpp index 95f294c..f8ebb2c 100644 --- a/Swiften/IDN/StringPrep.cpp +++ b/Swiften/IDN/StringPrep.cpp @@ -9,32 +9,45 @@ #include #include #include +#include -namespace Swift { +using namespace Swift; + + namespace { + static const int MAX_STRINGPREP_SIZE = 1024; -static const int MAX_STRINGPREP_SIZE = 1024; + const Stringprep_profile* getLibIDNProfile(StringPrep::Profile profile) { + switch(profile) { + case StringPrep::NamePrep: return stringprep_nameprep; break; + case StringPrep::XMPPNodePrep: return stringprep_xmpp_nodeprep; break; + case StringPrep::XMPPResourcePrep: return stringprep_xmpp_resourceprep; break; + case StringPrep::SASLPrep: return stringprep_saslprep; break; + } + assert(false); + return 0; + } -const Stringprep_profile* getLibIDNProfile(StringPrep::Profile profile) { - switch(profile) { - case StringPrep::NamePrep: return stringprep_nameprep; break; - case StringPrep::XMPPNodePrep: return stringprep_xmpp_nodeprep; break; - case StringPrep::XMPPResourcePrep: return stringprep_xmpp_resourceprep; break; - case StringPrep::SASLPrep: return stringprep_saslprep; break; + template + StringType getStringPrepared(const StringType& s, StringPrep::Profile profile) { + ContainerType input(s.begin(), s.end()); + input.resize(MAX_STRINGPREP_SIZE); + if (stringprep(&input[0], MAX_STRINGPREP_SIZE, static_cast(0), getLibIDNProfile(profile)) == 0) { + return StringType(&input[0]); + } + else { + return StringType(); + } } - assert(false); - return 0; } +namespace Swift { + std::string StringPrep::getPrepared(const std::string& s, Profile profile) { - - std::vector input(s.begin(), s.end()); - input.resize(MAX_STRINGPREP_SIZE); - if (stringprep(&input[0], MAX_STRINGPREP_SIZE, static_cast(0), getLibIDNProfile(profile)) == 0) { - return std::string(&input[0]); - } - else { - return ""; - } + return getStringPrepared< std::string, std::vector >(s, profile); +} + +SafeString StringPrep::getPrepared(const SafeString& s, Profile profile) { + return getStringPrepared > >(s, profile); } } diff --git a/Swiften/IDN/StringPrep.h b/Swiften/IDN/StringPrep.h index f40553b..fc75118 100644 --- a/Swiften/IDN/StringPrep.h +++ b/Swiften/IDN/StringPrep.h @@ -7,6 +7,7 @@ #pragma once #include +#include namespace Swift { class StringPrep { @@ -19,5 +20,6 @@ namespace Swift { }; static std::string getPrepared(const std::string& s, Profile profile); + static SafeString getPrepared(const SafeString& s, Profile profile); }; } diff --git a/Swiften/Parser/AuthRequestParser.cpp b/Swiften/Parser/AuthRequestParser.cpp index d5d977f..04d9e4f 100644 --- a/Swiften/Parser/AuthRequestParser.cpp +++ b/Swiften/Parser/AuthRequestParser.cpp @@ -22,7 +22,7 @@ void AuthRequestParser::handleStartElement(const std::string&, const std::string void AuthRequestParser::handleEndElement(const std::string&, const std::string&) { --depth_; if (depth_ == 0) { - getElementGeneric()->setMessage(Base64::decode(text_)); + getElementGeneric()->setMessage(createSafeByteArray(Base64::decode(text_))); } } diff --git a/Swiften/Parser/AuthResponseParser.cpp b/Swiften/Parser/AuthResponseParser.cpp index 32d66fe..7f9a530 100644 --- a/Swiften/Parser/AuthResponseParser.cpp +++ b/Swiften/Parser/AuthResponseParser.cpp @@ -19,7 +19,7 @@ void AuthResponseParser::handleStartElement(const std::string&, const std::strin void AuthResponseParser::handleEndElement(const std::string&, const std::string&) { --depth; if (depth == 0) { - getElementGeneric()->setValue(Base64::decode(text)); + getElementGeneric()->setValue(createSafeByteArray(Base64::decode(text))); } } diff --git a/Swiften/SASL/ClientAuthenticator.h b/Swiften/SASL/ClientAuthenticator.h index 399a9d5..6557b9a 100644 --- a/Swiften/SASL/ClientAuthenticator.h +++ b/Swiften/SASL/ClientAuthenticator.h @@ -7,10 +7,13 @@ #pragma once #include - #include #include +#include +#include +#include + namespace Swift { class ClientAuthenticator { public: @@ -21,14 +24,14 @@ namespace Swift { return name; } - void setCredentials(const std::string& authcid, const std::string& password, const std::string& authzid = std::string()) { + void setCredentials(const std::string& authcid, const SafeString& password, const std::string& authzid = std::string()) { this->authcid = authcid; this->password = password; this->authzid = authzid; } - virtual boost::optional< std::vector > getResponse() const = 0; - virtual bool setChallenge(const boost::optional< std::vector >&) = 0; + virtual boost::optional getResponse() const = 0; + virtual bool setChallenge(const boost::optional&) = 0; const std::string& getAuthenticationID() const { return authcid; @@ -38,14 +41,14 @@ namespace Swift { return authzid; } - const std::string& getPassword() const { + const SafeString& getPassword() const { return password; } private: std::string name; std::string authcid; - std::string password; + SafeString password; std::string authzid; }; } diff --git a/Swiften/SASL/DIGESTMD5ClientAuthenticator.cpp b/Swiften/SASL/DIGESTMD5ClientAuthenticator.cpp index 3ff0893..ffa098c 100644 --- a/Swiften/SASL/DIGESTMD5ClientAuthenticator.cpp +++ b/Swiften/SASL/DIGESTMD5ClientAuthenticator.cpp @@ -18,9 +18,9 @@ namespace Swift { DIGESTMD5ClientAuthenticator::DIGESTMD5ClientAuthenticator(const std::string& host, const std::string& nonce) : ClientAuthenticator("DIGEST-MD5"), step(Initial), host(host), cnonce(nonce) { } -boost::optional DIGESTMD5ClientAuthenticator::getResponse() const { +boost::optional DIGESTMD5ClientAuthenticator::getResponse() const { if (step == Initial) { - return boost::optional(); + return boost::optional(); } else if (step == Response) { std::string realm; @@ -33,7 +33,9 @@ boost::optional DIGESTMD5ClientAuthenticator::getResponse() const { // Compute the response value ByteArray A1 = concat( - MD5::getHash(createByteArray(getAuthenticationID() + ":" + realm + ":" + getPassword())), createByteArray(":"), createByteArray(*challenge.getValue("nonce")), createByteArray(":"), createByteArray(cnonce)); + MD5::getHash( + createSafeByteArray(concat(SafeString(getAuthenticationID().c_str()), SafeString(":"), SafeString(realm.c_str()), SafeString(":"), getPassword()))), + createByteArray(":"), createByteArray(*challenge.getValue("nonce")), createByteArray(":"), createByteArray(cnonce)); if (!getAuthorizationID().empty()) { append(A1, createByteArray(":" + getAuthenticationID())); } @@ -60,10 +62,10 @@ boost::optional DIGESTMD5ClientAuthenticator::getResponse() const { if (!getAuthorizationID().empty()) { response.setValue("authzid", getAuthorizationID()); } - return response.serialize(); + return createSafeByteArray(response.serialize()); } else { - return boost::optional(); + return boost::optional(); } } diff --git a/Swiften/SASL/DIGESTMD5ClientAuthenticator.h b/Swiften/SASL/DIGESTMD5ClientAuthenticator.h index 82c8bc5..55bd592 100644 --- a/Swiften/SASL/DIGESTMD5ClientAuthenticator.h +++ b/Swiften/SASL/DIGESTMD5ClientAuthenticator.h @@ -12,13 +12,14 @@ #include #include #include +#include namespace Swift { class DIGESTMD5ClientAuthenticator : public ClientAuthenticator { public: DIGESTMD5ClientAuthenticator(const std::string& host, const std::string& nonce); - virtual boost::optional > getResponse() const; + virtual boost::optional getResponse() const; virtual bool setChallenge(const boost::optional >&); private: diff --git a/Swiften/SASL/PLAINClientAuthenticator.cpp b/Swiften/SASL/PLAINClientAuthenticator.cpp index 675542f..17f880a 100644 --- a/Swiften/SASL/PLAINClientAuthenticator.cpp +++ b/Swiften/SASL/PLAINClientAuthenticator.cpp @@ -12,8 +12,8 @@ namespace Swift { PLAINClientAuthenticator::PLAINClientAuthenticator() : ClientAuthenticator("PLAIN") { } -boost::optional PLAINClientAuthenticator::getResponse() const { - return concat(createByteArray(getAuthorizationID()), createByteArray('\0'), createByteArray(getAuthenticationID()), createByteArray('\0'), createByteArray(getPassword())); +boost::optional PLAINClientAuthenticator::getResponse() const { + return concat(createSafeByteArray(getAuthorizationID()), createSafeByteArray('\0'), createSafeByteArray(getAuthenticationID()), createSafeByteArray('\0'), createSafeByteArray(getPassword())); } bool PLAINClientAuthenticator::setChallenge(const boost::optional&) { diff --git a/Swiften/SASL/PLAINClientAuthenticator.h b/Swiften/SASL/PLAINClientAuthenticator.h index 4e8f8be..83e45c1 100644 --- a/Swiften/SASL/PLAINClientAuthenticator.h +++ b/Swiften/SASL/PLAINClientAuthenticator.h @@ -14,7 +14,7 @@ namespace Swift { public: PLAINClientAuthenticator(); - virtual boost::optional getResponse() const; + virtual boost::optional getResponse() const; virtual bool setChallenge(const boost::optional&); }; } diff --git a/Swiften/SASL/PLAINMessage.cpp b/Swiften/SASL/PLAINMessage.cpp index 036887c..20ffea7 100644 --- a/Swiften/SASL/PLAINMessage.cpp +++ b/Swiften/SASL/PLAINMessage.cpp @@ -5,13 +5,14 @@ */ #include +#include namespace Swift { -PLAINMessage::PLAINMessage(const std::string& authcid, const std::string& password, const std::string& authzid) : authcid(authcid), authzid(authzid), password(password) { +PLAINMessage::PLAINMessage(const std::string& authcid, const SafeByteArray& password, const std::string& authzid) : authcid(authcid), authzid(authzid), password(password) { } -PLAINMessage::PLAINMessage(const ByteArray& value) { +PLAINMessage::PLAINMessage(const SafeByteArray& value) { size_t i = 0; while (i < value.size() && value[i] != '\0') { authzid += value[i]; @@ -31,14 +32,13 @@ PLAINMessage::PLAINMessage(const ByteArray& value) { } ++i; while (i < value.size()) { - password += value[i]; + password.push_back(value[i]); ++i; } } -ByteArray PLAINMessage::getValue() const { - std::string s = authzid + '\0' + authcid + '\0' + password; - return createByteArray(s); +SafeByteArray PLAINMessage::getValue() const { + return concat(createSafeByteArray(authzid), createSafeByteArray('\0'), createSafeByteArray(authcid), createSafeByteArray('\0'), password); } } diff --git a/Swiften/SASL/PLAINMessage.h b/Swiften/SASL/PLAINMessage.h index 916d267..46ee8f7 100644 --- a/Swiften/SASL/PLAINMessage.h +++ b/Swiften/SASL/PLAINMessage.h @@ -9,21 +9,21 @@ #pragma once #include -#include +#include namespace Swift { class PLAINMessage { public: - PLAINMessage(const std::string& authcid, const std::string& password, const std::string& authzid = ""); - PLAINMessage(const ByteArray& value); + PLAINMessage(const std::string& authcid, const SafeByteArray& password, const std::string& authzid = ""); + PLAINMessage(const SafeByteArray& value); - ByteArray getValue() const; + SafeByteArray getValue() const; const std::string& getAuthenticationID() const { return authcid; } - const std::string& getPassword() const { + const SafeByteArray& getPassword() const { return password; } @@ -34,6 +34,6 @@ namespace Swift { private: std::string authcid; std::string authzid; - std::string password; + SafeByteArray password; }; } diff --git a/Swiften/SASL/SCRAMSHA1ClientAuthenticator.cpp b/Swiften/SASL/SCRAMSHA1ClientAuthenticator.cpp index bda35b9..a9855a5 100644 --- a/Swiften/SASL/SCRAMSHA1ClientAuthenticator.cpp +++ b/Swiften/SASL/SCRAMSHA1ClientAuthenticator.cpp @@ -39,23 +39,23 @@ static std::string escape(const std::string& s) { SCRAMSHA1ClientAuthenticator::SCRAMSHA1ClientAuthenticator(const std::string& nonce, bool useChannelBinding) : ClientAuthenticator(useChannelBinding ? "SCRAM-SHA-1-PLUS" : "SCRAM-SHA-1"), step(Initial), clientnonce(nonce), useChannelBinding(useChannelBinding) { } -boost::optional SCRAMSHA1ClientAuthenticator::getResponse() const { +boost::optional SCRAMSHA1ClientAuthenticator::getResponse() const { if (step == Initial) { - return concat(getGS2Header(), getInitialBareClientMessage()); + return createSafeByteArray(concat(getGS2Header(), getInitialBareClientMessage())); } else if (step == Proof) { ByteArray clientKey = HMACSHA1::getResult(saltedPassword, createByteArray("Client Key")); ByteArray storedKey = SHA1::getHash(clientKey); - ByteArray clientSignature = HMACSHA1::getResult(storedKey, authMessage); + ByteArray clientSignature = HMACSHA1::getResult(createSafeByteArray(storedKey), authMessage); ByteArray clientProof = clientKey; for (unsigned int i = 0; i < clientProof.size(); ++i) { clientProof[i] ^= clientSignature[i]; } ByteArray result = concat(getFinalMessageWithoutProof(), createByteArray(",p="), createByteArray(Base64::encode(clientProof))); - return result; + return createSafeByteArray(result); } else { - return boost::optional(); + return boost::optional(); } } @@ -100,7 +100,7 @@ bool SCRAMSHA1ClientAuthenticator::setChallenge(const boost::optional } // Compute all the values needed for the server signature - saltedPassword = PBKDF2::encode(createByteArray(StringPrep::getPrepared(getPassword(), StringPrep::SASLPrep)), salt, iterations); + saltedPassword = PBKDF2::encode(createSafeByteArray(StringPrep::getPrepared(getPassword(), StringPrep::SASLPrep)), salt, iterations); authMessage = concat(getInitialBareClientMessage(), createByteArray(","), initialServerMessage, createByteArray(","), getFinalMessageWithoutProof()); ByteArray serverKey = HMACSHA1::getResult(saltedPassword, createByteArray("Server Key")); serverSignature = HMACSHA1::getResult(serverKey, authMessage); diff --git a/Swiften/SASL/SCRAMSHA1ClientAuthenticator.h b/Swiften/SASL/SCRAMSHA1ClientAuthenticator.h index 5780bc4..d140013 100644 --- a/Swiften/SASL/SCRAMSHA1ClientAuthenticator.h +++ b/Swiften/SASL/SCRAMSHA1ClientAuthenticator.h @@ -20,7 +20,7 @@ namespace Swift { void setTLSChannelBindingData(const ByteArray& channelBindingData); - virtual boost::optional getResponse() const; + virtual boost::optional getResponse() const; virtual bool setChallenge(const boost::optional&); private: diff --git a/Swiften/SASL/UnitTest/DIGESTMD5ClientAuthenticatorTest.cpp b/Swiften/SASL/UnitTest/DIGESTMD5ClientAuthenticatorTest.cpp index a16ffac..a16eda8 100644 --- a/Swiften/SASL/UnitTest/DIGESTMD5ClientAuthenticatorTest.cpp +++ b/Swiften/SASL/UnitTest/DIGESTMD5ClientAuthenticatorTest.cpp @@ -9,6 +9,7 @@ #include #include +#include using namespace Swift; @@ -36,9 +37,9 @@ class DIGESTMD5ClientAuthenticatorTest : public CppUnit::TestFixture { "nonce=\"O6skKPuaCZEny3hteI19qXMBXSadoWs840MchORo\"," "qop=auth,charset=utf-8,algorithm=md5-sess")); - ByteArray response = *testling.getResponse(); + SafeByteArray response = *testling.getResponse(); - CPPUNIT_ASSERT_EQUAL(std::string("charset=utf-8,cnonce=\"abcdefgh\",digest-uri=\"xmpp/xmpp.example.com\",nc=00000001,nonce=\"O6skKPuaCZEny3hteI19qXMBXSadoWs840MchORo\",qop=auth,realm=\"example.com\",response=088891c800ecff1b842159ad6459104a,username=\"user\""), byteArrayToString(response)); + CPPUNIT_ASSERT_EQUAL(createSafeByteArray("charset=utf-8,cnonce=\"abcdefgh\",digest-uri=\"xmpp/xmpp.example.com\",nc=00000001,nonce=\"O6skKPuaCZEny3hteI19qXMBXSadoWs840MchORo\",qop=auth,realm=\"example.com\",response=088891c800ecff1b842159ad6459104a,username=\"user\""), response); } void testGetResponse_WithAuthorizationID() { @@ -50,9 +51,9 @@ class DIGESTMD5ClientAuthenticatorTest : public CppUnit::TestFixture { "nonce=\"O6skKPuaCZEny3hteI19qXMBXSadoWs840MchORo\"," "qop=auth,charset=utf-8,algorithm=md5-sess")); - ByteArray response = *testling.getResponse(); + SafeByteArray response = *testling.getResponse(); - CPPUNIT_ASSERT_EQUAL(std::string("authzid=\"myauthzid\",charset=utf-8,cnonce=\"abcdefgh\",digest-uri=\"xmpp/xmpp.example.com\",nc=00000001,nonce=\"O6skKPuaCZEny3hteI19qXMBXSadoWs840MchORo\",qop=auth,realm=\"example.com\",response=4293834432b6e7889a2dee7e8fe7dd06,username=\"user\""), byteArrayToString(response)); + CPPUNIT_ASSERT_EQUAL(createSafeByteArray("authzid=\"myauthzid\",charset=utf-8,cnonce=\"abcdefgh\",digest-uri=\"xmpp/xmpp.example.com\",nc=00000001,nonce=\"O6skKPuaCZEny3hteI19qXMBXSadoWs840MchORo\",qop=auth,realm=\"example.com\",response=4293834432b6e7889a2dee7e8fe7dd06,username=\"user\""), response); } }; diff --git a/Swiften/SASL/UnitTest/PLAINClientAuthenticatorTest.cpp b/Swiften/SASL/UnitTest/PLAINClientAuthenticatorTest.cpp index 5c35e79..d6c4188 100644 --- a/Swiften/SASL/UnitTest/PLAINClientAuthenticatorTest.cpp +++ b/Swiften/SASL/UnitTest/PLAINClientAuthenticatorTest.cpp @@ -24,7 +24,7 @@ class PLAINClientAuthenticatorTest : public CppUnit::TestFixture { testling.setCredentials("user", "pass"); - CPPUNIT_ASSERT_EQUAL(*testling.getResponse(), createByteArray("\0user\0pass", 10)); + CPPUNIT_ASSERT_EQUAL(*testling.getResponse(), createSafeByteArray("\0user\0pass", 10)); } void testGetResponse_WithAuthzID() { @@ -32,7 +32,7 @@ class PLAINClientAuthenticatorTest : public CppUnit::TestFixture { testling.setCredentials("user", "pass", "authz"); - CPPUNIT_ASSERT_EQUAL(*testling.getResponse(), createByteArray("authz\0user\0pass", 15)); + CPPUNIT_ASSERT_EQUAL(*testling.getResponse(), createSafeByteArray("authz\0user\0pass", 15)); } }; diff --git a/Swiften/SASL/UnitTest/PLAINMessageTest.cpp b/Swiften/SASL/UnitTest/PLAINMessageTest.cpp index dc3f82f..26331d6 100644 --- a/Swiften/SASL/UnitTest/PLAINMessageTest.cpp +++ b/Swiften/SASL/UnitTest/PLAINMessageTest.cpp @@ -29,39 +29,39 @@ class PLAINMessageTest : public CppUnit::TestFixture PLAINMessageTest() {} void testGetValue_WithoutAuthzID() { - PLAINMessage message("user", "pass"); - CPPUNIT_ASSERT_EQUAL(message.getValue(), createByteArray("\0user\0pass", 10)); + PLAINMessage message("user", createSafeByteArray("pass")); + CPPUNIT_ASSERT_EQUAL(message.getValue(), createSafeByteArray("\0user\0pass", 10)); } void testGetValue_WithAuthzID() { - PLAINMessage message("user", "pass", "authz"); - CPPUNIT_ASSERT_EQUAL(message.getValue(), createByteArray("authz\0user\0pass", 15)); + PLAINMessage message("user", createSafeByteArray("pass"), "authz"); + CPPUNIT_ASSERT_EQUAL(message.getValue(), createSafeByteArray("authz\0user\0pass", 15)); } void testConstructor_WithoutAuthzID() { - PLAINMessage message(createByteArray("\0user\0pass", 10)); + PLAINMessage message(createSafeByteArray("\0user\0pass", 10)); CPPUNIT_ASSERT_EQUAL(std::string(""), message.getAuthorizationID()); CPPUNIT_ASSERT_EQUAL(std::string("user"), message.getAuthenticationID()); - CPPUNIT_ASSERT_EQUAL(std::string("pass"), message.getPassword()); + CPPUNIT_ASSERT_EQUAL(createSafeByteArray("pass"), message.getPassword()); } void testConstructor_WithAuthzID() { - PLAINMessage message(createByteArray("authz\0user\0pass", 15)); + PLAINMessage message(createSafeByteArray("authz\0user\0pass", 15)); CPPUNIT_ASSERT_EQUAL(std::string("authz"), message.getAuthorizationID()); CPPUNIT_ASSERT_EQUAL(std::string("user"), message.getAuthenticationID()); - CPPUNIT_ASSERT_EQUAL(std::string("pass"), message.getPassword()); + CPPUNIT_ASSERT_EQUAL(createSafeByteArray("pass"), message.getPassword()); } void testConstructor_NoAuthcid() { - PLAINMessage message(createByteArray("authzid", 7)); + PLAINMessage message(createSafeByteArray("authzid", 7)); CPPUNIT_ASSERT_EQUAL(std::string(""), message.getAuthenticationID()); } void testConstructor_NoPassword() { - PLAINMessage message(createByteArray("authzid\0authcid", 15)); + PLAINMessage message(createSafeByteArray("authzid\0authcid", 15)); CPPUNIT_ASSERT_EQUAL(std::string(""), message.getAuthenticationID()); } diff --git a/Swiften/SASL/UnitTest/SCRAMSHA1ClientAuthenticatorTest.cpp b/Swiften/SASL/UnitTest/SCRAMSHA1ClientAuthenticatorTest.cpp index 78afaf7..0112691 100644 --- a/Swiften/SASL/UnitTest/SCRAMSHA1ClientAuthenticatorTest.cpp +++ b/Swiften/SASL/UnitTest/SCRAMSHA1ClientAuthenticatorTest.cpp @@ -9,6 +9,7 @@ #include #include +#include using namespace Swift; @@ -43,36 +44,36 @@ class SCRAMSHA1ClientAuthenticatorTest : public CppUnit::TestFixture { SCRAMSHA1ClientAuthenticator testling("abcdefghABCDEFGH"); testling.setCredentials("user", "pass", ""); - ByteArray response = *testling.getResponse(); + SafeByteArray response = *testling.getResponse(); - CPPUNIT_ASSERT_EQUAL(std::string("n,,n=user,r=abcdefghABCDEFGH"), byteArrayToString(response)); + CPPUNIT_ASSERT_EQUAL(createSafeByteArray("n,,n=user,r=abcdefghABCDEFGH"), response); } void testGetInitialResponse_UsernameHasSpecialChars() { SCRAMSHA1ClientAuthenticator testling("abcdefghABCDEFGH"); testling.setCredentials(",us=,er=", "pass", ""); - ByteArray response = *testling.getResponse(); + SafeByteArray response = *testling.getResponse(); - CPPUNIT_ASSERT_EQUAL(std::string("n,,n==2Cus=3D=2Cer=3D,r=abcdefghABCDEFGH"), byteArrayToString(response)); + CPPUNIT_ASSERT_EQUAL(createSafeByteArray("n,,n==2Cus=3D=2Cer=3D,r=abcdefghABCDEFGH"), response); } void testGetInitialResponse_WithAuthorizationID() { SCRAMSHA1ClientAuthenticator testling("abcdefghABCDEFGH"); testling.setCredentials("user", "pass", "auth"); - ByteArray response = *testling.getResponse(); + SafeByteArray response = *testling.getResponse(); - CPPUNIT_ASSERT_EQUAL(std::string("n,a=auth,n=user,r=abcdefghABCDEFGH"), byteArrayToString(response)); + CPPUNIT_ASSERT_EQUAL(createSafeByteArray("n,a=auth,n=user,r=abcdefghABCDEFGH"), response); } void testGetInitialResponse_WithAuthorizationIDWithSpecialChars() { SCRAMSHA1ClientAuthenticator testling("abcdefghABCDEFGH"); testling.setCredentials("user", "pass", "a=u,th"); - ByteArray response = *testling.getResponse(); + SafeByteArray response = *testling.getResponse(); - CPPUNIT_ASSERT_EQUAL(std::string("n,a=a=3Du=2Cth,n=user,r=abcdefghABCDEFGH"), byteArrayToString(response)); + CPPUNIT_ASSERT_EQUAL(createSafeByteArray("n,a=a=3Du=2Cth,n=user,r=abcdefghABCDEFGH"), response); } void testGetInitialResponse_WithoutChannelBindingWithTLSChannelBindingData() { @@ -80,9 +81,9 @@ class SCRAMSHA1ClientAuthenticatorTest : public CppUnit::TestFixture { testling.setTLSChannelBindingData(createByteArray("xyza")); testling.setCredentials("user", "pass", ""); - ByteArray response = *testling.getResponse(); + SafeByteArray response = *testling.getResponse(); - CPPUNIT_ASSERT_EQUAL(std::string("y,,n=user,r=abcdefghABCDEFGH"), byteArrayToString(response)); + CPPUNIT_ASSERT_EQUAL(createSafeByteArray("y,,n=user,r=abcdefghABCDEFGH"), response); } void testGetInitialResponse_WithChannelBindingWithTLSChannelBindingData() { @@ -90,9 +91,9 @@ class SCRAMSHA1ClientAuthenticatorTest : public CppUnit::TestFixture { testling.setTLSChannelBindingData(createByteArray("xyza")); testling.setCredentials("user", "pass", ""); - ByteArray response = *testling.getResponse(); + SafeByteArray response = *testling.getResponse(); - CPPUNIT_ASSERT_EQUAL(std::string("p=tls-unique,,n=user,r=abcdefghABCDEFGH"), byteArrayToString(response)); + CPPUNIT_ASSERT_EQUAL(createSafeByteArray("p=tls-unique,,n=user,r=abcdefghABCDEFGH"), response); } void testGetFinalResponse() { @@ -100,9 +101,9 @@ class SCRAMSHA1ClientAuthenticatorTest : public CppUnit::TestFixture { testling.setCredentials("user", "pass", ""); testling.setChallenge(createByteArray("r=abcdefghABCDEFGH,s=MTIzNDU2NzgK,i=4096")); - ByteArray response = *testling.getResponse(); + SafeByteArray response = *testling.getResponse(); - CPPUNIT_ASSERT_EQUAL(std::string("c=biws,r=abcdefghABCDEFGH,p=CZbjGDpIteIJwQNBgO0P8pKkMGY="), byteArrayToString(response)); + CPPUNIT_ASSERT_EQUAL(createSafeByteArray("c=biws,r=abcdefghABCDEFGH,p=CZbjGDpIteIJwQNBgO0P8pKkMGY="), response); } void testGetFinalResponse_WithoutChannelBindingWithTLSChannelBindingData() { @@ -111,9 +112,9 @@ class SCRAMSHA1ClientAuthenticatorTest : public CppUnit::TestFixture { testling.setTLSChannelBindingData(createByteArray("xyza")); testling.setChallenge(createByteArray("r=abcdefghABCDEFGH,s=MTIzNDU2NzgK,i=4096")); - ByteArray response = *testling.getResponse(); + SafeByteArray response = *testling.getResponse(); - CPPUNIT_ASSERT_EQUAL(std::string("c=eSws,r=abcdefghABCDEFGH,p=JNpsiFEcxZvNZ1+FFBBqrYvYxMk="), byteArrayToString(response)); + CPPUNIT_ASSERT_EQUAL(createSafeByteArray("c=eSws,r=abcdefghABCDEFGH,p=JNpsiFEcxZvNZ1+FFBBqrYvYxMk="), response); } void testGetFinalResponse_WithChannelBindingWithTLSChannelBindingData() { @@ -122,9 +123,9 @@ class SCRAMSHA1ClientAuthenticatorTest : public CppUnit::TestFixture { testling.setTLSChannelBindingData(createByteArray("xyza")); testling.setChallenge(createByteArray("r=abcdefghABCDEFGH,s=MTIzNDU2NzgK,i=4096")); - ByteArray response = *testling.getResponse(); + SafeByteArray response = *testling.getResponse(); - CPPUNIT_ASSERT_EQUAL(std::string("c=cD10bHMtdW5pcXVlLCx4eXph,r=abcdefghABCDEFGH,p=i6Rghite81P1ype8XxaVAa5l7v0="), byteArrayToString(response)); + CPPUNIT_ASSERT_EQUAL(createSafeByteArray("c=cD10bHMtdW5pcXVlLCx4eXph,r=abcdefghABCDEFGH,p=i6Rghite81P1ype8XxaVAa5l7v0="), response); } void testSetFinalChallenge() { diff --git a/Swiften/Serializer/AuthRequestSerializer.cpp b/Swiften/Serializer/AuthRequestSerializer.cpp index 415a0ff..33bdd77 100644 --- a/Swiften/Serializer/AuthRequestSerializer.cpp +++ b/Swiften/Serializer/AuthRequestSerializer.cpp @@ -9,6 +9,7 @@ #include #include #include +#include namespace Swift { @@ -17,17 +18,17 @@ AuthRequestSerializer::AuthRequestSerializer() { std::string AuthRequestSerializer::serialize(boost::shared_ptr element) const { boost::shared_ptr authRequest(boost::dynamic_pointer_cast(element)); - std::string value; - boost::optional > message = authRequest->getMessage(); + SafeString value; + boost::optional message = authRequest->getMessage(); if (message) { if ((*message).empty()) { value = "="; } else { - value = Base64::encode(ByteArray(*message)); + value = Base64::encode(*message); } } - return "getMechanism() + "\">" + value + ""; + return "getMechanism() + "\">" + value.toString() + ""; } } diff --git a/Swiften/Serializer/AuthResponseSerializer.cpp b/Swiften/Serializer/AuthResponseSerializer.cpp index 0d1872b..cfdcc99 100644 --- a/Swiften/Serializer/AuthResponseSerializer.cpp +++ b/Swiften/Serializer/AuthResponseSerializer.cpp @@ -9,6 +9,7 @@ #include #include #include +#include namespace Swift { @@ -17,17 +18,17 @@ AuthResponseSerializer::AuthResponseSerializer() { std::string AuthResponseSerializer::serialize(boost::shared_ptr element) const { boost::shared_ptr authResponse(boost::dynamic_pointer_cast(element)); - std::string value; - boost::optional > message = authResponse->getValue(); + SafeString value; + boost::optional message = authResponse->getValue(); if (message) { if ((*message).empty()) { value = "="; } else { - value = Base64::encode(ByteArray(*message)); + value = Base64::encode(*message); } } - return "" + value + ""; + return "" + value.toString() + ""; } } diff --git a/Swiften/Serializer/UnitTest/AuthRequestSerializerTest.cpp b/Swiften/Serializer/UnitTest/AuthRequestSerializerTest.cpp index 8270139..d5c0a09 100644 --- a/Swiften/Serializer/UnitTest/AuthRequestSerializerTest.cpp +++ b/Swiften/Serializer/UnitTest/AuthRequestSerializerTest.cpp @@ -24,7 +24,7 @@ class AuthRequestSerializerTest : public CppUnit::TestFixture { void testSerialize() { AuthRequestSerializer testling; boost::shared_ptr authRequest(new AuthRequest("PLAIN")); - authRequest->setMessage(createByteArray("foo")); + authRequest->setMessage(createSafeByteArray("foo")); CPPUNIT_ASSERT_EQUAL(std::string( "" @@ -44,7 +44,7 @@ class AuthRequestSerializerTest : public CppUnit::TestFixture { void testSerialize_EmptyMessage() { AuthRequestSerializer testling; boost::shared_ptr authRequest(new AuthRequest("PLAIN")); - authRequest->setMessage(std::vector()); + authRequest->setMessage(SafeByteArray()); CPPUNIT_ASSERT_EQUAL(std::string( "" diff --git a/Swiften/Serializer/UnitTest/AuthResponseSerializerTest.cpp b/Swiften/Serializer/UnitTest/AuthResponseSerializerTest.cpp index e790cc3..8887b27 100644 --- a/Swiften/Serializer/UnitTest/AuthResponseSerializerTest.cpp +++ b/Swiften/Serializer/UnitTest/AuthResponseSerializerTest.cpp @@ -24,7 +24,7 @@ class AuthResponseSerializerTest : public CppUnit::TestFixture { void testSerialize() { AuthResponseSerializer testling; boost::shared_ptr authResponse(new AuthResponse()); - authResponse->setValue(createByteArray("foo")); + authResponse->setValue(createSafeByteArray("foo")); CPPUNIT_ASSERT_EQUAL(std::string( "" @@ -44,7 +44,7 @@ class AuthResponseSerializerTest : public CppUnit::TestFixture { void testSerialize_EmptyMessage() { AuthResponseSerializer testling; boost::shared_ptr authResponse(new AuthResponse()); - authResponse->setValue(std::vector()); + authResponse->setValue(SafeByteArray()); CPPUNIT_ASSERT_EQUAL(std::string( "" diff --git a/Swiften/Server/ServerFromClientSession.cpp b/Swiften/Server/ServerFromClientSession.cpp index b047f69..dbe9745 100644 --- a/Swiften/Server/ServerFromClientSession.cpp +++ b/Swiften/Server/ServerFromClientSession.cpp @@ -51,7 +51,7 @@ void ServerFromClientSession::handleElement(boost::shared_ptr element) getXMPPLayer()->resetParser(); } else { - PLAINMessage plainMessage(authRequest->getMessage() ? *authRequest->getMessage() : createByteArray("")); + PLAINMessage plainMessage(authRequest->getMessage() ? *authRequest->getMessage() : createSafeByteArray("")); if (userRegistry_->isValidUserPassword(JID(plainMessage.getAuthenticationID(), getLocalJID().getDomain()), plainMessage.getPassword())) { getXMPPLayer()->writeElement(boost::shared_ptr(new AuthSuccess())); user_ = plainMessage.getAuthenticationID(); diff --git a/Swiften/Server/SimpleUserRegistry.cpp b/Swiften/Server/SimpleUserRegistry.cpp index 9930a39..a519ac2 100644 --- a/Swiften/Server/SimpleUserRegistry.cpp +++ b/Swiften/Server/SimpleUserRegistry.cpp @@ -11,13 +11,13 @@ namespace Swift { SimpleUserRegistry::SimpleUserRegistry() { } -bool SimpleUserRegistry::isValidUserPassword(const JID& user, const std::string& password) const { - std::map::const_iterator i = users.find(user); +bool SimpleUserRegistry::isValidUserPassword(const JID& user, const SafeByteArray& password) const { + std::map::const_iterator i = users.find(user); return i != users.end() ? i->second == password : false; } void SimpleUserRegistry::addUser(const JID& user, const std::string& password) { - users.insert(std::make_pair(user, password)); + users.insert(std::make_pair(user, createSafeByteArray(password))); } } diff --git a/Swiften/Server/SimpleUserRegistry.h b/Swiften/Server/SimpleUserRegistry.h index ad1791b..324c099 100644 --- a/Swiften/Server/SimpleUserRegistry.h +++ b/Swiften/Server/SimpleUserRegistry.h @@ -19,10 +19,10 @@ namespace Swift { public: SimpleUserRegistry(); - virtual bool isValidUserPassword(const JID& user, const std::string& password) const; + virtual bool isValidUserPassword(const JID& user, const SafeByteArray& password) const; void addUser(const JID& user, const std::string& password); private: - std::map users; + std::map users; }; } diff --git a/Swiften/Server/UserRegistry.h b/Swiften/Server/UserRegistry.h index c021fc4..9584a7e 100644 --- a/Swiften/Server/UserRegistry.h +++ b/Swiften/Server/UserRegistry.h @@ -7,15 +7,15 @@ #pragma once #include +#include namespace Swift { - class JID; class UserRegistry { public: virtual ~UserRegistry(); - virtual bool isValidUserPassword(const JID& user, const std::string& password) const = 0; + virtual bool isValidUserPassword(const JID& user, const SafeByteArray& password) const = 0; }; } diff --git a/Swiften/StringCodecs/Base64.cpp b/Swiften/StringCodecs/Base64.cpp index 4ec2e16..d8511b4 100644 --- a/Swiften/StringCodecs/Base64.cpp +++ b/Swiften/StringCodecs/Base64.cpp @@ -10,42 +10,54 @@ #include #include +#include namespace Swift { #pragma GCC diagnostic ignored "-Wold-style-cast" -std::string Base64::encode(const ByteArray &s) { - int i; - int len = s.size(); - char tbl[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/="; - int a, b, c; - - std::string p; - p.resize((len+2)/3*4); - int at = 0; - for( i = 0; i < len; i += 3 ) { - a = ((unsigned char) (s[i]) & 3) << 4; - if(i + 1 < len) { - a += (unsigned char) (s[i + 1]) >> 4; - b = ((unsigned char) (s[i + 1]) & 0xF) << 2; - if(i + 2 < len) { - b += (unsigned char) (s[i + 2]) >> 6; - c = (unsigned char) (s[i + 2]) & 0x3F; +namespace { + template + TargetType base64Encode(const SourceType& s) { + int i; + int len = s.size(); + char tbl[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/="; + int a, b, c; + + TargetType p; + p.resize((len+2)/3*4); + int at = 0; + for( i = 0; i < len; i += 3 ) { + a = ((unsigned char) (s[i]) & 3) << 4; + if(i + 1 < len) { + a += (unsigned char) (s[i + 1]) >> 4; + b = ((unsigned char) (s[i + 1]) & 0xF) << 2; + if(i + 2 < len) { + b += (unsigned char) (s[i + 2]) >> 6; + c = (unsigned char) (s[i + 2]) & 0x3F; + } + else + c = 64; + } + else { + b = c = 64; } - else - c = 64; - } - else { - b = c = 64; - } - p[at++] = tbl[(unsigned char) (s[i]) >> 2]; - p[at++] = tbl[a]; - p[at++] = tbl[b]; - p[at++] = tbl[c]; + p[at++] = tbl[(unsigned char) (s[i]) >> 2]; + p[at++] = tbl[a]; + p[at++] = tbl[b]; + p[at++] = tbl[c]; + } + return p; } - return p; +} + +std::string Base64::encode(const ByteArray &s) { + return base64Encode(s); +} + +SafeString Base64::encode(const SafeByteArray &s) { + return base64Encode(s); } ByteArray Base64::decode(const std::string& input) { diff --git a/Swiften/StringCodecs/Base64.h b/Swiften/StringCodecs/Base64.h index 3b14595..00a290d 100644 --- a/Swiften/StringCodecs/Base64.h +++ b/Swiften/StringCodecs/Base64.h @@ -10,11 +10,14 @@ #include #include +#include namespace Swift { class Base64 { public: static std::string encode(const ByteArray& s); + static SafeString encode(const SafeByteArray& s); + static ByteArray decode(const std::string &s); }; } diff --git a/Swiften/StringCodecs/HMACSHA1.cpp b/Swiften/StringCodecs/HMACSHA1.cpp index e583e3b..fd951ae 100644 --- a/Swiften/StringCodecs/HMACSHA1.cpp +++ b/Swiften/StringCodecs/HMACSHA1.cpp @@ -9,37 +9,51 @@ #include #include -#include #include -namespace Swift { +using namespace Swift; -static const unsigned int B = 64; +namespace { + static const unsigned int B = 64; -ByteArray HMACSHA1::getResult(const ByteArray& key, const ByteArray& data) { - assert(key.size() <= B); + template + ByteArray getHMACSHA1(const SourceType& key, const ByteArray& data) { + assert(key.size() <= B); - // Create the padded key - ByteArray paddedKey(key); - paddedKey.resize(B, 0x0); + // Create the padded key + SourceType paddedKey(key); + paddedKey.resize(B, 0x0); - // Create the first value - ByteArray x(paddedKey); - for (unsigned int i = 0; i < x.size(); ++i) { - x[i] ^= 0x36; - } - append(x, data); + // Create the first value + SourceType x(paddedKey); + for (unsigned int i = 0; i < x.size(); ++i) { + x[i] ^= 0x36; + } + append(x, data); - // Create the second value - ByteArray y(paddedKey); - for (unsigned int i = 0; i < y.size(); ++i) { - y[i] ^= 0x5c; + // Create the second value + SourceType y(paddedKey); + for (unsigned int i = 0; i < y.size(); ++i) { + y[i] ^= 0x5c; + } + append(y, SHA1::getHash(x)); + + return SHA1::getHash(y); } - append(y, SHA1::getHash(x)); +} - return SHA1::getHash(y); +namespace Swift { + +ByteArray HMACSHA1::getResult(const SafeByteArray& key, const ByteArray& data) { + return getHMACSHA1(key, data); +} + +ByteArray HMACSHA1::getResult(const ByteArray& key, const ByteArray& data) { + return getHMACSHA1(key, data); } + + #if 0 // A tweaked version of HMACSHA1 that is more than twice as fast as the one above. diff --git a/Swiften/StringCodecs/HMACSHA1.h b/Swiften/StringCodecs/HMACSHA1.h index 39c6e4e..0463e64 100644 --- a/Swiften/StringCodecs/HMACSHA1.h +++ b/Swiften/StringCodecs/HMACSHA1.h @@ -7,10 +7,12 @@ #pragma once #include +#include namespace Swift { class HMACSHA1 { public: + static ByteArray getResult(const SafeByteArray& key, const ByteArray& data); static ByteArray getResult(const ByteArray& key, const ByteArray& data); }; } diff --git a/Swiften/StringCodecs/MD5.cpp b/Swiften/StringCodecs/MD5.cpp index 9e69172..159eb87 100644 --- a/Swiften/StringCodecs/MD5.cpp +++ b/Swiften/StringCodecs/MD5.cpp @@ -351,16 +351,27 @@ md5_finish(md5_state_t *pms, md5_byte_t digest[16]) digest[i] = (md5_byte_t)(pms->abcd[i >> 2] >> ((i & 3) << 3)); } -ByteArray MD5::getHash(const ByteArray& data) { - ByteArray digest; - digest.resize(16); +namespace { + template + ByteArray getMD5Hash(const SourceType& data) { + ByteArray digest; + digest.resize(16); + + md5_state_t state; + md5_init(&state); + md5_append(&state, reinterpret_cast(vecptr(data)), data.size()); + md5_finish(&state, reinterpret_cast(vecptr(digest))); + + return digest; + } +} - md5_state_t state; - md5_init(&state); - md5_append(&state, reinterpret_cast(vecptr(data)), data.size()); - md5_finish(&state, reinterpret_cast(vecptr(digest))); +ByteArray MD5::getHash(const ByteArray& data) { + return getMD5Hash(data); +} - return digest; +ByteArray MD5::getHash(const SafeByteArray& data) { + return getMD5Hash(data); } } diff --git a/Swiften/StringCodecs/MD5.h b/Swiften/StringCodecs/MD5.h index 93c48e9..b1d610c 100644 --- a/Swiften/StringCodecs/MD5.h +++ b/Swiften/StringCodecs/MD5.h @@ -7,10 +7,12 @@ #pragma once #include +#include namespace Swift { class MD5 { public: static ByteArray getHash(const ByteArray& data); + static ByteArray getHash(const SafeByteArray& data); }; } diff --git a/Swiften/StringCodecs/PBKDF2.cpp b/Swiften/StringCodecs/PBKDF2.cpp index c4a5a7f..81e1208 100644 --- a/Swiften/StringCodecs/PBKDF2.cpp +++ b/Swiften/StringCodecs/PBKDF2.cpp @@ -10,7 +10,7 @@ namespace Swift { -ByteArray PBKDF2::encode(const ByteArray& password, const ByteArray& salt, int iterations) { +ByteArray PBKDF2::encode(const SafeByteArray& password, const ByteArray& salt, int iterations) { ByteArray u = HMACSHA1::getResult(password, concat(salt, createByteArray("\0\0\0\1", 4))); ByteArray result(u); int i = 1; diff --git a/Swiften/StringCodecs/PBKDF2.h b/Swiften/StringCodecs/PBKDF2.h index dd31921..b1a5986 100644 --- a/Swiften/StringCodecs/PBKDF2.h +++ b/Swiften/StringCodecs/PBKDF2.h @@ -6,11 +6,11 @@ #pragma once -#include +#include namespace Swift { class PBKDF2 { public: - static ByteArray encode(const ByteArray& password, const ByteArray& salt, int iterations); + static ByteArray encode(const SafeByteArray& password, const ByteArray& salt, int iterations); }; } diff --git a/Swiften/StringCodecs/SHA1.cpp b/Swiften/StringCodecs/SHA1.cpp index 5001fb2..e4081f4 100644 --- a/Swiften/StringCodecs/SHA1.cpp +++ b/Swiften/StringCodecs/SHA1.cpp @@ -197,11 +197,12 @@ std::vector SHA1::getHash() const { return digest; } -ByteArray SHA1::getHash(const ByteArray& input) { +template +ByteArray SHA1::getHashInternal(const Container& input) { CTX context; Init(&context); - std::vector inputCopy(input); + Container inputCopy(input); Update(&context, (boost::uint8_t*) vecptr(inputCopy), inputCopy.size()); ByteArray digest; @@ -211,4 +212,13 @@ ByteArray SHA1::getHash(const ByteArray& input) { return digest; } +ByteArray SHA1::getHash(const ByteArray& input) { + return getHashInternal(input); +} + +ByteArray SHA1::getHash(const SafeByteArray& input) { + return getHashInternal(input); +} + + } diff --git a/Swiften/StringCodecs/SHA1.h b/Swiften/StringCodecs/SHA1.h index 671d890..25bfa54 100644 --- a/Swiften/StringCodecs/SHA1.h +++ b/Swiften/StringCodecs/SHA1.h @@ -10,6 +10,7 @@ #include #include +#include namespace Swift { class SHA1 { @@ -26,6 +27,8 @@ namespace Swift { */ static ByteArray getHash(const ByteArray& data); + static ByteArray getHash(const SafeByteArray& data); + private: typedef struct { boost::uint32_t state[5]; @@ -37,6 +40,8 @@ namespace Swift { static void Update(CTX* context, boost::uint8_t* data, unsigned int len); static void Final(boost::uint8_t digest[20], CTX* context); + template static ByteArray getHashInternal(const Container& input); + private: CTX context; }; diff --git a/Swiften/StringCodecs/UnitTest/HMACSHA1Test.cpp b/Swiften/StringCodecs/UnitTest/HMACSHA1Test.cpp index efb613f..1c9d158 100644 --- a/Swiften/StringCodecs/UnitTest/HMACSHA1Test.cpp +++ b/Swiften/StringCodecs/UnitTest/HMACSHA1Test.cpp @@ -22,7 +22,7 @@ class HMACSHA1Test : public CppUnit::TestFixture { public: void testGetResult() { - ByteArray result(HMACSHA1::getResult(createByteArray("foo"), createByteArray("foobar"))); + ByteArray result(HMACSHA1::getResult(createSafeByteArray("foo"), createByteArray("foobar"))); CPPUNIT_ASSERT_EQUAL(createByteArray("\xa4\xee\xba\x8e\x63\x3d\x77\x88\x69\xf5\x68\xd0\x5a\x1b\x3d\xc7\x2b\xfd\x4\xdd"), result); } }; diff --git a/Swiften/StringCodecs/UnitTest/PBKDF2Test.cpp b/Swiften/StringCodecs/UnitTest/PBKDF2Test.cpp index ae55ac8..9d91fea 100644 --- a/Swiften/StringCodecs/UnitTest/PBKDF2Test.cpp +++ b/Swiften/StringCodecs/UnitTest/PBKDF2Test.cpp @@ -24,19 +24,19 @@ class PBKDF2Test : public CppUnit::TestFixture { public: void testGetResult_I1() { - ByteArray result(PBKDF2::encode(createByteArray("password"), createByteArray("salt"), 1)); + ByteArray result(PBKDF2::encode(createSafeByteArray("password"), createByteArray("salt"), 1)); CPPUNIT_ASSERT_EQUAL(createByteArray("\x0c\x60\xc8\x0f\x96\x1f\x0e\x71\xf3\xa9\xb5\x24\xaf\x60\x12\x06\x2f\xe0\x37\xa6"), result); } void testGetResult_I2() { - ByteArray result(PBKDF2::encode(createByteArray("password"), createByteArray("salt"), 2)); + ByteArray result(PBKDF2::encode(createSafeByteArray("password"), createByteArray("salt"), 2)); CPPUNIT_ASSERT_EQUAL(createByteArray("\xea\x6c\x1\x4d\xc7\x2d\x6f\x8c\xcd\x1e\xd9\x2a\xce\x1d\x41\xf0\xd8\xde\x89\x57"), result); } void testGetResult_I4096() { - ByteArray result(PBKDF2::encode(createByteArray("password"), createByteArray("salt"), 4096)); + ByteArray result(PBKDF2::encode(createSafeByteArray("password"), createByteArray("salt"), 4096)); CPPUNIT_ASSERT_EQUAL(createByteArray("\x4b\x00\x79\x1\xb7\x65\x48\x9a\xbe\xad\x49\xd9\x26\xf7\x21\xd0\x65\xa4\x29\xc1", 20), result); } -- cgit v0.10.2-6-g49f6