From c2886c9ff6152130e2adb006f84268f972e629cc Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Remko=20Tron=C3=A7on?= <git@el-tramo.be>
Date: Fri, 25 Feb 2011 23:53:06 +0100
Subject: Some more sluift tweaks.


diff --git a/Sluift/client_test.lua b/Sluift/client_test.lua
index 4eebf0c..1652083 100644
--- a/Sluift/client_test.lua
+++ b/Sluift/client_test.lua
@@ -6,11 +6,13 @@ client2_jid = os.getenv("SWIFT_CLIENTTEST_JID") .. "/Client2"
 password = os.getenv("SWIFT_CLIENTTEST_PASS")
 
 print "Connecting client 1"
-client1 = sluift.connect(client1_jid, password)
+client1 = sluift.new_client(client1_jid, password)
+client1:connect()
 client1:send_presence("I'm here")
 
 print "Connecting client 2"
-client2 = sluift.connect(client2_jid, password)
+client2 = sluift.new_client(client2_jid, password)
+client2:connect()
 client2:send_presence("I'm here")
 
 print "Checking version of client 2 from client 1"
@@ -28,5 +30,9 @@ received_message = client2:for_event(function(event)
 	end)
 assert(received_message == "Hello")
 
+print "Retrieving the roster"
+roster = client1:get_roster()
+table.foreach(roster, print)
+
 client1:disconnect()
 client2:disconnect()
