NetworkDataSource.java

  1package de.gultsch.minidns;
  2
  3import android.util.Log;
  4
  5import androidx.annotation.NonNull;
  6
  7import com.google.common.base.Throwables;
  8import com.google.common.cache.CacheBuilder;
  9import com.google.common.cache.CacheLoader;
 10import com.google.common.cache.LoadingCache;
 11import com.google.common.cache.RemovalListener;
 12import com.google.common.collect.ImmutableList;
 13
 14import eu.siacs.conversations.Config;
 15
 16import org.minidns.MiniDnsException;
 17import org.minidns.dnsmessage.DnsMessage;
 18import org.minidns.dnsqueryresult.DnsQueryResult;
 19import org.minidns.dnsqueryresult.StandardDnsQueryResult;
 20import org.minidns.util.MultipleIoException;
 21
 22import java.io.IOException;
 23import java.net.DatagramPacket;
 24import java.net.DatagramSocket;
 25import java.net.InetAddress;
 26import java.util.ArrayList;
 27import java.util.List;
 28import java.util.Map;
 29import java.util.concurrent.ExecutionException;
 30import java.util.concurrent.TimeUnit;
 31
 32public class NetworkDataSource extends org.minidns.source.NetworkDataSource {
 33
 34    private static final LoadingCache<DNSServer, DNSSocket> socketCache =
 35            CacheBuilder.newBuilder()
 36                    .removalListener(
 37                            (RemovalListener<DNSServer, DNSSocket>)
 38                                    notification -> {
 39                                        final DNSServer dnsServer = notification.getKey();
 40                                        final DNSSocket dnsSocket = notification.getValue();
 41                                        if (dnsSocket == null) {
 42                                            return;
 43                                        }
 44                                        Log.d(Config.LOGTAG, "closing connection to " + dnsServer);
 45                                        dnsSocket.closeQuietly();
 46                                    })
 47                    .expireAfterAccess(5, TimeUnit.MINUTES)
 48                    .build(
 49                            new CacheLoader<DNSServer, DNSSocket>() {
 50                                @Override
 51                                @NonNull
 52                                public DNSSocket load(@NonNull final DNSServer dnsServer)
 53                                        throws Exception {
 54                                    Log.d(Config.LOGTAG, "establishing connection to " + dnsServer);
 55                                    return DNSSocket.connect(dnsServer);
 56                                }
 57                            });
 58
 59    private static List<Transport> transportsForPort(final int port) {
 60        final ImmutableList.Builder<Transport> transportBuilder = new ImmutableList.Builder<>();
 61        for (final Map.Entry<Transport, Integer> entry : Transport.DEFAULT_PORTS.entrySet()) {
 62            if (entry.getValue().equals(port)) {
 63                transportBuilder.add(entry.getKey());
 64            }
 65        }
 66        return transportBuilder.build();
 67    }
 68
 69    @Override
 70    public StandardDnsQueryResult query(
 71            final DnsMessage message, final InetAddress address, final int port)
 72            throws IOException {
 73        final List<Transport> transports = transportsForPort(port);
 74        Log.w(
 75                Config.LOGTAG,
 76                "using legacy DataSource interface. guessing transports "
 77                        + transports
 78                        + " from port");
 79        if (transports.isEmpty()) {
 80            throw new IOException(String.format("No transports found for port %d", port));
 81        }
 82        return query(message, new DNSServer(address, port, transports));
 83    }
 84
 85    public StandardDnsQueryResult query(final DnsMessage message, final DNSServer dnsServer)
 86            throws IOException {
 87        Log.d(Config.LOGTAG, "using " + dnsServer);
 88        final List<IOException> ioExceptions = new ArrayList<>();
 89        for (final Transport transport : dnsServer.transports) {
 90            try {
 91                final DnsMessage response =
 92                        queryWithUniqueTransport(message, dnsServer.asUniqueTransport(transport));
 93                if (response != null && !response.truncated) {
 94                    return new StandardDnsQueryResult(
 95                            dnsServer.inetAddress,
 96                            dnsServer.port,
 97                            transportToMethod(transport),
 98                            message,
 99                            response);
100                }
101            } catch (final IOException e) {
102                ioExceptions.add(e);
103            } catch (final InterruptedException e) {
104                throw new IOException(e);
105            }
106        }
107        MultipleIoException.throwIfRequired(ioExceptions);
108        return null;
109    }
110
111    private static DnsQueryResult.QueryMethod transportToMethod(final Transport transport) {
112        return switch (transport) {
113            case UDP -> DnsQueryResult.QueryMethod.udp;
114            default -> DnsQueryResult.QueryMethod.tcp;
115        };
116    }
117
118    private DnsMessage queryWithUniqueTransport(final DnsMessage message, final DNSServer dnsServer)
119            throws IOException, InterruptedException {
120        final Transport transport = dnsServer.uniqueTransport();
121        return switch (transport) {
122            case UDP -> queryUdp(message, dnsServer.inetAddress, dnsServer.port);
123            case TCP, TLS -> queryDnsSocket(message, dnsServer);
124            default -> throw new IOException(
125                    String.format("Transport %s has not been implemented", transport));
126        };
127    }
128
129    protected DnsMessage queryUdp(
130            final DnsMessage message, final InetAddress address, final int port)
131            throws IOException {
132        final DatagramPacket request = message.asDatagram(address, port);
133        final byte[] buffer = new byte[udpPayloadSize];
134        try (final DatagramSocket socket = new DatagramSocket()) {
135            socket.setSoTimeout(timeout);
136            socket.send(request);
137            final DatagramPacket response = new DatagramPacket(buffer, buffer.length);
138            socket.receive(response);
139            final DnsMessage dnsMessage = readDNSMessage(response.getData());
140            if (dnsMessage.id != message.id) {
141                throw new MiniDnsException.IdMismatch(message, dnsMessage);
142            }
143            return dnsMessage;
144        }
145    }
146
147    protected DnsMessage queryDnsSocket(final DnsMessage message, final DNSServer dnsServer)
148            throws IOException, InterruptedException {
149        final DNSSocket cachedDnsSocket = socketCache.getIfPresent(dnsServer);
150        if (cachedDnsSocket != null) {
151            try {
152                return cachedDnsSocket.query(message);
153            } catch (final IOException e) {
154                Log.d(
155                        Config.LOGTAG,
156                        "IOException occurred at cached socket. invalidating and falling through to new socket creation");
157                socketCache.invalidate(dnsServer);
158            }
159        }
160        try {
161            return socketCache.get(dnsServer).query(message);
162        } catch (final ExecutionException e) {
163            final Throwable cause = e.getCause();
164            if (cause instanceof IOException) {
165                throw (IOException) cause;
166            } else {
167                throw new IOException(cause);
168            }
169        }
170    }
171
172    public static DnsMessage readDNSMessage(final byte[] bytes) throws IOException {
173        try {
174            return new DnsMessage(bytes);
175        } catch (final IllegalArgumentException e) {
176            throw new IOException(Throwables.getRootCause(e));
177        }
178    }
179}