diff options
Diffstat (limited to 'Swiften/SASL/SCRAMSHA1ClientAuthenticator.cpp')
-rw-r--r-- | Swiften/SASL/SCRAMSHA1ClientAuthenticator.cpp | 55 |
1 files changed, 28 insertions, 27 deletions
diff --git a/Swiften/SASL/SCRAMSHA1ClientAuthenticator.cpp b/Swiften/SASL/SCRAMSHA1ClientAuthenticator.cpp index 33de014..20b3d8a 100644 --- a/Swiften/SASL/SCRAMSHA1ClientAuthenticator.cpp +++ b/Swiften/SASL/SCRAMSHA1ClientAuthenticator.cpp @@ -4,17 +4,18 @@ * See Documentation/Licenses/GPLv3.txt for more information. */ -#include "Swiften/SASL/SCRAMSHA1ClientAuthenticator.h" +#include <Swiften/SASL/SCRAMSHA1ClientAuthenticator.h> #include <cassert> #include <map> #include <boost/lexical_cast.hpp> -#include "Swiften/StringCodecs/SHA1.h" -#include "Swiften/StringCodecs/Base64.h" -#include "Swiften/StringCodecs/HMACSHA1.h" -#include "Swiften/StringCodecs/PBKDF2.h" -#include "Swiften/IDN/StringPrep.h" +#include <Swiften/StringCodecs/SHA1.h> +#include <Swiften/StringCodecs/Base64.h> +#include <Swiften/StringCodecs/HMACSHA1.h> +#include <Swiften/StringCodecs/PBKDF2.h> +#include <Swiften/IDN/StringPrep.h> +#include <Swiften/Base/Concat.h> namespace Swift { @@ -38,23 +39,23 @@ static std::string escape(const std::string& s) { SCRAMSHA1ClientAuthenticator::SCRAMSHA1ClientAuthenticator(const std::string& nonce, bool useChannelBinding) : ClientAuthenticator(useChannelBinding ? "SCRAM-SHA-1-PLUS" : "SCRAM-SHA-1"), step(Initial), clientnonce(nonce), useChannelBinding(useChannelBinding) { } -boost::optional<ByteArray> SCRAMSHA1ClientAuthenticator::getResponse() const { +boost::optional<SafeByteArray> SCRAMSHA1ClientAuthenticator::getResponse() const { if (step == Initial) { - return getGS2Header() + getInitialBareClientMessage(); + return createSafeByteArray(concat(getGS2Header(), getInitialBareClientMessage())); } else if (step == Proof) { - ByteArray clientKey = HMACSHA1::getResult(saltedPassword, "Client Key"); + ByteArray clientKey = HMACSHA1::getResult(saltedPassword, createByteArray("Client Key")); ByteArray storedKey = SHA1::getHash(clientKey); - ByteArray clientSignature = HMACSHA1::getResult(storedKey, authMessage); + ByteArray clientSignature = HMACSHA1::getResult(createSafeByteArray(storedKey), authMessage); ByteArray clientProof = clientKey; - for (unsigned int i = 0; i < clientProof.getSize(); ++i) { + for (unsigned int i = 0; i < clientProof.size(); ++i) { clientProof[i] ^= clientSignature[i]; } - ByteArray result = getFinalMessageWithoutProof() + ",p=" + Base64::encode(clientProof); - return result; + ByteArray result = concat(getFinalMessageWithoutProof(), createByteArray(",p="), createByteArray(Base64::encode(clientProof))); + return createSafeByteArray(result); } else { - return boost::optional<ByteArray>(); + return boost::optional<SafeByteArray>(); } } @@ -65,7 +66,7 @@ bool SCRAMSHA1ClientAuthenticator::setChallenge(const boost::optional<ByteArray> } initialServerMessage = *challenge; - std::map<char, std::string> keys = parseMap(initialServerMessage.toString()); + std::map<char, std::string> keys = parseMap(byteArrayToString(initialServerMessage)); // Extract the salt ByteArray salt = Base64::decode(keys['s']); @@ -79,7 +80,7 @@ bool SCRAMSHA1ClientAuthenticator::setChallenge(const boost::optional<ByteArray> if (receivedClientNonce != clientnonce) { return false; } - serverNonce = clientServerNonce.substr(clientnonce.size(), clientServerNonce.npos); + serverNonce = createByteArray(clientServerNonce.substr(clientnonce.size(), clientServerNonce.npos)); // Extract the number of iterations int iterations = 0; @@ -104,15 +105,15 @@ bool SCRAMSHA1ClientAuthenticator::setChallenge(const boost::optional<ByteArray> } catch (const std::exception&) { } - authMessage = getInitialBareClientMessage() + "," + initialServerMessage + "," + getFinalMessageWithoutProof(); - ByteArray serverKey = HMACSHA1::getResult(saltedPassword, "Server Key"); + authMessage = concat(getInitialBareClientMessage(), createByteArray(","), initialServerMessage, createByteArray(","), getFinalMessageWithoutProof()); + ByteArray serverKey = HMACSHA1::getResult(saltedPassword, createByteArray("Server Key")); serverSignature = HMACSHA1::getResult(serverKey, authMessage); step = Proof; return true; } else if (step == Proof) { - ByteArray result = ByteArray("v=") + ByteArray(Base64::encode(serverSignature)); + ByteArray result = concat(createByteArray("v="), createByteArray(Base64::encode(serverSignature))); step = Final; return challenge && challenge == result; } @@ -135,7 +136,7 @@ std::map<char, std::string> SCRAMSHA1ClientAuthenticator::parseMap(const std::st i++; } else if (s[i] == ',') { - result[key] = value; + result[static_cast<size_t>(key)] = value; value = ""; expectKey = true; } @@ -152,24 +153,24 @@ std::map<char, std::string> SCRAMSHA1ClientAuthenticator::parseMap(const std::st ByteArray SCRAMSHA1ClientAuthenticator::getInitialBareClientMessage() const { std::string authenticationID; try { - authenticationID = StringPrep::getPrepared(getAuthenticationID(), StringPrep::SASLPrep); + authenticationID = StringPrep::getPrepared(getAuthenticationID(), StringPrep::SASLPrep); } catch (const std::exception&) { } - return ByteArray(std::string("n=" + escape(authenticationID) + ",r=" + clientnonce)); + return createByteArray(std::string("n=" + escape(authenticationID) + ",r=" + clientnonce)); } ByteArray SCRAMSHA1ClientAuthenticator::getGS2Header() const { - ByteArray channelBindingHeader("n"); + ByteArray channelBindingHeader(createByteArray("n")); if (tlsChannelBindingData) { if (useChannelBinding) { - channelBindingHeader = ByteArray("p=tls-unique"); + channelBindingHeader = createByteArray("p=tls-unique"); } else { - channelBindingHeader = ByteArray("y"); + channelBindingHeader = createByteArray("y"); } } - return channelBindingHeader + ByteArray(",") + (getAuthorizationID().empty() ? "" : "a=" + escape(getAuthorizationID())) + ","; + return concat(channelBindingHeader, createByteArray(","), (getAuthorizationID().empty() ? ByteArray() : createByteArray("a=" + escape(getAuthorizationID()))), createByteArray(",")); } void SCRAMSHA1ClientAuthenticator::setTLSChannelBindingData(const ByteArray& channelBindingData) { @@ -181,7 +182,7 @@ ByteArray SCRAMSHA1ClientAuthenticator::getFinalMessageWithoutProof() const { if (useChannelBinding && tlsChannelBindingData) { channelBindData = *tlsChannelBindingData; } - return ByteArray("c=") + Base64::encode(getGS2Header() + channelBindData) + ",r=" + clientnonce + serverNonce; + return concat(createByteArray("c=" + Base64::encode(concat(getGS2Header(), channelBindData)) + ",r=" + clientnonce), serverNonce); } |