diff --git a/api/net/sf/briar/api/transport/ConnectionWindow.java b/api/net/sf/briar/api/transport/ConnectionWindow.java index 1a403f6c947c33de971298144a81d79a410f39a1..794846fcd408a7b15e6890766f0100ae29872eda 100644 --- a/api/net/sf/briar/api/transport/ConnectionWindow.java +++ b/api/net/sf/briar/api/transport/ConnectionWindow.java @@ -9,4 +9,6 @@ public interface ConnectionWindow { byte[] setSeen(long connection); Map<Long, byte[]> getUnseen(); + + void erase(); } diff --git a/components/net/sf/briar/transport/ConnectionRecogniserImpl.java b/components/net/sf/briar/transport/ConnectionRecogniserImpl.java index 0ce6f7305f112d5072608fc96ca0a9489c157046..940b37c0934cf8c75cb62a77f384d74c1dd0aab9 100644 --- a/components/net/sf/briar/transport/ConnectionRecogniserImpl.java +++ b/components/net/sf/briar/transport/ConnectionRecogniserImpl.java @@ -9,7 +9,6 @@ import java.util.HashMap; import java.util.Iterator; import java.util.Map; import java.util.concurrent.Executor; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.logging.Level; import java.util.logging.Logger; @@ -35,7 +34,6 @@ import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionRecogniser; import net.sf.briar.api.transport.ConnectionWindow; -import net.sf.briar.util.ByteUtils; import com.google.inject.Inject; @@ -50,7 +48,8 @@ DatabaseListener { private final Executor executor; private final Cipher ivCipher; // Locking: this private final Map<Bytes, Context> expected; // Locking: this - private final AtomicBoolean initialised = new AtomicBoolean(false); + + private boolean initialised = false; // Locking: this @Inject ConnectionRecogniserImpl(CryptoComponent crypto, DatabaseComponent db, @@ -61,70 +60,48 @@ DatabaseListener { ivCipher = crypto.getIvCipher(); expected = new HashMap<Bytes, Context>(); db.addListener(this); + } + + // Locking: this + private void initialise() throws DbException { Runtime.getRuntime().addShutdownHook(new Thread() { @Override public void run() { eraseSecrets(); } }); - } - - private synchronized void eraseSecrets() { - for(Context c : expected.values()) { - for(byte[] b : c.window.getUnseen().values()) ByteUtils.erase(b); - } - } - - private void initialise() throws DbException { - // Fill in the contexts as far as possible outside the lock - Collection<Context> partial = new ArrayList<Context>(); - Collection<Transport> transports = db.getLocalTransports(); + Collection<TransportId> transports = new ArrayList<TransportId>(); + for(Transport t : db.getLocalTransports()) transports.add(t.getId()); for(ContactId c : db.getContacts()) { - for(Transport transport : transports) { - getPartialContexts(c, transport.getId(), partial); - } - } - synchronized(this) { - // Complete the contexts and calculate the expected IVs - calculateIvs(completeContexts(partial)); - } - } - - private void getPartialContexts(ContactId c, TransportId t, - Collection<Context> partial) throws DbException { - try { - TransportIndex i = db.getRemoteIndex(c, t); - if(i != null) { - // Acquire the lock to avoid getting stale data - synchronized(this) { + Collection<Context> contexts = new ArrayList<Context>(); + try { + for(TransportId t : transports) { + TransportIndex i = db.getRemoteIndex(c, t); + if(i == null) continue; ConnectionWindow w = db.getConnectionWindow(c, i); - partial.add(new Context(c, t, i, -1, w)); + for(long unseen : w.getUnseen().keySet()) { + contexts.add(new Context(c, t, i, unseen, w)); + } } + } catch(NoSuchContactException e) { + // The contact was removed - don't add the IVs + for(Context ctx : contexts) ctx.window.erase(); + continue; } - } catch(NoSuchContactException e) { - // The contact was removed - we'll handle the event later + for(Context ctx : contexts) expected.put(calculateIv(ctx), ctx); } + initialised = true; } - // Locking: this - private Collection<Context> completeContexts(Collection<Context> partial) { - Collection<Context> contexts = new ArrayList<Context>(); - for(Context ctx : partial) { - for(long unseen : ctx.window.getUnseen().keySet()) { - contexts.add(new Context(ctx.contactId, ctx.transportId, - ctx.transportIndex, unseen, ctx.window)); - } - } - return contexts; + private synchronized void eraseSecrets() { + for(Context c : expected.values()) c.window.erase(); } // Locking: this - private void calculateIvs(Collection<Context> contexts) { - for(Context ctx : contexts) { - byte[] secret = ctx.window.getUnseen().get(ctx.connection); - byte[] iv = encryptIv(ctx.transportIndex, ctx.connection, secret); - expected.put(new Bytes(iv), ctx); - } + private Bytes calculateIv(Context ctx) { + byte[] secret = ctx.window.getUnseen().get(ctx.connection); + byte[] iv = encryptIv(ctx.transportIndex, ctx.connection, secret); + return new Bytes(iv); } // Locking: this @@ -164,8 +141,8 @@ DatabaseListener { byte[] encryptedIv) throws DbException { if(encryptedIv.length != IV_LENGTH) throw new IllegalArgumentException(); - if(!initialised.getAndSet(true)) initialise(); synchronized(this) { + if(!initialised) initialise(); Bytes b = new Bytes(encryptedIv); Context ctx = expected.get(b); if(ctx == null || !ctx.transportId.equals(t)) return null; @@ -175,37 +152,50 @@ DatabaseListener { TransportIndex i = ctx.transportIndex; long connection = ctx.connection; ConnectionWindow w = ctx.window; - byte[] secret; // Get the secret and update the connection window + byte[] secret = w.setSeen(connection); try { db.setConnectionWindow(c, i, w); } catch(NoSuchContactException e) { - // The contact was removed - we'll handle the event later + // The contact was removed - reject the connection + removeContact(c); + w.erase(); + return null; } - secret = w.setSeen(connection); // Update the connection window's expected IVs Iterator<Context> it = expected.values().iterator(); while(it.hasNext()) { Context ctx1 = it.next(); - if(ctx1.contactId.equals(c) - && ctx1.transportIndex.equals(i)) it.remove(); + if(ctx1.contactId.equals(c) && ctx1.transportIndex.equals(i)) + it.remove(); } - Collection<Context> contexts = new ArrayList<Context>(); for(long unseen : w.getUnseen().keySet()) { - contexts.add(new Context(c, t, i, unseen, w)); + Context ctx1 = new Context(c, t, i, unseen, w); + expected.put(calculateIv(ctx1), ctx1); } - calculateIvs(contexts); return new ConnectionContextImpl(c, i, connection, secret); } } + private synchronized void removeContact(ContactId c) { + if(!initialised) return; + Iterator<Context> it = expected.values().iterator(); + while(it.hasNext()) { + Context ctx = it.next(); + if(ctx.contactId.equals(c)) { + ctx.window.erase(); + it.remove(); + } + } + } + public void eventOccurred(DatabaseEvent e) { if(e instanceof ContactRemovedEvent) { // Remove the expected IVs for the ex-contact final ContactId c = ((ContactRemovedEvent) e).getContactId(); executor.execute(new Runnable() { public void run() { - removeIvs(c); + removeContact(c); } }); } else if(e instanceof TransportAddedEvent) { @@ -228,46 +218,51 @@ DatabaseListener { } } - private synchronized void removeIvs(ContactId c) { - Iterator<Context> it = expected.values().iterator(); - while(it.hasNext()) if(it.next().contactId.equals(c)) it.remove(); - } - - private void addTransport(TransportId t) { - // Fill in the contexts as far as possible outside the lock - Collection<Context> partial = new ArrayList<Context>(); + private synchronized void addTransport(TransportId t) { + if(!initialised) return; try { for(ContactId c : db.getContacts()) { - getPartialContexts(c, t, partial); + Collection<Context> contexts = new ArrayList<Context>(); + try { + TransportIndex i = db.getRemoteIndex(c, t); + if(i == null) continue; + ConnectionWindow w = db.getConnectionWindow(c, i); + for(long unseen : w.getUnseen().keySet()) { + contexts.add(new Context(c, t, i, unseen, w)); + } + } catch(NoSuchContactException e) { + // The contact was removed - don't add the IVs + for(Context ctx : contexts) ctx.window.erase(); + continue; + } + for(Context ctx : contexts) expected.put(calculateIv(ctx), ctx); } } catch(DbException e) { if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); - return; - } - synchronized(this) { - // Complete the contexts and calculate the expected IVs - calculateIvs(completeContexts(partial)); } } - private void updateContact(ContactId c) { - // Fill in the contexts as far as possible outside the lock - Collection<Context> partial = new ArrayList<Context>(); + private synchronized void updateContact(ContactId c) { + if(!initialised) return; + removeContact(c); try { - Collection<Transport> transports = db.getLocalTransports(); - for(Transport transport : transports) { - getPartialContexts(c, transport.getId(), partial); + Collection<Context> contexts = new ArrayList<Context>(); + try { + for(Transport transport : db.getLocalTransports()) { + TransportId t = transport.getId(); + TransportIndex i = db.getRemoteIndex(c, t); + ConnectionWindow w = db.getConnectionWindow(c, i); + for(long unseen : w.getUnseen().keySet()) { + contexts.add(new Context(c, t, i, unseen, w)); + } + } + } catch(NoSuchContactException e) { + // The contact was removed - don't add the IVs + return; } + for(Context ctx : contexts) expected.put(calculateIv(ctx), ctx); } catch(DbException e) { if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); - return; - } - synchronized(this) { - // Clear the contact's existing IVs - Iterator<Context> it = expected.values().iterator(); - while(it.hasNext()) if(it.next().contactId.equals(c)) it.remove(); - // Complete the contexts and calculate the expected IVs - calculateIvs(completeContexts(partial)); } } diff --git a/components/net/sf/briar/transport/ConnectionWindowImpl.java b/components/net/sf/briar/transport/ConnectionWindowImpl.java index fcc4e9ac55ae5a9722f1d9d3492140b714fbba7d..adb76fa7b87e346642caa1d9d224aef1e41114b7 100644 --- a/components/net/sf/briar/transport/ConnectionWindowImpl.java +++ b/components/net/sf/briar/transport/ConnectionWindowImpl.java @@ -95,4 +95,8 @@ class ConnectionWindowImpl implements ConnectionWindow { public Map<Long, byte[]> getUnseen() { return unseen; } + + public void erase() { + for(byte[] secret : unseen.values()) ByteUtils.erase(secret); + } }