catch illegal argument exception when reading DNS

Daniel Gultsch created

Change summary

src/main/java/de/gultsch/minidns/DNSSocket.java         |  5 +++--
src/main/java/de/gultsch/minidns/NetworkDataSource.java | 11 ++++++++++-
2 files changed, 13 insertions(+), 3 deletions(-)

Detailed changes

src/main/java/de/gultsch/minidns/DNSSocket.java 🔗

@@ -126,7 +126,8 @@ final class DNSSocket implements Closeable {
             sslSocket.setSoTimeout(QUERY_TIMEOUT);
             sslSocket.startHandshake();
         } else {
-            final SocketAddress socketAddress = new InetSocketAddress(dnsServer.hostname, dnsServer.port);
+            final SocketAddress socketAddress =
+                    new InetSocketAddress(dnsServer.hostname, dnsServer.port);
             sslSocket.connect(socketAddress, QUERY_TIMEOUT / 2);
             sslSocket.setSoTimeout(QUERY_TIMEOUT);
             sslSocket.startHandshake();
@@ -181,7 +182,7 @@ final class DNSSocket implements Closeable {
         while (read < length) {
             read += this.dataInputStream.read(data, read, length - read);
         }
-        return new DNSMessage(data);
+        return NetworkDataSource.readDNSMessage(data);
     }
 
     @Override

src/main/java/de/gultsch/minidns/NetworkDataSource.java 🔗

@@ -4,6 +4,7 @@ import android.util.Log;
 
 import androidx.annotation.NonNull;
 
+import com.google.common.base.Throwables;
 import com.google.common.cache.CacheBuilder;
 import com.google.common.cache.CacheLoader;
 import com.google.common.cache.LoadingCache;
@@ -125,7 +126,7 @@ public class NetworkDataSource extends DNSDataSource {
             socket.send(request);
             final DatagramPacket response = new DatagramPacket(buffer, buffer.length);
             socket.receive(response);
-            DNSMessage dnsMessage = new DNSMessage(response.getData());
+            final DNSMessage dnsMessage = readDNSMessage(response.getData());
             if (dnsMessage.id != message.id) {
                 throw new MiniDNSException.IdMismatch(message, dnsMessage);
             }
@@ -157,4 +158,12 @@ public class NetworkDataSource extends DNSDataSource {
             }
         }
     }
+
+    public static DNSMessage readDNSMessage(final byte[] bytes) throws IOException {
+        try {
+            return new DNSMessage(bytes);
+        } catch (final IllegalArgumentException e) {
+            throw new IOException(Throwables.getRootCause(e));
+        }
+    }
 }