NetworkDataSource.java

  1package de.gultsch.minidns;
  2
  3import android.util.Log;
  4
  5import androidx.annotation.NonNull;
  6
  7import com.google.common.base.Throwables;
  8import com.google.common.cache.CacheBuilder;
  9import com.google.common.cache.CacheLoader;
 10import com.google.common.cache.LoadingCache;
 11import com.google.common.cache.RemovalListener;
 12import com.google.common.collect.ImmutableList;
 13
 14import de.measite.minidns.DNSMessage;
 15import de.measite.minidns.MiniDNSException;
 16import de.measite.minidns.source.DNSDataSource;
 17import de.measite.minidns.util.MultipleIoException;
 18
 19import eu.siacs.conversations.Config;
 20
 21import java.io.IOException;
 22import java.net.DatagramPacket;
 23import java.net.DatagramSocket;
 24import java.net.InetAddress;
 25import java.util.ArrayList;
 26import java.util.List;
 27import java.util.Map;
 28import java.util.concurrent.ExecutionException;
 29import java.util.concurrent.TimeUnit;
 30
 31public class NetworkDataSource extends DNSDataSource {
 32
 33    private static final LoadingCache<DNSServer, DNSSocket> socketCache =
 34            CacheBuilder.newBuilder()
 35                    .removalListener(
 36                            (RemovalListener<DNSServer, DNSSocket>)
 37                                    notification -> {
 38                                        final DNSServer dnsServer = notification.getKey();
 39                                        final DNSSocket dnsSocket = notification.getValue();
 40                                        if (dnsSocket == null) {
 41                                            return;
 42                                        }
 43                                        Log.d(Config.LOGTAG, "closing connection to " + dnsServer);
 44                                        dnsSocket.closeQuietly();
 45                                    })
 46                    .expireAfterAccess(5, TimeUnit.MINUTES)
 47                    .build(
 48                            new CacheLoader<DNSServer, DNSSocket>() {
 49                                @Override
 50                                @NonNull
 51                                public DNSSocket load(@NonNull final DNSServer dnsServer)
 52                                        throws Exception {
 53                                    Log.d(Config.LOGTAG, "establishing connection to " + dnsServer);
 54                                    return DNSSocket.connect(dnsServer);
 55                                }
 56                            });
 57
 58    private static List<Transport> transportsForPort(final int port) {
 59        final ImmutableList.Builder<Transport> transportBuilder = new ImmutableList.Builder<>();
 60        for (final Map.Entry<Transport, Integer> entry : Transport.DEFAULT_PORTS.entrySet()) {
 61            if (entry.getValue().equals(port)) {
 62                transportBuilder.add(entry.getKey());
 63            }
 64        }
 65        return transportBuilder.build();
 66    }
 67
 68    @Override
 69    public DNSMessage query(final DNSMessage message, final InetAddress address, final int port)
 70            throws IOException {
 71        final List<Transport> transports = transportsForPort(port);
 72        Log.w(
 73                Config.LOGTAG,
 74                "using legacy DataSource interface. guessing transports "
 75                        + transports
 76                        + " from port");
 77        if (transports.isEmpty()) {
 78            throw new IOException(String.format("No transports found for port %d", port));
 79        }
 80        return query(message, new DNSServer(address, port, transports));
 81    }
 82
 83    public DNSMessage query(final DNSMessage message, final DNSServer dnsServer)
 84            throws IOException {
 85        Log.d(Config.LOGTAG, "using " + dnsServer);
 86        final List<IOException> ioExceptions = new ArrayList<>();
 87        for (final Transport transport : dnsServer.transports) {
 88            try {
 89                final DNSMessage response =
 90                        queryWithUniqueTransport(message, dnsServer.asUniqueTransport(transport));
 91                if (response != null && !response.truncated) {
 92                    return response;
 93                }
 94            } catch (final IOException e) {
 95                ioExceptions.add(e);
 96            } catch (final InterruptedException e) {
 97                throw new IOException(e);
 98            }
 99        }
100        MultipleIoException.throwIfRequired(ioExceptions);
101        return null;
102    }
103
104    private DNSMessage queryWithUniqueTransport(final DNSMessage message, final DNSServer dnsServer)
105            throws IOException, InterruptedException {
106        final Transport transport = dnsServer.uniqueTransport();
107        switch (transport) {
108            case UDP:
109                return queryUdp(message, dnsServer.inetAddress, dnsServer.port);
110            case TCP:
111            case TLS:
112                return queryDnsSocket(message, dnsServer);
113            default:
114                throw new IOException(
115                        String.format("Transport %s has not been implemented", transport));
116        }
117    }
118
119    protected DNSMessage queryUdp(
120            final DNSMessage message, final InetAddress address, final int port)
121            throws IOException {
122        final DatagramPacket request = message.asDatagram(address, port);
123        final byte[] buffer = new byte[udpPayloadSize];
124        try (final DatagramSocket socket = new DatagramSocket()) {
125            socket.setSoTimeout(timeout);
126            socket.send(request);
127            final DatagramPacket response = new DatagramPacket(buffer, buffer.length);
128            socket.receive(response);
129            final DNSMessage dnsMessage = readDNSMessage(response.getData());
130            if (dnsMessage.id != message.id) {
131                throw new MiniDNSException.IdMismatch(message, dnsMessage);
132            }
133            return dnsMessage;
134        }
135    }
136
137    protected DNSMessage queryDnsSocket(final DNSMessage message, final DNSServer dnsServer)
138            throws IOException, InterruptedException {
139        final DNSSocket cachedDnsSocket = socketCache.getIfPresent(dnsServer);
140        if (cachedDnsSocket != null) {
141            try {
142                return cachedDnsSocket.query(message);
143            } catch (final IOException e) {
144                Log.d(
145                        Config.LOGTAG,
146                        "IOException occurred at cached socket. invalidating and falling through to new socket creation");
147                socketCache.invalidate(dnsServer);
148            }
149        }
150        try {
151            return socketCache.get(dnsServer).query(message);
152        } catch (final ExecutionException e) {
153            final Throwable cause = e.getCause();
154            if (cause instanceof IOException) {
155                throw (IOException) cause;
156            } else {
157                throw new IOException(cause);
158            }
159        }
160    }
161
162    public static DNSMessage readDNSMessage(final byte[] bytes) throws IOException {
163        try {
164            return new DNSMessage(bytes);
165        } catch (final IllegalArgumentException e) {
166            throw new IOException(Throwables.getRootCause(e));
167        }
168    }
169}