From 7739bcdd068a40ab7c6e64487871f63645f73fb5 Mon Sep 17 00:00:00 2001
From: akwizgran <michael@briarproject.org>
Date: Mon, 8 Oct 2012 18:15:25 +0100
Subject: [PATCH] Second part of key rotation implementation. Work in progress.

---
 .../api/transport/ConnectionRecogniser.java   |  10 +-
 .../briar/api/transport/ConnectionWindow.java |  12 --
 .../transport/ConnectionRecogniserImpl.java   |  67 ++++++++
 ...nWindowImpl.java => ConnectionWindow.java} |  60 ++++---
 .../net/sf/briar/transport/TagEncoder.java    |  25 +--
 .../TransportConnectionRecogniser.java        | 160 ++++++++++++++++++
 test/build.xml                                |   2 +-
 ...mplTest.java => ConnectionWindowTest.java} |  31 +---
 8 files changed, 291 insertions(+), 76 deletions(-)
 delete mode 100644 api/net/sf/briar/api/transport/ConnectionWindow.java
 create mode 100644 components/net/sf/briar/transport/ConnectionRecogniserImpl.java
 rename components/net/sf/briar/transport/{ConnectionWindowImpl.java => ConnectionWindow.java} (53%)
 create mode 100644 components/net/sf/briar/transport/TransportConnectionRecogniser.java
 rename test/net/sf/briar/transport/{ConnectionWindowImplTest.java => ConnectionWindowTest.java} (74%)

diff --git a/api/net/sf/briar/api/transport/ConnectionRecogniser.java b/api/net/sf/briar/api/transport/ConnectionRecogniser.java
index 6bcf2f3bd1..61ca17daff 100644
--- a/api/net/sf/briar/api/transport/ConnectionRecogniser.java
+++ b/api/net/sf/briar/api/transport/ConnectionRecogniser.java
@@ -1,5 +1,6 @@
 package net.sf.briar.api.transport;
 
+import net.sf.briar.api.ContactId;
 import net.sf.briar.api.db.DbException;
 import net.sf.briar.api.protocol.TransportId;
 
@@ -14,5 +15,12 @@ public interface ConnectionRecogniser {
 	 * expected, or null if the connection was not expected.
 	 */
 	ConnectionContext acceptConnection(TransportId t, byte[] tag)
-	throws DbException;
+			throws DbException;
+
+	void addWindow(ContactId c, TransportId t, long period, boolean alice,
+			byte[] secret, long centre, byte[] bitmap) throws DbException;
+
+	void removeWindow(ContactId c, TransportId t, long period);
+
+	void removeWindows(ContactId c);
 }
