From c84fb752cc881dfca9727b69fcdb3230830b7cc4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Remko=20Tron=C3=A7on?= <git@el-tramo.be>
Date: Fri, 20 Nov 2009 22:14:01 +0100
Subject: Abstracting authenticators.


diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp
index a95c058..06a7617 100644
--- a/Swiften/Client/ClientSession.cpp
+++ b/Swiften/Client/ClientSession.cpp
@@ -13,7 +13,7 @@
 #include "Swiften/Elements/StartSession.h"
 #include "Swiften/Elements/IQ.h"
 #include "Swiften/Elements/ResourceBind.h"
-#include "Swiften/SASL/PLAINMessage.h"
+#include "Swiften/SASL/PLAINClientAuthenticator.h"
 #include "Swiften/Session/SessionStream.h"
 
 namespace Swift {
@@ -24,7 +24,8 @@ ClientSession::ClientSession(
 			localJID(jid),	
 			state(Initial), 
 			stream(stream),
-			needSessionStart(false) {
+			needSessionStart(false),
+			authenticator(NULL) {
 }
 
 void ClientSession::start() {
@@ -77,6 +78,7 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
 				}
 			}
 			else if (streamFeatures->hasAuthenticationMechanism("PLAIN")) {
+				authenticator = new PLAINClientAuthenticator();
 				state = WaitingForCredentials;
 				onNeedCredentials();
 			}
@@ -112,10 +114,14 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
 	else if (dynamic_cast<AuthSuccess*>(element.get())) {
 		checkState(Authenticating);
 		state = WaitingForStreamStart;
+		delete authenticator;
+		authenticator = NULL;
 		stream->resetXMPPParser();
 		sendStreamHeader();
 	}
 	else if (dynamic_cast<AuthFailure*>(element.get())) {
+		delete authenticator;
+		authenticator = NULL;
 		finishSession(Error::AuthenticationFailedError);
 	}
 	else if (dynamic_cast<TLSProceed*>(element.get())) {
@@ -190,7 +196,8 @@ bool ClientSession::checkState(State state) {
 void ClientSession::sendCredentials(const String& password) {
 	assert(WaitingForCredentials);
 	state = Authenticating;
-	stream->writeElement(boost::shared_ptr<Element>(new AuthRequest("PLAIN", PLAINMessage(localJID.getNode(), password).getValue())));
+	authenticator->setCredentials(localJID.getNode(), password);
+	stream->writeElement(boost::shared_ptr<AuthRequest>(new AuthRequest(authenticator->getName(), authenticator->getResponse())));
 }
 
 void ClientSession::handleTLSEncrypted() {
diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h
index f980a9e..f3bc119 100644
--- a/Swiften/Client/ClientSession.h
+++ b/Swiften/Client/ClientSession.h
@@ -12,6 +12,8 @@
 #include "Swiften/Elements/Element.h"
 
 namespace Swift {
+	class ClientAuthenticator;
+
 	class ClientSession : public boost::enable_shared_from_this<ClientSession> {
 		public:
 			enum State {
@@ -90,5 +92,6 @@ namespace Swift {
 			State state;
 			boost::shared_ptr<SessionStream> stream;
 			bool needSessionStart;
+			ClientAuthenticator* authenticator;
 	};
 }
diff --git a/Swiften/Elements/AuthSuccess.h b/Swiften/Elements/AuthSuccess.h
index f63d0a8..da4d798 100644
--- a/Swiften/Elements/AuthSuccess.h
+++ b/Swiften/Elements/AuthSuccess.h
@@ -1,5 +1,4 @@
-#ifndef SWIFTEN_AuthSuccess_H
-#define SWIFTEN_AuthSuccess_H
+#pragma once
 
 #include "Swiften/Elements/Element.h"
 
@@ -9,5 +8,3 @@ namespace Swift {
 			AuthSuccess() {}
 	};
 }
-
-#endif
diff --git a/Swiften/LinkLocal/IncomingLinkLocalSession.cpp b/Swiften/LinkLocal/IncomingLinkLocalSession.cpp
index 4c3a681..77910e6 100644
--- a/Swiften/LinkLocal/IncomingLinkLocalSession.cpp
+++ b/Swiften/LinkLocal/IncomingLinkLocalSession.cpp
@@ -9,7 +9,6 @@
 #include "Swiften/StreamStack/XMPPLayer.h"
 #include "Swiften/Elements/StreamFeatures.h"
 #include "Swiften/Elements/IQ.h"
-#include "Swiften/SASL/PLAINMessage.h"
 
 namespace Swift {
 
diff --git a/Swiften/SASL/ClientAuthenticator.cpp b/Swiften/SASL/ClientAuthenticator.cpp
new file mode 100644
index 0000000..5fc9e85
--- /dev/null
+++ b/Swiften/SASL/ClientAuthenticator.cpp
@@ -0,0 +1,11 @@
+#include "Swiften/SASL/ClientAuthenticator.h"
+
+namespace Swift {
+
+ClientAuthenticator::ClientAuthenticator(const String& name) : name(name) {
+}
+
+ClientAuthenticator::~ClientAuthenticator() {
+}
+
+}
diff --git a/Swiften/SASL/ClientAuthenticator.h b/Swiften/SASL/ClientAuthenticator.h
new file mode 100644
index 0000000..f42a51e
--- /dev/null
+++ b/Swiften/SASL/ClientAuthenticator.h
@@ -0,0 +1,43 @@
+#pragma once
+
+#include "Swiften/Base/String.h"
+#include "Swiften/Base/ByteArray.h"
+
+namespace Swift {
+	class ClientAuthenticator {
+		public:
+			ClientAuthenticator(const String& name);
+			virtual ~ClientAuthenticator();
+
+			const String& getName() const {
+				return name;
+			}
+
+			void setCredentials(const String& authcid, const String& password, const String& authzid = String()) {
+				this->authcid = authcid;
+				this->password = password;
+				this->authzid = authzid;
+			}
+
+			virtual ByteArray getResponse() const = 0;
+			virtual bool setChallenge(const ByteArray&) = 0;
+
+			const String& getAuthenticationID() const {
+				return authcid;
+			}
+
+			const String& getAuthorizationID() const {
+				return authzid;
+			}
+
+			const String& getPassword() const {
+				return password;
+			}
+		
+		private:
+			String name;
+			String authcid;
+			String password;
+			String authzid;
+	};
+}
diff --git a/Swiften/SASL/PLAINClientAuthenticator.cpp b/Swiften/SASL/PLAINClientAuthenticator.cpp
new file mode 100644
index 0000000..8f88c3c
--- /dev/null
+++ b/Swiften/SASL/PLAINClientAuthenticator.cpp
@@ -0,0 +1,16 @@
+#include "Swiften/SASL/PLAINClientAuthenticator.h"
+
+namespace Swift {
+
+PLAINClientAuthenticator::PLAINClientAuthenticator() : ClientAuthenticator("PLAIN") {
+}
+
+ByteArray PLAINClientAuthenticator::getResponse() const {
+	return ByteArray(getAuthorizationID()) + '\0' + ByteArray(getAuthenticationID()) + '\0' + ByteArray(getPassword());
+}
+
+bool PLAINClientAuthenticator::setChallenge(const ByteArray&) {
+	return true;
+}
+
+}
diff --git a/Swiften/SASL/PLAINClientAuthenticator.h b/Swiften/SASL/PLAINClientAuthenticator.h
new file mode 100644
index 0000000..854eb30
--- /dev/null
+++ b/Swiften/SASL/PLAINClientAuthenticator.h
@@ -0,0 +1,13 @@
+#pragma once
+
+#include "Swiften/SASL/ClientAuthenticator.h"
+
+namespace Swift {
+	class PLAINClientAuthenticator : public ClientAuthenticator {
+		public:
+			PLAINClientAuthenticator();
+
+			virtual ByteArray getResponse() const;
+			virtual bool setChallenge(const ByteArray&);
+	};
+}
diff --git a/Swiften/SASL/PLAINMessage.h b/Swiften/SASL/PLAINMessage.h
index 76de4f5..dd5e2ee 100644
--- a/Swiften/SASL/PLAINMessage.h
+++ b/Swiften/SASL/PLAINMessage.h
@@ -1,3 +1,5 @@
+// TODO: Get rid of this
+//
 #pragma once
 
 #include "Swiften/Base/String.h"
diff --git a/Swiften/SASL/SCRAMSHA1ClientAuthenticator.cpp b/Swiften/SASL/SCRAMSHA1ClientAuthenticator.cpp
index b2e85e9..3109f56 100644
--- a/Swiften/SASL/SCRAMSHA1ClientAuthenticator.cpp
+++ b/Swiften/SASL/SCRAMSHA1ClientAuthenticator.cpp
@@ -7,16 +7,16 @@
 
 namespace Swift {
 
-SCRAMSHA1ClientAuthenticator::SCRAMSHA1ClientAuthenticator(const String& authcid, const String& password, const String& authzid, const ByteArray& nonce) : step(Initial), authcid(authcid), password(password), authzid(authzid), clientnonce(nonce) {
+SCRAMSHA1ClientAuthenticator::SCRAMSHA1ClientAuthenticator(const ByteArray& nonce) : ClientAuthenticator("SCRAM-SHA1"), step(Initial), clientnonce(nonce) {
 }
 
-ByteArray SCRAMSHA1ClientAuthenticator::getMessage() const {
+ByteArray SCRAMSHA1ClientAuthenticator::getResponse() const {
 	if (step == Initial) {
 		return getInitialClientMessage();
 	}
 	else {
 		ByteArray mask = HMACSHA1::getResult(getClientVerifier(), initialServerMessage + getInitialClientMessage());
-		ByteArray p = SHA1::getBinaryHash(password);
+		ByteArray p = SHA1::getBinaryHash(getPassword());
 		for (unsigned int i = 0; i < p.getSize(); ++i) {
 			p[i] ^= mask[i];
 		}
@@ -24,7 +24,7 @@ ByteArray SCRAMSHA1ClientAuthenticator::getMessage() const {
 	}
 }
 
-bool SCRAMSHA1ClientAuthenticator::setResponse(const ByteArray& response) {
+bool SCRAMSHA1ClientAuthenticator::setChallenge(const ByteArray& response) {
 	if (step == Initial) {
 		initialServerMessage = response;
 		step = Proof;
@@ -46,11 +46,11 @@ ByteArray SCRAMSHA1ClientAuthenticator::getSalt() const {
 }
 
 ByteArray SCRAMSHA1ClientAuthenticator::getClientVerifier() const {
-	return HMACSHA1::getResult(SHA1::getBinaryHash(password), getSalt());
+	return HMACSHA1::getResult(SHA1::getBinaryHash(getPassword()), getSalt());
 }
 
 ByteArray SCRAMSHA1ClientAuthenticator::getInitialClientMessage() const {
-	return ByteArray(authzid) + '\0' + ByteArray(authcid) + '\0' + ByteArray(clientnonce);
+	return ByteArray(getAuthorizationID()) + '\0' + ByteArray(getAuthenticationID()) + '\0' + ByteArray(clientnonce);
 }
 
 }
diff --git a/Swiften/SASL/SCRAMSHA1ClientAuthenticator.h b/Swiften/SASL/SCRAMSHA1ClientAuthenticator.h
index d129468..161afd1 100644
--- a/Swiften/SASL/SCRAMSHA1ClientAuthenticator.h
+++ b/Swiften/SASL/SCRAMSHA1ClientAuthenticator.h
@@ -2,14 +2,15 @@
 
 #include "Swiften/Base/String.h"
 #include "Swiften/Base/ByteArray.h"
+#include "Swiften/SASL/ClientAuthenticator.h"
 
 namespace Swift {
-	class SCRAMSHA1ClientAuthenticator {
+	class SCRAMSHA1ClientAuthenticator : public ClientAuthenticator {
 		public:
-			SCRAMSHA1ClientAuthenticator(const String& authcid, const String& password, const String& authzid, const ByteArray& nonce);
-
-			ByteArray getMessage() const;
-			bool setResponse(const ByteArray&);
+			SCRAMSHA1ClientAuthenticator(const ByteArray& nonce);
+			
+			ByteArray getResponse() const;
+			bool setChallenge(const ByteArray&);
 
 		private:
 			ByteArray getInitialClientMessage() const;
diff --git a/Swiften/SASL/UnitTest/PLAINClientAuthenticatorTest.cpp b/Swiften/SASL/UnitTest/PLAINClientAuthenticatorTest.cpp
new file mode 100644
index 0000000..b83e1f5
--- /dev/null
+++ b/Swiften/SASL/UnitTest/PLAINClientAuthenticatorTest.cpp
@@ -0,0 +1,35 @@
+#include <cppunit/extensions/HelperMacros.h>
+#include <cppunit/extensions/TestFactoryRegistry.h>
+
+#include "Swiften/SASL/PLAINClientAuthenticator.h"
+
+using namespace Swift;
+
+class PLAINClientAuthenticatorTest : public CppUnit::TestFixture
+{
+		CPPUNIT_TEST_SUITE(PLAINClientAuthenticatorTest);
+		CPPUNIT_TEST(testGetResponse_WithoutAuthzID);
+		CPPUNIT_TEST(testGetResponse_WithAuthzID);
+		CPPUNIT_TEST_SUITE_END();
+
+	public:
+		PLAINClientAuthenticatorTest() {}
+
+		void testGetResponse_WithoutAuthzID() {
+			PLAINClientAuthenticator testling;
+
+			testling.setCredentials("user", "pass");
+
+			CPPUNIT_ASSERT_EQUAL(testling.getResponse(), ByteArray("\0user\0pass", 10));
+		}
+
+		void testGetResponse_WithAuthzID() {
+			PLAINClientAuthenticator testling;
+
+			testling.setCredentials("user", "pass", "authz");
+
+			CPPUNIT_ASSERT_EQUAL(testling.getResponse(), ByteArray("authz\0user\0pass", 15));
+		}
+};
+
+CPPUNIT_TEST_SUITE_REGISTRATION(PLAINClientAuthenticatorTest);
diff --git a/Swiften/SConscript b/Swiften/SConscript
index be0af24..a5ef56d 100644
--- a/Swiften/SConscript
+++ b/Swiften/SConscript
@@ -42,6 +42,8 @@ sources = [
 		"Roster/ContactRosterItem.cpp",
 		"Roster/Roster.cpp",
 		"Roster/XMPPRoster.cpp",
+		"SASL/ClientAuthenticator.cpp",
+		"SASL/PLAINClientAuthenticator.cpp",
 		"SASL/PLAINMessage.cpp",
 		"SASL/SCRAMSHA1ClientAuthenticator.cpp",
 		"Serializer/AuthRequestSerializer.cpp",
@@ -159,6 +161,7 @@ env.Append(UNITTEST_SOURCES = [
 		File("Roster/UnitTest/OfflineRosterFilterTest.cpp"),
 		File("Roster/UnitTest/RosterTest.cpp"),
 		File("SASL/UnitTest/PLAINMessageTest.cpp"),
+		File("SASL/UnitTest/PLAINClientAuthenticatorTest.cpp"),
 		File("Serializer/PayloadSerializers/UnitTest/PayloadsSerializer.cpp"),
 		File("Serializer/PayloadSerializers/UnitTest/CapsInfoSerializerTest.cpp"),
 		File("Serializer/PayloadSerializers/UnitTest/DiscoInfoSerializerTest.cpp"),
-- 
cgit v0.10.2-6-g49f6