From 9ccf1973ec3e23e4ba061b774c3f3e3bde4f1040 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Remko=20Tron=C3=A7on?= <git@el-tramo.be> Date: Sun, 19 Jul 2009 14:27:15 +0200 Subject: Rename Session to ClientSession. diff --git a/Swiften/Client/Client.cpp b/Swiften/Client/Client.cpp index 04a24bf..a38416a 100644 --- a/Swiften/Client/Client.cpp +++ b/Swiften/Client/Client.cpp @@ -3,7 +3,7 @@ #include <boost/bind.hpp> #include "Swiften/Network/DomainNameResolver.h" -#include "Swiften/Client/Session.h" +#include "Swiften/Client/ClientSession.h" #include "Swiften/StreamStack/PlatformTLSLayerFactory.h" #include "Swiften/Network/BoostConnectionFactory.h" #include "Swiften/Network/DomainNameResolveException.h" @@ -44,7 +44,7 @@ void Client::handleConnectionConnectFinished(bool error) { onError(ClientError::ConnectionError); } else { - session_ = new Session(jid_, connection_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_); + session_ = new ClientSession(jid_, connection_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_); if (!certificate_.isEmpty()) { session_->setCertificate(PKCS12Certificate(certificate_, password_)); } @@ -108,43 +108,43 @@ void Client::setCertificate(const String& certificate) { certificate_ = certificate; } -void Client::handleSessionError(Session::SessionError error) { +void Client::handleSessionError(ClientSession::SessionError error) { ClientError clientError; switch (error) { - case Session::NoError: + case ClientSession::NoError: assert(false); break; - case Session::ConnectionReadError: + case ClientSession::ConnectionReadError: clientError = ClientError(ClientError::ConnectionReadError); break; - case Session::ConnectionWriteError: + case ClientSession::ConnectionWriteError: clientError = ClientError(ClientError::ConnectionWriteError); break; - case Session::XMLError: + case ClientSession::XMLError: clientError = ClientError(ClientError::XMLError); break; - case Session::AuthenticationFailedError: + case ClientSession::AuthenticationFailedError: clientError = ClientError(ClientError::AuthenticationFailedError); break; - case Session::NoSupportedAuthMechanismsError: + case ClientSession::NoSupportedAuthMechanismsError: clientError = ClientError(ClientError::NoSupportedAuthMechanismsError); break; - case Session::UnexpectedElementError: + case ClientSession::UnexpectedElementError: clientError = ClientError(ClientError::UnexpectedElementError); break; - case Session::ResourceBindError: + case ClientSession::ResourceBindError: clientError = ClientError(ClientError::ResourceBindError); break; - case Session::SessionStartError: + case ClientSession::SessionStartError: clientError = ClientError(ClientError::SessionStartError); break; - case Session::TLSError: + case ClientSession::TLSError: clientError = ClientError(ClientError::TLSError); break; - case Session::ClientCertificateLoadError: + case ClientSession::ClientCertificateLoadError: clientError = ClientError(ClientError::ClientCertificateLoadError); break; - case Session::ClientCertificateError: + case ClientSession::ClientCertificateError: clientError = ClientError(ClientError::ClientCertificateError); break; } diff --git a/Swiften/Client/Client.h b/Swiften/Client/Client.h index 66f9b01..48b76d9 100644 --- a/Swiften/Client/Client.h +++ b/Swiften/Client/Client.h @@ -4,7 +4,7 @@ #include <boost/signals.hpp> #include <boost/shared_ptr.hpp> -#include "Swiften/Client/Session.h" +#include "Swiften/Client/ClientSession.h" #include "Swiften/Client/ClientError.h" #include "Swiften/Elements/Presence.h" #include "Swiften/Elements/Message.h" @@ -20,7 +20,7 @@ namespace Swift { class TLSLayerFactory; class ConnectionFactory; - class Session; + class ClientSession; class Client : public StanzaChannel, public IQRouter { public: @@ -47,7 +47,7 @@ namespace Swift { void send(boost::shared_ptr<Stanza>); virtual String getNewIQID(); void handleElement(boost::shared_ptr<Element>); - void handleSessionError(Session::SessionError error); + void handleSessionError(ClientSession::SessionError error); void handleNeedCredentials(); void handleDataRead(const ByteArray&); void handleDataWritten(const ByteArray&); @@ -61,7 +61,7 @@ namespace Swift { TLSLayerFactory* tlsLayerFactory_; FullPayloadParserFactoryCollection payloadParserFactories_; FullPayloadSerializerCollection payloadSerializers_; - Session* session_; + ClientSession* session_; boost::shared_ptr<Connection> connection_; String certificate_; }; diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp new file mode 100644 index 0000000..11317e8 --- /dev/null +++ b/Swiften/Client/ClientSession.cpp @@ -0,0 +1,273 @@ +#include "Swiften/Client/ClientSession.h" + +#include <boost/bind.hpp> + +#include "Swiften/Network/ConnectionFactory.h" +#include "Swiften/Elements/ProtocolHeader.h" +#include "Swiften/StreamStack/StreamStack.h" +#include "Swiften/StreamStack/ConnectionLayer.h" +#include "Swiften/StreamStack/XMPPLayer.h" +#include "Swiften/StreamStack/TLSLayer.h" +#include "Swiften/StreamStack/TLSLayerFactory.h" +#include "Swiften/Elements/StreamFeatures.h" +#include "Swiften/Elements/StartTLSRequest.h" +#include "Swiften/Elements/StartTLSFailure.h" +#include "Swiften/Elements/TLSProceed.h" +#include "Swiften/Elements/AuthRequest.h" +#include "Swiften/Elements/AuthSuccess.h" +#include "Swiften/Elements/AuthFailure.h" +#include "Swiften/Elements/StartSession.h" +#include "Swiften/Elements/IQ.h" +#include "Swiften/Elements/ResourceBind.h" +#include "Swiften/SASL/PLAINMessage.h" +#include "Swiften/StreamStack/WhitespacePingLayer.h" + +namespace Swift { + +ClientSession::ClientSession( + const JID& jid, + boost::shared_ptr<Connection> connection, + TLSLayerFactory* tlsLayerFactory, + PayloadParserFactoryCollection* payloadParserFactories, + PayloadSerializerCollection* payloadSerializers) : + jid_(jid), + tlsLayerFactory_(tlsLayerFactory), + payloadParserFactories_(payloadParserFactories), + payloadSerializers_(payloadSerializers), + state_(Initial), + error_(NoError), + connection_(connection), + streamStack_(0), + needSessionStart_(false) { +} + +ClientSession::~ClientSession() { + delete streamStack_; +} + +void ClientSession::start() { + assert(state_ == Initial); + + connection_->onDisconnected.connect(boost::bind(&ClientSession::handleDisconnected, this, _1)); + initializeStreamStack(); + state_ = WaitingForStreamStart; + sendStreamHeader(); +} + +void ClientSession::stop() { + // TODO: Send end stream header if applicable + connection_->disconnect(); +} + +void ClientSession::sendStreamHeader() { + ProtocolHeader header; + header.setTo(jid_.getDomain()); + xmppLayer_->writeHeader(header); +} + +void ClientSession::initializeStreamStack() { + xmppLayer_ = boost::shared_ptr<XMPPLayer>(new XMPPLayer(payloadParserFactories_, payloadSerializers_)); + xmppLayer_->onStreamStart.connect(boost::bind(&ClientSession::handleStreamStart, this)); + xmppLayer_->onElement.connect(boost::bind(&ClientSession::handleElement, this, _1)); + xmppLayer_->onError.connect(boost::bind(&ClientSession::setError, this, XMLError)); + xmppLayer_->onDataRead.connect(boost::bind(boost::ref(onDataRead), _1)); + xmppLayer_->onWriteData.connect(boost::bind(boost::ref(onDataWritten), _1)); + connectionLayer_ = boost::shared_ptr<ConnectionLayer>(new ConnectionLayer(connection_)); + streamStack_ = new StreamStack(xmppLayer_, connectionLayer_); +} + +void ClientSession::handleDisconnected(const boost::optional<Connection::Error>& error) { + if (error) { + switch (*error) { + case Connection::ReadError: + setError(ConnectionReadError); + break; + case Connection::WriteError: + setError(ConnectionWriteError); + break; + } + } +} + +void ClientSession::setCertificate(const PKCS12Certificate& certificate) { + certificate_ = certificate; +} + +void ClientSession::handleStreamStart() { + checkState(WaitingForStreamStart); + state_ = Negotiating; +} + +void ClientSession::handleElement(boost::shared_ptr<Element> element) { + if (getState() == SessionStarted) { + onElementReceived(element); + } + else if (StreamFeatures* streamFeatures = dynamic_cast<StreamFeatures*>(element.get())) { + if (!checkState(Negotiating)) { + return; + } + + if (streamFeatures->hasStartTLS() && tlsLayerFactory_->canCreate()) { + state_ = Encrypting; + xmppLayer_->writeElement(boost::shared_ptr<StartTLSRequest>(new StartTLSRequest())); + } + else if (streamFeatures->hasAuthenticationMechanisms()) { + if (!certificate_.isNull()) { + if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { + state_ = Authenticating; + xmppLayer_->writeElement(boost::shared_ptr<Element>(new AuthRequest("EXTERNAL", ""))); + } + else { + setError(ClientCertificateError); + } + } + else if (streamFeatures->hasAuthenticationMechanism("PLAIN")) { + state_ = WaitingForCredentials; + onNeedCredentials(); + } + else { + setError(NoSupportedAuthMechanismsError); + } + } + else { + // Start the session + + // Add a whitespace ping layer + whitespacePingLayer_ = boost::shared_ptr<WhitespacePingLayer>(new WhitespacePingLayer()); + streamStack_->addLayer(whitespacePingLayer_); + + if (streamFeatures->hasSession()) { + needSessionStart_ = true; + } + + if (streamFeatures->hasResourceBind()) { + state_ = BindingResource; + boost::shared_ptr<ResourceBind> resourceBind(new ResourceBind()); + if (!jid_.getResource().isEmpty()) { + resourceBind->setResource(jid_.getResource()); + } + xmppLayer_->writeElement(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind)); + } + else if (needSessionStart_) { + sendSessionStart(); + } + else { + state_ = SessionStarted; + onSessionStarted(); + } + } + } + else if (dynamic_cast<AuthSuccess*>(element.get())) { + checkState(Authenticating); + state_ = WaitingForStreamStart; + xmppLayer_->resetParser(); + sendStreamHeader(); + } + else if (dynamic_cast<AuthFailure*>(element.get())) { + setError(AuthenticationFailedError); + } + else if (dynamic_cast<TLSProceed*>(element.get())) { + tlsLayer_ = tlsLayerFactory_->createTLSLayer(); + streamStack_->addLayer(tlsLayer_); + if (!certificate_.isNull() && !tlsLayer_->setClientCertificate(certificate_)) { + setError(ClientCertificateLoadError); + } + else { + tlsLayer_->onConnected.connect(boost::bind(&ClientSession::handleTLSConnected, this)); + tlsLayer_->onError.connect(boost::bind(&ClientSession::handleTLSError, this)); + tlsLayer_->connect(); + } + } + else if (dynamic_cast<StartTLSFailure*>(element.get())) { + setError(TLSError); + } + else if (IQ* iq = dynamic_cast<IQ*>(element.get())) { + if (state_ == BindingResource) { + boost::shared_ptr<ResourceBind> resourceBind(iq->getPayload<ResourceBind>()); + if (iq->getType() == IQ::Error && iq->getID() == "session-bind") { + setError(ResourceBindError); + } + else if (!resourceBind) { + setError(UnexpectedElementError); + } + else if (iq->getType() == IQ::Result) { + jid_ = resourceBind->getJID(); + if (!jid_.isValid()) { + setError(ResourceBindError); + } + if (needSessionStart_) { + sendSessionStart(); + } + else { + state_ = SessionStarted; + } + } + else { + setError(UnexpectedElementError); + } + } + else if (state_ == StartingSession) { + if (iq->getType() == IQ::Result) { + state_ = SessionStarted; + onSessionStarted(); + } + else if (iq->getType() == IQ::Error) { + setError(SessionStartError); + } + else { + setError(UnexpectedElementError); + } + } + else { + setError(UnexpectedElementError); + } + } + else { + // FIXME Not correct? + state_ = SessionStarted; + onSessionStarted(); + } +} + +void ClientSession::sendSessionStart() { + state_ = StartingSession; + xmppLayer_->writeElement(IQ::createRequest(IQ::Set, JID(), "session-start", boost::shared_ptr<StartSession>(new StartSession()))); +} + +void ClientSession::setError(SessionError error) { + assert(error != NoError); + state_ = Error; + error_ = error; + onError(error); +} + +bool ClientSession::checkState(State state) { + if (state_ != state) { + setError(UnexpectedElementError); + return false; + } + return true; +} + +void ClientSession::sendCredentials(const String& password) { + assert(WaitingForCredentials); + state_ = Authenticating; + xmppLayer_->writeElement(boost::shared_ptr<Element>(new AuthRequest("PLAIN", PLAINMessage(jid_.getNode(), password).getValue()))); +} + +void ClientSession::sendElement(boost::shared_ptr<Element> element) { + assert(SessionStarted); + xmppLayer_->writeElement(element); +} + +void ClientSession::handleTLSConnected() { + state_ = WaitingForStreamStart; + xmppLayer_->resetParser(); + sendStreamHeader(); +} + +void ClientSession::handleTLSError() { + setError(TLSError); +} + +} diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h new file mode 100644 index 0000000..50dae24 --- /dev/null +++ b/Swiften/Client/ClientSession.h @@ -0,0 +1,124 @@ +#pragma once + +#include <boost/signal.hpp> +#include <boost/shared_ptr.hpp> + +#include "Swiften/Base/String.h" +#include "Swiften/JID/JID.h" +#include "Swiften/Elements/Element.h" +#include "Swiften/Network/Connection.h" +#include "Swiften/TLS/PKCS12Certificate.h" + +namespace Swift { + class PayloadParserFactoryCollection; + class PayloadSerializerCollection; + class ConnectionFactory; + class Connection; + class StreamStack; + class XMPPLayer; + class ConnectionLayer; + class TLSLayerFactory; + class TLSLayer; + class WhitespacePingLayer; + + class ClientSession { + public: + enum State { + Initial, + WaitingForStreamStart, + Negotiating, + Compressing, + Encrypting, + WaitingForCredentials, + Authenticating, + BindingResource, + StartingSession, + SessionStarted, + Error + }; + enum SessionError { + NoError, + ConnectionReadError, + ConnectionWriteError, + XMLError, + AuthenticationFailedError, + NoSupportedAuthMechanismsError, + UnexpectedElementError, + ResourceBindError, + SessionStartError, + TLSError, + ClientCertificateLoadError, + ClientCertificateError + }; + + ClientSession( + const JID& jid, + boost::shared_ptr<Connection>, + TLSLayerFactory*, + PayloadParserFactoryCollection*, + PayloadSerializerCollection*); + ~ClientSession(); + + State getState() const { + return state_; + } + + SessionError getError() const { + return error_; + } + + const JID& getJID() const { + return jid_; + } + + void start(); + void stop(); + + void sendCredentials(const String& password); + void sendElement(boost::shared_ptr<Element>); + void setCertificate(const PKCS12Certificate& certificate); + + protected: + StreamStack* getStreamStack() const { + return streamStack_; + } + + private: + void initializeStreamStack(); + void sendStreamHeader(); + void sendSessionStart(); + + void handleDisconnected(const boost::optional<Connection::Error>&); + void handleElement(boost::shared_ptr<Element>); + void handleStreamStart(); + void handleTLSConnected(); + void handleTLSError(); + + void setError(SessionError); + bool checkState(State); + + public: + boost::signal<void ()> onSessionStarted; + boost::signal<void (SessionError)> onError; + boost::signal<void ()> onNeedCredentials; + boost::signal<void (boost::shared_ptr<Element>) > onElementReceived; + boost::signal<void (const ByteArray&)> onDataWritten; + boost::signal<void (const ByteArray&)> onDataRead; + + private: + JID jid_; + TLSLayerFactory* tlsLayerFactory_; + PayloadParserFactoryCollection* payloadParserFactories_; + PayloadSerializerCollection* payloadSerializers_; + State state_; + SessionError error_; + boost::shared_ptr<Connection> connection_; + boost::shared_ptr<XMPPLayer> xmppLayer_; + boost::shared_ptr<TLSLayer> tlsLayer_; + boost::shared_ptr<ConnectionLayer> connectionLayer_; + boost::shared_ptr<WhitespacePingLayer> whitespacePingLayer_; + StreamStack* streamStack_; + bool needSessionStart_; + PKCS12Certificate certificate_; + }; +} diff --git a/Swiften/Client/Makefile.inc b/Swiften/Client/Makefile.inc index 75eb08f..8171ed1 100644 --- a/Swiften/Client/Makefile.inc +++ b/Swiften/Client/Makefile.inc @@ -1,5 +1,5 @@ SWIFTEN_SOURCES += \ Swiften/Client/Client.cpp \ - Swiften/Client/Session.cpp + Swiften/Client/ClientSession.cpp include Swiften/Client/UnitTest/Makefile.inc diff --git a/Swiften/Client/Session.cpp b/Swiften/Client/Session.cpp deleted file mode 100644 index 1bd2b22..0000000 --- a/Swiften/Client/Session.cpp +++ /dev/null @@ -1,273 +0,0 @@ -#include "Swiften/Client/Session.h" - -#include <boost/bind.hpp> - -#include "Swiften/Network/ConnectionFactory.h" -#include "Swiften/Elements/ProtocolHeader.h" -#include "Swiften/StreamStack/StreamStack.h" -#include "Swiften/StreamStack/ConnectionLayer.h" -#include "Swiften/StreamStack/XMPPLayer.h" -#include "Swiften/StreamStack/TLSLayer.h" -#include "Swiften/StreamStack/TLSLayerFactory.h" -#include "Swiften/Elements/StreamFeatures.h" -#include "Swiften/Elements/StartTLSRequest.h" -#include "Swiften/Elements/StartTLSFailure.h" -#include "Swiften/Elements/TLSProceed.h" -#include "Swiften/Elements/AuthRequest.h" -#include "Swiften/Elements/AuthSuccess.h" -#include "Swiften/Elements/AuthFailure.h" -#include "Swiften/Elements/StartSession.h" -#include "Swiften/Elements/IQ.h" -#include "Swiften/Elements/ResourceBind.h" -#include "Swiften/SASL/PLAINMessage.h" -#include "Swiften/StreamStack/WhitespacePingLayer.h" - -namespace Swift { - -Session::Session( - const JID& jid, - boost::shared_ptr<Connection> connection, - TLSLayerFactory* tlsLayerFactory, - PayloadParserFactoryCollection* payloadParserFactories, - PayloadSerializerCollection* payloadSerializers) : - jid_(jid), - tlsLayerFactory_(tlsLayerFactory), - payloadParserFactories_(payloadParserFactories), - payloadSerializers_(payloadSerializers), - state_(Initial), - error_(NoError), - connection_(connection), - streamStack_(0), - needSessionStart_(false) { -} - -Session::~Session() { - delete streamStack_; -} - -void Session::start() { - assert(state_ == Initial); - - connection_->onDisconnected.connect(boost::bind(&Session::handleDisconnected, this, _1)); - initializeStreamStack(); - state_ = WaitingForStreamStart; - sendStreamHeader(); -} - -void Session::stop() { - // TODO: Send end stream header if applicable - connection_->disconnect(); -} - -void Session::sendStreamHeader() { - ProtocolHeader header; - header.setTo(jid_.getDomain()); - xmppLayer_->writeHeader(header); -} - -void Session::initializeStreamStack() { - xmppLayer_ = boost::shared_ptr<XMPPLayer>(new XMPPLayer(payloadParserFactories_, payloadSerializers_)); - xmppLayer_->onStreamStart.connect(boost::bind(&Session::handleStreamStart, this)); - xmppLayer_->onElement.connect(boost::bind(&Session::handleElement, this, _1)); - xmppLayer_->onError.connect(boost::bind(&Session::setError, this, XMLError)); - xmppLayer_->onDataRead.connect(boost::bind(boost::ref(onDataRead), _1)); - xmppLayer_->onWriteData.connect(boost::bind(boost::ref(onDataWritten), _1)); - connectionLayer_ = boost::shared_ptr<ConnectionLayer>(new ConnectionLayer(connection_)); - streamStack_ = new StreamStack(xmppLayer_, connectionLayer_); -} - -void Session::handleDisconnected(const boost::optional<Connection::Error>& error) { - if (error) { - switch (*error) { - case Connection::ReadError: - setError(ConnectionReadError); - break; - case Connection::WriteError: - setError(ConnectionWriteError); - break; - } - } -} - -void Session::setCertificate(const PKCS12Certificate& certificate) { - certificate_ = certificate; -} - -void Session::handleStreamStart() { - checkState(WaitingForStreamStart); - state_ = Negotiating; -} - -void Session::handleElement(boost::shared_ptr<Element> element) { - if (getState() == SessionStarted) { - onElementReceived(element); - } - else if (StreamFeatures* streamFeatures = dynamic_cast<StreamFeatures*>(element.get())) { - if (!checkState(Negotiating)) { - return; - } - - if (streamFeatures->hasStartTLS() && tlsLayerFactory_->canCreate()) { - state_ = Encrypting; - xmppLayer_->writeElement(boost::shared_ptr<StartTLSRequest>(new StartTLSRequest())); - } - else if (streamFeatures->hasAuthenticationMechanisms()) { - if (!certificate_.isNull()) { - if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { - state_ = Authenticating; - xmppLayer_->writeElement(boost::shared_ptr<Element>(new AuthRequest("EXTERNAL", ""))); - } - else { - setError(ClientCertificateError); - } - } - else if (streamFeatures->hasAuthenticationMechanism("PLAIN")) { - state_ = WaitingForCredentials; - onNeedCredentials(); - } - else { - setError(NoSupportedAuthMechanismsError); - } - } - else { - // Start the session - - // Add a whitespace ping layer - whitespacePingLayer_ = boost::shared_ptr<WhitespacePingLayer>(new WhitespacePingLayer()); - streamStack_->addLayer(whitespacePingLayer_); - - if (streamFeatures->hasSession()) { - needSessionStart_ = true; - } - - if (streamFeatures->hasResourceBind()) { - state_ = BindingResource; - boost::shared_ptr<ResourceBind> resourceBind(new ResourceBind()); - if (!jid_.getResource().isEmpty()) { - resourceBind->setResource(jid_.getResource()); - } - xmppLayer_->writeElement(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind)); - } - else if (needSessionStart_) { - sendSessionStart(); - } - else { - state_ = SessionStarted; - onSessionStarted(); - } - } - } - else if (dynamic_cast<AuthSuccess*>(element.get())) { - checkState(Authenticating); - state_ = WaitingForStreamStart; - xmppLayer_->resetParser(); - sendStreamHeader(); - } - else if (dynamic_cast<AuthFailure*>(element.get())) { - setError(AuthenticationFailedError); - } - else if (dynamic_cast<TLSProceed*>(element.get())) { - tlsLayer_ = tlsLayerFactory_->createTLSLayer(); - streamStack_->addLayer(tlsLayer_); - if (!certificate_.isNull() && !tlsLayer_->setClientCertificate(certificate_)) { - setError(ClientCertificateLoadError); - } - else { - tlsLayer_->onConnected.connect(boost::bind(&Session::handleTLSConnected, this)); - tlsLayer_->onError.connect(boost::bind(&Session::handleTLSError, this)); - tlsLayer_->connect(); - } - } - else if (dynamic_cast<StartTLSFailure*>(element.get())) { - setError(TLSError); - } - else if (IQ* iq = dynamic_cast<IQ*>(element.get())) { - if (state_ == BindingResource) { - boost::shared_ptr<ResourceBind> resourceBind(iq->getPayload<ResourceBind>()); - if (iq->getType() == IQ::Error && iq->getID() == "session-bind") { - setError(ResourceBindError); - } - else if (!resourceBind) { - setError(UnexpectedElementError); - } - else if (iq->getType() == IQ::Result) { - jid_ = resourceBind->getJID(); - if (!jid_.isValid()) { - setError(ResourceBindError); - } - if (needSessionStart_) { - sendSessionStart(); - } - else { - state_ = SessionStarted; - } - } - else { - setError(UnexpectedElementError); - } - } - else if (state_ == StartingSession) { - if (iq->getType() == IQ::Result) { - state_ = SessionStarted; - onSessionStarted(); - } - else if (iq->getType() == IQ::Error) { - setError(SessionStartError); - } - else { - setError(UnexpectedElementError); - } - } - else { - setError(UnexpectedElementError); - } - } - else { - // FIXME Not correct? - state_ = SessionStarted; - onSessionStarted(); - } -} - -void Session::sendSessionStart() { - state_ = StartingSession; - xmppLayer_->writeElement(IQ::createRequest(IQ::Set, JID(), "session-start", boost::shared_ptr<StartSession>(new StartSession()))); -} - -void Session::setError(SessionError error) { - assert(error != NoError); - state_ = Error; - error_ = error; - onError(error); -} - -bool Session::checkState(State state) { - if (state_ != state) { - setError(UnexpectedElementError); - return false; - } - return true; -} - -void Session::sendCredentials(const String& password) { - assert(WaitingForCredentials); - state_ = Authenticating; - xmppLayer_->writeElement(boost::shared_ptr<Element>(new AuthRequest("PLAIN", PLAINMessage(jid_.getNode(), password).getValue()))); -} - -void Session::sendElement(boost::shared_ptr<Element> element) { - assert(SessionStarted); - xmppLayer_->writeElement(element); -} - -void Session::handleTLSConnected() { - state_ = WaitingForStreamStart; - xmppLayer_->resetParser(); - sendStreamHeader(); -} - -void Session::handleTLSError() { - setError(TLSError); -} - -} diff --git a/Swiften/Client/Session.h b/Swiften/Client/Session.h deleted file mode 100644 index 58531b3..0000000 --- a/Swiften/Client/Session.h +++ /dev/null @@ -1,128 +0,0 @@ -#ifndef SWIFTEN_Session_H -#define SWIFTEN_Session_H - -#include <boost/signal.hpp> -#include <boost/shared_ptr.hpp> - -#include "Swiften/Base/String.h" -#include "Swiften/JID/JID.h" -#include "Swiften/Elements/Element.h" -#include "Swiften/Network/Connection.h" -#include "Swiften/TLS/PKCS12Certificate.h" - -namespace Swift { - class PayloadParserFactoryCollection; - class PayloadSerializerCollection; - class ConnectionFactory; - class Connection; - class StreamStack; - class XMPPLayer; - class ConnectionLayer; - class TLSLayerFactory; - class TLSLayer; - class WhitespacePingLayer; - - class Session { - public: - enum State { - Initial, - WaitingForStreamStart, - Negotiating, - Compressing, - Encrypting, - WaitingForCredentials, - Authenticating, - BindingResource, - StartingSession, - SessionStarted, - Error - }; - enum SessionError { - NoError, - ConnectionReadError, - ConnectionWriteError, - XMLError, - AuthenticationFailedError, - NoSupportedAuthMechanismsError, - UnexpectedElementError, - ResourceBindError, - SessionStartError, - TLSError, - ClientCertificateLoadError, - ClientCertificateError - }; - - Session( - const JID& jid, - boost::shared_ptr<Connection>, - TLSLayerFactory*, - PayloadParserFactoryCollection*, - PayloadSerializerCollection*); - ~Session(); - - State getState() const { - return state_; - } - - SessionError getError() const { - return error_; - } - - const JID& getJID() const { - return jid_; - } - - void start(); - void stop(); - - void sendCredentials(const String& password); - void sendElement(boost::shared_ptr<Element>); - void setCertificate(const PKCS12Certificate& certificate); - - protected: - StreamStack* getStreamStack() const { - return streamStack_; - } - - private: - void initializeStreamStack(); - void sendStreamHeader(); - void sendSessionStart(); - - void handleDisconnected(const boost::optional<Connection::Error>&); - void handleElement(boost::shared_ptr<Element>); - void handleStreamStart(); - void handleTLSConnected(); - void handleTLSError(); - - void setError(SessionError); - bool checkState(State); - - public: - boost::signal<void ()> onSessionStarted; - boost::signal<void (SessionError)> onError; - boost::signal<void ()> onNeedCredentials; - boost::signal<void (boost::shared_ptr<Element>) > onElementReceived; - boost::signal<void (const ByteArray&)> onDataWritten; - boost::signal<void (const ByteArray&)> onDataRead; - - private: - JID jid_; - TLSLayerFactory* tlsLayerFactory_; - PayloadParserFactoryCollection* payloadParserFactories_; - PayloadSerializerCollection* payloadSerializers_; - State state_; - SessionError error_; - boost::shared_ptr<Connection> connection_; - boost::shared_ptr<XMPPLayer> xmppLayer_; - boost::shared_ptr<TLSLayer> tlsLayer_; - boost::shared_ptr<ConnectionLayer> connectionLayer_; - boost::shared_ptr<WhitespacePingLayer> whitespacePingLayer_; - StreamStack* streamStack_; - bool needSessionStart_; - PKCS12Certificate certificate_; - }; - -} - -#endif diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp new file mode 100644 index 0000000..1e66019 --- /dev/null +++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp @@ -0,0 +1,704 @@ +#include <cppunit/extensions/HelperMacros.h> +#include <cppunit/extensions/TestFactoryRegistry.h> +#include <boost/bind.hpp> +#include <boost/function.hpp> +#include <boost/optional.hpp> + +#include "Swiften/Parser/XMPPParser.h" +#include "Swiften/Parser/XMPPParserClient.h" +#include "Swiften/Serializer/XMPPSerializer.h" +#include "Swiften/StreamStack/TLSLayerFactory.h" +#include "Swiften/StreamStack/TLSLayer.h" +#include "Swiften/StreamStack/StreamStack.h" +#include "Swiften/StreamStack/WhitespacePingLayer.h" +#include "Swiften/Elements/ProtocolHeader.h" +#include "Swiften/Elements/StreamFeatures.h" +#include "Swiften/Elements/Element.h" +#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/IQ.h" +#include "Swiften/Elements/AuthRequest.h" +#include "Swiften/Elements/AuthSuccess.h" +#include "Swiften/Elements/AuthFailure.h" +#include "Swiften/Elements/ResourceBind.h" +#include "Swiften/Elements/StartSession.h" +#include "Swiften/Elements/StartTLSRequest.h" +#include "Swiften/Elements/StartTLSFailure.h" +#include "Swiften/Elements/TLSProceed.h" +#include "Swiften/Elements/Message.h" +#include "Swiften/EventLoop/MainEventLoop.h" +#include "Swiften/EventLoop/DummyEventLoop.h" +#include "Swiften/Network/Connection.h" +#include "Swiften/Network/ConnectionFactory.h" +#include "Swiften/Client/ClientSession.h" +#include "Swiften/TLS/PKCS12Certificate.h" +#include "Swiften/Parser/PayloadParsers/FullPayloadParserFactoryCollection.h" +#include "Swiften/Serializer/PayloadSerializers/FullPayloadSerializerCollection.h" + +using namespace Swift; + +class ClientSessionTest : public CppUnit::TestFixture { + CPPUNIT_TEST_SUITE(ClientSessionTest); + CPPUNIT_TEST(testConstructor); + CPPUNIT_TEST(testStart_Error); + CPPUNIT_TEST(testStart_XMLError); + CPPUNIT_TEST(testStartTLS); + CPPUNIT_TEST(testStartTLS_ServerError); + CPPUNIT_TEST(testStartTLS_NoTLSSupport); + CPPUNIT_TEST(testStartTLS_ConnectError); + CPPUNIT_TEST(testStartTLS_ErrorAfterConnect); + CPPUNIT_TEST(testAuthenticate); + CPPUNIT_TEST(testAuthenticate_Unauthorized); + CPPUNIT_TEST(testAuthenticate_NoValidAuthMechanisms); + CPPUNIT_TEST(testResourceBind); + CPPUNIT_TEST(testResourceBind_ChangeResource); + CPPUNIT_TEST(testResourceBind_EmptyResource); + CPPUNIT_TEST(testResourceBind_Error); + CPPUNIT_TEST(testSessionStart); + CPPUNIT_TEST(testSessionStart_Error); + CPPUNIT_TEST(testSessionStart_AfterResourceBind); + CPPUNIT_TEST(testWhitespacePing); + CPPUNIT_TEST(testReceiveElementAfterSessionStarted); + CPPUNIT_TEST(testSendElement); + CPPUNIT_TEST_SUITE_END(); + + public: + ClientSessionTest() {} + + void setUp() { + eventLoop_ = new DummyEventLoop(); + connection_ = boost::shared_ptr<MockConnection>(new MockConnection()); + tlsLayerFactory_ = new MockTLSLayerFactory(); + sessionStarted_ = false; + needCredentials_ = false; + } + + void tearDown() { + delete tlsLayerFactory_; + delete eventLoop_; + } + + void testConstructor() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + CPPUNIT_ASSERT_EQUAL(ClientSession::Initial, session->getState()); + } + + void testStart_Error() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + + getMockServer()->expectStreamStart(); + session->start(); + processEvents(); + CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForStreamStart, session->getState()); + + getMockServer()->setError(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::ConnectionReadError, session->getError()); + } + + void testStart_XMLError() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + + getMockServer()->expectStreamStart(); + session->start(); + processEvents(); + CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForStreamStart, session->getState()); + + getMockServer()->sendInvalidXML(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::XMLError, session->getError()); + } + + void testStartTLS_NoTLSSupport() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + tlsLayerFactory_->setTLSSupported(false); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithStartTLS(); + session->start(); + processEvents(); + CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState()); + } + + void testStartTLS() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithStartTLS(); + getMockServer()->expectStartTLS(); + // FIXME: Test 'encrypting' state + getMockServer()->sendTLSProceed(); + session->start(); + processEvents(); + CPPUNIT_ASSERT_EQUAL(ClientSession::Encrypting, session->getState()); + CPPUNIT_ASSERT(session->getTLSLayer()); + CPPUNIT_ASSERT(session->getTLSLayer()->isConnecting()); + + getMockServer()->resetParser(); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + session->getTLSLayer()->setConnected(); + // FIXME: Test 'WatingForStreamStart' state + processEvents(); + CPPUNIT_ASSERT_EQUAL(ClientSession::Negotiating, session->getState()); + } + + void testStartTLS_ServerError() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithStartTLS(); + getMockServer()->expectStartTLS(); + getMockServer()->sendTLSFailure(); + session->start(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::TLSError, session->getError()); + } + + void testStartTLS_ConnectError() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithStartTLS(); + getMockServer()->expectStartTLS(); + getMockServer()->sendTLSProceed(); + session->start(); + processEvents(); + session->getTLSLayer()->setError(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::TLSError, session->getError()); + } + + void testStartTLS_ErrorAfterConnect() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithStartTLS(); + getMockServer()->expectStartTLS(); + getMockServer()->sendTLSProceed(); + session->start(); + processEvents(); + getMockServer()->resetParser(); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + session->getTLSLayer()->setConnected(); + processEvents(); + + session->getTLSLayer()->setError(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::TLSError, session->getError()); + } + + void testAuthenticate() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + session->onNeedCredentials.connect(boost::bind(&ClientSessionTest::setNeedCredentials, this)); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithAuthentication(); + session->start(); + processEvents(); + CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForCredentials, session->getState()); + CPPUNIT_ASSERT(needCredentials_); + + getMockServer()->expectAuth("me", "mypass"); + getMockServer()->sendAuthSuccess(); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + session->sendCredentials("mypass"); + CPPUNIT_ASSERT_EQUAL(ClientSession::Authenticating, session->getState()); + processEvents(); + CPPUNIT_ASSERT_EQUAL(ClientSession::Negotiating, session->getState()); + } + + void testAuthenticate_Unauthorized() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithAuthentication(); + session->start(); + processEvents(); + + getMockServer()->expectAuth("me", "mypass"); + getMockServer()->sendAuthFailure(); + session->sendCredentials("mypass"); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::AuthenticationFailedError, session->getError()); + } + + void testAuthenticate_NoValidAuthMechanisms() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithUnsupportedAuthentication(); + session->start(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::NoSupportedAuthMechanismsError, session->getError()); + } + + void testResourceBind() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithResourceBind(); + getMockServer()->expectResourceBind("Bar", "session-bind"); + // FIXME: Check CPPUNIT_ASSERT_EQUAL(ClientSession::BindingResource, session->getState()); + getMockServer()->sendResourceBindResponse("me@foo.com/Bar", "session-bind"); + session->start(); + + processEvents(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState()); + CPPUNIT_ASSERT_EQUAL(JID("me@foo.com/Bar"), session->getJID()); + } + + void testResourceBind_ChangeResource() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithResourceBind(); + getMockServer()->expectResourceBind("Bar", "session-bind"); + getMockServer()->sendResourceBindResponse("me@foo.com/Bar123", "session-bind"); + session->start(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState()); + CPPUNIT_ASSERT_EQUAL(JID("me@foo.com/Bar123"), session->getJID()); + } + + void testResourceBind_EmptyResource() { + std::auto_ptr<MockSession> session(createSession("me@foo.com")); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithResourceBind(); + getMockServer()->expectResourceBind("", "session-bind"); + getMockServer()->sendResourceBindResponse("me@foo.com/NewResource", "session-bind"); + session->start(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState()); + CPPUNIT_ASSERT_EQUAL(JID("me@foo.com/NewResource"), session->getJID()); + } + + void testResourceBind_Error() { + std::auto_ptr<MockSession> session(createSession("me@foo.com")); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithResourceBind(); + getMockServer()->expectResourceBind("", "session-bind"); + getMockServer()->sendError("session-bind"); + session->start(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::ResourceBindError, session->getError()); + } + + void testSessionStart() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + session->onSessionStarted.connect(boost::bind(&ClientSessionTest::setSessionStarted, this)); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithSession(); + getMockServer()->expectSessionStart("session-start"); + // FIXME: Check CPPUNIT_ASSERT_EQUAL(ClientSession::StartingSession, session->getState()); + getMockServer()->sendSessionStartResponse("session-start"); + session->start(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState()); + CPPUNIT_ASSERT(sessionStarted_); + } + + void testSessionStart_Error() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithSession(); + getMockServer()->expectSessionStart("session-start"); + getMockServer()->sendError("session-start"); + session->start(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStartError, session->getError()); + } + + void testSessionStart_AfterResourceBind() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + session->onSessionStarted.connect(boost::bind(&ClientSessionTest::setSessionStarted, this)); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithResourceBindAndSession(); + getMockServer()->expectResourceBind("Bar", "session-bind"); + getMockServer()->sendResourceBindResponse("me@foo.com/Bar", "session-bind"); + getMockServer()->expectSessionStart("session-start"); + getMockServer()->sendSessionStartResponse("session-start"); + session->start(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState()); + CPPUNIT_ASSERT(sessionStarted_); + } + + void testWhitespacePing() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeatures(); + session->start(); + processEvents(); + CPPUNIT_ASSERT(session->getWhitespacePingLayer()); + } + + void testReceiveElementAfterSessionStarted() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeatures(); + session->start(); + processEvents(); + + getMockServer()->expectMessage(); + session->sendElement(boost::shared_ptr<Message>(new Message())); + } + + void testSendElement() { + std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + session->onElementReceived.connect(boost::bind(&ClientSessionTest::addReceivedElement, this, _1)); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeatures(); + getMockServer()->sendMessage(); + session->start(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(receivedElements_.size())); + CPPUNIT_ASSERT(boost::dynamic_pointer_cast<Message>(receivedElements_[0])); + } + + private: + struct MockConnection; + + boost::shared_ptr<MockConnection> getMockServer() const { + return connection_; + } + + void processEvents() { + eventLoop_->processEvents(); + getMockServer()->assertNoMoreExpectations(); + } + + void setSessionStarted() { + sessionStarted_ = true; + } + + void setNeedCredentials() { + needCredentials_ = true; + } + + void addReceivedElement(boost::shared_ptr<Element> element) { + receivedElements_.push_back(element); + } + + private: + struct MockConnection : public Connection, public XMPPParserClient { + struct Event { + enum Direction { In, Out }; + enum Type { StreamStartEvent, StreamEndEvent, ElementEvent }; + + Event( + Direction direction, + Type type, + boost::shared_ptr<Element> element = boost::shared_ptr<Element>()) : + direction(direction), type(type), element(element) {} + + Direction direction; + Type type; + boost::shared_ptr<Element> element; + }; + + MockConnection() : + resetParser_(false), + domain_("foo.com"), + parser_(0), + serializer_(&payloadSerializers_) { + parser_ = new XMPPParser(this, &payloadParserFactories_); + } + + ~MockConnection() { + delete parser_; + } + + void disconnect() { } + + void listen() { + assert(false); + } + + void connect(const HostAddressPort&) { assert(false); } + void connect(const String&) { assert(false); } + + void setError() { + MainEventLoop::postEvent(boost::bind(boost::ref(onDisconnected), Connection::ReadError)); + } + + void write(const ByteArray& data) { + CPPUNIT_ASSERT(parser_->parse(data.toString())); + if (resetParser_) { + resetParser(); + resetParser_ = false; + } + } + + void resetParser() { + delete parser_; + parser_ = new XMPPParser(this, &payloadParserFactories_); + } + + void handleStreamStart(const ProtocolHeader& header) { + CPPUNIT_ASSERT_EQUAL(domain_, header.getTo()); + handleEvent(Event::StreamStartEvent); + } + + void handleElement(boost::shared_ptr<Swift::Element> element) { + handleEvent(Event::ElementEvent, element); + } + + void handleStreamEnd() { + handleEvent(Event::StreamEndEvent); + } + + void handleEvent(Event::Type type, boost::shared_ptr<Element> element = boost::shared_ptr<Element>()) { + CPPUNIT_ASSERT(!events_.empty()); + CPPUNIT_ASSERT_EQUAL(events_[0].direction, Event::In); + CPPUNIT_ASSERT_EQUAL(events_[0].type, type); + if (type == Event::ElementEvent) { + CPPUNIT_ASSERT_EQUAL(serializer_.serializeElement(events_[0].element), serializer_.serializeElement(element)); + } + events_.pop_front(); + + while (!events_.empty() && events_[0].direction == Event::Out) { + sendData(serializeEvent(events_[0])); + events_.pop_front(); + } + + if (!events_.empty() && events_[0].type == Event::StreamStartEvent) { + resetParser_ = true; + } + } + + String serializeEvent(const Event& event) { + switch (event.type) { + case Event::StreamStartEvent: + { + ProtocolHeader header; + header.setTo(domain_); + return serializer_.serializeHeader(header); + } + case Event::ElementEvent: + return serializer_.serializeElement(event.element); + case Event::StreamEndEvent: + return serializer_.serializeFooter(); + } + assert(false); + return ""; + } + + void assertNoMoreExpectations() { + foreach (const Event& event, events_) { + std::cout << "Unprocessed event: " << serializeEvent(event) << std::endl; + } + CPPUNIT_ASSERT(events_.empty()); + } + + void sendData(const ByteArray& data) { + MainEventLoop::postEvent(boost::bind(boost::ref(onDataRead), data)); + } + + void expectStreamStart() { + events_.push_back(Event(Event::In, Event::StreamStartEvent)); + } + + void expectStartTLS() { + events_.push_back(Event(Event::In, Event::ElementEvent, boost::shared_ptr<StartTLSRequest>(new StartTLSRequest()))); + } + + void expectAuth(const String& user, const String& password) { + String s = String("") + '\0' + user + '\0' + password; + events_.push_back(Event(Event::In, Event::ElementEvent, boost::shared_ptr<AuthRequest>(new AuthRequest("PLAIN", ByteArray(s.getUTF8Data(), s.getUTF8Size()))))); + } + + void expectResourceBind(const String& resource, const String& id) { + boost::shared_ptr<ResourceBind> sessionStart(new ResourceBind()); + sessionStart->setResource(resource); + events_.push_back(Event(Event::In, Event::ElementEvent, IQ::createRequest(IQ::Set, JID(), id, sessionStart))); + } + + void expectSessionStart(const String& id) { + events_.push_back(Event(Event::In, Event::ElementEvent, IQ::createRequest(IQ::Set, JID(), id, boost::shared_ptr<StartSession>(new StartSession())))); + } + + void expectMessage() { + events_.push_back(Event(Event::In, Event::ElementEvent, boost::shared_ptr<Message>(new Message()))); + } + + void sendInvalidXML() { + sendData("<invalid xml/>"); + } + + void sendStreamStart() { + events_.push_back(Event(Event::Out, Event::StreamStartEvent)); + } + + void sendStreamFeatures() { + boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); + events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures)); + } + + void sendStreamFeaturesWithStartTLS() { + boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); + streamFeatures->setHasStartTLS(); + events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures)); + } + + void sendStreamFeaturesWithAuthentication() { + boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); + streamFeatures->addAuthenticationMechanism("PLAIN"); + events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures)); + } + + void sendStreamFeaturesWithUnsupportedAuthentication() { + boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); + streamFeatures->addAuthenticationMechanism("MY-UNSUPPORTED-MECH"); + events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures)); + } + + void sendStreamFeaturesWithResourceBind() { + boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); + streamFeatures->setHasResourceBind(); + events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures)); + } + + void sendStreamFeaturesWithSession() { + boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); + streamFeatures->setHasSession(); + events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures)); + } + + void sendStreamFeaturesWithResourceBindAndSession() { + boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); + streamFeatures->setHasResourceBind(); + streamFeatures->setHasSession(); + events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures)); + } + + void sendMessage() { + events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<Message>(new Message()))); + } + + void sendTLSProceed() { + events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<TLSProceed>(new TLSProceed()))); + } + + void sendTLSFailure() { + events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<StartTLSFailure>(new StartTLSFailure()))); + } + + void sendAuthSuccess() { + events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<AuthSuccess>(new AuthSuccess()))); + } + + void sendAuthFailure() { + events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<AuthFailure>(new AuthFailure()))); + } + + void sendResourceBindResponse(const String& jid, const String& id) { + boost::shared_ptr<ResourceBind> sessionStart(new ResourceBind()); + sessionStart->setJID(JID(jid)); + events_.push_back(Event(Event::Out, Event::ElementEvent, IQ::createResult(JID(), id, sessionStart))); + } + + void sendError(const String& id) { + events_.push_back(Event(Event::Out, Event::ElementEvent, IQ::createError(JID(), id, Swift::Error::NotAllowed, Swift::Error::Cancel))); + } + + void sendSessionStartResponse(const String& id) { + events_.push_back(Event(Event::Out, Event::ElementEvent, IQ::createResult(JID(), id, boost::shared_ptr<StartSession>(new StartSession())))); + } + + bool resetParser_; + String domain_; + FullPayloadParserFactoryCollection payloadParserFactories_; + FullPayloadSerializerCollection payloadSerializers_; + XMPPParser* parser_; + XMPPSerializer serializer_; + std::deque<Event> events_; + }; + + struct MockTLSLayer : public TLSLayer { + MockTLSLayer() : connecting_(false) {} + bool setClientCertificate(const PKCS12Certificate&) { return true; } + void writeData(const ByteArray& data) { onWriteData(data); } + void handleDataRead(const ByteArray& data) { onDataRead(data); } + void setConnected() { onConnected(); } + void setError() { onError(); } + void connect() { connecting_ = true; } + bool isConnecting() { return connecting_; } + + bool connecting_; + }; + + struct MockTLSLayerFactory : public TLSLayerFactory { + MockTLSLayerFactory() : haveTLS_(true) {} + void setTLSSupported(bool b) { haveTLS_ = b; } + virtual bool canCreate() const { return haveTLS_; } + virtual boost::shared_ptr<TLSLayer> createTLSLayer() { + assert(haveTLS_); + boost::shared_ptr<MockTLSLayer> result(new MockTLSLayer()); + layers_.push_back(result); + return result; + } + std::vector< boost::shared_ptr<MockTLSLayer> > layers_; + bool haveTLS_; + }; + + struct MockSession : public ClientSession { + MockSession(const JID& jid, boost::shared_ptr<Connection> connection, TLSLayerFactory* tlsLayerFactory, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers) : ClientSession(jid, connection, tlsLayerFactory, payloadParserFactories, payloadSerializers) {} + + boost::shared_ptr<MockTLSLayer> getTLSLayer() const { + return getStreamStack()->getLayer<MockTLSLayer>(); + } + boost::shared_ptr<WhitespacePingLayer> getWhitespacePingLayer() const { + return getStreamStack()->getLayer<WhitespacePingLayer>(); + } + }; + + MockSession* createSession(const String& jid) { + return new MockSession(JID(jid), connection_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_); + } + + + DummyEventLoop* eventLoop_; + boost::shared_ptr<MockConnection> connection_; + MockTLSLayerFactory* tlsLayerFactory_; + FullPayloadParserFactoryCollection payloadParserFactories_; + FullPayloadSerializerCollection payloadSerializers_; + bool sessionStarted_; + bool needCredentials_; + std::vector< boost::shared_ptr<Element> > receivedElements_; + typedef std::vector< boost::function<void ()> > EventQueue; + EventQueue events_; +}; + +CPPUNIT_TEST_SUITE_REGISTRATION(ClientSessionTest); diff --git a/Swiften/Client/UnitTest/Makefile.inc b/Swiften/Client/UnitTest/Makefile.inc index 3ef87e5..14dac57 100644 --- a/Swiften/Client/UnitTest/Makefile.inc +++ b/Swiften/Client/UnitTest/Makefile.inc @@ -1,2 +1,2 @@ UNITTEST_SOURCES += \ - Swiften/Client/UnitTest/SessionTest.cpp + Swiften/Client/UnitTest/ClientSessionTest.cpp diff --git a/Swiften/Client/UnitTest/SessionTest.cpp b/Swiften/Client/UnitTest/SessionTest.cpp deleted file mode 100644 index eb7281c..0000000 --- a/Swiften/Client/UnitTest/SessionTest.cpp +++ /dev/null @@ -1,704 +0,0 @@ -#include <cppunit/extensions/HelperMacros.h> -#include <cppunit/extensions/TestFactoryRegistry.h> -#include <boost/bind.hpp> -#include <boost/function.hpp> -#include <boost/optional.hpp> - -#include "Swiften/Parser/XMPPParser.h" -#include "Swiften/Parser/XMPPParserClient.h" -#include "Swiften/Serializer/XMPPSerializer.h" -#include "Swiften/StreamStack/TLSLayerFactory.h" -#include "Swiften/StreamStack/TLSLayer.h" -#include "Swiften/StreamStack/StreamStack.h" -#include "Swiften/StreamStack/WhitespacePingLayer.h" -#include "Swiften/Elements/ProtocolHeader.h" -#include "Swiften/Elements/StreamFeatures.h" -#include "Swiften/Elements/Element.h" -#include "Swiften/Elements/Error.h" -#include "Swiften/Elements/IQ.h" -#include "Swiften/Elements/AuthRequest.h" -#include "Swiften/Elements/AuthSuccess.h" -#include "Swiften/Elements/AuthFailure.h" -#include "Swiften/Elements/ResourceBind.h" -#include "Swiften/Elements/StartSession.h" -#include "Swiften/Elements/StartTLSRequest.h" -#include "Swiften/Elements/StartTLSFailure.h" -#include "Swiften/Elements/TLSProceed.h" -#include "Swiften/Elements/Message.h" -#include "Swiften/EventLoop/MainEventLoop.h" -#include "Swiften/EventLoop/DummyEventLoop.h" -#include "Swiften/Network/Connection.h" -#include "Swiften/Network/ConnectionFactory.h" -#include "Swiften/Client/Session.h" -#include "Swiften/TLS/PKCS12Certificate.h" -#include "Swiften/Parser/PayloadParsers/FullPayloadParserFactoryCollection.h" -#include "Swiften/Serializer/PayloadSerializers/FullPayloadSerializerCollection.h" - -using namespace Swift; - -class SessionTest : public CppUnit::TestFixture { - CPPUNIT_TEST_SUITE(SessionTest); - CPPUNIT_TEST(testConstructor); - CPPUNIT_TEST(testStart_Error); - CPPUNIT_TEST(testStart_XMLError); - CPPUNIT_TEST(testStartTLS); - CPPUNIT_TEST(testStartTLS_ServerError); - CPPUNIT_TEST(testStartTLS_NoTLSSupport); - CPPUNIT_TEST(testStartTLS_ConnectError); - CPPUNIT_TEST(testStartTLS_ErrorAfterConnect); - CPPUNIT_TEST(testAuthenticate); - CPPUNIT_TEST(testAuthenticate_Unauthorized); - CPPUNIT_TEST(testAuthenticate_NoValidAuthMechanisms); - CPPUNIT_TEST(testResourceBind); - CPPUNIT_TEST(testResourceBind_ChangeResource); - CPPUNIT_TEST(testResourceBind_EmptyResource); - CPPUNIT_TEST(testResourceBind_Error); - CPPUNIT_TEST(testSessionStart); - CPPUNIT_TEST(testSessionStart_Error); - CPPUNIT_TEST(testSessionStart_AfterResourceBind); - CPPUNIT_TEST(testWhitespacePing); - CPPUNIT_TEST(testReceiveElementAfterSessionStarted); - CPPUNIT_TEST(testSendElement); - CPPUNIT_TEST_SUITE_END(); - - public: - SessionTest() {} - - void setUp() { - eventLoop_ = new DummyEventLoop(); - connection_ = boost::shared_ptr<MockConnection>(new MockConnection()); - tlsLayerFactory_ = new MockTLSLayerFactory(); - sessionStarted_ = false; - needCredentials_ = false; - } - - void tearDown() { - delete tlsLayerFactory_; - delete eventLoop_; - } - - void testConstructor() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - CPPUNIT_ASSERT_EQUAL(Session::Initial, session->getState()); - } - - void testStart_Error() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - - getMockServer()->expectStreamStart(); - session->start(); - processEvents(); - CPPUNIT_ASSERT_EQUAL(Session::WaitingForStreamStart, session->getState()); - - getMockServer()->setError(); - processEvents(); - - CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); - CPPUNIT_ASSERT_EQUAL(Session::ConnectionReadError, session->getError()); - } - - void testStart_XMLError() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - - getMockServer()->expectStreamStart(); - session->start(); - processEvents(); - CPPUNIT_ASSERT_EQUAL(Session::WaitingForStreamStart, session->getState()); - - getMockServer()->sendInvalidXML(); - processEvents(); - - CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); - CPPUNIT_ASSERT_EQUAL(Session::XMLError, session->getError()); - } - - void testStartTLS_NoTLSSupport() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - tlsLayerFactory_->setTLSSupported(false); - getMockServer()->expectStreamStart(); - getMockServer()->sendStreamStart(); - getMockServer()->sendStreamFeaturesWithStartTLS(); - session->start(); - processEvents(); - CPPUNIT_ASSERT_EQUAL(Session::SessionStarted, session->getState()); - } - - void testStartTLS() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - getMockServer()->expectStreamStart(); - getMockServer()->sendStreamStart(); - getMockServer()->sendStreamFeaturesWithStartTLS(); - getMockServer()->expectStartTLS(); - // FIXME: Test 'encrypting' state - getMockServer()->sendTLSProceed(); - session->start(); - processEvents(); - CPPUNIT_ASSERT_EQUAL(Session::Encrypting, session->getState()); - CPPUNIT_ASSERT(session->getTLSLayer()); - CPPUNIT_ASSERT(session->getTLSLayer()->isConnecting()); - - getMockServer()->resetParser(); - getMockServer()->expectStreamStart(); - getMockServer()->sendStreamStart(); - session->getTLSLayer()->setConnected(); - // FIXME: Test 'WatingForStreamStart' state - processEvents(); - CPPUNIT_ASSERT_EQUAL(Session::Negotiating, session->getState()); - } - - void testStartTLS_ServerError() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - getMockServer()->expectStreamStart(); - getMockServer()->sendStreamStart(); - getMockServer()->sendStreamFeaturesWithStartTLS(); - getMockServer()->expectStartTLS(); - getMockServer()->sendTLSFailure(); - session->start(); - processEvents(); - - CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); - CPPUNIT_ASSERT_EQUAL(Session::TLSError, session->getError()); - } - - void testStartTLS_ConnectError() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - getMockServer()->expectStreamStart(); - getMockServer()->sendStreamStart(); - getMockServer()->sendStreamFeaturesWithStartTLS(); - getMockServer()->expectStartTLS(); - getMockServer()->sendTLSProceed(); - session->start(); - processEvents(); - session->getTLSLayer()->setError(); - - CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); - CPPUNIT_ASSERT_EQUAL(Session::TLSError, session->getError()); - } - - void testStartTLS_ErrorAfterConnect() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - getMockServer()->expectStreamStart(); - getMockServer()->sendStreamStart(); - getMockServer()->sendStreamFeaturesWithStartTLS(); - getMockServer()->expectStartTLS(); - getMockServer()->sendTLSProceed(); - session->start(); - processEvents(); - getMockServer()->resetParser(); - getMockServer()->expectStreamStart(); - getMockServer()->sendStreamStart(); - session->getTLSLayer()->setConnected(); - processEvents(); - - session->getTLSLayer()->setError(); - - CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); - CPPUNIT_ASSERT_EQUAL(Session::TLSError, session->getError()); - } - - void testAuthenticate() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - session->onNeedCredentials.connect(boost::bind(&SessionTest::setNeedCredentials, this)); - getMockServer()->expectStreamStart(); - getMockServer()->sendStreamStart(); - getMockServer()->sendStreamFeaturesWithAuthentication(); - session->start(); - processEvents(); - CPPUNIT_ASSERT_EQUAL(Session::WaitingForCredentials, session->getState()); - CPPUNIT_ASSERT(needCredentials_); - - getMockServer()->expectAuth("me", "mypass"); - getMockServer()->sendAuthSuccess(); - getMockServer()->expectStreamStart(); - getMockServer()->sendStreamStart(); - session->sendCredentials("mypass"); - CPPUNIT_ASSERT_EQUAL(Session::Authenticating, session->getState()); - processEvents(); - CPPUNIT_ASSERT_EQUAL(Session::Negotiating, session->getState()); - } - - void testAuthenticate_Unauthorized() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - getMockServer()->expectStreamStart(); - getMockServer()->sendStreamStart(); - getMockServer()->sendStreamFeaturesWithAuthentication(); - session->start(); - processEvents(); - - getMockServer()->expectAuth("me", "mypass"); - getMockServer()->sendAuthFailure(); - session->sendCredentials("mypass"); - processEvents(); - - CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); - CPPUNIT_ASSERT_EQUAL(Session::AuthenticationFailedError, session->getError()); - } - - void testAuthenticate_NoValidAuthMechanisms() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - getMockServer()->expectStreamStart(); - getMockServer()->sendStreamStart(); - getMockServer()->sendStreamFeaturesWithUnsupportedAuthentication(); - session->start(); - processEvents(); - - CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); - CPPUNIT_ASSERT_EQUAL(Session::NoSupportedAuthMechanismsError, session->getError()); - } - - void testResourceBind() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - getMockServer()->expectStreamStart(); - getMockServer()->sendStreamStart(); - getMockServer()->sendStreamFeaturesWithResourceBind(); - getMockServer()->expectResourceBind("Bar", "session-bind"); - // FIXME: Check CPPUNIT_ASSERT_EQUAL(Session::BindingResource, session->getState()); - getMockServer()->sendResourceBindResponse("me@foo.com/Bar", "session-bind"); - session->start(); - - processEvents(); - - CPPUNIT_ASSERT_EQUAL(Session::SessionStarted, session->getState()); - CPPUNIT_ASSERT_EQUAL(JID("me@foo.com/Bar"), session->getJID()); - } - - void testResourceBind_ChangeResource() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - getMockServer()->expectStreamStart(); - getMockServer()->sendStreamStart(); - getMockServer()->sendStreamFeaturesWithResourceBind(); - getMockServer()->expectResourceBind("Bar", "session-bind"); - getMockServer()->sendResourceBindResponse("me@foo.com/Bar123", "session-bind"); - session->start(); - processEvents(); - - CPPUNIT_ASSERT_EQUAL(Session::SessionStarted, session->getState()); - CPPUNIT_ASSERT_EQUAL(JID("me@foo.com/Bar123"), session->getJID()); - } - - void testResourceBind_EmptyResource() { - std::auto_ptr<MockSession> session(createSession("me@foo.com")); - getMockServer()->expectStreamStart(); - getMockServer()->sendStreamStart(); - getMockServer()->sendStreamFeaturesWithResourceBind(); - getMockServer()->expectResourceBind("", "session-bind"); - getMockServer()->sendResourceBindResponse("me@foo.com/NewResource", "session-bind"); - session->start(); - processEvents(); - - CPPUNIT_ASSERT_EQUAL(Session::SessionStarted, session->getState()); - CPPUNIT_ASSERT_EQUAL(JID("me@foo.com/NewResource"), session->getJID()); - } - - void testResourceBind_Error() { - std::auto_ptr<MockSession> session(createSession("me@foo.com")); - getMockServer()->expectStreamStart(); - getMockServer()->sendStreamStart(); - getMockServer()->sendStreamFeaturesWithResourceBind(); - getMockServer()->expectResourceBind("", "session-bind"); - getMockServer()->sendError("session-bind"); - session->start(); - processEvents(); - - CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); - CPPUNIT_ASSERT_EQUAL(Session::ResourceBindError, session->getError()); - } - - void testSessionStart() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - session->onSessionStarted.connect(boost::bind(&SessionTest::setSessionStarted, this)); - getMockServer()->expectStreamStart(); - getMockServer()->sendStreamStart(); - getMockServer()->sendStreamFeaturesWithSession(); - getMockServer()->expectSessionStart("session-start"); - // FIXME: Check CPPUNIT_ASSERT_EQUAL(Session::StartingSession, session->getState()); - getMockServer()->sendSessionStartResponse("session-start"); - session->start(); - processEvents(); - - CPPUNIT_ASSERT_EQUAL(Session::SessionStarted, session->getState()); - CPPUNIT_ASSERT(sessionStarted_); - } - - void testSessionStart_Error() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - getMockServer()->expectStreamStart(); - getMockServer()->sendStreamStart(); - getMockServer()->sendStreamFeaturesWithSession(); - getMockServer()->expectSessionStart("session-start"); - getMockServer()->sendError("session-start"); - session->start(); - processEvents(); - - CPPUNIT_ASSERT_EQUAL(Session::Error, session->getState()); - CPPUNIT_ASSERT_EQUAL(Session::SessionStartError, session->getError()); - } - - void testSessionStart_AfterResourceBind() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - session->onSessionStarted.connect(boost::bind(&SessionTest::setSessionStarted, this)); - getMockServer()->expectStreamStart(); - getMockServer()->sendStreamStart(); - getMockServer()->sendStreamFeaturesWithResourceBindAndSession(); - getMockServer()->expectResourceBind("Bar", "session-bind"); - getMockServer()->sendResourceBindResponse("me@foo.com/Bar", "session-bind"); - getMockServer()->expectSessionStart("session-start"); - getMockServer()->sendSessionStartResponse("session-start"); - session->start(); - processEvents(); - - CPPUNIT_ASSERT_EQUAL(Session::SessionStarted, session->getState()); - CPPUNIT_ASSERT(sessionStarted_); - } - - void testWhitespacePing() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - getMockServer()->expectStreamStart(); - getMockServer()->sendStreamStart(); - getMockServer()->sendStreamFeatures(); - session->start(); - processEvents(); - CPPUNIT_ASSERT(session->getWhitespacePingLayer()); - } - - void testReceiveElementAfterSessionStarted() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - getMockServer()->expectStreamStart(); - getMockServer()->sendStreamStart(); - getMockServer()->sendStreamFeatures(); - session->start(); - processEvents(); - - getMockServer()->expectMessage(); - session->sendElement(boost::shared_ptr<Message>(new Message())); - } - - void testSendElement() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); - session->onElementReceived.connect(boost::bind(&SessionTest::addReceivedElement, this, _1)); - getMockServer()->expectStreamStart(); - getMockServer()->sendStreamStart(); - getMockServer()->sendStreamFeatures(); - getMockServer()->sendMessage(); - session->start(); - processEvents(); - - CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(receivedElements_.size())); - CPPUNIT_ASSERT(boost::dynamic_pointer_cast<Message>(receivedElements_[0])); - } - - private: - struct MockConnection; - - boost::shared_ptr<MockConnection> getMockServer() const { - return connection_; - } - - void processEvents() { - eventLoop_->processEvents(); - getMockServer()->assertNoMoreExpectations(); - } - - void setSessionStarted() { - sessionStarted_ = true; - } - - void setNeedCredentials() { - needCredentials_ = true; - } - - void addReceivedElement(boost::shared_ptr<Element> element) { - receivedElements_.push_back(element); - } - - private: - struct MockConnection : public Connection, public XMPPParserClient { - struct Event { - enum Direction { In, Out }; - enum Type { StreamStartEvent, StreamEndEvent, ElementEvent }; - - Event( - Direction direction, - Type type, - boost::shared_ptr<Element> element = boost::shared_ptr<Element>()) : - direction(direction), type(type), element(element) {} - - Direction direction; - Type type; - boost::shared_ptr<Element> element; - }; - - MockConnection() : - resetParser_(false), - domain_("foo.com"), - parser_(0), - serializer_(&payloadSerializers_) { - parser_ = new XMPPParser(this, &payloadParserFactories_); - } - - ~MockConnection() { - delete parser_; - } - - void disconnect() { } - - void listen() { - assert(false); - } - - void connect(const HostAddressPort&) { assert(false); } - void connect(const String&) { assert(false); } - - void setError() { - MainEventLoop::postEvent(boost::bind(boost::ref(onDisconnected), Connection::ReadError)); - } - - void write(const ByteArray& data) { - CPPUNIT_ASSERT(parser_->parse(data.toString())); - if (resetParser_) { - resetParser(); - resetParser_ = false; - } - } - - void resetParser() { - delete parser_; - parser_ = new XMPPParser(this, &payloadParserFactories_); - } - - void handleStreamStart(const ProtocolHeader& header) { - CPPUNIT_ASSERT_EQUAL(domain_, header.getTo()); - handleEvent(Event::StreamStartEvent); - } - - void handleElement(boost::shared_ptr<Swift::Element> element) { - handleEvent(Event::ElementEvent, element); - } - - void handleStreamEnd() { - handleEvent(Event::StreamEndEvent); - } - - void handleEvent(Event::Type type, boost::shared_ptr<Element> element = boost::shared_ptr<Element>()) { - CPPUNIT_ASSERT(!events_.empty()); - CPPUNIT_ASSERT_EQUAL(events_[0].direction, Event::In); - CPPUNIT_ASSERT_EQUAL(events_[0].type, type); - if (type == Event::ElementEvent) { - CPPUNIT_ASSERT_EQUAL(serializer_.serializeElement(events_[0].element), serializer_.serializeElement(element)); - } - events_.pop_front(); - - while (!events_.empty() && events_[0].direction == Event::Out) { - sendData(serializeEvent(events_[0])); - events_.pop_front(); - } - - if (!events_.empty() && events_[0].type == Event::StreamStartEvent) { - resetParser_ = true; - } - } - - String serializeEvent(const Event& event) { - switch (event.type) { - case Event::StreamStartEvent: - { - ProtocolHeader header; - header.setTo(domain_); - return serializer_.serializeHeader(header); - } - case Event::ElementEvent: - return serializer_.serializeElement(event.element); - case Event::StreamEndEvent: - return serializer_.serializeFooter(); - } - assert(false); - return ""; - } - - void assertNoMoreExpectations() { - foreach (const Event& event, events_) { - std::cout << "Unprocessed event: " << serializeEvent(event) << std::endl; - } - CPPUNIT_ASSERT(events_.empty()); - } - - void sendData(const ByteArray& data) { - MainEventLoop::postEvent(boost::bind(boost::ref(onDataRead), data)); - } - - void expectStreamStart() { - events_.push_back(Event(Event::In, Event::StreamStartEvent)); - } - - void expectStartTLS() { - events_.push_back(Event(Event::In, Event::ElementEvent, boost::shared_ptr<StartTLSRequest>(new StartTLSRequest()))); - } - - void expectAuth(const String& user, const String& password) { - String s = String("") + '\0' + user + '\0' + password; - events_.push_back(Event(Event::In, Event::ElementEvent, boost::shared_ptr<AuthRequest>(new AuthRequest("PLAIN", ByteArray(s.getUTF8Data(), s.getUTF8Size()))))); - } - - void expectResourceBind(const String& resource, const String& id) { - boost::shared_ptr<ResourceBind> sessionStart(new ResourceBind()); - sessionStart->setResource(resource); - events_.push_back(Event(Event::In, Event::ElementEvent, IQ::createRequest(IQ::Set, JID(), id, sessionStart))); - } - - void expectSessionStart(const String& id) { - events_.push_back(Event(Event::In, Event::ElementEvent, IQ::createRequest(IQ::Set, JID(), id, boost::shared_ptr<StartSession>(new StartSession())))); - } - - void expectMessage() { - events_.push_back(Event(Event::In, Event::ElementEvent, boost::shared_ptr<Message>(new Message()))); - } - - void sendInvalidXML() { - sendData("<invalid xml/>"); - } - - void sendStreamStart() { - events_.push_back(Event(Event::Out, Event::StreamStartEvent)); - } - - void sendStreamFeatures() { - boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); - events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures)); - } - - void sendStreamFeaturesWithStartTLS() { - boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); - streamFeatures->setHasStartTLS(); - events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures)); - } - - void sendStreamFeaturesWithAuthentication() { - boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); - streamFeatures->addAuthenticationMechanism("PLAIN"); - events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures)); - } - - void sendStreamFeaturesWithUnsupportedAuthentication() { - boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); - streamFeatures->addAuthenticationMechanism("MY-UNSUPPORTED-MECH"); - events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures)); - } - - void sendStreamFeaturesWithResourceBind() { - boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); - streamFeatures->setHasResourceBind(); - events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures)); - } - - void sendStreamFeaturesWithSession() { - boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); - streamFeatures->setHasSession(); - events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures)); - } - - void sendStreamFeaturesWithResourceBindAndSession() { - boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); - streamFeatures->setHasResourceBind(); - streamFeatures->setHasSession(); - events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures)); - } - - void sendMessage() { - events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<Message>(new Message()))); - } - - void sendTLSProceed() { - events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<TLSProceed>(new TLSProceed()))); - } - - void sendTLSFailure() { - events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<StartTLSFailure>(new StartTLSFailure()))); - } - - void sendAuthSuccess() { - events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<AuthSuccess>(new AuthSuccess()))); - } - - void sendAuthFailure() { - events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<AuthFailure>(new AuthFailure()))); - } - - void sendResourceBindResponse(const String& jid, const String& id) { - boost::shared_ptr<ResourceBind> sessionStart(new ResourceBind()); - sessionStart->setJID(JID(jid)); - events_.push_back(Event(Event::Out, Event::ElementEvent, IQ::createResult(JID(), id, sessionStart))); - } - - void sendError(const String& id) { - events_.push_back(Event(Event::Out, Event::ElementEvent, IQ::createError(JID(), id, Swift::Error::NotAllowed, Swift::Error::Cancel))); - } - - void sendSessionStartResponse(const String& id) { - events_.push_back(Event(Event::Out, Event::ElementEvent, IQ::createResult(JID(), id, boost::shared_ptr<StartSession>(new StartSession())))); - } - - bool resetParser_; - String domain_; - FullPayloadParserFactoryCollection payloadParserFactories_; - FullPayloadSerializerCollection payloadSerializers_; - XMPPParser* parser_; - XMPPSerializer serializer_; - std::deque<Event> events_; - }; - - struct MockTLSLayer : public TLSLayer { - MockTLSLayer() : connecting_(false) {} - bool setClientCertificate(const PKCS12Certificate&) { return true; } - void writeData(const ByteArray& data) { onWriteData(data); } - void handleDataRead(const ByteArray& data) { onDataRead(data); } - void setConnected() { onConnected(); } - void setError() { onError(); } - void connect() { connecting_ = true; } - bool isConnecting() { return connecting_; } - - bool connecting_; - }; - - struct MockTLSLayerFactory : public TLSLayerFactory { - MockTLSLayerFactory() : haveTLS_(true) {} - void setTLSSupported(bool b) { haveTLS_ = b; } - virtual bool canCreate() const { return haveTLS_; } - virtual boost::shared_ptr<TLSLayer> createTLSLayer() { - assert(haveTLS_); - boost::shared_ptr<MockTLSLayer> result(new MockTLSLayer()); - layers_.push_back(result); - return result; - } - std::vector< boost::shared_ptr<MockTLSLayer> > layers_; - bool haveTLS_; - }; - - struct MockSession : public Session { - MockSession(const JID& jid, boost::shared_ptr<Connection> connection, TLSLayerFactory* tlsLayerFactory, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers) : Session(jid, connection, tlsLayerFactory, payloadParserFactories, payloadSerializers) {} - - boost::shared_ptr<MockTLSLayer> getTLSLayer() const { - return getStreamStack()->getLayer<MockTLSLayer>(); - } - boost::shared_ptr<WhitespacePingLayer> getWhitespacePingLayer() const { - return getStreamStack()->getLayer<WhitespacePingLayer>(); - } - }; - - MockSession* createSession(const String& jid) { - return new MockSession(JID(jid), connection_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_); - } - - - DummyEventLoop* eventLoop_; - boost::shared_ptr<MockConnection> connection_; - MockTLSLayerFactory* tlsLayerFactory_; - FullPayloadParserFactoryCollection payloadParserFactories_; - FullPayloadSerializerCollection payloadSerializers_; - bool sessionStarted_; - bool needCredentials_; - std::vector< boost::shared_ptr<Element> > receivedElements_; - typedef std::vector< boost::function<void ()> > EventQueue; - EventQueue events_; -}; - -CPPUNIT_TEST_SUITE_REGISTRATION(SessionTest); diff --git a/Swiften/Session/Makefile.inc b/Swiften/Session/Makefile.inc new file mode 100644 index 0000000..faa73c8 --- /dev/null +++ b/Swiften/Session/Makefile.inc @@ -0,0 +1,2 @@ +SWIFTEN_SOURCES += \ + Swiften/Session/Session.cpp diff --git a/Swiften/Session/Session.cpp b/Swiften/Session/Session.cpp new file mode 100644 index 0000000..9ab8e4d --- /dev/null +++ b/Swiften/Session/Session.cpp @@ -0,0 +1,72 @@ +#include "Swiften/Session/Session.h" + +#include <boost/bind.hpp> + +#include "Swiften/StreamStack/XMPPLayer.h" +#include "Swiften/StreamStack/StreamStack.h" + +namespace Swift { + +Session::Session( + boost::shared_ptr<Connection> connection, + PayloadParserFactoryCollection* payloadParserFactories, + PayloadSerializerCollection* payloadSerializers) : + connection(connection), + payloadParserFactories(payloadParserFactories), + payloadSerializers(payloadSerializers), + initialized(false) { +} + +Session::~Session() { + delete streamStack; +} + +void Session::startSession() { + initializeStreamStack(); + handleSessionStarted(); +} + +void Session::finishSession() { + connection->disconnect(); + onSessionFinished(boost::optional<Error>()); +} + +void Session::finishSession(const Error& error) { + connection->disconnect(); + onSessionFinished(boost::optional<Error>(error)); +} + +void Session::initializeStreamStack() { + xmppLayer = boost::shared_ptr<XMPPLayer>( + new XMPPLayer(payloadParserFactories, payloadSerializers)); + xmppLayer->onStreamStart.connect( + boost::bind(&Session::handleStreamStart, this, _1)); + xmppLayer->onElement.connect( + boost::bind(&Session::handleElement, this, _1)); + xmppLayer->onError.connect( + boost::bind(&Session::finishSession, this, XMLError)); + connection->onDisconnected.connect( + boost::bind(&Session::handleDisconnected, shared_from_this(), _1)); + connectionLayer = boost::shared_ptr<ConnectionLayer>(new ConnectionLayer(connection)); + streamStack = new StreamStack(xmppLayer, connectionLayer); +} + +void Session::sendStanza(boost::shared_ptr<Stanza> stanza) { + xmppLayer->writeElement(stanza); +} + +void Session::handleDisconnected(const boost::optional<Connection::Error>& connectionError) { + if (connectionError) { + finishSession(ConnectionError); + } + else { + finishSession(); + } +} + +void Session::setInitialized() { + initialized = true; + onSessionStarted(); +} + +} diff --git a/Swiften/Session/Session.h b/Swiften/Session/Session.h new file mode 100644 index 0000000..bf8049a --- /dev/null +++ b/Swiften/Session/Session.h @@ -0,0 +1,76 @@ +#pragma once + +#include <boost/shared_ptr.hpp> +#include <boost/signal.hpp> +#include <boost/optional.hpp> +#include <boost/enable_shared_from_this.hpp> + +#include "Swiften/JID/JID.h" +#include "Swiften/Elements/Element.h" +#include "Swiften/Network/Connection.h" +#include "Swiften/StreamStack/ConnectionLayer.h" + +namespace Swift { + class ProtocolHeader; + class StreamStack; + class JID; + class Stanza; + class ByteArray; + class PayloadParserFactoryCollection; + class PayloadSerializerCollection; + class XMPPLayer; + + class Session : public boost::enable_shared_from_this<Session> { + public: + enum Error { + ConnectionError, + XMLError + }; + + Session( + boost::shared_ptr<Connection> connection, + PayloadParserFactoryCollection* payloadParserFactories, + PayloadSerializerCollection* payloadSerializers); + virtual ~Session(); + + void startSession(); + void finishSession(); + void sendStanza(boost::shared_ptr<Stanza>); + + boost::signal<void (boost::shared_ptr<Stanza>)> onStanzaReceived; + boost::signal<void ()> onSessionStarted; + boost::signal<void (const boost::optional<Error>&)> onSessionFinished; + boost::signal<void (const ByteArray&)> onDataWritten; + boost::signal<void (const ByteArray&)> onDataRead; + + protected: + void finishSession(const Error&); + + virtual void handleSessionStarted() {} + virtual void handleElement(boost::shared_ptr<Element>) = 0; + virtual void handleStreamStart(const ProtocolHeader&) = 0; + + void initializeStreamStack(); + + boost::shared_ptr<XMPPLayer> getXMPPLayer() const { + return xmppLayer; + } + + void setInitialized(); + bool isInitialized() const { + return initialized; + } + + private: + void handleDisconnected(const boost::optional<Connection::Error>& error); + + private: + boost::shared_ptr<Connection> connection; + PayloadParserFactoryCollection* payloadParserFactories; + PayloadSerializerCollection* payloadSerializers; + boost::shared_ptr<XMPPLayer> xmppLayer; + boost::shared_ptr<ConnectionLayer> connectionLayer; + StreamStack* streamStack; + bool initialized; + }; +} -- cgit v0.10.2-6-g49f6