From 8068fa0d38e7977e96ac5d8e2afb9ffe9aaa126d Mon Sep 17 00:00:00 2001
From: akwizgran <akwizgran@users.sourceforge.net>
Date: Thu, 24 Nov 2011 13:56:58 +0000
Subject: [PATCH] Don't keep connection windows in memory.

---
 .../transport/ConnectionRecogniserImpl.java   | 147 +++++++++---------
 .../ConnectionRecogniserImplTest.java         |  11 +-
 2 files changed, 78 insertions(+), 80 deletions(-)

diff --git a/components/net/sf/briar/transport/ConnectionRecogniserImpl.java b/components/net/sf/briar/transport/ConnectionRecogniserImpl.java
index 34eacc8684..6be76b664a 100644
--- a/components/net/sf/briar/transport/ConnectionRecogniserImpl.java
+++ b/components/net/sf/briar/transport/ConnectionRecogniserImpl.java
@@ -6,8 +6,11 @@ import java.security.InvalidKeyException;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.Iterator;
 import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Set;
 import java.util.concurrent.Executor;
 import java.util.logging.Level;
 import java.util.logging.Logger;
@@ -28,13 +31,13 @@ import net.sf.briar.api.db.event.DatabaseEvent;
 import net.sf.briar.api.db.event.DatabaseListener;
 import net.sf.briar.api.db.event.RemoteTransportsUpdatedEvent;
 import net.sf.briar.api.db.event.TransportAddedEvent;
-import net.sf.briar.api.lifecycle.ShutdownManager;
 import net.sf.briar.api.protocol.Transport;
 import net.sf.briar.api.protocol.TransportId;
 import net.sf.briar.api.protocol.TransportIndex;
 import net.sf.briar.api.transport.ConnectionContext;
 import net.sf.briar.api.transport.ConnectionRecogniser;
 import net.sf.briar.api.transport.ConnectionWindow;
+import net.sf.briar.util.ByteUtils;
 
 import com.google.inject.Inject;
 
