NetworkDataSource.java

  1package de.gultsch.minidns;
  2
  3import android.util.Log;
  4
  5import androidx.annotation.NonNull;
  6
  7import com.google.common.cache.CacheBuilder;
  8import com.google.common.cache.CacheLoader;
  9import com.google.common.cache.LoadingCache;
 10import com.google.common.cache.RemovalListener;
 11import com.google.common.collect.ImmutableList;
 12
 13import de.measite.minidns.DNSMessage;
 14import de.measite.minidns.MiniDNSException;
 15import de.measite.minidns.source.DNSDataSource;
 16import de.measite.minidns.util.MultipleIoException;
 17
 18import eu.siacs.conversations.Config;
 19
 20import java.io.IOException;
 21import java.net.DatagramPacket;
 22import java.net.DatagramSocket;
 23import java.net.InetAddress;
 24import java.util.ArrayList;
 25import java.util.List;
 26import java.util.Map;
 27import java.util.concurrent.ExecutionException;
 28import java.util.concurrent.TimeUnit;
 29
 30public class NetworkDataSource extends DNSDataSource {
 31
 32    private static final LoadingCache<DNSServer, DNSSocket> socketCache =
 33            CacheBuilder.newBuilder()
 34                    .removalListener(
 35                            (RemovalListener<DNSServer, DNSSocket>)
 36                                    notification -> {
 37                                        final DNSServer dnsServer = notification.getKey();
 38                                        final DNSSocket dnsSocket = notification.getValue();
 39                                        if (dnsSocket == null) {
 40                                            return;
 41                                        }
 42                                        Log.d(Config.LOGTAG, "closing connection to " + dnsServer);
 43                                        dnsSocket.closeQuietly();
 44                                    })
 45                    .expireAfterAccess(5, TimeUnit.MINUTES)
 46                    .build(
 47                            new CacheLoader<DNSServer, DNSSocket>() {
 48                                @Override
 49                                @NonNull
 50                                public DNSSocket load(@NonNull final DNSServer dnsServer)
 51                                        throws Exception {
 52                                    Log.d(Config.LOGTAG, "establishing connection to " + dnsServer);
 53                                    return DNSSocket.connect(dnsServer);
 54                                }
 55                            });
 56
 57    private static List<Transport> transportsForPort(final int port) {
 58        final ImmutableList.Builder<Transport> transportBuilder = new ImmutableList.Builder<>();
 59        for (final Map.Entry<Transport, Integer> entry : Transport.DEFAULT_PORTS.entrySet()) {
 60            if (entry.getValue().equals(port)) {
 61                transportBuilder.add(entry.getKey());
 62            }
 63        }
 64        return transportBuilder.build();
 65    }
 66
 67    @Override
 68    public DNSMessage query(final DNSMessage message, final InetAddress address, final int port)
 69            throws IOException {
 70        final List<Transport> transports = transportsForPort(port);
 71        Log.w(
 72                Config.LOGTAG,
 73                "using legacy DataSource interface. guessing transports "
 74                        + transports
 75                        + " from port");
 76        if (transports.isEmpty()) {
 77            throw new IOException(String.format("No transports found for port %d", port));
 78        }
 79        return query(message, new DNSServer(address, port, transports));
 80    }
 81
 82    public DNSMessage query(final DNSMessage message, final DNSServer dnsServer)
 83            throws IOException {
 84        Log.d(Config.LOGTAG, "using " + dnsServer);
 85        final List<IOException> ioExceptions = new ArrayList<>();
 86        for (final Transport transport : dnsServer.transports) {
 87            try {
 88                final DNSMessage response =
 89                        queryWithUniqueTransport(message, dnsServer.asUniqueTransport(transport));
 90                if (response != null && !response.truncated) {
 91                    return response;
 92                }
 93            } catch (final IOException e) {
 94                ioExceptions.add(e);
 95            } catch (final InterruptedException e) {
 96                throw new IOException(e);
 97            }
 98        }
 99        MultipleIoException.throwIfRequired(ioExceptions);
100        return null;
101    }
102
103    private DNSMessage queryWithUniqueTransport(final DNSMessage message, final DNSServer dnsServer)
104            throws IOException, InterruptedException {
105        final Transport transport = dnsServer.uniqueTransport();
106        switch (transport) {
107            case UDP:
108                return queryUdp(message, dnsServer.inetAddress, dnsServer.port);
109            case TCP:
110            case TLS:
111                return queryDnsSocket(message, dnsServer);
112            default:
113                throw new IOException(
114                        String.format("Transport %s has not been implemented", transport));
115        }
116    }
117
118    protected DNSMessage queryUdp(
119            final DNSMessage message, final InetAddress address, final int port)
120            throws IOException {
121        final DatagramPacket request = message.asDatagram(address, port);
122        final byte[] buffer = new byte[udpPayloadSize];
123        try (final DatagramSocket socket = new DatagramSocket()) {
124            socket.setSoTimeout(timeout);
125            socket.send(request);
126            final DatagramPacket response = new DatagramPacket(buffer, buffer.length);
127            socket.receive(response);
128            DNSMessage dnsMessage = new DNSMessage(response.getData());
129            if (dnsMessage.id != message.id) {
130                throw new MiniDNSException.IdMismatch(message, dnsMessage);
131            }
132            return dnsMessage;
133        }
134    }
135
136    protected DNSMessage queryDnsSocket(final DNSMessage message, final DNSServer dnsServer)
137            throws IOException, InterruptedException {
138        final DNSSocket cachedDnsSocket = socketCache.getIfPresent(dnsServer);
139        if (cachedDnsSocket != null) {
140            try {
141                return cachedDnsSocket.query(message);
142            } catch (final IOException e) {
143                Log.d(
144                        Config.LOGTAG,
145                        "IOException occurred at cached socket. invalidating and falling through to new socket creation");
146                socketCache.invalidate(dnsServer);
147            }
148        }
149        try {
150            return socketCache.get(dnsServer).query(message);
151        } catch (final ExecutionException e) {
152            final Throwable cause = e.getCause();
153            if (cause instanceof IOException) {
154                throw (IOException) cause;
155            } else {
156                throw new IOException(cause);
157            }
158        }
159    }
160}