diff options
Diffstat (limited to 'Swiften/Queries')
-rw-r--r-- | Swiften/Queries/IQRouter.h | 15 | ||||
-rw-r--r-- | Swiften/Queries/Request.cpp | 43 | ||||
-rw-r--r-- | Swiften/Queries/Request.h | 2 | ||||
-rw-r--r-- | Swiften/Queries/UnitTest/RequestTest.cpp | 109 |
4 files changed, 144 insertions, 25 deletions
diff --git a/Swiften/Queries/IQRouter.h b/Swiften/Queries/IQRouter.h index 80d623c..167cb8f 100644 --- a/Swiften/Queries/IQRouter.h +++ b/Swiften/Queries/IQRouter.h @@ -22,6 +22,20 @@ namespace Swift { ~IQRouter(); /** + * Sets the JID of this IQ router. + * + * This JID is used by requests to check whether incoming + * results are addressed correctly. + */ + void setJID(const JID& jid) { + jid_ = jid; + } + + const JID& getJID() { + return jid_; + } + + /** * Sets the 'from' JID for all outgoing IQ stanzas. * * By default, IQRouter does not add a from to IQ stanzas, since @@ -55,6 +69,7 @@ namespace Swift { private: IQChannel* channel_; + JID jid_; JID from_; std::vector< boost::shared_ptr<IQHandler> > handlers_; std::vector< boost::shared_ptr<IQHandler> > queuedRemoves_; diff --git a/Swiften/Queries/Request.cpp b/Swiften/Queries/Request.cpp index 359c6a6..0126d62 100644 --- a/Swiften/Queries/Request.cpp +++ b/Swiften/Queries/Request.cpp @@ -41,27 +41,42 @@ bool Request::handleIQ(boost::shared_ptr<IQ> iq) { bool handled = false; if (iq->getType() == IQ::Result || iq->getType() == IQ::Error) { if (sent_ && iq->getID() == id_) { - if (iq->getType() == IQ::Result) { - boost::shared_ptr<Payload> payload = iq->getPayloadOfSameType(payload_); - if (!payload && boost::dynamic_pointer_cast<RawXMLPayload>(payload_) && !iq->getPayloads().empty()) { - payload = iq->getPayloads().front(); - } - handleResponse(payload, ErrorPayload::ref()); - } - else { - ErrorPayload::ref errorPayload = iq->getPayload<ErrorPayload>(); - if (errorPayload) { - handleResponse(boost::shared_ptr<Payload>(), errorPayload); + if (isCorrectSender(iq->getFrom())) { + if (iq->getType() == IQ::Result) { + boost::shared_ptr<Payload> payload = iq->getPayloadOfSameType(payload_); + if (!payload && boost::dynamic_pointer_cast<RawXMLPayload>(payload_) && !iq->getPayloads().empty()) { + payload = iq->getPayloads().front(); + } + handleResponse(payload, ErrorPayload::ref()); } else { - handleResponse(boost::shared_ptr<Payload>(), ErrorPayload::ref(new ErrorPayload(ErrorPayload::UndefinedCondition))); + ErrorPayload::ref errorPayload = iq->getPayload<ErrorPayload>(); + if (errorPayload) { + handleResponse(boost::shared_ptr<Payload>(), errorPayload); + } + else { + handleResponse(boost::shared_ptr<Payload>(), ErrorPayload::ref(new ErrorPayload(ErrorPayload::UndefinedCondition))); + } } + router_->removeHandler(this); + handled = true; } - router_->removeHandler(this); - handled = true; } } return handled; } +bool Request::isCorrectSender(const JID& jid) { + if (isAccountJID(receiver_)) { + return isAccountJID(jid); + } + else { + return jid.equals(receiver_, JID::WithResource); + } +} + +bool Request::isAccountJID(const JID& jid) { + return jid.isValid() ? router_->getJID().toBare().equals(jid, JID::WithResource) : true; +} + } diff --git a/Swiften/Queries/Request.h b/Swiften/Queries/Request.h index 19687c1..a7139cf 100644 --- a/Swiften/Queries/Request.h +++ b/Swiften/Queries/Request.h @@ -59,6 +59,8 @@ namespace Swift { private: bool handleIQ(boost::shared_ptr<IQ>); + bool isCorrectSender(const JID& jid); + bool isAccountJID(const JID& jid); private: IQRouter* router_; diff --git a/Swiften/Queries/UnitTest/RequestTest.cpp b/Swiften/Queries/UnitTest/RequestTest.cpp index b1d1b07..52d62fb 100644 --- a/Swiften/Queries/UnitTest/RequestTest.cpp +++ b/Swiften/Queries/UnitTest/RequestTest.cpp @@ -31,6 +31,12 @@ class RequestTest : public CppUnit::TestFixture { CPPUNIT_TEST(testHandleIQ_RawXMLPayload); CPPUNIT_TEST(testHandleIQ_GetWithSameID); CPPUNIT_TEST(testHandleIQ_SetWithSameID); + CPPUNIT_TEST(testHandleIQ_IncorrectSender); + CPPUNIT_TEST(testHandleIQ_IncorrectSenderForServerQuery); + CPPUNIT_TEST(testHandleIQ_IncorrectOtherResourceSenderForServerQuery); + CPPUNIT_TEST(testHandleIQ_ServerRespondsWithDomain); + CPPUNIT_TEST(testHandleIQ_ServerRespondsWithBareJID); + CPPUNIT_TEST(testHandleIQ_ServerRespondsWithoutFrom); CPPUNIT_TEST_SUITE_END(); public: @@ -98,7 +104,7 @@ class RequestTest : public CppUnit::TestFixture { testling.onResponse.connect(boost::bind(&RequestTest::handleResponse, this, _1, _2)); testling.send(); - channel_->onIQReceived(createResponse("test-id")); + channel_->onIQReceived(createResponse(JID("foo@bar.com/baz"),"test-id")); CPPUNIT_ASSERT_EQUAL(1, responsesReceived_); CPPUNIT_ASSERT_EQUAL(0, static_cast<int>(receivedErrors.size())); @@ -111,7 +117,7 @@ class RequestTest : public CppUnit::TestFixture { testling.onResponse.connect(boost::bind(&RequestTest::handleResponse, this, _1, _2)); testling.send(); - channel_->onIQReceived(createResponse("different-id")); + channel_->onIQReceived(createResponse(JID("foo@bar.com/baz"),"different-id")); CPPUNIT_ASSERT_EQUAL(0, responsesReceived_); CPPUNIT_ASSERT_EQUAL(0, static_cast<int>(receivedErrors.size())); @@ -123,7 +129,7 @@ class RequestTest : public CppUnit::TestFixture { testling.onResponse.connect(boost::bind(&RequestTest::handleResponse, this, _1, _2)); testling.send(); - boost::shared_ptr<IQ> error = createError("test-id"); + boost::shared_ptr<IQ> error = createError(JID("foo@bar.com/baz"),"test-id"); boost::shared_ptr<Payload> errorPayload = boost::shared_ptr<ErrorPayload>(new ErrorPayload(ErrorPayload::InternalServerError)); error->addPayload(errorPayload); channel_->onIQReceived(error); @@ -139,7 +145,7 @@ class RequestTest : public CppUnit::TestFixture { testling.onResponse.connect(boost::bind(&RequestTest::handleResponse, this, _1, _2)); testling.send(); - channel_->onIQReceived(createError("test-id")); + channel_->onIQReceived(createError(JID("foo@bar.com/baz"),"test-id")); CPPUNIT_ASSERT_EQUAL(0, responsesReceived_); CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(receivedErrors.size())); @@ -150,7 +156,7 @@ class RequestTest : public CppUnit::TestFixture { void testHandleIQ_BeforeSend() { MyRequest testling(IQ::Get, JID("foo@bar.com/baz"), payload_, router_); testling.onResponse.connect(boost::bind(&RequestTest::handleResponse, this, _1, _2)); - channel_->onIQReceived(createResponse("test-id")); + channel_->onIQReceived(createResponse(JID("foo@bar.com/baz"),"test-id")); CPPUNIT_ASSERT_EQUAL(0, responsesReceived_); CPPUNIT_ASSERT_EQUAL(0, static_cast<int>(receivedErrors.size())); @@ -163,7 +169,7 @@ class RequestTest : public CppUnit::TestFixture { testling.send(); responsePayload_ = boost::make_shared<MyOtherPayload>(); - channel_->onIQReceived(createResponse("test-id")); + channel_->onIQReceived(createResponse(JID("foo@bar.com/baz"),"test-id")); CPPUNIT_ASSERT_EQUAL(1, responsesReceived_); CPPUNIT_ASSERT_EQUAL(0, static_cast<int>(receivedErrors.size())); @@ -177,7 +183,7 @@ class RequestTest : public CppUnit::TestFixture { testling.send(); responsePayload_ = boost::make_shared<MyOtherPayload>(); - channel_->onIQReceived(createResponse("test-id")); + channel_->onIQReceived(createResponse(JID("foo@bar.com/baz"),"test-id")); CPPUNIT_ASSERT_EQUAL(1, responsesReceived_); CPPUNIT_ASSERT_EQUAL(0, static_cast<int>(receivedErrors.size())); @@ -189,7 +195,7 @@ class RequestTest : public CppUnit::TestFixture { testling.onResponse.connect(boost::bind(&RequestTest::handleResponse, this, _1, _2)); testling.send(); - boost::shared_ptr<IQ> response = createResponse("test-id"); + boost::shared_ptr<IQ> response = createResponse(JID("foo@bar.com/baz"),"test-id"); response->setType(IQ::Get); channel_->onIQReceived(response); @@ -203,7 +209,7 @@ class RequestTest : public CppUnit::TestFixture { testling.onResponse.connect(boost::bind(&RequestTest::handleResponse, this, _1, _2)); testling.send(); - boost::shared_ptr<IQ> response = createResponse("test-id"); + boost::shared_ptr<IQ> response = createResponse(JID("foo@bar.com/baz"), "test-id"); response->setType(IQ::Set); channel_->onIQReceived(response); @@ -212,6 +218,85 @@ class RequestTest : public CppUnit::TestFixture { CPPUNIT_ASSERT_EQUAL(2, static_cast<int>(channel_->iqs_.size())); } + void testHandleIQ_IncorrectSender() { + MyRequest testling(IQ::Get, JID("foo@bar.com/baz"), payload_, router_); + router_->setJID("alice@wonderland.lit/TeaParty"); + testling.onResponse.connect(boost::bind(&RequestTest::handleResponse, this, _1, _2)); + testling.send(); + + channel_->onIQReceived(createResponse(JID("anotherfoo@bar.com/baz"), "test-id")); + + CPPUNIT_ASSERT_EQUAL(0, responsesReceived_); + CPPUNIT_ASSERT_EQUAL(0, static_cast<int>(receivedErrors.size())); + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(channel_->iqs_.size())); + } + + void testHandleIQ_IncorrectSenderForServerQuery() { + MyRequest testling(IQ::Get, JID(), payload_, router_); + router_->setJID("alice@wonderland.lit/TeaParty"); + testling.onResponse.connect(boost::bind(&RequestTest::handleResponse, this, _1, _2)); + testling.send(); + + channel_->onIQReceived(createResponse(JID("foo@bar.com/baz"), "test-id")); + + CPPUNIT_ASSERT_EQUAL(0, responsesReceived_); + CPPUNIT_ASSERT_EQUAL(0, static_cast<int>(receivedErrors.size())); + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(channel_->iqs_.size())); + } + + void testHandleIQ_IncorrectOtherResourceSenderForServerQuery() { + MyRequest testling(IQ::Get, JID(), payload_, router_); + router_->setJID("alice@wonderland.lit/TeaParty"); + testling.onResponse.connect(boost::bind(&RequestTest::handleResponse, this, _1, _2)); + testling.send(); + + channel_->onIQReceived(createResponse(JID("alice@wonderland.lit/RabbitHole"), "test-id")); + + CPPUNIT_ASSERT_EQUAL(0, responsesReceived_); + CPPUNIT_ASSERT_EQUAL(0, static_cast<int>(receivedErrors.size())); + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(channel_->iqs_.size())); + } + + void testHandleIQ_ServerRespondsWithDomain() { + MyRequest testling(IQ::Get, JID(), payload_, router_); + router_->setJID("alice@wonderland.lit/TeaParty"); + testling.onResponse.connect(boost::bind(&RequestTest::handleResponse, this, _1, _2)); + testling.send(); + + channel_->onIQReceived(createResponse(JID("wonderland.lit"),"test-id")); + + CPPUNIT_ASSERT_EQUAL(0, responsesReceived_); + CPPUNIT_ASSERT_EQUAL(0, static_cast<int>(receivedErrors.size())); + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(channel_->iqs_.size())); + } + + void testHandleIQ_ServerRespondsWithBareJID() { + MyRequest testling(IQ::Get, JID(), payload_, router_); + router_->setJID("alice@wonderland.lit/TeaParty"); + testling.onResponse.connect(boost::bind(&RequestTest::handleResponse, this, _1, _2)); + testling.send(); + + channel_->onIQReceived(createResponse(JID("alice@wonderland.lit"),"test-id")); + + CPPUNIT_ASSERT_EQUAL(1, responsesReceived_); + CPPUNIT_ASSERT_EQUAL(0, static_cast<int>(receivedErrors.size())); + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(channel_->iqs_.size())); + } + + void testHandleIQ_ServerRespondsWithoutFrom() { + MyRequest testling(IQ::Get, JID(), payload_, router_); + router_->setJID("alice@wonderland.lit/TeaParty"); + testling.onResponse.connect(boost::bind(&RequestTest::handleResponse, this, _1, _2)); + testling.send(); + + channel_->onIQReceived(createResponse(JID(),"test-id")); + + CPPUNIT_ASSERT_EQUAL(1, responsesReceived_); + CPPUNIT_ASSERT_EQUAL(0, static_cast<int>(receivedErrors.size())); + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(channel_->iqs_.size())); + } + + private: void handleResponse(boost::shared_ptr<Payload> p, ErrorPayload::ref e) { @@ -239,15 +324,17 @@ class RequestTest : public CppUnit::TestFixture { ++responsesReceived_; } - boost::shared_ptr<IQ> createResponse(const std::string& id) { + boost::shared_ptr<IQ> createResponse(const JID& from, const std::string& id) { boost::shared_ptr<IQ> iq(new IQ(IQ::Result)); + iq->setFrom(from); iq->addPayload(responsePayload_); iq->setID(id); return iq; } - boost::shared_ptr<IQ> createError(const std::string& id) { + boost::shared_ptr<IQ> createError(const JID& from, const std::string& id) { boost::shared_ptr<IQ> iq(new IQ(IQ::Error)); + iq->setFrom(from); iq->setID(id); return iq; } |