From 353e9d5cd422779888d21e3780a0cb8f299f0a93 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Remko=20Tron=C3=A7on?= <git@el-tramo.be>
Date: Fri, 12 Aug 2011 21:52:24 +0200
Subject: Don't hard-code HMAC block size.


diff --git a/Swiften/SASL/SCRAMSHA1ClientAuthenticator.cpp b/Swiften/SASL/SCRAMSHA1ClientAuthenticator.cpp
index bcd6c5d..7842b4f 100644
--- a/Swiften/SASL/SCRAMSHA1ClientAuthenticator.cpp
+++ b/Swiften/SASL/SCRAMSHA1ClientAuthenticator.cpp
@@ -12,8 +12,7 @@
 
 #include <Swiften/StringCodecs/SHA1.h>
 #include <Swiften/StringCodecs/Base64.h>
-#include <Swiften/StringCodecs/HMAC.h>
-#include <Swiften/StringCodecs/SHA1.h>
+#include <Swiften/StringCodecs/HMAC_SHA1.h>
 #include <Swiften/StringCodecs/PBKDF2.h>
 #include <Swiften/IDN/StringPrep.h>
 #include <Swiften/Base/Concat.h>
@@ -45,9 +44,9 @@ boost::optional<SafeByteArray> SCRAMSHA1ClientAuthenticator::getResponse() const
 		return createSafeByteArray(concat(getGS2Header(), getInitialBareClientMessage()));
 	}
 	else if (step == Proof) {
-		ByteArray clientKey = HMAC<SHA1>()(saltedPassword, createByteArray("Client Key"));
+		ByteArray clientKey = HMAC_SHA1()(saltedPassword, createByteArray("Client Key"));
 		ByteArray storedKey = SHA1::getHash(clientKey);
-		ByteArray clientSignature = HMAC<SHA1>()(createSafeByteArray(storedKey), authMessage);
+		ByteArray clientSignature = HMAC_SHA1()(createSafeByteArray(storedKey), authMessage);
 		ByteArray clientProof = clientKey;
 		for (unsigned int i = 0; i < clientProof.size(); ++i) {
 			clientProof[i] ^= clientSignature[i];
@@ -102,13 +101,13 @@ bool SCRAMSHA1ClientAuthenticator::setChallenge(const boost::optional<ByteArray>
 
 		// Compute all the values needed for the server signature
 		try {
-			saltedPassword = PBKDF2::encode<HMAC<SHA1> >(StringPrep::getPrepared(getPassword(), StringPrep::SASLPrep), salt, iterations);
+			saltedPassword = PBKDF2::encode<HMAC_SHA1>(StringPrep::getPrepared(getPassword(), StringPrep::SASLPrep), salt, iterations);
 		}
 		catch (const std::exception&) {
 		}
 		authMessage = concat(getInitialBareClientMessage(), createByteArray(","), initialServerMessage, createByteArray(","), getFinalMessageWithoutProof());
-		ByteArray serverKey = HMAC<SHA1>()(saltedPassword, createByteArray("Server Key"));
-		serverSignature = HMAC<SHA1>()(serverKey, authMessage);
+		ByteArray serverKey = HMAC_SHA1()(saltedPassword, createByteArray("Server Key"));
+		serverSignature = HMAC_SHA1()(serverKey, authMessage);
 
 		step = Proof;
 		return true;
diff --git a/Swiften/StringCodecs/HMAC.h b/Swiften/StringCodecs/HMAC.h
index 438a3a7..cf0abfe 100644
--- a/Swiften/StringCodecs/HMAC.h
+++ b/Swiften/StringCodecs/HMAC.h
@@ -12,16 +12,25 @@
 
 namespace Swift {
 	namespace HMAC_Detail {
-		static const unsigned int B = 64;
+		template<typename KeyType> struct KeyWrapper;
+		template<> struct KeyWrapper<ByteArray> {
+			ByteArray wrap(const ByteArray& hash) const {
+				return hash;
+			}
+		};
+		template<> struct KeyWrapper<SafeByteArray> {
+			SafeByteArray wrap(const ByteArray& hash) const {
+				return createSafeByteArray(hash);
+			}
+		};
 
-		template<typename Hash, typename KeyType>
+		template<typename Hash, typename KeyType, int BlockSize>
 		static ByteArray getHMAC(const KeyType& key, const ByteArray& data) {
-			assert(key.size() <= B);
 			Hash hash;
 
 			// Create the padded key
-			KeyType paddedKey(key);
-			paddedKey.resize(B, 0x0);
+			KeyType paddedKey(key.size() <= BlockSize ? key : KeyWrapper<KeyType>().wrap(hash(key)));
+			paddedKey.resize(BlockSize, 0x0);
 
 			// Create the first value
 			KeyType x(paddedKey);
@@ -41,17 +50,17 @@ namespace Swift {
 		}
 	};
 
-	template<typename Hash>
+	template<typename Hash, int BlockSize>
 	class HMAC {
 		private:
 
 		public:
 			ByteArray operator()(const ByteArray& key, const ByteArray& data) const {
-				return HMAC_Detail::getHMAC<Hash,ByteArray>(key, data);
+				return HMAC_Detail::getHMAC<Hash,ByteArray,BlockSize>(key, data);
 			}
 
 			ByteArray operator()(const SafeByteArray& key, const ByteArray& data) const {
-				return HMAC_Detail::getHMAC<Hash,SafeByteArray>(key, data);
+				return HMAC_Detail::getHMAC<Hash,SafeByteArray,BlockSize>(key, data);
 			}
 	};
 }
diff --git a/Swiften/StringCodecs/HMAC_SHA1.h b/Swiften/StringCodecs/HMAC_SHA1.h
new file mode 100644
index 0000000..8f403c6
--- /dev/null
+++ b/Swiften/StringCodecs/HMAC_SHA1.h
@@ -0,0 +1,14 @@
+/*
+ * Copyright (c) 2010 Remko Tronçon
+ * Licensed under the GNU General Public License v3.
+ * See Documentation/Licenses/GPLv3.txt for more information.
+ */
+
+#pragma once
+
+#include <Swiften/StringCodecs/HMAC.h>
+#include <Swiften/StringCodecs/SHA1.h>
+
+namespace Swift {
+	typedef HMAC<SHA1, 64> HMAC_SHA1;
+}
diff --git a/Swiften/StringCodecs/SHA256.cpp b/Swiften/StringCodecs/SHA256.cpp
index ff1f5e9..02114ca 100644
--- a/Swiften/StringCodecs/SHA256.cpp
+++ b/Swiften/StringCodecs/SHA256.cpp
@@ -330,6 +330,10 @@ int SHA256::done(State * md, unsigned char *out)
 
 namespace Swift {
 
+SHA256::SHA256() {
+	init(&state);
+}
+
 SHA256& SHA256::update(const std::vector<unsigned char>& input) {
 	std::vector<unsigned char> inputCopy(input);
 	process(&state, (boost::uint8_t*) vecptr(inputCopy), inputCopy.size());
diff --git a/Swiften/StringCodecs/UnitTest/HMACTest.cpp b/Swiften/StringCodecs/UnitTest/HMACTest.cpp
index bdb0d96..75ae23e 100644
--- a/Swiften/StringCodecs/UnitTest/HMACTest.cpp
+++ b/Swiften/StringCodecs/UnitTest/HMACTest.cpp
@@ -11,21 +11,26 @@
 #include <cppunit/extensions/TestFactoryRegistry.h>
 
 #include <Swiften/Base/ByteArray.h>
-#include <Swiften/StringCodecs/HMAC.h>
-#include <Swiften/StringCodecs/SHA1.h>
+#include <Swiften/StringCodecs/HMAC_SHA1.h>
 
 using namespace Swift;
 
 class HMACTest : public CppUnit::TestFixture {
 		CPPUNIT_TEST_SUITE(HMACTest);
 		CPPUNIT_TEST(testGetResult);
+		CPPUNIT_TEST(testGetResult_KeyLongerThanBlockSize);
 		CPPUNIT_TEST_SUITE_END();
 
 	public:
 		void testGetResult() {
-			ByteArray result(HMAC<SHA1>()(createSafeByteArray("foo"), createByteArray("foobar")));
+			ByteArray result(HMAC_SHA1()(createSafeByteArray("foo"), createByteArray("foobar")));
 			CPPUNIT_ASSERT_EQUAL(createByteArray("\xa4\xee\xba\x8e\x63\x3d\x77\x88\x69\xf5\x68\xd0\x5a\x1b\x3d\xc7\x2b\xfd\x4\xdd"), result);
 		}
+
+		void testGetResult_KeyLongerThanBlockSize() {
+			ByteArray result(HMAC_SHA1()(createSafeByteArray("---------|---------|---------|---------|---------|----------|---------|"), createByteArray("foobar")));
+			CPPUNIT_ASSERT_EQUAL(createByteArray("\xd6""n""\x8f""P|1""\xd3"",""\x6"" ""\xb9\xe3""gg""\x8e\xcf"" ]+""\xa"), result);
+		}
 };
 
 CPPUNIT_TEST_SUITE_REGISTRATION(HMACTest);
diff --git a/Swiften/StringCodecs/UnitTest/PBKDF2Test.cpp b/Swiften/StringCodecs/UnitTest/PBKDF2Test.cpp
index 377e5c9..608ca62 100644
--- a/Swiften/StringCodecs/UnitTest/PBKDF2Test.cpp
+++ b/Swiften/StringCodecs/UnitTest/PBKDF2Test.cpp
@@ -12,8 +12,7 @@
 
 #include <Swiften/Base/ByteArray.h>
 #include <Swiften/StringCodecs/PBKDF2.h>
-#include <Swiften/StringCodecs/HMAC.h>
-#include <Swiften/StringCodecs/SHA1.h>
+#include <Swiften/StringCodecs/HMAC_SHA1.h>
 
 using namespace Swift;
 
@@ -26,19 +25,19 @@ class PBKDF2Test : public CppUnit::TestFixture {
 
 	public:
 		void testGetResult_I1() {
-			ByteArray result(PBKDF2::encode<HMAC<SHA1> >(createSafeByteArray("password"), createByteArray("salt"), 1));
+			ByteArray result(PBKDF2::encode<HMAC_SHA1 >(createSafeByteArray("password"), createByteArray("salt"), 1));
 
 			CPPUNIT_ASSERT_EQUAL(createByteArray("\x0c\x60\xc8\x0f\x96\x1f\x0e\x71\xf3\xa9\xb5\x24\xaf\x60\x12\x06\x2f\xe0\x37\xa6"), result);
 		}
 
 		void testGetResult_I2() {
-			ByteArray result(PBKDF2::encode<HMAC<SHA1> >(createSafeByteArray("password"), createByteArray("salt"), 2));
+			ByteArray result(PBKDF2::encode<HMAC_SHA1 >(createSafeByteArray("password"), createByteArray("salt"), 2));
 
 			CPPUNIT_ASSERT_EQUAL(createByteArray("\xea\x6c\x1\x4d\xc7\x2d\x6f\x8c\xcd\x1e\xd9\x2a\xce\x1d\x41\xf0\xd8\xde\x89\x57"), result);
 		}
 
 		void testGetResult_I4096() {
-			ByteArray result(PBKDF2::encode<HMAC<SHA1> >(createSafeByteArray("password"), createByteArray("salt"), 4096));
+			ByteArray result(PBKDF2::encode<HMAC_SHA1 >(createSafeByteArray("password"), createByteArray("salt"), 4096));
 
 			CPPUNIT_ASSERT_EQUAL(createByteArray("\x4b\x00\x79\x1\xb7\x65\x48\x9a\xbe\xad\x49\xd9\x26\xf7\x21\xd0\x65\xa4\x29\xc1", 20), result);
 		}
-- 
cgit v0.10.2-6-g49f6