NetworkDataSource.java

  1package de.gultsch.minidns;
  2
  3import android.util.Log;
  4
  5import androidx.annotation.NonNull;
  6
  7import com.google.common.base.Throwables;
  8import com.google.common.cache.CacheBuilder;
  9import com.google.common.cache.CacheLoader;
 10import com.google.common.cache.LoadingCache;
 11import com.google.common.cache.RemovalListener;
 12import com.google.common.collect.ImmutableList;
 13
 14import eu.siacs.conversations.Config;
 15
 16import org.minidns.MiniDnsException;
 17import org.minidns.dnsmessage.DnsMessage;
 18import org.minidns.dnsname.InvalidDnsNameException;
 19import org.minidns.dnsqueryresult.DnsQueryResult;
 20import org.minidns.dnsqueryresult.StandardDnsQueryResult;
 21import org.minidns.util.MultipleIoException;
 22
 23import java.io.IOException;
 24import java.net.DatagramPacket;
 25import java.net.DatagramSocket;
 26import java.net.InetAddress;
 27import java.util.ArrayList;
 28import java.util.List;
 29import java.util.Map;
 30import java.util.concurrent.ExecutionException;
 31import java.util.concurrent.TimeUnit;
 32
 33public class NetworkDataSource extends org.minidns.source.NetworkDataSource {
 34
 35    private static final LoadingCache<DNSServer, DNSSocket> socketCache =
 36            CacheBuilder.newBuilder()
 37                    .removalListener(
 38                            (RemovalListener<DNSServer, DNSSocket>)
 39                                    notification -> {
 40                                        final DNSServer dnsServer = notification.getKey();
 41                                        final DNSSocket dnsSocket = notification.getValue();
 42                                        if (dnsSocket == null) {
 43                                            return;
 44                                        }
 45                                        Log.d(Config.LOGTAG, "closing connection to " + dnsServer);
 46                                        dnsSocket.closeQuietly();
 47                                    })
 48                    .expireAfterAccess(5, TimeUnit.MINUTES)
 49                    .build(
 50                            new CacheLoader<DNSServer, DNSSocket>() {
 51                                @Override
 52                                @NonNull
 53                                public DNSSocket load(@NonNull final DNSServer dnsServer)
 54                                        throws Exception {
 55                                    Log.d(Config.LOGTAG, "establishing connection to " + dnsServer);
 56                                    return DNSSocket.connect(dnsServer);
 57                                }
 58                            });
 59
 60    private static List<Transport> transportsForPort(final int port) {
 61        final ImmutableList.Builder<Transport> transportBuilder = new ImmutableList.Builder<>();
 62        for (final Map.Entry<Transport, Integer> entry : Transport.DEFAULT_PORTS.entrySet()) {
 63            if (entry.getValue().equals(port)) {
 64                transportBuilder.add(entry.getKey());
 65            }
 66        }
 67        return transportBuilder.build();
 68    }
 69
 70    @Override
 71    public StandardDnsQueryResult query(
 72            final DnsMessage message, final InetAddress address, final int port)
 73            throws IOException {
 74        final List<Transport> transports = transportsForPort(port);
 75        Log.w(
 76                Config.LOGTAG,
 77                "using legacy DataSource interface. guessing transports "
 78                        + transports
 79                        + " from port");
 80        if (transports.isEmpty()) {
 81            throw new IOException(String.format("No transports found for port %d", port));
 82        }
 83        return query(message, new DNSServer(address, port, transports));
 84    }
 85
 86    public StandardDnsQueryResult query(final DnsMessage message, final DNSServer dnsServer)
 87            throws IOException {
 88        Log.d(Config.LOGTAG, "using " + dnsServer);
 89        final List<IOException> ioExceptions = new ArrayList<>();
 90        for (final Transport transport : dnsServer.transports) {
 91            try {
 92                final DnsMessage response =
 93                        queryWithUniqueTransport(message, dnsServer.asUniqueTransport(transport));
 94                if (response != null && !response.truncated) {
 95                    return new StandardDnsQueryResult(
 96                            dnsServer.inetAddress,
 97                            dnsServer.port,
 98                            transportToMethod(transport),
 99                            message,
100                            response);
101                }
102            } catch (final IOException e) {
103                ioExceptions.add(e);
104            } catch (final InterruptedException e) {
105                throw new IOException(e);
106            }
107        }
108        MultipleIoException.throwIfRequired(ioExceptions);
109        return null;
110    }
111
112    private static DnsQueryResult.QueryMethod transportToMethod(final Transport transport) {
113        return switch (transport) {
114            case UDP -> DnsQueryResult.QueryMethod.udp;
115            default -> DnsQueryResult.QueryMethod.tcp;
116        };
117    }
118
119    private DnsMessage queryWithUniqueTransport(final DnsMessage message, final DNSServer dnsServer)
120            throws IOException, InterruptedException {
121        final Transport transport = dnsServer.uniqueTransport();
122        return switch (transport) {
123            case UDP -> queryUdp(message, dnsServer.inetAddress, dnsServer.port);
124            case TCP, TLS -> queryDnsSocket(message, dnsServer);
125            default -> throw new IOException(
126                    String.format("Transport %s has not been implemented", transport));
127        };
128    }
129
130    protected DnsMessage queryUdp(
131            final DnsMessage message, final InetAddress address, final int port)
132            throws IOException {
133        final DatagramPacket request = message.asDatagram(address, port);
134        final byte[] buffer = new byte[udpPayloadSize];
135        try (final DatagramSocket socket = new DatagramSocket()) {
136            socket.setSoTimeout(timeout);
137            socket.send(request);
138            final DatagramPacket response = new DatagramPacket(buffer, buffer.length);
139            socket.receive(response);
140            final DnsMessage dnsMessage = readDNSMessage(response.getData());
141            if (dnsMessage.id != message.id) {
142                throw new MiniDnsException.IdMismatch(message, dnsMessage);
143            }
144            return dnsMessage;
145        }
146    }
147
148    protected DnsMessage queryDnsSocket(final DnsMessage message, final DNSServer dnsServer)
149            throws IOException, InterruptedException {
150        final DNSSocket cachedDnsSocket = socketCache.getIfPresent(dnsServer);
151        if (cachedDnsSocket != null) {
152            try {
153                return cachedDnsSocket.query(message);
154            } catch (final IOException e) {
155                Log.d(
156                        Config.LOGTAG,
157                        "IOException occurred at cached socket. invalidating and falling through to new socket creation");
158                socketCache.invalidate(dnsServer);
159            }
160        }
161        try {
162            return socketCache.get(dnsServer).query(message);
163        } catch (final ExecutionException e) {
164            final Throwable cause = e.getCause();
165            if (cause instanceof IOException) {
166                throw (IOException) cause;
167            } else {
168                throw new IOException(cause);
169            }
170        }
171    }
172
173    public static DnsMessage readDNSMessage(final byte[] bytes) throws IOException {
174        try {
175            return new DnsMessage(bytes);
176        } catch (final InvalidDnsNameException | IllegalArgumentException e) {
177            throw new IOException(Throwables.getRootCause(e));
178        }
179    }
180}