1package eu.siacs.conversations.crypto.sasl;
2
3import android.util.Log;
4import com.google.common.base.Preconditions;
5import com.google.common.base.Strings;
6import eu.siacs.conversations.Config;
7import eu.siacs.conversations.entities.Account;
8import eu.siacs.conversations.utils.SSLSockets;
9import eu.siacs.conversations.xml.Element;
10import eu.siacs.conversations.xml.Namespace;
11import java.util.Collection;
12import java.util.Collections;
13import javax.net.ssl.SSLSocket;
14
15public abstract class SaslMechanism {
16
17 protected final Account account;
18
19 protected State state = State.INITIAL;
20
21 protected SaslMechanism(final Account account) {
22 this.account = account;
23 }
24
25 public static String namespace(final Version version) {
26 if (version == Version.SASL) {
27 return Namespace.SASL;
28 } else {
29 return Namespace.SASL_2;
30 }
31 }
32
33 /**
34 * The priority is used to pin the authentication mechanism. If authentication fails, it MAY be
35 * retried with another mechanism of the same priority, but MUST NOT be tried with a mechanism
36 * of lower priority (to prevent downgrade attacks).
37 *
38 * @return An arbitrary int representing the priority
39 */
40 public abstract int getPriority();
41
42 public abstract String getMechanism();
43
44 public abstract String getClientFirstMessage(final SSLSocket sslSocket);
45
46 public abstract String getResponse(final String challenge, final SSLSocket sslSocket)
47 throws AuthenticationException;
48
49 public enum State {
50 INITIAL,
51 AUTH_TEXT_SENT,
52 RESPONSE_SENT,
53 VALID_SERVER_RESPONSE,
54 }
55
56 protected void checkState(final State expected) throws InvalidStateException {
57 final var current = this.state;
58 if (current == null) {
59 throw new InvalidStateException("Current state is null. Implementation problem");
60 }
61 if (current != expected) {
62 throw new InvalidStateException(
63 String.format("State was %s. Expected %s", current, expected));
64 }
65 }
66
67 public enum Version {
68 SASL,
69 SASL_2;
70
71 public static Version of(final Element element) {
72 return switch (Strings.nullToEmpty(element.getNamespace())) {
73 case Namespace.SASL -> SASL;
74 case Namespace.SASL_2 -> SASL_2;
75 default -> throw new IllegalArgumentException("Unrecognized SASL namespace");
76 };
77 }
78 }
79
80 public static class AuthenticationException extends Exception {
81 public AuthenticationException(final String message) {
82 super(message);
83 }
84
85 public AuthenticationException(final Exception inner) {
86 super(inner);
87 }
88
89 public AuthenticationException(final String message, final Exception exception) {
90 super(message, exception);
91 }
92 }
93
94 public static class InvalidStateException extends AuthenticationException {
95 public InvalidStateException(final String message) {
96 super(message);
97 }
98
99 public InvalidStateException(final State state) {
100 this("Invalid state: " + state.toString());
101 }
102 }
103
104 public static final class Factory {
105
106 private final Account account;
107
108 public Factory(final Account account) {
109 this.account = account;
110 }
111
112 private SaslMechanism of(
113 final Collection<String> mechanisms, final ChannelBinding channelBinding) {
114 Preconditions.checkNotNull(channelBinding, "Use ChannelBinding.NONE instead of null");
115 if (mechanisms.contains(External.MECHANISM) && account.getPrivateKeyAlias() != null) {
116 return new External(account);
117 } else if (mechanisms.contains(ScramSha512Plus.MECHANISM)
118 && channelBinding != ChannelBinding.NONE) {
119 return new ScramSha512Plus(account, channelBinding);
120 } else if (mechanisms.contains(ScramSha256Plus.MECHANISM)
121 && channelBinding != ChannelBinding.NONE) {
122 return new ScramSha256Plus(account, channelBinding);
123 } else if (mechanisms.contains(ScramSha1Plus.MECHANISM)
124 && channelBinding != ChannelBinding.NONE) {
125 return new ScramSha1Plus(account, channelBinding);
126 } else if (mechanisms.contains(ScramSha512.MECHANISM)) {
127 return new ScramSha512(account);
128 } else if (mechanisms.contains(ScramSha256.MECHANISM)) {
129 return new ScramSha256(account);
130 } else if (mechanisms.contains(ScramSha1.MECHANISM)) {
131 return new ScramSha1(account);
132 } else if (mechanisms.contains(Plain.MECHANISM)
133 && !account.getServer().equals("nimbuzz.com")) {
134 return new Plain(account);
135 } else if (mechanisms.contains(DigestMd5.MECHANISM)) {
136 return new DigestMd5(account);
137 } else if (mechanisms.contains(Anonymous.MECHANISM)) {
138 return new Anonymous(account);
139 } else {
140 return null;
141 }
142 }
143
144 public SaslMechanism of(
145 final Collection<String> mechanisms,
146 final Collection<ChannelBinding> bindings,
147 final Version version,
148 final SSLSockets.Version sslVersion) {
149 final HashedToken fastMechanism = account.getFastMechanism();
150 if (version == Version.SASL_2 && fastMechanism != null) {
151 return fastMechanism;
152 }
153 final ChannelBinding channelBinding = ChannelBinding.best(bindings, sslVersion);
154 return of(mechanisms, channelBinding);
155 }
156
157 public SaslMechanism of(final String mechanism, final ChannelBinding channelBinding) {
158 return of(Collections.singleton(mechanism), channelBinding);
159 }
160 }
161
162 public static SaslMechanism ensureAvailable(
163 final SaslMechanism mechanism,
164 final SSLSockets.Version sslVersion,
165 final boolean requireChannelBinding) {
166 if (mechanism instanceof ChannelBindingMechanism) {
167 final ChannelBinding cb = ((ChannelBindingMechanism) mechanism).getChannelBinding();
168 if (ChannelBinding.isAvailable(cb, sslVersion)) {
169 return mechanism;
170 } else {
171 Log.d(
172 Config.LOGTAG,
173 "pinned channel binding method " + cb + " no longer available");
174 return null;
175 }
176 } else if (requireChannelBinding) {
177 Log.d(Config.LOGTAG, "pinned mechanism did not provide channel binding");
178 return null;
179 } else {
180 return mechanism;
181 }
182 }
183
184 public static boolean hashedToken(final SaslMechanism saslMechanism) {
185 return saslMechanism instanceof HashedToken;
186 }
187
188 public static boolean pin(final SaslMechanism saslMechanism) {
189 return !hashedToken(saslMechanism);
190 }
191}