DNSSocket.java

  1package de.gultsch.minidns;
  2
  3import android.util.Log;
  4
  5import com.google.common.base.Preconditions;
  6import com.google.common.base.Strings;
  7import com.google.common.util.concurrent.ListenableFuture;
  8import com.google.common.util.concurrent.SettableFuture;
  9
 10import de.measite.minidns.DNSMessage;
 11
 12import eu.siacs.conversations.Config;
 13
 14import org.conscrypt.OkHostnameVerifier;
 15
 16import java.io.Closeable;
 17import java.io.DataInputStream;
 18import java.io.DataOutputStream;
 19import java.io.EOFException;
 20import java.io.IOException;
 21import java.net.InetSocketAddress;
 22import java.net.Socket;
 23import java.net.SocketAddress;
 24import java.security.cert.Certificate;
 25import java.security.cert.X509Certificate;
 26import java.util.HashMap;
 27import java.util.Iterator;
 28import java.util.Map;
 29import java.util.concurrent.ExecutionException;
 30import java.util.concurrent.Semaphore;
 31import java.util.concurrent.TimeUnit;
 32import java.util.concurrent.TimeoutException;
 33
 34import javax.net.ssl.SSLPeerUnverifiedException;
 35import javax.net.ssl.SSLSession;
 36import javax.net.ssl.SSLSocket;
 37import javax.net.ssl.SSLSocketFactory;
 38
 39final class DNSSocket implements Closeable {
 40
 41    public static final int QUERY_TIMEOUT = 5_000;
 42
 43    private final Semaphore semaphore = new Semaphore(1);
 44    private final Map<Integer, SettableFuture<DNSMessage>> inFlightQueries = new HashMap<>();
 45    private final Socket socket;
 46    private final DataInputStream dataInputStream;
 47    private final DataOutputStream dataOutputStream;
 48
 49    private DNSSocket(
 50            final Socket socket,
 51            final DataInputStream dataInputStream,
 52            final DataOutputStream dataOutputStream) {
 53        this.socket = socket;
 54        this.dataInputStream = dataInputStream;
 55        this.dataOutputStream = dataOutputStream;
 56        new Thread(this::readDNSMessages).start();
 57    }
 58
 59    private void readDNSMessages() {
 60        try {
 61            while (socket.isConnected()) {
 62                final DNSMessage response = readDNSMessage();
 63                final SettableFuture<DNSMessage> future;
 64                synchronized (inFlightQueries) {
 65                    future = inFlightQueries.remove(response.id);
 66                }
 67                if (future != null) {
 68                    future.set(response);
 69                } else {
 70                    Log.e(Config.LOGTAG, "no in flight query found for response id " + response.id);
 71                }
 72            }
 73            evictInFlightQueries(new EOFException());
 74        } catch (final IOException e) {
 75            evictInFlightQueries(e);
 76        }
 77    }
 78
 79    private void evictInFlightQueries(final Exception e) {
 80        synchronized (inFlightQueries) {
 81            final Iterator<Map.Entry<Integer, SettableFuture<DNSMessage>>> iterator =
 82                    inFlightQueries.entrySet().iterator();
 83            while (iterator.hasNext()) {
 84                final Map.Entry<Integer, SettableFuture<DNSMessage>> entry = iterator.next();
 85                entry.getValue().setException(e);
 86                iterator.remove();
 87            }
 88        }
 89    }
 90
 91    private static DNSSocket of(final Socket socket) throws IOException {
 92        final DataInputStream dataInputStream = new DataInputStream(socket.getInputStream());
 93        final DataOutputStream dataOutputStream = new DataOutputStream(socket.getOutputStream());
 94        return new DNSSocket(socket, dataInputStream, dataOutputStream);
 95    }
 96
 97    public static DNSSocket connect(final DNSServer dnsServer) throws IOException {
 98        switch (dnsServer.uniqueTransport()) {
 99            case TCP:
100                return connectTcpSocket(dnsServer);
101            case TLS:
102                return connectTlsSocket(dnsServer);
103            default:
104                throw new IllegalStateException("This is not a socket based transport");
105        }
106    }
107
108    private static DNSSocket connectTcpSocket(final DNSServer dnsServer) throws IOException {
109        Preconditions.checkArgument(dnsServer.uniqueTransport() == Transport.TCP);
110        final SocketAddress socketAddress =
111                new InetSocketAddress(dnsServer.inetAddress, dnsServer.port);
112        final Socket socket = new Socket();
113        socket.connect(socketAddress, QUERY_TIMEOUT / 2);
114        socket.setSoTimeout(QUERY_TIMEOUT);
115        return DNSSocket.of(socket);
116    }
117
118    private static DNSSocket connectTlsSocket(final DNSServer dnsServer) throws IOException {
119        Preconditions.checkArgument(dnsServer.uniqueTransport() == Transport.TLS);
120        final SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault();
121        final SSLSocket sslSocket = (SSLSocket) factory.createSocket();
122        if (Strings.isNullOrEmpty(dnsServer.hostname)) {
123            final SocketAddress socketAddress =
124                    new InetSocketAddress(dnsServer.inetAddress, dnsServer.port);
125            sslSocket.connect(socketAddress, QUERY_TIMEOUT / 2);
126            sslSocket.setSoTimeout(QUERY_TIMEOUT);
127            sslSocket.startHandshake();
128        } else {
129            final SocketAddress socketAddress =
130                    new InetSocketAddress(dnsServer.hostname, dnsServer.port);
131            sslSocket.connect(socketAddress, QUERY_TIMEOUT / 2);
132            sslSocket.setSoTimeout(QUERY_TIMEOUT);
133            sslSocket.startHandshake();
134            final SSLSession session = sslSocket.getSession();
135            final Certificate[] peerCertificates = session.getPeerCertificates();
136            if (peerCertificates.length == 0 || !(peerCertificates[0] instanceof X509Certificate)) {
137                throw new IOException("Peer did not provide X509 certificates");
138            }
139            final X509Certificate certificate = (X509Certificate) peerCertificates[0];
140            if (!OkHostnameVerifier.strictInstance().verify(dnsServer.hostname, certificate)) {
141                throw new SSLPeerUnverifiedException("Peer did not provide valid certificates");
142            }
143        }
144        return DNSSocket.of(sslSocket);
145    }
146
147    public DNSMessage query(final DNSMessage query) throws IOException, InterruptedException {
148        try {
149            return queryAsync(query).get(QUERY_TIMEOUT, TimeUnit.MILLISECONDS);
150        } catch (final ExecutionException e) {
151            final Throwable cause = e.getCause();
152            if (cause instanceof IOException) {
153                throw (IOException) cause;
154            } else {
155                throw new IOException(e);
156            }
157        } catch (final TimeoutException e) {
158            throw new IOException(e);
159        }
160    }
161
162    public ListenableFuture<DNSMessage> queryAsync(final DNSMessage query)
163            throws InterruptedException, IOException {
164        final SettableFuture<DNSMessage> responseFuture = SettableFuture.create();
165        synchronized (this.inFlightQueries) {
166            this.inFlightQueries.put(query.id, responseFuture);
167        }
168        this.semaphore.acquire();
169        try {
170            query.writeTo(this.dataOutputStream);
171            this.dataOutputStream.flush();
172        } finally {
173            this.semaphore.release();
174        }
175        return responseFuture;
176    }
177
178    private DNSMessage readDNSMessage() throws IOException {
179        final int length = this.dataInputStream.readUnsignedShort();
180        byte[] data = new byte[length];
181        int read = 0;
182        while (read < length) {
183            read += this.dataInputStream.read(data, read, length - read);
184        }
185        return NetworkDataSource.readDNSMessage(data);
186    }
187
188    @Override
189    public void close() throws IOException {
190        this.socket.close();
191    }
192
193    public void closeQuietly() {
194        try {
195            this.socket.close();
196        } catch (final IOException ignored) {
197
198        }
199    }
200}