diff options
Diffstat (limited to 'Swiften/Client')
-rw-r--r-- | Swiften/Client/Client.cpp | 212 | ||||
-rw-r--r-- | Swiften/Client/Client.h | 78 | ||||
-rw-r--r-- | Swiften/Client/ClientError.h | 32 | ||||
-rw-r--r-- | Swiften/Client/ClientSession.cpp | 276 | ||||
-rw-r--r-- | Swiften/Client/ClientSession.h | 100 | ||||
-rw-r--r-- | Swiften/Client/ClientXMLTracer.h | 26 | ||||
-rw-r--r-- | Swiften/Client/DummyStanzaChannel.h | 38 | ||||
-rw-r--r-- | Swiften/Client/StanzaChannel.h | 20 | ||||
-rw-r--r-- | Swiften/Client/UnitTest/ClientSessionTest.cpp | 498 |
9 files changed, 1280 insertions, 0 deletions
diff --git a/Swiften/Client/Client.cpp b/Swiften/Client/Client.cpp new file mode 100644 index 0000000..c704248 --- /dev/null +++ b/Swiften/Client/Client.cpp @@ -0,0 +1,212 @@ +#include "Swiften/Client/Client.h" + +#include <boost/bind.hpp> + +#include "Swiften/Network/MainBoostIOServiceThread.h" +#include "Swiften/Network/BoostIOServiceThread.h" +#include "Swiften/Client/ClientSession.h" +#include "Swiften/StreamStack/PlatformTLSLayerFactory.h" +#include "Swiften/Network/Connector.h" +#include "Swiften/Network/BoostConnectionFactory.h" +#include "Swiften/Network/BoostTimerFactory.h" +#include "Swiften/TLS/PKCS12Certificate.h" +#include "Swiften/Session/BasicSessionStream.h" + +namespace Swift { + +Client::Client(const JID& jid, const String& password) : + IQRouter(this), jid_(jid), password_(password) { + connectionFactory_ = new BoostConnectionFactory(&MainBoostIOServiceThread::getInstance().getIOService()); + timerFactory_ = new BoostTimerFactory(&MainBoostIOServiceThread::getInstance().getIOService()); + tlsLayerFactory_ = new PlatformTLSLayerFactory(); +} + +Client::~Client() { + if (session_ || connection_) { + std::cerr << "Warning: Client not disconnected properly" << std::endl; + } + delete tlsLayerFactory_; + delete timerFactory_; + delete connectionFactory_; +} + +bool Client::isAvailable() { + return session_; +} + +void Client::connect() { + assert(!connector_); + connector_ = boost::shared_ptr<Connector>(new Connector(jid_.getDomain(), &resolver_, connectionFactory_, timerFactory_)); + connector_->onConnectFinished.connect(boost::bind(&Client::handleConnectorFinished, this, _1)); + connector_->setTimeoutMilliseconds(60*1000); + connector_->start(); +} + +void Client::handleConnectorFinished(boost::shared_ptr<Connection> connection) { + // TODO: Add domain name resolver error + connector_.reset(); + if (!connection) { + onError(ClientError::ConnectionError); + } + else { + assert(!connection_); + connection_ = connection; + + assert(!sessionStream_); + sessionStream_ = boost::shared_ptr<BasicSessionStream>(new BasicSessionStream(connection_, &payloadParserFactories_, &payloadSerializers_, tlsLayerFactory_, timerFactory_)); + if (!certificate_.isEmpty()) { + sessionStream_->setTLSCertificate(PKCS12Certificate(certificate_, password_)); + } + sessionStream_->onDataRead.connect(boost::bind(&Client::handleDataRead, this, _1)); + sessionStream_->onDataWritten.connect(boost::bind(&Client::handleDataWritten, this, _1)); + sessionStream_->initialize(); + + session_ = ClientSession::create(jid_, sessionStream_); + session_->onInitialized.connect(boost::bind(boost::ref(onConnected))); + session_->onFinished.connect(boost::bind(&Client::handleSessionFinished, this, _1)); + session_->onNeedCredentials.connect(boost::bind(&Client::handleNeedCredentials, this)); + session_->onElementReceived.connect(boost::bind(&Client::handleElement, this, _1)); + session_->start(); + } +} + +void Client::disconnect() { + if (session_) { + session_->finish(); + } + else { + closeConnection(); + } +} + +void Client::closeConnection() { + if (sessionStream_) { + sessionStream_.reset(); + } + if (connection_) { + connection_->disconnect(); + connection_.reset(); + } +} + +void Client::send(boost::shared_ptr<Stanza> stanza) { + if (!isAvailable()) { + std::cerr << "Warning: Client: Trying to send a stanza while disconnected." << std::endl; + return; + } + session_->sendElement(stanza); +} + +void Client::sendIQ(boost::shared_ptr<IQ> iq) { + send(iq); +} + +void Client::sendMessage(boost::shared_ptr<Message> message) { + send(message); +} + +void Client::sendPresence(boost::shared_ptr<Presence> presence) { + send(presence); +} + +String Client::getNewIQID() { + return idGenerator_.generateID(); +} + +void Client::handleElement(boost::shared_ptr<Element> element) { + boost::shared_ptr<Message> message = boost::dynamic_pointer_cast<Message>(element); + if (message) { + onMessageReceived(message); + return; + } + + boost::shared_ptr<Presence> presence = boost::dynamic_pointer_cast<Presence>(element); + if (presence) { + onPresenceReceived(presence); + return; + } + + boost::shared_ptr<IQ> iq = boost::dynamic_pointer_cast<IQ>(element); + if (iq) { + onIQReceived(iq); + return; + } +} + +void Client::setCertificate(const String& certificate) { + certificate_ = certificate; +} + +void Client::handleSessionFinished(boost::shared_ptr<Error> error) { + session_.reset(); + closeConnection(); + if (error) { + ClientError clientError; + if (boost::shared_ptr<ClientSession::Error> actualError = boost::dynamic_pointer_cast<ClientSession::Error>(error)) { + switch(actualError->type) { + case ClientSession::Error::AuthenticationFailedError: + clientError = ClientError(ClientError::AuthenticationFailedError); + break; + case ClientSession::Error::CompressionFailedError: + clientError = ClientError(ClientError::CompressionFailedError); + break; + case ClientSession::Error::ServerVerificationFailedError: + clientError = ClientError(ClientError::ServerVerificationFailedError); + break; + case ClientSession::Error::NoSupportedAuthMechanismsError: + clientError = ClientError(ClientError::NoSupportedAuthMechanismsError); + break; + case ClientSession::Error::UnexpectedElementError: + clientError = ClientError(ClientError::UnexpectedElementError); + break; + case ClientSession::Error::ResourceBindError: + clientError = ClientError(ClientError::ResourceBindError); + break; + case ClientSession::Error::SessionStartError: + clientError = ClientError(ClientError::SessionStartError); + break; + case ClientSession::Error::TLSError: + clientError = ClientError(ClientError::TLSError); + break; + case ClientSession::Error::TLSClientCertificateError: + clientError = ClientError(ClientError::ClientCertificateError); + break; + } + } + else if (boost::shared_ptr<SessionStream::Error> actualError = boost::dynamic_pointer_cast<SessionStream::Error>(error)) { + switch(actualError->type) { + case SessionStream::Error::ParseError: + clientError = ClientError(ClientError::XMLError); + break; + case SessionStream::Error::TLSError: + clientError = ClientError(ClientError::TLSError); + break; + case SessionStream::Error::InvalidTLSCertificateError: + clientError = ClientError(ClientError::ClientCertificateLoadError); + break; + case SessionStream::Error::ConnectionReadError: + clientError = ClientError(ClientError::ConnectionReadError); + break; + case SessionStream::Error::ConnectionWriteError: + clientError = ClientError(ClientError::ConnectionWriteError); + break; + } + } + onError(clientError); + } +} + +void Client::handleNeedCredentials() { + assert(session_); + session_->sendCredentials(password_); +} + +void Client::handleDataRead(const String& data) { + onDataRead(data); +} + +void Client::handleDataWritten(const String& data) { + onDataWritten(data); +} + +} diff --git a/Swiften/Client/Client.h b/Swiften/Client/Client.h new file mode 100644 index 0000000..444c136 --- /dev/null +++ b/Swiften/Client/Client.h @@ -0,0 +1,78 @@ +#pragma once + +#include <boost/signals.hpp> +#include <boost/shared_ptr.hpp> + +#include "Swiften/Network/PlatformDomainNameResolver.h" +#include "Swiften/Base/Error.h" +#include "Swiften/Client/ClientSession.h" +#include "Swiften/Client/ClientError.h" +#include "Swiften/Elements/Presence.h" +#include "Swiften/Elements/Message.h" +#include "Swiften/JID/JID.h" +#include "Swiften/Base/String.h" +#include "Swiften/Base/IDGenerator.h" +#include "Swiften/Client/StanzaChannel.h" +#include "Swiften/Queries/IQRouter.h" +#include "Swiften/Parser/PayloadParsers/FullPayloadParserFactoryCollection.h" +#include "Swiften/Serializer/PayloadSerializers/FullPayloadSerializerCollection.h" + +namespace Swift { + class TLSLayerFactory; + class ConnectionFactory; + class TimerFactory; + class ClientSession; + class BasicSessionStream; + class Connector; + + class Client : public StanzaChannel, public IQRouter, public boost::bsignals::trackable { + public: + Client(const JID& jid, const String& password); + ~Client(); + + void setCertificate(const String& certificate); + + void connect(); + void disconnect(); + + bool isAvailable(); + + virtual void sendIQ(boost::shared_ptr<IQ>); + virtual void sendMessage(boost::shared_ptr<Message>); + virtual void sendPresence(boost::shared_ptr<Presence>); + + public: + boost::signal<void (const ClientError&)> onError; + boost::signal<void ()> onConnected; + boost::signal<void (const String&)> onDataRead; + boost::signal<void (const String&)> onDataWritten; + + private: + void handleConnectorFinished(boost::shared_ptr<Connection>); + void send(boost::shared_ptr<Stanza>); + virtual String getNewIQID(); + void handleElement(boost::shared_ptr<Element>); + void handleSessionFinished(boost::shared_ptr<Error>); + void handleNeedCredentials(); + void handleDataRead(const String&); + void handleDataWritten(const String&); + + void closeConnection(); + + private: + PlatformDomainNameResolver resolver_; + JID jid_; + String password_; + IDGenerator idGenerator_; + boost::shared_ptr<Connector> connector_; + ConnectionFactory* connectionFactory_; + TimerFactory* timerFactory_; + TLSLayerFactory* tlsLayerFactory_; + FullPayloadParserFactoryCollection payloadParserFactories_; + FullPayloadSerializerCollection payloadSerializers_; + boost::shared_ptr<Connection> connection_; + boost::shared_ptr<BasicSessionStream> sessionStream_; + boost::shared_ptr<ClientSession> session_; + String certificate_; + }; +} diff --git a/Swiften/Client/ClientError.h b/Swiften/Client/ClientError.h new file mode 100644 index 0000000..a0557d4 --- /dev/null +++ b/Swiften/Client/ClientError.h @@ -0,0 +1,32 @@ +#pragma once + +namespace Swift { + class ClientError { + public: + enum Type { + UnknownError, + DomainNameResolveError, + ConnectionError, + ConnectionReadError, + ConnectionWriteError, + XMLError, + AuthenticationFailedError, + CompressionFailedError, + ServerVerificationFailedError, + NoSupportedAuthMechanismsError, + UnexpectedElementError, + ResourceBindError, + SessionStartError, + TLSError, + ClientCertificateLoadError, + ClientCertificateError + }; + + ClientError(Type type = UnknownError) : type_(type) {} + + Type getType() const { return type_; } + + private: + Type type_; + }; +} diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp new file mode 100644 index 0000000..16bda40 --- /dev/null +++ b/Swiften/Client/ClientSession.cpp @@ -0,0 +1,276 @@ +#include "Swiften/Client/ClientSession.h" + +#include <boost/bind.hpp> + +#include "Swiften/Elements/ProtocolHeader.h" +#include "Swiften/Elements/StreamFeatures.h" +#include "Swiften/Elements/StartTLSRequest.h" +#include "Swiften/Elements/StartTLSFailure.h" +#include "Swiften/Elements/TLSProceed.h" +#include "Swiften/Elements/AuthRequest.h" +#include "Swiften/Elements/AuthSuccess.h" +#include "Swiften/Elements/AuthFailure.h" +#include "Swiften/Elements/AuthChallenge.h" +#include "Swiften/Elements/AuthResponse.h" +#include "Swiften/Elements/Compressed.h" +#include "Swiften/Elements/CompressFailure.h" +#include "Swiften/Elements/CompressRequest.h" +#include "Swiften/Elements/StartSession.h" +#include "Swiften/Elements/IQ.h" +#include "Swiften/Elements/ResourceBind.h" +#include "Swiften/SASL/PLAINClientAuthenticator.h" +#include "Swiften/SASL/SCRAMSHA1ClientAuthenticator.h" +#include "Swiften/Session/SessionStream.h" + +namespace Swift { + +ClientSession::ClientSession( + const JID& jid, + boost::shared_ptr<SessionStream> stream) : + localJID(jid), + state(Initial), + stream(stream), + needSessionStart(false), + authenticator(NULL) { +} + +ClientSession::~ClientSession() { +} + +void ClientSession::start() { + stream->onStreamStartReceived.connect(boost::bind(&ClientSession::handleStreamStart, shared_from_this(), _1)); + stream->onElementReceived.connect(boost::bind(&ClientSession::handleElement, shared_from_this(), _1)); + stream->onError.connect(boost::bind(&ClientSession::handleStreamError, shared_from_this(), _1)); + stream->onTLSEncrypted.connect(boost::bind(&ClientSession::handleTLSEncrypted, shared_from_this())); + + assert(state == Initial); + state = WaitingForStreamStart; + sendStreamHeader(); +} + +void ClientSession::sendStreamHeader() { + ProtocolHeader header; + header.setTo(getRemoteJID()); + stream->writeHeader(header); +} + +void ClientSession::sendElement(boost::shared_ptr<Element> element) { + stream->writeElement(element); +} + +void ClientSession::handleStreamStart(const ProtocolHeader&) { + checkState(WaitingForStreamStart); + state = Negotiating; +} + +void ClientSession::handleElement(boost::shared_ptr<Element> element) { + if (getState() == Initialized) { + onElementReceived(element); + } + else if (StreamFeatures* streamFeatures = dynamic_cast<StreamFeatures*>(element.get())) { + if (!checkState(Negotiating)) { + return; + } + + if (streamFeatures->hasStartTLS() && stream->supportsTLSEncryption()) { + state = WaitingForEncrypt; + stream->writeElement(boost::shared_ptr<StartTLSRequest>(new StartTLSRequest())); + } + else if (streamFeatures->hasCompressionMethod("zlib")) { + state = Compressing; + stream->writeElement(boost::shared_ptr<CompressRequest>(new CompressRequest("zlib"))); + } + else if (streamFeatures->hasAuthenticationMechanisms()) { + if (stream->hasTLSCertificate()) { + if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { + state = Authenticating; + stream->writeElement(boost::shared_ptr<Element>(new AuthRequest("EXTERNAL", ""))); + } + else { + finishSession(Error::TLSClientCertificateError); + } + } + else if (streamFeatures->hasAuthenticationMechanism("SCRAM-SHA-1")) { + // FIXME: Use a real nonce + authenticator = new SCRAMSHA1ClientAuthenticator("ClientNonce"); + state = WaitingForCredentials; + onNeedCredentials(); + } + else if (streamFeatures->hasAuthenticationMechanism("PLAIN")) { + authenticator = new PLAINClientAuthenticator(); + state = WaitingForCredentials; + onNeedCredentials(); + } + else { + finishSession(Error::NoSupportedAuthMechanismsError); + } + } + else { + // Start the session + stream->setWhitespacePingEnabled(true); + + if (streamFeatures->hasSession()) { + needSessionStart = true; + } + + if (streamFeatures->hasResourceBind()) { + state = BindingResource; + boost::shared_ptr<ResourceBind> resourceBind(new ResourceBind()); + if (!localJID.getResource().isEmpty()) { + resourceBind->setResource(localJID.getResource()); + } + stream->writeElement(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind)); + } + else if (needSessionStart) { + sendSessionStart(); + } + else { + state = Initialized; + onInitialized(); + } + } + } + else if (boost::dynamic_pointer_cast<Compressed>(element)) { + checkState(Compressing); + state = WaitingForStreamStart; + stream->addZLibCompression(); + stream->resetXMPPParser(); + sendStreamHeader(); + } + else if (boost::dynamic_pointer_cast<CompressFailure>(element)) { + finishSession(Error::CompressionFailedError); + } + else if (AuthChallenge* challenge = dynamic_cast<AuthChallenge*>(element.get())) { + checkState(Authenticating); + assert(authenticator); + if (authenticator->setChallenge(challenge->getValue())) { + stream->writeElement(boost::shared_ptr<AuthResponse>(new AuthResponse(authenticator->getResponse()))); + } + else { + finishSession(Error::AuthenticationFailedError); + } + } + else if (AuthSuccess* authSuccess = dynamic_cast<AuthSuccess*>(element.get())) { + checkState(Authenticating); + if (authenticator && !authenticator->setChallenge(authSuccess->getValue())) { + finishSession(Error::ServerVerificationFailedError); + } + else { + state = WaitingForStreamStart; + delete authenticator; + authenticator = NULL; + stream->resetXMPPParser(); + sendStreamHeader(); + } + } + else if (dynamic_cast<AuthFailure*>(element.get())) { + delete authenticator; + authenticator = NULL; + finishSession(Error::AuthenticationFailedError); + } + else if (dynamic_cast<TLSProceed*>(element.get())) { + checkState(WaitingForEncrypt); + state = Encrypting; + stream->addTLSEncryption(); + } + else if (dynamic_cast<StartTLSFailure*>(element.get())) { + finishSession(Error::TLSError); + } + else if (IQ* iq = dynamic_cast<IQ*>(element.get())) { + if (state == BindingResource) { + boost::shared_ptr<ResourceBind> resourceBind(iq->getPayload<ResourceBind>()); + if (iq->getType() == IQ::Error && iq->getID() == "session-bind") { + finishSession(Error::ResourceBindError); + } + else if (!resourceBind) { + finishSession(Error::UnexpectedElementError); + } + else if (iq->getType() == IQ::Result) { + localJID = resourceBind->getJID(); + if (!localJID.isValid()) { + finishSession(Error::ResourceBindError); + } + if (needSessionStart) { + sendSessionStart(); + } + else { + state = Initialized; + } + } + else { + finishSession(Error::UnexpectedElementError); + } + } + else if (state == StartingSession) { + if (iq->getType() == IQ::Result) { + state = Initialized; + onInitialized(); + } + else if (iq->getType() == IQ::Error) { + finishSession(Error::SessionStartError); + } + else { + finishSession(Error::UnexpectedElementError); + } + } + else { + finishSession(Error::UnexpectedElementError); + } + } + else { + // FIXME Not correct? + state = Initialized; + onInitialized(); + } +} + +void ClientSession::sendSessionStart() { + state = StartingSession; + stream->writeElement(IQ::createRequest(IQ::Set, JID(), "session-start", boost::shared_ptr<StartSession>(new StartSession()))); +} + +bool ClientSession::checkState(State state) { + if (state != state) { + finishSession(Error::UnexpectedElementError); + return false; + } + return true; +} + +void ClientSession::sendCredentials(const String& password) { + assert(WaitingForCredentials); + state = Authenticating; + authenticator->setCredentials(localJID.getNode(), password); + stream->writeElement(boost::shared_ptr<AuthRequest>(new AuthRequest(authenticator->getName(), authenticator->getResponse()))); +} + +void ClientSession::handleTLSEncrypted() { + checkState(WaitingForEncrypt); + state = WaitingForStreamStart; + stream->resetXMPPParser(); + sendStreamHeader(); +} + +void ClientSession::handleStreamError(boost::shared_ptr<Swift::Error> error) { + finishSession(error); +} + +void ClientSession::finish() { + if (stream->isAvailable()) { + stream->writeFooter(); + } + finishSession(boost::shared_ptr<Error>()); +} + +void ClientSession::finishSession(Error::Type error) { + finishSession(boost::shared_ptr<Swift::ClientSession::Error>(new Swift::ClientSession::Error(error))); +} + +void ClientSession::finishSession(boost::shared_ptr<Swift::Error> error) { + state = Finished; + stream->setWhitespacePingEnabled(false); + onFinished(error); +} + + +} diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h new file mode 100644 index 0000000..685672e --- /dev/null +++ b/Swiften/Client/ClientSession.h @@ -0,0 +1,100 @@ +#pragma once + +#include <boost/signal.hpp> +#include <boost/shared_ptr.hpp> +#include <boost/enable_shared_from_this.hpp> + +#include "Swiften/Base/Error.h" +#include "Swiften/Session/SessionStream.h" +#include "Swiften/Session/BasicSessionStream.h" +#include "Swiften/Base/String.h" +#include "Swiften/JID/JID.h" +#include "Swiften/Elements/Element.h" + +namespace Swift { + class ClientAuthenticator; + + class ClientSession : public boost::enable_shared_from_this<ClientSession> { + public: + enum State { + Initial, + WaitingForStreamStart, + Negotiating, + Compressing, + WaitingForEncrypt, + Encrypting, + WaitingForCredentials, + Authenticating, + BindingResource, + StartingSession, + Initialized, + Finished + }; + + struct Error : public Swift::Error { + enum Type { + AuthenticationFailedError, + CompressionFailedError, + ServerVerificationFailedError, + NoSupportedAuthMechanismsError, + UnexpectedElementError, + ResourceBindError, + SessionStartError, + TLSClientCertificateError, + TLSError, + } type; + Error(Type type) : type(type) {} + }; + + ~ClientSession(); + static boost::shared_ptr<ClientSession> create(const JID& jid, boost::shared_ptr<SessionStream> stream) { + return boost::shared_ptr<ClientSession>(new ClientSession(jid, stream)); + } + + State getState() const { + return state; + } + + void start(); + void finish(); + + void sendCredentials(const String& password); + void sendElement(boost::shared_ptr<Element> element); + + public: + boost::signal<void ()> onNeedCredentials; + boost::signal<void ()> onInitialized; + boost::signal<void (boost::shared_ptr<Swift::Error>)> onFinished; + boost::signal<void (boost::shared_ptr<Element>)> onElementReceived; + + private: + ClientSession( + const JID& jid, + boost::shared_ptr<SessionStream>); + + void finishSession(Error::Type error); + void finishSession(boost::shared_ptr<Swift::Error> error); + + JID getRemoteJID() const { + return JID("", localJID.getDomain()); + } + + void sendStreamHeader(); + void sendSessionStart(); + + void handleElement(boost::shared_ptr<Element>); + void handleStreamStart(const ProtocolHeader&); + void handleStreamError(boost::shared_ptr<Swift::Error>); + + void handleTLSEncrypted(); + + bool checkState(State); + + private: + JID localJID; + State state; + boost::shared_ptr<SessionStream> stream; + bool needSessionStart; + ClientAuthenticator* authenticator; + }; +} diff --git a/Swiften/Client/ClientXMLTracer.h b/Swiften/Client/ClientXMLTracer.h new file mode 100644 index 0000000..85224c8 --- /dev/null +++ b/Swiften/Client/ClientXMLTracer.h @@ -0,0 +1,26 @@ +#pragma once + +#include "Swiften/Client/Client.h" + +namespace Swift { + class ClientXMLTracer { + public: + ClientXMLTracer(Client* client) { + client->onDataRead.connect(boost::bind(&ClientXMLTracer::printData, '<', _1)); + client->onDataWritten.connect(boost::bind(&ClientXMLTracer::printData, '>', _1)); + } + + private: + static void printData(char direction, const String& data) { + printLine(direction); + std::cerr << data << std::endl; + } + + static void printLine(char c) { + for (unsigned int i = 0; i < 80; ++i) { + std::cerr << c; + } + std::cerr << std::endl; + } + }; +} diff --git a/Swiften/Client/DummyStanzaChannel.h b/Swiften/Client/DummyStanzaChannel.h new file mode 100644 index 0000000..d618167 --- /dev/null +++ b/Swiften/Client/DummyStanzaChannel.h @@ -0,0 +1,38 @@ +#pragma once + +#include <vector> + +#include "Swiften/Client/StanzaChannel.h" + +namespace Swift { + class DummyStanzaChannel : public StanzaChannel { + public: + DummyStanzaChannel() {} + + virtual void sendStanza(boost::shared_ptr<Stanza> stanza) { + sentStanzas.push_back(stanza); + } + + virtual void sendIQ(boost::shared_ptr<IQ> iq) { + sentStanzas.push_back(iq); + } + + virtual void sendMessage(boost::shared_ptr<Message> message) { + sentStanzas.push_back(message); + } + + virtual void sendPresence(boost::shared_ptr<Presence> presence) { + sentStanzas.push_back(presence); + } + + virtual String getNewIQID() { + return "test-id"; + } + + virtual bool isAvailable() { + return true; + } + + std::vector<boost::shared_ptr<Stanza> > sentStanzas; + }; +} diff --git a/Swiften/Client/StanzaChannel.h b/Swiften/Client/StanzaChannel.h new file mode 100644 index 0000000..a0e291d --- /dev/null +++ b/Swiften/Client/StanzaChannel.h @@ -0,0 +1,20 @@ +#pragma once + +#include <boost/signal.hpp> +#include <boost/shared_ptr.hpp> + +#include "Swiften/Queries/IQChannel.h" +#include "Swiften/Elements/Message.h" +#include "Swiften/Elements/Presence.h" + +namespace Swift { + class StanzaChannel : public IQChannel { + public: + virtual void sendMessage(boost::shared_ptr<Message>) = 0; + virtual void sendPresence(boost::shared_ptr<Presence>) = 0; + virtual bool isAvailable() = 0; + + boost::signal<void (boost::shared_ptr<Message>)> onMessageReceived; + boost::signal<void (boost::shared_ptr<Presence>) > onPresenceReceived; + }; +} diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp new file mode 100644 index 0000000..e035ba3 --- /dev/null +++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp @@ -0,0 +1,498 @@ +#include <cppunit/extensions/HelperMacros.h> +#include <cppunit/extensions/TestFactoryRegistry.h> +#include <deque> +#include <boost/bind.hpp> +#include <boost/optional.hpp> + +#include "Swiften/Session/SessionStream.h" +#include "Swiften/Client/ClientSession.h" +#include "Swiften/Elements/StartTLSRequest.h" +#include "Swiften/Elements/StreamFeatures.h" +#include "Swiften/Elements/TLSProceed.h" +#include "Swiften/Elements/StartTLSFailure.h" +#include "Swiften/Elements/AuthRequest.h" +#include "Swiften/Elements/AuthSuccess.h" +#include "Swiften/Elements/AuthFailure.h" + +using namespace Swift; + +class ClientSessionTest : public CppUnit::TestFixture { + CPPUNIT_TEST_SUITE(ClientSessionTest); + CPPUNIT_TEST(testStart_Error); + CPPUNIT_TEST(testStartTLS); + CPPUNIT_TEST(testStartTLS_ServerError); + CPPUNIT_TEST(testStartTLS_ConnectError); + CPPUNIT_TEST(testAuthenticate); + CPPUNIT_TEST(testAuthenticate_Unauthorized); + CPPUNIT_TEST(testAuthenticate_NoValidAuthMechanisms); + /* + CPPUNIT_TEST(testResourceBind); + CPPUNIT_TEST(testResourceBind_ChangeResource); + CPPUNIT_TEST(testResourceBind_EmptyResource); + CPPUNIT_TEST(testResourceBind_Error); + CPPUNIT_TEST(testSessionStart); + CPPUNIT_TEST(testSessionStart_Error); + CPPUNIT_TEST(testSessionStart_AfterResourceBind); + CPPUNIT_TEST(testWhitespacePing); + CPPUNIT_TEST(testReceiveElementAfterSessionStarted); + CPPUNIT_TEST(testSendElement); + */ + CPPUNIT_TEST_SUITE_END(); + + public: + void setUp() { + server = boost::shared_ptr<MockSessionStream>(new MockSessionStream()); + sessionFinishedReceived = false; + needCredentials = false; + } + + void testStart_Error() { + boost::shared_ptr<ClientSession> session(createSession()); + session->start(); + server->breakConnection(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT(sessionFinishedReceived); + CPPUNIT_ASSERT(sessionFinishedError); + } + + void testStartTLS() { + boost::shared_ptr<ClientSession> session(createSession()); + session->start(); + server->receiveStreamStart(); + server->sendStreamStart(); + server->sendStreamFeaturesWithStartTLS(); + server->receiveStartTLS(); + CPPUNIT_ASSERT(!server->tlsEncrypted); + server->sendTLSProceed(); + CPPUNIT_ASSERT(server->tlsEncrypted); + server->onTLSEncrypted(); + server->receiveStreamStart(); + server->sendStreamStart(); + } + + void testStartTLS_ServerError() { + boost::shared_ptr<ClientSession> session(createSession()); + session->start(); + server->receiveStreamStart(); + server->sendStreamStart(); + server->sendStreamFeaturesWithStartTLS(); + server->receiveStartTLS(); + server->sendTLSFailure(); + + CPPUNIT_ASSERT(!server->tlsEncrypted); + CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT(sessionFinishedReceived); + CPPUNIT_ASSERT(sessionFinishedError); + } + + void testStartTLS_ConnectError() { + boost::shared_ptr<ClientSession> session(createSession()); + session->start(); + server->receiveStreamStart(); + server->sendStreamStart(); + server->sendStreamFeaturesWithStartTLS(); + server->receiveStartTLS(); + server->sendTLSProceed(); + server->breakTLS(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT(sessionFinishedReceived); + CPPUNIT_ASSERT(sessionFinishedError); + } + + void testAuthenticate() { + boost::shared_ptr<ClientSession> session(createSession()); + session->start(); + server->receiveStreamStart(); + server->sendStreamStart(); + server->sendStreamFeaturesWithPLAINAuthentication(); + CPPUNIT_ASSERT(needCredentials); + CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForCredentials, session->getState()); + session->sendCredentials("mypass"); + server->receiveAuthRequest("PLAIN"); + server->sendAuthSuccess(); + server->receiveStreamStart(); + } + + void testAuthenticate_Unauthorized() { + boost::shared_ptr<ClientSession> session(createSession()); + session->start(); + server->receiveStreamStart(); + server->sendStreamStart(); + server->sendStreamFeaturesWithPLAINAuthentication(); + CPPUNIT_ASSERT(needCredentials); + CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForCredentials, session->getState()); + session->sendCredentials("mypass"); + server->receiveAuthRequest("PLAIN"); + server->sendAuthFailure(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT(sessionFinishedReceived); + CPPUNIT_ASSERT(sessionFinishedError); + } + + void testAuthenticate_NoValidAuthMechanisms() { + boost::shared_ptr<ClientSession> session(createSession()); + session->start(); + server->receiveStreamStart(); + server->sendStreamStart(); + server->sendStreamFeaturesWithUnknownAuthentication(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT(sessionFinishedReceived); + CPPUNIT_ASSERT(sessionFinishedError); + } + + private: + boost::shared_ptr<ClientSession> createSession() { + boost::shared_ptr<ClientSession> session = ClientSession::create(JID("me@foo.com"), server); + session->onFinished.connect(boost::bind(&ClientSessionTest::handleSessionFinished, this, _1)); + session->onNeedCredentials.connect(boost::bind(&ClientSessionTest::handleSessionNeedCredentials, this)); + return session; + } + + void handleSessionFinished(boost::shared_ptr<Error> error) { + sessionFinishedReceived = true; + sessionFinishedError = error; + } + + void handleSessionNeedCredentials() { + needCredentials = true; + } + + class MockSessionStream : public SessionStream { + public: + struct Event { + Event(boost::shared_ptr<Element> element) : element(element), footer(false) {} + Event(const ProtocolHeader& header) : header(header), footer(false) {} + Event() : footer(true) {} + + boost::shared_ptr<Element> element; + boost::optional<ProtocolHeader> header; + bool footer; + }; + + MockSessionStream() : available(true), canTLSEncrypt(true), tlsEncrypted(false), compressed(false), whitespacePingEnabled(false), resetCount(0) { + } + + virtual bool isAvailable() { + return available; + } + + virtual void writeHeader(const ProtocolHeader& header) { + receivedEvents.push_back(Event(header)); + } + + virtual void writeFooter() { + receivedEvents.push_back(Event()); + } + + virtual void writeElement(boost::shared_ptr<Element> element) { + receivedEvents.push_back(Event(element)); + } + + virtual bool supportsTLSEncryption() { + return canTLSEncrypt; + } + + virtual void addTLSEncryption() { + tlsEncrypted = true; + } + + virtual void addZLibCompression() { + compressed = true; + } + + virtual void setWhitespacePingEnabled(bool enabled) { + whitespacePingEnabled = enabled; + } + + virtual void resetXMPPParser() { + resetCount++; + } + + void breakConnection() { + onError(boost::shared_ptr<SessionStream::Error>(new SessionStream::Error(SessionStream::Error::ConnectionReadError))); + } + + void breakTLS() { + onError(boost::shared_ptr<SessionStream::Error>(new SessionStream::Error(SessionStream::Error::TLSError))); + } + + + void sendStreamStart() { + ProtocolHeader header; + header.setTo("foo.com"); + return onStreamStartReceived(header); + } + + void sendStreamFeaturesWithStartTLS() { + boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); + streamFeatures->setHasStartTLS(); + onElementReceived(streamFeatures); + } + + void sendTLSProceed() { + onElementReceived(boost::shared_ptr<TLSProceed>(new TLSProceed())); + } + + void sendTLSFailure() { + onElementReceived(boost::shared_ptr<StartTLSFailure>(new StartTLSFailure())); + } + + void sendStreamFeaturesWithPLAINAuthentication() { + boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); + streamFeatures->addAuthenticationMechanism("PLAIN"); + onElementReceived(streamFeatures); + } + + void sendStreamFeaturesWithUnknownAuthentication() { + boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures()); + streamFeatures->addAuthenticationMechanism("UNKNOWN"); + onElementReceived(streamFeatures); + } + + void sendAuthSuccess() { + onElementReceived(boost::shared_ptr<AuthSuccess>(new AuthSuccess())); + } + + void sendAuthFailure() { + onElementReceived(boost::shared_ptr<AuthFailure>(new AuthFailure())); + } + + void receiveStreamStart() { + Event event = popEvent(); + CPPUNIT_ASSERT(event.header); + } + + void receiveStartTLS() { + Event event = popEvent(); + CPPUNIT_ASSERT(event.element); + CPPUNIT_ASSERT(boost::dynamic_pointer_cast<StartTLSRequest>(event.element)); + } + + void receiveAuthRequest(const String& mech) { + Event event = popEvent(); + CPPUNIT_ASSERT(event.element); + boost::shared_ptr<AuthRequest> request(boost::dynamic_pointer_cast<AuthRequest>(event.element)); + CPPUNIT_ASSERT(request); + CPPUNIT_ASSERT_EQUAL(mech, request->getMechanism()); + } + + Event popEvent() { + CPPUNIT_ASSERT(receivedEvents.size() > 0); + Event event = receivedEvents.front(); + receivedEvents.pop_front(); + return event; + } + + bool available; + bool canTLSEncrypt; + bool tlsEncrypted; + bool compressed; + bool whitespacePingEnabled; + int resetCount; + std::deque<Event> receivedEvents; + }; + + boost::shared_ptr<MockSessionStream> server; + bool sessionFinishedReceived; + bool needCredentials; + boost::shared_ptr<Error> sessionFinishedError; +}; + +CPPUNIT_TEST_SUITE_REGISTRATION(ClientSessionTest); + +#if 0 + void testAuthenticate() { + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); + session->onNeedCredentials.connect(boost::bind(&ClientSessionTest::setNeedCredentials, this)); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithAuthentication(); + session->startSession(); + processEvents(); + CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForCredentials, session->getState()); + CPPUNIT_ASSERT(needCredentials_); + + getMockServer()->expectAuth("me", "mypass"); + getMockServer()->sendAuthSuccess(); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + session->sendCredentials("mypass"); + CPPUNIT_ASSERT_EQUAL(ClientSession::Authenticating, session->getState()); + processEvents(); + CPPUNIT_ASSERT_EQUAL(ClientSession::Negotiating, session->getState()); + } + + void testAuthenticate_Unauthorized() { + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithAuthentication(); + session->startSession(); + processEvents(); + + getMockServer()->expectAuth("me", "mypass"); + getMockServer()->sendAuthFailure(); + session->sendCredentials("mypass"); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::AuthenticationFailedError, *session->getError()); + } + + void testAuthenticate_NoValidAuthMechanisms() { + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithUnsupportedAuthentication(); + session->startSession(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::NoSupportedAuthMechanismsError, *session->getError()); + } + + void testResourceBind() { + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithResourceBind(); + getMockServer()->expectResourceBind("Bar", "session-bind"); + // FIXME: Check CPPUNIT_ASSERT_EQUAL(ClientSession::BindingResource, session->getState()); + getMockServer()->sendResourceBindResponse("me@foo.com/Bar", "session-bind"); + session->startSession(); + + processEvents(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState()); + CPPUNIT_ASSERT_EQUAL(JID("me@foo.com/Bar"), session->getLocalJID()); + } + + void testResourceBind_ChangeResource() { + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithResourceBind(); + getMockServer()->expectResourceBind("Bar", "session-bind"); + getMockServer()->sendResourceBindResponse("me@foo.com/Bar123", "session-bind"); + session->startSession(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState()); + CPPUNIT_ASSERT_EQUAL(JID("me@foo.com/Bar123"), session->getLocalJID()); + } + + void testResourceBind_EmptyResource() { + boost::shared_ptr<MockSession> session(createSession("me@foo.com")); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithResourceBind(); + getMockServer()->expectResourceBind("", "session-bind"); + getMockServer()->sendResourceBindResponse("me@foo.com/NewResource", "session-bind"); + session->startSession(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState()); + CPPUNIT_ASSERT_EQUAL(JID("me@foo.com/NewResource"), session->getLocalJID()); + } + + void testResourceBind_Error() { + boost::shared_ptr<MockSession> session(createSession("me@foo.com")); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithResourceBind(); + getMockServer()->expectResourceBind("", "session-bind"); + getMockServer()->sendError("session-bind"); + session->startSession(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::ResourceBindError, *session->getError()); + } + + void testSessionStart() { + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); + session->onSessionStarted.connect(boost::bind(&ClientSessionTest::setSessionStarted, this)); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithSession(); + getMockServer()->expectSessionStart("session-start"); + // FIXME: Check CPPUNIT_ASSERT_EQUAL(ClientSession::StartingSession, session->getState()); + getMockServer()->sendSessionStartResponse("session-start"); + session->startSession(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState()); + CPPUNIT_ASSERT(sessionStarted_); + } + + void testSessionStart_Error() { + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithSession(); + getMockServer()->expectSessionStart("session-start"); + getMockServer()->sendError("session-start"); + session->startSession(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStartError, *session->getError()); + } + + void testSessionStart_AfterResourceBind() { + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); + session->onSessionStarted.connect(boost::bind(&ClientSessionTest::setSessionStarted, this)); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeaturesWithResourceBindAndSession(); + getMockServer()->expectResourceBind("Bar", "session-bind"); + getMockServer()->sendResourceBindResponse("me@foo.com/Bar", "session-bind"); + getMockServer()->expectSessionStart("session-start"); + getMockServer()->sendSessionStartResponse("session-start"); + session->startSession(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState()); + CPPUNIT_ASSERT(sessionStarted_); + } + + void testWhitespacePing() { + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeatures(); + session->startSession(); + processEvents(); + CPPUNIT_ASSERT(session->getWhitespacePingLayer()); + } + + void testReceiveElementAfterSessionStarted() { + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeatures(); + session->startSession(); + processEvents(); + + getMockServer()->expectMessage(); + session->sendElement(boost::shared_ptr<Message>(new Message())); + } + + void testSendElement() { + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); + session->onElementReceived.connect(boost::bind(&ClientSessionTest::addReceivedElement, this, _1)); + getMockServer()->expectStreamStart(); + getMockServer()->sendStreamStart(); + getMockServer()->sendStreamFeatures(); + getMockServer()->sendMessage(); + session->startSession(); + processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(receivedElements_.size())); + CPPUNIT_ASSERT(boost::dynamic_pointer_cast<Message>(receivedElements_[0])); + } +#endif |