1package eu.siacs.conversations.utils;
2
3import com.google.common.io.ByteStreams;
4
5import java.io.IOException;
6import java.io.InputStream;
7import java.io.OutputStream;
8import java.net.InetAddress;
9import java.net.InetSocketAddress;
10import java.net.Socket;
11import java.nio.ByteBuffer;
12
13import eu.siacs.conversations.Config;
14
15public class SocksSocketFactory {
16
17 private static final byte[] LOCALHOST = new byte[]{127, 0, 0, 1};
18
19 public static void createSocksConnection(final Socket socket, final String destination, final int port) throws IOException {
20 //TODO use different Socks Addr Type if destination is IP or IPv6
21 final InputStream proxyIs = socket.getInputStream();
22 final OutputStream proxyOs = socket.getOutputStream();
23 proxyOs.write(new byte[]{0x05, 0x01, 0x00});
24 proxyOs.flush();
25 final byte[] handshake = new byte[2];
26 ByteStreams.readFully(proxyIs, handshake);
27 if (handshake[0] != 0x05 || handshake[1] != 0x00) {
28 throw new SocksConnectionException("Socks 5 handshake failed");
29 }
30 final byte[] dest = destination.getBytes();
31 final ByteBuffer request = ByteBuffer.allocate(7 + dest.length);
32 request.put(new byte[]{0x05, 0x01, 0x00, 0x03});
33 request.put((byte) dest.length);
34 request.put(dest);
35 request.putShort((short) port);
36 proxyOs.write(request.array());
37 proxyOs.flush();
38 final byte[] response = new byte[4];
39 ByteStreams.readFully(proxyIs, response);
40 final byte ver = response[0];
41 if (ver != 0x05) {
42 throw new IOException(String.format("Unknown Socks version %02X ", ver));
43 }
44 final byte status = response[1];
45 final byte bndAddrType = response[3];
46 final byte[] bndDestination = readDestination(bndAddrType, proxyIs);
47 final byte[] bndPort = new byte[2];
48 if (bndAddrType == 0x03) {
49 final String receivedDestination = new String(bndDestination);
50 if (!receivedDestination.equalsIgnoreCase(destination)) {
51 throw new IOException(String.format("Destination mismatch. Received %s Expected %s", receivedDestination, destination));
52 }
53 }
54 ByteStreams.readFully(proxyIs, bndPort);
55 if (status != 0x00) {
56 if (status == 0x04) {
57 throw new HostNotFoundException("Host unreachable");
58 }
59 if (status == 0x05) {
60 throw new HostNotFoundException("Connection refused");
61 }
62 throw new IOException(String.format("Unknown status code %02X ", status));
63 }
64 }
65
66 private static byte[] readDestination(final byte type, final InputStream inputStream) throws IOException {
67 final byte[] bndDestination;
68 if (type == 0x01) {
69 bndDestination = new byte[4];
70 } else if (type == 0x03) {
71 final int length = inputStream.read();
72 bndDestination = new byte[length];
73 } else if (type == 0x04) {
74 bndDestination = new byte[16];
75 } else {
76 throw new IOException(String.format("Unknown Socks address type %02X ", type));
77 }
78 ByteStreams.readFully(inputStream, bndDestination);
79 return bndDestination;
80 }
81
82 public static boolean contains(byte needle, byte[] haystack) {
83 for (byte hay : haystack) {
84 if (hay == needle) {
85 return true;
86 }
87 }
88 return false;
89 }
90
91 private static Socket createSocket(InetSocketAddress address, String destination, int port) throws IOException {
92 Socket socket = new Socket();
93 try {
94 socket.connect(address, Config.CONNECT_TIMEOUT * 1000);
95 } catch (IOException e) {
96 throw new SocksProxyNotFoundException();
97 }
98 createSocksConnection(socket, destination, port);
99 return socket;
100 }
101
102 public static Socket createSocketOverTor(String destination, int port) throws IOException {
103 return createSocket(new InetSocketAddress(InetAddress.getByAddress(LOCALHOST), 9050), destination, port);
104 }
105
106 private static class SocksConnectionException extends IOException {
107 SocksConnectionException(String message) {
108 super(message);
109 }
110 }
111
112 public static class SocksProxyNotFoundException extends IOException {
113
114 }
115
116 public static class HostNotFoundException extends SocksConnectionException {
117 HostNotFoundException(String message) {
118 super(message);
119 }
120 }
121}