diff --git a/Sluift/multiple_client_test.lua b/Sluift/multiple_client_test.lua
new file mode 100644
index 0000000..d2b2cd7
--- /dev/null
+++ b/Sluift/multiple_client_test.lua
@@ -0,0 +1,26 @@
+require "sluift"
+
+-- sluift.debug = true
+num_clients = 10
+
+print("Connecting clients")
+clients = {}
+for i = 1, num_clients do
+	jid = os.getenv("SWIFT_CLIENTTEST_JID") .. "/Client" .. i
+	client = sluift.new_client(jid, os.getenv("SWIFT_CLIENTTEST_PASS"))
+	client:async_connect()
+	table.insert(clients, client)
+end
+
+print("Waiting for clients to be connected")
+for i, client in ipairs(clients) do
+	client:wait_connected()
+	client:send_presence("Hello")
+end
+
+print("Disconnecting clients")
+for i, client in ipairs(clients) do
+	client:disconnect()
+end
+
+print("Done")
diff --git a/Sluift/sluift.cpp b/Sluift/sluift.cpp
index 84643b8..b3ca8dc 100644
--- a/Sluift/sluift.cpp
+++ b/Sluift/sluift.cpp
@@ -21,6 +21,7 @@ extern "C" {
 #include <Swiften/Base/sleep.h>
 #include <Swiften/Elements/SoftwareVersion.h>
 #include <Swiften/Queries/Requests/GetSoftwareVersionRequest.h>
+#include <Swiften/Roster/XMPPRoster.h>
 
 using namespace Swift;
 
@@ -39,13 +40,28 @@ bool debug = false;
 SimpleEventLoop eventLoop;
 BoostNetworkFactories networkFactories(&eventLoop);
 
+class SluiftException {
+	public:
+		SluiftException(const std::string& reason) : reason(reason) {
+		}
+
+		const std::string& getReason() const {
+			return reason;
+		}
+	
+	private:
+		std::string reason;
+};
+
 class SluiftClient {
 	public:
 		SluiftClient(const JID& jid, const std::string& password) : tracer(NULL) {
 			client = new Client(jid, password, &networkFactories);
 			client->setAlwaysTrustCertificates();
+			client->onDisconnected.connect(boost::bind(&SluiftClient::handleDisconnected, this, _1));
 			client->onMessageReceived.connect(boost::bind(&SluiftClient::handleIncomingEvent, this, _1));
 			client->onPresenceReceived.connect(boost::bind(&SluiftClient::handleIncomingEvent, this, _1));
+			client->getRoster()->onInitialRosterPopulated.connect(boost::bind(&SluiftClient::handleInitialRosterPopulated, this));
 			if (debug) {
 				tracer = new ClientXMLTracer(client);
 			}
@@ -57,7 +73,11 @@ class SluiftClient {
 		}
 
 		void connect() {
+			rosterReceived = false;
 			client->connect();
+		}
+
+		void waitConnected() {
 			while (client->isActive() && !client->isAvailable()) {
 				processEvents();
 			}
@@ -121,11 +141,23 @@ class SluiftClient {
 			}
 		}
 
+		std::vector<XMPPRosterItem> getRoster() {
+			client->requestRoster();
+			while (!rosterReceived) {
+				processEvents();
+			}
+			return client->getRoster()->getItems();
+		}
+
 	private:
 		void handleIncomingEvent(Stanza::ref stanza) {
 			pendingEvents.push_back(stanza);
 		}
 
+		void handleInitialRosterPopulated() {
+			rosterReceived = true;
+		}
+
 		void processEvents() {
 			eventLoop.runUntilEvents();
 		}
@@ -141,15 +173,57 @@ class SluiftClient {
 				this->softwareVersion = SoftwareVersion("", "", "");
 			}
 		}
+
+		void handleDisconnected(const boost::optional<ClientError>& error) {
+			if (error) {
+				std::string reason("Disconnected: ");
+				switch(error->getType()) {
+					case ClientError::UnknownError: reason += "Unknown Error"; break;
+					case ClientError::DomainNameResolveError: reason += "Unable to find server"; break;
+					case ClientError::ConnectionError: reason += "Error connecting to server"; break;
+					case ClientError::ConnectionReadError: reason += "Error while receiving server data"; break;
+					case ClientError::ConnectionWriteError: reason += "Error while sending data to the server"; break;
+					case ClientError::XMLError: reason += "Error parsing server data"; break;
+					case ClientError::AuthenticationFailedError: reason += "Login/password invalid"; break;
+					case ClientError::CompressionFailedError: reason += "Error while compressing stream"; break;
+					case ClientError::ServerVerificationFailedError: reason += "Server verification failed"; break;
+					case ClientError::NoSupportedAuthMechanismsError: reason += "Authentication mechanisms not supported"; break;
+					case ClientError::UnexpectedElementError: reason += "Unexpected response"; break;
+					case ClientError::ResourceBindError: reason += "Error binding resource"; break;
+					case ClientError::SessionStartError: reason += "Error starting session"; break;
+					case ClientError::StreamError: reason += "Stream error"; break;
+					case ClientError::TLSError: reason += "Encryption error"; break;
+					case ClientError::ClientCertificateLoadError: reason += "Error loading certificate (Invalid password?)"; break;
+					case ClientError::ClientCertificateError: reason += "Certificate not authorized"; break;
+					case ClientError::UnknownCertificateError: reason += "Unknown certificate"; break;
+					case ClientError::CertificateExpiredError: reason += "Certificate has expired"; break;
+					case ClientError::CertificateNotYetValidError: reason += "Certificate is not yet valid"; break;
+					case ClientError::CertificateSelfSignedError: reason += "Certificate is self-signed"; break;
+					case ClientError::CertificateRejectedError: reason += "Certificate has been rejected"; break;
+					case ClientError::CertificateUntrustedError: reason += "Certificate is not trusted"; break;
+					case ClientError::InvalidCertificatePurposeError: reason += "Certificate cannot be used for encrypting your connection"; break;
+					case ClientError::CertificatePathLengthExceededError: reason += "Certificate path length constraint exceeded"; break;
+					case ClientError::InvalidCertificateSignatureError: reason += "Invalid certificate signature"; break;
+					case ClientError::InvalidCAError: reason += "Invalid Certificate Authority"; break;
+					case ClientError::InvalidServerIdentityError: reason += "Certificate does not match the host identity"; break;
+				}
+				throw SluiftException(reason);
+			}
+		}
 	
 	private:
 		Client* client;
 		ClientXMLTracer* tracer;
+		bool rosterReceived;
 		boost::optional<SoftwareVersion> softwareVersion;
 		ErrorPayload::ref error;
 		std::deque<Stanza::ref> pendingEvents;
 };
 
+#define CHECK_CLIENT_CONNECTED(client, L) \
+	if (!(*client)->isConnected()) { \
+		lua_pushnil(L); \
+	} 
 
 /*******************************************************************************
  * Client functions.
@@ -159,14 +233,51 @@ static inline SluiftClient* getClient(lua_State* L) {
 	return *reinterpret_cast<SluiftClient**>(luaL_checkudata(L, 1, SLUIFT_CLIENT));
 }
 
+static int sluift_client_connect(lua_State *L) {
+	try {
+		SluiftClient* client = getClient(L);
+		client->connect();
+		client->waitConnected();
+		return 0;
+	}
+	catch (const SluiftException& e) {
+		return luaL_error(L, e.getReason().c_str());
+	}
+}
+
+static int sluift_client_async_connect(lua_State *L) {
+	try {
+		getClient(L)->connect();
+		return 0;
+	}
+	catch (const SluiftException& e) {
+		return luaL_error(L, e.getReason().c_str());
+	}
+}
+
+static int sluift_client_wait_connected(lua_State *L) {
+	try {
+		getClient(L)->waitConnected();
+		return 0;
+	}
+	catch (const SluiftException& e) {
+		return luaL_error(L, e.getReason().c_str());
+	}
+}
+
 static int sluift_client_is_connected(lua_State *L) {
 	lua_pushboolean(L, getClient(L)->isConnected());
 	return 1;
 }
 
 static int sluift_client_disconnect(lua_State *L) {
-	getClient(L)->disconnect();
-	return 0;
+	try {
+		getClient(L)->disconnect();
+		return 0;
+	}
+	catch (const SluiftException& e) {
+		return luaL_error(L, e.getReason().c_str());
+	}
 }
 
 static int sluift_client_set_version(lua_State *L) {
@@ -183,24 +294,80 @@ static int sluift_client_set_version(lua_State *L) {
 	return 0;
 }
 
-static int sluift_client_get_version(lua_State *L) {
-	SluiftClient* client = getClient(L);
-	JID jid(std::string(luaL_checkstring(L, 2)));
+static int sluift_client_get_roster(lua_State *L) {
+	try {
+		SluiftClient* client = getClient(L);
+		std::vector<XMPPRosterItem> items = client->getRoster();
+
+		lua_createtable(L, 0, items.size());
+		foreach(const XMPPRosterItem& item, items) {
+			lua_createtable(L, 0, 3);
+
+			lua_pushstring(L, item.getName().c_str());
+			lua_setfield(L, -2, "name");
+
+			std::string subscription;
+			switch(item.getSubscription()) {
+				case RosterItemPayload::None: subscription = "none"; break;
+				case RosterItemPayload::To: subscription = "to"; break;
+				case RosterItemPayload::From: subscription = "from"; break;
+				case RosterItemPayload::Both: subscription = "both"; break;
+				case RosterItemPayload::Remove: subscription = "remove"; break;
+			}
+			lua_pushstring(L, subscription.c_str());
+			lua_setfield(L, -2, "subscription");
+
+			std::vector<std::string> groups = item.getGroups();
+			lua_createtable(L, groups.size(), 0);
+			for (size_t i = 0; i < groups.size(); ++i) {
+				lua_pushstring(L, groups[i].c_str());
+				lua_rawseti(L, -2, i + 1);
+			}
+			lua_setfield(L, -2, "groups");
 
-	boost::optional<SoftwareVersion> version = client->getSoftwareVersion(jid);
-	if (version) {
-		lua_createtable(L, 0, 3);
+			lua_setfield(L, -2, item.getJID().toString().c_str());
+		}
+		/*boost::optional<SoftwareVersion> version = client->getSoftwareVersion(jid);
+		if (version) {
 		lua_pushstring(L, version->getName().c_str());
-		lua_setfield(L, -2, "name");
-		lua_pushstring(L, version->getVersion().c_str());
-		lua_setfield(L, -2, "version");
-		lua_pushstring(L, version->getOS().c_str());
-		lua_setfield(L, -2, "os");
+			lua_pushstring(L, version->getName().c_str());
+			lua_setfield(L, -2, "name");
+			lua_pushstring(L, version->getVersion().c_str());
+			lua_setfield(L, -2, "version");
+			lua_pushstring(L, version->getOS().c_str());
+			lua_setfield(L, -2, "os");
+		}
+		*/
+		return 1;
 	}
