1package de.gultsch.minidns;
2
3import android.util.Log;
4
5import com.google.common.base.Preconditions;
6import com.google.common.base.Strings;
7import com.google.common.util.concurrent.ListenableFuture;
8import com.google.common.util.concurrent.SettableFuture;
9
10import de.measite.minidns.DNSMessage;
11
12import eu.siacs.conversations.Config;
13
14import org.conscrypt.OkHostnameVerifier;
15
16import java.io.Closeable;
17import java.io.DataInputStream;
18import java.io.DataOutputStream;
19import java.io.EOFException;
20import java.io.IOException;
21import java.net.InetSocketAddress;
22import java.net.Socket;
23import java.net.SocketAddress;
24import java.security.cert.Certificate;
25import java.security.cert.X509Certificate;
26import java.util.HashMap;
27import java.util.Iterator;
28import java.util.Map;
29import java.util.concurrent.ExecutionException;
30import java.util.concurrent.Semaphore;
31
32import javax.net.ssl.SSLPeerUnverifiedException;
33import javax.net.ssl.SSLSession;
34import javax.net.ssl.SSLSocket;
35import javax.net.ssl.SSLSocketFactory;
36
37final class DNSSocket implements Closeable {
38
39 private static final int CONNECT_TIMEOUT = 5_000;
40
41 private final Semaphore semaphore = new Semaphore(1);
42 private final Map<Integer, SettableFuture<DNSMessage>> inFlightQueries = new HashMap<>();
43 private final Socket socket;
44 private final DataInputStream dataInputStream;
45 private final DataOutputStream dataOutputStream;
46
47 private DNSSocket(
48 final Socket socket,
49 final DataInputStream dataInputStream,
50 final DataOutputStream dataOutputStream) {
51 this.socket = socket;
52 this.dataInputStream = dataInputStream;
53 this.dataOutputStream = dataOutputStream;
54 new Thread(this::readDNSMessages).start();
55 }
56
57 private void readDNSMessages() {
58 try {
59 while (socket.isConnected()) {
60 final DNSMessage response = readDNSMessage();
61 final SettableFuture<DNSMessage> future;
62 synchronized (inFlightQueries) {
63 future = inFlightQueries.remove(response.id);
64 }
65 if (future != null) {
66 future.set(response);
67 } else {
68 Log.e(Config.LOGTAG, "no in flight query found for response id " + response.id);
69 }
70 }
71 evictInFlightQueries(new EOFException());
72 } catch (final IOException e) {
73 evictInFlightQueries(e);
74 }
75 }
76
77 private void evictInFlightQueries(final Exception e) {
78 synchronized (inFlightQueries) {
79 final Iterator<Map.Entry<Integer, SettableFuture<DNSMessage>>> iterator =
80 inFlightQueries.entrySet().iterator();
81 while (iterator.hasNext()) {
82 final Map.Entry<Integer, SettableFuture<DNSMessage>> entry = iterator.next();
83 entry.getValue().setException(e);
84 iterator.remove();
85 }
86 }
87 }
88
89 private static DNSSocket of(final Socket socket) throws IOException {
90 final DataInputStream dataInputStream = new DataInputStream(socket.getInputStream());
91 final DataOutputStream dataOutputStream = new DataOutputStream(socket.getOutputStream());
92 return new DNSSocket(socket, dataInputStream, dataOutputStream);
93 }
94
95 public static DNSSocket connect(final DNSServer dnsServer) throws IOException {
96 switch (dnsServer.uniqueTransport()) {
97 case TCP:
98 return connectTcpSocket(dnsServer);
99 case TLS:
100 return connectTlsSocket(dnsServer);
101 default:
102 throw new IllegalStateException("This is not a socket based transport");
103 }
104 }
105
106 private static DNSSocket connectTcpSocket(final DNSServer dnsServer) throws IOException {
107 Preconditions.checkArgument(dnsServer.uniqueTransport() == Transport.TCP);
108 final SocketAddress socketAddress =
109 new InetSocketAddress(dnsServer.inetAddress, dnsServer.port);
110 final Socket socket = new Socket();
111 socket.connect(socketAddress, CONNECT_TIMEOUT);
112 return DNSSocket.of(socket);
113 }
114
115 private static DNSSocket connectTlsSocket(final DNSServer dnsServer) throws IOException {
116 Preconditions.checkArgument(dnsServer.uniqueTransport() == Transport.TLS);
117 final SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault();
118 final SSLSocket sslSocket;
119 if (Strings.isNullOrEmpty(dnsServer.hostname)) {
120 final SocketAddress socketAddress =
121 new InetSocketAddress(dnsServer.inetAddress, dnsServer.port);
122 sslSocket = (SSLSocket) factory.createSocket(dnsServer.inetAddress, dnsServer.port);
123 sslSocket.connect(socketAddress, 5_000);
124 } else {
125 sslSocket = (SSLSocket) factory.createSocket(dnsServer.hostname, dnsServer.port);
126 final SSLSession session = sslSocket.getSession();
127 final Certificate[] peerCertificates = session.getPeerCertificates();
128 if (peerCertificates.length == 0 || !(peerCertificates[0] instanceof X509Certificate)) {
129 throw new IOException("Peer did not provide X509 certificates");
130 }
131 final X509Certificate certificate = (X509Certificate) peerCertificates[0];
132 if (!OkHostnameVerifier.strictInstance().verify(dnsServer.hostname, certificate)) {
133 throw new SSLPeerUnverifiedException("Peer did not provide valid certificates");
134 }
135 }
136 return DNSSocket.of(sslSocket);
137 }
138
139 public DNSMessage query(final DNSMessage query) throws IOException, InterruptedException {
140 try {
141 return queryAsync(query).get();
142 } catch (final ExecutionException e) {
143 final Throwable cause = e.getCause();
144 if (cause instanceof IOException) {
145 throw (IOException) cause;
146 } else {
147 throw new IOException(e);
148 }
149 }
150 }
151
152 public ListenableFuture<DNSMessage> queryAsync(final DNSMessage query)
153 throws InterruptedException, IOException {
154 final SettableFuture<DNSMessage> responseFuture = SettableFuture.create();
155 synchronized (this.inFlightQueries) {
156 this.inFlightQueries.put(query.id, responseFuture);
157 }
158 this.semaphore.acquire();
159 try {
160 query.writeTo(this.dataOutputStream);
161 this.dataOutputStream.flush();
162 } finally {
163 this.semaphore.release();
164 }
165 return responseFuture;
166 }
167
168 private DNSMessage readDNSMessage() throws IOException {
169 final int length = this.dataInputStream.readUnsignedShort();
170 byte[] data = new byte[length];
171 int read = 0;
172 while (read < length) {
173 read += this.dataInputStream.read(data, read, length - read);
174 }
175 return new DNSMessage(data);
176 }
177
178 @Override
179 public void close() throws IOException {
180 this.socket.close();
181 }
182
183 public void closeQuietly() {
184 try {
185 this.socket.close();
186 } catch (final IOException ignored) {
187
188 }
189 }
190}