diff options
26 files changed, 87 insertions, 42 deletions
diff --git a/Limber/Server/ServerFromClientSession.cpp b/Limber/Server/ServerFromClientSession.cpp index 3a37c65..fd361b7 100644 --- a/Limber/Server/ServerFromClientSession.cpp +++ b/Limber/Server/ServerFromClientSession.cpp @@ -1,67 +1,68 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include "Limber/Server/ServerFromClientSession.h" #include <boost/bind.hpp> #include "Swiften/Elements/ProtocolHeader.h" #include "Limber/Server/UserRegistry.h" #include "Swiften/Network/Connection.h" #include "Swiften/StreamStack/XMPPLayer.h" #include "Swiften/Elements/StreamFeatures.h" #include "Swiften/Elements/ResourceBind.h" #include "Swiften/Elements/StartSession.h" #include "Swiften/Elements/IQ.h" #include "Swiften/Elements/AuthSuccess.h" #include "Swiften/Elements/AuthFailure.h" #include "Swiften/Elements/AuthRequest.h" #include "Swiften/SASL/PLAINMessage.h" namespace Swift { ServerFromClientSession::ServerFromClientSession( const std::string& id, boost::shared_ptr<Connection> connection, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers, + XMLParserFactory* xmlParserFactory, UserRegistry* userRegistry) : - Session(connection, payloadParserFactories, payloadSerializers), + Session(connection, payloadParserFactories, payloadSerializers, xmlParserFactory), id_(id), userRegistry_(userRegistry), authenticated_(false), initialized(false), allowSASLEXTERNAL(false) { } void ServerFromClientSession::handleElement(boost::shared_ptr<Element> element) { if (isInitialized()) { onElementReceived(element); } else { if (AuthRequest* authRequest = dynamic_cast<AuthRequest*>(element.get())) { if (authRequest->getMechanism() == "PLAIN" || (allowSASLEXTERNAL && authRequest->getMechanism() == "EXTERNAL")) { if (authRequest->getMechanism() == "EXTERNAL") { getXMPPLayer()->writeElement(boost::shared_ptr<AuthSuccess>(new AuthSuccess())); authenticated_ = true; getXMPPLayer()->resetParser(); } else { PLAINMessage plainMessage(authRequest->getMessage() ? *authRequest->getMessage() : createSafeByteArray("")); if (userRegistry_->isValidUserPassword(JID(plainMessage.getAuthenticationID(), getLocalJID().getDomain()), plainMessage.getPassword())) { getXMPPLayer()->writeElement(boost::shared_ptr<AuthSuccess>(new AuthSuccess())); user_ = plainMessage.getAuthenticationID(); authenticated_ = true; getXMPPLayer()->resetParser(); } else { getXMPPLayer()->writeElement(boost::shared_ptr<AuthFailure>(new AuthFailure)); finishSession(AuthenticationFailedError); } } } else { diff --git a/Limber/Server/ServerFromClientSession.h b/Limber/Server/ServerFromClientSession.h index 1a0e109..34ea40e 100644 --- a/Limber/Server/ServerFromClientSession.h +++ b/Limber/Server/ServerFromClientSession.h @@ -1,60 +1,62 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once #include <boost/shared_ptr.hpp> #include <Swiften/Base/boost_bsignals.h> #include <boost/enable_shared_from_this.hpp> #include <string> #include <Swiften/Session/Session.h> #include <Swiften/JID/JID.h> #include <Swiften/Network/Connection.h> #include <Swiften/Base/ByteArray.h> namespace Swift { class ProtocolHeader; class Element; class Stanza; class PayloadParserFactoryCollection; class PayloadSerializerCollection; class StreamStack; class UserRegistry; class XMPPLayer; class ConnectionLayer; class Connection; + class XMLParserFactory; class ServerFromClientSession : public Session { public: ServerFromClientSession( const std::string& id, boost::shared_ptr<Connection> connection, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers, + XMLParserFactory* xmlParserFactory, UserRegistry* userRegistry); boost::signal<void ()> onSessionStarted; void setAllowSASLEXTERNAL(); private: void handleElement(boost::shared_ptr<Element>); void handleStreamStart(const ProtocolHeader& header); void setInitialized(); bool isInitialized() const { return initialized; } private: std::string id_; UserRegistry* userRegistry_; bool authenticated_; bool initialized; bool allowSASLEXTERNAL; std::string user_; }; } diff --git a/Limber/main.cpp b/Limber/main.cpp index e6bc45d..350b357 100644 --- a/Limber/main.cpp +++ b/Limber/main.cpp @@ -1,102 +1,104 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <string> #include <boost/bind.hpp> #include <boost/shared_ptr.hpp> #include "Swiften/Elements/IQ.h" #include "Swiften/Elements/RosterPayload.h" #include "Swiften/Elements/VCard.h" #include "Swiften/Base/IDGenerator.h" #include "Swiften/EventLoop/EventLoop.h" #include "Swiften/EventLoop/SimpleEventLoop.h" #include "Swiften/EventLoop/EventOwner.h" #include "Swiften/Elements/Stanza.h" #include "Swiften/Network/ConnectionServer.h" #include "Swiften/Network/BoostConnection.h" #include "Swiften/Network/BoostIOServiceThread.h" #include "Swiften/Network/BoostConnectionServer.h" #include "Limber/Server/SimpleUserRegistry.h" #include "Limber/Server/ServerFromClientSession.h" #include "Swiften/Parser/PayloadParsers/FullPayloadParserFactoryCollection.h" +#include "Swiften/Parser/PlatformXMLParserFactory.h" #include "Swiften/Serializer/PayloadSerializers/FullPayloadSerializerCollection.h" using namespace Swift; class Server { public: Server(UserRegistry* userRegistry, EventLoop* eventLoop) : userRegistry_(userRegistry) { serverFromClientConnectionServer_ = BoostConnectionServer::create(5222, boostIOServiceThread_.getIOService(), eventLoop); serverFromClientConnectionServer_->onNewConnection.connect(boost::bind(&Server::handleNewConnection, this, _1)); serverFromClientConnectionServer_->start(); } private: void handleNewConnection(boost::shared_ptr<Connection> c) { - boost::shared_ptr<ServerFromClientSession> session(new ServerFromClientSession(idGenerator_.generateID(), c, &payloadParserFactories_, &payloadSerializers_, userRegistry_)); + boost::shared_ptr<ServerFromClientSession> session(new ServerFromClientSession(idGenerator_.generateID(), c, &payloadParserFactories_, &payloadSerializers_, &xmlParserFactory, userRegistry_)); serverFromClientSessions_.push_back(session); session->onElementReceived.connect(boost::bind(&Server::handleElementReceived, this, _1, session)); session->onSessionFinished.connect(boost::bind(&Server::handleSessionFinished, this, session)); session->startSession(); } void handleSessionFinished(boost::shared_ptr<ServerFromClientSession> session) { serverFromClientSessions_.erase(std::remove(serverFromClientSessions_.begin(), serverFromClientSessions_.end(), session), serverFromClientSessions_.end()); } void handleElementReceived(boost::shared_ptr<Element> element, boost::shared_ptr<ServerFromClientSession> session) { boost::shared_ptr<Stanza> stanza(boost::dynamic_pointer_cast<Stanza>(element)); if (!stanza) { return; } stanza->setFrom(session->getRemoteJID()); if (!stanza->getTo().isValid()) { stanza->setTo(JID(session->getLocalJID())); } if (!stanza->getTo().isValid() || stanza->getTo() == session->getLocalJID() || stanza->getTo() == session->getRemoteJID().toBare()) { if (boost::shared_ptr<IQ> iq = boost::dynamic_pointer_cast<IQ>(stanza)) { if (iq->getPayload<RosterPayload>()) { session->sendElement(IQ::createResult(iq->getFrom(), iq->getID(), boost::shared_ptr<RosterPayload>(new RosterPayload()))); } if (iq->getPayload<VCard>()) { if (iq->getType() == IQ::Get) { boost::shared_ptr<VCard> vcard(new VCard()); vcard->setNickname(iq->getFrom().getNode()); session->sendElement(IQ::createResult(iq->getFrom(), iq->getID(), vcard)); } else { session->sendElement(IQ::createError(iq->getFrom(), iq->getID(), ErrorPayload::Forbidden, ErrorPayload::Cancel)); } } else { session->sendElement(IQ::createError(iq->getFrom(), iq->getID(), ErrorPayload::FeatureNotImplemented, ErrorPayload::Cancel)); } } } } private: IDGenerator idGenerator_; + PlatformXMLParserFactory xmlParserFactory; UserRegistry* userRegistry_; BoostIOServiceThread boostIOServiceThread_; boost::shared_ptr<BoostConnectionServer> serverFromClientConnectionServer_; std::vector< boost::shared_ptr<ServerFromClientSession> > serverFromClientSessions_; FullPayloadParserFactoryCollection payloadParserFactories_; FullPayloadSerializerCollection payloadSerializers_; }; int main() { SimpleEventLoop eventLoop; SimpleUserRegistry userRegistry; userRegistry.addUser(JID("remko@localhost"), "remko"); userRegistry.addUser(JID("kevin@localhost"), "kevin"); userRegistry.addUser(JID("remko@limber.swift.im"), "remko"); userRegistry.addUser(JID("kevin@limber.swift.im"), "kevin"); Server server(&userRegistry, &eventLoop); eventLoop.run(); return 0; } diff --git a/Slimber/Server.cpp b/Slimber/Server.cpp index 84b33fa..769217f 100644 --- a/Slimber/Server.cpp +++ b/Slimber/Server.cpp @@ -115,71 +115,71 @@ void Server::stop(boost::optional<ServerError> e) { } linkLocalSessions.clear(); foreach(boost::shared_ptr<LinkLocalConnector> connector, connectors) { connector->cancel(); } connectors.clear(); tracers.clear(); if (serverFromNetworkConnectionServer) { serverFromNetworkConnectionServer->stop(); foreach(boost::bsignals::connection& connection, serverFromNetworkConnectionServerSignalConnections) { connection.disconnect(); } serverFromNetworkConnectionServerSignalConnections.clear(); serverFromNetworkConnectionServer.reset(); } if (serverFromClientConnectionServer) { serverFromClientConnectionServer->stop(); foreach(boost::bsignals::connection& connection, serverFromClientConnectionServerSignalConnections) { connection.disconnect(); } serverFromClientConnectionServerSignalConnections.clear(); serverFromClientConnectionServer.reset(); } stopping = false; onStopped(e); } void Server::handleNewClientConnection(boost::shared_ptr<Connection> connection) { if (serverFromClientSession) { connection->disconnect(); } serverFromClientSession = boost::shared_ptr<ServerFromClientSession>( new ServerFromClientSession(idGenerator.generateID(), connection, - &payloadParserFactories, &payloadSerializers, &userRegistry)); + &payloadParserFactories, &payloadSerializers, &xmlParserFactory, &userRegistry)); serverFromClientSession->setAllowSASLEXTERNAL(); serverFromClientSession->onSessionStarted.connect( boost::bind(&Server::handleSessionStarted, this)); serverFromClientSession->onElementReceived.connect( boost::bind(&Server::handleElementReceived, this, _1, serverFromClientSession)); serverFromClientSession->onSessionFinished.connect( boost::bind(&Server::handleSessionFinished, this, serverFromClientSession)); //tracers.push_back(boost::shared_ptr<SessionTracer>( // new SessionTracer(serverFromClientSession))); serverFromClientSession->startSession(); } void Server::handleSessionStarted() { onSelfConnected(true); } void Server::handleSessionFinished(boost::shared_ptr<ServerFromClientSession>) { serverFromClientSession.reset(); unregisterService(); selfJID = JID(); rosterRequested = false; onSelfConnected(false); lastPresence.reset(); } void Server::unregisterService() { if (linkLocalServiceRegistered) { linkLocalServiceRegistered = false; linkLocalServiceBrowser->unregisterService(); } } void Server::handleElementReceived(boost::shared_ptr<Element> element, boost::shared_ptr<ServerFromClientSession> session) { @@ -249,102 +249,102 @@ void Server::handleElementReceived(boost::shared_ptr<Element> element, boost::sh if (outgoingSession) { outgoingSession->sendElement(stanza); } else { boost::optional<LinkLocalService> service = presenceManager->getServiceForJID(toJID); if (service) { boost::shared_ptr<LinkLocalConnector> connector = getLinkLocalConnectorForJID(toJID); if (!connector) { connector = boost::shared_ptr<LinkLocalConnector>( new LinkLocalConnector( *service, linkLocalServiceBrowser->getQuerier(), BoostConnection::create(boostIOServiceThread.getIOService(), eventLoop))); connector->onConnectFinished.connect( boost::bind(&Server::handleConnectFinished, this, connector, _1)); connectors.push_back(connector); connector->connect(); } connector->queueElement(element); } else { session->sendElement(IQ::createError( stanza->getFrom(), stanza->getID(), ErrorPayload::RecipientUnavailable, ErrorPayload::Wait)); } } } } void Server::handleNewLinkLocalConnection(boost::shared_ptr<Connection> connection) { boost::shared_ptr<IncomingLinkLocalSession> session( new IncomingLinkLocalSession( selfJID, connection, - &payloadParserFactories, &payloadSerializers)); + &payloadParserFactories, &payloadSerializers, &xmlParserFactory)); registerLinkLocalSession(session); } void Server::handleLinkLocalSessionFinished(boost::shared_ptr<Session> session) { //std::cout << "Link local session from " << session->getRemoteJID() << " ended" << std::endl; linkLocalSessions.erase( std::remove(linkLocalSessions.begin(), linkLocalSessions.end(), session), linkLocalSessions.end()); } void Server::handleLinkLocalElementReceived(boost::shared_ptr<Element> element, boost::shared_ptr<Session> session) { if (boost::shared_ptr<Stanza> stanza = boost::dynamic_pointer_cast<Stanza>(element)) { JID fromJID = session->getRemoteJID(); if (!presenceManager->getServiceForJID(fromJID.toBare())) { return; // TODO: Send error back } stanza->setFrom(fromJID); serverFromClientSession->sendElement(stanza); } } void Server::handleConnectFinished(boost::shared_ptr<LinkLocalConnector> connector, bool error) { if (error) { std::cerr << "Error connecting" << std::endl; // TODO: Send back queued stanzas } else { boost::shared_ptr<OutgoingLinkLocalSession> outgoingSession( new OutgoingLinkLocalSession( selfJID, connector->getService().getJID(), connector->getConnection(), - &payloadParserFactories, &payloadSerializers)); + &payloadParserFactories, &payloadSerializers, &xmlParserFactory)); foreach(const boost::shared_ptr<Element> element, connector->getQueuedElements()) { outgoingSession->queueElement(element); } registerLinkLocalSession(outgoingSession); } connectors.erase(std::remove(connectors.begin(), connectors.end(), connector), connectors.end()); } void Server::registerLinkLocalSession(boost::shared_ptr<Session> session) { session->onSessionFinished.connect( boost::bind(&Server::handleLinkLocalSessionFinished, this, session)); session->onElementReceived.connect( boost::bind(&Server::handleLinkLocalElementReceived, this, _1, session)); linkLocalSessions.push_back(session); //tracers.push_back(boost::shared_ptr<SessionTracer>(new SessionTracer(session))); session->startSession(); } boost::shared_ptr<Session> Server::getLinkLocalSessionForJID(const JID& jid) { foreach(const boost::shared_ptr<Session> session, linkLocalSessions) { if (session->getRemoteJID() == jid) { return session; } } return boost::shared_ptr<Session>(); } boost::shared_ptr<LinkLocalConnector> Server::getLinkLocalConnectorForJID(const JID& jid) { foreach(const boost::shared_ptr<LinkLocalConnector> connector, connectors) { if (connector->getService().getJID() == jid) { return connector; } } return boost::shared_ptr<LinkLocalConnector>(); } diff --git a/Slimber/Server.h b/Slimber/Server.h index 58b1e7c..96401d9 100644 --- a/Slimber/Server.h +++ b/Slimber/Server.h @@ -1,51 +1,52 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once #include <boost/shared_ptr.hpp> #include <boost/optional.hpp> #include <vector> #include "Swiften/Network/BoostIOServiceThread.h" #include "Swiften/Network/BoostConnectionServer.h" #include "Limber/Server/UserRegistry.h" #include "Swiften/Base/IDGenerator.h" +#include "Swiften/Parser/PlatformXMLParserFactory.h" #include "Limber/Server/ServerFromClientSession.h" #include "Swiften/JID/JID.h" #include "Swiften/Parser/PayloadParsers/FullPayloadParserFactoryCollection.h" #include "Swiften/Serializer/PayloadSerializers/FullPayloadSerializerCollection.h" #include "Swiften/LinkLocal/LinkLocalServiceInfo.h" #include "Slimber/ServerError.h" namespace Swift { class DNSSDServiceID; class VCardCollection; class LinkLocalConnector; class LinkLocalServiceBrowser; class LinkLocalPresenceManager; class BoostConnectionServer; class SessionTracer; class RosterPayload; class Presence; class EventLoop; class Server { public: Server( int clientConnectionPort, int linkLocalConnectionPort, LinkLocalServiceBrowser* browser, VCardCollection* vCardCollection, EventLoop* eventLoop); ~Server(); void start(); void stop(); int getLinkLocalPort() const { return linkLocalConnectionPort; @@ -66,56 +67,57 @@ namespace Swift { void handleSessionFinished(boost::shared_ptr<ServerFromClientSession>); void handleElementReceived(boost::shared_ptr<Element> element, boost::shared_ptr<ServerFromClientSession> session); void handleRosterChanged(boost::shared_ptr<RosterPayload> roster); void handlePresenceChanged(boost::shared_ptr<Presence> presence); void handleServiceRegistered(const DNSSDServiceID& service); void handleNewLinkLocalConnection(boost::shared_ptr<Connection> connection); void handleLinkLocalSessionFinished(boost::shared_ptr<Session> session); void handleLinkLocalElementReceived(boost::shared_ptr<Element> element, boost::shared_ptr<Session> session); void handleConnectFinished(boost::shared_ptr<LinkLocalConnector> connector, bool error); void handleClientConnectionServerStopped( boost::optional<BoostConnectionServer::Error>); void handleLinkLocalConnectionServerStopped( boost::optional<BoostConnectionServer::Error>); boost::shared_ptr<Session> getLinkLocalSessionForJID(const JID& jid); boost::shared_ptr<LinkLocalConnector> getLinkLocalConnectorForJID(const JID& jid); void registerLinkLocalSession(boost::shared_ptr<Session> session); void unregisterService(); LinkLocalServiceInfo getLinkLocalServiceInfo(boost::shared_ptr<Presence> presence); private: class DummyUserRegistry : public UserRegistry { public: DummyUserRegistry() {} virtual bool isValidUserPassword(const JID&, const SafeByteArray&) const { return true; } }; private: IDGenerator idGenerator; FullPayloadParserFactoryCollection payloadParserFactories; FullPayloadSerializerCollection payloadSerializers; BoostIOServiceThread boostIOServiceThread; DummyUserRegistry userRegistry; + PlatformXMLParserFactory xmlParserFactory; bool linkLocalServiceRegistered; bool rosterRequested; int clientConnectionPort; int linkLocalConnectionPort; LinkLocalServiceBrowser* linkLocalServiceBrowser; VCardCollection* vCardCollection; EventLoop* eventLoop; LinkLocalPresenceManager* presenceManager; bool stopping; boost::shared_ptr<BoostConnectionServer> serverFromClientConnectionServer; std::vector<boost::bsignals::connection> serverFromClientConnectionServerSignalConnections; boost::shared_ptr<ServerFromClientSession> serverFromClientSession; boost::shared_ptr<Presence> lastPresence; JID selfJID; boost::shared_ptr<BoostConnectionServer> serverFromNetworkConnectionServer; std::vector<boost::bsignals::connection> serverFromNetworkConnectionServerSignalConnections; std::vector< boost::shared_ptr<Session> > linkLocalSessions; std::vector< boost::shared_ptr<LinkLocalConnector> > connectors; std::vector< boost::shared_ptr<SessionTracer> > tracers; }; } diff --git a/Swiften/Client/CoreClient.cpp b/Swiften/Client/CoreClient.cpp index f6a3f20..dbc6de2 100644 --- a/Swiften/Client/CoreClient.cpp +++ b/Swiften/Client/CoreClient.cpp @@ -69,71 +69,71 @@ void CoreClient::connect(const std::string& host) { PlatformProxyProvider proxyProvider; if(proxyProvider.getSOCKS5Proxy().isValid()) { proxyConnectionFactories.push_back(new SOCKS5ProxiedConnectionFactory(networkFactories->getConnectionFactory(), proxyProvider.getSOCKS5Proxy())); } if(proxyProvider.getHTTPConnectProxy().isValid()) { proxyConnectionFactories.push_back(new HTTPConnectProxiedConnectionFactory(networkFactories->getConnectionFactory(), proxyProvider.getHTTPConnectProxy())); } std::vector<ConnectionFactory*> connectionFactories(proxyConnectionFactories); connectionFactories.push_back(networkFactories->getConnectionFactory()); connector_ = boost::make_shared<ChainedConnector>(host, networkFactories->getDomainNameResolver(), connectionFactories, networkFactories->getTimerFactory()); connector_->onConnectFinished.connect(boost::bind(&CoreClient::handleConnectorFinished, this, _1)); connector_->setTimeoutMilliseconds(60*1000); connector_->start(); } void CoreClient::handleConnectorFinished(boost::shared_ptr<Connection> connection) { connector_->onConnectFinished.disconnect(boost::bind(&CoreClient::handleConnectorFinished, this, _1)); connector_.reset(); foreach(ConnectionFactory* f, proxyConnectionFactories) { delete f; } proxyConnectionFactories.clear(); if (!connection) { if (options.forgetPassword) { purgePassword(); } onDisconnected(disconnectRequested_ ? boost::optional<ClientError>() : boost::optional<ClientError>(ClientError::ConnectionError)); } else { assert(!connection_); connection_ = connection; assert(!sessionStream_); - sessionStream_ = boost::make_shared<BasicSessionStream>(ClientStreamType, connection_, getPayloadParserFactories(), getPayloadSerializers(), tlsFactories->getTLSContextFactory(), networkFactories->getTimerFactory()); + sessionStream_ = boost::make_shared<BasicSessionStream>(ClientStreamType, connection_, getPayloadParserFactories(), getPayloadSerializers(), tlsFactories->getTLSContextFactory(), networkFactories->getTimerFactory(), networkFactories->getXMLParserFactory()); if (!certificate_.empty()) { sessionStream_->setTLSCertificate(PKCS12Certificate(certificate_, password_)); } sessionStream_->onDataRead.connect(boost::bind(&CoreClient::handleDataRead, this, _1)); sessionStream_->onDataWritten.connect(boost::bind(&CoreClient::handleDataWritten, this, _1)); session_ = ClientSession::create(jid_, sessionStream_); session_->setCertificateTrustChecker(certificateTrustChecker); session_->setUseStreamCompression(options.useStreamCompression); session_->setAllowPLAINOverNonTLS(options.allowPLAINWithoutTLS); switch(options.useTLS) { case ClientOptions::UseTLSWhenAvailable: session_->setUseTLS(ClientSession::UseTLSWhenAvailable); break; case ClientOptions::NeverUseTLS: session_->setUseTLS(ClientSession::NeverUseTLS); break; case ClientOptions::RequireTLS: session_->setUseTLS(ClientSession::RequireTLS); break; } session_->setUseAcks(options.useAcks); stanzaChannel_->setSession(session_); session_->onFinished.connect(boost::bind(&CoreClient::handleSessionFinished, this, _1)); session_->onNeedCredentials.connect(boost::bind(&CoreClient::handleNeedCredentials, this)); session_->start(); } } void CoreClient::disconnect() { // FIXME: We should be able to do without this boolean. We just have to make sure we can tell the difference between // connector finishing without a connection due to an error or because of a disconnect. disconnectRequested_ = true; if (session_ && !session_->isFinished()) { session_->finish(); diff --git a/Swiften/Component/CoreComponent.cpp b/Swiften/Component/CoreComponent.cpp index 7ee1ff5..e630ddf 100644 --- a/Swiften/Component/CoreComponent.cpp +++ b/Swiften/Component/CoreComponent.cpp @@ -31,71 +31,71 @@ CoreComponent::CoreComponent(EventLoop* eventLoop, NetworkFactories* networkFact } CoreComponent::~CoreComponent() { if (session_ || connection_) { std::cerr << "Warning: Component not disconnected properly" << std::endl; } delete iqRouter_; stanzaChannel_->onAvailableChanged.disconnect(boost::bind(&CoreComponent::handleStanzaChannelAvailableChanged, this, _1)); stanzaChannel_->onMessageReceived.disconnect(boost::ref(onMessageReceived)); stanzaChannel_->onPresenceReceived.disconnect(boost::ref(onPresenceReceived)); delete stanzaChannel_; } void CoreComponent::connect(const std::string& host, int port) { assert(!connector_); connector_ = ComponentConnector::create(host, port, &resolver_, networkFactories->getConnectionFactory(), networkFactories->getTimerFactory()); connector_->onConnectFinished.connect(boost::bind(&CoreComponent::handleConnectorFinished, this, _1)); connector_->setTimeoutMilliseconds(60*1000); connector_->start(); } void CoreComponent::handleConnectorFinished(boost::shared_ptr<Connection> connection) { connector_->onConnectFinished.disconnect(boost::bind(&CoreComponent::handleConnectorFinished, this, _1)); connector_.reset(); if (!connection) { if (!disconnectRequested_) { onError(ComponentError::ConnectionError); } } else { assert(!connection_); connection_ = connection; assert(!sessionStream_); - sessionStream_ = boost::shared_ptr<BasicSessionStream>(new BasicSessionStream(ComponentStreamType, connection_, getPayloadParserFactories(), getPayloadSerializers(), NULL, networkFactories->getTimerFactory())); + sessionStream_ = boost::shared_ptr<BasicSessionStream>(new BasicSessionStream(ComponentStreamType, connection_, getPayloadParserFactories(), getPayloadSerializers(), NULL, networkFactories->getTimerFactory(), networkFactories->getXMLParserFactory())); sessionStream_->onDataRead.connect(boost::bind(&CoreComponent::handleDataRead, this, _1)); sessionStream_->onDataWritten.connect(boost::bind(&CoreComponent::handleDataWritten, this, _1)); session_ = ComponentSession::create(jid_, secret_, sessionStream_); stanzaChannel_->setSession(session_); session_->onFinished.connect(boost::bind(&CoreComponent::handleSessionFinished, this, _1)); session_->start(); } } void CoreComponent::disconnect() { // FIXME: We should be able to do without this boolean. We just have to make sure we can tell the difference between // connector finishing without a connection due to an error or because of a disconnect. disconnectRequested_ = true; if (session_) { session_->finish(); } else if (connector_) { connector_->stop(); assert(!session_); } assert(!session_); assert(!sessionStream_); assert(!connector_); disconnectRequested_ = false; } void CoreComponent::handleSessionFinished(boost::shared_ptr<Error> error) { session_->onFinished.disconnect(boost::bind(&CoreComponent::handleSessionFinished, this, _1)); session_.reset(); sessionStream_->onDataRead.disconnect(boost::bind(&CoreComponent::handleDataRead, this, _1)); sessionStream_->onDataWritten.disconnect(boost::bind(&CoreComponent::handleDataWritten, this, _1)); sessionStream_.reset(); diff --git a/Swiften/Examples/ParserTester/ParserTester.cpp b/Swiften/Examples/ParserTester/ParserTester.cpp index 211d44f..009eef4 100644 --- a/Swiften/Examples/ParserTester/ParserTester.cpp +++ b/Swiften/Examples/ParserTester/ParserTester.cpp @@ -1,58 +1,60 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <iostream> #include <fstream> #include <typeinfo> #include <Swiften/Parser/UnitTest/ParserTester.h> #include <Swiften/Parser/XMPPParser.h> #include <Swiften/Parser/XMPPParserClient.h> +#include <Swiften/Parser/PlatformXMLParserFactory.h> #include <Swiften/Parser/PayloadParsers/FullPayloadParserFactoryCollection.h> using namespace Swift; class MyXMPPParserClient : public XMPPParserClient { public: virtual void handleStreamStart(const ProtocolHeader&) { std::cout << "-> Stream start" << std::endl; } virtual void handleElement(boost::shared_ptr<Element> element) { std::cout << "-> Element " << typeid(*element.get()).name() << std::endl; } virtual void handleStreamEnd() { std::cout << "-> Stream end" << std::endl; } }; int main(int argc, char* argv[]) { if (argc != 2) { std::cerr << "Usage: " << argv[0] << " file" << std::endl; return 0; } FullPayloadParserFactoryCollection factories; MyXMPPParserClient parserClient; - XMPPParser parser(&parserClient, &factories); + PlatformXMLParserFactory xmlParserFactory; + XMPPParser parser(&parserClient, &factories, &xmlParserFactory); ParserTester<XMLParserClient> tester(&parser); std::string line; std::ifstream myfile (argv[1]); if (myfile.is_open()) { while (!myfile.eof()) { getline (myfile,line); std::cout << "Parsing: " << line << std::endl; if (!tester.parse(line)) { std::cerr << "PARSE ERROR" << std::endl; return -1; } } myfile.close(); } else { std::cerr << "Unable to open file " << argv[1] << std::endl; } return 0; } diff --git a/Swiften/LinkLocal/IncomingLinkLocalSession.cpp b/Swiften/LinkLocal/IncomingLinkLocalSession.cpp index c4dea64..b89de81 100644 --- a/Swiften/LinkLocal/IncomingLinkLocalSession.cpp +++ b/Swiften/LinkLocal/IncomingLinkLocalSession.cpp @@ -1,61 +1,62 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <Swiften/LinkLocal/IncomingLinkLocalSession.h> #include <boost/bind.hpp> #include <Swiften/Elements/ProtocolHeader.h> #include <Swiften/Network/Connection.h> #include <Swiften/StreamStack/StreamStack.h> #include <Swiften/StreamStack/ConnectionLayer.h> #include <Swiften/StreamStack/XMPPLayer.h> #include <Swiften/Elements/StreamFeatures.h> #include <Swiften/Elements/IQ.h> namespace Swift { IncomingLinkLocalSession::IncomingLinkLocalSession( const JID& localJID, boost::shared_ptr<Connection> connection, PayloadParserFactoryCollection* payloadParserFactories, - PayloadSerializerCollection* payloadSerializers) : - Session(connection, payloadParserFactories, payloadSerializers), + PayloadSerializerCollection* payloadSerializers, + XMLParserFactory* xmlParserFactory) : + Session(connection, payloadParserFactories, payloadSerializers, xmlParserFactory), initialized(false) { setLocalJID(localJID); } void IncomingLinkLocalSession::handleStreamStart(const ProtocolHeader& incomingHeader) { setRemoteJID(JID(incomingHeader.getFrom())); if (!getRemoteJID().isValid()) { finishSession(); return; } ProtocolHeader header; header.setFrom(getLocalJID()); getXMPPLayer()->writeHeader(header); if (incomingHeader.getVersion() == "1.0") { getXMPPLayer()->writeElement(boost::shared_ptr<StreamFeatures>(new StreamFeatures())); } else { setInitialized(); } } void IncomingLinkLocalSession::handleElement(boost::shared_ptr<Element> element) { boost::shared_ptr<Stanza> stanza = boost::dynamic_pointer_cast<Stanza>(element); // If we get our first stanza before streamfeatures, our session is implicitly // initialized if (stanza && !isInitialized()) { setInitialized(); } onElementReceived(element); } void IncomingLinkLocalSession::setInitialized() { diff --git a/Swiften/LinkLocal/IncomingLinkLocalSession.h b/Swiften/LinkLocal/IncomingLinkLocalSession.h index 68e21a5..f00c166 100644 --- a/Swiften/LinkLocal/IncomingLinkLocalSession.h +++ b/Swiften/LinkLocal/IncomingLinkLocalSession.h @@ -1,43 +1,44 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once #include <boost/shared_ptr.hpp> #include <Swiften/Base/boost_bsignals.h> #include <Swiften/Session/Session.h> #include <Swiften/JID/JID.h> #include <Swiften/Network/Connection.h> namespace Swift { class ProtocolHeader; - + class XMLParserFactory; class Element; class PayloadParserFactoryCollection; class PayloadSerializerCollection; class IncomingLinkLocalSession : public Session { public: IncomingLinkLocalSession( const JID& localJID, boost::shared_ptr<Connection> connection, PayloadParserFactoryCollection* payloadParserFactories, - PayloadSerializerCollection* payloadSerializers); + PayloadSerializerCollection* payloadSerializers, + XMLParserFactory* xmlParserFactory); boost::signal<void ()> onSessionStarted; private: void handleElement(boost::shared_ptr<Element>); void handleStreamStart(const ProtocolHeader&); void setInitialized(); bool isInitialized() const { return initialized; } bool initialized; }; } diff --git a/Swiften/LinkLocal/OutgoingLinkLocalSession.cpp b/Swiften/LinkLocal/OutgoingLinkLocalSession.cpp index 9d712f8..7a59715 100644 --- a/Swiften/LinkLocal/OutgoingLinkLocalSession.cpp +++ b/Swiften/LinkLocal/OutgoingLinkLocalSession.cpp @@ -1,52 +1,53 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <Swiften/LinkLocal/OutgoingLinkLocalSession.h> #include <boost/bind.hpp> #include <Swiften/Base/foreach.h> #include <Swiften/StreamStack/XMPPLayer.h> #include <Swiften/Elements/ProtocolHeader.h> #include <Swiften/Elements/StreamFeatures.h> #include <Swiften/Elements/IQ.h> namespace Swift { OutgoingLinkLocalSession::OutgoingLinkLocalSession( const JID& localJID, const JID& remoteJID, boost::shared_ptr<Connection> connection, PayloadParserFactoryCollection* payloadParserFactories, - PayloadSerializerCollection* payloadSerializers) : - Session(connection, payloadParserFactories, payloadSerializers) { + PayloadSerializerCollection* payloadSerializers, + XMLParserFactory* xmlParserFactory) : + Session(connection, payloadParserFactories, payloadSerializers, xmlParserFactory) { setLocalJID(localJID); setRemoteJID(remoteJID); } void OutgoingLinkLocalSession::handleSessionStarted() { ProtocolHeader header; header.setFrom(getLocalJID()); getXMPPLayer()->writeHeader(header); } void OutgoingLinkLocalSession::handleStreamStart(const ProtocolHeader&) { foreach(const boost::shared_ptr<Element>& stanza, queuedElements_) { sendElement(stanza); } queuedElements_.clear(); } void OutgoingLinkLocalSession::handleElement(boost::shared_ptr<Element> element) { onElementReceived(element); } void OutgoingLinkLocalSession::queueElement(boost::shared_ptr<Element> element) { queuedElements_.push_back(element); } } diff --git a/Swiften/LinkLocal/OutgoingLinkLocalSession.h b/Swiften/LinkLocal/OutgoingLinkLocalSession.h index 430c446..b97f2bf 100644 --- a/Swiften/LinkLocal/OutgoingLinkLocalSession.h +++ b/Swiften/LinkLocal/OutgoingLinkLocalSession.h @@ -1,43 +1,44 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once #include <boost/shared_ptr.hpp> #include <Swiften/Base/boost_bsignals.h> #include <boost/enable_shared_from_this.hpp> #include <vector> #include <Swiften/Session/Session.h> #include <Swiften/JID/JID.h> namespace Swift { class ConnectionFactory; - + class XMLParserFactory; class Element; class PayloadParserFactoryCollection; class PayloadSerializerCollection; class OutgoingLinkLocalSession : public Session { public: OutgoingLinkLocalSession( const JID& localJID, const JID& remoteJID, boost::shared_ptr<Connection> connection, PayloadParserFactoryCollection* payloadParserFactories, - PayloadSerializerCollection* payloadSerializers); + PayloadSerializerCollection* payloadSerializers, + XMLParserFactory* xmlParserFactory); void queueElement(boost::shared_ptr<Element> element); private: void handleSessionStarted(); void handleElement(boost::shared_ptr<Element>); void handleStreamStart(const ProtocolHeader&); private: std::vector<boost::shared_ptr<Element> > queuedElements_; }; } diff --git a/Swiften/Network/BoostNetworkFactories.cpp b/Swiften/Network/BoostNetworkFactories.cpp index 315290c..56be2b7 100644 --- a/Swiften/Network/BoostNetworkFactories.cpp +++ b/Swiften/Network/BoostNetworkFactories.cpp @@ -1,37 +1,40 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <Swiften/Network/BoostNetworkFactories.h> #include <Swiften/Network/BoostTimerFactory.h> #include <Swiften/Network/BoostConnectionFactory.h> #include <Swiften/Network/PlatformDomainNameResolver.h> #include <Swiften/Network/BoostConnectionServerFactory.h> #include <Swiften/Network/PlatformNATTraversalWorker.h> +#include <Swiften/Parser/PlatformXMLParserFactory.h> #include <Swiften/Network/NullNATTraverser.h> namespace Swift { BoostNetworkFactories::BoostNetworkFactories(EventLoop* eventLoop) { timerFactory = new BoostTimerFactory(ioServiceThread.getIOService(), eventLoop); connectionFactory = new BoostConnectionFactory(ioServiceThread.getIOService(), eventLoop); domainNameResolver = new PlatformDomainNameResolver(eventLoop); connectionServerFactory = new BoostConnectionServerFactory(ioServiceThread.getIOService(), eventLoop); #ifdef SWIFT_EXPERIMENTAL_FT natTraverser = new PlatformNATTraversalWorker(eventLoop); #else natTraverser = new NullNATTraverser(eventLoop); #endif + xmlParserFactory = new PlatformXMLParserFactory(); } BoostNetworkFactories::~BoostNetworkFactories() { + delete xmlParserFactory; delete natTraverser; delete connectionServerFactory; delete domainNameResolver; delete connectionFactory; delete timerFactory; } } diff --git a/Swiften/Network/BoostNetworkFactories.h b/Swiften/Network/BoostNetworkFactories.h index bc7a963..c9ecb59 100644 --- a/Swiften/Network/BoostNetworkFactories.h +++ b/Swiften/Network/BoostNetworkFactories.h @@ -10,44 +10,49 @@ #include <Swiften/Network/BoostIOServiceThread.h> namespace Swift { class EventLoop; class NATTraverser; class BoostNetworkFactories : public NetworkFactories { public: BoostNetworkFactories(EventLoop* eventLoop); ~BoostNetworkFactories(); virtual TimerFactory* getTimerFactory() const { return timerFactory; } virtual ConnectionFactory* getConnectionFactory() const { return connectionFactory; } BoostIOServiceThread* getIOServiceThread() { return &ioServiceThread; } DomainNameResolver* getDomainNameResolver() const { return domainNameResolver; } ConnectionServerFactory* getConnectionServerFactory() const { return connectionServerFactory; } NATTraverser* getNATTraverser() const { return natTraverser; } + virtual XMLParserFactory* getXMLParserFactory() const { + return xmlParserFactory; + } + private: BoostIOServiceThread ioServiceThread; TimerFactory* timerFactory; ConnectionFactory* connectionFactory; DomainNameResolver* domainNameResolver; ConnectionServerFactory* connectionServerFactory; NATTraverser* natTraverser; + XMLParserFactory* xmlParserFactory; }; } diff --git a/Swiften/Network/NetworkFactories.h b/Swiften/Network/NetworkFactories.h index 05ddfe3..42c9f6a 100644 --- a/Swiften/Network/NetworkFactories.h +++ b/Swiften/Network/NetworkFactories.h @@ -1,29 +1,31 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once namespace Swift { class TimerFactory; class ConnectionFactory; class DomainNameResolver; class ConnectionServerFactory; class NATTraverser; + class XMLParserFactory; /** * An interface collecting network factories. */ class NetworkFactories { public: virtual ~NetworkFactories(); virtual TimerFactory* getTimerFactory() const = 0; virtual ConnectionFactory* getConnectionFactory() const = 0; virtual DomainNameResolver* getDomainNameResolver() const = 0; virtual ConnectionServerFactory* getConnectionServerFactory() const = 0; virtual NATTraverser* getNATTraverser() const = 0; + virtual XMLParserFactory* getXMLParserFactory() const = 0; }; } diff --git a/Swiften/Parser/UnitTest/XMPPParserTest.cpp b/Swiften/Parser/UnitTest/XMPPParserTest.cpp index dbee18a..f8d60f2 100644 --- a/Swiften/Parser/UnitTest/XMPPParserTest.cpp +++ b/Swiften/Parser/UnitTest/XMPPParserTest.cpp @@ -1,194 +1,196 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <cppunit/extensions/HelperMacros.h> #include <cppunit/extensions/TestFactoryRegistry.h> #include <vector> #include <Swiften/Elements/ProtocolHeader.h> #include <string> #include <Swiften/Parser/XMPPParser.h> #include <Swiften/Parser/ElementParser.h> #include <Swiften/Parser/XMPPParserClient.h> #include <Swiften/Parser/PayloadParserFactoryCollection.h> +#include <Swiften/Parser/PlatformXMLParserFactory.h> #include <Swiften/Elements/Presence.h> #include <Swiften/Elements/IQ.h> #include <Swiften/Elements/Message.h> #include <Swiften/Elements/StreamFeatures.h> #include <Swiften/Elements/UnknownElement.h> using namespace Swift; class XMPPParserTest : public CppUnit::TestFixture { CPPUNIT_TEST_SUITE(XMPPParserTest); CPPUNIT_TEST(testParse_SimpleSession); CPPUNIT_TEST(testParse_SimpleClientFromServerSession); CPPUNIT_TEST(testParse_Presence); CPPUNIT_TEST(testParse_IQ); CPPUNIT_TEST(testParse_Message); CPPUNIT_TEST(testParse_StreamFeatures); CPPUNIT_TEST(testParse_UnknownElement); CPPUNIT_TEST(testParse_StrayCharacterData); CPPUNIT_TEST(testParse_InvalidStreamStart); CPPUNIT_TEST(testParse_ElementEndAfterInvalidStreamStart); CPPUNIT_TEST_SUITE_END(); public: void testParse_SimpleSession() { - XMPPParser testling(&client_, &factories_); + XMPPParser testling(&client_, &factories_, &xmlParserFactory_); CPPUNIT_ASSERT(testling.parse("<?xml version='1.0'?>")); CPPUNIT_ASSERT(testling.parse("<stream:stream to='example.com' xmlns='jabber:client' xmlns:stream='http://etherx.jabber.org/streams' >")); CPPUNIT_ASSERT(testling.parse("<presence/>")); CPPUNIT_ASSERT(testling.parse("<presence/>")); CPPUNIT_ASSERT(testling.parse("<iq/>")); CPPUNIT_ASSERT(testling.parse("</stream:stream>")); CPPUNIT_ASSERT_EQUAL(5, static_cast<int>(client_.events.size())); CPPUNIT_ASSERT_EQUAL(Client::StreamStart, client_.events[0].type); CPPUNIT_ASSERT_EQUAL(std::string("example.com"), client_.events[0].header->getTo()); CPPUNIT_ASSERT_EQUAL(Client::ElementEvent, client_.events[1].type); CPPUNIT_ASSERT_EQUAL(Client::ElementEvent, client_.events[2].type); CPPUNIT_ASSERT_EQUAL(Client::ElementEvent, client_.events[3].type); CPPUNIT_ASSERT_EQUAL(Client::StreamEnd, client_.events[4].type); } void testParse_SimpleClientFromServerSession() { - XMPPParser testling(&client_, &factories_); + XMPPParser testling(&client_, &factories_, &xmlParserFactory_); CPPUNIT_ASSERT(testling.parse("<?xml version='1.0'?>")); CPPUNIT_ASSERT(testling.parse("<stream:stream from='example.com' xmlns='jabber:client' xmlns:stream='http://etherx.jabber.org/streams' id='aeab'>")); CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(client_.events.size())); CPPUNIT_ASSERT_EQUAL(Client::StreamStart, client_.events[0].type); CPPUNIT_ASSERT_EQUAL(std::string("example.com"), client_.events[0].header->getFrom()); CPPUNIT_ASSERT_EQUAL(std::string("aeab"), client_.events[0].header->getID()); } void testParse_Presence() { - XMPPParser testling(&client_, &factories_); + XMPPParser testling(&client_, &factories_, &xmlParserFactory_); CPPUNIT_ASSERT(testling.parse("<stream:stream xmlns:stream='http://etherx.jabber.org/streams'>")); CPPUNIT_ASSERT(testling.parse("<presence/>")); CPPUNIT_ASSERT_EQUAL(2, static_cast<int>(client_.events.size())); CPPUNIT_ASSERT_EQUAL(Client::ElementEvent, client_.events[1].type); CPPUNIT_ASSERT(dynamic_cast<Presence*>(client_.events[1].element.get())); } void testParse_IQ() { - XMPPParser testling(&client_, &factories_); + XMPPParser testling(&client_, &factories_, &xmlParserFactory_); CPPUNIT_ASSERT(testling.parse("<stream:stream xmlns:stream='http://etherx.jabber.org/streams'>")); CPPUNIT_ASSERT(testling.parse("<iq/>")); CPPUNIT_ASSERT_EQUAL(2, static_cast<int>(client_.events.size())); CPPUNIT_ASSERT_EQUAL(Client::ElementEvent, client_.events[1].type); CPPUNIT_ASSERT(dynamic_cast<IQ*>(client_.events[1].element.get())); } void testParse_Message() { - XMPPParser testling(&client_, &factories_); + XMPPParser testling(&client_, &factories_, &xmlParserFactory_); CPPUNIT_ASSERT(testling.parse("<stream:stream xmlns:stream='http://etherx.jabber.org/streams'>")); CPPUNIT_ASSERT(testling.parse("<message/>")); CPPUNIT_ASSERT_EQUAL(2, static_cast<int>(client_.events.size())); CPPUNIT_ASSERT_EQUAL(Client::ElementEvent, client_.events[1].type); CPPUNIT_ASSERT(dynamic_cast<Message*>(client_.events[1].element.get())); } void testParse_StreamFeatures() { - XMPPParser testling(&client_, &factories_); + XMPPParser testling(&client_, &factories_, &xmlParserFactory_); CPPUNIT_ASSERT(testling.parse("<stream:stream xmlns:stream='http://etherx.jabber.org/streams'>")); CPPUNIT_ASSERT(testling.parse("<stream:features/>")); CPPUNIT_ASSERT_EQUAL(2, static_cast<int>(client_.events.size())); CPPUNIT_ASSERT_EQUAL(Client::ElementEvent, client_.events[1].type); CPPUNIT_ASSERT(dynamic_cast<StreamFeatures*>(client_.events[1].element.get())); } void testParse_UnknownElement() { - XMPPParser testling(&client_, &factories_); + XMPPParser testling(&client_, &factories_, &xmlParserFactory_); CPPUNIT_ASSERT(testling.parse("<stream:stream xmlns:stream='http://etherx.jabber.org/streams'>")); CPPUNIT_ASSERT(testling.parse("<presence/>")); CPPUNIT_ASSERT(testling.parse("<foo/>")); CPPUNIT_ASSERT(testling.parse("<bar/>")); CPPUNIT_ASSERT(testling.parse("<presence/>")); CPPUNIT_ASSERT_EQUAL(5, static_cast<int>(client_.events.size())); CPPUNIT_ASSERT_EQUAL(Client::ElementEvent, client_.events[2].type); CPPUNIT_ASSERT(dynamic_cast<UnknownElement*>(client_.events[2].element.get())); CPPUNIT_ASSERT_EQUAL(Client::ElementEvent, client_.events[3].type); CPPUNIT_ASSERT(dynamic_cast<UnknownElement*>(client_.events[3].element.get())); CPPUNIT_ASSERT_EQUAL(Client::ElementEvent, client_.events[4].type); CPPUNIT_ASSERT(dynamic_cast<Presence*>(client_.events[4].element.get())); } void testParse_StrayCharacterData() { - XMPPParser testling(&client_, &factories_); + XMPPParser testling(&client_, &factories_, &xmlParserFactory_); CPPUNIT_ASSERT(testling.parse("<stream:stream xmlns:stream='http://etherx.jabber.org/streams'>")); CPPUNIT_ASSERT(testling.parse("<presence/>")); CPPUNIT_ASSERT(testling.parse("bla")); CPPUNIT_ASSERT(testling.parse("<iq/>")); CPPUNIT_ASSERT_EQUAL(3, static_cast<int>(client_.events.size())); CPPUNIT_ASSERT_EQUAL(Client::ElementEvent, client_.events[2].type); CPPUNIT_ASSERT(dynamic_cast<IQ*>(client_.events[2].element.get())); } void testParse_InvalidStreamStart() { - XMPPParser testling(&client_, &factories_); + XMPPParser testling(&client_, &factories_, &xmlParserFactory_); CPPUNIT_ASSERT(!testling.parse("<tream>")); } void testParse_ElementEndAfterInvalidStreamStart() { - XMPPParser testling(&client_, &factories_); + XMPPParser testling(&client_, &factories_, &xmlParserFactory_); CPPUNIT_ASSERT(!testling.parse("<tream/>")); } private: class Client : public XMPPParserClient { public: enum Type { StreamStart, ElementEvent, StreamEnd }; struct Event { Event(Type type, boost::shared_ptr<Element> element) : type(type), element(element) {} Event(Type type, const ProtocolHeader& header) : type(type), header(header) {} Event(Type type) : type(type) {} Type type; boost::optional<ProtocolHeader> header; boost::shared_ptr<Element> element; }; Client() {} void handleStreamStart(const ProtocolHeader& header) { events.push_back(Event(StreamStart, header)); } void handleElement(boost::shared_ptr<Element> element) { events.push_back(Event(ElementEvent, element)); } void handleStreamEnd() { events.push_back(Event(StreamEnd)); } std::vector<Event> events; } client_; PayloadParserFactoryCollection factories_; + PlatformXMLParserFactory xmlParserFactory_; }; CPPUNIT_TEST_SUITE_REGISTRATION(XMPPParserTest); diff --git a/Swiften/Parser/XMPPParser.cpp b/Swiften/Parser/XMPPParser.cpp index 6779b86..069a5bd 100644 --- a/Swiften/Parser/XMPPParser.cpp +++ b/Swiften/Parser/XMPPParser.cpp @@ -1,94 +1,95 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <Swiften/Parser/XMPPParser.h> #include <iostream> #include <cassert> #include <Swiften/Elements/ProtocolHeader.h> #include <string> #include <Swiften/Parser/XMLParser.h> -#include <Swiften/Parser/PlatformXMLParserFactory.h> #include <Swiften/Parser/XMPPParserClient.h> #include <Swiften/Parser/XMPPParser.h> #include <Swiften/Parser/ElementParser.h> #include <Swiften/Parser/PresenceParser.h> #include <Swiften/Parser/IQParser.h> #include <Swiften/Parser/MessageParser.h> #include <Swiften/Parser/StreamFeaturesParser.h> #include <Swiften/Parser/StreamErrorParser.h> #include <Swiften/Parser/AuthRequestParser.h> #include <Swiften/Parser/AuthSuccessParser.h> #include <Swiften/Parser/AuthFailureParser.h> #include <Swiften/Parser/AuthChallengeParser.h> #include <Swiften/Parser/AuthResponseParser.h> #include <Swiften/Parser/EnableStreamManagementParser.h> #include <Swiften/Parser/StreamManagementEnabledParser.h> #include <Swiften/Parser/StreamManagementFailedParser.h> #include <Swiften/Parser/StreamResumeParser.h> #include <Swiften/Parser/StreamResumedParser.h> #include <Swiften/Parser/StanzaAckParser.h> #include <Swiften/Parser/StanzaAckRequestParser.h> #include <Swiften/Parser/StartTLSParser.h> #include <Swiften/Parser/StartTLSFailureParser.h> #include <Swiften/Parser/CompressParser.h> #include <Swiften/Parser/CompressFailureParser.h> #include <Swiften/Parser/CompressedParser.h> #include <Swiften/Parser/UnknownElementParser.h> #include <Swiften/Parser/TLSProceedParser.h> #include <Swiften/Parser/ComponentHandshakeParser.h> +#include <Swiften/Parser/XMLParserFactory.h> // TODO: Whenever an error occurs in the handlers, stop the parser by returing // a bool value, and stopping the XML parser namespace Swift { XMPPParser::XMPPParser( XMPPParserClient* client, - PayloadParserFactoryCollection* payloadParserFactories) : + PayloadParserFactoryCollection* payloadParserFactories, + XMLParserFactory* xmlParserFactory) : xmlParser_(0), client_(client), payloadParserFactories_(payloadParserFactories), level_(0), currentElementParser_(0), parseErrorOccurred_(false) { - xmlParser_ = PlatformXMLParserFactory().createXMLParser(this); + xmlParser_ = xmlParserFactory->createXMLParser(this); } XMPPParser::~XMPPParser() { delete currentElementParser_; delete xmlParser_; } bool XMPPParser::parse(const std::string& data) { bool xmlParseResult = xmlParser_->parse(data); return xmlParseResult && !parseErrorOccurred_; } void XMPPParser::handleStartElement(const std::string& element, const std::string& ns, const AttributeMap& attributes) { if (!parseErrorOccurred_) { if (level_ == TopLevel) { if (element == "stream" && ns == "http://etherx.jabber.org/streams") { ProtocolHeader header; header.setFrom(attributes.getAttribute("from")); header.setTo(attributes.getAttribute("to")); header.setID(attributes.getAttribute("id")); header.setVersion(attributes.getAttribute("version")); client_->handleStreamStart(header); } else { parseErrorOccurred_ = true; } } else { if (level_ == StreamLevel) { assert(!currentElementParser_); currentElementParser_ = createElementParser(element, ns); } currentElementParser_->handleStartElement(element, ns, attributes); } } diff --git a/Swiften/Parser/XMPPParser.h b/Swiften/Parser/XMPPParser.h index b5d6d24..6cce2bd 100644 --- a/Swiften/Parser/XMPPParser.h +++ b/Swiften/Parser/XMPPParser.h @@ -1,54 +1,55 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once #include <boost/shared_ptr.hpp> #include <boost/noncopyable.hpp> #include <Swiften/Parser/XMLParserClient.h> #include <Swiften/Parser/AttributeMap.h> namespace Swift { class XMLParser; class XMPPParserClient; - + class XMLParserFactory; class ElementParser; class PayloadParserFactoryCollection; class XMPPParser : public XMLParserClient, boost::noncopyable { public: XMPPParser( XMPPParserClient* parserClient, - PayloadParserFactoryCollection* payloadParserFactories); + PayloadParserFactoryCollection* payloadParserFactories, + XMLParserFactory* xmlParserFactory); ~XMPPParser(); bool parse(const std::string&); private: virtual void handleStartElement( const std::string& element, const std::string& ns, const AttributeMap& attributes); virtual void handleEndElement(const std::string& element, const std::string& ns); virtual void handleCharacterData(const std::string& data); ElementParser* createElementParser(const std::string& element, const std::string& xmlns); private: XMLParser* xmlParser_; XMPPParserClient* client_; PayloadParserFactoryCollection* payloadParserFactories_; enum Level { TopLevel = 0, StreamLevel = 1, ElementLevel = 2 }; int level_; ElementParser* currentElementParser_; bool parseErrorOccurred_; }; } diff --git a/Swiften/Session/BasicSessionStream.cpp b/Swiften/Session/BasicSessionStream.cpp index d08be4f..07a04b8 100644 --- a/Swiften/Session/BasicSessionStream.cpp +++ b/Swiften/Session/BasicSessionStream.cpp @@ -1,74 +1,75 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <Swiften/Session/BasicSessionStream.h> #include <boost/bind.hpp> #include <Swiften/StreamStack/XMPPLayer.h> #include <Swiften/StreamStack/StreamStack.h> #include <Swiften/StreamStack/ConnectionLayer.h> #include <Swiften/StreamStack/WhitespacePingLayer.h> #include <Swiften/StreamStack/CompressionLayer.h> #include <Swiften/StreamStack/TLSLayer.h> #include <Swiften/TLS/TLSContextFactory.h> #include <Swiften/TLS/TLSContext.h> namespace Swift { BasicSessionStream::BasicSessionStream( StreamType streamType, boost::shared_ptr<Connection> connection, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers, TLSContextFactory* tlsContextFactory, - TimerFactory* timerFactory) : + TimerFactory* timerFactory, + XMLParserFactory* xmlParserFactory) : available(false), connection(connection), payloadParserFactories(payloadParserFactories), payloadSerializers(payloadSerializers), tlsContextFactory(tlsContextFactory), timerFactory(timerFactory), streamType(streamType), compressionLayer(NULL), tlsLayer(NULL), whitespacePingLayer(NULL) { - xmppLayer = new XMPPLayer(payloadParserFactories, payloadSerializers, streamType); + xmppLayer = new XMPPLayer(payloadParserFactories, payloadSerializers, xmlParserFactory, streamType); xmppLayer->onStreamStart.connect(boost::bind(&BasicSessionStream::handleStreamStartReceived, this, _1)); xmppLayer->onElement.connect(boost::bind(&BasicSessionStream::handleElementReceived, this, _1)); xmppLayer->onError.connect(boost::bind(&BasicSessionStream::handleXMPPError, this)); xmppLayer->onDataRead.connect(boost::bind(&BasicSessionStream::handleDataRead, this, _1)); xmppLayer->onWriteData.connect(boost::bind(&BasicSessionStream::handleDataWritten, this, _1)); connection->onDisconnected.connect(boost::bind(&BasicSessionStream::handleConnectionFinished, this, _1)); connectionLayer = new ConnectionLayer(connection); streamStack = new StreamStack(xmppLayer, connectionLayer); available = true; } BasicSessionStream::~BasicSessionStream() { delete compressionLayer; if (tlsLayer) { tlsLayer->onError.disconnect(boost::bind(&BasicSessionStream::handleTLSError, this)); tlsLayer->onConnected.disconnect(boost::bind(&BasicSessionStream::handleTLSConnected, this)); delete tlsLayer; } delete whitespacePingLayer; delete streamStack; connection->onDisconnected.disconnect(boost::bind(&BasicSessionStream::handleConnectionFinished, this, _1)); delete connectionLayer; xmppLayer->onStreamStart.disconnect(boost::bind(&BasicSessionStream::handleStreamStartReceived, this, _1)); xmppLayer->onElement.disconnect(boost::bind(&BasicSessionStream::handleElementReceived, this, _1)); xmppLayer->onError.disconnect(boost::bind(&BasicSessionStream::handleXMPPError, this)); xmppLayer->onDataRead.disconnect(boost::bind(&BasicSessionStream::handleDataRead, this, _1)); xmppLayer->onWriteData.disconnect(boost::bind(&BasicSessionStream::handleDataWritten, this, _1)); delete xmppLayer; } diff --git a/Swiften/Session/BasicSessionStream.h b/Swiften/Session/BasicSessionStream.h index 2a1ed8a..2ed5ac6 100644 --- a/Swiften/Session/BasicSessionStream.h +++ b/Swiften/Session/BasicSessionStream.h @@ -1,71 +1,73 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once #include <boost/shared_ptr.hpp> #include <Swiften/Base/SafeByteArray.h> #include <Swiften/Network/Connection.h> #include <Swiften/Session/SessionStream.h> #include <Swiften/Elements/StreamType.h> namespace Swift { class TLSContextFactory; class TLSLayer; class TimerFactory; class WhitespacePingLayer; class PayloadParserFactoryCollection; class PayloadSerializerCollection; class StreamStack; class XMPPLayer; class ConnectionLayer; class CompressionLayer; + class XMLParserFactory; class BasicSessionStream : public SessionStream { public: BasicSessionStream( StreamType streamType, boost::shared_ptr<Connection> connection, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers, TLSContextFactory* tlsContextFactory, - TimerFactory* whitespacePingLayerFactory + TimerFactory* whitespacePingLayerFactory, + XMLParserFactory* xmlParserFactory ); ~BasicSessionStream(); virtual void close(); virtual bool isOpen(); virtual void writeHeader(const ProtocolHeader& header); virtual void writeElement(boost::shared_ptr<Element>); virtual void writeFooter(); virtual void writeData(const std::string& data); virtual void addZLibCompression(); virtual bool supportsTLSEncryption(); virtual void addTLSEncryption(); virtual bool isTLSEncrypted(); virtual Certificate::ref getPeerCertificate() const; virtual boost::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const; virtual ByteArray getTLSFinishMessage() const; virtual void setWhitespacePingEnabled(bool); virtual void resetXMPPParser(); private: void handleConnectionFinished(const boost::optional<Connection::Error>& error); void handleXMPPError(); void handleTLSConnected(); void handleTLSError(); void handleStreamStartReceived(const ProtocolHeader&); void handleElementReceived(boost::shared_ptr<Element>); void handleDataRead(const SafeByteArray& data); void handleDataWritten(const SafeByteArray& data); private: diff --git a/Swiften/Session/Session.cpp b/Swiften/Session/Session.cpp index e8b8308..661cb8d 100644 --- a/Swiften/Session/Session.cpp +++ b/Swiften/Session/Session.cpp @@ -1,101 +1,103 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <Swiften/Session/Session.h> #include <boost/bind.hpp> #include <Swiften/StreamStack/XMPPLayer.h> #include <Swiften/StreamStack/StreamStack.h> namespace Swift { Session::Session( boost::shared_ptr<Connection> connection, PayloadParserFactoryCollection* payloadParserFactories, - PayloadSerializerCollection* payloadSerializers) : + PayloadSerializerCollection* payloadSerializers, + XMLParserFactory* xmlParserFactory) : connection(connection), payloadParserFactories(payloadParserFactories), payloadSerializers(payloadSerializers), + xmlParserFactory(xmlParserFactory), xmppLayer(NULL), connectionLayer(NULL), streamStack(0), finishing(false) { } Session::~Session() { delete streamStack; delete connectionLayer; delete xmppLayer; } void Session::startSession() { initializeStreamStack(); handleSessionStarted(); } void Session::finishSession() { if (finishing) { return; } finishing = true; if (xmppLayer) { xmppLayer->writeFooter(); } connection->disconnect(); handleSessionFinished(boost::optional<SessionError>()); onSessionFinished(boost::optional<SessionError>()); } void Session::finishSession(const SessionError& error) { if (finishing) { return; } finishing = true; if (xmppLayer) { xmppLayer->writeFooter(); } connection->disconnect(); handleSessionFinished(boost::optional<SessionError>(error)); onSessionFinished(boost::optional<SessionError>(error)); } void Session::initializeStreamStack() { - xmppLayer = new XMPPLayer(payloadParserFactories, payloadSerializers, ClientStreamType); + xmppLayer = new XMPPLayer(payloadParserFactories, payloadSerializers, xmlParserFactory, ClientStreamType); xmppLayer->onStreamStart.connect( boost::bind(&Session::handleStreamStart, shared_from_this(), _1)); xmppLayer->onElement.connect(boost::bind(&Session::handleElement, shared_from_this(), _1)); xmppLayer->onError.connect( boost::bind(&Session::finishSession, shared_from_this(), XMLError)); xmppLayer->onDataRead.connect(boost::bind(boost::ref(onDataRead), _1)); xmppLayer->onWriteData.connect(boost::bind(boost::ref(onDataWritten), _1)); connection->onDisconnected.connect( boost::bind(&Session::handleDisconnected, shared_from_this(), _1)); connectionLayer = new ConnectionLayer(connection); streamStack = new StreamStack(xmppLayer, connectionLayer); } void Session::sendElement(boost::shared_ptr<Element> stanza) { xmppLayer->writeElement(stanza); } void Session::handleDisconnected(const boost::optional<Connection::Error>& connectionError) { if (connectionError) { switch (*connectionError) { case Connection::ReadError: finishSession(ConnectionReadError); break; case Connection::WriteError: finishSession(ConnectionWriteError); break; } } else { finishSession(); } } } diff --git a/Swiften/Session/Session.h b/Swiften/Session/Session.h index 9e954c7..c937430 100644 --- a/Swiften/Session/Session.h +++ b/Swiften/Session/Session.h @@ -1,111 +1,114 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once #include <boost/shared_ptr.hpp> #include <Swiften/Base/boost_bsignals.h> #include <boost/optional.hpp> #include <boost/enable_shared_from_this.hpp> #include <Swiften/JID/JID.h> #include <Swiften/Elements/Element.h> #include <Swiften/Network/Connection.h> #include <Swiften/StreamStack/ConnectionLayer.h> #include <Swiften/Base/SafeByteArray.h> namespace Swift { class ProtocolHeader; class StreamStack; class JID; class Element; class PayloadParserFactoryCollection; class PayloadSerializerCollection; class XMPPLayer; + class XMLParserFactory; class Session : public boost::enable_shared_from_this<Session> { public: enum SessionError { ConnectionReadError, ConnectionWriteError, XMLError, AuthenticationFailedError, NoSupportedAuthMechanismsError, UnexpectedElementError, ResourceBindError, SessionStartError, TLSError, ClientCertificateLoadError, ClientCertificateError }; Session( boost::shared_ptr<Connection> connection, PayloadParserFactoryCollection* payloadParserFactories, - PayloadSerializerCollection* payloadSerializers); + PayloadSerializerCollection* payloadSerializers, + XMLParserFactory* xmlParserFactory); virtual ~Session(); void startSession(); void finishSession(); void sendElement(boost::shared_ptr<Element>); const JID& getLocalJID() const { return localJID; } const JID& getRemoteJID() const { return remoteJID; } boost::signal<void (boost::shared_ptr<Element>)> onElementReceived; boost::signal<void (const boost::optional<SessionError>&)> onSessionFinished; boost::signal<void (const SafeByteArray&)> onDataWritten; boost::signal<void (const SafeByteArray&)> onDataRead; protected: void setRemoteJID(const JID& j) { remoteJID = j; } void setLocalJID(const JID& j) { localJID = j; } void finishSession(const SessionError&); virtual void handleSessionStarted() {} virtual void handleSessionFinished(const boost::optional<SessionError>&) {} virtual void handleElement(boost::shared_ptr<Element>) = 0; virtual void handleStreamStart(const ProtocolHeader&) = 0; void initializeStreamStack(); XMPPLayer* getXMPPLayer() const { return xmppLayer; } StreamStack* getStreamStack() const { return streamStack; } void setFinished(); private: void handleDisconnected(const boost::optional<Connection::Error>& error); private: JID localJID; JID remoteJID; boost::shared_ptr<Connection> connection; PayloadParserFactoryCollection* payloadParserFactories; PayloadSerializerCollection* payloadSerializers; + XMLParserFactory* xmlParserFactory; XMPPLayer* xmppLayer; ConnectionLayer* connectionLayer; StreamStack* streamStack; bool finishing; }; } diff --git a/Swiften/StreamStack/UnitTest/StreamStackTest.cpp b/Swiften/StreamStack/UnitTest/StreamStackTest.cpp index d3c0a7c..213948a 100644 --- a/Swiften/StreamStack/UnitTest/StreamStackTest.cpp +++ b/Swiften/StreamStack/UnitTest/StreamStackTest.cpp @@ -1,76 +1,77 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <Swiften/Base/ByteArray.h> #include <QA/Checker/IO.h> #include <vector> #include <boost/bind.hpp> #include <boost/smart_ptr.hpp> #include <cppunit/extensions/HelperMacros.h> #include <cppunit/extensions/TestFactoryRegistry.h> #include <Swiften/Base/ByteArray.h> #include <Swiften/Base/Concat.h> #include <Swiften/StreamStack/StreamStack.h> #include <Swiften/StreamStack/LowLayer.h> #include <Swiften/StreamStack/XMPPLayer.h> #include <Swiften/StreamStack/StreamLayer.h> +#include <Swiften/Parser/PlatformXMLParserFactory.h> #include <Swiften/Parser/PayloadParsers/FullPayloadParserFactoryCollection.h> #include <Swiften/Serializer/PayloadSerializers/FullPayloadSerializerCollection.h> using namespace Swift; class StreamStackTest : public CppUnit::TestFixture { CPPUNIT_TEST_SUITE(StreamStackTest); CPPUNIT_TEST(testWriteData_NoIntermediateStreamStack); CPPUNIT_TEST(testWriteData_OneIntermediateStream); CPPUNIT_TEST(testWriteData_TwoIntermediateStreamStack); CPPUNIT_TEST(testReadData_NoIntermediateStreamStack); CPPUNIT_TEST(testReadData_OneIntermediateStream); CPPUNIT_TEST(testReadData_TwoIntermediateStreamStack); CPPUNIT_TEST(testAddLayer_ExistingOnWriteDataSlot); CPPUNIT_TEST_SUITE_END(); public: void setUp() { physicalStream_ = new TestLowLayer(); - xmppStream_ = new XMPPLayer(&parserFactories_, &serializers_, ClientStreamType); + xmppStream_ = new XMPPLayer(&parserFactories_, &serializers_, &xmlParserFactory_, ClientStreamType); elementsReceived_ = 0; dataWriteReceived_ = 0; } void tearDown() { delete physicalStream_; delete xmppStream_; } void testWriteData_NoIntermediateStreamStack() { StreamStack testling(xmppStream_, physicalStream_); xmppStream_->writeData("foo"); CPPUNIT_ASSERT_EQUAL(static_cast<size_t>(1), physicalStream_->data_.size()); CPPUNIT_ASSERT_EQUAL(createSafeByteArray("foo"), physicalStream_->data_[0]); } void testWriteData_OneIntermediateStream() { StreamStack testling(xmppStream_, physicalStream_); boost::shared_ptr<MyStreamLayer> xStream(new MyStreamLayer("X")); testling.addLayer(xStream.get()); xmppStream_->writeData("foo"); CPPUNIT_ASSERT_EQUAL(static_cast<size_t>(1), physicalStream_->data_.size()); CPPUNIT_ASSERT_EQUAL(createSafeByteArray("Xfoo"), physicalStream_->data_[0]); } void testWriteData_TwoIntermediateStreamStack() { StreamStack testling(xmppStream_, physicalStream_); boost::shared_ptr<MyStreamLayer> xStream(new MyStreamLayer("X")); boost::shared_ptr<MyStreamLayer> yStream(new MyStreamLayer("Y")); testling.addLayer(xStream.get()); testling.addLayer(yStream.get()); @@ -140,41 +141,42 @@ class StreamStackTest : public CppUnit::TestFixture { } virtual void writeData(const SafeByteArray& data) { writeDataToChildLayer(concat(createSafeByteArray(prepend_), data)); } virtual void handleDataRead(const SafeByteArray& data) { writeDataToParentLayer(concat(createSafeByteArray(prepend_), data)); } private: std::string prepend_; }; class TestLowLayer : public LowLayer { public: TestLowLayer() { } virtual void writeData(const SafeByteArray& data) { data_.push_back(data); } void onDataRead(const SafeByteArray& data) { writeDataToParentLayer(data); } std::vector<SafeByteArray> data_; }; private: FullPayloadParserFactoryCollection parserFactories_; FullPayloadSerializerCollection serializers_; TestLowLayer* physicalStream_; + PlatformXMLParserFactory xmlParserFactory_; XMPPLayer* xmppStream_; int elementsReceived_; int dataWriteReceived_; }; CPPUNIT_TEST_SUITE_REGISTRATION(StreamStackTest); diff --git a/Swiften/StreamStack/UnitTest/XMPPLayerTest.cpp b/Swiften/StreamStack/UnitTest/XMPPLayerTest.cpp index bb0ce61..8123c00 100644 --- a/Swiften/StreamStack/UnitTest/XMPPLayerTest.cpp +++ b/Swiften/StreamStack/UnitTest/XMPPLayerTest.cpp @@ -1,70 +1,71 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <vector> #include <boost/bind.hpp> #include <cppunit/extensions/HelperMacros.h> #include <cppunit/extensions/TestFactoryRegistry.h> #include <Swiften/Elements/ProtocolHeader.h> #include <Swiften/Elements/Presence.h> #include <Swiften/Base/ByteArray.h> #include <Swiften/StreamStack/XMPPLayer.h> #include <Swiften/StreamStack/LowLayer.h> +#include <Swiften/Parser/PlatformXMLParserFactory.h> #include <Swiften/Parser/PayloadParsers/FullPayloadParserFactoryCollection.h> #include <Swiften/Serializer/PayloadSerializers/FullPayloadSerializerCollection.h> using namespace Swift; class XMPPLayerTest : public CppUnit::TestFixture { CPPUNIT_TEST_SUITE(XMPPLayerTest); CPPUNIT_TEST(testParseData_Error); CPPUNIT_TEST(testResetParser); CPPUNIT_TEST(testResetParser_FromSlot); CPPUNIT_TEST(testWriteHeader); CPPUNIT_TEST(testWriteElement); CPPUNIT_TEST(testWriteFooter); CPPUNIT_TEST_SUITE_END(); public: void setUp() { lowLayer_ = new DummyLowLayer(); - testling_ = new XMPPLayerExposed(&parserFactories_, &serializers_, ClientStreamType); + testling_ = new XMPPLayerExposed(&parserFactories_, &serializers_, &xmlParserFactory_, ClientStreamType); testling_->setChildLayer(lowLayer_); elementsReceived_ = 0; errorReceived_ = 0; } void tearDown() { delete testling_; delete lowLayer_; } void testParseData_Error() { testling_->onError.connect(boost::bind(&XMPPLayerTest::handleError, this)); testling_->handleDataRead(createSafeByteArray("<iq>")); CPPUNIT_ASSERT_EQUAL(1, errorReceived_); } void testResetParser() { testling_->onElement.connect(boost::bind(&XMPPLayerTest::handleElement, this, _1)); testling_->onError.connect(boost::bind(&XMPPLayerTest::handleError, this)); testling_->handleDataRead(createSafeByteArray("<stream:stream to=\"example.com\" xmlns=\"jabber:client\" xmlns:stream=\"http://etherx.jabber.org/streams\" >")); testling_->resetParser(); testling_->handleDataRead(createSafeByteArray("<stream:stream to=\"example.com\" xmlns=\"jabber:client\" xmlns:stream=\"http://etherx.jabber.org/streams\" >")); testling_->handleDataRead(createSafeByteArray("<presence/>")); CPPUNIT_ASSERT_EQUAL(1, elementsReceived_); CPPUNIT_ASSERT_EQUAL(0, errorReceived_); } void testResetParser_FromSlot() { testling_->onElement.connect(boost::bind(&XMPPLayerTest::handleElementAndReset, this, _1)); testling_->handleDataRead(createSafeByteArray("<stream:stream to=\"example.com\" xmlns=\"jabber:client\" xmlns:stream=\"http://etherx.jabber.org/streams\" ><presence/>")); testling_->handleDataRead(createSafeByteArray("<stream:stream to=\"example.com\" xmlns=\"jabber:client\" xmlns:stream=\"http://etherx.jabber.org/streams\" ><presence/>")); @@ -80,59 +81,61 @@ class XMPPLayerTest : public CppUnit::TestFixture { CPPUNIT_ASSERT_EQUAL(std::string("<?xml version=\"1.0\"?><stream:stream xmlns=\"jabber:client\" xmlns:stream=\"http://etherx.jabber.org/streams\" to=\"example.com\" version=\"1.0\">"), lowLayer_->writtenData); } void testWriteElement() { testling_->writeElement(boost::shared_ptr<Presence>(new Presence())); CPPUNIT_ASSERT_EQUAL(std::string("<presence/>"), lowLayer_->writtenData); } void testWriteFooter() { testling_->writeFooter(); CPPUNIT_ASSERT_EQUAL(std::string("</stream:stream>"), lowLayer_->writtenData); } void handleElement(boost::shared_ptr<Element>) { ++elementsReceived_; } void handleElementAndReset(boost::shared_ptr<Element>) { ++elementsReceived_; testling_->resetParser(); } void handleError() { ++errorReceived_; } private: class XMPPLayerExposed : public XMPPLayer { public: XMPPLayerExposed( PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers, - StreamType streamType) : XMPPLayer(payloadParserFactories, payloadSerializers, streamType) {} + XMLParserFactory* xmlParserFactory, + StreamType streamType) : XMPPLayer(payloadParserFactories, payloadSerializers, xmlParserFactory, streamType) {} using XMPPLayer::handleDataRead; using HighLayer::setChildLayer; }; class DummyLowLayer : public LowLayer { public: virtual void writeData(const SafeByteArray& data) { writtenData += byteArrayToString(ByteArray(data.begin(), data.end())); } std::string writtenData; }; FullPayloadParserFactoryCollection parserFactories_; FullPayloadSerializerCollection serializers_; DummyLowLayer* lowLayer_; XMPPLayerExposed* testling_; + PlatformXMLParserFactory xmlParserFactory_; int elementsReceived_; int errorReceived_; }; CPPUNIT_TEST_SUITE_REGISTRATION(XMPPLayerTest); diff --git a/Swiften/StreamStack/XMPPLayer.cpp b/Swiften/StreamStack/XMPPLayer.cpp index 1dcd84f..94afcf9 100644 --- a/Swiften/StreamStack/XMPPLayer.cpp +++ b/Swiften/StreamStack/XMPPLayer.cpp @@ -1,95 +1,97 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <Swiften/StreamStack/XMPPLayer.h> #include <Swiften/Parser/XMPPParser.h> #include <Swiften/Serializer/XMPPSerializer.h> #include <Swiften/Elements/ProtocolHeader.h> namespace Swift { XMPPLayer::XMPPLayer( PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers, + XMLParserFactory* xmlParserFactory, StreamType streamType) : payloadParserFactories_(payloadParserFactories), payloadSerializers_(payloadSerializers), + xmlParserFactory_(xmlParserFactory), resetParserAfterParse_(false), inParser_(false) { - xmppParser_ = new XMPPParser(this, payloadParserFactories_); + xmppParser_ = new XMPPParser(this, payloadParserFactories_, xmlParserFactory); xmppSerializer_ = new XMPPSerializer(payloadSerializers_, streamType); } XMPPLayer::~XMPPLayer() { delete xmppSerializer_; delete xmppParser_; } void XMPPLayer::writeHeader(const ProtocolHeader& header) { writeDataInternal(createSafeByteArray(xmppSerializer_->serializeHeader(header))); } void XMPPLayer::writeFooter() { writeDataInternal(createSafeByteArray(xmppSerializer_->serializeFooter())); } void XMPPLayer::writeElement(boost::shared_ptr<Element> element) { writeDataInternal(xmppSerializer_->serializeElement(element)); } void XMPPLayer::writeData(const std::string& data) { writeDataInternal(createSafeByteArray(data)); } void XMPPLayer::writeDataInternal(const SafeByteArray& data) { onWriteData(data); writeDataToChildLayer(data); } void XMPPLayer::handleDataRead(const SafeByteArray& data) { onDataRead(data); inParser_ = true; // FIXME: Converting to unsafe string. Should be ok, since we don't take passwords // from the stream in clients. If servers start using this, and require safe storage, // we need to fix this. if (!xmppParser_->parse(byteArrayToString(ByteArray(data.begin(), data.end())))) { inParser_ = false; onError(); return; } inParser_ = false; if (resetParserAfterParse_) { doResetParser(); } } void XMPPLayer::doResetParser() { delete xmppParser_; - xmppParser_ = new XMPPParser(this, payloadParserFactories_); + xmppParser_ = new XMPPParser(this, payloadParserFactories_, xmlParserFactory_); resetParserAfterParse_ = false; } void XMPPLayer::handleStreamStart(const ProtocolHeader& header) { onStreamStart(header); } void XMPPLayer::handleElement(boost::shared_ptr<Element> stanza) { onElement(stanza); } void XMPPLayer::handleStreamEnd() { } void XMPPLayer::resetParser() { if (inParser_) { resetParserAfterParse_ = true; } else { doResetParser(); } } } diff --git a/Swiften/StreamStack/XMPPLayer.h b/Swiften/StreamStack/XMPPLayer.h index 54bdd42..9be00b2 100644 --- a/Swiften/StreamStack/XMPPLayer.h +++ b/Swiften/StreamStack/XMPPLayer.h @@ -1,67 +1,70 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once #include <boost/shared_ptr.hpp> #include <Swiften/Base/boost_bsignals.h> #include <boost/noncopyable.hpp> #include <Swiften/StreamStack/HighLayer.h> #include <Swiften/Base/SafeByteArray.h> #include <Swiften/Elements/Element.h> #include <Swiften/Elements/StreamType.h> #include <Swiften/Parser/XMPPParserClient.h> namespace Swift { class ProtocolHeader; class XMPPParser; class PayloadParserFactoryCollection; class XMPPSerializer; class PayloadSerializerCollection; + class XMLParserFactory; class XMPPLayer : public XMPPParserClient, public HighLayer, boost::noncopyable { public: XMPPLayer( PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers, + XMLParserFactory* xmlParserFactory, StreamType streamType); ~XMPPLayer(); void writeHeader(const ProtocolHeader& header); void writeFooter(); void writeElement(boost::shared_ptr<Element>); void writeData(const std::string& data); void resetParser(); protected: void handleDataRead(const SafeByteArray& data); void writeDataInternal(const SafeByteArray& data); public: boost::signal<void (const ProtocolHeader&)> onStreamStart; boost::signal<void (boost::shared_ptr<Element>)> onElement; boost::signal<void (const SafeByteArray&)> onWriteData; boost::signal<void (const SafeByteArray&)> onDataRead; boost::signal<void ()> onError; private: void handleStreamStart(const ProtocolHeader&); void handleElement(boost::shared_ptr<Element>); void handleStreamEnd(); void doResetParser(); private: PayloadParserFactoryCollection* payloadParserFactories_; XMPPParser* xmppParser_; PayloadSerializerCollection* payloadSerializers_; + XMLParserFactory* xmlParserFactory_; XMPPSerializer* xmppSerializer_; bool resetParserAfterParse_; bool inParser_; }; } |
Swift