diff --git a/components/net/sf/briar/transport/ConnectionRecogniserImpl.java b/components/net/sf/briar/transport/ConnectionRecogniserImpl.java index c2c1f53c1b51ec7d4724e413bf48654cf3ed1235..0ce6f7305f112d5072608fc96ca0a9489c157046 100644 --- a/components/net/sf/briar/transport/ConnectionRecogniserImpl.java +++ b/components/net/sf/briar/transport/ConnectionRecogniserImpl.java @@ -8,7 +8,6 @@ import java.util.Collection; import java.util.HashMap; import java.util.Iterator; import java.util.Map; -import java.util.Map.Entry; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; import java.util.logging.Level; @@ -33,6 +32,7 @@ import net.sf.briar.api.db.event.TransportAddedEvent; import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportIndex; +import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionRecogniser; import net.sf.briar.api.transport.ConnectionWindow; import net.sf.briar.util.ByteUtils; @@ -48,8 +48,8 @@ DatabaseListener { private final CryptoComponent crypto; private final DatabaseComponent db; private final Executor executor; + private final Cipher ivCipher; // Locking: this private final Map<Bytes, Context> expected; // Locking: this - private final Collection<TransportId> localTransportIds; // Locking: this private final AtomicBoolean initialised = new AtomicBoolean(false); @Inject @@ -58,91 +58,90 @@ DatabaseListener { this.crypto = crypto; this.db = db; this.executor = executor; + ivCipher = crypto.getIvCipher(); expected = new HashMap<Bytes, Context>(); - localTransportIds = new ArrayList<TransportId>(); db.addListener(this); - } - - private void initialise() throws DbException { Runtime.getRuntime().addShutdownHook(new Thread() { @Override public void run() { eraseSecrets(); } }); - Collection<TransportId> ids = new ArrayList<TransportId>(); - for(Transport t : db.getLocalTransports()) ids.add(t.getId()); - synchronized(this) { - localTransportIds.addAll(ids); + } + + private synchronized void eraseSecrets() { + for(Context c : expected.values()) { + for(byte[] b : c.window.getUnseen().values()) ByteUtils.erase(b); } - Map<Bytes, Context> ivs = new HashMap<Bytes, Context>(); + } + + private void initialise() throws DbException { + // Fill in the contexts as far as possible outside the lock + Collection<Context> partial = new ArrayList<Context>(); + Collection<Transport> transports = db.getLocalTransports(); for(ContactId c : db.getContacts()) { - try { - ivs.putAll(calculateIvs(c)); - } catch(NoSuchContactException e) { - // The contact was removed - clean up in eventOccurred() + for(Transport transport : transports) { + getPartialContexts(c, transport.getId(), partial); } } synchronized(this) { - expected.putAll(ivs); + // Complete the contexts and calculate the expected IVs + calculateIvs(completeContexts(partial)); } } - private synchronized void eraseSecrets() { - for(Context c : expected.values()) { - synchronized(c.window) { - for(byte[] b : c.window.getUnseen().values()) { - ByteUtils.erase(b); + private void getPartialContexts(ContactId c, TransportId t, + Collection<Context> partial) throws DbException { + try { + TransportIndex i = db.getRemoteIndex(c, t); + if(i != null) { + // Acquire the lock to avoid getting stale data + synchronized(this) { + ConnectionWindow w = db.getConnectionWindow(c, i); + partial.add(new Context(c, t, i, -1, w)); } } + } catch(NoSuchContactException e) { + // The contact was removed - we'll handle the event later } } - private Map<Bytes, Context> calculateIvs(ContactId c) throws DbException { - Map<Bytes, Context> ivs = new HashMap<Bytes, Context>(); - Collection<TransportId> ids; - synchronized(this) { - ids = new ArrayList<TransportId>(localTransportIds); - } - for(TransportId t : ids) { - TransportIndex i = db.getRemoteIndex(c, t); - if(i != null) { - ConnectionWindow w = db.getConnectionWindow(c, i); - ivs.putAll(calculateIvs(c, t, i, w)); + // Locking: this + private Collection<Context> completeContexts(Collection<Context> partial) { + Collection<Context> contexts = new ArrayList<Context>(); + for(Context ctx : partial) { + for(long unseen : ctx.window.getUnseen().keySet()) { + contexts.add(new Context(ctx.contactId, ctx.transportId, + ctx.transportIndex, unseen, ctx.window)); } } - return ivs; + return contexts; } - private Map<Bytes, Context> calculateIvs(ContactId c, TransportId t, - TransportIndex i, ConnectionWindow w) throws DbException { - Map<Bytes, Context> ivs = new HashMap<Bytes, Context>(); - synchronized(w) { - for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) { - long connection = e.getKey(); - byte[] secret = e.getValue(); - Bytes iv = new Bytes(encryptIv(i, connection, secret)); - ivs.put(iv, new Context(c, t, i, connection, w)); - } + // Locking: this + private void calculateIvs(Collection<Context> contexts) { + for(Context ctx : contexts) { + byte[] secret = ctx.window.getUnseen().get(ctx.connection); + byte[] iv = encryptIv(ctx.transportIndex, ctx.connection, secret); + expected.put(new Bytes(iv), ctx); } - return ivs; } + // Locking: this private byte[] encryptIv(TransportIndex i, long connection, byte[] secret) { byte[] iv = IvEncoder.encodeIv(true, i.getInt(), connection); ErasableKey ivKey = crypto.deriveIvKey(secret, true); try { - Cipher ivCipher = crypto.getIvCipher(); ivCipher.init(Cipher.ENCRYPT_MODE, ivKey); - byte[] encryptedIv = ivCipher.doFinal(iv); - ivKey.erase(); - return encryptedIv; + return ivCipher.doFinal(iv); } catch(BadPaddingException badCipher) { throw new RuntimeException(badCipher); } catch(IllegalBlockSizeException badCipher) { throw new RuntimeException(badCipher); } catch(InvalidKeyException badKey) { throw new RuntimeException(badKey); + } finally { + ivKey.erase(); } } @@ -150,92 +149,82 @@ DatabaseListener { final Callback callback) { executor.execute(new Runnable() { public void run() { - acceptConnectionSync(t, encryptedIv, callback); + try { + ConnectionContext ctx = acceptConnection(t, encryptedIv); + if(ctx == null) callback.connectionRejected(); + else callback.connectionAccepted(ctx); + } catch(DbException e) { + callback.handleException(e); + } } }); } - private void acceptConnectionSync(TransportId t, byte[] encryptedIv, - Callback callback) { - try { - if(encryptedIv.length != IV_LENGTH) - throw new IllegalArgumentException(); - if(!initialised.getAndSet(true)) initialise(); - Context ctx; - synchronized(this) { - Bytes b = new Bytes(encryptedIv); - ctx = expected.get(b); - if(ctx == null || !ctx.transportId.equals(t)) { - callback.connectionRejected(); - return; - } - expected.remove(b); - } + private ConnectionContext acceptConnection(TransportId t, + byte[] encryptedIv) throws DbException { + if(encryptedIv.length != IV_LENGTH) + throw new IllegalArgumentException(); + if(!initialised.getAndSet(true)) initialise(); + synchronized(this) { + Bytes b = new Bytes(encryptedIv); + Context ctx = expected.get(b); + if(ctx == null || !ctx.transportId.equals(t)) return null; // The IV was expected + expected.remove(b); ContactId c = ctx.contactId; TransportIndex i = ctx.transportIndex; long connection = ctx.connection; ConnectionWindow w = ctx.window; byte[] secret; - synchronized(w) { - // Get the secret and update the connection window - secret = w.setSeen(connection); - try { - db.setConnectionWindow(c, i, w); - } catch(NoSuchContactException e) { - // The contact was removed - clean up in eventOccurred() - } + // Get the secret and update the connection window + try { + db.setConnectionWindow(c, i, w); + } catch(NoSuchContactException e) { + // The contact was removed - we'll handle the event later } - // Update the set of expected IVs - Map<Bytes, Context> ivs = calculateIvs(c, t, i, w); - synchronized(this) { - Iterator<Context> it = expected.values().iterator(); - while(it.hasNext()) { - Context ctx1 = it.next(); - if(ctx1.contactId.equals(c) - && ctx1.transportIndex.equals(i)) it.remove(); - } - expected.putAll(ivs); + secret = w.setSeen(connection); + // Update the connection window's expected IVs + Iterator<Context> it = expected.values().iterator(); + while(it.hasNext()) { + Context ctx1 = it.next(); + if(ctx1.contactId.equals(c) + && ctx1.transportIndex.equals(i)) it.remove(); } - callback.connectionAccepted(new ConnectionContextImpl(c, i, - connection, secret)); - } catch(DbException e) { - callback.handleException(e); + Collection<Context> contexts = new ArrayList<Context>(); + for(long unseen : w.getUnseen().keySet()) { + contexts.add(new Context(c, t, i, unseen, w)); + } + calculateIvs(contexts); + return new ConnectionContextImpl(c, i, connection, secret); } } public void eventOccurred(DatabaseEvent e) { if(e instanceof ContactRemovedEvent) { // Remove the expected IVs for the ex-contact - removeIvs(((ContactRemovedEvent) e).getContactId()); + final ContactId c = ((ContactRemovedEvent) e).getContactId(); + executor.execute(new Runnable() { + public void run() { + removeIvs(c); + } + }); } else if(e instanceof TransportAddedEvent) { - // Calculate the expected IVs for the new transport - TransportId t = ((TransportAddedEvent) e).getTransportId(); - try { - if(!initialised.getAndSet(true)) initialise(); - Map<Bytes, Context> ivs = calculateIvs(t); - synchronized(this) { - localTransportIds.add(t); - expected.putAll(ivs); + // Add the expected IVs for the new transport + final TransportId t = ((TransportAddedEvent) e).getTransportId(); + executor.execute(new Runnable() { + public void run() { + addTransport(t); } - } catch(DbException e1) { - if(LOG.isLoggable(Level.WARNING)) - LOG.warning(e1.getMessage()); - } + }); } else if(e instanceof RemoteTransportsUpdatedEvent) { - // Remove and recalculate the expected IVs for the contact - ContactId c = ((RemoteTransportsUpdatedEvent) e).getContactId(); - try { - if(!initialised.getAndSet(true)) initialise(); - Map<Bytes, Context> ivs = calculateIvs(c); - synchronized(this) { - removeIvs(c); - expected.putAll(ivs); + // Recalculate the expected IVs for the contact + final ContactId c = + ((RemoteTransportsUpdatedEvent) e).getContactId(); + executor.execute(new Runnable() { + public void run() { + updateContact(c); } - } catch(DbException e1) { - if(LOG.isLoggable(Level.WARNING)) - LOG.warning(e1.getMessage()); - } + }); } } @@ -244,20 +233,42 @@ DatabaseListener { while(it.hasNext()) if(it.next().contactId.equals(c)) it.remove(); } - private Map<Bytes, Context> calculateIvs(TransportId t) throws DbException { - Map<Bytes, Context> ivs = new HashMap<Bytes, Context>(); - for(ContactId c : db.getContacts()) { - try { - TransportIndex i = db.getRemoteIndex(c, t); - if(i != null) { - ConnectionWindow w = db.getConnectionWindow(c, i); - ivs.putAll(calculateIvs(c, t, i, w)); - } - } catch(NoSuchContactException e) { - // The contact was removed - clean up in eventOccurred() + private void addTransport(TransportId t) { + // Fill in the contexts as far as possible outside the lock + Collection<Context> partial = new ArrayList<Context>(); + try { + for(ContactId c : db.getContacts()) { + getPartialContexts(c, t, partial); + } + } catch(DbException e) { + if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); + return; + } + synchronized(this) { + // Complete the contexts and calculate the expected IVs + calculateIvs(completeContexts(partial)); + } + } + + private void updateContact(ContactId c) { + // Fill in the contexts as far as possible outside the lock + Collection<Context> partial = new ArrayList<Context>(); + try { + Collection<Transport> transports = db.getLocalTransports(); + for(Transport transport : transports) { + getPartialContexts(c, transport.getId(), partial); } + } catch(DbException e) { + if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); + return; + } + synchronized(this) { + // Clear the contact's existing IVs + Iterator<Context> it = expected.values().iterator(); + while(it.hasNext()) if(it.next().contactId.equals(c)) it.remove(); + // Complete the contexts and calculate the expected IVs + calculateIvs(completeContexts(partial)); } - return ivs; } private static class Context { @@ -266,6 +277,7 @@ DatabaseListener { private final TransportId transportId; private final TransportIndex transportIndex; private final long connection; + // Locking: ConnectionRecogniser.this private final ConnectionWindow window; private Context(ContactId contactId, TransportId transportId,