1package de.gultsch.minidns;
2
3import android.util.Log;
4
5import com.google.common.base.Preconditions;
6import com.google.common.base.Strings;
7import com.google.common.util.concurrent.ListenableFuture;
8import com.google.common.util.concurrent.SettableFuture;
9
10import de.measite.minidns.DNSMessage;
11
12import eu.siacs.conversations.Config;
13
14import org.conscrypt.OkHostnameVerifier;
15
16import java.io.Closeable;
17import java.io.DataInputStream;
18import java.io.DataOutputStream;
19import java.io.EOFException;
20import java.io.IOException;
21import java.net.InetSocketAddress;
22import java.net.Socket;
23import java.net.SocketAddress;
24import java.security.cert.Certificate;
25import java.security.cert.X509Certificate;
26import java.util.HashMap;
27import java.util.Iterator;
28import java.util.Map;
29import java.util.concurrent.ExecutionException;
30import java.util.concurrent.Semaphore;
31import java.util.concurrent.TimeUnit;
32import java.util.concurrent.TimeoutException;
33
34import javax.net.ssl.SSLPeerUnverifiedException;
35import javax.net.ssl.SSLSession;
36import javax.net.ssl.SSLSocket;
37import javax.net.ssl.SSLSocketFactory;
38
39final class DNSSocket implements Closeable {
40
41 private static final int CONNECT_TIMEOUT = 5_000;
42 public static final int QUERY_TIMEOUT = 5_000;
43
44 private final Semaphore semaphore = new Semaphore(1);
45 private final Map<Integer, SettableFuture<DNSMessage>> inFlightQueries = new HashMap<>();
46 private final Socket socket;
47 private final DataInputStream dataInputStream;
48 private final DataOutputStream dataOutputStream;
49
50 private DNSSocket(
51 final Socket socket,
52 final DataInputStream dataInputStream,
53 final DataOutputStream dataOutputStream) {
54 this.socket = socket;
55 this.dataInputStream = dataInputStream;
56 this.dataOutputStream = dataOutputStream;
57 new Thread(this::readDNSMessages).start();
58 }
59
60 private void readDNSMessages() {
61 try {
62 while (socket.isConnected()) {
63 final DNSMessage response = readDNSMessage();
64 final SettableFuture<DNSMessage> future;
65 synchronized (inFlightQueries) {
66 future = inFlightQueries.remove(response.id);
67 }
68 if (future != null) {
69 future.set(response);
70 } else {
71 Log.e(Config.LOGTAG, "no in flight query found for response id " + response.id);
72 }
73 }
74 evictInFlightQueries(new EOFException());
75 } catch (final IOException e) {
76 evictInFlightQueries(e);
77 }
78 }
79
80 private void evictInFlightQueries(final Exception e) {
81 synchronized (inFlightQueries) {
82 final Iterator<Map.Entry<Integer, SettableFuture<DNSMessage>>> iterator =
83 inFlightQueries.entrySet().iterator();
84 while (iterator.hasNext()) {
85 final Map.Entry<Integer, SettableFuture<DNSMessage>> entry = iterator.next();
86 entry.getValue().setException(e);
87 iterator.remove();
88 }
89 }
90 }
91
92 private static DNSSocket of(final Socket socket) throws IOException {
93 final DataInputStream dataInputStream = new DataInputStream(socket.getInputStream());
94 final DataOutputStream dataOutputStream = new DataOutputStream(socket.getOutputStream());
95 return new DNSSocket(socket, dataInputStream, dataOutputStream);
96 }
97
98 public static DNSSocket connect(final DNSServer dnsServer) throws IOException {
99 switch (dnsServer.uniqueTransport()) {
100 case TCP:
101 return connectTcpSocket(dnsServer);
102 case TLS:
103 return connectTlsSocket(dnsServer);
104 default:
105 throw new IllegalStateException("This is not a socket based transport");
106 }
107 }
108
109 private static DNSSocket connectTcpSocket(final DNSServer dnsServer) throws IOException {
110 Preconditions.checkArgument(dnsServer.uniqueTransport() == Transport.TCP);
111 final SocketAddress socketAddress =
112 new InetSocketAddress(dnsServer.inetAddress, dnsServer.port);
113 final Socket socket = new Socket();
114 socket.connect(socketAddress, CONNECT_TIMEOUT);
115 socket.setSoTimeout(QUERY_TIMEOUT);
116 return DNSSocket.of(socket);
117 }
118
119 private static DNSSocket connectTlsSocket(final DNSServer dnsServer) throws IOException {
120 Preconditions.checkArgument(dnsServer.uniqueTransport() == Transport.TLS);
121 final SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault();
122 final SSLSocket sslSocket;
123 if (Strings.isNullOrEmpty(dnsServer.hostname)) {
124 final SocketAddress socketAddress =
125 new InetSocketAddress(dnsServer.inetAddress, dnsServer.port);
126 sslSocket = (SSLSocket) factory.createSocket(dnsServer.inetAddress, dnsServer.port);
127 sslSocket.connect(socketAddress, CONNECT_TIMEOUT);
128 sslSocket.setSoTimeout(QUERY_TIMEOUT);
129 } else {
130 sslSocket = (SSLSocket) factory.createSocket(dnsServer.hostname, dnsServer.port);
131 sslSocket.setSoTimeout(QUERY_TIMEOUT);
132 final SSLSession session = sslSocket.getSession();
133 final Certificate[] peerCertificates = session.getPeerCertificates();
134 if (peerCertificates.length == 0 || !(peerCertificates[0] instanceof X509Certificate)) {
135 throw new IOException("Peer did not provide X509 certificates");
136 }
137 final X509Certificate certificate = (X509Certificate) peerCertificates[0];
138 if (!OkHostnameVerifier.strictInstance().verify(dnsServer.hostname, certificate)) {
139 throw new SSLPeerUnverifiedException("Peer did not provide valid certificates");
140 }
141 }
142 return DNSSocket.of(sslSocket);
143 }
144
145 public DNSMessage query(final DNSMessage query) throws IOException, InterruptedException {
146 try {
147 return queryAsync(query).get(QUERY_TIMEOUT, TimeUnit.MILLISECONDS);
148 } catch (final ExecutionException e) {
149 final Throwable cause = e.getCause();
150 if (cause instanceof IOException) {
151 throw (IOException) cause;
152 } else {
153 throw new IOException(e);
154 }
155 } catch (final TimeoutException e) {
156 throw new IOException(e);
157 }
158 }
159
160 public ListenableFuture<DNSMessage> queryAsync(final DNSMessage query)
161 throws InterruptedException, IOException {
162 final SettableFuture<DNSMessage> responseFuture = SettableFuture.create();
163 synchronized (this.inFlightQueries) {
164 this.inFlightQueries.put(query.id, responseFuture);
165 }
166 this.semaphore.acquire();
167 try {
168 query.writeTo(this.dataOutputStream);
169 this.dataOutputStream.flush();
170 } finally {
171 this.semaphore.release();
172 }
173 return responseFuture;
174 }
175
176 private DNSMessage readDNSMessage() throws IOException {
177 final int length = this.dataInputStream.readUnsignedShort();
178 byte[] data = new byte[length];
179 int read = 0;
180 while (read < length) {
181 read += this.dataInputStream.read(data, read, length - read);
182 }
183 return new DNSMessage(data);
184 }
185
186 @Override
187 public void close() throws IOException {
188 this.socket.close();
189 }
190
191 public void closeQuietly() {
192 try {
193 this.socket.close();
194 } catch (final IOException ignored) {
195
196 }
197 }
198}