From 18f4f0ba13bbfe901dae44e95d869ba0425e93c7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Remko=20Tron=C3=A7on?= <git@el-tramo.be>
Date: Mon, 23 Apr 2012 15:32:24 +0200
Subject: Select SRV randomly, taking weight into account.

Resolves: #1030

diff --git a/3rdParty/Boost/src/boost/random/uniform_real.hpp b/3rdParty/Boost/src/boost/random/uniform_real.hpp
new file mode 100644
index 0000000..06bfbc3
--- /dev/null
+++ b/3rdParty/Boost/src/boost/random/uniform_real.hpp
@@ -0,0 +1,108 @@
+/* boost random/uniform_real.hpp header file
+ *
+ * Copyright Jens Maurer 2000-2001
+ * Distributed under the Boost Software License, Version 1.0. (See
+ * accompanying file LICENSE_1_0.txt or copy at
+ * http://www.boost.org/LICENSE_1_0.txt)
+ *
+ * See http://www.boost.org for most recent version including documentation.
+ *
+ * $Id: uniform_real.hpp 60755 2010-03-22 00:45:06Z steven_watanabe $
+ *
+ * Revision history
+ *  2001-04-08  added min<max assertion (N. Becker)
+ *  2001-02-18  moved to individual header files
+ */
+
+#ifndef BOOST_RANDOM_UNIFORM_REAL_HPP
+#define BOOST_RANDOM_UNIFORM_REAL_HPP
+
+#include <cassert>
+#include <iostream>
+#include <boost/config.hpp>
+#include <boost/limits.hpp>
+#include <boost/static_assert.hpp>
+#include <boost/random/detail/config.hpp>
+
+namespace boost {
+
+/**
+ * The distribution function uniform_real models a random distribution.
+ * On each invocation, it returns a random floating-point value uniformly
+ * distributed in the range [min..max). The value is computed using
+ * std::numeric_limits<RealType>::digits random binary digits, i.e.
+ * the mantissa of the floating-point value is completely filled with
+ * random bits.
+ *
+ * Note: The current implementation is buggy, because it may not fill
+ * all of the mantissa with random bits.
+ */
+template<class RealType = double>
+class uniform_real
+{
+public:
+  typedef RealType input_type;
+  typedef RealType result_type;
+
+  /**
+   * Constructs a uniform_real object. @c min and @c max are the
+   * parameters of the distribution.
+   *
+   * Requires: min <= max
+   */
+  explicit uniform_real(RealType min_arg = RealType(0),
+                        RealType max_arg = RealType(1))
+    : _min(min_arg), _max(max_arg)
+  {
+#ifndef BOOST_NO_LIMITS_COMPILE_TIME_CONSTANTS
+    BOOST_STATIC_ASSERT(!std::numeric_limits<RealType>::is_integer);
+#endif
+    assert(min_arg <= max_arg);
+  }
+
+  // compiler-generated copy ctor and assignment operator are fine
+
+  /**
+   * Returns: The "min" parameter of the distribution
+   */
+  result_type min BOOST_PREVENT_MACRO_SUBSTITUTION () const { return _min; }
+  /**
+   * Returns: The "max" parameter of the distribution
+   */
+  result_type max BOOST_PREVENT_MACRO_SUBSTITUTION () const { return _max; }
+  void reset() { }
+
+  template<class Engine>
+  result_type operator()(Engine& eng) {
+    result_type numerator = static_cast<result_type>(eng() - eng.min BOOST_PREVENT_MACRO_SUBSTITUTION());
+    result_type divisor = static_cast<result_type>(eng.max BOOST_PREVENT_MACRO_SUBSTITUTION() - eng.min BOOST_PREVENT_MACRO_SUBSTITUTION());
+    assert(divisor > 0);
+    assert(numerator >= 0 && numerator <= divisor);
+    return numerator / divisor * (_max - _min) + _min;
+  }
+
+#ifndef BOOST_RANDOM_NO_STREAM_OPERATORS
+  template<class CharT, class Traits>
+  friend std::basic_ostream<CharT,Traits>&
+  operator<<(std::basic_ostream<CharT,Traits>& os, const uniform_real& ud)
+  {
+    os << ud._min << " " << ud._max;
+    return os;
+  }
+
+  template<class CharT, class Traits>
+  friend std::basic_istream<CharT,Traits>&
+  operator>>(std::basic_istream<CharT,Traits>& is, uniform_real& ud)
+  {
+    is >> std::ws >> ud._min >> std::ws >> ud._max;
+    return is;
+  }
+#endif
+
+private:
+  RealType _min, _max;
+};
+
+} // namespace boost
+
+#endif // BOOST_RANDOM_UNIFORM_REAL_HPP
diff --git a/3rdParty/Boost/update.sh b/3rdParty/Boost/update.sh
index a7c0638..9b28f2d 100755
--- a/3rdParty/Boost/update.sh
+++ b/3rdParty/Boost/update.sh
@@ -11,6 +11,9 @@ fi
 
 ./bcp --boost="$1" \
 	tools/bcp \
