1package eu.siacs.conversations.crypto.sasl;
2
3import android.util.Log;
4import com.google.common.base.Preconditions;
5import com.google.common.base.Strings;
6import eu.siacs.conversations.Config;
7import eu.siacs.conversations.entities.Account;
8import eu.siacs.conversations.utils.SSLSockets;
9import eu.siacs.conversations.xml.Element;
10import eu.siacs.conversations.xml.Namespace;
11import java.util.Collection;
12import java.util.Collections;
13import javax.net.ssl.SSLSocket;
14
15public abstract class SaslMechanism {
16
17 protected final Account account;
18
19 protected State state = State.INITIAL;
20
21 protected SaslMechanism(final Account account) {
22 this.account = account;
23 }
24
25 public static String namespace(final Version version) {
26 if (version == Version.SASL) {
27 return Namespace.SASL;
28 } else {
29 return Namespace.SASL_2;
30 }
31 }
32
33 /**
34 * The priority is used to pin the authentication mechanism. If authentication fails, it MAY be
35 * retried with another mechanism of the same priority, but MUST NOT be tried with a mechanism
36 * of lower priority (to prevent downgrade attacks).
37 *
38 * @return An arbitrary int representing the priority
39 */
40 public abstract int getPriority();
41
42 public abstract String getMechanism();
43
44 public abstract String getClientFirstMessage(final SSLSocket sslSocket);
45
46 public abstract String getResponse(final String challenge, final SSLSocket sslSocket)
47 throws AuthenticationException;
48
49 public enum State {
50 INITIAL,
51 AUTH_TEXT_SENT,
52 RESPONSE_SENT,
53 VALID_SERVER_RESPONSE,
54 }
55
56 protected void checkState(final State expected) throws InvalidStateException {
57 final var current = this.state;
58 if (current == null) {
59 throw new InvalidStateException("Current state is null. Implementation problem");
60 }
61 if (current != expected) {
62 throw new InvalidStateException(
63 String.format("State was %s. Expected %s", current, expected));
64 }
65 }
66
67 public enum Version {
68 SASL,
69 SASL_2;
70
71 public static Version of(final Element element) {
72 return switch (Strings.nullToEmpty(element.getNamespace())) {
73 case Namespace.SASL -> SASL;
74 case Namespace.SASL_2 -> SASL_2;
75 default -> throw new IllegalArgumentException("Unrecognized SASL namespace");
76 };
77 }
78 }
79
80 public static class AuthenticationException extends Exception {
81 public AuthenticationException(final String message) {
82 super(message);
83 }
84
85 public AuthenticationException(final Exception inner) {
86 super(inner);
87 }
88
89 public AuthenticationException(final String message, final Exception exception) {
90 super(message, exception);
91 }
92 }
93
94 public static class InvalidStateException extends AuthenticationException {
95 public InvalidStateException(final String message) {
96 super(message);
97 }
98
99 public InvalidStateException(final State state) {
100 this("Invalid state: " + state.toString());
101 }
102 }
103
104 public static final class Factory {
105
106 private final Account account;
107
108 public Factory(final Account account) {
109 this.account = account;
110 }
111
112 private SaslMechanism of(
113 final Collection<String> mechanisms, final ChannelBinding channelBinding) {
114 Preconditions.checkNotNull(channelBinding, "Use ChannelBinding.NONE instead of null");
115 if (mechanisms.contains(External.MECHANISM) && account.getPrivateKeyAlias() != null) {
116 return new External(account);
117 } else if (mechanisms.contains(ScramSha512Plus.MECHANISM)
118 && channelBinding != ChannelBinding.NONE) {
119 return new ScramSha512Plus(account, channelBinding);
120 } else if (mechanisms.contains(ScramSha256Plus.MECHANISM)
121 && channelBinding != ChannelBinding.NONE) {
122 return new ScramSha256Plus(account, channelBinding);
123 } else if (mechanisms.contains(ScramSha1Plus.MECHANISM)
124 && channelBinding != ChannelBinding.NONE) {
125 return new ScramSha1Plus(account, channelBinding);
126 } else if (mechanisms.contains(ScramSha512.MECHANISM)) {
127 return new ScramSha512(account);
128 } else if (mechanisms.contains(ScramSha256.MECHANISM)) {
129 return new ScramSha256(account);
130 } else if (mechanisms.contains(ScramSha1.MECHANISM)) {
131 return new ScramSha1(account);
132 } else if (mechanisms.contains(Plain.MECHANISM)) {
133 return new Plain(account);
134 } else if (mechanisms.contains(DigestMd5.MECHANISM)) {
135 return new DigestMd5(account);
136 } else if (mechanisms.contains(Anonymous.MECHANISM)) {
137 return new Anonymous(account);
138 } else {
139 return null;
140 }
141 }
142
143 public SaslMechanism of(
144 final Collection<String> mechanisms,
145 final Collection<ChannelBinding> bindings,
146 final Version version,
147 final SSLSockets.Version sslVersion) {
148 final HashedToken fastMechanism = account.getFastMechanism();
149 if (version == Version.SASL_2 && fastMechanism != null) {
150 return fastMechanism;
151 }
152 final ChannelBinding channelBinding = ChannelBinding.best(bindings, sslVersion);
153 return of(mechanisms, channelBinding);
154 }
155
156 public SaslMechanism of(final String mechanism, final ChannelBinding channelBinding) {
157 return of(Collections.singleton(mechanism), channelBinding);
158 }
159 }
160
161 public static SaslMechanism ensureAvailable(
162 final SaslMechanism mechanism,
163 final SSLSockets.Version sslVersion,
164 final boolean requireChannelBinding) {
165 if (mechanism instanceof ChannelBindingMechanism) {
166 final ChannelBinding cb = ((ChannelBindingMechanism) mechanism).getChannelBinding();
167 if (ChannelBinding.isAvailable(cb, sslVersion)) {
168 return mechanism;
169 } else {
170 Log.d(
171 Config.LOGTAG,
172 "pinned channel binding method " + cb + " no longer available");
173 return null;
174 }
175 } else if (requireChannelBinding) {
176 Log.d(Config.LOGTAG, "pinned mechanism did not provide channel binding");
177 return null;
178 } else {
179 return mechanism;
180 }
181 }
182
183 public static boolean hashedToken(final SaslMechanism saslMechanism) {
184 return saslMechanism instanceof HashedToken;
185 }
186
187 public static boolean pin(final SaslMechanism saslMechanism) {
188 return !hashedToken(saslMechanism);
189 }
190}