1package de.gultsch.minidns;
2
3import android.util.Log;
4
5import com.google.common.base.Preconditions;
6import com.google.common.base.Strings;
7import com.google.common.util.concurrent.ListenableFuture;
8import com.google.common.util.concurrent.SettableFuture;
9
10import de.measite.minidns.DNSMessage;
11
12import eu.siacs.conversations.Config;
13
14import org.conscrypt.OkHostnameVerifier;
15
16import java.io.Closeable;
17import java.io.DataInputStream;
18import java.io.DataOutputStream;
19import java.io.EOFException;
20import java.io.IOException;
21import java.net.InetSocketAddress;
22import java.net.Socket;
23import java.net.SocketAddress;
24import java.security.cert.Certificate;
25import java.security.cert.X509Certificate;
26import java.util.HashMap;
27import java.util.Iterator;
28import java.util.Map;
29import java.util.concurrent.ExecutionException;
30import java.util.concurrent.Semaphore;
31import java.util.concurrent.TimeUnit;
32import java.util.concurrent.TimeoutException;
33
34import javax.net.ssl.SSLPeerUnverifiedException;
35import javax.net.ssl.SSLSession;
36import javax.net.ssl.SSLSocket;
37import javax.net.ssl.SSLSocketFactory;
38
39final class DNSSocket implements Closeable {
40
41 public static final int QUERY_TIMEOUT = 5_000;
42
43 private final Semaphore semaphore = new Semaphore(1);
44 private final Map<Integer, SettableFuture<DNSMessage>> inFlightQueries = new HashMap<>();
45 private final Socket socket;
46 private final DataInputStream dataInputStream;
47 private final DataOutputStream dataOutputStream;
48
49 private DNSSocket(
50 final Socket socket,
51 final DataInputStream dataInputStream,
52 final DataOutputStream dataOutputStream) {
53 this.socket = socket;
54 this.dataInputStream = dataInputStream;
55 this.dataOutputStream = dataOutputStream;
56 new Thread(this::readDNSMessages).start();
57 }
58
59 private void readDNSMessages() {
60 try {
61 while (socket.isConnected()) {
62 final DNSMessage response = readDNSMessage();
63 final SettableFuture<DNSMessage> future;
64 synchronized (inFlightQueries) {
65 future = inFlightQueries.remove(response.id);
66 }
67 if (future != null) {
68 future.set(response);
69 } else {
70 Log.e(Config.LOGTAG, "no in flight query found for response id " + response.id);
71 }
72 }
73 evictInFlightQueries(new EOFException());
74 } catch (final IOException e) {
75 evictInFlightQueries(e);
76 }
77 }
78
79 private void evictInFlightQueries(final Exception e) {
80 synchronized (inFlightQueries) {
81 final Iterator<Map.Entry<Integer, SettableFuture<DNSMessage>>> iterator =
82 inFlightQueries.entrySet().iterator();
83 while (iterator.hasNext()) {
84 final Map.Entry<Integer, SettableFuture<DNSMessage>> entry = iterator.next();
85 entry.getValue().setException(e);
86 iterator.remove();
87 }
88 }
89 }
90
91 private static DNSSocket of(final Socket socket) throws IOException {
92 final DataInputStream dataInputStream = new DataInputStream(socket.getInputStream());
93 final DataOutputStream dataOutputStream = new DataOutputStream(socket.getOutputStream());
94 return new DNSSocket(socket, dataInputStream, dataOutputStream);
95 }
96
97 public static DNSSocket connect(final DNSServer dnsServer) throws IOException {
98 switch (dnsServer.uniqueTransport()) {
99 case TCP:
100 return connectTcpSocket(dnsServer);
101 case TLS:
102 return connectTlsSocket(dnsServer);
103 default:
104 throw new IllegalStateException("This is not a socket based transport");
105 }
106 }
107
108 private static DNSSocket connectTcpSocket(final DNSServer dnsServer) throws IOException {
109 Preconditions.checkArgument(dnsServer.uniqueTransport() == Transport.TCP);
110 final SocketAddress socketAddress =
111 new InetSocketAddress(dnsServer.inetAddress, dnsServer.port);
112 final Socket socket = new Socket();
113 socket.connect(socketAddress, QUERY_TIMEOUT / 2);
114 socket.setSoTimeout(QUERY_TIMEOUT);
115 return DNSSocket.of(socket);
116 }
117
118 private static DNSSocket connectTlsSocket(final DNSServer dnsServer) throws IOException {
119 Preconditions.checkArgument(dnsServer.uniqueTransport() == Transport.TLS);
120 final SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault();
121 final SSLSocket sslSocket = (SSLSocket) factory.createSocket();
122 if (Strings.isNullOrEmpty(dnsServer.hostname)) {
123 final SocketAddress socketAddress =
124 new InetSocketAddress(dnsServer.inetAddress, dnsServer.port);
125 sslSocket.connect(socketAddress, QUERY_TIMEOUT / 2);
126 sslSocket.setSoTimeout(QUERY_TIMEOUT);
127 sslSocket.startHandshake();
128 } else {
129 final SocketAddress socketAddress =
130 new InetSocketAddress(dnsServer.hostname, dnsServer.port);
131 sslSocket.connect(socketAddress, QUERY_TIMEOUT / 2);
132 sslSocket.setSoTimeout(QUERY_TIMEOUT);
133 sslSocket.startHandshake();
134 final SSLSession session = sslSocket.getSession();
135 final Certificate[] peerCertificates = session.getPeerCertificates();
136 if (peerCertificates.length == 0 || !(peerCertificates[0] instanceof X509Certificate)) {
137 throw new IOException("Peer did not provide X509 certificates");
138 }
139 final X509Certificate certificate = (X509Certificate) peerCertificates[0];
140 if (!OkHostnameVerifier.strictInstance().verify(dnsServer.hostname, certificate)) {
141 throw new SSLPeerUnverifiedException("Peer did not provide valid certificates");
142 }
143 }
144 return DNSSocket.of(sslSocket);
145 }
146
147 public DNSMessage query(final DNSMessage query) throws IOException, InterruptedException {
148 try {
149 return queryAsync(query).get(QUERY_TIMEOUT, TimeUnit.MILLISECONDS);
150 } catch (final ExecutionException e) {
151 final Throwable cause = e.getCause();
152 if (cause instanceof IOException) {
153 throw (IOException) cause;
154 } else {
155 throw new IOException(e);
156 }
157 } catch (final TimeoutException e) {
158 throw new IOException(e);
159 }
160 }
161
162 public ListenableFuture<DNSMessage> queryAsync(final DNSMessage query)
163 throws InterruptedException, IOException {
164 final SettableFuture<DNSMessage> responseFuture = SettableFuture.create();
165 synchronized (this.inFlightQueries) {
166 this.inFlightQueries.put(query.id, responseFuture);
167 }
168 this.semaphore.acquire();
169 try {
170 query.writeTo(this.dataOutputStream);
171 this.dataOutputStream.flush();
172 } finally {
173 this.semaphore.release();
174 }
175 return responseFuture;
176 }
177
178 private DNSMessage readDNSMessage() throws IOException {
179 final int length = this.dataInputStream.readUnsignedShort();
180 byte[] data = new byte[length];
181 int read = 0;
182 while (read < length) {
183 read += this.dataInputStream.read(data, read, length - read);
184 }
185 return NetworkDataSource.readDNSMessage(data);
186 }
187
188 @Override
189 public void close() throws IOException {
190 this.socket.close();
191 }
192
193 public void closeQuietly() {
194 try {
195 this.socket.close();
196 } catch (final IOException ignored) {
197
198 }
199 }
200}