SaslMechanism.java

  1package eu.siacs.conversations.crypto.sasl;
  2
  3import android.util.Log;
  4import com.google.common.base.Preconditions;
  5import com.google.common.base.Strings;
  6import eu.siacs.conversations.Config;
  7import eu.siacs.conversations.entities.Account;
  8import eu.siacs.conversations.utils.SSLSockets;
  9import eu.siacs.conversations.xml.Element;
 10import eu.siacs.conversations.xml.Namespace;
 11import java.util.Collection;
 12import java.util.Collections;
 13import javax.net.ssl.SSLSocket;
 14
 15public abstract class SaslMechanism {
 16
 17    protected final Account account;
 18
 19    protected State state = State.INITIAL;
 20
 21    protected SaslMechanism(final Account account) {
 22        this.account = account;
 23    }
 24
 25    public static String namespace(final Version version) {
 26        if (version == Version.SASL) {
 27            return Namespace.SASL;
 28        } else {
 29            return Namespace.SASL_2;
 30        }
 31    }
 32
 33    /**
 34     * The priority is used to pin the authentication mechanism. If authentication fails, it MAY be
 35     * retried with another mechanism of the same priority, but MUST NOT be tried with a mechanism
 36     * of lower priority (to prevent downgrade attacks).
 37     *
 38     * @return An arbitrary int representing the priority
 39     */
 40    public abstract int getPriority();
 41
 42    public abstract String getMechanism();
 43
 44    public abstract String getClientFirstMessage(final SSLSocket sslSocket);
 45
 46    public abstract String getResponse(final String challenge, final SSLSocket sslSocket)
 47            throws AuthenticationException;
 48
 49    public enum State {
 50        INITIAL,
 51        AUTH_TEXT_SENT,
 52        RESPONSE_SENT,
 53        VALID_SERVER_RESPONSE,
 54    }
 55
 56    protected void checkState(final State expected) throws InvalidStateException {
 57        final var current = this.state;
 58        if (current == null) {
 59            throw new InvalidStateException("Current state is null. Implementation problem");
 60        }
 61        if (current != expected) {
 62            throw new InvalidStateException(
 63                    String.format("State was %s. Expected %s", current, expected));
 64        }
 65    }
 66
 67    public enum Version {
 68        SASL,
 69        SASL_2;
 70
 71        public static Version of(final Element element) {
 72            return switch (Strings.nullToEmpty(element.getNamespace())) {
 73                case Namespace.SASL -> SASL;
 74                case Namespace.SASL_2 -> SASL_2;
 75                default -> throw new IllegalArgumentException("Unrecognized SASL namespace");
 76            };
 77        }
 78    }
 79
 80    public static class AuthenticationException extends Exception {
 81        public AuthenticationException(final String message) {
 82            super(message);
 83        }
 84
 85        public AuthenticationException(final Exception inner) {
 86            super(inner);
 87        }
 88
 89        public AuthenticationException(final String message, final Exception exception) {
 90            super(message, exception);
 91        }
 92    }
 93
 94    public static class InvalidStateException extends AuthenticationException {
 95        public InvalidStateException(final String message) {
 96            super(message);
 97        }
 98
 99        public InvalidStateException(final State state) {
100            this("Invalid state: " + state.toString());
101        }
102    }
103
104    public static final class Factory {
105
106        private final Account account;
107
108        public Factory(final Account account) {
109            this.account = account;
110        }
111
112        private SaslMechanism of(
113                final Collection<String> mechanisms, final ChannelBinding channelBinding) {
114            Preconditions.checkNotNull(channelBinding, "Use ChannelBinding.NONE instead of null");
115            if (mechanisms.contains(External.MECHANISM) && account.getPrivateKeyAlias() != null) {
116                return new External(account);
117            } else if (mechanisms.contains(ScramSha512Plus.MECHANISM)
118                    && channelBinding != ChannelBinding.NONE) {
119                return new ScramSha512Plus(account, channelBinding);
120            } else if (mechanisms.contains(ScramSha256Plus.MECHANISM)
121                    && channelBinding != ChannelBinding.NONE) {
122                return new ScramSha256Plus(account, channelBinding);
123            } else if (mechanisms.contains(ScramSha1Plus.MECHANISM)
124                    && channelBinding != ChannelBinding.NONE) {
125                return new ScramSha1Plus(account, channelBinding);
126            } else if (mechanisms.contains(ScramSha512.MECHANISM)) {
127                return new ScramSha512(account);
128            } else if (mechanisms.contains(ScramSha256.MECHANISM)) {
129                return new ScramSha256(account);
130            } else if (mechanisms.contains(ScramSha1.MECHANISM)) {
131                return new ScramSha1(account);
132            } else if (mechanisms.contains(Plain.MECHANISM)) {
133                return new Plain(account);
134            } else if (mechanisms.contains(DigestMd5.MECHANISM)) {
135                return new DigestMd5(account);
136            } else if (mechanisms.contains(Anonymous.MECHANISM)) {
137                return new Anonymous(account);
138            } else {
139                return null;
140            }
141        }
142
143        public SaslMechanism of(
144                final Collection<String> mechanisms,
145                final Collection<ChannelBinding> bindings,
146                final Version version,
147                final SSLSockets.Version sslVersion) {
148            final HashedToken fastMechanism = account.getFastMechanism();
149            if (version == Version.SASL_2 && fastMechanism != null) {
150                return fastMechanism;
151            }
152            final ChannelBinding channelBinding = ChannelBinding.best(bindings, sslVersion);
153            return of(mechanisms, channelBinding);
154        }
155
156        public SaslMechanism of(final String mechanism, final ChannelBinding channelBinding) {
157            return of(Collections.singleton(mechanism), channelBinding);
158        }
159    }
160
161    public static SaslMechanism ensureAvailable(
162            final SaslMechanism mechanism,
163            final SSLSockets.Version sslVersion,
164            final boolean requireChannelBinding) {
165        if (mechanism instanceof ChannelBindingMechanism) {
166            final ChannelBinding cb = ((ChannelBindingMechanism) mechanism).getChannelBinding();
167            if (ChannelBinding.isAvailable(cb, sslVersion)) {
168                return mechanism;
169            } else {
170                Log.d(
171                        Config.LOGTAG,
172                        "pinned channel binding method " + cb + " no longer available");
173                return null;
174            }
175        } else if (requireChannelBinding) {
176            Log.d(Config.LOGTAG, "pinned mechanism did not provide channel binding");
177            return null;
178        } else {
179            return mechanism;
180        }
181    }
182
183    public static boolean hashedToken(final SaslMechanism saslMechanism) {
184        return saslMechanism instanceof HashedToken;
185    }
186
187    public static boolean pin(final SaslMechanism saslMechanism) {
188        return !hashedToken(saslMechanism);
189    }
190}