Skip to content
Snippets Groups Projects
Commit 8068fa0d authored by akwizgran's avatar akwizgran
Browse files

Don't keep connection windows in memory.

parent 98148085
No related branches found
No related tags found
No related merge requests found
...@@ -6,8 +6,11 @@ import java.security.InvalidKeyException; ...@@ -6,8 +6,11 @@ import java.security.InvalidKeyException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator; import java.util.Iterator;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
...@@ -28,13 +31,13 @@ import net.sf.briar.api.db.event.DatabaseEvent; ...@@ -28,13 +31,13 @@ import net.sf.briar.api.db.event.DatabaseEvent;
import net.sf.briar.api.db.event.DatabaseListener; import net.sf.briar.api.db.event.DatabaseListener;
import net.sf.briar.api.db.event.RemoteTransportsUpdatedEvent; import net.sf.briar.api.db.event.RemoteTransportsUpdatedEvent;
import net.sf.briar.api.db.event.TransportAddedEvent; import net.sf.briar.api.db.event.TransportAddedEvent;
import net.sf.briar.api.lifecycle.ShutdownManager;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionRecogniser; import net.sf.briar.api.transport.ConnectionRecogniser;
import net.sf.briar.api.transport.ConnectionWindow; import net.sf.briar.api.transport.ConnectionWindow;
import net.sf.briar.util.ByteUtils;
import com.google.inject.Inject; import com.google.inject.Inject;
...@@ -47,7 +50,6 @@ DatabaseListener { ...@@ -47,7 +50,6 @@ DatabaseListener {
private final CryptoComponent crypto; private final CryptoComponent crypto;
private final DatabaseComponent db; private final DatabaseComponent db;
private final Executor executor; private final Executor executor;
private final ShutdownManager shutdown;
private final Cipher ivCipher; // Locking: this private final Cipher ivCipher; // Locking: this
private final Map<Bytes, Context> expected; // Locking: this private final Map<Bytes, Context> expected; // Locking: this
...@@ -55,65 +57,50 @@ DatabaseListener { ...@@ -55,65 +57,50 @@ DatabaseListener {
@Inject @Inject
ConnectionRecogniserImpl(CryptoComponent crypto, DatabaseComponent db, ConnectionRecogniserImpl(CryptoComponent crypto, DatabaseComponent db,
Executor executor, ShutdownManager shutdown) { Executor executor) {
this.crypto = crypto; this.crypto = crypto;
this.db = db; this.db = db;
this.executor = executor; this.executor = executor;
this.shutdown = shutdown;
ivCipher = crypto.getIvCipher(); ivCipher = crypto.getIvCipher();
expected = new HashMap<Bytes, Context>(); expected = new HashMap<Bytes, Context>();
db.addListener(this);
} }
// Locking: this // Locking: this
private void initialise() throws DbException { private void initialise() throws DbException {
assert !initialised; assert !initialised;
shutdown.addShutdownHook(new Runnable() { db.addListener(this);
public void run() { Map<Bytes, Context> ivs = new HashMap<Bytes, Context>();
eraseSecrets();
}
});
Collection<TransportId> transports = new ArrayList<TransportId>(); Collection<TransportId> transports = new ArrayList<TransportId>();
for(Transport t : db.getLocalTransports()) transports.add(t.getId()); for(Transport t : db.getLocalTransports()) transports.add(t.getId());
for(ContactId c : db.getContacts()) { for(ContactId c : db.getContacts()) {
Collection<Context> contexts = new ArrayList<Context>();
try { try {
for(TransportId t : transports) { for(TransportId t : transports) {
TransportIndex i = db.getRemoteIndex(c, t); TransportIndex i = db.getRemoteIndex(c, t);
if(i == null) continue; if(i == null) continue;
ConnectionWindow w = db.getConnectionWindow(c, i); ConnectionWindow w = db.getConnectionWindow(c, i);
for(long unseen : w.getUnseen().keySet()) { for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) {
contexts.add(new Context(c, t, i, unseen, w)); Context ctx = new Context(c, t, i, e.getKey());
ivs.put(calculateIv(ctx, e.getValue()), ctx);
} }
w.erase();
} }
} catch(NoSuchContactException e) { } catch(NoSuchContactException e) {
// The contact was removed - don't add the IVs // The contact was removed - clean up in removeContact()
for(Context ctx : contexts) ctx.window.erase();
continue; continue;
} }
for(Context ctx : contexts) expected.put(calculateIv(ctx), ctx);
} }
expected.putAll(ivs);
initialised = true; initialised = true;
} }
private synchronized void eraseSecrets() {
for(Context c : expected.values()) c.window.erase();
}
// Locking: this // Locking: this
private Bytes calculateIv(Context ctx) { private Bytes calculateIv(Context ctx, byte[] secret) {
byte[] secret = ctx.window.getUnseen().get(ctx.connection); byte[] iv = IvEncoder.encodeIv(true, ctx.transportIndex.getInt(),
byte[] iv = encryptIv(ctx.transportIndex, ctx.connection, secret); ctx.connection);
return new Bytes(iv);
}
// Locking: this
private byte[] encryptIv(TransportIndex i, long connection, byte[] secret) {
byte[] iv = IvEncoder.encodeIv(true, i.getInt(), connection);
ErasableKey ivKey = crypto.deriveIvKey(secret, true); ErasableKey ivKey = crypto.deriveIvKey(secret, true);
try { try {
ivCipher.init(Cipher.ENCRYPT_MODE, ivKey); ivCipher.init(Cipher.ENCRYPT_MODE, ivKey);
return ivCipher.doFinal(iv); return new Bytes(ivCipher.doFinal(iv));
} catch(BadPaddingException badCipher) { } catch(BadPaddingException badCipher) {
throw new RuntimeException(badCipher); throw new RuntimeException(badCipher);
} catch(IllegalBlockSizeException badCipher) { } catch(IllegalBlockSizeException badCipher) {
...@@ -154,15 +141,17 @@ DatabaseListener { ...@@ -154,15 +141,17 @@ DatabaseListener {
ContactId c = ctx.contactId; ContactId c = ctx.contactId;
TransportIndex i = ctx.transportIndex; TransportIndex i = ctx.transportIndex;
long connection = ctx.connection; long connection = ctx.connection;
ConnectionWindow w = ctx.window; ConnectionWindow w = null;
byte[] secret = null;
// Get the secret and update the connection window // Get the secret and update the connection window
byte[] secret = w.setSeen(connection);
try { try {
w = db.getConnectionWindow(c, i);
secret = w.setSeen(connection);
db.setConnectionWindow(c, i, w); db.setConnectionWindow(c, i, w);
} catch(NoSuchContactException e) { } catch(NoSuchContactException e) {
// The contact was removed - reject the connection // The contact was removed - reject the connection
removeContact(c); if(w != null) w.erase();
w.erase(); if(secret != null) ByteUtils.erase(secret);
return null; return null;
} }
// Update the connection window's expected IVs // Update the connection window's expected IVs
...@@ -172,26 +161,15 @@ DatabaseListener { ...@@ -172,26 +161,15 @@ DatabaseListener {
if(ctx1.contactId.equals(c) && ctx1.transportIndex.equals(i)) if(ctx1.contactId.equals(c) && ctx1.transportIndex.equals(i))
it.remove(); it.remove();
} }
for(long unseen : w.getUnseen().keySet()) { for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) {
Context ctx1 = new Context(c, t, i, unseen, w); Context ctx1 = new Context(c, t, i, e.getKey());
expected.put(calculateIv(ctx1), ctx1); expected.put(calculateIv(ctx1, e.getValue()), ctx1);
} }
w.erase();
return new ConnectionContextImpl(c, i, connection, secret); return new ConnectionContextImpl(c, i, connection, secret);
} }
} }
private synchronized void removeContact(ContactId c) {
if(!initialised) return;
Iterator<Context> it = expected.values().iterator();
while(it.hasNext()) {
Context ctx = it.next();
if(ctx.contactId.equals(c)) {
ctx.window.erase();
it.remove();
}
}
}
public void eventOccurred(DatabaseEvent e) { public void eventOccurred(DatabaseEvent e) {
if(e instanceof ContactRemovedEvent) { if(e instanceof ContactRemovedEvent) {
// Remove the expected IVs for the ex-contact // Remove the expected IVs for the ex-contact
...@@ -210,7 +188,7 @@ DatabaseListener { ...@@ -210,7 +188,7 @@ DatabaseListener {
} }
}); });
} else if(e instanceof RemoteTransportsUpdatedEvent) { } else if(e instanceof RemoteTransportsUpdatedEvent) {
// Recalculate the expected IVs for the contact // Update the expected IVs for the contact
final ContactId c = final ContactId c =
((RemoteTransportsUpdatedEvent) e).getContactId(); ((RemoteTransportsUpdatedEvent) e).getContactId();
executor.execute(new Runnable() { executor.execute(new Runnable() {
...@@ -221,52 +199,79 @@ DatabaseListener { ...@@ -221,52 +199,79 @@ DatabaseListener {
} }
} }
private synchronized void removeContact(ContactId c) {
if(!initialised) return;
Iterator<Context> it = expected.values().iterator();
while(it.hasNext()) if(it.next().contactId.equals(c)) it.remove();
}
private synchronized void addTransport(TransportId t) { private synchronized void addTransport(TransportId t) {
if(!initialised) return; if(!initialised) return;
Map<Bytes, Context> ivs = new HashMap<Bytes, Context>();
try { try {
for(ContactId c : db.getContacts()) { for(ContactId c : db.getContacts()) {
Collection<Context> contexts = new ArrayList<Context>();
try { try {
TransportIndex i = db.getRemoteIndex(c, t); TransportIndex i = db.getRemoteIndex(c, t);
if(i == null) continue; if(i == null) continue;
ConnectionWindow w = db.getConnectionWindow(c, i); ConnectionWindow w = db.getConnectionWindow(c, i);
for(long unseen : w.getUnseen().keySet()) { for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) {
contexts.add(new Context(c, t, i, unseen, w)); Context ctx = new Context(c, t, i, e.getKey());
ivs.put(calculateIv(ctx, e.getValue()), ctx);
} }
w.erase();
} catch(NoSuchContactException e) { } catch(NoSuchContactException e) {
// The contact was removed - don't add the IVs // The contact was removed - clean up in removeContact()
for(Context ctx : contexts) ctx.window.erase();
continue; continue;
} }
for(Context ctx : contexts) expected.put(calculateIv(ctx), ctx);
} }
} catch(DbException e) { } catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
return;
} }
expected.putAll(ivs);
} }
private synchronized void updateContact(ContactId c) { private synchronized void updateContact(ContactId c) {
if(!initialised) return; if(!initialised) return;
removeContact(c); // Don't recalculate IVs for transports that are already known
Set<TransportIndex> known = new HashSet<TransportIndex>();
for(Context ctx : expected.values()) {
if(ctx.contactId.equals(c)) known.add(ctx.transportIndex);
}
Set<TransportIndex> current = new HashSet<TransportIndex>();
Map<Bytes, Context> ivs = new HashMap<Bytes, Context>();
try { try {
Collection<Context> contexts = new ArrayList<Context>(); for(Transport transport : db.getLocalTransports()) {
try { TransportId t = transport.getId();
for(Transport transport : db.getLocalTransports()) { TransportIndex i = db.getRemoteIndex(c, t);
TransportId t = transport.getId(); if(i == null) continue;
TransportIndex i = db.getRemoteIndex(c, t); current.add(i);
// If the transport is not already known, calculate the IVs
if(!known.contains(i)) {
ConnectionWindow w = db.getConnectionWindow(c, i); ConnectionWindow w = db.getConnectionWindow(c, i);
for(long unseen : w.getUnseen().keySet()) { for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) {
contexts.add(new Context(c, t, i, unseen, w)); Context ctx = new Context(c, t, i, e.getKey());
ivs.put(calculateIv(ctx, e.getValue()), ctx);
} }
w.erase();
} }
} catch(NoSuchContactException e) {
// The contact was removed - don't add the IVs
return;
} }
for(Context ctx : contexts) expected.put(calculateIv(ctx), ctx); } catch(NoSuchContactException e) {
// The contact was removed - clean up in removeContact()
return;
} catch(DbException e) { } catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
return;
}
// Remove any IVs that are no longer current
Iterator<Context> it = expected.values().iterator();
while(it.hasNext()) {
Context ctx = it.next();
if(ctx.contactId.equals(c) && !current.contains(ctx.transportIndex))
it.remove();
} }
// Add any IVs that were not previously known
expected.putAll(ivs);
} }
private static class Context { private static class Context {
...@@ -275,17 +280,13 @@ DatabaseListener { ...@@ -275,17 +280,13 @@ DatabaseListener {
private final TransportId transportId; private final TransportId transportId;
private final TransportIndex transportIndex; private final TransportIndex transportIndex;
private final long connection; private final long connection;
// Locking: ConnectionRecogniser.this
private final ConnectionWindow window;
private Context(ContactId contactId, TransportId transportId, private Context(ContactId contactId, TransportId transportId,
TransportIndex transportIndex, long connection, TransportIndex transportIndex, long connection) {
ConnectionWindow window) {
this.contactId = contactId; this.contactId = contactId;
this.transportId = transportId; this.transportId = transportId;
this.transportIndex = transportIndex; this.transportIndex = transportIndex;
this.connection = connection; this.connection = connection;
this.window = window;
} }
} }
} }
...@@ -17,7 +17,6 @@ import net.sf.briar.api.crypto.CryptoComponent; ...@@ -17,7 +17,6 @@ import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.ErasableKey; import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.DbException;
import net.sf.briar.api.lifecycle.ShutdownManager;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
...@@ -66,11 +65,9 @@ public class ConnectionRecogniserImplTest extends TestCase { ...@@ -66,11 +65,9 @@ public class ConnectionRecogniserImplTest extends TestCase {
public void testUnexpectedIv() throws Exception { public void testUnexpectedIv() throws Exception {
Mockery context = new Mockery(); Mockery context = new Mockery();
final DatabaseComponent db = context.mock(DatabaseComponent.class); final DatabaseComponent db = context.mock(DatabaseComponent.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(db).addListener(with(any(ConnectionRecogniserImpl.class))); oneOf(db).addListener(with(any(ConnectionRecogniserImpl.class)));
// Initialise // Initialise
oneOf(shutdown).addShutdownHook(with(any(Runnable.class)));
oneOf(db).getLocalTransports(); oneOf(db).getLocalTransports();
will(returnValue(transports)); will(returnValue(transports));
oneOf(db).getContacts(); oneOf(db).getContacts();
...@@ -82,7 +79,7 @@ public class ConnectionRecogniserImplTest extends TestCase { ...@@ -82,7 +79,7 @@ public class ConnectionRecogniserImplTest extends TestCase {
}}); }});
Executor executor = new ImmediateExecutor(); Executor executor = new ImmediateExecutor();
ConnectionRecogniser c = new ConnectionRecogniserImpl(crypto, db, ConnectionRecogniser c = new ConnectionRecogniserImpl(crypto, db,
executor, shutdown); executor);
c.acceptConnection(transportId, new byte[IV_LENGTH], new Callback() { c.acceptConnection(transportId, new byte[IV_LENGTH], new Callback() {
public void connectionAccepted(ConnectionContext ctx) { public void connectionAccepted(ConnectionContext ctx) {
...@@ -116,11 +113,9 @@ public class ConnectionRecogniserImplTest extends TestCase { ...@@ -116,11 +113,9 @@ public class ConnectionRecogniserImplTest extends TestCase {
Mockery context = new Mockery(); Mockery context = new Mockery();
final DatabaseComponent db = context.mock(DatabaseComponent.class); final DatabaseComponent db = context.mock(DatabaseComponent.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(db).addListener(with(any(ConnectionRecogniserImpl.class))); oneOf(db).addListener(with(any(ConnectionRecogniserImpl.class)));
// Initialise // Initialise
oneOf(shutdown).addShutdownHook(with(any(Runnable.class)));
oneOf(db).getLocalTransports(); oneOf(db).getLocalTransports();
will(returnValue(transports)); will(returnValue(transports));
oneOf(db).getContacts(); oneOf(db).getContacts();
...@@ -130,12 +125,14 @@ public class ConnectionRecogniserImplTest extends TestCase { ...@@ -130,12 +125,14 @@ public class ConnectionRecogniserImplTest extends TestCase {
oneOf(db).getConnectionWindow(contactId, remoteIndex); oneOf(db).getConnectionWindow(contactId, remoteIndex);
will(returnValue(connectionWindow)); will(returnValue(connectionWindow));
// Update the window // Update the window
oneOf(db).getConnectionWindow(contactId, remoteIndex);
will(returnValue(connectionWindow));
oneOf(db).setConnectionWindow(contactId, remoteIndex, oneOf(db).setConnectionWindow(contactId, remoteIndex,
connectionWindow); connectionWindow);
}}); }});
Executor executor = new ImmediateExecutor(); Executor executor = new ImmediateExecutor();
ConnectionRecogniser c = new ConnectionRecogniserImpl(crypto, db, ConnectionRecogniser c = new ConnectionRecogniserImpl(crypto, db,
executor, shutdown); executor);
// The IV should not be expected by the wrong transport // The IV should not be expected by the wrong transport
TransportId wrong = new TransportId(TestUtils.getRandomId()); TransportId wrong = new TransportId(TestUtils.getRandomId());
c.acceptConnection(wrong, encryptedIv, new Callback() { c.acceptConnection(wrong, encryptedIv, new Callback() {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment