diff --git a/tor-probe-core/src/main/java/org/briarproject/torprobe/TorProbe.java b/tor-probe-core/src/main/java/org/briarproject/torprobe/TorProbe.java index 67cd3b4479d8bcdc635b170a85efe22659d3a4f3..8135dd50be02bf5be96826ea6b2deef6791a531c 100644 --- a/tor-probe-core/src/main/java/org/briarproject/torprobe/TorProbe.java +++ b/tor-probe-core/src/main/java/org/briarproject/torprobe/TorProbe.java @@ -18,6 +18,7 @@ import java.util.logging.Logger; import static java.util.logging.Level.INFO; +@SuppressWarnings("WeakerAccess") public class TorProbe { private static final Logger LOG = @@ -29,7 +30,7 @@ public class TorProbe { private static final int SSL3_RSA_FIPS_WITH_3DES_EDE_CBC_SHA = 0xfeff; // https://gitweb.torproject.org/torspec.git/tree/tor-spec.txt#n347 - private static final int[] TOR_CIPHER_SUITES = new int[] { + private static final int[] TOR_CIPHER_SUITES = new int[]{ CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, CipherSuite.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, CipherSuite.TLS_DHE_RSA_WITH_AES_256_CBC_SHA, @@ -61,55 +62,71 @@ public class TorProbe { }; // https://gitweb.torproject.org/torspec.git/tree/tor-spec.txt#n412 - private static final byte[] VERSIONS_CELL = new byte[] { + private static final byte[] VERSIONS_CELL = new byte[]{ 0x00, 0x00, // Circuit ID: 0 0x07, // Command: Versions 0x00, 0x06, // Payload length: 6 bytes 0x00, 0x03, 0x00, 0x04, 0x00, 0x05 // Supported versions: 3, 4, 5 }; - public List<Integer> probe(String address, int port) throws IOException { + List<Integer> probe(String address, int port) throws IOException { + try (Socket socket = connectSocket(address, port)) { + TlsClientProtocol client = connectTls(socket); + try { + return exchangeVersions(socket, client); + } finally { + client.close(); + } + } + } + + Socket connectSocket(String address, int port) throws IOException { if (LOG.isLoggable(INFO)) LOG.info("Connecting to " + address + ":" + port); Socket socket = new Socket(address, port); LOG.info("Connected"); + return socket; + } + + TlsClientProtocol connectTls(Socket socket) throws IOException { TlsClientProtocol client = new TlsClientProtocol( socket.getInputStream(), socket.getOutputStream(), new SecureRandom()); client.connect(new TorTlsClient()); LOG.info("TLS handshake succeeded"); + return client; + } + + List<Integer> exchangeVersions(Socket socket, TlsClientProtocol client) + throws IOException { socket.setSoTimeout(READ_TIMEOUT); - try { - // Send a versions cell - OutputStream out = client.getOutputStream(); - out.write(VERSIONS_CELL); - out.flush(); - LOG.info("Sent versions cell"); - - // Expect a versions cell in response - List<Integer> versions = new ArrayList<>(); - DataInputStream in = new DataInputStream(client.getInputStream()); - int circuitId = in.readUnsignedShort(); - if (circuitId != 0) - throw new IOException("Unexpected circuit ID: " + circuitId); - int command = in.readUnsignedByte(); - if (command != 7) - throw new IOException("Unexpected command: " + command); - int payloadLength = in.readUnsignedShort(); - if (payloadLength == 0 || payloadLength % 2 != 0) { - throw new IOException("Invalid payload length: " - + payloadLength); - } - for (int i = 0; i < payloadLength / 2; i++) { - int version = in.readUnsignedShort(); - versions.add(version); - } - if (LOG.isLoggable(INFO)) - LOG.info("Supported versions: " + versions); - return versions; - } finally { - client.close(); + + // Send a versions cell + OutputStream out = client.getOutputStream(); + out.write(VERSIONS_CELL); + out.flush(); + LOG.info("Sent versions cell"); + + // Expect a versions cell in response + List<Integer> versions = new ArrayList<>(); + DataInputStream in = new DataInputStream(client.getInputStream()); + int circuitId = in.readUnsignedShort(); + if (circuitId != 0) + throw new IOException("Unexpected circuit ID: " + circuitId); + int command = in.readUnsignedByte(); + if (command != 7) + throw new IOException("Unexpected command: " + command); + int payloadLength = in.readUnsignedShort(); + if (payloadLength == 0 || payloadLength % 2 != 0) { + throw new IOException("Invalid payload length: " + payloadLength); + } + for (int i = 0; i < payloadLength / 2; i++) { + int version = in.readUnsignedShort(); + versions.add(version); } + if (LOG.isLoggable(INFO)) + LOG.info("Supported versions: " + versions); + return versions; } public static void main(String[] args) throws IOException {