From 6080dd4915801b45598268c805b62aa6c723a3a3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Remko=20Tron=C3=A7on?= <git@el-tramo.be>
Date: Mon, 18 Jun 2012 21:00:13 +0200
Subject: Handle unexpected challenges.

Resolves: #1132

diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp
index 7e1f517..48e38b9 100644
--- a/Swiften/Client/ClientSession.cpp
+++ b/Swiften/Client/ClientSession.cpp
@@ -37,6 +37,7 @@
 #include <Swiften/Elements/IQ.h>
 #include <Swiften/Elements/ResourceBind.h>
 #include <Swiften/SASL/PLAINClientAuthenticator.h>
+#include <Swiften/SASL/EXTERNALClientAuthenticator.h>
 #include <Swiften/SASL/SCRAMSHA1ClientAuthenticator.h>
 #include <Swiften/SASL/DIGESTMD5ClientAuthenticator.h>
 #include <Swiften/Session/SessionStream.h>
@@ -48,6 +49,9 @@
 #include <Swiften/Base/WindowsRegistry.h>
 #endif
 
+#define CHECK_STATE_OR_RETURN(a) \
+	if (!checkState(a)) { return; }
+
 namespace Swift {
 
 ClientSession::ClientSession(
@@ -101,7 +105,7 @@ void ClientSession::sendStanza(boost::shared_ptr<Stanza> stanza) {
 }
 
 void ClientSession::handleStreamStart(const ProtocolHeader&) {
-	checkState(WaitingForStreamStart);
+	CHECK_STATE_OR_RETURN(WaitingForStreamStart);
 	state = Negotiating;
 }
 
@@ -182,9 +186,7 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
 		}
 	}
 	else if (StreamFeatures* streamFeatures = dynamic_cast<StreamFeatures*>(element.get())) {
-		if (!checkState(Negotiating)) {
-			return;
-		}
+		CHECK_STATE_OR_RETURN(Negotiating);
 
 		if (streamFeatures->hasStartTLS() && stream->supportsTLSEncryption() && useTLS != NeverUseTLS) {
 			state = WaitingForEncrypt;
@@ -200,6 +202,7 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
 		else if (streamFeatures->hasAuthenticationMechanisms()) {
 			if (stream->hasTLSCertificate()) {
 				if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) {
+					authenticator = new EXTERNALClientAuthenticator();
 					state = Authenticating;
 					stream->writeElement(boost::make_shared<AuthRequest>("EXTERNAL", createSafeByteArray("")));
 				}
@@ -208,6 +211,7 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
 				}
 			}
 			else if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) {
+				authenticator = new EXTERNALClientAuthenticator();
 				state = Authenticating;
 				stream->writeElement(boost::make_shared<AuthRequest>("EXTERNAL", createSafeByteArray("")));
 			}
