DNSSocket.java

  1package de.gultsch.minidns;
  2
  3import android.util.Log;
  4
  5import com.google.common.base.Preconditions;
  6import com.google.common.base.Strings;
  7import com.google.common.util.concurrent.ListenableFuture;
  8import com.google.common.util.concurrent.SettableFuture;
  9
 10
 11import eu.siacs.conversations.Config;
 12
 13import org.conscrypt.OkHostnameVerifier;
 14import org.minidns.dnsmessage.DnsMessage;
 15
 16import java.io.Closeable;
 17import java.io.DataInputStream;
 18import java.io.DataOutputStream;
 19import java.io.EOFException;
 20import java.io.IOException;
 21import java.net.InetSocketAddress;
 22import java.net.Socket;
 23import java.net.SocketAddress;
 24import java.security.cert.Certificate;
 25import java.security.cert.X509Certificate;
 26import java.util.HashMap;
 27import java.util.Iterator;
 28import java.util.Map;
 29import java.util.concurrent.ExecutionException;
 30import java.util.concurrent.Semaphore;
 31import java.util.concurrent.TimeUnit;
 32import java.util.concurrent.TimeoutException;
 33
 34import javax.net.ssl.SSLPeerUnverifiedException;
 35import javax.net.ssl.SSLSession;
 36import javax.net.ssl.SSLSocket;
 37import javax.net.ssl.SSLSocketFactory;
 38
 39final class DNSSocket implements Closeable {
 40
 41    public static final int QUERY_TIMEOUT = 5_000;
 42
 43    private final Semaphore semaphore = new Semaphore(1);
 44    private final Map<Integer, SettableFuture<DnsMessage>> inFlightQueries = new HashMap<>();
 45    private final Socket socket;
 46    private final DataInputStream dataInputStream;
 47    private final DataOutputStream dataOutputStream;
 48
 49    private DNSSocket(
 50            final Socket socket,
 51            final DataInputStream dataInputStream,
 52            final DataOutputStream dataOutputStream) {
 53        this.socket = socket;
 54        this.dataInputStream = dataInputStream;
 55        this.dataOutputStream = dataOutputStream;
 56        new Thread(this::readDNSMessages).start();
 57    }
 58
 59    private void readDNSMessages() {
 60        try {
 61            while (socket.isConnected()) {
 62                final DnsMessage response = readDNSMessage();
 63                final SettableFuture<DnsMessage> future;
 64                synchronized (inFlightQueries) {
 65                    future = inFlightQueries.remove(response.id);
 66                }
 67                if (future != null) {
 68                    future.set(response);
 69                } else {
 70                    Log.e(Config.LOGTAG, "no in flight query found for response id " + response.id);
 71                }
 72            }
 73            evictInFlightQueries(new EOFException());
 74        } catch (final IOException e) {
 75            evictInFlightQueries(e);
 76        }
 77    }
 78
 79    private void evictInFlightQueries(final Exception e) {
 80        synchronized (inFlightQueries) {
 81            final Iterator<Map.Entry<Integer, SettableFuture<DnsMessage>>> iterator =
 82                    inFlightQueries.entrySet().iterator();
 83            while (iterator.hasNext()) {
 84                final Map.Entry<Integer, SettableFuture<DnsMessage>> entry = iterator.next();
 85                entry.getValue().setException(e);
 86                iterator.remove();
 87            }
 88        }
 89    }
 90
 91    private static DNSSocket of(final Socket socket) throws IOException {
 92        final DataInputStream dataInputStream = new DataInputStream(socket.getInputStream());
 93        final DataOutputStream dataOutputStream = new DataOutputStream(socket.getOutputStream());
 94        return new DNSSocket(socket, dataInputStream, dataOutputStream);
 95    }
 96
 97    public static DNSSocket connect(final DNSServer dnsServer) throws IOException {
 98        return switch (dnsServer.uniqueTransport()) {
 99            case TCP -> connectTcpSocket(dnsServer);
100            case TLS -> connectTlsSocket(dnsServer);
101            default -> throw new IllegalStateException("This is not a socket based transport");
102        };
103    }
104
105    private static DNSSocket connectTcpSocket(final DNSServer dnsServer) throws IOException {
106        Preconditions.checkArgument(dnsServer.uniqueTransport() == Transport.TCP);
107        final SocketAddress socketAddress =
108                new InetSocketAddress(dnsServer.inetAddress, dnsServer.port);
109        final Socket socket = new Socket();
110        socket.connect(socketAddress, QUERY_TIMEOUT / 2);
111        socket.setSoTimeout(QUERY_TIMEOUT);
112        return DNSSocket.of(socket);
113    }
114
115    private static DNSSocket connectTlsSocket(final DNSServer dnsServer) throws IOException {
116        Preconditions.checkArgument(dnsServer.uniqueTransport() == Transport.TLS);
117        final SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault();
118        final SSLSocket sslSocket = (SSLSocket) factory.createSocket();
119        if (Strings.isNullOrEmpty(dnsServer.hostname)) {
120            final SocketAddress socketAddress =
121                    new InetSocketAddress(dnsServer.inetAddress, dnsServer.port);
122            sslSocket.connect(socketAddress, QUERY_TIMEOUT / 2);
123            sslSocket.setSoTimeout(QUERY_TIMEOUT);
124            sslSocket.startHandshake();
125        } else {
126            final SocketAddress socketAddress =
127                    new InetSocketAddress(dnsServer.hostname, dnsServer.port);
128            sslSocket.connect(socketAddress, QUERY_TIMEOUT / 2);
129            sslSocket.setSoTimeout(QUERY_TIMEOUT);
130            sslSocket.startHandshake();
131            final SSLSession session = sslSocket.getSession();
132            final Certificate[] peerCertificates = session.getPeerCertificates();
133            if (peerCertificates.length == 0 || !(peerCertificates[0] instanceof X509Certificate certificate)) {
134                throw new IOException("Peer did not provide X509 certificates");
135            }
136            if (!OkHostnameVerifier.strictInstance().verify(dnsServer.hostname, certificate)) {
137                throw new SSLPeerUnverifiedException("Peer did not provide valid certificates");
138            }
139        }
140        return DNSSocket.of(sslSocket);
141    }
142
143    public DnsMessage query(final DnsMessage query) throws IOException, InterruptedException {
144        try {
145            return queryAsync(query).get(QUERY_TIMEOUT, TimeUnit.MILLISECONDS);
146        } catch (final ExecutionException e) {
147            final Throwable cause = e.getCause();
148            if (cause instanceof IOException) {
149                throw (IOException) cause;
150            } else {
151                throw new IOException(e);
152            }
153        } catch (final TimeoutException e) {
154            throw new IOException(e);
155        }
156    }
157
158    public ListenableFuture<DnsMessage> queryAsync(final DnsMessage query)
159            throws InterruptedException, IOException {
160        final SettableFuture<DnsMessage> responseFuture = SettableFuture.create();
161        synchronized (this.inFlightQueries) {
162            this.inFlightQueries.put(query.id, responseFuture);
163        }
164        this.semaphore.acquire();
165        try {
166            query.writeTo(this.dataOutputStream);
167            this.dataOutputStream.flush();
168        } finally {
169            this.semaphore.release();
170        }
171        return responseFuture;
172    }
173
174    private DnsMessage readDNSMessage() throws IOException {
175        final int length = this.dataInputStream.readUnsignedShort();
176        byte[] data = new byte[length];
177        int read = 0;
178        while (read < length) {
179            read += this.dataInputStream.read(data, read, length - read);
180        }
181        return NetworkDataSource.readDNSMessage(data);
182    }
183
184    @Override
185    public void close() throws IOException {
186        this.socket.close();
187    }
188
189    public void closeQuietly() {
190        try {
191            this.socket.close();
192        } catch (final IOException ignored) {
193
194        }
195    }
196}