@@ -47,7 +50,6 @@ DatabaseListener {
 	private final CryptoComponent crypto;
 	private final DatabaseComponent db;
 	private final Executor executor;
-	private final ShutdownManager shutdown;
 	private final Cipher ivCipher; // Locking: this
 	private final Map<Bytes, Context> expected; // Locking: this
 
@@ -55,65 +57,50 @@ DatabaseListener {
 
 	@Inject
 	ConnectionRecogniserImpl(CryptoComponent crypto, DatabaseComponent db,
-			Executor executor, ShutdownManager shutdown) {
+			Executor executor) {
 		this.crypto = crypto;
 		this.db = db;
 		this.executor = executor;
-		this.shutdown = shutdown;
 		ivCipher = crypto.getIvCipher();
 		expected = new HashMap<Bytes, Context>();
-		db.addListener(this);
 	}
 
 	// Locking: this
 	private void initialise() throws DbException {
 		assert !initialised;
-		shutdown.addShutdownHook(new Runnable() {
-			public void run() {
-				eraseSecrets();
-			}
-		});
+		db.addListener(this);
+		Map<Bytes, Context> ivs = new HashMap<Bytes, Context>();
 		Collection<TransportId> transports = new ArrayList<TransportId>();
 		for(Transport t : db.getLocalTransports()) transports.add(t.getId());
 		for(ContactId c : db.getContacts()) {
-			Collection<Context> contexts = new ArrayList<Context>();
 			try {
 				for(TransportId t : transports) {
 					TransportIndex i = db.getRemoteIndex(c, t);
 					if(i == null) continue;
 					ConnectionWindow w = db.getConnectionWindow(c, i);
-					for(long unseen : w.getUnseen().keySet()) {
-						contexts.add(new Context(c, t, i, unseen, w));
+					for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) {
+						Context ctx = new Context(c, t, i, e.getKey());
+						ivs.put(calculateIv(ctx, e.getValue()), ctx);
 					}
+					w.erase();
 				}
 			} catch(NoSuchContactException e) {
-				// The contact was removed - don't add the IVs
-				for(Context ctx : contexts) ctx.window.erase();
+				// The contact was removed - clean up in removeContact()
 				continue;
 			}
-			for(Context ctx : contexts) expected.put(calculateIv(ctx), ctx);
 		}
+		expected.putAll(ivs);
 		initialised = true;
 	}
 
-	private synchronized void eraseSecrets() {
-		for(Context c : expected.values()) c.window.erase();
-	}
-
 	// Locking: this
-	private Bytes calculateIv(Context ctx) {
-		byte[] secret = ctx.window.getUnseen().get(ctx.connection);
-		byte[] iv = encryptIv(ctx.transportIndex, ctx.connection, secret);
-		return new Bytes(iv);
-	}
-
-	// Locking: this
-	private byte[] encryptIv(TransportIndex i, long connection, byte[] secret) {
-		byte[] iv = IvEncoder.encodeIv(true, i.getInt(), connection);
+	private Bytes calculateIv(Context ctx, byte[] secret) {
+		byte[] iv = IvEncoder.encodeIv(true, ctx.transportIndex.getInt(),
+				ctx.connection);
 		ErasableKey ivKey = crypto.deriveIvKey(secret, true);
 		try {
 			ivCipher.init(Cipher.ENCRYPT_MODE, ivKey);
-			return ivCipher.doFinal(iv);
+			return new Bytes(ivCipher.doFinal(iv));
 		} catch(BadPaddingException badCipher) {
 			throw new RuntimeException(badCipher);
 		} catch(IllegalBlockSizeException badCipher) {
@@ -154,15 +141,17 @@ DatabaseListener {
 			ContactId c = ctx.contactId;
 			TransportIndex i = ctx.transportIndex;
 			long connection = ctx.connection;
-			ConnectionWindow w = ctx.window;
+			ConnectionWindow w = null;
+			byte[] secret = null;
 			// Get the secret and update the connection window
-			byte[] secret = w.setSeen(connection);
 			try {
+				w = db.getConnectionWindow(c, i);
+				secret = w.setSeen(connection);
 				db.setConnectionWindow(c, i, w);
 			} catch(NoSuchContactException e) {
 				// The contact was removed - reject the connection
-				removeContact(c);
-				w.erase();
+				if(w != null) w.erase();
+				if(secret != null) ByteUtils.erase(secret);
 				return null;
 			}
 			// Update the connection window's expected IVs
@@ -172,26 +161,15 @@ DatabaseListener {
 				if(ctx1.contactId.equals(c) && ctx1.transportIndex.equals(i))
 					it.remove();
 			}
-			for(long unseen : w.getUnseen().keySet()) {
-				Context ctx1 = new Context(c, t, i, unseen, w);
-				expected.put(calculateIv(ctx1), ctx1);
+			for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) {
+				Context ctx1 = new Context(c, t, i, e.getKey());
+				expected.put(calculateIv(ctx1, e.getValue()), ctx1);
 			}
+			w.erase();
 			return new ConnectionContextImpl(c, i, connection, secret);
 		}
 	}
 
-	private synchronized void removeContact(ContactId c) {
-		if(!initialised) return;
-		Iterator<Context> it = expected.values().iterator();
-		while(it.hasNext()) {
-			Context ctx = it.next();
-			if(ctx.contactId.equals(c)) {
-				ctx.window.erase();
-				it.remove();
-			}
-		}
-	}
-
 	public void eventOccurred(DatabaseEvent e) {
 		if(e instanceof ContactRemovedEvent) {
 			// Remove the expected IVs for the ex-contact
@@ -210,7 +188,7 @@ DatabaseListener {
 				}
 			});
 		} else if(e instanceof RemoteTransportsUpdatedEvent) {
-			// Recalculate the expected IVs for the contact
+			// Update the expected IVs for the contact
 			final ContactId c =
 				((RemoteTransportsUpdatedEvent) e).getContactId();
 			executor.execute(new Runnable() {
@@ -221,52 +199,79 @@ DatabaseListener {
 		}
 	}
 
+	private synchronized void removeContact(ContactId c) {
+		if(!initialised) return;
+		Iterator<Context> it = expected.values().iterator();
+		while(it.hasNext()) if(it.next().contactId.equals(c)) it.remove();
+	}
+
 	private synchronized void addTransport(TransportId t) {
 		if(!initialised) return;
+		Map<Bytes, Context> ivs = new HashMap<Bytes, Context>();
 		try {
 			for(ContactId c : db.getContacts()) {
-				Collection<Context> contexts = new ArrayList<Context>();
 				try {
 					TransportIndex i = db.getRemoteIndex(c, t);
 					if(i == null) continue;
 					ConnectionWindow w = db.getConnectionWindow(c, i);
-					for(long unseen : w.getUnseen().keySet()) {
-						contexts.add(new Context(c, t, i, unseen, w));
+					for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) {
+						Context ctx = new Context(c, t, i, e.getKey());
+						ivs.put(calculateIv(ctx, e.getValue()), ctx);
 					}
+					w.erase();
 				} catch(NoSuchContactException e) {
-					// The contact was removed - don't add the IVs
-					for(Context ctx : contexts) ctx.window.erase();
+					// The contact was removed - clean up in removeContact()
 					continue;
 				}
-				for(Context ctx : contexts) expected.put(calculateIv(ctx), ctx);
 			}
 		} catch(DbException e) {
 			if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
+			return;
 		}
+		expected.putAll(ivs);
 	}
 
 	private synchronized void updateContact(ContactId c) {
 		if(!initialised) return;
-		removeContact(c);
+		// Don't recalculate IVs for transports that are already known
+		Set<TransportIndex> known = new HashSet<TransportIndex>();
+		for(Context ctx : expected.values()) {
+			if(ctx.contactId.equals(c)) known.add(ctx.transportIndex);
+		}
+		Set<TransportIndex> current = new HashSet<TransportIndex>();
+		Map<Bytes, Context> ivs = new HashMap<Bytes, Context>();
 		try {
-			Collection<Context> contexts = new ArrayList<Context>();
-			try {
-				for(Transport transport : db.getLocalTransports()) {
-					TransportId t = transport.getId();
-					TransportIndex i = db.getRemoteIndex(c, t);
+			for(Transport transport : db.getLocalTransports()) {
+				TransportId t = transport.getId();
+				TransportIndex i = db.getRemoteIndex(c, t);
+				if(i == null) continue;
+				current.add(i);
+				// If the transport is not already known, calculate the IVs
+				if(!known.contains(i)) {
 					ConnectionWindow w = db.getConnectionWindow(c, i);
-					for(long unseen : w.getUnseen().keySet()) {
-						contexts.add(new Context(c, t, i, unseen, w));
+					for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) {
+						Context ctx = new Context(c, t, i, e.getKey());
+						ivs.put(calculateIv(ctx, e.getValue()), ctx);
 					}
+					w.erase();
 				}
-			} catch(NoSuchContactException e) {
-				// The contact was removed - don't add the IVs
-				return;
 			}
-			for(Context ctx : contexts) expected.put(calculateIv(ctx), ctx);
+		} catch(NoSuchContactException e) {
+			// The contact was removed - clean up in removeContact()
+			return;
 		} catch(DbException e) {
 			if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
+			return;
+		}
+		// Remove any IVs that are no longer current
+		Iterator<Context> it = expected.values().iterator();
+		while(it.hasNext()) {
+			Context ctx = it.next();
+			if(ctx.contactId.equals(c) && !current.contains(ctx.transportIndex))
+				it.remove();
 		}
+		// Add any IVs that were not previously known
+		expected.putAll(ivs);
 	}
 
 	private static class Context {
@@ -275,17 +280,13 @@ DatabaseListener {
 		private final TransportId transportId;
 		private final TransportIndex transportIndex;
 		private final long connection;
-		// Locking: ConnectionRecogniser.this
-		private final ConnectionWindow window;
 
 		private Context(ContactId contactId, TransportId transportId,
-				TransportIndex transportIndex, long connection,
-				ConnectionWindow window) {
+				TransportIndex transportIndex, long connection) {
 			this.contactId = contactId;
 			this.transportId = transportId;
 			this.transportIndex = transportIndex;
 			this.connection = connection;
-			this.window = window;
 		}
 	}
 }
diff --git a/test/net/sf/briar/transport/ConnectionRecogniserImplTest.java b/test/net/sf/briar/transport/ConnectionRecogniserImplTest.java
index f0c5b8b765..b7ae59a940 100644
--- a/test/net/sf/briar/transport/ConnectionRecogniserImplTest.java
+++ b/test/net/sf/briar/transport/ConnectionRecogniserImplTest.java
@@ -17,7 +17,6 @@ import net.sf.briar.api.crypto.CryptoComponent;
 import net.sf.briar.api.crypto.ErasableKey;
 import net.sf.briar.api.db.DatabaseComponent;
 import net.sf.briar.api.db.DbException;
-import net.sf.briar.api.lifecycle.ShutdownManager;
 import net.sf.briar.api.protocol.Transport;
 import net.sf.briar.api.protocol.TransportId;
 import net.sf.briar.api.protocol.TransportIndex;
@@ -66,11 +65,9 @@ public class ConnectionRecogniserImplTest extends TestCase {
 	public void testUnexpectedIv() throws Exception {
 		Mockery context = new Mockery();
 		final DatabaseComponent db = context.mock(DatabaseComponent.class);
-		final ShutdownManager shutdown = context.mock(ShutdownManager.class);
 		context.checking(new Expectations() {{
 			oneOf(db).addListener(with(any(ConnectionRecogniserImpl.class)));
 			// Initialise
-			oneOf(shutdown).addShutdownHook(with(any(Runnable.class)));
 			oneOf(db).getLocalTransports();
 			will(returnValue(transports));
 			oneOf(db).getContacts();
@@ -82,7 +79,7 @@ public class ConnectionRecogniserImplTest extends TestCase {
 		}});
 		Executor executor = new ImmediateExecutor();
 		ConnectionRecogniser c = new ConnectionRecogniserImpl(crypto, db,
-				executor, shutdown);
+				executor);
 		c.acceptConnection(transportId, new byte[IV_LENGTH], new Callback() {
 
 			public void connectionAccepted(ConnectionContext ctx) {
@@ -116,11 +113,9 @@ public class ConnectionRecogniserImplTest extends TestCase {
 
 		Mockery context = new Mockery();
 		final DatabaseComponent db = context.mock(DatabaseComponent.class);
-		final ShutdownManager shutdown = context.mock(ShutdownManager.class);
 		context.checking(new Expectations() {{
 			oneOf(db).addListener(with(any(ConnectionRecogniserImpl.class)));
 			// Initialise
-			oneOf(shutdown).addShutdownHook(with(any(Runnable.class)));
 			oneOf(db).getLocalTransports();
 			will(returnValue(transports));
 			oneOf(db).getContacts();
@@ -130,12 +125,14 @@ public class ConnectionRecogniserImplTest extends TestCase {
 			oneOf(db).getConnectionWindow(contactId, remoteIndex);
 			will(returnValue(connectionWindow));
 			// Update the window
+			oneOf(db).getConnectionWindow(contactId, remoteIndex);
+			will(returnValue(connectionWindow));
 			oneOf(db).setConnectionWindow(contactId, remoteIndex,
 					connectionWindow);
 		}});
 		Executor executor = new ImmediateExecutor();
 		ConnectionRecogniser c = new ConnectionRecogniserImpl(crypto, db,
-				executor, shutdown);
+				executor);
 		// The IV should not be expected by the wrong transport
 		TransportId wrong = new TransportId(TestUtils.getRandomId());
 		c.acceptConnection(wrong, encryptedIv, new Callback() {
-- 
GitLab