/*
 * Copyright (c) 2010-2013 Remko Tronçon
 * Licensed under the GNU General Public License v3.
 * See Documentation/Licenses/GPLv3.txt for more information.
 */

#include <Swiften/Base/ByteArray.h>
#include <Swiften/Base/Platform.h>
#include <QA/Checker/IO.h>

#include <cppunit/extensions/HelperMacros.h>
#include <cppunit/extensions/TestFactoryRegistry.h>

#ifdef SWIFTEN_PLATFORM_WIN32
#include <Swiften/Crypto/WindowsCryptoProvider.h>
#endif
#ifdef HAVE_OPENSSL_CRYPTO_PROVIDER
#include <Swiften/Crypto/OpenSSLCryptoProvider.h>
#endif
#ifdef HAVE_OPENSSL_CRYPTO_PROVIDER
#include <Swiften/Crypto/CommonCryptoCryptoProvider.h>
#endif
#include <Swiften/Crypto/Hash.h>

using namespace Swift;

template <typename CryptoProviderType>
class CryptoProviderTest : public CppUnit::TestFixture {
		CPPUNIT_TEST_SUITE(CryptoProviderTest);

		CPPUNIT_TEST(testGetSHA1Hash);
		CPPUNIT_TEST(testGetSHA1Hash_TwoUpdates);
		CPPUNIT_TEST(testGetSHA1Hash_NoData);
		CPPUNIT_TEST(testGetSHA1HashStatic);
		CPPUNIT_TEST(testGetSHA1HashStatic_Twice);
		CPPUNIT_TEST(testGetSHA1HashStatic_NoData);

		CPPUNIT_TEST(testGetMD5Hash_Empty);
		CPPUNIT_TEST(testGetMD5Hash_Alphabet);
		CPPUNIT_TEST(testMD5Incremental);

		CPPUNIT_TEST(testGetHMACSHA1);
		CPPUNIT_TEST(testGetHMACSHA1_KeyLongerThanBlockSize);
		
		CPPUNIT_TEST_SUITE_END();

	public:
		void setUp() {
			provider = new CryptoProviderType();
		}

		void tearDown() {
			delete provider;
		}

		////////////////////////////////////////////////////////////	
		// SHA-1
		////////////////////////////////////////////////////////////	
		
		void testGetSHA1Hash() {
			boost::shared_ptr<Hash> sha = boost::shared_ptr<Hash>(provider->createSHA1());
			sha->update(createByteArray("client/pc//Exodus 0.9.1<http://jabber.org/protocol/caps<http://jabber.org/protocol/disco#info<http://jabber.org/protocol/disco#items<http://jabber.org/protocol/muc<"));

			CPPUNIT_ASSERT_EQUAL(createByteArray("\x42\x06\xb2\x3c\xa6\xb0\xa6\x43\xd2\x0d\x89\xb0\x4f\xf5\x8c\xf7\x8b\x80\x96\xed"), sha->getHash());
		}

		void testGetSHA1Hash_TwoUpdates() {
			boost::shared_ptr<Hash> sha = boost::shared_ptr<Hash>(provider->createSHA1());
			sha->update(createByteArray("client/pc//Exodus 0.9.1<http://jabber.org/protocol/caps<"));
			sha->update(createByteArray("http://jabber.org/protocol/disco#info<http://jabber.org/protocol/disco#items<http://jabber.org/protocol/muc<"));

			CPPUNIT_ASSERT_EQUAL(createByteArray("\x42\x06\xb2\x3c\xa6\xb0\xa6\x43\xd2\x0d\x89\xb0\x4f\xf5\x8c\xf7\x8b\x80\x96\xed"), sha->getHash());
		}

		void testGetSHA1Hash_NoData() {
			boost::shared_ptr<Hash> sha = boost::shared_ptr<Hash>(provider->createSHA1());
			sha->update(std::vector<unsigned char>());

			CPPUNIT_ASSERT_EQUAL(createByteArray("\xda\x39\xa3\xee\x5e\x6b\x4b\x0d\x32\x55\xbf\xef\x95\x60\x18\x90\xaf\xd8\x07\x09"), sha->getHash());
		}

		void testGetSHA1HashStatic() {
			ByteArray result(provider->getSHA1Hash(createByteArray("client/pc//Exodus 0.9.1<http://jabber.org/protocol/caps<http://jabber.org/protocol/disco#info<http://jabber.org/protocol/disco#items<http://jabber.org/protocol/muc<")));
			CPPUNIT_ASSERT_EQUAL(createByteArray("\x42\x06\xb2\x3c\xa6\xb0\xa6\x43\xd2\x0d\x89\xb0\x4f\xf5\x8c\xf7\x8b\x80\x96\xed"), result);
		}


