Simplify cache and extend cache operations.

Rene Treffer created

Remove the external cache dependency and use a simple LRU based on
LinkedHashMap.
Make it possible to get the parse time of DNSMessage, which means we
can evaluate the TTL later on :-)

Change summary

build.gradle                                     |  7 -
src/main/java/de/measite/minidns/Client.java     | 87 ++++++++++++++++-
src/main/java/de/measite/minidns/DNSMessage.java | 14 ++
3 files changed, 92 insertions(+), 16 deletions(-)

Detailed changes

build.gradle 🔗

@@ -27,13 +27,6 @@ if (isSnapshot) {
 repositories {
 	mavenLocal()
 	mavenCentral()
-	maven {
-		url 'https://oss.sonatype.org/content/repositories/snapshots/'
-	}
-}
-
-dependencies {
-	compile 'org.igniterealtime.jxmpp:jxmpp-util-cache:0.1.0-alpha1-SNAPSHOT'
 }
 
 jar {

src/main/java/de/measite/minidns/Client.java 🔗

@@ -13,12 +13,12 @@ import java.security.SecureRandom;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.Map.Entry;
 import java.util.Random;
 import java.util.logging.Level;
 import java.util.logging.Logger;
 
-import org.jxmpp.util.cache.ExpirationCache;
-
 import de.measite.minidns.Record.CLASS;
 import de.measite.minidns.Record.TYPE;
 
@@ -30,9 +30,6 @@ public class Client {
 
     private static final Logger LOGGER = Logger.getLogger(Client.class.getName());
 
-    protected static final ExpirationCache<Question, DNSMessage> cache = new ExpirationCache<Question, DNSMessage>(
-            10, 1000 * 60 * 60 * 24);
-
     /**
      * The internal random class for sequence generation.
      */
@@ -48,6 +45,16 @@ public class Client {
      */
     protected int timeout = 5000;
 
+    /**
+     * The internal DNS cache.
+     */
+    protected LinkedHashMap<Question, DNSMessage> cache;
+
+    /**
+     * Maximum acceptable ttl.
+     */
+    protected long maxTTL = 60 * 60 * 1000;
+
     /**
      * Create a new DNS client.
      */
@@ -57,6 +64,7 @@ public class Client {
         } catch (NoSuchAlgorithmException e1) {
             random = new SecureRandom();
         }
+        setCacheSize(10);
     }
 
     /**
@@ -123,9 +131,20 @@ public class Client {
      * @throws IOException On IOErrors.
      */
     public DNSMessage query(Question q, String host, int port) throws IOException {
-        DNSMessage dnsMessage = cache.get(q);
-        if (dnsMessage != null) {
-            return dnsMessage;
+        DNSMessage dnsMessage = (cache == null) ? null : cache.get(q);
+        if (dnsMessage != null && dnsMessage.getReceiveTimestamp() > 0l) {
+            // check the ttl
+            long ttl = maxTTL;
+            for (Record r : dnsMessage.getAnswers()) {
+                ttl = Math.min(ttl, r.ttl);
+            }
+            for (Record r : dnsMessage.getAdditionalResourceRecords()) {
+                ttl = Math.min(ttl, r.ttl);
+            }
+            if (dnsMessage.getReceiveTimestamp() + ttl <
+                System.currentTimeMillis()) {
+                return dnsMessage;
+            }
         }
         DNSMessage message = new DNSMessage();
         message.setQuestions(new Question[]{q});
@@ -145,7 +164,9 @@ public class Client {
             }
             for (Record record : dnsMessage.getAnswers()) {
                 if (record.isAnswer(q)) {
-                    cache.put(q, dnsMessage, record.ttl);
+                    if (cache != null) {
+                        cache.put(q, dnsMessage);
+                    }
                     break;
                 }
             }
@@ -305,4 +326,52 @@ public class Client {
         return null;
     }
 
+    /**
+     * Configure the cache size (default 10).
+     * @param maximumSize The new cache size or 0 to disable.
+     */
+    @SuppressWarnings("serial")
+    public void setCacheSize(final int maximumSize) {
+        if (maximumSize == 0) {
+            this.cache = null;
+        } else {
+            LinkedHashMap<Question,DNSMessage> old = cache;
+            cache = new LinkedHashMap<Question,DNSMessage>() {
+                @Override
+                protected boolean removeEldestEntry(
+                        Entry<Question, DNSMessage> eldest) {
+                    return size() > maximumSize;
+                }
+            };
+            if (old != null) {
+                cache.putAll(old);
+            }
+        }
+    }
+
+    /**
+     * Flush the DNS cache.
+     */
+    public void flushCache() {
+        if (cache != null) {
+            cache.clear();
+        }
+    }
+
+    /**
+     * Get the current maximum record ttl.
+     * @return The maximum record ttl.
+     */
+    public long getMaxTTL() {
+        return maxTTL;
+    }
+
+    /**
+     * Set the maximum record ttl.
+     * @param maxTTL The new maximum ttl.
+     */
+    public void setMaxTTL(long maxTTL) {
+        this.maxTTL = maxTTL;
+    }
+
 }

src/main/java/de/measite/minidns/DNSMessage.java 🔗

@@ -195,6 +195,11 @@ public class DNSMessage {
      */
     protected Record additionalResourceRecords[];
 
+    /**
+     * The receive timestamp of this message.
+     */
+    protected long receiveTimestamp;
+
     /**
      * Retrieve the current DNS message id.
      * @return The current DNS message id.
@@ -211,6 +216,14 @@ public class DNSMessage {
         this.id = id & 0xffff;
     }
 
+    /**
+     * Get the receive timestamp if this message was created via parse.
+     * This should be used to evaluate TTLs.
+     */
+    public long getReceiveTimestamp() {
+        return receiveTimestamp;
+    }
+
     /**
      * Retrieve the query type (true or false;
      * @return True if this DNS message is a query.
@@ -410,6 +423,7 @@ public class DNSMessage {
         message.authenticData = ((header >> 5) & 1) == 1;
         message.checkDisabled = ((header >> 4) & 1) == 1;
         message.responseCode = RESPONSE_CODE.getResponseCode(header & 0xf);
+	message.receiveTimestamp = System.currentTimeMillis();
         int questionCount = dis.readUnsignedShort();
         int answerCount = dis.readUnsignedShort();
         int nameserverCount = dis.readUnsignedShort();