From aaf38fe2e6804bd87ea5e99a05ed57070cbe1c57 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Remko=20Tron=C3=A7on?= <git@el-tramo.be>
Date: Sat, 11 Dec 2010 13:43:08 +0100
Subject: Added SCRAM-SHA-1-PLUS support.

Release-Notes: Swift now supports SCRAM-SHA-1-PLUS authentication.

diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp
index 0398012..d4cf065 100644
--- a/Swiften/Client/ClientSession.cpp
+++ b/Swiften/Client/ClientSession.cpp
@@ -191,10 +191,13 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
 				state = Authenticating;
 				stream->writeElement(boost::shared_ptr<Element>(new AuthRequest("EXTERNAL", "")));
 			}
-			else if (streamFeatures->hasAuthenticationMechanism("SCRAM-SHA-1")) {
+			else if (streamFeatures->hasAuthenticationMechanism("SCRAM-SHA-1") || streamFeatures->hasAuthenticationMechanism("SCRAM-SHA-1-PLUS")) {
 				std::ostringstream s;
 				s << boost::uuids::random_generator()();
-				SCRAMSHA1ClientAuthenticator* scramAuthenticator = new SCRAMSHA1ClientAuthenticator(s.str(), false);
+				SCRAMSHA1ClientAuthenticator* scramAuthenticator = new SCRAMSHA1ClientAuthenticator(s.str(), streamFeatures->hasAuthenticationMechanism("SCRAM-SHA-1-PLUS"));
+				if (stream->isTLSEncrypted()) {
+					scramAuthenticator->setTLSChannelBindingData(stream->getTLSFinishMessage());
+				}
 				authenticator = scramAuthenticator;
 				state = WaitingForCredentials;
 				onNeedCredentials();
diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp
index dbed9aa..358e308 100644
--- a/Swiften/Client/UnitTest/ClientSessionTest.cpp
+++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp
@@ -324,6 +324,10 @@ class ClientSessionTest : public CppUnit::TestFixture {
 					return tlsEncrypted;
 				}
 
+				virtual ByteArray getTLSFinishMessage() const {
+					return ByteArray();
+				}
+
 				virtual Certificate::ref getPeerCertificate() const {
 					return Certificate::ref(new SimpleCertificate());
 				}
diff --git a/Swiften/Component/UnitTest/ComponentSessionTest.cpp b/Swiften/Component/UnitTest/ComponentSessionTest.cpp
index 4fe8e87..ac24778 100644
--- a/Swiften/Component/UnitTest/ComponentSessionTest.cpp
+++ b/Swiften/Component/UnitTest/ComponentSessionTest.cpp
@@ -123,6 +123,10 @@ class ComponentSessionTest : public CppUnit::TestFixture {
 					return false;
 				}
 
+				virtual ByteArray getTLSFinishMessage() const {
+					return ByteArray();
+				}
+
 				virtual Certificate::ref getPeerCertificate() const {
 					return Certificate::ref();
 				}
diff --git a/Swiften/SASL/SCRAMSHA1ClientAuthenticator.cpp b/Swiften/SASL/SCRAMSHA1ClientAuthenticator.cpp
index 4e00397..2cc7cea 100644
--- a/Swiften/SASL/SCRAMSHA1ClientAuthenticator.cpp
+++ b/Swiften/SASL/SCRAMSHA1ClientAuthenticator.cpp
@@ -35,7 +35,7 @@ static String escape(const String& s) {
 }
 
 
-SCRAMSHA1ClientAuthenticator::SCRAMSHA1ClientAuthenticator(const String& nonce, bool useChannelBinding) : ClientAuthenticator("SCRAM-SHA-1"), step(Initial), clientnonce(nonce), useChannelBinding(useChannelBinding) {
+SCRAMSHA1ClientAuthenticator::SCRAMSHA1ClientAuthenticator(const String& nonce, bool useChannelBinding) : ClientAuthenticator(useChannelBinding ? "SCRAM-SHA-1-PLUS" : "SCRAM-SHA-1"), step(Initial), clientnonce(nonce), useChannelBinding(useChannelBinding) {
 }
 
 boost::optional<ByteArray> SCRAMSHA1ClientAuthenticator::getResponse() const {
@@ -50,11 +50,7 @@ boost::optional<ByteArray> SCRAMSHA1ClientAuthenticator::getResponse() const {
 		for (unsigned int i = 0; i < clientProof.getSize(); ++i) {
 			clientProof[i] ^= clientSignature[i];
 		}
-		ByteArray channelBindData;
-		if (useChannelBinding && tlsChannelBindingData) {
-			channelBindData = *tlsChannelBindingData;
-		}
-		ByteArray result = ByteArray("c=") + Base64::encode(getGS2Header() + channelBindData) + ",r=" + clientnonce + serverNonce + ",p=" + Base64::encode(clientProof);
+		ByteArray result = getFinalMessageWithoutProof() + ",p=" + Base64::encode(clientProof);
 		return result;
 	}
 	else {
@@ -97,9 +93,14 @@ bool SCRAMSHA1ClientAuthenticator::setChallenge(const boost::optional<ByteArray>
 			return false;
 		}
 
+		ByteArray channelBindData;
+		if (useChannelBinding && tlsChannelBindingData) {
+			channelBindData = *tlsChannelBindingData;
+		}
+
 		// Compute all the values needed for the server signature
 		saltedPassword = PBKDF2::encode(StringPrep::getPrepared(getPassword(), StringPrep::SASLPrep), salt, iterations);
-		authMessage = getInitialBareClientMessage() + "," + initialServerMessage + "," + "c=" + Base64::encode(getGS2Header()) + ",r=" + clientnonce + serverNonce;
+		authMessage = getInitialBareClientMessage() + "," + initialServerMessage + "," + getFinalMessageWithoutProof();
 		ByteArray serverKey = HMACSHA1::getResult(saltedPassword, "Server Key");
 		serverSignature = HMACSHA1::getResult(serverKey, authMessage);
 
@@ -153,7 +154,7 @@ ByteArray SCRAMSHA1ClientAuthenticator::getGS2Header() const {
 	ByteArray channelBindingHeader("n");
 	if (tlsChannelBindingData) {
 		if (useChannelBinding) {
-			channelBindingHeader = ByteArray("p=tls-server-end-point");
+			channelBindingHeader = ByteArray("p=tls-unique");
 		}
 		else {
 			channelBindingHeader = ByteArray("y");
@@ -166,4 +167,13 @@ void SCRAMSHA1ClientAuthenticator::setTLSChannelBindingData(const ByteArray& cha
 	this->tlsChannelBindingData = channelBindingData;
 }
 
+ByteArray SCRAMSHA1ClientAuthenticator::getFinalMessageWithoutProof() const {
+	ByteArray channelBindData;
+	if (useChannelBinding && tlsChannelBindingData) {
+		channelBindData = *tlsChannelBindingData;
+	}
+	return ByteArray("c=") + Base64::encode(getGS2Header() + channelBindData) + ",r=" + clientnonce + serverNonce;
+}
+
+
 }
diff --git a/Swiften/SASL/SCRAMSHA1ClientAuthenticator.h b/Swiften/SASL/SCRAMSHA1ClientAuthenticator.h
index b44e6b7..2cf3cc7 100644
--- a/Swiften/SASL/SCRAMSHA1ClientAuthenticator.h
+++ b/Swiften/SASL/SCRAMSHA1ClientAuthenticator.h
@@ -26,6 +26,7 @@ namespace Swift {
 		private:
 			ByteArray getInitialBareClientMessage() const;
 			ByteArray getGS2Header() const;
+			ByteArray getFinalMessageWithoutProof() const;
 
 			static std::map<char, String> parseMap(const String&);
 
diff --git a/Swiften/SASL/UnitTest/SCRAMSHA1ClientAuthenticatorTest.cpp b/Swiften/SASL/UnitTest/SCRAMSHA1ClientAuthenticatorTest.cpp
index 0d12bd3..0e42f38 100644
--- a/Swiften/SASL/UnitTest/SCRAMSHA1ClientAuthenticatorTest.cpp
+++ b/Swiften/SASL/UnitTest/SCRAMSHA1ClientAuthenticatorTest.cpp
@@ -92,7 +92,7 @@ class SCRAMSHA1ClientAuthenticatorTest : public CppUnit::TestFixture {
 
 			ByteArray response = *testling.getResponse();
 
-			CPPUNIT_ASSERT_EQUAL(String("p=tls-server-end-point,,n=user,r=abcdefghABCDEFGH"), response.toString());
+			CPPUNIT_ASSERT_EQUAL(String("p=tls-unique,,n=user,r=abcdefghABCDEFGH"), response.toString());
 		}
 
 		void testGetFinalResponse() {
@@ -124,7 +124,7 @@ class SCRAMSHA1ClientAuthenticatorTest : public CppUnit::TestFixture {
 
 			ByteArray response = *testling.getResponse();
 
-			CPPUNIT_ASSERT_EQUAL(String("c=cD10bHMtc2VydmVyLWVuZC1wb2ludCwseHl6YQ==,r=abcdefghABCDEFGH,p=ycZyNs03w1HlRzFmXl8dlKx3NAU="), response.toString());
+			CPPUNIT_ASSERT_EQUAL(String("c=cD10bHMtdW5pcXVlLCx4eXph,r=abcdefghABCDEFGH,p=i6Rghite81P1ype8XxaVAa5l7v0="), response.toString());
 		}
 
 		void testSetFinalChallenge() {
diff --git a/Swiften/Session/BasicSessionStream.cpp b/Swiften/Session/BasicSessionStream.cpp
index 32424bc..45489cf 100644
--- a/Swiften/Session/BasicSessionStream.cpp
+++ b/Swiften/Session/BasicSessionStream.cpp
@@ -15,6 +15,7 @@
 #include "Swiften/StreamStack/CompressionLayer.h"
 #include "Swiften/StreamStack/TLSLayer.h"
 #include "Swiften/TLS/TLSContextFactory.h"
+#include "Swiften/TLS/TLSContext.h"
 
 namespace Swift {
 
@@ -93,6 +94,9 @@ boost::shared_ptr<CertificateVerificationError> BasicSessionStream::getPeerCerti
 	return tlsLayer->getPeerCertificateVerificationError();
 }
 
+ByteArray BasicSessionStream::getTLSFinishMessage() const {
+	return tlsLayer->getContext()->getFinishMessage();
+}
 
 void BasicSessionStream::addZLibCompression() {
 	boost::shared_ptr<CompressionLayer> compressionLayer(new CompressionLayer());
diff --git a/Swiften/Session/BasicSessionStream.h b/Swiften/Session/BasicSessionStream.h
index fbaa937..6927800 100644
--- a/Swiften/Session/BasicSessionStream.h
+++ b/Swiften/Session/BasicSessionStream.h
@@ -54,6 +54,7 @@ namespace Swift {
 			virtual bool isTLSEncrypted();
 			virtual Certificate::ref getPeerCertificate() const;
 			virtual boost::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const;
+			virtual ByteArray getTLSFinishMessage() const;
 
 			virtual void setWhitespacePingEnabled(bool);
 
diff --git a/Swiften/Session/SessionStream.h b/Swiften/Session/SessionStream.h
index d648f91..d3d3ebb 100644
--- a/Swiften/Session/SessionStream.h
+++ b/Swiften/Session/SessionStream.h
@@ -63,6 +63,8 @@ namespace Swift {
 			virtual Certificate::ref getPeerCertificate() const = 0;
 			virtual boost::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const = 0;
 
+			virtual ByteArray getTLSFinishMessage() const = 0;
+
 			boost::signal<void (const ProtocolHeader&)> onStreamStartReceived;
 			boost::signal<void (boost::shared_ptr<Element>)> onElementReceived;
 			boost::signal<void (boost::shared_ptr<Error>)> onError;
diff --git a/Swiften/StreamStack/TLSLayer.h b/Swiften/StreamStack/TLSLayer.h
index a69f789..22e9aef 100644
--- a/Swiften/StreamStack/TLSLayer.h
+++ b/Swiften/StreamStack/TLSLayer.h
@@ -30,6 +30,10 @@ namespace Swift {
 			void writeData(const ByteArray& data);
 			void handleDataRead(const ByteArray& data);
 
+			TLSContext* getContext() const {
+				return context;
+			}
+
 		public:
 			boost::signal<void ()> onError;
 			boost::signal<void ()> onConnected;
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
index 50436c7..6c55a63 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
+++ b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
@@ -23,6 +23,7 @@
 
 namespace Swift {
 
+static const int MAX_FINISHED_SIZE = 4096;
 static const int SSL_READ_BUFFERSIZE = 8192;
 
 void freeX509Stack(STACK_OF(X509)* stack) {
@@ -210,6 +211,14 @@ boost::shared_ptr<CertificateVerificationError> OpenSSLContext::getPeerCertifica
 	}
 }
 
+ByteArray OpenSSLContext::getFinishMessage() const {
+	ByteArray data;
+	data.resize(MAX_FINISHED_SIZE);
+	size_t size = SSL_get_finished(handle_, data.getData(), data.getSize());
+	data.resize(size);
+	return data;
+}
+
 CertificateVerificationError::Type OpenSSLContext::getVerificationErrorTypeForResult(int result) {
 	assert(result != 0);
 	switch (result) {
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.h b/Swiften/TLS/OpenSSL/OpenSSLContext.h
index 9cb287d..40e5483 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLContext.h
+++ b/Swiften/TLS/OpenSSL/OpenSSLContext.h
@@ -30,6 +30,8 @@ namespace Swift {
 			Certificate::ref getPeerCertificate() const;
 			boost::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const;
 
+			virtual ByteArray getFinishMessage() const;
+
 		private:
 			static void ensureLibraryInitialized();	
 
diff --git a/Swiften/TLS/TLSContext.h b/Swiften/TLS/TLSContext.h
index 2b8ed2d..1279eeb 100644
--- a/Swiften/TLS/TLSContext.h
+++ b/Swiften/TLS/TLSContext.h
@@ -30,6 +30,8 @@ namespace Swift {
 			virtual Certificate::ref getPeerCertificate() const = 0;
 			virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const = 0;
 
+			virtual ByteArray getFinishMessage() const = 0;
+
 		public:
 			boost::signal<void (const ByteArray&)> onDataForNetwork;
 			boost::signal<void (const ByteArray&)> onDataForApplication;
-- 
cgit v0.10.2-6-g49f6