DNSSocket.java

  1package de.gultsch.minidns;
  2
  3import android.util.Log;
  4
  5import com.google.common.base.Preconditions;
  6import com.google.common.base.Strings;
  7import com.google.common.util.concurrent.ListenableFuture;
  8import com.google.common.util.concurrent.SettableFuture;
  9
 10import de.measite.minidns.DNSMessage;
 11
 12import eu.siacs.conversations.Config;
 13
 14import org.conscrypt.OkHostnameVerifier;
 15
 16import java.io.Closeable;
 17import java.io.DataInputStream;
 18import java.io.DataOutputStream;
 19import java.io.EOFException;
 20import java.io.IOException;
 21import java.net.InetSocketAddress;
 22import java.net.Socket;
 23import java.net.SocketAddress;
 24import java.security.cert.Certificate;
 25import java.security.cert.X509Certificate;
 26import java.util.HashMap;
 27import java.util.Iterator;
 28import java.util.Map;
 29import java.util.concurrent.ExecutionException;
 30import java.util.concurrent.Semaphore;
 31import java.util.concurrent.TimeUnit;
 32import java.util.concurrent.TimeoutException;
 33
 34import javax.net.ssl.SSLPeerUnverifiedException;
 35import javax.net.ssl.SSLSession;
 36import javax.net.ssl.SSLSocket;
 37import javax.net.ssl.SSLSocketFactory;
 38
 39final class DNSSocket implements Closeable {
 40
 41    private static final int CONNECT_TIMEOUT = 5_000;
 42    public static final int QUERY_TIMEOUT = 5_000;
 43
 44    private final Semaphore semaphore = new Semaphore(1);
 45    private final Map<Integer, SettableFuture<DNSMessage>> inFlightQueries = new HashMap<>();
 46    private final Socket socket;
 47    private final DataInputStream dataInputStream;
 48    private final DataOutputStream dataOutputStream;
 49
 50    private DNSSocket(
 51            final Socket socket,
 52            final DataInputStream dataInputStream,
 53            final DataOutputStream dataOutputStream) {
 54        this.socket = socket;
 55        this.dataInputStream = dataInputStream;
 56        this.dataOutputStream = dataOutputStream;
 57        new Thread(this::readDNSMessages).start();
 58    }
 59
 60    private void readDNSMessages() {
 61        try {
 62            while (socket.isConnected()) {
 63                final DNSMessage response = readDNSMessage();
 64                final SettableFuture<DNSMessage> future;
 65                synchronized (inFlightQueries) {
 66                    future = inFlightQueries.remove(response.id);
 67                }
 68                if (future != null) {
 69                    future.set(response);
 70                } else {
 71                    Log.e(Config.LOGTAG, "no in flight query found for response id " + response.id);
 72                }
 73            }
 74            evictInFlightQueries(new EOFException());
 75        } catch (final IOException e) {
 76            evictInFlightQueries(e);
 77        }
 78    }
 79
 80    private void evictInFlightQueries(final Exception e) {
 81        synchronized (inFlightQueries) {
 82            final Iterator<Map.Entry<Integer, SettableFuture<DNSMessage>>> iterator =
 83                    inFlightQueries.entrySet().iterator();
 84            while (iterator.hasNext()) {
 85                final Map.Entry<Integer, SettableFuture<DNSMessage>> entry = iterator.next();
 86                entry.getValue().setException(e);
 87                iterator.remove();
 88            }
 89        }
 90    }
 91
 92    private static DNSSocket of(final Socket socket) throws IOException {
 93        final DataInputStream dataInputStream = new DataInputStream(socket.getInputStream());
 94        final DataOutputStream dataOutputStream = new DataOutputStream(socket.getOutputStream());
 95        return new DNSSocket(socket, dataInputStream, dataOutputStream);
 96    }
 97
 98    public static DNSSocket connect(final DNSServer dnsServer) throws IOException {
 99        switch (dnsServer.uniqueTransport()) {
100            case TCP:
101                return connectTcpSocket(dnsServer);
102            case TLS:
103                return connectTlsSocket(dnsServer);
104            default:
105                throw new IllegalStateException("This is not a socket based transport");
106        }
107    }
108
109    private static DNSSocket connectTcpSocket(final DNSServer dnsServer) throws IOException {
110        Preconditions.checkArgument(dnsServer.uniqueTransport() == Transport.TCP);
111        final SocketAddress socketAddress =
112                new InetSocketAddress(dnsServer.inetAddress, dnsServer.port);
113        final Socket socket = new Socket();
114        socket.connect(socketAddress, CONNECT_TIMEOUT);
115        socket.setSoTimeout(QUERY_TIMEOUT);
116        return DNSSocket.of(socket);
117    }
118
119    private static DNSSocket connectTlsSocket(final DNSServer dnsServer) throws IOException {
120        Preconditions.checkArgument(dnsServer.uniqueTransport() == Transport.TLS);
121        final SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault();
122        final SSLSocket sslSocket;
123        if (Strings.isNullOrEmpty(dnsServer.hostname)) {
124            final SocketAddress socketAddress =
125                    new InetSocketAddress(dnsServer.inetAddress, dnsServer.port);
126            sslSocket = (SSLSocket) factory.createSocket(dnsServer.inetAddress, dnsServer.port);
127            sslSocket.connect(socketAddress, CONNECT_TIMEOUT);
128            sslSocket.setSoTimeout(QUERY_TIMEOUT);
129        } else {
130            sslSocket = (SSLSocket) factory.createSocket(dnsServer.hostname, dnsServer.port);
131            sslSocket.setSoTimeout(QUERY_TIMEOUT);
132            final SSLSession session = sslSocket.getSession();
133            final Certificate[] peerCertificates = session.getPeerCertificates();
134            if (peerCertificates.length == 0 || !(peerCertificates[0] instanceof X509Certificate)) {
135                throw new IOException("Peer did not provide X509 certificates");
136            }
137            final X509Certificate certificate = (X509Certificate) peerCertificates[0];
138            if (!OkHostnameVerifier.strictInstance().verify(dnsServer.hostname, certificate)) {
139                throw new SSLPeerUnverifiedException("Peer did not provide valid certificates");
140            }
141        }
142        return DNSSocket.of(sslSocket);
143    }
144
145    public DNSMessage query(final DNSMessage query) throws IOException, InterruptedException {
146        try {
147            return queryAsync(query).get(QUERY_TIMEOUT, TimeUnit.MILLISECONDS);
148        } catch (final ExecutionException e) {
149            final Throwable cause = e.getCause();
150            if (cause instanceof IOException) {
151                throw (IOException) cause;
152            } else {
153                throw new IOException(e);
154            }
155        } catch (final TimeoutException e) {
156            throw new IOException(e);
157        }
158    }
159
160    public ListenableFuture<DNSMessage> queryAsync(final DNSMessage query)
161            throws InterruptedException, IOException {
162        final SettableFuture<DNSMessage> responseFuture = SettableFuture.create();
163        synchronized (this.inFlightQueries) {
164            this.inFlightQueries.put(query.id, responseFuture);
165        }
166        this.semaphore.acquire();
167        try {
168            query.writeTo(this.dataOutputStream);
169            this.dataOutputStream.flush();
170        } finally {
171            this.semaphore.release();
172        }
173        return responseFuture;
174    }
175
176    private DNSMessage readDNSMessage() throws IOException {
177        final int length = this.dataInputStream.readUnsignedShort();
178        byte[] data = new byte[length];
179        int read = 0;
180        while (read < length) {
181            read += this.dataInputStream.read(data, read, length - read);
182        }
183        return new DNSMessage(data);
184    }
185
186    @Override
187    public void close() throws IOException {
188        this.socket.close();
189    }
190
191    public void closeQuietly() {
192        try {
193            this.socket.close();
194        } catch (final IOException ignored) {
195
196        }
197    }
198}