summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to 'Swiften/Client')
-rw-r--r--Swiften/Client/Client.cpp147
-rw-r--r--Swiften/Client/Client.h66
-rw-r--r--Swiften/Client/ClientError.h32
-rw-r--r--Swiften/Client/Makefile.inc5
-rw-r--r--Swiften/Client/Session.cpp292
-rw-r--r--Swiften/Client/Session.h126
-rw-r--r--Swiften/Client/StanzaChannel.h22
-rw-r--r--Swiften/Client/UnitTest/Makefile.inc2
-rw-r--r--Swiften/Client/UnitTest/SessionTest.cpp752
9 files changed, 1444 insertions, 0 deletions
diff --git a/Swiften/Client/Client.cpp b/Swiften/Client/Client.cpp
new file mode 100644
index 0000000..e5bbf9d
--- /dev/null
+++ b/Swiften/Client/Client.cpp
@@ -0,0 +1,147 @@
+#include "Swiften/Client/Client.h"
+
+#include <boost/bind.hpp>
+
+#include "Swiften/Client/Session.h"
+#include "Swiften/StreamStack/PlatformTLSLayerFactory.h"
+#include "Swiften/Network/BoostConnectionFactory.h"
+#include "Swiften/TLS/PKCS12Certificate.h"
+
+namespace Swift {
+
+Client::Client(const JID& jid, const String& password) :
+ IQRouter(this), jid_(jid), password_(password), session_(0) {
+ connectionFactory_ = new BoostConnectionFactory();
+ tlsLayerFactory_ = new PlatformTLSLayerFactory();
+}
+
+Client::~Client() {
+ delete session_;
+ delete tlsLayerFactory_;
+ delete connectionFactory_;
+}
+
+void Client::connect() {
+ delete session_;
+ session_ = new Session(jid_, connectionFactory_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_);
+ if (!certificate_.isEmpty()) {
+ session_->setCertificate(PKCS12Certificate(certificate_, password_));
+ }
+ session_->onSessionStarted.connect(boost::bind(boost::ref(onConnected)));
+ session_->onError.connect(boost::bind(&Client::handleSessionError, this, _1));
+ session_->onNeedCredentials.connect(boost::bind(&Client::handleNeedCredentials, this));
+ session_->onDataRead.connect(boost::bind(&Client::handleDataRead, this, _1));
+ session_->onDataWritten.connect(boost::bind(&Client::handleDataWritten, this, _1));
+ session_->onElementReceived.connect(boost::bind(&Client::handleElement, this, _1));
+ session_->start();
+}
+
+void Client::disconnect() {
+ if (session_) {
+ session_->stop();
+ }
+}
+
+void Client::send(boost::shared_ptr<Stanza> stanza) {
+ session_->sendElement(stanza);
+}
+
+void Client::sendIQ(boost::shared_ptr<IQ> iq) {
+ send(iq);
+}
+
+void Client::sendMessage(boost::shared_ptr<Message> message) {
+ send(message);
+}
+
+void Client::sendPresence(boost::shared_ptr<Presence> presence) {
+ send(presence);
+}
+
+String Client::getNewIQID() {
+ return idGenerator_.generateID();
+}
+
+void Client::handleElement(boost::shared_ptr<Element> element) {
+ boost::shared_ptr<Message> message = boost::dynamic_pointer_cast<Message>(element);
+ if (message) {
+ onMessageReceived(message);
+ return;
+ }
+
+ boost::shared_ptr<Presence> presence = boost::dynamic_pointer_cast<Presence>(element);
+ if (presence) {
+ onPresenceReceived(presence);
+ return;
+ }
+
+ boost::shared_ptr<IQ> iq = boost::dynamic_pointer_cast<IQ>(element);
+ if (iq) {
+ onIQReceived(iq);
+ return;
+ }
+}
+
+void Client::setCertificate(const String& certificate) {
+ certificate_ = certificate;
+}
+
+void Client::handleSessionError(Session::SessionError error) {
+ ClientError clientError;
+ switch (error) {
+ case Session::NoError:
+ assert(false);
+ break;
+ case Session::DomainNameResolveError:
+ clientError = ClientError(ClientError::DomainNameResolveError);
+ break;
+ case Session::ConnectionError:
+ clientError = ClientError(ClientError::ConnectionError);
+ break;
+ case Session::ConnectionReadError:
+ clientError = ClientError(ClientError::ConnectionReadError);
+ break;
+ case Session::XMLError:
+ clientError = ClientError(ClientError::XMLError);
+ break;
+ case Session::AuthenticationFailedError:
+ clientError = ClientError(ClientError::AuthenticationFailedError);
+ break;
+ case Session::NoSupportedAuthMechanismsError:
+ clientError = ClientError(ClientError::NoSupportedAuthMechanismsError);
+ break;
+ case Session::UnexpectedElementError:
+ clientError = ClientError(ClientError::UnexpectedElementError);
+ break;
+ case Session::ResourceBindError:
+ clientError = ClientError(ClientError::ResourceBindError);
+ break;
+ case Session::SessionStartError:
+ clientError = ClientError(ClientError::SessionStartError);
+ break;
+ case Session::TLSError:
+ clientError = ClientError(ClientError::TLSError);
+ break;
+ case Session::ClientCertificateLoadError:
+ clientError = ClientError(ClientError::ClientCertificateLoadError);
+ break;
+ case Session::ClientCertificateError:
+ clientError = ClientError(ClientError::ClientCertificateError);
+ break;
+ }
+ onError(clientError);
+}
+
+void Client::handleNeedCredentials() {
+ session_->sendCredentials(password_);
+}
+
+void Client::handleDataRead(const ByteArray& data) {
+ onDataRead(String(data.getData(), data.getSize()));
+}
+
+void Client::handleDataWritten(const ByteArray& data) {
+ onDataWritten(String(data.getData(), data.getSize()));
+}
+
+}
diff --git a/Swiften/Client/Client.h b/Swiften/Client/Client.h
new file mode 100644
index 0000000..946bdbd
--- /dev/null
+++ b/Swiften/Client/Client.h
@@ -0,0 +1,66 @@
+#ifndef SWIFTEN_Client_H
+#define SWIFTEN_Client_H
+
+#include <boost/signals.hpp>
+#include <boost/shared_ptr.hpp>
+
+#include "Swiften/Client/Session.h"
+#include "Swiften/Client/ClientError.h"
+#include "Swiften/Elements/Presence.h"
+#include "Swiften/Elements/Message.h"
+#include "Swiften/JID/JID.h"
+#include "Swiften/Base/String.h"
+#include "Swiften/Base/IDGenerator.h"
+#include "Swiften/Client/StanzaChannel.h"
+#include "Swiften/Queries/IQRouter.h"
+#include "Swiften/Parser/PayloadParsers/FullPayloadParserFactoryCollection.h"
+#include "Swiften/Serializer/PayloadSerializers/FullPayloadSerializerCollection.h"
+
+namespace Swift {
+ class TLSLayerFactory;
+ class ConnectionFactory;
+ class Session;
+
+ class Client : public StanzaChannel, public IQRouter {
+ public:
+ Client(const JID& jid, const String& password);
+ ~Client();
+
+ void setCertificate(const String& certificate);
+
+ void connect();
+ void disconnect();
+
+ virtual void sendIQ(boost::shared_ptr<IQ>);
+ virtual void sendMessage(boost::shared_ptr<Message>);
+ virtual void sendPresence(boost::shared_ptr<Presence>);
+
+ public:
+ boost::signal<void (ClientError)> onError;
+ boost::signal<void ()> onConnected;
+ boost::signal<void (const String&)> onDataRead;
+ boost::signal<void (const String&)> onDataWritten;
+
+ private:
+ void send(boost::shared_ptr<Stanza>);
+ virtual String getNewIQID();
+ void handleElement(boost::shared_ptr<Element>);
+ void handleSessionError(Session::SessionError error);
+ void handleNeedCredentials();
+ void handleDataRead(const ByteArray&);
+ void handleDataWritten(const ByteArray&);
+
+ private:
+ JID jid_;
+ String password_;
+ IDGenerator idGenerator_;
+ ConnectionFactory* connectionFactory_;
+ TLSLayerFactory* tlsLayerFactory_;
+ FullPayloadParserFactoryCollection payloadParserFactories_;
+ FullPayloadSerializerCollection payloadSerializers_;
+ Session* session_;
+ String certificate_;
+ };
+}
+
+#endif
diff --git a/Swiften/Client/ClientError.h b/Swiften/Client/ClientError.h
new file mode 100644
index 0000000..38f20c0
--- /dev/null
+++ b/Swiften/Client/ClientError.h
@@ -0,0 +1,32 @@
+#ifndef SWIFTEN_ClientError_H
+#define SWIFTEN_ClientError_H
+
+namespace Swift {
+ class ClientError {
+ public:
+ enum Type {
+ NoError,
+ DomainNameResolveError,
+ ConnectionError,
+ ConnectionReadError,
+ XMLError,
+ AuthenticationFailedError,
+ NoSupportedAuthMechanismsError,
+ UnexpectedElementError,
+ ResourceBindError,
+ SessionStartError,
+ TLSError,
+ ClientCertificateLoadError,
+ ClientCertificateError
+ };
+
+ ClientError(Type type = NoError) : type_(type) {}
+
+ Type getType() const { return type_; }
+
+ private:
+ Type type_;
+ };
+}
+
+#endif
diff --git a/Swiften/Client/Makefile.inc b/Swiften/Client/Makefile.inc
new file mode 100644
index 0000000..75eb08f
--- /dev/null
+++ b/Swiften/Client/Makefile.inc
@@ -0,0 +1,5 @@
+SWIFTEN_SOURCES += \
+ Swiften/Client/Client.cpp \
+ Swiften/Client/Session.cpp
+
+include Swiften/Client/UnitTest/Makefile.inc
diff --git a/Swiften/Client/Session.cpp b/Swiften/Client/Session.cpp
new file mode 100644
index 0000000..aa3cc62
--- /dev/null
+++ b/Swiften/Client/Session.cpp
@@ -0,0 +1,292 @@
+#include "Swiften/Client/Session.h"
+
+#include <boost/bind.hpp>
+
+#include "Swiften/Network/ConnectionFactory.h"
+#include "Swiften/StreamStack/StreamStack.h"
+#include "Swiften/StreamStack/ConnectionLayer.h"
+#include "Swiften/StreamStack/XMPPLayer.h"
+#include "Swiften/StreamStack/TLSLayer.h"
+#include "Swiften/StreamStack/TLSLayerFactory.h"
+#include "Swiften/Elements/StreamFeatures.h"
+#include "Swiften/Elements/StartTLSRequest.h"
+#include "Swiften/Elements/StartTLSFailure.h"
+#include "Swiften/Elements/TLSProceed.h"
+#include "Swiften/Elements/AuthRequest.h"
+#include "Swiften/Elements/AuthSuccess.h"
+#include "Swiften/Elements/AuthFailure.h"
+#include "Swiften/Elements/StartSession.h"
+#include "Swiften/Elements/IQ.h"
+#include "Swiften/Elements/ResourceBind.h"
+#include "Swiften/SASL/PLAINMessage.h"
+#include "Swiften/StreamStack/WhitespacePingLayer.h"
+
+namespace Swift {
+
+Session::Session(const JID& jid, ConnectionFactory* connectionFactory, TLSLayerFactory* tlsLayerFactory, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers) :
+ jid_(jid),
+ connectionFactory_(connectionFactory),
+ tlsLayerFactory_(tlsLayerFactory),
+ payloadParserFactories_(payloadParserFactories),
+ payloadSerializers_(payloadSerializers),
+ state_(Initial),
+ error_(NoError),
+ connection_(0),
+ xmppLayer_(0),
+ tlsLayer_(0),
+ connectionLayer_(0),
+ whitespacePingLayer_(0),
+ streamStack_(0),
+ needSessionStart_(false) {
+}
+
+Session::~Session() {
+ delete streamStack_;
+ delete whitespacePingLayer_;
+ delete connectionLayer_;
+ delete tlsLayer_;
+ delete xmppLayer_;
+ delete connection_;
+}
+
+void Session::start() {
+ assert(state_ == Initial);
+ state_ = Connecting;
+ connection_ = connectionFactory_->createConnection(jid_.getDomain());
+ connection_->onConnected.connect(boost::bind(&Session::handleConnected, this));
+ connection_->onError.connect(boost::bind(&Session::handleConnectionError, this, _1));
+ connection_->connect();
+}
+
+void Session::stop() {
+ // TODO: Send end stream header if applicable
+ connection_->disconnect();
+}
+
+void Session::handleConnected() {
+ assert(state_ == Connecting);
+ initializeStreamStack();
+ state_ = WaitingForStreamStart;
+ sendStreamHeader();
+}
+
+void Session::sendStreamHeader() {
+ xmppLayer_->writeHeader(jid_.getDomain());
+}
+
+void Session::initializeStreamStack() {
+ xmppLayer_ = new XMPPLayer(payloadParserFactories_, payloadSerializers_);
+ xmppLayer_->onStreamStart.connect(boost::bind(&Session::handleStreamStart, this));
+ xmppLayer_->onElement.connect(boost::bind(&Session::handleElement, this, _1));
+ xmppLayer_->onError.connect(boost::bind(&Session::setError, this, XMLError));
+ xmppLayer_->onDataRead.connect(boost::bind(boost::ref(onDataRead), _1));
+ xmppLayer_->onWriteData.connect(boost::bind(boost::ref(onDataWritten), _1));
+ connectionLayer_ = new ConnectionLayer(connection_);
+ streamStack_ = new StreamStack(xmppLayer_, connectionLayer_);
+}
+
+void Session::handleConnectionError(Connection::Error error) {
+ switch (error) {
+ case Connection::DomainNameResolveError:
+ setError(DomainNameResolveError);
+ break;
+ case Connection::ReadError:
+ setError(ConnectionReadError);
+ break;
+ case Connection::ConnectionError:
+ setError(ConnectionError);
+ break;
+ }
+}
+
+void Session::setCertificate(const PKCS12Certificate& certificate) {
+ certificate_ = certificate;
+}
+
+void Session::handleStreamStart() {
+ checkState(WaitingForStreamStart);
+ state_ = Negotiating;
+}
+
+void Session::handleElement(boost::shared_ptr<Element> element) {
+ if (getState() == SessionStarted) {
+ onElementReceived(element);
+ }
+ else {
+ StreamFeatures* streamFeatures = dynamic_cast<StreamFeatures*>(element.get());
+ if (streamFeatures) {
+ if (!checkState(Negotiating)) {
+ return;
+ }
+
+ if (streamFeatures->hasStartTLS() && tlsLayerFactory_->canCreate()) {
+ state_ = Encrypting;
+ xmppLayer_->writeElement(boost::shared_ptr<StartTLSRequest>(new StartTLSRequest()));
+ }
+ else if (streamFeatures->hasAuthenticationMechanisms()) {
+ if (!certificate_.isNull()) {
+ if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) {
+ state_ = Authenticating;
+ xmppLayer_->writeElement(boost::shared_ptr<Element>(new AuthRequest("EXTERNAL", "")));
+ }
+ else {
+ setError(ClientCertificateError);
+ }
+ }
+ else if (streamFeatures->hasAuthenticationMechanism("PLAIN")) {
+ state_ = WaitingForCredentials;
+ onNeedCredentials();
+ }
+ else {
+ setError(NoSupportedAuthMechanismsError);
+ }
+ }
+ else {
+ // Start the session
+
+ // Add a whitespace ping layer
+ whitespacePingLayer_ = new WhitespacePingLayer();
+ streamStack_->addLayer(whitespacePingLayer_);
+
+ if (streamFeatures->hasSession()) {
+ needSessionStart_ = true;
+ }
+
+ if (streamFeatures->hasResourceBind()) {
+ state_ = BindingResource;
+ boost::shared_ptr<ResourceBind> resourceBind(new ResourceBind());
+ if (!jid_.getResource().isEmpty()) {
+ resourceBind->setResource(jid_.getResource());
+ }
+ xmppLayer_->writeElement(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind));
+ }
+ else if (needSessionStart_) {
+ sendSessionStart();
+ }
+ else {
+ state_ = SessionStarted;
+ onSessionStarted();
+ }
+ }
+ }
+ else {
+ AuthSuccess* authSuccess = dynamic_cast<AuthSuccess*>(element.get());
+ if (authSuccess) {
+ checkState(Authenticating);
+ state_ = WaitingForStreamStart;
+ xmppLayer_->resetParser();
+ sendStreamHeader();
+ }
+ else if (dynamic_cast<AuthFailure*>(element.get())) {
+ setError(AuthenticationFailedError);
+ }
+ else if (dynamic_cast<TLSProceed*>(element.get())) {
+ tlsLayer_ = tlsLayerFactory_->createTLSLayer();
+ streamStack_->addLayer(tlsLayer_);
+ if (!certificate_.isNull() && !tlsLayer_->setClientCertificate(certificate_)) {
+ setError(ClientCertificateLoadError);
+ }
+ else {
+ tlsLayer_->onConnected.connect(boost::bind(&Session::handleTLSConnected, this));
+ tlsLayer_->onError.connect(boost::bind(&Session::handleTLSError, this));
+ tlsLayer_->connect();
+ }
+ }
+ else if (dynamic_cast<StartTLSFailure*>(element.get())) {
+ setError(TLSError);
+ }
+ else {
+ IQ* iq = dynamic_cast<IQ*>(element.get());
+ if (iq) {
+ if (state_ == BindingResource) {
+ boost::shared_ptr<ResourceBind> resourceBind(iq->getPayload<ResourceBind>());
+ if (iq->getType() == IQ::Error && iq->getID() == "session-bind") {
+ setError(ResourceBindError);
+ }
+ else if (!resourceBind) {
+ setError(UnexpectedElementError);
+ }
+ else if (iq->getType() == IQ::Result) {
+ jid_ = resourceBind->getJID();
+ if (!jid_.isValid()) {
+ setError(ResourceBindError);
+ }
+ if (needSessionStart_) {
+ sendSessionStart();
+ }
+ else {
+ state_ = SessionStarted;
+ }
+ }
+ else {
+ setError(UnexpectedElementError);
+ }
+ }
+ else if (state_ == StartingSession) {
+ if (iq->getType() == IQ::Result) {
+ state_ = SessionStarted;
+ onSessionStarted();
+ }
+ else if (iq->getType() == IQ::Error) {
+ setError(SessionStartError);
+ }
+ else {
+ setError(UnexpectedElementError);
+ }
+ }
+ else {
+ setError(UnexpectedElementError);
+ }
+ }
+ else {
+ // FIXME Not correct?
+ state_ = SessionStarted;
+ onSessionStarted();
+ }
+ }
+ }
+ }
+}
+
+void Session::sendSessionStart() {
+ state_ = StartingSession;
+ xmppLayer_->writeElement(IQ::createRequest(IQ::Set, JID(), "session-start", boost::shared_ptr<StartSession>(new StartSession())));
+}
+
+void Session::setError(SessionError error) {
+ assert(error != NoError);
+ state_ = Error;
+ error_ = error;
+ onError(error);
+}
+
+bool Session::checkState(State state) {
+ if (state_ != state) {
+ setError(UnexpectedElementError);
+ return false;
+ }
+ return true;
+}
+
+void Session::sendCredentials(const String& password) {
+ assert(WaitingForCredentials);
+ state_ = Authenticating;
+ xmppLayer_->writeElement(boost::shared_ptr<Element>(new AuthRequest("PLAIN", PLAINMessage(jid_.getNode(), password).getValue())));
+}
+
+void Session::sendElement(boost::shared_ptr<Element> element) {
+ assert(SessionStarted);
+ xmppLayer_->writeElement(element);
+}
+
+void Session::handleTLSConnected() {
+ state_ = WaitingForStreamStart;
+ xmppLayer_->resetParser();
+ sendStreamHeader();
+}
+
+void Session::handleTLSError() {
+ setError(TLSError);
+}
+
+}
diff --git a/Swiften/Client/Session.h b/Swiften/Client/Session.h
new file mode 100644
index 0000000..c49d877
--- /dev/null
+++ b/Swiften/Client/Session.h
@@ -0,0 +1,126 @@
+#ifndef SWIFTEN_Session_H
+#define SWIFTEN_Session_H
+
+#include <boost/signal.hpp>
+#include <boost/shared_ptr.hpp>
+
+#include "Swiften/Base/String.h"
+#include "Swiften/JID/JID.h"
+#include "Swiften/Elements/Element.h"
+#include "Swiften/Network/Connection.h"
+#include "Swiften/TLS/PKCS12Certificate.h"
+
+namespace Swift {
+ class PayloadParserFactoryCollection;
+ class PayloadSerializerCollection;
+ class ConnectionFactory;
+ class Connection;
+ class StreamStack;
+ class XMPPLayer;
+ class ConnectionLayer;
+ class TLSLayerFactory;
+ class TLSLayer;
+ class WhitespacePingLayer;
+
+ class Session {
+ public:
+ enum State {
+ Initial,
+ Connecting,
+ WaitingForStreamStart,
+ Negotiating,
+ Compressing,
+ Encrypting,
+ WaitingForCredentials,
+ Authenticating,
+ BindingResource,
+ StartingSession,
+ SessionStarted,
+ Error
+ };
+ enum SessionError {
+ NoError,
+ DomainNameResolveError,
+ ConnectionError,
+ ConnectionReadError,
+ XMLError,
+ AuthenticationFailedError,
+ NoSupportedAuthMechanismsError,
+ UnexpectedElementError,
+ ResourceBindError,
+ SessionStartError,
+ TLSError,
+ ClientCertificateLoadError,
+ ClientCertificateError
+ };
+
+ Session(const JID& jid, ConnectionFactory*, TLSLayerFactory*, PayloadParserFactoryCollection*, PayloadSerializerCollection*);
+ ~Session();
+
+ State getState() const {
+ return state_;
+ }
+
+ SessionError getError() const {
+ return error_;
+ }
+
+ const JID& getJID() const {
+ return jid_;
+ }
+
+ void start();
+ void stop();
+ void sendCredentials(const String& password);
+ void sendElement(boost::shared_ptr<Element>);
+ void setCertificate(const PKCS12Certificate& certificate);
+
+ protected:
+ StreamStack* getStreamStack() const {
+ return streamStack_;
+ }
+
+ private:
+ void initializeStreamStack();
+ void sendStreamHeader();
+ void sendSessionStart();
+
+ void handleConnected();
+ void handleConnectionError(Connection::Error);
+ void handleElement(boost::shared_ptr<Element>);
+ void handleStreamStart();
+ void handleTLSConnected();
+ void handleTLSError();
+
+ void setError(SessionError);
+ bool checkState(State);
+
+ public:
+ boost::signal<void ()> onSessionStarted;
+ boost::signal<void (SessionError)> onError;
+ boost::signal<void ()> onNeedCredentials;
+ boost::signal<void (boost::shared_ptr<Element>) > onElementReceived;
+ boost::signal<void (const ByteArray&)> onDataWritten;
+ boost::signal<void (const ByteArray&)> onDataRead;
+
+ private:
+ JID jid_;
+ ConnectionFactory* connectionFactory_;
+ TLSLayerFactory* tlsLayerFactory_;
+ PayloadParserFactoryCollection* payloadParserFactories_;
+ PayloadSerializerCollection* payloadSerializers_;
+ State state_;
+ SessionError error_;
+ Connection* connection_;
+ XMPPLayer* xmppLayer_;
+ TLSLayer* tlsLayer_;
+ ConnectionLayer* connectionLayer_;
+ WhitespacePingLayer* whitespacePingLayer_;
+ StreamStack* streamStack_;
+ bool needSessionStart_;
+ PKCS12Certificate certificate_;
+ };
+
+}
+
+#endif
diff --git a/Swiften/Client/StanzaChannel.h b/Swiften/Client/StanzaChannel.h
new file mode 100644
index 0000000..719ed10
--- /dev/null
+++ b/Swiften/Client/StanzaChannel.h
@@ -0,0 +1,22 @@
+#ifndef SWIFTEN_MessageChannel_H
+#define SWIFTEN_MessageChannel_H
+
+#include <boost/signal.hpp>
+#include <boost/shared_ptr.hpp>
+
+#include "Swiften/Queries/IQChannel.h"
+#include "Swiften/Elements/Message.h"
+#include "Swiften/Elements/Presence.h"
+
+namespace Swift {
+ class StanzaChannel : public IQChannel {
+ public:
+ virtual void sendMessage(boost::shared_ptr<Message>) = 0;
+ virtual void sendPresence(boost::shared_ptr<Presence>) = 0;
+
+ boost::signal<void (boost::shared_ptr<Message>)> onMessageReceived;
+ boost::signal<void (boost::shared_ptr<Presence>) > onPresenceReceived;
+ };
+}
+
+#endif
diff --git a/Swiften/Client/UnitTest/Makefile.inc b/Swiften/Client/UnitTest/Makefile.inc
new file mode 100644
index 0000000..3ef87e5
--- /dev/null
+++ b/Swiften/Client/UnitTest/Makefile.inc
@@ -0,0 +1,2 @@
+UNITTEST_SOURCES += \
+ Swiften/Client/UnitTest/SessionTest.cpp
diff --git a/Swiften/Client/UnitTest/SessionTest.cpp b/Swiften/Client/UnitTest/SessionTest.cpp
new file mode 100644
index 0000000..7b7a916
--- /dev/null
+++ b/Swiften/Client/UnitTest/SessionTest.cpp
@@ -0,0 +1,752 @@
+#include <cppunit/extensions/HelperMacros.h>
+#include <cppunit/extensions/TestFactoryRegistry.h>
+#include <boost/bind.hpp>
+#include <boost/function.hpp>
+#include <boost/optional.hpp>
+
+#include "Swiften/Parser/XMPPParser.h"
+#include "Swiften/Parser/XMPPParserClient.h"
+#include "Swiften/Serializer/XMPPSerializer.h"
+#include "Swiften/StreamStack/TLSLayerFactory.h"
+#include "Swiften/StreamStack/TLSLayer.h"
+#include "Swiften/StreamStack/StreamStack.h"
+#include "Swiften/StreamStack/WhitespacePingLayer.h"
+#include "Swiften/Elements/StreamFeatures.h"
+#include "Swiften/Elements/Element.h"
+#include "Swiften/Elements/Error.h"
+#include "Swiften/Elements/IQ.h"
+#include "Swiften/Elements/AuthRequest.h"
+#include "Swiften/Elements/AuthSuccess.h"
+#include "Swiften/Elements/AuthFailure.h"
+#include "Swiften/Elements/ResourceBind.h"
+#include "Swiften/Elements/StartSession.h"
+#include "Swiften/Elements/StartTLSRequest.h"
+#include "Swiften/Elements/StartTLSFailure.h"
+#include "Swiften/Elements/TLSProceed.h"
+#include "Swiften/Elements/Message.h"
+#include "Swiften/EventLoop/MainEventLoop.h"
+#include "Swiften/EventLoop/DummyEventLoop.h"
+#include "Swiften/Network/Connection.h"
+#include "Swiften/Network/ConnectionFactory.h"
+#include "Swiften/Client/Session.h"
+#include "Swiften/TLS/PKCS12Certificate.h"
+#include "Swiften/Parser/PayloadParsers/FullPayloadParserFactoryCollection.h"
+#include "Swiften/Serializer/PayloadSerializers/FullPayloadSerializerCollection.h"
+
+using namespace Swift;
+
+class SessionTest : public CppUnit::TestFixture {
+ CPPUNIT_TEST_SUITE(SessionTest);
+ CPPUNIT_TEST(testConstructor);
+ CPPUNIT_TEST(testConnect);
+ CPPUNIT_TEST(testConnect_Error);
+ CPPUNIT_TEST(testConnect_ErrorAfterSuccesfulConnect);
+ CPPUNIT_TEST(testConnect_XMLError);
+ CPPUNIT_TEST(testStartTLS);
+ CPPUNIT_TEST(testStartTLS_ServerError);
+ CPPUNIT_TEST(testStartTLS_NoTLSSupport);
+ CPPUNIT_TEST(testStartTLS_ConnectError);
+ CPPUNIT_TEST(testStartTLS_ErrorAfterConnect);
+ CPPUNIT_TEST(testAuthenticate);
+ CPPUNIT_TEST(testAuthenticate_Unauthorized);
+ CPPUNIT_TEST(testAuthenticate_NoValidAuthMechanisms);
+ CPPUNIT_TEST(testResourceBind);
+ CPPUNIT_TEST(testResourceBind_ChangeResource);
+ CPPUNIT_TEST(testResourceBind_EmptyResource);
+ CPPUNIT_TEST(testResourceBind_Error);
+ CPPUNIT_TEST(testSessionStart);
+ CPPUNIT_TEST(testSessionStart_Error);
+ CPPUNIT_TEST(testSessionStart_AfterResourceBind);
+ CPPUNIT_TEST(testWhitespacePing);
+ CPPUNIT_TEST(testReceiveElementAfterSessionStarted);
+ CPPUNIT_TEST(testSendElement);
+ CPPUNIT_TEST_SUITE_END();
+
+ public:
+ SessionTest() {}
+
+ void setUp() {
+ eventLoop_ = new DummyEventLoop();
+ connectionFactory_ = new MockConnectionFactory();
+ tlsLayerFactory_ = new MockTLSLayerFactory();
+ sessionStarted_ = false;
+ needCredentials_ = false;
+ }
+
+ void tearDown() {
+ delete tlsLayerFactory_;
+ delete connectionFactory_;
+ delete eventLoop_;
+ }
+
+ void testConstructor() {
+ std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ CPPUNIT_ASSERT_EQUAL(Session::Initial, session->getState());
+ }
+
+ void testConnect() {
+ std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+
+ session->start();
+ CPPUNIT_ASSERT_EQUAL(Session::Connecting, session->getState());
+
+ getMockServer()->expectStreamStart();
+
+ processEvents();
+ CPPUNIT_ASSERT_EQUAL(Session::WaitingForStreamStart, session->getState());
+ }
+
+ void testConnect_Error() {
+ std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ session->onSessionStarted.connect(boost::bind(&SessionTest::setSessionStarted, this));
+
+ connectionFactory_->setCreateFailingConnections();
+ session->start();
+ processEvents();
+
+ CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState());
+ CPPUNIT_ASSERT(!sessionStarted_);
+ CPPUNIT_ASSERT_EQUAL(Session::ConnectionError, session->getError());
+ }
+
+ void testConnect_ErrorAfterSuccesfulConnect() {
+ std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+
+ session->start();
+ getMockServer()->expectStreamStart();
+ processEvents();
+ CPPUNIT_ASSERT_EQUAL(Session::WaitingForStreamStart, session->getState());
+
+ connectionFactory_->connections_[0]->setError();
+ processEvents();
+
+ CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState());
+ CPPUNIT_ASSERT_EQUAL(Session::ConnectionError, session->getError());
+ }
+
+ void testConnect_XMLError() {
+ std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+
+ session->start();
+ getMockServer()->expectStreamStart();
+ processEvents();
+ CPPUNIT_ASSERT_EQUAL(Session::WaitingForStreamStart, session->getState());
+
+ getMockServer()->sendInvalidXML();
+ processEvents();
+
+ CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState());
+ CPPUNIT_ASSERT_EQUAL(Session::XMLError, session->getError());
+ }
+
+ void testStartTLS_NoTLSSupport() {
+ std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ tlsLayerFactory_->setTLSSupported(false);
+ session->start();
+
+ getMockServer()->expectStreamStart();
+ getMockServer()->sendStreamStart();
+ getMockServer()->sendStreamFeaturesWithStartTLS();
+ processEvents();
+ CPPUNIT_ASSERT_EQUAL(Session::SessionStarted, session->getState());
+ }
+
+ void testStartTLS() {
+ std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ session->start();
+
+ getMockServer()->expectStreamStart();
+ getMockServer()->sendStreamStart();
+ getMockServer()->sendStreamFeaturesWithStartTLS();
+ getMockServer()->expectStartTLS();
+ // FIXME: Test 'encrypting' state
+ getMockServer()->sendTLSProceed();
+ processEvents();
+ CPPUNIT_ASSERT_EQUAL(Session::Encrypting, session->getState());
+ CPPUNIT_ASSERT(session->getTLSLayer());
+ CPPUNIT_ASSERT(session->getTLSLayer()->isConnecting());
+
+ getMockServer()->resetParser();
+ getMockServer()->expectStreamStart();
+ getMockServer()->sendStreamStart();
+ session->getTLSLayer()->setConnected();
+ // FIXME: Test 'WatingForStreamStart' state
+ processEvents();
+ CPPUNIT_ASSERT_EQUAL(Session::Negotiating, session->getState());
+ }
+
+ void testStartTLS_ServerError() {
+ std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ session->start();
+
+ getMockServer()->expectStreamStart();
+ getMockServer()->sendStreamStart();
+ getMockServer()->sendStreamFeaturesWithStartTLS();
+ getMockServer()->expectStartTLS();
+ getMockServer()->sendTLSFailure();
+ processEvents();
+
+ CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState());
+ CPPUNIT_ASSERT_EQUAL(Session::TLSError, session->getError());
+ }
+
+ void testStartTLS_ConnectError() {
+ std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ session->start();
+
+ getMockServer()->expectStreamStart();
+ getMockServer()->sendStreamStart();
+ getMockServer()->sendStreamFeaturesWithStartTLS();
+ getMockServer()->expectStartTLS();
+ getMockServer()->sendTLSProceed();
+ processEvents();
+ session->getTLSLayer()->setError();
+
+ CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState());
+ CPPUNIT_ASSERT_EQUAL(Session::TLSError, session->getError());
+ }
+
+ void testStartTLS_ErrorAfterConnect() {
+ std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ session->start();
+
+ getMockServer()->expectStreamStart();
+ getMockServer()->sendStreamStart();
+ getMockServer()->sendStreamFeaturesWithStartTLS();
+ getMockServer()->expectStartTLS();
+ getMockServer()->sendTLSProceed();
+ processEvents();
+ getMockServer()->resetParser();
+ getMockServer()->expectStreamStart();
+ getMockServer()->sendStreamStart();
+ session->getTLSLayer()->setConnected();
+ processEvents();
+
+ session->getTLSLayer()->setError();
+
+ CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState());
+ CPPUNIT_ASSERT_EQUAL(Session::TLSError, session->getError());
+ }
+
+ void testAuthenticate() {
+ std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ session->onNeedCredentials.connect(boost::bind(&SessionTest::setNeedCredentials, this));
+ session->start();
+
+ getMockServer()->expectStreamStart();
+ getMockServer()->sendStreamStart();
+ getMockServer()->sendStreamFeaturesWithAuthentication();
+ processEvents();
+ CPPUNIT_ASSERT_EQUAL(Session::WaitingForCredentials, session->getState());
+ CPPUNIT_ASSERT(needCredentials_);
+
+ getMockServer()->expectAuth("me", "mypass");
+ getMockServer()->sendAuthSuccess();
+ getMockServer()->expectStreamStart();
+ getMockServer()->sendStreamStart();
+ session->sendCredentials("mypass");
+ CPPUNIT_ASSERT_EQUAL(Session::Authenticating, session->getState());
+ processEvents();
+ CPPUNIT_ASSERT_EQUAL(Session::Negotiating, session->getState());
+ }
+
+ void testAuthenticate_Unauthorized() {
+ std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ session->start();
+ getMockServer()->expectStreamStart();
+ getMockServer()->sendStreamStart();
+ getMockServer()->sendStreamFeaturesWithAuthentication();
+ processEvents();
+
+ getMockServer()->expectAuth("me", "mypass");
+ getMockServer()->sendAuthFailure();
+ session->sendCredentials("mypass");
+ processEvents();
+
+ CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState());
+ CPPUNIT_ASSERT_EQUAL(Session::AuthenticationFailedError, session->getError());
+ }
+
+ void testAuthenticate_NoValidAuthMechanisms() {
+ std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ session->start();
+ getMockServer()->expectStreamStart();
+ getMockServer()->sendStreamStart();
+ getMockServer()->sendStreamFeaturesWithUnsupportedAuthentication();
+ processEvents();
+
+ CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState());
+ CPPUNIT_ASSERT_EQUAL(Session::NoSupportedAuthMechanismsError, session->getError());
+ }
+
+ void testResourceBind() {
+ std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ session->start();
+
+ getMockServer()->expectStreamStart();
+ getMockServer()->sendStreamStart();
+ getMockServer()->sendStreamFeaturesWithResourceBind();
+ getMockServer()->expectResourceBind("Bar", "session-bind");
+ // FIXME: Check CPPUNIT_ASSERT_EQUAL(Session::BindingResource, session->getState());
+ getMockServer()->sendResourceBindResponse("me@foo.com/Bar", "session-bind");
+ processEvents();
+
+ CPPUNIT_ASSERT_EQUAL(Session::SessionStarted, session->getState());
+ CPPUNIT_ASSERT_EQUAL(JID("me@foo.com/Bar"), session->getJID());
+ }
+
+ void testResourceBind_ChangeResource() {
+ std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ session->start();
+
+ getMockServer()->expectStreamStart();
+ getMockServer()->sendStreamStart();
+ getMockServer()->sendStreamFeaturesWithResourceBind();
+ getMockServer()->expectResourceBind("Bar", "session-bind");
+ getMockServer()->sendResourceBindResponse("me@foo.com/Bar123", "session-bind");
+ processEvents();
+
+ CPPUNIT_ASSERT_EQUAL(Session::SessionStarted, session->getState());
+ CPPUNIT_ASSERT_EQUAL(JID("me@foo.com/Bar123"), session->getJID());
+ }
+
+ void testResourceBind_EmptyResource() {
+ std::auto_ptr<MockSession> session(createSession("me@foo.com"));
+ session->start();
+
+ getMockServer()->expectStreamStart();
+ getMockServer()->sendStreamStart();
+ getMockServer()->sendStreamFeaturesWithResourceBind();
+ getMockServer()->expectResourceBind("", "session-bind");
+ getMockServer()->sendResourceBindResponse("me@foo.com/NewResource", "session-bind");
+ processEvents();
+
+ CPPUNIT_ASSERT_EQUAL(Session::SessionStarted, session->getState());
+ CPPUNIT_ASSERT_EQUAL(JID("me@foo.com/NewResource"), session->getJID());
+ }
+
+ void testResourceBind_Error() {
+ std::auto_ptr<MockSession> session(createSession("me@foo.com"));
+ session->start();
+
+ getMockServer()->expectStreamStart();
+ getMockServer()->sendStreamStart();
+ getMockServer()->sendStreamFeaturesWithResourceBind();
+ getMockServer()->expectResourceBind("", "session-bind");
+ getMockServer()->sendError("session-bind");
+ processEvents();
+
+ CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState());
+ CPPUNIT_ASSERT_EQUAL(Session::ResourceBindError, session->getError());
+ }
+
+ void testSessionStart() {
+ std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ session->onSessionStarted.connect(boost::bind(&SessionTest::setSessionStarted, this));
+ session->start();
+
+ getMockServer()->expectStreamStart();
+ getMockServer()->sendStreamStart();
+ getMockServer()->sendStreamFeaturesWithSession();
+ getMockServer()->expectSessionStart("session-start");
+ // FIXME: Check CPPUNIT_ASSERT_EQUAL(Session::StartingSession, session->getState());
+ getMockServer()->sendSessionStartResponse("session-start");
+ processEvents();
+
+ CPPUNIT_ASSERT_EQUAL(Session::SessionStarted, session->getState());
+ CPPUNIT_ASSERT(sessionStarted_);
+ }
+
+ void testSessionStart_Error() {
+ std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ session->start();
+
+ getMockServer()->expectStreamStart();
+ getMockServer()->sendStreamStart();
+ getMockServer()->sendStreamFeaturesWithSession();
+ getMockServer()->expectSessionStart("session-start");
+ getMockServer()->sendError("session-start");
+ processEvents();
+
+ CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState());
+ CPPUNIT_ASSERT_EQUAL(Session::SessionStartError, session->getError());
+ }
+
+ void testSessionStart_AfterResourceBind() {
+ std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ session->onSessionStarted.connect(boost::bind(&SessionTest::setSessionStarted, this));
+ session->start();
+ getMockServer()->expectStreamStart();
+ getMockServer()->sendStreamStart();
+ getMockServer()->sendStreamFeaturesWithResourceBindAndSession();
+ getMockServer()->expectResourceBind("Bar", "session-bind");
+ getMockServer()->sendResourceBindResponse("me@foo.com/Bar", "session-bind");
+ getMockServer()->expectSessionStart("session-start");
+ getMockServer()->sendSessionStartResponse("session-start");
+ processEvents();
+
+ CPPUNIT_ASSERT_EQUAL(Session::SessionStarted, session->getState());
+ CPPUNIT_ASSERT(sessionStarted_);
+ }
+
+ void testWhitespacePing() {
+ std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ session->start();
+ getMockServer()->expectStreamStart();
+ getMockServer()->sendStreamStart();
+ getMockServer()->sendStreamFeatures();
+ processEvents();
+ CPPUNIT_ASSERT(session->getWhitespacePingLayer());
+ }
+
+ void testReceiveElementAfterSessionStarted() {
+ std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ session->start();
+ getMockServer()->expectStreamStart();
+ getMockServer()->sendStreamStart();
+ getMockServer()->sendStreamFeatures();
+ processEvents();
+
+ getMockServer()->expectMessage();
+ session->sendElement(boost::shared_ptr<Message>(new Message()));
+ }
+
+ void testSendElement() {
+ std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ session->onElementReceived.connect(boost::bind(&SessionTest::addReceivedElement, this, _1));
+ session->start();
+ getMockServer()->expectStreamStart();
+ getMockServer()->sendStreamStart();
+ getMockServer()->sendStreamFeatures();
+ getMockServer()->sendMessage();
+ processEvents();
+
+ CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(receivedElements_.size()));
+ CPPUNIT_ASSERT(boost::dynamic_pointer_cast<Message>(receivedElements_[0]));
+ }
+
+ private:
+ struct MockConnection;
+
+ MockConnection* getMockServer() const {
+ CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connectionFactory_->connections_.size()));
+ return connectionFactory_->connections_[0];
+ }
+
+ void processEvents() {
+ eventLoop_->processEvents();
+ getMockServer()->assertNoMoreExpectations();
+ }
+
+ void setSessionStarted() {
+ sessionStarted_ = true;
+ }
+
+ void setNeedCredentials() {
+ needCredentials_ = true;
+ }
+
+ void addReceivedElement(boost::shared_ptr<Element> element) {
+ receivedElements_.push_back(element);
+ }
+
+ private:
+ struct MockConnection : public Connection, public XMPPParserClient {
+ struct Event {
+ enum Direction { In, Out };
+ enum Type { StreamStartEvent, StreamEndEvent, ElementEvent };
+
+ Event(
+ Direction direction,
+ Type type,
+ boost::shared_ptr<Element> element = boost::shared_ptr<Element>()) :
+ direction(direction), type(type), element(element) {}
+
+ Direction direction;
+ Type type;
+ boost::shared_ptr<Element> element;
+ };
+
+ MockConnection(const String& domain, bool fail) :
+ Connection(domain),
+ fail_(fail),
+ resetParser_(false),
+ parser_(0),
+ serializer_(&payloadSerializers_) {
+ parser_ = new XMPPParser(this, &payloadParserFactories_);
+ }
+
+ ~MockConnection() {
+ delete parser_;
+ }
+
+ void disconnect() {
+ }
+
+ void connect() {
+ if (fail_) {
+ MainEventLoop::postEvent(boost::bind(boost::ref(onError), Connection::ConnectionError));
+ }
+ else {
+ MainEventLoop::postEvent(boost::bind(boost::ref(onConnected)));
+ }
+ }
+
+ void setError() {
+ MainEventLoop::postEvent(boost::bind(boost::ref(onError), Connection::ConnectionError));
+ }
+
+ void write(const ByteArray& data) {
+ CPPUNIT_ASSERT(parser_->parse(data.toString()));
+ if (resetParser_) {
+ resetParser();
+ resetParser_ = false;
+ }
+ }
+
+ void resetParser() {
+ delete parser_;
+ parser_ = new XMPPParser(this, &payloadParserFactories_);
+ }
+
+ void handleStreamStart() {
+ handleEvent(Event::StreamStartEvent);
+ }
+
+ void handleElement(boost::shared_ptr<Swift::Element> element) {
+ handleEvent(Event::ElementEvent, element);
+ }
+
+ void handleStreamEnd() {
+ handleEvent(Event::StreamEndEvent);
+ }
+
+ void handleEvent(Event::Type type, boost::shared_ptr<Element> element = boost::shared_ptr<Element>()) {
+ CPPUNIT_ASSERT(!events_.empty());
+ CPPUNIT_ASSERT_EQUAL(events_[0].direction, Event::In);
+ CPPUNIT_ASSERT_EQUAL(events_[0].type, type);
+ if (type == Event::ElementEvent) {
+ CPPUNIT_ASSERT_EQUAL(serializer_.serializeElement(events_[0].element), serializer_.serializeElement(element));
+ }
+ events_.pop_front();
+
+ while (!events_.empty() && events_[0].direction == Event::Out) {
+ sendData(serializeEvent(events_[0]));
+ events_.pop_front();
+ }
+
+ if (!events_.empty() && events_[0].type == Event::StreamStartEvent) {
+ resetParser_ = true;
+ }
+ }
+
+ String serializeEvent(const Event& event) {
+ switch (event.type) {
+ case Event::StreamStartEvent:
+ return serializer_.serializeHeader(getDomain());
+ case Event::ElementEvent:
+ return serializer_.serializeElement(event.element);
+ case Event::StreamEndEvent:
+ return serializer_.serializeFooter();
+ }
+ assert(false);
+ }
+
+ void assertNoMoreExpectations() {
+ CPPUNIT_ASSERT(events_.empty());
+ }
+
+ void sendData(const ByteArray& data) {
+ MainEventLoop::postEvent(boost::bind(boost::ref(onDataRead), data));
+ }
+
+ void expectStreamStart() {
+ events_.push_back(Event(Event::In, Event::StreamStartEvent));
+ }
+
+ void expectStartTLS() {
+ events_.push_back(Event(Event::In, Event::ElementEvent, boost::shared_ptr<StartTLSRequest>(new StartTLSRequest())));
+ }
+
+ void expectAuth(const String& user, const String& password) {
+ String s = String("") + '\0' + user + '\0' + password;
+ events_.push_back(Event(Event::In, Event::ElementEvent, boost::shared_ptr<AuthRequest>(new AuthRequest("PLAIN", ByteArray(s.getUTF8Data(), s.getUTF8Size())))));
+ }
+
+ void expectResourceBind(const String& resource, const String& id) {
+ boost::shared_ptr<ResourceBind> sessionStart(new ResourceBind());
+ sessionStart->setResource(resource);
+ events_.push_back(Event(Event::In, Event::ElementEvent, IQ::createRequest(IQ::Set, JID(), id, sessionStart)));
+ }
+
+ void expectSessionStart(const String& id) {
+ events_.push_back(Event(Event::In, Event::ElementEvent, IQ::createRequest(IQ::Set, JID(), id, boost::shared_ptr<StartSession>(new StartSession()))));
+ }
+
+ void expectMessage() {
+ events_.push_back(Event(Event::In, Event::ElementEvent, boost::shared_ptr<Message>(new Message())));
+ }
+
+ void sendInvalidXML() {
+ sendData("<invalid xml/>");
+ }
+
+ void sendStreamStart() {
+ events_.push_back(Event(Event::Out, Event::StreamStartEvent));
+ }
+
+ void sendStreamFeatures() {
+ boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures());
+ events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures));
+ }
+
+ void sendStreamFeaturesWithStartTLS() {
+ boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures());
+ streamFeatures->setHasStartTLS();
+ events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures));
+ }
+
+ void sendStreamFeaturesWithAuthentication() {
+ boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures());
+ streamFeatures->addAuthenticationMechanism("PLAIN");
+ events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures));
+ }
+
+ void sendStreamFeaturesWithUnsupportedAuthentication() {
+ boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures());
+ streamFeatures->addAuthenticationMechanism("MY-UNSUPPORTED-MECH");
+ events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures));
+ }
+
+ void sendStreamFeaturesWithResourceBind() {
+ boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures());
+ streamFeatures->setHasResourceBind();
+ events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures));
+ }
+
+ void sendStreamFeaturesWithSession() {
+ boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures());
+ streamFeatures->setHasSession();
+ events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures));
+ }
+
+ void sendStreamFeaturesWithResourceBindAndSession() {
+ boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures());
+ streamFeatures->setHasResourceBind();
+ streamFeatures->setHasSession();
+ events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures));
+ }
+
+ void sendMessage() {
+ events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<Message>(new Message())));
+ }
+
+ void sendTLSProceed() {
+ events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<TLSProceed>(new TLSProceed())));
+ }
+
+ void sendTLSFailure() {
+ events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<StartTLSFailure>(new StartTLSFailure())));
+ }
+
+ void sendAuthSuccess() {
+ events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<AuthSuccess>(new AuthSuccess())));
+ }
+
+ void sendAuthFailure() {
+ events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<AuthFailure>(new AuthFailure())));
+ }
+
+ void sendResourceBindResponse(const String& jid, const String& id) {
+ boost::shared_ptr<ResourceBind> sessionStart(new ResourceBind());
+ sessionStart->setJID(JID(jid));
+ events_.push_back(Event(Event::Out, Event::ElementEvent, IQ::createResult(JID(), id, sessionStart)));
+ }
+
+ void sendError(const String& id) {
+ events_.push_back(Event(Event::Out, Event::ElementEvent, IQ::createError(JID(), id, Swift::Error::NotAllowed, Swift::Error::Cancel)));
+ }
+
+ void sendSessionStartResponse(const String& id) {
+ events_.push_back(Event(Event::Out, Event::ElementEvent, IQ::createResult(JID(), id, boost::shared_ptr<StartSession>(new StartSession()))));
+ }
+
+ bool fail_;
+ bool resetParser_;
+ FullPayloadParserFactoryCollection payloadParserFactories_;
+ FullPayloadSerializerCollection payloadSerializers_;
+ XMPPParser* parser_;
+ XMPPSerializer serializer_;
+ std::deque<Event> events_;
+ };
+
+ struct MockConnectionFactory : public ConnectionFactory {
+ MockConnectionFactory() : fail_(false) {}
+ MockConnection* createConnection(const String& domain) {
+ MockConnection* result = new MockConnection(domain, fail_);
+ connections_.push_back(result);
+ return result;
+ }
+ void setCreateFailingConnections() {
+ fail_ = true;
+ }
+ std::vector<MockConnection*> connections_;
+ bool fail_;
+ };
+
+ struct MockTLSLayer : public TLSLayer {
+ MockTLSLayer() : connecting_(false) {}
+ bool setClientCertificate(const PKCS12Certificate&) { return true; }
+ void writeData(const ByteArray& data) { onWriteData(data); }
+ void handleDataRead(const ByteArray& data) { onDataRead(data); }
+ void setConnected() { onConnected(); }
+ void setError() { onError(); }
+ void connect() { connecting_ = true; }
+ bool isConnecting() { return connecting_; }
+
+ bool connecting_;
+ };
+
+ struct MockTLSLayerFactory : public TLSLayerFactory {
+ MockTLSLayerFactory() : haveTLS_(true) {}
+ void setTLSSupported(bool b) { haveTLS_ = b; }
+ virtual bool canCreate() const { return haveTLS_; }
+ virtual TLSLayer* createTLSLayer() {
+ assert(haveTLS_);
+ MockTLSLayer* result = new MockTLSLayer();
+ layers_.push_back(result);
+ return result;
+ }
+ std::vector<MockTLSLayer*> layers_;
+ bool haveTLS_;
+ };
+
+ struct MockSession : public Session {
+ MockSession(const JID& jid, ConnectionFactory* connectionFactory, TLSLayerFactory* tlsLayerFactory, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers) : Session(jid, connectionFactory, tlsLayerFactory, payloadParserFactories, payloadSerializers) {}
+
+ MockTLSLayer* getTLSLayer() const {
+ return getStreamStack()->getLayer<MockTLSLayer>();
+ }
+ WhitespacePingLayer* getWhitespacePingLayer() const {
+ return getStreamStack()->getLayer<WhitespacePingLayer>();
+ }
+ };
+
+ MockSession* createSession(const String& jid) {
+ return new MockSession(JID(jid), connectionFactory_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_);
+ }
+
+
+ DummyEventLoop* eventLoop_;
+ MockConnectionFactory* connectionFactory_;
+ MockTLSLayerFactory* tlsLayerFactory_;
+ FullPayloadParserFactoryCollection payloadParserFactories_;
+ FullPayloadSerializerCollection payloadSerializers_;
+ bool sessionStarted_;
+ bool needCredentials_;
+ std::vector< boost::shared_ptr<Element> > receivedElements_;
+ typedef std::vector< boost::function<void ()> > EventQueue;
+ EventQueue events_;
+};
+
+CPPUNIT_TEST_SUITE_REGISTRATION(SessionTest);