+	algorithm/string.hpp \
+	asio.hpp \
+	assign/list_of.hpp \
 	bind.hpp \
 	cast.hpp \
 	date_time/posix_time/posix_time.hpp \
@@ -18,25 +21,25 @@ fi
 	foreach.hpp \
 	filesystem.hpp \
 	filesystem/fstream.hpp \
+	format.hpp \
+	logic/tribool.hpp \
 	noncopyable.hpp \
 	numeric/conversion/cast.hpp \
+	optional.hpp \
+	program_options.hpp \
+	random/mersenne_twister.hpp \
+	random/uniform_real.hpp \
+	random/variate_generator.hpp \
+	regex.hpp \
 	shared_ptr.hpp \
 	smart_ptr/make_shared.hpp \
-	optional.hpp \
 	signals.hpp \
-	program_options.hpp \
 	thread.hpp \
-	asio.hpp \
+	unordered_map.hpp \
 	uuid/uuid.hpp \
 	uuid/uuid_io.hpp \
 	uuid/uuid_generators.hpp \
 	variant.hpp \
-	regex.hpp \
-	unordered_map.hpp \
-	algorithm/string.hpp \
-	format.hpp \
-	logic/tribool.hpp \
-	assign/list_of.hpp \
 	$TARGET_DIR
 
 rm -rf $TARGET_DIR/libs/config
diff --git a/Swiften/Base/BoostRandomGenerator.cpp b/Swiften/Base/BoostRandomGenerator.cpp
new file mode 100644
index 0000000..b8c50d0
--- /dev/null
+++ b/Swiften/Base/BoostRandomGenerator.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2012 Remko Tronçon
+ * Licensed under the GNU General Public License v3.
+ * See Documentation/Licenses/GPLv3.txt for more information.
+ */
+
+#include <Swiften/Base/BoostRandomGenerator.h>
+
+#include <numeric>
+#include <boost/random/uniform_real.hpp>
+#include <boost/random/variate_generator.hpp>
+
+namespace Swift {
+
+int BoostRandomGenerator::generateWeighedRandomNumber(std::vector<double>::const_iterator probabilities_begin, std::vector<double>::const_iterator probabilities_end) {
+	// Only works starting boost 1.47
+	//boost::random::discrete_distribution<> distribution(weights.begin(), weights.end());
+	//return distribution(generator);
+
+	std::vector<double> cumulative;
+	std::partial_sum(probabilities_begin, probabilities_end, std::back_inserter(cumulative));
+	boost::uniform_real<> dist(0, cumulative.back());
+	boost::variate_generator<boost::mt19937&, boost::uniform_real<> > die(generator, dist);
+	return std::lower_bound(cumulative.begin(), cumulative.end(), die()) - cumulative.begin();
+}
+
+}
diff --git a/Swiften/Base/BoostRandomGenerator.h b/Swiften/Base/BoostRandomGenerator.h
new file mode 100644
index 0000000..ffc7a72
--- /dev/null
+++ b/Swiften/Base/BoostRandomGenerator.h
@@ -0,0 +1,21 @@
+/*
+ * Copyright (c) 2012 Remko Tronçon
+ * Licensed under the GNU General Public License v3.
+ * See Documentation/Licenses/GPLv3.txt for more information.
+ */
+
+#pragma once
+
+#include <Swiften/Base/RandomGenerator.h>
+
+#include <boost/random/mersenne_twister.hpp>
+
+namespace Swift {
+	class BoostRandomGenerator : public RandomGenerator{
+		public:
+			int generateWeighedRandomNumber(std::vector<double>::const_iterator probabilities_begin, std::vector<double>::const_iterator probabilities_end);
+
+		private:
+			boost::mt19937 generator;
+	};
+}
diff --git a/Swiften/Base/RandomGenerator.cpp b/Swiften/Base/RandomGenerator.cpp
new file mode 100644
index 0000000..f2dcca3
--- /dev/null
+++ b/Swiften/Base/RandomGenerator.cpp
@@ -0,0 +1,15 @@
+/*
+ * Copyright (c) 2012 Remko Tronçon
+ * Licensed under the GNU General Public License v3.
+ * See Documentation/Licenses/GPLv3.txt for more information.
+ */
+
+#include <Swiften/Base/RandomGenerator.h>
+
+namespace Swift {
+
+RandomGenerator::~RandomGenerator() {
+
+}
+
+}
diff --git a/Swiften/Base/RandomGenerator.h b/Swiften/Base/RandomGenerator.h
new file mode 100644
index 0000000..a998e0d
--- /dev/null
+++ b/Swiften/Base/RandomGenerator.h
@@ -0,0 +1,18 @@
+/*
+ * Copyright (c) 2012 Remko Tronçon
+ * Licensed under the GNU General Public License v3.
+ * See Documentation/Licenses/GPLv3.txt for more information.
+ */
+
+#pragma once
+
+#include <vector>
+
+namespace Swift {
+	class RandomGenerator {
+		public:
+			virtual ~RandomGenerator();
+
+			virtual int generateWeighedRandomNumber(std::vector<double>::const_iterator probabilities_begin, std::vector<double>::const_iterator probabilities_end) = 0;
+	};
+}
diff --git a/Swiften/Base/SConscript b/Swiften/Base/SConscript
index 1f07483..a5f3592 100644
--- a/Swiften/Base/SConscript
+++ b/Swiften/Base/SConscript
@@ -10,6 +10,8 @@ objects = swiften_env.SwiftenObject([
 			"String.cpp",
 			"IDGenerator.cpp",
 			"SimpleIDGenerator.cpp",
+			"RandomGenerator.cpp",
+			"BoostRandomGenerator.cpp",
 			"sleep.cpp",
 		])
 swiften_env.Append(SWIFTEN_OBJECTS = [objects])
