diff options
Diffstat (limited to 'Swiften/Client')
-rw-r--r-- | Swiften/Client/Client.cpp | 147 | ||||
-rw-r--r-- | Swiften/Client/Client.h | 66 | ||||
-rw-r--r-- | Swiften/Client/ClientError.h | 32 | ||||
-rw-r--r-- | Swiften/Client/Makefile.inc | 5 | ||||
-rw-r--r-- | Swiften/Client/Session.cpp | 292 | ||||
-rw-r--r-- | Swiften/Client/Session.h | 126 | ||||
-rw-r--r-- | Swiften/Client/StanzaChannel.h | 22 | ||||
-rw-r--r-- | Swiften/Client/UnitTest/Makefile.inc | 2 | ||||
-rw-r--r-- | Swiften/Client/UnitTest/SessionTest.cpp | 752 |
9 files changed, 1444 insertions, 0 deletions
diff --git a/Swiften/Client/Client.cpp b/Swiften/Client/Client.cpp new file mode 100644 index 0000000..e5bbf9d --- /dev/null +++ b/Swiften/Client/Client.cpp @@ -0,0 +1,147 @@ +#include "Swiften/Client/Client.h" + +#include <boost/bind.hpp> + +#include "Swiften/Client/Session.h" +#include "Swiften/StreamStack/PlatformTLSLayerFactory.h" +#include "Swiften/Network/BoostConnectionFactory.h" +#include "Swiften/TLS/PKCS12Certificate.h" + +namespace Swift { + +Client::Client(const JID& jid, const String& password) : + IQRouter(this), jid_(jid), password_(password), session_(0) { + connectionFactory_ = new BoostConnectionFactory(); + tlsLayerFactory_ = new PlatformTLSLayerFactory(); +} + +Client::~Client() { + delete session_; + delete tlsLayerFactory_; + delete connectionFactory_; +} + +void Client::connect() { + delete session_; + session_ = new Session(jid_, connectionFactory_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_); + if (!certificate_.isEmpty()) { + session_->setCertificate(PKCS12Certificate(certificate_, password_)); + } + session_->onSessionStarted.connect(boost::bind(boost::ref(onConnected))); + session_->onError.connect(boost::bind(&Client::handleSessionError, this, _1)); + session_->onNeedCredentials.connect(boost::bind(&Client::handleNeedCredentials, this)); + session_->onDataRead.connect(boost::bind(&Client::handleDataRead, this, _1)); + session_->onDataWritten.connect(boost::bind(&Client::handleDataWritten, this, _1)); + session_->onElementReceived.connect(boost::bind(&Client::handleElement, this, _1)); + session_->start(); +} + +void Client::disconnect() { + if (session_) { + session_->stop(); + } +} + +void Client::send(boost::shared_ptr<Stanza> stanza) { + session_->sendElement(stanza); +} + +void Client::sendIQ(boost::shared_ptr<IQ> iq) { + send(iq); +} + +void Client::sendMessage(boost::shared_ptr<Message> message) { + send(message); +} + +void Client::sendPresence(boost::shared_ptr<Presence> presence) { + send(presence); +} + +String Client::getNewIQID() { + return idGenerator_.generateID(); +} + +void Client::handleElement(boost::shared_ptr<Element> element) { + boost::shared_ptr<Message> message = boost::dynamic_pointer_cast<Message>(element); + if (message) { + onMessageReceived(message); + return; + } + + boost::shared_ptr<Presence> presence = boost::dynamic_pointer_cast<Presence>(element); + if (presence) { + onPresenceReceived(presence); + return; + } + + boost::shared_ptr<IQ> iq = boost::dynamic_pointer_cast<IQ>(element); + if (iq) { + onIQReceived(iq); + return; + } +} + +void Client::setCertificate(const String& certificate) { + certificate_ = certificate; +} + +void Client::handleSessionError(Session::SessionError error) { + ClientError clientError; + switch (error) { + case Session::NoError: + assert(false); + break; + case Session::DomainNameResolveError: + clientError = ClientError(ClientError::DomainNameResolveError); + break; + case Session::ConnectionError: + clientError = ClientError(ClientError::ConnectionError); + break; + case Session::ConnectionReadError: + clientError = ClientError(ClientError::ConnectionReadError); + break; + case Session::XMLError: + clientError = ClientError(ClientError::XMLError); + break; + case Session::AuthenticationFailedError: + clientError = ClientError(ClientError::AuthenticationFailedError); + break; + case Session::NoSupportedAuthMechanismsError: + clientError = ClientError(ClientError::NoSupportedAuthMechanismsError); + break; + case Session::UnexpectedElementError: + clientError = ClientError(ClientError::UnexpectedElementError); + break; + case Session::ResourceBindError: + clientError = ClientError(ClientError::ResourceBindError); + break; + case Session::SessionStartError: + clientError = ClientError(ClientError::SessionStartError); + break; + case Session::TLSError: + clientError = ClientError(ClientError::TLSError); + break; + case Session::ClientCertificateLoadError: + clientError = ClientError(ClientError::ClientCertificateLoadError); + break; + case Session::ClientCertificateError: + clientError = ClientError(ClientError::ClientCertificateError); + break; + } + onError(clientError); +} + +void Client::handleNeedCredentials() { + session_->sendCredentials(password_); +} + +void Client::handleDataRead(const ByteArray& data) { + onDataRead(String(data.getData(), data.getSize())); +} + +void Client::handleDataWritten(const ByteArray& data) { + onDataWritten(String(data.getData(), data.getSize())); +} + +} diff --git a/Swiften/Client/Client.h b/Swiften/Client/Client.h new file mode 100644 index 0000000..946bdbd --- /dev/null +++ b/Swiften/Client/Client.h @@ -0,0 +1,66 @@ +#ifndef SWIFTEN_Client_H +#define SWIFTEN_Client_H + +#include <boost/signals.hpp> +#include <boost/shared_ptr.hpp> + +#include "Swiften/Client/Session.h" +#include "Swiften/Client/ClientError.h" +#include "Swiften/Elements/Presence.h" +#include "Swiften/Elements/Message.h" +#include "Swiften/JID/JID.h" +#include "Swiften/Base/String.h" +#include "Swiften/Base/IDGenerator.h" +#include "Swiften/Client/StanzaChannel.h" +#include "Swiften/Queries/IQRouter.h" +#include "Swiften/Parser/PayloadParsers/FullPayloadParserFactoryCollection.h" +#include "Swiften/Serializer/PayloadSerializers/FullPayloadSerializerCollection.h" + +namespace Swift { + class TLSLayerFactory; + class ConnectionFactory; + class Session; + + class Client : public StanzaChannel, public IQRouter { + public: + Client(const JID& jid, const String& password); + ~Client(); + + void setCertificate(const String& certificate); + + void connect(); + void disconnect(); + + virtual void sendIQ(boost::shared_ptr<IQ>); + virtual void sendMessage(boost::shared_ptr<Message>); + virtual void sendPresence(boost::shared_ptr<Presence>); + + public: + boost::signal<void (ClientError)> onError; + boost::signal<void ()> onConnected; + boost::signal<void (const String&)> onDataRead; + boost::signal<void (const String&)> onDataWritten; + + private: + void send(boost::shared_ptr<Stanza>); + virtual String getNewIQID(); + void handleElement(boost::shared_ptr<Element>); + void handleSessionError(Session::SessionError error); + void handleNeedCredentials(); + void handleDataRead(const ByteArray&); + void handleDataWritten(const ByteArray&); + + private: + JID jid_; + String password_; + IDGenerator idGenerator_; + ConnectionFactory* connectionFactory_; + TLSLayerFactory* tlsLayerFactory_; + FullPayloadParserFactoryCollection payloadParserFactories_; + FullPayloadSerializerCollection payloadSerializers_; + Session* session_; + String certificate_; + }; +} + +#endif diff --git a/Swiften/Client/ClientError.h b/Swiften/Client/ClientError.h new file mode 100644 index 0000000..38f20c0 --- /dev/null +++ b/Swiften/Client/ClientError.h @@ -0,0 +1,32 @@ +#ifndef SWIFTEN_ClientError_H +#define SWIFTEN_ClientError_H + +namespace Swift { + class ClientError { + public: + enum Type { + NoError, + DomainNameResolveError, + ConnectionError, + ConnectionReadError, + XMLError, + AuthenticationFailedError, + NoSupportedAuthMechanismsError, + UnexpectedElementError, + ResourceBindError, + SessionStartError, + TLSError, + ClientCertificateLoadError, + ClientCertificateError + }; + + ClientError(Type type = NoError) : type_(type) {} + + Type getType() const { return type_; } + + private: + Type type_; + }; +} + +#endif diff --git a/Swiften/Client/Makefile.inc b/Swiften/Client/Makefile.inc new file mode 100644 index 0000000..75eb08f --- /dev/null +++ b/Swiften/Client/Makefile.inc @@ -0,0 +1,5 @@ +SWIFTEN_SOURCES += \ + Swiften/Client/Client.cpp \ + Swiften/Client/Session.cpp + +include Swiften/Client/UnitTest/Makefile.inc diff --git a/Swiften/Client/Session.cpp b/Swiften/Client/Session.cpp new file mode 100644 index 0000000..aa3cc62 --- /dev/null +++ b/Swiften/Client/Session.cpp @@ -0,0 +1,292 @@ +#include "Swiften/Client/Session.h" + +#include <boost/bind.hpp> + +#include "Swiften/Network/ConnectionFactory.h" +#include "Swiften/StreamStack/StreamStack.h" +#include "Swiften/StreamStack/ConnectionLayer.h" +#include "Swiften/StreamStack/XMPPLayer.h" +#include "Swiften/StreamStack/TLSLayer.h" +#include "Swiften/StreamStack/TLSLayerFactory.h" +#include "Swiften/Elements/StreamFeatures.h" +#include "Swiften/Elements/StartTLSRequest.h" +#include "Swiften/Elements/StartTLSFailure.h" +#include "Swiften/Elements/TLSProceed.h" +#include "Swiften/Elements/AuthRequest.h" +#include "Swiften/Elements/AuthSuccess.h" +#include "Swiften/Elements/AuthFailure.h" +#include "Swiften/Elements/StartSession.h" +#include "Swiften/Elements/IQ.h" +#include "Swiften/Elements/ResourceBind.h" +#include "Swiften/SASL/PLAINMessage.h" +#include "Swiften/StreamStack/WhitespacePingLayer.h" + +namespace Swift { + +Session::Session(const JID& jid, ConnectionFactory* connectionFactory, TLSLayerFactory* tlsLayerFactory, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers) : + jid_(jid), + connectionFactory_(connectionFactory), + tlsLayerFactory_(tlsLayerFactory), + payloadParserFactories_(payloadParserFactories), + payloadSerializers_(payloadSerializers), + state_(Initial), + error_(NoError), + connection_(0), + xmppLayer_(0), + tlsLayer_(0), + connectionLayer_(0), + whitespacePingLayer_(0), + streamStack_(0), + needSessionStart_(false) { +} + +Session::~Session() { + delete streamStack_; + delete whitespacePingLayer_; + delete connectionLayer_; + delete tlsLayer_; + delete xmppLayer_; + delete connection_; +} + +void Session::start() { + assert(state_ == Initial); + state_ = Connecting; + connection_ = connectionFactory_->createConnection(jid_.getDomain()); + connection_->onConnected.connect(boost::bind(&Session::handleConnected, this)); + connection_->onError.connect(boost::bind(&Session::handleConnectionError, this, _1)); + connection_->connect(); +} + +void Session::stop() { + // TODO: Send end stream header if applicable + connection_->disconnect(); +} + +void Session::handleConnected() { + assert(state_ == Connecting); + initializeStreamStack(); + state_ = WaitingForStreamStart; + sendStreamHeader(); +} + +void Session::sendStreamHeader() { + xmppLayer_->writeHeader(jid_.getDomain()); +} + +void Session::initializeStreamStack() { + xmppLayer_ = new XMPPLayer(payloadParserFactories_, payloadSerializers_); + xmppLayer_->onStreamStart.connect(boost::bind(&Session::handleStreamStart, this)); + xmppLayer_->onElement.connect(boost::bind(&Session::handleElement, this, _1)); + xmppLayer_->onError.connect(boost::bind(&Session::setError, this, XMLError)); + xmppLayer_->onDataRead.connect(boost::bind(boost::ref(onDataRead), _1)); + xmppLayer_->onWriteData.connect(boost::bind(boost::ref(onDataWritten), _1)); + connectionLayer_ = new ConnectionLayer(connection_); + streamStack_ = new StreamStack(xmppLayer_, connectionLayer_); +} + +void Session::handleConnectionError(Connection::Error error) { + switch (error) { + case Connection::DomainNameResolveError: + setError(DomainNameResolveError); + break; + case Connection::ReadError: + setError(ConnectionReadError); + break; + case Connection::ConnectionError: + setError(ConnectionError); + break; + } +} + +void Session::setCertificate(const PKCS12Certificate& certificate) { + certificate_ = certificate; +} + +void Session::handleStreamStart() { + checkState(WaitingForStreamStart); + state_ = Negotiating; +} + +void Session::handleElement(boost::shared_ptr<Element> element) { + if (getState() == SessionStarted) { + onElementReceived(element); + } + else { + StreamFeatures* streamFeatures = dynamic_cast<StreamFeatures*>(element.get()); + if (streamFeatures) { + if (!checkState(Negotiating)) { + return; + } + + if (streamFeatures->hasStartTLS() && tlsLayerFactory_->canCreate()) { + state_ = Encrypting; + xmppLayer_->writeElement(boost::shared_ptr<StartTLSRequest>(new StartTLSRequest())); + } + else if (streamFeatures->hasAuthenticationMechanisms()) { + if (!certificate_.isNull()) { + if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { + state_ = Authenticating; + xmppLayer_->writeElement(boost::shared_ptr<Element>(new AuthRequest("EXTERNAL", ""))); + } + else { + setError(ClientCertificateError); + } + } + else if (streamFeatures->hasAuthenticationMechanism("PLAIN")) { + state_ = WaitingForCredentials; + onNeedCredentials(); + } + else { + setError(NoSupportedAuthMechanismsError); + } + } + else { + // Start the session + + // Add a whitespace ping layer + whitespacePingLayer_ = new WhitespacePingLayer(); + streamStack_->addLayer(whitespacePingLayer_); + + if (streamFeatures->hasSession()) { + needSessionStart_ = true; + } + + if (streamFeatures->hasResourceBind()) { + state_ = BindingResource; + boost::shared_ptr<ResourceBind> resourceBind(new ResourceBind()); + if (!jid_.getResource().isEmpty()) { + resourceBind->setResource(jid_.getResource()); + } + xmppLayer_->writeElement(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind)); + } + else if (needSessionStart_) { + sendSessionStart(); + } + else { + state_ = SessionStarted; + onSessionStarted(); + } + } + } + else { + AuthSuccess* authSuccess = dynamic_cast<AuthSuccess*>(element.get()); + if (authSuccess) { + checkState(Authenticating); + state_ = WaitingForStreamStart; + xmppLayer_->resetParser(); + sendStreamHeader(); + } + else if (dynamic_cast<AuthFailure*>(element.get())) { + setError(AuthenticationFailedError); + } + else if (dynamic_cast<TLSProceed*>(element.get())) { + tlsLayer_ = tlsLayerFactory_->createTLSLayer(); + streamStack_->addLayer(tlsLayer_); + if (!certificate_.isNull() && !tlsLayer_->setClientCertificate(certificate_)) { + setError(ClientCertificateLoadError); + } + else { + tlsLayer_->onConnected.connect(boost::bind(&Session::handleTLSConnected, this)); + tlsLayer_->onError.connect(boost::bind(&Session::handleTLSError, this)); + tlsLayer_->connect(); + } + } + else if (dynamic_cast<StartTLSFailure*>(element.get())) { + setError(TLSError); + } + else { + IQ* iq = dynamic_cast<IQ*>(element.get()); + if (iq) { + if (state_ == BindingResource) { + boost::shared_ptr<ResourceBind> resourceBind(iq->getPayload<ResourceBind>()); + if (iq->getType() == IQ::Error && iq->getID() == "session-bind") { + setError(ResourceBindError); + } + else if (!resourceBind) { + setError(UnexpectedElementError); + } + else if (iq->getType() == IQ::Result) { + jid_ = resourceBind->getJID(); + if (!jid_.isValid()) { + setError(ResourceBindError); + } + if (needSessionStart_) { + sendSessionStart(); + } + else { + state_ = SessionStarted; + } + } + else { + setError(UnexpectedElementError); + } + } + else if (state_ == StartingSession) { + if (iq->getType() == IQ::Result) { + state_ = SessionStarted; + onSessionStarted(); + } + else if (iq->getType() == IQ::Error) { + setError(SessionStartError); + } + else { + setError(UnexpectedElementError); + } + } + else { + setError(UnexpectedElementError); + } + } + else { + // FIXME Not correct? + state_ = SessionStarted; + onSessionStarted(); + } + } + } + } +} + +void Session::sendSessionStart() { + state_ = StartingSession; + xmppLayer_->writeElement(IQ::createRequest(IQ::Set, JID(), "session-start", boost::shared_ptr<StartSession>(new StartSession()))); +} + +void Session::setError(SessionError error) { + assert(error != NoError); + state_ = Error; + error_ = error; + onError(error); +} + +bool Session::checkState(State state) { + if (state_ != state) { + setError(UnexpectedElementError); + return false; + } + return true; +} + +void Session::sendCredentials(const String& password) { + assert(WaitingForCredentials); + state_ = Authenticating; + xmppLayer_->writeElement(boost::shared_ptr<Element>(new AuthRequest("PLAIN", PLAINMessage(jid_.getNode(), password).getValue()))); +} + +void Session::sendElement(boost::shared_ptr<Element> element) { + assert(SessionStarted); + xmppLayer_->writeElement(element); +} + +void Session::handleTLSConnected() { + state_ = WaitingForStreamStart; + xmppLayer_->resetParser(); + sendStreamHeader(); +} + +void Session::handleTLSError() { + setError(TLSError); +} + +} diff --git a/Swiften/Client/Session.h b/Swiften/Client/Session.h new file mode 100644 index 0000000..c49d877 --- /dev/null +++ b/Swiften/Client/Session.h @@ -0,0 +1,126 @@ +#ifndef SWIFTEN_Session_H +#define SWIFTEN_Session_H + +#include <boost/signal.hpp> +#include <boost/shared_ptr.hpp> + +#include "Swiften/Base/String.h" +#include "Swiften/JID/JID.h" +#include "Swiften/Elements/Element.h" +#include "Swiften/Network/Connection.h" +#include "Swiften/TLS/PKCS12Certificate.h" + +namespace Swift { + class PayloadParserFactoryCollection; + class PayloadSerializerCollection; + class ConnectionFactory; + class Connection; + class StreamStack; + class XMPPLayer; + class ConnectionLayer; + class TLSLayerFactory; + class TLSLayer; + class WhitespacePingLayer; + + class Session { + public: + enum State { + Initial, + Connecting, + WaitingForStreamStart, + Negotiating, + Compressing, + Encrypting, + WaitingForCredentials, + Authenticating, + BindingResource, + StartingSession, + SessionStarted, + Error + }; + enum SessionError { + NoError, + DomainNameResolveError, + ConnectionError, + ConnectionReadError, + XMLError, + AuthenticationFailedError, + NoSupportedAuthMechanismsError, + UnexpectedElementError, + ResourceBindError, + SessionStartError, + TLSError, + ClientCertificateLoadError, + ClientCertificateError + }; + + Session(const JID& jid, ConnectionFactory*, TLSLayerFactory*, PayloadParserFactoryCollection*, PayloadSerializerCollection*); + ~Session(); + + State getState() const { + return state_; + } + + SessionError getError() const { + return error_; + } + + const JID& getJID() const { + return jid_; + } + + void start(); + void stop(); + void sendCredentials(const String& password); + void sendElement(boost::shared_ptr<Element>); + void setCertificate(const PKCS12Certificate& certificate); + + protected: + StreamStack* getStreamStack() const { + return streamStack_; + } + + private: + void initializeStreamStack(); + void sendStreamHeader(); + void sendSessionStart(); + + void handleConnected(); + void handleConnectionError(Connection::Error); + void handleElement(boost::shared_ptr<Element>); + void handleStreamStart(); + void handleTLSConnected(); + void handleTLSError(); + + void setError(SessionError); + bool checkState(State); + + public: + boost::signal<void ()> onSessionStarted; + boost::signal<void (SessionError)> onError; + boost::signal<void ()> onNeedCredentials; + boost::signal<void (boost::shared_ptr<Element>) > onElementReceived; + boost::signal<void (const ByteArray&)> onDataWritten; + boost::signal<void (const ByteArray&)> onDataRead; + + private: + JID jid_; + ConnectionFactory* connectionFactory_; + TLSLayerFactory* tlsLayerFactory_; + PayloadParserFactoryCollection* payloadParserFactories_; + PayloadSerializerCollection* payloadSerializers_; + State state_; + SessionError error_; + Connection* connection_; + XMPPLayer* xmppLayer_; + TLSLayer* tlsLayer_; + ConnectionLayer* connectionLayer_; + WhitespacePingLayer* whitespacePingLayer_; + StreamStack* streamStack_; + bool needSessionStart_; + PKCS12Certificate certificate_; + }; + +} + +#endif diff --git a/Swiften/Client/StanzaChannel.h b/Swiften/Client/StanzaChannel.h new file mode 100644 index 0000000..719ed10 --- /dev/null +++ b/Swiften/Client/StanzaChannel.h @@ -0,0 +1,22 @@ +#ifndef SWIFTEN_MessageChannel_H +#define SWIFTEN_MessageChannel_H + +#include <boost/signal.hpp> +#include <boost/shared_ptr.hpp> + +#include "Swiften/Queries/IQChannel.h" +#include "Swiften/Elements/Message.h" +#include "Swiften/Elements/Presence.h" + +namespace Swift { + class StanzaChannel : public IQChannel { + public: + virtual void sendMessage(boost::shared_ptr<Message>) = 0; + virtual void sendPresence(boost::shared_ptr<Presence>) = 0; + + boost::signal<void (boost::shared_ptr<Message>)> onMessageReceived; + boost::signal<void (boost::shared_ptr<Presence>) > onPresenceReceived; + }; +} + +#endif diff --git a/Swiften/Client/UnitTest/Makefile.inc b/Swiften/Client/UnitTest/Makefile.inc new file mode 100644 index 0000000..3ef87e5 --- /dev/null +++ b/Swiften/Client/UnitTest/Makefile.inc @@ -0,0 +1,2 @@ +UNITTEST_SOURCES += \ + Swiften/Client/UnitTest/SessionTest.cpp diff --git a/Swiften/Client/UnitTest/SessionTest.cpp b/Swiften/Client/UnitTest/SessionTest.cpp new file mode 100644 index 0000000..7b7a916 --- /dev/null +++ b/Swiften/Client/UnitTest/SessionTest.cpp @@ -0,0 +1,752 @@ +#include <cppunit/extensions/HelperMacros.h> +#include <cppunit/extensions/TestFactoryRegistry.h> +#include <boost/bind.hpp> +#include <boost/function.hpp> +#include <boost/optional.hpp> + +#include "Swiften/Parser/XMPPParser.h" +#include "Swiften/Parser/XMPPParserClient.h" +#include "Swiften/Serializer/XMPPSerializer.h" +#include "Swiften/StreamStack/TLSLayerFactory.h" +#include "Swiften/StreamStack/TLSLayer.h" +#include "Swiften/StreamStack/StreamStack.h" +#include "Swiften/StreamStack/WhitespacePingLayer.h" +#include "Swiften/Elements/StreamFeatures.h" +#include "Swiften/Elements/Element.h" +#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/IQ.h" +#include "Swiften/Elements/AuthRequest.h" +#include "Swiften/Elements/AuthSuccess.h" +#include "Swiften/Elements/AuthFailure.h" +#include "Swiften/Elements/ResourceBind.h" +#include "Swiften/Elements/StartSession.h" +#include "Swiften/Elements/StartTLSRequest.h" +#include "Swiften/Elements/StartTLSFailure.h" +#include "Swiften/Elements/TLSProceed.h" +#include "Swiften/Elements/Message.h" +#include "Swiften/EventLoop/MainEventLoop.h" +#include "Swiften/EventLoop/DummyEventLoop.h" +#include "Swiften/Network/Connection.h" +#include "Swiften/Network/ConnectionFactory.h" +#include "Swiften/Client/Session.h" +#include "Swiften/TLS/PKCS12Certificate.h" +#include "Swiften/Parser/PayloadParsers/FullPayloadParserFactoryCollection.h" +#include "Swiften/Serializer/PayloadSerializers/FullPayloadSerializerCollection.h" + +using namespace Swift; + +class SessionTest : public CppUnit::TestFixture { + CPPUNIT_TEST_SUITE(SessionTest); + CPPUNIT_TEST(testConstructor); + CPPUNIT_TEST(testConnect); + CPPUNIT_TEST(testConnect_Error); + CPPUNIT_TEST(testConnect_ErrorAfterSuccesfulConnect); + CPPUNIT_TEST(testConnect_XMLError); + CPPUNIT_TEST(testStartTLS); + CPPUNIT_TEST(testStartTLS_ServerError); + CPPUNIT_TEST(testStartTLS_NoTLSSupport); + CPPUNIT_TEST(testStartTLS_ConnectError); + CPPUNIT_TEST(testStartTLS_ErrorAfterConnect); + CPPUNIT_TEST(testAuthenticate); + CPPUNIT_TEST(testAuthenticate_Unauthorized); + CPPUNIT_TEST(testAuthenticate_NoValidAuthMechanisms); + CPPUNIT_TEST(testResourceBind); + CPPUNIT_TEST(testResourceBind_ChangeResource); + CPPUNIT_TEST(testResourceBind_EmptyResource); + CPPUNIT_TEST(testResourceBind_Error); + CPPUNIT_TEST(testSessionStart); + CPPUNIT_TEST(testSessionStart_Error); + CPPUNIT_TEST(testSessionStart_AfterResourceBind); + CPPUNIT_TEST(testWhitespacePing); + CPPUNIT_TEST(testReceiveElementAfterSessionStarted); + CPPUNIT_TEST(testSendElement); + CPPUNIT_TEST_SUITE_END(); + + public: + SessionTest() {} + + void setUp() { + eventLoop_ = new DummyEventLoop(); + connectionFactory_ = new MockConnectionFactory(); + tlsLayerFactory_ = new MockTLSLayerFactory(); + sessionStarted_ = false; + needCredentials_ = false; + } + + void tearDown() { + delete tlsLayerFactory_; + delete connectionFactory_; + delete eventLoop_; + } + + void testConstructor() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + CPPUNIT_ASSERT_EQUAL(Session::Initial, session->getState()); + } + + void testConnect() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + + session->start(); + CPPUNIT_ASSERT_EQUAL(Session::Connecting, session->getState()); + + getMockServer()->expectStreamStart(); + + processEvents(); + CPPUNIT_ASSERT_EQUAL(Session::WaitingForStreamStart, session->getState()); + } + + void testConnect_Error() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + session->onSessionStarted.connect(boost::bind(&SessionTest::setSessionStarted, this)); + + connectionFactory_->setCreateFailingConnections(); + session->start(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); + CPPUNIT_ASSERT(!sessionStarted_); + CPPUNIT_ASSERT_EQUAL(Session::ConnectionError, session->getError()); + } + + void testConnect_ErrorAfterSuccesfulConnect() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + + session->start(); + getMockServer()->expectStreamStart(); + processEvents(); + CPPUNIT_ASSERT_EQUAL(Session::WaitingForStreamStart, session->getState()); + + connectionFactory_->connections_[0]->setError(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); + CPPUNIT_ASSERT_EQUAL(Session::ConnectionError, session->getError()); + } + + void testConnect_XMLError() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + + session->start(); + getMockServer()->expectStreamStart(); + processEvents(); + CPPUNIT_ASSERT_EQUAL(Session::WaitingForStreamStart, session->getState()); + + getMockServer()->sendInvalidXML(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); + CPPUNIT_ASSERT_EQUAL(Session::XMLError, session->getError()); + } + + void testStartTLS_NoTLSSupport() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + tlsLayerFactory_->setTLSSupported(false); + session->start(); + + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithStartTLS(); + processEvents(); + CPPUNIT_ASSERT_EQUAL(Session::SessionStarted, session->getState()); + } + + void testStartTLS() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + session->start(); + + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithStartTLS(); + getMockServer()->expectStartTLS(); + // FIXME: Test 'encrypting' state + getMockServer()->sendTLSProceed(); + processEvents(); + CPPUNIT_ASSERT_EQUAL(Session::Encrypting, session->getState()); + CPPUNIT_ASSERT(session->getTLSLayer()); + CPPUNIT_ASSERT(session->getTLSLayer()->isConnecting()); + + getMockServer()->resetParser(); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + session->getTLSLayer()->setConnected(); + // FIXME: Test 'WatingForStreamStart' state + processEvents(); + CPPUNIT_ASSERT_EQUAL(Session::Negotiating, session->getState()); + } + + void testStartTLS_ServerError() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + session->start(); + + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithStartTLS(); + getMockServer()->expectStartTLS(); + getMockServer()->sendTLSFailure(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); + CPPUNIT_ASSERT_EQUAL(Session::TLSError, session->getError()); + } + + void testStartTLS_ConnectError() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + session->start(); + + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithStartTLS(); + getMockServer()->expectStartTLS(); + getMockServer()->sendTLSProceed(); + processEvents(); + session->getTLSLayer()->setError(); + + CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); + CPPUNIT_ASSERT_EQUAL(Session::TLSError, session->getError()); + } + + void testStartTLS_ErrorAfterConnect() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + session->start(); + + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithStartTLS(); + getMockServer()->expectStartTLS(); + getMockServer()->sendTLSProceed(); + processEvents(); + getMockServer()->resetParser(); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + session->getTLSLayer()->setConnected(); + processEvents(); + + session->getTLSLayer()->setError(); + + CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); + CPPUNIT_ASSERT_EQUAL(Session::TLSError, session->getError()); + } + + void testAuthenticate() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + session->onNeedCredentials.connect(boost::bind(&SessionTest::setNeedCredentials, this)); + session->start(); + + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithAuthentication(); + processEvents(); + CPPUNIT_ASSERT_EQUAL(Session::WaitingForCredentials, session->getState()); + CPPUNIT_ASSERT(needCredentials_); + + getMockServer()->expectAuth("me", "mypass"); + getMockServer()->sendAuthSuccess(); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + session->sendCredentials("mypass"); + CPPUNIT_ASSERT_EQUAL(Session::Authenticating, session->getState()); + processEvents(); + CPPUNIT_ASSERT_EQUAL(Session::Negotiating, session->getState()); + } + + void testAuthenticate_Unauthorized() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + session->start(); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithAuthentication(); + processEvents(); + + getMockServer()->expectAuth("me", "mypass"); + getMockServer()->sendAuthFailure(); + session->sendCredentials("mypass"); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); + CPPUNIT_ASSERT_EQUAL(Session::AuthenticationFailedError, session->getError()); + } + + void testAuthenticate_NoValidAuthMechanisms() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + session->start(); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithUnsupportedAuthentication(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); + CPPUNIT_ASSERT_EQUAL(Session::NoSupportedAuthMechanismsError, session->getError()); + } + + void testResourceBind() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + session->start(); + + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithResourceBind(); + getMockServer()->expectResourceBind("Bar", "session-bind"); + // FIXME: Check CPPUNIT_ASSERT_EQUAL(Session::BindingResource, session->getState()); + getMockServer()->sendResourceBindResponse("me@foo.com/Bar", "session-bind"); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(Session::SessionStarted, session->getState()); + CPPUNIT_ASSERT_EQUAL(JID("me@foo.com/Bar"), session->getJID()); + } + + void testResourceBind_ChangeResource() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + session->start(); + + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithResourceBind(); + getMockServer()->expectResourceBind("Bar", "session-bind"); + getMockServer()->sendResourceBindResponse("me@foo.com/Bar123", "session-bind"); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(Session::SessionStarted, session->getState()); + CPPUNIT_ASSERT_EQUAL(JID("me@foo.com/Bar123"), session->getJID()); + } + + void testResourceBind_EmptyResource() { + std::auto_ptr<MockSession> session(createSession("me@foo.com")); + session->start(); + + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithResourceBind(); + getMockServer()->expectResourceBind("", "session-bind"); + getMockServer()->sendResourceBindResponse("me@foo.com/NewResource", "session-bind"); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(Session::SessionStarted, session->getState()); + CPPUNIT_ASSERT_EQUAL(JID("me@foo.com/NewResource"), session->getJID()); + } + + void testResourceBind_Error() { + std::auto_ptr<MockSession> session(createSession("me@foo.com")); + session->start(); + + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithResourceBind(); + getMockServer()->expectResourceBind("", "session-bind"); + getMockServer()->sendError("session-bind"); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); + CPPUNIT_ASSERT_EQUAL(Session::ResourceBindError, session->getError()); + } + + void testSessionStart() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + session->onSessionStarted.connect(boost::bind(&SessionTest::setSessionStarted, this)); + session->start(); + + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithSession(); + getMockServer()->expectSessionStart("session-start"); + // FIXME: Check CPPUNIT_ASSERT_EQUAL(Session::StartingSession, session->getState()); + getMockServer()->sendSessionStartResponse("session-start"); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(Session::SessionStarted, session->getState()); + CPPUNIT_ASSERT(sessionStarted_); + } + + void testSessionStart_Error() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + session->start(); + + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithSession(); + getMockServer()->expectSessionStart("session-start"); + getMockServer()->sendError("session-start"); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); + CPPUNIT_ASSERT_EQUAL(Session::SessionStartError, session->getError()); + } + + void testSessionStart_AfterResourceBind() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + session->onSessionStarted.connect(boost::bind(&SessionTest::setSessionStarted, this)); + session->start(); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithResourceBindAndSession(); + getMockServer()->expectResourceBind("Bar", "session-bind"); + getMockServer()->sendResourceBindResponse("me@foo.com/Bar", "session-bind"); + getMockServer()->expectSessionStart("session-start"); + getMockServer()->sendSessionStartResponse("session-start"); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(Session::SessionStarted, session->getState()); + CPPUNIT_ASSERT(sessionStarted_); + } + + void testWhitespacePing() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + session->start(); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeatures(); + processEvents(); + CPPUNIT_ASSERT(session->getWhitespacePingLayer()); + } + + void testReceiveElementAfterSessionStarted() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + session->start(); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeatures(); + processEvents(); + + getMockServer()->expectMessage(); + session->sendElement(boost::shared_ptr<Message>(new Message())); + } + + void testSendElement() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + session->onElementReceived.connect(boost::bind(&SessionTest::addReceivedElement, this, _1)); + session->start(); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeatures(); + getMockServer()->sendMessage(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(receivedElements_.size())); + CPPUNIT_ASSERT(boost::dynamic_pointer_cast<Message>(receivedElements_[0])); + } + + private: + struct MockConnection; + + MockConnection* getMockServer() const { + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connectionFactory_->connections_.size())); + return connectionFactory_->connections_[0]; + } + + void processEvents() { + eventLoop_->processEvents(); + getMockServer()->assertNoMoreExpectations(); + } + + void setSessionStarted() { + sessionStarted_ = true; + } + + void setNeedCredentials() { + needCredentials_ = true; + } + + void addReceivedElement(boost::shared_ptr<Element> element) { + receivedElements_.push_back(element); + } + + private: + struct MockConnection : public Connection, public XMPPParserClient { + struct Event { + enum Direction { In, Out }; + enum Type { StreamStartEvent, StreamEndEvent, ElementEvent }; + + Event( + Direction direction, + Type type, + boost::shared_ptr<Element> element = boost::shared_ptr<Element>()) : + direction(direction), type(type), element(element) {} + + Direction direction; + Type type; + boost::shared_ptr<Element> element; + }; + + MockConnection(const String& domain, bool fail) : + Connection(domain), + fail_(fail), + resetParser_(false), + parser_(0), + serializer_(&payloadSerializers_) { + parser_ = new XMPPParser(this, &payloadParserFactories_); + } + + ~MockConnection() { + delete parser_; + } + + void disconnect() { + } + + void connect() { + if (fail_) { + MainEventLoop::postEvent(boost::bind(boost::ref(onError), Connection::ConnectionError)); + } + else { + MainEventLoop::postEvent(boost::bind(boost::ref(onConnected))); + } + } + + void setError() { + MainEventLoop::postEvent(boost::bind(boost::ref(onError), Connection::ConnectionError)); + } + + void write(const ByteArray& data) { + CPPUNIT_ASSERT(parser_->parse(data.toString())); + if (resetParser_) { + resetParser(); + resetParser_ = false; + } + } + + void resetParser() { + delete parser_; + parser_ = new XMPPParser(this, &payloadParserFactories_); + } + + void handleStreamStart() { + handleEvent(Event::StreamStartEvent); + } + + void handleElement(boost::shared_ptr<Swift::Element> element) { + handleEvent(Event::ElementEvent, element); + } + + void handleStreamEnd() { + handleEvent(Event::StreamEndEvent); + } + + void handleEvent(Event::Type type, boost::shared_ptr<Element> element = boost::shared_ptr<Element>()) { + CPPUNIT_ASSERT(!events_.empty()); + CPPUNIT_ASSERT_EQUAL(events_[0].direction, Event::In); + CPPUNIT_ASSERT_EQUAL(events_[0].type, type); + if (type == Event::ElementEvent) { + CPPUNIT_ASSERT_EQUAL(serializer_.serializeElement(events_[0].element), serializer_.serializeElement(element)); + } + events_.pop_front(); + + while (!events_.empty() && events_[0].direction == Event::Out) { + sendData(serializeEvent(events_[0])); + events_.pop_front(); + } + + if (!events_.empty() && events_[0].type == Event::StreamStartEvent) { + resetParser_ = true; + } + } + + String serializeEvent(const Event& event) { + switch (event.type) { + case Event::StreamStartEvent: + return serializer_.serializeHeader(getDomain()); + case Event::ElementEvent: + return serializer_.serializeElement(event.element); + case Event::StreamEndEvent: + return serializer_.serializeFooter(); + } + assert(false); + } + + void assertNoMoreExpectations() { + CPPUNIT_ASSERT(events_.empty()); + } + + void sendData(const ByteArray& data) { + MainEventLoop::postEvent(boost::bind(boost::ref(onDataRead), data)); + } + + void expectStreamStart() { + events_.push_back(Event(Event::In, Event::StreamStartEvent)); + } + + void expectStartTLS() { + events_.push_back(Event(Event::In, Event::ElementEvent, boost::shared_ptr<StartTLSRequest>(new StartTLSRequest()))); + } + + void expectAuth(const String& user, const String& password) { + String s = String("") + '\0' + user + '\0' + password; + events_.push_back(Event(Event::In, Event::ElementEvent, boost::shared_ptr<AuthRequest>(new AuthRequest("PLAIN", ByteArray(s.getUTF8Data(), s.getUTF8Size()))))); + } + + void expectResourceBind(const String& resource, const String& id) { + boost::shared_ptr<ResourceBind> sessionStart(new ResourceBind()); + sessionStart->setResource(resource); + events_.push_back(Event(Event::In, Event::ElementEvent, IQ::createRequest(IQ::Set, JID(), id, sessionStart))); + } + + void expectSessionStart(const String& id) { + events_.push_back(Event(Event::In, Event::ElementEvent, IQ::createRequest(IQ::Set, JID(), id, boost::shared_ptr<StartSession>(new StartSession())))); + } + + void expectMessage() { + events_.push_back(Event(Event::In, Event::ElementEvent, boost::shared_ptr<Message>(new Message()))); + } + + void sendInvalidXML() { + sendData("<invalid xml/>"); + } + + void sendStreamStart() { + events_.push_back(Event(Event::Out, Event::StreamStartEvent)); + } + + void sendStreamFeatures() { + boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); + events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures)); + } + + void sendStreamFeaturesWithStartTLS() { + boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); + streamFeatures->setHasStartTLS(); + events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures)); + } + + void sendStreamFeaturesWithAuthentication() { + boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); + streamFeatures->addAuthenticationMechanism("PLAIN"); + events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures)); + } + + void sendStreamFeaturesWithUnsupportedAuthentication() { + boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); + streamFeatures->addAuthenticationMechanism("MY-UNSUPPORTED-MECH"); + events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures)); + } + + void sendStreamFeaturesWithResourceBind() { + boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); + streamFeatures->setHasResourceBind(); + events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures)); + } + + void sendStreamFeaturesWithSession() { + boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); + streamFeatures->setHasSession(); + events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures)); + } + + void sendStreamFeaturesWithResourceBindAndSession() { + boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); + streamFeatures->setHasResourceBind(); + streamFeatures->setHasSession(); + events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures)); + } + + void sendMessage() { + events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<Message>(new Message()))); + } + + void sendTLSProceed() { + events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<TLSProceed>(new TLSProceed()))); + } + + void sendTLSFailure() { + events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<StartTLSFailure>(new StartTLSFailure()))); + } + + void sendAuthSuccess() { + events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<AuthSuccess>(new AuthSuccess()))); + } + + void sendAuthFailure() { + events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<AuthFailure>(new AuthFailure()))); + } + + void sendResourceBindResponse(const String& jid, const String& id) { + boost::shared_ptr<ResourceBind> sessionStart(new ResourceBind()); + sessionStart->setJID(JID(jid)); + events_.push_back(Event(Event::Out, Event::ElementEvent, IQ::createResult(JID(), id, sessionStart))); + } + + void sendError(const String& id) { + events_.push_back(Event(Event::Out, Event::ElementEvent, IQ::createError(JID(), id, Swift::Error::NotAllowed, Swift::Error::Cancel))); + } + + void sendSessionStartResponse(const String& id) { + events_.push_back(Event(Event::Out, Event::ElementEvent, IQ::createResult(JID(), id, boost::shared_ptr<StartSession>(new StartSession())))); + } + + bool fail_; + bool resetParser_; + FullPayloadParserFactoryCollection payloadParserFactories_; + FullPayloadSerializerCollection payloadSerializers_; + XMPPParser* parser_; + XMPPSerializer serializer_; + std::deque<Event> events_; + }; + + struct MockConnectionFactory : public ConnectionFactory { + MockConnectionFactory() : fail_(false) {} + MockConnection* createConnection(const String& domain) { + MockConnection* result = new MockConnection(domain, fail_); + connections_.push_back(result); + return result; + } + void setCreateFailingConnections() { + fail_ = true; + } + std::vector<MockConnection*> connections_; + bool fail_; + }; + + struct MockTLSLayer : public TLSLayer { + MockTLSLayer() : connecting_(false) {} + bool setClientCertificate(const PKCS12Certificate&) { return true; } + void writeData(const ByteArray& data) { onWriteData(data); } + void handleDataRead(const ByteArray& data) { onDataRead(data); } + void setConnected() { onConnected(); } + void setError() { onError(); } + void connect() { connecting_ = true; } + bool isConnecting() { return connecting_; } + + bool connecting_; + }; + + struct MockTLSLayerFactory : public TLSLayerFactory { + MockTLSLayerFactory() : haveTLS_(true) {} + void setTLSSupported(bool b) { haveTLS_ = b; } + virtual bool canCreate() const { return haveTLS_; } + virtual TLSLayer* createTLSLayer() { + assert(haveTLS_); + MockTLSLayer* result = new MockTLSLayer(); + layers_.push_back(result); + return result; + } + std::vector<MockTLSLayer*> layers_; + bool haveTLS_; + }; + + struct MockSession : public Session { + MockSession(const JID& jid, ConnectionFactory* connectionFactory, TLSLayerFactory* tlsLayerFactory, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers) : Session(jid, connectionFactory, tlsLayerFactory, payloadParserFactories, payloadSerializers) {} + + MockTLSLayer* getTLSLayer() const { + return getStreamStack()->getLayer<MockTLSLayer>(); + } + WhitespacePingLayer* getWhitespacePingLayer() const { + return getStreamStack()->getLayer<WhitespacePingLayer>(); + } + }; + + MockSession* createSession(const String& jid) { + return new MockSession(JID(jid), connectionFactory_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_); + } + + + DummyEventLoop* eventLoop_; + MockConnectionFactory* connectionFactory_; + MockTLSLayerFactory* tlsLayerFactory_; + FullPayloadParserFactoryCollection payloadParserFactories_; + FullPayloadSerializerCollection payloadSerializers_; + bool sessionStarted_; + bool needCredentials_; + std::vector< boost::shared_ptr<Element> > receivedElements_; + typedef std::vector< boost::function<void ()> > EventQueue; + EventQueue events_; +}; + +CPPUNIT_TEST_SUITE_REGISTRATION(SessionTest); |