ensure all bytes are read in socks handshake. fixes #4188

Daniel Gultsch created

Change summary

src/main/java/eu/siacs/conversations/utils/SocksSocketFactory.java          | 162 
src/main/java/eu/siacs/conversations/xmpp/jingle/JingleSocks5Transport.java |  14 
2 files changed, 106 insertions(+), 70 deletions(-)

Detailed changes

src/main/java/eu/siacs/conversations/utils/SocksSocketFactory.java 🔗

@@ -1,5 +1,7 @@
 package eu.siacs.conversations.utils;
 
+import com.google.common.io.ByteStreams;
+
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
@@ -12,76 +14,108 @@ import eu.siacs.conversations.Config;
 
 public class SocksSocketFactory {
 
-	private static final byte[] LOCALHOST = new byte[]{127,0,0,1};
+    private static final byte[] LOCALHOST = new byte[]{127, 0, 0, 1};
+
+    public static void createSocksConnection(final Socket socket, final String destination, final int port) throws IOException {
+        //TODO use different Socks Addr Type if destination is IP or IPv6
+        final InputStream proxyIs = socket.getInputStream();
+        final OutputStream proxyOs = socket.getOutputStream();
+        proxyOs.write(new byte[]{0x05, 0x01, 0x00});
+        proxyOs.flush();
+        final byte[] handshake = new byte[2];
+        ByteStreams.readFully(proxyIs, handshake);
+        if (handshake[0] != 0x05 || handshake[1] != 0x00) {
+            throw new SocksConnectionException("Socks 5 handshake failed");
+        }
+        final byte[] dest = destination.getBytes();
+        final ByteBuffer request = ByteBuffer.allocate(7 + dest.length);
+        request.put(new byte[]{0x05, 0x01, 0x00, 0x03});
+        request.put((byte) dest.length);
+        request.put(dest);
+        request.putShort((short) port);
+        proxyOs.write(request.array());
+        proxyOs.flush();
+        final byte[] response = new byte[4];
+        ByteStreams.readFully(proxyIs, response);
+        final byte ver = response[0];
+        if (ver != 0x05) {
+            throw new IOException(String.format("Unknown Socks version %02X ", ver));
+        }
+        final byte status = response[1];
+        final byte bndAddrType = response[3];
+        final byte[] bndDestination = readDestination(bndAddrType, proxyIs);
+        final byte[] bndPort = new byte[2];
+        if (bndAddrType == 0x03) {
+            final String receivedDestination = new String(bndDestination);
+            if (!receivedDestination.equalsIgnoreCase(destination)) {
+                throw new IOException(String.format("Destination mismatch. Received %s Expected %s", receivedDestination, destination));
+            }
+        }
+        ByteStreams.readFully(proxyIs, bndPort);
+        if (status != 0x00) {
+            if (status == 0x04) {
+                throw new HostNotFoundException("Host unreachable");
+            }
+            if (status == 0x05) {
+                throw new HostNotFoundException("Connection refused");
+            }
+            throw new IOException(String.format("Unknown status code %02X ", status));
+        }
+    }
 
-	public static void createSocksConnection(final Socket socket, final String destination, final int port) throws IOException {
-		final InputStream proxyIs = socket.getInputStream();
-		final OutputStream proxyOs = socket.getOutputStream();
-		proxyOs.write(new byte[]{0x05, 0x01, 0x00});
-		proxyOs.flush();
-		final byte[] handshake = new byte[2];
-		proxyIs.read(handshake);
-		if (handshake[0] != 0x05 || handshake[1] != 0x00) {
-			throw new SocksConnectionException("Socks 5 handshake failed");
-		}
-		final byte[] dest = destination.getBytes();
-		final ByteBuffer request = ByteBuffer.allocate(7 + dest.length);
-		request.put(new byte[]{0x05, 0x01, 0x00, 0x03});
-		request.put((byte) dest.length);
-		request.put(dest);
-		request.putShort((short) port);
-		proxyOs.write(request.array());
-		proxyOs.flush();
-		final byte[] response = new byte[7 + dest.length];
-		proxyIs.read(response);
-		if (response[1] != 0x00) {
-			if (response[1] == 0x04) {
-				throw new HostNotFoundException("Host unreachable");
-			}
-			if (response[1] == 0x05) {
-				throw new HostNotFoundException("Connection refused");
-			}
-			throw new SocksConnectionException("Unable to connect to destination "+(int) (response[1]));
-		}
-	}
+    private static byte[] readDestination(final byte type, final InputStream inputStream) throws IOException {
+        final byte[] bndDestination;
+        if (type == 0x01) {
+            bndDestination = new byte[4];
+        } else if (type == 0x03) {
+            final int length = inputStream.read();
+            bndDestination = new byte[length];
+        } else if (type == 0x04) {
+            bndDestination = new byte[16];
+        } else {
+            throw new IOException(String.format("Unknown Socks address type %02X ", type));
+        }
+        ByteStreams.readFully(inputStream, bndDestination);
+        return bndDestination;
+    }
 
-	public static boolean contains(byte needle, byte[] haystack) {
-		for(byte hay : haystack) {
-			if (hay == needle) {
-				return true;
-			}
-		}
-		return false;
-	}
+    public static boolean contains(byte needle, byte[] haystack) {
+        for (byte hay : haystack) {
+            if (hay == needle) {
+                return true;
+            }
+        }
+        return false;
+    }
 
-	private static Socket createSocket(InetSocketAddress address, String destination, int port) throws IOException {
-		Socket socket = new Socket();
-		try {
-			socket.connect(address, Config.CONNECT_TIMEOUT * 1000);
-		} catch (IOException e) {
-			throw new SocksProxyNotFoundException();
-		}
-		createSocksConnection(socket, destination, port);
-		return socket;
-	}
+    private static Socket createSocket(InetSocketAddress address, String destination, int port) throws IOException {
+        Socket socket = new Socket();
+        try {
+            socket.connect(address, Config.CONNECT_TIMEOUT * 1000);
+        } catch (IOException e) {
+            throw new SocksProxyNotFoundException();
+        }
+        createSocksConnection(socket, destination, port);
+        return socket;
+    }
 
-	public static Socket createSocketOverTor(String destination, int port) throws IOException {
-		return createSocket(new InetSocketAddress(InetAddress.getByAddress(LOCALHOST), 9050), destination, port);
-	}
+    public static Socket createSocketOverTor(String destination, int port) throws IOException {
+        return createSocket(new InetSocketAddress(InetAddress.getByAddress(LOCALHOST), 9050), destination, port);
+    }
 
-	private static class SocksConnectionException extends IOException {
-		SocksConnectionException(String message) {
-			super(message);
-		}
-	}
+    private static class SocksConnectionException extends IOException {
+        SocksConnectionException(String message) {
+            super(message);
+        }
+    }
 
-	public static class SocksProxyNotFoundException extends IOException {
+    public static class SocksProxyNotFoundException extends IOException {
 
-	}
+    }
 
-	public static class HostNotFoundException extends SocksConnectionException {
-		HostNotFoundException(String message) {
-			super(message);
-		}
-	}
+    public static class HostNotFoundException extends SocksConnectionException {
+        HostNotFoundException(String message) {
+            super(message);
+        }
+    }
 }