diff --git a/api/net/sf/briar/api/transport/ConnectionWindow.java b/api/net/sf/briar/api/transport/ConnectionWindow.java
deleted file mode 100644
index e2ecd94903..0000000000
--- a/api/net/sf/briar/api/transport/ConnectionWindow.java
+++ /dev/null
@@ -1,12 +0,0 @@
-package net.sf.briar.api.transport;
-
-import java.util.Set;
-
-public interface ConnectionWindow {
-
-	boolean isSeen(long connection);
-
-	void setSeen(long connection);
-
-	Set<Long> getUnseen();
-}
diff --git a/components/net/sf/briar/transport/ConnectionRecogniserImpl.java b/components/net/sf/briar/transport/ConnectionRecogniserImpl.java
new file mode 100644
index 0000000000..3917b0c48f
--- /dev/null
+++ b/components/net/sf/briar/transport/ConnectionRecogniserImpl.java
@@ -0,0 +1,67 @@
+package net.sf.briar.transport;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import net.sf.briar.api.ContactId;
+import net.sf.briar.api.crypto.CryptoComponent;
+import net.sf.briar.api.db.DatabaseComponent;
+import net.sf.briar.api.db.DbException;
+import net.sf.briar.api.protocol.TransportId;
+import net.sf.briar.api.transport.ConnectionContext;
+import net.sf.briar.api.transport.ConnectionRecogniser;
+
+import com.google.inject.Inject;
+
+class ConnectionRecogniserImpl implements ConnectionRecogniser {
+
+	private final CryptoComponent crypto;
+	private final DatabaseComponent db;
+	// Locking: this
+	private final Map<TransportId, TransportConnectionRecogniser> recognisers;
+
+	@Inject
+	ConnectionRecogniserImpl(CryptoComponent crypto, DatabaseComponent db) {
+		this.crypto = crypto;
+		this.db = db;
+		recognisers = new HashMap<TransportId, TransportConnectionRecogniser>();
+	}
+
+	public ConnectionContext acceptConnection(TransportId t, byte[] tag)
+			throws DbException {
+		TransportConnectionRecogniser r;
+		synchronized(this) {
+			r = recognisers.get(t);
+		}
+		if(r == null) return null;
+		return r.acceptConnection(tag);
+	}
+
+	public void addWindow(ContactId c, TransportId t, long period,
+			boolean alice, byte[] secret, long centre, byte[] bitmap)
+					throws DbException {
+		TransportConnectionRecogniser r;
+		synchronized(this) {
+			r = recognisers.get(t);
+			if(r == null) {
+				r = new TransportConnectionRecogniser(crypto, db, t);
+				recognisers.put(t, r);
+			}
+		}
+		r.addWindow(c, period, alice, secret, centre, bitmap);
+	}
+
+	public void removeWindow(ContactId c, TransportId t, long period) {
+		TransportConnectionRecogniser r;
+		synchronized(this) {
+			r = recognisers.get(t);
+		}
+		if(r != null) r.removeWindow(c, period);
+	}
+
+	public synchronized void removeWindows(ContactId c) {
+		for(TransportConnectionRecogniser r : recognisers.values()) {
+			r.removeWindows(c);
+		}
+	}
+}
diff --git a/components/net/sf/briar/transport/ConnectionWindowImpl.java b/components/net/sf/briar/transport/ConnectionWindow.java
similarity index 53%
rename from components/net/sf/briar/transport/ConnectionWindowImpl.java
rename to components/net/sf/briar/transport/ConnectionWindow.java
index 057cd9335a..3a2ced3163 100644
--- a/components/net/sf/briar/transport/ConnectionWindowImpl.java
+++ b/components/net/sf/briar/transport/ConnectionWindow.java
@@ -3,59 +3,75 @@ package net.sf.briar.transport;
 import static net.sf.briar.api.transport.TransportConstants.CONNECTION_WINDOW_SIZE;
 import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED;
 
+import java.util.ArrayList;
+import java.util.Collection;
 import java.util.HashSet;
 import java.util.Set;
 
-import net.sf.briar.api.transport.ConnectionWindow;
-
 // This class is not thread-safe
