summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to 'Sluift/sluift.cpp')
-rw-r--r--Sluift/sluift.cpp312
1 files changed, 242 insertions, 70 deletions
diff --git a/Sluift/sluift.cpp b/Sluift/sluift.cpp
index 84643b8..b3ca8dc 100644
--- a/Sluift/sluift.cpp
+++ b/Sluift/sluift.cpp
@@ -21,6 +21,7 @@ extern "C" {
#include <Swiften/Base/sleep.h>
#include <Swiften/Elements/SoftwareVersion.h>
#include <Swiften/Queries/Requests/GetSoftwareVersionRequest.h>
+#include <Swiften/Roster/XMPPRoster.h>
using namespace Swift;
@@ -39,13 +40,28 @@ bool debug = false;
SimpleEventLoop eventLoop;
BoostNetworkFactories networkFactories(&eventLoop);
+class SluiftException {
+ public:
+ SluiftException(const std::string& reason) : reason(reason) {
+ }
+
+ const std::string& getReason() const {
+ return reason;
+ }
+
+ private:
+ std::string reason;
+};
+
class SluiftClient {
public:
SluiftClient(const JID& jid, const std::string& password) : tracer(NULL) {
client = new Client(jid, password, &networkFactories);
client->setAlwaysTrustCertificates();
+ client->onDisconnected.connect(boost::bind(&SluiftClient::handleDisconnected, this, _1));
client->onMessageReceived.connect(boost::bind(&SluiftClient::handleIncomingEvent, this, _1));
client->onPresenceReceived.connect(boost::bind(&SluiftClient::handleIncomingEvent, this, _1));
+ client->getRoster()->onInitialRosterPopulated.connect(boost::bind(&SluiftClient::handleInitialRosterPopulated, this));
if (debug) {
tracer = new ClientXMLTracer(client);
}
@@ -57,7 +73,11 @@ class SluiftClient {
}
void connect() {
+ rosterReceived = false;
client->connect();
+ }
+
+ void waitConnected() {
while (client->isActive() && !client->isAvailable()) {
processEvents();
}
@@ -121,11 +141,23 @@ class SluiftClient {
}
}
+ std::vector<XMPPRosterItem> getRoster() {
+ client->requestRoster();
+ while (!rosterReceived) {
+ processEvents();
+ }
+ return client->getRoster()->getItems();
+ }
+
private:
void handleIncomingEvent(Stanza::ref stanza) {
pendingEvents.push_back(stanza);
}
+ void handleInitialRosterPopulated() {
+ rosterReceived = true;
+ }
+
void processEvents() {
eventLoop.runUntilEvents();
}
@@ -141,15 +173,57 @@ class SluiftClient {
this->softwareVersion = SoftwareVersion("", "", "");
}
}
+
+ void handleDisconnected(const boost::optional<ClientError>& error) {
+ if (error) {
+ std::string reason("Disconnected: ");
+ switch(error->getType()) {
+ case ClientError::UnknownError: reason += "Unknown Error"; break;
+ case ClientError::DomainNameResolveError: reason += "Unable to find server"; break;
+ case ClientError::ConnectionError: reason += "Error connecting to server"; break;
+ case ClientError::ConnectionReadError: reason += "Error while receiving server data"; break;
+ case ClientError::ConnectionWriteError: reason += "Error while sending data to the server"; break;
+ case ClientError::XMLError: reason += "Error parsing server data"; break;
+ case ClientError::AuthenticationFailedError: reason += "Login/password invalid"; break;
+ case ClientError::CompressionFailedError: reason += "Error while compressing stream"; break;
+ case ClientError::ServerVerificationFailedError: reason += "Server verification failed"; break;
+ case ClientError::NoSupportedAuthMechanismsError: reason += "Authentication mechanisms not supported"; break;
+ case ClientError::UnexpectedElementError: reason += "Unexpected response"; break;
+ case ClientError::ResourceBindError: reason += "Error binding resource"; break;
+ case ClientError::SessionStartError: reason += "Error starting session"; break;
+ case ClientError::StreamError: reason += "Stream error"; break;
+ case ClientError::TLSError: reason += "Encryption error"; break;
+ case ClientError::ClientCertificateLoadError: reason += "Error loading certificate (Invalid password?)"; break;
+ case ClientError::ClientCertificateError: reason += "Certificate not authorized"; break;
+ case ClientError::UnknownCertificateError: reason += "Unknown certificate"; break;
+ case ClientError::CertificateExpiredError: reason += "Certificate has expired"; break;
+ case ClientError::CertificateNotYetValidError: reason += "Certificate is not yet valid"; break;
+ case ClientError::CertificateSelfSignedError: reason += "Certificate is self-signed"; break;
+ case ClientError::CertificateRejectedError: reason += "Certificate has been rejected"; break;
+ case ClientError::CertificateUntrustedError: reason += "Certificate is not trusted"; break;
+ case ClientError::InvalidCertificatePurposeError: reason += "Certificate cannot be used for encrypting your connection"; break;
+ case ClientError::CertificatePathLengthExceededError: reason += "Certificate path length constraint exceeded"; break;
+ case ClientError::InvalidCertificateSignatureError: reason += "Invalid certificate signature"; break;
+ case ClientError::InvalidCAError: reason += "Invalid Certificate Authority"; break;
+ case ClientError::InvalidServerIdentityError: reason += "Certificate does not match the host identity"; break;
+ }
+ throw SluiftException(reason);
+ }
+ }
private:
Client* client;
ClientXMLTracer* tracer;
+ bool rosterReceived;
boost::optional<SoftwareVersion> softwareVersion;
ErrorPayload::ref error;
std::deque<Stanza::ref> pendingEvents;
};
+#define CHECK_CLIENT_CONNECTED(client, L) \
+ if (!(*client)->isConnected()) { \
+ lua_pushnil(L); \
+ }
/*******************************************************************************
* Client functions.
@@ -159,14 +233,51 @@ static inline SluiftClient* getClient(lua_State* L) {
return *reinterpret_cast<SluiftClient**>(luaL_checkudata(L, 1, SLUIFT_CLIENT));
}
+static int sluift_client_connect(lua_State *L) {
+ try {
+ SluiftClient* client = getClient(L);
+ client->connect();
+ client->waitConnected();
+ return 0;
+ }
+ catch (const SluiftException& e) {
+ return luaL_error(L, e.getReason().c_str());
+ }
+}
+
+static int sluift_client_async_connect(lua_State *L) {
+ try {
+ getClient(L)->connect();
+ return 0;
+ }
+ catch (const SluiftException& e) {
+ return luaL_error(L, e.getReason().c_str());
+ }
+}
+
+static int sluift_client_wait_connected(lua_State *L) {
+ try {
+ getClient(L)->waitConnected();
+ return 0;
+ }
+ catch (const SluiftException& e) {
+ return luaL_error(L, e.getReason().c_str());
+ }
+}
+
static int sluift_client_is_connected(lua_State *L) {
lua_pushboolean(L, getClient(L)->isConnected());
return 1;
}
static int sluift_client_disconnect(lua_State *L) {
- getClient(L)->disconnect();
- return 0;
+ try {
+ getClient(L)->disconnect();
+ return 0;
+ }
+ catch (const SluiftException& e) {
+ return luaL_error(L, e.getReason().c_str());
+ }
}
static int sluift_client_set_version(lua_State *L) {
@@ -183,24 +294,80 @@ static int sluift_client_set_version(lua_State *L) {
return 0;
}
-static int sluift_client_get_version(lua_State *L) {
- SluiftClient* client = getClient(L);
- JID jid(std::string(luaL_checkstring(L, 2)));
+static int sluift_client_get_roster(lua_State *L) {
+ try {
+ SluiftClient* client = getClient(L);
+ std::vector<XMPPRosterItem> items = client->getRoster();
+
+ lua_createtable(L, 0, items.size());
+ foreach(const XMPPRosterItem& item, items) {
+ lua_createtable(L, 0, 3);
+
+ lua_pushstring(L, item.getName().c_str());
+ lua_setfield(L, -2, "name");
+
+ std::string subscription;
+ switch(item.getSubscription()) {
+ case RosterItemPayload::None: subscription = "none"; break;
+ case RosterItemPayload::To: subscription = "to"; break;
+ case RosterItemPayload::From: subscription = "from"; break;
+ case RosterItemPayload::Both: subscription = "both"; break;
+ case RosterItemPayload::Remove: subscription = "remove"; break;
+ }
+ lua_pushstring(L, subscription.c_str());
+ lua_setfield(L, -2, "subscription");
+
+ std::vector<std::string> groups = item.getGroups();
+ lua_createtable(L, groups.size(), 0);
+ for (size_t i = 0; i < groups.size(); ++i) {
+ lua_pushstring(L, groups[i].c_str());
+ lua_rawseti(L, -2, i + 1);
+ }
+ lua_setfield(L, -2, "groups");
- boost::optional<SoftwareVersion> version = client->getSoftwareVersion(jid);
- if (version) {
- lua_createtable(L, 0, 3);
+ lua_setfield(L, -2, item.getJID().toString().c_str());
+ }
+ /*boost::optional<SoftwareVersion> version = client->getSoftwareVersion(jid);
+ if (version) {
lua_pushstring(L, version->getName().c_str());
- lua_setfield(L, -2, "name");
- lua_pushstring(L, version->getVersion().c_str());
- lua_setfield(L, -2, "version");
- lua_pushstring(L, version->getOS().c_str());
- lua_setfield(L, -2, "os");
+ lua_pushstring(L, version->getName().c_str());
+ lua_setfield(L, -2, "name");
+ lua_pushstring(L, version->getVersion().c_str());
+ lua_setfield(L, -2, "version");
+ lua_pushstring(L, version->getOS().c_str());
+ lua_setfield(L, -2, "os");
+ }
+ */
+ return 1;
}
- else {
- lua_pushnil(L);
+ catch (const SluiftException& e) {
+ return luaL_error(L, e.getReason().c_str());
+ }
+}
+
+static int sluift_client_get_version(lua_State *L) {
+ try {
+ SluiftClient* client = getClient(L);
+ JID jid(std::string(luaL_checkstring(L, 2)));
+
+ boost::optional<SoftwareVersion> version = client->getSoftwareVersion(jid);
+ if (version) {
+ lua_createtable(L, 0, 3);
+ lua_pushstring(L, version->getName().c_str());
+ lua_setfield(L, -2, "name");
+ lua_pushstring(L, version->getVersion().c_str());
+ lua_setfield(L, -2, "version");
+ lua_pushstring(L, version->getOS().c_str());
+ lua_setfield(L, -2, "os");
+ }
+ else {
+ lua_pushnil(L);
+ }
+ return 1;
+ }
+ catch (const SluiftException& e) {
+ return luaL_error(L, e.getReason().c_str());
}
- return 1;
}
static int sluift_client_send_message(lua_State *L) {
@@ -214,60 +381,65 @@ static int sluift_client_send_presence(lua_State *L) {
}
static int sluift_client_for_event (lua_State *L) {
- SluiftClient* client = getClient(L);
- luaL_checktype(L, 2, LUA_TFUNCTION);
- while (true) {
- Stanza::ref event = client->getNextEvent();
- if (!event) {
- // We got disconnected
- lua_pushnil(L);
- lua_pushliteral(L, "disconnected");
- return 2;
- }
- else {
- // Push the function on the stack
- lua_pushvalue(L, 2);
-
- bool emitEvent = false;
- if (Message::ref message = boost::dynamic_pointer_cast<Message>(event)) {
- lua_createtable(L, 0, 3);
- lua_pushliteral(L, "message");
- lua_setfield(L, -2, "type");
- lua_pushstring(L, message->getFrom().toString().c_str());
- lua_setfield(L, -2, "from");
- lua_pushstring(L, message->getBody().c_str());
- lua_setfield(L, -2, "body");
- emitEvent = true;
- }
- else if (Presence::ref presence = boost::dynamic_pointer_cast<Presence>(event)) {
- lua_createtable(L, 0, 3);
- lua_pushliteral(L, "presence");
- lua_setfield(L, -2, "type");
- lua_pushstring(L, presence->getFrom().toString().c_str());
- lua_setfield(L, -2, "from");
- lua_pushstring(L, presence->getStatus().c_str());
- lua_setfield(L, -2, "status");
- emitEvent = true;
+ try {
+ SluiftClient* client = getClient(L);
+ luaL_checktype(L, 2, LUA_TFUNCTION);
+ while (true) {
+ Stanza::ref event = client->getNextEvent();
+ if (!event) {
+ // We got disconnected
+ lua_pushnil(L);
+ lua_pushliteral(L, "disconnected");
+ return 2;
}
else {
- assert(false);
- }
- if (emitEvent) {
- int oldTop = lua_gettop(L) - 2;
- lua_call(L, 1, LUA_MULTRET);
- int returnValues = lua_gettop(L) - oldTop;
- if (returnValues > 0) {
- lua_remove(L, -1 - returnValues);
- return returnValues;
+ // Push the function on the stack
+ lua_pushvalue(L, 2);
+
+ bool emitEvent = false;
+ if (Message::ref message = boost::dynamic_pointer_cast<Message>(event)) {
+ lua_createtable(L, 0, 3);
+ lua_pushliteral(L, "message");
+ lua_setfield(L, -2, "type");
+ lua_pushstring(L, message->getFrom().toString().c_str());
+ lua_setfield(L, -2, "from");
+ lua_pushstring(L, message->getBody().c_str());
+ lua_setfield(L, -2, "body");
+ emitEvent = true;
+ }
+ else if (Presence::ref presence = boost::dynamic_pointer_cast<Presence>(event)) {
+ lua_createtable(L, 0, 3);
+ lua_pushliteral(L, "presence");
+ lua_setfield(L, -2, "type");
+ lua_pushstring(L, presence->getFrom().toString().c_str());
+ lua_setfield(L, -2, "from");
+ lua_pushstring(L, presence->getStatus().c_str());
+ lua_setfield(L, -2, "status");
+ emitEvent = true;
+ }
+ else {
+ assert(false);
+ }
+ if (emitEvent) {
+ int oldTop = lua_gettop(L) - 2;
+ lua_call(L, 1, LUA_MULTRET);
+ int returnValues = lua_gettop(L) - oldTop;
+ if (returnValues > 0) {
+ lua_remove(L, -1 - returnValues);
+ return returnValues;
+ }
+ }
+ else {
+ // Remove the function from the stack again, since
+ // we're not calling the function
+ lua_pop(L, 1);
}
- }
- else {
- // Remove the function from the stack again, since
- // we're not calling the function
- lua_pop(L, 1);
}
}
}
+ catch (const SluiftException& e) {
+ return luaL_error(L, e.getReason().c_str());
+ }
}
static int sluift_client_gc (lua_State *L) {
@@ -278,11 +450,15 @@ static int sluift_client_gc (lua_State *L) {
static const luaL_reg sluift_client_functions[] = {
+ {"connect", sluift_client_connect},
+ {"async_connect", sluift_client_async_connect},
+ {"wait_connected", sluift_client_wait_connected},
{"is_connected", sluift_client_is_connected},
{"disconnect", sluift_client_disconnect},
{"send_message", sluift_client_send_message},
{"send_presence", sluift_client_send_presence},
{"set_version", sluift_client_set_version},
+ {"get_roster", sluift_client_get_roster},
{"get_version", sluift_client_get_version},
{"for_event", sluift_client_for_event},
{"__gc", sluift_client_gc},
@@ -293,7 +469,7 @@ static const luaL_reg sluift_client_functions[] = {
* Module functions
******************************************************************************/
-static int sluift_connect(lua_State *L) {
+static int sluift_new_client(lua_State *L) {
JID jid(std::string(luaL_checkstring(L, 1)));
std::string password(luaL_checkstring(L, 2));
@@ -302,10 +478,6 @@ static int sluift_connect(lua_State *L) {
lua_setmetatable(L, -2);
*client = new SluiftClient(jid, password);
- (*client)->connect();
- if (!(*client)->isConnected()) {
- lua_pushnil(L);
- }
return 1;
}
@@ -332,7 +504,7 @@ static int sluift_newindex(lua_State *L) {
}
static const luaL_reg sluift_functions[] = {
- {"connect", sluift_connect},
+ {"new_client", sluift_new_client},
{NULL, NULL}
};