diff options
-rw-r--r-- | Swiften/Client/ClientOptions.h | 3 | ||||
-rw-r--r-- | Swiften/Client/ClientSession.cpp | 3 | ||||
-rw-r--r-- | Swiften/Client/ClientSession.h | 3 | ||||
-rw-r--r-- | Swiften/Client/CoreClient.cpp | 3 | ||||
-rw-r--r-- | Swiften/Client/UnitTest/ClientSessionTest.cpp | 24 |
5 files changed, 34 insertions, 2 deletions
diff --git a/Swiften/Client/ClientOptions.h b/Swiften/Client/ClientOptions.h index 0766402..6b15f18 100644 --- a/Swiften/Client/ClientOptions.h +++ b/Swiften/Client/ClientOptions.h @@ -4,19 +4,20 @@ * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once namespace Swift { struct ClientOptions { enum UseTLS { NeverUseTLS, - UseTLSWhenAvailable + UseTLSWhenAvailable, + RequireTLS }; ClientOptions() : useStreamCompression(true), useTLS(UseTLSWhenAvailable), allowPLAINWithoutTLS(false), useStreamResumption(false), forgetPassword(false) { } /** * Whether ZLib stream compression should be used when available. * * Default: true diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp index 8945e9a..2eeb3c0 100644 --- a/Swiften/Client/ClientSession.cpp +++ b/Swiften/Client/ClientSession.cpp @@ -171,18 +171,21 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) { else if (StreamFeatures* streamFeatures = dynamic_cast<StreamFeatures*>(element.get())) { if (!checkState(Negotiating)) { return; } if (streamFeatures->hasStartTLS() && stream->supportsTLSEncryption() && useTLS != NeverUseTLS) { state = WaitingForEncrypt; stream->writeElement(boost::make_shared<StartTLSRequest>()); } + else if (useTLS == RequireTLS && !stream->isTLSEncrypted()) { + finishSession(Error::NoSupportedAuthMechanismsError); + } else if (useStreamCompression && streamFeatures->hasCompressionMethod("zlib")) { state = Compressing; stream->writeElement(boost::make_shared<CompressRequest>("zlib")); } else if (streamFeatures->hasAuthenticationMechanisms()) { if (stream->hasTLSCertificate()) { if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { state = Authenticating; stream->writeElement(boost::make_shared<AuthRequest>("EXTERNAL", createSafeByteArray(""))); diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h index ace9868..e58e758 100644 --- a/Swiften/Client/ClientSession.h +++ b/Swiften/Client/ClientSession.h @@ -53,19 +53,20 @@ namespace Swift { TLSClientCertificateError, TLSError, StreamError, } type; Error(Type type) : type(type) {} }; enum UseTLS { NeverUseTLS, - UseTLSWhenAvailable + UseTLSWhenAvailable, + RequireTLS }; ~ClientSession(); static boost::shared_ptr<ClientSession> create(const JID& jid, boost::shared_ptr<SessionStream> stream) { return boost::shared_ptr<ClientSession>(new ClientSession(jid, stream)); } State getState() const { diff --git a/Swiften/Client/CoreClient.cpp b/Swiften/Client/CoreClient.cpp index 9aaa66b..cceec74 100644 --- a/Swiften/Client/CoreClient.cpp +++ b/Swiften/Client/CoreClient.cpp @@ -113,18 +113,21 @@ void CoreClient::handleConnectorFinished(boost::shared_ptr<Connection> connectio session_->setUseStreamCompression(options.useStreamCompression); session_->setAllowPLAINOverNonTLS(options.allowPLAINWithoutTLS); switch(options.useTLS) { case ClientOptions::UseTLSWhenAvailable: session_->setUseTLS(ClientSession::UseTLSWhenAvailable); break; case ClientOptions::NeverUseTLS: session_->setUseTLS(ClientSession::NeverUseTLS); break; + case ClientOptions::RequireTLS: + session_->setUseTLS(ClientSession::RequireTLS); + break; } stanzaChannel_->setSession(session_); session_->onFinished.connect(boost::bind(&CoreClient::handleSessionFinished, this, _1)); session_->onNeedCredentials.connect(boost::bind(&CoreClient::handleNeedCredentials, this)); session_->start(); } } void CoreClient::disconnect() { diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp index 57e53e4..e9d1b21 100644 --- a/Swiften/Client/UnitTest/ClientSessionTest.cpp +++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp @@ -39,18 +39,20 @@ class ClientSessionTest : public CppUnit::TestFixture { CPPUNIT_TEST(testStart_StreamError); CPPUNIT_TEST(testStartTLS); CPPUNIT_TEST(testStartTLS_ServerError); CPPUNIT_TEST(testStartTLS_ConnectError); CPPUNIT_TEST(testStartTLS_InvalidIdentity); CPPUNIT_TEST(testStart_StreamFeaturesWithoutResourceBindingFails); CPPUNIT_TEST(testAuthenticate); CPPUNIT_TEST(testAuthenticate_Unauthorized); CPPUNIT_TEST(testAuthenticate_NoValidAuthMechanisms); + CPPUNIT_TEST(testAuthenticate_PLAINOverNonTLS); + CPPUNIT_TEST(testAuthenticate_RequireTLS); CPPUNIT_TEST(testStreamManagement); CPPUNIT_TEST(testStreamManagement_Failed); CPPUNIT_TEST(testFinishAcksStanzas); /* CPPUNIT_TEST(testResourceBind); CPPUNIT_TEST(testResourceBind_ChangeResource); CPPUNIT_TEST(testResourceBind_EmptyResource); CPPUNIT_TEST(testResourceBind_Error); CPPUNIT_TEST(testSessionStart); @@ -213,18 +215,32 @@ class ClientSessionTest : public CppUnit::TestFixture { server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithPLAINAuthentication(); CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } + void testAuthenticate_RequireTLS() { + boost::shared_ptr<ClientSession> session(createSession()); + session->setUseTLS(ClientSession::RequireTLS); + session->setAllowPLAINOverNonTLS(true); + session->start(); + server->receiveStreamStart(); + server->sendStreamStart(); + server->sendStreamFeaturesWithMultipleAuthentication(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT(sessionFinishedReceived); + CPPUNIT_ASSERT(sessionFinishedError); + } + void testAuthenticate_NoValidAuthMechanisms() { boost::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithUnknownAuthentication(); CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); @@ -426,18 +442,26 @@ class ClientSessionTest : public CppUnit::TestFixture { void sendTLSProceed() { onElementReceived(boost::shared_ptr<TLSProceed>(new TLSProceed())); } void sendTLSFailure() { onElementReceived(boost::shared_ptr<StartTLSFailure>(new StartTLSFailure())); } + void sendStreamFeaturesWithMultipleAuthentication() { + boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); + streamFeatures->addAuthenticationMechanism("PLAIN"); + streamFeatures->addAuthenticationMechanism("DIGEST-MD5"); + streamFeatures->addAuthenticationMechanism("SCRAM-SHA1"); + onElementReceived(streamFeatures); + } + void sendStreamFeaturesWithPLAINAuthentication() { boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); streamFeatures->addAuthenticationMechanism("PLAIN"); onElementReceived(streamFeatures); } void sendStreamFeaturesWithUnknownAuthentication() { boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); streamFeatures->addAuthenticationMechanism("UNKNOWN"); |