@@ -5,18 +5,30 @@ import android.net.ConnectivityManager;
import android.net.LinkProperties;
import android.net.Network;
import android.os.Build;
+import android.util.Log;
+import androidx.collection.LruCache;
+
+import com.google.common.base.Objects;
import com.google.common.base.Strings;
+import com.google.common.collect.Collections2;
import com.google.common.collect.ImmutableList;
import de.measite.minidns.AbstractDNSClient;
import de.measite.minidns.DNSMessage;
+import eu.siacs.conversations.Config;
+
import java.io.IOException;
import java.net.InetAddress;
+import java.time.Duration;
+import java.util.Collections;
import java.util.List;
public class AndroidDNSClient extends AbstractDNSClient {
+
+ private static final LruCache<QuestionServerTuple, DNSMessage> QUERY_CACHE =
+ new LruCache<>(1024);
private final Context context;
private final NetworkDataSource networkDataSource = new NetworkDataSource();
private boolean askForDnssec = false;
@@ -56,6 +68,8 @@ public class AndroidDNSClient extends AbstractDNSClient {
protected DNSMessage query(final DNSMessage.Builder queryBuilder) throws IOException {
final DNSMessage question = newQuestion(queryBuilder).build();
for (final DNSServer dnsServer : getDNSServers()) {
+ final QuestionServerTuple cacheKey = new QuestionServerTuple(dnsServer, question);
+ final DNSMessage cachedResponse = queryCache(cacheKey);
final DNSMessage response = this.networkDataSource.query(question, dnsServer);
if (response == null) {
continue;
@@ -67,7 +81,7 @@ public class AndroidDNSClient extends AbstractDNSClient {
default:
continue;
}
-
+ cacheQuery(cacheKey, response);
return response;
}
return null;
@@ -120,4 +134,68 @@ public class AndroidDNSClient extends AbstractDNSClient {
}
return connectivityManager.getAllNetworks();
}
+
+ private DNSMessage queryCache(final QuestionServerTuple key) {
+ final DNSMessage cachedResponse;
+ synchronized (QUERY_CACHE) {
+ cachedResponse = QUERY_CACHE.get(key);
+ if (cachedResponse == null) {
+ return null;
+ }
+ final long expiresIn = expiresIn(cachedResponse);
+ if (expiresIn < 0) {
+ QUERY_CACHE.remove(key);
+ return null;
+ }
+ if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
+ Log.d(
+ Config.LOGTAG,
+ "DNS query came from cache. expires in " + Duration.ofMillis(expiresIn));
+ }
+ }
+ return cachedResponse;
+ }
+
+ private void cacheQuery(final QuestionServerTuple key, final DNSMessage response) {
+ if (response.receiveTimestamp <= 0) {
+ return;
+ }
+ synchronized (QUERY_CACHE) {
+ QUERY_CACHE.put(key, response);
+ }
+ }
+
+ private static long expiresAt(final DNSMessage dnsMessage) {
+ return dnsMessage.receiveTimestamp
+ + (Collections.min(Collections2.transform(dnsMessage.answerSection, d -> d.ttl))
+ * 1000L);
+ }
+
+ private static long expiresIn(final DNSMessage dnsMessage) {
+ return expiresAt(dnsMessage) - System.currentTimeMillis();
+ }
+
+ private static class QuestionServerTuple {
+ private final DNSServer dnsServer;
+ private final DNSMessage question;
+
+ private QuestionServerTuple(final DNSServer dnsServer, final DNSMessage question) {
+ this.dnsServer = dnsServer;
+ this.question = question.asNormalizedVersion();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ QuestionServerTuple that = (QuestionServerTuple) o;
+ return Objects.equal(dnsServer, that.dnsServer)
+ && Objects.equal(question, that.question);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(dnsServer, question);
+ }
+ }
}