1package eu.siacs.conversations.crypto.sasl;
2
3import android.util.Base64;
4import android.util.Log;
5
6import androidx.annotation.NonNull;
7
8import com.google.common.base.MoreObjects;
9import com.google.common.base.Strings;
10import com.google.common.collect.ImmutableMultimap;
11import com.google.common.collect.Multimap;
12import com.google.common.hash.HashFunction;
13import com.google.common.primitives.Bytes;
14
15import java.nio.charset.StandardCharsets;
16import java.util.Arrays;
17import java.util.Collection;
18import java.util.List;
19
20import javax.net.ssl.SSLSocket;
21
22import eu.siacs.conversations.Config;
23import eu.siacs.conversations.entities.Account;
24import eu.siacs.conversations.utils.SSLSockets;
25
26public abstract class HashedToken extends SaslMechanism implements ChannelBindingMechanism {
27
28 private static final String PREFIX = "HT";
29
30 private static final List<String> HASH_FUNCTIONS = Arrays.asList("SHA-512", "SHA-256");
31 private static final byte[] INITIATOR = "Initiator".getBytes(StandardCharsets.UTF_8);
32 private static final byte[] RESPONDER = "Responder".getBytes(StandardCharsets.UTF_8);
33
34 protected final ChannelBinding channelBinding;
35
36 protected HashedToken(final Account account, final ChannelBinding channelBinding) {
37 super(account);
38 this.channelBinding = channelBinding;
39 }
40
41 @Override
42 public int getPriority() {
43 throw new UnsupportedOperationException();
44 }
45
46 @Override
47 public String getClientFirstMessage(final SSLSocket sslSocket) {
48 final String token = Strings.nullToEmpty(this.account.getFastToken());
49 final HashFunction hashing = getHashFunction(token.getBytes(StandardCharsets.UTF_8));
50 final byte[] cbData = getChannelBindingData(sslSocket);
51 final byte[] initiatorHashedToken =
52 hashing.hashBytes(Bytes.concat(INITIATOR, cbData)).asBytes();
53 final byte[] firstMessage =
54 Bytes.concat(
55 account.getUsername().getBytes(StandardCharsets.UTF_8),
56 new byte[] {0x00},
57 initiatorHashedToken);
58 return Base64.encodeToString(firstMessage, Base64.NO_WRAP);
59 }
60
61 private byte[] getChannelBindingData(final SSLSocket sslSocket) {
62 if (this.channelBinding == ChannelBinding.NONE) {
63 return new byte[0];
64 }
65 try {
66 return ChannelBindingMechanism.getChannelBindingData(sslSocket, this.channelBinding);
67 } catch (final AuthenticationException e) {
68 Log.e(
69 Config.LOGTAG,
70 account.getJid().asBareJid()
71 + ": unable to retrieve channel binding data for "
72 + getMechanism(),
73 e);
74 return new byte[0];
75 }
76 }
77
78 @Override
79 public String getResponse(final String challenge, final SSLSocket socket)
80 throws AuthenticationException {
81 final byte[] responderMessage;
82 try {
83 responderMessage = Base64.decode(challenge, Base64.NO_WRAP);
84 } catch (final Exception e) {
85 throw new AuthenticationException("Unable to decode responder message", e);
86 }
87 final String token = Strings.nullToEmpty(this.account.getFastToken());
88 final HashFunction hashing = getHashFunction(token.getBytes(StandardCharsets.UTF_8));
89 final byte[] cbData = getChannelBindingData(socket);
90 final byte[] expectedResponderMessage =
91 hashing.hashBytes(Bytes.concat(RESPONDER, cbData)).asBytes();
92 if (Arrays.equals(responderMessage, expectedResponderMessage)) {
93 return null;
94 }
95 throw new AuthenticationException("Responder message did not match");
96 }
97
98 protected abstract HashFunction getHashFunction(final byte[] key);
99
100 public abstract Mechanism getTokenMechanism();
101
102 @Override
103 public String getMechanism() {
104 return getTokenMechanism().name();
105 }
106
107 public static final class Mechanism {
108 public final String hashFunction;
109 public final ChannelBinding channelBinding;
110
111 public Mechanism(String hashFunction, ChannelBinding channelBinding) {
112 this.hashFunction = hashFunction;
113 this.channelBinding = channelBinding;
114 }
115
116 public static Mechanism of(final String mechanism) {
117 final int first = mechanism.indexOf('-');
118 final int last = mechanism.lastIndexOf('-');
119 if (last <= first || mechanism.length() <= last) {
120 throw new IllegalArgumentException("Not a valid HashedToken name");
121 }
122 if (mechanism.substring(0, first).equals(PREFIX)) {
123 final String hashFunction = mechanism.substring(first + 1, last);
124 final String cbShortName = mechanism.substring(last + 1);
125 final ChannelBinding channelBinding =
126 ChannelBinding.SHORT_NAMES.inverse().get(cbShortName);
127 if (channelBinding == null) {
128 throw new IllegalArgumentException("Unknown channel binding " + cbShortName);
129 }
130 return new Mechanism(hashFunction, channelBinding);
131 } else {
132 throw new IllegalArgumentException("HashedToken name does not start with HT");
133 }
134 }
135
136 public static Mechanism ofOrNull(final String mechanism) {
137 try {
138 return mechanism == null ? null : of(mechanism);
139 } catch (final IllegalArgumentException e) {
140 return null;
141 }
142 }
143
144 public static Multimap<String, ChannelBinding> of(final Collection<String> mechanisms) {
145 final ImmutableMultimap.Builder<String, ChannelBinding> builder =
146 ImmutableMultimap.builder();
147 for (final String name : mechanisms) {
148 try {
149 final Mechanism mechanism = Mechanism.of(name);
150 builder.put(mechanism.hashFunction, mechanism.channelBinding);
151 } catch (final IllegalArgumentException ignored) {
152 }
153 }
154 return builder.build();
155 }
156
157 public static Mechanism best(
158 final Collection<String> mechanisms, final SSLSockets.Version sslVersion) {
159 final Multimap<String, ChannelBinding> multimap = of(mechanisms);
160 for (final String hashFunction : HASH_FUNCTIONS) {
161 final Collection<ChannelBinding> channelBindings = multimap.get(hashFunction);
162 if (channelBindings.isEmpty()) {
163 continue;
164 }
165 final ChannelBinding cb = ChannelBinding.best(channelBindings, sslVersion);
166 return new Mechanism(hashFunction, cb);
167 }
168 return null;
169 }
170
171 @NonNull
172 @Override
173 public String toString() {
174 return MoreObjects.toStringHelper(this)
175 .add("hashFunction", hashFunction)
176 .add("channelBinding", channelBinding)
177 .toString();
178 }
179
180 public String name() {
181 return String.format(
182 "%s-%s-%s",
183 PREFIX, hashFunction, ChannelBinding.SHORT_NAMES.get(channelBinding));
184 }
185 }
186
187 public ChannelBinding getChannelBinding() {
188 return this.channelBinding;
189 }
190}