-class ConnectionWindowImpl implements ConnectionWindow {
+class ConnectionWindow {
 
 	private final Set<Long> unseen;
 
 	private long centre;
 
-	ConnectionWindowImpl() {
+	ConnectionWindow() {
 		unseen = new HashSet<Long>();
 		for(long l = 0; l < CONNECTION_WINDOW_SIZE / 2; l++) unseen.add(l);
 		centre = 0;
 	}
 
-	ConnectionWindowImpl(Set<Long> unseen) {
-		long min = Long.MAX_VALUE, max = Long.MIN_VALUE;
-		for(long l : unseen) {
-			if(l < 0 || l > MAX_32_BIT_UNSIGNED)
-				throw new IllegalArgumentException();
-			if(l < min) min = l;
-			if(l > max) max = l;
-		}
-		if(max - min > CONNECTION_WINDOW_SIZE)
+	ConnectionWindow(long centre, byte[] bitmap) {
+		if(centre < 0 || centre > MAX_32_BIT_UNSIGNED)
+			throw new IllegalArgumentException();
+		if(bitmap.length != CONNECTION_WINDOW_SIZE / 8)
 			throw new IllegalArgumentException();
-		centre = max - CONNECTION_WINDOW_SIZE / 2 + 1;
-		for(long l = centre; l <= max; l++) {
-			if(!unseen.contains(l)) throw new IllegalArgumentException();
+		unseen = new HashSet<Long>();
+		long bottom = getBottom(centre);
+		long top = getTop(centre);
+		for(long l = bottom; l < top; l++) {
+			int offset = (int) (l - bottom);
+			int bytes = offset / 8;
+			int bits = offset % 8;
+			if((bitmap[bytes] & (128 >> bits)) == 0) unseen.add(l);
 		}
-		this.unseen = unseen;
+		this.centre = centre;
 	}
 
-	public boolean isSeen(long connection) {
+	boolean isSeen(long connection) {
 		return !unseen.contains(connection);
 	}
 
-	public void setSeen(long connection) {
+	Collection<Long> setSeen(long connection) {
 		long bottom = getBottom(centre);
 		long top = getTop(centre);
 		if(connection < bottom || connection > top)
 			throw new IllegalArgumentException();
 		if(!unseen.remove(connection))
 			throw new IllegalArgumentException();
+		Collection<Long> changed = new ArrayList<Long>();
 		if(connection >= centre) {
 			centre = connection + 1;
 			long newBottom = getBottom(centre);
 			long newTop = getTop(centre);
-			for(long l = bottom; l < newBottom; l++) unseen.remove(l);
-			for(long l = top + 1; l <= newTop; l++) unseen.add(l);
+			for(long l = bottom; l < newBottom; l++) {
+				if(unseen.remove(l)) changed.add(l);
+			}
+			for(long l = top + 1; l <= newTop; l++) {
+				if(unseen.add(l)) changed.add(l);
+			}
 		}
+		return changed;
+	}
+
+	long getCentre() {
+		return centre;
+	}
+
+	byte[] getBitmap() {
+		byte[] bitmap = new byte[CONNECTION_WINDOW_SIZE / 8];
+		// FIXME
+		return bitmap;
 	}
 
 	// Returns the lowest value contained in a window with the given centre
@@ -69,7 +85,7 @@ class ConnectionWindowImpl implements ConnectionWindow {
 				centre + CONNECTION_WINDOW_SIZE / 2 - 1);
 	}
 
-	public Set<Long> getUnseen() {
+	public Collection<Long> getUnseen() {
 		return unseen;
 	}
 }
diff --git a/components/net/sf/briar/transport/TagEncoder.java b/components/net/sf/briar/transport/TagEncoder.java
index 404870e1f5..fc0e1a85fc 100644
--- a/components/net/sf/briar/transport/TagEncoder.java
+++ b/components/net/sf/briar/transport/TagEncoder.java
@@ -1,21 +1,25 @@
 package net.sf.briar.transport;
 
-import static javax.crypto.Cipher.DECRYPT_MODE;
 import static javax.crypto.Cipher.ENCRYPT_MODE;
 import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH;
+import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED;
 
 import java.security.GeneralSecurityException;
 
 import javax.crypto.Cipher;
 
 import net.sf.briar.api.crypto.ErasableKey;
+import net.sf.briar.util.ByteUtils;
 
 class TagEncoder {
 
-	static void encodeTag(byte[] tag, Cipher tagCipher, ErasableKey tagKey) {
+	static void encodeTag(byte[] tag, Cipher tagCipher, ErasableKey tagKey,
+			long connection) {
 		if(tag.length < TAG_LENGTH) throw new IllegalArgumentException();
-		// Blank plaintext
+		if(connection < 0 || connection > MAX_32_BIT_UNSIGNED)
+			throw new IllegalArgumentException();
 		for(int i = 0; i < TAG_LENGTH; i++) tag[i] = 0;
+		ByteUtils.writeUint32(connection, tag, 0);
 		try {
 			tagCipher.init(ENCRYPT_MODE, tagKey);
 			int encrypted = tagCipher.doFinal(tag, 0, TAG_LENGTH, tag);
@@ -25,19 +29,4 @@ class TagEncoder {
 			throw new IllegalArgumentException(e);
 		}
 	}
-
-	static boolean decodeTag(byte[] tag, Cipher tagCipher, ErasableKey tagKey) {
-		if(tag.length < TAG_LENGTH) throw new IllegalArgumentException();
-		try {
-			tagCipher.init(DECRYPT_MODE, tagKey);
-			int decrypted = tagCipher.doFinal(tag, 0, TAG_LENGTH, tag);
-			if(decrypted != TAG_LENGTH) throw new IllegalArgumentException();
-			//The plaintext should be blank
-			for(int i = 0; i < TAG_LENGTH; i++) if(tag[i] != 0) return false;
-			return true;
-		} catch(GeneralSecurityException e) {
-			// Unsuitable cipher or key
-			throw new IllegalArgumentException(e);
-		}
-	}
 }
diff --git a/components/net/sf/briar/transport/TransportConnectionRecogniser.java b/components/net/sf/briar/transport/TransportConnectionRecogniser.java
new file mode 100644
index 0000000000..9ddaacfc28
--- /dev/null
+++ b/components/net/sf/briar/transport/TransportConnectionRecogniser.java
@@ -0,0 +1,160 @@
+package net.sf.briar.transport;
+
+import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+
+import javax.crypto.Cipher;
+
+import net.sf.briar.api.Bytes;
+import net.sf.briar.api.ContactId;
+import net.sf.briar.api.crypto.CryptoComponent;
+import net.sf.briar.api.crypto.ErasableKey;
+import net.sf.briar.api.db.DatabaseComponent;
+import net.sf.briar.api.db.DbException;
+import net.sf.briar.api.protocol.TransportId;
+import net.sf.briar.api.transport.ConnectionContext;
+import net.sf.briar.util.ByteUtils;
+
+/** A connection recogniser for a specific transport. */
+class TransportConnectionRecogniser {
+
+	private final CryptoComponent crypto;
+	private final DatabaseComponent db;
+	private final TransportId transportId;
+	private final Map<Bytes, WindowContext> tagMap; // Locking: this
+	private final Map<RemovalKey, RemovalContext> windowMap; // Locking: this
+
+	TransportConnectionRecogniser(CryptoComponent crypto, DatabaseComponent db,
+			TransportId transportId) {
+		this.crypto = crypto;
+		this.db = db;
+		this.transportId = transportId;
+		tagMap = new HashMap<Bytes, WindowContext>();
+		windowMap = new HashMap<RemovalKey, RemovalContext>();
+	}
+
+	synchronized ConnectionContext acceptConnection(byte[] tag)
+			throws DbException {
+		WindowContext wctx = tagMap.remove(new Bytes(tag));
+		if(wctx == null) return null;
+		ConnectionWindow w = wctx.window;
+		ConnectionContext ctx = wctx.context;
+		long connection = ctx.getConnectionNumber();
+		Cipher cipher = crypto.getTagCipher();
+		ErasableKey key = crypto.deriveTagKey(ctx.getSecret(), ctx.getAlice());
+		byte[] changedTag = new byte[TAG_LENGTH];
+		Bytes changedTagWrapper = new Bytes(changedTag);
+		for(long conn : w.setSeen(connection)) {
+			TagEncoder.encodeTag(changedTag, cipher, key, conn);
+			WindowContext old;
+			if(conn <= connection) old = tagMap.remove(changedTagWrapper);
+			else old = tagMap.put(changedTagWrapper, wctx);
+			assert old == null;
+		}
+		key.erase();
+		ContactId c = ctx.getContactId();
+		long centre = w.getCentre();
+		byte[] bitmap = w.getBitmap();
+		db.setConnectionWindow(c, transportId, wctx.period, centre, bitmap);
+		return wctx.context;
+	}
+
+	synchronized void addWindow(ContactId c, long period, boolean alice,
+			byte[] secret, long centre, byte[] bitmap) throws DbException {
+		Cipher cipher = crypto.getTagCipher();
+		ErasableKey key = crypto.deriveTagKey(secret, alice);
+		ConnectionWindow w = new ConnectionWindow(centre, bitmap);
+		for(long conn : w.getUnseen()) {
+			byte[] tag = new byte[TAG_LENGTH];
+			TagEncoder.encodeTag(tag, cipher, key, conn);
+			ConnectionContext ctx = new ConnectionContext(c, transportId, tag,
+					secret, conn, alice);
+			WindowContext wctx = new WindowContext(w, ctx, period);
+			tagMap.put(new Bytes(tag), wctx);
+		}
+		db.setConnectionWindow(c, transportId, period, centre, bitmap);
+		RemovalContext ctx = new RemovalContext(w, secret, alice);
+		windowMap.put(new RemovalKey(c, period), ctx);
+	}
+
+	synchronized void removeWindow(ContactId c, long period) {
+		RemovalContext ctx = windowMap.remove(new RemovalKey(c, period));
+		if(ctx == null) throw new IllegalArgumentException();
+		Cipher cipher = crypto.getTagCipher();
+		ErasableKey key = crypto.deriveTagKey(ctx.secret, ctx.alice);
+		byte[] removedTag = new byte[TAG_LENGTH];
+		Bytes removedTagWrapper = new Bytes(removedTag);
+		for(long conn : ctx.window.getUnseen()) {
+			TagEncoder.encodeTag(removedTag, cipher, key, conn);
+			WindowContext old = tagMap.remove(removedTagWrapper);
+			assert old != null;
+		}
+		key.erase();
+		ByteUtils.erase(ctx.secret);
+	}
+
+	synchronized void removeWindows(ContactId c) {
+		Collection<RemovalKey> keysToRemove = new ArrayList<RemovalKey>();
+		for(RemovalKey k : windowMap.keySet()) {
+			if(k.contactId.equals(c)) keysToRemove.add(k);
+		}
+		for(RemovalKey k : keysToRemove) removeWindow(k.contactId, k.period);
+	}
+
+	private static class WindowContext {
+
+		private final ConnectionWindow window;
+		private final ConnectionContext context;
+		private final long period;
+
+		private WindowContext(ConnectionWindow window,
+				ConnectionContext context, long period) {
+			this.window = window;
+			this.context = context;
+			this.period = period;
+		}
+	}
+
+	private static class RemovalKey {
+
+		private final ContactId contactId;
+		private final long period;
+
+		private RemovalKey(ContactId contactId, long period) {
+			this.contactId = contactId;
+			this.period = period;
+		}
+
+		@Override
+		public int hashCode() {
+			return contactId.hashCode()+ (int) period;
+		}
+
+		@Override
+		public boolean equals(Object o) {
+			if(o instanceof RemovalKey) {
+				RemovalKey w = (RemovalKey) o;
+				return contactId.equals(w.contactId) && period == w.period;
+			}
+			return false;
+		}
+	}
+
+	private static class RemovalContext {
+
+		private final ConnectionWindow window;
+		private final byte[] secret;
+		private final boolean alice;
+
+		private RemovalContext(ConnectionWindow window, byte[] secret,
+				boolean alice) {
+			this.window = window;
+			this.secret = secret;
+			this.alice = alice;
+		}
+	}
+}
diff --git a/test/build.xml b/test/build.xml
index 3b86a959e8..a83fab2ee6 100644
--- a/test/build.xml
+++ b/test/build.xml
@@ -49,7 +49,7 @@
 			<test name='net.sf.briar.serial.WriterImplTest'/>
 			<test name='net.sf.briar.transport.ConnectionReaderImplTest'/>
 			<test name='net.sf.briar.transport.ConnectionRegistryImplTest'/>
-			<test name='net.sf.briar.transport.ConnectionWindowImplTest'/>
+			<test name='net.sf.briar.transport.ConnectionWindowTest'/>
 			<test name='net.sf.briar.transport.ConnectionWriterImplTest'/>
 			<test name='net.sf.briar.transport.IncomingEncryptionLayerTest'/>
 			<test name='net.sf.briar.transport.OutgoingEncryptionLayerTest'/>
diff --git a/test/net/sf/briar/transport/ConnectionWindowImplTest.java b/test/net/sf/briar/transport/ConnectionWindowTest.java
similarity index 74%
rename from test/net/sf/briar/transport/ConnectionWindowImplTest.java
rename to test/net/sf/briar/transport/ConnectionWindowTest.java
index 0db1d33684..74de234cf9 100644
--- a/test/net/sf/briar/transport/ConnectionWindowImplTest.java
+++ b/test/net/sf/briar/transport/ConnectionWindowTest.java
@@ -1,20 +1,16 @@
 package net.sf.briar.transport;
 
-import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED;
-
-import java.util.HashSet;
-import java.util.Set;
+import java.util.Collection;
 
 import net.sf.briar.BriarTestCase;
-import net.sf.briar.api.transport.ConnectionWindow;
 
 import org.junit.Test;
 
-public class ConnectionWindowImplTest extends BriarTestCase {
+public class ConnectionWindowTest extends BriarTestCase {
 
 	@Test
 	public void testWindowSliding() {
-		ConnectionWindow w = new ConnectionWindowImpl();
+		ConnectionWindow w = new ConnectionWindow();
 		for(int i = 0; i < 100; i++) {
 			assertFalse(w.isSeen(i));
 			w.setSeen(i);
@@ -24,7 +20,7 @@ public class ConnectionWindowImplTest extends BriarTestCase {
 
 	@Test
 	public void testWindowJumping() {
-		ConnectionWindow w = new ConnectionWindowImpl();
+		ConnectionWindow w = new ConnectionWindow();
 		for(int i = 0; i < 100; i += 13) {
 			assertFalse(w.isSeen(i));
 			w.setSeen(i);
@@ -34,7 +30,7 @@ public class ConnectionWindowImplTest extends BriarTestCase {
 
 	@Test
 	public void testWindowUpperLimit() {
-		ConnectionWindow w = new ConnectionWindowImpl();
+		ConnectionWindow w = new ConnectionWindow();
 		// Centre is 0, highest value in window is 15
 		w.setSeen(15);
 		// Centre is 16, highest value in window is 31
@@ -44,20 +40,11 @@ public class ConnectionWindowImplTest extends BriarTestCase {
 			w.setSeen(48);
 			fail();
 		} catch(IllegalArgumentException expected) {}
-		// Values greater than 2^32 - 1 should never be allowed
-		Set<Long> unseen = new HashSet<Long>();
-		for(int i = 0; i < 32; i++) unseen.add(MAX_32_BIT_UNSIGNED - i);
-		w = new ConnectionWindowImpl(unseen);
-		w.setSeen(MAX_32_BIT_UNSIGNED);
-		try {
-			w.setSeen(MAX_32_BIT_UNSIGNED + 1);
-			fail();
-		} catch(IllegalArgumentException expected) {}
 	}
 
 	@Test
 	public void testWindowLowerLimit() {
-		ConnectionWindow w = new ConnectionWindowImpl();
+		ConnectionWindow w = new ConnectionWindow();
 		// Centre is 0, negative values should never be allowed
 		try {
 			w.setSeen(-1);
@@ -87,7 +74,7 @@ public class ConnectionWindowImplTest extends BriarTestCase {
 
 	@Test
 	public void testCannotSetSeenTwice() {
-		ConnectionWindow w = new ConnectionWindowImpl();
+		ConnectionWindow w = new ConnectionWindow();
 		w.setSeen(15);
 		try {
 			w.setSeen(15);
@@ -97,9 +84,9 @@ public class ConnectionWindowImplTest extends BriarTestCase {
 
 	@Test
 	public void testGetUnseenConnectionNumbers() {
-		ConnectionWindow w = new ConnectionWindowImpl();
+		ConnectionWindow w = new ConnectionWindow();
 		// Centre is 0; window should cover 0 to 15, inclusive, with none seen
-		Set<Long> unseen = w.getUnseen();
+		Collection<Long> unseen = w.getUnseen();
 		assertEquals(16, unseen.size());
 		for(int i = 0; i < 16; i++) {
 			assertTrue(unseen.contains(Long.valueOf(i)));
-- 
GitLab