HashedToken.java

  1package eu.siacs.conversations.crypto.sasl;
  2
  3import android.util.Base64;
  4
  5import com.google.common.base.MoreObjects;
  6import com.google.common.base.Strings;
  7import com.google.common.collect.ImmutableMultimap;
  8import com.google.common.collect.Multimap;
  9import com.google.common.hash.HashFunction;
 10import com.google.common.primitives.Bytes;
 11
 12import org.jetbrains.annotations.NotNull;
 13
 14import java.nio.charset.StandardCharsets;
 15import java.util.Arrays;
 16import java.util.Collection;
 17import java.util.List;
 18
 19import javax.net.ssl.SSLSocket;
 20
 21import eu.siacs.conversations.entities.Account;
 22import eu.siacs.conversations.utils.SSLSockets;
 23
 24public abstract class HashedToken extends SaslMechanism implements ChannelBindingMechanism {
 25
 26    private static final String PREFIX = "HT";
 27
 28    private static final List<String> HASH_FUNCTIONS = Arrays.asList("SHA-512", "SHA-256");
 29    private static final byte[] INITIATOR = "Initiator".getBytes(StandardCharsets.UTF_8);
 30    private static final byte[] RESPONDER = "Responder".getBytes(StandardCharsets.UTF_8);
 31
 32    protected final ChannelBinding channelBinding;
 33
 34    protected HashedToken(final Account account, final ChannelBinding channelBinding) {
 35        super(account);
 36        this.channelBinding = channelBinding;
 37    }
 38
 39    @Override
 40    public int getPriority() {
 41        throw new UnsupportedOperationException();
 42    }
 43
 44    @Override
 45    public String getClientFirstMessage() {
 46        final String token = Strings.nullToEmpty(this.account.getFastToken());
 47        final HashFunction hashing = getHashFunction(token.getBytes(StandardCharsets.UTF_8));
 48        final byte[] cbData = new byte[0];
 49        final byte[] initiatorHashedToken =
 50                hashing.hashBytes(Bytes.concat(INITIATOR, cbData)).asBytes();
 51        final byte[] firstMessage =
 52                Bytes.concat(
 53                        account.getUsername().getBytes(StandardCharsets.UTF_8),
 54                        new byte[] {0x00},
 55                        initiatorHashedToken);
 56        return Base64.encodeToString(firstMessage, Base64.NO_WRAP);
 57    }
 58
 59    @Override
 60    public String getResponse(final String challenge, final SSLSocket socket)
 61            throws AuthenticationException {
 62        final byte[] responderMessage;
 63        try {
 64            responderMessage = Base64.decode(challenge, Base64.NO_WRAP);
 65        } catch (final Exception e) {
 66            throw new AuthenticationException("Unable to decode responder message", e);
 67        }
 68        final String token = Strings.nullToEmpty(this.account.getFastToken());
 69        final HashFunction hashing = getHashFunction(token.getBytes(StandardCharsets.UTF_8));
 70        final byte[] cbData = new byte[0];
 71        final byte[] expectedResponderMessage =
 72                hashing.hashBytes(Bytes.concat(RESPONDER, cbData)).asBytes();
 73        if (Arrays.equals(responderMessage, expectedResponderMessage)) {
 74            return null;
 75        }
 76        throw new AuthenticationException("Responder message did not match");
 77    }
 78
 79    protected abstract HashFunction getHashFunction(final byte[] key);
 80
 81    public abstract Mechanism getTokenMechanism();
 82
 83    @Override
 84    public String getMechanism() {
 85        return getTokenMechanism().name();
 86    }
 87
 88    public static final class Mechanism {
 89        public final String hashFunction;
 90        public final ChannelBinding channelBinding;
 91
 92        public Mechanism(String hashFunction, ChannelBinding channelBinding) {
 93            this.hashFunction = hashFunction;
 94            this.channelBinding = channelBinding;
 95        }
 96
 97        public static Mechanism of(final String mechanism) {
 98            final int first = mechanism.indexOf('-');
 99            final int last = mechanism.lastIndexOf('-');
100            if (last <= first || mechanism.length() <= last) {
101                throw new IllegalArgumentException("Not a valid HashedToken name");
102            }
103            if (mechanism.substring(0, first).equals(PREFIX)) {
104                final String hashFunction = mechanism.substring(first + 1, last);
105                final String cbShortName = mechanism.substring(last + 1);
106                final ChannelBinding channelBinding =
107                        ChannelBinding.SHORT_NAMES.inverse().get(cbShortName);
108                if (channelBinding == null) {
109                    throw new IllegalArgumentException("Unknown channel binding " + cbShortName);
110                }
111                return new Mechanism(hashFunction, channelBinding);
112            } else {
113                throw new IllegalArgumentException("HashedToken name does not start with HT");
114            }
115        }
116
117        public static Mechanism ofOrNull(final String mechanism) {
118            try {
119                return mechanism == null ? null : of(mechanism);
120            } catch (final IllegalArgumentException e) {
121                return null;
122            }
123        }
124
125        public static Multimap<String, ChannelBinding> of(final Collection<String> mechanisms) {
126            final ImmutableMultimap.Builder<String, ChannelBinding> builder =
127                    ImmutableMultimap.builder();
128            for (final String name : mechanisms) {
129                try {
130                    final Mechanism mechanism = Mechanism.of(name);
131                    builder.put(mechanism.hashFunction, mechanism.channelBinding);
132                } catch (final IllegalArgumentException ignored) {
133                }
134            }
135            return builder.build();
136        }
137
138        public static Mechanism best(
139                final Collection<String> mechanisms, final SSLSockets.Version sslVersion) {
140            final Multimap<String, ChannelBinding> multimap = of(mechanisms);
141            for (final String hashFunction : HASH_FUNCTIONS) {
142                final Collection<ChannelBinding> channelBindings = multimap.get(hashFunction);
143                if (channelBindings.isEmpty()) {
144                    continue;
145                }
146                final ChannelBinding cb = ChannelBinding.best(channelBindings, sslVersion);
147                return new Mechanism(hashFunction, cb);
148            }
149            return null;
150        }
151
152        @NotNull
153        @Override
154        public String toString() {
155            return MoreObjects.toStringHelper(this)
156                    .add("hashFunction", hashFunction)
157                    .add("channelBinding", channelBinding)
158                    .toString();
159        }
160
161        public String name() {
162            return String.format(
163                    "%s-%s-%s",
164                    PREFIX, hashFunction, ChannelBinding.SHORT_NAMES.get(channelBinding));
165        }
166    }
167
168    public ChannelBinding getChannelBinding() {
169        return this.channelBinding;
170    }
171}