diff --git a/Swiften/Network/DomainNameServiceQuery.cpp b/Swiften/Network/DomainNameServiceQuery.cpp
index 5713b63..eb999e0 100644
--- a/Swiften/Network/DomainNameServiceQuery.cpp
+++ b/Swiften/Network/DomainNameServiceQuery.cpp
@@ -6,9 +6,64 @@
 
 #include <Swiften/Network/DomainNameServiceQuery.h>
 
+#include <numeric>
+#include <cassert>
+
+#include <Swiften/Base/RandomGenerator.h>
+#include <boost/numeric/conversion/cast.hpp>
+
+using namespace Swift;
+
+namespace {
+	struct ResultPriorityComparator {
+		bool operator()(const DomainNameServiceQuery::Result& a, const DomainNameServiceQuery::Result& b) const {
+			return a.priority < b.priority;
+		}
+	};
+
+	struct WeightAccumulator {
+			int operator()(int accumulator, const DomainNameServiceQuery::Result& result) {
+				return accumulator + result.weight;
+			}
+	};
+
+	struct WeightToProbability {
+			WeightToProbability(int total) : total(total) {}
+
+			double operator()(const DomainNameServiceQuery::Result& result) {
+				return result.weight / boost::numeric_cast<double>(total);
+			}
+			int total;
+	};
+}
+
 namespace Swift {
 
 DomainNameServiceQuery::~DomainNameServiceQuery() {
 }
 
+void DomainNameServiceQuery::sortResults(std::vector<DomainNameServiceQuery::Result>& queries, RandomGenerator& generator) {
+	ResultPriorityComparator comparator;
+	std::sort(queries.begin(), queries.end(), comparator);
+
+	std::vector<DomainNameServiceQuery::Result>::iterator i = queries.begin();
+	while (i != queries.end()) {
+		std::vector<DomainNameServiceQuery::Result>::iterator next = std::upper_bound(i, queries.end(), *i, comparator);
+		if (std::distance(i, next) > 1) {
+			int weightSum = std::accumulate(i, next, 0, WeightAccumulator());
+			std::vector<double> probabilities;
+			std::transform(i, next, std::back_inserter(probabilities), WeightToProbability(weightSum));
+
+			// Shuffling the result array and the probabilities in parallel
+			for (size_t j = 0; j < probabilities.size(); ++j) {
+				int selectedIndex = generator.generateWeighedRandomNumber(probabilities.begin() + j, probabilities.end());
+				std::swap(i[j], i[j + selectedIndex]);
+				std::swap(probabilities.begin()[j], probabilities.begin()[j + selectedIndex]);
+			}
+		}
+		i = next;
+	}
+}
+
+
 }
