summaryrefslogtreecommitdiffstats
blob: 2e4c45bcaa0187b9149e7f715a1b6b2c1733ce2c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
/*
 * Copyright (c) 2010-2015, Isode Limited, London, England.
 * All rights reserved.
 */
package com.isode.stroke.disco;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import com.isode.stroke.client.StanzaChannel;
import com.isode.stroke.crypto.CryptoProvider;
import com.isode.stroke.elements.CapsInfo;
import com.isode.stroke.elements.DiscoInfo;
import com.isode.stroke.elements.ErrorPayload;
import com.isode.stroke.elements.Presence;
import com.isode.stroke.jid.JID;
import com.isode.stroke.queries.IQRouter;
import com.isode.stroke.signals.Slot1;
import com.isode.stroke.signals.Slot2;

public class CapsManager extends CapsProvider {

    private final IQRouter iqRouter;
    private final CryptoProvider crypto;
    private final CapsStorage capsStorage;
    private boolean warnOnInvalidHash;
    private Set<String> requestedDiscoInfos = new HashSet<String>();
    private Set<CapsPair> failingCaps = new HashSet<CapsPair>();
    private Map<String, Set<CapsPair>> fallbacks = new HashMap<String, Set<CapsPair>>();

    private class CapsPair {
        JID jid;
        String node;

        CapsPair(JID j, String n) {jid = j; node = n;}

        @Override
        public boolean equals(Object o) {
            if (!(o instanceof CapsPair)) return false;
            CapsPair o1 = (CapsPair) o;
            return jid.equals(o1.jid) && node.equals(o1.node);
        }

        @Override public int hashCode() {return jid.hashCode() * 5 + node.hashCode();}
    }

    public CapsManager(CapsStorage capsStorage, StanzaChannel stanzaChannel,
            IQRouter iqRouter, CryptoProvider crypto) {
        this.iqRouter = iqRouter;
        this.crypto = crypto;
        this.capsStorage = capsStorage;
        this.warnOnInvalidHash = true;

        stanzaChannel.onPresenceReceived.connect(new Slot1<Presence>() {
            @Override
            public void call(Presence p1) {
                handlePresenceReceived(p1);
            }
        });
        stanzaChannel.onAvailableChanged.connect(new Slot1<Boolean>() {
            @Override
            public void call(Boolean p1) {
                handleStanzaChannelAvailableChanged(p1);
            }
        });
    }

    private void handlePresenceReceived(Presence presence) {
        CapsInfo capsInfo = presence.getPayload(new CapsInfo());
        if (capsInfo == null || !capsInfo.getHash().equals("sha-1")
                || presence.getPayload(new ErrorPayload()) != null) {
            return;
        }
        String hash = capsInfo.getVersion();
        if (capsStorage.getDiscoInfo(hash) != null) {
            return;
        }
        if (failingCaps.contains(new CapsPair(presence.getFrom(), hash))) {
            return;
        }
        if (requestedDiscoInfos.contains(hash)) {
            Set<CapsPair> fallback = fallbacks.get(hash);
            if (fallback == null) fallbacks.put(hash, fallback = new HashSet<CapsPair>());
            fallback.add(new CapsPair(presence.getFrom(), capsInfo.getNode()));
            return;
        }
        requestDiscoInfo(presence.getFrom(), capsInfo.getNode(), hash);
    }

    private void handleStanzaChannelAvailableChanged(boolean available) {
        if (available) {
            failingCaps.clear();
            fallbacks.clear();
            requestedDiscoInfos.clear();
        }
    }

    private void handleDiscoInfoReceived(final JID from, final String hash, DiscoInfo discoInfo, ErrorPayload error) {
        requestedDiscoInfos.remove(hash);
        if (error != null || discoInfo == null
                || !new CapsInfoGenerator("", crypto).generateCapsInfo(discoInfo).getVersion().equals(hash)) {
            if (warnOnInvalidHash && error == null && discoInfo != null) {
//                std.cerr << "Warning: Caps from " << from.toString() << " do not verify" << std.endl;
            }
            failingCaps.add(new CapsPair(from, hash));
            Set<CapsPair> i = fallbacks.get(hash);
            if (i != null && !i.isEmpty()) {
                CapsPair fallbackAndNode = i.iterator().next();
                i.remove(fallbackAndNode);
                requestDiscoInfo(fallbackAndNode.jid, fallbackAndNode.node, hash);
            }
            return;
        }
        fallbacks.remove(hash);
        capsStorage.setDiscoInfo(hash, discoInfo);
        onCapsAvailable.emit(hash);
    }

    private void requestDiscoInfo(final JID jid, final String node, final String hash) {
        GetDiscoInfoRequest request = GetDiscoInfoRequest.create(jid, node
                + "#" + hash, iqRouter);
        request.onResponse.connect(new Slot2<DiscoInfo, ErrorPayload>() {
            @Override
            public void call(DiscoInfo p1, ErrorPayload p2) {
                handleDiscoInfoReceived(jid, hash, p1, p2);
            }
        });
        requestedDiscoInfos.add(hash);
        request.send();
    }

    @Override
    DiscoInfo getCaps(String hash) {
        return capsStorage.getDiscoInfo(hash);
    }

    // Mainly for testing purposes
    void setWarnOnInvalidHash(boolean b) {
        warnOnInvalidHash = b;
    }

}