diff options
Diffstat (limited to 'Swiften/Client')
-rw-r--r-- | Swiften/Client/Client.cpp | 48 | ||||
-rw-r--r-- | Swiften/Client/Client.h | 2 | ||||
-rw-r--r-- | Swiften/Client/Session.cpp | 46 | ||||
-rw-r--r-- | Swiften/Client/Session.h | 13 | ||||
-rw-r--r-- | Swiften/Client/UnitTest/SessionTest.cpp | 139 |
5 files changed, 99 insertions, 149 deletions
diff --git a/Swiften/Client/Client.cpp b/Swiften/Client/Client.cpp index d87673b..04a24bf 100644 --- a/Swiften/Client/Client.cpp +++ b/Swiften/Client/Client.cpp @@ -2,9 +2,11 @@ #include <boost/bind.hpp> +#include "Swiften/Network/DomainNameResolver.h" #include "Swiften/Client/Session.h" #include "Swiften/StreamStack/PlatformTLSLayerFactory.h" #include "Swiften/Network/BoostConnectionFactory.h" +#include "Swiften/Network/DomainNameResolveException.h" #include "Swiften/TLS/PKCS12Certificate.h" namespace Swift { @@ -23,17 +25,37 @@ Client::~Client() { void Client::connect() { delete session_; - session_ = new Session(jid_, connectionFactory_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_); - if (!certificate_.isEmpty()) { - session_->setCertificate(PKCS12Certificate(certificate_, password_)); + session_ = 0; + + DomainNameResolver resolver; + try { + HostAddressPort remote = resolver.resolve(jid_.getDomain().getUTF8String()); + connection_ = connectionFactory_->createConnection(); + connection_->onConnectFinished.connect(boost::bind(&Client::handleConnectionConnectFinished, this, _1)); + connection_->connect(remote); + } + catch (const DomainNameResolveException& e) { + onError(ClientError::DomainNameResolveError); + } +} + +void Client::handleConnectionConnectFinished(bool error) { + if (error) { + onError(ClientError::ConnectionError); + } + else { + session_ = new Session(jid_, connection_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_); + if (!certificate_.isEmpty()) { + session_->setCertificate(PKCS12Certificate(certificate_, password_)); + } + session_->onSessionStarted.connect(boost::bind(boost::ref(onConnected))); + session_->onError.connect(boost::bind(&Client::handleSessionError, this, _1)); + session_->onNeedCredentials.connect(boost::bind(&Client::handleNeedCredentials, this)); + session_->onDataRead.connect(boost::bind(&Client::handleDataRead, this, _1)); + session_->onDataWritten.connect(boost::bind(&Client::handleDataWritten, this, _1)); + session_->onElementReceived.connect(boost::bind(&Client::handleElement, this, _1)); + session_->start(); } - session_->onSessionStarted.connect(boost::bind(boost::ref(onConnected))); - session_->onError.connect(boost::bind(&Client::handleSessionError, this, _1)); - session_->onNeedCredentials.connect(boost::bind(&Client::handleNeedCredentials, this)); - session_->onDataRead.connect(boost::bind(&Client::handleDataRead, this, _1)); - session_->onDataWritten.connect(boost::bind(&Client::handleDataWritten, this, _1)); - session_->onElementReceived.connect(boost::bind(&Client::handleElement, this, _1)); - session_->start(); } void Client::disconnect() { @@ -92,12 +114,6 @@ void Client::handleSessionError(Session::SessionError error) { case Session::NoError: assert(false); break; - case Session::DomainNameResolveError: - clientError = ClientError(ClientError::DomainNameResolveError); - break; - case Session::ConnectionError: - clientError = ClientError(ClientError::ConnectionError); - break; case Session::ConnectionReadError: clientError = ClientError(ClientError::ConnectionReadError); break; diff --git a/Swiften/Client/Client.h b/Swiften/Client/Client.h index d876302..66f9b01 100644 --- a/Swiften/Client/Client.h +++ b/Swiften/Client/Client.h @@ -43,6 +43,7 @@ namespace Swift { boost::signal<void (const String&)> onDataWritten; private: + void handleConnectionConnectFinished(bool error); void send(boost::shared_ptr<Stanza>); virtual String getNewIQID(); void handleElement(boost::shared_ptr<Element>); @@ -61,6 +62,7 @@ namespace Swift { FullPayloadParserFactoryCollection payloadParserFactories_; FullPayloadSerializerCollection payloadSerializers_; Session* session_; + boost::shared_ptr<Connection> connection_; String certificate_; }; } diff --git a/Swiften/Client/Session.cpp b/Swiften/Client/Session.cpp index 6c2a873..1bd2b22 100644 --- a/Swiften/Client/Session.cpp +++ b/Swiften/Client/Session.cpp @@ -24,16 +24,21 @@ namespace Swift { -Session::Session(const JID& jid, ConnectionFactory* connectionFactory, TLSLayerFactory* tlsLayerFactory, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers) : - jid_(jid), - connectionFactory_(connectionFactory), - tlsLayerFactory_(tlsLayerFactory), - payloadParserFactories_(payloadParserFactories), - payloadSerializers_(payloadSerializers), - state_(Initial), - error_(NoError), - streamStack_(0), - needSessionStart_(false) { +Session::Session( + const JID& jid, + boost::shared_ptr<Connection> connection, + TLSLayerFactory* tlsLayerFactory, + PayloadParserFactoryCollection* payloadParserFactories, + PayloadSerializerCollection* payloadSerializers) : + jid_(jid), + tlsLayerFactory_(tlsLayerFactory), + payloadParserFactories_(payloadParserFactories), + payloadSerializers_(payloadSerializers), + state_(Initial), + error_(NoError), + connection_(connection), + streamStack_(0), + needSessionStart_(false) { } Session::~Session() { @@ -42,11 +47,11 @@ Session::~Session() { void Session::start() { assert(state_ == Initial); - state_ = Connecting; - connection_ = connectionFactory_->createConnection(); - connection_->onConnected.connect(boost::bind(&Session::handleConnected, this)); + connection_->onDisconnected.connect(boost::bind(&Session::handleDisconnected, this, _1)); - connection_->connect(jid_.getDomain()); + initializeStreamStack(); + state_ = WaitingForStreamStart; + sendStreamHeader(); } void Session::stop() { @@ -54,13 +59,6 @@ void Session::stop() { connection_->disconnect(); } -void Session::handleConnected() { - assert(state_ == Connecting); - initializeStreamStack(); - state_ = WaitingForStreamStart; - sendStreamHeader(); -} - void Session::sendStreamHeader() { ProtocolHeader header; header.setTo(jid_.getDomain()); @@ -81,18 +79,12 @@ void Session::initializeStreamStack() { void Session::handleDisconnected(const boost::optional<Connection::Error>& error) { if (error) { switch (*error) { - case Connection::DomainNameResolveError: - setError(DomainNameResolveError); - break; case Connection::ReadError: setError(ConnectionReadError); break; case Connection::WriteError: setError(ConnectionWriteError); break; - case Connection::ConnectionError: - setError(ConnectionError); - break; } } } diff --git a/Swiften/Client/Session.h b/Swiften/Client/Session.h index 17c10b9..58531b3 100644 --- a/Swiften/Client/Session.h +++ b/Swiften/Client/Session.h @@ -26,7 +26,6 @@ namespace Swift { public: enum State { Initial, - Connecting, WaitingForStreamStart, Negotiating, Compressing, @@ -40,8 +39,6 @@ namespace Swift { }; enum SessionError { NoError, - DomainNameResolveError, - ConnectionError, ConnectionReadError, ConnectionWriteError, XMLError, @@ -55,7 +52,12 @@ namespace Swift { ClientCertificateError }; - Session(const JID& jid, ConnectionFactory*, TLSLayerFactory*, PayloadParserFactoryCollection*, PayloadSerializerCollection*); + Session( + const JID& jid, + boost::shared_ptr<Connection>, + TLSLayerFactory*, + PayloadParserFactoryCollection*, + PayloadSerializerCollection*); ~Session(); State getState() const { @@ -72,6 +74,7 @@ namespace Swift { void start(); void stop(); + void sendCredentials(const String& password); void sendElement(boost::shared_ptr<Element>); void setCertificate(const PKCS12Certificate& certificate); @@ -86,7 +89,6 @@ namespace Swift { void sendStreamHeader(); void sendSessionStart(); - void handleConnected(); void handleDisconnected(const boost::optional<Connection::Error>&); void handleElement(boost::shared_ptr<Element>); void handleStreamStart(); @@ -106,7 +108,6 @@ namespace Swift { private: JID jid_; - ConnectionFactory* connectionFactory_; TLSLayerFactory* tlsLayerFactory_; PayloadParserFactoryCollection* payloadParserFactories_; PayloadSerializerCollection* payloadSerializers_; diff --git a/Swiften/Client/UnitTest/SessionTest.cpp b/Swiften/Client/UnitTest/SessionTest.cpp index a11ddde..eb7281c 100644 --- a/Swiften/Client/UnitTest/SessionTest.cpp +++ b/Swiften/Client/UnitTest/SessionTest.cpp @@ -39,10 +39,8 @@ using namespace Swift; class SessionTest : public CppUnit::TestFixture { CPPUNIT_TEST_SUITE(SessionTest); CPPUNIT_TEST(testConstructor); - CPPUNIT_TEST(testConnect); - CPPUNIT_TEST(testConnect_Error); - CPPUNIT_TEST(testConnect_ErrorAfterSuccesfulConnect); - CPPUNIT_TEST(testConnect_XMLError); + CPPUNIT_TEST(testStart_Error); + CPPUNIT_TEST(testStart_XMLError); CPPUNIT_TEST(testStartTLS); CPPUNIT_TEST(testStartTLS_ServerError); CPPUNIT_TEST(testStartTLS_NoTLSSupport); @@ -68,7 +66,7 @@ class SessionTest : public CppUnit::TestFixture { void setUp() { eventLoop_ = new DummyEventLoop(); - connectionFactory_ = new MockConnectionFactory(); + connection_ = boost::shared_ptr<MockConnection>(new MockConnection()); tlsLayerFactory_ = new MockTLSLayerFactory(); sessionStarted_ = false; needCredentials_ = false; @@ -76,7 +74,6 @@ class SessionTest : public CppUnit::TestFixture { void tearDown() { delete tlsLayerFactory_; - delete connectionFactory_; delete eventLoop_; } @@ -85,51 +82,26 @@ class SessionTest : public CppUnit::TestFixture { CPPUNIT_ASSERT_EQUAL(Session::Initial, session->getState()); } - void testConnect() { + void testStart_Error() { std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - session->start(); - CPPUNIT_ASSERT_EQUAL(Session::Connecting, session->getState()); - getMockServer()->expectStreamStart(); - - processEvents(); - CPPUNIT_ASSERT_EQUAL(Session::WaitingForStreamStart, session->getState()); - } - - void testConnect_Error() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - session->onSessionStarted.connect(boost::bind(&SessionTest::setSessionStarted, this)); - - connectionFactory_->setCreateFailingConnections(); session->start(); processEvents(); - - CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); - CPPUNIT_ASSERT(!sessionStarted_); - CPPUNIT_ASSERT_EQUAL(Session::ConnectionError, session->getError()); - } - - void testConnect_ErrorAfterSuccesfulConnect() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - - session->start(); - getMockServer()->expectStreamStart(); - processEvents(); CPPUNIT_ASSERT_EQUAL(Session::WaitingForStreamStart, session->getState()); - connectionFactory_->connections_[0]->setError(); + getMockServer()->setError(); processEvents(); CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); - CPPUNIT_ASSERT_EQUAL(Session::ConnectionError, session->getError()); + CPPUNIT_ASSERT_EQUAL(Session::ConnectionReadError, session->getError()); } - void testConnect_XMLError() { + void testStart_XMLError() { std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - session->start(); getMockServer()->expectStreamStart(); + session->start(); processEvents(); CPPUNIT_ASSERT_EQUAL(Session::WaitingForStreamStart, session->getState()); @@ -143,25 +115,23 @@ class SessionTest : public CppUnit::TestFixture { void testStartTLS_NoTLSSupport() { std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); tlsLayerFactory_->setTLSSupported(false); - session->start(); - getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithStartTLS(); + session->start(); processEvents(); CPPUNIT_ASSERT_EQUAL(Session::SessionStarted, session->getState()); } void testStartTLS() { std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - session->start(); - getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithStartTLS(); getMockServer()->expectStartTLS(); // FIXME: Test 'encrypting' state getMockServer()->sendTLSProceed(); + session->start(); processEvents(); CPPUNIT_ASSERT_EQUAL(Session::Encrypting, session->getState()); CPPUNIT_ASSERT(session->getTLSLayer()); @@ -178,13 +148,12 @@ class SessionTest : public CppUnit::TestFixture { void testStartTLS_ServerError() { std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - session->start(); - getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithStartTLS(); getMockServer()->expectStartTLS(); getMockServer()->sendTLSFailure(); + session->start(); processEvents(); CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); @@ -193,13 +162,12 @@ class SessionTest : public CppUnit::TestFixture { void testStartTLS_ConnectError() { std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - session->start(); - getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithStartTLS(); getMockServer()->expectStartTLS(); getMockServer()->sendTLSProceed(); + session->start(); processEvents(); session->getTLSLayer()->setError(); @@ -209,13 +177,12 @@ class SessionTest : public CppUnit::TestFixture { void testStartTLS_ErrorAfterConnect() { std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - session->start(); - getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithStartTLS(); getMockServer()->expectStartTLS(); getMockServer()->sendTLSProceed(); + session->start(); processEvents(); getMockServer()->resetParser(); getMockServer()->expectStreamStart(); @@ -232,11 +199,10 @@ class SessionTest : public CppUnit::TestFixture { void testAuthenticate() { std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); session->onNeedCredentials.connect(boost::bind(&SessionTest::setNeedCredentials, this)); - session->start(); - getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithAuthentication(); + session->start(); processEvents(); CPPUNIT_ASSERT_EQUAL(Session::WaitingForCredentials, session->getState()); CPPUNIT_ASSERT(needCredentials_); @@ -253,10 +219,10 @@ class SessionTest : public CppUnit::TestFixture { void testAuthenticate_Unauthorized() { std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - session->start(); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithAuthentication(); + session->start(); processEvents(); getMockServer()->expectAuth("me", "mypass"); @@ -270,10 +236,10 @@ class SessionTest : public CppUnit::TestFixture { void testAuthenticate_NoValidAuthMechanisms() { std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - session->start(); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithUnsupportedAuthentication(); + session->start(); processEvents(); CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); @@ -282,14 +248,14 @@ class SessionTest : public CppUnit::TestFixture { void testResourceBind() { std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - session->start(); - getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithResourceBind(); getMockServer()->expectResourceBind("Bar", "session-bind"); // FIXME: Check CPPUNIT_ASSERT_EQUAL(Session::BindingResource, session->getState()); getMockServer()->sendResourceBindResponse("me@foo.com/Bar", "session-bind"); + session->start(); + processEvents(); CPPUNIT_ASSERT_EQUAL(Session::SessionStarted, session->getState()); @@ -298,13 +264,12 @@ class SessionTest : public CppUnit::TestFixture { void testResourceBind_ChangeResource() { std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - session->start(); - getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithResourceBind(); getMockServer()->expectResourceBind("Bar", "session-bind"); getMockServer()->sendResourceBindResponse("me@foo.com/Bar123", "session-bind"); + session->start(); processEvents(); CPPUNIT_ASSERT_EQUAL(Session::SessionStarted, session->getState()); @@ -313,13 +278,12 @@ class SessionTest : public CppUnit::TestFixture { void testResourceBind_EmptyResource() { std::auto_ptr<MockSession> session(createSession("me@foo.com")); - session->start(); - getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithResourceBind(); getMockServer()->expectResourceBind("", "session-bind"); getMockServer()->sendResourceBindResponse("me@foo.com/NewResource", "session-bind"); + session->start(); processEvents(); CPPUNIT_ASSERT_EQUAL(Session::SessionStarted, session->getState()); @@ -328,13 +292,12 @@ class SessionTest : public CppUnit::TestFixture { void testResourceBind_Error() { std::auto_ptr<MockSession> session(createSession("me@foo.com")); - session->start(); - getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithResourceBind(); getMockServer()->expectResourceBind("", "session-bind"); getMockServer()->sendError("session-bind"); + session->start(); processEvents(); CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); @@ -344,14 +307,13 @@ class SessionTest : public CppUnit::TestFixture { void testSessionStart() { std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); session->onSessionStarted.connect(boost::bind(&SessionTest::setSessionStarted, this)); - session->start(); - getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithSession(); getMockServer()->expectSessionStart("session-start"); // FIXME: Check CPPUNIT_ASSERT_EQUAL(Session::StartingSession, session->getState()); getMockServer()->sendSessionStartResponse("session-start"); + session->start(); processEvents(); CPPUNIT_ASSERT_EQUAL(Session::SessionStarted, session->getState()); @@ -360,13 +322,12 @@ class SessionTest : public CppUnit::TestFixture { void testSessionStart_Error() { std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - session->start(); - getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithSession(); getMockServer()->expectSessionStart("session-start"); getMockServer()->sendError("session-start"); + session->start(); processEvents(); CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); @@ -376,7 +337,6 @@ class SessionTest : public CppUnit::TestFixture { void testSessionStart_AfterResourceBind() { std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); session->onSessionStarted.connect(boost::bind(&SessionTest::setSessionStarted, this)); - session->start(); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithResourceBindAndSession(); @@ -384,6 +344,7 @@ class SessionTest : public CppUnit::TestFixture { getMockServer()->sendResourceBindResponse("me@foo.com/Bar", "session-bind"); getMockServer()->expectSessionStart("session-start"); getMockServer()->sendSessionStartResponse("session-start"); + session->start(); processEvents(); CPPUNIT_ASSERT_EQUAL(Session::SessionStarted, session->getState()); @@ -392,20 +353,20 @@ class SessionTest : public CppUnit::TestFixture { void testWhitespacePing() { std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - session->start(); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeatures(); + session->start(); processEvents(); CPPUNIT_ASSERT(session->getWhitespacePingLayer()); } void testReceiveElementAfterSessionStarted() { std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - session->start(); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeatures(); + session->start(); processEvents(); getMockServer()->expectMessage(); @@ -415,11 +376,11 @@ class SessionTest : public CppUnit::TestFixture { void testSendElement() { std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); session->onElementReceived.connect(boost::bind(&SessionTest::addReceivedElement, this, _1)); - session->start(); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeatures(); getMockServer()->sendMessage(); + session->start(); processEvents(); CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(receivedElements_.size())); @@ -430,8 +391,7 @@ class SessionTest : public CppUnit::TestFixture { struct MockConnection; boost::shared_ptr<MockConnection> getMockServer() const { - CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connectionFactory_->connections_.size())); - return connectionFactory_->connections_[0]; + return connection_; } void processEvents() { @@ -468,9 +428,9 @@ class SessionTest : public CppUnit::TestFixture { boost::shared_ptr<Element> element; }; - MockConnection(bool fail) : - fail_(fail), + MockConnection() : resetParser_(false), + domain_("foo.com"), parser_(0), serializer_(&payloadSerializers_) { parser_ = new XMPPParser(this, &payloadParserFactories_); @@ -480,26 +440,17 @@ class SessionTest : public CppUnit::TestFixture { delete parser_; } - void disconnect() { - } + void disconnect() { } void listen() { assert(false); } void connect(const HostAddressPort&) { assert(false); } - void connect(const String& domain) { - if (fail_) { - MainEventLoop::postEvent(boost::bind(boost::ref(onDisconnected), Connection::ConnectionError)); - } - else { - domain_ = domain; - MainEventLoop::postEvent(boost::bind(boost::ref(onConnected))); - } - } + void connect(const String&) { assert(false); } void setError() { - MainEventLoop::postEvent(boost::bind(boost::ref(onDisconnected), Connection::ConnectionError)); + MainEventLoop::postEvent(boost::bind(boost::ref(onDisconnected), Connection::ReadError)); } void write(const ByteArray& data) { @@ -565,6 +516,9 @@ class SessionTest : public CppUnit::TestFixture { } void assertNoMoreExpectations() { + foreach (const Event& event, events_) { + std::cout << "Unprocessed event: " << serializeEvent(event) << std::endl; + } CPPUNIT_ASSERT(events_.empty()); } @@ -683,7 +637,6 @@ class SessionTest : public CppUnit::TestFixture { events_.push_back(Event(Event::Out, Event::ElementEvent, IQ::createResult(JID(), id, boost::shared_ptr<StartSession>(new StartSession())))); } - bool fail_; bool resetParser_; String domain_; FullPayloadParserFactoryCollection payloadParserFactories_; @@ -693,20 +646,6 @@ class SessionTest : public CppUnit::TestFixture { std::deque<Event> events_; }; - struct MockConnectionFactory : public ConnectionFactory { - MockConnectionFactory() : fail_(false) {} - boost::shared_ptr<Connection> createConnection() { - boost::shared_ptr<MockConnection> result(new MockConnection(fail_)); - connections_.push_back(result); - return result; - } - void setCreateFailingConnections() { - fail_ = true; - } - std::vector<boost::shared_ptr<MockConnection> > connections_; - bool fail_; - }; - struct MockTLSLayer : public TLSLayer { MockTLSLayer() : connecting_(false) {} bool setClientCertificate(const PKCS12Certificate&) { return true; } @@ -735,7 +674,7 @@ class SessionTest : public CppUnit::TestFixture { }; struct MockSession : public Session { - MockSession(const JID& jid, ConnectionFactory* connectionFactory, TLSLayerFactory* tlsLayerFactory, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers) : Session(jid, connectionFactory, tlsLayerFactory, payloadParserFactories, payloadSerializers) {} + MockSession(const JID& jid, boost::shared_ptr<Connection> connection, TLSLayerFactory* tlsLayerFactory, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers) : Session(jid, connection, tlsLayerFactory, payloadParserFactories, payloadSerializers) {} boost::shared_ptr<MockTLSLayer> getTLSLayer() const { return getStreamStack()->getLayer<MockTLSLayer>(); @@ -746,12 +685,12 @@ class SessionTest : public CppUnit::TestFixture { }; MockSession* createSession(const String& jid) { - return new MockSession(JID(jid), connectionFactory_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_); + return new MockSession(JID(jid), connection_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_); } DummyEventLoop* eventLoop_; - MockConnectionFactory* connectionFactory_; + boost::shared_ptr<MockConnection> connection_; MockTLSLayerFactory* tlsLayerFactory_; FullPayloadParserFactoryCollection payloadParserFactories_; FullPayloadSerializerCollection payloadSerializers_; |