keep cache of DNS messages

Daniel Gultsch created

Change summary

src/main/java/de/gultsch/minidns/AndroidDNSClient.java | 80 +++++++++++
1 file changed, 79 insertions(+), 1 deletion(-)

Detailed changes

src/main/java/de/gultsch/minidns/AndroidDNSClient.java 🔗

@@ -5,18 +5,30 @@ import android.net.ConnectivityManager;
 import android.net.LinkProperties;
 import android.net.Network;
 import android.os.Build;
+import android.util.Log;
 
+import androidx.collection.LruCache;
+
+import com.google.common.base.Objects;
 import com.google.common.base.Strings;
+import com.google.common.collect.Collections2;
 import com.google.common.collect.ImmutableList;
 
 import de.measite.minidns.AbstractDNSClient;
 import de.measite.minidns.DNSMessage;
 
+import eu.siacs.conversations.Config;
+
 import java.io.IOException;
 import java.net.InetAddress;
+import java.time.Duration;
+import java.util.Collections;
 import java.util.List;
 
 public class AndroidDNSClient extends AbstractDNSClient {
+
+    private static final LruCache<QuestionServerTuple, DNSMessage> QUERY_CACHE =
+            new LruCache<>(1024);
     private final Context context;
     private final NetworkDataSource networkDataSource = new NetworkDataSource();
     private boolean askForDnssec = false;
@@ -56,6 +68,8 @@ public class AndroidDNSClient extends AbstractDNSClient {
     protected DNSMessage query(final DNSMessage.Builder queryBuilder) throws IOException {
         final DNSMessage question = newQuestion(queryBuilder).build();
         for (final DNSServer dnsServer : getDNSServers()) {
+            final QuestionServerTuple cacheKey = new QuestionServerTuple(dnsServer, question);
+            final DNSMessage cachedResponse = queryCache(cacheKey);
             final DNSMessage response = this.networkDataSource.query(question, dnsServer);
             if (response == null) {
                 continue;
@@ -67,7 +81,7 @@ public class AndroidDNSClient extends AbstractDNSClient {
                 default:
                     continue;
             }
-
+            cacheQuery(cacheKey, response);
             return response;
         }
         return null;
@@ -120,4 +134,68 @@ public class AndroidDNSClient extends AbstractDNSClient {
         }
         return connectivityManager.getAllNetworks();
     }
+
+    private DNSMessage queryCache(final QuestionServerTuple key) {
+        final DNSMessage cachedResponse;
+        synchronized (QUERY_CACHE) {
+            cachedResponse = QUERY_CACHE.get(key);
+            if (cachedResponse == null) {
+                return null;
+            }
+            final long expiresIn = expiresIn(cachedResponse);
+            if (expiresIn < 0) {
+                QUERY_CACHE.remove(key);
+                return null;
+            }
+            if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
+                Log.d(
+                        Config.LOGTAG,
+                        "DNS query came from cache. expires in " + Duration.ofMillis(expiresIn));
+            }
+        }
+        return cachedResponse;
+    }
+
+    private void cacheQuery(final QuestionServerTuple key, final DNSMessage response) {
+        if (response.receiveTimestamp <= 0) {
+            return;
+        }
+        synchronized (QUERY_CACHE) {
+            QUERY_CACHE.put(key, response);
+        }
+    }
+
+    private static long expiresAt(final DNSMessage dnsMessage) {
+        return dnsMessage.receiveTimestamp
+                + (Collections.min(Collections2.transform(dnsMessage.answerSection, d -> d.ttl))
+                        * 1000L);
+    }
+
+    private static long expiresIn(final DNSMessage dnsMessage) {
+        return expiresAt(dnsMessage) - System.currentTimeMillis();
+    }
+
+    private static class QuestionServerTuple {
+        private final DNSServer dnsServer;
+        private final DNSMessage question;
+
+        private QuestionServerTuple(final DNSServer dnsServer, final DNSMessage question) {
+            this.dnsServer = dnsServer;
+            this.question = question.asNormalizedVersion();
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            QuestionServerTuple that = (QuestionServerTuple) o;
+            return Objects.equal(dnsServer, that.dnsServer)
+                    && Objects.equal(question, that.question);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hashCode(dnsServer, question);
+        }
+    }
 }