From a1fc89586ddfbaccca5fb1ed17f9d62137be25a0 Mon Sep 17 00:00:00 2001
From: Kevin Smith <git@kismith.co.uk>
Date: Tue, 20 Sep 2011 13:05:28 +0100
Subject: Swiften support for requiring TLS


diff --git a/Swiften/Client/ClientOptions.h b/Swiften/Client/ClientOptions.h
index 0766402..6b15f18 100644
--- a/Swiften/Client/ClientOptions.h
+++ b/Swiften/Client/ClientOptions.h
@@ -10,7 +10,8 @@ namespace Swift {
 	struct ClientOptions {
 		enum UseTLS {
 			NeverUseTLS,
-			UseTLSWhenAvailable
+			UseTLSWhenAvailable,
+			RequireTLS
 		};
 
 		ClientOptions() : useStreamCompression(true), useTLS(UseTLSWhenAvailable), allowPLAINWithoutTLS(false), useStreamResumption(false), forgetPassword(false) {
diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp
index 8945e9a..2eeb3c0 100644
--- a/Swiften/Client/ClientSession.cpp
+++ b/Swiften/Client/ClientSession.cpp
@@ -177,6 +177,9 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
 			state = WaitingForEncrypt;
 			stream->writeElement(boost::make_shared<StartTLSRequest>());
 		}
+		else if (useTLS == RequireTLS && !stream->isTLSEncrypted()) {
+			finishSession(Error::NoSupportedAuthMechanismsError);
+		}
 		else if (useStreamCompression && streamFeatures->hasCompressionMethod("zlib")) {
 			state = Compressing;
 			stream->writeElement(boost::make_shared<CompressRequest>("zlib"));
diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h
index ace9868..e58e758 100644
--- a/Swiften/Client/ClientSession.h
+++ b/Swiften/Client/ClientSession.h
@@ -59,7 +59,8 @@ namespace Swift {
 
 			enum UseTLS {
 				NeverUseTLS,
-				UseTLSWhenAvailable
+				UseTLSWhenAvailable,
+				RequireTLS
 			};
 
 			~ClientSession();
diff --git a/Swiften/Client/CoreClient.cpp b/Swiften/Client/CoreClient.cpp
index 9aaa66b..cceec74 100644
--- a/Swiften/Client/CoreClient.cpp
+++ b/Swiften/Client/CoreClient.cpp
@@ -119,6 +119,9 @@ void CoreClient::handleConnectorFinished(boost::shared_ptr<Connection> connectio
 			case ClientOptions::NeverUseTLS:
 				session_->setUseTLS(ClientSession::NeverUseTLS);
 				break;
+			case ClientOptions::RequireTLS:
+				session_->setUseTLS(ClientSession::RequireTLS);
+				break;
 		}
 		stanzaChannel_->setSession(session_);
 		session_->onFinished.connect(boost::bind(&CoreClient::handleSessionFinished, this, _1));
diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp
index 57e53e4..e9d1b21 100644
--- a/Swiften/Client/UnitTest/ClientSessionTest.cpp
+++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp
@@ -45,6 +45,8 @@ class ClientSessionTest : public CppUnit::TestFixture {
 		CPPUNIT_TEST(testAuthenticate);
 		CPPUNIT_TEST(testAuthenticate_Unauthorized);
 		CPPUNIT_TEST(testAuthenticate_NoValidAuthMechanisms);
+		CPPUNIT_TEST(testAuthenticate_PLAINOverNonTLS);
+		CPPUNIT_TEST(testAuthenticate_RequireTLS);
 		CPPUNIT_TEST(testStreamManagement);
 		CPPUNIT_TEST(testStreamManagement_Failed);
 		CPPUNIT_TEST(testFinishAcksStanzas);
@@ -219,6 +221,20 @@ class ClientSessionTest : public CppUnit::TestFixture {
 			CPPUNIT_ASSERT(sessionFinishedError);
 		}
 
+		void testAuthenticate_RequireTLS() {
+			boost::shared_ptr<ClientSession> session(createSession());
+			session->setUseTLS(ClientSession::RequireTLS);
+			session->setAllowPLAINOverNonTLS(true);
+			session->start();
+			server->receiveStreamStart();
+			server->sendStreamStart();
+			server->sendStreamFeaturesWithMultipleAuthentication();
+
+			CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
+			CPPUNIT_ASSERT(sessionFinishedReceived);
+			CPPUNIT_ASSERT(sessionFinishedError);
+		}
+
 		void testAuthenticate_NoValidAuthMechanisms() {
 			boost::shared_ptr<ClientSession> session(createSession());
 			session->start();
@@ -432,6 +448,14 @@ class ClientSessionTest : public CppUnit::TestFixture {
 					onElementReceived(boost::shared_ptr<StartTLSFailure>(new StartTLSFailure()));
 				}
 
+				void sendStreamFeaturesWithMultipleAuthentication() {
+					boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures());
+					streamFeatures->addAuthenticationMechanism("PLAIN");
+					streamFeatures->addAuthenticationMechanism("DIGEST-MD5");
+					streamFeatures->addAuthenticationMechanism("SCRAM-SHA1");
+					onElementReceived(streamFeatures);
+				}
+
 				void sendStreamFeaturesWithPLAINAuthentication() {
 					boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures());
 					streamFeatures->addAuthenticationMechanism("PLAIN");
-- 
cgit v0.10.2-6-g49f6