summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to 'Swiften/Client')
-rw-r--r--Swiften/Client/ClientSession.cpp34
-rw-r--r--Swiften/Client/UnitTest/ClientSessionTest.cpp41
2 files changed, 61 insertions, 14 deletions
diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp
index 7e1f517..48e38b9 100644
--- a/Swiften/Client/ClientSession.cpp
+++ b/Swiften/Client/ClientSession.cpp
@@ -37,6 +37,7 @@
#include <Swiften/Elements/IQ.h>
#include <Swiften/Elements/ResourceBind.h>
#include <Swiften/SASL/PLAINClientAuthenticator.h>
+#include <Swiften/SASL/EXTERNALClientAuthenticator.h>
#include <Swiften/SASL/SCRAMSHA1ClientAuthenticator.h>
#include <Swiften/SASL/DIGESTMD5ClientAuthenticator.h>
#include <Swiften/Session/SessionStream.h>
@@ -48,6 +49,9 @@
#include <Swiften/Base/WindowsRegistry.h>
#endif
+#define CHECK_STATE_OR_RETURN(a) \
+ if (!checkState(a)) { return; }
+
namespace Swift {
ClientSession::ClientSession(
@@ -101,7 +105,7 @@ void ClientSession::sendStanza(boost::shared_ptr<Stanza> stanza) {
}
void ClientSession::handleStreamStart(const ProtocolHeader&) {
- checkState(WaitingForStreamStart);
+ CHECK_STATE_OR_RETURN(WaitingForStreamStart);
state = Negotiating;
}
@@ -182,9 +186,7 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
}
}
else if (StreamFeatures* streamFeatures = dynamic_cast<StreamFeatures*>(element.get())) {
- if (!checkState(Negotiating)) {
- return;
- }
+ CHECK_STATE_OR_RETURN(Negotiating);
if (streamFeatures->hasStartTLS() && stream->supportsTLSEncryption() && useTLS != NeverUseTLS) {
state = WaitingForEncrypt;
@@ -200,6 +202,7 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
else if (streamFeatures->hasAuthenticationMechanisms()) {
if (stream->hasTLSCertificate()) {
if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) {
+ authenticator = new EXTERNALClientAuthenticator();
state = Authenticating;
stream->writeElement(boost::make_shared<AuthRequest>("EXTERNAL", createSafeByteArray("")));
}
@@ -208,6 +211,7 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
}
}
else if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) {
+ authenticator = new EXTERNALClientAuthenticator();
state = Authenticating;
stream->writeElement(boost::make_shared<AuthRequest>("EXTERNAL", createSafeByteArray("")));
}
@@ -262,7 +266,7 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
}
}
else if (boost::dynamic_pointer_cast<Compressed>(element)) {
- checkState(Compressing);
+ CHECK_STATE_OR_RETURN(Compressing);
state = WaitingForStreamStart;
stream->addZLibCompression();
stream->resetXMPPParser();
@@ -285,7 +289,7 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
continueSessionInitialization();
}
else if (AuthChallenge* challenge = dynamic_cast<AuthChallenge*>(element.get())) {
- checkState(Authenticating);
+ CHECK_STATE_OR_RETURN(Authenticating);
assert(authenticator);
if (authenticator->setChallenge(challenge->getValue())) {
stream->writeElement(boost::make_shared<AuthResponse>(authenticator->getResponse()));
@@ -295,10 +299,9 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
}
}
else if (AuthSuccess* authSuccess = dynamic_cast<AuthSuccess*>(element.get())) {
- checkState(Authenticating);
- if (authenticator && !authenticator->setChallenge(authSuccess->getValue())) {
- delete authenticator;
- authenticator = NULL;
+ CHECK_STATE_OR_RETURN(Authenticating);
+ assert(authenticator);
+ if (!authenticator->setChallenge(authSuccess->getValue())) {
finishSession(Error::ServerVerificationFailedError);
}
else {
@@ -310,12 +313,10 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
}
}
else if (dynamic_cast<AuthFailure*>(element.get())) {
- delete authenticator;
- authenticator = NULL;
finishSession(Error::AuthenticationFailedError);
}
else if (dynamic_cast<TLSProceed*>(element.get())) {
- checkState(WaitingForEncrypt);
+ CHECK_STATE_OR_RETURN(WaitingForEncrypt);
state = Encrypting;
stream->addTLSEncryption();
}
@@ -362,13 +363,14 @@ bool ClientSession::checkState(State state) {
void ClientSession::sendCredentials(const SafeByteArray& password) {
assert(WaitingForCredentials);
+ assert(authenticator);
state = Authenticating;
authenticator->setCredentials(localJID.getNode(), password);
stream->writeElement(boost::make_shared<AuthRequest>(authenticator->getName(), authenticator->getResponse()));
}
void ClientSession::handleTLSEncrypted() {
- checkState(Encrypting);
+ CHECK_STATE_OR_RETURN(Encrypting);
std::vector<Certificate::ref> certificateChain = stream->getPeerCertificateChain();
boost::shared_ptr<CertificateVerificationError> verificationError = stream->getPeerCertificateVerificationError();
@@ -448,6 +450,10 @@ void ClientSession::finishSession(boost::shared_ptr<Swift::Error> error) {
if (stanzaAckResponder_) {
stanzaAckResponder_->handleAckRequestReceived();
}
+ if (authenticator) {
+ delete authenticator;
+ authenticator = NULL;
+ }
stream->writeFooter();
stream->close();
}
diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp
index d1ca70a..a8cd53c 100644
--- a/Swiften/Client/UnitTest/ClientSessionTest.cpp
+++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp
@@ -14,6 +14,7 @@
#include <Swiften/Session/SessionStream.h>
#include <Swiften/Client/ClientSession.h>
#include <Swiften/Elements/Message.h>
+#include <Swiften/Elements/AuthChallenge.h>
#include <Swiften/Elements/StartTLSRequest.h>
#include <Swiften/Elements/StreamFeatures.h>
#include <Swiften/Elements/StreamError.h>
@@ -47,8 +48,10 @@ class ClientSessionTest : public CppUnit::TestFixture {
CPPUNIT_TEST(testAuthenticate_NoValidAuthMechanisms);
CPPUNIT_TEST(testAuthenticate_PLAINOverNonTLS);
CPPUNIT_TEST(testAuthenticate_RequireTLS);
+ CPPUNIT_TEST(testAuthenticate_EXTERNAL);
CPPUNIT_TEST(testStreamManagement);
CPPUNIT_TEST(testStreamManagement_Failed);
+ CPPUNIT_TEST(testUnexpectedChallenge);
CPPUNIT_TEST(testFinishAcksStanzas);
/*
CPPUNIT_TEST(testResourceBind);
@@ -247,6 +250,34 @@ class ClientSessionTest : public CppUnit::TestFixture {
CPPUNIT_ASSERT(sessionFinishedError);
}
+ void testAuthenticate_EXTERNAL() {
+ boost::shared_ptr<ClientSession> session(createSession());
+ session->start();
+ server->receiveStreamStart();
+ server->sendStreamStart();
+ server->sendStreamFeaturesWithEXTERNALAuthentication();
+ server->receiveAuthRequest("EXTERNAL");
+ server->sendAuthSuccess();
+ server->receiveStreamStart();
+
+ session->finish();
+ }
+
+ void testUnexpectedChallenge() {
+ boost::shared_ptr<ClientSession> session(createSession());
+ session->start();
+ server->receiveStreamStart();
+ server->sendStreamStart();
+ server->sendStreamFeaturesWithEXTERNALAuthentication();
+ server->receiveAuthRequest("EXTERNAL");
+ server->sendChallenge();
+ server->sendChallenge();
+
+ CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
+ CPPUNIT_ASSERT(sessionFinishedReceived);
+ CPPUNIT_ASSERT(sessionFinishedError);
+ }
+
void testStreamManagement() {
boost::shared_ptr<ClientSession> session(createSession());
session->start();
@@ -444,6 +475,10 @@ class ClientSessionTest : public CppUnit::TestFixture {
onElementReceived(streamFeatures);
}
+ void sendChallenge() {
+ onElementReceived(boost::make_shared<AuthChallenge>());
+ }
+
void sendStreamError() {
onElementReceived(boost::make_shared<StreamError>());
}
@@ -470,6 +505,12 @@ class ClientSessionTest : public CppUnit::TestFixture {
onElementReceived(streamFeatures);
}
+ void sendStreamFeaturesWithEXTERNALAuthentication() {
+ boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures());
+ streamFeatures->addAuthenticationMechanism("EXTERNAL");
+ onElementReceived(streamFeatures);
+ }
+
void sendStreamFeaturesWithUnknownAuthentication() {
boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures());
streamFeatures->addAuthenticationMechanism("UNKNOWN");