diff --git a/Swiften/Network/DomainNameServiceQuery.h b/Swiften/Network/DomainNameServiceQuery.h
index 0bd1569..0e80233 100644
--- a/Swiften/Network/DomainNameServiceQuery.h
+++ b/Swiften/Network/DomainNameServiceQuery.h
@@ -15,6 +15,8 @@
 #include <Swiften/Network/DomainNameResolveError.h>
 
 namespace Swift {
+	class RandomGenerator;
+
 	class DomainNameServiceQuery {
 		public:
 			typedef boost::shared_ptr<DomainNameServiceQuery> ref;
@@ -27,15 +29,10 @@ namespace Swift {
 				int weight;
 			};
 
-			struct ResultPriorityComparator {
-				bool operator()(const DomainNameServiceQuery::Result& a, const DomainNameServiceQuery::Result& b) const {
-					return a.priority < b.priority;
-				}
-			};
-
 			virtual ~DomainNameServiceQuery();
 
 			virtual void run() = 0;
+			static void sortResults(std::vector<DomainNameServiceQuery::Result>& queries, RandomGenerator& generator);
 
 			boost::signal<void (const std::vector<Result>&)> onResult;
 	};
diff --git a/Swiften/Network/PlatformDomainNameServiceQuery.cpp b/Swiften/Network/PlatformDomainNameServiceQuery.cpp
index 5d076ac..b0579a7 100644
--- a/Swiften/Network/PlatformDomainNameServiceQuery.cpp
+++ b/Swiften/Network/PlatformDomainNameServiceQuery.cpp
@@ -29,6 +29,7 @@
 #include <Swiften/Base/ByteArray.h>
 #include <Swiften/EventLoop/EventLoop.h>
 #include <Swiften/Base/foreach.h>
+#include <Swiften/Base/BoostRandomGenerator.h>
 #include <Swiften/Base/Log.h>
 #include <Swiften/Network/PlatformDomainNameResolver.h>
 
@@ -158,7 +159,8 @@ void PlatformDomainNameServiceQuery::runBlocking() {
 	}
 #endif
 
-	std::sort(records.begin(), records.end(), ResultPriorityComparator());
+	BoostRandomGenerator generator;
+	DomainNameServiceQuery::sortResults(records, generator);
 	//std::cout << "Sending out " << records.size() << " SRV results " << std::endl;
 	eventLoop->postEvent(boost::bind(boost::ref(onResult), records), shared_from_this());
 }
