1package de.gultsch.minidns;
2
3import android.util.Log;
4
5import androidx.annotation.NonNull;
6
7import com.google.common.cache.CacheBuilder;
8import com.google.common.cache.CacheLoader;
9import com.google.common.cache.LoadingCache;
10import com.google.common.cache.RemovalListener;
11import com.google.common.collect.ImmutableList;
12
13import de.measite.minidns.DNSMessage;
14import de.measite.minidns.MiniDNSException;
15import de.measite.minidns.source.DNSDataSource;
16import de.measite.minidns.util.MultipleIoException;
17
18import eu.siacs.conversations.Config;
19
20import java.io.IOException;
21import java.net.DatagramPacket;
22import java.net.DatagramSocket;
23import java.net.InetAddress;
24import java.util.ArrayList;
25import java.util.List;
26import java.util.Map;
27import java.util.concurrent.ExecutionException;
28import java.util.concurrent.TimeUnit;
29
30public class NetworkDataSource extends DNSDataSource {
31
32 private static final LoadingCache<DNSServer, DNSSocket> socketCache =
33 CacheBuilder.newBuilder()
34 .removalListener(
35 (RemovalListener<DNSServer, DNSSocket>)
36 notification -> {
37 final DNSServer dnsServer = notification.getKey();
38 final DNSSocket dnsSocket = notification.getValue();
39 if (dnsSocket == null) {
40 return;
41 }
42 Log.d(Config.LOGTAG, "closing connection to " + dnsServer);
43 dnsSocket.closeQuietly();
44 })
45 .expireAfterAccess(5, TimeUnit.MINUTES)
46 .build(
47 new CacheLoader<DNSServer, DNSSocket>() {
48 @Override
49 @NonNull
50 public DNSSocket load(@NonNull final DNSServer dnsServer)
51 throws Exception {
52 Log.d(Config.LOGTAG, "establishing connection to " + dnsServer);
53 return DNSSocket.connect(dnsServer);
54 }
55 });
56
57 private static List<Transport> transportsForPort(final int port) {
58 final ImmutableList.Builder<Transport> transportBuilder = new ImmutableList.Builder<>();
59 for (final Map.Entry<Transport, Integer> entry : Transport.DEFAULT_PORTS.entrySet()) {
60 if (entry.getValue().equals(port)) {
61 transportBuilder.add(entry.getKey());
62 }
63 }
64 return transportBuilder.build();
65 }
66
67 @Override
68 public DNSMessage query(final DNSMessage message, final InetAddress address, final int port)
69 throws IOException {
70 final List<Transport> transports = transportsForPort(port);
71 Log.w(
72 Config.LOGTAG,
73 "using legacy DataSource interface. guessing transports "
74 + transports
75 + " from port");
76 if (transports.isEmpty()) {
77 throw new IOException(String.format("No transports found for port %d", port));
78 }
79 return query(message, new DNSServer(address, port, transports));
80 }
81
82 public DNSMessage query(final DNSMessage message, final DNSServer dnsServer)
83 throws IOException {
84 Log.d(Config.LOGTAG, "using " + dnsServer);
85 final List<IOException> ioExceptions = new ArrayList<>();
86 for (final Transport transport : dnsServer.transports) {
87 try {
88 final DNSMessage response =
89 queryWithUniqueTransport(message, dnsServer.asUniqueTransport(transport));
90 if (response != null && !response.truncated) {
91 return response;
92 }
93 } catch (final IOException e) {
94 ioExceptions.add(e);
95 } catch (final InterruptedException e) {
96 throw new IOException(e);
97 }
98 }
99 MultipleIoException.throwIfRequired(ioExceptions);
100 return null;
101 }
102
103 private DNSMessage queryWithUniqueTransport(final DNSMessage message, final DNSServer dnsServer)
104 throws IOException, InterruptedException {
105 final Transport transport = dnsServer.uniqueTransport();
106 switch (transport) {
107 case UDP:
108 return queryUdp(message, dnsServer.inetAddress, dnsServer.port);
109 case TCP:
110 case TLS:
111 return queryDnsSocket(message, dnsServer);
112 default:
113 throw new IOException(
114 String.format("Transport %s has not been implemented", transport));
115 }
116 }
117
118 protected DNSMessage queryUdp(
119 final DNSMessage message, final InetAddress address, final int port)
120 throws IOException {
121 final DatagramPacket request = message.asDatagram(address, port);
122 final byte[] buffer = new byte[udpPayloadSize];
123 try (final DatagramSocket socket = new DatagramSocket()) {
124 socket.setSoTimeout(timeout);
125 socket.send(request);
126 final DatagramPacket response = new DatagramPacket(buffer, buffer.length);
127 socket.receive(response);
128 DNSMessage dnsMessage = new DNSMessage(response.getData());
129 if (dnsMessage.id != message.id) {
130 throw new MiniDNSException.IdMismatch(message, dnsMessage);
131 }
132 return dnsMessage;
133 }
134 }
135
136 protected DNSMessage queryDnsSocket(final DNSMessage message, final DNSServer dnsServer)
137 throws IOException, InterruptedException {
138 final DNSSocket cachedDnsSocket = socketCache.getIfPresent(dnsServer);
139 if (cachedDnsSocket != null) {
140 try {
141 return cachedDnsSocket.query(message);
142 } catch (final IOException e) {
143 Log.d(
144 Config.LOGTAG,
145 "IOException occurred at cached socket. invalidating and falling through to new socket creation");
146 socketCache.invalidate(dnsServer);
147 }
148 }
149 try {
150 return socketCache.get(dnsServer).query(message);
151 } catch (final ExecutionException e) {
152 final Throwable cause = e.getCause();
153 if (cause instanceof IOException) {
154 throw (IOException) cause;
155 } else {
156 throw new IOException(cause);
157 }
158 }
159 }
160}