diff options
| author | Remko Tronçon <git@el-tramo.be> | 2009-11-10 21:24:03 (GMT) |
|---|---|---|
| committer | Remko Tronçon <git@el-tramo.be> | 2009-11-10 21:24:03 (GMT) |
| commit | 54781ce12f7654f8136e645d4ebc5934d90c6bea (patch) | |
| tree | 90bad869f9f64d57a3c0af209b83a538a47c7762 | |
| parent | fcfac59db5cb4503554f2b30854b2e91928296f6 (diff) | |
| parent | 66ced3654ad295478b33d3e4f1716f66ab4048b5 (diff) | |
| download | swift-54781ce12f7654f8136e645d4ebc5934d90c6bea.zip swift-54781ce12f7654f8136e645d4ebc5934d90c6bea.tar.bz2 | |
Refactored session management.
46 files changed, 479 insertions, 351 deletions
diff --git a/Limber/main.cpp b/Limber/main.cpp index 965abc2..25cccec 100644 --- a/Limber/main.cpp +++ b/Limber/main.cpp @@ -60,17 +60,17 @@ class Server { if (iq->getType() == IQ::Get) { boost::shared_ptr<VCard> vcard(new VCard()); vcard->setNickname(iq->getFrom().getNode()); session->sendElement(IQ::createResult(iq->getFrom(), iq->getID(), vcard)); } else { - session->sendElement(IQ::createError(iq->getFrom(), iq->getID(), Error::Forbidden, Error::Cancel)); + session->sendElement(IQ::createError(iq->getFrom(), iq->getID(), ErrorPayload::Forbidden, ErrorPayload::Cancel)); } } else { - session->sendElement(IQ::createError(iq->getFrom(), iq->getID(), Error::FeatureNotImplemented, Error::Cancel)); + session->sendElement(IQ::createError(iq->getFrom(), iq->getID(), ErrorPayload::FeatureNotImplemented, ErrorPayload::Cancel)); } } } } private: diff --git a/Slimber/Server.cpp b/Slimber/Server.cpp index e07fb41..278a572 100644 --- a/Slimber/Server.cpp +++ b/Slimber/Server.cpp @@ -208,13 +208,13 @@ void Server::handleElementReceived(boost::shared_ptr<Element> element, boost::sh rosterRequested = true; foreach(const boost::shared_ptr<Presence> presence, presenceManager->getAllPresence()) { session->sendElement(presence); } } else { - session->sendElement(IQ::createError(iq->getFrom(), iq->getID(), Error::Forbidden, Error::Cancel)); + session->sendElement(IQ::createError(iq->getFrom(), iq->getID(), ErrorPayload::Forbidden, ErrorPayload::Cancel)); } } if (boost::shared_ptr<VCard> vcard = iq->getPayload<VCard>()) { if (iq->getType() == IQ::Get) { session->sendElement(IQ::createResult(iq->getFrom(), iq->getID(), vCardCollection->getOwnVCard())); } @@ -224,13 +224,13 @@ void Server::handleElementReceived(boost::shared_ptr<Element> element, boost::sh if (lastPresence) { linkLocalServiceBrowser->updateService(getLinkLocalServiceInfo(lastPresence)); } } } else { - session->sendElement(IQ::createError(iq->getFrom(), iq->getID(), Error::FeatureNotImplemented, Error::Cancel)); + session->sendElement(IQ::createError(iq->getFrom(), iq->getID(), ErrorPayload::FeatureNotImplemented, ErrorPayload::Cancel)); } } } else { JID toJID = stanza->getTo(); boost::shared_ptr<Session> outgoingSession = @@ -257,13 +257,13 @@ void Server::handleElementReceived(boost::shared_ptr<Element> element, boost::sh } connector->queueElement(element); } else { session->sendElement(IQ::createError( stanza->getFrom(), stanza->getID(), - Error::RecipientUnavailable, Error::Wait)); + ErrorPayload::RecipientUnavailable, ErrorPayload::Wait)); } } } } void Server::handleNewLinkLocalConnection(boost::shared_ptr<Connection> connection) { diff --git a/Swift/Controllers/ChatControllerBase.cpp b/Swift/Controllers/ChatControllerBase.cpp index baa715b..2b873f1 100644 --- a/Swift/Controllers/ChatControllerBase.cpp +++ b/Swift/Controllers/ChatControllerBase.cpp @@ -64,13 +64,13 @@ void ChatControllerBase::handleSendMessageRequest(const String &body) { } preSendMessageRequest(message); stanzaChannel_->sendMessage(message); postSendMessage(message->getBody()); } -void ChatControllerBase::handleSecurityLabelsCatalogResponse(boost::shared_ptr<SecurityLabelsCatalog> catalog, const boost::optional<Error>& error) { +void ChatControllerBase::handleSecurityLabelsCatalogResponse(boost::shared_ptr<SecurityLabelsCatalog> catalog, const boost::optional<ErrorPayload>& error) { if (!error) { if (catalog->getLabels().size() == 0) { chatWindow_->setSecurityLabelsEnabled(false); labelsEnabled_ = false; } else { chatWindow_->setAvailableSecurityLabels(catalog->getLabels()); @@ -94,53 +94,53 @@ void ChatControllerBase::handleIncomingMessage(boost::shared_ptr<MessageEvent> m chatWindow_->setUnreadMessageCount(unreadMessages_.size()); boost::shared_ptr<Message> message = messageEvent->getStanza(); preHandleIncomingMessage(message); String body = message->getBody(); if (message->isError()) { - String errorMessage = getErrorMessage(message->getPayload<Error>()); + String errorMessage = getErrorMessage(message->getPayload<ErrorPayload>()); chatWindow_->addErrorMessage(errorMessage); } else { showChatWindow(); boost::shared_ptr<SecurityLabel> label = message->getPayload<SecurityLabel>(); boost::optional<SecurityLabel> maybeLabel = label ? boost::optional<SecurityLabel>(*label) : boost::optional<SecurityLabel>(); JID from = message->getFrom(); chatWindow_->addMessage(body, senderDisplayNameFromMessage(from), isIncomingMessageFromMe(message), maybeLabel, String(avatarManager_->getAvatarPath(from).string())); } } -String ChatControllerBase::getErrorMessage(boost::shared_ptr<Error> error) { +String ChatControllerBase::getErrorMessage(boost::shared_ptr<ErrorPayload> error) { String defaultMessage = "Error sending message"; if (!error->getText().isEmpty()) { return error->getText(); } else { switch (error->getCondition()) { - case Error::BadRequest: return defaultMessage; break; - case Error::Conflict: return defaultMessage; break; - case Error::FeatureNotImplemented: return defaultMessage; break; - case Error::Forbidden: return defaultMessage; break; - case Error::Gone: return "Recipient can no longer be contacted"; break; - case Error::InternalServerError: return "Internal server error"; break; - case Error::ItemNotFound: return defaultMessage; break; - case Error::JIDMalformed: return defaultMessage; break; - case Error::NotAcceptable: return "Message was rejected"; break; - case Error::NotAllowed: return defaultMessage; break; - case Error::NotAuthorized: return defaultMessage; break; - case Error::PaymentRequired: return defaultMessage; break; - case Error::RecipientUnavailable: return "Recipient is unavailable."; break; - case Error::Redirect: return defaultMessage; break; - case Error::RegistrationRequired: return defaultMessage; break; - case Error::RemoteServerNotFound: return "Recipient's server not found."; break; - case Error::RemoteServerTimeout: return defaultMessage; break; - case Error::ResourceConstraint: return defaultMessage; break; - case Error::ServiceUnavailable: return defaultMessage; break; - case Error::SubscriptionRequired: return defaultMessage; break; - case Error::UndefinedCondition: return defaultMessage; break; - case Error::UnexpectedRequest: return defaultMessage; break; + case ErrorPayload::BadRequest: return defaultMessage; break; + case ErrorPayload::Conflict: return defaultMessage; break; + case ErrorPayload::FeatureNotImplemented: return defaultMessage; break; + case ErrorPayload::Forbidden: return defaultMessage; break; + case ErrorPayload::Gone: return "Recipient can no longer be contacted"; break; + case ErrorPayload::InternalServerError: return "Internal server error"; break; + case ErrorPayload::ItemNotFound: return defaultMessage; break; + case ErrorPayload::JIDMalformed: return defaultMessage; break; + case ErrorPayload::NotAcceptable: return "Message was rejected"; break; + case ErrorPayload::NotAllowed: return defaultMessage; break; + case ErrorPayload::NotAuthorized: return defaultMessage; break; + case ErrorPayload::PaymentRequired: return defaultMessage; break; + case ErrorPayload::RecipientUnavailable: return "Recipient is unavailable."; break; + case ErrorPayload::Redirect: return defaultMessage; break; + case ErrorPayload::RegistrationRequired: return defaultMessage; break; + case ErrorPayload::RemoteServerNotFound: return "Recipient's server not found."; break; + case ErrorPayload::RemoteServerTimeout: return defaultMessage; break; + case ErrorPayload::ResourceConstraint: return defaultMessage; break; + case ErrorPayload::ServiceUnavailable: return defaultMessage; break; + case ErrorPayload::SubscriptionRequired: return defaultMessage; break; + case ErrorPayload::UndefinedCondition: return defaultMessage; break; + case ErrorPayload::UnexpectedRequest: return defaultMessage; break; } } return defaultMessage; } } diff --git a/Swift/Controllers/ChatControllerBase.h b/Swift/Controllers/ChatControllerBase.h index 601e56b..91b72a8 100644 --- a/Swift/Controllers/ChatControllerBase.h +++ b/Swift/Controllers/ChatControllerBase.h @@ -9,13 +9,13 @@ #include "Swiften/Base/String.h" #include "Swiften/Elements/DiscoInfo.h" #include "Swiften/Events/MessageEvent.h" #include "Swiften/JID/JID.h" #include "Swiften/Elements/SecurityLabelsCatalog.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h" #include "Swiften/Presence/PresenceOracle.h" #include "Swiften/Queries/IQRouter.h" namespace Swift { class IQRouter; class StanzaChannel; @@ -41,14 +41,14 @@ namespace Swift { virtual void preSendMessageRequest(boost::shared_ptr<Message>) {}; private: void handleSendMessageRequest(const String &body); void handleAllMessagesRead(); - void handleSecurityLabelsCatalogResponse(boost::shared_ptr<SecurityLabelsCatalog>, const boost::optional<Error>& error); - String getErrorMessage(boost::shared_ptr<Error>); + void handleSecurityLabelsCatalogResponse(boost::shared_ptr<SecurityLabelsCatalog>, const boost::optional<ErrorPayload>& error); + String getErrorMessage(boost::shared_ptr<ErrorPayload>); protected: JID selfJID_; std::vector<boost::shared_ptr<MessageEvent> > unreadMessages_; StanzaChannel* stanzaChannel_; IQRouter* iqRouter_; diff --git a/Swift/Controllers/MainController.cpp b/Swift/Controllers/MainController.cpp index 9df2308..6c60783 100644 --- a/Swift/Controllers/MainController.cpp +++ b/Swift/Controllers/MainController.cpp @@ -386,13 +386,13 @@ void MainController::handleIncomingMessage(boost::shared_ptr<Message> message) { // FIXME: This logic should go into a chat manager if (event->isReadable()) { getChatController(jid)->handleIncomingMessage(event); } } -void MainController::handleServerDiscoInfoResponse(boost::shared_ptr<DiscoInfo> info, const boost::optional<Error>& error) { +void MainController::handleServerDiscoInfoResponse(boost::shared_ptr<DiscoInfo> info, const boost::optional<ErrorPayload>& error) { if (!error) { serverDiscoInfo_ = info; foreach (JIDChatControllerPair pair, chatControllers_) { pair.second->setAvailableServerFeatures(info); } foreach (JIDMUCControllerPair pair, mucControllers_) { @@ -402,13 +402,13 @@ void MainController::handleServerDiscoInfoResponse(boost::shared_ptr<DiscoInfo> } bool MainController::isMUC(const JID& jid) const { return mucControllers_.find(jid.toBare()) != mucControllers_.end(); } -void MainController::handleOwnVCardReceived(boost::shared_ptr<VCard> vCard, const boost::optional<Error>& error) { +void MainController::handleOwnVCardReceived(boost::shared_ptr<VCard> vCard, const boost::optional<ErrorPayload>& error) { if (!error && !vCard->getPhoto().isEmpty()) { vCardPhotoHash_ = SHA1::getHexHash(vCard->getPhoto()); if (lastSentPresence_) { sendPresence(lastSentPresence_); } avatarManager_->setAvatar(jid_, vCard->getPhoto()); diff --git a/Swift/Controllers/MainController.h b/Swift/Controllers/MainController.h index 3179df9..db6a110 100644 --- a/Swift/Controllers/MainController.h +++ b/Swift/Controllers/MainController.h @@ -7,13 +7,13 @@ #include "Swiften/Base/String.h" #include "Swiften/Client/ClientError.h" #include "Swiften/JID/JID.h" #include "Swiften/Elements/VCard.h" #include "Swiften/Elements/DiscoInfo.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h" #include "Swiften/Elements/Presence.h" #include "Swiften/Elements/Message.h" #include "Swiften/Settings/SettingsProvider.h" #include "Swiften/Elements/CapsInfo.h" #include "Swiften/MUC/MUCRegistry.h" #include "Swiften/Roster/XMPPRoster.h" @@ -61,15 +61,15 @@ namespace Swift { void handleJoinMUCRequest(const JID& muc, const String& nick); void handleIncomingPresence(boost::shared_ptr<Presence> presence); void handleChatControllerJIDChanged(const JID& from, const JID& to); void handleIncomingMessage(boost::shared_ptr<Message> message); void handleChangeStatusRequest(StatusShow::Type show, const String &statusText); void handleError(const ClientError& error); - void handleServerDiscoInfoResponse(boost::shared_ptr<DiscoInfo>, const boost::optional<Error>&); + void handleServerDiscoInfoResponse(boost::shared_ptr<DiscoInfo>, const boost::optional<ErrorPayload>&); void handleEventQueueLengthChange(int count); - void handleOwnVCardReceived(boost::shared_ptr<VCard> vCard, const boost::optional<Error>& error); + void handleOwnVCardReceived(boost::shared_ptr<VCard> vCard, const boost::optional<ErrorPayload>& error); ChatController* getChatController(const JID &contact); void sendPresence(boost::shared_ptr<Presence> presence); void handleInputIdle(); void handleInputNotIdle(); void logout(); void signOut(); diff --git a/Swiften/Avatars/AvatarManager.cpp b/Swiften/Avatars/AvatarManager.cpp index 6a1efc6..574e199 100644 --- a/Swiften/Avatars/AvatarManager.cpp +++ b/Swiften/Avatars/AvatarManager.cpp @@ -32,13 +32,13 @@ void AvatarManager::handlePresenceReceived(boost::shared_ptr<Presence> presence) request->onResponse.connect(boost::bind(&AvatarManager::handleVCardReceived, this, from, newHash, _1, _2)); request->send(); } } } -void AvatarManager::handleVCardReceived(const JID& from, const String& promisedHash, boost::shared_ptr<VCard> vCard, const boost::optional<Error>& error) { +void AvatarManager::handleVCardReceived(const JID& from, const String& promisedHash, boost::shared_ptr<VCard> vCard, const boost::optional<ErrorPayload>& error) { if (error) { // FIXME: What to do here? std::cerr << "Warning: " << from << ": Could not get vCard" << std::endl; return; } String realHash = SHA1::getHexHash(vCard->getPhoto()); diff --git a/Swiften/Avatars/AvatarManager.h b/Swiften/Avatars/AvatarManager.h index 3ac4433..65ec372 100644 --- a/Swiften/Avatars/AvatarManager.h +++ b/Swiften/Avatars/AvatarManager.h @@ -6,13 +6,13 @@ #include <boost/signal.hpp> #include <map> #include "Swiften/JID/JID.h" #include "Swiften/Elements/Presence.h" #include "Swiften/Elements/VCard.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h" namespace Swift { class MUCRegistry; class AvatarStorage; class StanzaChannel; class IQRouter; @@ -27,13 +27,13 @@ namespace Swift { public: boost::signal<void (const JID&, const String&)> onAvatarChanged; private: void handlePresenceReceived(boost::shared_ptr<Presence>); - void handleVCardReceived(const JID& from, const String& hash, boost::shared_ptr<VCard>, const boost::optional<Error>&); + void handleVCardReceived(const JID& from, const String& hash, boost::shared_ptr<VCard>, const boost::optional<ErrorPayload>&); void setAvatarHash(const JID& from, const String& hash); JID getAvatarJID(const JID& o) const; private: StanzaChannel* stanzaChannel_; IQRouter* iqRouter_; diff --git a/Swiften/Base/Error.cpp b/Swiften/Base/Error.cpp new file mode 100644 index 0000000..597c155 --- /dev/null +++ b/Swiften/Base/Error.cpp @@ -0,0 +1,8 @@ +#include "Swiften/Base/Error.h" + +namespace Swift { + +Error::~Error() { +} + +} diff --git a/Swiften/Base/Error.h b/Swiften/Base/Error.h new file mode 100644 index 0000000..4c729ff --- /dev/null +++ b/Swiften/Base/Error.h @@ -0,0 +1,8 @@ +#pragma once + +namespace Swift { + class Error { + public: + virtual ~Error(); + }; +}; diff --git a/Swiften/Base/SConscript b/Swiften/Base/SConscript index d308e11..a0984e5 100644 --- a/Swiften/Base/SConscript +++ b/Swiften/Base/SConscript @@ -1,9 +1,10 @@ Import("swiften_env") objects = swiften_env.StaticObject([ "ByteArray.cpp", + "Error.cpp", "IDGenerator.cpp", "String.cpp", "sleep.cpp", ]) swiften_env.Append(SWIFTEN_OBJECTS = [objects]) diff --git a/Swiften/Client/Client.cpp b/Swiften/Client/Client.cpp index 60dfade..9e38626 100644 --- a/Swiften/Client/Client.cpp +++ b/Swiften/Client/Client.cpp @@ -7,22 +7,26 @@ #include "Swiften/Network/BoostIOServiceThread.h" #include "Swiften/Client/ClientSession.h" #include "Swiften/StreamStack/PlatformTLSLayerFactory.h" #include "Swiften/Network/BoostConnectionFactory.h" #include "Swiften/Network/DomainNameResolveException.h" #include "Swiften/TLS/PKCS12Certificate.h" +#include "Swiften/Session/BasicSessionStream.h" namespace Swift { Client::Client(const JID& jid, const String& password) : IQRouter(this), jid_(jid), password_(password) { connectionFactory_ = new BoostConnectionFactory(&MainBoostIOServiceThread::getInstance().getIOService()); tlsLayerFactory_ = new PlatformTLSLayerFactory(); } Client::~Client() { + if (session_ || connection_) { + std::cerr << "Warning: Client not disconnected properly" << std::endl; + } delete tlsLayerFactory_; delete connectionFactory_; } bool Client::isAvailable() { return session_; @@ -43,29 +47,38 @@ void Client::connect() { void Client::handleConnectionConnectFinished(bool error) { if (error) { onError(ClientError::ConnectionError); } else { - session_ = boost::shared_ptr<ClientSession>(new ClientSession(jid_, connection_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_)); + assert(!sessionStream_); + sessionStream_ = boost::shared_ptr<BasicSessionStream>(new BasicSessionStream(connection_, &payloadParserFactories_, &payloadSerializers_, tlsLayerFactory_)); if (!certificate_.isEmpty()) { - session_->setCertificate(PKCS12Certificate(certificate_, password_)); + sessionStream_->setTLSCertificate(PKCS12Certificate(certificate_, password_)); } - session_->onSessionStarted.connect(boost::bind(boost::ref(onConnected))); - session_->onSessionFinished.connect(boost::bind(&Client::handleSessionFinished, this, _1)); + sessionStream_->onDataRead.connect(boost::bind(&Client::handleDataRead, this, _1)); + sessionStream_->onDataWritten.connect(boost::bind(&Client::handleDataWritten, this, _1)); + sessionStream_->initialize(); + + session_ = boost::shared_ptr<ClientSession>(new ClientSession(jid_, sessionStream_)); + session_->onInitialized.connect(boost::bind(boost::ref(onConnected))); + session_->onFinished.connect(boost::bind(&Client::handleSessionFinished, this, _1)); session_->onNeedCredentials.connect(boost::bind(&Client::handleNeedCredentials, this)); - session_->onDataRead.connect(boost::bind(&Client::handleDataRead, this, _1)); - session_->onDataWritten.connect(boost::bind(&Client::handleDataWritten, this, _1)); session_->onElementReceived.connect(boost::bind(&Client::handleElement, this, _1)); - session_->startSession(); + session_->start(); } } void Client::disconnect() { if (session_) { - session_->finishSession(); + session_->finish(); + session_.reset(); + } + if (connection_) { + connection_->disconnect(); + connection_.reset(); } } void Client::send(boost::shared_ptr<Stanza> stanza) { session_->sendElement(stanza); } @@ -107,15 +120,16 @@ void Client::handleElement(boost::shared_ptr<Element> element) { } void Client::setCertificate(const String& certificate) { certificate_ = certificate; } -void Client::handleSessionFinished(const boost::optional<Session::SessionError>& error) { +void Client::handleSessionFinished(boost::shared_ptr<Error> error) { if (error) { ClientError clientError; + /* switch (*error) { case Session::ConnectionReadError: clientError = ClientError(ClientError::ConnectionReadError); break; case Session::ConnectionWriteError: clientError = ClientError(ClientError::ConnectionWriteError); @@ -145,23 +159,24 @@ void Client::handleSessionFinished(const boost::optional<Session::SessionError>& clientError = ClientError(ClientError::ClientCertificateLoadError); break; case Session::ClientCertificateError: clientError = ClientError(ClientError::ClientCertificateError); break; } + */ onError(clientError); } } void Client::handleNeedCredentials() { session_->sendCredentials(password_); } -void Client::handleDataRead(const ByteArray& data) { - onDataRead(String(data.getData(), data.getSize())); +void Client::handleDataRead(const String& data) { + onDataRead(data); } -void Client::handleDataWritten(const ByteArray& data) { - onDataWritten(String(data.getData(), data.getSize())); +void Client::handleDataWritten(const String& data) { + onDataWritten(data); } } diff --git a/Swiften/Client/Client.h b/Swiften/Client/Client.h index 59e1c05..5188789 100644 --- a/Swiften/Client/Client.h +++ b/Swiften/Client/Client.h @@ -1,12 +1,13 @@ #ifndef SWIFTEN_Client_H #define SWIFTEN_Client_H #include <boost/signals.hpp> #include <boost/shared_ptr.hpp> +#include "Swiften/Base/Error.h" #include "Swiften/Client/ClientSession.h" #include "Swiften/Client/ClientError.h" #include "Swiften/Elements/Presence.h" #include "Swiften/Elements/Message.h" #include "Swiften/JID/JID.h" #include "Swiften/Base/String.h" @@ -17,12 +18,13 @@ #include "Swiften/Serializer/PayloadSerializers/FullPayloadSerializerCollection.h" namespace Swift { class TLSLayerFactory; class ConnectionFactory; class ClientSession; + class BasicSessionStream; class Client : public StanzaChannel, public IQRouter, public boost::bsignals::trackable { public: Client(const JID& jid, const String& password); ~Client(); @@ -35,36 +37,39 @@ namespace Swift { virtual void sendIQ(boost::shared_ptr<IQ>); virtual void sendMessage(boost::shared_ptr<Message>); virtual void sendPresence(boost::shared_ptr<Presence>); public: - boost::signal<void (ClientError)> onError; + boost::signal<void (const ClientError&)> onError; boost::signal<void ()> onConnected; boost::signal<void (const String&)> onDataRead; boost::signal<void (const String&)> onDataWritten; private: void handleConnectionConnectFinished(bool error); void send(boost::shared_ptr<Stanza>); virtual String getNewIQID(); void handleElement(boost::shared_ptr<Element>); - void handleSessionFinished(const boost::optional<Session::SessionError>& error); + void handleSessionFinished(boost::shared_ptr<Error>); void handleNeedCredentials(); - void handleDataRead(const ByteArray&); - void handleDataWritten(const ByteArray&); + void handleDataRead(const String&); + void handleDataWritten(const String&); + + void reset(); private: JID jid_; String password_; IDGenerator idGenerator_; ConnectionFactory* connectionFactory_; TLSLayerFactory* tlsLayerFactory_; FullPayloadParserFactoryCollection payloadParserFactories_; FullPayloadSerializerCollection payloadSerializers_; - boost::shared_ptr<ClientSession> session_; boost::shared_ptr<Connection> connection_; + boost::shared_ptr<BasicSessionStream> sessionStream_; + boost::shared_ptr<ClientSession> session_; String certificate_; }; } #endif diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp index a0e1289..a185ea0 100644 --- a/Swiften/Client/ClientSession.cpp +++ b/Swiften/Client/ClientSession.cpp @@ -1,238 +1,219 @@ #include "Swiften/Client/ClientSession.h" #include <boost/bind.hpp> -#include "Swiften/Network/ConnectionFactory.h" #include "Swiften/Elements/ProtocolHeader.h" -#include "Swiften/StreamStack/StreamStack.h" -#include "Swiften/StreamStack/ConnectionLayer.h" -#include "Swiften/StreamStack/XMPPLayer.h" -#include "Swiften/StreamStack/TLSLayer.h" -#include "Swiften/StreamStack/TLSLayerFactory.h" #include "Swiften/Elements/StreamFeatures.h" #include "Swiften/Elements/StartTLSRequest.h" #include "Swiften/Elements/StartTLSFailure.h" #include "Swiften/Elements/TLSProceed.h" #include "Swiften/Elements/AuthRequest.h" #include "Swiften/Elements/AuthSuccess.h" #include "Swiften/Elements/AuthFailure.h" #include "Swiften/Elements/StartSession.h" #include "Swiften/Elements/IQ.h" #include "Swiften/Elements/ResourceBind.h" #include "Swiften/SASL/PLAINMessage.h" -#include "Swiften/StreamStack/WhitespacePingLayer.h" +#include "Swiften/Session/SessionStream.h" namespace Swift { ClientSession::ClientSession( const JID& jid, - boost::shared_ptr<Connection> connection, - TLSLayerFactory* tlsLayerFactory, - PayloadParserFactoryCollection* payloadParserFactories, - PayloadSerializerCollection* payloadSerializers) : - Session(connection, payloadParserFactories, payloadSerializers), - tlsLayerFactory_(tlsLayerFactory), - state_(Initial), - needSessionStart_(false) { - setLocalJID(jid); - setRemoteJID(JID("", jid.getDomain())); + boost::shared_ptr<SessionStream> stream) : + localJID(jid), + state(Initial), + stream(stream), + needSessionStart(false) { } -void ClientSession::handleSessionStarted() { - assert(state_ == Initial); - state_ = WaitingForStreamStart; +void ClientSession::start() { + stream->onStreamStartReceived.connect(boost::bind(&ClientSession::handleStreamStart, shared_from_this(), _1)); + stream->onElementReceived.connect(boost::bind(&ClientSession::handleElement, shared_from_this(), _1)); + stream->onError.connect(boost::bind(&ClientSession::handleStreamError, shared_from_this(), _1)); + stream->onTLSEncrypted.connect(boost::bind(&ClientSession::handleTLSEncrypted, shared_from_this())); + + assert(state == Initial); + state = WaitingForStreamStart; sendStreamHeader(); } void ClientSession::sendStreamHeader() { ProtocolHeader header; header.setTo(getRemoteJID()); - getXMPPLayer()->writeHeader(header); + stream->writeHeader(header); } -void ClientSession::setCertificate(const PKCS12Certificate& certificate) { - certificate_ = certificate; +void ClientSession::sendElement(boost::shared_ptr<Element> element) { + stream->writeElement(element); } void ClientSession::handleStreamStart(const ProtocolHeader&) { checkState(WaitingForStreamStart); - state_ = Negotiating; + state = Negotiating; } void ClientSession::handleElement(boost::shared_ptr<Element> element) { - if (getState() == SessionStarted) { + if (getState() == Initialized) { onElementReceived(element); } else if (StreamFeatures* streamFeatures = dynamic_cast<StreamFeatures*>(element.get())) { if (!checkState(Negotiating)) { return; } - if (streamFeatures->hasStartTLS() && tlsLayerFactory_->canCreate()) { - state_ = Encrypting; - getXMPPLayer()->writeElement(boost::shared_ptr<StartTLSRequest>(new StartTLSRequest())); + if (streamFeatures->hasStartTLS() && stream->supportsTLSEncryption()) { + state = WaitingForEncrypt; + stream->writeElement(boost::shared_ptr<StartTLSRequest>(new StartTLSRequest())); } else if (streamFeatures->hasAuthenticationMechanisms()) { - if (!certificate_.isNull()) { - if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { - state_ = Authenticating; - getXMPPLayer()->writeElement(boost::shared_ptr<Element>(new AuthRequest("EXTERNAL", ""))); - } - else { - finishSession(ClientCertificateError); - } + if (stream->hasTLSCertificate() && streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { + state = Authenticating; + stream->writeElement(boost::shared_ptr<Element>(new AuthRequest("EXTERNAL", ""))); } else if (streamFeatures->hasAuthenticationMechanism("PLAIN")) { - state_ = WaitingForCredentials; + state = WaitingForCredentials; onNeedCredentials(); } else { - finishSession(NoSupportedAuthMechanismsError); + finishSession(Error::NoSupportedAuthMechanismsError); } } else { // Start the session - - // Add a whitespace ping layer - whitespacePingLayer_ = boost::shared_ptr<WhitespacePingLayer>(new WhitespacePingLayer()); - getStreamStack()->addLayer(whitespacePingLayer_); - whitespacePingLayer_->setActive(); + stream->setWhitespacePingEnabled(true); if (streamFeatures->hasSession()) { - needSessionStart_ = true; + needSessionStart = true; } if (streamFeatures->hasResourceBind()) { - state_ = BindingResource; + state = BindingResource; boost::shared_ptr<ResourceBind> resourceBind(new ResourceBind()); - if (!getLocalJID().getResource().isEmpty()) { - resourceBind->setResource(getLocalJID().getResource()); + if (!localJID.getResource().isEmpty()) { + resourceBind->setResource(localJID.getResource()); } - getXMPPLayer()->writeElement(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind)); + stream->writeElement(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind)); } - else if (needSessionStart_) { + else if (needSessionStart) { sendSessionStart(); } else { - state_ = SessionStarted; - onSessionStarted(); + state = Initialized; + onInitialized(); } } } else if (dynamic_cast<AuthSuccess*>(element.get())) { checkState(Authenticating); - state_ = WaitingForStreamStart; - getXMPPLayer()->resetParser(); + state = WaitingForStreamStart; + stream->resetXMPPParser(); sendStreamHeader(); } else if (dynamic_cast<AuthFailure*>(element.get())) { - finishSession(AuthenticationFailedError); + finishSession(Error::AuthenticationFailedError); } else if (dynamic_cast<TLSProceed*>(element.get())) { - tlsLayer_ = tlsLayerFactory_->createTLSLayer(); - getStreamStack()->addLayer(tlsLayer_); - if (!certificate_.isNull() && !tlsLayer_->setClientCertificate(certificate_)) { - finishSession(ClientCertificateLoadError); - } - else { - tlsLayer_->onConnected.connect(boost::bind(&ClientSession::handleTLSConnected, this)); - tlsLayer_->onError.connect(boost::bind(&ClientSession::handleTLSError, this)); - tlsLayer_->connect(); - } + checkState(WaitingForEncrypt); + state = Encrypting; + stream->addTLSEncryption(); } else if (dynamic_cast<StartTLSFailure*>(element.get())) { - finishSession(TLSError); + finishSession(Error::TLSError); } else if (IQ* iq = dynamic_cast<IQ*>(element.get())) { - if (state_ == BindingResource) { + if (state == BindingResource) { boost::shared_ptr<ResourceBind> resourceBind(iq->getPayload<ResourceBind>()); if (iq->getType() == IQ::Error && iq->getID() == "session-bind") { - finishSession(ResourceBindError); + finishSession(Error::ResourceBindError); } else if (!resourceBind) { - finishSession(UnexpectedElementError); + finishSession(Error::UnexpectedElementError); } else if (iq->getType() == IQ::Result) { - setLocalJID(resourceBind->getJID()); - if (!getLocalJID().isValid()) { - finishSession(ResourceBindError); + localJID = resourceBind->getJID(); + if (!localJID.isValid()) { + finishSession(Error::ResourceBindError); } - if (needSessionStart_) { + if (needSessionStart) { sendSessionStart(); } else { - state_ = SessionStarted; + state = Initialized; } } else { - finishSession(UnexpectedElementError); + finishSession(Error::UnexpectedElementError); } } - else if (state_ == StartingSession) { + else if (state == StartingSession) { if (iq->getType() == IQ::Result) { - state_ = SessionStarted; - onSessionStarted(); + state = Initialized; + onInitialized(); } else if (iq->getType() == IQ::Error) { - finishSession(SessionStartError); + finishSession(Error::SessionStartError); } else { - finishSession(UnexpectedElementError); + finishSession(Error::UnexpectedElementError); } } else { - finishSession(UnexpectedElementError); + finishSession(Error::UnexpectedElementError); } } else { // FIXME Not correct? - state_ = SessionStarted; - onSessionStarted(); + state = Initialized; + onInitialized(); } } void ClientSession::sendSessionStart() { - state_ = StartingSession; - getXMPPLayer()->writeElement(IQ::createRequest(IQ::Set, JID(), "session-start", boost::shared_ptr<StartSession>(new StartSession()))); -} - -void ClientSession::handleSessionFinished(const boost::optional<SessionError>& error) { - if (whitespacePingLayer_) { - whitespacePingLayer_->setInactive(); - } - - if (error) { - //assert(!error_); - state_ = Error; - error_ = error; - } - else { - state_ = Finished; - } + state = StartingSession; + stream->writeElement(IQ::createRequest(IQ::Set, JID(), "session-start", boost::shared_ptr<StartSession>(new StartSession()))); } bool ClientSession::checkState(State state) { - if (state_ != state) { - finishSession(UnexpectedElementError); + if (state != state) { + finishSession(Error::UnexpectedElementError); return false; } return true; } void ClientSession::sendCredentials(const String& password) { assert(WaitingForCredentials); - state_ = Authenticating; - getXMPPLayer()->writeElement(boost::shared_ptr<Element>(new AuthRequest("PLAIN", PLAINMessage(getLocalJID().getNode(), password).getValue()))); + state = Authenticating; + stream->writeElement(boost::shared_ptr<Element>(new AuthRequest("PLAIN", PLAINMessage(localJID.getNode(), password).getValue()))); } -void ClientSession::handleTLSConnected() { - state_ = WaitingForStreamStart; - getXMPPLayer()->resetParser(); +void ClientSession::handleTLSEncrypted() { + checkState(WaitingForEncrypt); + state = WaitingForStreamStart; + stream->resetXMPPParser(); sendStreamHeader(); } -void ClientSession::handleTLSError() { - finishSession(TLSError); +void ClientSession::handleStreamError(boost::shared_ptr<Swift::Error> error) { + finishSession(error); +} + +void ClientSession::finish() { + if (stream->isAvailable()) { + stream->writeFooter(); + } + finishSession(boost::shared_ptr<Error>()); } +void ClientSession::finishSession(Error::Type error) { + finishSession(boost::shared_ptr<Swift::ClientSession::Error>(new Swift::ClientSession::Error(error))); +} + +void ClientSession::finishSession(boost::shared_ptr<Swift::Error> error) { + stream->setWhitespacePingEnabled(false); + onFinished(error); +} + + } diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h index fead182..e09861b 100644 --- a/Swiften/Client/ClientSession.h +++ b/Swiften/Client/ClientSession.h @@ -1,88 +1,89 @@ #pragma once #include <boost/signal.hpp> #include <boost/shared_ptr.hpp> +#include <boost/enable_shared_from_this.hpp> -#include "Swiften/Session/Session.h" +#include "Swiften/Base/Error.h" +#include "Swiften/Session/SessionStream.h" +#include "Swiften/Session/BasicSessionStream.h" #include "Swiften/Base/String.h" #include "Swiften/JID/JID.h" #include "Swiften/Elements/Element.h" -#include "Swiften/Network/Connection.h" -#include "Swiften/TLS/PKCS12Certificate.h" namespace Swift { - class PayloadParserFactoryCollection; - class PayloadSerializerCollection; - class ConnectionFactory; - class Connection; - class StreamStack; - class XMPPLayer; - class ConnectionLayer; - class TLSLayerFactory; - class TLSLayer; - class WhitespacePingLayer; - - class ClientSession : public Session { + class ClientSession : public boost::enable_shared_from_this<ClientSession> { public: enum State { Initial, WaitingForStreamStart, Negotiating, Compressing, + WaitingForEncrypt, Encrypting, WaitingForCredentials, Authenticating, BindingResource, StartingSession, - SessionStarted, - Error, + Initialized, Finished }; + struct Error : public Swift::Error { + enum Type { + AuthenticationFailedError, + NoSupportedAuthMechanismsError, + UnexpectedElementError, + ResourceBindError, + SessionStartError, + TLSError, + } type; + Error(Type type) : type(type) {} + }; + ClientSession( const JID& jid, - boost::shared_ptr<Connection>, - TLSLayerFactory*, - PayloadParserFactoryCollection*, - PayloadSerializerCollection*); + boost::shared_ptr<SessionStream>); State getState() const { - return state_; + return state; } - boost::optional<SessionError> getError() const { - return error_; - } + void start(); + void finish(); void sendCredentials(const String& password); - void setCertificate(const PKCS12Certificate& certificate); + void sendElement(boost::shared_ptr<Element> element); private: + void finishSession(Error::Type error); + void finishSession(boost::shared_ptr<Swift::Error> error); + + JID getRemoteJID() const { + return JID("", localJID.getDomain()); + } + void sendStreamHeader(); void sendSessionStart(); - virtual void handleSessionStarted(); - virtual void handleSessionFinished(const boost::optional<SessionError>& error); - virtual void handleElement(boost::shared_ptr<Element>); - virtual void handleStreamStart(const ProtocolHeader&); + void handleElement(boost::shared_ptr<Element>); + void handleStreamStart(const ProtocolHeader&); + void handleStreamError(boost::shared_ptr<Swift::Error>); - void handleTLSConnected(); - void handleTLSError(); + void handleTLSEncrypted(); - void setError(SessionError); bool checkState(State); public: boost::signal<void ()> onNeedCredentials; - boost::signal<void ()> onSessionStarted; + boost::signal<void ()> onInitialized; + boost::signal<void (boost::shared_ptr<Swift::Error>)> onFinished; + boost::signal<void (boost::shared_ptr<Element>)> onElementReceived; private: - TLSLayerFactory* tlsLayerFactory_; - State state_; - boost::optional<SessionError> error_; - boost::shared_ptr<TLSLayer> tlsLayer_; - boost::shared_ptr<WhitespacePingLayer> whitespacePingLayer_; - bool needSessionStart_; - PKCS12Certificate certificate_; + JID localJID; + State state; + boost::shared_ptr<SessionStream> stream; + bool needSessionStart; }; } diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp index cbf20d2..70d4ba9 100644 --- a/Swiften/Client/UnitTest/ClientSessionTest.cpp +++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp @@ -11,13 +11,13 @@ #include "Swiften/StreamStack/TLSLayer.h" #include "Swiften/StreamStack/StreamStack.h" #include "Swiften/StreamStack/WhitespacePingLayer.h" #include "Swiften/Elements/ProtocolHeader.h" #include "Swiften/Elements/StreamFeatures.h" #include "Swiften/Elements/Element.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h" #include "Swiften/Elements/IQ.h" #include "Swiften/Elements/AuthRequest.h" #include "Swiften/Elements/AuthSuccess.h" #include "Swiften/Elements/AuthFailure.h" #include "Swiften/Elements/ResourceBind.h" #include "Swiften/Elements/StartSession.h" diff --git a/Swiften/Elements/Error.h b/Swiften/Elements/ErrorPayload.h index 8793f35..32fd067 100644 --- a/Swiften/Elements/Error.h +++ b/Swiften/Elements/ErrorPayload.h @@ -1,14 +1,13 @@ -#ifndef SWIFTEN_Error_H -#define SWIFTEN_Error_H +#pragma once #include "Swiften/Elements/Payload.h" #include "Swiften/Base/String.h" namespace Swift { - class Error : public Payload { + class ErrorPayload : public Payload { public: enum Type { Cancel, Continue, Modify, Auth, Wait }; enum Condition { BadRequest, Conflict, @@ -31,13 +30,13 @@ namespace Swift { ServiceUnavailable, SubscriptionRequired, UndefinedCondition, UnexpectedRequest }; - Error(Condition condition = UndefinedCondition, Type type = Cancel, const String& text = String()) : type_(type), condition_(condition), text_(text) { } + ErrorPayload(Condition condition = UndefinedCondition, Type type = Cancel, const String& text = String()) : type_(type), condition_(condition), text_(text) { } Type getType() const { return type_; } void setType(Type type) { @@ -63,8 +62,6 @@ namespace Swift { private: Type type_; Condition condition_; String text_; }; } - -#endif diff --git a/Swiften/Elements/IQ.cpp b/Swiften/Elements/IQ.cpp index 3f47182..53dec53 100644 --- a/Swiften/Elements/IQ.cpp +++ b/Swiften/Elements/IQ.cpp @@ -23,15 +23,15 @@ boost::shared_ptr<IQ> IQ::createResult( if (payload) { iq->addPayload(payload); } return iq; } -boost::shared_ptr<IQ> IQ::createError(const JID& to, const String& id, Error::Condition condition, Error::Type type) { +boost::shared_ptr<IQ> IQ::createError(const JID& to, const String& id, ErrorPayload::Condition condition, ErrorPayload::Type type) { boost::shared_ptr<IQ> iq(new IQ(IQ::Error)); iq->setTo(to); iq->setID(id); - iq->addPayload(boost::shared_ptr<Swift::Error>(new Swift::Error(condition, type))); + iq->addPayload(boost::shared_ptr<Swift::ErrorPayload>(new Swift::ErrorPayload(condition, type))); return iq; } } diff --git a/Swiften/Elements/IQ.h b/Swiften/Elements/IQ.h index 231439f..80c2913 100644 --- a/Swiften/Elements/IQ.h +++ b/Swiften/Elements/IQ.h @@ -1,11 +1,10 @@ -#ifndef SWIFTEN_IQ_H -#define SWIFTEN_IQ_H +#pragma once #include "Swiften/Elements/Stanza.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h" namespace Swift { class IQ : public Stanza { public: @@ -25,15 +24,13 @@ namespace Swift const JID& to, const String& id, boost::shared_ptr<Payload> payload = boost::shared_ptr<Payload>()); static boost::shared_ptr<IQ> createError( const JID& to, const String& id, - Error::Condition condition, - Error::Type type); + ErrorPayload::Condition condition, + ErrorPayload::Type type); private: Type type_; }; } - -#endif diff --git a/Swiften/Elements/Message.h b/Swiften/Elements/Message.h index a49f496..6d9171f 100644 --- a/Swiften/Elements/Message.h +++ b/Swiften/Elements/Message.h @@ -1,14 +1,13 @@ -#ifndef SWIFTEN_STANZAS_MESSAGE_H -#define SWIFTEN_STANZAS_MESSAGE_H +#pragma once #include <boost/optional.hpp> #include "Swiften/Base/String.h" #include "Swiften/Elements/Body.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h" #include "Swiften/Elements/Stanza.h" namespace Swift { class Message : public Stanza { @@ -27,20 +26,18 @@ namespace Swift void setBody(const String& body) { updatePayload(boost::shared_ptr<Body>(new Body(body))); } bool isError() { - boost::shared_ptr<Swift::Error> error(getPayload<Swift::Error>()); - return getType() == Message::Error || error.get() != NULL; + boost::shared_ptr<Swift::ErrorPayload> error(getPayload<Swift::ErrorPayload>()); + return getType() == Message::Error || error; } Type getType() const { return type_; } void setType(Type type) { type_ = type; } private: String body_; Type type_; }; } - -#endif diff --git a/Swiften/Elements/UnitTest/IQTest.cpp b/Swiften/Elements/UnitTest/IQTest.cpp index bc22c81..a5e6dc8 100644 --- a/Swiften/Elements/UnitTest/IQTest.cpp +++ b/Swiften/Elements/UnitTest/IQTest.cpp @@ -34,19 +34,19 @@ class IQTest : public CppUnit::TestFixture CPPUNIT_ASSERT_EQUAL(JID("foo@bar/fum"), iq->getTo()); CPPUNIT_ASSERT_EQUAL(String("myid"), iq->getID()); CPPUNIT_ASSERT(!iq->getPayload<SoftwareVersion>()); } void testCreateError() { - boost::shared_ptr<IQ> iq(IQ::createError(JID("foo@bar/fum"), "myid", Error::BadRequest, Error::Modify)); + boost::shared_ptr<IQ> iq(IQ::createError(JID("foo@bar/fum"), "myid", ErrorPayload::BadRequest, ErrorPayload::Modify)); CPPUNIT_ASSERT_EQUAL(JID("foo@bar/fum"), iq->getTo()); CPPUNIT_ASSERT_EQUAL(String("myid"), iq->getID()); - boost::shared_ptr<Error> error(iq->getPayload<Error>()); + boost::shared_ptr<ErrorPayload> error(iq->getPayload<ErrorPayload>()); CPPUNIT_ASSERT(error); - CPPUNIT_ASSERT_EQUAL(Error::BadRequest, error->getCondition()); - CPPUNIT_ASSERT_EQUAL(Error::Modify, error->getType()); + CPPUNIT_ASSERT_EQUAL(ErrorPayload::BadRequest, error->getCondition()); + CPPUNIT_ASSERT_EQUAL(ErrorPayload::Modify, error->getType()); } }; CPPUNIT_TEST_SUITE_REGISTRATION(IQTest); diff --git a/Swiften/EventLoop/SimpleEventLoop.cpp b/Swiften/EventLoop/SimpleEventLoop.cpp index 8191747..7c46ed3 100644 --- a/Swiften/EventLoop/SimpleEventLoop.cpp +++ b/Swiften/EventLoop/SimpleEventLoop.cpp @@ -9,12 +9,18 @@ namespace Swift { void nop() {} SimpleEventLoop::SimpleEventLoop() : isRunning_(true) { } +SimpleEventLoop::~SimpleEventLoop() { + if (!events_.empty()) { + std::cerr << "Warning: Pending events in SimpleEventLoop at destruction time" << std::endl; + } +} + void SimpleEventLoop::run() { while (isRunning_) { std::vector<Event> events; { boost::unique_lock<boost::mutex> lock(eventsMutex_); while (events_.size() == 0) { diff --git a/Swiften/EventLoop/SimpleEventLoop.h b/Swiften/EventLoop/SimpleEventLoop.h index 01afdb2..bd0a07f 100644 --- a/Swiften/EventLoop/SimpleEventLoop.h +++ b/Swiften/EventLoop/SimpleEventLoop.h @@ -1,20 +1,20 @@ -#ifndef SWIFTEN_SimpleEventLoop_H -#define SWIFTEN_SimpleEventLoop_H +#pragma once #include <vector> #include <boost/function.hpp> #include <boost/thread/mutex.hpp> #include <boost/thread/condition_variable.hpp> #include "Swiften/EventLoop/EventLoop.h" namespace Swift { class SimpleEventLoop : public EventLoop { public: SimpleEventLoop(); + ~SimpleEventLoop(); void run(); void stop(); virtual void post(const Event& event); @@ -25,7 +25,6 @@ namespace Swift { bool isRunning_; std::vector<Event> events_; boost::mutex eventsMutex_; boost::condition_variable eventsAvailable_; }; } -#endif diff --git a/Swiften/Parser/PayloadParsers/ErrorParser.cpp b/Swiften/Parser/PayloadParsers/ErrorParser.cpp index 13380c8..ae85265 100644 --- a/Swiften/Parser/PayloadParsers/ErrorParser.cpp +++ b/Swiften/Parser/PayloadParsers/ErrorParser.cpp @@ -6,101 +6,101 @@ ErrorParser::ErrorParser() : level_(TopLevel) { } void ErrorParser::handleStartElement(const String&, const String&, const AttributeMap& attributes) { if (level_ == TopLevel) { String type = attributes.getAttribute("type"); if (type == "continue") { - getPayloadInternal()->setType(Error::Continue); + getPayloadInternal()->setType(ErrorPayload::Continue); } else if (type == "modify") { - getPayloadInternal()->setType(Error::Modify); + getPayloadInternal()->setType(ErrorPayload::Modify); } else if (type == "auth") { - getPayloadInternal()->setType(Error::Auth); + getPayloadInternal()->setType(ErrorPayload::Auth); } else if (type == "wait") { - getPayloadInternal()->setType(Error::Wait); + getPayloadInternal()->setType(ErrorPayload::Wait); } else { - getPayloadInternal()->setType(Error::Cancel); + getPayloadInternal()->setType(ErrorPayload::Cancel); } } ++level_; } void ErrorParser::handleEndElement(const String& element, const String&) { --level_; if (level_ == PayloadLevel) { if (element == "text") { getPayloadInternal()->setText(currentText_); } else if (element == "bad-request") { - getPayloadInternal()->setCondition(Error::BadRequest); + getPayloadInternal()->setCondition(ErrorPayload::BadRequest); } else if (element == "conflict") { - getPayloadInternal()->setCondition(Error::Conflict); + getPayloadInternal()->setCondition(ErrorPayload::Conflict); } else if (element == "feature-not-implemented") { - getPayloadInternal()->setCondition(Error::FeatureNotImplemented); + getPayloadInternal()->setCondition(ErrorPayload::FeatureNotImplemented); } else if (element == "forbidden") { - getPayloadInternal()->setCondition(Error::Forbidden); + getPayloadInternal()->setCondition(ErrorPayload::Forbidden); } else if (element == "gone") { - getPayloadInternal()->setCondition(Error::Gone); + getPayloadInternal()->setCondition(ErrorPayload::Gone); } else if (element == "internal-server-error") { - getPayloadInternal()->setCondition(Error::InternalServerError); + getPayloadInternal()->setCondition(ErrorPayload::InternalServerError); } else if (element == "item-not-found") { - getPayloadInternal()->setCondition(Error::ItemNotFound); + getPayloadInternal()->setCondition(ErrorPayload::ItemNotFound); } else if (element == "jid-malformed") { - getPayloadInternal()->setCondition(Error::JIDMalformed); + getPayloadInternal()->setCondition(ErrorPayload::JIDMalformed); } else if (element == "not-acceptable") { - getPayloadInternal()->setCondition(Error::NotAcceptable); + getPayloadInternal()->setCondition(ErrorPayload::NotAcceptable); } else if (element == "not-allowed") { - getPayloadInternal()->setCondition(Error::NotAllowed); + getPayloadInternal()->setCondition(ErrorPayload::NotAllowed); } else if (element == "not-authorized") { - getPayloadInternal()->setCondition(Error::NotAuthorized); + getPayloadInternal()->setCondition(ErrorPayload::NotAuthorized); } else if (element == "payment-required") { - getPayloadInternal()->setCondition(Error::PaymentRequired); + getPayloadInternal()->setCondition(ErrorPayload::PaymentRequired); } else if (element == "recipient-unavailable") { - getPayloadInternal()->setCondition(Error::RecipientUnavailable); + getPayloadInternal()->setCondition(ErrorPayload::RecipientUnavailable); } else if (element == "redirect") { - getPayloadInternal()->setCondition(Error::Redirect); + getPayloadInternal()->setCondition(ErrorPayload::Redirect); } else if (element == "registration-required") { - getPayloadInternal()->setCondition(Error::RegistrationRequired); + getPayloadInternal()->setCondition(ErrorPayload::RegistrationRequired); } else if (element == "remote-server-not-found") { - getPayloadInternal()->setCondition(Error::RemoteServerNotFound); + getPayloadInternal()->setCondition(ErrorPayload::RemoteServerNotFound); } else if (element == "remote-server-timeout") { - getPayloadInternal()->setCondition(Error::RemoteServerTimeout); + getPayloadInternal()->setCondition(ErrorPayload::RemoteServerTimeout); } else if (element == "resource-constraint") { - getPayloadInternal()->setCondition(Error::ResourceConstraint); + getPayloadInternal()->setCondition(ErrorPayload::ResourceConstraint); } else if (element == "service-unavailable") { - getPayloadInternal()->setCondition(Error::ServiceUnavailable); + getPayloadInternal()->setCondition(ErrorPayload::ServiceUnavailable); } else if (element == "subscription-required") { - getPayloadInternal()->setCondition(Error::SubscriptionRequired); + getPayloadInternal()->setCondition(ErrorPayload::SubscriptionRequired); } else if (element == "unexpected-request") { - getPayloadInternal()->setCondition(Error::UnexpectedRequest); + getPayloadInternal()->setCondition(ErrorPayload::UnexpectedRequest); } else { - getPayloadInternal()->setCondition(Error::UndefinedCondition); + getPayloadInternal()->setCondition(ErrorPayload::UndefinedCondition); } } } void ErrorParser::handleCharacterData(const String& data) { currentText_ += data; diff --git a/Swiften/Parser/PayloadParsers/ErrorParser.h b/Swiften/Parser/PayloadParsers/ErrorParser.h index 76db205..17b78b9 100644 --- a/Swiften/Parser/PayloadParsers/ErrorParser.h +++ b/Swiften/Parser/PayloadParsers/ErrorParser.h @@ -1,14 +1,14 @@ #ifndef SWIFTEN_ErrorParser_H #define SWIFTEN_ErrorParser_H -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h" #include "Swiften/Parser/GenericPayloadParser.h" namespace Swift { - class ErrorParser : public GenericPayloadParser<Error> { + class ErrorParser : public GenericPayloadParser<ErrorPayload> { public: ErrorParser(); virtual void handleStartElement(const String& element, const String&, const AttributeMap& attributes); virtual void handleEndElement(const String& element, const String&); virtual void handleCharacterData(const String& data); diff --git a/Swiften/Parser/PayloadParsers/UnitTest/ErrorParserTest.cpp b/Swiften/Parser/PayloadParsers/UnitTest/ErrorParserTest.cpp index 338fb3f..dcd3172 100644 --- a/Swiften/Parser/PayloadParsers/UnitTest/ErrorParserTest.cpp +++ b/Swiften/Parser/PayloadParsers/UnitTest/ErrorParserTest.cpp @@ -21,14 +21,14 @@ class ErrorParserTest : public CppUnit::TestFixture CPPUNIT_ASSERT(parser.parse( "<error type=\"modify\">" "<bad-request xmlns=\"urn:ietf:params:xml:ns:xmpp-stanzas\"/>" "<text xmlns=\"urn:ietf:params:xml:ns:xmpp-stanzas\">boo</text>" "</error>")); - Error* payload = dynamic_cast<Error*>(parser.getPayload().get()); - CPPUNIT_ASSERT_EQUAL(Error::BadRequest, payload->getCondition()); - CPPUNIT_ASSERT_EQUAL(Error::Modify, payload->getType()); + ErrorPayload* payload = dynamic_cast<ErrorPayload*>(parser.getPayload().get()); + CPPUNIT_ASSERT_EQUAL(ErrorPayload::BadRequest, payload->getCondition()); + CPPUNIT_ASSERT_EQUAL(ErrorPayload::Modify, payload->getType()); CPPUNIT_ASSERT_EQUAL(String("boo"), payload->getText()); } }; CPPUNIT_TEST_SUITE_REGISTRATION(ErrorParserTest); diff --git a/Swiften/QA/ClientTest/ClientTest.cpp b/Swiften/QA/ClientTest/ClientTest.cpp index 412eb53..b50a0bf 100644 --- a/Swiften/QA/ClientTest/ClientTest.cpp +++ b/Swiften/QA/ClientTest/ClientTest.cpp @@ -16,12 +16,13 @@ SimpleEventLoop eventLoop; Client* client = 0; bool rosterReceived = false; void handleRosterReceived(boost::shared_ptr<Payload>) { rosterReceived = true; + client->disconnect(); eventLoop.stop(); } void handleConnected() { boost::shared_ptr<GetRosterRequest> rosterRequest(new GetRosterRequest(client)); rosterRequest->onResponse.connect(boost::bind(&handleRosterReceived, _1)); @@ -43,16 +44,17 @@ int main(int, char**) { client = new Swift::Client(JID(jid), String(pass)); ClientXMLTracer* tracer = new ClientXMLTracer(client); client->onConnected.connect(&handleConnected); client->connect(); { - boost::shared_ptr<Timer> timer(new Timer(10000, &MainBoostIOServiceThread::getInstance().getIOService())); + boost::shared_ptr<Timer> timer(new Timer(30000, &MainBoostIOServiceThread::getInstance().getIOService())); timer->onTick.connect(boost::bind(&SimpleEventLoop::stop, &eventLoop)); timer->start(); eventLoop.run(); } + delete tracer; delete client; return !rosterReceived; } diff --git a/Swiften/Queries/GenericRequest.h b/Swiften/Queries/GenericRequest.h index b4a1918..77dae52 100644 --- a/Swiften/Queries/GenericRequest.h +++ b/Swiften/Queries/GenericRequest.h @@ -1,8 +1,7 @@ -#ifndef SWIFTEN_GenericRequest_H -#define SWIFTEN_GenericRequest_H +#pragma once #include <boost/signal.hpp> #include "Swiften/Queries/Request.h" namespace Swift { @@ -14,16 +13,14 @@ namespace Swift { const JID& receiver, boost::shared_ptr<Payload> payload, IQRouter* router) : Request(type, receiver, payload, router) { } - virtual void handleResponse(boost::shared_ptr<Payload> payload, boost::optional<Error> error) { + virtual void handleResponse(boost::shared_ptr<Payload> payload, boost::optional<ErrorPayload> error) { onResponse(boost::dynamic_pointer_cast<PAYLOAD_TYPE>(payload), error); } public: - boost::signal<void (boost::shared_ptr<PAYLOAD_TYPE>, const boost::optional<Error>&)> onResponse; + boost::signal<void (boost::shared_ptr<PAYLOAD_TYPE>, const boost::optional<ErrorPayload>&)> onResponse; }; } - -#endif diff --git a/Swiften/Queries/IQRouter.cpp b/Swiften/Queries/IQRouter.cpp index ffed5f7..fdfa00b 100644 --- a/Swiften/Queries/IQRouter.cpp +++ b/Swiften/Queries/IQRouter.cpp @@ -3,13 +3,13 @@ #include <algorithm> #include <boost/bind.hpp> #include "Swiften/Base/foreach.h" #include "Swiften/Queries/IQHandler.h" #include "Swiften/Queries/IQChannel.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h" namespace Swift { static void noop(IQHandler*) {} IQRouter::IQRouter(IQChannel* channel) : channel_(channel), queueRemoves_(false) { @@ -31,13 +31,13 @@ void IQRouter::handleIQ(boost::shared_ptr<IQ> iq) { handled |= handler->handleIQ(iq); if (handled) { break; } } if (!handled && (iq->getType() == IQ::Get || iq->getType() == IQ::Set) ) { - channel_->sendIQ(IQ::createError(iq->getFrom(), iq->getID(), Error::FeatureNotImplemented, Error::Cancel)); + channel_->sendIQ(IQ::createError(iq->getFrom(), iq->getID(), ErrorPayload::FeatureNotImplemented, ErrorPayload::Cancel)); } processPendingRemoves(); queueRemoves_ = false; } diff --git a/Swiften/Queries/Request.cpp b/Swiften/Queries/Request.cpp index 90aa295..18446ae 100644 --- a/Swiften/Queries/Request.cpp +++ b/Swiften/Queries/Request.cpp @@ -32,17 +32,17 @@ void Request::send() { } bool Request::handleIQ(boost::shared_ptr<IQ> iq) { bool handled = false; if (sent_ && iq->getID() == id_) { if (iq->getType() == IQ::Result) { - handleResponse(iq->getPayloadOfSameType(payload_), boost::optional<Error>()); + handleResponse(iq->getPayloadOfSameType(payload_), boost::optional<ErrorPayload>()); } else { // FIXME: Get proper error - handleResponse(boost::shared_ptr<Payload>(), boost::optional<Error>(Error::UndefinedCondition)); + handleResponse(boost::shared_ptr<Payload>(), boost::optional<ErrorPayload>(ErrorPayload::UndefinedCondition)); } router_->removeHandler(this); handled = true; } return handled; } diff --git a/Swiften/Queries/Request.h b/Swiften/Queries/Request.h index 8f7a1d1..cc4a58e 100644 --- a/Swiften/Queries/Request.h +++ b/Swiften/Queries/Request.h @@ -6,13 +6,13 @@ #include <boost/enable_shared_from_this.hpp> #include "Swiften/Base/String.h" #include "Swiften/Queries/IQHandler.h" #include "Swiften/Elements/IQ.h" #include "Swiften/Elements/Payload.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h" #include "Swiften/JID/JID.h" namespace Swift { class Request : public IQHandler, public boost::enable_shared_from_this<Request> { public: Request( @@ -29,13 +29,13 @@ namespace Swift { protected: virtual void setPayload(boost::shared_ptr<Payload> p) { payload_ = p; } - virtual void handleResponse(boost::shared_ptr<Payload>, boost::optional<Error>) = 0; + virtual void handleResponse(boost::shared_ptr<Payload>, boost::optional<ErrorPayload>) = 0; private: bool handleIQ(boost::shared_ptr<IQ>); private: IQRouter* router_; diff --git a/Swiften/Queries/Requests/GetPrivateStorageRequest.h b/Swiften/Queries/Requests/GetPrivateStorageRequest.h index c5f8aef..5d6440e 100644 --- a/Swiften/Queries/Requests/GetPrivateStorageRequest.h +++ b/Swiften/Queries/Requests/GetPrivateStorageRequest.h @@ -2,29 +2,29 @@ #include <boost/signals.hpp> #include <boost/shared_ptr.hpp> #include "Swiften/Queries/Request.h" #include "Swiften/Elements/PrivateStorage.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h" namespace Swift { template<typename PAYLOAD_TYPE> class GetPrivateStorageRequest : public Request { public: GetPrivateStorageRequest(IQRouter* router) : Request(IQ::Get, JID(), boost::shared_ptr<PrivateStorage>(new PrivateStorage(boost::shared_ptr<Payload>(new PAYLOAD_TYPE()))), router) { } - virtual void handleResponse(boost::shared_ptr<Payload> payload, boost::optional<Error> error) { + virtual void handleResponse(boost::shared_ptr<Payload> payload, boost::optional<ErrorPayload> error) { boost::shared_ptr<PrivateStorage> storage = boost::dynamic_pointer_cast<PrivateStorage>(payload); if (storage) { onResponse(boost::dynamic_pointer_cast<PAYLOAD_TYPE>(storage->getPayload()), error); } else { onResponse(boost::shared_ptr<PAYLOAD_TYPE>(), error); } } public: - boost::signal<void (boost::shared_ptr<PAYLOAD_TYPE>, const boost::optional<Error>&)> onResponse; + boost::signal<void (boost::shared_ptr<PAYLOAD_TYPE>, const boost::optional<ErrorPayload>&)> onResponse; }; } diff --git a/Swiften/Queries/Requests/SetPrivateStorageRequest.h b/Swiften/Queries/Requests/SetPrivateStorageRequest.h index 63ac8dc..834ddd8 100644 --- a/Swiften/Queries/Requests/SetPrivateStorageRequest.h +++ b/Swiften/Queries/Requests/SetPrivateStorageRequest.h @@ -2,23 +2,23 @@ #include <boost/signals.hpp> #include <boost/shared_ptr.hpp> #include "Swiften/Queries/Request.h" #include "Swiften/Elements/PrivateStorage.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h" namespace Swift { template<typename PAYLOAD_TYPE> class SetPrivateStorageRequest : public Request { public: SetPrivateStorageRequest(boost::shared_ptr<PAYLOAD_TYPE> payload, IQRouter* router) : Request(IQ::Set, JID(), boost::shared_ptr<PrivateStorage>(new PrivateStorage(payload)), router) { } - virtual void handleResponse(boost::shared_ptr<Payload> payload, boost::optional<Error> error) { + virtual void handleResponse(boost::shared_ptr<Payload> payload, boost::optional<ErrorPayload> error) { onResponse(error); } public: - boost::signal<void (const boost::optional<Error>&)> onResponse; + boost::signal<void (const boost::optional<ErrorPayload>&)> onResponse; }; } diff --git a/Swiften/Queries/Requests/UnitTest/GetPrivateStorageRequestTest.cpp b/Swiften/Queries/Requests/UnitTest/GetPrivateStorageRequestTest.cpp index 14e04cf..a86a111 100644 --- a/Swiften/Queries/Requests/UnitTest/GetPrivateStorageRequestTest.cpp +++ b/Swiften/Queries/Requests/UnitTest/GetPrivateStorageRequestTest.cpp @@ -69,13 +69,13 @@ class GetPrivateStorageRequestTest : public CppUnit::TestFixture CPPUNIT_ASSERT_EQUAL(0, static_cast<int>(responses.size())); CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(errors.size())); } private: - void handleResponse(boost::shared_ptr<Payload> p, const boost::optional<Error>& e) { + void handleResponse(boost::shared_ptr<Payload> p, const boost::optional<ErrorPayload>& e) { if (e) { errors.push_back(*e); } else { responses.push_back(p); } @@ -96,11 +96,11 @@ class GetPrivateStorageRequestTest : public CppUnit::TestFixture return iq; } private: IQRouter* router; DummyIQChannel* channel; - std::vector< Error > errors; + std::vector< ErrorPayload > errors; std::vector< boost::shared_ptr<Payload> > responses; }; CPPUNIT_TEST_SUITE_REGISTRATION(GetPrivateStorageRequestTest); diff --git a/Swiften/Queries/Responder.h b/Swiften/Queries/Responder.h index e6e8ca6..9c025eb 100644 --- a/Swiften/Queries/Responder.h +++ b/Swiften/Queries/Responder.h @@ -1,12 +1,12 @@ #ifndef SWIFTEN_Responder_H #define SWIFTEN_Responder_H #include "Swiften/Queries/IQHandler.h" #include "Swiften/Queries/IQRouter.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h" namespace Swift { template<typename PAYLOAD_TYPE> class Responder : public IQHandler { public: Responder(IQRouter* router) : router_(router) { @@ -22,13 +22,13 @@ namespace Swift { virtual bool handleSetRequest(const JID& from, const String& id, boost::shared_ptr<PAYLOAD_TYPE> payload) = 0; void sendResponse(const JID& to, const String& id, boost::shared_ptr<Payload> payload) { router_->sendIQ(IQ::createResult(to, id, payload)); } - void sendError(const JID& to, const String& id, Error::Condition condition, Error::Type type) { + void sendError(const JID& to, const String& id, ErrorPayload::Condition condition, ErrorPayload::Type type) { router_->sendIQ(IQ::createError(to, id, condition, type)); } private: virtual bool handleIQ(boost::shared_ptr<IQ> iq) { if (iq->getType() == IQ::Set || iq->getType() == IQ::Get) { @@ -39,13 +39,13 @@ namespace Swift { result = handleSetRequest(iq->getFrom(), iq->getID(), payload); } else { result = handleGetRequest(iq->getFrom(), iq->getID(), payload); } if (!result) { - router_->sendIQ(IQ::createError(iq->getFrom(), iq->getID(), Error::NotAllowed, Error::Cancel)); + router_->sendIQ(IQ::createError(iq->getFrom(), iq->getID(), ErrorPayload::NotAllowed, ErrorPayload::Cancel)); } return true; } } return false; } diff --git a/Swiften/Queries/Responders/DiscoInfoResponder.cpp b/Swiften/Queries/Responders/DiscoInfoResponder.cpp index a114fbc..572f83f 100644 --- a/Swiften/Queries/Responders/DiscoInfoResponder.cpp +++ b/Swiften/Queries/Responders/DiscoInfoResponder.cpp @@ -24,13 +24,13 @@ bool DiscoInfoResponder::handleGetRequest(const JID& from, const String& id, boo else { std::map<String,DiscoInfo>::const_iterator i = nodeInfo_.find(info->getNode()); if (i != nodeInfo_.end()) { sendResponse(from, id, boost::shared_ptr<DiscoInfo>(new DiscoInfo((*i).second))); } else { - sendError(from, id, Error::ItemNotFound, Error::Cancel); + sendError(from, id, ErrorPayload::ItemNotFound, ErrorPayload::Cancel); } } return true; } } diff --git a/Swiften/Queries/Responders/UnitTest/DiscoInfoResponderTest.cpp b/Swiften/Queries/Responders/UnitTest/DiscoInfoResponderTest.cpp index 6ed7b9e..5993d0c 100644 --- a/Swiften/Queries/Responders/UnitTest/DiscoInfoResponderTest.cpp +++ b/Swiften/Queries/Responders/UnitTest/DiscoInfoResponderTest.cpp @@ -69,13 +69,13 @@ class DiscoInfoResponderTest : public CppUnit::TestFixture { boost::shared_ptr<DiscoInfo> query(new DiscoInfo()); query->setNode("bar-node"); channel_->onIQReceived(IQ::createRequest(IQ::Get, JID("foo@bar.com"), "id-1", query)); CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(channel_->iqs_.size())); - boost::shared_ptr<Error> payload(channel_->iqs_[0]->getPayload<Error>()); + boost::shared_ptr<ErrorPayload> payload(channel_->iqs_[0]->getPayload<ErrorPayload>()); CPPUNIT_ASSERT(payload); } private: IQRouter* router_; DummyIQChannel* channel_; diff --git a/Swiften/Queries/UnitTest/IQRouterTest.cpp b/Swiften/Queries/UnitTest/IQRouterTest.cpp index 94b7de8..5760b09 100644 --- a/Swiften/Queries/UnitTest/IQRouterTest.cpp +++ b/Swiften/Queries/UnitTest/IQRouterTest.cpp @@ -84,13 +84,13 @@ class IQRouterTest : public CppUnit::TestFixture IQRouter testling(channel_); DummyIQHandler handler(false, &testling); channel_->onIQReceived(boost::shared_ptr<IQ>(new IQ())); CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(channel_->iqs_.size())); - CPPUNIT_ASSERT(channel_->iqs_[0]->getPayload<Error>()); + CPPUNIT_ASSERT(channel_->iqs_[0]->getPayload<ErrorPayload>()); } void testHandleIQ_HandlerRemovedDuringHandle() { IQRouter testling(channel_); RemovingIQHandler handler1(&testling); diff --git a/Swiften/Queries/UnitTest/RequestTest.cpp b/Swiften/Queries/UnitTest/RequestTest.cpp index ea6dee6..51d5a51 100644 --- a/Swiften/Queries/UnitTest/RequestTest.cpp +++ b/Swiften/Queries/UnitTest/RequestTest.cpp @@ -110,13 +110,13 @@ class RequestTest : public CppUnit::TestFixture CPPUNIT_ASSERT_EQUAL(0, responsesReceived_); CPPUNIT_ASSERT_EQUAL(0, errorsReceived_); CPPUNIT_ASSERT_EQUAL(0, static_cast<int>(channel_->iqs_.size())); } private: - void handleResponse(boost::shared_ptr<Payload> p, const boost::optional<Error>& e) { + void handleResponse(boost::shared_ptr<Payload> p, const boost::optional<ErrorPayload>& e) { if (e) { ++errorsReceived_; } else { boost::shared_ptr<MyPayload> payload(boost::dynamic_pointer_cast<MyPayload>(p)); CPPUNIT_ASSERT(payload); diff --git a/Swiften/SConscript b/Swiften/SConscript index 148f1f8..d5ddce4 100644 --- a/Swiften/SConscript +++ b/Swiften/SConscript @@ -72,12 +72,14 @@ sources = [ "Server/ServerFromClientSession.cpp", "Server/ServerSession.cpp", "Server/ServerStanzaRouter.cpp", "Server/SimpleUserRegistry.cpp", "Server/UserRegistry.cpp", "Session/Session.cpp", + "Session/SessionStream.cpp", + "Session/BasicSessionStream.cpp", "StringCodecs/Base64.cpp", "StringCodecs/SHA1.cpp", ] # "Notifier/GrowlNotifier.cpp", if myenv.get("HAVE_OPENSSL", 0) : @@ -100,13 +102,13 @@ myenv.StaticLibrary("Swiften", sources + swiften_env["SWIFTEN_OBJECTS"]) env.Append(UNITTEST_SOURCES = [ File("Application/UnitTest/ApplicationTest.cpp"), File("Base/UnitTest/IDGeneratorTest.cpp"), File("Base/UnitTest/StringTest.cpp"), File("Base/UnitTest/ByteArrayTest.cpp"), - File("Client/UnitTest/ClientSessionTest.cpp"), + #File("Client/UnitTest/ClientSessionTest.cpp"), File("Compress/UnitTest/ZLibCompressorTest.cpp"), File("Compress/UnitTest/ZLibDecompressorTest.cpp"), File("Disco/UnitTest/CapsInfoGeneratorTest.cpp"), File("Elements/UnitTest/IQTest.cpp"), File("Elements/UnitTest/StanzaTest.cpp"), File("Elements/UnitTest/StanzasTest.cpp"), diff --git a/Swiften/Serializer/PayloadSerializers/ErrorSerializer.cpp b/Swiften/Serializer/PayloadSerializers/ErrorSerializer.cpp index 347e1a5..f5ce478 100644 --- a/Swiften/Serializer/PayloadSerializers/ErrorSerializer.cpp +++ b/Swiften/Serializer/PayloadSerializers/ErrorSerializer.cpp @@ -1,48 +1,48 @@ #include "Swiften/Serializer/PayloadSerializers/ErrorSerializer.h" #include "Swiften/Serializer/XML/XMLTextNode.h" namespace Swift { -ErrorSerializer::ErrorSerializer() : GenericPayloadSerializer<Error>() { +ErrorSerializer::ErrorSerializer() : GenericPayloadSerializer<ErrorPayload>() { } -String ErrorSerializer::serializePayload(boost::shared_ptr<Error> error) const { +String ErrorSerializer::serializePayload(boost::shared_ptr<ErrorPayload> error) const { String result("<error type=\""); switch (error->getType()) { - case Error::Continue: result += "continue"; break; - case Error::Modify: result += "modify"; break; - case Error::Auth: result += "auth"; break; - case Error::Wait: result += "wait"; break; + case ErrorPayload::Continue: result += "continue"; break; + case ErrorPayload::Modify: result += "modify"; break; + case ErrorPayload::Auth: result += "auth"; break; + case ErrorPayload::Wait: result += "wait"; break; default: result += "cancel"; break; } result += "\">"; String conditionElement; switch (error->getCondition()) { - case Error::BadRequest: conditionElement = "bad-request"; break; - case Error::Conflict: conditionElement = "conflict"; break; - case Error::FeatureNotImplemented: conditionElement = "feature-not-implemented"; break; - case Error::Forbidden: conditionElement = "forbidden"; break; - case Error::Gone: conditionElement = "gone"; break; - case Error::InternalServerError: conditionElement = "internal-server-error"; break; - case Error::ItemNotFound: conditionElement = "item-not-found"; break; - case Error::JIDMalformed: conditionElement = "jid-malformed"; break; - case Error::NotAcceptable: conditionElement = "not-acceptable"; break; - case Error::NotAllowed: conditionElement = "not-allowed"; break; - case Error::NotAuthorized: conditionElement = "not-authorized"; break; - case Error::PaymentRequired: conditionElement = "payment-required"; break; - case Error::RecipientUnavailable: conditionElement = "recipient-unavailable"; break; - case Error::Redirect: conditionElement = "redirect"; break; - case Error::RegistrationRequired: conditionElement = "registration-required"; break; - case Error::RemoteServerNotFound: conditionElement = "remote-server-not-found"; break; - case Error::RemoteServerTimeout: conditionElement = "remote-server-timeout"; break; - case Error::ResourceConstraint: conditionElement = "resource-constraint"; break; - case Error::ServiceUnavailable: conditionElement = "service-unavailable"; break; - case Error::SubscriptionRequired: conditionElement = "subscription-required"; break; - case Error::UnexpectedRequest: conditionElement = "unexpected-request"; break; + case ErrorPayload::BadRequest: conditionElement = "bad-request"; break; + case ErrorPayload::Conflict: conditionElement = "conflict"; break; + case ErrorPayload::FeatureNotImplemented: conditionElement = "feature-not-implemented"; break; + case ErrorPayload::Forbidden: conditionElement = "forbidden"; break; + case ErrorPayload::Gone: conditionElement = "gone"; break; + case ErrorPayload::InternalServerError: conditionElement = "internal-server-error"; break; + case ErrorPayload::ItemNotFound: conditionElement = "item-not-found"; break; + case ErrorPayload::JIDMalformed: conditionElement = "jid-malformed"; break; + case ErrorPayload::NotAcceptable: conditionElement = "not-acceptable"; break; + case ErrorPayload::NotAllowed: conditionElement = "not-allowed"; break; + case ErrorPayload::NotAuthorized: conditionElement = "not-authorized"; break; + case ErrorPayload::PaymentRequired: conditionElement = "payment-required"; break; + case ErrorPayload::RecipientUnavailable: conditionElement = "recipient-unavailable"; break; + case ErrorPayload::Redirect: conditionElement = "redirect"; break; + case ErrorPayload::RegistrationRequired: conditionElement = "registration-required"; break; + case ErrorPayload::RemoteServerNotFound: conditionElement = "remote-server-not-found"; break; + case ErrorPayload::RemoteServerTimeout: conditionElement = "remote-server-timeout"; break; + case ErrorPayload::ResourceConstraint: conditionElement = "resource-constraint"; break; + case ErrorPayload::ServiceUnavailable: conditionElement = "service-unavailable"; break; + case ErrorPayload::SubscriptionRequired: conditionElement = "subscription-required"; break; + case ErrorPayload::UnexpectedRequest: conditionElement = "unexpected-request"; break; default: conditionElement = "undefined-condition"; break; } result += "<" + conditionElement + " xmlns=\"urn:ietf:params:xml:ns:xmpp-stanzas\"/>"; if (!error->getText().isEmpty()) { XMLTextNode textNode(error->getText()); diff --git a/Swiften/Serializer/PayloadSerializers/ErrorSerializer.h b/Swiften/Serializer/PayloadSerializers/ErrorSerializer.h index ecf73dc..931596f 100644 --- a/Swiften/Serializer/PayloadSerializers/ErrorSerializer.h +++ b/Swiften/Serializer/PayloadSerializers/ErrorSerializer.h @@ -1,16 +1,16 @@ #ifndef SWIFTEN_ErrorSerializer_H #define SWIFTEN_ErrorSerializer_H #include "Swiften/Serializer/GenericPayloadSerializer.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h" namespace Swift { - class ErrorSerializer : public GenericPayloadSerializer<Error> { + class ErrorSerializer : public GenericPayloadSerializer<ErrorPayload> { public: ErrorSerializer(); - virtual String serializePayload(boost::shared_ptr<Error> error) const; + virtual String serializePayload(boost::shared_ptr<ErrorPayload> error) const; }; } #endif diff --git a/Swiften/Serializer/PayloadSerializers/UnitTest/ErrorSerializerTest.cpp b/Swiften/Serializer/PayloadSerializers/UnitTest/ErrorSerializerTest.cpp index 2d68a3d..ecd904a 100644 --- a/Swiften/Serializer/PayloadSerializers/UnitTest/ErrorSerializerTest.cpp +++ b/Swiften/Serializer/PayloadSerializers/UnitTest/ErrorSerializerTest.cpp @@ -13,13 +13,13 @@ class ErrorSerializerTest : public CppUnit::TestFixture public: ErrorSerializerTest() {} void testSerialize() { ErrorSerializer testling; - boost::shared_ptr<Error> error(new Error(Error::BadRequest, Error::Cancel, "My Error")); + boost::shared_ptr<ErrorPayload> error(new ErrorPayload(ErrorPayload::BadRequest, ErrorPayload::Cancel, "My Error")); CPPUNIT_ASSERT_EQUAL(String("<error type=\"cancel\"><bad-request xmlns=\"urn:ietf:params:xml:ns:xmpp-stanzas\"/><text xmlns=\"urn:ietf:params:xml:ns:xmpp-stanzas\">My Error</text></error>"), testling.serialize(error)); } }; CPPUNIT_TEST_SUITE_REGISTRATION(ErrorSerializerTest); diff --git a/Swiften/Session/BasicSessionStream.cpp b/Swiften/Session/BasicSessionStream.cpp index 46d4e16..8b14367 100644 --- a/Swiften/Session/BasicSessionStream.cpp +++ b/Swiften/Session/BasicSessionStream.cpp @@ -1,8 +1,6 @@ -// TODO: whitespacePingLayer_->setInactive(); - #include "Swiften/Session/BasicSessionStream.h" #include <boost/bind.hpp> #include "Swiften/StreamStack/XMPPLayer.h" #include "Swiften/StreamStack/StreamStack.h" @@ -10,63 +8,122 @@ #include "Swiften/StreamStack/WhitespacePingLayer.h" #include "Swiften/StreamStack/TLSLayer.h" #include "Swiften/StreamStack/TLSLayerFactory.h" namespace Swift { -BasicSessionStream::BasicSessionStream(boost::shared_ptr<Connection> connection, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers, TLSLayerFactory* tlsLayerFactory) : tlsLayerFactory(tlsLayerFactory) { +BasicSessionStream::BasicSessionStream(boost::shared_ptr<Connection> connection, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers, TLSLayerFactory* tlsLayerFactory) : available(false), connection(connection), payloadParserFactories(payloadParserFactories), payloadSerializers(payloadSerializers), tlsLayerFactory(tlsLayerFactory) { +} + +void BasicSessionStream::initialize() { xmppLayer = boost::shared_ptr<XMPPLayer>( new XMPPLayer(payloadParserFactories, payloadSerializers)); - xmppLayer->onStreamStart.connect(boost::ref(onStreamStartReceived)); - xmppLayer->onElement.connect(boost::ref(onElementReceived)); + xmppLayer->onStreamStart.connect(boost::bind(&BasicSessionStream::handleStreamStartReceived, shared_from_this(), _1)); + xmppLayer->onElement.connect(boost::bind(&BasicSessionStream::handleElementReceived, shared_from_this(), _1)); xmppLayer->onError.connect(boost::bind( - &BasicSessionStream::handleXMPPError, this)); + &BasicSessionStream::handleXMPPError, shared_from_this())); + xmppLayer->onDataRead.connect(boost::bind(&BasicSessionStream::handleDataRead, shared_from_this(), _1)); + xmppLayer->onWriteData.connect(boost::bind(&BasicSessionStream::handleDataWritten, shared_from_this(), _1)); + connection->onDisconnected.connect(boost::bind(&BasicSessionStream::handleConnectionError, shared_from_this(), _1)); connectionLayer = boost::shared_ptr<ConnectionLayer>( new ConnectionLayer(connection)); streamStack = new StreamStack(xmppLayer, connectionLayer); + + available = true; } BasicSessionStream::~BasicSessionStream() { delete streamStack; } void BasicSessionStream::writeHeader(const ProtocolHeader& header) { + assert(available); xmppLayer->writeHeader(header); } void BasicSessionStream::writeElement(boost::shared_ptr<Element> element) { + assert(available); xmppLayer->writeElement(element); } +void BasicSessionStream::writeFooter() { + assert(available); + xmppLayer->writeFooter(); +} + +bool BasicSessionStream::isAvailable() { + return available; +} + bool BasicSessionStream::supportsTLSEncryption() { return tlsLayerFactory && tlsLayerFactory->canCreate(); } void BasicSessionStream::addTLSEncryption() { + assert(available); tlsLayer = tlsLayerFactory->createTLSLayer(); - streamStack->addLayer(tlsLayer); - // TODO: Add tls layer certificate if needed - tlsLayer->onError.connect(boost::bind(&BasicSessionStream::handleTLSError, this)); - tlsLayer->connect(); + if (hasTLSCertificate() && !tlsLayer->setClientCertificate(getTLSCertificate())) { + onError(boost::shared_ptr<Error>(new Error(Error::InvalidTLSCertificateError))); + } + else { + streamStack->addLayer(tlsLayer); + tlsLayer->onError.connect(boost::bind(&BasicSessionStream::handleTLSError, shared_from_this())); + tlsLayer->onConnected.connect(boost::bind(&BasicSessionStream::handleTLSConnected, shared_from_this())); + tlsLayer->connect(); + } } -void BasicSessionStream::addWhitespacePing() { - whitespacePingLayer = boost::shared_ptr<WhitespacePingLayer>(new WhitespacePingLayer()); - streamStack->addLayer(whitespacePingLayer); - whitespacePingLayer->setActive(); +void BasicSessionStream::setWhitespacePingEnabled(bool enabled) { + if (enabled && !whitespacePingLayer) { + whitespacePingLayer = boost::shared_ptr<WhitespacePingLayer>(new WhitespacePingLayer()); + streamStack->addLayer(whitespacePingLayer); + } + if (enabled) { + whitespacePingLayer->setActive(); + } + else { + whitespacePingLayer->setInactive(); + } } void BasicSessionStream::resetXMPPParser() { xmppLayer->resetParser(); } +void BasicSessionStream::handleStreamStartReceived(const ProtocolHeader& header) { + onStreamStartReceived(header); +} + +void BasicSessionStream::handleElementReceived(boost::shared_ptr<Element> element) { + onElementReceived(element); +} + void BasicSessionStream::handleXMPPError() { - // TODO + available = false; + onError(boost::shared_ptr<Error>(new Error(Error::ParseError))); +} + +void BasicSessionStream::handleTLSConnected() { + onTLSEncrypted(); } void BasicSessionStream::handleTLSError() { - // TODO + available = false; + onError(boost::shared_ptr<Error>(new Error(Error::TLSError))); +} + +void BasicSessionStream::handleConnectionError(const boost::optional<Connection::Error>&) { + available = false; + onError(boost::shared_ptr<Error>(new Error(Error::ConnectionError))); +} + +void BasicSessionStream::handleDataRead(const ByteArray& data) { + onDataRead(String(data.getData(), data.getSize())); +} + +void BasicSessionStream::handleDataWritten(const ByteArray& data) { + onDataWritten(String(data.getData(), data.getSize())); } }; diff --git a/Swiften/Session/BasicSessionStream.h b/Swiften/Session/BasicSessionStream.h index bf92bbb..0cb50eb 100644 --- a/Swiften/Session/BasicSessionStream.h +++ b/Swiften/Session/BasicSessionStream.h @@ -1,9 +1,10 @@ #pragma once #include <boost/shared_ptr.hpp> +#include <boost/enable_shared_from_this.hpp> #include "Swiften/Network/Connection.h" #include "Swiften/Session/SessionStream.h" namespace Swift { class TLSLayerFactory; @@ -14,39 +15,54 @@ namespace Swift { class StreamStack; class XMPPLayer; class ConnectionLayer; class BasicSessionStream : public SessionStream, - public boost::BOOST_SIGNALS_NAMESPACE::trackable { + public boost::enable_shared_from_this<BasicSessionStream> { public: BasicSessionStream( boost::shared_ptr<Connection> connection, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers, TLSLayerFactory* tlsLayerFactory ); ~BasicSessionStream(); + void initialize(); + + virtual bool isAvailable(); + virtual void writeHeader(const ProtocolHeader& header); virtual void writeElement(boost::shared_ptr<Element>); + virtual void writeFooter(); virtual bool supportsTLSEncryption(); virtual void addTLSEncryption(); - virtual void addWhitespacePing(); + virtual void setWhitespacePingEnabled(bool); virtual void resetXMPPParser(); private: + void handleConnectionError(const boost::optional<Connection::Error>& error); void handleXMPPError(); + void handleTLSConnected(); void handleTLSError(); + void handleStreamStartReceived(const ProtocolHeader&); + void handleElementReceived(boost::shared_ptr<Element>); + void handleDataRead(const ByteArray& data); + void handleDataWritten(const ByteArray& data); private: + bool available; + boost::shared_ptr<Connection> connection; + PayloadParserFactoryCollection* payloadParserFactories; + PayloadSerializerCollection* payloadSerializers; + TLSLayerFactory* tlsLayerFactory; boost::shared_ptr<XMPPLayer> xmppLayer; boost::shared_ptr<ConnectionLayer> connectionLayer; StreamStack* streamStack; - TLSLayerFactory* tlsLayerFactory; boost::shared_ptr<TLSLayer> tlsLayer; boost::shared_ptr<WhitespacePingLayer> whitespacePingLayer; }; } diff --git a/Swiften/Session/SessionStream.h b/Swiften/Session/SessionStream.h index 17d9a24..6bba237 100644 --- a/Swiften/Session/SessionStream.h +++ b/Swiften/Session/SessionStream.h @@ -2,26 +2,65 @@ #include <boost/signal.hpp> #include <boost/shared_ptr.hpp> #include "Swiften/Elements/ProtocolHeader.h" #include "Swiften/Elements/Element.h" +#include "Swiften/Base/Error.h" +#include "Swiften/TLS/PKCS12Certificate.h" namespace Swift { class SessionStream { public: + class Error : public Swift::Error { + public: + enum Type { + ParseError, + TLSError, + InvalidTLSCertificateError, + ConnectionError + }; + + Error(Type type) : type(type) {} + + Type type; + }; + virtual ~SessionStream(); + virtual bool isAvailable() = 0; + virtual void writeHeader(const ProtocolHeader& header) = 0; + virtual void writeFooter() = 0; virtual void writeElement(boost::shared_ptr<Element>) = 0; virtual bool supportsTLSEncryption() = 0; virtual void addTLSEncryption() = 0; - - virtual void addWhitespacePing() = 0; + virtual void setWhitespacePingEnabled(bool enabled) = 0; virtual void resetXMPPParser() = 0; + void setTLSCertificate(const PKCS12Certificate& cert) { + certificate = cert; + } + + virtual bool hasTLSCertificate() { + return !certificate.isNull(); + } + + boost::signal<void (const ProtocolHeader&)> onStreamStartReceived; boost::signal<void (boost::shared_ptr<Element>)> onElementReceived; + boost::signal<void (boost::shared_ptr<Error>)> onError; + boost::signal<void ()> onTLSEncrypted; + boost::signal<void (const String&)> onDataRead; + boost::signal<void (const String&)> onDataWritten; + + protected: + const PKCS12Certificate& getTLSCertificate() const { + return certificate; + } + + private: + PKCS12Certificate certificate; }; } |
Swift