diff --git a/Swiften/Network/UnitTest/DomainNameServiceQueryTest.cpp b/Swiften/Network/UnitTest/DomainNameServiceQueryTest.cpp
new file mode 100644
index 0000000..aefd815
--- /dev/null
+++ b/Swiften/Network/UnitTest/DomainNameServiceQueryTest.cpp
@@ -0,0 +1,82 @@
+/*
+ * Copyright (c) 2012 Remko Tronçon
+ * Licensed under the GNU General Public License v3.
+ * See Documentation/Licenses/GPLv3.txt for more information.
+ */
+
+#include <QA/Checker/IO.h>
+
+#include <cppunit/extensions/HelperMacros.h>
+#include <cppunit/extensions/TestFactoryRegistry.h>
+
+#include <Swiften/Network/DomainNameServiceQuery.h>
+#include <Swiften/Base/RandomGenerator.h>
+
+using namespace Swift;
+
+namespace {
+	struct RandomGenerator1 : public RandomGenerator {
+		virtual int generateWeighedRandomNumber(std::vector<double>::const_iterator, std::vector<double>::const_iterator) {
+			return 0;
+		}
+	};
+
+	struct RandomGenerator2 : public RandomGenerator {
+		virtual int generateWeighedRandomNumber(std::vector<double>::const_iterator probabilities_begin, std::vector<double>::const_iterator probabilities_end) {
+			return std::max_element(probabilities_begin, probabilities_end) - probabilities_begin;
+		}
+	};
+}
+
+class DomainNameServiceQueryTest : public CppUnit::TestFixture {
+		CPPUNIT_TEST_SUITE(DomainNameServiceQueryTest);
+		CPPUNIT_TEST(testSortResults_Random1);
+		CPPUNIT_TEST(testSortResults_Random2);
+		CPPUNIT_TEST_SUITE_END();
+
+	public:
+		void testSortResults_Random1() {
+			std::vector<DomainNameServiceQuery::Result> results;
+			results.push_back(DomainNameServiceQuery::Result("server1.com", 5222, 5, 1));
+			results.push_back(DomainNameServiceQuery::Result("server2.com", 5222, 3, 10));
+			results.push_back(DomainNameServiceQuery::Result("server3.com", 5222, 6, 1));
+			results.push_back(DomainNameServiceQuery::Result("server4.com", 5222, 3, 20));
+			results.push_back(DomainNameServiceQuery::Result("server5.com", 5222, 2, 1));
+			results.push_back(DomainNameServiceQuery::Result("server6.com", 5222, 3, 10));
+
+			RandomGenerator1 generator;
+			DomainNameServiceQuery::sortResults(results, generator);
+
+			CPPUNIT_ASSERT_EQUAL(std::string("server5.com"), results[0].hostname);
+			CPPUNIT_ASSERT_EQUAL(std::string("server2.com"), results[1].hostname);
+			CPPUNIT_ASSERT_EQUAL(std::string("server4.com"), results[2].hostname);
+			CPPUNIT_ASSERT_EQUAL(std::string("server6.com"), results[3].hostname);
+			CPPUNIT_ASSERT_EQUAL(std::string("server1.com"), results[4].hostname);
+			CPPUNIT_ASSERT_EQUAL(std::string("server3.com"), results[5].hostname);
+		}
+
+		void testSortResults_Random2() {
+			std::vector<DomainNameServiceQuery::Result> results;
+			results.push_back(DomainNameServiceQuery::Result("server1.com", 5222, 5, 1));
+			results.push_back(DomainNameServiceQuery::Result("server2.com", 5222, 3, 10));
+			results.push_back(DomainNameServiceQuery::Result("server3.com", 5222, 6, 1));
+			results.push_back(DomainNameServiceQuery::Result("server4.com", 5222, 3, 20));
+			results.push_back(DomainNameServiceQuery::Result("server5.com", 5222, 2, 1));
+			results.push_back(DomainNameServiceQuery::Result("server6.com", 5222, 3, 10));
+			results.push_back(DomainNameServiceQuery::Result("server7.com", 5222, 3, 40));
+
+			RandomGenerator2 generator;
+			DomainNameServiceQuery::sortResults(results, generator);
+
+			CPPUNIT_ASSERT_EQUAL(std::string("server5.com"), results[0].hostname);
+			CPPUNIT_ASSERT_EQUAL(std::string("server7.com"), results[1].hostname);
+			CPPUNIT_ASSERT_EQUAL(std::string("server4.com"), results[2].hostname);
+			CPPUNIT_ASSERT_EQUAL(std::string("server6.com"), results[3].hostname);
+			CPPUNIT_ASSERT_EQUAL(std::string("server2.com"), results[4].hostname);
+			CPPUNIT_ASSERT_EQUAL(std::string("server1.com"), results[5].hostname);
+			CPPUNIT_ASSERT_EQUAL(std::string("server3.com"), results[6].hostname);
+		}
+};
+
+
+CPPUNIT_TEST_SUITE_REGISTRATION(DomainNameServiceQueryTest);
diff --git a/Swiften/SConscript b/Swiften/SConscript
index 6308a80..2e0b73b 100644
--- a/Swiften/SConscript
+++ b/Swiften/SConscript
@@ -294,6 +294,7 @@ if env["SCONS_STAGE"] == "build" :
 			File("Network/UnitTest/HostAddressTest.cpp"),
 			File("Network/UnitTest/ConnectorTest.cpp"),
 			File("Network/UnitTest/ChainedConnectorTest.cpp"),
+			File("Network/UnitTest/DomainNameServiceQueryTest.cpp"),	
 			File("Network/UnitTest/HTTPConnectProxiedConnectionTest.cpp"),
 			File("Network/UnitTest/BOSHConnectionTest.cpp"),
 			File("Network/UnitTest/BOSHConnectionPoolTest.cpp"),
-- 
cgit v0.10.2-6-g49f6