From 6080dd4915801b45598268c805b62aa6c723a3a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Remko=20Tron=C3=A7on?= Date: Mon, 18 Jun 2012 21:00:13 +0200 Subject: Handle unexpected challenges. Resolves: #1132 diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp index 7e1f517..48e38b9 100644 --- a/Swiften/Client/ClientSession.cpp +++ b/Swiften/Client/ClientSession.cpp @@ -37,6 +37,7 @@ #include #include #include +#include #include #include #include @@ -48,6 +49,9 @@ #include #endif +#define CHECK_STATE_OR_RETURN(a) \ + if (!checkState(a)) { return; } + namespace Swift { ClientSession::ClientSession( @@ -101,7 +105,7 @@ void ClientSession::sendStanza(boost::shared_ptr stanza) { } void ClientSession::handleStreamStart(const ProtocolHeader&) { - checkState(WaitingForStreamStart); + CHECK_STATE_OR_RETURN(WaitingForStreamStart); state = Negotiating; } @@ -182,9 +186,7 @@ void ClientSession::handleElement(boost::shared_ptr element) { } } else if (StreamFeatures* streamFeatures = dynamic_cast(element.get())) { - if (!checkState(Negotiating)) { - return; - } + CHECK_STATE_OR_RETURN(Negotiating); if (streamFeatures->hasStartTLS() && stream->supportsTLSEncryption() && useTLS != NeverUseTLS) { state = WaitingForEncrypt; @@ -200,6 +202,7 @@ void ClientSession::handleElement(boost::shared_ptr element) { else if (streamFeatures->hasAuthenticationMechanisms()) { if (stream->hasTLSCertificate()) { if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { + authenticator = new EXTERNALClientAuthenticator(); state = Authenticating; stream->writeElement(boost::make_shared("EXTERNAL", createSafeByteArray(""))); } @@ -208,6 +211,7 @@ void ClientSession::handleElement(boost::shared_ptr element) { } } else if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { + authenticator = new EXTERNALClientAuthenticator(); state = Authenticating; stream->writeElement(boost::make_shared("EXTERNAL", createSafeByteArray(""))); } @@ -262,7 +266,7 @@ void ClientSession::handleElement(boost::shared_ptr element) { } } else if (boost::dynamic_pointer_cast(element)) { - checkState(Compressing); + CHECK_STATE_OR_RETURN(Compressing); state = WaitingForStreamStart; stream->addZLibCompression(); stream->resetXMPPParser(); @@ -285,7 +289,7 @@ void ClientSession::handleElement(boost::shared_ptr element) { continueSessionInitialization(); } else if (AuthChallenge* challenge = dynamic_cast(element.get())) { - checkState(Authenticating); + CHECK_STATE_OR_RETURN(Authenticating); assert(authenticator); if (authenticator->setChallenge(challenge->getValue())) { stream->writeElement(boost::make_shared(authenticator->getResponse())); @@ -295,10 +299,9 @@ void ClientSession::handleElement(boost::shared_ptr element) { } } else if (AuthSuccess* authSuccess = dynamic_cast(element.get())) { - checkState(Authenticating); - if (authenticator && !authenticator->setChallenge(authSuccess->getValue())) { - delete authenticator; - authenticator = NULL; + CHECK_STATE_OR_RETURN(Authenticating); + assert(authenticator); + if (!authenticator->setChallenge(authSuccess->getValue())) { finishSession(Error::ServerVerificationFailedError); } else { @@ -310,12 +313,10 @@ void ClientSession::handleElement(boost::shared_ptr element) { } } else if (dynamic_cast(element.get())) { - delete authenticator; - authenticator = NULL; finishSession(Error::AuthenticationFailedError); } else if (dynamic_cast(element.get())) { - checkState(WaitingForEncrypt); + CHECK_STATE_OR_RETURN(WaitingForEncrypt); state = Encrypting; stream->addTLSEncryption(); } @@ -362,13 +363,14 @@ bool ClientSession::checkState(State state) { void ClientSession::sendCredentials(const SafeByteArray& password) { assert(WaitingForCredentials); + assert(authenticator); state = Authenticating; authenticator->setCredentials(localJID.getNode(), password); stream->writeElement(boost::make_shared(authenticator->getName(), authenticator->getResponse())); } void ClientSession::handleTLSEncrypted() { - checkState(Encrypting); + CHECK_STATE_OR_RETURN(Encrypting); std::vector certificateChain = stream->getPeerCertificateChain(); boost::shared_ptr verificationError = stream->getPeerCertificateVerificationError(); @@ -448,6 +450,10 @@ void ClientSession::finishSession(boost::shared_ptr error) { if (stanzaAckResponder_) { stanzaAckResponder_->handleAckRequestReceived(); } + if (authenticator) { + delete authenticator; + authenticator = NULL; + } stream->writeFooter(); stream->close(); } diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp index d1ca70a..a8cd53c 100644 --- a/Swiften/Client/UnitTest/ClientSessionTest.cpp +++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -47,8 +48,10 @@ class ClientSessionTest : public CppUnit::TestFixture { CPPUNIT_TEST(testAuthenticate_NoValidAuthMechanisms); CPPUNIT_TEST(testAuthenticate_PLAINOverNonTLS); CPPUNIT_TEST(testAuthenticate_RequireTLS); + CPPUNIT_TEST(testAuthenticate_EXTERNAL); CPPUNIT_TEST(testStreamManagement); CPPUNIT_TEST(testStreamManagement_Failed); + CPPUNIT_TEST(testUnexpectedChallenge); CPPUNIT_TEST(testFinishAcksStanzas); /* CPPUNIT_TEST(testResourceBind); @@ -247,6 +250,34 @@ class ClientSessionTest : public CppUnit::TestFixture { CPPUNIT_ASSERT(sessionFinishedError); } + void testAuthenticate_EXTERNAL() { + boost::shared_ptr session(createSession()); + session->start(); + server->receiveStreamStart(); + server->sendStreamStart(); + server->sendStreamFeaturesWithEXTERNALAuthentication(); + server->receiveAuthRequest("EXTERNAL"); + server->sendAuthSuccess(); + server->receiveStreamStart(); + + session->finish(); + } + + void testUnexpectedChallenge() { + boost::shared_ptr session(createSession()); + session->start(); + server->receiveStreamStart(); + server->sendStreamStart(); + server->sendStreamFeaturesWithEXTERNALAuthentication(); + server->receiveAuthRequest("EXTERNAL"); + server->sendChallenge(); + server->sendChallenge(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT(sessionFinishedReceived); + CPPUNIT_ASSERT(sessionFinishedError); + } + void testStreamManagement() { boost::shared_ptr session(createSession()); session->start(); @@ -444,6 +475,10 @@ class ClientSessionTest : public CppUnit::TestFixture { onElementReceived(streamFeatures); } + void sendChallenge() { + onElementReceived(boost::make_shared()); + } + void sendStreamError() { onElementReceived(boost::make_shared()); } @@ -470,6 +505,12 @@ class ClientSessionTest : public CppUnit::TestFixture { onElementReceived(streamFeatures); } + void sendStreamFeaturesWithEXTERNALAuthentication() { + boost::shared_ptr streamFeatures(new StreamFeatures()); + streamFeatures->addAuthenticationMechanism("EXTERNAL"); + onElementReceived(streamFeatures); + } + void sendStreamFeaturesWithUnknownAuthentication() { boost::shared_ptr streamFeatures(new StreamFeatures()); streamFeatures->addAuthenticationMechanism("UNKNOWN"); diff --git a/Swiften/SASL/EXTERNALClientAuthenticator.cpp b/Swiften/SASL/EXTERNALClientAuthenticator.cpp new file mode 100644 index 0000000..a3016d1 --- /dev/null +++ b/Swiften/SASL/EXTERNALClientAuthenticator.cpp @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2012 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include + +namespace Swift { + +EXTERNALClientAuthenticator::EXTERNALClientAuthenticator() : ClientAuthenticator("EXTERNAL"), finished(false) { +} + +boost::optional EXTERNALClientAuthenticator::getResponse() const { + return boost::optional(); +} + +bool EXTERNALClientAuthenticator::setChallenge(const boost::optional&) { + if (finished) { + return false; + } + finished = true; + return true; +} + +} diff --git a/Swiften/SASL/EXTERNALClientAuthenticator.h b/Swiften/SASL/EXTERNALClientAuthenticator.h new file mode 100644 index 0000000..b986295 --- /dev/null +++ b/Swiften/SASL/EXTERNALClientAuthenticator.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2012 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include +#include + +namespace Swift { + class EXTERNALClientAuthenticator : public ClientAuthenticator { + public: + EXTERNALClientAuthenticator(); + + virtual boost::optional getResponse() const; + virtual bool setChallenge(const boost::optional&); + + private: + bool finished; + }; +} diff --git a/Swiften/SASL/SConscript b/Swiften/SASL/SConscript index 3a67938..6509547 100644 --- a/Swiften/SASL/SConscript +++ b/Swiften/SASL/SConscript @@ -4,6 +4,7 @@ myenv = swiften_env.Clone() objects = myenv.SwiftenObject([ "ClientAuthenticator.cpp", + "EXTERNALClientAuthenticator.cpp", "PLAINClientAuthenticator.cpp", "PLAINMessage.cpp", "SCRAMSHA1ClientAuthenticator.cpp", -- cgit v0.10.2-6-g49f6