HashedToken.java

  1package eu.siacs.conversations.crypto.sasl;
  2
  3import android.util.Base64;
  4import android.util.Log;
  5
  6import androidx.annotation.NonNull;
  7
  8import com.google.common.base.MoreObjects;
  9import com.google.common.base.Strings;
 10import com.google.common.collect.ImmutableMultimap;
 11import com.google.common.collect.Multimap;
 12import com.google.common.hash.HashFunction;
 13import com.google.common.primitives.Bytes;
 14
 15import java.nio.charset.StandardCharsets;
 16import java.util.Arrays;
 17import java.util.Collection;
 18import java.util.List;
 19
 20import javax.net.ssl.SSLSocket;
 21
 22import eu.siacs.conversations.Config;
 23import eu.siacs.conversations.entities.Account;
 24import eu.siacs.conversations.utils.SSLSockets;
 25
 26public abstract class HashedToken extends SaslMechanism implements ChannelBindingMechanism {
 27
 28    private static final String PREFIX = "HT";
 29
 30    private static final List<String> HASH_FUNCTIONS = Arrays.asList("SHA-512", "SHA-256");
 31    private static final byte[] INITIATOR = "Initiator".getBytes(StandardCharsets.UTF_8);
 32    private static final byte[] RESPONDER = "Responder".getBytes(StandardCharsets.UTF_8);
 33
 34    protected final ChannelBinding channelBinding;
 35
 36    protected HashedToken(final Account account, final ChannelBinding channelBinding) {
 37        super(account);
 38        this.channelBinding = channelBinding;
 39    }
 40
 41    @Override
 42    public int getPriority() {
 43        throw new UnsupportedOperationException();
 44    }
 45
 46    @Override
 47    public String getClientFirstMessage(final SSLSocket sslSocket) {
 48        final String token = Strings.nullToEmpty(this.account.getFastToken());
 49        final HashFunction hashing = getHashFunction(token.getBytes(StandardCharsets.UTF_8));
 50        final byte[] cbData = getChannelBindingData(sslSocket);
 51        final byte[] initiatorHashedToken =
 52                hashing.hashBytes(Bytes.concat(INITIATOR, cbData)).asBytes();
 53        final byte[] firstMessage =
 54                Bytes.concat(
 55                        account.getUsername().getBytes(StandardCharsets.UTF_8),
 56                        new byte[] {0x00},
 57                        initiatorHashedToken);
 58        return Base64.encodeToString(firstMessage, Base64.NO_WRAP);
 59    }
 60
 61    private byte[] getChannelBindingData(final SSLSocket sslSocket) {
 62        if (this.channelBinding == ChannelBinding.NONE) {
 63            return new byte[0];
 64        }
 65        try {
 66            return ChannelBindingMechanism.getChannelBindingData(sslSocket, this.channelBinding);
 67        } catch (final AuthenticationException e) {
 68            Log.e(
 69                    Config.LOGTAG,
 70                    account.getJid().asBareJid()
 71                            + ": unable to retrieve channel binding data for "
 72                            + getMechanism(),
 73                    e);
 74            return new byte[0];
 75        }
 76    }
 77
 78    @Override
 79    public String getResponse(final String challenge, final SSLSocket socket)
 80            throws AuthenticationException {
 81        final byte[] responderMessage;
 82        try {
 83            responderMessage = Base64.decode(challenge, Base64.NO_WRAP);
 84        } catch (final Exception e) {
 85            throw new AuthenticationException("Unable to decode responder message", e);
 86        }
 87        final String token = Strings.nullToEmpty(this.account.getFastToken());
 88        final HashFunction hashing = getHashFunction(token.getBytes(StandardCharsets.UTF_8));
 89        final byte[] cbData = getChannelBindingData(socket);
 90        final byte[] expectedResponderMessage =
 91                hashing.hashBytes(Bytes.concat(RESPONDER, cbData)).asBytes();
 92        if (Arrays.equals(responderMessage, expectedResponderMessage)) {
 93            return null;
 94        }
 95        throw new AuthenticationException("Responder message did not match");
 96    }
 97
 98    protected abstract HashFunction getHashFunction(final byte[] key);
 99
100    public abstract Mechanism getTokenMechanism();
101
102    @Override
103    public String getMechanism() {
104        return getTokenMechanism().name();
105    }
106
107    public static final class Mechanism {
108        public final String hashFunction;
109        public final ChannelBinding channelBinding;
110
111        public Mechanism(String hashFunction, ChannelBinding channelBinding) {
112            this.hashFunction = hashFunction;
113            this.channelBinding = channelBinding;
114        }
115
116        public static Mechanism of(final String mechanism) {
117            final int first = mechanism.indexOf('-');
118            final int last = mechanism.lastIndexOf('-');
119            if (last <= first || mechanism.length() <= last) {
120                throw new IllegalArgumentException("Not a valid HashedToken name");
121            }
122            if (mechanism.substring(0, first).equals(PREFIX)) {
123                final String hashFunction = mechanism.substring(first + 1, last);
124                final String cbShortName = mechanism.substring(last + 1);
125                final ChannelBinding channelBinding =
126                        ChannelBinding.SHORT_NAMES.inverse().get(cbShortName);
127                if (channelBinding == null) {
128                    throw new IllegalArgumentException("Unknown channel binding " + cbShortName);
129                }
130                return new Mechanism(hashFunction, channelBinding);
131            } else {
132                throw new IllegalArgumentException("HashedToken name does not start with HT");
133            }
134        }
135
136        public static Mechanism ofOrNull(final String mechanism) {
137            try {
138                return mechanism == null ? null : of(mechanism);
139            } catch (final IllegalArgumentException e) {
140                return null;
141            }
142        }
143
144        public static Multimap<String, ChannelBinding> of(final Collection<String> mechanisms) {
145            final ImmutableMultimap.Builder<String, ChannelBinding> builder =
146                    ImmutableMultimap.builder();
147            for (final String name : mechanisms) {
148                try {
149                    final Mechanism mechanism = Mechanism.of(name);
150                    builder.put(mechanism.hashFunction, mechanism.channelBinding);
151                } catch (final IllegalArgumentException ignored) {
152                }
153            }
154            return builder.build();
155        }
156
157        public static Mechanism best(
158                final Collection<String> mechanisms, final SSLSockets.Version sslVersion) {
159            final Multimap<String, ChannelBinding> multimap = of(mechanisms);
160            for (final String hashFunction : HASH_FUNCTIONS) {
161                final Collection<ChannelBinding> channelBindings = multimap.get(hashFunction);
162                if (channelBindings.isEmpty()) {
163                    continue;
164                }
165                final ChannelBinding cb = ChannelBinding.best(channelBindings, sslVersion);
166                return new Mechanism(hashFunction, cb);
167            }
168            return null;
169        }
170
171        @NonNull
172        @Override
173        public String toString() {
174            return MoreObjects.toStringHelper(this)
175                    .add("hashFunction", hashFunction)
176                    .add("channelBinding", channelBinding)
177                    .toString();
178        }
179
180        public String name() {
181            return String.format(
182                    "%s-%s-%s",
183                    PREFIX, hashFunction, ChannelBinding.SHORT_NAMES.get(channelBinding));
184        }
185    }
186
187    public ChannelBinding getChannelBinding() {
188        return this.channelBinding;
189    }
190}