diff --git a/components/net/sf/briar/transport/ConnectionRecogniserImpl.java b/components/net/sf/briar/transport/ConnectionRecogniserImpl.java index 44825a42558d090f101132bead3366f1b94306a8..c2c1f53c1b51ec7d4724e413bf48654cf3ed1235 100644 --- a/components/net/sf/briar/transport/ConnectionRecogniserImpl.java +++ b/components/net/sf/briar/transport/ConnectionRecogniserImpl.java @@ -10,6 +10,7 @@ import java.util.Iterator; import java.util.Map; import java.util.Map.Entry; import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.logging.Level; import java.util.logging.Logger; @@ -47,11 +48,9 @@ DatabaseListener { private final CryptoComponent crypto; private final DatabaseComponent db; private final Executor executor; - private final Cipher ivCipher; - private final Map<Bytes, Context> expected; - private final Collection<TransportId> localTransportIds; - - private boolean initialised = false; + private final Map<Bytes, Context> expected; // Locking: this + private final Collection<TransportId> localTransportIds; // Locking: this + private final AtomicBoolean initialised = new AtomicBoolean(false); @Inject ConnectionRecogniserImpl(CryptoComponent crypto, DatabaseComponent db, @@ -59,60 +58,85 @@ DatabaseListener { this.crypto = crypto; this.db = db; this.executor = executor; - ivCipher = crypto.getIvCipher(); expected = new HashMap<Bytes, Context>(); localTransportIds = new ArrayList<TransportId>(); db.addListener(this); } - private synchronized void initialise() throws DbException { + private void initialise() throws DbException { Runtime.getRuntime().addShutdownHook(new Thread() { @Override public void run() { eraseSecrets(); } }); - for(Transport t : db.getLocalTransports()) { - localTransportIds.add(t.getId()); + Collection<TransportId> ids = new ArrayList<TransportId>(); + for(Transport t : db.getLocalTransports()) ids.add(t.getId()); + synchronized(this) { + localTransportIds.addAll(ids); } + Map<Bytes, Context> ivs = new HashMap<Bytes, Context>(); for(ContactId c : db.getContacts()) { try { - calculateIvs(c); + ivs.putAll(calculateIvs(c)); } catch(NoSuchContactException e) { // The contact was removed - clean up in eventOccurred() } } - initialised = true; + synchronized(this) { + expected.putAll(ivs); + } + } + + private synchronized void eraseSecrets() { + for(Context c : expected.values()) { + synchronized(c.window) { + for(byte[] b : c.window.getUnseen().values()) { + ByteUtils.erase(b); + } + } + } } - private synchronized void calculateIvs(ContactId c) throws DbException { - for(TransportId t : localTransportIds) { + private Map<Bytes, Context> calculateIvs(ContactId c) throws DbException { + Map<Bytes, Context> ivs = new HashMap<Bytes, Context>(); + Collection<TransportId> ids; + synchronized(this) { + ids = new ArrayList<TransportId>(localTransportIds); + } + for(TransportId t : ids) { TransportIndex i = db.getRemoteIndex(c, t); if(i != null) { ConnectionWindow w = db.getConnectionWindow(c, i); - calculateIvs(c, t, i, w); + ivs.putAll(calculateIvs(c, t, i, w)); } } + return ivs; } - private synchronized void calculateIvs(ContactId c, TransportId t, + private Map<Bytes, Context> calculateIvs(ContactId c, TransportId t, TransportIndex i, ConnectionWindow w) throws DbException { - for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) { - long connection = e.getKey(); - byte[] secret = e.getValue(); - ErasableKey ivKey = crypto.deriveIvKey(secret, true); - Bytes iv = new Bytes(encryptIv(i, connection, ivKey)); - ivKey.erase(); - expected.put(iv, new Context(c, t, i, connection, w)); + Map<Bytes, Context> ivs = new HashMap<Bytes, Context>(); + synchronized(w) { + for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) { + long connection = e.getKey(); + byte[] secret = e.getValue(); + Bytes iv = new Bytes(encryptIv(i, connection, secret)); + ivs.put(iv, new Context(c, t, i, connection, w)); + } } + return ivs; } - private synchronized byte[] encryptIv(TransportIndex i, long connection, - ErasableKey ivKey) { + private byte[] encryptIv(TransportIndex i, long connection, byte[] secret) { byte[] iv = IvEncoder.encodeIv(true, i.getInt(), connection); + ErasableKey ivKey = crypto.deriveIvKey(secret, true); try { + Cipher ivCipher = crypto.getIvCipher(); ivCipher.init(Cipher.ENCRYPT_MODE, ivKey); - return ivCipher.doFinal(iv); + byte[] encryptedIv = ivCipher.doFinal(iv); + ivKey.erase(); + return encryptedIv; } catch(BadPaddingException badCipher) { throw new RuntimeException(badCipher); } catch(IllegalBlockSizeException badCipher) { @@ -122,12 +146,6 @@ DatabaseListener { } } - private synchronized void eraseSecrets() { - for(Context c : expected.values()) { - for(byte[] b : c.window.getUnseen().values()) ByteUtils.erase(b); - } - } - public void acceptConnection(final TransportId t, final byte[] encryptedIv, final Callback callback) { executor.execute(new Runnable() { @@ -137,39 +155,48 @@ DatabaseListener { }); } - private synchronized void acceptConnectionSync(TransportId t, - byte[] encryptedIv, Callback callback) { + private void acceptConnectionSync(TransportId t, byte[] encryptedIv, + Callback callback) { try { if(encryptedIv.length != IV_LENGTH) throw new IllegalArgumentException(); - if(!initialised) initialise(); - Bytes b = new Bytes(encryptedIv); - Context ctx = expected.get(b); - if(ctx == null || !ctx.transportId.equals(t)) { - callback.connectionRejected(); - return; + if(!initialised.getAndSet(true)) initialise(); + Context ctx; + synchronized(this) { + Bytes b = new Bytes(encryptedIv); + ctx = expected.get(b); + if(ctx == null || !ctx.transportId.equals(t)) { + callback.connectionRejected(); + return; + } + expected.remove(b); } // The IV was expected - expected.remove(b); ContactId c = ctx.contactId; TransportIndex i = ctx.transportIndex; long connection = ctx.connection; ConnectionWindow w = ctx.window; - // Get the secret and update the connection window - byte[] secret = w.setSeen(connection); - try { - db.setConnectionWindow(c, i, w); - } catch(NoSuchContactException e) { - // The contact was removed - clean up in eventOccurred() + byte[] secret; + synchronized(w) { + // Get the secret and update the connection window + secret = w.setSeen(connection); + try { + db.setConnectionWindow(c, i, w); + } catch(NoSuchContactException e) { + // The contact was removed - clean up in eventOccurred() + } } // Update the set of expected IVs - Iterator<Context> it = expected.values().iterator(); - while(it.hasNext()) { - Context ctx1 = it.next(); - if(ctx1.contactId.equals(c) && ctx1.transportIndex.equals(i)) - it.remove(); + Map<Bytes, Context> ivs = calculateIvs(c, t, i, w); + synchronized(this) { + Iterator<Context> it = expected.values().iterator(); + while(it.hasNext()) { + Context ctx1 = it.next(); + if(ctx1.contactId.equals(c) + && ctx1.transportIndex.equals(i)) it.remove(); + } + expected.putAll(ivs); } - calculateIvs(c, t, i, w); callback.connectionAccepted(new ConnectionContextImpl(c, i, connection, secret)); } catch(DbException e) { @@ -184,28 +211,30 @@ DatabaseListener { } else if(e instanceof TransportAddedEvent) { // Calculate the expected IVs for the new transport TransportId t = ((TransportAddedEvent) e).getTransportId(); - synchronized(this) { - if(!initialised) return; - try { + try { + if(!initialised.getAndSet(true)) initialise(); + Map<Bytes, Context> ivs = calculateIvs(t); + synchronized(this) { localTransportIds.add(t); - calculateIvs(t); - } catch(DbException e1) { - if(LOG.isLoggable(Level.WARNING)) - LOG.warning(e1.getMessage()); + expected.putAll(ivs); } + } catch(DbException e1) { + if(LOG.isLoggable(Level.WARNING)) + LOG.warning(e1.getMessage()); } } else if(e instanceof RemoteTransportsUpdatedEvent) { // Remove and recalculate the expected IVs for the contact ContactId c = ((RemoteTransportsUpdatedEvent) e).getContactId(); - synchronized(this) { - if(!initialised) return; - removeIvs(c); - try { - calculateIvs(c); - } catch(DbException e1) { - if(LOG.isLoggable(Level.WARNING)) - LOG.warning(e1.getMessage()); + try { + if(!initialised.getAndSet(true)) initialise(); + Map<Bytes, Context> ivs = calculateIvs(c); + synchronized(this) { + removeIvs(c); + expected.putAll(ivs); } + } catch(DbException e1) { + if(LOG.isLoggable(Level.WARNING)) + LOG.warning(e1.getMessage()); } } } @@ -215,18 +244,20 @@ DatabaseListener { while(it.hasNext()) if(it.next().contactId.equals(c)) it.remove(); } - private synchronized void calculateIvs(TransportId t) throws DbException { + private Map<Bytes, Context> calculateIvs(TransportId t) throws DbException { + Map<Bytes, Context> ivs = new HashMap<Bytes, Context>(); for(ContactId c : db.getContacts()) { try { TransportIndex i = db.getRemoteIndex(c, t); if(i != null) { ConnectionWindow w = db.getConnectionWindow(c, i); - calculateIvs(c, t, i, w); + ivs.putAll(calculateIvs(c, t, i, w)); } } catch(NoSuchContactException e) { - // The contact was removed - clean up when we get the event + // The contact was removed - clean up in eventOccurred() } } + return ivs; } private static class Context {