/* * Copyright (c) 2011 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include "sluift.h" #include <lauxlib.h> #include <iostream> #include <string> #include <deque> #include <boost/assign/list_of.hpp> #include <Swiften/Base/foreach.h> #include <Swiften/Swiften.h> #include "Watchdog.h" #include "SluiftException.h" #include "ResponseSink.h" #include "Lua/Value.h" using namespace Swift; #define SLUIFT_CLIENT "SluiftClient*" /******************************************************************************* * Forward declarations ******************************************************************************/ static bool debug = false; static int globalTimeout = 30000; /******************************************************************************* * Helper classes ******************************************************************************/ SimpleEventLoop eventLoop; BoostNetworkFactories networkFactories(&eventLoop); class SluiftClient { public: SluiftClient(const JID& jid, const std::string& password) : tracer(NULL) { client = new Client(jid, password, &networkFactories); client->setAlwaysTrustCertificates(); client->onDisconnected.connect(boost::bind(&SluiftClient::handleDisconnected, this, _1)); client->onMessageReceived.connect(boost::bind(&SluiftClient::handleIncomingEvent, this, _1)); client->onPresenceReceived.connect(boost::bind(&SluiftClient::handleIncomingEvent, this, _1)); client->getRoster()->onInitialRosterPopulated.connect(boost::bind(&SluiftClient::handleInitialRosterPopulated, this)); if (debug) { tracer = new ClientXMLTracer(client); } } ~SluiftClient() { delete tracer; delete client; } Client* getClient() { return client; } ClientOptions& getOptions() { return options; } void connect() { rosterReceived = false; client->connect(options); } void connect(const std::string& host) { rosterReceived = false; options.manualHostname = host; client->connect(options); } void waitConnected() { Watchdog watchdog(globalTimeout, networkFactories.getTimerFactory()); while (!watchdog.getTimedOut() && client->isActive() && !client->isAvailable()) { eventLoop.runUntilEvents(); } if (watchdog.getTimedOut()) { client->disconnect(); throw SluiftException("Timeout while connecting"); } } bool isConnected() const { return client->isAvailable(); } void sendMessage(const JID& to, const std::string& body) { Message::ref message = boost::make_shared<Message>(); message->setTo(to); message->setBody(body); client->sendMessage(message); } void sendPresence(const std::string& status) { client->sendPresence(boost::make_shared<Presence>(status)); } boost::optional<std::string> sendQuery(const JID& jid, IQ::Type type, const std::string& data, int timeout) { rawRequestResponse.reset(); RawRequest::ref request = RawRequest::create(type, jid, data, client->getIQRouter()); boost::signals::scoped_connection c = request->onResponse.connect(boost::bind(&SluiftClient::handleRawRequestResponse, this, _1)); request->send(); Watchdog watchdog(timeout, networkFactories.getTimerFactory()); while (!watchdog.getTimedOut() && !rawRequestResponse) { eventLoop.runUntilEvents(); } if (watchdog.getTimedOut()) { return boost::optional<std::string>(); } else { return *rawRequestResponse; } } void disconnect() { client->disconnect(); while (client->isActive()) { eventLoop.runUntilEvents(); } } void setSoftwareVersion(const std::string& name, const std::string& version, const std::string& os) { client->setSoftwareVersion(name, version, os); } Stanza::ref getNextEvent(int timeout) { if (!pendingEvents.empty()) { Stanza::ref event = pendingEvents.front(); pendingEvents.pop_front(); return event; } Watchdog watchdog(timeout, networkFactories.getTimerFactory()); while (!watchdog.getTimedOut() && pendingEvents.empty() && !client->isActive()) { eventLoop.runUntilEvents(); } if (watchdog.getTimedOut() || !client->isActive()) { return Stanza::ref(); } else if (!pendingEvents.empty()) { Stanza::ref event = pendingEvents.front(); pendingEvents.pop_front(); return event; } else { return Stanza::ref(); } } std::vector<XMPPRosterItem> getRoster() { if (!rosterReceived) { // If we haven't requested it yet, request it for the first time client->requestRoster(); } while (!rosterReceived) { eventLoop.runUntilEvents(); } return client->getRoster()->getItems(); } private: void handleIncomingEvent(Stanza::ref stanza) { pendingEvents.push_back(stanza); } void handleInitialRosterPopulated() { rosterReceived = true; } void handleRawRequestResponse(const std::string& response) { rawRequestResponse = response; } void handleDisconnected(const boost::optional<ClientError>& error) { if (error) { throw SluiftException(*error); } } private: Client* client; ClientOptions options; ClientXMLTracer* tracer; bool rosterReceived; std::deque<Stanza::ref> pendingEvents; boost::optional<std::string> rawRequestResponse; }; /******************************************************************************* * Client functions. ******************************************************************************/ static inline SluiftClient* getClient(lua_State* L) { return *reinterpret_cast<SluiftClient**>(luaL_checkudata(L, 1, SLUIFT_CLIENT)); } static int sluift_client_connect(lua_State *L) { try { SluiftClient* client = getClient(L); std::string host; if (lua_type(L, 2) != LUA_TNONE) { host = luaL_checkstring(L, 2); } if (host.empty()) { client->connect(); } else { client->connect(host); } client->waitConnected(); return 1; } catch (const SluiftException& e) { return luaL_error(L, e.getReason().c_str()); } } static int sluift_client_async_connect(lua_State *L) { try { getClient(L)->connect(); return 1; } catch (const SluiftException& e) { return luaL_error(L, e.getReason().c_str()); } } static int sluift_client_wait_connected(lua_State *L) { try { getClient(L)->waitConnected(); return 1; } catch (const SluiftException& e) { return luaL_error(L, e.getReason().c_str()); } } static int sluift_client_is_connected(lua_State *L) { lua_pushboolean(L, getClient(L)->isConnected()); return 1; } static int sluift_client_disconnect(lua_State *L) { try { getClient(L)->disconnect(); return 1; } catch (const SluiftException& e) { return luaL_error(L, e.getReason().c_str()); } } static int sluift_client_set_version(lua_State *L) { try { eventLoop.runOnce(); SluiftClient* client = getClient(L); luaL_checktype(L, 2, LUA_TTABLE); lua_getfield(L, 2, "name"); const char* rawName = lua_tostring(L, -1); lua_getfield(L, 2, "version"); const char* rawVersion = lua_tostring(L, -1); lua_getfield(L, 2, "os"); const char* rawOS = lua_tostring(L, -1); client->setSoftwareVersion(rawName ? rawName : "", rawVersion ? rawVersion : "", rawOS ? rawOS : ""); lua_pop(L, 3); lua_pushvalue(L, 1); return 1; } catch (const SluiftException& e) { return luaL_error(L, e.getReason().c_str()); } } static int sluift_client_get_contacts(lua_State *L) { try { eventLoop.runOnce(); SluiftClient* client = getClient(L); Lua::Table contactsTable; foreach(const XMPPRosterItem& item, client->getRoster()) { std::string subscription; switch(item.getSubscription()) { case RosterItemPayload::None: subscription = "none"; break; case RosterItemPayload::To: subscription = "to"; break; case RosterItemPayload::From: subscription = "from"; break; case RosterItemPayload::Both: subscription = "both"; break; case RosterItemPayload::Remove: subscription = "remove"; break; } Lua::Value groups(std::vector<Lua::Value>(item.getGroups().begin(), item.getGroups().end())); Lua::Table itemTable = boost::assign::map_list_of ("jid", boost::make_shared<Lua::Value>(item.getJID().toString())) ("name", boost::make_shared<Lua::Value>(item.getName())) ("subscription", boost::make_shared<Lua::Value>(subscription)) ("groups", boost::make_shared<Lua::Value>(std::vector<Lua::Value>(item.getGroups().begin(), item.getGroups().end()))); contactsTable[item.getJID().toString()] = boost::make_shared<Lua::Value>(itemTable); } pushValue(L, contactsTable); return 1; } catch (const SluiftException& e) { return luaL_error(L, e.getReason().c_str()); } } static int sluift_client_get_version(lua_State *L) { try { SluiftClient* client = getClient(L); int timeout = -1; if (lua_type(L, 3) != LUA_TNONE) { timeout = luaL_checknumber(L, 3); } ResponseSink<SoftwareVersion> sink; GetSoftwareVersionRequest::ref request = GetSoftwareVersionRequest::create(std::string(luaL_checkstring(L, 2)), client->getClient()->getIQRouter()); boost::signals::scoped_connection c = request->onResponse.connect(boost::ref(sink)); request->send(); Watchdog watchdog(timeout, networkFactories.getTimerFactory()); while (!watchdog.getTimedOut() && !sink.hasResponse()) { eventLoop.runUntilEvents(); } ErrorPayload::ref error = sink.getResponseError(); if (error || watchdog.getTimedOut()) { lua_pushnil(L); if (watchdog.getTimedOut()) { lua_pushstring(L, "Timeout"); } else if (error->getCondition() == ErrorPayload::RemoteServerNotFound) { lua_pushstring(L, "Remote server not found"); } // TODO else { lua_pushstring(L, "Error"); } return 2; } else if (SoftwareVersion::ref version = sink.getResponsePayload()) { Lua::Table result = boost::assign::map_list_of ("name", boost::make_shared<Lua::Value>(version->getName())) ("version", boost::make_shared<Lua::Value>(version->getVersion())) ("os", boost::make_shared<Lua::Value>(version->getOS())); Lua::pushValue(L, result); } else { lua_pushnil(L); } return 1; } catch (const SluiftException& e) { return luaL_error(L, e.getReason().c_str()); } } static int sluift_client_send_message(lua_State *L) { try { eventLoop.runOnce(); getClient(L)->sendMessage(std::string(luaL_checkstring(L, 2)), luaL_checkstring(L, 3)); lua_pushvalue(L, 1); return 1; } catch (const SluiftException& e) { return luaL_error(L, e.getReason().c_str()); } } static int sluift_client_send_presence(lua_State *L) { try { eventLoop.runOnce(); getClient(L)->sendPresence(std::string(luaL_checkstring(L, 2))); lua_pushvalue(L, 1); return 0; } catch (const SluiftException& e) { return luaL_error(L, e.getReason().c_str()); } } static int sluift_client_get(lua_State *L) { try { SluiftClient* client = getClient(L); JID jid; std::string data; int timeout = -1; if (lua_type(L, 3) == LUA_TSTRING) { jid = JID(std::string(luaL_checkstring(L, 2))); data = std::string(luaL_checkstring(L, 3)); if (lua_type(L, 4) != LUA_TNONE) { timeout = luaL_checknumber(L, 4); } } else { data = std::string(luaL_checkstring(L, 2)); if (lua_type(L, 3) != LUA_TNONE) { timeout = luaL_checknumber(L, 3); } } boost::optional<std::string> result = client->sendQuery(jid, IQ::Get, data, timeout); if (result) { lua_pushstring(L, result->c_str()); } else { lua_pushnil(L); } return 1; } catch (const SluiftException& e) { return luaL_error(L, e.getReason().c_str()); } } static int sluift_client_set(lua_State *L) { try { SluiftClient* client = getClient(L); JID jid; std::string data; int timeout = -1; if (lua_type(L, 3) == LUA_TSTRING) { jid = JID(std::string(luaL_checkstring(L, 2))); data = std::string(luaL_checkstring(L, 3)); if (lua_type(L, 4) != LUA_TNONE) { timeout = luaL_checknumber(L, 4); } } else { data = std::string(luaL_checkstring(L, 2)); if (lua_type(L, 3) != LUA_TNONE) { timeout = luaL_checknumber(L, 3); } } boost::optional<std::string> result = client->sendQuery(jid, IQ::Set, data, timeout); if (result) { lua_pushstring(L, result->c_str()); } else { lua_pushnil(L); } return 1; } catch (const SluiftException& e) { return luaL_error(L, e.getReason().c_str()); } } static int sluift_client_send(lua_State *L) { try { eventLoop.runOnce(); getClient(L)->getClient()->sendData(std::string(luaL_checkstring(L, 2))); lua_pushvalue(L, 1); return 0; } catch (const SluiftException& e) { return luaL_error(L, e.getReason().c_str()); } } static int sluift_client_set_options(lua_State* L) { SluiftClient* client = getClient(L); luaL_checktype(L, 2, LUA_TTABLE); lua_getfield(L, 2, "ack"); if (!lua_isnil(L, -1)) { client->getOptions().useAcks = lua_toboolean(L, -1); } lua_getfield(L, 2, "compress"); if (!lua_isnil(L, -1)) { client->getOptions().useStreamCompression = lua_toboolean(L, -1); } lua_getfield(L, 2, "tls"); if (!lua_isnil(L, -1)) { bool useTLS = lua_toboolean(L, -1); client->getOptions().useTLS = (useTLS ? ClientOptions::UseTLSWhenAvailable : ClientOptions::NeverUseTLS); } lua_pushvalue(L, 1); return 0; } static void pushEvent(lua_State* L, Stanza::ref event) { if (Message::ref message = boost::dynamic_pointer_cast<Message>(event)) { Lua::Table result = boost::assign::map_list_of ("type", boost::make_shared<Lua::Value>(std::string("message"))) ("from", boost::make_shared<Lua::Value>(message->getFrom().toString())) ("body", boost::make_shared<Lua::Value>(message->getBody())); Lua::pushValue(L, result); } else if (Presence::ref presence = boost::dynamic_pointer_cast<Presence>(event)) { Lua::Table result = boost::assign::map_list_of ("type", boost::make_shared<Lua::Value>(std::string("presence"))) ("from", boost::make_shared<Lua::Value>(presence->getFrom().toString())) ("status", boost::make_shared<Lua::Value>(presence->getStatus())); Lua::pushValue(L, result); } else { lua_pushnil(L); } } static int sluift_client_for_event(lua_State *L) { try { eventLoop.runOnce(); SluiftClient* client = getClient(L); luaL_checktype(L, 2, LUA_TFUNCTION); int timeout = -1; if (lua_type(L, 3) != LUA_TNONE) { timeout = lua_tonumber(L, 3); } while (true) { Stanza::ref event = client->getNextEvent(timeout); if (!event) { // We got a timeout lua_pushnil(L); return 1; } else { // Push the function and event on the stack lua_pushvalue(L, 2); pushEvent(L, event); int oldTop = lua_gettop(L) - 2; lua_call(L, 1, LUA_MULTRET); int returnValues = lua_gettop(L) - oldTop; if (returnValues > 0) { lua_remove(L, -1 - returnValues); return returnValues; } } } } catch (const SluiftException& e) { return luaL_error(L, e.getReason().c_str()); } } static int sluift_client_get_next_event(lua_State *L) { try { eventLoop.runOnce(); SluiftClient* client = getClient(L); int timeout = -1; if (lua_type(L, 2) != LUA_TNONE) { timeout = lua_tonumber(L, 2); } pushEvent(L, client->getNextEvent(timeout)); return 1; } catch (const SluiftException& e) { return luaL_error(L, e.getReason().c_str()); } } static int sluift_client_add_contact(lua_State* L) { try { eventLoop.runOnce(); SluiftClient* client = getClient(L); RosterItemPayload item; if (lua_type(L, 2) == LUA_TTABLE) { lua_getfield(L, 2, "jid"); const char* rawJID = lua_tostring(L, -1); if (rawJID) { item.setJID(std::string(rawJID)); } lua_getfield(L, 2, "name"); const char* rawName = lua_tostring(L, -1); if (rawName) { item.setName(rawName); } lua_getfield(L, 2, "groups"); if (!lua_isnil(L, -1)) { if (lua_type(L, -1) == LUA_TTABLE) { for (size_t i = 1; i <= lua_objlen(L, -1); ++i) { lua_rawgeti(L, -1, i); const char* rawGroup = lua_tostring(L, -1); if (rawGroup) { item.addGroup(rawGroup); } lua_pop(L, 1); } } else { return luaL_error(L, "Groups should be a table"); } } } else { item.setJID(luaL_checkstring(L, 2)); } client->getRoster(); if (!client->getClient()->getRoster()->containsJID(item.getJID())) { RosterPayload::ref roster = boost::make_shared<RosterPayload>(); roster->addItem(item); ResponseSink<RosterPayload> sink; SetRosterRequest::ref request = SetRosterRequest::create(roster, client->getClient()->getIQRouter()); boost::signals::scoped_connection c = request->onResponse.connect(boost::ref(sink)); request->send(); while (!sink.hasResponse()) { eventLoop.runUntilEvents(); } if (sink.getResponseError()) { lua_pushboolean(L, false); return 1; } } client->getClient()->getSubscriptionManager()->requestSubscription(item.getJID()); lua_pushboolean(L, true); return 1; } catch (const SluiftException& e) { return luaL_error(L, e.getReason().c_str()); } } static int sluift_client_remove_contact(lua_State* L) { try { eventLoop.runOnce(); SluiftClient* client = getClient(L); JID jid(luaL_checkstring(L, 2)); RosterPayload::ref roster = boost::make_shared<RosterPayload>(); roster->addItem(RosterItemPayload(JID(luaL_checkstring(L, 2)), "", RosterItemPayload::Remove)); ResponseSink<RosterPayload> sink; SetRosterRequest::ref request = SetRosterRequest::create(roster, client->getClient()->getIQRouter()); boost::signals::scoped_connection c = request->onResponse.connect(boost::ref(sink)); request->send(); while (!sink.hasResponse()) { eventLoop.runUntilEvents(); } lua_pushboolean(L, !sink.getResponseError()); return 1; } catch (const SluiftException& e) { return luaL_error(L, e.getReason().c_str()); } } static int sluift_client_confirm_subscription(lua_State* L) { try { eventLoop.runOnce(); SluiftClient* client = getClient(L); JID jid(luaL_checkstring(L, 2)); client->getClient()->getSubscriptionManager()->confirmSubscription(jid); return 0; } catch (const SluiftException& e) { return luaL_error(L, e.getReason().c_str()); } } static int sluift_client_cancel_subscription(lua_State* L) { try { eventLoop.runOnce(); SluiftClient* client = getClient(L); JID jid(luaL_checkstring(L, 2)); client->getClient()->getSubscriptionManager()->cancelSubscription(jid); return 0; } catch (const SluiftException& e) { return luaL_error(L, e.getReason().c_str()); } } static int sluift_client_gc (lua_State *L) { SluiftClient* client = getClient(L); delete client; return 0; } static const luaL_reg sluift_client_functions[] = { {"connect", sluift_client_connect}, {"async_connect", sluift_client_async_connect}, {"wait_connected", sluift_client_wait_connected}, {"is_connected", sluift_client_is_connected}, {"disconnect", sluift_client_disconnect}, {"send_message", sluift_client_send_message}, {"send_presence", sluift_client_send_presence}, {"get", sluift_client_get}, {"set", sluift_client_set}, {"send", sluift_client_send}, {"set_version", sluift_client_set_version}, {"get_contacts", sluift_client_get_contacts}, {"get_version", sluift_client_get_version}, {"set_options", sluift_client_set_options}, {"for_event", sluift_client_for_event}, {"get_next_event", sluift_client_get_next_event}, {"add_contact", sluift_client_add_contact}, {"remove_contact", sluift_client_remove_contact}, {"confirm_subscription", sluift_client_confirm_subscription}, {"cancel_subscription", sluift_client_cancel_subscription}, {"__gc", sluift_client_gc}, {NULL, NULL} }; /******************************************************************************* * Module functions ******************************************************************************/ static int sluift_new_client(lua_State *L) { try { JID jid(std::string(luaL_checkstring(L, 1))); std::string password(luaL_checkstring(L, 2)); SluiftClient** client = reinterpret_cast<SluiftClient**>(lua_newuserdata(L, sizeof(SluiftClient*))); luaL_getmetatable(L, SLUIFT_CLIENT); lua_setmetatable(L, -2); *client = new SluiftClient(jid, password); return 1; } catch (const SluiftException& e) { return luaL_error(L, e.getReason().c_str()); } } static int sluift_jid_to_bare(lua_State *L) { JID jid(std::string(luaL_checkstring(L, 1))); lua_pushstring(L, jid.toBare().toString().c_str()); return 1; } static int sluift_jid_node(lua_State *L) { JID jid(std::string(luaL_checkstring(L, 1))); lua_pushstring(L, jid.getNode().c_str()); return 1; } static int sluift_jid_domain(lua_State *L) { JID jid(std::string(luaL_checkstring(L, 1))); lua_pushstring(L, jid.getDomain().c_str()); return 1; } static int sluift_jid_resource(lua_State *L) { JID jid(std::string(luaL_checkstring(L, 1))); lua_pushstring(L, jid.getResource().c_str()); return 1; } static int sluift_sleep(lua_State *L) { try { eventLoop.runOnce(); int timeout = luaL_checknumber(L, 1); Watchdog watchdog(timeout, networkFactories.getTimerFactory()); while (!watchdog.getTimedOut()) { Swift::sleep(std::min(100, timeout)); eventLoop.runOnce(); } return 0; } catch (const SluiftException& e) { return luaL_error(L, e.getReason().c_str()); } } static int sluift_index(lua_State *L) { std::string key(luaL_checkstring(L, 2)); if (key == "debug") { lua_pushboolean(L, debug); return 1; } else if (key == "timeout") { lua_pushnumber(L, globalTimeout); return 1; } else { return luaL_error(L, "Invalid index"); } } static int sluift_newindex(lua_State *L) { std::string key(luaL_checkstring(L, 2)); if (key == "debug") { debug = lua_toboolean(L, 3); return 0; } else if (key == "timeout") { globalTimeout = luaL_checknumber(L, 3); return 0; } else { return luaL_error(L, "Invalid index"); } } static const luaL_reg sluift_functions[] = { {"new_client", sluift_new_client}, {"jid_to_bare", sluift_jid_to_bare}, {"jid_node", sluift_jid_node}, {"jid_domain", sluift_jid_domain}, {"jid_resource", sluift_jid_resource}, {"sleep", sluift_sleep}, {NULL, NULL} }; /******************************************************************************* * Module registration ******************************************************************************/ SLUIFT_API int luaopen_sluift(lua_State *L) { // Register functions luaL_register(L, "sluift", sluift_functions); lua_createtable(L, 0, 0); lua_pushcclosure(L, sluift_index, 0); lua_setfield(L, -2, "__index"); lua_pushcclosure(L, sluift_newindex, 0); lua_setfield(L, -2, "__newindex"); lua_setmetatable(L, -2); // Register the client metatable luaL_newmetatable(L, SLUIFT_CLIENT); lua_pushvalue(L, -1); lua_setfield(L, -2, "__index"); luaL_register(L, NULL, sluift_client_functions); return 1; }