1package eu.siacs.conversations.crypto.sasl;
2
3import android.util.Log;
4import com.google.common.base.Preconditions;
5import com.google.common.base.Strings;
6import eu.siacs.conversations.Config;
7import eu.siacs.conversations.entities.Account;
8import eu.siacs.conversations.utils.SSLSockets;
9import eu.siacs.conversations.xml.Element;
10import eu.siacs.conversations.xml.Namespace;
11import java.util.Collection;
12import java.util.Collections;
13import javax.net.ssl.SSLSocket;
14
15public abstract class SaslMechanism {
16
17 protected final Account account;
18
19 protected SaslMechanism(final Account account) {
20 this.account = account;
21 }
22
23 public static String namespace(final Version version) {
24 if (version == Version.SASL) {
25 return Namespace.SASL;
26 } else {
27 return Namespace.SASL_2;
28 }
29 }
30
31 /**
32 * The priority is used to pin the authentication mechanism. If authentication fails, it MAY be
33 * retried with another mechanism of the same priority, but MUST NOT be tried with a mechanism
34 * of lower priority (to prevent downgrade attacks).
35 *
36 * @return An arbitrary int representing the priority
37 */
38 public abstract int getPriority();
39
40 public abstract String getMechanism();
41
42 public String getClientFirstMessage(final SSLSocket sslSocket) {
43 return "";
44 }
45
46 public String getResponse(final String challenge, final SSLSocket sslSocket)
47 throws AuthenticationException {
48 return "";
49 }
50
51 public enum State {
52 INITIAL,
53 AUTH_TEXT_SENT,
54 RESPONSE_SENT,
55 VALID_SERVER_RESPONSE,
56 }
57
58 public enum Version {
59 SASL,
60 SASL_2;
61
62 public static Version of(final Element element) {
63 return switch (Strings.nullToEmpty(element.getNamespace())) {
64 case Namespace.SASL -> SASL;
65 case Namespace.SASL_2 -> SASL_2;
66 default -> throw new IllegalArgumentException("Unrecognized SASL namespace");
67 };
68 }
69 }
70
71 public static class AuthenticationException extends Exception {
72 public AuthenticationException(final String message) {
73 super(message);
74 }
75
76 public AuthenticationException(final Exception inner) {
77 super(inner);
78 }
79
80 public AuthenticationException(final String message, final Exception exception) {
81 super(message, exception);
82 }
83 }
84
85 public static class InvalidStateException extends AuthenticationException {
86 public InvalidStateException(final String message) {
87 super(message);
88 }
89
90 public InvalidStateException(final State state) {
91 this("Invalid state: " + state.toString());
92 }
93 }
94
95 public static final class Factory {
96
97 private final Account account;
98
99 public Factory(final Account account) {
100 this.account = account;
101 }
102
103 private SaslMechanism of(
104 final Collection<String> mechanisms, final ChannelBinding channelBinding) {
105 Preconditions.checkNotNull(channelBinding, "Use ChannelBinding.NONE instead of null");
106 if (mechanisms.contains(External.MECHANISM) && account.getPrivateKeyAlias() != null) {
107 return new External(account);
108 } else if (mechanisms.contains(ScramSha512Plus.MECHANISM)
109 && channelBinding != ChannelBinding.NONE) {
110 return new ScramSha512Plus(account, channelBinding);
111 } else if (mechanisms.contains(ScramSha256Plus.MECHANISM)
112 && channelBinding != ChannelBinding.NONE) {
113 return new ScramSha256Plus(account, channelBinding);
114 } else if (mechanisms.contains(ScramSha1Plus.MECHANISM)
115 && channelBinding != ChannelBinding.NONE) {
116 return new ScramSha1Plus(account, channelBinding);
117 } else if (mechanisms.contains(ScramSha512.MECHANISM)) {
118 return new ScramSha512(account);
119 } else if (mechanisms.contains(ScramSha256.MECHANISM)) {
120 return new ScramSha256(account);
121 } else if (mechanisms.contains(ScramSha1.MECHANISM)) {
122 return new ScramSha1(account);
123 } else if (mechanisms.contains(Plain.MECHANISM)
124 && !account.getServer().equals("nimbuzz.com")) {
125 return new Plain(account);
126 } else if (mechanisms.contains(DigestMd5.MECHANISM)) {
127 return new DigestMd5(account);
128 } else if (mechanisms.contains(Anonymous.MECHANISM)) {
129 return new Anonymous(account);
130 } else {
131 return null;
132 }
133 }
134
135 public SaslMechanism of(
136 final Collection<String> mechanisms,
137 final Collection<ChannelBinding> bindings,
138 final Version version,
139 final SSLSockets.Version sslVersion) {
140 final HashedToken fastMechanism = account.getFastMechanism();
141 if (version == Version.SASL_2 && fastMechanism != null) {
142 return fastMechanism;
143 }
144 final ChannelBinding channelBinding = ChannelBinding.best(bindings, sslVersion);
145 return of(mechanisms, channelBinding);
146 }
147
148 public SaslMechanism of(final String mechanism, final ChannelBinding channelBinding) {
149 return of(Collections.singleton(mechanism), channelBinding);
150 }
151 }
152
153 public static SaslMechanism ensureAvailable(
154 final SaslMechanism mechanism,
155 final SSLSockets.Version sslVersion,
156 final boolean requireChannelBinding) {
157 if (mechanism instanceof ChannelBindingMechanism) {
158 final ChannelBinding cb = ((ChannelBindingMechanism) mechanism).getChannelBinding();
159 if (ChannelBinding.isAvailable(cb, sslVersion)) {
160 return mechanism;
161 } else {
162 Log.d(
163 Config.LOGTAG,
164 "pinned channel binding method " + cb + " no longer available");
165 return null;
166 }
167 } else if (requireChannelBinding) {
168 Log.d(Config.LOGTAG, "pinned mechanism did not provide channel binding");
169 return null;
170 } else {
171 return mechanism;
172 }
173 }
174
175 public static boolean hashedToken(final SaslMechanism saslMechanism) {
176 return saslMechanism instanceof HashedToken;
177 }
178
179 public static boolean pin(final SaslMechanism saslMechanism) {
180 return !hashedToken(saslMechanism);
181 }
182}