diff options
30 files changed, 734 insertions, 28 deletions
diff --git a/Swiften/Client/CoreClient.cpp b/Swiften/Client/CoreClient.cpp index c80acb5..272a2a7 100644 --- a/Swiften/Client/CoreClient.cpp +++ b/Swiften/Client/CoreClient.cpp @@ -82,7 +82,7 @@ void CoreClient::handleConnectorFinished(boost::shared_ptr<Connection> connectio connection_ = connection; assert(!sessionStream_); - sessionStream_ = boost::shared_ptr<BasicSessionStream>(new BasicSessionStream(connection_, &payloadParserFactories_, &payloadSerializers_, tlsLayerFactory_, timerFactory_)); + sessionStream_ = boost::shared_ptr<BasicSessionStream>(new BasicSessionStream(ClientStreamType, connection_, &payloadParserFactories_, &payloadSerializers_, tlsLayerFactory_, timerFactory_)); if (!certificate_.isEmpty()) { sessionStream_->setTLSCertificate(PKCS12Certificate(certificate_, password_)); } diff --git a/Swiften/Component/ComponentConnector.cpp b/Swiften/Component/ComponentConnector.cpp new file mode 100644 index 0000000..e764138 --- /dev/null +++ b/Swiften/Component/ComponentConnector.cpp @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2010 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include "Swiften/Component/ComponentConnector.h" + +#include <boost/bind.hpp> +#include <iostream> + +#include "Swiften/Network/ConnectionFactory.h" +#include "Swiften/Network/DomainNameResolver.h" +#include "Swiften/Network/DomainNameAddressQuery.h" +#include "Swiften/Network/TimerFactory.h" + +namespace Swift { + +ComponentConnector::ComponentConnector(const String& hostname, int port, DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory) : hostname(hostname), port(port), resolver(resolver), connectionFactory(connectionFactory), timerFactory(timerFactory), timeoutMilliseconds(0) { +} + +void ComponentConnector::setTimeoutMilliseconds(int milliseconds) { + timeoutMilliseconds = milliseconds; +} + +void ComponentConnector::start() { + assert(!currentConnection); + assert(!timer); + assert(!addressQuery); + addressQuery = resolver->createAddressQuery(hostname); + addressQuery->onResult.connect(boost::bind(&ComponentConnector::handleAddressQueryResult, shared_from_this(), _1, _2)); + if (timeoutMilliseconds > 0) { + timer = timerFactory->createTimer(timeoutMilliseconds); + timer->onTick.connect(boost::bind(&ComponentConnector::handleTimeout, shared_from_this())); + timer->start(); + } + addressQuery->run(); +} + +void ComponentConnector::stop() { + finish(boost::shared_ptr<Connection>()); +} + + +void ComponentConnector::handleAddressQueryResult(const std::vector<HostAddress>& addresses, boost::optional<DomainNameResolveError> error) { + addressQuery.reset(); + if (error || addresses.empty()) { + finish(boost::shared_ptr<Connection>()); + } + else { + addressQueryResults = std::deque<HostAddress>(addresses.begin(), addresses.end()); + tryNextAddress(); + } +} + +void ComponentConnector::tryNextAddress() { + assert(!addressQueryResults.empty()); + HostAddress address = addressQueryResults.front(); + addressQueryResults.pop_front(); + tryConnect(HostAddressPort(address, port)); +} + +void ComponentConnector::tryConnect(const HostAddressPort& target) { + assert(!currentConnection); + currentConnection = connectionFactory->createConnection(); + currentConnection->onConnectFinished.connect(boost::bind(&ComponentConnector::handleConnectionConnectFinished, shared_from_this(), _1)); + currentConnection->connect(target); +} + +void ComponentConnector::handleConnectionConnectFinished(bool error) { + currentConnection->onConnectFinished.disconnect(boost::bind(&ComponentConnector::handleConnectionConnectFinished, shared_from_this(), _1)); + if (error) { + currentConnection.reset(); + if (!addressQueryResults.empty()) { + tryNextAddress(); + } + else { + finish(boost::shared_ptr<Connection>()); + } + } + else { + finish(currentConnection); + } +} + +void ComponentConnector::finish(boost::shared_ptr<Connection> connection) { + if (timer) { + timer->stop(); + timer->onTick.disconnect(boost::bind(&ComponentConnector::handleTimeout, shared_from_this())); + timer.reset(); + } + if (addressQuery) { + addressQuery->onResult.disconnect(boost::bind(&ComponentConnector::handleAddressQueryResult, shared_from_this(), _1, _2)); + addressQuery.reset(); + } + if (currentConnection) { + currentConnection->onConnectFinished.disconnect(boost::bind(&ComponentConnector::handleConnectionConnectFinished, shared_from_this(), _1)); + currentConnection.reset(); + } + onConnectFinished(connection); +} + +void ComponentConnector::handleTimeout() { + finish(boost::shared_ptr<Connection>()); +} + +}; diff --git a/Swiften/Component/ComponentConnector.h b/Swiften/Component/ComponentConnector.h new file mode 100644 index 0000000..a84d8ba --- /dev/null +++ b/Swiften/Component/ComponentConnector.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2010 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include <deque> +#include "Swiften/Base/boost_bsignals.h" +#include <boost/shared_ptr.hpp> + +#include "Swiften/Network/Connection.h" +#include "Swiften/Network/Timer.h" +#include "Swiften/Network/HostAddressPort.h" +#include "Swiften/Base/String.h" +#include "Swiften/Network/DomainNameResolveError.h" + +namespace Swift { + class DomainNameAddressQuery; + class DomainNameResolver; + class ConnectionFactory; + class TimerFactory; + + class ComponentConnector : public boost::bsignals::trackable, public boost::enable_shared_from_this<ComponentConnector> { + public: + typedef boost::shared_ptr<ComponentConnector> ref; + + static ComponentConnector::ref create(const String& hostname, int port, DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory) { + return ComponentConnector::ref(new ComponentConnector(hostname, port, resolver, connectionFactory, timerFactory)); + } + + void setTimeoutMilliseconds(int milliseconds); + + void start(); + void stop(); + + boost::signal<void (boost::shared_ptr<Connection>)> onConnectFinished; + + private: + ComponentConnector(const String& hostname, int port, DomainNameResolver*, ConnectionFactory*, TimerFactory*); + + void handleAddressQueryResult(const std::vector<HostAddress>& address, boost::optional<DomainNameResolveError> error); + void tryNextAddress(); + void tryConnect(const HostAddressPort& target); + + void handleConnectionConnectFinished(bool error); + void finish(boost::shared_ptr<Connection>); + void handleTimeout(); + + + private: + String hostname; + int port; + DomainNameResolver* resolver; + ConnectionFactory* connectionFactory; + TimerFactory* timerFactory; + int timeoutMilliseconds; + boost::shared_ptr<Timer> timer; + boost::shared_ptr<DomainNameAddressQuery> addressQuery; + std::deque<HostAddress> addressQueryResults; + boost::shared_ptr<Connection> currentConnection; + }; +}; diff --git a/Swiften/Component/ComponentHandshakeGenerator.cpp b/Swiften/Component/ComponentHandshakeGenerator.cpp new file mode 100644 index 0000000..422f986 --- /dev/null +++ b/Swiften/Component/ComponentHandshakeGenerator.cpp @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2010 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include "Swiften/Component/ComponentHandshakeGenerator.h" +#include "Swiften/StringCodecs/Hexify.h" +#include "Swiften/StringCodecs/SHA1.h" + +namespace Swift { + +String ComponentHandshakeGenerator::getHandshake(const String& streamID, const String& secret) { + String concatenatedString = streamID + secret; + concatenatedString.replaceAll('&', "&"); + concatenatedString.replaceAll('<', "<"); + concatenatedString.replaceAll('>', ">"); + concatenatedString.replaceAll('\'', "'"); + concatenatedString.replaceAll('"', """); + return Hexify::hexify(SHA1::getHash(ByteArray(concatenatedString))); +} + +} diff --git a/Swiften/Component/ComponentHandshakeGenerator.h b/Swiften/Component/ComponentHandshakeGenerator.h new file mode 100644 index 0000000..d71a664 --- /dev/null +++ b/Swiften/Component/ComponentHandshakeGenerator.h @@ -0,0 +1,17 @@ +/* + * Copyright (c) 2010 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include "Swiften/Base/String.h" + +namespace Swift { + class ComponentHandshakeGenerator { + public: + static String getHandshake(const String& streamID, const String& secret); + }; + +} diff --git a/Swiften/Component/SConscript b/Swiften/Component/SConscript new file mode 100644 index 0000000..6d86575 --- /dev/null +++ b/Swiften/Component/SConscript @@ -0,0 +1,8 @@ +Import("swiften_env") + +sources = [ + "ComponentHandshakeGenerator.cpp", + "ComponentConnector.cpp", + ] + +swiften_env.Append(SWIFTEN_OBJECTS = swiften_env.StaticObject(sources))
\ No newline at end of file diff --git a/Swiften/Component/UnitTest/ComponentConnectorTest.cpp b/Swiften/Component/UnitTest/ComponentConnectorTest.cpp new file mode 100644 index 0000000..7b8a4f8 --- /dev/null +++ b/Swiften/Component/UnitTest/ComponentConnectorTest.cpp @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2010 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include <cppunit/extensions/HelperMacros.h> +#include <cppunit/extensions/TestFactoryRegistry.h> + +#include <boost/optional.hpp> +#include <boost/bind.hpp> + +#include "Swiften/Component/ComponentConnector.h" +#include "Swiften/Network/Connection.h" +#include "Swiften/Network/ConnectionFactory.h" +#include "Swiften/Network/HostAddressPort.h" +#include "Swiften/Network/StaticDomainNameResolver.h" +#include "Swiften/Network/DummyTimerFactory.h" +#include "Swiften/EventLoop/MainEventLoop.h" +#include "Swiften/EventLoop/DummyEventLoop.h" + +using namespace Swift; + +class ComponentConnectorTest : public CppUnit::TestFixture { + CPPUNIT_TEST_SUITE(ComponentConnectorTest); + CPPUNIT_TEST(testConnect); + CPPUNIT_TEST(testConnect_FirstAddressHostFails); + CPPUNIT_TEST(testConnect_NoHosts); + CPPUNIT_TEST(testConnect_TimeoutDuringResolve); + CPPUNIT_TEST(testConnect_TimeoutDuringConnect); + CPPUNIT_TEST(testConnect_NoTimeout); + CPPUNIT_TEST(testStop_Timeout); + CPPUNIT_TEST_SUITE_END(); + + public: + ComponentConnectorTest() : host1("1.1.1.1"), host2("2.2.2.2") { + } + + void setUp() { + eventLoop = new DummyEventLoop(); + resolver = new StaticDomainNameResolver(); + connectionFactory = new MockConnectionFactory(); + timerFactory = new DummyTimerFactory(); + } + + void tearDown() { + delete timerFactory; + delete connectionFactory; + delete resolver; + delete eventLoop; + } + + void testConnect() { + ComponentConnector::ref testling(createConnector("foo.com", 1234)); + resolver->addAddress("foo.com", host1); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(connections[0]); + CPPUNIT_ASSERT(HostAddressPort(host1, 1234) == *(connections[0]->hostAddressPort)); + } + + void testConnect_FirstAddressHostFails() { + ComponentConnector::ref testling(createConnector("foo.com", 1234)); + resolver->addAddress("foo.com", host1); + resolver->addAddress("foo.com", host2); + connectionFactory->failingPorts.push_back(HostAddressPort(host1, 1234)); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(connections[0]); + CPPUNIT_ASSERT(HostAddressPort(host2, 1234) == *(connections[0]->hostAddressPort)); + } + + void testConnect_NoHosts() { + ComponentConnector::ref testling(createConnector("foo.com", 1234)); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + } + + + void testConnect_TimeoutDuringResolve() { + ComponentConnector::ref testling(createConnector("foo.com", 1234)); + + testling->setTimeoutMilliseconds(10); + resolver->setIsResponsive(false); + + testling->start(); + eventLoop->processEvents(); + timerFactory->setTime(10); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + } + + void testConnect_TimeoutDuringConnect() { + ComponentConnector::ref testling(createConnector("foo.com", 1234)); + testling->setTimeoutMilliseconds(10); + resolver->addAddress("foo.com", host1); + connectionFactory->isResponsive = false; + + testling->start(); + eventLoop->processEvents(); + timerFactory->setTime(10); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + } + + void testConnect_NoTimeout() { + ComponentConnector::ref testling(createConnector("foo.com", 1234)); + testling->setTimeoutMilliseconds(10); + resolver->addAddress("foo.com", host1); + + testling->start(); + eventLoop->processEvents(); + timerFactory->setTime(10); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(connections[0]); + } + + void testStop_Timeout() { + ComponentConnector::ref testling(createConnector("foo.com", 1234)); + testling->setTimeoutMilliseconds(10); + resolver->addAddress("foo.com", host1); + + testling->start(); + testling->stop(); + + eventLoop->processEvents(); + timerFactory->setTime(10); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + } + + private: + ComponentConnector::ref createConnector(const String& hostname, int port) { + ComponentConnector::ref connector = ComponentConnector::create(hostname, port, resolver, connectionFactory, timerFactory); + connector->onConnectFinished.connect(boost::bind(&ComponentConnectorTest::handleConnectorFinished, this, _1)); + return connector; + } + + void handleConnectorFinished(boost::shared_ptr<Connection> connection) { + boost::shared_ptr<MockConnection> c(boost::dynamic_pointer_cast<MockConnection>(connection)); + if (connection) { + assert(c); + } + connections.push_back(c); + } + + struct MockConnection : public Connection { + public: + MockConnection(const std::vector<HostAddressPort>& failingPorts, bool isResponsive) : failingPorts(failingPorts), isResponsive(isResponsive) {} + + void listen() { assert(false); } + void connect(const HostAddressPort& address) { + hostAddressPort = address; + if (isResponsive) { + bool fail = std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end(); + MainEventLoop::postEvent(boost::bind(boost::ref(onConnectFinished), fail)); + } + } + + void disconnect() { assert(false); } + void write(const ByteArray&) { assert(false); } + + boost::optional<HostAddressPort> hostAddressPort; + std::vector<HostAddressPort> failingPorts; + bool isResponsive; + }; + + struct MockConnectionFactory : public ConnectionFactory { + MockConnectionFactory() : isResponsive(true) { + } + + boost::shared_ptr<Connection> createConnection() { + return boost::shared_ptr<Connection>(new MockConnection(failingPorts, isResponsive)); + } + + bool isResponsive; + std::vector<HostAddressPort> failingPorts; + }; + + private: + HostAddress host1; + HostAddress host2; + DummyEventLoop* eventLoop; + StaticDomainNameResolver* resolver; + MockConnectionFactory* connectionFactory; + DummyTimerFactory* timerFactory; + std::vector< boost::shared_ptr<MockConnection> > connections; +}; + +CPPUNIT_TEST_SUITE_REGISTRATION(ComponentConnectorTest); diff --git a/Swiften/Component/UnitTest/ComponentHandshakeGeneratorTest.cpp b/Swiften/Component/UnitTest/ComponentHandshakeGeneratorTest.cpp new file mode 100644 index 0000000..e72dbea --- /dev/null +++ b/Swiften/Component/UnitTest/ComponentHandshakeGeneratorTest.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2010 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include <cppunit/extensions/HelperMacros.h> +#include <cppunit/extensions/TestFactoryRegistry.h> + +#include "Swiften/Component/ComponentHandshakeGenerator.h" + +using namespace Swift; + +class ComponentHandshakeGeneratorTest : public CppUnit::TestFixture { + CPPUNIT_TEST_SUITE(ComponentHandshakeGeneratorTest); + CPPUNIT_TEST(testGetHandshake); + CPPUNIT_TEST(testGetHandshake_SpecialChars); + CPPUNIT_TEST_SUITE_END(); + + public: + void testGetHandshake() { + String result = ComponentHandshakeGenerator::getHandshake("myid", "mysecret"); + CPPUNIT_ASSERT_EQUAL(String("4011cd31f9b99ac089a0cd7ce297da7323fa2525"), result); + } + + void testGetHandshake_SpecialChars() { + String result = ComponentHandshakeGenerator::getHandshake("&<", ">'\""); + CPPUNIT_ASSERT_EQUAL(String("33631b3e0aaeb2a11c4994c917919324028873fe"), result); + } + +}; + +CPPUNIT_TEST_SUITE_REGISTRATION(ComponentHandshakeGeneratorTest); diff --git a/Swiften/Elements/ComponentHandshake.h b/Swiften/Elements/ComponentHandshake.h new file mode 100644 index 0000000..1067310 --- /dev/null +++ b/Swiften/Elements/ComponentHandshake.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2010 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include "Swiften/Elements/Element.h" +#include "Swiften/Base/String.h" + +namespace Swift { + class ComponentHandshake : public Element { + public: + ComponentHandshake(const String& data = "") : data(data) { + } + + void setData(const String& d) { + data = d; + } + + const String& getData() const { + return data; + } + + private: + String data; + }; +} diff --git a/Swiften/Elements/StreamType.h b/Swiften/Elements/StreamType.h new file mode 100644 index 0000000..7ca9ed5 --- /dev/null +++ b/Swiften/Elements/StreamType.h @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2010 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +namespace Swift { + enum StreamType { + ClientStreamType, + ServerStreamType, + ComponentStreamType + }; +} diff --git a/Swiften/Network/Connector.cpp b/Swiften/Network/Connector.cpp index 01875f7..63fc2f9 100644 --- a/Swiften/Network/Connector.cpp +++ b/Swiften/Network/Connector.cpp @@ -122,6 +122,7 @@ void Connector::tryConnect(const HostAddressPort& target) { void Connector::handleConnectionConnectFinished(bool error) { //std::cout << "Connector::handleConnectionConnectFinished() " << error << std::endl; + currentConnection->onConnectFinished.disconnect(boost::bind(&Connector::handleConnectionConnectFinished, shared_from_this(), _1)); if (error) { currentConnection.reset(); if (!addressQueryResults.empty()) { @@ -155,6 +156,7 @@ void Connector::finish(boost::shared_ptr<Connection> connection) { } if (currentConnection) { currentConnection->onConnectFinished.disconnect(boost::bind(&Connector::handleConnectionConnectFinished, shared_from_this(), _1)); + currentConnection.reset(); } onConnectFinished(connection); } diff --git a/Swiften/Parser/ComponentHandshakeParser.cpp b/Swiften/Parser/ComponentHandshakeParser.cpp new file mode 100644 index 0000000..e88adb3 --- /dev/null +++ b/Swiften/Parser/ComponentHandshakeParser.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2010 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include "Swiften/Parser/ComponentHandshakeParser.h" +#include "Swiften/StringCodecs/Base64.h" + +namespace Swift { + +ComponentHandshakeParser::ComponentHandshakeParser() : GenericElementParser<ComponentHandshake>(), depth(0) { +} + +void ComponentHandshakeParser::handleStartElement(const String&, const String&, const AttributeMap&) { + ++depth; +} + +void ComponentHandshakeParser::handleEndElement(const String&, const String&) { + --depth; + if (depth == 0) { + getElementGeneric()->setData(text); + } +} + +void ComponentHandshakeParser::handleCharacterData(const String& text) { + this->text += text; +} + +} diff --git a/Swiften/Parser/ComponentHandshakeParser.h b/Swiften/Parser/ComponentHandshakeParser.h new file mode 100644 index 0000000..de5b8e1 --- /dev/null +++ b/Swiften/Parser/ComponentHandshakeParser.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2010 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include "Swiften/Parser/GenericElementParser.h" +#include "Swiften/Elements/ComponentHandshake.h" +#include "Swiften/Base/String.h" + +namespace Swift { + class ComponentHandshakeParser : public GenericElementParser<ComponentHandshake> { + public: + ComponentHandshakeParser(); + + virtual void handleStartElement(const String&, const String& ns, const AttributeMap&); + virtual void handleEndElement(const String&, const String& ns); + virtual void handleCharacterData(const String&); + + private: + int depth; + String text; + }; +} diff --git a/Swiften/Parser/SConscript b/Swiften/Parser/SConscript index 9cd6b31..0256cbf 100644 --- a/Swiften/Parser/SConscript +++ b/Swiften/Parser/SConscript @@ -16,6 +16,7 @@ sources = [ "MessageParser.cpp", "PayloadParser.cpp", "StanzaAckParser.cpp", + "ComponentHandshakeParser.cpp", "PayloadParserFactory.cpp", "PayloadParserFactoryCollection.cpp", "PayloadParsers/BodyParser.cpp", diff --git a/Swiften/Parser/UnitTest/XMPPParserTest.cpp b/Swiften/Parser/UnitTest/XMPPParserTest.cpp index cd42b90..90a4b03 100644 --- a/Swiften/Parser/UnitTest/XMPPParserTest.cpp +++ b/Swiften/Parser/UnitTest/XMPPParserTest.cpp @@ -22,8 +22,7 @@ using namespace Swift; -class XMPPParserTest : public CppUnit::TestFixture -{ +class XMPPParserTest : public CppUnit::TestFixture { CPPUNIT_TEST_SUITE(XMPPParserTest); CPPUNIT_TEST(testParse_SimpleSession); CPPUNIT_TEST(testParse_SimpleClientFromServerSession); @@ -37,8 +36,6 @@ class XMPPParserTest : public CppUnit::TestFixture CPPUNIT_TEST_SUITE_END(); public: - XMPPParserTest() {} - void testParse_SimpleSession() { XMPPParser testling(&client_, &factories_); @@ -51,7 +48,7 @@ class XMPPParserTest : public CppUnit::TestFixture CPPUNIT_ASSERT_EQUAL(5, static_cast<int>(client_.events.size())); CPPUNIT_ASSERT_EQUAL(Client::StreamStart, client_.events[0].type); - CPPUNIT_ASSERT_EQUAL(String("example.com"), client_.events[0].to); + CPPUNIT_ASSERT_EQUAL(String("example.com"), client_.events[0].header->getTo()); CPPUNIT_ASSERT_EQUAL(Client::ElementEvent, client_.events[1].type); CPPUNIT_ASSERT_EQUAL(Client::ElementEvent, client_.events[2].type); CPPUNIT_ASSERT_EQUAL(Client::ElementEvent, client_.events[3].type); @@ -66,8 +63,8 @@ class XMPPParserTest : public CppUnit::TestFixture CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(client_.events.size())); CPPUNIT_ASSERT_EQUAL(Client::StreamStart, client_.events[0].type); - CPPUNIT_ASSERT_EQUAL(String("example.com"), client_.events[0].from); - CPPUNIT_ASSERT_EQUAL(String("aeab"), client_.events[0].id); + CPPUNIT_ASSERT_EQUAL(String("example.com"), client_.events[0].header->getFrom()); + CPPUNIT_ASSERT_EQUAL(String("aeab"), client_.events[0].header->getID()); } @@ -159,21 +156,19 @@ class XMPPParserTest : public CppUnit::TestFixture struct Event { Event(Type type, boost::shared_ptr<Element> element) : type(type), element(element) {} - Event(Type type, const String& from, const String& to, const String& id) : type(type), from(from), to(to), id(id) {} + Event(Type type, const ProtocolHeader& header) : type(type), header(header) {} Event(Type type) : type(type) {} Type type; - String from; - String to; - String id; + boost::optional<ProtocolHeader> header; boost::shared_ptr<Element> element; }; Client() {} void handleStreamStart(const ProtocolHeader& header) { - events.push_back(Event(StreamStart, header.getFrom(), header.getTo(), header.getID())); + events.push_back(Event(StreamStart, header)); } void handleElement(boost::shared_ptr<Element> element) { diff --git a/Swiften/Parser/XMPPParser.cpp b/Swiften/Parser/XMPPParser.cpp index 795bee6..93797b3 100644 --- a/Swiften/Parser/XMPPParser.cpp +++ b/Swiften/Parser/XMPPParser.cpp @@ -37,6 +37,7 @@ #include "Swiften/Parser/CompressedParser.h" #include "Swiften/Parser/UnknownElementParser.h" #include "Swiften/Parser/TLSProceedParser.h" +#include "Swiften/Parser/ComponentHandshakeParser.h" // TODO: Whenever an error occurs in the handlers, stop the parser by returing // a bool value, and stopping the XML parser @@ -177,6 +178,9 @@ ElementParser* XMPPParser::createElementParser(const String& element, const Stri else if (element == "r" && ns == "urn:xmpp:sm:2") { return new StanzaAckRequestParser(); } + else if (element == "handshake") { + return new ComponentHandshakeParser(); + } return new UnknownElementParser(); } diff --git a/Swiften/SConscript b/Swiften/SConscript index 839413d..84d9a67 100644 --- a/Swiften/SConscript +++ b/Swiften/SConscript @@ -69,6 +69,7 @@ if env["SCONS_STAGE"] == "build" : "Serializer/CompressRequestSerializer.cpp", "Serializer/ElementSerializer.cpp", "Serializer/MessageSerializer.cpp", + "Serializer/ComponentHandshakeSerializer.cpp", "Serializer/PayloadSerializer.cpp", "Serializer/PayloadSerializerCollection.cpp", "Serializer/PayloadSerializers/CapsInfoSerializer.cpp", @@ -133,6 +134,7 @@ if env["SCONS_STAGE"] == "build" : "StreamStack", "LinkLocal", "StreamManagement", + "Component", ]) SConscript(test_only = True, dirs = [ "QA", @@ -155,6 +157,8 @@ if env["SCONS_STAGE"] == "build" : File("Client/UnitTest/ClientSessionTest.cpp"), File("Compress/UnitTest/ZLibCompressorTest.cpp"), File("Compress/UnitTest/ZLibDecompressorTest.cpp"), + File("Component/UnitTest/ComponentHandshakeGeneratorTest.cpp"), + File("Component/UnitTest/ComponentConnectorTest.cpp"), File("Disco/UnitTest/CapsInfoGeneratorTest.cpp"), File("Disco/UnitTest/CapsManagerTest.cpp"), File("Disco/UnitTest/EntityCapsManagerTest.cpp"), @@ -234,6 +238,7 @@ if env["SCONS_STAGE"] == "build" : File("Serializer/UnitTest/AuthChallengeSerializerTest.cpp"), File("Serializer/UnitTest/AuthRequestSerializerTest.cpp"), File("Serializer/UnitTest/AuthResponseSerializerTest.cpp"), + File("Serializer/UnitTest/XMPPSerializerTest.cpp"), File("Serializer/XML/UnitTest/XMLElementTest.cpp"), File("StreamManagement/UnitTest/StanzaAckRequesterTest.cpp"), File("StreamManagement/UnitTest/StanzaAckResponderTest.cpp"), diff --git a/Swiften/Serializer/ComponentHandshakeSerializer.cpp b/Swiften/Serializer/ComponentHandshakeSerializer.cpp new file mode 100644 index 0000000..de1958e --- /dev/null +++ b/Swiften/Serializer/ComponentHandshakeSerializer.cpp @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2010 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include "Swiften/Serializer/ComponentHandshakeSerializer.h" + +#include "Swiften/Elements/ComponentHandshake.h" + +namespace Swift { + +ComponentHandshakeSerializer::ComponentHandshakeSerializer() { +} + +String ComponentHandshakeSerializer::serialize(boost::shared_ptr<Element> element) const { + boost::shared_ptr<ComponentHandshake> handshake(boost::dynamic_pointer_cast<ComponentHandshake>(element)); + return "<handshake>" + handshake->getData() + "</challenge>"; +} + +} diff --git a/Swiften/Serializer/ComponentHandshakeSerializer.h b/Swiften/Serializer/ComponentHandshakeSerializer.h new file mode 100644 index 0000000..5423f08 --- /dev/null +++ b/Swiften/Serializer/ComponentHandshakeSerializer.h @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2010 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include <boost/shared_ptr.hpp> + +#include "Swiften/Elements/ComponentHandshake.h" +#include "Swiften/Serializer/GenericElementSerializer.h" + +namespace Swift { + class ComponentHandshakeSerializer : public GenericElementSerializer<ComponentHandshake> { + public: + ComponentHandshakeSerializer(); + + virtual String serialize(boost::shared_ptr<Element> element) const; + }; +} diff --git a/Swiften/Serializer/GenericElementSerializer.h b/Swiften/Serializer/GenericElementSerializer.h index ffebe40..702e374 100644 --- a/Swiften/Serializer/GenericElementSerializer.h +++ b/Swiften/Serializer/GenericElementSerializer.h @@ -4,8 +4,7 @@ * See Documentation/Licenses/GPLv3.txt for more information. */ -#ifndef SWIFTEN_GenericElementSerializer_H -#define SWIFTEN_GenericElementSerializer_H +#pragma once #include "Swiften/Serializer/ElementSerializer.h" @@ -20,5 +19,3 @@ namespace Swift { } }; } - -#endif diff --git a/Swiften/Serializer/UnitTest/XMPPSerializerTest.cpp b/Swiften/Serializer/UnitTest/XMPPSerializerTest.cpp new file mode 100644 index 0000000..45ffe4b --- /dev/null +++ b/Swiften/Serializer/UnitTest/XMPPSerializerTest.cpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2010 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include <cppunit/extensions/HelperMacros.h> +#include <cppunit/extensions/TestFactoryRegistry.h> + +#include "Swiften/Serializer/XMPPSerializer.h" +#include "Swiften/Elements/AuthChallenge.h" +#include "Swiften/Serializer/PayloadSerializerCollection.h" +#include "Swiften/Elements/ProtocolHeader.h" + +using namespace Swift; + +class XMPPSerializerTest : public CppUnit::TestFixture { + CPPUNIT_TEST_SUITE(XMPPSerializerTest); + CPPUNIT_TEST(testSerializeHeader_Client); + CPPUNIT_TEST(testSerializeHeader_Component); + CPPUNIT_TEST(testSerializeHeader_Server); + CPPUNIT_TEST_SUITE_END(); + + public: + void setUp() { + payloadSerializerCollection = new PayloadSerializerCollection(); + } + + void tearDown() { + delete payloadSerializerCollection; + } + + void testSerializeHeader_Client() { + std::auto_ptr<XMPPSerializer> testling(createSerializer(ClientStreamType)); + ProtocolHeader protocolHeader; + protocolHeader.setFrom("bla@foo.com"); + protocolHeader.setTo("foo.com"); + protocolHeader.setID("myid"); + protocolHeader.setVersion("0.99"); + + CPPUNIT_ASSERT_EQUAL(String("<?xml version=\"1.0\"?><stream:stream xmlns=\"jabber:client\" xmlns:stream=\"http://etherx.jabber.org/streams\" from=\"bla@foo.com\" to=\"foo.com\" id=\"myid\" version=\"0.99\">"), testling->serializeHeader(protocolHeader)); + } + + void testSerializeHeader_Component() { + std::auto_ptr<XMPPSerializer> testling(createSerializer(ComponentStreamType)); + ProtocolHeader protocolHeader; + protocolHeader.setFrom("bla@foo.com"); + protocolHeader.setTo("foo.com"); + protocolHeader.setID("myid"); + protocolHeader.setVersion("0.99"); + + CPPUNIT_ASSERT_EQUAL(String("<?xml version=\"1.0\"?><stream:stream xmlns=\"jabber:component:accept\" xmlns:stream=\"http://etherx.jabber.org/streams\" from=\"bla@foo.com\" to=\"foo.com\" id=\"myid\" version=\"0.99\">"), testling->serializeHeader(protocolHeader)); + } + + void testSerializeHeader_Server() { + std::auto_ptr<XMPPSerializer> testling(createSerializer(ServerStreamType)); + ProtocolHeader protocolHeader; + protocolHeader.setFrom("bla@foo.com"); + protocolHeader.setTo("foo.com"); + protocolHeader.setID("myid"); + protocolHeader.setVersion("0.99"); + + CPPUNIT_ASSERT_EQUAL(String("<?xml version=\"1.0\"?><stream:stream xmlns=\"jabber:server\" xmlns:stream=\"http://etherx.jabber.org/streams\" from=\"bla@foo.com\" to=\"foo.com\" id=\"myid\" version=\"0.99\">"), testling->serializeHeader(protocolHeader)); + } + + private: + XMPPSerializer* createSerializer(StreamType type) { + return new XMPPSerializer(payloadSerializerCollection, type); + } + + private: + PayloadSerializerCollection* payloadSerializerCollection; +}; + +CPPUNIT_TEST_SUITE_REGISTRATION(XMPPSerializerTest); diff --git a/Swiften/Serializer/XMPPSerializer.cpp b/Swiften/Serializer/XMPPSerializer.cpp index d2eb520..da4715c 100644 --- a/Swiften/Serializer/XMPPSerializer.cpp +++ b/Swiften/Serializer/XMPPSerializer.cpp @@ -8,6 +8,7 @@ #include <boost/bind.hpp> #include <iostream> +#include <cassert> #include "Swiften/Elements/ProtocolHeader.h" #include "Swiften/Base/foreach.h" @@ -30,10 +31,11 @@ #include "Swiften/Serializer/MessageSerializer.h" #include "Swiften/Serializer/PresenceSerializer.h" #include "Swiften/Serializer/IQSerializer.h" +#include "Swiften/Serializer/ComponentHandshakeSerializer.h" namespace Swift { -XMPPSerializer::XMPPSerializer(PayloadSerializerCollection* payloadSerializers) { +XMPPSerializer::XMPPSerializer(PayloadSerializerCollection* payloadSerializers, StreamType type) : type_(type) { serializers_.push_back(boost::shared_ptr<ElementSerializer>(new PresenceSerializer(payloadSerializers))); serializers_.push_back(boost::shared_ptr<ElementSerializer>(new IQSerializer(payloadSerializers))); serializers_.push_back(boost::shared_ptr<ElementSerializer>(new MessageSerializer(payloadSerializers))); @@ -53,10 +55,11 @@ XMPPSerializer::XMPPSerializer(PayloadSerializerCollection* payloadSerializers) serializers_.push_back(boost::shared_ptr<ElementSerializer>(new StreamManagementFailedSerializer())); serializers_.push_back(boost::shared_ptr<ElementSerializer>(new StanzaAckSerializer())); serializers_.push_back(boost::shared_ptr<ElementSerializer>(new StanzaAckRequestSerializer())); + serializers_.push_back(boost::shared_ptr<ElementSerializer>(new ComponentHandshakeSerializer())); } String XMPPSerializer::serializeHeader(const ProtocolHeader& header) const { - String result = "<?xml version=\"1.0\"?><stream:stream xmlns=\"jabber:client\" xmlns:stream=\"http://etherx.jabber.org/streams\""; + String result = "<?xml version=\"1.0\"?><stream:stream xmlns=\"" + getDefaultNamespace() + "\" xmlns:stream=\"http://etherx.jabber.org/streams\""; if (!header.getFrom().isEmpty()) { result += " from=\"" + header.getFrom() + "\""; } @@ -90,4 +93,14 @@ String XMPPSerializer::serializeFooter() const { return "</stream:stream>"; } +String XMPPSerializer::getDefaultNamespace() const { + switch (type_) { + case ClientStreamType: return "jabber:client"; + case ServerStreamType: return "jabber:server"; + case ComponentStreamType: return "jabber:component:accept"; + } + assert(false); + return ""; +} + } diff --git a/Swiften/Serializer/XMPPSerializer.h b/Swiften/Serializer/XMPPSerializer.h index fbb2cda..13c2cf7 100644 --- a/Swiften/Serializer/XMPPSerializer.h +++ b/Swiften/Serializer/XMPPSerializer.h @@ -10,6 +10,7 @@ #include <vector> #include "Swiften/Elements/Element.h" +#include "Swiften/Elements/StreamType.h" #include "Swiften/Base/String.h" #include "Swiften/Serializer/ElementSerializer.h" @@ -20,13 +21,17 @@ namespace Swift { class XMPPSerializer { public: - XMPPSerializer(PayloadSerializerCollection*); + XMPPSerializer(PayloadSerializerCollection*, StreamType type); String serializeHeader(const ProtocolHeader&) const; String serializeElement(boost::shared_ptr<Element> stanza) const; String serializeFooter() const; private: + String getDefaultNamespace() const; + + private: + StreamType type_; std::vector< boost::shared_ptr<ElementSerializer> > serializers_; }; } diff --git a/Swiften/Session/BasicSessionStream.cpp b/Swiften/Session/BasicSessionStream.cpp index d75e2c3..e2c2ebe 100644 --- a/Swiften/Session/BasicSessionStream.cpp +++ b/Swiften/Session/BasicSessionStream.cpp @@ -18,12 +18,12 @@ namespace Swift { -BasicSessionStream::BasicSessionStream(boost::shared_ptr<Connection> connection, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers, TLSLayerFactory* tlsLayerFactory, TimerFactory* timerFactory) : available(false), connection(connection), payloadParserFactories(payloadParserFactories), payloadSerializers(payloadSerializers), tlsLayerFactory(tlsLayerFactory), timerFactory(timerFactory) { +BasicSessionStream::BasicSessionStream(StreamType streamType, boost::shared_ptr<Connection> connection, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers, TLSLayerFactory* tlsLayerFactory, TimerFactory* timerFactory) : available(false), connection(connection), payloadParserFactories(payloadParserFactories), payloadSerializers(payloadSerializers), tlsLayerFactory(tlsLayerFactory), timerFactory(timerFactory), streamType(streamType) { } void BasicSessionStream::initialize() { xmppLayer = boost::shared_ptr<XMPPLayer>( - new XMPPLayer(payloadParserFactories, payloadSerializers)); + new XMPPLayer(payloadParserFactories, payloadSerializers, streamType)); xmppLayer->onStreamStart.connect(boost::bind(&BasicSessionStream::handleStreamStartReceived, shared_from_this(), _1)); xmppLayer->onElement.connect(boost::bind(&BasicSessionStream::handleElementReceived, shared_from_this(), _1)); xmppLayer->onError.connect(boost::bind( diff --git a/Swiften/Session/BasicSessionStream.h b/Swiften/Session/BasicSessionStream.h index 7f194ff..bea9406 100644 --- a/Swiften/Session/BasicSessionStream.h +++ b/Swiften/Session/BasicSessionStream.h @@ -11,6 +11,7 @@ #include "Swiften/Network/Connection.h" #include "Swiften/Session/SessionStream.h" +#include "Swiften/Elements/StreamType.h" namespace Swift { class TLSLayerFactory; @@ -29,6 +30,7 @@ namespace Swift { public boost::enable_shared_from_this<BasicSessionStream> { public: BasicSessionStream( + StreamType streamType, boost::shared_ptr<Connection> connection, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers, @@ -72,6 +74,7 @@ namespace Swift { PayloadSerializerCollection* payloadSerializers; TLSLayerFactory* tlsLayerFactory; TimerFactory* timerFactory; + StreamType streamType; boost::shared_ptr<XMPPLayer> xmppLayer; boost::shared_ptr<ConnectionLayer> connectionLayer; StreamStack* streamStack; @@ -79,4 +82,5 @@ namespace Swift { boost::shared_ptr<TLSLayer> tlsLayer; boost::shared_ptr<WhitespacePingLayer> whitespacePingLayer; }; + } diff --git a/Swiften/Session/Session.cpp b/Swiften/Session/Session.cpp index 34845dc..747d1d9 100644 --- a/Swiften/Session/Session.cpp +++ b/Swiften/Session/Session.cpp @@ -51,7 +51,7 @@ void Session::finishSession(const SessionError& error) { void Session::initializeStreamStack() { xmppLayer = boost::shared_ptr<XMPPLayer>( - new XMPPLayer(payloadParserFactories, payloadSerializers)); + new XMPPLayer(payloadParserFactories, payloadSerializers, ClientStreamType)); xmppLayer->onStreamStart.connect( boost::bind(&Session::handleStreamStart, shared_from_this(), _1)); xmppLayer->onElement.connect(boost::bind(&Session::handleElement, shared_from_this(), _1)); diff --git a/Swiften/StreamStack/UnitTest/StreamStackTest.cpp b/Swiften/StreamStack/UnitTest/StreamStackTest.cpp index ab716b4..0ea0835 100644 --- a/Swiften/StreamStack/UnitTest/StreamStackTest.cpp +++ b/Swiften/StreamStack/UnitTest/StreamStackTest.cpp @@ -36,7 +36,7 @@ class StreamStackTest : public CppUnit::TestFixture void setUp() { physicalStream_ = boost::shared_ptr<TestLowLayer>(new TestLowLayer()); - xmppStream_ = boost::shared_ptr<XMPPLayer>(new XMPPLayer(&parserFactories_, &serializers_)); + xmppStream_ = boost::shared_ptr<XMPPLayer>(new XMPPLayer(&parserFactories_, &serializers_, ClientStreamType)); elementsReceived_ = 0; dataWriteReceived_ = 0; } diff --git a/Swiften/StreamStack/UnitTest/XMPPLayerTest.cpp b/Swiften/StreamStack/UnitTest/XMPPLayerTest.cpp index 8f98de5..6db997e 100644 --- a/Swiften/StreamStack/UnitTest/XMPPLayerTest.cpp +++ b/Swiften/StreamStack/UnitTest/XMPPLayerTest.cpp @@ -33,7 +33,7 @@ class XMPPLayerTest : public CppUnit::TestFixture XMPPLayerTest() {} void setUp() { - testling_ = new XMPPLayer(&parserFactories_, &serializers_); + testling_ = new XMPPLayer(&parserFactories_, &serializers_, ClientStreamType); elementsReceived_ = 0; dataReceived_ = ""; errorReceived_ = 0; diff --git a/Swiften/StreamStack/XMPPLayer.cpp b/Swiften/StreamStack/XMPPLayer.cpp index 9782240..d4e329b 100644 --- a/Swiften/StreamStack/XMPPLayer.cpp +++ b/Swiften/StreamStack/XMPPLayer.cpp @@ -13,13 +13,14 @@ namespace Swift { XMPPLayer::XMPPLayer( PayloadParserFactoryCollection* payloadParserFactories, - PayloadSerializerCollection* payloadSerializers) : + PayloadSerializerCollection* payloadSerializers, + StreamType streamType) : payloadParserFactories_(payloadParserFactories), payloadSerializers_(payloadSerializers), resetParserAfterParse_(false), inParser_(false) { xmppParser_ = new XMPPParser(this, payloadParserFactories_); - xmppSerializer_ = new XMPPSerializer(payloadSerializers_); + xmppSerializer_ = new XMPPSerializer(payloadSerializers_, streamType); } XMPPLayer::~XMPPLayer() { diff --git a/Swiften/StreamStack/XMPPLayer.h b/Swiften/StreamStack/XMPPLayer.h index 0df1f14..7316afe 100644 --- a/Swiften/StreamStack/XMPPLayer.h +++ b/Swiften/StreamStack/XMPPLayer.h @@ -12,6 +12,7 @@ #include "Swiften/Base/ByteArray.h" #include "Swiften/Elements/Element.h" +#include "Swiften/Elements/StreamType.h" #include "Swiften/Parser/XMPPParserClient.h" namespace Swift { @@ -25,7 +26,8 @@ namespace Swift { public: XMPPLayer( PayloadParserFactoryCollection* payloadParserFactories, - PayloadSerializerCollection* payloadSerializers); + PayloadSerializerCollection* payloadSerializers, + StreamType streamType); ~XMPPLayer(); void writeHeader(const ProtocolHeader& header); |