1package eu.siacs.conversations.crypto.sasl;
2
3import android.util.Log;
4
5import com.google.common.base.Preconditions;
6import com.google.common.base.Strings;
7import com.google.common.collect.Collections2;
8
9import java.util.Collection;
10import java.util.Collections;
11
12import javax.net.ssl.SSLSocket;
13
14import eu.siacs.conversations.Config;
15import eu.siacs.conversations.entities.Account;
16import eu.siacs.conversations.utils.SSLSockets;
17import eu.siacs.conversations.xml.Element;
18import eu.siacs.conversations.xml.Namespace;
19
20public abstract class SaslMechanism {
21
22 protected final Account account;
23
24 protected SaslMechanism(final Account account) {
25 this.account = account;
26 }
27
28 public static String namespace(final Version version) {
29 if (version == Version.SASL) {
30 return Namespace.SASL;
31 } else {
32 return Namespace.SASL_2;
33 }
34 }
35
36 /**
37 * The priority is used to pin the authentication mechanism. If authentication fails, it MAY be
38 * retried with another mechanism of the same priority, but MUST NOT be tried with a mechanism
39 * of lower priority (to prevent downgrade attacks).
40 *
41 * @return An arbitrary int representing the priority
42 */
43 public abstract int getPriority();
44
45 public abstract String getMechanism();
46
47 public String getClientFirstMessage(final SSLSocket sslSocket) {
48 return "";
49 }
50
51 public String getResponse(final String challenge, final SSLSocket sslSocket)
52 throws AuthenticationException {
53 return "";
54 }
55
56 public static Collection<String> mechanisms(final Element authElement) {
57 if (authElement == null) {
58 return Collections.emptyList();
59 }
60 return Collections2.transform(
61 Collections2.filter(
62 authElement.getChildren(),
63 c -> c != null && "mechanism".equals(c.getName())),
64 c -> c == null ? null : c.getContent());
65 }
66
67 protected enum State {
68 INITIAL,
69 AUTH_TEXT_SENT,
70 RESPONSE_SENT,
71 VALID_SERVER_RESPONSE,
72 }
73
74 public enum Version {
75 SASL,
76 SASL_2;
77
78 public static Version of(final Element element) {
79 switch (Strings.nullToEmpty(element.getNamespace())) {
80 case Namespace.SASL:
81 return SASL;
82 case Namespace.SASL_2:
83 return SASL_2;
84 default:
85 throw new IllegalArgumentException("Unrecognized SASL namespace");
86 }
87 }
88 }
89
90 public static class AuthenticationException extends Exception {
91 public AuthenticationException(final String message) {
92 super(message);
93 }
94
95 public AuthenticationException(final Exception inner) {
96 super(inner);
97 }
98
99 public AuthenticationException(final String message, final Exception exception) {
100 super(message, exception);
101 }
102 }
103
104 public static class InvalidStateException extends AuthenticationException {
105 public InvalidStateException(final String message) {
106 super(message);
107 }
108
109 public InvalidStateException(final State state) {
110 this("Invalid state: " + state.toString());
111 }
112 }
113
114 public static final class Factory {
115
116 private final Account account;
117
118 public Factory(final Account account) {
119 this.account = account;
120 }
121
122 private SaslMechanism of(
123 final Collection<String> mechanisms, final ChannelBinding channelBinding) {
124 Preconditions.checkNotNull(channelBinding, "Use ChannelBinding.NONE instead of null");
125 if (mechanisms.contains(External.MECHANISM) && account.getPrivateKeyAlias() != null) {
126 return new External(account);
127 } else if (mechanisms.contains(ScramSha512Plus.MECHANISM)
128 && channelBinding != ChannelBinding.NONE) {
129 return new ScramSha512Plus(account, channelBinding);
130 } else if (mechanisms.contains(ScramSha256Plus.MECHANISM)
131 && channelBinding != ChannelBinding.NONE) {
132 return new ScramSha256Plus(account, channelBinding);
133 } else if (mechanisms.contains(ScramSha1Plus.MECHANISM)
134 && channelBinding != ChannelBinding.NONE) {
135 return new ScramSha1Plus(account, channelBinding);
136 } else if (mechanisms.contains(ScramSha512.MECHANISM)) {
137 return new ScramSha512(account);
138 } else if (mechanisms.contains(ScramSha256.MECHANISM)) {
139 return new ScramSha256(account);
140 } else if (mechanisms.contains(ScramSha1.MECHANISM)) {
141 return new ScramSha1(account);
142 } else if (mechanisms.contains(Plain.MECHANISM)
143 && !account.getServer().equals("nimbuzz.com")) {
144 return new Plain(account);
145 } else if (mechanisms.contains(DigestMd5.MECHANISM)) {
146 return new DigestMd5(account);
147 } else if (mechanisms.contains(Anonymous.MECHANISM)) {
148 return new Anonymous(account);
149 } else {
150 return null;
151 }
152 }
153
154 public SaslMechanism of(
155 final Collection<String> mechanisms,
156 final Collection<ChannelBinding> bindings,
157 final Version version,
158 final SSLSockets.Version sslVersion) {
159 final HashedToken fastMechanism = account.getFastMechanism();
160 if (version == Version.SASL_2 && fastMechanism != null) {
161 return fastMechanism;
162 }
163 final ChannelBinding channelBinding = ChannelBinding.best(bindings, sslVersion);
164 return of(mechanisms, channelBinding);
165 }
166
167 public SaslMechanism of(final String mechanism, final ChannelBinding channelBinding) {
168 return of(Collections.singleton(mechanism), channelBinding);
169 }
170 }
171
172 public static SaslMechanism ensureAvailable(
173 final SaslMechanism mechanism, final SSLSockets.Version sslVersion) {
174 if (mechanism instanceof ChannelBindingMechanism) {
175 final ChannelBinding cb = ((ChannelBindingMechanism) mechanism).getChannelBinding();
176 if (ChannelBinding.isAvailable(cb, sslVersion)) {
177 return mechanism;
178 } else {
179 Log.d(
180 Config.LOGTAG,
181 "pinned channel binding method " + cb + " no longer available");
182 return null;
183 }
184 } else {
185 return mechanism;
186 }
187 }
188
189 public static boolean hashedToken(final SaslMechanism saslMechanism) {
190 return saslMechanism instanceof HashedToken;
191 }
192
193 public static boolean pin(final SaslMechanism saslMechanism) {
194 return !hashedToken(saslMechanism);
195 }
196}