		void testGetSHA1HashStatic_Twice() {
			ByteArray input(createByteArray("client/pc//Exodus 0.9.1<http://jabber.org/protocol/caps<http://jabber.org/protocol/disco#info<http://jabber.org/protocol/disco#items<http://jabber.org/protocol/muc<"));
			provider->getSHA1Hash(input);
			ByteArray result(provider->getSHA1Hash(input));

			CPPUNIT_ASSERT_EQUAL(createByteArray("\x42\x06\xb2\x3c\xa6\xb0\xa6\x43\xd2\x0d\x89\xb0\x4f\xf5\x8c\xf7\x8b\x80\x96\xed"), result);
		}

		void testGetSHA1HashStatic_NoData() {
			ByteArray result(provider->getSHA1Hash(ByteArray()));

			CPPUNIT_ASSERT_EQUAL(createByteArray("\xda\x39\xa3\xee\x5e\x6b\x4b\x0d\x32\x55\xbf\xef\x95\x60\x18\x90\xaf\xd8\x07\x09"), result);
		}
		
		
		////////////////////////////////////////////////////////////	
		// MD5
		////////////////////////////////////////////////////////////	

		void testGetMD5Hash_Empty() {
			ByteArray result(provider->getMD5Hash(createByteArray("")));

			CPPUNIT_ASSERT_EQUAL(createByteArray("\xd4\x1d\x8c\xd9\x8f\x00\xb2\x04\xe9\x80\x09\x98\xec\xf8\x42\x7e", 16), result);
		}

		void testGetMD5Hash_Alphabet() {
			ByteArray result(provider->getMD5Hash(createByteArray("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789")));

			CPPUNIT_ASSERT_EQUAL(createByteArray("\xd1\x74\xab\x98\xd2\x77\xd9\xf5\xa5\x61\x1c\x2c\x9f\x41\x9d\x9f", 16), result);
		}

		void testMD5Incremental() {
			boost::shared_ptr<Hash> testling = boost::shared_ptr<Hash>(provider->createMD5());
			testling->update(createByteArray("ABCDEFGHIJKLMNOPQRSTUVWXYZ"));
			testling->update(createByteArray("abcdefghijklmnopqrstuvwxyz0123456789"));

			ByteArray result = testling->getHash();

			CPPUNIT_ASSERT_EQUAL(createByteArray("\xd1\x74\xab\x98\xd2\x77\xd9\xf5\xa5\x61\x1c\x2c\x9f\x41\x9d\x9f", 16), result);
		}


		////////////////////////////////////////////////////////////	
		// HMAC-SHA1
		////////////////////////////////////////////////////////////	

		void testGetHMACSHA1() {
			ByteArray result(provider->getHMACSHA1(createSafeByteArray("foo"), createByteArray("foobar")));
			CPPUNIT_ASSERT_EQUAL(createByteArray("\xa4\xee\xba\x8e\x63\x3d\x77\x88\x69\xf5\x68\xd0\x5a\x1b\x3d\xc7\x2b\xfd\x4\xdd"), result);
		}

		void testGetHMACSHA1_KeyLongerThanBlockSize() {
			ByteArray result(provider->getHMACSHA1(createSafeByteArray("---------|---------|---------|---------|---------|----------|---------|"), createByteArray("foobar")));
			CPPUNIT_ASSERT_EQUAL(createByteArray("\xd6""n""\x8f""P|1""\xd3"",""\x6"" ""\xb9\xe3""gg""\x8e\xcf"" ]+""\xa"), result);
		}

	private:
		CryptoProviderType* provider;
};

#ifdef SWIFTEN_PLATFORM_WIN32
CPPUNIT_TEST_SUITE_REGISTRATION(CryptoProviderTest<WindowsCryptoProvider>);
#endif
#ifdef HAVE_OPENSSL_CRYPTO_PROVIDER
CPPUNIT_TEST_SUITE_REGISTRATION(CryptoProviderTest<OpenSSLCryptoProvider>);
#endif
#ifdef HAVE_COMMONCRYPTO_CRYPTO_PROVIDER
CPPUNIT_TEST_SUITE_REGISTRATION(CryptoProviderTest<CommonCryptoCryptoProvider>);
#endif