From 9928be64a4c19f497302963d23ed0efc66b899c0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Remko=20Tron=C3=A7on?= <git@el-tramo.be>
Date: Wed, 1 Jun 2011 23:09:23 +0200
Subject: Added fix for a billion laughs attack on Expat.


diff --git a/Swiften/Parser/ExpatParser.cpp b/Swiften/Parser/ExpatParser.cpp
index 88be752..f091f79 100644
--- a/Swiften/Parser/ExpatParser.cpp
+++ b/Swiften/Parser/ExpatParser.cpp
@@ -16,7 +16,7 @@ namespace Swift {
 
 static const char NAMESPACE_SEPARATOR = '\x01';
 
-static void handleStartElement(void* client, const XML_Char* name, const XML_Char** attributes) {
+static void handleStartElement(void* parser, const XML_Char* name, const XML_Char** attributes) {
 	std::pair<std::string,std::string> nsTagPair = String::getSplittedAtFirst(name, NAMESPACE_SEPARATOR);
 	if (nsTagPair.second == "") {
 		nsTagPair.second = nsTagPair.first;
@@ -34,32 +34,37 @@ static void handleStartElement(void* client, const XML_Char* name, const XML_Cha
 		currentAttribute += 2;
 	}
 
-	static_cast<XMLParserClient*>(client)->handleStartElement(nsTagPair.second, nsTagPair.first, attributeValues);
+	static_cast<XMLParser*>(parser)->getClient()->handleStartElement(nsTagPair.second, nsTagPair.first, attributeValues);
 }
 
-static void handleEndElement(void* client, const XML_Char* name) {
+static void handleEndElement(void* parser, const XML_Char* name) {
 	std::pair<std::string,std::string> nsTagPair = String::getSplittedAtFirst(name, NAMESPACE_SEPARATOR);
 	if (nsTagPair.second == "") {
 		nsTagPair.second = nsTagPair.first;
 		nsTagPair.first = "";
 	}
-	static_cast<XMLParserClient*>(client)->handleEndElement(nsTagPair.second, nsTagPair.first);
+	static_cast<XMLParser*>(parser)->getClient()->handleEndElement(nsTagPair.second, nsTagPair.first);
 }
 
-static void handleCharacterData(void* client, const XML_Char* data, int len) {
-	static_cast<XMLParserClient*>(client)->handleCharacterData(std::string(data, len));
+static void handleCharacterData(void* parser, const XML_Char* data, int len) {
+	static_cast<XMLParser*>(parser)->getClient()->handleCharacterData(std::string(data, len));
 }
 
 static void handleXMLDeclaration(void*, const XML_Char*, const XML_Char*, int) {
 }
 
+static void handleEntityDeclaration(void* parser, const XML_Char*, int, const XML_Char*, int, const XML_Char*, const XML_Char*, const XML_Char*, const XML_Char*) {
+	XML_StopParser(static_cast<ExpatParser*>(parser)->getParser(), static_cast<XML_Bool>(0));
+}
+
 
 ExpatParser::ExpatParser(XMLParserClient* client) : XMLParser(client) {
 	parser_ = XML_ParserCreateNS("UTF-8", NAMESPACE_SEPARATOR);
-	XML_SetUserData(parser_, client);
+	XML_SetUserData(parser_, this);
 	XML_SetElementHandler(parser_, handleStartElement, handleEndElement);
 	XML_SetCharacterDataHandler(parser_, handleCharacterData);
 	XML_SetXmlDeclHandler(parser_, handleXMLDeclaration);
+	XML_SetEntityDeclHandler(parser_, handleEntityDeclaration);
 }
 
 ExpatParser::~ExpatParser() {
diff --git a/Swiften/Parser/ExpatParser.h b/Swiften/Parser/ExpatParser.h
index f6faf17..cd981ef 100644
--- a/Swiften/Parser/ExpatParser.h
+++ b/Swiften/Parser/ExpatParser.h
@@ -20,6 +20,10 @@ namespace Swift {
 
 			bool parse(const std::string& data);
 
+			XML_Parser getParser() {
+				return parser_;
+			}
+
 		private:
 			XML_Parser parser_;
 	};
diff --git a/Swiften/Parser/LibXMLParser.cpp b/Swiften/Parser/LibXMLParser.cpp
index 34db4ca..c94a360 100644
--- a/Swiften/Parser/LibXMLParser.cpp
+++ b/Swiften/Parser/LibXMLParser.cpp
@@ -15,20 +15,20 @@
 
 namespace Swift {
 
-static void handleStartElement(void *client, const xmlChar* name, const xmlChar*, const xmlChar* xmlns, int, const xmlChar**, int nbAttributes, int, const xmlChar ** attributes) {
+static void handleStartElement(void *parser, const xmlChar* name, const xmlChar*, const xmlChar* xmlns, int, const xmlChar**, int nbAttributes, int, const xmlChar ** attributes) {
 	AttributeMap attributeValues;
 	for (int i = 0; i < nbAttributes*5; i += 5) {
 		attributeValues[std::string(reinterpret_cast<const char*>(attributes[i]))] = std::string(reinterpret_cast<const char*>(attributes[i+3]), attributes[i+4]-attributes[i+3]);
 	}
-	static_cast<XMLParserClient*>(client)->handleStartElement(reinterpret_cast<const char*>(name), (xmlns ? reinterpret_cast<const char*>(xmlns) : std::string()), attributeValues);
+	static_cast<XMLParser*>(parser)->getClient()->handleStartElement(reinterpret_cast<const char*>(name), (xmlns ? reinterpret_cast<const char*>(xmlns) : std::string()), attributeValues);
 }
 
-static void handleEndElement(void *client, const xmlChar* name, const xmlChar*, const xmlChar* xmlns) {
-	static_cast<XMLParserClient*>(client)->handleEndElement(reinterpret_cast<const char*>(name), (xmlns ? reinterpret_cast<const char*>(xmlns) : std::string()));
+static void handleEndElement(void *parser, const xmlChar* name, const xmlChar*, const xmlChar* xmlns) {
+	static_cast<XMLParser*>(parser)->getClient()->handleEndElement(reinterpret_cast<const char*>(name), (xmlns ? reinterpret_cast<const char*>(xmlns) : std::string()));
 }
 
-static void handleCharacterData(void* client, const xmlChar* data, int len) {
-	static_cast<XMLParserClient*>(client)->handleCharacterData(std::string(reinterpret_cast<const char*>(data), len));
+static void handleCharacterData(void* parser, const xmlChar* data, int len) {
+	static_cast<XMLParser*>(parser)->getClient()->handleCharacterData(std::string(reinterpret_cast<const char*>(data), len));
 }
 
 static void handleError(void*, const char* /*m*/, ... ) {
@@ -54,7 +54,7 @@ LibXMLParser::LibXMLParser(XMLParserClient* client) : XMLParser(client) {
 	handler_.warning = &handleWarning;
 	handler_.error = &handleError;
 
-	context_ = xmlCreatePushParserCtxt(&handler_, client, 0, 0, 0);
+	context_ = xmlCreatePushParserCtxt(&handler_, this, 0, 0, 0);
 	assert(context_);
 }
 
diff --git a/Swiften/Parser/UnitTest/XMLParserTest.cpp b/Swiften/Parser/UnitTest/XMLParserTest.cpp
index 426b7a0..2086ece 100644
--- a/Swiften/Parser/UnitTest/XMLParserTest.cpp
+++ b/Swiften/Parser/UnitTest/XMLParserTest.cpp
@@ -25,12 +25,14 @@ class XMLParserTest : public CppUnit::TestFixture {
 		CPPUNIT_TEST(testParse_NestedElements);
 		CPPUNIT_TEST(testParse_ElementInNamespacedElement);
 		CPPUNIT_TEST(testParse_CharacterData);
+		CPPUNIT_TEST(testParse_XMLEntity);
 		CPPUNIT_TEST(testParse_NamespacePrefix);
 		CPPUNIT_TEST(testParse_UnhandledXML);
 		CPPUNIT_TEST(testParse_InvalidXML);
 		CPPUNIT_TEST(testParse_InErrorState);
 		CPPUNIT_TEST(testParse_Incremental);
 		CPPUNIT_TEST(testParse_WhitespaceInAttribute);
+		CPPUNIT_TEST(testParse_BillionLaughs);
 		CPPUNIT_TEST_SUITE_END();
 
 	public:
@@ -124,6 +126,26 @@ class XMLParserTest : public CppUnit::TestFixture {
 			CPPUNIT_ASSERT_EQUAL(std::string("html"), client_.events[6].data);
 		}
 
+		void testParse_XMLEntity() {
+			ParserType testling(&client_);
+
+			CPPUNIT_ASSERT(testling.parse("<html>&lt;&gt;</html>"));
+
+			CPPUNIT_ASSERT_EQUAL(static_cast<size_t>(4), client_.events.size());
+
+			CPPUNIT_ASSERT_EQUAL(Client::StartElement, client_.events[0].type);
+			CPPUNIT_ASSERT_EQUAL(std::string("html"), client_.events[0].data);
+
+			CPPUNIT_ASSERT_EQUAL(Client::CharacterData, client_.events[1].type);
+			CPPUNIT_ASSERT_EQUAL(std::string("<"), client_.events[1].data);
+
+			CPPUNIT_ASSERT_EQUAL(Client::CharacterData, client_.events[2].type);
+			CPPUNIT_ASSERT_EQUAL(std::string(">"), client_.events[2].data);
+
+			CPPUNIT_ASSERT_EQUAL(Client::EndElement, client_.events[3].type);
+			CPPUNIT_ASSERT_EQUAL(std::string("html"), client_.events[3].data);
+		}
+
 		void testParse_NamespacePrefix() {
 			ParserType testling(&client_);
 
@@ -205,6 +227,27 @@ class XMLParserTest : public CppUnit::TestFixture {
 			CPPUNIT_ASSERT_EQUAL(Client::EndElement, client_.events[2].type);
 			CPPUNIT_ASSERT_EQUAL(std::string("presence"), client_.events[2].data);
 		}
+
+		void testParse_BillionLaughs() {
+			ParserType testling(&client_);
+
+			CPPUNIT_ASSERT(!testling.parse(
+				"<?xml version=\"1.0\"?>"
+				"<!DOCTYPE lolz ["
+				"  <!ENTITY lol \"lol\">"
+				"  <!ENTITY lol2 \"&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;\">"
+				"  <!ENTITY lol3 \"&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;\">"
+				"  <!ENTITY lol4 \"&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;\">"
+				"  <!ENTITY lol5 \"&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;\">"
+				"  <!ENTITY lol6 \"&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;\">"
+				"  <!ENTITY lol7 \"&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;\">"
+				"  <!ENTITY lol8 \"&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;\">"
+				"  <!ENTITY lol9 \"&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;\">"
+				"]>"
+				"<lolz>&lol9;</lolz>"
+			));
+
+		}
 	
 	private:
 		class Client : public XMLParserClient {
diff --git a/Swiften/Parser/XMLParser.h b/Swiften/Parser/XMLParser.h
index 69a6ecf..1b866e3 100644
--- a/Swiften/Parser/XMLParser.h
+++ b/Swiften/Parser/XMLParser.h
@@ -19,7 +19,6 @@ namespace Swift {
 
 			virtual bool parse(const std::string& data) = 0;
 
-		protected:
 			XMLParserClient* getClient() const {
 				return client_;
 			}
-- 
cgit v0.10.2-6-g49f6