src/main/java/eu/siacs/conversations/xmpp/jingle/JingleSocks5Transport.java 🔗

@@ -3,6 +3,8 @@ package eu.siacs.conversations.xmpp.jingle;
 import android.os.PowerManager;
 import android.util.Log;
 
+import com.google.common.io.ByteStreams;
+
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
@@ -114,26 +116,26 @@ public class JingleSocks5Transport extends JingleTransport {
         final byte[] authBegin = new byte[2];
         final InputStream inputStream = socket.getInputStream();
         final OutputStream outputStream = socket.getOutputStream();
-        inputStream.read(authBegin);
+        ByteStreams.readFully(inputStream, authBegin);
         if (authBegin[0] != 0x5) {
             socket.close();
         }
         final short methodCount = authBegin[1];
         final byte[] methods = new byte[methodCount];
-        inputStream.read(methods);
+        ByteStreams.readFully(inputStream, methods);
         if (SocksSocketFactory.contains((byte) 0x00, methods)) {
             outputStream.write(new byte[]{0x05, 0x00});
         } else {
             outputStream.write(new byte[]{0x05, (byte) 0xff});
         }
-        byte[] connectCommand = new byte[4];
-        inputStream.read(connectCommand);
+        final byte[] connectCommand = new byte[4];
+        ByteStreams.readFully(inputStream, connectCommand);
         if (connectCommand[0] == 0x05 && connectCommand[1] == 0x01 && connectCommand[3] == 0x03) {
             int destinationCount = inputStream.read();
             final byte[] destination = new byte[destinationCount];
-            inputStream.read(destination);
+            ByteStreams.readFully(inputStream, destination);
             final byte[] port = new byte[2];
-            inputStream.read(port);
+            ByteStreams.readFully(inputStream, port);
             final String receivedDestination = new String(destination);
             final ByteBuffer response = ByteBuffer.allocate(7 + destination.length);
             final byte[] responseHeader;