/*
 * Copyright (c) 2010 Remko Tronçon
 * Licensed under the GNU General Public License v3.
 * See Documentation/Licenses/GPLv3.txt for more information.
 */

#include <Swiften/SASL/SCRAMSHA1ClientAuthenticator.h>

#include <cassert>
#include <map>
#include <boost/lexical_cast.hpp>

#include <Swiften/Crypto/CryptoProvider.h>
#include <Swiften/StringCodecs/Base64.h>
#include <Swiften/StringCodecs/PBKDF2.h>
#include <Swiften/IDN/IDNConverter.h>
#include <Swiften/Base/Concat.h>

namespace Swift {

static std::string escape(const std::string& s) {
	std::string result;
	for (size_t i = 0; i < s.size(); ++i) {
		if (s[i] == ',') {
			result += "=2C";
		}
		else if (s[i] == '=') {
			result += "=3D";
		}
		else {
			result += s[i];
		}
	}
	return result;
}


SCRAMSHA1ClientAuthenticator::SCRAMSHA1ClientAuthenticator(const std::string& nonce, bool useChannelBinding, IDNConverter* idnConverter, CryptoProvider* crypto) : ClientAuthenticator(useChannelBinding ? "SCRAM-SHA-1-PLUS" : "SCRAM-SHA-1"), step(Initial), clientnonce(nonce), useChannelBinding(useChannelBinding), idnConverter(idnConverter), crypto(crypto) {
}

boost::optional<SafeByteArray> SCRAMSHA1ClientAuthenticator::getResponse() const {
	if (step == Initial) {
		return createSafeByteArray(concat(getGS2Header(), getInitialBareClientMessage()));
	}
	else if (step == Proof) {
		ByteArray clientKey = crypto->getHMACSHA1(saltedPassword, createByteArray("Client Key"));
		ByteArray storedKey = crypto->getSHA1Hash(clientKey);
		ByteArray clientSignature = crypto->getHMACSHA1(createSafeByteArray(storedKey), authMessage);
		ByteArray clientProof = clientKey;
		for (unsigned int i = 0; i < clientProof.size(); ++i) {
			clientProof[i] ^= clientSignature[i];
		}
		ByteArray result = concat(getFinalMessageWithoutProof(), createByteArray(",p="), createByteArray(Base64::encode(clientProof)));
		return createSafeByteArray(result);
	}
	else {
		return boost::optional<SafeByteArray>();
	}
}

bool SCRAMSHA1ClientAuthenticator::setChallenge(const boost::optional<ByteArray>& challenge) {
	if (step == Initial) {
		if (!challenge) {
			return false;
		}
		initialServerMessage = *challenge;

		std::map<char, std::string> keys = parseMap(byteArrayToString(initialServerMessage));

		// Extract the salt
		ByteArray salt = Base64::decode(keys['s']);

		// Extract the server nonce
		std::string clientServerNonce = keys['r'];
		if (clientServerNonce.size() <= clientnonce.size()) {
			return false;
		}
		std::string receivedClientNonce = clientServerNonce.substr(0, clientnonce.size());
		if (receivedClientNonce != clientnonce) {
			return false;
		}
		serverNonce = createByteArray(clientServerNonce.substr(clientnonce.size(), clientServerNonce.npos));

		// Extract the number of iterations
		int iterations = 0;
		try {
			iterations = boost::lexical_cast<int>(keys['i']);
		}
		catch (const boost::bad_lexical_cast&) {
			return false;
		}
		if (iterations <= 0) {
			return false;
		}

		// Compute all the values needed for the server signature
		try {
			saltedPassword = PBKDF2::encode(idnConverter->getStringPrepared(getPassword(), IDNConverter::SASLPrep), salt, iterations, crypto);
		}
		catch (const std::exception&) {
		}
		authMessage = concat(getInitialBareClientMessage(), createByteArray(","), initialServerMessage, createByteArray(","), getFinalMessageWithoutProof());
		ByteArray serverKey = crypto->getHMACSHA1(saltedPassword, createByteArray("Server Key"));
		serverSignature = crypto->getHMACSHA1(serverKey, authMessage);

		step = Proof;
		return true;
	}
	else if (step == Proof) {
		ByteArray result = concat(createByteArray("v="), createByteArray(Base64::encode(serverSignature)));
		step = Final;
		return challenge && challenge == result;
	}
	else {
		return true;
	}
}

std::map<char, std::string> SCRAMSHA1ClientAuthenticator::parseMap(const std::string& s) {
	std::map<char, std::string> result;
	if (s.size() > 0) {
		char key = 0;
		std::string value;
		size_t i = 0;
		bool expectKey = true;
		while (i < s.size()) {
			if (expectKey) {
				key = s[i];
				expectKey = false;
				i++;
			}
			else if (s[i] == ',') {
				result[key] = value;
				value = "";
				expectKey = true;
			}
			else {
				value += s[i];
			}
			i++;
		}
		result[key] = value;
	}
	return result;
}

ByteArray SCRAMSHA1ClientAuthenticator::getInitialBareClientMessage() const {
	std::string authenticationID;
	try {
		authenticationID = idnConverter->getStringPrepared(getAuthenticationID(), IDNConverter::SASLPrep);
	}
	catch (const std::exception&) {
	}
	return createByteArray(std::string("n=" + escape(authenticationID) + ",r=" + clientnonce));
}

ByteArray SCRAMSHA1ClientAuthenticator::getGS2Header() const {
	ByteArray channelBindingHeader(createByteArray("n"));
	if (tlsChannelBindingData) {
		if (useChannelBinding) {
			channelBindingHeader = createByteArray("p=tls-unique");
		}
		else {
			channelBindingHeader = createByteArray("y");
		}
	}
	return concat(channelBindingHeader, createByteArray(","), (getAuthorizationID().empty() ? ByteArray() : createByteArray("a=" + escape(getAuthorizationID()))), createByteArray(","));
}

void SCRAMSHA1ClientAuthenticator::setTLSChannelBindingData(const ByteArray& channelBindingData) {
	this->tlsChannelBindingData = channelBindingData;
}

ByteArray SCRAMSHA1ClientAuthenticator::getFinalMessageWithoutProof() const {
	ByteArray channelBindData;
	if (useChannelBinding && tlsChannelBindingData) {
		channelBindData = *tlsChannelBindingData;
	}
	return concat(createByteArray("c=" + Base64::encode(concat(getGS2Header(), channelBindData)) + ",r=" + clientnonce), serverNonce);
}


}