@@ -262,7 +266,7 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
 		}
 	}
 	else if (boost::dynamic_pointer_cast<Compressed>(element)) {
-		checkState(Compressing);
+		CHECK_STATE_OR_RETURN(Compressing);
 		state = WaitingForStreamStart;
 		stream->addZLibCompression();
 		stream->resetXMPPParser();
@@ -285,7 +289,7 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
 		continueSessionInitialization();
 	}
 	else if (AuthChallenge* challenge = dynamic_cast<AuthChallenge*>(element.get())) {
-		checkState(Authenticating);
+		CHECK_STATE_OR_RETURN(Authenticating);
 		assert(authenticator);
 		if (authenticator->setChallenge(challenge->getValue())) {
 			stream->writeElement(boost::make_shared<AuthResponse>(authenticator->getResponse()));
@@ -295,10 +299,9 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
 		}
 	}
 	else if (AuthSuccess* authSuccess = dynamic_cast<AuthSuccess*>(element.get())) {
-		checkState(Authenticating);
-		if (authenticator && !authenticator->setChallenge(authSuccess->getValue())) {
-			delete authenticator;
-			authenticator = NULL;
+		CHECK_STATE_OR_RETURN(Authenticating);
+		assert(authenticator);
+		if (!authenticator->setChallenge(authSuccess->getValue())) {
 			finishSession(Error::ServerVerificationFailedError);
 		}
 		else {
@@ -310,12 +313,10 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
 		}
 	}
 	else if (dynamic_cast<AuthFailure*>(element.get())) {
-		delete authenticator;
-		authenticator = NULL;
 		finishSession(Error::AuthenticationFailedError);
 	}
 	else if (dynamic_cast<TLSProceed*>(element.get())) {
-		checkState(WaitingForEncrypt);
+		CHECK_STATE_OR_RETURN(WaitingForEncrypt);
 		state = Encrypting;
 		stream->addTLSEncryption();
 	}
@@ -362,13 +363,14 @@ bool ClientSession::checkState(State state) {
 
 void ClientSession::sendCredentials(const SafeByteArray& password) {
 	assert(WaitingForCredentials);
+	assert(authenticator);
 	state = Authenticating;
 	authenticator->setCredentials(localJID.getNode(), password);
 	stream->writeElement(boost::make_shared<AuthRequest>(authenticator->getName(), authenticator->getResponse()));
 }
 
 void ClientSession::handleTLSEncrypted() {
-	checkState(Encrypting);
+	CHECK_STATE_OR_RETURN(Encrypting);
 
 	std::vector<Certificate::ref> certificateChain = stream->getPeerCertificateChain();
 	boost::shared_ptr<CertificateVerificationError> verificationError = stream->getPeerCertificateVerificationError();
@@ -448,6 +450,10 @@ void ClientSession::finishSession(boost::shared_ptr<Swift::Error> error) {
 	if (stanzaAckResponder_) {
 		stanzaAckResponder_->handleAckRequestReceived();
 	}
+	if (authenticator) {
+		delete authenticator;
+		authenticator = NULL;
+	}
 	stream->writeFooter();
 	stream->close();
 }
diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp
index d1ca70a..a8cd53c 100644
--- a/Swiften/Client/UnitTest/ClientSessionTest.cpp
+++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp
@@ -14,6 +14,7 @@
 #include <Swiften/Session/SessionStream.h>
 #include <Swiften/Client/ClientSession.h>
 #include <Swiften/Elements/Message.h>
+#include <Swiften/Elements/AuthChallenge.h>
 #include <Swiften/Elements/StartTLSRequest.h>
 #include <Swiften/Elements/StreamFeatures.h>
 #include <Swiften/Elements/StreamError.h>
@@ -47,8 +48,10 @@ class ClientSessionTest : public CppUnit::TestFixture {
 		CPPUNIT_TEST(testAuthenticate_NoValidAuthMechanisms);
 		CPPUNIT_TEST(testAuthenticate_PLAINOverNonTLS);
 		CPPUNIT_TEST(testAuthenticate_RequireTLS);
+		CPPUNIT_TEST(testAuthenticate_EXTERNAL);
 		CPPUNIT_TEST(testStreamManagement);
 		CPPUNIT_TEST(testStreamManagement_Failed);
+		CPPUNIT_TEST(testUnexpectedChallenge);
 		CPPUNIT_TEST(testFinishAcksStanzas);
 		/*
 		CPPUNIT_TEST(testResourceBind);
@@ -247,6 +250,34 @@ class ClientSessionTest : public CppUnit::TestFixture {
 			CPPUNIT_ASSERT(sessionFinishedError);
 		}
 
+		void testAuthenticate_EXTERNAL() {
+			boost::shared_ptr<ClientSession> session(createSession());
+			session->start();
+			server->receiveStreamStart();
+			server->sendStreamStart();
+			server->sendStreamFeaturesWithEXTERNALAuthentication();
+			server->receiveAuthRequest("EXTERNAL");
+			server->sendAuthSuccess();
+			server->receiveStreamStart();
+
+			session->finish();
+		}
+
+		void testUnexpectedChallenge() {
+			boost::shared_ptr<ClientSession> session(createSession());
+			session->start();
+			server->receiveStreamStart();
+			server->sendStreamStart();
+			server->sendStreamFeaturesWithEXTERNALAuthentication();
+			server->receiveAuthRequest("EXTERNAL");
+			server->sendChallenge();
+			server->sendChallenge();
+
+			CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
+			CPPUNIT_ASSERT(sessionFinishedReceived);
+			CPPUNIT_ASSERT(sessionFinishedError);
+		}
+
 		void testStreamManagement() {
 			boost::shared_ptr<ClientSession> session(createSession());
 			session->start();
@@ -444,6 +475,10 @@ class ClientSessionTest : public CppUnit::TestFixture {
 					onElementReceived(streamFeatures);
 				}
 
+				void sendChallenge() {
+					onElementReceived(boost::make_shared<AuthChallenge>());
+				}
+
 				void sendStreamError() {
 					onElementReceived(boost::make_shared<StreamError>());
 				}
@@ -470,6 +505,12 @@ class ClientSessionTest : public CppUnit::TestFixture {
 					onElementReceived(streamFeatures);
 				}
 
+				void sendStreamFeaturesWithEXTERNALAuthentication() {
+					boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures());
+					streamFeatures->addAuthenticationMechanism("EXTERNAL");
+					onElementReceived(streamFeatures);
+				}
+
 				void sendStreamFeaturesWithUnknownAuthentication() {
 					boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures());
 					streamFeatures->addAuthenticationMechanism("UNKNOWN");
diff --git a/Swiften/SASL/EXTERNALClientAuthenticator.cpp b/Swiften/SASL/EXTERNALClientAuthenticator.cpp
new file mode 100644
index 0000000..a3016d1
--- /dev/null
+++ b/Swiften/SASL/EXTERNALClientAuthenticator.cpp
@@ -0,0 +1,26 @@
+/*
+ * Copyright (c) 2012 Remko Tronçon
+ * Licensed under the GNU General Public License v3.
+ * See Documentation/Licenses/GPLv3.txt for more information.
+ */
+
+#include <Swiften/SASL/EXTERNALClientAuthenticator.h>
+
+namespace Swift {
+
+EXTERNALClientAuthenticator::EXTERNALClientAuthenticator() : ClientAuthenticator("EXTERNAL"), finished(false) {
+}
+
+boost::optional<SafeByteArray> EXTERNALClientAuthenticator::getResponse() const {
+	return boost::optional<SafeByteArray>();
+}
+
+bool EXTERNALClientAuthenticator::setChallenge(const boost::optional<ByteArray>&) {
+	if (finished) {
+		return false;
+	}
+	finished = true;
+	return true;
+}
+
+}
diff --git a/Swiften/SASL/EXTERNALClientAuthenticator.h b/Swiften/SASL/EXTERNALClientAuthenticator.h
new file mode 100644
index 0000000..b986295
--- /dev/null
+++ b/Swiften/SASL/EXTERNALClientAuthenticator.h
@@ -0,0 +1,23 @@
+/*
+ * Copyright (c) 2012 Remko Tronçon
+ * Licensed under the GNU General Public License v3.
+ * See Documentation/Licenses/GPLv3.txt for more information.
+ */
+
+#pragma once
+
+#include <Swiften/SASL/ClientAuthenticator.h>
+#include <Swiften/Base/ByteArray.h>
+
+namespace Swift {
+	class EXTERNALClientAuthenticator : public ClientAuthenticator {
+		public:
+			EXTERNALClientAuthenticator();
+
+			virtual boost::optional<SafeByteArray> getResponse() const;
+			virtual bool setChallenge(const boost::optional<ByteArray>&);
+
+		private:
+			bool finished;
+	};
+}
diff --git a/Swiften/SASL/SConscript b/Swiften/SASL/SConscript
index 3a67938..6509547 100644
--- a/Swiften/SASL/SConscript
+++ b/Swiften/SASL/SConscript
@@ -4,6 +4,7 @@ myenv = swiften_env.Clone()
 
 objects = myenv.SwiftenObject([
 		"ClientAuthenticator.cpp",
+		"EXTERNALClientAuthenticator.cpp",
 		"PLAINClientAuthenticator.cpp",
 		"PLAINMessage.cpp",
 		"SCRAMSHA1ClientAuthenticator.cpp",
-- 
cgit v0.10.2-6-g49f6