From 949278e9418e4090c2f3c68ed313cf4853b25ef1 Mon Sep 17 00:00:00 2001
From: dknn <yoann.blein@free.fr>
Date: Wed, 29 Aug 2012 13:16:57 +0200
Subject: Add UDP support in nat traversal


diff --git a/Swiften/Examples/NetworkTool/main.cpp b/Swiften/Examples/NetworkTool/main.cpp
index 00c12d2..2354c8b 100644
--- a/Swiften/Examples/NetworkTool/main.cpp
+++ b/Swiften/Examples/NetworkTool/main.cpp
@@ -65,7 +65,7 @@ int main(int argc, char* argv[]) {
 		if (argc < 4) {
 			std::cerr << "Invalid parameters" << std::endl;
 		}
-		boost::shared_ptr<NATTraversalForwardPortRequest> query = natTraverser.createForwardPortRequest(boost::lexical_cast<int>(argv[2]), boost::lexical_cast<int>(argv[3]));
+		boost::shared_ptr<NATTraversalForwardPortRequest> query = natTraverser.createForwardPortRequest(boost::lexical_cast<int>(argv[2]), boost::lexical_cast<int>(argv[3]), NATPortMapping::TCP);
 		query->onResult.connect(boost::bind(&handleGetForwardPortRequestResponse, _1));
 		query->run();
 		eventLoop.run();
@@ -74,7 +74,7 @@ int main(int argc, char* argv[]) {
 		if (argc < 4) {
 			std::cerr << "Invalid parameters" << std::endl;
 		}
-		boost::shared_ptr<NATTraversalRemovePortForwardingRequest> query = natTraverser.createRemovePortForwardingRequest(boost::lexical_cast<int>(argv[2]), boost::lexical_cast<int>(argv[3]));
+		boost::shared_ptr<NATTraversalRemovePortForwardingRequest> query = natTraverser.createRemovePortForwardingRequest(boost::lexical_cast<int>(argv[2]), boost::lexical_cast<int>(argv[3]), NATPortMapping::TCP);
 		query->onResult.connect(boost::bind(&handleRemovePortForwardingRequestResponse, _1));
 		query->run();
 		eventLoop.run();
diff --git a/Swiften/FileTransfer/ConnectivityManager.cpp b/Swiften/FileTransfer/ConnectivityManager.cpp
index de5eccb..f4e7e96 100644
--- a/Swiften/FileTransfer/ConnectivityManager.cpp
+++ b/Swiften/FileTransfer/ConnectivityManager.cpp
@@ -23,39 +23,44 @@ ConnectivityManager::ConnectivityManager(NATTraverser* worker) : natTraversalWor
 }
 
 ConnectivityManager::~ConnectivityManager() {
-	std::set<int> leftOpenPorts = ports;
-	foreach(int port, leftOpenPorts) {
-		removeListeningPort(port);
+	std::set<int> leftOpenUDPPorts = udpPorts;
+	foreach(int port, leftOpenUDPPorts) {
+		removeListeningPort(port, NATPortMapping::UDP);
+	}
+	std::set<int> leftOpenTCPPorts = tcpPorts;
+	foreach(int port, leftOpenTCPPorts) {
+		removeListeningPort(port, NATPortMapping::TCP);
 	}
 }
 
-void ConnectivityManager::addListeningPort(int port) {
-	ports.insert(port);
+void ConnectivityManager::addListeningPort(int port, NATPortMapping::Protocol protocol) {
+	udpPorts.insert(port);
 	boost::shared_ptr<NATTraversalGetPublicIPRequest> getIPRequest = natTraversalWorker->createGetPublicIPRequest();
 	if (getIPRequest) {
 		getIPRequest->onResult.connect(boost::bind(&ConnectivityManager::natTraversalGetPublicIPResult, this, _1));
 		getIPRequest->run();
 	}
 
-	boost::shared_ptr<NATTraversalForwardPortRequest> forwardPortRequest = natTraversalWorker->createForwardPortRequest(port, port);
+	boost::shared_ptr<NATTraversalForwardPortRequest> forwardPortRequest = natTraversalWorker->createForwardPortRequest(port, port, protocol);
 	if (forwardPortRequest) {
 		forwardPortRequest->onResult.connect(boost::bind(&ConnectivityManager::natTraversalForwardPortResult, this, _1));
 		forwardPortRequest->run();
 	}
 }
 
-void ConnectivityManager::removeListeningPort(int port) {
+void ConnectivityManager::removeListeningPort(int port, NATPortMapping::Protocol protocol) {
 	SWIFT_LOG(debug) << "remove listening port " << port << std::endl;
-	ports.erase(port);
-	boost::shared_ptr<NATTraversalRemovePortForwardingRequest> removePortForwardingRequest = natTraversalWorker->createRemovePortForwardingRequest(port, port);
+	udpPorts.erase(port);
+	boost::shared_ptr<NATTraversalRemovePortForwardingRequest> removePortForwardingRequest = natTraversalWorker->createRemovePortForwardingRequest(port, port, protocol);
 	if (removePortForwardingRequest) {
 		removePortForwardingRequest->run();
 	}
 }
 
-std::vector<HostAddressPort> ConnectivityManager::getHostAddressPorts() const {
+std::vector<HostAddressPort> ConnectivityManager::getHostAddressPorts(NATPortMapping::Protocol protocol) const {
 	PlatformNetworkEnvironment env;
 	std::vector<HostAddressPort> results;
+	const std::set<int>& ports = (protocol == NATPortMapping::TCP ? udpPorts : tcpPorts);
 
 	//std::vector<HostAddress> addresses;
 
@@ -71,10 +76,11 @@ std::vector<HostAddressPort> ConnectivityManager::getHostAddressPorts() const {
 	return results;
 }
 
-std::vector<HostAddressPort> ConnectivityManager::getAssistedHostAddressPorts() const {
+std::vector<HostAddressPort> ConnectivityManager::getAssistedHostAddressPorts(NATPortMapping::Protocol protocol) const {
 	std::vector<HostAddressPort> results;
 
 	if (publicAddress) {
+		const std::set<int>& ports = (protocol == NATPortMapping::TCP ? udpPorts : tcpPorts);
 		foreach (int port, ports) {
 			results.push_back(HostAddressPort(publicAddress.get(), port));
 		}
diff --git a/Swiften/FileTransfer/ConnectivityManager.h b/Swiften/FileTransfer/ConnectivityManager.h
index c094c02..c70cb5a 100644
--- a/Swiften/FileTransfer/ConnectivityManager.h
+++ b/Swiften/FileTransfer/ConnectivityManager.h
@@ -25,11 +25,11 @@ public:
 	ConnectivityManager(NATTraverser*);
 	~ConnectivityManager();
 public:
-	void addListeningPort(int port);
-	void removeListeningPort(int port);
+	void addListeningPort(int port, NATPortMapping::Protocol protocol);
+	void removeListeningPort(int port, NATPortMapping::Protocol protocol);
 
-	std::vector<HostAddressPort> getHostAddressPorts() const;
-	std::vector<HostAddressPort> getAssistedHostAddressPorts() const;
+	std::vector<HostAddressPort> getHostAddressPorts(NATPortMapping::Protocol protocol) const;
+	std::vector<HostAddressPort> getAssistedHostAddressPorts(NATPortMapping::Protocol protocol) const;
 
 private:
 	void natTraversalGetPublicIPResult(boost::optional<HostAddress> address);
@@ -38,7 +38,8 @@ private:
 private:
 	NATTraverser* natTraversalWorker;
 
-	std::set<int> ports;
+	std::set<int> udpPorts;
+	std::set<int> tcpPorts;
 	boost::optional<HostAddress> publicAddress;
 };
 
diff --git a/Swiften/FileTransfer/DefaultLocalJingleTransportCandidateGenerator.cpp b/Swiften/FileTransfer/DefaultLocalJingleTransportCandidateGenerator.cpp
index 4b205cb..c3921bc 100644
--- a/Swiften/FileTransfer/DefaultLocalJingleTransportCandidateGenerator.cpp
+++ b/Swiften/FileTransfer/DefaultLocalJingleTransportCandidateGenerator.cpp
@@ -41,7 +41,7 @@ void DefaultLocalJingleTransportCandidateGenerator::generateLocalTransportCandid
 		const unsigned long localPreference = 0;
 
 		// get direct candidates
-		std::vector<HostAddressPort> directCandidates = connectivityManager->getHostAddressPorts();
+		std::vector<HostAddressPort> directCandidates = connectivityManager->getHostAddressPorts(NATPortMapping::TCP);
 		foreach(HostAddressPort addressPort, directCandidates) {
 			JingleS5BTransportPayload::Candidate candidate;
 			candidate.type = JingleS5BTransportPayload::Candidate::DirectType;
@@ -53,7 +53,7 @@ void DefaultLocalJingleTransportCandidateGenerator::generateLocalTransportCandid
 		}
 
 		// get assissted candidates
-		std::vector<HostAddressPort> assisstedCandidates = connectivityManager->getAssistedHostAddressPorts();
+		std::vector<HostAddressPort> assisstedCandidates = connectivityManager->getAssistedHostAddressPorts(NATPortMapping::TCP);
 		foreach(HostAddressPort addressPort, assisstedCandidates) {
 			JingleS5BTransportPayload::Candidate candidate;
 			candidate.type = JingleS5BTransportPayload::Candidate::AssistedType;
diff --git a/Swiften/FileTransfer/FileTransferManagerImpl.cpp b/Swiften/FileTransfer/FileTransferManagerImpl.cpp
index 7fd8b07..22dc681 100644
--- a/Swiften/FileTransfer/FileTransferManagerImpl.cpp
+++ b/Swiften/FileTransfer/FileTransferManagerImpl.cpp
@@ -68,7 +68,7 @@ void FileTransferManagerImpl::startListeningOnPort(int port) {
 	server->start();
 	bytestreamServer = new SOCKS5BytestreamServer(server, bytestreamRegistry);
 	bytestreamServer->start();
-	connectivityManager->addListeningPort(port);
+	connectivityManager->addListeningPort(port, NATPortMapping::TCP);
 
 	s5bProxyFinder = new SOCKS5BytestreamProxyFinder(ownJID.getDomain(), iqRouter);
 	s5bProxyFinder->onProxyFound.connect(boost::bind(&FileTransferManagerImpl::addS5BProxy, this, _1));
diff --git a/Swiften/Network/MiniUPnPInterface.cpp b/Swiften/Network/MiniUPnPInterface.cpp
index c729371..e15ba3a 100644
--- a/Swiften/Network/MiniUPnPInterface.cpp
+++ b/Swiften/Network/MiniUPnPInterface.cpp
@@ -61,12 +61,12 @@ boost::optional<HostAddress> MiniUPnPInterface::getPublicIP() {
 	}
 }
 
-boost::optional<NATPortMapping> MiniUPnPInterface::addPortForward(int actualLocalPort, int actualPublicPort) {
+boost::optional<NATPortMapping> MiniUPnPInterface::addPortForward(int actualLocalPort, int actualPublicPort, NATPortMapping::Protocol protocol) {
 	if (!p->isValid) {
 		return boost::optional<NATPortMapping>();
 	}
 
-	NATPortMapping mapping(actualLocalPort, actualPublicPort, NATPortMapping::TCP);
+	NATPortMapping mapping(actualLocalPort, actualPublicPort, protocol);
 
 	std::string publicPort = boost::lexical_cast<std::string>(mapping.getPublicPort());
 	std::string localPort = boost::lexical_cast<std::string>(mapping.getLocalPort());
diff --git a/Swiften/Network/MiniUPnPInterface.h b/Swiften/Network/MiniUPnPInterface.h
index 61d12ca..d917f7c 100644
--- a/Swiften/Network/MiniUPnPInterface.h
+++ b/Swiften/Network/MiniUPnPInterface.h
@@ -21,9 +21,9 @@ namespace Swift {
 
 			virtual bool isAvailable();
 
-			boost::optional<HostAddress> getPublicIP();
-			boost::optional<NATPortMapping> addPortForward(int localPort, int publicPort);
-			bool removePortForward(const NATPortMapping&);
+			virtual boost::optional<HostAddress> getPublicIP();
+			virtual boost::optional<NATPortMapping> addPortForward(int localPort, int publicPort, NATPortMapping::Protocol protocol);
+			virtual bool removePortForward(const NATPortMapping&);
 
 		private:
 			struct Private;
diff --git a/Swiften/Network/NATPMPInterface.cpp b/Swiften/Network/NATPMPInterface.cpp
index 220e3e9..2634eaf 100644
--- a/Swiften/Network/NATPMPInterface.cpp
+++ b/Swiften/Network/NATPMPInterface.cpp
@@ -61,8 +61,8 @@ boost::optional<HostAddress> NATPMPInterface::getPublicIP() {
 	}
 }
 
-boost::optional<NATPortMapping> NATPMPInterface::addPortForward(int localPort, int publicPort) {
-	NATPortMapping mapping(localPort, publicPort, NATPortMapping::TCP);
+boost::optional<NATPortMapping> NATPMPInterface::addPortForward(int localPort, int publicPort, NATPortMapping::Protocol protocol) {
+	NATPortMapping mapping(localPort, publicPort, protocol);
 	if (sendnewportmappingrequest(&p->natpmp, mapping.getProtocol() == NATPortMapping::TCP ? NATPMP_PROTOCOL_TCP : NATPMP_PROTOCOL_UDP, mapping.getLeaseInSeconds(), mapping.getPublicPort(), mapping.getLocalPort()) < 0) {
 			SWIFT_LOG(debug) << "Failed to send NAT-PMP port forwarding request!" << std::endl;
 			return boost::optional<NATPortMapping>();
@@ -81,7 +81,7 @@ boost::optional<NATPortMapping> NATPMPInterface::addPortForward(int localPort, i
 	} while(r == NATPMP_TRYAGAIN);
 
 	if (r == 0) {
-		NATPortMapping result(response.pnu.newportmapping.privateport, response.pnu.newportmapping.mappedpublicport, NATPortMapping::TCP, response.pnu.newportmapping.lifetime);
+		NATPortMapping result(response.pnu.newportmapping.privateport, response.pnu.newportmapping.mappedpublicport, protocol, response.pnu.newportmapping.lifetime);
 		return result;
 	}
 	else {
diff --git a/Swiften/Network/NATPMPInterface.h b/Swiften/Network/NATPMPInterface.h
index e079a59..f3bf432 100644
--- a/Swiften/Network/NATPMPInterface.h
+++ b/Swiften/Network/NATPMPInterface.h
@@ -21,7 +21,7 @@ namespace Swift {
 			virtual bool isAvailable();
 
 			virtual boost::optional<HostAddress> getPublicIP();
-			virtual boost::optional<NATPortMapping> addPortForward(int localPort, int publicPort);
+			virtual boost::optional<NATPortMapping> addPortForward(int localPort, int publicPort, NATPortMapping::Protocol protocol);
 			virtual bool removePortForward(const NATPortMapping&);
 
 		private:
diff --git a/Swiften/Network/NATTraversalInterface.h b/Swiften/Network/NATTraversalInterface.h
index c84deba..6c09e7b 100644
--- a/Swiften/Network/NATTraversalInterface.h
+++ b/Swiften/Network/NATTraversalInterface.h
@@ -18,7 +18,7 @@ namespace Swift {
 			virtual bool isAvailable() = 0;
 
 			virtual boost::optional<HostAddress> getPublicIP() = 0;
-			virtual boost::optional<NATPortMapping> addPortForward(int localPort, int publicPort) = 0;
+			virtual boost::optional<NATPortMapping> addPortForward(int localPort, int publicPort, NATPortMapping::Protocol protocol) = 0;
 			virtual bool removePortForward(const NATPortMapping&) = 0;
 	};
 }
diff --git a/Swiften/Network/NATTraverser.h b/Swiften/Network/NATTraverser.h
index e48ce26..5273b35 100644
--- a/Swiften/Network/NATTraverser.h
+++ b/Swiften/Network/NATTraverser.h
@@ -7,6 +7,7 @@
 #pragma once
 
 #include <boost/shared_ptr.hpp>
+#include <Swiften/Network/NATPortMapping.h>
 
 namespace Swift {
 	class NATTraversalGetPublicIPRequest;
@@ -18,7 +19,7 @@ namespace Swift {
 			virtual ~NATTraverser();
 
 			virtual boost::shared_ptr<NATTraversalGetPublicIPRequest> createGetPublicIPRequest() = 0;
-			virtual boost::shared_ptr<NATTraversalForwardPortRequest> createForwardPortRequest(int localPort, int publicPort) = 0;
-			virtual boost::shared_ptr<NATTraversalRemovePortForwardingRequest> createRemovePortForwardingRequest(int localPort, int publicPort) = 0;
+			virtual boost::shared_ptr<NATTraversalForwardPortRequest> createForwardPortRequest(int localPort, int publicPort, NATPortMapping::Protocol protocol) = 0;
+			virtual boost::shared_ptr<NATTraversalRemovePortForwardingRequest> createRemovePortForwardingRequest(int localPort, int publicPort, NATPortMapping::Protocol protocol) = 0;
 	};
 }
diff --git a/Swiften/Network/NullNATTraversalInterface.h b/Swiften/Network/NullNATTraversalInterface.h
index 72a4a08..ec70615 100644
--- a/Swiften/Network/NullNATTraversalInterface.h
+++ b/Swiften/Network/NullNATTraversalInterface.h
@@ -21,7 +21,7 @@ namespace Swift {
 				return boost::optional<HostAddress>();
 			}
 
-			virtual boost::optional<NATPortMapping> addPortForward(int, int) {
+			virtual boost::optional<NATPortMapping> addPortForward(int, int, NATPortMapping::Protocol) {
 				return boost::optional<NATPortMapping>();
 			}
 
diff --git a/Swiften/Network/NullNATTraverser.cpp b/Swiften/Network/NullNATTraverser.cpp
index 8cb35cd..2cd6658 100644
--- a/Swiften/Network/NullNATTraverser.cpp
+++ b/Swiften/Network/NullNATTraverser.cpp
@@ -62,11 +62,11 @@ boost::shared_ptr<NATTraversalGetPublicIPRequest> NullNATTraverser::createGetPub
 	return boost::make_shared<NullNATTraversalGetPublicIPRequest>(eventLoop);
 }
 
-boost::shared_ptr<NATTraversalForwardPortRequest> NullNATTraverser::createForwardPortRequest(int, int) {
+boost::shared_ptr<NATTraversalForwardPortRequest> NullNATTraverser::createForwardPortRequest(int, int, NATPortMapping::Protocol) {
 	return boost::make_shared<NullNATTraversalForwardPortRequest>(eventLoop);
 }
 
-boost::shared_ptr<NATTraversalRemovePortForwardingRequest> NullNATTraverser::createRemovePortForwardingRequest(int, int) {
+boost::shared_ptr<NATTraversalRemovePortForwardingRequest> NullNATTraverser::createRemovePortForwardingRequest(int, int, NATPortMapping::Protocol) {
 	return boost::make_shared<NullNATTraversalRemovePortForwardingRequest>(eventLoop);
 }
 
diff --git a/Swiften/Network/NullNATTraverser.h b/Swiften/Network/NullNATTraverser.h
index 5775a9b..8505112 100644
--- a/Swiften/Network/NullNATTraverser.h
+++ b/Swiften/Network/NullNATTraverser.h
@@ -16,8 +16,8 @@ namespace Swift {
 			NullNATTraverser(EventLoop* eventLoop);
 
 			boost::shared_ptr<NATTraversalGetPublicIPRequest> createGetPublicIPRequest();
-			boost::shared_ptr<NATTraversalForwardPortRequest> createForwardPortRequest(int localPort, int publicPort);
-			boost::shared_ptr<NATTraversalRemovePortForwardingRequest> createRemovePortForwardingRequest(int localPort, int publicPort);
+			boost::shared_ptr<NATTraversalForwardPortRequest> createForwardPortRequest(int localPort, int publicPort, NATPortMapping::Protocol);
+			boost::shared_ptr<NATTraversalRemovePortForwardingRequest> createRemovePortForwardingRequest(int localPort, int publicPort, NATPortMapping::Protocol);
 
 		private:
 			EventLoop* eventLoop;
diff --git a/Swiften/Network/PlatformNATTraversalWorker.cpp b/Swiften/Network/PlatformNATTraversalWorker.cpp
index c962b3b..bce0bbc 100644
--- a/Swiften/Network/PlatformNATTraversalWorker.cpp
+++ b/Swiften/Network/PlatformNATTraversalWorker.cpp
@@ -60,7 +60,7 @@ class PlatformNATTraversalGetPublicIPRequest : public NATTraversalGetPublicIPReq
 
 class PlatformNATTraversalForwardPortRequest : public NATTraversalForwardPortRequest, public PlatformNATTraversalRequest {
 	public:
-		PlatformNATTraversalForwardPortRequest(PlatformNATTraversalWorker* worker, unsigned int localIP, unsigned int publicIP) : PlatformNATTraversalRequest(worker), localIP(localIP), publicIP(publicIP) {
+		PlatformNATTraversalForwardPortRequest(PlatformNATTraversalWorker* worker, unsigned int localIP, unsigned int publicIP, NATPortMapping::Protocol protocol) : PlatformNATTraversalRequest(worker), localIP(localIP), publicIP(publicIP), protocol(protocol) {
 		}
 
 		virtual void run() {
@@ -68,12 +68,13 @@ class PlatformNATTraversalForwardPortRequest : public NATTraversalForwardPortReq
 		}
 
 		virtual void runBlocking() {
-			onResult(getNATTraversalInterface()->addPortForward(localIP, publicIP));
+			onResult(getNATTraversalInterface()->addPortForward(localIP, publicIP, protocol));
 		}
 
 	private:
 		unsigned int localIP;
 		unsigned int publicIP;
+		NATPortMapping::Protocol protocol;
 };
 
 class PlatformNATTraversalRemovePortForwardingRequest : public NATTraversalRemovePortForwardingRequest, public PlatformNATTraversalRequest {
@@ -133,12 +134,12 @@ boost::shared_ptr<NATTraversalGetPublicIPRequest> PlatformNATTraversalWorker::cr
 	return boost::make_shared<PlatformNATTraversalGetPublicIPRequest>(this);
 }
 
-boost::shared_ptr<NATTraversalForwardPortRequest> PlatformNATTraversalWorker::createForwardPortRequest(int localPort, int publicPort) {
-	return boost::make_shared<PlatformNATTraversalForwardPortRequest>(this, localPort, publicPort);
+boost::shared_ptr<NATTraversalForwardPortRequest> PlatformNATTraversalWorker::createForwardPortRequest(int localPort, int publicPort, NATPortMapping::Protocol protocol) {
+	return boost::make_shared<PlatformNATTraversalForwardPortRequest>(this, localPort, publicPort, protocol);
 }
 
-boost::shared_ptr<NATTraversalRemovePortForwardingRequest> PlatformNATTraversalWorker::createRemovePortForwardingRequest(int localPort, int publicPort) {
-	NATPortMapping mapping(localPort, publicPort, NATPortMapping::TCP); // FIXME
+boost::shared_ptr<NATTraversalRemovePortForwardingRequest> PlatformNATTraversalWorker::createRemovePortForwardingRequest(int localPort, int publicPort, NATPortMapping::Protocol protocol) {
+	NATPortMapping mapping(localPort, publicPort, protocol); // FIXME
 	return boost::make_shared<PlatformNATTraversalRemovePortForwardingRequest>(this, mapping);
 }
 
diff --git a/Swiften/Network/PlatformNATTraversalWorker.h b/Swiften/Network/PlatformNATTraversalWorker.h
index 8060e31..5c4a387 100644
--- a/Swiften/Network/PlatformNATTraversalWorker.h
+++ b/Swiften/Network/PlatformNATTraversalWorker.h
@@ -37,8 +37,8 @@ namespace Swift {
 			~PlatformNATTraversalWorker();
 
 			boost::shared_ptr<NATTraversalGetPublicIPRequest> createGetPublicIPRequest();
-			boost::shared_ptr<NATTraversalForwardPortRequest> createForwardPortRequest(int localPort, int publicPort);
-			boost::shared_ptr<NATTraversalRemovePortForwardingRequest> createRemovePortForwardingRequest(int localPort, int publicPort);
+			boost::shared_ptr<NATTraversalForwardPortRequest> createForwardPortRequest(int localPort, int publicPort, NATPortMapping::Protocol protocol);
+			boost::shared_ptr<NATTraversalRemovePortForwardingRequest> createRemovePortForwardingRequest(int localPort, int publicPort, NATPortMapping::Protocol protocol);
 
 		private:
 			NATTraversalInterface* getNATTraversalInterface() const;
-- 
cgit v0.10.2-6-g49f6