summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to 'Swiften/Network')
-rw-r--r--Swiften/Network/DomainNameServiceQuery.cpp55
-rw-r--r--Swiften/Network/DomainNameServiceQuery.h9
-rw-r--r--Swiften/Network/PlatformDomainNameServiceQuery.cpp4
-rw-r--r--Swiften/Network/UnitTest/DomainNameServiceQueryTest.cpp82
4 files changed, 143 insertions, 7 deletions
diff --git a/Swiften/Network/DomainNameServiceQuery.cpp b/Swiften/Network/DomainNameServiceQuery.cpp
index 5713b63..eb999e0 100644
--- a/Swiften/Network/DomainNameServiceQuery.cpp
+++ b/Swiften/Network/DomainNameServiceQuery.cpp
@@ -6,9 +6,64 @@
#include <Swiften/Network/DomainNameServiceQuery.h>
+#include <numeric>
+#include <cassert>
+
+#include <Swiften/Base/RandomGenerator.h>
+#include <boost/numeric/conversion/cast.hpp>
+
+using namespace Swift;
+
+namespace {
+ struct ResultPriorityComparator {
+ bool operator()(const DomainNameServiceQuery::Result& a, const DomainNameServiceQuery::Result& b) const {
+ return a.priority < b.priority;
+ }
+ };
+
+ struct WeightAccumulator {
+ int operator()(int accumulator, const DomainNameServiceQuery::Result& result) {
+ return accumulator + result.weight;
+ }
+ };
+
+ struct WeightToProbability {
+ WeightToProbability(int total) : total(total) {}
+
+ double operator()(const DomainNameServiceQuery::Result& result) {
+ return result.weight / boost::numeric_cast<double>(total);
+ }
+ int total;
+ };
+}
+
namespace Swift {
DomainNameServiceQuery::~DomainNameServiceQuery() {
}
+void DomainNameServiceQuery::sortResults(std::vector<DomainNameServiceQuery::Result>& queries, RandomGenerator& generator) {
+ ResultPriorityComparator comparator;
+ std::sort(queries.begin(), queries.end(), comparator);
+
+ std::vector<DomainNameServiceQuery::Result>::iterator i = queries.begin();
+ while (i != queries.end()) {
+ std::vector<DomainNameServiceQuery::Result>::iterator next = std::upper_bound(i, queries.end(), *i, comparator);
+ if (std::distance(i, next) > 1) {
+ int weightSum = std::accumulate(i, next, 0, WeightAccumulator());
+ std::vector<double> probabilities;
+ std::transform(i, next, std::back_inserter(probabilities), WeightToProbability(weightSum));
+
+ // Shuffling the result array and the probabilities in parallel
+ for (size_t j = 0; j < probabilities.size(); ++j) {
+ int selectedIndex = generator.generateWeighedRandomNumber(probabilities.begin() + j, probabilities.end());
+ std::swap(i[j], i[j + selectedIndex]);
+ std::swap(probabilities.begin()[j], probabilities.begin()[j + selectedIndex]);
+ }
+ }
+ i = next;
+ }
+}
+
+
}
diff --git a/Swiften/Network/DomainNameServiceQuery.h b/Swiften/Network/DomainNameServiceQuery.h
index 0bd1569..0e80233 100644
--- a/Swiften/Network/DomainNameServiceQuery.h
+++ b/Swiften/Network/DomainNameServiceQuery.h
@@ -15,6 +15,8 @@
#include <Swiften/Network/DomainNameResolveError.h>
namespace Swift {
+ class RandomGenerator;
+
class DomainNameServiceQuery {
public:
typedef boost::shared_ptr<DomainNameServiceQuery> ref;
@@ -27,15 +29,10 @@ namespace Swift {
int weight;
};
- struct ResultPriorityComparator {
- bool operator()(const DomainNameServiceQuery::Result& a, const DomainNameServiceQuery::Result& b) const {
- return a.priority < b.priority;
- }
- };
-
virtual ~DomainNameServiceQuery();
virtual void run() = 0;
+ static void sortResults(std::vector<DomainNameServiceQuery::Result>& queries, RandomGenerator& generator);
boost::signal<void (const std::vector<Result>&)> onResult;
};
diff --git a/Swiften/Network/PlatformDomainNameServiceQuery.cpp b/Swiften/Network/PlatformDomainNameServiceQuery.cpp
index 5d076ac..b0579a7 100644
--- a/Swiften/Network/PlatformDomainNameServiceQuery.cpp
+++ b/Swiften/Network/PlatformDomainNameServiceQuery.cpp
@@ -29,6 +29,7 @@
#include <Swiften/Base/ByteArray.h>
#include <Swiften/EventLoop/EventLoop.h>
#include <Swiften/Base/foreach.h>
+#include <Swiften/Base/BoostRandomGenerator.h>
#include <Swiften/Base/Log.h>
#include <Swiften/Network/PlatformDomainNameResolver.h>
@@ -158,7 +159,8 @@ void PlatformDomainNameServiceQuery::runBlocking() {
}
#endif
- std::sort(records.begin(), records.end(), ResultPriorityComparator());
+ BoostRandomGenerator generator;
+ DomainNameServiceQuery::sortResults(records, generator);
//std::cout << "Sending out " << records.size() << " SRV results " << std::endl;
eventLoop->postEvent(boost::bind(boost::ref(onResult), records), shared_from_this());
}
diff --git a/Swiften/Network/UnitTest/DomainNameServiceQueryTest.cpp b/Swiften/Network/UnitTest/DomainNameServiceQueryTest.cpp
new file mode 100644
index 0000000..aefd815
--- /dev/null
+++ b/Swiften/Network/UnitTest/DomainNameServiceQueryTest.cpp
@@ -0,0 +1,82 @@
+/*
+ * Copyright (c) 2012 Remko Tronçon
+ * Licensed under the GNU General Public License v3.
+ * See Documentation/Licenses/GPLv3.txt for more information.
+ */
+
+#include <QA/Checker/IO.h>
+
+#include <cppunit/extensions/HelperMacros.h>
+#include <cppunit/extensions/TestFactoryRegistry.h>
+
+#include <Swiften/Network/DomainNameServiceQuery.h>
+#include <Swiften/Base/RandomGenerator.h>
+
+using namespace Swift;
+
+namespace {
+ struct RandomGenerator1 : public RandomGenerator {
+ virtual int generateWeighedRandomNumber(std::vector<double>::const_iterator, std::vector<double>::const_iterator) {
+ return 0;
+ }
+ };
+
+ struct RandomGenerator2 : public RandomGenerator {
+ virtual int generateWeighedRandomNumber(std::vector<double>::const_iterator probabilities_begin, std::vector<double>::const_iterator probabilities_end) {
+ return std::max_element(probabilities_begin, probabilities_end) - probabilities_begin;
+ }
+ };
+}
+
+class DomainNameServiceQueryTest : public CppUnit::TestFixture {
+ CPPUNIT_TEST_SUITE(DomainNameServiceQueryTest);
+ CPPUNIT_TEST(testSortResults_Random1);
+ CPPUNIT_TEST(testSortResults_Random2);
+ CPPUNIT_TEST_SUITE_END();
+
+ public:
+ void testSortResults_Random1() {
+ std::vector<DomainNameServiceQuery::Result> results;
+ results.push_back(DomainNameServiceQuery::Result("server1.com", 5222, 5, 1));
+ results.push_back(DomainNameServiceQuery::Result("server2.com", 5222, 3, 10));
+ results.push_back(DomainNameServiceQuery::Result("server3.com", 5222, 6, 1));
+ results.push_back(DomainNameServiceQuery::Result("server4.com", 5222, 3, 20));
+ results.push_back(DomainNameServiceQuery::Result("server5.com", 5222, 2, 1));
+ results.push_back(DomainNameServiceQuery::Result("server6.com", 5222, 3, 10));
+
+ RandomGenerator1 generator;
+ DomainNameServiceQuery::sortResults(results, generator);
+
+ CPPUNIT_ASSERT_EQUAL(std::string("server5.com"), results[0].hostname);
+ CPPUNIT_ASSERT_EQUAL(std::string("server2.com"), results[1].hostname);
+ CPPUNIT_ASSERT_EQUAL(std::string("server4.com"), results[2].hostname);
+ CPPUNIT_ASSERT_EQUAL(std::string("server6.com"), results[3].hostname);
+ CPPUNIT_ASSERT_EQUAL(std::string("server1.com"), results[4].hostname);
+ CPPUNIT_ASSERT_EQUAL(std::string("server3.com"), results[5].hostname);
+ }
+
+ void testSortResults_Random2() {
+ std::vector<DomainNameServiceQuery::Result> results;
+ results.push_back(DomainNameServiceQuery::Result("server1.com", 5222, 5, 1));
+ results.push_back(DomainNameServiceQuery::Result("server2.com", 5222, 3, 10));
+ results.push_back(DomainNameServiceQuery::Result("server3.com", 5222, 6, 1));
+ results.push_back(DomainNameServiceQuery::Result("server4.com", 5222, 3, 20));
+ results.push_back(DomainNameServiceQuery::Result("server5.com", 5222, 2, 1));
+ results.push_back(DomainNameServiceQuery::Result("server6.com", 5222, 3, 10));
+ results.push_back(DomainNameServiceQuery::Result("server7.com", 5222, 3, 40));
+
+ RandomGenerator2 generator;
+ DomainNameServiceQuery::sortResults(results, generator);
+
+ CPPUNIT_ASSERT_EQUAL(std::string("server5.com"), results[0].hostname);
+ CPPUNIT_ASSERT_EQUAL(std::string("server7.com"), results[1].hostname);
+ CPPUNIT_ASSERT_EQUAL(std::string("server4.com"), results[2].hostname);
+ CPPUNIT_ASSERT_EQUAL(std::string("server6.com"), results[3].hostname);
+ CPPUNIT_ASSERT_EQUAL(std::string("server2.com"), results[4].hostname);
+ CPPUNIT_ASSERT_EQUAL(std::string("server1.com"), results[5].hostname);
+ CPPUNIT_ASSERT_EQUAL(std::string("server3.com"), results[6].hostname);
+ }
+};
+
+
+CPPUNIT_TEST_SUITE_REGISTRATION(DomainNameServiceQueryTest);