-	else {
-		lua_pushnil(L);
+	catch (const SluiftException& e) {
+		return luaL_error(L, e.getReason().c_str());
+	}
+}
+
+static int sluift_client_get_version(lua_State *L) {
+	try {
+		SluiftClient* client = getClient(L);
+		JID jid(std::string(luaL_checkstring(L, 2)));
+
+		boost::optional<SoftwareVersion> version = client->getSoftwareVersion(jid);
+		if (version) {
+			lua_createtable(L, 0, 3);
+			lua_pushstring(L, version->getName().c_str());
+			lua_setfield(L, -2, "name");
+			lua_pushstring(L, version->getVersion().c_str());
+			lua_setfield(L, -2, "version");
+			lua_pushstring(L, version->getOS().c_str());
+			lua_setfield(L, -2, "os");
+		}
+		else {
+			lua_pushnil(L);
+		}
+		return 1;
+	}
+	catch (const SluiftException& e) {
+		return luaL_error(L, e.getReason().c_str());
 	}
-	return 1;
 }
 
 static int sluift_client_send_message(lua_State *L) {
@@ -214,60 +381,65 @@ static int sluift_client_send_presence(lua_State *L) {
 }
 
 static int sluift_client_for_event (lua_State *L) {
-	SluiftClient* client = getClient(L);
-	luaL_checktype(L, 2, LUA_TFUNCTION);
-	while (true) {
-		Stanza::ref event = client->getNextEvent();
-		if (!event) {
-			// We got disconnected
-			lua_pushnil(L);
-			lua_pushliteral(L, "disconnected");
-			return 2;
-		}
-		else {
-			// Push the function on the stack
-			lua_pushvalue(L, 2);
-
-			bool emitEvent = false;
-			if (Message::ref message = boost::dynamic_pointer_cast<Message>(event)) {
-				lua_createtable(L, 0, 3);
-				lua_pushliteral(L, "message");
-				lua_setfield(L, -2, "type");
-				lua_pushstring(L, message->getFrom().toString().c_str());
-				lua_setfield(L, -2, "from");
-				lua_pushstring(L, message->getBody().c_str());
-				lua_setfield(L, -2, "body");
-				emitEvent = true;
-			}
-			else if (Presence::ref presence = boost::dynamic_pointer_cast<Presence>(event)) {
-				lua_createtable(L, 0, 3);
-				lua_pushliteral(L, "presence");
-				lua_setfield(L, -2, "type");
-				lua_pushstring(L, presence->getFrom().toString().c_str());
-				lua_setfield(L, -2, "from");
-				lua_pushstring(L, presence->getStatus().c_str());
-				lua_setfield(L, -2, "status");
-				emitEvent = true;
+	try {
+		SluiftClient* client = getClient(L);
+		luaL_checktype(L, 2, LUA_TFUNCTION);
+		while (true) {
+			Stanza::ref event = client->getNextEvent();
+			if (!event) {
+				// We got disconnected
+				lua_pushnil(L);
+				lua_pushliteral(L, "disconnected");
+				return 2;
 			}
 			else {
-				assert(false);
-			}
-			if (emitEvent) {
-				int oldTop = lua_gettop(L) - 2;
-				lua_call(L, 1, LUA_MULTRET);
-				int returnValues = lua_gettop(L) - oldTop;
-				if (returnValues > 0) {
-					lua_remove(L, -1 - returnValues);
-					return returnValues;
+				// Push the function on the stack
+				lua_pushvalue(L, 2);
+
+				bool emitEvent = false;
+				if (Message::ref message = boost::dynamic_pointer_cast<Message>(event)) {
+					lua_createtable(L, 0, 3);
+					lua_pushliteral(L, "message");
+					lua_setfield(L, -2, "type");
+					lua_pushstring(L, message->getFrom().toString().c_str());
+					lua_setfield(L, -2, "from");
+					lua_pushstring(L, message->getBody().c_str());
+					lua_setfield(L, -2, "body");
+					emitEvent = true;
+				}
+				else if (Presence::ref presence = boost::dynamic_pointer_cast<Presence>(event)) {
+					lua_createtable(L, 0, 3);
+					lua_pushliteral(L, "presence");
+					lua_setfield(L, -2, "type");
+					lua_pushstring(L, presence->getFrom().toString().c_str());
+					lua_setfield(L, -2, "from");
+					lua_pushstring(L, presence->getStatus().c_str());
+					lua_setfield(L, -2, "status");
+					emitEvent = true;
+				}
+				else {
+					assert(false);
+				}
+				if (emitEvent) {
+					int oldTop = lua_gettop(L) - 2;
+					lua_call(L, 1, LUA_MULTRET);
+					int returnValues = lua_gettop(L) - oldTop;
+					if (returnValues > 0) {
+						lua_remove(L, -1 - returnValues);
+						return returnValues;
+					}
+				}
+				else {
+					// Remove the function from the stack again, since
+					// we're not calling the function
+					lua_pop(L, 1);
 				}
-			}
-			else {
-				// Remove the function from the stack again, since
-				// we're not calling the function
-				lua_pop(L, 1);
 			}
 		}
 	}
+	catch (const SluiftException& e) {
+		return luaL_error(L, e.getReason().c_str());
+	}
 }
 
 static int sluift_client_gc (lua_State *L) {
@@ -278,11 +450,15 @@ static int sluift_client_gc (lua_State *L) {
 
 
 static const luaL_reg sluift_client_functions[] = {
+	{"connect",  sluift_client_connect},
+	{"async_connect",  sluift_client_async_connect},
+	{"wait_connected", sluift_client_wait_connected},
 	{"is_connected", sluift_client_is_connected},
 	{"disconnect",  sluift_client_disconnect},
 	{"send_message", sluift_client_send_message},
 	{"send_presence", sluift_client_send_presence},
 	{"set_version", sluift_client_set_version},
+	{"get_roster", sluift_client_get_roster},
 	{"get_version", sluift_client_get_version},
 	{"for_event", sluift_client_for_event},
 	{"__gc", sluift_client_gc},
@@ -293,7 +469,7 @@ static const luaL_reg sluift_client_functions[] = {
  * Module functions
  ******************************************************************************/
 
-static int sluift_connect(lua_State *L) {
+static int sluift_new_client(lua_State *L) {
 	JID jid(std::string(luaL_checkstring(L, 1)));
 	std::string password(luaL_checkstring(L, 2));
 
@@ -302,10 +478,6 @@ static int sluift_connect(lua_State *L) {
 	lua_setmetatable(L, -2);
 
 	*client = new SluiftClient(jid, password);
-	(*client)->connect();
-	if (!(*client)->isConnected()) {
-		lua_pushnil(L);
-	}
 	return 1;
 }
 
@@ -332,7 +504,7 @@ static int sluift_newindex(lua_State *L) {
 }
 
 static const luaL_reg sluift_functions[] = {
-	{"connect",   sluift_connect},
+	{"new_client", sluift_new_client},
 	{NULL, NULL}
 };
 
diff --git a/Swiften/Roster/XMPPRoster.h b/Swiften/Roster/XMPPRoster.h
index 958c1f6..be3494d 100644
--- a/Swiften/Roster/XMPPRoster.h
+++ b/Swiften/Roster/XMPPRoster.h
@@ -86,5 +86,11 @@ namespace Swift {
 			 * onJIDAdded and onJIDRemoved events.
 			 */
 			boost::signal<void ()> onRosterCleared;
+
+			/**
+			 * Emitted after the last contact of the initial roster request response
+			 * was added.
+			 */
+			boost::signal<void ()> onInitialRosterPopulated;
 	};
 }
diff --git a/Swiften/Roster/XMPPRosterController.cpp b/Swiften/Roster/XMPPRosterController.cpp
index 3a1d11f..a294d35 100644
--- a/Swiften/Roster/XMPPRosterController.cpp
+++ b/Swiften/Roster/XMPPRosterController.cpp
@@ -20,7 +20,7 @@ namespace Swift {
  * The controller does not gain ownership of these parameters.
  */
 XMPPRosterController::XMPPRosterController(IQRouter* iqRouter, XMPPRosterImpl* xmppRoster) : iqRouter_(iqRouter), rosterPushResponder_(iqRouter), xmppRoster_(xmppRoster) {
-	rosterPushResponder_.onRosterReceived.connect(boost::bind(&XMPPRosterController::handleRosterReceived, this, _1));
+	rosterPushResponder_.onRosterReceived.connect(boost::bind(&XMPPRosterController::handleRosterReceived, this, _1, false));
 	rosterPushResponder_.start();
 }
 
@@ -31,11 +31,11 @@ XMPPRosterController::~XMPPRosterController() {
 void XMPPRosterController::requestRoster() {
 	xmppRoster_->clear();
 	GetRosterRequest::ref rosterRequest = GetRosterRequest::create(iqRouter_);
-	rosterRequest->onResponse.connect(boost::bind(&XMPPRosterController::handleRosterReceived, this, _1));
+	rosterRequest->onResponse.connect(boost::bind(&XMPPRosterController::handleRosterReceived, this, _1, true));
 	rosterRequest->send();
 }
 
-void XMPPRosterController::handleRosterReceived(boost::shared_ptr<RosterPayload> rosterPayload) {
+void XMPPRosterController::handleRosterReceived(boost::shared_ptr<RosterPayload> rosterPayload, bool initial) {
 	if (rosterPayload) {
 		foreach(const RosterItemPayload& item, rosterPayload->getItems()) {
 			//Don't worry about the updated case, the XMPPRoster sorts that out.
@@ -46,6 +46,9 @@ void XMPPRosterController::handleRosterReceived(boost::shared_ptr<RosterPayload>
 			}
 		}
 	}
+	if (initial) {
+		xmppRoster_->onInitialRosterPopulated();
+	}
 }
 
 }
diff --git a/Swiften/Roster/XMPPRosterController.h b/Swiften/Roster/XMPPRosterController.h
index 28c2541..eeb84f6 100644
--- a/Swiften/Roster/XMPPRosterController.h
+++ b/Swiften/Roster/XMPPRosterController.h
@@ -27,7 +27,7 @@ namespace Swift {
 			void requestRoster();
 
 		private:
-			void handleRosterReceived(boost::shared_ptr<RosterPayload> rosterPayload);
+			void handleRosterReceived(boost::shared_ptr<RosterPayload> rosterPayload, bool initial);
 
 		private:
 			IQRouter* iqRouter_;
-- 
cgit v0.10.2-6-g49f6