diff options
author | Tobias Markmann <tm@ayena.de> | 2016-07-29 09:47:23 (GMT) |
---|---|---|
committer | Tobias Markmann <tm@ayena.de> | 2016-11-28 10:35:05 (GMT) |
commit | 2039930eadd4756068a8a60c8340d9908a7136d3 (patch) | |
tree | d8aca4bf98a2bb6e3b819305b1f87af3117f4910 /Swiften/Client/UnitTest/ClientSessionTest.cpp | |
parent | 2f90eb7409df91a80c60b189242ac0c1de313910 (diff) | |
download | swift-2039930eadd4756068a8a60c8340d9908a7136d3.zip swift-2039930eadd4756068a8a60c8340d9908a7136d3.tar.bz2 |
Correctly handle server initiated closing of stream
If a server closes the XMPP stream, it sends a </stream:stream>
tag. The client is supposed to respond with the same tag and
then both parties can close the TLS/TCP socket.
Previously Swift(-en) would simply ignore </stream:stream>
tag if it was not directly followed by a shutdown of the TCP
connection.
In addition there is now a timeout timer started as soon as
Swiften or the server initiates a shutdown. It will close
the socket and cleanup the ClientSession if the server does
not respond in time or the network is faulty.
Refactored some code in ClientSession in the process. Moved
ClientSession::State to a C++11 strongly typed enum class.
This also fixes issues where duplicated </stream:stream>
tags would be send by Swift.
Test-Information:
Tested against Prosody ba782a093b14 and M-Link 16.3v6-0,
which provide ad-hoc commands to end a user session.
Previously this was ignored by Swift. Now it correctly responds
to the server, detects it as a disconnect and tries to
reconnect afterwards.
Added unit test for the case where the server closes the
session stream.
Change-Id: I59dfde3aa6b50dc117f340e5db6b9e58b54b3c60
Diffstat (limited to 'Swiften/Client/UnitTest/ClientSessionTest.cpp')
-rw-r--r-- | Swiften/Client/UnitTest/ClientSessionTest.cpp | 111 |
1 files changed, 95 insertions, 16 deletions
diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp index bd93f4b..44df961 100644 --- a/Swiften/Client/UnitTest/ClientSessionTest.cpp +++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp @@ -1,380 +1,458 @@ /* * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #include <deque> #include <memory> #include <boost/bind.hpp> #include <boost/optional.hpp> +#include <Swiften/Base/Debug.h> + #include <cppunit/extensions/HelperMacros.h> #include <cppunit/extensions/TestFactoryRegistry.h> #include <Swiften/Client/ClientSession.h> #include <Swiften/Crypto/CryptoProvider.h> #include <Swiften/Crypto/PlatformCryptoProvider.h> #include <Swiften/Elements/AuthChallenge.h> #include <Swiften/Elements/AuthFailure.h> #include <Swiften/Elements/AuthRequest.h> #include <Swiften/Elements/AuthSuccess.h> #include <Swiften/Elements/EnableStreamManagement.h> #include <Swiften/Elements/IQ.h> #include <Swiften/Elements/Message.h> #include <Swiften/Elements/ResourceBind.h> #include <Swiften/Elements/StanzaAck.h> #include <Swiften/Elements/StartTLSFailure.h> #include <Swiften/Elements/StartTLSRequest.h> #include <Swiften/Elements/StreamError.h> #include <Swiften/Elements/StreamFeatures.h> #include <Swiften/Elements/StreamManagementEnabled.h> #include <Swiften/Elements/StreamManagementFailed.h> #include <Swiften/Elements/TLSProceed.h> #include <Swiften/IDN/IDNConverter.h> #include <Swiften/IDN/PlatformIDNConverter.h> +#include <Swiften/Network/DummyTimerFactory.h> #include <Swiften/Session/SessionStream.h> #include <Swiften/TLS/BlindCertificateTrustChecker.h> #include <Swiften/TLS/SimpleCertificate.h> using namespace Swift; class ClientSessionTest : public CppUnit::TestFixture { CPPUNIT_TEST_SUITE(ClientSessionTest); CPPUNIT_TEST(testStart_Error); CPPUNIT_TEST(testStart_StreamError); CPPUNIT_TEST(testStartTLS); CPPUNIT_TEST(testStartTLS_ServerError); CPPUNIT_TEST(testStartTLS_ConnectError); CPPUNIT_TEST(testStartTLS_InvalidIdentity); CPPUNIT_TEST(testStart_StreamFeaturesWithoutResourceBindingFails); CPPUNIT_TEST(testAuthenticate); CPPUNIT_TEST(testAuthenticate_Unauthorized); CPPUNIT_TEST(testAuthenticate_NoValidAuthMechanisms); CPPUNIT_TEST(testAuthenticate_PLAINOverNonTLS); CPPUNIT_TEST(testAuthenticate_RequireTLS); CPPUNIT_TEST(testAuthenticate_EXTERNAL); CPPUNIT_TEST(testStreamManagement); CPPUNIT_TEST(testStreamManagement_Failed); CPPUNIT_TEST(testUnexpectedChallenge); CPPUNIT_TEST(testFinishAcksStanzas); + + CPPUNIT_TEST(testServerInitiatedSessionClose); + CPPUNIT_TEST(testClientInitiatedSessionClose); + CPPUNIT_TEST(testTimeoutOnShutdown); + /* CPPUNIT_TEST(testResourceBind); CPPUNIT_TEST(testResourceBind_ChangeResource); CPPUNIT_TEST(testResourceBind_EmptyResource); CPPUNIT_TEST(testResourceBind_Error); CPPUNIT_TEST(testSessionStart); CPPUNIT_TEST(testSessionStart_Error); CPPUNIT_TEST(testSessionStart_AfterResourceBind); CPPUNIT_TEST(testWhitespacePing); CPPUNIT_TEST(testReceiveElementAfterSessionStarted); CPPUNIT_TEST(testSendElement); */ CPPUNIT_TEST_SUITE_END(); public: void setUp() { crypto = std::shared_ptr<CryptoProvider>(PlatformCryptoProvider::create()); idnConverter = std::shared_ptr<IDNConverter>(PlatformIDNConverter::create()); + timerFactory = std::make_shared<DummyTimerFactory>(); server = std::make_shared<MockSessionStream>(); sessionFinishedReceived = false; needCredentials = false; blindCertificateTrustChecker = new BlindCertificateTrustChecker(); } void tearDown() { delete blindCertificateTrustChecker; } void testStart_Error() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->breakConnection(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testStart_StreamError() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->sendStreamStart(); server->sendStreamError(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testStartTLS() { std::shared_ptr<ClientSession> session(createSession()); session->setCertificateTrustChecker(blindCertificateTrustChecker); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithStartTLS(); server->receiveStartTLS(); CPPUNIT_ASSERT(!server->tlsEncrypted); server->sendTLSProceed(); CPPUNIT_ASSERT(server->tlsEncrypted); server->onTLSEncrypted(); server->receiveStreamStart(); server->sendStreamStart(); session->finish(); } void testStartTLS_ServerError() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithStartTLS(); server->receiveStartTLS(); server->sendTLSFailure(); CPPUNIT_ASSERT(!server->tlsEncrypted); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testStartTLS_ConnectError() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithStartTLS(); server->receiveStartTLS(); server->sendTLSProceed(); server->breakTLS(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testStartTLS_InvalidIdentity() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithStartTLS(); server->receiveStartTLS(); CPPUNIT_ASSERT(!server->tlsEncrypted); server->sendTLSProceed(); CPPUNIT_ASSERT(server->tlsEncrypted); server->onTLSEncrypted(); + server->close(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); + CPPUNIT_ASSERT(std::dynamic_pointer_cast<CertificateVerificationError>(sessionFinishedError)); CPPUNIT_ASSERT_EQUAL(CertificateVerificationError::InvalidServerIdentity, std::dynamic_pointer_cast<CertificateVerificationError>(sessionFinishedError)->getType()); } void testStart_StreamFeaturesWithoutResourceBindingFails() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendEmptyStreamFeatures(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testAuthenticate() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithPLAINAuthentication(); CPPUNIT_ASSERT(needCredentials); - CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForCredentials, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::WaitingForCredentials, session->getState()); session->sendCredentials(createSafeByteArray("mypass")); server->receiveAuthRequest("PLAIN"); server->sendAuthSuccess(); server->receiveStreamStart(); session->finish(); } void testAuthenticate_Unauthorized() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithPLAINAuthentication(); CPPUNIT_ASSERT(needCredentials); - CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForCredentials, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::WaitingForCredentials, session->getState()); session->sendCredentials(createSafeByteArray("mypass")); server->receiveAuthRequest("PLAIN"); server->sendAuthFailure(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testAuthenticate_PLAINOverNonTLS() { std::shared_ptr<ClientSession> session(createSession()); session->setAllowPLAINOverNonTLS(false); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithPLAINAuthentication(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testAuthenticate_RequireTLS() { std::shared_ptr<ClientSession> session(createSession()); session->setUseTLS(ClientSession::RequireTLS); session->setAllowPLAINOverNonTLS(true); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithMultipleAuthentication(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testAuthenticate_NoValidAuthMechanisms() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithUnknownAuthentication(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + server->close(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testAuthenticate_EXTERNAL() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithEXTERNALAuthentication(); server->receiveAuthRequest("EXTERNAL"); server->sendAuthSuccess(); server->receiveStreamStart(); session->finish(); } void testUnexpectedChallenge() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithEXTERNALAuthentication(); server->receiveAuthRequest("EXTERNAL"); server->sendChallenge(); server->sendChallenge(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testStreamManagement() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithPLAINAuthentication(); session->sendCredentials(createSafeByteArray("mypass")); server->receiveAuthRequest("PLAIN"); server->sendAuthSuccess(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithBindAndStreamManagement(); server->receiveBind(); server->sendBindResult(); server->receiveStreamManagementEnable(); server->sendStreamManagementEnabled(); CPPUNIT_ASSERT(session->getStreamManagementEnabled()); // TODO: Test if the requesters & responders do their work - CPPUNIT_ASSERT_EQUAL(ClientSession::Initialized, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Initialized, session->getState()); session->finish(); } void testStreamManagement_Failed() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithPLAINAuthentication(); session->sendCredentials(createSafeByteArray("mypass")); server->receiveAuthRequest("PLAIN"); server->sendAuthSuccess(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithBindAndStreamManagement(); server->receiveBind(); server->sendBindResult(); server->receiveStreamManagementEnable(); server->sendStreamManagementFailed(); CPPUNIT_ASSERT(!session->getStreamManagementEnabled()); - CPPUNIT_ASSERT_EQUAL(ClientSession::Initialized, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Initialized, session->getState()); session->finish(); } void testFinishAcksStanzas() { std::shared_ptr<ClientSession> session(createSession()); initializeSession(session); server->sendMessage(); server->sendMessage(); server->sendMessage(); session->finish(); server->receiveAck(3); } + void testServerInitiatedSessionClose() { + std::shared_ptr<ClientSession> session(createSession()); + initializeSession(session); + + server->onStreamEndReceived(); + server->close(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + } + + void testClientInitiatedSessionClose() { + std::shared_ptr<ClientSession> session(createSession()); + initializeSession(session); + + session->finish(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); + } + + void testTimeoutOnShutdown() { + std::shared_ptr<ClientSession> session(createSession()); + initializeSession(session); + + session->finish(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + timerFactory->setTime(60000); + + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); + CPPUNIT_ASSERT(sessionFinishedReceived); + } + private: std::shared_ptr<ClientSession> createSession() { - std::shared_ptr<ClientSession> session = ClientSession::create(JID("me@foo.com"), server, idnConverter.get(), crypto.get()); + std::shared_ptr<ClientSession> session = ClientSession::create(JID("me@foo.com"), server, idnConverter.get(), crypto.get(), timerFactory.get()); session->onFinished.connect(boost::bind(&ClientSessionTest::handleSessionFinished, this, _1)); session->onNeedCredentials.connect(boost::bind(&ClientSessionTest::handleSessionNeedCredentials, this)); session->setAllowPLAINOverNonTLS(true); return session; } void initializeSession(std::shared_ptr<ClientSession> session) { session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithPLAINAuthentication(); session->sendCredentials(createSafeByteArray("mypass")); server->receiveAuthRequest("PLAIN"); server->sendAuthSuccess(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithBindAndStreamManagement(); server->receiveBind(); server->sendBindResult(); server->receiveStreamManagementEnable(); server->sendStreamManagementEnabled(); } void handleSessionFinished(std::shared_ptr<Error> error) { sessionFinishedReceived = true; sessionFinishedError = error; } void handleSessionNeedCredentials() { needCredentials = true; @@ -604,60 +682,61 @@ class ClientSessionTest : public CppUnit::TestFixture { CPPUNIT_ASSERT(event.element); std::shared_ptr<StanzaAck> ack = std::dynamic_pointer_cast<StanzaAck>(event.element); CPPUNIT_ASSERT(ack); CPPUNIT_ASSERT_EQUAL(n, ack->getHandledStanzasCount()); } Event popEvent() { CPPUNIT_ASSERT(!receivedEvents.empty()); Event event = receivedEvents.front(); receivedEvents.pop_front(); return event; } bool available; bool canTLSEncrypt; bool tlsEncrypted; bool compressed; bool whitespacePingEnabled; std::string bindID; int resetCount; std::deque<Event> receivedEvents; }; std::shared_ptr<IDNConverter> idnConverter; std::shared_ptr<MockSessionStream> server; bool sessionFinishedReceived; bool needCredentials; std::shared_ptr<Error> sessionFinishedError; BlindCertificateTrustChecker* blindCertificateTrustChecker; std::shared_ptr<CryptoProvider> crypto; + std::shared_ptr<DummyTimerFactory> timerFactory; }; CPPUNIT_TEST_SUITE_REGISTRATION(ClientSessionTest); #if 0 void testAuthenticate() { std::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); session->onNeedCredentials.connect(boost::bind(&ClientSessionTest::setNeedCredentials, this)); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithAuthentication(); session->startSession(); processEvents(); CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForCredentials, session->getState()); CPPUNIT_ASSERT(needCredentials_); getMockServer()->expectAuth("me", "mypass"); getMockServer()->sendAuthSuccess(); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); session->sendCredentials("mypass"); CPPUNIT_ASSERT_EQUAL(ClientSession::Authenticating, session->getState()); processEvents(); CPPUNIT_ASSERT_EQUAL(ClientSession::Negotiating, session->getState()); } void testAuthenticate_Unauthorized() { std::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); |