diff options
Diffstat (limited to 'Swiften/Network')
50 files changed, 2151 insertions, 0 deletions
diff --git a/Swiften/Network/BoostConnection.cpp b/Swiften/Network/BoostConnection.cpp new file mode 100644 index 0000000..0d62300 --- /dev/null +++ b/Swiften/Network/BoostConnection.cpp @@ -0,0 +1,109 @@ +#include "Swiften/Network/BoostConnection.h" + +#include <iostream> +#include <boost/bind.hpp> +#include <boost/thread.hpp> + +#include "Swiften/EventLoop/MainEventLoop.h" +#include "Swiften/Base/String.h" +#include "Swiften/Base/ByteArray.h" +#include "Swiften/Network/HostAddressPort.h" + +namespace Swift { + +static const size_t BUFFER_SIZE = 4096; + +// ----------------------------------------------------------------------------- + +// A reference-counted non-modifiable buffer class. +class SharedBuffer { + public: + SharedBuffer(const ByteArray& data) : + data_(new std::vector<char>(data.begin(), data.end())), + buffer_(boost::asio::buffer(*data_)) { + } + + // ConstBufferSequence requirements. + typedef boost::asio::const_buffer value_type; + typedef const boost::asio::const_buffer* const_iterator; + const boost::asio::const_buffer* begin() const { return &buffer_; } + const boost::asio::const_buffer* end() const { return &buffer_ + 1; } + + private: + boost::shared_ptr< std::vector<char> > data_; + boost::asio::const_buffer buffer_; +}; + +// ----------------------------------------------------------------------------- + +BoostConnection::BoostConnection(boost::asio::io_service* ioService) : + socket_(*ioService), readBuffer_(BUFFER_SIZE) { +} + +BoostConnection::~BoostConnection() { +} + +void BoostConnection::listen() { + doRead(); +} + +void BoostConnection::connect(const HostAddressPort& addressPort) { + boost::asio::ip::tcp::endpoint endpoint( + boost::asio::ip::address::from_string(addressPort.getAddress().toString()), addressPort.getPort()); + socket_.async_connect( + endpoint, + boost::bind(&BoostConnection::handleConnectFinished, shared_from_this(), boost::asio::placeholders::error)); +} + +void BoostConnection::disconnect() { + //MainEventLoop::removeEventsFromOwner(shared_from_this()); + socket_.close(); +} + +void BoostConnection::write(const ByteArray& data) { + boost::asio::async_write(socket_, SharedBuffer(data), + boost::bind(&BoostConnection::handleDataWritten, shared_from_this(), boost::asio::placeholders::error)); +} + +void BoostConnection::handleConnectFinished(const boost::system::error_code& error) { + if (!error) { + MainEventLoop::postEvent(boost::bind(boost::ref(onConnectFinished), false), shared_from_this()); + doRead(); + } + else if (error != boost::asio::error::operation_aborted) { + MainEventLoop::postEvent(boost::bind(boost::ref(onConnectFinished), true), shared_from_this()); + } +} + +void BoostConnection::doRead() { + socket_.async_read_some( + boost::asio::buffer(readBuffer_), + boost::bind(&BoostConnection::handleSocketRead, shared_from_this(), boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)); +} + +void BoostConnection::handleSocketRead(const boost::system::error_code& error, size_t bytesTransferred) { + if (!error) { + MainEventLoop::postEvent(boost::bind(boost::ref(onDataRead), ByteArray(&readBuffer_[0], bytesTransferred)), shared_from_this()); + doRead(); + } + else if (error == boost::asio::error::eof) { + MainEventLoop::postEvent(boost::bind(boost::ref(onDisconnected), boost::optional<Error>()), shared_from_this()); + } + else if (error != boost::asio::error::operation_aborted) { + MainEventLoop::postEvent(boost::bind(boost::ref(onDisconnected), ReadError), shared_from_this()); + } +} + +void BoostConnection::handleDataWritten(const boost::system::error_code& error) { + if (!error) { + return; + } + if (error == boost::asio::error::eof) { + MainEventLoop::postEvent(boost::bind(boost::ref(onDisconnected), boost::optional<Error>()), shared_from_this()); + } + else if (error && error != boost::asio::error::operation_aborted) { + MainEventLoop::postEvent(boost::bind(boost::ref(onDisconnected), WriteError), shared_from_this()); + } +} + +} diff --git a/Swiften/Network/BoostConnection.h b/Swiften/Network/BoostConnection.h new file mode 100644 index 0000000..ae09fb8 --- /dev/null +++ b/Swiften/Network/BoostConnection.h @@ -0,0 +1,42 @@ +#pragma once + +#include <boost/asio.hpp> +#include <boost/enable_shared_from_this.hpp> + +#include "Swiften/Network/Connection.h" +#include "Swiften/EventLoop/EventOwner.h" + +namespace boost { + class thread; + namespace system { + class error_code; + } +} + +namespace Swift { + class BoostConnection : public Connection, public EventOwner, public boost::enable_shared_from_this<BoostConnection> { + public: + BoostConnection(boost::asio::io_service* ioService); + ~BoostConnection(); + + virtual void listen(); + virtual void connect(const HostAddressPort& address); + virtual void disconnect(); + virtual void write(const ByteArray& data); + + boost::asio::ip::tcp::socket& getSocket() { + return socket_; + } + + private: + void handleConnectFinished(const boost::system::error_code& error); + void handleSocketRead(const boost::system::error_code& error, size_t bytesTransferred); + void handleDataWritten(const boost::system::error_code& error); + void doRead(); + + private: + boost::asio::ip::tcp::socket socket_; + std::vector<char> readBuffer_; + bool disconnecting_; + }; +} diff --git a/Swiften/Network/BoostConnectionFactory.cpp b/Swiften/Network/BoostConnectionFactory.cpp new file mode 100644 index 0000000..3f62730 --- /dev/null +++ b/Swiften/Network/BoostConnectionFactory.cpp @@ -0,0 +1,13 @@ +#include "Swiften/Network/BoostConnectionFactory.h" +#include "Swiften/Network/BoostConnection.h" + +namespace Swift { + +BoostConnectionFactory::BoostConnectionFactory(boost::asio::io_service* ioService) : ioService(ioService) { +} + +boost::shared_ptr<Connection> BoostConnectionFactory::createConnection() { + return boost::shared_ptr<Connection>(new BoostConnection(ioService)); +} + +} diff --git a/Swiften/Network/BoostConnectionFactory.h b/Swiften/Network/BoostConnectionFactory.h new file mode 100644 index 0000000..5695c6c --- /dev/null +++ b/Swiften/Network/BoostConnectionFactory.h @@ -0,0 +1,20 @@ +#pragma once + +#include <boost/asio.hpp> + +#include "Swiften/Network/ConnectionFactory.h" +#include "Swiften/Network/BoostConnection.h" + +namespace Swift { + class BoostConnection; + + class BoostConnectionFactory : public ConnectionFactory { + public: + BoostConnectionFactory(boost::asio::io_service*); + + virtual boost::shared_ptr<Connection> createConnection(); + + private: + boost::asio::io_service* ioService; + }; +} diff --git a/Swiften/Network/BoostConnectionServer.cpp b/Swiften/Network/BoostConnectionServer.cpp new file mode 100644 index 0000000..cea016d --- /dev/null +++ b/Swiften/Network/BoostConnectionServer.cpp @@ -0,0 +1,68 @@ +#include "Swiften/Network/BoostConnectionServer.h" + +#include <boost/bind.hpp> +#include <boost/system/system_error.hpp> + +#include "Swiften/EventLoop/MainEventLoop.h" + +namespace Swift { + +BoostConnectionServer::BoostConnectionServer(int port, boost::asio::io_service* ioService) : port_(port), ioService_(ioService), acceptor_(NULL) { +} + + +void BoostConnectionServer::start() { + try { + assert(!acceptor_); + acceptor_ = new boost::asio::ip::tcp::acceptor( + *ioService_, + boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), port_)); + acceptNextConnection(); + } + catch (const boost::system::system_error& e) { + if (e.code() == boost::asio::error::address_in_use) { + MainEventLoop::postEvent(boost::bind(boost::ref(onStopped), Conflict), shared_from_this()); + } + else { + MainEventLoop::postEvent(boost::bind(boost::ref(onStopped), UnknownError), shared_from_this()); + } + } +} + + +void BoostConnectionServer::stop() { + stop(boost::optional<Error>()); +} + +void BoostConnectionServer::stop(boost::optional<Error> e) { + if (acceptor_) { + acceptor_->close(); + delete acceptor_; + acceptor_ = NULL; + } + MainEventLoop::postEvent(boost::bind(boost::ref(onStopped), e), shared_from_this()); +} + +void BoostConnectionServer::acceptNextConnection() { + boost::shared_ptr<BoostConnection> newConnection(new BoostConnection(&acceptor_->io_service())); + acceptor_->async_accept(newConnection->getSocket(), + boost::bind(&BoostConnectionServer::handleAccept, shared_from_this(), newConnection, boost::asio::placeholders::error)); +} + +void BoostConnectionServer::handleAccept(boost::shared_ptr<BoostConnection> newConnection, const boost::system::error_code& error) { + if (error) { + MainEventLoop::postEvent( + boost::bind( + &BoostConnectionServer::stop, shared_from_this(), UnknownError), + shared_from_this()); + } + else { + MainEventLoop::postEvent( + boost::bind(boost::ref(onNewConnection), newConnection), + shared_from_this()); + newConnection->listen(); + acceptNextConnection(); + } +} + +} diff --git a/Swiften/Network/BoostConnectionServer.h b/Swiften/Network/BoostConnectionServer.h new file mode 100644 index 0000000..d8e5eb4 --- /dev/null +++ b/Swiften/Network/BoostConnectionServer.h @@ -0,0 +1,36 @@ +#pragma once + +#include <boost/shared_ptr.hpp> +#include <boost/enable_shared_from_this.hpp> +#include <boost/asio.hpp> +#include <boost/signal.hpp> + +#include "Swiften/Network/BoostConnection.h" +#include "Swiften/Network/ConnectionServer.h" +#include "Swiften/EventLoop/EventOwner.h" + +namespace Swift { + class BoostConnectionServer : public ConnectionServer, public EventOwner, public boost::enable_shared_from_this<BoostConnectionServer> { + public: + enum Error { + Conflict, + UnknownError + }; + BoostConnectionServer(int port, boost::asio::io_service* ioService); + + void start(); + void stop(); + + boost::signal<void (boost::optional<Error>)> onStopped; + + private: + void stop(boost::optional<Error> e); + void acceptNextConnection(); + void handleAccept(boost::shared_ptr<BoostConnection> newConnection, const boost::system::error_code& error); + + private: + int port_; + boost::asio::io_service* ioService_; + boost::asio::ip::tcp::acceptor* acceptor_; + }; +} diff --git a/Swiften/Network/BoostIOServiceThread.cpp b/Swiften/Network/BoostIOServiceThread.cpp new file mode 100644 index 0000000..01c3bf3 --- /dev/null +++ b/Swiften/Network/BoostIOServiceThread.cpp @@ -0,0 +1,18 @@ +#include "Swiften/Network/BoostIOServiceThread.h" + +namespace Swift { + +BoostIOServiceThread::BoostIOServiceThread() : thread_(boost::bind(&BoostIOServiceThread::doRun, this)) { +} + +BoostIOServiceThread::~BoostIOServiceThread() { + ioService_.stop(); + thread_.join(); +} + +void BoostIOServiceThread::doRun() { + boost::asio::io_service::work work(ioService_); + ioService_.run(); +} + +} diff --git a/Swiften/Network/BoostIOServiceThread.h b/Swiften/Network/BoostIOServiceThread.h new file mode 100644 index 0000000..ddc90bf --- /dev/null +++ b/Swiften/Network/BoostIOServiceThread.h @@ -0,0 +1,23 @@ +#pragma once + +#include <boost/asio.hpp> +#include <boost/thread.hpp> + +namespace Swift { + class BoostIOServiceThread { + public: + BoostIOServiceThread(); + ~BoostIOServiceThread(); + + boost::asio::io_service& getIOService() { + return ioService_; + } + + private: + void doRun(); + + private: + boost::asio::io_service ioService_; + boost::thread thread_; + }; +} diff --git a/Swiften/Network/BoostTimer.cpp b/Swiften/Network/BoostTimer.cpp new file mode 100644 index 0000000..fdbd45d --- /dev/null +++ b/Swiften/Network/BoostTimer.cpp @@ -0,0 +1,34 @@ +#include "Swiften/Network/BoostTimer.h" + +#include <boost/date_time/posix_time/posix_time.hpp> +#include <boost/asio.hpp> + +#include "Swiften/EventLoop/MainEventLoop.h" + +namespace Swift { + +BoostTimer::BoostTimer(int milliseconds, boost::asio::io_service* service) : + timeout(milliseconds), timer(*service) { +} + +void BoostTimer::start() { + timer.expires_from_now(boost::posix_time::milliseconds(timeout)); + timer.async_wait(boost::bind(&BoostTimer::handleTimerTick, shared_from_this(), boost::asio::placeholders::error)); +} + +void BoostTimer::stop() { + timer.cancel(); +} + +void BoostTimer::handleTimerTick(const boost::system::error_code& error) { + if (error) { + assert(error == boost::asio::error::operation_aborted); + } + else { + MainEventLoop::postEvent(boost::bind(boost::ref(onTick)), shared_from_this()); + timer.expires_from_now(boost::posix_time::milliseconds(timeout)); + timer.async_wait(boost::bind(&BoostTimer::handleTimerTick, shared_from_this(), boost::asio::placeholders::error)); + } +} + +} diff --git a/Swiften/Network/BoostTimer.h b/Swiften/Network/BoostTimer.h new file mode 100644 index 0000000..9b27cf9 --- /dev/null +++ b/Swiften/Network/BoostTimer.h @@ -0,0 +1,25 @@ +#pragma once + +#include <boost/asio.hpp> +#include <boost/thread.hpp> +#include <boost/enable_shared_from_this.hpp> + +#include "Swiften/EventLoop/EventOwner.h" +#include "Swiften/Network/Timer.h" + +namespace Swift { + class BoostTimer : public Timer, public EventOwner, public boost::enable_shared_from_this<BoostTimer> { + public: + BoostTimer(int milliseconds, boost::asio::io_service* service); + + virtual void start(); + virtual void stop(); + + private: + void handleTimerTick(const boost::system::error_code& error); + + private: + int timeout; + boost::asio::deadline_timer timer; + }; +} diff --git a/Swiften/Network/BoostTimerFactory.cpp b/Swiften/Network/BoostTimerFactory.cpp new file mode 100644 index 0000000..bbcd83f --- /dev/null +++ b/Swiften/Network/BoostTimerFactory.cpp @@ -0,0 +1,13 @@ +#include "Swiften/Network/BoostTimerFactory.h" +#include "Swiften/Network/BoostTimer.h" + +namespace Swift { + +BoostTimerFactory::BoostTimerFactory(boost::asio::io_service* ioService) : ioService(ioService) { +} + +boost::shared_ptr<Timer> BoostTimerFactory::createTimer(int milliseconds) { + return boost::shared_ptr<Timer>(new BoostTimer(milliseconds, ioService)); +} + +} diff --git a/Swiften/Network/BoostTimerFactory.h b/Swiften/Network/BoostTimerFactory.h new file mode 100644 index 0000000..e98c9de --- /dev/null +++ b/Swiften/Network/BoostTimerFactory.h @@ -0,0 +1,20 @@ +#pragma once + +#include <boost/asio.hpp> + +#include "Swiften/Network/TimerFactory.h" +#include "Swiften/Network/BoostTimer.h" + +namespace Swift { + class BoostTimer; + + class BoostTimerFactory : public TimerFactory { + public: + BoostTimerFactory(boost::asio::io_service*); + + virtual boost::shared_ptr<Timer> createTimer(int milliseconds); + + private: + boost::asio::io_service* ioService; + }; +} diff --git a/Swiften/Network/CAresDomainNameResolver.cpp b/Swiften/Network/CAresDomainNameResolver.cpp new file mode 100644 index 0000000..c0bf8a0 --- /dev/null +++ b/Swiften/Network/CAresDomainNameResolver.cpp @@ -0,0 +1,162 @@ +// TODO: Check the second param of postEvent. We sometimes omit it. Same +// goes for the PlatformDomainNameResolver. + +#include "Swiften/Network/CAresDomainNameResolver.h" +#include "Swiften/Base/Platform.h" + +#ifndef SWIFTEN_PLATFORM_WINDOWS +#include <netdb.h> +#include <arpa/inet.h> +#endif +#include <algorithm> + +#include "Swiften/Network/DomainNameServiceQuery.h" +#include "Swiften/Network/DomainNameAddressQuery.h" +#include "Swiften/Base/ByteArray.h" +#include "Swiften/EventLoop/MainEventLoop.h" +#include "Swiften/Base/foreach.h" + +namespace Swift { + +class CAresQuery : public boost::enable_shared_from_this<CAresQuery>, public EventOwner { + public: + CAresQuery(const String& query, int dnsclass, int type, CAresDomainNameResolver* resolver) : query(query), dnsclass(dnsclass), type(type), resolver(resolver) { + } + + virtual ~CAresQuery() { + } + + void addToQueue() { + resolver->addToQueue(shared_from_this()); + } + + void doRun(ares_channel* channel) { + ares_query(*channel, query.getUTF8Data(), dnsclass, type, &CAresQuery::handleResult, this); + } + + static void handleResult(void* arg, int status, int timeouts, unsigned char* buffer, int len) { + reinterpret_cast<CAresQuery*>(arg)->handleResult(status, timeouts, buffer, len); + } + + virtual void handleResult(int status, int, unsigned char* buffer, int len) = 0; + + private: + String query; + int dnsclass; + int type; + CAresDomainNameResolver* resolver; +}; + +class CAresDomainNameServiceQuery : public DomainNameServiceQuery, public CAresQuery { + public: + CAresDomainNameServiceQuery(const String& service, CAresDomainNameResolver* resolver) : CAresQuery(service, 1, 33, resolver) { + } + + virtual void run() { + addToQueue(); + } + + void handleResult(int status, int, unsigned char* buffer, int len) { + if (status == ARES_SUCCESS) { + std::vector<DomainNameServiceQuery::Result> records; + ares_srv_reply* rawRecords; + if (ares_parse_srv_reply(buffer, len, &rawRecords) == ARES_SUCCESS) { + for( ; rawRecords != NULL; rawRecords = rawRecords->next) { + DomainNameServiceQuery::Result record; + record.priority = rawRecords->priority; + record.weight = rawRecords->weight; + record.port = rawRecords->port; + record.hostname = String(rawRecords->host); + records.push_back(record); + } + } + std::sort(records.begin(), records.end(), ResultPriorityComparator()); + MainEventLoop::postEvent(boost::bind(boost::ref(onResult), records)); + } + else if (status != ARES_EDESTRUCTION) { + MainEventLoop::postEvent(boost::bind(boost::ref(onResult), std::vector<DomainNameServiceQuery::Result>()), shared_from_this()); + } + } +}; + +class CAresDomainNameAddressQuery : public DomainNameAddressQuery, public CAresQuery { + public: + CAresDomainNameAddressQuery(const String& host, CAresDomainNameResolver* resolver) : CAresQuery(host, 1, 1, resolver) { + } + + virtual void run() { + addToQueue(); + } + + void handleResult(int status, int, unsigned char* buffer, int len) { + if (status == ARES_SUCCESS) { + struct hostent* hosts; + if (ares_parse_a_reply(buffer, len, &hosts, NULL, NULL) == ARES_SUCCESS) { + // Check whether the different fields are what we expect them to be + struct in_addr addr; + addr.s_addr = *(unsigned int*)hosts->h_addr_list[0]; + HostAddress result(inet_ntoa(addr)); + MainEventLoop::postEvent(boost::bind(boost::ref(onResult), result, boost::optional<DomainNameResolveError>()), boost::dynamic_pointer_cast<CAresDomainNameAddressQuery>(shared_from_this())); + ares_free_hostent(hosts); + } + else { + MainEventLoop::postEvent(boost::bind(boost::ref(onResult), HostAddress(), boost::optional<DomainNameResolveError>(DomainNameResolveError())), shared_from_this()); + } + } + else if (status != ARES_EDESTRUCTION) { + MainEventLoop::postEvent(boost::bind(boost::ref(onResult), HostAddress(), boost::optional<DomainNameResolveError>(DomainNameResolveError())), shared_from_this()); + } + } +}; + +CAresDomainNameResolver::CAresDomainNameResolver() : stopRequested(false) { + ares_init(&channel); + thread = new boost::thread(boost::bind(&CAresDomainNameResolver::run, this)); +} + +CAresDomainNameResolver::~CAresDomainNameResolver() { + stopRequested = true; + thread->join(); + ares_destroy(channel); +} + +boost::shared_ptr<DomainNameServiceQuery> CAresDomainNameResolver::createServiceQuery(const String& name) { + return boost::shared_ptr<DomainNameServiceQuery>(new CAresDomainNameServiceQuery(getNormalized(name), this)); +} + +boost::shared_ptr<DomainNameAddressQuery> CAresDomainNameResolver::createAddressQuery(const String& name) { + return boost::shared_ptr<DomainNameAddressQuery>(new CAresDomainNameAddressQuery(getNormalized(name), this)); +} + +void CAresDomainNameResolver::addToQueue(boost::shared_ptr<CAresQuery> query) { + boost::lock_guard<boost::mutex> lock(pendingQueriesMutex); + pendingQueries.push_back(query); +} + +void CAresDomainNameResolver::run() { + fd_set readers, writers; + struct timeval timeout; + timeout.tv_sec = 0; + timeout.tv_usec = 100000; + while(!stopRequested) { + { + boost::unique_lock<boost::mutex> lock(pendingQueriesMutex); + foreach(const boost::shared_ptr<CAresQuery>& query, pendingQueries) { + query->doRun(&channel); + } + pendingQueries.clear(); + } + FD_ZERO(&readers); + FD_ZERO(&writers); + int nfds = ares_fds(channel, &readers, &writers); + //if (nfds) { + // break; + //} + struct timeval tv; + struct timeval* tvp = ares_timeout(channel, &timeout, &tv); + select(nfds, &readers, &writers, NULL, tvp); + ares_process(channel, &readers, &writers); + } +} + +} diff --git a/Swiften/Network/CAresDomainNameResolver.h b/Swiften/Network/CAresDomainNameResolver.h new file mode 100644 index 0000000..0cdd163 --- /dev/null +++ b/Swiften/Network/CAresDomainNameResolver.h @@ -0,0 +1,34 @@ +#pragma once + +#include <ares.h> +#include <boost/thread.hpp> +#include <boost/thread/mutex.hpp> +#include <list> + +#include "Swiften/Network/DomainNameResolver.h" + +namespace Swift { + class CAresQuery; + + class CAresDomainNameResolver : public DomainNameResolver { + public: + CAresDomainNameResolver(); + ~CAresDomainNameResolver(); + + virtual boost::shared_ptr<DomainNameServiceQuery> createServiceQuery(const String& name); + virtual boost::shared_ptr<DomainNameAddressQuery> createAddressQuery(const String& name); + + private: + friend class CAresQuery; + + void run(); + void addToQueue(boost::shared_ptr<CAresQuery>); + + private: + bool stopRequested; + ares_channel channel; + boost::thread* thread; + boost::mutex pendingQueriesMutex; + std::list< boost::shared_ptr<CAresQuery> > pendingQueries; + }; +} diff --git a/Swiften/Network/Connection.h b/Swiften/Network/Connection.h new file mode 100644 index 0000000..a995774 --- /dev/null +++ b/Swiften/Network/Connection.h @@ -0,0 +1,31 @@ +#pragma once + +#include <boost/signals.hpp> + +#include "Swiften/Base/ByteArray.h" +#include "Swiften/Base/String.h" + +namespace Swift { + class HostAddressPort; + + class Connection { + public: + enum Error { + ReadError, + WriteError + }; + + Connection() {} + virtual ~Connection() {} + + virtual void listen() = 0; + virtual void connect(const HostAddressPort& address) = 0; + virtual void disconnect() = 0; + virtual void write(const ByteArray& data) = 0; + + public: + boost::signal<void (bool /* error */)> onConnectFinished; + boost::signal<void (const boost::optional<Error>&)> onDisconnected; + boost::signal<void (const ByteArray&)> onDataRead; + }; +} diff --git a/Swiften/Network/ConnectionFactory.cpp b/Swiften/Network/ConnectionFactory.cpp new file mode 100644 index 0000000..686a165 --- /dev/null +++ b/Swiften/Network/ConnectionFactory.cpp @@ -0,0 +1,8 @@ +#include "Swiften/Network/ConnectionFactory.h" + +namespace Swift { + +ConnectionFactory::~ConnectionFactory() { +} + +} diff --git a/Swiften/Network/ConnectionFactory.h b/Swiften/Network/ConnectionFactory.h new file mode 100644 index 0000000..e78f6ab --- /dev/null +++ b/Swiften/Network/ConnectionFactory.h @@ -0,0 +1,14 @@ +#pragma once + +#include <boost/shared_ptr.hpp> + +namespace Swift { + class Connection; + + class ConnectionFactory { + public: + virtual ~ConnectionFactory(); + + virtual boost::shared_ptr<Connection> createConnection() = 0; + }; +} diff --git a/Swiften/Network/ConnectionServer.cpp b/Swiften/Network/ConnectionServer.cpp new file mode 100644 index 0000000..7f63fee --- /dev/null +++ b/Swiften/Network/ConnectionServer.cpp @@ -0,0 +1,8 @@ +#include "Swiften/Network/ConnectionServer.h" + +namespace Swift { + +ConnectionServer::~ConnectionServer() { +} + +} diff --git a/Swiften/Network/ConnectionServer.h b/Swiften/Network/ConnectionServer.h new file mode 100644 index 0000000..539367d --- /dev/null +++ b/Swiften/Network/ConnectionServer.h @@ -0,0 +1,15 @@ +#pragma once + +#include <boost/shared_ptr.hpp> +#include <boost/signal.hpp> + +#include "Swiften/Network/Connection.h" + +namespace Swift { + class ConnectionServer { + public: + virtual ~ConnectionServer(); + + boost::signal<void (boost::shared_ptr<Connection>)> onNewConnection; + }; +} diff --git a/Swiften/Network/Connector.cpp b/Swiften/Network/Connector.cpp new file mode 100644 index 0000000..d372bf2 --- /dev/null +++ b/Swiften/Network/Connector.cpp @@ -0,0 +1,126 @@ +#include "Swiften/Network/Connector.h" + +#include <boost/bind.hpp> +#include <iostream> + +#include "Swiften/Network/ConnectionFactory.h" +#include "Swiften/Network/DomainNameResolver.h" +#include "Swiften/Network/DomainNameAddressQuery.h" +#include "Swiften/Network/TimerFactory.h" + +namespace Swift { + +Connector::Connector(const String& hostname, DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory) : hostname(hostname), resolver(resolver), connectionFactory(connectionFactory), timerFactory(timerFactory), timeoutMilliseconds(0), queriedAllHosts(true) { +} + +void Connector::setTimeoutMilliseconds(int milliseconds) { + timeoutMilliseconds = milliseconds; +} + +void Connector::start() { + //std::cout << "Connector::start()" << std::endl; + assert(!currentConnection); + assert(!serviceQuery); + assert(!timer); + queriedAllHosts = false; + serviceQuery = resolver->createServiceQuery("_xmpp-client._tcp." + hostname); + serviceQuery->onResult.connect(boost::bind(&Connector::handleServiceQueryResult, this, _1)); + if (timeoutMilliseconds > 0) { + timer = timerFactory->createTimer(timeoutMilliseconds); + timer->onTick.connect(boost::bind(&Connector::handleTimeout, this)); + timer->start(); + } + serviceQuery->run(); +} + +void Connector::queryAddress(const String& hostname) { + assert(!addressQuery); + addressQuery = resolver->createAddressQuery(hostname); + addressQuery->onResult.connect(boost::bind(&Connector::handleAddressQueryResult, this, _1, _2)); + addressQuery->run(); +} + +void Connector::handleServiceQueryResult(const std::vector<DomainNameServiceQuery::Result>& result) { + //std::cout << "Received SRV results" << std::endl; + serviceQueryResults = std::deque<DomainNameServiceQuery::Result>(result.begin(), result.end()); + serviceQuery.reset(); + tryNextHostname(); +} + +void Connector::tryNextHostname() { + if (queriedAllHosts) { + //std::cout << "Connector::tryNextHostName(): Queried all hosts. Error." << std::endl; + finish(boost::shared_ptr<Connection>()); + } + else if (serviceQueryResults.empty()) { + //std::cout << "Connector::tryNextHostName(): Falling back on A resolution" << std::endl; + // Fall back on simple address resolving + queriedAllHosts = true; + queryAddress(hostname); + } + else { + //std::cout << "Connector::tryNextHostName(): Querying next address" << std::endl; + queryAddress(serviceQueryResults.front().hostname); + } +} + +void Connector::handleAddressQueryResult(const HostAddress& address, boost::optional<DomainNameResolveError> error) { + //std::cout << "Connector::handleAddressQueryResult(): Start" << std::endl; + addressQuery.reset(); + if (!serviceQueryResults.empty()) { + DomainNameServiceQuery::Result serviceQueryResult = serviceQueryResults.front(); + serviceQueryResults.pop_front(); + if (error) { + //std::cout << "Connector::handleAddressQueryResult(): A lookup for SRV host " << serviceQueryResult.hostname << " failed." << std::endl; + tryNextHostname(); + } + else { + //std::cout << "Connector::handleAddressQueryResult(): A lookup for SRV host " << serviceQueryResult.hostname << " succeeded: " << address.toString() << std::endl; + tryConnect(HostAddressPort(address, serviceQueryResult.port)); + } + } + else if (error) { + //std::cout << "Connector::handleAddressQueryResult(): Fallback address query failed. Giving up" << std::endl; + // The fallback address query failed + assert(queriedAllHosts); + finish(boost::shared_ptr<Connection>()); + } + else { + //std::cout << "Connector::handleAddressQueryResult(): Fallback address query succeeded: " << address.toString() << std::endl; + // The fallback query succeeded + tryConnect(HostAddressPort(address, 5222)); + } +} + +void Connector::tryConnect(const HostAddressPort& target) { + assert(!currentConnection); + //std::cout << "Connector::tryConnect() " << target.getAddress().toString() << " " << target.getPort() << std::endl; + currentConnection = connectionFactory->createConnection(); + currentConnection->onConnectFinished.connect(boost::bind(&Connector::handleConnectionConnectFinished, this, _1)); + currentConnection->connect(target); +} + +void Connector::handleConnectionConnectFinished(bool error) { + //std::cout << "Connector::handleConnectionConnectFinished() " << error << std::endl; + if (error) { + currentConnection.reset(); + tryNextHostname(); + } + else { + finish(currentConnection); + } +} + +void Connector::finish(boost::shared_ptr<Connection> connection) { + if (timer) { + timer->stop(); + timer.reset(); + } + onConnectFinished(connection); +} + +void Connector::handleTimeout() { + finish(boost::shared_ptr<Connection>()); +} + +}; diff --git a/Swiften/Network/Connector.h b/Swiften/Network/Connector.h new file mode 100644 index 0000000..507f085 --- /dev/null +++ b/Swiften/Network/Connector.h @@ -0,0 +1,54 @@ +#pragma once + +#include <deque> +#include <boost/signal.hpp> +#include <boost/shared_ptr.hpp> + +#include "Swiften/Network/DomainNameServiceQuery.h" +#include "Swiften/Network/Connection.h" +#include "Swiften/Network/Timer.h" +#include "Swiften/Network/HostAddressPort.h" +#include "Swiften/Base/String.h" +#include "Swiften/Network/DomainNameResolveError.h" + +namespace Swift { + class DomainNameAddressQuery; + class DomainNameResolver; + class ConnectionFactory; + class TimerFactory; + + class Connector : public boost::bsignals::trackable { + public: + Connector(const String& hostname, DomainNameResolver*, ConnectionFactory*, TimerFactory*); + + void setTimeoutMilliseconds(int milliseconds); + void start(); + + boost::signal<void (boost::shared_ptr<Connection>)> onConnectFinished; + + private: + void handleServiceQueryResult(const std::vector<DomainNameServiceQuery::Result>& result); + void handleAddressQueryResult(const HostAddress& address, boost::optional<DomainNameResolveError> error); + void queryAddress(const String& hostname); + + void tryNextHostname(); + void tryConnect(const HostAddressPort& target); + + void handleConnectionConnectFinished(bool error); + void finish(boost::shared_ptr<Connection>); + void handleTimeout(); + + private: + String hostname; + DomainNameResolver* resolver; + ConnectionFactory* connectionFactory; + TimerFactory* timerFactory; + int timeoutMilliseconds; + boost::shared_ptr<Timer> timer; + boost::shared_ptr<DomainNameServiceQuery> serviceQuery; + std::deque<DomainNameServiceQuery::Result> serviceQueryResults; + boost::shared_ptr<DomainNameAddressQuery> addressQuery; + bool queriedAllHosts; + boost::shared_ptr<Connection> currentConnection; + }; +}; diff --git a/Swiften/Network/DomainNameAddressQuery.cpp b/Swiften/Network/DomainNameAddressQuery.cpp new file mode 100644 index 0000000..5e77cd7 --- /dev/null +++ b/Swiften/Network/DomainNameAddressQuery.cpp @@ -0,0 +1,8 @@ +#include "Swiften/Network/DomainNameAddressQuery.h" + +namespace Swift { + +DomainNameAddressQuery::~DomainNameAddressQuery() { +} + +} diff --git a/Swiften/Network/DomainNameAddressQuery.h b/Swiften/Network/DomainNameAddressQuery.h new file mode 100644 index 0000000..66a79db --- /dev/null +++ b/Swiften/Network/DomainNameAddressQuery.h @@ -0,0 +1,19 @@ +#pragma once + +#include <boost/signals.hpp> +#include <boost/optional.hpp> +#include <boost/shared_ptr.hpp> + +#include "Swiften/Network/DomainNameResolveError.h" +#include "Swiften/Network/HostAddress.h" + +namespace Swift { + class DomainNameAddressQuery { + public: + virtual ~DomainNameAddressQuery(); + + virtual void run() = 0; + + boost::signal<void (const HostAddress&, boost::optional<DomainNameResolveError>)> onResult; + }; +} diff --git a/Swiften/Network/DomainNameResolveError.h b/Swiften/Network/DomainNameResolveError.h new file mode 100644 index 0000000..860ea23 --- /dev/null +++ b/Swiften/Network/DomainNameResolveError.h @@ -0,0 +1,10 @@ +#pragma once + +#include "Swiften/Base/Error.h" + +namespace Swift { + class DomainNameResolveError : public Error { + public: + DomainNameResolveError() {} + }; +} diff --git a/Swiften/Network/DomainNameResolver.cpp b/Swiften/Network/DomainNameResolver.cpp new file mode 100644 index 0000000..63ed881 --- /dev/null +++ b/Swiften/Network/DomainNameResolver.cpp @@ -0,0 +1,22 @@ +#include "Swiften/Network/DomainNameResolver.h" + +#include <idna.h> + +namespace Swift { + +DomainNameResolver::~DomainNameResolver() { +} + +String DomainNameResolver::getNormalized(const String& domain) { + char* output; + if (idna_to_ascii_8z(domain.getUTF8Data(), &output, 0) == IDNA_SUCCESS) { + String result(output); + free(output); + return result; + } + else { + return domain; + } +} + +} diff --git a/Swiften/Network/DomainNameResolver.h b/Swiften/Network/DomainNameResolver.h new file mode 100644 index 0000000..d3dab26 --- /dev/null +++ b/Swiften/Network/DomainNameResolver.h @@ -0,0 +1,22 @@ +#pragma once + +#include <boost/shared_ptr.hpp> + +#include "Swiften/Base/String.h" + +namespace Swift { + class DomainNameServiceQuery; + class DomainNameAddressQuery; + class String; + + class DomainNameResolver { + public: + virtual ~DomainNameResolver(); + + virtual boost::shared_ptr<DomainNameServiceQuery> createServiceQuery(const String& name) = 0; + virtual boost::shared_ptr<DomainNameAddressQuery> createAddressQuery(const String& name) = 0; + + protected: + static String getNormalized(const String& domain); + }; +} diff --git a/Swiften/Network/DomainNameServiceQuery.cpp b/Swiften/Network/DomainNameServiceQuery.cpp new file mode 100644 index 0000000..7dfd353 --- /dev/null +++ b/Swiften/Network/DomainNameServiceQuery.cpp @@ -0,0 +1,8 @@ +#include "Swiften/Network/DomainNameServiceQuery.h" + +namespace Swift { + +DomainNameServiceQuery::~DomainNameServiceQuery() { +} + +} diff --git a/Swiften/Network/DomainNameServiceQuery.h b/Swiften/Network/DomainNameServiceQuery.h new file mode 100644 index 0000000..57e48d3 --- /dev/null +++ b/Swiften/Network/DomainNameServiceQuery.h @@ -0,0 +1,33 @@ +#pragma once + +#include <boost/signals.hpp> +#include <boost/optional.hpp> +#include <vector> + +#include "Swiften/Base/String.h" +#include "Swiften/Network/DomainNameResolveError.h" + +namespace Swift { + class DomainNameServiceQuery { + public: + struct Result { + Result(const String& hostname = "", int port = -1, int priority = -1, int weight = -1) : hostname(hostname), port(port), priority(priority), weight(weight) {} + String hostname; + int port; + int priority; + int weight; + }; + + struct ResultPriorityComparator { + bool operator()(const DomainNameServiceQuery::Result& a, const DomainNameServiceQuery::Result& b) const { + return a.priority < b.priority; + } + }; + + virtual ~DomainNameServiceQuery(); + + virtual void run() = 0; + + boost::signal<void (const std::vector<Result>&)> onResult; + }; +} diff --git a/Swiften/Network/DummyConnection.h b/Swiften/Network/DummyConnection.h new file mode 100644 index 0000000..11281b3 --- /dev/null +++ b/Swiften/Network/DummyConnection.h @@ -0,0 +1,40 @@ +#pragma once + +#include <cassert> +#include <boost/bind.hpp> +#include <boost/enable_shared_from_this.hpp> + +#include "Swiften/Network/Connection.h" +#include "Swiften/EventLoop/MainEventLoop.h" +#include "Swiften/EventLoop/EventOwner.h" + +namespace Swift { + class DummyConnection : + public Connection, + public EventOwner, + public boost::enable_shared_from_this<DummyConnection> { + + void listen() { + assert(false); + } + + void connect(const HostAddressPort&) { + assert(false); + } + + void disconnect() { + assert(false); + } + + void write(const ByteArray& data) { + onDataWritten(data); + } + + void receive(const ByteArray& data) { + MainEventLoop::postEvent(boost::bind( + boost::ref(onDataRead), ByteArray(data)), shared_from_this()); + } + + boost::signal<void (const ByteArray&)> onDataWritten; + }; +} diff --git a/Swiften/Network/DummyTimerFactory.cpp b/Swiften/Network/DummyTimerFactory.cpp new file mode 100644 index 0000000..7626584 --- /dev/null +++ b/Swiften/Network/DummyTimerFactory.cpp @@ -0,0 +1,60 @@ +#include "Swiften/Network/DummyTimerFactory.h" + +#include <algorithm> + +#include "Swiften/Base/foreach.h" +#include "Swiften/Network/Timer.h" + +namespace Swift { + +class DummyTimerFactory::DummyTimer : public Timer { + public: + DummyTimer(int timeout) : timeout(timeout), isRunning(false) { + } + + virtual void start() { + isRunning = true; + } + + virtual void stop() { + isRunning = false; + } + + int timeout; + bool isRunning; +}; + + +DummyTimerFactory::DummyTimerFactory() : currentTime(0) { +} + +boost::shared_ptr<Timer> DummyTimerFactory::createTimer(int milliseconds) { + boost::shared_ptr<DummyTimer> timer(new DummyTimer(milliseconds)); + timers.push_back(timer); + return timer; +} + +static bool hasZeroTimeout(boost::shared_ptr<DummyTimerFactory::DummyTimer> timer) { + return timer->timeout == 0; +} + +void DummyTimerFactory::setTime(int time) { + assert(time > currentTime); + int increment = time - currentTime; + std::vector< boost::shared_ptr<DummyTimer> > notifyTimers(timers.begin(), timers.end()); + foreach(boost::shared_ptr<DummyTimer> timer, notifyTimers) { + if (increment >= timer->timeout) { + if (timer->isRunning) { + timer->onTick(); + } + timer->timeout = 0; + } + else { + timer->timeout -= increment; + } + } + timers.erase(std::remove_if(timers.begin(), timers.end(), hasZeroTimeout), timers.end()); + currentTime = time; +} + +} diff --git a/Swiften/Network/DummyTimerFactory.h b/Swiften/Network/DummyTimerFactory.h new file mode 100644 index 0000000..feac029 --- /dev/null +++ b/Swiften/Network/DummyTimerFactory.h @@ -0,0 +1,22 @@ +#pragma once + +#include <list> + +#include "Swiften/Network/TimerFactory.h" + +namespace Swift { + class DummyTimerFactory : public TimerFactory { + public: + class DummyTimer; + + DummyTimerFactory(); + + virtual boost::shared_ptr<Timer> createTimer(int milliseconds); + void setTime(int time); + + private: + friend class DummyTimer; + int currentTime; + std::list<boost::shared_ptr<DummyTimer> > timers; + }; +} diff --git a/Swiften/Network/FakeConnection.h b/Swiften/Network/FakeConnection.h new file mode 100644 index 0000000..92a03c3 --- /dev/null +++ b/Swiften/Network/FakeConnection.h @@ -0,0 +1,88 @@ +#pragma once + +#include <boost/optional.hpp> +#include <boost/bind.hpp> +#include <boost/enable_shared_from_this.hpp> +#include <vector> + +#include "Swiften/Network/Connection.h" +#include "Swiften/Network/HostAddressPort.h" +#include "Swiften/EventLoop/EventOwner.h" +#include "Swiften/EventLoop/MainEventLoop.h" + +namespace Swift { + class FakeConnection : + public Connection, + public EventOwner, + public boost::enable_shared_from_this<FakeConnection> { + public: + enum State { + Initial, + Connecting, + Connected, + Disconnected, + DisconnectedWithError + }; + + FakeConnection() : state(Initial), delayConnect(false) {} + + virtual void listen() { + assert(false); + } + + void setError(const Error& e) { + error = boost::optional<Error>(e); + state = DisconnectedWithError; + if (connectedTo) { + MainEventLoop::postEvent( + boost::bind(boost::ref(onDisconnected), error), + shared_from_this()); + } + } + + virtual void connect(const HostAddressPort& address) { + if (delayConnect) { + state = Connecting; + } + else { + if (!error) { + connectedTo = address; + state = Connected; + } + else { + state = DisconnectedWithError; + } + MainEventLoop::postEvent( + boost::bind(boost::ref(onConnectFinished), error), + shared_from_this()); + } + } + + virtual void disconnect() { + if (!error) { + state = Disconnected; + } + else { + state = DisconnectedWithError; + } + connectedTo.reset(); + MainEventLoop::postEvent( + boost::bind(boost::ref(onDisconnected), error), + shared_from_this()); + } + + virtual void write(const ByteArray& data) { + dataWritten.push_back(data); + } + + void setDelayConnect() { + delayConnect = true; + } + + boost::optional<HostAddressPort> connectedTo; + std::vector<ByteArray> dataWritten; + boost::optional<Error> error; + State state; + bool delayConnect; + }; +} diff --git a/Swiften/Network/HostAddress.cpp b/Swiften/Network/HostAddress.cpp new file mode 100644 index 0000000..8ac66bb --- /dev/null +++ b/Swiften/Network/HostAddress.cpp @@ -0,0 +1,61 @@ +#include "Swiften/Network/HostAddress.h" + +#include <boost/numeric/conversion/cast.hpp> +#include <boost/lexical_cast.hpp> +#include <cassert> +#include <sstream> +#include <iomanip> + +#include "Swiften/Base/foreach.h" +#include "Swiften/Base/String.h" + +namespace Swift { + +HostAddress::HostAddress() { + for (int i = 0; i < 4; ++i) { + address_.push_back(0); + } +} + +HostAddress::HostAddress(const String& address) { + std::vector<String> components = address.split('.'); + assert(components.size() == 4); + foreach(const String& component, components) { + address_.push_back(boost::lexical_cast<int>(component.getUTF8String())); + } +} + +HostAddress::HostAddress(const unsigned char* address, int length) { + assert(length == 4 || length == 16); + address_.reserve(length); + for (int i = 0; i < length; ++i) { + address_.push_back(address[i]); + } +} + +std::string HostAddress::toString() const { + if (address_.size() == 4) { + std::ostringstream result; + for (size_t i = 0; i < address_.size() - 1; ++i) { + result << boost::numeric_cast<unsigned int>(address_[i]) << "."; + } + result << boost::numeric_cast<unsigned int>(address_[address_.size() - 1]); + return result.str(); + } + else if (address_.size() == 16) { + std::ostringstream result; + result << std::hex; + result.fill('0'); + for (size_t i = 0; i < (address_.size() / 2) - 1; ++i) { + result << std::setw(2) << boost::numeric_cast<unsigned int>(address_[2*i]) << std::setw(2) << boost::numeric_cast<unsigned int>(address_[(2*i)+1]) << ":"; + } + result << std::setw(2) << boost::numeric_cast<unsigned int>(address_[address_.size() - 2]) << std::setw(2) << boost::numeric_cast<unsigned int>(address_[address_.size() - 1]); + return result.str(); + } + else { + assert(false); + return ""; + } +} + +} diff --git a/Swiften/Network/HostAddress.h b/Swiften/Network/HostAddress.h new file mode 100644 index 0000000..11f8a2b --- /dev/null +++ b/Swiften/Network/HostAddress.h @@ -0,0 +1,28 @@ +#pragma once + +#include <string> +#include <vector> + +namespace Swift { + class String; + + class HostAddress { + public: + HostAddress(); + HostAddress(const String&); + HostAddress(const unsigned char* address, int length); + + const std::vector<unsigned char>& getRawAddress() const { + return address_; + } + + std::string toString() const; + + bool operator==(const HostAddress& o) const { + return address_ == o.address_; + } + + private: + std::vector<unsigned char> address_; + }; +} diff --git a/Swiften/Network/HostAddressPort.h b/Swiften/Network/HostAddressPort.h new file mode 100644 index 0000000..d632058 --- /dev/null +++ b/Swiften/Network/HostAddressPort.h @@ -0,0 +1,30 @@ +#ifndef SWIFTEN_HostAddressPort_H +#define SWIFTEN_HostAddressPort_H + +#include "Swiften/Network/HostAddress.h" + +namespace Swift { + class HostAddressPort { + public: + HostAddressPort(const HostAddress& address, int port) : address_(address), port_(port) { + } + + const HostAddress& getAddress() const { + return address_; + } + + int getPort() const { + return port_; + } + + bool operator==(const HostAddressPort& o) const { + return address_ == o.address_ && port_ == o.port_; + } + + private: + HostAddress address_; + int port_; + }; +} + +#endif diff --git a/Swiften/Network/MainBoostIOServiceThread.cpp b/Swiften/Network/MainBoostIOServiceThread.cpp new file mode 100644 index 0000000..672bb07 --- /dev/null +++ b/Swiften/Network/MainBoostIOServiceThread.cpp @@ -0,0 +1,12 @@ +#include "Swiften/Network/MainBoostIOServiceThread.h" + +#include "Swiften/Network/BoostIOServiceThread.h" + +namespace Swift { + +BoostIOServiceThread& MainBoostIOServiceThread::getInstance() { + static BoostIOServiceThread instance; + return instance; +} + +} diff --git a/Swiften/Network/MainBoostIOServiceThread.h b/Swiften/Network/MainBoostIOServiceThread.h new file mode 100644 index 0000000..cca7c2e --- /dev/null +++ b/Swiften/Network/MainBoostIOServiceThread.h @@ -0,0 +1,10 @@ +#pragma once + +namespace Swift { + class BoostIOServiceThread; + + class MainBoostIOServiceThread { + public: + static BoostIOServiceThread& getInstance(); + }; +} diff --git a/Swiften/Network/PlatformDomainNameResolver.cpp b/Swiften/Network/PlatformDomainNameResolver.cpp new file mode 100644 index 0000000..7b8a6d5 --- /dev/null +++ b/Swiften/Network/PlatformDomainNameResolver.cpp @@ -0,0 +1,94 @@ +#include "Swiften/Network/PlatformDomainNameResolver.h" + +// Putting this early on, because some system types conflict with thread +#include "Swiften/Network/PlatformDomainNameServiceQuery.h" + +#include <string> +#include <vector> +#include <boost/asio.hpp> +#include <boost/bind.hpp> +#include <boost/thread.hpp> +#include <boost/enable_shared_from_this.hpp> +#include <algorithm> + +#include "Swiften/Base/String.h" +#include "Swiften/Network/HostAddress.h" +#include "Swiften/EventLoop/MainEventLoop.h" +#include "Swiften/Network/HostAddressPort.h" +#include "Swiften/Network/DomainNameAddressQuery.h" + +using namespace Swift; + +namespace { + struct AddressQuery : public DomainNameAddressQuery, public boost::enable_shared_from_this<AddressQuery>, public EventOwner { + AddressQuery(const String& host) : hostname(host), thread(NULL), safeToJoin(false) {} + + ~AddressQuery() { + if (safeToJoin) { + thread->join(); + } + else { + // FIXME: UGLYYYYY + } + delete thread; + } + + void run() { + safeToJoin = false; + thread = new boost::thread(boost::bind(&AddressQuery::doRun, shared_from_this())); + } + + void doRun() { + //std::cout << "PlatformDomainNameResolver::doRun()" << std::endl; + boost::asio::ip::tcp::resolver resolver(ioService); + boost::asio::ip::tcp::resolver::query query(hostname.getUTF8String(), "5222"); + try { + //std::cout << "PlatformDomainNameResolver::doRun(): Resolving" << std::endl; + boost::asio::ip::tcp::resolver::iterator endpointIterator = resolver.resolve(query); + //std::cout << "PlatformDomainNameResolver::doRun(): Resolved" << std::endl; + if (endpointIterator == boost::asio::ip::tcp::resolver::iterator()) { + //std::cout << "PlatformDomainNameResolver::doRun(): Error 1" << std::endl; + emitError(); + } + else { + boost::asio::ip::address address = (*endpointIterator).endpoint().address(); + HostAddress result = (address.is_v4() ? HostAddress(&address.to_v4().to_bytes()[0], 4) : HostAddress(&address.to_v6().to_bytes()[0], 16)); + //std::cout << "PlatformDomainNameResolver::doRun(): Success" << std::endl; + MainEventLoop::postEvent( + boost::bind(boost::ref(onResult), result, boost::optional<DomainNameResolveError>()), + shared_from_this()); + } + } + catch (...) { + //std::cout << "PlatformDomainNameResolver::doRun(): Error 2" << std::endl; + emitError(); + } + safeToJoin = true; + } + + void emitError() { + MainEventLoop::postEvent(boost::bind(boost::ref(onResult), HostAddress(), boost::optional<DomainNameResolveError>(DomainNameResolveError())), shared_from_this()); + } + + boost::asio::io_service ioService; + String hostname; + boost::thread* thread; + bool safeToJoin; + }; + +} + +namespace Swift { + +PlatformDomainNameResolver::PlatformDomainNameResolver() { +} + +boost::shared_ptr<DomainNameServiceQuery> PlatformDomainNameResolver::createServiceQuery(const String& name) { + return boost::shared_ptr<DomainNameServiceQuery>(new PlatformDomainNameServiceQuery(getNormalized(name))); +} + +boost::shared_ptr<DomainNameAddressQuery> PlatformDomainNameResolver::createAddressQuery(const String& name) { + return boost::shared_ptr<DomainNameAddressQuery>(new AddressQuery(getNormalized(name))); +} + +} diff --git a/Swiften/Network/PlatformDomainNameResolver.h b/Swiften/Network/PlatformDomainNameResolver.h new file mode 100644 index 0000000..4617b15 --- /dev/null +++ b/Swiften/Network/PlatformDomainNameResolver.h @@ -0,0 +1,15 @@ +#pragma once + +#include "Swiften/Network/DomainNameResolver.h" + +namespace Swift { + class String; + + class PlatformDomainNameResolver : public DomainNameResolver { + public: + PlatformDomainNameResolver(); + + virtual boost::shared_ptr<DomainNameServiceQuery> createServiceQuery(const String& name); + virtual boost::shared_ptr<DomainNameAddressQuery> createAddressQuery(const String& name); + }; +} diff --git a/Swiften/Network/PlatformDomainNameServiceQuery.cpp b/Swiften/Network/PlatformDomainNameServiceQuery.cpp new file mode 100644 index 0000000..bde851b --- /dev/null +++ b/Swiften/Network/PlatformDomainNameServiceQuery.cpp @@ -0,0 +1,170 @@ +#include "Swiften/Network/PlatformDomainNameServiceQuery.h" + +#include "Swiften/Base/Platform.h" +#include <stdlib.h> +#ifdef SWIFTEN_PLATFORM_WINDOWS +#undef UNICODE +#include <windows.h> +#include <windns.h> +#ifndef DNS_TYPE_SRV +#define DNS_TYPE_SRV 33 +#endif +#else +#include <arpa/nameser.h> +#include <arpa/nameser_compat.h> +#include <resolv.h> +#endif +#include <boost/bind.hpp> + +#include "Swiften/Base/ByteArray.h" +#include "Swiften/EventLoop/MainEventLoop.h" +#include "Swiften/Base/foreach.h" + +using namespace Swift; + +namespace Swift { + +PlatformDomainNameServiceQuery::PlatformDomainNameServiceQuery(const String& service) : thread(NULL), service(service), safeToJoin(true) { +} + +PlatformDomainNameServiceQuery::~PlatformDomainNameServiceQuery() { + if (safeToJoin) { + thread->join(); + } + else { + // FIXME: UGLYYYYY + } + delete thread; +} + +void PlatformDomainNameServiceQuery::run() { + safeToJoin = false; + thread = new boost::thread(boost::bind(&PlatformDomainNameServiceQuery::doRun, shared_from_this())); +} + +void PlatformDomainNameServiceQuery::doRun() { + std::vector<DomainNameServiceQuery::Result> records; + +#if defined(SWIFTEN_PLATFORM_WINDOWS) + DNS_RECORD* responses; + // FIXME: This conversion doesn't work if unicode is deffed above + if (DnsQuery(service.getUTF8Data(), DNS_TYPE_SRV, DNS_QUERY_STANDARD, NULL, &responses, NULL) != ERROR_SUCCESS) { + emitError(); + return; + } + + DNS_RECORD* currentEntry = responses; + while (currentEntry) { + if (currentEntry->wType == DNS_TYPE_SRV) { + DomainNameServiceQuery::Result record; + record.priority = currentEntry->Data.SRV.wPriority; + record.weight = currentEntry->Data.SRV.wWeight; + record.port = currentEntry->Data.SRV.wPort; + + // The pNameTarget is actually a PCWSTR, so I would have expected this + // conversion to not work at all, but it does. + // Actually, it doesn't. Fix this and remove explicit cast + // Remove unicode undef above as well + record.hostname = String((const char*) currentEntry->Data.SRV.pNameTarget); + records.push_back(record); + } + currentEntry = currentEntry->pNext; + } + DnsRecordListFree(responses, DnsFreeRecordList); + +#else + // Make sure we reinitialize the domain list every time + res_init(); + + //std::cout << "SRV: Querying " << service << std::endl; + ByteArray response; + response.resize(NS_PACKETSZ); + int responseLength = res_query(const_cast<char*>(service.getUTF8Data()), ns_c_in, ns_t_srv, reinterpret_cast<u_char*>(response.getData()), response.getSize()); + //std::cout << "res_query done " << (responseLength != -1) << std::endl; + if (responseLength == -1) { + emitError(); + return; + } + + // Parse header + HEADER* header = reinterpret_cast<HEADER*>(response.getData()); + unsigned char* messageStart = reinterpret_cast<unsigned char*>(response.getData()); + unsigned char* messageEnd = messageStart + responseLength; + unsigned char* currentEntry = messageStart + NS_HFIXEDSZ; + + // Skip over the queries + int queriesCount = ntohs(header->qdcount); + while (queriesCount > 0) { + int entryLength = dn_skipname(currentEntry, messageEnd); + if (entryLength < 0) { + emitError(); + return; + } + currentEntry += entryLength + NS_QFIXEDSZ; + queriesCount--; + } + + // Process the SRV answers + int answersCount = ntohs(header->ancount); + while (answersCount > 0) { + DomainNameServiceQuery::Result record; + + int entryLength = dn_skipname(currentEntry, messageEnd); + currentEntry += entryLength; + currentEntry += NS_RRFIXEDSZ; + + // Priority + if (currentEntry + 2 >= messageEnd) { + emitError(); + return; + } + record.priority = ns_get16(currentEntry); + currentEntry += 2; + + // Weight + if (currentEntry + 2 >= messageEnd) { + emitError(); + return; + } + record.weight = ns_get16(currentEntry); + currentEntry += 2; + + // Port + if (currentEntry + 2 >= messageEnd) { + emitError(); + return; + } + record.port = ns_get16(currentEntry); + currentEntry += 2; + + // Hostname + if (currentEntry >= messageEnd) { + emitError(); + return; + } + ByteArray entry; + entry.resize(NS_MAXDNAME); + entryLength = dn_expand(messageStart, messageEnd, currentEntry, entry.getData(), entry.getSize()); + if (entryLength < 0) { + emitError(); + return; + } + record.hostname = String(entry.getData()); + records.push_back(record); + currentEntry += entryLength; + answersCount--; + } +#endif + + safeToJoin = true; + std::sort(records.begin(), records.end(), ResultPriorityComparator()); + //std::cout << "Sending out " << records.size() << " SRV results " << std::endl; + MainEventLoop::postEvent(boost::bind(boost::ref(onResult), records)); +} + +void PlatformDomainNameServiceQuery::emitError() { + safeToJoin = true; + MainEventLoop::postEvent(boost::bind(boost::ref(onResult), std::vector<DomainNameServiceQuery::Result>()), shared_from_this()); +} + +} diff --git a/Swiften/Network/PlatformDomainNameServiceQuery.h b/Swiften/Network/PlatformDomainNameServiceQuery.h new file mode 100644 index 0000000..753e2c6 --- /dev/null +++ b/Swiften/Network/PlatformDomainNameServiceQuery.h @@ -0,0 +1,27 @@ +#pragma once + +#include <boost/thread.hpp> +#include <boost/enable_shared_from_this.hpp> + +#include "Swiften/Network/DomainNameServiceQuery.h" +#include "Swiften/EventLoop/EventOwner.h" +#include "Swiften/Base/String.h" + +namespace Swift { + class PlatformDomainNameServiceQuery : public DomainNameServiceQuery, public boost::enable_shared_from_this<PlatformDomainNameServiceQuery>, public EventOwner { + public: + PlatformDomainNameServiceQuery(const String& service); + ~PlatformDomainNameServiceQuery(); + + virtual void run(); + + private: + void doRun(); + void emitError(); + + private: + boost::thread* thread; + String service; + bool safeToJoin; + }; +} diff --git a/Swiften/Network/SConscript b/Swiften/Network/SConscript new file mode 100644 index 0000000..937ab0c --- /dev/null +++ b/Swiften/Network/SConscript @@ -0,0 +1,34 @@ +Import("swiften_env") + +myenv = swiften_env.Clone() +myenv.MergeFlags(myenv["LIBIDN_FLAGS"]) +if myenv["target"] == "native": + myenv.MergeFlags(myenv["CARES_FLAGS"]) + +sourceList = [ + "BoostConnection.cpp", + "BoostConnectionFactory.cpp", + "BoostConnectionServer.cpp", + "MainBoostIOServiceThread.cpp", + "BoostIOServiceThread.cpp", + "ConnectionFactory.cpp", + "ConnectionServer.cpp", + "Connector.cpp", + "TimerFactory.cpp", + "DummyTimerFactory.cpp", + "BoostTimerFactory.cpp", + "DomainNameResolver.cpp", + "DomainNameAddressQuery.cpp", + "DomainNameServiceQuery.cpp", + "PlatformDomainNameResolver.cpp", + "PlatformDomainNameServiceQuery.cpp", + "StaticDomainNameResolver.cpp", + "HostAddress.cpp", + "Timer.cpp", + "BoostTimer.cpp"] +if myenv["target"] == "native": + sourceList.append("CAresDomainNameResolver.cpp") + + +objects = myenv.StaticObject(sourceList) +swiften_env.Append(SWIFTEN_OBJECTS = [objects]) diff --git a/Swiften/Network/StaticDomainNameResolver.cpp b/Swiften/Network/StaticDomainNameResolver.cpp new file mode 100644 index 0000000..a7275d2 --- /dev/null +++ b/Swiften/Network/StaticDomainNameResolver.cpp @@ -0,0 +1,85 @@ +#include "Swiften/Network/StaticDomainNameResolver.h" + +#include <boost/bind.hpp> +#include <boost/lexical_cast.hpp> + +#include "Swiften/Network/DomainNameResolveError.h" +#include "Swiften/Base/String.h" + +using namespace Swift; + +namespace { + struct ServiceQuery : public DomainNameServiceQuery, public EventOwner { + ServiceQuery(const String& service, Swift::StaticDomainNameResolver* resolver) : service(service), resolver(resolver) {} + + virtual void run() { + if (!resolver->getIsResponsive()) { + return; + } + std::vector<DomainNameServiceQuery::Result> results; + for(StaticDomainNameResolver::ServicesCollection::const_iterator i = resolver->getServices().begin(); i != resolver->getServices().end(); ++i) { + if (i->first == service) { + results.push_back(i->second); + } + } + MainEventLoop::postEvent(boost::bind(boost::ref(onResult), results)); + } + + String service; + StaticDomainNameResolver* resolver; + }; + + struct AddressQuery : public DomainNameAddressQuery, public EventOwner { + AddressQuery(const String& host, StaticDomainNameResolver* resolver) : host(host), resolver(resolver) {} + + virtual void run() { + if (!resolver->getIsResponsive()) { + return; + } + StaticDomainNameResolver::AddressesMap::const_iterator i = resolver->getAddresses().find(host); + if (i != resolver->getAddresses().end()) { + MainEventLoop::postEvent( + boost::bind(boost::ref(onResult), i->second, boost::optional<DomainNameResolveError>())); + } + else { + MainEventLoop::postEvent(boost::bind(boost::ref(onResult), HostAddress(), boost::optional<DomainNameResolveError>(DomainNameResolveError()))); + } + + } + + String host; + StaticDomainNameResolver* resolver; + }; +} + +namespace Swift { + +StaticDomainNameResolver::StaticDomainNameResolver() : isResponsive(true) { +} + +void StaticDomainNameResolver::addAddress(const String& domain, const HostAddress& address) { + addresses[domain] = address; +} + +void StaticDomainNameResolver::addService(const String& service, const DomainNameServiceQuery::Result& result) { + services.push_back(std::make_pair(service, result)); +} + +void StaticDomainNameResolver::addXMPPClientService(const String& domain, const HostAddressPort& address) { + static int hostid = 0; + String hostname(std::string("host-") + boost::lexical_cast<std::string>(hostid)); + hostid++; + + addService("_xmpp-client._tcp." + domain, ServiceQuery::Result(hostname, address.getPort(), 0, 0)); + addAddress(hostname, address.getAddress()); +} + +boost::shared_ptr<DomainNameServiceQuery> StaticDomainNameResolver::createServiceQuery(const String& name) { + return boost::shared_ptr<DomainNameServiceQuery>(new ServiceQuery(name, this)); +} + +boost::shared_ptr<DomainNameAddressQuery> StaticDomainNameResolver::createAddressQuery(const String& name) { + return boost::shared_ptr<DomainNameAddressQuery>(new AddressQuery(name, this)); +} + +} diff --git a/Swiften/Network/StaticDomainNameResolver.h b/Swiften/Network/StaticDomainNameResolver.h new file mode 100644 index 0000000..d7e7ba4 --- /dev/null +++ b/Swiften/Network/StaticDomainNameResolver.h @@ -0,0 +1,52 @@ +#pragma once + +#include <vector> +#include <map> + +#include "Swiften/Network/HostAddress.h" +#include "Swiften/Network/HostAddressPort.h" +#include "Swiften/Network/DomainNameResolver.h" +#include "Swiften/Network/DomainNameServiceQuery.h" +#include "Swiften/Network/DomainNameAddressQuery.h" +#include "Swiften/EventLoop/MainEventLoop.h" + +namespace Swift { + class String; + + class StaticDomainNameResolver : public DomainNameResolver { + public: + typedef std::map<String, HostAddress> AddressesMap; + typedef std::vector< std::pair<String, DomainNameServiceQuery::Result> > ServicesCollection; + + public: + StaticDomainNameResolver(); + + void addAddress(const String& domain, const HostAddress& address); + void addService(const String& service, const DomainNameServiceQuery::Result& result); + void addXMPPClientService(const String& domain, const HostAddressPort&); + + const AddressesMap& getAddresses() const { + return addresses; + } + + const ServicesCollection& getServices() const { + return services; + } + + bool getIsResponsive() const { + return isResponsive; + } + + void setIsResponsive(bool b) { + isResponsive = b; + } + + virtual boost::shared_ptr<DomainNameServiceQuery> createServiceQuery(const String& name); + virtual boost::shared_ptr<DomainNameAddressQuery> createAddressQuery(const String& name); + + private: + bool isResponsive; + AddressesMap addresses; + ServicesCollection services; + }; +} diff --git a/Swiften/Network/Timer.cpp b/Swiften/Network/Timer.cpp new file mode 100644 index 0000000..a8d17c3 --- /dev/null +++ b/Swiften/Network/Timer.cpp @@ -0,0 +1,8 @@ +#include "Swiften/Network/Timer.h" + +namespace Swift { + +Timer::~Timer() { +} + +} diff --git a/Swiften/Network/Timer.h b/Swiften/Network/Timer.h new file mode 100644 index 0000000..9b01a0d --- /dev/null +++ b/Swiften/Network/Timer.h @@ -0,0 +1,15 @@ +#pragma once + +#include <boost/signals.hpp> + +namespace Swift { + class Timer { + public: + virtual ~Timer(); + + virtual void start() = 0; + virtual void stop() = 0; + + boost::signal<void ()> onTick; + }; +} diff --git a/Swiften/Network/TimerFactory.cpp b/Swiften/Network/TimerFactory.cpp new file mode 100644 index 0000000..642ac52 --- /dev/null +++ b/Swiften/Network/TimerFactory.cpp @@ -0,0 +1,8 @@ +#include "Swiften/Network/TimerFactory.h" + +namespace Swift { + +TimerFactory::~TimerFactory() { +} + +} diff --git a/Swiften/Network/TimerFactory.h b/Swiften/Network/TimerFactory.h new file mode 100644 index 0000000..f72a8fc --- /dev/null +++ b/Swiften/Network/TimerFactory.h @@ -0,0 +1,14 @@ +#pragma once + +#include <boost/shared_ptr.hpp> + +namespace Swift { + class Timer; + + class TimerFactory { + public: + virtual ~TimerFactory(); + + virtual boost::shared_ptr<Timer> createTimer(int milliseconds) = 0; + }; +} diff --git a/Swiften/Network/UnitTest/ConnectorTest.cpp b/Swiften/Network/UnitTest/ConnectorTest.cpp new file mode 100644 index 0000000..663011c --- /dev/null +++ b/Swiften/Network/UnitTest/ConnectorTest.cpp @@ -0,0 +1,245 @@ +#include <cppunit/extensions/HelperMacros.h> +#include <cppunit/extensions/TestFactoryRegistry.h> + +#include <boost/optional.hpp> +#include <boost/bind.hpp> + +#include "Swiften/Network/Connector.h" +#include "Swiften/Network/Connection.h" +#include "Swiften/Network/ConnectionFactory.h" +#include "Swiften/Network/HostAddressPort.h" +#include "Swiften/Network/StaticDomainNameResolver.h" +#include "Swiften/Network/DummyTimerFactory.h" +#include "Swiften/EventLoop/MainEventLoop.h" +#include "Swiften/EventLoop/DummyEventLoop.h" + +using namespace Swift; + +class ConnectorTest : public CppUnit::TestFixture { + CPPUNIT_TEST_SUITE(ConnectorTest); + CPPUNIT_TEST(testConnect); + CPPUNIT_TEST(testConnect_NoSRVHost); + CPPUNIT_TEST(testConnect_NoHosts); + CPPUNIT_TEST(testConnect_FirstSRVHostFails); + CPPUNIT_TEST(testConnect_AllSRVHostsFailWithoutFallbackHost); + CPPUNIT_TEST(testConnect_AllSRVHostsFailWithFallbackHost); + CPPUNIT_TEST(testConnect_SRVAndFallbackHostsFail); + CPPUNIT_TEST(testConnect_TimeoutDuringResolve); + CPPUNIT_TEST(testConnect_TimeoutDuringConnect); + CPPUNIT_TEST(testConnect_NoTimeout); + CPPUNIT_TEST_SUITE_END(); + + public: + ConnectorTest() : host1(HostAddress("1.1.1.1"), 1234), host2(HostAddress("2.2.2.2"), 2345), host3(HostAddress("3.3.3.3"), 5222) { + } + + void setUp() { + eventLoop = new DummyEventLoop(); + resolver = new StaticDomainNameResolver(); + connectionFactory = new MockConnectionFactory(); + timerFactory = new DummyTimerFactory(); + } + + void tearDown() { + delete timerFactory; + delete connectionFactory; + delete resolver; + delete eventLoop; + } + + void testConnect() { + std::auto_ptr<Connector> testling(createConnector()); + resolver->addXMPPClientService("foo.com", host1); + resolver->addXMPPClientService("foo.com", host2); + resolver->addAddress("foo.com", host3.getAddress()); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(connections[0]); + CPPUNIT_ASSERT(host1 == *(connections[0]->hostAddressPort)); + } + + void testConnect_NoSRVHost() { + std::auto_ptr<Connector> testling(createConnector()); + resolver->addAddress("foo.com", host3.getAddress()); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(connections[0]); + CPPUNIT_ASSERT(host3 == *(connections[0]->hostAddressPort)); + } + + void testConnect_NoHosts() { + std::auto_ptr<Connector> testling(createConnector()); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + } + + void testConnect_FirstSRVHostFails() { + std::auto_ptr<Connector> testling(createConnector()); + resolver->addXMPPClientService("foo.com", host1); + resolver->addXMPPClientService("foo.com", host2); + connectionFactory->failingPorts.push_back(host1); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(host2 == *(connections[0]->hostAddressPort)); + } + + void testConnect_AllSRVHostsFailWithoutFallbackHost() { + std::auto_ptr<Connector> testling(createConnector()); + resolver->addXMPPClientService("foo.com", host1); + resolver->addXMPPClientService("foo.com", host2); + connectionFactory->failingPorts.push_back(host1); + connectionFactory->failingPorts.push_back(host2); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + } + + void testConnect_AllSRVHostsFailWithFallbackHost() { + std::auto_ptr<Connector> testling(createConnector()); + resolver->addXMPPClientService("foo.com", host1); + resolver->addXMPPClientService("foo.com", host2); + resolver->addAddress("foo.com", host3.getAddress()); + connectionFactory->failingPorts.push_back(host1); + connectionFactory->failingPorts.push_back(host2); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(connections[0]); + CPPUNIT_ASSERT(host3 == *(connections[0]->hostAddressPort)); + } + + void testConnect_SRVAndFallbackHostsFail() { + std::auto_ptr<Connector> testling(createConnector()); + resolver->addXMPPClientService("foo.com", host1); + resolver->addAddress("foo.com", host3.getAddress()); + connectionFactory->failingPorts.push_back(host1); + connectionFactory->failingPorts.push_back(host3); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + } + + void testConnect_TimeoutDuringResolve() { + std::auto_ptr<Connector> testling(createConnector()); + testling->setTimeoutMilliseconds(10); + resolver->setIsResponsive(false); + + testling->start(); + eventLoop->processEvents(); + timerFactory->setTime(10); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + } + + void testConnect_TimeoutDuringConnect() { + std::auto_ptr<Connector> testling(createConnector()); + testling->setTimeoutMilliseconds(10); + resolver->addXMPPClientService("foo.com", host1); + connectionFactory->isResponsive = false; + + testling->start(); + eventLoop->processEvents(); + timerFactory->setTime(10); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + } + + void testConnect_NoTimeout() { + std::auto_ptr<Connector> testling(createConnector()); + testling->setTimeoutMilliseconds(10); + resolver->addXMPPClientService("foo.com", host1); + + testling->start(); + eventLoop->processEvents(); + timerFactory->setTime(10); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(connections[0]); + } + + + private: + Connector* createConnector() { + Connector* connector = new Connector("foo.com", resolver, connectionFactory, timerFactory); + connector->onConnectFinished.connect(boost::bind(&ConnectorTest::handleConnectorFinished, this, _1)); + return connector; + } + + void handleConnectorFinished(boost::shared_ptr<Connection> connection) { + boost::shared_ptr<MockConnection> c(boost::dynamic_pointer_cast<MockConnection>(connection)); + if (connection) { + assert(c); + } + connections.push_back(c); + } + + struct MockConnection : public Connection { + public: + MockConnection(const std::vector<HostAddressPort>& failingPorts, bool isResponsive) : failingPorts(failingPorts), isResponsive(isResponsive) {} + + void listen() { assert(false); } + void connect(const HostAddressPort& address) { + hostAddressPort = address; + if (isResponsive) { + MainEventLoop::postEvent(boost::bind(boost::ref(onConnectFinished), std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end())); + } + } + + void disconnect() { assert(false); } + void write(const ByteArray&) { assert(false); } + + boost::optional<HostAddressPort> hostAddressPort; + std::vector<HostAddressPort> failingPorts; + bool isResponsive; + }; + + struct MockConnectionFactory : public ConnectionFactory { + MockConnectionFactory() : isResponsive(true) { + } + + boost::shared_ptr<Connection> createConnection() { + return boost::shared_ptr<Connection>(new MockConnection(failingPorts, isResponsive)); + } + + bool isResponsive; + std::vector<HostAddressPort> failingPorts; + }; + + private: + HostAddressPort host1; + HostAddressPort host2; + HostAddressPort host3; + DummyEventLoop* eventLoop; + StaticDomainNameResolver* resolver; + MockConnectionFactory* connectionFactory; + DummyTimerFactory* timerFactory; + std::vector< boost::shared_ptr<MockConnection> > connections; +}; + +CPPUNIT_TEST_SUITE_REGISTRATION(ConnectorTest); diff --git a/Swiften/Network/UnitTest/HostAddressTest.cpp b/Swiften/Network/UnitTest/HostAddressTest.cpp new file mode 100644 index 0000000..50e9198 --- /dev/null +++ b/Swiften/Network/UnitTest/HostAddressTest.cpp @@ -0,0 +1,38 @@ +#include <cppunit/extensions/HelperMacros.h> +#include <cppunit/extensions/TestFactoryRegistry.h> + +#include "Swiften/Network/HostAddress.h" +#include "Swiften/Base/String.h" + +using namespace Swift; + +class HostAddressTest : public CppUnit::TestFixture { + CPPUNIT_TEST_SUITE(HostAddressTest); + CPPUNIT_TEST(testConstructor); + CPPUNIT_TEST(testToString); + CPPUNIT_TEST(testToString_IPv6); + CPPUNIT_TEST_SUITE_END(); + + public: + void testConstructor() { + HostAddress testling("192.168.1.254"); + + CPPUNIT_ASSERT_EQUAL(std::string("192.168.1.254"), testling.toString()); + } + + void testToString() { + unsigned char address[4] = {10, 0, 1, 253}; + HostAddress testling(address, 4); + + CPPUNIT_ASSERT_EQUAL(std::string("10.0.1.253"), testling.toString()); + } + + void testToString_IPv6() { + unsigned char address[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17}; + HostAddress testling(address, 16); + + CPPUNIT_ASSERT_EQUAL(std::string("0102:0304:0506:0708:090a:0b0c:0d0e:0f11"), testling.toString()); + } +}; + +CPPUNIT_TEST_SUITE_REGISTRATION(HostAddressTest); |