diff options
Diffstat (limited to 'Sluift/sluift.cpp')
-rw-r--r-- | Sluift/sluift.cpp | 312 |
1 files changed, 242 insertions, 70 deletions
diff --git a/Sluift/sluift.cpp b/Sluift/sluift.cpp index 84643b8..b3ca8dc 100644 --- a/Sluift/sluift.cpp +++ b/Sluift/sluift.cpp @@ -21,6 +21,7 @@ extern "C" { #include <Swiften/Base/sleep.h> #include <Swiften/Elements/SoftwareVersion.h> #include <Swiften/Queries/Requests/GetSoftwareVersionRequest.h> +#include <Swiften/Roster/XMPPRoster.h> using namespace Swift; @@ -39,13 +40,28 @@ bool debug = false; SimpleEventLoop eventLoop; BoostNetworkFactories networkFactories(&eventLoop); +class SluiftException { + public: + SluiftException(const std::string& reason) : reason(reason) { + } + + const std::string& getReason() const { + return reason; + } + + private: + std::string reason; +}; + class SluiftClient { public: SluiftClient(const JID& jid, const std::string& password) : tracer(NULL) { client = new Client(jid, password, &networkFactories); client->setAlwaysTrustCertificates(); + client->onDisconnected.connect(boost::bind(&SluiftClient::handleDisconnected, this, _1)); client->onMessageReceived.connect(boost::bind(&SluiftClient::handleIncomingEvent, this, _1)); client->onPresenceReceived.connect(boost::bind(&SluiftClient::handleIncomingEvent, this, _1)); + client->getRoster()->onInitialRosterPopulated.connect(boost::bind(&SluiftClient::handleInitialRosterPopulated, this)); if (debug) { tracer = new ClientXMLTracer(client); } @@ -57,7 +73,11 @@ class SluiftClient { } void connect() { + rosterReceived = false; client->connect(); + } + + void waitConnected() { while (client->isActive() && !client->isAvailable()) { processEvents(); } @@ -121,11 +141,23 @@ class SluiftClient { } } + std::vector<XMPPRosterItem> getRoster() { + client->requestRoster(); + while (!rosterReceived) { + processEvents(); + } + return client->getRoster()->getItems(); + } + private: void handleIncomingEvent(Stanza::ref stanza) { pendingEvents.push_back(stanza); } + void handleInitialRosterPopulated() { + rosterReceived = true; + } + void processEvents() { eventLoop.runUntilEvents(); } @@ -141,15 +173,57 @@ class SluiftClient { this->softwareVersion = SoftwareVersion("", "", ""); } } + + void handleDisconnected(const boost::optional<ClientError>& error) { + if (error) { + std::string reason("Disconnected: "); + switch(error->getType()) { + case ClientError::UnknownError: reason += "Unknown Error"; break; + case ClientError::DomainNameResolveError: reason += "Unable to find server"; break; + case ClientError::ConnectionError: reason += "Error connecting to server"; break; + case ClientError::ConnectionReadError: reason += "Error while receiving server data"; break; + case ClientError::ConnectionWriteError: reason += "Error while sending data to the server"; break; + case ClientError::XMLError: reason += "Error parsing server data"; break; + case ClientError::AuthenticationFailedError: reason += "Login/password invalid"; break; + case ClientError::CompressionFailedError: reason += "Error while compressing stream"; break; + case ClientError::ServerVerificationFailedError: reason += "Server verification failed"; break; + case ClientError::NoSupportedAuthMechanismsError: reason += "Authentication mechanisms not supported"; break; + case ClientError::UnexpectedElementError: reason += "Unexpected response"; break; + case ClientError::ResourceBindError: reason += "Error binding resource"; break; + case ClientError::SessionStartError: reason += "Error starting session"; break; + case ClientError::StreamError: reason += "Stream error"; break; + case ClientError::TLSError: reason += "Encryption error"; break; + case ClientError::ClientCertificateLoadError: reason += "Error loading certificate (Invalid password?)"; break; + case ClientError::ClientCertificateError: reason += "Certificate not authorized"; break; + case ClientError::UnknownCertificateError: reason += "Unknown certificate"; break; + case ClientError::CertificateExpiredError: reason += "Certificate has expired"; break; + case ClientError::CertificateNotYetValidError: reason += "Certificate is not yet valid"; break; + case ClientError::CertificateSelfSignedError: reason += "Certificate is self-signed"; break; + case ClientError::CertificateRejectedError: reason += "Certificate has been rejected"; break; + case ClientError::CertificateUntrustedError: reason += "Certificate is not trusted"; break; + case ClientError::InvalidCertificatePurposeError: reason += "Certificate cannot be used for encrypting your connection"; break; + case ClientError::CertificatePathLengthExceededError: reason += "Certificate path length constraint exceeded"; break; + case ClientError::InvalidCertificateSignatureError: reason += "Invalid certificate signature"; break; + case ClientError::InvalidCAError: reason += "Invalid Certificate Authority"; break; + case ClientError::InvalidServerIdentityError: reason += "Certificate does not match the host identity"; break; + } + throw SluiftException(reason); + } + } private: Client* client; ClientXMLTracer* tracer; + bool rosterReceived; boost::optional<SoftwareVersion> softwareVersion; ErrorPayload::ref error; std::deque<Stanza::ref> pendingEvents; }; +#define CHECK_CLIENT_CONNECTED(client, L) \ + if (!(*client)->isConnected()) { \ + lua_pushnil(L); \ + } /******************************************************************************* * Client functions. @@ -159,14 +233,51 @@ static inline SluiftClient* getClient(lua_State* L) { return *reinterpret_cast<SluiftClient**>(luaL_checkudata(L, 1, SLUIFT_CLIENT)); } +static int sluift_client_connect(lua_State *L) { + try { + SluiftClient* client = getClient(L); + client->connect(); + client->waitConnected(); + return 0; + } + catch (const SluiftException& e) { + return luaL_error(L, e.getReason().c_str()); + } +} + +static int sluift_client_async_connect(lua_State *L) { + try { + getClient(L)->connect(); + return 0; + } + catch (const SluiftException& e) { + return luaL_error(L, e.getReason().c_str()); + } +} + +static int sluift_client_wait_connected(lua_State *L) { + try { + getClient(L)->waitConnected(); + return 0; + } + catch (const SluiftException& e) { + return luaL_error(L, e.getReason().c_str()); + } +} + static int sluift_client_is_connected(lua_State *L) { lua_pushboolean(L, getClient(L)->isConnected()); return 1; } static int sluift_client_disconnect(lua_State *L) { - getClient(L)->disconnect(); - return 0; + try { + getClient(L)->disconnect(); + return 0; + } + catch (const SluiftException& e) { + return luaL_error(L, e.getReason().c_str()); + } } static int sluift_client_set_version(lua_State *L) { @@ -183,24 +294,80 @@ static int sluift_client_set_version(lua_State *L) { return 0; } -static int sluift_client_get_version(lua_State *L) { - SluiftClient* client = getClient(L); - JID jid(std::string(luaL_checkstring(L, 2))); +static int sluift_client_get_roster(lua_State *L) { + try { + SluiftClient* client = getClient(L); + std::vector<XMPPRosterItem> items = client->getRoster(); + + lua_createtable(L, 0, items.size()); + foreach(const XMPPRosterItem& item, items) { + lua_createtable(L, 0, 3); + + lua_pushstring(L, item.getName().c_str()); + lua_setfield(L, -2, "name"); + + std::string subscription; + switch(item.getSubscription()) { + case RosterItemPayload::None: subscription = "none"; break; + case RosterItemPayload::To: subscription = "to"; break; + case RosterItemPayload::From: subscription = "from"; break; + case RosterItemPayload::Both: subscription = "both"; break; + case RosterItemPayload::Remove: subscription = "remove"; break; + } + lua_pushstring(L, subscription.c_str()); + lua_setfield(L, -2, "subscription"); + + std::vector<std::string> groups = item.getGroups(); + lua_createtable(L, groups.size(), 0); + for (size_t i = 0; i < groups.size(); ++i) { + lua_pushstring(L, groups[i].c_str()); + lua_rawseti(L, -2, i + 1); + } + lua_setfield(L, -2, "groups"); - boost::optional<SoftwareVersion> version = client->getSoftwareVersion(jid); - if (version) { - lua_createtable(L, 0, 3); + lua_setfield(L, -2, item.getJID().toString().c_str()); + } + /*boost::optional<SoftwareVersion> version = client->getSoftwareVersion(jid); + if (version) { lua_pushstring(L, version->getName().c_str()); - lua_setfield(L, -2, "name"); - lua_pushstring(L, version->getVersion().c_str()); - lua_setfield(L, -2, "version"); - lua_pushstring(L, version->getOS().c_str()); - lua_setfield(L, -2, "os"); + lua_pushstring(L, version->getName().c_str()); + lua_setfield(L, -2, "name"); + lua_pushstring(L, version->getVersion().c_str()); + lua_setfield(L, -2, "version"); + lua_pushstring(L, version->getOS().c_str()); + lua_setfield(L, -2, "os"); + } + */ + return 1; } - else { - lua_pushnil(L); + catch (const SluiftException& e) { + return luaL_error(L, e.getReason().c_str()); + } +} + +static int sluift_client_get_version(lua_State *L) { + try { + SluiftClient* client = getClient(L); + JID jid(std::string(luaL_checkstring(L, 2))); + + boost::optional<SoftwareVersion> version = client->getSoftwareVersion(jid); + if (version) { + lua_createtable(L, 0, 3); + lua_pushstring(L, version->getName().c_str()); + lua_setfield(L, -2, "name"); + lua_pushstring(L, version->getVersion().c_str()); + lua_setfield(L, -2, "version"); + lua_pushstring(L, version->getOS().c_str()); + lua_setfield(L, -2, "os"); + } + else { + lua_pushnil(L); + } + return 1; + } + catch (const SluiftException& e) { + return luaL_error(L, e.getReason().c_str()); } - return 1; } static int sluift_client_send_message(lua_State *L) { @@ -214,60 +381,65 @@ static int sluift_client_send_presence(lua_State *L) { } static int sluift_client_for_event (lua_State *L) { - SluiftClient* client = getClient(L); - luaL_checktype(L, 2, LUA_TFUNCTION); - while (true) { - Stanza::ref event = client->getNextEvent(); - if (!event) { - // We got disconnected - lua_pushnil(L); - lua_pushliteral(L, "disconnected"); - return 2; - } - else { - // Push the function on the stack - lua_pushvalue(L, 2); - - bool emitEvent = false; - if (Message::ref message = boost::dynamic_pointer_cast<Message>(event)) { - lua_createtable(L, 0, 3); - lua_pushliteral(L, "message"); - lua_setfield(L, -2, "type"); - lua_pushstring(L, message->getFrom().toString().c_str()); - lua_setfield(L, -2, "from"); - lua_pushstring(L, message->getBody().c_str()); - lua_setfield(L, -2, "body"); - emitEvent = true; - } - else if (Presence::ref presence = boost::dynamic_pointer_cast<Presence>(event)) { - lua_createtable(L, 0, 3); - lua_pushliteral(L, "presence"); - lua_setfield(L, -2, "type"); - lua_pushstring(L, presence->getFrom().toString().c_str()); - lua_setfield(L, -2, "from"); - lua_pushstring(L, presence->getStatus().c_str()); - lua_setfield(L, -2, "status"); - emitEvent = true; + try { + SluiftClient* client = getClient(L); + luaL_checktype(L, 2, LUA_TFUNCTION); + while (true) { + Stanza::ref event = client->getNextEvent(); + if (!event) { + // We got disconnected + lua_pushnil(L); + lua_pushliteral(L, "disconnected"); + return 2; } else { - assert(false); - } - if (emitEvent) { - int oldTop = lua_gettop(L) - 2; - lua_call(L, 1, LUA_MULTRET); - int returnValues = lua_gettop(L) - oldTop; - if (returnValues > 0) { - lua_remove(L, -1 - returnValues); - return returnValues; + // Push the function on the stack + lua_pushvalue(L, 2); + + bool emitEvent = false; + if (Message::ref message = boost::dynamic_pointer_cast<Message>(event)) { + lua_createtable(L, 0, 3); + lua_pushliteral(L, "message"); + lua_setfield(L, -2, "type"); + lua_pushstring(L, message->getFrom().toString().c_str()); + lua_setfield(L, -2, "from"); + lua_pushstring(L, message->getBody().c_str()); + lua_setfield(L, -2, "body"); + emitEvent = true; + } + else if (Presence::ref presence = boost::dynamic_pointer_cast<Presence>(event)) { + lua_createtable(L, 0, 3); + lua_pushliteral(L, "presence"); + lua_setfield(L, -2, "type"); + lua_pushstring(L, presence->getFrom().toString().c_str()); + lua_setfield(L, -2, "from"); + lua_pushstring(L, presence->getStatus().c_str()); + lua_setfield(L, -2, "status"); + emitEvent = true; + } + else { + assert(false); + } + if (emitEvent) { + int oldTop = lua_gettop(L) - 2; + lua_call(L, 1, LUA_MULTRET); + int returnValues = lua_gettop(L) - oldTop; + if (returnValues > 0) { + lua_remove(L, -1 - returnValues); + return returnValues; + } + } + else { + // Remove the function from the stack again, since + // we're not calling the function + lua_pop(L, 1); } - } - else { - // Remove the function from the stack again, since - // we're not calling the function - lua_pop(L, 1); } } } + catch (const SluiftException& e) { + return luaL_error(L, e.getReason().c_str()); + } } static int sluift_client_gc (lua_State *L) { @@ -278,11 +450,15 @@ static int sluift_client_gc (lua_State *L) { static const luaL_reg sluift_client_functions[] = { + {"connect", sluift_client_connect}, + {"async_connect", sluift_client_async_connect}, + {"wait_connected", sluift_client_wait_connected}, {"is_connected", sluift_client_is_connected}, {"disconnect", sluift_client_disconnect}, {"send_message", sluift_client_send_message}, {"send_presence", sluift_client_send_presence}, {"set_version", sluift_client_set_version}, + {"get_roster", sluift_client_get_roster}, {"get_version", sluift_client_get_version}, {"for_event", sluift_client_for_event}, {"__gc", sluift_client_gc}, @@ -293,7 +469,7 @@ static const luaL_reg sluift_client_functions[] = { * Module functions ******************************************************************************/ -static int sluift_connect(lua_State *L) { +static int sluift_new_client(lua_State *L) { JID jid(std::string(luaL_checkstring(L, 1))); std::string password(luaL_checkstring(L, 2)); @@ -302,10 +478,6 @@ static int sluift_connect(lua_State *L) { lua_setmetatable(L, -2); *client = new SluiftClient(jid, password); - (*client)->connect(); - if (!(*client)->isConnected()) { - lua_pushnil(L); - } return 1; } @@ -332,7 +504,7 @@ static int sluift_newindex(lua_State *L) { } static const luaL_reg sluift_functions[] = { - {"connect", sluift_connect}, + {"new_client", sluift_new_client}, {NULL, NULL} }; |