1package de.gultsch.minidns;
2
3import android.util.Log;
4
5import com.google.common.base.Preconditions;
6import com.google.common.base.Strings;
7import com.google.common.util.concurrent.ListenableFuture;
8import com.google.common.util.concurrent.SettableFuture;
9
10
11import eu.siacs.conversations.Config;
12
13import org.conscrypt.OkHostnameVerifier;
14import org.minidns.dnsmessage.DnsMessage;
15
16import java.io.Closeable;
17import java.io.DataInputStream;
18import java.io.DataOutputStream;
19import java.io.EOFException;
20import java.io.IOException;
21import java.net.InetSocketAddress;
22import java.net.Socket;
23import java.net.SocketAddress;
24import java.security.cert.Certificate;
25import java.security.cert.X509Certificate;
26import java.util.HashMap;
27import java.util.Iterator;
28import java.util.Map;
29import java.util.concurrent.ExecutionException;
30import java.util.concurrent.Semaphore;
31import java.util.concurrent.TimeUnit;
32import java.util.concurrent.TimeoutException;
33
34import javax.net.ssl.SSLPeerUnverifiedException;
35import javax.net.ssl.SSLSession;
36import javax.net.ssl.SSLSocket;
37import javax.net.ssl.SSLSocketFactory;
38
39final class DNSSocket implements Closeable {
40
41 public static final int QUERY_TIMEOUT = 5_000;
42
43 private final Semaphore semaphore = new Semaphore(1);
44 private final Map<Integer, SettableFuture<DnsMessage>> inFlightQueries = new HashMap<>();
45 private final Socket socket;
46 private final DataInputStream dataInputStream;
47 private final DataOutputStream dataOutputStream;
48
49 private DNSSocket(
50 final Socket socket,
51 final DataInputStream dataInputStream,
52 final DataOutputStream dataOutputStream) {
53 this.socket = socket;
54 this.dataInputStream = dataInputStream;
55 this.dataOutputStream = dataOutputStream;
56 new Thread(this::readDNSMessages).start();
57 }
58
59 private void readDNSMessages() {
60 try {
61 while (socket.isConnected()) {
62 final DnsMessage response = readDNSMessage();
63 final SettableFuture<DnsMessage> future;
64 synchronized (inFlightQueries) {
65 future = inFlightQueries.remove(response.id);
66 }
67 if (future != null) {
68 future.set(response);
69 } else {
70 Log.e(Config.LOGTAG, "no in flight query found for response id " + response.id);
71 }
72 }
73 evictInFlightQueries(new EOFException());
74 } catch (final IOException e) {
75 evictInFlightQueries(e);
76 }
77 }
78
79 private void evictInFlightQueries(final Exception e) {
80 synchronized (inFlightQueries) {
81 final Iterator<Map.Entry<Integer, SettableFuture<DnsMessage>>> iterator =
82 inFlightQueries.entrySet().iterator();
83 while (iterator.hasNext()) {
84 final Map.Entry<Integer, SettableFuture<DnsMessage>> entry = iterator.next();
85 entry.getValue().setException(e);
86 iterator.remove();
87 }
88 }
89 }
90
91 private static DNSSocket of(final Socket socket) throws IOException {
92 final DataInputStream dataInputStream = new DataInputStream(socket.getInputStream());
93 final DataOutputStream dataOutputStream = new DataOutputStream(socket.getOutputStream());
94 return new DNSSocket(socket, dataInputStream, dataOutputStream);
95 }
96
97 public static DNSSocket connect(final DNSServer dnsServer) throws IOException {
98 return switch (dnsServer.uniqueTransport()) {
99 case TCP -> connectTcpSocket(dnsServer);
100 case TLS -> connectTlsSocket(dnsServer);
101 default -> throw new IllegalStateException("This is not a socket based transport");
102 };
103 }
104
105 private static DNSSocket connectTcpSocket(final DNSServer dnsServer) throws IOException {
106 Preconditions.checkArgument(dnsServer.uniqueTransport() == Transport.TCP);
107 final SocketAddress socketAddress =
108 new InetSocketAddress(dnsServer.inetAddress, dnsServer.port);
109 final Socket socket = new Socket();
110 socket.connect(socketAddress, QUERY_TIMEOUT / 2);
111 socket.setSoTimeout(QUERY_TIMEOUT);
112 return DNSSocket.of(socket);
113 }
114
115 private static DNSSocket connectTlsSocket(final DNSServer dnsServer) throws IOException {
116 Preconditions.checkArgument(dnsServer.uniqueTransport() == Transport.TLS);
117 final SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault();
118 final SSLSocket sslSocket = (SSLSocket) factory.createSocket();
119 if (Strings.isNullOrEmpty(dnsServer.hostname)) {
120 final SocketAddress socketAddress =
121 new InetSocketAddress(dnsServer.inetAddress, dnsServer.port);
122 sslSocket.connect(socketAddress, QUERY_TIMEOUT / 2);
123 sslSocket.setSoTimeout(QUERY_TIMEOUT);
124 sslSocket.startHandshake();
125 } else {
126 final SocketAddress socketAddress =
127 new InetSocketAddress(dnsServer.hostname, dnsServer.port);
128 sslSocket.connect(socketAddress, QUERY_TIMEOUT / 2);
129 sslSocket.setSoTimeout(QUERY_TIMEOUT);
130 sslSocket.startHandshake();
131 final SSLSession session = sslSocket.getSession();
132 final Certificate[] peerCertificates = session.getPeerCertificates();
133 if (peerCertificates.length == 0 || !(peerCertificates[0] instanceof X509Certificate certificate)) {
134 throw new IOException("Peer did not provide X509 certificates");
135 }
136 if (!OkHostnameVerifier.strictInstance().verify(dnsServer.hostname, certificate)) {
137 throw new SSLPeerUnverifiedException("Peer did not provide valid certificates");
138 }
139 }
140 return DNSSocket.of(sslSocket);
141 }
142
143 public DnsMessage query(final DnsMessage query) throws IOException, InterruptedException {
144 try {
145 return queryAsync(query).get(QUERY_TIMEOUT, TimeUnit.MILLISECONDS);
146 } catch (final ExecutionException e) {
147 final Throwable cause = e.getCause();
148 if (cause instanceof IOException) {
149 throw (IOException) cause;
150 } else {
151 throw new IOException(e);
152 }
153 } catch (final TimeoutException e) {
154 throw new IOException(e);
155 }
156 }
157
158 public ListenableFuture<DnsMessage> queryAsync(final DnsMessage query)
159 throws InterruptedException, IOException {
160 final SettableFuture<DnsMessage> responseFuture = SettableFuture.create();
161 synchronized (this.inFlightQueries) {
162 this.inFlightQueries.put(query.id, responseFuture);
163 }
164 this.semaphore.acquire();
165 try {
166 query.writeTo(this.dataOutputStream);
167 this.dataOutputStream.flush();
168 } finally {
169 this.semaphore.release();
170 }
171 return responseFuture;
172 }
173
174 private DnsMessage readDNSMessage() throws IOException {
175 final int length = this.dataInputStream.readUnsignedShort();
176 byte[] data = new byte[length];
177 int read = 0;
178 while (read < length) {
179 read += this.dataInputStream.read(data, read, length - read);
180 }
181 return NetworkDataSource.readDNSMessage(data);
182 }
183
184 @Override
185 public void close() throws IOException {
186 this.socket.close();
187 }
188
189 public void closeQuietly() {
190 try {
191 this.socket.close();
192 } catch (final IOException ignored) {
193
194 }
195 }
196}