1package de.gultsch.minidns;
2
3import android.util.Log;
4
5import androidx.annotation.NonNull;
6
7import com.google.common.base.Throwables;
8import com.google.common.cache.CacheBuilder;
9import com.google.common.cache.CacheLoader;
10import com.google.common.cache.LoadingCache;
11import com.google.common.cache.RemovalListener;
12import com.google.common.collect.ImmutableList;
13
14import eu.siacs.conversations.Config;
15
16import org.minidns.MiniDnsException;
17import org.minidns.dnsmessage.DnsMessage;
18import org.minidns.dnsqueryresult.DnsQueryResult;
19import org.minidns.dnsqueryresult.StandardDnsQueryResult;
20import org.minidns.util.MultipleIoException;
21
22import java.io.IOException;
23import java.net.DatagramPacket;
24import java.net.DatagramSocket;
25import java.net.InetAddress;
26import java.util.ArrayList;
27import java.util.List;
28import java.util.Map;
29import java.util.concurrent.ExecutionException;
30import java.util.concurrent.TimeUnit;
31
32public class NetworkDataSource extends org.minidns.source.NetworkDataSource {
33
34 private static final LoadingCache<DNSServer, DNSSocket> socketCache =
35 CacheBuilder.newBuilder()
36 .removalListener(
37 (RemovalListener<DNSServer, DNSSocket>)
38 notification -> {
39 final DNSServer dnsServer = notification.getKey();
40 final DNSSocket dnsSocket = notification.getValue();
41 if (dnsSocket == null) {
42 return;
43 }
44 Log.d(Config.LOGTAG, "closing connection to " + dnsServer);
45 dnsSocket.closeQuietly();
46 })
47 .expireAfterAccess(5, TimeUnit.MINUTES)
48 .build(
49 new CacheLoader<DNSServer, DNSSocket>() {
50 @Override
51 @NonNull
52 public DNSSocket load(@NonNull final DNSServer dnsServer)
53 throws Exception {
54 Log.d(Config.LOGTAG, "establishing connection to " + dnsServer);
55 return DNSSocket.connect(dnsServer);
56 }
57 });
58
59 private static List<Transport> transportsForPort(final int port) {
60 final ImmutableList.Builder<Transport> transportBuilder = new ImmutableList.Builder<>();
61 for (final Map.Entry<Transport, Integer> entry : Transport.DEFAULT_PORTS.entrySet()) {
62 if (entry.getValue().equals(port)) {
63 transportBuilder.add(entry.getKey());
64 }
65 }
66 return transportBuilder.build();
67 }
68
69 @Override
70 public StandardDnsQueryResult query(
71 final DnsMessage message, final InetAddress address, final int port)
72 throws IOException {
73 final List<Transport> transports = transportsForPort(port);
74 Log.w(
75 Config.LOGTAG,
76 "using legacy DataSource interface. guessing transports "
77 + transports
78 + " from port");
79 if (transports.isEmpty()) {
80 throw new IOException(String.format("No transports found for port %d", port));
81 }
82 return query(message, new DNSServer(address, port, transports));
83 }
84
85 public StandardDnsQueryResult query(final DnsMessage message, final DNSServer dnsServer)
86 throws IOException {
87 Log.d(Config.LOGTAG, "using " + dnsServer);
88 final List<IOException> ioExceptions = new ArrayList<>();
89 for (final Transport transport : dnsServer.transports) {
90 try {
91 final DnsMessage response =
92 queryWithUniqueTransport(message, dnsServer.asUniqueTransport(transport));
93 if (response != null && !response.truncated) {
94 return new StandardDnsQueryResult(
95 dnsServer.inetAddress,
96 dnsServer.port,
97 transportToMethod(transport),
98 message,
99 response);
100 }
101 } catch (final IOException e) {
102 ioExceptions.add(e);
103 } catch (final InterruptedException e) {
104 throw new IOException(e);
105 }
106 }
107 MultipleIoException.throwIfRequired(ioExceptions);
108 return null;
109 }
110
111 private static DnsQueryResult.QueryMethod transportToMethod(final Transport transport) {
112 return switch (transport) {
113 case UDP -> DnsQueryResult.QueryMethod.udp;
114 default -> DnsQueryResult.QueryMethod.tcp;
115 };
116 }
117
118 private DnsMessage queryWithUniqueTransport(final DnsMessage message, final DNSServer dnsServer)
119 throws IOException, InterruptedException {
120 final Transport transport = dnsServer.uniqueTransport();
121 return switch (transport) {
122 case UDP -> queryUdp(message, dnsServer.inetAddress, dnsServer.port);
123 case TCP, TLS -> queryDnsSocket(message, dnsServer);
124 default -> throw new IOException(
125 String.format("Transport %s has not been implemented", transport));
126 };
127 }
128
129 protected DnsMessage queryUdp(
130 final DnsMessage message, final InetAddress address, final int port)
131 throws IOException {
132 final DatagramPacket request = message.asDatagram(address, port);
133 final byte[] buffer = new byte[udpPayloadSize];
134 try (final DatagramSocket socket = new DatagramSocket()) {
135 socket.setSoTimeout(timeout);
136 socket.send(request);
137 final DatagramPacket response = new DatagramPacket(buffer, buffer.length);
138 socket.receive(response);
139 final DnsMessage dnsMessage = readDNSMessage(response.getData());
140 if (dnsMessage.id != message.id) {
141 throw new MiniDnsException.IdMismatch(message, dnsMessage);
142 }
143 return dnsMessage;
144 }
145 }
146
147 protected DnsMessage queryDnsSocket(final DnsMessage message, final DNSServer dnsServer)
148 throws IOException, InterruptedException {
149 final DNSSocket cachedDnsSocket = socketCache.getIfPresent(dnsServer);
150 if (cachedDnsSocket != null) {
151 try {
152 return cachedDnsSocket.query(message);
153 } catch (final IOException e) {
154 Log.d(
155 Config.LOGTAG,
156 "IOException occurred at cached socket. invalidating and falling through to new socket creation");
157 socketCache.invalidate(dnsServer);
158 }
159 }
160 try {
161 return socketCache.get(dnsServer).query(message);
162 } catch (final ExecutionException e) {
163 final Throwable cause = e.getCause();
164 if (cause instanceof IOException) {
165 throw (IOException) cause;
166 } else {
167 throw new IOException(cause);
168 }
169 }
170 }
171
172 public static DnsMessage readDNSMessage(final byte[] bytes) throws IOException {
173 try {
174 return new DnsMessage(bytes);
175 } catch (final IllegalArgumentException e) {
176 throw new IOException(Throwables.getRootCause(e));
177 }
178 }
179}