diff options
Diffstat (limited to 'Swiften/Client')
-rw-r--r-- | Swiften/Client/ClientSession.cpp | 34 | ||||
-rw-r--r-- | Swiften/Client/UnitTest/ClientSessionTest.cpp | 41 |
2 files changed, 61 insertions, 14 deletions
diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp index 7e1f517..48e38b9 100644 --- a/Swiften/Client/ClientSession.cpp +++ b/Swiften/Client/ClientSession.cpp @@ -37,6 +37,7 @@ #include <Swiften/Elements/IQ.h> #include <Swiften/Elements/ResourceBind.h> #include <Swiften/SASL/PLAINClientAuthenticator.h> +#include <Swiften/SASL/EXTERNALClientAuthenticator.h> #include <Swiften/SASL/SCRAMSHA1ClientAuthenticator.h> #include <Swiften/SASL/DIGESTMD5ClientAuthenticator.h> #include <Swiften/Session/SessionStream.h> @@ -48,6 +49,9 @@ #include <Swiften/Base/WindowsRegistry.h> #endif +#define CHECK_STATE_OR_RETURN(a) \ + if (!checkState(a)) { return; } + namespace Swift { ClientSession::ClientSession( @@ -101,7 +105,7 @@ void ClientSession::sendStanza(boost::shared_ptr<Stanza> stanza) { } void ClientSession::handleStreamStart(const ProtocolHeader&) { - checkState(WaitingForStreamStart); + CHECK_STATE_OR_RETURN(WaitingForStreamStart); state = Negotiating; } @@ -182,9 +186,7 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) { } } else if (StreamFeatures* streamFeatures = dynamic_cast<StreamFeatures*>(element.get())) { - if (!checkState(Negotiating)) { - return; - } + CHECK_STATE_OR_RETURN(Negotiating); if (streamFeatures->hasStartTLS() && stream->supportsTLSEncryption() && useTLS != NeverUseTLS) { state = WaitingForEncrypt; @@ -200,6 +202,7 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) { else if (streamFeatures->hasAuthenticationMechanisms()) { if (stream->hasTLSCertificate()) { if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { + authenticator = new EXTERNALClientAuthenticator(); state = Authenticating; stream->writeElement(boost::make_shared<AuthRequest>("EXTERNAL", createSafeByteArray(""))); } @@ -208,6 +211,7 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) { } } else if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { + authenticator = new EXTERNALClientAuthenticator(); state = Authenticating; stream->writeElement(boost::make_shared<AuthRequest>("EXTERNAL", createSafeByteArray(""))); } @@ -262,7 +266,7 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) { } } else if (boost::dynamic_pointer_cast<Compressed>(element)) { - checkState(Compressing); + CHECK_STATE_OR_RETURN(Compressing); state = WaitingForStreamStart; stream->addZLibCompression(); stream->resetXMPPParser(); @@ -285,7 +289,7 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) { continueSessionInitialization(); } else if (AuthChallenge* challenge = dynamic_cast<AuthChallenge*>(element.get())) { - checkState(Authenticating); + CHECK_STATE_OR_RETURN(Authenticating); assert(authenticator); if (authenticator->setChallenge(challenge->getValue())) { stream->writeElement(boost::make_shared<AuthResponse>(authenticator->getResponse())); @@ -295,10 +299,9 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) { } } else if (AuthSuccess* authSuccess = dynamic_cast<AuthSuccess*>(element.get())) { - checkState(Authenticating); - if (authenticator && !authenticator->setChallenge(authSuccess->getValue())) { - delete authenticator; - authenticator = NULL; + CHECK_STATE_OR_RETURN(Authenticating); + assert(authenticator); + if (!authenticator->setChallenge(authSuccess->getValue())) { finishSession(Error::ServerVerificationFailedError); } else { @@ -310,12 +313,10 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) { } } else if (dynamic_cast<AuthFailure*>(element.get())) { - delete authenticator; - authenticator = NULL; finishSession(Error::AuthenticationFailedError); } else if (dynamic_cast<TLSProceed*>(element.get())) { - checkState(WaitingForEncrypt); + CHECK_STATE_OR_RETURN(WaitingForEncrypt); state = Encrypting; stream->addTLSEncryption(); } @@ -362,13 +363,14 @@ bool ClientSession::checkState(State state) { void ClientSession::sendCredentials(const SafeByteArray& password) { assert(WaitingForCredentials); + assert(authenticator); state = Authenticating; authenticator->setCredentials(localJID.getNode(), password); stream->writeElement(boost::make_shared<AuthRequest>(authenticator->getName(), authenticator->getResponse())); } void ClientSession::handleTLSEncrypted() { - checkState(Encrypting); + CHECK_STATE_OR_RETURN(Encrypting); std::vector<Certificate::ref> certificateChain = stream->getPeerCertificateChain(); boost::shared_ptr<CertificateVerificationError> verificationError = stream->getPeerCertificateVerificationError(); @@ -448,6 +450,10 @@ void ClientSession::finishSession(boost::shared_ptr<Swift::Error> error) { if (stanzaAckResponder_) { stanzaAckResponder_->handleAckRequestReceived(); } + if (authenticator) { + delete authenticator; + authenticator = NULL; + } stream->writeFooter(); stream->close(); } diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp index d1ca70a..a8cd53c 100644 --- a/Swiften/Client/UnitTest/ClientSessionTest.cpp +++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp @@ -14,6 +14,7 @@ #include <Swiften/Session/SessionStream.h> #include <Swiften/Client/ClientSession.h> #include <Swiften/Elements/Message.h> +#include <Swiften/Elements/AuthChallenge.h> #include <Swiften/Elements/StartTLSRequest.h> #include <Swiften/Elements/StreamFeatures.h> #include <Swiften/Elements/StreamError.h> @@ -47,8 +48,10 @@ class ClientSessionTest : public CppUnit::TestFixture { CPPUNIT_TEST(testAuthenticate_NoValidAuthMechanisms); CPPUNIT_TEST(testAuthenticate_PLAINOverNonTLS); CPPUNIT_TEST(testAuthenticate_RequireTLS); + CPPUNIT_TEST(testAuthenticate_EXTERNAL); CPPUNIT_TEST(testStreamManagement); CPPUNIT_TEST(testStreamManagement_Failed); + CPPUNIT_TEST(testUnexpectedChallenge); CPPUNIT_TEST(testFinishAcksStanzas); /* CPPUNIT_TEST(testResourceBind); @@ -247,6 +250,34 @@ class ClientSessionTest : public CppUnit::TestFixture { CPPUNIT_ASSERT(sessionFinishedError); } + void testAuthenticate_EXTERNAL() { + boost::shared_ptr<ClientSession> session(createSession()); + session->start(); + server->receiveStreamStart(); + server->sendStreamStart(); + server->sendStreamFeaturesWithEXTERNALAuthentication(); + server->receiveAuthRequest("EXTERNAL"); + server->sendAuthSuccess(); + server->receiveStreamStart(); + + session->finish(); + } + + void testUnexpectedChallenge() { + boost::shared_ptr<ClientSession> session(createSession()); + session->start(); + server->receiveStreamStart(); + server->sendStreamStart(); + server->sendStreamFeaturesWithEXTERNALAuthentication(); + server->receiveAuthRequest("EXTERNAL"); + server->sendChallenge(); + server->sendChallenge(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT(sessionFinishedReceived); + CPPUNIT_ASSERT(sessionFinishedError); + } + void testStreamManagement() { boost::shared_ptr<ClientSession> session(createSession()); session->start(); @@ -444,6 +475,10 @@ class ClientSessionTest : public CppUnit::TestFixture { onElementReceived(streamFeatures); } + void sendChallenge() { + onElementReceived(boost::make_shared<AuthChallenge>()); + } + void sendStreamError() { onElementReceived(boost::make_shared<StreamError>()); } @@ -470,6 +505,12 @@ class ClientSessionTest : public CppUnit::TestFixture { onElementReceived(streamFeatures); } + void sendStreamFeaturesWithEXTERNALAuthentication() { + boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); + streamFeatures->addAuthenticationMechanism("EXTERNAL"); + onElementReceived(streamFeatures); + } + void sendStreamFeaturesWithUnknownAuthentication() { boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); streamFeatures->addAuthenticationMechanism("UNKNOWN"); |