From 0bfa4f15a165449dbed08a3bb895aad8e04e64c7 Mon Sep 17 00:00:00 2001
From: akwizgran <michael@briarproject.org>
Date: Fri, 14 Dec 2012 19:34:25 +0000
Subject: [PATCH] Safer locking for ModemImpl.

---
 .../src/net/sf/briar/plugins/modem/Modem.java |   3 +-
 .../net/sf/briar/plugins/modem/ModemImpl.java | 359 ++++++++++++------
 2 files changed, 238 insertions(+), 124 deletions(-)

diff --git a/briar-core/src/net/sf/briar/plugins/modem/Modem.java b/briar-core/src/net/sf/briar/plugins/modem/Modem.java
index 004f7f18be..277749c33b 100644
--- a/briar-core/src/net/sf/briar/plugins/modem/Modem.java
+++ b/briar-core/src/net/sf/briar/plugins/modem/Modem.java
@@ -12,8 +12,9 @@ interface Modem {
 
 	/**
 	 * Call this method after creating the modem and before making any calls.
+	 * If this method returns false the modem cannot be used.
 	 */
-	void start() throws IOException;
+	boolean start() throws IOException;
 
 	/**
 	 * Call this method when the modem is no longer needed. If a call is in
diff --git a/briar-core/src/net/sf/briar/plugins/modem/ModemImpl.java b/briar-core/src/net/sf/briar/plugins/modem/ModemImpl.java
index ea26b698a3..774ac84938 100644
--- a/briar-core/src/net/sf/briar/plugins/modem/ModemImpl.java
+++ b/briar-core/src/net/sf/briar/plugins/modem/ModemImpl.java
@@ -9,7 +9,7 @@ import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
 import java.util.concurrent.Executor;
-import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.Semaphore;
 import java.util.logging.Logger;
 
 import jssc.SerialPort;
@@ -31,128 +31,132 @@ class ModemImpl implements Modem, WriteHandler, SerialPortEventListener {
 	private final Executor executor;
 	private final Callback callback;
 	private final SerialPort port;
-	private final AtomicBoolean initialised; // Locking: self
-	private final AtomicBoolean connected; // Locking: self
+	private final Semaphore stateChange;
 	private final byte[] line;
 
 	private int lineLen = 0;
 
-	private ReliabilityLayer reliabilityLayer = null; // Locking: this
-	private boolean offHook = false; // Locking: this;
+	// All of the following are locking: this
+	private ReliabilityLayer reliabilityLayer = null;
+	private boolean initialised = false, offHook = false, connected = false;
 
 	ModemImpl(Executor executor, Callback callback, String portName) {
 		this.executor = executor;
 		this.callback = callback;
 		port = new SerialPort(portName);
-		initialised = new AtomicBoolean(false);
-		connected = new AtomicBoolean(false);
+		stateChange = new Semaphore(1);
 		line = new byte[MAX_LINE_LENGTH];
 	}
 
-	public void start() throws IOException {
-		if(LOG.isLoggable(INFO)) LOG.info("Initialising");
+	public boolean start() throws IOException {
+		if(LOG.isLoggable(INFO)) LOG.info("Starting");
 		try {
-			if(!port.openPort())
-				throw new IOException("Failed to open serial port");
-		} catch(SerialPortException e) {
-			throw new IOException(e.toString());
+			stateChange.acquire();
+		} catch(InterruptedException e) {
+			Thread.currentThread().interrupt();
+			throw new IOException("Interrupted while waiting to start");
 		}
 		try {
-			boolean foundBaudRate = false;
-			for(int baudRate : BAUD_RATES) {
-				if(port.setParams(baudRate, 8, 1, 0)) {
-					foundBaudRate = true;
-					break;
+			// Open the serial port
+			try {
+				if(!port.openPort())
+					throw new IOException("Failed to open serial port");
+			} catch(SerialPortException e) {
+				throw new IOException(e.toString());
+			}
+			// Find a suitable baud rate and initialise the modem
+			try {
+				boolean foundBaudRate = false;
+				for(int baudRate : BAUD_RATES) {
+					if(port.setParams(baudRate, 8, 1, 0)) {
+						foundBaudRate = true;
+						break;
+					}
 				}
+				if(!foundBaudRate)
+					throw new IOException("No suitable baud rate");
+				port.purgePort(PURGE_RXCLEAR | PURGE_TXCLEAR);
+				port.addEventListener(this);
+				port.writeBytes("ATZ\r\n".getBytes("US-ASCII")); // Reset
+				port.writeBytes("ATE0\r\n".getBytes("US-ASCII")); // Echo off
+			} catch(SerialPortException e) {
+				tryToClose(port);
+				throw new IOException(e.toString());
 			}
-			if(!foundBaudRate)
-				throw new IOException("Could not find a suitable baud rate");
-			port.addEventListener(this);
-			port.purgePort(PURGE_RXCLEAR | PURGE_TXCLEAR);
-			port.writeBytes("ATZ\r\n".getBytes("US-ASCII")); // Reset
-			port.writeBytes("ATE0\r\n".getBytes("US-ASCII")); // Echo off
-		} catch(SerialPortException e) {
-			tryToClose(port);
-			throw new IOException(e.toString());
-		}
-		try {
-			synchronized(initialised) {
-				if(!initialised.get()) initialised.wait(OK_TIMEOUT);
-				if(initialised.get()) return;
+			// Wait for the event thread to receive "OK"
+			try {
+				synchronized(this) {
+					long now = System.currentTimeMillis();
+					long end = now + OK_TIMEOUT;
+					while(now < end && !initialised) {
+						wait(end - now);
+						now = System.currentTimeMillis();
+					}
+					if(initialised) return true;
+				}
+			} catch(InterruptedException e) {
+				tryToClose(port);
+				Thread.currentThread().interrupt();
+				throw new IOException("Interrupted while initialising");
 			}
-		} catch(InterruptedException e) {
 			tryToClose(port);
-			Thread.currentThread().interrupt();
-			throw new IOException("Interrupted while initialising modem");
+			return false;
+		} finally {
+			stateChange.release();
 		}
-		tryToClose(port);
-		throw new IOException("Modem did not respond");
 	}
 
-	public void stop() throws IOException {
-		hangUp();
+	private void tryToClose(SerialPort port) {
 		try {
 			port.closePort();
 		} catch(SerialPortException e) {
-			throw new IOException(e.toString());
+			if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
 		}
 	}
 
-	public boolean dial(String number) throws IOException {
+	public void stop() throws IOException {
+		if(LOG.isLoggable(INFO)) LOG.info("Stopping");
+		// Wake any threads that are waiting to connect
 		synchronized(this) {
-			if(offHook) {
-				if(LOG.isLoggable(INFO))
-					LOG.info("Not dialling - call in progress");
-				return false;
-			}
-			reliabilityLayer = new ReliabilityLayer(this);
-			reliabilityLayer.start();
-			offHook = true;
+			initialised = false;
+			connected = false;
+			notifyAll();
 		}
-		if(LOG.isLoggable(INFO)) LOG.info("Dialling");
+		// Hang up if necessary and close the port
 		try {
-			port.writeBytes(("ATDT" + number + "\r\n").getBytes("US-ASCII"));
-		} catch(SerialPortException e) {
-			tryToClose(port);
-			throw new IOException(e.toString());
+			stateChange.acquire();
+		} catch(InterruptedException e) {
+			Thread.currentThread().interrupt();
+			throw new IOException("Interrupted while waiting to stop");
 		}
 		try {
-			synchronized(connected) {
-				if(!connected.get()) connected.wait(CONNECT_TIMEOUT);
-				if(connected.get()) return true;
+			hangUpInner();
+			try {
+				port.closePort();
+			} catch(SerialPortException e) {
+				throw new IOException(e.toString());
 			}
-		} catch(InterruptedException e) {
-			tryToClose(port);
-			Thread.currentThread().interrupt();
-			throw new IOException("Interrupted while connecting outgoing call");
+		} finally {
+			stateChange.release();
 		}
-		hangUp();
-		return false;
 	}
 
-	public synchronized InputStream getInputStream() throws IOException {
-		if(offHook) return reliabilityLayer.getInputStream();
-		throw new IOException("Not connected");
-	}
-
-	public synchronized OutputStream getOutputStream() throws IOException {
-		if(offHook) return reliabilityLayer.getOutputStream();
-		throw new IOException("Not connected");
-	}
-
-	public void hangUp() throws IOException {
+	// Locking: stateChange
+	private void hangUpInner() throws IOException {
+		ReliabilityLayer reliabilityLayer;
 		synchronized(this) {
 			if(!offHook) {
 				if(LOG.isLoggable(INFO))
 					LOG.info("Not hanging up - already on the hook");
 				return;
 			}
-			reliabilityLayer.stop();
-			reliabilityLayer = null;
+			if(LOG.isLoggable(INFO)) LOG.info("Hanging up");
+			reliabilityLayer = this.reliabilityLayer;
+			this.reliabilityLayer = null;
 			offHook = false;
+			connected = false;
 		}
-		if(LOG.isLoggable(INFO)) LOG.info("Hanging up");
-		connected.set(false);
+		reliabilityLayer.stop();
 		try {
 			port.setDTR(false);
 		} catch(SerialPortException e) {
@@ -161,6 +165,92 @@ class ModemImpl implements Modem, WriteHandler, SerialPortEventListener {
 		}
 	}
 
+	public boolean dial(String number) throws IOException {
+		if(!stateChange.tryAcquire()) {
+			if(LOG.isLoggable(INFO))
+				LOG.info("Not dialling - state change in progress");
+			return false;
+		}
+		try {
+			ReliabilityLayer reliabilityLayer = new ReliabilityLayer(this);
+			synchronized(this) {
+				if(!initialised) {
+					if(LOG.isLoggable(INFO))
+						LOG.info("Not dialling - modem not initialised");
+					return false;
+				}
+				if(offHook) {
+					if(LOG.isLoggable(INFO))
+						LOG.info("Not dialling - call in progress");
+					return false;
+				}
+				this.reliabilityLayer = reliabilityLayer;
+				offHook = true;
+			}
+			reliabilityLayer.start();
+			if(LOG.isLoggable(INFO)) LOG.info("Dialling");
+			try {
+				String dial = "ATDT" + number + "\r\n";
+				port.writeBytes(dial.getBytes("US-ASCII"));
+			} catch(SerialPortException e) {
+				tryToClose(port);
+				throw new IOException(e.toString());
+			}
+			// Wait for the event thread to receive "CONNECT"
+			try {
+				synchronized(this) {
+					long now = System.currentTimeMillis();
+					long end = now + CONNECT_TIMEOUT;
+					while(now < end && initialised && !connected) {
+						wait(end - now);
+						now = System.currentTimeMillis();
+					}
+					if(connected) return true;
+				}
+			} catch(InterruptedException e) {
+				tryToClose(port);
+				Thread.currentThread().interrupt();
+				throw new IOException("Interrupted while dialling");
+			}
+			hangUpInner();
+			return false;
+		} finally {
+			stateChange.release();
+		}
+	}
+
+	public InputStream getInputStream() throws IOException {
+		ReliabilityLayer reliabilityLayer;
+		synchronized(this) {
+			reliabilityLayer = this.reliabilityLayer;
+		}
+		if(reliabilityLayer == null) throw new IOException("Not connected");
+		return reliabilityLayer.getInputStream();
+	}
+
+	public OutputStream getOutputStream() throws IOException {
+		ReliabilityLayer reliabilityLayer;
+		synchronized(this) {
+			reliabilityLayer = this.reliabilityLayer;
+		}
+		if(reliabilityLayer == null) throw new IOException("Not connected");
+		return reliabilityLayer.getOutputStream();
+	}
+
+	public void hangUp() throws IOException {
+		try {
+			stateChange.acquire();
+		} catch(InterruptedException e) {
+			Thread.currentThread().interrupt();
+			throw new IOException("Interrupted while waiting to hang up");
+		}
+		try {
+			hangUpInner();
+		} finally {
+			stateChange.release();
+		}
+	}
+
 	public void handleWrite(byte[] b) throws IOException {
 		try {
 			port.writeBytes(b);
@@ -174,8 +264,7 @@ class ModemImpl implements Modem, WriteHandler, SerialPortEventListener {
 		try {
 			if(ev.isRXCHAR()) {
 				byte[] b = port.readBytes();
-				if(connected.get()) reliabilityLayer.handleRead(b);
-				else handleText(b);
+				if(!handleData(b)) handleText(b);
 			} else if(ev.isDSR() && ev.getEventValue() == 0) {
 				if(LOG.isLoggable(INFO)) LOG.info("Remote end hung up");
 				hangUp();
@@ -192,6 +281,16 @@ class ModemImpl implements Modem, WriteHandler, SerialPortEventListener {
 		}
 	}
 
+	private boolean handleData(byte[] b) throws IOException {
+		ReliabilityLayer reliabilityLayer;
+		synchronized(this) {
+			reliabilityLayer = this.reliabilityLayer;
+		}
+		if(reliabilityLayer == null) return false;
+		reliabilityLayer.handleRead(b);
+		return true;
+	}
+
 	private void handleText(byte[] b) throws IOException {
 		if(lineLen + b.length > MAX_LINE_LENGTH) {
 			tryToClose(port);
@@ -204,27 +303,28 @@ class ModemImpl implements Modem, WriteHandler, SerialPortEventListener {
 				lineLen = 0;
 				if(LOG.isLoggable(INFO)) LOG.info("Modem status: " + s);
 				if(s.startsWith("CONNECT")) {
-					synchronized(connected) {
-						connected.set(true);
-						connected.notifyAll();
+					synchronized(this) {
+						connected = true;
+						notifyAll();
 					}
 					// There might be data in the buffer as well as text
 					int off = i + 1;
 					if(off < b.length) {
 						byte[] data = new byte[b.length - off];
 						System.arraycopy(b, off, data, 0, data.length);
-						reliabilityLayer.handleRead(data);
+						handleData(data);
 					}
 					return;
 				} else if(s.equals("BUSY") || s.equals("NO DIALTONE")
 						|| s.equals("NO CARRIER")) {
-					synchronized(connected) {
-						connected.notifyAll();
+					synchronized(this) {
+						connected = false;
+						notifyAll();
 					}
 				} else if(s.equals("OK")) {
-					synchronized(initialised) {
-						initialised.set(true);
-						initialised.notifyAll();
+					synchronized(this) {
+						initialised = true;
+						notifyAll();
 					}
 				} else if(s.equals("RING")) {
 					executor.execute(new Runnable() {
@@ -245,43 +345,56 @@ class ModemImpl implements Modem, WriteHandler, SerialPortEventListener {
 	}
 
 	private void answer() throws IOException {
-		synchronized(this) {
-			if(offHook) {
-				if(LOG.isLoggable(INFO))
-					LOG.info("Not answering - call in progress");
-				return;
-			}
-			reliabilityLayer = new ReliabilityLayer(this);
-			reliabilityLayer.start();
-			offHook = true;
+		if(!stateChange.tryAcquire()) {
+			if(LOG.isLoggable(INFO))
+				LOG.info("Not answering - state change in progress");
+			return;
 		}
-		if(LOG.isLoggable(INFO)) LOG.info("Answering");
 		try {
-			port.writeBytes("ATA\r\n".getBytes("US-ASCII"));
-		} catch(SerialPortException e) {
-			tryToClose(port);
-			throw new IOException(e.toString());
-		}
-		boolean success = false;
-		try {
-			synchronized(connected) {
-				if(!connected.get()) connected.wait(CONNECT_TIMEOUT);
-				success = connected.get();
+			ReliabilityLayer reliabilityLayer = new ReliabilityLayer(this);
+			synchronized(this) {
+				if(!initialised) {
+					if(LOG.isLoggable(INFO))
+						LOG.info("Not answering - modem not initialised");
+					return;
+				}
+				if(offHook) {
+					if(LOG.isLoggable(INFO))
+						LOG.info("Not answering - call in progress");
+					return;
+				}
+				this.reliabilityLayer = reliabilityLayer;
+				offHook = true;
 			}
-		} catch(InterruptedException e) {
-			tryToClose(port);
-			Thread.currentThread().interrupt();
-			throw new IOException("Interrupted while connecting incoming call");
-		}
-		if(success) callback.incomingCallConnected();
-		else hangUp();
-	}
-
-	private void tryToClose(SerialPort port) {
-		try {
-			port.closePort();
-		} catch(SerialPortException e) {
-			if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
+			reliabilityLayer.start();
+			if(LOG.isLoggable(INFO)) LOG.info("Answering");
+			try {
+				port.writeBytes("ATA\r\n".getBytes("US-ASCII"));
+			} catch(SerialPortException e) {
+				tryToClose(port);
+				throw new IOException(e.toString());
+			}
+			// Wait for the event thread to receive "CONNECT"
+			boolean success = false;
+			try {
+				synchronized(this) {
+					long now = System.currentTimeMillis();
+					long end = now + CONNECT_TIMEOUT;
+					while(now < end && initialised && !connected) {
+						wait(end - now);
+						now = System.currentTimeMillis();
+					}
+					success = connected;
+				}
+			} catch(InterruptedException e) {
+				tryToClose(port);
+				Thread.currentThread().interrupt();
+				throw new IOException("Interrupted while answering");
+			}
+			if(success) callback.incomingCallConnected();
+			else hangUpInner();
+		} finally {
+			stateChange.release();
 		}
 	}
 }
-- 
GitLab