diff --git a/api/net/sf/briar/api/transport/ConnectionRecogniser.java b/api/net/sf/briar/api/transport/ConnectionRecogniser.java index 7336f002dfc9ec8696c71bdc2f66cc4e03cb3694..01bd760beda392edcb94f4076ddf5005aa95ff1f 100644 --- a/api/net/sf/briar/api/transport/ConnectionRecogniser.java +++ b/api/net/sf/briar/api/transport/ConnectionRecogniser.java @@ -1,6 +1,7 @@ package net.sf.briar.api.transport; import net.sf.briar.api.db.DbException; +import net.sf.briar.api.protocol.TransportId; /** * Maintains the connection reordering windows and decides whether incoming @@ -12,5 +13,6 @@ public interface ConnectionRecogniser { * Returns the connection's context if the connection should be accepted, * or null if the connection should be rejected. */ - ConnectionContext acceptConnection(byte[] encryptedIv) throws DbException; + ConnectionContext acceptConnection(TransportId t, byte[] encryptedIv) + throws DbException; } diff --git a/components/net/sf/briar/transport/ConnectionDispatcherImpl.java b/components/net/sf/briar/transport/ConnectionDispatcherImpl.java index 910c303d0e4aceea62ba501d3b424ce5f1156cc4..e93a44fe6b0a7a091cc92ee73190400b8258e6d3 100644 --- a/components/net/sf/briar/transport/ConnectionDispatcherImpl.java +++ b/components/net/sf/briar/transport/ConnectionDispatcherImpl.java @@ -52,7 +52,7 @@ public class ConnectionDispatcherImpl implements ConnectionDispatcher { // Get the connection context, or null if the IV wasn't expected ConnectionContext ctx; try { - ctx = recogniser.acceptConnection(encryptedIv); + ctx = recogniser.acceptConnection(t, encryptedIv); } catch(DbException e) { if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); r.dispose(false); @@ -95,7 +95,7 @@ public class ConnectionDispatcherImpl implements ConnectionDispatcher { // Get the connection context, or null if the IV wasn't expected ConnectionContext ctx; try { - ctx = recogniser.acceptConnection(encryptedIv); + ctx = recogniser.acceptConnection(t, encryptedIv); } catch(DbException e) { if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); s.dispose(false); diff --git a/components/net/sf/briar/transport/ConnectionRecogniserImpl.java b/components/net/sf/briar/transport/ConnectionRecogniserImpl.java index 6603de1a6657a40ae4d385236622ec6f278f089e..4fac2b3544c543b8b99ee7feac07930c46e1fc1c 100644 --- a/components/net/sf/briar/transport/ConnectionRecogniserImpl.java +++ b/components/net/sf/briar/transport/ConnectionRecogniserImpl.java @@ -87,20 +87,20 @@ DatabaseListener { TransportIndex i = db.getRemoteIndex(c, t); if(i != null) { ConnectionWindow w = db.getConnectionWindow(c, i); - calculateIvs(c, i, w); + calculateIvs(c, t, i, w); } } } - private synchronized void calculateIvs(ContactId c, TransportIndex i, - ConnectionWindow w) throws DbException { + private synchronized void calculateIvs(ContactId c, TransportId t, + TransportIndex i, ConnectionWindow w) throws DbException { for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) { long connection = e.getKey(); byte[] secret = e.getValue(); ErasableKey ivKey = crypto.deriveIvKey(secret, true); Bytes iv = new Bytes(encryptIv(i, connection, ivKey)); ivKey.erase(); - expected.put(iv, new Context(c, i, connection, w)); + expected.put(iv, new Context(c, t, i, connection, w)); } } @@ -125,17 +125,24 @@ DatabaseListener { } } - public synchronized ConnectionContext acceptConnection(byte[] encryptedIv) - throws DbException { + public synchronized ConnectionContext acceptConnection(TransportId t, + byte[] encryptedIv) throws DbException { if(encryptedIv.length != IV_LENGTH) throw new IllegalArgumentException(); if(!initialised) initialise(); - Context ctx = expected.remove(new Bytes(encryptedIv)); - if(ctx == null) return null; // The IV was not expected + Bytes b = new Bytes(encryptedIv); + Context ctx = expected.get(b); + // If the IV was not expected by this transport, reject the connection + if(ctx == null || !ctx.transportId.equals(t)) return null; + expected.remove(b); + ContactId c = ctx.contactId; + TransportIndex i = ctx.transportIndex; + long connection = ctx.connection; + ConnectionWindow w = ctx.window; // Get the secret and update the connection window - byte[] secret = ctx.window.setSeen(ctx.connection); + byte[] secret = w.setSeen(connection); try { - db.setConnectionWindow(ctx.contactId, ctx.index, ctx.window); + db.setConnectionWindow(c, i, w); } catch(NoSuchContactException e) { // The contact was removed - clean up when we get the event } @@ -143,12 +150,11 @@ DatabaseListener { Iterator<Context> it = expected.values().iterator(); while(it.hasNext()) { Context ctx1 = it.next(); - if(ctx1.contactId.equals(ctx.contactId) - && ctx1.index.equals(ctx.index)) it.remove(); + if(ctx1.contactId.equals(c) && ctx1.transportIndex.equals(i)) + it.remove(); } - calculateIvs(ctx.contactId, ctx.index, ctx.window); - return new ConnectionContextImpl(ctx.contactId, ctx.index, - ctx.connection, secret); + calculateIvs(c, t, i, w); + return new ConnectionContextImpl(c, i, connection, secret); } public void eventOccurred(DatabaseEvent e) { @@ -195,7 +201,7 @@ DatabaseListener { TransportIndex i = db.getRemoteIndex(c, t); if(i != null) { ConnectionWindow w = db.getConnectionWindow(c, i); - calculateIvs(c, i, w); + calculateIvs(c, t, i, w); } } catch(NoSuchContactException e) { // The contact was removed - clean up when we get the event @@ -206,14 +212,17 @@ DatabaseListener { private static class Context { private final ContactId contactId; - private final TransportIndex index; + private final TransportId transportId; + private final TransportIndex transportIndex; private final long connection; private final ConnectionWindow window; - private Context(ContactId contactId, TransportIndex index, - long connection, ConnectionWindow window) { + private Context(ContactId contactId, TransportId transportId, + TransportIndex transportIndex, long connection, + ConnectionWindow window) { this.contactId = contactId; - this.index = index; + this.transportId = transportId; + this.transportIndex = transportIndex; this.connection = connection; this.window = window; } diff --git a/test/net/sf/briar/transport/ConnectionRecogniserImplTest.java b/test/net/sf/briar/transport/ConnectionRecogniserImplTest.java index a0c552f8e1081b5c639f8881f06b0d7e4a42356a..87eacbd7f95f3454f27d04674b41fbc227e58643 100644 --- a/test/net/sf/briar/transport/ConnectionRecogniserImplTest.java +++ b/test/net/sf/briar/transport/ConnectionRecogniserImplTest.java @@ -74,7 +74,7 @@ public class ConnectionRecogniserImplTest extends TestCase { }}); final ConnectionRecogniserImpl c = new ConnectionRecogniserImpl(crypto, db); - assertNull(c.acceptConnection(new byte[IV_LENGTH])); + assertNull(c.acceptConnection(transportId, new byte[IV_LENGTH])); context.assertIsSatisfied(); } @@ -111,20 +111,22 @@ public class ConnectionRecogniserImplTest extends TestCase { }}); final ConnectionRecogniserImpl c = new ConnectionRecogniserImpl(crypto, db); - // First time - the IV should be expected - ConnectionContext ctx = c.acceptConnection(encryptedIv); + // The IV should not be expected by the wrong transport + TransportId wrong = new TransportId(TestUtils.getRandomId()); + assertNull(c.acceptConnection(wrong, encryptedIv)); + // The IV should be expected by the right transport + ConnectionContext ctx = c.acceptConnection(transportId, encryptedIv); assertNotNull(ctx); assertEquals(contactId, ctx.getContactId()); assertEquals(remoteIndex, ctx.getTransportIndex()); assertEquals(3L, ctx.getConnectionNumber()); - // Second time - the IV should no longer be expected - assertNull(c.acceptConnection(encryptedIv)); + // The IV should no longer be expected + assertNull(c.acceptConnection(transportId, encryptedIv)); // The window should have advanced Map<Long, byte[]> unseen = connectionWindow.getUnseen(); assertEquals(19, unseen.size()); for(int i = 0; i < 19; i++) { - if(i == 3) continue; - assertTrue(unseen.containsKey(Long.valueOf(i))); + assertEquals(i != 3, unseen.containsKey(Long.valueOf(i))); } context.assertIsSatisfied(); } diff --git a/test/net/sf/briar/transport/batch/BatchConnectionReadWriteTest.java b/test/net/sf/briar/transport/batch/BatchConnectionReadWriteTest.java index 0be09b15c910331c94a9b6b41f5cc0f6bed026b6..0eab9d978f592a1e0916c1a15472cbab8217f1a2 100644 --- a/test/net/sf/briar/transport/batch/BatchConnectionReadWriteTest.java +++ b/test/net/sf/briar/transport/batch/BatchConnectionReadWriteTest.java @@ -159,7 +159,7 @@ public class BatchConnectionReadWriteTest extends TestCase { byte[] encryptedIv = new byte[IV_LENGTH]; int read = in.read(encryptedIv); assertEquals(encryptedIv.length, read); - ConnectionContext ctx = rec.acceptConnection(encryptedIv); + ConnectionContext ctx = rec.acceptConnection(transportId, encryptedIv); assertNotNull(ctx); assertEquals(contactId, ctx.getContactId()); assertEquals(transportIndex, ctx.getTransportIndex());