1package eu.siacs.conversations.crypto.sasl;
2
3import android.util.Base64;
4
5import com.google.common.base.MoreObjects;
6import com.google.common.base.Strings;
7import com.google.common.collect.ImmutableMultimap;
8import com.google.common.collect.Multimap;
9import com.google.common.hash.HashFunction;
10import com.google.common.primitives.Bytes;
11
12import org.jetbrains.annotations.NotNull;
13
14import java.nio.charset.StandardCharsets;
15import java.util.Arrays;
16import java.util.Collection;
17import java.util.List;
18
19import javax.net.ssl.SSLSocket;
20
21import eu.siacs.conversations.entities.Account;
22import eu.siacs.conversations.utils.SSLSockets;
23
24public abstract class HashedToken extends SaslMechanism implements ChannelBindingMechanism {
25
26 private static final String PREFIX = "HT";
27
28 private static final List<String> HASH_FUNCTIONS = Arrays.asList("SHA-512", "SHA-256");
29 private static final byte[] INITIATOR = "Initiator".getBytes(StandardCharsets.UTF_8);
30 private static final byte[] RESPONDER = "Responder".getBytes(StandardCharsets.UTF_8);
31
32 protected final ChannelBinding channelBinding;
33
34 protected HashedToken(final Account account, final ChannelBinding channelBinding) {
35 super(account);
36 this.channelBinding = channelBinding;
37 }
38
39 @Override
40 public int getPriority() {
41 throw new UnsupportedOperationException();
42 }
43
44 @Override
45 public String getClientFirstMessage() {
46 final String token = Strings.nullToEmpty(this.account.getFastToken());
47 final HashFunction hashing = getHashFunction(token.getBytes(StandardCharsets.UTF_8));
48 final byte[] cbData = new byte[0];
49 final byte[] initiatorHashedToken =
50 hashing.hashBytes(Bytes.concat(INITIATOR, cbData)).asBytes();
51 final byte[] firstMessage =
52 Bytes.concat(
53 account.getUsername().getBytes(StandardCharsets.UTF_8),
54 new byte[] {0x00},
55 initiatorHashedToken);
56 return Base64.encodeToString(firstMessage, Base64.NO_WRAP);
57 }
58
59 @Override
60 public String getResponse(final String challenge, final SSLSocket socket)
61 throws AuthenticationException {
62 final byte[] responderMessage;
63 try {
64 responderMessage = Base64.decode(challenge, Base64.NO_WRAP);
65 } catch (final Exception e) {
66 throw new AuthenticationException("Unable to decode responder message", e);
67 }
68 final String token = Strings.nullToEmpty(this.account.getFastToken());
69 final HashFunction hashing = getHashFunction(token.getBytes(StandardCharsets.UTF_8));
70 final byte[] cbData = new byte[0];
71 final byte[] expectedResponderMessage =
72 hashing.hashBytes(Bytes.concat(RESPONDER, cbData)).asBytes();
73 if (Arrays.equals(responderMessage, expectedResponderMessage)) {
74 return null;
75 }
76 throw new AuthenticationException("Responder message did not match");
77 }
78
79 protected abstract HashFunction getHashFunction(final byte[] key);
80
81 public abstract Mechanism getTokenMechanism();
82
83 @Override
84 public String getMechanism() {
85 return getTokenMechanism().name();
86 }
87
88 public static final class Mechanism {
89 public final String hashFunction;
90 public final ChannelBinding channelBinding;
91
92 public Mechanism(String hashFunction, ChannelBinding channelBinding) {
93 this.hashFunction = hashFunction;
94 this.channelBinding = channelBinding;
95 }
96
97 public static Mechanism of(final String mechanism) {
98 final int first = mechanism.indexOf('-');
99 final int last = mechanism.lastIndexOf('-');
100 if (last <= first || mechanism.length() <= last) {
101 throw new IllegalArgumentException("Not a valid HashedToken name");
102 }
103 if (mechanism.substring(0, first).equals(PREFIX)) {
104 final String hashFunction = mechanism.substring(first + 1, last);
105 final String cbShortName = mechanism.substring(last + 1);
106 final ChannelBinding channelBinding =
107 ChannelBinding.SHORT_NAMES.inverse().get(cbShortName);
108 if (channelBinding == null) {
109 throw new IllegalArgumentException("Unknown channel binding " + cbShortName);
110 }
111 return new Mechanism(hashFunction, channelBinding);
112 } else {
113 throw new IllegalArgumentException("HashedToken name does not start with HT");
114 }
115 }
116
117 public static Mechanism ofOrNull(final String mechanism) {
118 try {
119 return mechanism == null ? null : of(mechanism);
120 } catch (final IllegalArgumentException e) {
121 return null;
122 }
123 }
124
125 public static Multimap<String, ChannelBinding> of(final Collection<String> mechanisms) {
126 final ImmutableMultimap.Builder<String, ChannelBinding> builder =
127 ImmutableMultimap.builder();
128 for (final String name : mechanisms) {
129 try {
130 final Mechanism mechanism = Mechanism.of(name);
131 builder.put(mechanism.hashFunction, mechanism.channelBinding);
132 } catch (final IllegalArgumentException ignored) {
133 }
134 }
135 return builder.build();
136 }
137
138 public static Mechanism best(
139 final Collection<String> mechanisms, final SSLSockets.Version sslVersion) {
140 final Multimap<String, ChannelBinding> multimap = of(mechanisms);
141 for (final String hashFunction : HASH_FUNCTIONS) {
142 final Collection<ChannelBinding> channelBindings = multimap.get(hashFunction);
143 if (channelBindings.isEmpty()) {
144 continue;
145 }
146 final ChannelBinding cb = ChannelBinding.best(channelBindings, sslVersion);
147 return new Mechanism(hashFunction, cb);
148 }
149 return null;
150 }
151
152 @NotNull
153 @Override
154 public String toString() {
155 return MoreObjects.toStringHelper(this)
156 .add("hashFunction", hashFunction)
157 .add("channelBinding", channelBinding)
158 .toString();
159 }
160
161 public String name() {
162 return String.format(
163 "%s-%s-%s",
164 PREFIX, hashFunction, ChannelBinding.SHORT_NAMES.get(channelBinding));
165 }
166 }
167
168 public ChannelBinding getChannelBinding() {
169 return this.channelBinding;
170 }
171}