DNSSocket.java

  1package de.gultsch.minidns;
  2
  3import android.util.Log;
  4
  5import com.google.common.base.Preconditions;
  6import com.google.common.base.Strings;
  7import com.google.common.util.concurrent.ListenableFuture;
  8import com.google.common.util.concurrent.SettableFuture;
  9
 10import de.measite.minidns.DNSMessage;
 11
 12import eu.siacs.conversations.Config;
 13
 14import org.conscrypt.OkHostnameVerifier;
 15
 16import java.io.Closeable;
 17import java.io.DataInputStream;
 18import java.io.DataOutputStream;
 19import java.io.EOFException;
 20import java.io.IOException;
 21import java.net.InetSocketAddress;
 22import java.net.Socket;
 23import java.net.SocketAddress;
 24import java.security.cert.Certificate;
 25import java.security.cert.X509Certificate;
 26import java.util.HashMap;
 27import java.util.Iterator;
 28import java.util.Map;
 29import java.util.concurrent.ExecutionException;
 30import java.util.concurrent.Semaphore;
 31
 32import javax.net.ssl.SSLPeerUnverifiedException;
 33import javax.net.ssl.SSLSession;
 34import javax.net.ssl.SSLSocket;
 35import javax.net.ssl.SSLSocketFactory;
 36
 37final class DNSSocket implements Closeable {
 38
 39    private static final int CONNECT_TIMEOUT = 5_000;
 40
 41    private final Semaphore semaphore = new Semaphore(1);
 42    private final Map<Integer, SettableFuture<DNSMessage>> inFlightQueries = new HashMap<>();
 43    private final Socket socket;
 44    private final DataInputStream dataInputStream;
 45    private final DataOutputStream dataOutputStream;
 46
 47    private DNSSocket(
 48            final Socket socket,
 49            final DataInputStream dataInputStream,
 50            final DataOutputStream dataOutputStream) {
 51        this.socket = socket;
 52        this.dataInputStream = dataInputStream;
 53        this.dataOutputStream = dataOutputStream;
 54        new Thread(this::readDNSMessages).start();
 55    }
 56
 57    private void readDNSMessages() {
 58        try {
 59            while (socket.isConnected()) {
 60                final DNSMessage response = readDNSMessage();
 61                final SettableFuture<DNSMessage> future;
 62                synchronized (inFlightQueries) {
 63                    future = inFlightQueries.remove(response.id);
 64                }
 65                if (future != null) {
 66                    future.set(response);
 67                } else {
 68                    Log.e(Config.LOGTAG, "no in flight query found for response id " + response.id);
 69                }
 70            }
 71            evictInFlightQueries(new EOFException());
 72        } catch (final IOException e) {
 73            evictInFlightQueries(e);
 74        }
 75    }
 76
 77    private void evictInFlightQueries(final Exception e) {
 78        synchronized (inFlightQueries) {
 79            final Iterator<Map.Entry<Integer, SettableFuture<DNSMessage>>> iterator =
 80                    inFlightQueries.entrySet().iterator();
 81            while (iterator.hasNext()) {
 82                final Map.Entry<Integer, SettableFuture<DNSMessage>> entry = iterator.next();
 83                entry.getValue().setException(e);
 84                iterator.remove();
 85            }
 86        }
 87    }
 88
 89    private static DNSSocket of(final Socket socket) throws IOException {
 90        final DataInputStream dataInputStream = new DataInputStream(socket.getInputStream());
 91        final DataOutputStream dataOutputStream = new DataOutputStream(socket.getOutputStream());
 92        return new DNSSocket(socket, dataInputStream, dataOutputStream);
 93    }
 94
 95    public static DNSSocket connect(final DNSServer dnsServer) throws IOException {
 96        switch (dnsServer.uniqueTransport()) {
 97            case TCP:
 98                return connectTcpSocket(dnsServer);
 99            case TLS:
100                return connectTlsSocket(dnsServer);
101            default:
102                throw new IllegalStateException("This is not a socket based transport");
103        }
104    }
105
106    private static DNSSocket connectTcpSocket(final DNSServer dnsServer) throws IOException {
107        Preconditions.checkArgument(dnsServer.uniqueTransport() == Transport.TCP);
108        final SocketAddress socketAddress =
109                new InetSocketAddress(dnsServer.inetAddress, dnsServer.port);
110        final Socket socket = new Socket();
111        socket.connect(socketAddress, CONNECT_TIMEOUT);
112        return DNSSocket.of(socket);
113    }
114
115    private static DNSSocket connectTlsSocket(final DNSServer dnsServer) throws IOException {
116        Preconditions.checkArgument(dnsServer.uniqueTransport() == Transport.TLS);
117        final SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault();
118        final SSLSocket sslSocket;
119        if (Strings.isNullOrEmpty(dnsServer.hostname)) {
120            final SocketAddress socketAddress =
121                    new InetSocketAddress(dnsServer.inetAddress, dnsServer.port);
122            sslSocket = (SSLSocket) factory.createSocket(dnsServer.inetAddress, dnsServer.port);
123            sslSocket.connect(socketAddress, 5_000);
124        } else {
125            sslSocket = (SSLSocket) factory.createSocket(dnsServer.hostname, dnsServer.port);
126            final SSLSession session = sslSocket.getSession();
127            final Certificate[] peerCertificates = session.getPeerCertificates();
128            if (peerCertificates.length == 0 || !(peerCertificates[0] instanceof X509Certificate)) {
129                throw new IOException("Peer did not provide X509 certificates");
130            }
131            final X509Certificate certificate = (X509Certificate) peerCertificates[0];
132            if (!OkHostnameVerifier.strictInstance().verify(dnsServer.hostname, certificate)) {
133                throw new SSLPeerUnverifiedException("Peer did not provide valid certificates");
134            }
135        }
136        return DNSSocket.of(sslSocket);
137    }
138
139    public DNSMessage query(final DNSMessage query) throws IOException, InterruptedException {
140        try {
141            return queryAsync(query).get();
142        } catch (final ExecutionException e) {
143            final Throwable cause = e.getCause();
144            if (cause instanceof IOException) {
145                throw (IOException) cause;
146            } else {
147                throw new IOException(e);
148            }
149        }
150    }
151
152    public ListenableFuture<DNSMessage> queryAsync(final DNSMessage query)
153            throws InterruptedException, IOException {
154        final SettableFuture<DNSMessage> responseFuture = SettableFuture.create();
155        synchronized (this.inFlightQueries) {
156            this.inFlightQueries.put(query.id, responseFuture);
157        }
158        this.semaphore.acquire();
159        try {
160            query.writeTo(this.dataOutputStream);
161            this.dataOutputStream.flush();
162        } finally {
163            this.semaphore.release();
164        }
165        return responseFuture;
166    }
167
168    private DNSMessage readDNSMessage() throws IOException {
169        final int length = this.dataInputStream.readUnsignedShort();
170        byte[] data = new byte[length];
171        int read = 0;
172        while (read < length) {
173            read += this.dataInputStream.read(data, read, length - read);
174        }
175        return new DNSMessage(data);
176    }
177
178    @Override
179    public void close() throws IOException {
180        this.socket.close();
181    }
182
183    public void closeQuietly() {
184        try {
185            this.socket.close();
186        } catch (final IOException ignored) {
187
188        }
189    }
190}