diff --git a/tor-probe-core/src/main/java/org/briarproject/torprobe/TorProbe.java b/tor-probe-core/src/main/java/org/briarproject/torprobe/TorProbe.java
index 67cd3b4479d8bcdc635b170a85efe22659d3a4f3..8135dd50be02bf5be96826ea6b2deef6791a531c 100644
--- a/tor-probe-core/src/main/java/org/briarproject/torprobe/TorProbe.java
+++ b/tor-probe-core/src/main/java/org/briarproject/torprobe/TorProbe.java
@@ -18,6 +18,7 @@ import java.util.logging.Logger;
 
 import static java.util.logging.Level.INFO;
 
+@SuppressWarnings("WeakerAccess")
 public class TorProbe {
 
 	private static final Logger LOG =
@@ -29,7 +30,7 @@ public class TorProbe {
 	private static final int SSL3_RSA_FIPS_WITH_3DES_EDE_CBC_SHA = 0xfeff;
 
 	// https://gitweb.torproject.org/torspec.git/tree/tor-spec.txt#n347
-	private static final int[] TOR_CIPHER_SUITES = new int[] {
+	private static final int[] TOR_CIPHER_SUITES = new int[]{
 			CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
 			CipherSuite.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
 			CipherSuite.TLS_DHE_RSA_WITH_AES_256_CBC_SHA,
@@ -61,55 +62,71 @@ public class TorProbe {
 	};
 
 	// https://gitweb.torproject.org/torspec.git/tree/tor-spec.txt#n412
-	private static final byte[] VERSIONS_CELL = new byte[] {
+	private static final byte[] VERSIONS_CELL = new byte[]{
 			0x00, 0x00, // Circuit ID: 0
 			0x07, // Command: Versions
 			0x00, 0x06, // Payload length: 6 bytes
 			0x00, 0x03, 0x00, 0x04, 0x00, 0x05 // Supported versions: 3, 4, 5
 	};
 
-	public List<Integer> probe(String address, int port) throws IOException {
+	List<Integer> probe(String address, int port) throws IOException {
+		try (Socket socket = connectSocket(address, port)) {
+			TlsClientProtocol client = connectTls(socket);
+			try {
+				return exchangeVersions(socket, client);
+			} finally {
+				client.close();
+			}
+		}
+	}
+
+	Socket connectSocket(String address, int port) throws IOException {
 		if (LOG.isLoggable(INFO))
 			LOG.info("Connecting to " + address + ":" + port);
 		Socket socket = new Socket(address, port);
 		LOG.info("Connected");
+		return socket;
+	}
+
+	TlsClientProtocol connectTls(Socket socket) throws IOException {
 		TlsClientProtocol client = new TlsClientProtocol(
 				socket.getInputStream(), socket.getOutputStream(),
 				new SecureRandom());
 		client.connect(new TorTlsClient());
 		LOG.info("TLS handshake succeeded");
+		return client;
+	}
+
+	List<Integer> exchangeVersions(Socket socket, TlsClientProtocol client)
+			throws IOException {
 		socket.setSoTimeout(READ_TIMEOUT);
-		try {
-			// Send a versions cell
-			OutputStream out = client.getOutputStream();
-			out.write(VERSIONS_CELL);
-			out.flush();
-			LOG.info("Sent versions cell");
-
-			// Expect a versions cell in response
-			List<Integer> versions = new ArrayList<>();
-			DataInputStream in = new DataInputStream(client.getInputStream());
-			int circuitId = in.readUnsignedShort();
-			if (circuitId != 0)
-				throw new IOException("Unexpected circuit ID: " + circuitId);
-			int command = in.readUnsignedByte();
-			if (command != 7)
-				throw new IOException("Unexpected command: " + command);
-			int payloadLength = in.readUnsignedShort();
-			if (payloadLength == 0 || payloadLength % 2 != 0) {
-				throw new IOException("Invalid payload length: "
-						+ payloadLength);
-			}
-			for (int i = 0; i < payloadLength / 2; i++) {
-				int version = in.readUnsignedShort();
-				versions.add(version);
-			}
-			if (LOG.isLoggable(INFO))
-				LOG.info("Supported versions: " + versions);
-			return versions;
-		} finally {
-			client.close();
+
+		// Send a versions cell
+		OutputStream out = client.getOutputStream();
+		out.write(VERSIONS_CELL);
+		out.flush();
+		LOG.info("Sent versions cell");
+
+		// Expect a versions cell in response
+		List<Integer> versions = new ArrayList<>();
+		DataInputStream in = new DataInputStream(client.getInputStream());
+		int circuitId = in.readUnsignedShort();
+		if (circuitId != 0)
+			throw new IOException("Unexpected circuit ID: " + circuitId);
+		int command = in.readUnsignedByte();
+		if (command != 7)
+			throw new IOException("Unexpected command: " + command);
+		int payloadLength = in.readUnsignedShort();
+		if (payloadLength == 0 || payloadLength % 2 != 0) {
+			throw new IOException("Invalid payload length: " + payloadLength);
+		}
+		for (int i = 0; i < payloadLength / 2; i++) {
+			int version = in.readUnsignedShort();
+			versions.add(version);
 		}
+		if (LOG.isLoggable(INFO))
+			LOG.info("Supported versions: " + versions);
+		return versions;
 	}
 
 	public static void main(String[] args) throws IOException {