diff options
| author | Richard Maudsley <richard.maudsley@isode.com> | 2014-07-17 09:46:50 (GMT) |
|---|---|---|
| committer | Swift Review <review@swift.im> | 2014-08-10 11:08:27 (GMT) |
| commit | 8ec22a9c5591584fd1725ed028d714c51b7509d3 (patch) | |
| tree | 3687e7023696c9e790a24fd54b7d04f14ac58ab2 | |
| parent | 5e9e715e49a5ddb6ce9c76ec61e7ecfd6eacdb58 (diff) | |
| download | swift-contrib-8ec22a9c5591584fd1725ed028d714c51b7509d3.zip swift-contrib-8ec22a9c5591584fd1725ed028d714c51b7509d3.tar.bz2 | |
Fix invalid characters being allowed in JID domains
Test-Information:
Prepare valid and invalid JIDs and make sure that isValid() is reported correctly. Added unit tests.
Change-Id: Ic4d86f8b6ea9defc517ada2f8e3cc54979237cf4
25 files changed, 94 insertions, 37 deletions
diff --git a/Slimber/Qt/SConscript b/Slimber/Qt/SConscript index 054d8b6..b4f0bc3 100644 --- a/Slimber/Qt/SConscript +++ b/Slimber/Qt/SConscript @@ -1,53 +1,59 @@ import os, shutil, datetime Import("env") myenv = env.Clone() myenv["CXXFLAGS"] = filter(lambda x : x != "-Wfloat-equal", myenv["CXXFLAGS"]) myenv.UseFlags(env["SLIMBER_FLAGS"]) myenv.UseFlags(env["LIMBER_FLAGS"]) myenv.UseFlags(env["SWIFTOOLS_FLAGS"]) myenv.UseFlags(env["SWIFTEN_FLAGS"]) -myenv.UseFlags(env["LIBIDN_FLAGS"]) myenv.UseFlags(env["BOOST_FLAGS"]) myenv.UseFlags(env.get("LIBXML_FLAGS", "")) myenv.UseFlags(env.get("EXPAT_FLAGS", "")) myenv.UseFlags(env.get("AVAHI_FLAGS", "")) myenv.UseFlags(myenv["PLATFORM_FLAGS"]) +if myenv.get("HAVE_ICU") : + myenv.MergeFlags(env["ICU_FLAGS"]) + myenv.Append(CPPDEFINES = ["HAVE_ICU"]) +if myenv.get("HAVE_LIBIDN") : + myenv.MergeFlags(env["LIBIDN_FLAGS"]) + myenv.Append(CPPDEFINES = ["HAVE_LIBIDN"]) + myenv.Tool("qt4", toolpath = ["#/BuildTools/SCons/Tools"]) myenv.Tool("nsis", toolpath = ["#/BuildTools/SCons/Tools"]) myenv.EnableQt4Modules(['QtCore', 'QtGui'], debug = False) myenv.Append(CPPPATH = ["."]) if env["PLATFORM"] == "win32" : myenv.Append(LINKFLAGS = ["/SUBSYSTEM:WINDOWS"]) myenv.Append(LIBS = "qtmain") myenv.BuildVersion("BuildVersion.h", project = "slimber") sources = [ "main.cpp", "QtMenulet.cpp", "QtAboutDialog.cpp", myenv.Qrc("Slimber.qrc"), ] #if env["PLATFORM"] == "win32" : # myenv.RES("../resources/Windows/Slimber.rc") # sources += ["../resources/Windows/Slimber.res"] if env["PLATFORM"] == "win32" : slimberProgram = myenv.Program("Slimber", sources) else : slimberProgram = myenv.Program("slimber", sources) if env["PLATFORM"] == "win32" : if "dist" in COMMAND_LINE_TARGETS or env.GetOption("clean") : myenv.WindowsBundle("Slimber", resources = {}, qtlibs = ["QtCore4", "QtGui4"]) myenv.Append(NSIS_OPTIONS = [ "/DmsvccRedistributableDir=\"" + env["vcredist"] + "\"", "/DbuildDate=" + datetime.date.today().strftime("%Y%m%d") ]) #myenv.Nsis("../Packaging/nsis/slimber.nsi") diff --git a/Sluift/sluift.cpp b/Sluift/sluift.cpp index 5e837c1..bef2e3d 100644 --- a/Sluift/sluift.cpp +++ b/Sluift/sluift.cpp @@ -1,494 +1,500 @@ /* * Copyright (c) 2011-2014 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <Sluift/sluift.h> #include <lua.hpp> #include <string> #include <boost/bind.hpp> #include <boost/numeric/conversion/cast.hpp> #include <boost/assign/list_of.hpp> #include <boost/filesystem.hpp> #include "Watchdog.h" #include <Sluift/Lua/Check.h> #include <Sluift/SluiftClient.h> #include <Sluift/SluiftComponent.h> #include <Sluift/globals.h> #include <Sluift/Lua/Exception.h> #include <Sluift/Lua/LuaUtils.h> #include <Sluift/Lua/FunctionRegistration.h> #include <Swiften/Base/sleep.h> #include <Swiften/Base/foreach.h> #include <Swiften/Base/IDGenerator.h> #include <Swiften/Parser/PayloadParsers/UnitTest/PayloadsParserTester.h> #include <Swiften/Serializer/PayloadSerializers/FullPayloadSerializerCollection.h> #include <Swiften/Serializer/PayloadSerializer.h> #include <Swiften/TLS/Certificate.h> #include <Swiften/TLS/CertificateFactory.h> #include <Sluift/LuaElementConvertor.h> #include <Sluift/Lua/Debug.h> #include <Swiften/StringCodecs/Base64.h> #include <Swiften/StringCodecs/Hexify.h> #include <Swiften/IDN/IDNConverter.h> #include <Swiften/Crypto/CryptoProvider.h> #include <Swiften/Crypto/PlatformCryptoProvider.h> #include <Sluift/ITunesInterface.h> using namespace Swift; namespace Swift { namespace Sluift { SluiftGlobals globals; } } extern "C" const char core_lua[]; extern "C" size_t core_lua_size; static inline bool getGlobalDebug(lua_State* L) { lua_rawgeti(L, LUA_REGISTRYINDEX, Sluift::globals.moduleLibIndex); lua_getfield(L, -1, "debug"); int result = lua_toboolean(L, -1); lua_pop(L, 2); return result; } /******************************************************************************* * Module functions ******************************************************************************/ SLUIFT_LUA_FUNCTION_WITH_HELP( Sluift, new_client, "Creates a new client.\n\nReturns a @{Client} object.\n", "jid The JID to connect as\n" "passphrase The passphrase to use\n", "" ) { Lua::checkString(L, 1); JID jid(std::string(Lua::checkString(L, 1))); std::string password(Lua::checkString(L, 2)); SluiftClient** client = reinterpret_cast<SluiftClient**>(lua_newuserdata(L, sizeof(SluiftClient*))); lua_rawgeti(L, LUA_REGISTRYINDEX, Sluift::globals.coreLibIndex); lua_getfield(L, -1, "Client"); lua_setmetatable(L, -3); lua_pop(L, 1); *client = new SluiftClient(jid, password, &Sluift::globals.networkFactories, &Sluift::globals.eventLoop); (*client)->setTraceEnabled(getGlobalDebug(L)); return 1; } SLUIFT_LUA_FUNCTION_WITH_HELP( Sluift, new_component, "Creates a new component.\n\nReturns a @{Component} object.\n", "jid The JID to connect as\n" "passphrase The passphrase to use\n", "" ) { Lua::checkString(L, 1); JID jid(std::string(Lua::checkString(L, 1))); std::string password(Lua::checkString(L, 2)); SluiftComponent** component = reinterpret_cast<SluiftComponent**>(lua_newuserdata(L, sizeof(SluiftComponent*))); lua_rawgeti(L, LUA_REGISTRYINDEX, Sluift::globals.coreLibIndex); lua_getfield(L, -1, "Component"); lua_setmetatable(L, -3); lua_pop(L, 1); *component = new SluiftComponent(jid, password, &Sluift::globals.networkFactories, &Sluift::globals.eventLoop); (*component)->setTraceEnabled(getGlobalDebug(L)); return 1; } SLUIFT_LUA_FUNCTION_WITH_HELP( Sluift, sha1, "Compute the SHA-1 hash of given data", "data the data to hash", "" ) { static boost::shared_ptr<CryptoProvider> crypto(PlatformCryptoProvider::create()); if (!lua_isstring(L, 1)) { throw Lua::Exception("Expected string"); } size_t len; const char* data = lua_tolstring(L, 1, &len); ByteArray result = crypto->getSHA1Hash(createByteArray(data, len)); lua_pushlstring(L, reinterpret_cast<char*>(vecptr(result)), result.size()); return 1; } SLUIFT_LUA_FUNCTION_WITH_HELP( Sluift, sleep, "Sleeps for the given time.", "milliseconds the amount of milliseconds to sleep", "" ) { Sluift::globals.eventLoop.runOnce(); int timeout = Lua::checkIntNumber(L, 1); Watchdog watchdog(timeout, Sluift::globals.networkFactories.getTimerFactory()); while (!watchdog.getTimedOut()) { Swift::sleep(boost::numeric_cast<unsigned int>(std::min(100, timeout))); Sluift::globals.eventLoop.runOnce(); } return 0; } SLUIFT_LUA_FUNCTION_WITH_HELP( Sluift, new_uuid, "Generates a new UUID", "", "" ) { lua_pushstring(L, IDGenerator().generateID().c_str()); return 1; } SLUIFT_LUA_FUNCTION_WITH_HELP( Sluift, from_xml, "Convert a raw XML string into a structured representation.", "string the string to convert", "" ) { PayloadsParserTester parser; if (!parser.parse(Lua::checkString(L, 1))) { throw Lua::Exception("Error in XML"); } return Sluift::globals.elementConvertor.convertToLua(L, parser.getPayload()); } SLUIFT_LUA_FUNCTION_WITH_HELP( Sluift, to_xml, "Convert a structured element into XML.", "element the element to convert", "" ) { static FullPayloadSerializerCollection serializers; boost::shared_ptr<Payload> payload = boost::dynamic_pointer_cast<Payload>(Sluift::globals.elementConvertor.convertFromLua(L, 1)); if (!payload) { throw Lua::Exception("Unrecognized XML"); } PayloadSerializer* serializer = serializers.getPayloadSerializer(payload); if (!payload) { throw Lua::Exception("Unrecognized XML"); } lua_pushstring(L, serializer->serialize(payload).c_str()); return 1; } SLUIFT_LUA_FUNCTION_WITH_HELP( Sluift, hexify, "Convert binary data into hexadecimal format.", "data the data to convert", "" ) { if (!lua_isstring(L, 1)) { throw Lua::Exception("Expected string"); } size_t len; const char* data = lua_tolstring(L, 1, &len); lua_pushstring(L, Hexify::hexify(createByteArray(data, len)).c_str()); return 1; } SLUIFT_LUA_FUNCTION_WITH_HELP( Sluift, unhexify, "Convert hexadecimal data into binary data.", "data the data in hexadecimal format", "" ) { if (!lua_isstring(L, 1)) { throw Lua::Exception("Expected string"); } ByteArray result = Hexify::unhexify(lua_tostring(L, 1)); lua_pushlstring(L, reinterpret_cast<char*>(vecptr(result)), result.size()); return 1; } /******************************************************************************* * Crypto functions ******************************************************************************/ SLUIFT_LUA_FUNCTION_WITH_HELP( Crypto, new_certificate, "Creates a new X.509 certificate from DER data.\n", "der the DER-encoded certificate data", "") { ByteArray certData(Lua::checkByteArray(L, 1)); Certificate::ref cert(Sluift::globals.tlsFactories.getCertificateFactory()->createCertificateFromDER(certData)); lua_createtable(L, 0, 0); lua_pushstring(L, cert->getSubjectName().c_str()); lua_setfield(L, -2, "subject_name"); lua_pushstring(L, Certificate::getSHA1Fingerprint(cert, Sluift::globals.networkFactories.getCryptoProvider()).c_str()); lua_setfield(L, -2, "sha1_fingerprint"); Lua::pushStringArray(L, cert->getCommonNames()); lua_setfield(L, -2, "common_names"); Lua::pushStringArray(L, cert->getSRVNames()); lua_setfield(L, -2, "srv_names"); Lua::pushStringArray(L, cert->getDNSNames()); lua_setfield(L, -2, "dns_names"); Lua::pushStringArray(L, cert->getXMPPAddresses()); lua_setfield(L, -2, "xmpp_addresses"); Lua::registerTableToString(L, -1); return 1; } /******************************************************************************* * Filesystem Functions ******************************************************************************/ SLUIFT_LUA_FUNCTION(FS, list) { boost::filesystem::path dir(std::string(Lua::checkString(L, 1))); if (!boost::filesystem::exists(dir) || !boost::filesystem::is_directory(dir)) { lua_pushnil(L); lua_pushstring(L, "Argument is not an existing directory"); return 2; } boost::filesystem::directory_iterator i(dir); std::vector<boost::filesystem::path> items( i, boost::filesystem::directory_iterator()); lua_createtable(L, boost::numeric_cast<int>(items.size()), 0); for (size_t i = 0; i < items.size(); ++i) { lua_pushstring(L, items[i].string().c_str()); lua_rawseti(L, -2, boost::numeric_cast<int>(i+1)); } Lua::registerTableToString(L, -1); return 1; } SLUIFT_LUA_FUNCTION(FS, is_file) { boost::filesystem::path file(std::string(Lua::checkString(L, 1))); lua_pushboolean(L, boost::filesystem::is_regular_file(file)); return 1; } /******************************************************************************* * JID Functions ******************************************************************************/ SLUIFT_LUA_FUNCTION(JID, to_bare) { JID jid(std::string(Lua::checkString(L, 1))); lua_pushstring(L, jid.toBare().toString().c_str()); return 1; } SLUIFT_LUA_FUNCTION(JID, node) { JID jid(std::string(Lua::checkString(L, 1))); lua_pushstring(L, jid.getNode().c_str()); return 1; } SLUIFT_LUA_FUNCTION(JID, domain) { JID jid(std::string(Lua::checkString(L, 1))); lua_pushstring(L, jid.getDomain().c_str()); return 1; } SLUIFT_LUA_FUNCTION(JID, resource) { JID jid(std::string(Lua::checkString(L, 1))); lua_pushstring(L, jid.getResource().c_str()); return 1; } SLUIFT_LUA_FUNCTION(JID, escape_node) { lua_pushstring(L, JID::getEscapedNode(Lua::checkString(L, 1)).c_str()); return 1; } /******************************************************************************* * Base64 Functions ******************************************************************************/ SLUIFT_LUA_FUNCTION(Base64, encode) { if (!lua_isstring(L, 1)) { throw Lua::Exception("Expected string"); } size_t len; const char* data = lua_tolstring(L, 1, &len); lua_pushstring(L, Base64::encode(createByteArray(data, len)).c_str()); return 1; } SLUIFT_LUA_FUNCTION(Base64, decode) { if (!lua_isstring(L, 1)) { throw Lua::Exception("Expected string"); } ByteArray result = Base64::decode(lua_tostring(L, 1)); lua_pushlstring(L, reinterpret_cast<char*>(vecptr(result)), result.size()); return 1; } /******************************************************************************* * IDN Functions ******************************************************************************/ SLUIFT_LUA_FUNCTION(IDN, encode) { IDNConverter* converter = Sluift::globals.networkFactories.getIDNConverter(); - lua_pushstring(L, converter->getIDNAEncoded(Lua::checkString(L, 1)).c_str()); + boost::optional<std::string> encoded = converter->getIDNAEncoded(Lua::checkString(L, 1)); + if (!encoded) { + lua_pushnil(L); + lua_pushstring(L, "Error encoding domain name"); + return 2; + } + lua_pushstring(L, encoded->c_str()); return 1; } SLUIFT_LUA_FUNCTION(IDN, stringprep) { IDNConverter* converter = Sluift::globals.networkFactories.getIDNConverter(); IDNConverter::StringPrepProfile profile; std::string profileString = Lua::checkString(L, 2); if (profileString == "nameprep") { profile = IDNConverter::NamePrep; } else if (profileString == "xmpp_nodeprep") { profile = IDNConverter::XMPPNodePrep; } else if (profileString == "xmpp_resourceprep") { profile = IDNConverter::XMPPResourcePrep; } else if (profileString == "saslprep") { profile = IDNConverter::SASLPrep; } else { throw Lua::Exception("Invalid profile"); } try { lua_pushstring(L, converter->getStringPrepared(Lua::checkString(L, 1), profile).c_str()); } catch (const std::exception&) { throw Lua::Exception("Error"); } return 1; } /******************************************************************************* * iTunes Functions ******************************************************************************/ #ifdef HAVE_ITUNES SLUIFT_LUA_FUNCTION(iTunes, get_current_track) { boost::optional<ITunesInterface::Track> track = Sluift::globals.iTunes.getCurrentTrack(); if (!track) { return 0; } lua_createtable(L, 0, 0); lua_pushstring(L, track->artist.c_str()); lua_setfield(L, -2, "artist"); lua_pushstring(L, track->name.c_str()); lua_setfield(L, -2, "name"); lua_pushstring(L, track->album.c_str()); lua_setfield(L, -2, "album"); lua_pushinteger(L, track->trackNumber); lua_setfield(L, -2, "track_number"); lua_pushnumber(L, track->duration); lua_setfield(L, -2, "duration"); lua_pushinteger(L, track->rating); lua_setfield(L, -2, "rating"); Lua::registerTableToString(L, -1); Lua::registerTableEquals(L, -1); return 1; } #endif /******************************************************************************* * Module registration ******************************************************************************/ static const luaL_Reg sluift_functions[] = { {NULL, NULL} }; SLUIFT_API int luaopen_sluift(lua_State* L) { // Initialize & store the module table luaL_register(L, lua_tostring(L, 1), sluift_functions); lua_pushinteger(L, -1); lua_setfield(L, -2, "timeout"); lua_pushboolean(L, 0); lua_setfield(L, -2, "debug"); lua_pushvalue(L, -1); Sluift::globals.moduleLibIndex = luaL_ref(L, LUA_REGISTRYINDEX); // Load core lib code if (luaL_loadbuffer(L, core_lua, core_lua_size, "core.lua") != 0) { lua_error(L); } lua_pushvalue(L, -2); lua_call(L, 1, 1); Sluift::globals.coreLibIndex = luaL_ref(L, LUA_REGISTRYINDEX); // Register functions Lua::FunctionRegistry::getInstance().addFunctionsToTable(L, "Sluift"); Lua::FunctionRegistry::getInstance().createFunctionTable(L, "JID"); lua_setfield(L, -2, "jid"); Lua::FunctionRegistry::getInstance().createFunctionTable(L, "Base64"); lua_setfield(L, -2, "base64"); Lua::FunctionRegistry::getInstance().createFunctionTable(L, "IDN"); lua_setfield(L, -2, "idn"); Lua::FunctionRegistry::getInstance().createFunctionTable(L, "Crypto"); lua_setfield(L, -2, "crypto"); Lua::FunctionRegistry::getInstance().createFunctionTable(L, "FS"); lua_setfield(L, -2, "fs"); #ifdef HAVE_ITUNES Lua::FunctionRegistry::getInstance().createFunctionTable(L, "iTunes"); lua_setfield(L, -2, "itunes"); #endif // Register convenience functions lua_rawgeti(L, LUA_REGISTRYINDEX, Sluift::globals.coreLibIndex); std::vector<std::string> coreLibExports = boost::assign::list_of ("tprint")("disco")("help")("get_help")("copy")("with")("read_file")("create_form"); foreach (const std::string& coreLibExport, coreLibExports) { lua_getfield(L, -1, coreLibExport.c_str()); lua_setfield(L, -3, coreLibExport.c_str()); } lua_pop(L, 1); // Load client metatable lua_rawgeti(L, LUA_REGISTRYINDEX, Sluift::globals.coreLibIndex); std::vector<std::string> tables = boost::assign::list_of("Client"); foreach(const std::string& table, tables) { lua_getfield(L, -1, table.c_str()); Lua::FunctionRegistry::getInstance().addFunctionsToTable(L, table); lua_pop(L, 1); } lua_pop(L, 1); // Load component metatable lua_rawgeti(L, LUA_REGISTRYINDEX, Sluift::globals.coreLibIndex); std::vector<std::string> comp_tables = boost::assign::list_of("Component"); foreach(const std::string& table, comp_tables) { lua_getfield(L, -1, table.c_str()); Lua::FunctionRegistry::getInstance().addFunctionsToTable(L, table); lua_pop(L, 1); } lua_pop(L, 1); // Register documentation for all elements foreach (boost::shared_ptr<LuaElementConvertor> convertor, Sluift::globals.elementConvertor.getConvertors()) { boost::optional<LuaElementConvertor::Documentation> documentation = convertor->getDocumentation(); if (documentation) { Lua::registerClassHelp(L, documentation->className, documentation->description); } } // Register global documentation Lua::registerExtraHelp(L, -1, "sluift"); return 1; } diff --git a/Swiften/IDN/ICUConverter.cpp b/Swiften/IDN/ICUConverter.cpp index 18ff231..f698eb9 100644 --- a/Swiften/IDN/ICUConverter.cpp +++ b/Swiften/IDN/ICUConverter.cpp @@ -1,157 +1,157 @@ /* * Copyright (c) 2012-2013 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <Swiften/IDN/ICUConverter.h> #pragma GCC diagnostic ignored "-Wold-style-cast" #pragma clang diagnostic ignored "-Wheader-hygiene" #include <unicode/uidna.h> #include <unicode/usprep.h> #include <unicode/ucnv.h> #include <unicode/ustring.h> #include <boost/numeric/conversion/cast.hpp> using namespace Swift; using boost::numeric_cast; namespace { typedef std::vector<UChar, SafeAllocator<UChar> > ICUString; const char* toConstCharArray(const std::string& input) { return input.c_str(); } const char* toConstCharArray(const std::vector<unsigned char, SafeAllocator<unsigned char> >& input) { return reinterpret_cast<const char*>(vecptr(input)); } template<typename T> ICUString convertToICUString(const T& s) { ICUString result; result.resize(s.size()); UErrorCode status = U_ZERO_ERROR; int32_t icuResultLength = numeric_cast<int32_t>(result.size()); u_strFromUTF8Lenient(vecptr(result), numeric_cast<int32_t>(result.size()), &icuResultLength, toConstCharArray(s), numeric_cast<int32_t>(s.size()), &status); if (status == U_BUFFER_OVERFLOW_ERROR) { status = U_ZERO_ERROR; result.resize(numeric_cast<size_t>(icuResultLength)); u_strFromUTF8Lenient(vecptr(result), numeric_cast<int32_t>(result.size()), &icuResultLength, toConstCharArray(s), numeric_cast<int32_t>(s.size()), &status); } if (U_FAILURE(status)) { return ICUString(); } result.resize(numeric_cast<size_t>(icuResultLength)); return result; } std::vector<char, SafeAllocator<char> > convertToArray(const ICUString& input) { std::vector<char, SafeAllocator<char> > result; result.resize(input.size()); UErrorCode status = U_ZERO_ERROR; int32_t inputLength = numeric_cast<int32_t>(result.size()); u_strToUTF8(vecptr(result), numeric_cast<int32_t>(result.size()), &inputLength, vecptr(input), numeric_cast<int32_t>(input.size()), &status); if (status == U_BUFFER_OVERFLOW_ERROR) { status = U_ZERO_ERROR; result.resize(numeric_cast<size_t>(inputLength)); u_strToUTF8(vecptr(result), numeric_cast<int32_t>(result.size()), &inputLength, vecptr(input), numeric_cast<int32_t>(input.size()), &status); } if (U_FAILURE(status)) { return std::vector<char, SafeAllocator<char> >(); } result.resize(numeric_cast<size_t>(inputLength) + 1); result[result.size() - 1] = '\0'; return result; } std::string convertToString(const ICUString& input) { return std::string(vecptr(convertToArray(input))); } UStringPrepProfileType getICUProfileType(IDNConverter::StringPrepProfile profile) { switch(profile) { case IDNConverter::NamePrep: return USPREP_RFC3491_NAMEPREP; case IDNConverter::XMPPNodePrep: return USPREP_RFC3920_NODEPREP; case IDNConverter::XMPPResourcePrep: return USPREP_RFC3920_RESOURCEPREP; case IDNConverter::SASLPrep: return USPREP_RFC4013_SASLPREP; } assert(false); return USPREP_RFC3491_NAMEPREP; } template<typename StringType> std::vector<char, SafeAllocator<char> > getStringPreparedDetail(const StringType& s, IDNConverter::StringPrepProfile profile) { UErrorCode status = U_ZERO_ERROR; boost::shared_ptr<UStringPrepProfile> icuProfile(usprep_openByType(getICUProfileType(profile), &status), usprep_close); assert(U_SUCCESS(status)); ICUString icuInput = convertToICUString(s); ICUString icuResult; UParseError parseError; icuResult.resize(icuInput.size()); int32_t icuResultLength = usprep_prepare(icuProfile.get(), vecptr(icuInput), numeric_cast<int32_t>(icuInput.size()), vecptr(icuResult), numeric_cast<int32_t>(icuResult.size()), USPREP_ALLOW_UNASSIGNED, &parseError, &status); icuResult.resize(numeric_cast<size_t>(icuResultLength)); if (status == U_BUFFER_OVERFLOW_ERROR) { status = U_ZERO_ERROR; icuResult.resize(numeric_cast<size_t>(icuResultLength)); icuResultLength = usprep_prepare(icuProfile.get(), vecptr(icuInput), numeric_cast<int32_t>(icuInput.size()), vecptr(icuResult), numeric_cast<int32_t>(icuResult.size()), USPREP_ALLOW_UNASSIGNED, &parseError, &status); icuResult.resize(numeric_cast<size_t>(icuResultLength)); } if (U_FAILURE(status)) { return std::vector<char, SafeAllocator<char> >(); } icuResult.resize(numeric_cast<size_t>(icuResultLength)); return convertToArray(icuResult); } } namespace Swift { std::string ICUConverter::getStringPrepared(const std::string& s, StringPrepProfile profile) { if (s.empty()) { return ""; } std::vector<char, SafeAllocator<char> > preparedData = getStringPreparedDetail(s, profile); if (preparedData.empty()) { throw std::exception(); } return std::string(vecptr(preparedData)); } SafeByteArray ICUConverter::getStringPrepared(const SafeByteArray& s, StringPrepProfile profile) { if (s.empty()) { return SafeByteArray(); } std::vector<char, SafeAllocator<char> > preparedData = getStringPreparedDetail(s, profile); if (preparedData.empty()) { throw std::exception(); } return createSafeByteArray(reinterpret_cast<const char*>(vecptr(preparedData))); } -std::string ICUConverter::getIDNAEncoded(const std::string& domain) { +boost::optional<std::string> ICUConverter::getIDNAEncoded(const std::string& domain) { UErrorCode status = U_ZERO_ERROR; ICUString icuInput = convertToICUString(domain); ICUString icuResult; icuResult.resize(icuInput.size()); UParseError parseError; - int32_t icuResultLength = uidna_IDNToASCII(vecptr(icuInput), numeric_cast<int32_t>(icuInput.size()), vecptr(icuResult), numeric_cast<int32_t>(icuResult.size()), UIDNA_DEFAULT, &parseError, &status); + int32_t icuResultLength = uidna_IDNToASCII(vecptr(icuInput), numeric_cast<int32_t>(icuInput.size()), vecptr(icuResult), numeric_cast<int32_t>(icuResult.size()), UIDNA_USE_STD3_RULES, &parseError, &status); if (status == U_BUFFER_OVERFLOW_ERROR) { status = U_ZERO_ERROR; icuResult.resize(numeric_cast<size_t>(icuResultLength)); - icuResultLength = uidna_IDNToASCII(vecptr(icuInput), numeric_cast<int32_t>(icuInput.size()), vecptr(icuResult), numeric_cast<int32_t>(icuResult.size()), UIDNA_DEFAULT, &parseError, &status); + icuResultLength = uidna_IDNToASCII(vecptr(icuInput), numeric_cast<int32_t>(icuInput.size()), vecptr(icuResult), numeric_cast<int32_t>(icuResult.size()), UIDNA_USE_STD3_RULES, &parseError, &status); } if (U_FAILURE(status)) { - return domain; + return boost::optional<std::string>(); } icuResult.resize(numeric_cast<size_t>(icuResultLength)); return convertToString(icuResult); } } diff --git a/Swiften/IDN/ICUConverter.h b/Swiften/IDN/ICUConverter.h index 8ba9bb5..05eafcc 100644 --- a/Swiften/IDN/ICUConverter.h +++ b/Swiften/IDN/ICUConverter.h @@ -1,22 +1,22 @@ /* * Copyright (c) 2012-2013 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once #include <string> #include <Swiften/Base/API.h> #include <Swiften/Base/Override.h> #include <Swiften/IDN/IDNConverter.h> namespace Swift { class SWIFTEN_API ICUConverter : public IDNConverter { public: virtual std::string getStringPrepared(const std::string& s, StringPrepProfile profile) SWIFTEN_OVERRIDE; virtual SafeByteArray getStringPrepared(const SafeByteArray& s, StringPrepProfile profile) SWIFTEN_OVERRIDE; - virtual std::string getIDNAEncoded(const std::string& s) SWIFTEN_OVERRIDE; + virtual boost::optional<std::string> getIDNAEncoded(const std::string& s) SWIFTEN_OVERRIDE; }; } diff --git a/Swiften/IDN/IDNConverter.h b/Swiften/IDN/IDNConverter.h index c55d969..f6974bc 100644 --- a/Swiften/IDN/IDNConverter.h +++ b/Swiften/IDN/IDNConverter.h @@ -1,31 +1,32 @@ /* * Copyright (c) 2013 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once #include <string> #include <Swiften/Base/API.h> #include <Swiften/Base/SafeByteArray.h> +#include <boost/optional.hpp> namespace Swift { class SWIFTEN_API IDNConverter { public: virtual ~IDNConverter(); enum StringPrepProfile { NamePrep, XMPPNodePrep, XMPPResourcePrep, SASLPrep }; virtual std::string getStringPrepared(const std::string& s, StringPrepProfile profile) = 0; virtual SafeByteArray getStringPrepared(const SafeByteArray& s, StringPrepProfile profile) = 0; // Thread-safe - virtual std::string getIDNAEncoded(const std::string& s) = 0; + virtual boost::optional<std::string> getIDNAEncoded(const std::string& s) = 0; }; } diff --git a/Swiften/IDN/LibIDNConverter.cpp b/Swiften/IDN/LibIDNConverter.cpp index c4a1c18..45b1d14 100644 --- a/Swiften/IDN/LibIDNConverter.cpp +++ b/Swiften/IDN/LibIDNConverter.cpp @@ -1,80 +1,80 @@ /* * Copyright (c) 2012-2013 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <Swiften/IDN/LibIDNConverter.h> extern "C" { #include <stringprep.h> #include <idna.h> } #include <vector> #include <cassert> #include <cstdlib> #include <Swiften/Base/ByteArray.h> #include <Swiften/Base/SafeAllocator.h> #include <boost/shared_ptr.hpp> using namespace Swift; namespace { static const int MAX_STRINGPREP_SIZE = 1024; const Stringprep_profile* getLibIDNProfile(IDNConverter::StringPrepProfile profile) { switch(profile) { case IDNConverter::NamePrep: return stringprep_nameprep; case IDNConverter::XMPPNodePrep: return stringprep_xmpp_nodeprep; case IDNConverter::XMPPResourcePrep: return stringprep_xmpp_resourceprep; case IDNConverter::SASLPrep: return stringprep_saslprep; } assert(false); return 0; } template<typename StringType, typename ContainerType> ContainerType getStringPreparedInternal(const StringType& s, IDNConverter::StringPrepProfile profile) { ContainerType input(s.begin(), s.end()); input.resize(MAX_STRINGPREP_SIZE); if (stringprep(&input[0], MAX_STRINGPREP_SIZE, static_cast<Stringprep_profile_flags>(0), getLibIDNProfile(profile)) == 0) { return input; } else { return ContainerType(); } } } namespace Swift { std::string LibIDNConverter::getStringPrepared(const std::string& s, StringPrepProfile profile) { std::vector<char> preparedData = getStringPreparedInternal< std::string, std::vector<char> >(s, profile); if (preparedData.empty()) { throw std::exception(); } return std::string(vecptr(preparedData)); } SafeByteArray LibIDNConverter::getStringPrepared(const SafeByteArray& s, StringPrepProfile profile) { std::vector<char, SafeAllocator<char> > preparedData = getStringPreparedInternal<SafeByteArray, std::vector<char, SafeAllocator<char> > >(s, profile); if (preparedData.empty()) { throw std::exception(); } return createSafeByteArray(reinterpret_cast<const char*>(vecptr(preparedData))); } -std::string LibIDNConverter::getIDNAEncoded(const std::string& domain) { +boost::optional<std::string> LibIDNConverter::getIDNAEncoded(const std::string& domain) { char* output; - if (idna_to_ascii_8z(domain.c_str(), &output, 0) == IDNA_SUCCESS) { + if (idna_to_ascii_8z(domain.c_str(), &output, IDNA_USE_STD3_ASCII_RULES) == IDNA_SUCCESS) { std::string result(output); free(output); return result; } else { - return domain; + return boost::optional<std::string>(); } } } diff --git a/Swiften/IDN/LibIDNConverter.h b/Swiften/IDN/LibIDNConverter.h index 23f6bbd..4cfff1a 100644 --- a/Swiften/IDN/LibIDNConverter.h +++ b/Swiften/IDN/LibIDNConverter.h @@ -1,23 +1,23 @@ /* * Copyright (c) 2012-2013 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once #include <string> #include <Swiften/Base/API.h> #include <Swiften/Base/Override.h> #include <Swiften/IDN/IDNConverter.h> namespace Swift { class SWIFTEN_API LibIDNConverter : public IDNConverter { public: virtual std::string getStringPrepared(const std::string& s, StringPrepProfile profile) SWIFTEN_OVERRIDE; virtual SafeByteArray getStringPrepared(const SafeByteArray& s, StringPrepProfile profile) SWIFTEN_OVERRIDE; - virtual std::string getIDNAEncoded(const std::string& s) SWIFTEN_OVERRIDE; + virtual boost::optional<std::string> getIDNAEncoded(const std::string& s) SWIFTEN_OVERRIDE; }; } diff --git a/Swiften/IDN/UnitTest/IDNConverterTest.cpp b/Swiften/IDN/UnitTest/IDNConverterTest.cpp index 285cf4b..a66e141 100644 --- a/Swiften/IDN/UnitTest/IDNConverterTest.cpp +++ b/Swiften/IDN/UnitTest/IDNConverterTest.cpp @@ -1,56 +1,63 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <cppunit/extensions/HelperMacros.h> #include <cppunit/extensions/TestFactoryRegistry.h> #include <boost/shared_ptr.hpp> #include <Swiften/IDN/IDNConverter.h> #include <Swiften/IDN/PlatformIDNConverter.h> using namespace Swift; class IDNConverterTest : public CppUnit::TestFixture { CPPUNIT_TEST_SUITE(IDNConverterTest); CPPUNIT_TEST(testStringPrep); CPPUNIT_TEST(testStringPrep_Empty); CPPUNIT_TEST(testGetEncoded); CPPUNIT_TEST(testGetEncoded_International); + CPPUNIT_TEST(testGetEncoded_Invalid); CPPUNIT_TEST_SUITE_END(); public: void setUp() { testling = boost::shared_ptr<IDNConverter>(PlatformIDNConverter::create()); } void testStringPrep() { std::string result = testling->getStringPrepared("tron\xc3\x87on", IDNConverter::NamePrep); CPPUNIT_ASSERT_EQUAL(std::string("tron\xc3\xa7on"), result); } void testStringPrep_Empty() { CPPUNIT_ASSERT_EQUAL(std::string(""), testling->getStringPrepared("", IDNConverter::NamePrep)); CPPUNIT_ASSERT_EQUAL(std::string(""), testling->getStringPrepared("", IDNConverter::XMPPNodePrep)); CPPUNIT_ASSERT_EQUAL(std::string(""), testling->getStringPrepared("", IDNConverter::XMPPResourcePrep)); } void testGetEncoded() { - std::string result = testling->getIDNAEncoded("www.swift.im"); - CPPUNIT_ASSERT_EQUAL(std::string("www.swift.im"), result); + boost::optional<std::string> result = testling->getIDNAEncoded("www.swift.im"); + CPPUNIT_ASSERT(!!result); + CPPUNIT_ASSERT_EQUAL(std::string("www.swift.im"), *result); } void testGetEncoded_International() { - std::string result = testling->getIDNAEncoded("www.tron\xc3\x87on.com"); - CPPUNIT_ASSERT_EQUAL(std::string("www.xn--tronon-zua.com"), result); + boost::optional<std::string> result = testling->getIDNAEncoded("www.tron\xc3\x87on.com"); + CPPUNIT_ASSERT(!!result); + CPPUNIT_ASSERT_EQUAL(std::string("www.xn--tronon-zua.com"), *result); } + void testGetEncoded_Invalid() { + boost::optional<std::string> result = testling->getIDNAEncoded("www.foo,bar.com"); + CPPUNIT_ASSERT(!result); + } private: boost::shared_ptr<IDNConverter> testling; }; CPPUNIT_TEST_SUITE_REGISTRATION(IDNConverterTest); diff --git a/Swiften/JID/JID.cpp b/Swiften/JID/JID.cpp index 0f2d8d1..fcd49f9 100644 --- a/Swiften/JID/JID.cpp +++ b/Swiften/JID/JID.cpp @@ -1,326 +1,326 @@ /* * Copyright (c) 2010-2013 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #define SWIFTEN_CACHE_JID_PREP #include <vector> #include <list> #include <iostream> #include <string> #ifdef SWIFTEN_CACHE_JID_PREP #include <boost/thread/mutex.hpp> #include <boost/unordered_map.hpp> #endif #include <boost/assign/list_of.hpp> #include <boost/algorithm/string/find_format.hpp> #include <boost/algorithm/string/finder.hpp> #include <boost/optional.hpp> #include <iostream> #include <sstream> #include <Swiften/Base/String.h> #include <Swiften/JID/JID.h> #include <Swiften/IDN/IDNConverter.h> #ifndef SWIFTEN_JID_NO_DEFAULT_IDN_CONVERTER #include <boost/shared_ptr.hpp> #include <Swiften/IDN/PlatformIDNConverter.h> #endif using namespace Swift; #ifdef SWIFTEN_CACHE_JID_PREP typedef boost::unordered_map<std::string, std::string> PrepCache; static boost::mutex namePrepCacheMutex; static PrepCache nodePrepCache; static PrepCache domainPrepCache; static PrepCache resourcePrepCache; #endif static const std::list<char> escapedChars = boost::assign::list_of(' ')('"')('&')('\'')('/')('<')('>')('@')(':'); static IDNConverter* idnConverter = NULL; #ifndef SWIFTEN_JID_NO_DEFAULT_IDN_CONVERTER namespace { struct IDNInitializer { IDNInitializer() : defaultIDNConverter(PlatformIDNConverter::create()) { idnConverter = defaultIDNConverter.get(); } boost::shared_ptr<IDNConverter> defaultIDNConverter; } initializer; } #endif static std::string getEscaped(char c) { return makeString() << '\\' << std::hex << static_cast<int>(c); } static bool getEscapeSequenceValue(const std::string& sequence, unsigned char& value) { std::stringstream s; unsigned int v; s << std::hex << sequence; s >> v; value = static_cast<unsigned char>(v); return (!s.fail() && !s.bad() && (value == 0x5C || std::find(escapedChars.begin(), escapedChars.end(), value) != escapedChars.end())); } // Disabling this code for now, since GCC4.5+boost1.42 (on ubuntu) seems to // result in a bug. Replacing it with naive code. #if 0 struct UnescapedCharacterFinder { template<typename Iterator> boost::iterator_range<Iterator> operator()(Iterator begin, Iterator end) { for (; begin != end; ++begin) { if (std::find(escapedChars.begin(), escapedChars.end(), *begin) != escapedChars.end()) { return boost::iterator_range<Iterator>(begin, begin + 1); } else if (*begin == '\\') { // Check if we have an escaped dissalowed character sequence Iterator innerBegin = begin + 1; if (innerBegin != end && innerBegin + 1 != end) { Iterator innerEnd = innerBegin + 2; unsigned char value; if (getEscapeSequenceValue(std::string(innerBegin, innerEnd), value)) { return boost::iterator_range<Iterator>(begin, begin + 1); } } } } return boost::iterator_range<Iterator>(end, end); } }; struct UnescapedCharacterFormatter { template<typename FindResult> std::string operator()(const FindResult& match) const { std::ostringstream s; s << '\\' << std::hex << static_cast<int>(*match.begin()); return s.str(); } }; struct EscapedCharacterFinder { template<typename Iterator> boost::iterator_range<Iterator> operator()(Iterator begin, Iterator end) { for (; begin != end; ++begin) { if (*begin == '\\') { Iterator innerEnd = begin + 1; for (size_t i = 0; i < 2 && innerEnd != end; ++i, ++innerEnd) { } unsigned char value; if (getEscapeSequenceValue(std::string(begin + 1, innerEnd), value)) { return boost::iterator_range<Iterator>(begin, innerEnd); } } } return boost::iterator_range<Iterator>(end, end); } }; struct EscapedCharacterFormatter { template<typename FindResult> std::string operator()(const FindResult& match) const { unsigned char value; if (getEscapeSequenceValue(std::string(match.begin() + 1, match.end()), value)) { return std::string(reinterpret_cast<const char*>(&value), 1); } return boost::copy_range<std::string>(match); } }; #endif namespace Swift { JID::JID(const char* jid) : valid_(true) { assert(jid); initializeFromString(std::string(jid)); } JID::JID(const std::string& jid) : valid_(true) { initializeFromString(jid); } JID::JID(const std::string& node, const std::string& domain) : valid_(true), hasResource_(false) { nameprepAndSetComponents(node, domain, ""); } JID::JID(const std::string& node, const std::string& domain, const std::string& resource) : valid_(true), hasResource_(true) { nameprepAndSetComponents(node, domain, resource); } void JID::initializeFromString(const std::string& jid) { if (String::beginsWith(jid, '@')) { valid_ = false; return; } std::string bare, resource; size_t slashIndex = jid.find('/'); if (slashIndex != jid.npos) { hasResource_ = true; bare = jid.substr(0, slashIndex); resource = jid.substr(slashIndex + 1, jid.npos); } else { hasResource_ = false; bare = jid; } std::pair<std::string,std::string> nodeAndDomain = String::getSplittedAtFirst(bare, '@'); if (nodeAndDomain.second.empty()) { nameprepAndSetComponents("", nodeAndDomain.first, resource); } else { nameprepAndSetComponents(nodeAndDomain.first, nodeAndDomain.second, resource); } } void JID::nameprepAndSetComponents(const std::string& node, const std::string& domain, const std::string& resource) { - if (domain.empty()) { + if (domain.empty() || !idnConverter->getIDNAEncoded(domain)) { valid_ = false; return; } #ifndef SWIFTEN_CACHE_JID_PREP node_ = idnConverter->getStringPrepared(node, IDNConverter::XMPPNodePrep); domain_ = idnConverter->getStringPrepared(domain, IDNConverter::NamePrep); resource_ = idnConverter->getStringPrepared(resource, IDNConverter::XMPPResourcePrep); #else boost::mutex::scoped_lock lock(namePrepCacheMutex); std::pair<PrepCache::iterator, bool> r; r = nodePrepCache.insert(std::make_pair(node, std::string())); if (r.second) { try { r.first->second = idnConverter->getStringPrepared(node, IDNConverter::XMPPNodePrep); } catch (...) { nodePrepCache.erase(r.first); valid_ = false; return; } } node_ = r.first->second; r = domainPrepCache.insert(std::make_pair(domain, std::string())); if (r.second) { try { r.first->second = idnConverter->getStringPrepared(domain, IDNConverter::NamePrep); } catch (...) { domainPrepCache.erase(r.first); valid_ = false; return; } } domain_ = r.first->second; r = resourcePrepCache.insert(std::make_pair(resource, std::string())); if (r.second) { try { r.first->second = idnConverter->getStringPrepared(resource, IDNConverter::XMPPResourcePrep); } catch (...) { resourcePrepCache.erase(r.first); valid_ = false; return; } } resource_ = r.first->second; #endif if (domain_.empty()) { valid_ = false; return; } } std::string JID::toString() const { std::string string; if (!node_.empty()) { string += node_ + "@"; } string += domain_; if (!isBare()) { string += "/" + resource_; } return string; } int JID::compare(const Swift::JID& o, CompareType compareType) const { if (node_ < o.node_) { return -1; } if (node_ > o.node_) { return 1; } if (domain_ < o.domain_) { return -1; } if (domain_ > o.domain_) { return 1; } if (compareType == WithResource) { if (hasResource_ != o.hasResource_) { return hasResource_ ? 1 : -1; } if (resource_ < o.resource_) { return -1; } if (resource_ > o.resource_) { return 1; } } return 0; } std::string JID::getEscapedNode(const std::string& node) { std::string result; for (std::string::const_iterator i = node.begin(); i != node.end(); ++i) { if (std::find(escapedChars.begin(), escapedChars.end(), *i) != escapedChars.end()) { result += getEscaped(*i); continue; } else if (*i == '\\') { // Check if we have an escaped dissalowed character sequence std::string::const_iterator innerBegin = i + 1; if (innerBegin != node.end() && innerBegin + 1 != node.end()) { std::string::const_iterator innerEnd = innerBegin + 2; unsigned char value; if (getEscapeSequenceValue(std::string(innerBegin, innerEnd), value)) { result += getEscaped(*i); continue; } } } result += *i; } return result; //return boost::find_format_all_copy(node, UnescapedCharacterFinder(), UnescapedCharacterFormatter()); } std::string JID::getUnescapedNode() const { std::string result; for (std::string::const_iterator j = node_.begin(); j != node_.end();) { if (*j == '\\') { std::string::const_iterator innerEnd = j + 1; for (size_t i = 0; i < 2 && innerEnd != node_.end(); ++i, ++innerEnd) { } unsigned char value; if (getEscapeSequenceValue(std::string(j + 1, innerEnd), value)) { result += std::string(reinterpret_cast<const char*>(&value), 1); j = innerEnd; continue; } } result += *j; ++j; } return result; //return boost::find_format_all_copy(node_, EscapedCharacterFinder(), EscapedCharacterFormatter()); } void JID::setIDNConverter(IDNConverter* converter) { idnConverter = converter; } std::ostream& operator<<(std::ostream& os, const JID& j) { os << j.toString(); return os; } boost::optional<JID> JID::parse(const std::string& s) { JID jid(s); return jid.isValid() ? jid : boost::optional<JID>(); } } diff --git a/Swiften/JID/UnitTest/JIDTest.cpp b/Swiften/JID/UnitTest/JIDTest.cpp index 72ca884..03203de 100644 --- a/Swiften/JID/UnitTest/JIDTest.cpp +++ b/Swiften/JID/UnitTest/JIDTest.cpp @@ -1,396 +1,401 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <cppunit/extensions/HelperMacros.h> #include <cppunit/extensions/TestFactoryRegistry.h> #include <Swiften/JID/JID.h> using namespace Swift; class JIDTest : public CppUnit::TestFixture { CPPUNIT_TEST_SUITE(JIDTest); CPPUNIT_TEST(testConstructorWithString); CPPUNIT_TEST(testConstructorWithString_NoResource); CPPUNIT_TEST(testConstructorWithString_NoNode); CPPUNIT_TEST(testConstructorWithString_EmptyResource); CPPUNIT_TEST(testConstructorWithString_OnlyDomain); + CPPUNIT_TEST(testConstructorWithString_InvalidDomain); CPPUNIT_TEST(testConstructorWithString_UpperCaseNode); CPPUNIT_TEST(testConstructorWithString_UpperCaseDomain); CPPUNIT_TEST(testConstructorWithString_UpperCaseResource); CPPUNIT_TEST(testConstructorWithString_EmptyNode); CPPUNIT_TEST(testConstructorWithString_IllegalResource); CPPUNIT_TEST(testConstructorWithString_SpacesInNode); CPPUNIT_TEST(testConstructorWithStrings); CPPUNIT_TEST(testConstructorWithStrings_EmptyDomain); CPPUNIT_TEST(testIsBare); CPPUNIT_TEST(testIsBare_NotBare); CPPUNIT_TEST(testToBare); CPPUNIT_TEST(testToBare_EmptyNode); CPPUNIT_TEST(testToBare_EmptyResource); CPPUNIT_TEST(testToString); CPPUNIT_TEST(testToString_EmptyNode); CPPUNIT_TEST(testToString_EmptyResource); CPPUNIT_TEST(testToString_NoResource); CPPUNIT_TEST(testCompare_SmallerNode); CPPUNIT_TEST(testCompare_LargerNode); CPPUNIT_TEST(testCompare_SmallerDomain); CPPUNIT_TEST(testCompare_LargerDomain); CPPUNIT_TEST(testCompare_SmallerResource); CPPUNIT_TEST(testCompare_LargerResource); CPPUNIT_TEST(testCompare_Equal); CPPUNIT_TEST(testCompare_EqualWithoutResource); CPPUNIT_TEST(testCompare_NoResourceAndEmptyResource); CPPUNIT_TEST(testCompare_EmptyResourceAndNoResource); CPPUNIT_TEST(testEquals); CPPUNIT_TEST(testEquals_NotEqual); CPPUNIT_TEST(testEquals_WithoutResource); CPPUNIT_TEST(testSmallerThan); CPPUNIT_TEST(testSmallerThan_Equal); CPPUNIT_TEST(testSmallerThan_Larger); CPPUNIT_TEST(testHasResource); CPPUNIT_TEST(testHasResource_NoResource); CPPUNIT_TEST(testGetEscapedNode); CPPUNIT_TEST(testGetEscapedNode_XEP106Examples); CPPUNIT_TEST(testGetEscapedNode_BackslashAtEnd); CPPUNIT_TEST(testGetUnescapedNode); CPPUNIT_TEST(testGetUnescapedNode_XEP106Examples); CPPUNIT_TEST_SUITE_END(); public: JIDTest() {} void testConstructorWithString() { JID testling("foo@bar/baz"); CPPUNIT_ASSERT_EQUAL(std::string("foo"), testling.getNode()); CPPUNIT_ASSERT_EQUAL(std::string("bar"), testling.getDomain()); CPPUNIT_ASSERT_EQUAL(std::string("baz"), testling.getResource()); CPPUNIT_ASSERT(!testling.isBare()); } void testConstructorWithString_NoResource() { JID testling("foo@bar"); CPPUNIT_ASSERT_EQUAL(std::string("foo"), testling.getNode()); CPPUNIT_ASSERT_EQUAL(std::string("bar"), testling.getDomain()); CPPUNIT_ASSERT_EQUAL(std::string(""), testling.getResource()); CPPUNIT_ASSERT(testling.isBare()); } void testConstructorWithString_EmptyResource() { JID testling("bar/"); CPPUNIT_ASSERT(testling.isValid()); CPPUNIT_ASSERT(!testling.isBare()); } void testConstructorWithString_NoNode() { JID testling("bar/baz"); CPPUNIT_ASSERT_EQUAL(std::string(""), testling.getNode()); CPPUNIT_ASSERT_EQUAL(std::string("bar"), testling.getDomain()); CPPUNIT_ASSERT_EQUAL(std::string("baz"), testling.getResource()); CPPUNIT_ASSERT(!testling.isBare()); } void testConstructorWithString_OnlyDomain() { JID testling("bar"); CPPUNIT_ASSERT_EQUAL(std::string(""), testling.getNode()); CPPUNIT_ASSERT_EQUAL(std::string("bar"), testling.getDomain()); CPPUNIT_ASSERT_EQUAL(std::string(""), testling.getResource()); CPPUNIT_ASSERT(testling.isBare()); } + void testConstructorWithString_InvalidDomain() { + CPPUNIT_ASSERT(!JID("foo@bar,baz").isValid()); + } + void testConstructorWithString_UpperCaseNode() { JID testling("Fo\xCE\xA9@bar"); CPPUNIT_ASSERT_EQUAL(std::string("fo\xCF\x89"), testling.getNode()); CPPUNIT_ASSERT_EQUAL(std::string("bar"), testling.getDomain()); } void testConstructorWithString_UpperCaseDomain() { JID testling("Fo\xCE\xA9"); CPPUNIT_ASSERT_EQUAL(std::string("fo\xCF\x89"), testling.getDomain()); } void testConstructorWithString_UpperCaseResource() { JID testling("bar/Fo\xCE\xA9"); CPPUNIT_ASSERT_EQUAL(testling.getResource(), std::string("Fo\xCE\xA9")); } void testConstructorWithString_EmptyNode() { JID testling("@bar"); CPPUNIT_ASSERT(!testling.isValid()); } void testConstructorWithString_IllegalResource() { JID testling("foo@bar.com/\xd8\xb1\xd9\x85\xd9\x82\xd9\x87\x20\xd8\xaa\xd8\xb1\xd9\x86\xd8\xb3\x20"); CPPUNIT_ASSERT(!testling.isValid()); } void testConstructorWithString_SpacesInNode() { CPPUNIT_ASSERT(!JID(" alice@wonderland.lit").isValid()); CPPUNIT_ASSERT(!JID("alice @wonderland.lit").isValid()); } void testConstructorWithStrings() { JID testling("foo", "bar", "baz"); CPPUNIT_ASSERT_EQUAL(std::string("foo"), testling.getNode()); CPPUNIT_ASSERT_EQUAL(std::string("bar"), testling.getDomain()); CPPUNIT_ASSERT_EQUAL(std::string("baz"), testling.getResource()); } void testConstructorWithStrings_EmptyDomain() { JID testling("foo", "", "baz"); CPPUNIT_ASSERT(!testling.isValid()); } void testIsBare() { CPPUNIT_ASSERT(JID("foo@bar").isBare()); } void testIsBare_NotBare() { CPPUNIT_ASSERT(!JID("foo@bar/baz").isBare()); } void testToBare() { JID testling("foo@bar/baz"); CPPUNIT_ASSERT_EQUAL(std::string("foo"), testling.toBare().getNode()); CPPUNIT_ASSERT_EQUAL(std::string("bar"), testling.toBare().getDomain()); CPPUNIT_ASSERT(testling.toBare().isBare()); } void testToBare_EmptyNode() { JID testling("bar/baz"); CPPUNIT_ASSERT_EQUAL(std::string(""), testling.toBare().getNode()); CPPUNIT_ASSERT_EQUAL(std::string("bar"), testling.toBare().getDomain()); CPPUNIT_ASSERT(testling.toBare().isBare()); } void testToBare_EmptyResource() { JID testling("bar/"); CPPUNIT_ASSERT_EQUAL(std::string(""), testling.toBare().getNode()); CPPUNIT_ASSERT_EQUAL(std::string("bar"), testling.toBare().getDomain()); CPPUNIT_ASSERT(testling.toBare().isBare()); } void testToString() { JID testling("foo@bar/baz"); CPPUNIT_ASSERT_EQUAL(std::string("foo@bar/baz"), testling.toString()); } void testToString_EmptyNode() { JID testling("bar/baz"); CPPUNIT_ASSERT_EQUAL(std::string("bar/baz"), testling.toString()); } void testToString_NoResource() { JID testling("foo@bar"); CPPUNIT_ASSERT_EQUAL(std::string("foo@bar"), testling.toString()); } void testToString_EmptyResource() { JID testling("foo@bar/"); CPPUNIT_ASSERT_EQUAL(std::string("foo@bar/"), testling.toString()); } void testCompare_SmallerNode() { JID testling1("a@c"); JID testling2("b@b"); CPPUNIT_ASSERT_EQUAL(-1, testling1.compare(testling2, JID::WithResource)); } void testCompare_LargerNode() { JID testling1("c@a"); JID testling2("b@b"); CPPUNIT_ASSERT_EQUAL(1, testling1.compare(testling2, JID::WithResource)); } void testCompare_SmallerDomain() { JID testling1("x@a/c"); JID testling2("x@b/b"); CPPUNIT_ASSERT_EQUAL(-1, testling1.compare(testling2, JID::WithResource)); } void testCompare_LargerDomain() { JID testling1("x@b/b"); JID testling2("x@a/c"); CPPUNIT_ASSERT_EQUAL(1, testling1.compare(testling2, JID::WithResource)); } void testCompare_SmallerResource() { JID testling1("x@y/a"); JID testling2("x@y/b"); CPPUNIT_ASSERT_EQUAL(-1, testling1.compare(testling2, JID::WithResource)); } void testCompare_LargerResource() { JID testling1("x@y/b"); JID testling2("x@y/a"); CPPUNIT_ASSERT_EQUAL(1, testling1.compare(testling2, JID::WithResource)); } void testCompare_Equal() { JID testling1("x@y/z"); JID testling2("x@y/z"); CPPUNIT_ASSERT_EQUAL(0, testling1.compare(testling2, JID::WithResource)); } void testCompare_EqualWithoutResource() { JID testling1("x@y/a"); JID testling2("x@y/b"); CPPUNIT_ASSERT_EQUAL(0, testling1.compare(testling2, JID::WithoutResource)); } void testCompare_NoResourceAndEmptyResource() { JID testling1("x@y/"); JID testling2("x@y"); CPPUNIT_ASSERT_EQUAL(1, testling1.compare(testling2, JID::WithResource)); } void testCompare_EmptyResourceAndNoResource() { JID testling1("x@y"); JID testling2("x@y/"); CPPUNIT_ASSERT_EQUAL(-1, testling1.compare(testling2, JID::WithResource)); } void testEquals() { JID testling1("x@y/c"); JID testling2("x@y/c"); CPPUNIT_ASSERT(testling1.equals(testling2, JID::WithResource)); } void testEquals_NotEqual() { JID testling1("x@y/c"); JID testling2("x@y/d"); CPPUNIT_ASSERT(!testling1.equals(testling2, JID::WithResource)); } void testEquals_WithoutResource() { JID testling1("x@y/c"); JID testling2("x@y/d"); CPPUNIT_ASSERT(testling1.equals(testling2, JID::WithoutResource)); } void testSmallerThan() { JID testling1("x@y/c"); JID testling2("x@y/d"); CPPUNIT_ASSERT(testling1 < testling2); } void testSmallerThan_Equal() { JID testling1("x@y/d"); JID testling2("x@y/d"); CPPUNIT_ASSERT(!(testling1 < testling2)); } void testSmallerThan_Larger() { JID testling1("x@y/d"); JID testling2("x@y/c"); CPPUNIT_ASSERT(!(testling1 < testling2)); } void testHasResource() { JID testling("x@y/d"); CPPUNIT_ASSERT(!testling.isBare()); } void testHasResource_NoResource() { JID testling("x@y"); CPPUNIT_ASSERT(testling.isBare()); } void testGetEscapedNode() { std::string escaped = JID::getEscapedNode("alice@wonderland.lit"); CPPUNIT_ASSERT_EQUAL(std::string("alice\\40wonderland.lit"), escaped); escaped = JID::getEscapedNode("\\& \" ' / <\\\\> @ :\\3a\\40"); CPPUNIT_ASSERT_EQUAL(std::string("\\\\26\\20\\22\\20\\27\\20\\2f\\20\\3c\\\\\\3e\\20\\40\\20\\3a\\5c3a\\5c40"), escaped); } void testGetEscapedNode_XEP106Examples() { CPPUNIT_ASSERT_EQUAL(std::string("\\2plus\\2is\\4"), JID::getEscapedNode("\\2plus\\2is\\4")); CPPUNIT_ASSERT_EQUAL(std::string("foo\\bar"), JID::getEscapedNode("foo\\bar")); CPPUNIT_ASSERT_EQUAL(std::string("foob\\41r"), JID::getEscapedNode("foob\\41r")); CPPUNIT_ASSERT_EQUAL(JID::getEscapedNode("space cadet"), std::string("space\\20cadet")); CPPUNIT_ASSERT_EQUAL(JID::getEscapedNode("call me \"ishmael\""), std::string("call\\20me\\20\\22ishmael\\22")); CPPUNIT_ASSERT_EQUAL(JID::getEscapedNode("at&t guy"), std::string("at\\26t\\20guy")); CPPUNIT_ASSERT_EQUAL(JID::getEscapedNode("d'artagnan"), std::string("d\\27artagnan")); CPPUNIT_ASSERT_EQUAL(JID::getEscapedNode("/.fanboy"), std::string("\\2f.fanboy")); CPPUNIT_ASSERT_EQUAL(JID::getEscapedNode("::foo::"), std::string("\\3a\\3afoo\\3a\\3a")); CPPUNIT_ASSERT_EQUAL(JID::getEscapedNode("<foo>"), std::string("\\3cfoo\\3e")); CPPUNIT_ASSERT_EQUAL(JID::getEscapedNode("user@host"), std::string("user\\40host")); CPPUNIT_ASSERT_EQUAL(JID::getEscapedNode("c:\\net"), std::string("c\\3a\\net")); CPPUNIT_ASSERT_EQUAL(JID::getEscapedNode("c:\\\\net"), std::string("c\\3a\\\\net")); CPPUNIT_ASSERT_EQUAL(JID::getEscapedNode("c:\\cool stuff"), std::string("c\\3a\\cool\\20stuff")); CPPUNIT_ASSERT_EQUAL(JID::getEscapedNode("c:\\5commas"), std::string("c\\3a\\5c5commas")); } void testGetEscapedNode_BackslashAtEnd() { CPPUNIT_ASSERT_EQUAL(std::string("foo\\"), JID::getEscapedNode("foo\\")); } void testGetUnescapedNode() { std::string input = "\\& \" ' / <\\\\> @ : \\5c\\40"; JID testling(JID::getEscapedNode(input) + "@y"); CPPUNIT_ASSERT(testling.isValid()); CPPUNIT_ASSERT_EQUAL(input, testling.getUnescapedNode()); } void testGetUnescapedNode_XEP106Examples() { CPPUNIT_ASSERT_EQUAL(std::string("\\2plus\\2is\\4"), JID("\\2plus\\2is\\4@example.com").getUnescapedNode()); CPPUNIT_ASSERT_EQUAL(std::string("foo\\bar"), JID("foo\\bar@example.com").getUnescapedNode()); CPPUNIT_ASSERT_EQUAL(std::string("foob\\41r"), JID("foob\\41r@example.com").getUnescapedNode()); CPPUNIT_ASSERT_EQUAL(std::string("space cadet"), JID("space\\20cadet@example.com").getUnescapedNode()); CPPUNIT_ASSERT_EQUAL(std::string("call me \"ishmael\""), JID("call\\20me\\20\\22ishmael\\22@example.com").getUnescapedNode()); CPPUNIT_ASSERT_EQUAL(std::string("at&t guy"), JID("at\\26t\\20guy@example.com").getUnescapedNode()); CPPUNIT_ASSERT_EQUAL(std::string("d'artagnan"), JID("d\\27artagnan@example.com").getUnescapedNode()); CPPUNIT_ASSERT_EQUAL(std::string("/.fanboy"), JID("\\2f.fanboy@example.com").getUnescapedNode()); CPPUNIT_ASSERT_EQUAL(std::string("::foo::"), JID("\\3a\\3afoo\\3a\\3a@example.com").getUnescapedNode()); CPPUNIT_ASSERT_EQUAL(std::string("<foo>"), JID("\\3cfoo\\3e@example.com").getUnescapedNode()); CPPUNIT_ASSERT_EQUAL(std::string("user@host"), JID("user\\40host@example.com").getUnescapedNode()); CPPUNIT_ASSERT_EQUAL(std::string("c:\\net"), JID("c\\3a\\net@example.com").getUnescapedNode()); CPPUNIT_ASSERT_EQUAL(std::string("c:\\\\net"), JID("c\\3a\\\\net@example.com").getUnescapedNode()); CPPUNIT_ASSERT_EQUAL(std::string("c:\\cool stuff"), JID("c\\3a\\cool\\20stuff@example.com").getUnescapedNode()); CPPUNIT_ASSERT_EQUAL(std::string("c:\\5commas"), JID("c\\3a\\5c5commas@example.com").getUnescapedNode()); } }; CPPUNIT_TEST_SUITE_REGISTRATION(JIDTest); diff --git a/Swiften/Network/CachingDomainNameResolver.cpp b/Swiften/Network/CachingDomainNameResolver.cpp index 4cf8286..fea14a3 100644 --- a/Swiften/Network/CachingDomainNameResolver.cpp +++ b/Swiften/Network/CachingDomainNameResolver.cpp @@ -1,30 +1,30 @@ /* * Copyright (c) 2012 Kevin Smith * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <Swiften/Network/CachingDomainNameResolver.h> #include <boost/smart_ptr/make_shared.hpp> namespace Swift { CachingDomainNameResolver::CachingDomainNameResolver(DomainNameResolver* realResolver, EventLoop*) : realResolver(realResolver) { } CachingDomainNameResolver::~CachingDomainNameResolver() { } -DomainNameServiceQuery::ref CachingDomainNameResolver::createServiceQuery(const std::string& name) { +DomainNameServiceQuery::ref CachingDomainNameResolver::createServiceQuery(const std::string& serviceLookupPrefix, const std::string& domain) { //TODO: Cache - return realResolver->createServiceQuery(name); + return realResolver->createServiceQuery(serviceLookupPrefix, domain); } DomainNameAddressQuery::ref CachingDomainNameResolver::createAddressQuery(const std::string& name) { //TODO: Cache return realResolver->createAddressQuery(name); } } diff --git a/Swiften/Network/CachingDomainNameResolver.h b/Swiften/Network/CachingDomainNameResolver.h index 66b4d68..3d50676 100644 --- a/Swiften/Network/CachingDomainNameResolver.h +++ b/Swiften/Network/CachingDomainNameResolver.h @@ -1,31 +1,31 @@ /* * Copyright (c) 2012 Kevin Smith * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once #include <boost/shared_ptr.hpp> #include <Swiften/Network/DomainNameResolver.h> #include <Swiften/Network/StaticDomainNameResolver.h> /* * FIXME: Does not do any caching yet. */ namespace Swift { class EventLoop; class CachingDomainNameResolver : public DomainNameResolver { public: CachingDomainNameResolver(DomainNameResolver* realResolver, EventLoop* eventLoop); ~CachingDomainNameResolver(); - virtual DomainNameServiceQuery::ref createServiceQuery(const std::string& name); + virtual DomainNameServiceQuery::ref createServiceQuery(const std::string& serviceLookupPrefix, const std::string& domain); virtual DomainNameAddressQuery::ref createAddressQuery(const std::string& name); private: DomainNameResolver* realResolver; }; } diff --git a/Swiften/Network/Connector.cpp b/Swiften/Network/Connector.cpp index da2490f..1db1eac 100644 --- a/Swiften/Network/Connector.cpp +++ b/Swiften/Network/Connector.cpp @@ -1,186 +1,186 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <Swiften/Network/Connector.h> #include <boost/bind.hpp> #include <iostream> #include <Swiften/Network/ConnectionFactory.h> #include <Swiften/Network/DomainNameResolver.h> #include <Swiften/Network/DomainNameAddressQuery.h> #include <Swiften/Network/TimerFactory.h> #include <Swiften/Base/Log.h> namespace Swift { Connector::Connector(const std::string& hostname, int port, const boost::optional<std::string>& serviceLookupPrefix, DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory) : hostname(hostname), port(port), serviceLookupPrefix(serviceLookupPrefix), resolver(resolver), connectionFactory(connectionFactory), timerFactory(timerFactory), timeoutMilliseconds(0), queriedAllServices(true), foundSomeDNS(false) { } void Connector::setTimeoutMilliseconds(int milliseconds) { timeoutMilliseconds = milliseconds; } void Connector::start() { SWIFT_LOG(debug) << "Starting connector for " << hostname << std::endl; //std::cout << "Connector::start()" << std::endl; assert(!currentConnection); assert(!serviceQuery); assert(!timer); queriedAllServices = false; if (timeoutMilliseconds > 0) { timer = timerFactory->createTimer(timeoutMilliseconds); timer->onTick.connect(boost::bind(&Connector::handleTimeout, shared_from_this())); } if (serviceLookupPrefix) { - serviceQuery = resolver->createServiceQuery((*serviceLookupPrefix) + hostname); + serviceQuery = resolver->createServiceQuery(*serviceLookupPrefix, hostname); serviceQuery->onResult.connect(boost::bind(&Connector::handleServiceQueryResult, shared_from_this(), _1)); serviceQuery->run(); } else { queryAddress(hostname); } } void Connector::stop() { finish(boost::shared_ptr<Connection>()); } void Connector::queryAddress(const std::string& hostname) { assert(!addressQuery); addressQuery = resolver->createAddressQuery(hostname); addressQuery->onResult.connect(boost::bind(&Connector::handleAddressQueryResult, shared_from_this(), _1, _2)); addressQuery->run(); } void Connector::handleServiceQueryResult(const std::vector<DomainNameServiceQuery::Result>& result) { SWIFT_LOG(debug) << result.size() << " SRV result(s)" << std::endl; serviceQueryResults = std::deque<DomainNameServiceQuery::Result>(result.begin(), result.end()); serviceQuery.reset(); if (!serviceQueryResults.empty()) { foundSomeDNS = true; } tryNextServiceOrFallback(); } void Connector::tryNextServiceOrFallback() { if (queriedAllServices) { SWIFT_LOG(debug) << "Queried all services" << std::endl; finish(boost::shared_ptr<Connection>()); } else if (serviceQueryResults.empty()) { SWIFT_LOG(debug) << "Falling back on A resolution" << std::endl; // Fall back on simple address resolving queriedAllServices = true; queryAddress(hostname); } else { SWIFT_LOG(debug) << "Querying next address" << std::endl; queryAddress(serviceQueryResults.front().hostname); } } void Connector::handleAddressQueryResult(const std::vector<HostAddress>& addresses, boost::optional<DomainNameResolveError> error) { SWIFT_LOG(debug) << addresses.size() << " addresses" << std::endl; addressQuery.reset(); if (error || addresses.empty()) { if (!serviceQueryResults.empty()) { serviceQueryResults.pop_front(); } tryNextServiceOrFallback(); } else { foundSomeDNS = true; addressQueryResults = std::deque<HostAddress>(addresses.begin(), addresses.end()); tryNextAddress(); } } void Connector::tryNextAddress() { if (addressQueryResults.empty()) { SWIFT_LOG(debug) << "Done trying addresses. Moving on." << std::endl; // Done trying all addresses. Move on to the next host. if (!serviceQueryResults.empty()) { serviceQueryResults.pop_front(); } tryNextServiceOrFallback(); } else { SWIFT_LOG(debug) << "Trying next address" << std::endl; HostAddress address = addressQueryResults.front(); addressQueryResults.pop_front(); int connectPort = (port == -1 ? 5222 : port); if (!serviceQueryResults.empty()) { connectPort = serviceQueryResults.front().port; } tryConnect(HostAddressPort(address, connectPort)); } } void Connector::tryConnect(const HostAddressPort& target) { assert(!currentConnection); SWIFT_LOG(debug) << "Trying to connect to " << target.getAddress().toString() << ":" << target.getPort() << std::endl; currentConnection = connectionFactory->createConnection(); currentConnection->onConnectFinished.connect(boost::bind(&Connector::handleConnectionConnectFinished, shared_from_this(), _1)); currentConnection->connect(target); if (timer) { timer->start(); } } void Connector::handleConnectionConnectFinished(bool error) { SWIFT_LOG(debug) << "ConnectFinished: " << (error ? "error" : "success") << std::endl; if (timer) { timer->stop(); timer.reset(); } currentConnection->onConnectFinished.disconnect(boost::bind(&Connector::handleConnectionConnectFinished, shared_from_this(), _1)); if (error) { currentConnection.reset(); if (!addressQueryResults.empty()) { tryNextAddress(); } else { if (!serviceQueryResults.empty()) { serviceQueryResults.pop_front(); } tryNextServiceOrFallback(); } } else { finish(currentConnection); } } void Connector::finish(boost::shared_ptr<Connection> connection) { if (timer) { timer->stop(); timer->onTick.disconnect(boost::bind(&Connector::handleTimeout, shared_from_this())); timer.reset(); } if (serviceQuery) { serviceQuery->onResult.disconnect(boost::bind(&Connector::handleServiceQueryResult, shared_from_this(), _1)); serviceQuery.reset(); } if (addressQuery) { addressQuery->onResult.disconnect(boost::bind(&Connector::handleAddressQueryResult, shared_from_this(), _1, _2)); addressQuery.reset(); } if (currentConnection) { currentConnection->onConnectFinished.disconnect(boost::bind(&Connector::handleConnectionConnectFinished, shared_from_this(), _1)); currentConnection.reset(); } onConnectFinished(connection, (connection || foundSomeDNS) ? boost::shared_ptr<Error>() : boost::make_shared<DomainNameResolveError>()); } void Connector::handleTimeout() { SWIFT_LOG(debug) << "Timeout" << std::endl; handleConnectionConnectFinished(true); } } diff --git a/Swiften/Network/DomainNameResolver.h b/Swiften/Network/DomainNameResolver.h index 491586a..dc7013e 100644 --- a/Swiften/Network/DomainNameResolver.h +++ b/Swiften/Network/DomainNameResolver.h @@ -1,26 +1,26 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once #include <boost/shared_ptr.hpp> #include <string> #include <Swiften/Base/API.h> namespace Swift { class DomainNameServiceQuery; class DomainNameAddressQuery; class SWIFTEN_API DomainNameResolver { public: virtual ~DomainNameResolver(); - virtual boost::shared_ptr<DomainNameServiceQuery> createServiceQuery(const std::string& name) = 0; + virtual boost::shared_ptr<DomainNameServiceQuery> createServiceQuery(const std::string& serviceLookupPrefix, const std::string& domain) = 0; virtual boost::shared_ptr<DomainNameAddressQuery> createAddressQuery(const std::string& name) = 0; }; } diff --git a/Swiften/Network/PlatformDomainNameAddressQuery.cpp b/Swiften/Network/PlatformDomainNameAddressQuery.cpp index ec7e663..91d15b9 100644 --- a/Swiften/Network/PlatformDomainNameAddressQuery.cpp +++ b/Swiften/Network/PlatformDomainNameAddressQuery.cpp @@ -1,58 +1,66 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <Swiften/Network/PlatformDomainNameAddressQuery.h> #include <boost/asio/ip/tcp.hpp> #include <Swiften/Network/PlatformDomainNameResolver.h> #include <Swiften/EventLoop/EventLoop.h> namespace Swift { -PlatformDomainNameAddressQuery::PlatformDomainNameAddressQuery(const std::string& host, EventLoop* eventLoop, PlatformDomainNameResolver* resolver) : PlatformDomainNameQuery(resolver), hostname(host), eventLoop(eventLoop) { +PlatformDomainNameAddressQuery::PlatformDomainNameAddressQuery(const boost::optional<std::string>& host, EventLoop* eventLoop, PlatformDomainNameResolver* resolver) : PlatformDomainNameQuery(resolver), hostnameValid(false), eventLoop(eventLoop) { + if (!!host) { + hostname = *host; + hostnameValid = true; + } } void PlatformDomainNameAddressQuery::run() { getResolver()->addQueryToQueue(shared_from_this()); } void PlatformDomainNameAddressQuery::runBlocking() { + if (!hostnameValid) { + emitError(); + return; + } //std::cout << "PlatformDomainNameResolver::doRun()" << std::endl; boost::asio::ip::tcp::resolver resolver(ioService); boost::asio::ip::tcp::resolver::query query(hostname, "5222"); try { //std::cout << "PlatformDomainNameResolver::doRun(): Resolving" << std::endl; boost::asio::ip::tcp::resolver::iterator endpointIterator = resolver.resolve(query); //std::cout << "PlatformDomainNameResolver::doRun(): Resolved" << std::endl; if (endpointIterator == boost::asio::ip::tcp::resolver::iterator()) { //std::cout << "PlatformDomainNameResolver::doRun(): Error 1" << std::endl; emitError(); } else { std::vector<HostAddress> results; for ( ; endpointIterator != boost::asio::ip::tcp::resolver::iterator(); ++endpointIterator) { boost::asio::ip::address address = (*endpointIterator).endpoint().address(); results.push_back(address.is_v4() ? HostAddress(&address.to_v4().to_bytes()[0], 4) : HostAddress(&address.to_v6().to_bytes()[0], 16)); } //std::cout << "PlatformDomainNameResolver::doRun(): Success" << std::endl; eventLoop->postEvent( boost::bind(boost::ref(onResult), results, boost::optional<DomainNameResolveError>()), shared_from_this()); } } catch (...) { //std::cout << "PlatformDomainNameResolver::doRun(): Error 2" << std::endl; emitError(); } } void PlatformDomainNameAddressQuery::emitError() { eventLoop->postEvent(boost::bind(boost::ref(onResult), std::vector<HostAddress>(), boost::optional<DomainNameResolveError>(DomainNameResolveError())), shared_from_this()); } } diff --git a/Swiften/Network/PlatformDomainNameAddressQuery.h b/Swiften/Network/PlatformDomainNameAddressQuery.h index e1dc05f..9e89086 100644 --- a/Swiften/Network/PlatformDomainNameAddressQuery.h +++ b/Swiften/Network/PlatformDomainNameAddressQuery.h @@ -1,38 +1,39 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once #include <boost/asio/io_service.hpp> #include <boost/enable_shared_from_this.hpp> #include <Swiften/Network/DomainNameAddressQuery.h> #include <Swiften/Network/PlatformDomainNameQuery.h> #include <Swiften/EventLoop/EventOwner.h> #include <string> namespace Swift { class PlatformDomainNameResolver; class EventLoop; class PlatformDomainNameAddressQuery : public DomainNameAddressQuery, public PlatformDomainNameQuery, public boost::enable_shared_from_this<PlatformDomainNameAddressQuery>, public EventOwner { public: - PlatformDomainNameAddressQuery(const std::string& host, EventLoop* eventLoop, PlatformDomainNameResolver*); + PlatformDomainNameAddressQuery(const boost::optional<std::string>& host, EventLoop* eventLoop, PlatformDomainNameResolver*); void run(); private: void runBlocking(); void emitError(); private: boost::asio::io_service ioService; std::string hostname; + bool hostnameValid; EventLoop* eventLoop; }; } diff --git a/Swiften/Network/PlatformDomainNameResolver.cpp b/Swiften/Network/PlatformDomainNameResolver.cpp index 677f1d5..b65e884 100644 --- a/Swiften/Network/PlatformDomainNameResolver.cpp +++ b/Swiften/Network/PlatformDomainNameResolver.cpp @@ -1,76 +1,81 @@ /* * Copyright (c) 2010-2013 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <Swiften/Network/PlatformDomainNameResolver.h> // Putting this early on, because some system types conflict with thread #include <Swiften/Network/PlatformDomainNameServiceQuery.h> #include <string> #include <vector> #include <boost/bind.hpp> #include <boost/thread.hpp> #include <algorithm> #include <string> #include <Swiften/IDN/IDNConverter.h> #include <Swiften/Network/HostAddress.h> #include <Swiften/EventLoop/EventLoop.h> #include <Swiften/Network/HostAddressPort.h> #include <Swiften/Network/DomainNameAddressQuery.h> #include <Swiften/Network/PlatformDomainNameAddressQuery.h> using namespace Swift; namespace Swift { PlatformDomainNameResolver::PlatformDomainNameResolver(IDNConverter* idnConverter, EventLoop* eventLoop) : idnConverter(idnConverter), eventLoop(eventLoop), stopRequested(false) { thread = new boost::thread(boost::bind(&PlatformDomainNameResolver::run, this)); } PlatformDomainNameResolver::~PlatformDomainNameResolver() { stopRequested = true; addQueryToQueue(boost::shared_ptr<PlatformDomainNameQuery>()); thread->join(); delete thread; } -boost::shared_ptr<DomainNameServiceQuery> PlatformDomainNameResolver::createServiceQuery(const std::string& name) { - return boost::shared_ptr<DomainNameServiceQuery>(new PlatformDomainNameServiceQuery(idnConverter->getIDNAEncoded(name), eventLoop, this)); +boost::shared_ptr<DomainNameServiceQuery> PlatformDomainNameResolver::createServiceQuery(const std::string& serviceLookupPrefix, const std::string& domain) { + boost::optional<std::string> encodedDomain = idnConverter->getIDNAEncoded(domain); + std::string result; + if (encodedDomain) { + result = serviceLookupPrefix + *encodedDomain; + } + return boost::shared_ptr<DomainNameServiceQuery>(new PlatformDomainNameServiceQuery(result, eventLoop, this)); } boost::shared_ptr<DomainNameAddressQuery> PlatformDomainNameResolver::createAddressQuery(const std::string& name) { return boost::shared_ptr<DomainNameAddressQuery>(new PlatformDomainNameAddressQuery(idnConverter->getIDNAEncoded(name), eventLoop, this)); } void PlatformDomainNameResolver::run() { while (!stopRequested) { PlatformDomainNameQuery::ref query; { boost::unique_lock<boost::mutex> lock(queueMutex); while (queue.empty()) { queueNonEmpty.wait(lock); } query = queue.front(); queue.pop_front(); } // Check whether we don't have a non-null query (used to stop the // resolver) if (query) { query->runBlocking(); } } } void PlatformDomainNameResolver::addQueryToQueue(PlatformDomainNameQuery::ref query) { { boost::lock_guard<boost::mutex> lock(queueMutex); queue.push_back(query); } queueNonEmpty.notify_one(); } } diff --git a/Swiften/Network/PlatformDomainNameResolver.h b/Swiften/Network/PlatformDomainNameResolver.h index 25d87cf..6c3bf10 100644 --- a/Swiften/Network/PlatformDomainNameResolver.h +++ b/Swiften/Network/PlatformDomainNameResolver.h @@ -1,47 +1,47 @@ /* * Copyright (c) 2010-2013 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once #include <deque> #include <boost/thread/thread.hpp> #include <boost/thread/mutex.hpp> #include <boost/thread/condition_variable.hpp> #include <Swiften/Base/API.h> #include <Swiften/Network/DomainNameResolver.h> #include <Swiften/Network/PlatformDomainNameQuery.h> #include <Swiften/Network/DomainNameServiceQuery.h> #include <Swiften/Network/DomainNameAddressQuery.h> namespace Swift { class IDNConverter; class EventLoop; class SWIFTEN_API PlatformDomainNameResolver : public DomainNameResolver { public: PlatformDomainNameResolver(IDNConverter* idnConverter, EventLoop* eventLoop); ~PlatformDomainNameResolver(); - virtual DomainNameServiceQuery::ref createServiceQuery(const std::string& name); + virtual DomainNameServiceQuery::ref createServiceQuery(const std::string& serviceLookupPrefix, const std::string& domain); virtual DomainNameAddressQuery::ref createAddressQuery(const std::string& name); private: void run(); void addQueryToQueue(PlatformDomainNameQuery::ref); private: friend class PlatformDomainNameServiceQuery; friend class PlatformDomainNameAddressQuery; IDNConverter* idnConverter; EventLoop* eventLoop; bool stopRequested; boost::thread* thread; std::deque<PlatformDomainNameQuery::ref> queue; boost::mutex queueMutex; boost::condition_variable queueNonEmpty; }; } diff --git a/Swiften/Network/PlatformDomainNameServiceQuery.cpp b/Swiften/Network/PlatformDomainNameServiceQuery.cpp index 5788d2f..58cf8d2 100644 --- a/Swiften/Network/PlatformDomainNameServiceQuery.cpp +++ b/Swiften/Network/PlatformDomainNameServiceQuery.cpp @@ -1,173 +1,182 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <boost/asio.hpp> #include <Swiften/Network/PlatformDomainNameServiceQuery.h> #pragma GCC diagnostic ignored "-Wold-style-cast" #include <Swiften/Base/Platform.h> #include <stdlib.h> #include <boost/numeric/conversion/cast.hpp> #ifdef SWIFTEN_PLATFORM_WINDOWS #undef UNICODE #include <windows.h> #include <windns.h> #ifndef DNS_TYPE_SRV #define DNS_TYPE_SRV 33 #endif #else #include <arpa/nameser.h> #include <arpa/nameser_compat.h> #include <resolv.h> #endif #include <boost/bind.hpp> #include <Swiften/Base/ByteArray.h> #include <Swiften/EventLoop/EventLoop.h> #include <Swiften/Base/foreach.h> #include <Swiften/Base/BoostRandomGenerator.h> #include <Swiften/Base/Log.h> #include <Swiften/Network/PlatformDomainNameResolver.h> using namespace Swift; namespace Swift { -PlatformDomainNameServiceQuery::PlatformDomainNameServiceQuery(const std::string& service, EventLoop* eventLoop, PlatformDomainNameResolver* resolver) : PlatformDomainNameQuery(resolver), eventLoop(eventLoop), service(service) { +PlatformDomainNameServiceQuery::PlatformDomainNameServiceQuery(const boost::optional<std::string>& serviceName, EventLoop* eventLoop, PlatformDomainNameResolver* resolver) : PlatformDomainNameQuery(resolver), eventLoop(eventLoop), serviceValid(false) { + if (!!serviceName) { + service = *serviceName; + serviceValid = true; + } } void PlatformDomainNameServiceQuery::run() { getResolver()->addQueryToQueue(shared_from_this()); } void PlatformDomainNameServiceQuery::runBlocking() { + if (!serviceValid) { + emitError(); + return; + } + SWIFT_LOG(debug) << "Querying " << service << std::endl; std::vector<DomainNameServiceQuery::Result> records; #if defined(SWIFTEN_PLATFORM_WINDOWS) DNS_RECORD* responses; // FIXME: This conversion doesn't work if unicode is deffed above if (DnsQuery(service.c_str(), DNS_TYPE_SRV, DNS_QUERY_STANDARD, NULL, &responses, NULL) != ERROR_SUCCESS) { emitError(); return; } DNS_RECORD* currentEntry = responses; while (currentEntry) { if (currentEntry->wType == DNS_TYPE_SRV) { DomainNameServiceQuery::Result record; record.priority = currentEntry->Data.SRV.wPriority; record.weight = currentEntry->Data.SRV.wWeight; record.port = currentEntry->Data.SRV.wPort; // The pNameTarget is actually a PCWSTR, so I would have expected this // conversion to not work at all, but it does. // Actually, it doesn't. Fix this and remove explicit cast // Remove unicode undef above as well record.hostname = std::string((const char*) currentEntry->Data.SRV.pNameTarget); records.push_back(record); } currentEntry = currentEntry->pNext; } DnsRecordListFree(responses, DnsFreeRecordList); #else // Make sure we reinitialize the domain list every time res_init(); ByteArray response; response.resize(NS_PACKETSZ); int responseLength = res_query(const_cast<char*>(service.c_str()), ns_c_in, ns_t_srv, reinterpret_cast<u_char*>(vecptr(response)), response.size()); if (responseLength == -1) { SWIFT_LOG(debug) << "Error" << std::endl; emitError(); return; } // Parse header HEADER* header = reinterpret_cast<HEADER*>(vecptr(response)); unsigned char* messageStart = vecptr(response); unsigned char* messageEnd = messageStart + responseLength; unsigned char* currentEntry = messageStart + NS_HFIXEDSZ; // Skip over the queries int queriesCount = ntohs(header->qdcount); while (queriesCount > 0) { int entryLength = dn_skipname(currentEntry, messageEnd); if (entryLength < 0) { emitError(); return; } currentEntry += entryLength + NS_QFIXEDSZ; queriesCount--; } // Process the SRV answers int answersCount = ntohs(header->ancount); while (answersCount > 0) { DomainNameServiceQuery::Result record; int entryLength = dn_skipname(currentEntry, messageEnd); currentEntry += entryLength; currentEntry += NS_RRFIXEDSZ; // Priority if (currentEntry + 2 >= messageEnd) { emitError(); return; } record.priority = boost::numeric_cast<int>(ns_get16(currentEntry)); currentEntry += 2; // Weight if (currentEntry + 2 >= messageEnd) { emitError(); return; } record.weight = boost::numeric_cast<int>(ns_get16(currentEntry)); currentEntry += 2; // Port if (currentEntry + 2 >= messageEnd) { emitError(); return; } record.port = boost::numeric_cast<int>(ns_get16(currentEntry)); currentEntry += 2; // Hostname if (currentEntry >= messageEnd) { emitError(); return; } ByteArray entry; entry.resize(NS_MAXDNAME); entryLength = dn_expand(messageStart, messageEnd, currentEntry, reinterpret_cast<char*>(vecptr(entry)), entry.size()); if (entryLength < 0) { emitError(); return; } record.hostname = std::string(reinterpret_cast<const char*>(vecptr(entry))); records.push_back(record); currentEntry += entryLength; answersCount--; } #endif BoostRandomGenerator generator; DomainNameServiceQuery::sortResults(records, generator); //std::cout << "Sending out " << records.size() << " SRV results " << std::endl; eventLoop->postEvent(boost::bind(boost::ref(onResult), records), shared_from_this()); } void PlatformDomainNameServiceQuery::emitError() { eventLoop->postEvent(boost::bind(boost::ref(onResult), std::vector<DomainNameServiceQuery::Result>()), shared_from_this()); } } diff --git a/Swiften/Network/PlatformDomainNameServiceQuery.h b/Swiften/Network/PlatformDomainNameServiceQuery.h index 310e639..e105479 100644 --- a/Swiften/Network/PlatformDomainNameServiceQuery.h +++ b/Swiften/Network/PlatformDomainNameServiceQuery.h @@ -1,33 +1,34 @@ /* * Copyright (c) 2010-2013 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once #include <boost/enable_shared_from_this.hpp> #include <Swiften/Network/DomainNameServiceQuery.h> #include <Swiften/EventLoop/EventOwner.h> #include <string> #include <Swiften/Network/PlatformDomainNameQuery.h> namespace Swift { class EventLoop; class PlatformDomainNameServiceQuery : public DomainNameServiceQuery, public PlatformDomainNameQuery, public boost::enable_shared_from_this<PlatformDomainNameServiceQuery>, public EventOwner { public: - PlatformDomainNameServiceQuery(const std::string& service, EventLoop* eventLoop, PlatformDomainNameResolver* resolver); + PlatformDomainNameServiceQuery(const boost::optional<std::string>& serviceName, EventLoop* eventLoop, PlatformDomainNameResolver* resolver); virtual void run(); private: void runBlocking(); void emitError(); private: EventLoop* eventLoop; std::string service; + bool serviceValid; }; } diff --git a/Swiften/Network/StaticDomainNameResolver.cpp b/Swiften/Network/StaticDomainNameResolver.cpp index ee18ee5..17d9c3b 100644 --- a/Swiften/Network/StaticDomainNameResolver.cpp +++ b/Swiften/Network/StaticDomainNameResolver.cpp @@ -1,119 +1,119 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <Swiften/Network/StaticDomainNameResolver.h> #include <boost/bind.hpp> #include <boost/lexical_cast.hpp> #include <Swiften/Network/DomainNameResolveError.h> #include <Swiften/EventLoop/EventOwner.h> #include <string> using namespace Swift; namespace { struct ServiceQuery : public DomainNameServiceQuery, public boost::enable_shared_from_this<ServiceQuery> { ServiceQuery(const std::string& service, Swift::StaticDomainNameResolver* resolver, EventLoop* eventLoop, boost::shared_ptr<EventOwner> owner) : eventLoop(eventLoop), service(service), resolver(resolver), owner(owner) {} virtual void run() { if (!resolver->getIsResponsive()) { return; } std::vector<DomainNameServiceQuery::Result> results; for(StaticDomainNameResolver::ServicesCollection::const_iterator i = resolver->getServices().begin(); i != resolver->getServices().end(); ++i) { if (i->first == service) { results.push_back(i->second); } } eventLoop->postEvent(boost::bind(&ServiceQuery::emitOnResult, shared_from_this(), results), owner); } void emitOnResult(std::vector<DomainNameServiceQuery::Result> results) { onResult(results); } EventLoop* eventLoop; std::string service; StaticDomainNameResolver* resolver; boost::shared_ptr<EventOwner> owner; }; struct AddressQuery : public DomainNameAddressQuery, public boost::enable_shared_from_this<AddressQuery> { AddressQuery(const std::string& host, StaticDomainNameResolver* resolver, EventLoop* eventLoop, boost::shared_ptr<EventOwner> owner) : eventLoop(eventLoop), host(host), resolver(resolver), owner(owner) {} virtual void run() { if (!resolver->getIsResponsive()) { return; } StaticDomainNameResolver::AddressesMap::const_iterator i = resolver->getAddresses().find(host); if (i != resolver->getAddresses().end()) { eventLoop->postEvent( boost::bind(&AddressQuery::emitOnResult, shared_from_this(), i->second, boost::optional<DomainNameResolveError>())); } else { eventLoop->postEvent(boost::bind(&AddressQuery::emitOnResult, shared_from_this(), std::vector<HostAddress>(), boost::optional<DomainNameResolveError>(DomainNameResolveError())), owner); } } void emitOnResult(std::vector<HostAddress> results, boost::optional<DomainNameResolveError> error) { onResult(results, error); } EventLoop* eventLoop; std::string host; StaticDomainNameResolver* resolver; boost::shared_ptr<EventOwner> owner; }; } class StaticDomainNameResolverEventOwner : public EventOwner { public: ~StaticDomainNameResolverEventOwner() { } }; namespace Swift { StaticDomainNameResolver::StaticDomainNameResolver(EventLoop* eventLoop) : eventLoop(eventLoop), isResponsive(true), owner(new StaticDomainNameResolverEventOwner()) { } StaticDomainNameResolver::~StaticDomainNameResolver() { eventLoop->removeEventsFromOwner(owner); } void StaticDomainNameResolver::addAddress(const std::string& domain, const HostAddress& address) { addresses[domain].push_back(address); } void StaticDomainNameResolver::addService(const std::string& service, const DomainNameServiceQuery::Result& result) { services.push_back(std::make_pair(service, result)); } void StaticDomainNameResolver::addXMPPClientService(const std::string& domain, const HostAddressPort& address) { static int hostid = 0; std::string hostname(std::string("host-") + boost::lexical_cast<std::string>(hostid)); hostid++; addService("_xmpp-client._tcp." + domain, ServiceQuery::Result(hostname, address.getPort(), 0, 0)); addAddress(hostname, address.getAddress()); } void StaticDomainNameResolver::addXMPPClientService(const std::string& domain, const std::string& hostname, int port) { addService("_xmpp-client._tcp." + domain, ServiceQuery::Result(hostname, port, 0, 0)); } -boost::shared_ptr<DomainNameServiceQuery> StaticDomainNameResolver::createServiceQuery(const std::string& name) { - return boost::shared_ptr<DomainNameServiceQuery>(new ServiceQuery(name, this, eventLoop, owner)); +boost::shared_ptr<DomainNameServiceQuery> StaticDomainNameResolver::createServiceQuery(const std::string& serviceLookupPrefix, const std::string& domain) { + return boost::shared_ptr<DomainNameServiceQuery>(new ServiceQuery(serviceLookupPrefix + domain, this, eventLoop, owner)); } boost::shared_ptr<DomainNameAddressQuery> StaticDomainNameResolver::createAddressQuery(const std::string& name) { return boost::shared_ptr<DomainNameAddressQuery>(new AddressQuery(name, this, eventLoop, owner)); } } diff --git a/Swiften/Network/StaticDomainNameResolver.h b/Swiften/Network/StaticDomainNameResolver.h index 386179b..81ff040 100644 --- a/Swiften/Network/StaticDomainNameResolver.h +++ b/Swiften/Network/StaticDomainNameResolver.h @@ -1,60 +1,60 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once #include <vector> #include <map> #include <Swiften/Base/API.h> #include <Swiften/Network/HostAddress.h> #include <Swiften/Network/HostAddressPort.h> #include <Swiften/Network/DomainNameResolver.h> #include <Swiften/Network/DomainNameServiceQuery.h> #include <Swiften/Network/DomainNameAddressQuery.h> #include <Swiften/EventLoop/EventLoop.h> namespace Swift { class SWIFTEN_API StaticDomainNameResolver : public DomainNameResolver { public: typedef std::map<std::string, std::vector<HostAddress> > AddressesMap; typedef std::vector< std::pair<std::string, DomainNameServiceQuery::Result> > ServicesCollection; public: StaticDomainNameResolver(EventLoop* eventLoop); ~StaticDomainNameResolver(); void addAddress(const std::string& domain, const HostAddress& address); void addService(const std::string& service, const DomainNameServiceQuery::Result& result); void addXMPPClientService(const std::string& domain, const HostAddressPort&); void addXMPPClientService(const std::string& domain, const std::string& host, int port); const AddressesMap& getAddresses() const { return addresses; } const ServicesCollection& getServices() const { return services; } bool getIsResponsive() const { return isResponsive; } void setIsResponsive(bool b) { isResponsive = b; } - virtual boost::shared_ptr<DomainNameServiceQuery> createServiceQuery(const std::string& name); + virtual boost::shared_ptr<DomainNameServiceQuery> createServiceQuery(const std::string& serviceLookupPrefix, const std::string& domain); virtual boost::shared_ptr<DomainNameAddressQuery> createAddressQuery(const std::string& name); private: EventLoop* eventLoop; bool isResponsive; AddressesMap addresses; ServicesCollection services; boost::shared_ptr<EventOwner> owner; }; } diff --git a/Swiften/QA/NetworkTest/DomainNameResolverTest.cpp b/Swiften/QA/NetworkTest/DomainNameResolverTest.cpp index bc4f1a3..6d25f49 100644 --- a/Swiften/QA/NetworkTest/DomainNameResolverTest.cpp +++ b/Swiften/QA/NetworkTest/DomainNameResolverTest.cpp @@ -1,242 +1,242 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <cppunit/extensions/HelperMacros.h> #include <cppunit/extensions/TestFactoryRegistry.h> #include <boost/bind.hpp> #include <algorithm> #include <Swiften/Base/sleep.h> #include <string> #include <Swiften/Base/ByteArray.h> #ifdef USE_UNBOUND #include <Swiften/Network/UnboundDomainNameResolver.h> #else #include <Swiften/Network/PlatformDomainNameResolver.h> #endif #include <Swiften/Network/BoostTimerFactory.h> #include <Swiften/Network/NetworkFactories.h> #include <Swiften/Network/BoostIOServiceThread.h> #include <Swiften/Network/DomainNameAddressQuery.h> #include <Swiften/Network/DomainNameServiceQuery.h> #include <Swiften/EventLoop/DummyEventLoop.h> #include <Swiften/IDN/IDNConverter.h> #include <Swiften/IDN/PlatformIDNConverter.h> using namespace Swift; struct CompareHostAddresses { bool operator()(const HostAddress& h1, const HostAddress& h2) { return h1.toString() < h2.toString(); } }; class DomainNameResolverTest : public CppUnit::TestFixture { CPPUNIT_TEST_SUITE(DomainNameResolverTest); CPPUNIT_TEST(testResolveAddress); CPPUNIT_TEST(testResolveAddress_Error); #ifndef USE_UNBOUND CPPUNIT_TEST(testResolveAddress_IPv6); CPPUNIT_TEST(testResolveAddress_IPv4and6); CPPUNIT_TEST(testResolveAddress_International); #endif CPPUNIT_TEST(testResolveAddress_Localhost); CPPUNIT_TEST(testResolveAddress_Parallel); #ifndef USE_UNBOUND CPPUNIT_TEST(testResolveService); #endif CPPUNIT_TEST(testResolveService_Error); CPPUNIT_TEST_SUITE_END(); public: void setUp() { ioServiceThread = new BoostIOServiceThread(); eventLoop = new DummyEventLoop(); #ifdef USE_UNBOUND resolver = new UnboundDomainNameResolver(ioServiceThread->getIOService(), eventLoop); #else idnConverter = boost::shared_ptr<IDNConverter>(PlatformIDNConverter::create()); resolver = new PlatformDomainNameResolver(idnConverter.get(), eventLoop); #endif resultsAvailable = false; } void tearDown() { delete ioServiceThread; delete resolver; delete eventLoop; } void testResolveAddress() { boost::shared_ptr<DomainNameAddressQuery> query(createAddressQuery("xmpp.test.swift.im")); query->run(); waitForResults(); CPPUNIT_ASSERT(!addressQueryError); CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(addressQueryResult.size())); CPPUNIT_ASSERT_EQUAL(std::string("10.0.0.0"), addressQueryResult[0].toString()); } void testResolveAddress_Error() { boost::shared_ptr<DomainNameAddressQuery> query(createAddressQuery("invalid.test.swift.im")); query->run(); waitForResults(); CPPUNIT_ASSERT(addressQueryError); } void testResolveAddress_IPv6() { boost::shared_ptr<DomainNameAddressQuery> query(createAddressQuery("xmpp-ipv6.test.swift.im")); query->run(); waitForResults(); CPPUNIT_ASSERT(!addressQueryError); CPPUNIT_ASSERT_EQUAL(std::string("2001:470:1f0e:852::2"), addressQueryResult[0].toString()); } void testResolveAddress_IPv4and6() { boost::shared_ptr<DomainNameAddressQuery> query(createAddressQuery("xmpp-ipv46.test.swift.im")); query->run(); waitForResults(); CPPUNIT_ASSERT(!addressQueryError); CPPUNIT_ASSERT_EQUAL(2, static_cast<int>(addressQueryResult.size())); CPPUNIT_ASSERT_EQUAL(std::string("10.0.0.7"), addressQueryResult[0].toString()); CPPUNIT_ASSERT_EQUAL(std::string("1234:5678:9abc:def0:fed:cba9:8765:4321"), addressQueryResult[1].toString()); } void testResolveAddress_International() { boost::shared_ptr<DomainNameAddressQuery> query(createAddressQuery("tron\xc3\xa7on.test.swift.im")); query->run(); waitForResults(); CPPUNIT_ASSERT(!addressQueryError); CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(addressQueryResult.size())); CPPUNIT_ASSERT_EQUAL(std::string("10.0.0.3"), addressQueryResult[0].toString()); } void testResolveAddress_Localhost() { boost::shared_ptr<DomainNameAddressQuery> query(createAddressQuery("localhost")); query->run(); waitForResults(); CPPUNIT_ASSERT(!addressQueryError); CPPUNIT_ASSERT(std::find(addressQueryResult.begin(), addressQueryResult.end(), HostAddress("127.0.0.1")) != addressQueryResult.end()); } void testResolveAddress_Parallel() { std::vector<DomainNameAddressQuery::ref> queries; static const size_t numQueries = 100; for (size_t i = 0; i < numQueries; ++i) { DomainNameAddressQuery::ref query(createAddressQuery("xmpp.test.swift.im")); queries.push_back(query); query->run(); } eventLoop->processEvents(); int ticks = 0; while (allAddressQueryResults.size() < numQueries) { ticks++; if (ticks > 1000) { CPPUNIT_ASSERT(false); } Swift::sleep(10); eventLoop->processEvents(); } CPPUNIT_ASSERT_EQUAL(numQueries, allAddressQueryResults.size()); for (size_t i = 0; i < numQueries; ++i) { CPPUNIT_ASSERT_EQUAL(std::string("10.0.0.0"), allAddressQueryResults[i].toString()); } } void testResolveService() { - boost::shared_ptr<DomainNameServiceQuery> query(createServiceQuery("_xmpp-client._tcp.xmpp-srv.test.swift.im")); + boost::shared_ptr<DomainNameServiceQuery> query(createServiceQuery("_xmpp-client._tcp.", "xmpp-srv.test.swift.im")); query->run(); waitForResults(); CPPUNIT_ASSERT_EQUAL(4, static_cast<int>(serviceQueryResult.size())); CPPUNIT_ASSERT_EQUAL(std::string("xmpp1.test.swift.im"), serviceQueryResult[0].hostname); CPPUNIT_ASSERT_EQUAL(5000, serviceQueryResult[0].port); CPPUNIT_ASSERT_EQUAL(0, serviceQueryResult[0].priority); CPPUNIT_ASSERT_EQUAL(1, serviceQueryResult[0].weight); CPPUNIT_ASSERT_EQUAL(std::string("xmpp-invalid.test.swift.im"), serviceQueryResult[1].hostname); CPPUNIT_ASSERT_EQUAL(5000, serviceQueryResult[1].port); CPPUNIT_ASSERT_EQUAL(1, serviceQueryResult[1].priority); CPPUNIT_ASSERT_EQUAL(100, serviceQueryResult[1].weight); CPPUNIT_ASSERT_EQUAL(std::string("xmpp3.test.swift.im"), serviceQueryResult[2].hostname); CPPUNIT_ASSERT_EQUAL(5000, serviceQueryResult[2].port); CPPUNIT_ASSERT_EQUAL(3, serviceQueryResult[2].priority); CPPUNIT_ASSERT_EQUAL(100, serviceQueryResult[2].weight); CPPUNIT_ASSERT_EQUAL(std::string("xmpp2.test.swift.im"), serviceQueryResult[3].hostname); CPPUNIT_ASSERT_EQUAL(5000, serviceQueryResult[3].port); CPPUNIT_ASSERT_EQUAL(5, serviceQueryResult[3].priority); CPPUNIT_ASSERT_EQUAL(100, serviceQueryResult[3].weight); } void testResolveService_Error() { } private: boost::shared_ptr<DomainNameAddressQuery> createAddressQuery(const std::string& domain) { boost::shared_ptr<DomainNameAddressQuery> result = resolver->createAddressQuery(domain); result->onResult.connect(boost::bind(&DomainNameResolverTest::handleAddressQueryResult, this, _1, _2)); return result; } void handleAddressQueryResult(const std::vector<HostAddress>& addresses, boost::optional<DomainNameResolveError> error) { addressQueryResult = addresses; std::sort(addressQueryResult.begin(), addressQueryResult.end(), CompareHostAddresses()); allAddressQueryResults.insert(allAddressQueryResults.begin(), addresses.begin(), addresses.end()); addressQueryError = error; resultsAvailable = true; } - boost::shared_ptr<DomainNameServiceQuery> createServiceQuery(const std::string& domain) { - boost::shared_ptr<DomainNameServiceQuery> result = resolver->createServiceQuery(domain); + boost::shared_ptr<DomainNameServiceQuery> createServiceQuery(const std::string& serviceLookupPrefix, const std::string& domain) { + boost::shared_ptr<DomainNameServiceQuery> result = resolver->createServiceQuery(serviceLookupPrefix, domain); result->onResult.connect(boost::bind(&DomainNameResolverTest::handleServiceQueryResult, this, _1)); return result; } void handleServiceQueryResult(const std::vector<DomainNameServiceQuery::Result>& result) { serviceQueryResult = result; resultsAvailable = true; } void waitForResults() { eventLoop->processEvents(); int ticks = 0; while (!resultsAvailable) { ticks++; if (ticks > 1000) { CPPUNIT_ASSERT(false); } Swift::sleep(10); eventLoop->processEvents(); } } private: BoostIOServiceThread* ioServiceThread; DummyEventLoop* eventLoop; boost::shared_ptr<IDNConverter> idnConverter; boost::shared_ptr<TimerFactory> timerFactory; bool resultsAvailable; std::vector<HostAddress> addressQueryResult; std::vector<HostAddress> allAddressQueryResults; boost::optional<DomainNameResolveError> addressQueryError; std::vector<DomainNameServiceQuery::Result> serviceQueryResult; DomainNameResolver* resolver; }; CPPUNIT_TEST_SUITE_REGISTRATION(DomainNameResolverTest); diff --git a/Swiften/TLS/ServerIdentityVerifier.cpp b/Swiften/TLS/ServerIdentityVerifier.cpp index 02459b9..0608a03 100644 --- a/Swiften/TLS/ServerIdentityVerifier.cpp +++ b/Swiften/TLS/ServerIdentityVerifier.cpp @@ -1,88 +1,95 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include <Swiften/TLS/ServerIdentityVerifier.h> #include <boost/algorithm/string.hpp> #include <Swiften/Base/foreach.h> #include <Swiften/IDN/IDNConverter.h> namespace Swift { -ServerIdentityVerifier::ServerIdentityVerifier(const JID& jid, IDNConverter* idnConverter) { +ServerIdentityVerifier::ServerIdentityVerifier(const JID& jid, IDNConverter* idnConverter) : domainValid(false) { domain = jid.getDomain(); - encodedDomain = idnConverter->getIDNAEncoded(domain); + boost::optional<std::string> domainResult = idnConverter->getIDNAEncoded(domain); + if (!!domainResult) { + encodedDomain = *domainResult; + domainValid = true; + } } bool ServerIdentityVerifier::certificateVerifies(Certificate::ref certificate) { bool hasSAN = false; if (certificate == NULL) { return false; } // DNS names std::vector<std::string> dnsNames = certificate->getDNSNames(); foreach (const std::string& dnsName, dnsNames) { if (matchesDomain(dnsName)) { return true; } } hasSAN |= !dnsNames.empty(); // SRV names std::vector<std::string> srvNames = certificate->getSRVNames(); foreach (const std::string& srvName, srvNames) { // Only match SRV names that begin with the service; this isn't required per // spec, but we're being purist about this. if (boost::starts_with(srvName, "_xmpp-client.") && matchesDomain(srvName.substr(std::string("_xmpp-client.").size(), srvName.npos))) { return true; } } hasSAN |= !srvNames.empty(); // XmppAddr std::vector<std::string> xmppAddresses = certificate->getXMPPAddresses(); foreach (const std::string& xmppAddress, xmppAddresses) { if (matchesAddress(xmppAddress)) { return true; } } hasSAN |= !xmppAddresses.empty(); // CommonNames. Only check this if there was no SAN (according to spec). if (!hasSAN) { std::vector<std::string> commonNames = certificate->getCommonNames(); foreach (const std::string& commonName, commonNames) { if (matchesDomain(commonName)) { return true; } } } return false; } bool ServerIdentityVerifier::matchesDomain(const std::string& s) const { + if (!domainValid) { + return false; + } if (boost::starts_with(s, "*.")) { std::string matchString(s.substr(2, s.npos)); std::string matchDomain = encodedDomain; size_t dotIndex = matchDomain.find('.'); if (dotIndex != matchDomain.npos) { matchDomain = matchDomain.substr(dotIndex + 1, matchDomain.npos); } return matchString == matchDomain; } else { return s == encodedDomain; } } bool ServerIdentityVerifier::matchesAddress(const std::string& s) const { return s == domain; } } diff --git a/Swiften/TLS/ServerIdentityVerifier.h b/Swiften/TLS/ServerIdentityVerifier.h index 4167ce8..ea08749 100644 --- a/Swiften/TLS/ServerIdentityVerifier.h +++ b/Swiften/TLS/ServerIdentityVerifier.h @@ -1,33 +1,34 @@ /* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #pragma once #include <boost/shared_ptr.hpp> #include <string> #include <Swiften/Base/API.h> #include <Swiften/JID/JID.h> #include <Swiften/TLS/Certificate.h> namespace Swift { class IDNConverter; class SWIFTEN_API ServerIdentityVerifier { public: ServerIdentityVerifier(const JID& jid, IDNConverter* idnConverter); bool certificateVerifies(Certificate::ref); private: bool matchesDomain(const std::string&) const ; bool matchesAddress(const std::string&) const; private: std::string domain; std::string encodedDomain; + bool domainValid; }; } |
Swift