From f6ae4734cebf82647171c25adedb6bbd11687457 Mon Sep 17 00:00:00 2001 From: akwizgran <akwizgran@users.sourceforge.net> Date: Wed, 16 Nov 2011 15:35:16 +0000 Subject: [PATCH] Forward secrecy. Each connection's keys are derived from a secret that is erased after deriving the keys and the secret for the next connection. --- .../sf/briar/api/crypto/CryptoComponent.java | 2 +- .../sf/briar/api/db/DatabaseComponent.java | 6 +- .../api/transport/BatchConnectionFactory.java | 4 +- .../api/transport/ConnectionContext.java | 2 + .../transport/ConnectionContextFactory.java | 5 +- .../api/transport/ConnectionDispatcher.java | 4 +- .../transport/ConnectionReaderFactory.java | 15 +- .../briar/api/transport/ConnectionWindow.java | 4 +- .../transport/ConnectionWindowFactory.java | 9 +- .../transport/ConnectionWriterFactory.java | 11 +- .../transport/StreamConnectionFactory.java | 4 +- .../sf/briar/crypto/CryptoComponentImpl.java | 21 +-- .../net/sf/briar/crypto/ErasableKeyImpl.java | 3 +- components/net/sf/briar/db/Database.java | 24 ++-- .../sf/briar/db/DatabaseComponentImpl.java | 35 ++--- components/net/sf/briar/db/H2Database.java | 7 +- components/net/sf/briar/db/JdbcDatabase.java | 132 +++++++++--------- .../sf/briar/plugins/PluginManagerImpl.java | 4 +- .../ConnectionContextFactoryImpl.java | 21 ++- .../transport/ConnectionContextImpl.java | 8 +- .../transport/ConnectionDispatcherImpl.java | 14 +- .../ConnectionReaderFactoryImpl.java | 27 ++-- .../transport/ConnectionRecogniserImpl.java | 32 ++--- .../ConnectionWindowFactoryImpl.java | 23 ++- .../briar/transport/ConnectionWindowImpl.java | 56 +++++--- .../ConnectionWriterFactoryImpl.java | 30 ++-- .../net/sf/briar/transport/IvEncoder.java | 31 ++-- .../batch/BatchConnectionFactoryImpl.java | 10 +- .../batch/IncomingBatchConnection.java | 25 ++-- .../batch/OutgoingBatchConnection.java | 14 +- .../stream/IncomingStreamConnection.java | 18 +-- .../stream/OutgoingStreamConnection.java | 23 ++- .../transport/stream/StreamConnection.java | 6 +- .../stream/StreamConnectionFactoryImpl.java | 9 +- .../net/sf/briar/ProtocolIntegrationTest.java | 41 +++--- .../sf/briar/db/DatabaseComponentTest.java | 29 +--- test/net/sf/briar/db/H2DatabaseTest.java | 100 ++++++------- .../ConnectionDecrypterImplTest.java | 7 +- .../ConnectionEncrypterImplTest.java | 5 +- .../ConnectionRecogniserImplTest.java | 23 +-- .../transport/ConnectionWindowImplTest.java | 60 +++++--- .../briar/transport/ConnectionWriterTest.java | 18 ++- .../briar/transport/FrameReadWriteTest.java | 5 +- .../batch/BatchConnectionReadWriteTest.java | 5 +- util/net/sf/briar/util/ByteUtils.java | 4 + 45 files changed, 506 insertions(+), 430 deletions(-) diff --git a/api/net/sf/briar/api/crypto/CryptoComponent.java b/api/net/sf/briar/api/crypto/CryptoComponent.java index 0c6b7f9ed0..a764e33bc1 100644 --- a/api/net/sf/briar/api/crypto/CryptoComponent.java +++ b/api/net/sf/briar/api/crypto/CryptoComponent.java @@ -15,7 +15,7 @@ public interface CryptoComponent { ErasableKey deriveMacKey(byte[] secret, boolean initiator); - byte[] deriveNextSecret(byte[] secret, long connection); + byte[] deriveNextSecret(byte[] secret, int index, long connection); KeyPair generateKeyPair(); diff --git a/api/net/sf/briar/api/db/DatabaseComponent.java b/api/net/sf/briar/api/db/DatabaseComponent.java index f2df7b5cc2..2f47365caa 100644 --- a/api/net/sf/briar/api/db/DatabaseComponent.java +++ b/api/net/sf/briar/api/db/DatabaseComponent.java @@ -57,8 +57,7 @@ public interface DatabaseComponent { * Adds a new contact to the database with the given secrets and returns an * ID for the contact. */ - ContactId addContact(byte[] incomingSecret, byte[] outgoingSecret) - throws DbException; + ContactId addContact(byte[] inSecret, byte[] outSecret) throws DbException; /** Adds a locally generated group message to the database. */ void addLocalGroupMessage(Message m) throws DbException; @@ -160,9 +159,6 @@ public interface DatabaseComponent { Map<ContactId, TransportProperties> getRemoteProperties(TransportId t) throws DbException; - /** Returns the secret shared with the given contact. */ - byte[] getSharedSecret(ContactId c, boolean incoming) throws DbException; - /** Returns the set of groups to which the user subscribes. */ Collection<Group> getSubscriptions() throws DbException; diff --git a/api/net/sf/briar/api/transport/BatchConnectionFactory.java b/api/net/sf/briar/api/transport/BatchConnectionFactory.java index 20fb14e60b..279e605256 100644 --- a/api/net/sf/briar/api/transport/BatchConnectionFactory.java +++ b/api/net/sf/briar/api/transport/BatchConnectionFactory.java @@ -5,9 +5,9 @@ import net.sf.briar.api.protocol.TransportIndex; public interface BatchConnectionFactory { - void createIncomingConnection(TransportIndex i, ContactId c, + void createIncomingConnection(ConnectionContext ctx, BatchTransportReader r, byte[] encryptedIv); - void createOutgoingConnection(TransportIndex i, ContactId c, + void createOutgoingConnection(ContactId c, TransportIndex i, BatchTransportWriter w); } diff --git a/api/net/sf/briar/api/transport/ConnectionContext.java b/api/net/sf/briar/api/transport/ConnectionContext.java index 314392dc81..1a5dbd5519 100644 --- a/api/net/sf/briar/api/transport/ConnectionContext.java +++ b/api/net/sf/briar/api/transport/ConnectionContext.java @@ -10,4 +10,6 @@ public interface ConnectionContext { TransportIndex getTransportIndex(); long getConnectionNumber(); + + byte[] getSecret(); } diff --git a/api/net/sf/briar/api/transport/ConnectionContextFactory.java b/api/net/sf/briar/api/transport/ConnectionContextFactory.java index 39ab561dce..374a357ef8 100644 --- a/api/net/sf/briar/api/transport/ConnectionContextFactory.java +++ b/api/net/sf/briar/api/transport/ConnectionContextFactory.java @@ -6,5 +6,8 @@ import net.sf.briar.api.protocol.TransportIndex; public interface ConnectionContextFactory { ConnectionContext createConnectionContext(ContactId c, TransportIndex i, - long connection); + long connection, byte[] secret); + + ConnectionContext createNextConnectionContext(ContactId c, TransportIndex i, + long connection, byte[] previousSecret); } diff --git a/api/net/sf/briar/api/transport/ConnectionDispatcher.java b/api/net/sf/briar/api/transport/ConnectionDispatcher.java index 6207063d5a..4317983395 100644 --- a/api/net/sf/briar/api/transport/ConnectionDispatcher.java +++ b/api/net/sf/briar/api/transport/ConnectionDispatcher.java @@ -8,10 +8,10 @@ public interface ConnectionDispatcher { void dispatchReader(TransportId t, BatchTransportReader r); - void dispatchWriter(TransportIndex i, ContactId c, BatchTransportWriter w); + void dispatchWriter(ContactId c, TransportIndex i, BatchTransportWriter w); void dispatchIncomingConnection(TransportId t, StreamTransportConnection s); - void dispatchOutgoingConnection(TransportIndex i, ContactId c, + void dispatchOutgoingConnection(ContactId c, TransportIndex i, StreamTransportConnection s); } diff --git a/api/net/sf/briar/api/transport/ConnectionReaderFactory.java b/api/net/sf/briar/api/transport/ConnectionReaderFactory.java index db9dead7b0..a916be465e 100644 --- a/api/net/sf/briar/api/transport/ConnectionReaderFactory.java +++ b/api/net/sf/briar/api/transport/ConnectionReaderFactory.java @@ -2,22 +2,19 @@ package net.sf.briar.api.transport; import java.io.InputStream; -import net.sf.briar.api.protocol.TransportIndex; - public interface ConnectionReaderFactory { /** * Creates a connection reader for a batch-mode connection or the - * initiator's side of a stream-mode connection. The secret is erased before - * returning. + * initiator's side of a stream-mode connection. */ - ConnectionReader createConnectionReader(InputStream in, TransportIndex i, - byte[] encryptedIv, byte[] secret); + ConnectionReader createConnectionReader(InputStream in, + ConnectionContext ctx, byte[] encryptedIv); /** * Creates a connection reader for the responder's side of a stream-mode - * connection. The secret is erased before returning. + * connection. */ - ConnectionReader createConnectionReader(InputStream in, TransportIndex i, - long connection, byte[] secret); + ConnectionReader createConnectionReader(InputStream in, + ConnectionContext ctx); } diff --git a/api/net/sf/briar/api/transport/ConnectionWindow.java b/api/net/sf/briar/api/transport/ConnectionWindow.java index 35fc872f15..4d0a9d59ce 100644 --- a/api/net/sf/briar/api/transport/ConnectionWindow.java +++ b/api/net/sf/briar/api/transport/ConnectionWindow.java @@ -1,6 +1,6 @@ package net.sf.briar.api.transport; -import java.util.Collection; +import java.util.Map; public interface ConnectionWindow { @@ -8,5 +8,5 @@ public interface ConnectionWindow { void setSeen(long connection); - Collection<Long> getUnseen(); + Map<Long, byte[]> getUnseen(); } diff --git a/api/net/sf/briar/api/transport/ConnectionWindowFactory.java b/api/net/sf/briar/api/transport/ConnectionWindowFactory.java index 33ef8e7a6a..110df0e99c 100644 --- a/api/net/sf/briar/api/transport/ConnectionWindowFactory.java +++ b/api/net/sf/briar/api/transport/ConnectionWindowFactory.java @@ -1,10 +1,13 @@ package net.sf.briar.api.transport; -import java.util.Collection; +import java.util.Map; + +import net.sf.briar.api.protocol.TransportIndex; public interface ConnectionWindowFactory { - ConnectionWindow createConnectionWindow(); + ConnectionWindow createConnectionWindow(TransportIndex i, byte[] secret); - ConnectionWindow createConnectionWindow(Collection<Long> unseen); + ConnectionWindow createConnectionWindow(TransportIndex i, + Map<Long, byte[]> unseen); } diff --git a/api/net/sf/briar/api/transport/ConnectionWriterFactory.java b/api/net/sf/briar/api/transport/ConnectionWriterFactory.java index 8b05d3e2a5..c999efd7e9 100644 --- a/api/net/sf/briar/api/transport/ConnectionWriterFactory.java +++ b/api/net/sf/briar/api/transport/ConnectionWriterFactory.java @@ -2,22 +2,19 @@ package net.sf.briar.api.transport; import java.io.OutputStream; -import net.sf.briar.api.protocol.TransportIndex; - public interface ConnectionWriterFactory { /** * Creates a connection writer for a batch-mode connection or the - * initiator's side of a stream-mode connection. The secret is erased before - * returning. + * initiator's side of a stream-mode connection. */ ConnectionWriter createConnectionWriter(OutputStream out, long capacity, - TransportIndex i, long connection, byte[] secret); + ConnectionContext ctx); /** * Creates a connection writer for the responder's side of a stream-mode - * connection. The secret is erased before returning. + * connection. */ ConnectionWriter createConnectionWriter(OutputStream out, long capacity, - TransportIndex i, byte[] encryptedIv, byte[] secret); + ConnectionContext ctx, byte[] encryptedIv); } diff --git a/api/net/sf/briar/api/transport/StreamConnectionFactory.java b/api/net/sf/briar/api/transport/StreamConnectionFactory.java index 3287453e7e..13c1daf69e 100644 --- a/api/net/sf/briar/api/transport/StreamConnectionFactory.java +++ b/api/net/sf/briar/api/transport/StreamConnectionFactory.java @@ -5,9 +5,9 @@ import net.sf.briar.api.protocol.TransportIndex; public interface StreamConnectionFactory { - void createIncomingConnection(TransportIndex i, ContactId c, + void createIncomingConnection(ConnectionContext ctx, StreamTransportConnection s, byte[] encryptedIv); - void createOutgoingConnection(TransportIndex i, ContactId c, + void createOutgoingConnection(ContactId c, TransportIndex i, StreamTransportConnection s); } diff --git a/components/net/sf/briar/crypto/CryptoComponentImpl.java b/components/net/sf/briar/crypto/CryptoComponentImpl.java index d82d7f924c..0f1aca360a 100644 --- a/components/net/sf/briar/crypto/CryptoComponentImpl.java +++ b/components/net/sf/briar/crypto/CryptoComponentImpl.java @@ -88,14 +88,14 @@ class CryptoComponentImpl implements CryptoComponent { if(secret.length != SECRET_KEY_BYTES) throw new IllegalArgumentException(); ErasableKey key = new ErasableKeyImpl(secret, SECRET_KEY_ALGO); - // The context must leave four bytes free for the length - if(context.length + 4 > SECRET_KEY_BYTES) + // The context must leave two bytes free for the length + if(context.length + 2 > SECRET_KEY_BYTES) throw new IllegalArgumentException(); byte[] input = new byte[SECRET_KEY_BYTES]; - // The initial bytes of the input are the context - System.arraycopy(context, 0, input, 0, context.length); - // The final bytes of the input are the length as a big-endian uint32 - ByteUtils.writeUint32(context.length, input, input.length - 4); + // The input starts with the length of the context as a big-endian int16 + ByteUtils.writeUint16(context.length, input, 0); + // The remaining bytes of the input are the context + System.arraycopy(context, 0, input, 2, context.length); // Initialise the counter to zero byte[] zero = new byte[KEY_DERIVATION_IV_BYTES]; IvParameterSpec iv = new IvParameterSpec(zero); @@ -110,12 +110,15 @@ class CryptoComponentImpl implements CryptoComponent { } } - public byte[] deriveNextSecret(byte[] secret, long connection) { + public byte[] deriveNextSecret(byte[] secret, int index, long connection) { + if(index < 0 || index > ByteUtils.MAX_16_BIT_UNSIGNED) + throw new IllegalArgumentException(); if(connection < 0 || connection > ByteUtils.MAX_32_BIT_UNSIGNED) throw new IllegalArgumentException(); - byte[] context = new byte[NEXT.length + 4]; + byte[] context = new byte[NEXT.length + 6]; System.arraycopy(NEXT, 0, context, 0, NEXT.length); - ByteUtils.writeUint32(connection, context, NEXT.length); + ByteUtils.writeUint16(index, context, NEXT.length); + ByteUtils.writeUint32(connection, context, NEXT.length + 2); return counterModeKdf(secret, context); } diff --git a/components/net/sf/briar/crypto/ErasableKeyImpl.java b/components/net/sf/briar/crypto/ErasableKeyImpl.java index 209699a3c5..595485f3bb 100644 --- a/components/net/sf/briar/crypto/ErasableKeyImpl.java +++ b/components/net/sf/briar/crypto/ErasableKeyImpl.java @@ -3,6 +3,7 @@ package net.sf.briar.crypto; import java.util.Arrays; import net.sf.briar.api.crypto.ErasableKey; +import net.sf.briar.util.ByteUtils; class ErasableKeyImpl implements ErasableKey { @@ -34,7 +35,7 @@ class ErasableKeyImpl implements ErasableKey { public void erase() { if(erased) throw new IllegalStateException(); - for(int i = 0; i < key.length; i++) key[i] = 0; + ByteUtils.erase(key); erased = true; } diff --git a/components/net/sf/briar/db/Database.java b/components/net/sf/briar/db/Database.java index fe1a6fc005..02dc62d98b 100644 --- a/components/net/sf/briar/db/Database.java +++ b/components/net/sf/briar/db/Database.java @@ -84,10 +84,14 @@ interface Database<T> { * Adds a new contact to the database with the given secrets and returns an * ID for the contact. * <p> + * Any secrets generated by the method are stored in the given collection + * and should be erased by the caller once the transaction has been + * committed or aborted. + * <p> * Locking: contact write. */ - ContactId addContact(T txn, byte[] incomingSecret, byte[] outgoingSecret) - throws DbException; + ContactId addContact(T txn, byte[] inSecret, byte[] outSecret, + Collection<byte[]> erase) throws DbException; /** * Returns false if the given message is already in the database. Otherwise @@ -187,10 +191,14 @@ interface Database<T> { * Returns an outgoing connection context for the given contact and * transport. * <p> + * Any secrets generated by the method are stored in the given collection + * and should be erased by the caller once the transaction has been + * committed or aborted. + * <p> * Locking: contact read, window write. */ - ConnectionContext getConnectionContext(T txn, ContactId c, TransportIndex i) - throws DbException; + ConnectionContext getConnectionContext(T txn, ContactId c, TransportIndex i, + Collection<byte[]> erase) throws DbException; /** * Returns the connection reordering window for the given contact and @@ -373,14 +381,6 @@ interface Database<T> { Collection<MessageId> getSendableMessages(T txn, ContactId c, int capacity) throws DbException; - /** - * Returns the secret shared with the given contact. - * <p> - * Locking: contact read. - */ - byte[] getSharedSecret(T txn, ContactId c, boolean incoming) - throws DbException; - /** * Returns true if the given message has been starred. * <p> diff --git a/components/net/sf/briar/db/DatabaseComponentImpl.java b/components/net/sf/briar/db/DatabaseComponentImpl.java index 1fd216d55f..bdf067cdb1 100644 --- a/components/net/sf/briar/db/DatabaseComponentImpl.java +++ b/components/net/sf/briar/db/DatabaseComponentImpl.java @@ -62,6 +62,7 @@ import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter; import net.sf.briar.api.protocol.writers.TransportUpdateWriter; import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionWindow; +import net.sf.briar.util.ByteUtils; import com.google.inject.Inject; @@ -136,23 +137,24 @@ DatabaseCleaner.Callback { } } - public ContactId addContact(byte[] incomingSecret, byte[] outgoingSecret) + public ContactId addContact(byte[] inSecret, byte[] outSecret) throws DbException { - if(LOG.isLoggable(Level.FINE)) LOG.fine("Adding contact"); ContactId c; + Collection<byte[]> erase = new ArrayList<byte[]>(); contactLock.writeLock().lock(); try { T txn = db.startTransaction(); try { - c = db.addContact(txn, incomingSecret, outgoingSecret); + c = db.addContact(txn, inSecret, outSecret, erase); db.commitTransaction(txn); - if(LOG.isLoggable(Level.FINE)) LOG.fine("Added contact " + c); } catch(DbException e) { db.abortTransaction(txn); throw e; } } finally { contactLock.writeLock().unlock(); + // Erase the secrets after committing or aborting the transaction + for(byte[] b : erase) ByteUtils.erase(b); } // Call the listeners outside the lock callListeners(new ContactAddedEvent(c)); @@ -703,6 +705,7 @@ DatabaseCleaner.Callback { public ConnectionContext getConnectionContext(ContactId c, TransportIndex i) throws DbException { + Collection<byte[]> erase = new ArrayList<byte[]>(); contactLock.readLock().lock(); try { if(!containsContact(c)) throw new NoSuchContactException(); @@ -710,7 +713,8 @@ DatabaseCleaner.Callback { try { T txn = db.startTransaction(); try { - ConnectionContext ctx = db.getConnectionContext(txn, c, i); + ConnectionContext ctx = + db.getConnectionContext(txn, c, i, erase); db.commitTransaction(txn); return ctx; } catch(DbException e) { @@ -722,6 +726,8 @@ DatabaseCleaner.Callback { } } finally { contactLock.readLock().unlock(); + // Erase the secrets after committing or aborting the transaction + for(byte[] b : erase) ByteUtils.erase(b); } } @@ -907,25 +913,6 @@ DatabaseCleaner.Callback { } } - public byte[] getSharedSecret(ContactId c, boolean incoming) - throws DbException { - contactLock.readLock().lock(); - try { - if(!containsContact(c)) throw new NoSuchContactException(); - T txn = db.startTransaction(); - try { - byte[] secret = db.getSharedSecret(txn, c, incoming); - db.commitTransaction(txn); - return secret; - } catch(DbException e) { - db.abortTransaction(txn); - throw e; - } - } finally { - contactLock.readLock().unlock(); - } - } - public Collection<Group> getSubscriptions() throws DbException { subscriptionLock.readLock().lock(); try { diff --git a/components/net/sf/briar/db/H2Database.java b/components/net/sf/briar/db/H2Database.java index a721aa9a2e..f9e86e349f 100644 --- a/components/net/sf/briar/db/H2Database.java +++ b/components/net/sf/briar/db/H2Database.java @@ -29,6 +29,11 @@ class H2Database extends JdbcDatabase { private static final Logger LOG = Logger.getLogger(H2Database.class.getName()); + private static final String HASH_TYPE = "BINARY(32)"; + private static final String BINARY_TYPE = "BINARY"; + private static final String COUNTER_TYPE = "INT NOT NULL AUTO_INCREMENT"; + private static final String SECRET_TYPE = "BINARY(32)"; + private final File home; private final Password password; private final String url; @@ -42,7 +47,7 @@ class H2Database extends JdbcDatabase { ConnectionWindowFactory connectionWindowFactory, GroupFactory groupFactory) { super(connectionContextFactory, connectionWindowFactory, groupFactory, - "BINARY(32)", "BINARY", "INT NOT NULL AUTO_INCREMENT"); + HASH_TYPE, BINARY_TYPE, COUNTER_TYPE, SECRET_TYPE); home = new File(dir, "db"); this.password = password; url = "jdbc:h2:split:" + home.getPath() diff --git a/components/net/sf/briar/db/JdbcDatabase.java b/components/net/sf/briar/db/JdbcDatabase.java index 6a874a16fd..c06b71bc95 100644 --- a/components/net/sf/briar/db/JdbcDatabase.java +++ b/components/net/sf/briar/db/JdbcDatabase.java @@ -58,8 +58,6 @@ abstract class JdbcDatabase implements Database<Connection> { private static final String CREATE_CONTACTS = "CREATE TABLE contacts" + " (contactId COUNTER," - + " incomingSecret BINARY NOT NULL," - + " outgoingSecret BINARY NOT NULL," + " PRIMARY KEY (contactId))"; private static final String CREATE_MESSAGES = @@ -221,7 +219,8 @@ abstract class JdbcDatabase implements Database<Connection> { "CREATE TABLE connections" + " (contactId INT NOT NULL," + " index INT NOT NULL," - + " outgoing BIGINT NOT NULL," + + " connection BIGINT NOT NULL," + + " secret SECRET NOT NULL," + " PRIMARY KEY (contactId, index)," + " FOREIGN KEY (contactId) REFERENCES contacts (contactId)" + " ON DELETE CASCADE)"; @@ -231,6 +230,7 @@ abstract class JdbcDatabase implements Database<Connection> { + " (contactId INT NOT NULL," + " index INT NOT NULL," + " unseen BIGINT NOT NULL," + + " secret SECRET NOT NULL," + " PRIMARY KEY (contactId, index, unseen)," + " FOREIGN KEY (contactId) REFERENCES contacts (contactId)" + " ON DELETE CASCADE)"; @@ -271,7 +271,7 @@ abstract class JdbcDatabase implements Database<Connection> { private final ConnectionWindowFactory connectionWindowFactory; private final GroupFactory groupFactory; // Different database libraries use different names for certain types - private final String hashType, binaryType, counterType; + private final String hashType, binaryType, counterType, secretType; private final LinkedList<Connection> connections = new LinkedList<Connection>(); // Locking: self @@ -284,13 +284,14 @@ abstract class JdbcDatabase implements Database<Connection> { JdbcDatabase(ConnectionContextFactory connectionContextFactory, ConnectionWindowFactory connectionWindowFactory, GroupFactory groupFactory, String hashType, String binaryType, - String counterType) { + String counterType, String secretType) { this.connectionContextFactory = connectionContextFactory; this.connectionWindowFactory = connectionWindowFactory; this.groupFactory = groupFactory; this.hashType = hashType; this.binaryType = binaryType; this.counterType = counterType; + this.secretType = secretType; } protected void open(boolean resume, File dir, String driverClass) @@ -371,6 +372,7 @@ abstract class JdbcDatabase implements Database<Connection> { s = s.replaceAll("HASH", hashType); s = s.replaceAll("BINARY", binaryType); s = s.replaceAll("COUNTER", counterType); + s = s.replaceAll("SECRET", secretType); return s; } @@ -515,17 +517,14 @@ abstract class JdbcDatabase implements Database<Connection> { } } - public ContactId addContact(Connection txn, byte[] incomingSecret, - byte[] outgoingSecret) throws DbException { + public ContactId addContact(Connection txn, byte[] inSecret, + byte[] outSecret, Collection<byte[]> erase) throws DbException { PreparedStatement ps = null; ResultSet rs = null; try { // Create a new contact row - String sql = "INSERT INTO contacts (incomingSecret, outgoingSecret)" - + " VALUES (?, ?)"; + String sql = "INSERT INTO contacts DEFAULT VALUES"; ps = txn.prepareStatement(sql); - ps.setBytes(1, incomingSecret); - ps.setBytes(2, outgoingSecret); int affected = ps.executeUpdate(); if(affected != 1) throw new DbStateException(); ps.close(); @@ -558,13 +557,20 @@ abstract class JdbcDatabase implements Database<Connection> { affected = ps.executeUpdate(); if(affected != 1) throw new DbStateException(); ps.close(); - // Initialise the connection numbers for all transports - sql = "INSERT INTO connections (contactId, index, outgoing)" - + " VALUES (?, ?, ZERO())"; + // Initialise the outgoing connection contexts for all transports + sql = "INSERT INTO connections" + + " (contactId, index, connection, secret)" + + " VALUES (?, ?, ZERO(), ?)"; ps = txn.prepareStatement(sql); ps.setInt(1, c.getInt()); for(int i = 0; i < ProtocolConstants.MAX_TRANSPORTS; i++) { ps.setInt(2, i); + ConnectionContext ctx = + connectionContextFactory.createNextConnectionContext(c, + new TransportIndex(i), 0L, outSecret); + byte[] secret = ctx.getSecret(); + erase.add(secret); + ps.setBytes(3, secret); ps.addBatch(); } int[] batchAffected = ps.executeBatch(); @@ -574,18 +580,23 @@ abstract class JdbcDatabase implements Database<Connection> { if(batchAffected[i] != 1) throw new DbStateException(); } ps.close(); - // Initialise the connection windows for all transports - sql = "INSERT INTO connectionWindows (contactId, index, unseen)" - + " VALUES (?, ?, ?)"; + // Initialise the incoming connection windows for all transports + sql = "INSERT INTO connectionWindows" + + " (contactId, index, unseen, secret)" + + " VALUES (?, ?, ?, ?)"; ps = txn.prepareStatement(sql); ps.setInt(1, c.getInt()); int batchSize = 0; for(int i = 0; i < ProtocolConstants.MAX_TRANSPORTS; i++) { ps.setInt(2, i); ConnectionWindow w = - connectionWindowFactory.createConnectionWindow(); - for(long l : w.getUnseen()) { - ps.setLong(3, l); + connectionWindowFactory.createConnectionWindow( + new TransportIndex(i), inSecret); + for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) { + ps.setLong(3, e.getKey()); + byte[] secret = e.getValue(); + erase.add(secret); + ps.setBytes(4, secret); ps.addBatch(); batchSize++; } @@ -945,31 +956,43 @@ abstract class JdbcDatabase implements Database<Connection> { } public ConnectionContext getConnectionContext(Connection txn, ContactId c, - TransportIndex i) throws DbException { + TransportIndex i, Collection<byte[]> erase) throws DbException { PreparedStatement ps = null; ResultSet rs = null; try { - String sql = "UPDATE connections SET outgoing = outgoing + 1" - + " WHERE contactId = ? AND index = ?"; - ps = txn.prepareStatement(sql); - ps.setInt(1, c.getInt()); - ps.setInt(2, i.getInt()); - int affected = ps.executeUpdate(); - if(affected != 1) throw new DbStateException(); - ps.close(); - sql = "SELECT outgoing FROM connections" + // Retrieve the current context + String sql = "SELECT connection, secret FROM connections" + " WHERE contactId = ? AND index = ?"; ps = txn.prepareStatement(sql); ps.setInt(1, c.getInt()); ps.setInt(2, i.getInt()); rs = ps.executeQuery(); if(!rs.next()) throw new DbStateException(); - long outgoing = rs.getLong(1); + long connection = rs.getLong(1); + byte[] secret = rs.getBytes(2); if(rs.next()) throw new DbStateException(); rs.close(); ps.close(); - return connectionContextFactory.createConnectionContext(c, i, - outgoing); + ConnectionContext ctx = + connectionContextFactory.createConnectionContext(c, i, + connection, secret); + // Calculate and store the next context + ConnectionContext next = + connectionContextFactory.createNextConnectionContext(c, i, + connection + 1, secret); + byte[] nextSecret = next.getSecret(); + erase.add(nextSecret); + sql = "UPDATE connections" + + " SET connection = connection + 1, secret = ?" + + " WHERE contactId = ? AND index = ?"; + ps = txn.prepareStatement(sql); + ps.setBytes(1, nextSecret); + ps.setInt(2, c.getInt()); + ps.setInt(3, i.getInt()); + int affected = ps.executeUpdate(); + if(affected != 1) throw new DbStateException(); + ps.close(); + return ctx; } catch(SQLException e) { tryToClose(rs); tryToClose(ps); @@ -982,17 +1005,17 @@ abstract class JdbcDatabase implements Database<Connection> { PreparedStatement ps = null; ResultSet rs = null; try { - String sql = "SELECT unseen FROM connectionWindows" + String sql = "SELECT unseen, secret FROM connectionWindows" + " WHERE contactId = ? AND index = ?"; ps = txn.prepareStatement(sql); ps.setInt(1, c.getInt()); ps.setInt(2, i.getInt()); rs = ps.executeQuery(); - Collection<Long> unseen = new ArrayList<Long>(); - while(rs.next()) unseen.add(rs.getLong(1)); + Map<Long, byte[]> unseen = new HashMap<Long, byte[]>(); + while(rs.next()) unseen.put(rs.getLong(1), rs.getBytes(2)); rs.close(); ps.close(); - return connectionWindowFactory.createConnectionWindow(unseen); + return connectionWindowFactory.createConnectionWindow(i, unseen); } catch(SQLException e) { tryToClose(rs); tryToClose(ps); @@ -1652,29 +1675,6 @@ abstract class JdbcDatabase implements Database<Connection> { } } - public byte[] getSharedSecret(Connection txn, ContactId c, boolean incoming) - throws DbException { - PreparedStatement ps = null; - ResultSet rs = null; - try { - String col = incoming ? "incomingSecret" : "outgoingSecret"; - String sql = "SELECT " + col + " FROM contacts WHERE contactId = ?"; - ps = txn.prepareStatement(sql); - ps.setInt(1, c.getInt()); - rs = ps.executeQuery(); - if(!rs.next()) throw new DbStateException(); - byte[] secret = rs.getBytes(1); - if(rs.next()) throw new DbStateException(); - rs.close(); - ps.close(); - return secret; - } catch(SQLException e) { - tryToClose(rs); - tryToClose(ps); - throw new DbException(e); - } - } - public boolean getStarred(Connection txn, MessageId m) throws DbException { PreparedStatement ps = null; ResultSet rs = null; @@ -2197,14 +2197,16 @@ abstract class JdbcDatabase implements Database<Connection> { ps.executeUpdate(); ps.close(); // Store the new connection window - sql = "INSERT INTO connectionWindows (contactId, index, unseen)" - + " VALUES(?, ?, ?)"; + sql = "INSERT INTO connectionWindows" + + " (contactId, index, unseen, secret)" + + " VALUES(?, ?, ?, ?)"; ps = txn.prepareStatement(sql); ps.setInt(1, c.getInt()); ps.setInt(2, i.getInt()); - Collection<Long> unseen = w.getUnseen(); - for(long l : unseen) { - ps.setLong(3, l); + Map<Long, byte[]> unseen = w.getUnseen(); + for(Entry<Long, byte[]> e : unseen.entrySet()) { + ps.setLong(3, e.getKey()); + ps.setBytes(4, e.getValue()); ps.addBatch(); } int[] affectedBatch = ps.executeBatch(); diff --git a/components/net/sf/briar/plugins/PluginManagerImpl.java b/components/net/sf/briar/plugins/PluginManagerImpl.java index 0d65447654..e6f069580b 100644 --- a/components/net/sf/briar/plugins/PluginManagerImpl.java +++ b/components/net/sf/briar/plugins/PluginManagerImpl.java @@ -292,7 +292,7 @@ class PluginManagerImpl implements PluginManager { public void writerCreated(ContactId c, BatchTransportWriter w) { assert index != null; - dispatcher.dispatchWriter(index, c, w); + dispatcher.dispatchWriter(c, index, w); } } @@ -307,7 +307,7 @@ class PluginManagerImpl implements PluginManager { public void outgoingConnectionCreated(ContactId c, StreamTransportConnection s) { assert index != null; - dispatcher.dispatchOutgoingConnection(index, c, s); + dispatcher.dispatchOutgoingConnection(c, index, s); } } } \ No newline at end of file diff --git a/components/net/sf/briar/transport/ConnectionContextFactoryImpl.java b/components/net/sf/briar/transport/ConnectionContextFactoryImpl.java index 08116a58b7..f50d5193c6 100644 --- a/components/net/sf/briar/transport/ConnectionContextFactoryImpl.java +++ b/components/net/sf/briar/transport/ConnectionContextFactoryImpl.java @@ -1,14 +1,31 @@ package net.sf.briar.transport; import net.sf.briar.api.ContactId; +import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContextFactory; +import com.google.inject.Inject; + class ConnectionContextFactoryImpl implements ConnectionContextFactory { + private final CryptoComponent crypto; + + @Inject + ConnectionContextFactoryImpl(CryptoComponent crypto) { + this.crypto = crypto; + } + public ConnectionContext createConnectionContext(ContactId c, - TransportIndex i, long connection) { - return new ConnectionContextImpl(c, i, connection); + TransportIndex i, long connection, byte[] secret) { + return new ConnectionContextImpl(c, i, connection, secret); + } + + public ConnectionContext createNextConnectionContext(ContactId c, + TransportIndex i, long connection, byte[] previousSecret) { + byte[] secret = crypto.deriveNextSecret(previousSecret, i.getInt(), + connection); + return new ConnectionContextImpl(c, i, connection, secret); } } diff --git a/components/net/sf/briar/transport/ConnectionContextImpl.java b/components/net/sf/briar/transport/ConnectionContextImpl.java index eb52114e3c..27e8caee68 100644 --- a/components/net/sf/briar/transport/ConnectionContextImpl.java +++ b/components/net/sf/briar/transport/ConnectionContextImpl.java @@ -9,12 +9,14 @@ class ConnectionContextImpl implements ConnectionContext { private final ContactId contactId; private final TransportIndex transportIndex; private final long connectionNumber; + private final byte[] secret; ConnectionContextImpl(ContactId contactId, TransportIndex transportIndex, - long connectionNumber) { + long connectionNumber, byte[] secret) { this.contactId = contactId; this.transportIndex = transportIndex; this.connectionNumber = connectionNumber; + this.secret = secret; } public ContactId getContactId() { @@ -28,4 +30,8 @@ class ConnectionContextImpl implements ConnectionContext { public long getConnectionNumber() { return connectionNumber; } + + public byte[] getSecret() { + return secret; + } } diff --git a/components/net/sf/briar/transport/ConnectionDispatcherImpl.java b/components/net/sf/briar/transport/ConnectionDispatcherImpl.java index 3cb09d0e4c..910c303d0e 100644 --- a/components/net/sf/briar/transport/ConnectionDispatcherImpl.java +++ b/components/net/sf/briar/transport/ConnectionDispatcherImpl.java @@ -62,8 +62,7 @@ public class ConnectionDispatcherImpl implements ConnectionDispatcher { r.dispose(false); return; } - batchConnFactory.createIncomingConnection(ctx.getTransportIndex(), - ctx.getContactId(), r, encryptedIv); + batchConnFactory.createIncomingConnection(ctx, r, encryptedIv); } private byte[] readIv(InputStream in) throws IOException { @@ -77,9 +76,9 @@ public class ConnectionDispatcherImpl implements ConnectionDispatcher { return b; } - public void dispatchWriter(TransportIndex i, ContactId c, + public void dispatchWriter(ContactId c, TransportIndex i, BatchTransportWriter w) { - batchConnFactory.createOutgoingConnection(i, c, w); + batchConnFactory.createOutgoingConnection(c, i, w); } public void dispatchIncomingConnection(TransportId t, @@ -106,12 +105,11 @@ public class ConnectionDispatcherImpl implements ConnectionDispatcher { s.dispose(false); return; } - streamConnFactory.createIncomingConnection(ctx.getTransportIndex(), - ctx.getContactId(), s, encryptedIv); + streamConnFactory.createIncomingConnection(ctx, s, encryptedIv); } - public void dispatchOutgoingConnection(TransportIndex i, ContactId c, + public void dispatchOutgoingConnection(ContactId c, TransportIndex i, StreamTransportConnection s) { - streamConnFactory.createOutgoingConnection(i, c, s); + streamConnFactory.createOutgoingConnection(c, i, s); } } diff --git a/components/net/sf/briar/transport/ConnectionReaderFactoryImpl.java b/components/net/sf/briar/transport/ConnectionReaderFactoryImpl.java index 6469e8b4e2..ffc95ca0c0 100644 --- a/components/net/sf/briar/transport/ConnectionReaderFactoryImpl.java +++ b/components/net/sf/briar/transport/ConnectionReaderFactoryImpl.java @@ -7,12 +7,13 @@ import javax.crypto.BadPaddingException; import javax.crypto.Cipher; import javax.crypto.IllegalBlockSizeException; import javax.crypto.Mac; -import net.sf.briar.api.crypto.ErasableKey; import net.sf.briar.api.crypto.CryptoComponent; -import net.sf.briar.api.protocol.TransportIndex; +import net.sf.briar.api.crypto.ErasableKey; +import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReaderFactory; +import net.sf.briar.util.ByteUtils; import com.google.inject.Inject; @@ -26,10 +27,10 @@ class ConnectionReaderFactoryImpl implements ConnectionReaderFactory { } public ConnectionReader createConnectionReader(InputStream in, - TransportIndex i, byte[] encryptedIv, byte[] secret) { + ConnectionContext ctx, byte[] encryptedIv) { // Decrypt the IV Cipher ivCipher = crypto.getIvCipher(); - ErasableKey ivKey = crypto.deriveIvKey(secret, true); + ErasableKey ivKey = crypto.deriveIvKey(ctx.getSecret(), true); byte[] iv; try { ivCipher.init(Cipher.DECRYPT_MODE, ivKey); @@ -42,27 +43,25 @@ class ConnectionReaderFactoryImpl implements ConnectionReaderFactory { throw new IllegalArgumentException(badKey); } // Validate the IV - if(!IvEncoder.validateIv(iv, true, i)) + if(!IvEncoder.validateIv(iv, true, ctx)) throw new IllegalArgumentException(); - // Copy the connection number - long connection = IvEncoder.getConnectionNumber(iv); - return createConnectionReader(in, true, i, connection, secret); + return createConnectionReader(in, true, ctx); } public ConnectionReader createConnectionReader(InputStream in, - TransportIndex i, long connection, byte[] secret) { - return createConnectionReader(in, false, i, connection, secret); + ConnectionContext ctx) { + return createConnectionReader(in, false, ctx); } private ConnectionReader createConnectionReader(InputStream in, - boolean initiator, TransportIndex i, long connection, - byte[] secret) { + boolean initiator, ConnectionContext ctx) { // Derive the keys and erase the secret + byte[] secret = ctx.getSecret(); ErasableKey frameKey = crypto.deriveFrameKey(secret, initiator); ErasableKey macKey = crypto.deriveMacKey(secret, initiator); - for(int j = 0; j < secret.length; j++) secret[j] = 0; + ByteUtils.erase(secret); // Create the decrypter - byte[] iv = IvEncoder.encodeIv(initiator, i, connection); + byte[] iv = IvEncoder.encodeIv(initiator, ctx); Cipher frameCipher = crypto.getFrameCipher(); ConnectionDecrypter decrypter = new ConnectionDecrypterImpl(in, iv, frameCipher, frameKey); diff --git a/components/net/sf/briar/transport/ConnectionRecogniserImpl.java b/components/net/sf/briar/transport/ConnectionRecogniserImpl.java index a8745f61e4..fb5e59eff1 100644 --- a/components/net/sf/briar/transport/ConnectionRecogniserImpl.java +++ b/components/net/sf/briar/transport/ConnectionRecogniserImpl.java @@ -8,25 +8,26 @@ import java.util.Collection; import java.util.HashMap; import java.util.Iterator; import java.util.Map; +import java.util.Map.Entry; import java.util.logging.Level; import java.util.logging.Logger; import javax.crypto.BadPaddingException; import javax.crypto.Cipher; import javax.crypto.IllegalBlockSizeException; -import net.sf.briar.api.crypto.ErasableKey; import net.sf.briar.api.Bytes; import net.sf.briar.api.ContactId; import net.sf.briar.api.crypto.CryptoComponent; +import net.sf.briar.api.crypto.ErasableKey; import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.NoSuchContactException; import net.sf.briar.api.db.event.ContactRemovedEvent; import net.sf.briar.api.db.event.DatabaseEvent; import net.sf.briar.api.db.event.DatabaseListener; -import net.sf.briar.api.db.event.TransportAddedEvent; import net.sf.briar.api.db.event.RemoteTransportsUpdatedEvent; +import net.sf.briar.api.db.event.TransportAddedEvent; import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportIndex; @@ -75,30 +76,29 @@ DatabaseListener { } private synchronized void calculateIvs(ContactId c) throws DbException { - byte[] secret = db.getSharedSecret(c, true); - ErasableKey ivKey = crypto.deriveIvKey(secret, true); - for(int i = 0; i < secret.length; i++) secret[i] = 0; for(TransportId t : localTransportIds) { TransportIndex i = db.getRemoteIndex(c, t); if(i != null) { ConnectionWindow w = db.getConnectionWindow(c, i); - calculateIvs(c, i, ivKey, w); + calculateIvs(c, i, w); } } } private synchronized void calculateIvs(ContactId c, TransportIndex i, - ErasableKey ivKey, ConnectionWindow w) - throws DbException { - for(Long unseen : w.getUnseen()) { + ConnectionWindow w) throws DbException { + for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) { + long unseen = e.getKey(); + byte[] secret = e.getValue(); + ErasableKey ivKey = crypto.deriveIvKey(secret, true); Bytes iv = new Bytes(encryptIv(i, unseen, ivKey)); - expected.put(iv, new ConnectionContextImpl(c, i, unseen)); + expected.put(iv, new ConnectionContextImpl(c, i, unseen, secret)); } } private synchronized byte[] encryptIv(TransportIndex i, long connection, ErasableKey ivKey) { - byte[] iv = IvEncoder.encodeIv(true, i, connection); + byte[] iv = IvEncoder.encodeIv(true, i.getInt(), connection); try { ivCipher.init(Cipher.ENCRYPT_MODE, ivKey); return ivCipher.doFinal(iv); @@ -133,10 +133,7 @@ DatabaseListener { TransportIndex i1 = ctx1.getTransportIndex(); if(c1.equals(c) && i1.equals(i)) it.remove(); } - byte[] secret = db.getSharedSecret(c, true); - ErasableKey ivKey = crypto.deriveIvKey(secret, true); - for(int j = 0; j < secret.length; j++) secret[j] = 0; - calculateIvs(c, i, ivKey, w); + calculateIvs(c, i, w); } catch(NoSuchContactException e) { // The contact was removed - clean up when we get the event } @@ -185,13 +182,10 @@ DatabaseListener { private synchronized void calculateIvs(TransportId t) throws DbException { for(ContactId c : db.getContacts()) { try { - byte[] secret = db.getSharedSecret(c, true); - ErasableKey ivKey = crypto.deriveIvKey(secret, true); - for(int i = 0; i < secret.length; i++) secret[i] = 0; TransportIndex i = db.getRemoteIndex(c, t); if(i != null) { ConnectionWindow w = db.getConnectionWindow(c, i); - calculateIvs(c, i, ivKey, w); + calculateIvs(c, i, w); } } catch(NoSuchContactException e) { // The contact was removed - clean up when we get the event diff --git a/components/net/sf/briar/transport/ConnectionWindowFactoryImpl.java b/components/net/sf/briar/transport/ConnectionWindowFactoryImpl.java index 69a80afd2b..2890e04272 100644 --- a/components/net/sf/briar/transport/ConnectionWindowFactoryImpl.java +++ b/components/net/sf/briar/transport/ConnectionWindowFactoryImpl.java @@ -1,17 +1,30 @@ package net.sf.briar.transport; -import java.util.Collection; +import java.util.Map; +import net.sf.briar.api.crypto.CryptoComponent; +import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.transport.ConnectionWindow; import net.sf.briar.api.transport.ConnectionWindowFactory; +import com.google.inject.Inject; + class ConnectionWindowFactoryImpl implements ConnectionWindowFactory { - public ConnectionWindow createConnectionWindow() { - return new ConnectionWindowImpl(); + private final CryptoComponent crypto; + + @Inject + ConnectionWindowFactoryImpl(CryptoComponent crypto) { + this.crypto = crypto; + } + + public ConnectionWindow createConnectionWindow(TransportIndex i, + byte[] secret) { + return new ConnectionWindowImpl(crypto, i, secret); } - public ConnectionWindow createConnectionWindow(Collection<Long> unseen) { - return new ConnectionWindowImpl(unseen); + public ConnectionWindow createConnectionWindow(TransportIndex i, + Map<Long, byte[]> unseen) { + return new ConnectionWindowImpl(crypto, i, unseen); } } diff --git a/components/net/sf/briar/transport/ConnectionWindowImpl.java b/components/net/sf/briar/transport/ConnectionWindowImpl.java index 3be375edbf..0bd397c300 100644 --- a/components/net/sf/briar/transport/ConnectionWindowImpl.java +++ b/components/net/sf/briar/transport/ConnectionWindowImpl.java @@ -2,28 +2,38 @@ package net.sf.briar.transport; import static net.sf.briar.api.protocol.ProtocolConstants.CONNECTION_WINDOW_SIZE; -import java.util.Collection; -import java.util.Set; -import java.util.TreeSet; +import java.util.HashMap; +import java.util.Map; +import net.sf.briar.api.crypto.CryptoComponent; +import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.transport.ConnectionWindow; import net.sf.briar.util.ByteUtils; class ConnectionWindowImpl implements ConnectionWindow { - private final Set<Long> unseen; + private final CryptoComponent crypto; + private final int index; + private final Map<Long, byte[]> unseen; private long centre; - ConnectionWindowImpl() { - unseen = new TreeSet<Long>(); - for(long l = 0; l < CONNECTION_WINDOW_SIZE / 2; l++) unseen.add(l); + ConnectionWindowImpl(CryptoComponent crypto, TransportIndex i, + byte[] secret) { + this.crypto = crypto; + index = i.getInt(); + unseen = new HashMap<Long, byte[]>(); + for(long l = 0; l < CONNECTION_WINDOW_SIZE / 2; l++) { + secret = crypto.deriveNextSecret(secret, index, l); + unseen.put(l, secret); + } centre = 0; } - ConnectionWindowImpl(Collection<Long> unseen) { + ConnectionWindowImpl(CryptoComponent crypto, TransportIndex i, + Map<Long, byte[]> unseen) { long min = Long.MAX_VALUE, max = Long.MIN_VALUE; - for(long l : unseen) { + for(long l : unseen.keySet()) { if(l < 0 || l > ByteUtils.MAX_32_BIT_UNSIGNED) throw new IllegalArgumentException(); if(l < min) min = l; @@ -31,15 +41,17 @@ class ConnectionWindowImpl implements ConnectionWindow { } if(max - min > CONNECTION_WINDOW_SIZE) throw new IllegalArgumentException(); - this.unseen = new TreeSet<Long>(unseen); centre = max - CONNECTION_WINDOW_SIZE / 2 + 1; for(long l = centre; l <= max; l++) { - if(!this.unseen.contains(l)) throw new IllegalArgumentException(); + if(!unseen.containsKey(l)) throw new IllegalArgumentException(); } + this.crypto = crypto; + index = i.getInt(); + this.unseen = unseen; } public boolean isSeen(long connection) { - return !unseen.contains(connection); + return !unseen.containsKey(connection); } public void setSeen(long connection) { @@ -47,14 +59,26 @@ class ConnectionWindowImpl implements ConnectionWindow { long top = getTop(centre); if(connection < bottom || connection > top) throw new IllegalArgumentException(); - if(!unseen.remove(connection)) throw new IllegalArgumentException(); + if(!unseen.containsKey(connection)) + throw new IllegalArgumentException(); if(connection >= centre) { centre = connection + 1; long newBottom = getBottom(centre); long newTop = getTop(centre); - for(long l = bottom; l < newBottom; l++) unseen.remove(l); - for(long l = top + 1; l <= newTop; l++) unseen.add(l); + for(long l = bottom; l < newBottom; l++) { + byte[] expired = unseen.remove(l); + if(expired != null) ByteUtils.erase(expired); + } + byte[] topSecret = unseen.get(top); + assert topSecret != null; + for(long l = top + 1; l <= newTop; l++) { + topSecret = crypto.deriveNextSecret(topSecret, index, l); + unseen.put(l, topSecret); + } } + byte[] seen = unseen.remove(connection); + assert seen != null; + ByteUtils.erase(seen); } // Returns the lowest value contained in a window with the given centre @@ -68,7 +92,7 @@ class ConnectionWindowImpl implements ConnectionWindow { centre + CONNECTION_WINDOW_SIZE / 2 - 1); } - public Collection<Long> getUnseen() { + public Map<Long, byte[]> getUnseen() { return unseen; } } diff --git a/components/net/sf/briar/transport/ConnectionWriterFactoryImpl.java b/components/net/sf/briar/transport/ConnectionWriterFactoryImpl.java index caab149260..0b91a80d77 100644 --- a/components/net/sf/briar/transport/ConnectionWriterFactoryImpl.java +++ b/components/net/sf/briar/transport/ConnectionWriterFactoryImpl.java @@ -7,12 +7,13 @@ import javax.crypto.BadPaddingException; import javax.crypto.Cipher; import javax.crypto.IllegalBlockSizeException; import javax.crypto.Mac; -import net.sf.briar.api.crypto.ErasableKey; import net.sf.briar.api.crypto.CryptoComponent; -import net.sf.briar.api.protocol.TransportIndex; +import net.sf.briar.api.crypto.ErasableKey; +import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionWriter; import net.sf.briar.api.transport.ConnectionWriterFactory; +import net.sf.briar.util.ByteUtils; import com.google.inject.Inject; @@ -26,17 +27,15 @@ class ConnectionWriterFactoryImpl implements ConnectionWriterFactory { } public ConnectionWriter createConnectionWriter(OutputStream out, - long capacity, TransportIndex i, long connection, byte[] secret) { - return createConnectionWriter(out, capacity, true, i, connection, - secret); + long capacity, ConnectionContext ctx) { + return createConnectionWriter(out, capacity, true, ctx); } public ConnectionWriter createConnectionWriter(OutputStream out, - long capacity, TransportIndex i, byte[] encryptedIv, - byte[] secret) { + long capacity, ConnectionContext ctx, byte[] encryptedIv) { // Decrypt the IV Cipher ivCipher = crypto.getIvCipher(); - ErasableKey ivKey = crypto.deriveIvKey(secret, true); + ErasableKey ivKey = crypto.deriveIvKey(ctx.getSecret(), true); byte[] iv; try { ivCipher.init(Cipher.DECRYPT_MODE, ivKey); @@ -49,26 +48,23 @@ class ConnectionWriterFactoryImpl implements ConnectionWriterFactory { throw new RuntimeException(badKey); } // Validate the IV - if(!IvEncoder.validateIv(iv, true, i)) + if(!IvEncoder.validateIv(iv, true, ctx)) throw new IllegalArgumentException(); - // Copy the connection number - long connection = IvEncoder.getConnectionNumber(iv); - return createConnectionWriter(out, capacity, false, i, connection, - secret); + return createConnectionWriter(out, capacity, false, ctx); } private ConnectionWriter createConnectionWriter(OutputStream out, - long capacity, boolean initiator, TransportIndex i, long connection, - byte[] secret) { + long capacity, boolean initiator, ConnectionContext ctx) { // Derive the keys and erase the secret + byte[] secret = ctx.getSecret(); ErasableKey ivKey = crypto.deriveIvKey(secret, initiator); ErasableKey frameKey = crypto.deriveFrameKey(secret, initiator); ErasableKey macKey = crypto.deriveMacKey(secret, initiator); - for(int j = 0; j < secret.length; j++) secret[j] = 0; + ByteUtils.erase(secret); // Create the encrypter Cipher ivCipher = crypto.getIvCipher(); Cipher frameCipher = crypto.getFrameCipher(); - byte[] iv = IvEncoder.encodeIv(initiator, i, connection); + byte[] iv = IvEncoder.encodeIv(initiator, ctx); ConnectionEncrypter encrypter = new ConnectionEncrypterImpl(out, capacity, iv, ivCipher, frameCipher, ivKey, frameKey); // Create the writer diff --git a/components/net/sf/briar/transport/IvEncoder.java b/components/net/sf/briar/transport/IvEncoder.java index 8bba94ba7c..94aa963ffa 100644 --- a/components/net/sf/briar/transport/IvEncoder.java +++ b/components/net/sf/briar/transport/IvEncoder.java @@ -1,18 +1,22 @@ package net.sf.briar.transport; import static net.sf.briar.api.transport.TransportConstants.IV_LENGTH; -import net.sf.briar.api.protocol.TransportIndex; +import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.util.ByteUtils; class IvEncoder { - static byte[] encodeIv(boolean initiator, TransportIndex i, - long connection) { + static byte[] encodeIv(boolean initiator, ConnectionContext ctx) { + return encodeIv(initiator, ctx.getTransportIndex().getInt(), + ctx.getConnectionNumber()); + } + + static byte[] encodeIv(boolean initiator, int index, long connection) { byte[] iv = new byte[IV_LENGTH]; // Bit 31 is the initiator flag if(initiator) iv[3] = 1; - // Encode the transport identifier as an unsigned 16-bit integer - ByteUtils.writeUint16(i.getInt(), iv, 4); + // Encode the transport index as an unsigned 16-bit integer + ByteUtils.writeUint16(index, iv, 4); // Encode the connection number as an unsigned 32-bit integer ByteUtils.writeUint32(connection, iv, 6); return iv; @@ -24,7 +28,14 @@ class IvEncoder { ByteUtils.writeUint32(frame, iv, 10); } - static boolean validateIv(byte[] iv, boolean initiator, TransportIndex i) { + static boolean validateIv(byte[] iv, boolean initiator, + ConnectionContext ctx) { + return validateIv(iv, initiator, ctx.getTransportIndex().getInt(), + ctx.getConnectionNumber()); + } + + static boolean validateIv(byte[] iv, boolean initiator, int index, + long connection) { if(iv.length != IV_LENGTH) return false; // Check that the reserved bits are all zero for(int j = 0; j < 2; j++) if(iv[j] != 0) return false; @@ -32,8 +43,10 @@ class IvEncoder { for(int j = 10; j < iv.length; j++) if(iv[j] != 0) return false; // Check that the initiator flag matches if(initiator != getInitiatorFlag(iv)) return false; - // Check that the transport ID matches - if(i.getInt() != getTransportId(iv)) return false; + // Check that the transport index matches + if(index != getTransportIndex(iv)) return false; + // Check that the connection number matches + if(connection != getConnectionNumber(iv)) return false; // The IV is valid return true; } @@ -43,7 +56,7 @@ class IvEncoder { return (iv[3] & 1) == 1; } - static int getTransportId(byte[] iv) { + static int getTransportIndex(byte[] iv) { if(iv.length != IV_LENGTH) throw new IllegalArgumentException(); return ByteUtils.readUint16(iv, 4); } diff --git a/components/net/sf/briar/transport/batch/BatchConnectionFactoryImpl.java b/components/net/sf/briar/transport/batch/BatchConnectionFactoryImpl.java index c3329d8492..280f865b41 100644 --- a/components/net/sf/briar/transport/batch/BatchConnectionFactoryImpl.java +++ b/components/net/sf/briar/transport/batch/BatchConnectionFactoryImpl.java @@ -8,6 +8,7 @@ import net.sf.briar.api.protocol.writers.ProtocolWriterFactory; import net.sf.briar.api.transport.BatchConnectionFactory; import net.sf.briar.api.transport.BatchTransportReader; import net.sf.briar.api.transport.BatchTransportWriter; +import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionWriterFactory; @@ -33,11 +34,10 @@ class BatchConnectionFactoryImpl implements BatchConnectionFactory { this.protoWriterFactory = protoWriterFactory; } - public void createIncomingConnection(TransportIndex i, ContactId c, + public void createIncomingConnection(ConnectionContext ctx, BatchTransportReader r, byte[] encryptedIv) { final IncomingBatchConnection conn = new IncomingBatchConnection( - connReaderFactory, db, protoReaderFactory, i, c, r, - encryptedIv); + connReaderFactory, db, protoReaderFactory, ctx, r, encryptedIv); Runnable read = new Runnable() { public void run() { conn.read(); @@ -46,10 +46,10 @@ class BatchConnectionFactoryImpl implements BatchConnectionFactory { new Thread(read).start(); } - public void createOutgoingConnection(TransportIndex i, ContactId c, + public void createOutgoingConnection(ContactId c, TransportIndex i, BatchTransportWriter w) { final OutgoingBatchConnection conn = new OutgoingBatchConnection( - connWriterFactory, db, protoWriterFactory, i, c, w); + connWriterFactory, db, protoWriterFactory, c, i, w); Runnable write = new Runnable() { public void run() { conn.write(); diff --git a/components/net/sf/briar/transport/batch/IncomingBatchConnection.java b/components/net/sf/briar/transport/batch/IncomingBatchConnection.java index a97c392c97..4440edacd0 100644 --- a/components/net/sf/briar/transport/batch/IncomingBatchConnection.java +++ b/components/net/sf/briar/transport/batch/IncomingBatchConnection.java @@ -13,9 +13,9 @@ import net.sf.briar.api.protocol.Batch; import net.sf.briar.api.protocol.ProtocolReader; import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.SubscriptionUpdate; -import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.transport.BatchTransportReader; +import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReaderFactory; @@ -27,46 +27,43 @@ class IncomingBatchConnection { private final ConnectionReaderFactory connFactory; private final DatabaseComponent db; private final ProtocolReaderFactory protoFactory; - private final TransportIndex transportIndex; - private final ContactId contactId; + private final ConnectionContext ctx; private final BatchTransportReader reader; private final byte[] encryptedIv; IncomingBatchConnection(ConnectionReaderFactory connFactory, DatabaseComponent db, ProtocolReaderFactory protoFactory, - TransportIndex transportIndex, ContactId contactId, - BatchTransportReader reader, byte[] encryptedIv) { + ConnectionContext ctx, BatchTransportReader reader, + byte[] encryptedIv) { this.connFactory = connFactory; this.db = db; this.protoFactory = protoFactory; - this.transportIndex = transportIndex; - this.contactId = contactId; + this.ctx = ctx; this.reader = reader; this.encryptedIv = encryptedIv; } void read() { try { - byte[] secret = db.getSharedSecret(contactId, true); ConnectionReader conn = connFactory.createConnectionReader( - reader.getInputStream(), transportIndex, encryptedIv, - secret); + reader.getInputStream(), ctx, encryptedIv); ProtocolReader proto = protoFactory.createProtocolReader( conn.getInputStream()); + ContactId c = ctx.getContactId(); // Read packets until EOF while(!proto.eof()) { if(proto.hasAck()) { Ack a = proto.readAck(); - db.receiveAck(contactId, a); + db.receiveAck(c, a); } else if(proto.hasBatch()) { Batch b = proto.readBatch(); - db.receiveBatch(contactId, b); + db.receiveBatch(c, b); } else if(proto.hasSubscriptionUpdate()) { SubscriptionUpdate s = proto.readSubscriptionUpdate(); - db.receiveSubscriptionUpdate(contactId, s); + db.receiveSubscriptionUpdate(c, s); } else if(proto.hasTransportUpdate()) { TransportUpdate t = proto.readTransportUpdate(); - db.receiveTransportUpdate(contactId, t); + db.receiveTransportUpdate(c, t); } else { throw new FormatException(); } diff --git a/components/net/sf/briar/transport/batch/OutgoingBatchConnection.java b/components/net/sf/briar/transport/batch/OutgoingBatchConnection.java index e00529eee1..88f0f127f3 100644 --- a/components/net/sf/briar/transport/batch/OutgoingBatchConnection.java +++ b/components/net/sf/briar/transport/batch/OutgoingBatchConnection.java @@ -29,30 +29,28 @@ class OutgoingBatchConnection { private final ConnectionWriterFactory connFactory; private final DatabaseComponent db; private final ProtocolWriterFactory protoFactory; - private final TransportIndex transportIndex; private final ContactId contactId; + private final TransportIndex transportIndex; private final BatchTransportWriter writer; OutgoingBatchConnection(ConnectionWriterFactory connFactory, DatabaseComponent db, ProtocolWriterFactory protoFactory, - TransportIndex transportIndex, ContactId contactId, + ContactId contactId, TransportIndex transportIndex, BatchTransportWriter writer) { this.connFactory = connFactory; this.db = db; this.protoFactory = protoFactory; - this.transportIndex = transportIndex; this.contactId = contactId; + this.transportIndex = transportIndex; this.writer = writer; } void write() { try { - byte[] secret = db.getSharedSecret(contactId, false); - ConnectionContext ctx = - db.getConnectionContext(contactId, transportIndex); + ConnectionContext ctx = db.getConnectionContext(contactId, + transportIndex); ConnectionWriter conn = connFactory.createConnectionWriter( - writer.getOutputStream(), writer.getCapacity(), - transportIndex, ctx.getConnectionNumber(), secret); + writer.getOutputStream(), writer.getCapacity(), ctx); OutputStream out = conn.getOutputStream(); // There should be enough space for a packet long capacity = conn.getRemainingCapacity(); diff --git a/components/net/sf/briar/transport/stream/IncomingStreamConnection.java b/components/net/sf/briar/transport/stream/IncomingStreamConnection.java index bc01c5ad61..d3da63f21c 100644 --- a/components/net/sf/briar/transport/stream/IncomingStreamConnection.java +++ b/components/net/sf/briar/transport/stream/IncomingStreamConnection.java @@ -2,12 +2,11 @@ package net.sf.briar.transport.stream; import java.io.IOException; -import net.sf.briar.api.ContactId; import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DbException; import net.sf.briar.api.protocol.ProtocolReaderFactory; -import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.writers.ProtocolWriterFactory; +import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionWriter; @@ -16,35 +15,32 @@ import net.sf.briar.api.transport.StreamTransportConnection; public class IncomingStreamConnection extends StreamConnection { + private final ConnectionContext ctx; private final byte[] encryptedIv; IncomingStreamConnection(ConnectionReaderFactory connReaderFactory, ConnectionWriterFactory connWriterFactory, DatabaseComponent db, ProtocolReaderFactory protoReaderFactory, ProtocolWriterFactory protoWriterFactory, - TransportIndex transportIndex, ContactId contactId, - StreamTransportConnection connection, + ConnectionContext ctx, StreamTransportConnection connection, byte[] encryptedIv) { super(connReaderFactory, connWriterFactory, db, protoReaderFactory, - protoWriterFactory, transportIndex, contactId, connection); + protoWriterFactory, ctx.getContactId(), connection); + this.ctx = ctx; this.encryptedIv = encryptedIv; } @Override protected ConnectionReader createConnectionReader() throws DbException, IOException { - byte[] secret = db.getSharedSecret(contactId, true); return connReaderFactory.createConnectionReader( - connection.getInputStream(), transportIndex, encryptedIv, - secret); + connection.getInputStream(), ctx, encryptedIv); } @Override protected ConnectionWriter createConnectionWriter() throws DbException, IOException { - byte[] secret = db.getSharedSecret(contactId, false); return connWriterFactory.createConnectionWriter( - connection.getOutputStream(), Long.MAX_VALUE, transportIndex, - encryptedIv, secret); + connection.getOutputStream(), Long.MAX_VALUE, ctx, encryptedIv); } } diff --git a/components/net/sf/briar/transport/stream/OutgoingStreamConnection.java b/components/net/sf/briar/transport/stream/OutgoingStreamConnection.java index b02e460074..178190e2a1 100644 --- a/components/net/sf/briar/transport/stream/OutgoingStreamConnection.java +++ b/components/net/sf/briar/transport/stream/OutgoingStreamConnection.java @@ -17,43 +17,40 @@ import net.sf.briar.api.transport.StreamTransportConnection; public class OutgoingStreamConnection extends StreamConnection { + private final TransportIndex transportIndex; + private ConnectionContext ctx = null; // Locking: this OutgoingStreamConnection(ConnectionReaderFactory connReaderFactory, ConnectionWriterFactory connWriterFactory, DatabaseComponent db, ProtocolReaderFactory protoReaderFactory, - ProtocolWriterFactory protoWriterFactory, - TransportIndex transportIndex, ContactId contactId, + ProtocolWriterFactory protoWriterFactory, ContactId contactId, + TransportIndex transportIndex, StreamTransportConnection connection) { super(connReaderFactory, connWriterFactory, db, protoReaderFactory, - protoWriterFactory, transportIndex, contactId, connection); + protoWriterFactory, contactId, connection); + this.transportIndex = transportIndex; } @Override protected ConnectionReader createConnectionReader() throws DbException, IOException { synchronized(this) { - if(ctx == null) { + if(ctx == null) ctx = db.getConnectionContext(contactId, transportIndex); - } } - byte[] secret = db.getSharedSecret(contactId, true); return connReaderFactory.createConnectionReader( - connection.getInputStream(), transportIndex, - ctx.getConnectionNumber(), secret); + connection.getInputStream(), ctx); } @Override protected ConnectionWriter createConnectionWriter() throws DbException, IOException { synchronized(this) { - if(ctx == null) { + if(ctx == null) ctx = db.getConnectionContext(contactId, transportIndex); - } } - byte[] secret = db.getSharedSecret(contactId, false); return connWriterFactory.createConnectionWriter( - connection.getOutputStream(), Long.MAX_VALUE, transportIndex, - ctx.getConnectionNumber(), secret); + connection.getOutputStream(), Long.MAX_VALUE, ctx); } } diff --git a/components/net/sf/briar/transport/stream/StreamConnection.java b/components/net/sf/briar/transport/stream/StreamConnection.java index 90d026b63b..f24557cedf 100644 --- a/components/net/sf/briar/transport/stream/StreamConnection.java +++ b/components/net/sf/briar/transport/stream/StreamConnection.java @@ -29,7 +29,6 @@ import net.sf.briar.api.protocol.ProtocolReader; import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.SubscriptionUpdate; -import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.writers.AckWriter; import net.sf.briar.api.protocol.writers.BatchWriter; @@ -56,7 +55,6 @@ abstract class StreamConnection implements DatabaseListener { protected final DatabaseComponent db; protected final ProtocolReaderFactory protoReaderFactory; protected final ProtocolWriterFactory protoWriterFactory; - protected final TransportIndex transportIndex; protected final ContactId contactId; protected final StreamTransportConnection connection; @@ -69,15 +67,13 @@ abstract class StreamConnection implements DatabaseListener { StreamConnection(ConnectionReaderFactory connReaderFactory, ConnectionWriterFactory connWriterFactory, DatabaseComponent db, ProtocolReaderFactory protoReaderFactory, - ProtocolWriterFactory protoWriterFactory, - TransportIndex transportIndex, ContactId contactId, + ProtocolWriterFactory protoWriterFactory, ContactId contactId, StreamTransportConnection connection) { this.connReaderFactory = connReaderFactory; this.connWriterFactory = connWriterFactory; this.db = db; this.protoReaderFactory = protoReaderFactory; this.protoWriterFactory = protoWriterFactory; - this.transportIndex = transportIndex; this.contactId = contactId; this.connection = connection; } diff --git a/components/net/sf/briar/transport/stream/StreamConnectionFactoryImpl.java b/components/net/sf/briar/transport/stream/StreamConnectionFactoryImpl.java index c3aecf5c54..cfaf8fe747 100644 --- a/components/net/sf/briar/transport/stream/StreamConnectionFactoryImpl.java +++ b/components/net/sf/briar/transport/stream/StreamConnectionFactoryImpl.java @@ -5,6 +5,7 @@ import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.writers.ProtocolWriterFactory; +import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.api.transport.StreamConnectionFactory; @@ -32,11 +33,11 @@ public class StreamConnectionFactoryImpl implements StreamConnectionFactory { this.protoWriterFactory = protoWriterFactory; } - public void createIncomingConnection(TransportIndex i, ContactId c, + public void createIncomingConnection(ConnectionContext ctx, StreamTransportConnection s, byte[] encryptedIv) { final StreamConnection conn = new IncomingStreamConnection( connReaderFactory, connWriterFactory, db, protoReaderFactory, - protoWriterFactory, i, c, s, encryptedIv); + protoWriterFactory, ctx, s, encryptedIv); Runnable write = new Runnable() { public void run() { conn.write(); @@ -51,11 +52,11 @@ public class StreamConnectionFactoryImpl implements StreamConnectionFactory { new Thread(read).start(); } - public void createOutgoingConnection(TransportIndex i, ContactId c, + public void createOutgoingConnection(ContactId c, TransportIndex i, StreamTransportConnection s) { final StreamConnection conn = new OutgoingStreamConnection( connReaderFactory, connWriterFactory, db, protoReaderFactory, - protoWriterFactory, i, c, s); + protoWriterFactory, c, i, s); Runnable write = new Runnable() { public void run() { conn.write(); diff --git a/test/net/sf/briar/ProtocolIntegrationTest.java b/test/net/sf/briar/ProtocolIntegrationTest.java index 001407cdfe..add29b2568 100644 --- a/test/net/sf/briar/ProtocolIntegrationTest.java +++ b/test/net/sf/briar/ProtocolIntegrationTest.java @@ -16,6 +16,7 @@ import java.util.Map; import java.util.Random; import junit.framework.TestCase; +import net.sf.briar.api.ContactId; import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.protocol.Ack; import net.sf.briar.api.protocol.Author; @@ -36,7 +37,6 @@ import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportUpdate; -import net.sf.briar.api.protocol.UniqueId; import net.sf.briar.api.protocol.writers.AckWriter; import net.sf.briar.api.protocol.writers.BatchWriter; import net.sf.briar.api.protocol.writers.OfferWriter; @@ -44,6 +44,8 @@ import net.sf.briar.api.protocol.writers.ProtocolWriterFactory; import net.sf.briar.api.protocol.writers.RequestWriter; import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter; import net.sf.briar.api.protocol.writers.TransportUpdateWriter; +import net.sf.briar.api.transport.ConnectionContext; +import net.sf.briar.api.transport.ConnectionContextFactory; import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionWriter; @@ -68,12 +70,14 @@ public class ProtocolIntegrationTest extends TestCase { private final BatchId ack = new BatchId(TestUtils.getRandomId()); private final long timestamp = System.currentTimeMillis(); + private final ConnectionContextFactory connectionContextFactory; private final ConnectionReaderFactory connectionReaderFactory; private final ConnectionWriterFactory connectionWriterFactory; private final ProtocolReaderFactory protocolReaderFactory; private final ProtocolWriterFactory protocolWriterFactory; private final CryptoComponent crypto; - private final byte[] aliceToBobSecret; + private final byte[] secret; + private final ContactId contactId = new ContactId(13); private final TransportIndex transportIndex = new TransportIndex(13); private final long connection = 12345L; private final Author author; @@ -91,16 +95,17 @@ public class ProtocolIntegrationTest extends TestCase { new ProtocolWritersModule(), new SerialModule(), new TestDatabaseModule(), new TransportBatchModule(), new TransportModule(), new TransportStreamModule()); + connectionContextFactory = + i.getInstance(ConnectionContextFactory.class); connectionReaderFactory = i.getInstance(ConnectionReaderFactory.class); connectionWriterFactory = i.getInstance(ConnectionWriterFactory.class); protocolReaderFactory = i.getInstance(ProtocolReaderFactory.class); protocolWriterFactory = i.getInstance(ProtocolWriterFactory.class); crypto = i.getInstance(CryptoComponent.class); - assertEquals(crypto.getMessageDigest().getDigestLength(), - UniqueId.LENGTH); + // Create a shared secret Random r = new Random(); - aliceToBobSecret = new byte[32]; - r.nextBytes(aliceToBobSecret); + secret = new byte[32]; + r.nextBytes(secret); // Create two groups: one restricted, one unrestricted GroupFactory groupFactory = i.getInstance(GroupFactory.class); group = groupFactory.createGroup("Unrestricted group", null); @@ -139,9 +144,11 @@ public class ProtocolIntegrationTest extends TestCase { private byte[] write() throws Exception { ByteArrayOutputStream out = new ByteArrayOutputStream(); - byte[] copyOfSecret = Arrays.clone(aliceToBobSecret); + ConnectionContext ctx = + connectionContextFactory.createConnectionContext(contactId, + transportIndex, connection, Arrays.clone(secret)); ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out, - Long.MAX_VALUE, transportIndex, connection, copyOfSecret); + Long.MAX_VALUE, ctx); OutputStream out1 = w.getOutputStream(); AckWriter a = protocolWriterFactory.createAckWriter(out1); @@ -184,19 +191,15 @@ public class ProtocolIntegrationTest extends TestCase { return out.toByteArray(); } - private void read(byte[] connection) throws Exception { - InputStream in = new ByteArrayInputStream(connection); + private void read(byte[] connectionData) throws Exception { + InputStream in = new ByteArrayInputStream(connectionData); byte[] encryptedIv = new byte[16]; - int offset = 0; - while(offset < 16) { - int read = in.read(encryptedIv, offset, 16 - offset); - if(read == -1) break; - offset += read; - } - assertEquals(16, offset); - byte[] copyOfSecret = Arrays.clone(aliceToBobSecret); + assertEquals(16, in.read(encryptedIv, 0, 16)); + ConnectionContext ctx = + connectionContextFactory.createConnectionContext(contactId, + transportIndex, connection, Arrays.clone(secret)); ConnectionReader r = connectionReaderFactory.createConnectionReader(in, - transportIndex, encryptedIv, copyOfSecret); + ctx, encryptedIv); in = r.getInputStream(); ProtocolReader protocolReader = protocolReaderFactory.createProtocolReader(in); diff --git a/test/net/sf/briar/db/DatabaseComponentTest.java b/test/net/sf/briar/db/DatabaseComponentTest.java index 66c84b223e..58c95b8a92 100644 --- a/test/net/sf/briar/db/DatabaseComponentTest.java +++ b/test/net/sf/briar/db/DatabaseComponentTest.java @@ -47,7 +47,6 @@ import net.sf.briar.api.transport.ConnectionWindow; import org.jmock.Expectations; import org.jmock.Mockery; -import static org.junit.Assert.assertArrayEquals; import org.junit.Test; public abstract class DatabaseComponentTest extends TestCase { @@ -107,9 +106,9 @@ public abstract class DatabaseComponentTest extends TestCase { Database<T> database, DatabaseCleaner cleaner); @Test + @SuppressWarnings("unchecked") public void testSimpleCalls() throws Exception { Mockery context = new Mockery(); - @SuppressWarnings("unchecked") final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ConnectionWindow connectionWindow = @@ -138,7 +137,8 @@ public abstract class DatabaseComponentTest extends TestCase { oneOf(database).setRating(txn, authorId, Rating.GOOD); will(returnValue(Rating.GOOD)); // addContact() - oneOf(database).addContact(txn, inSecret, outSecret); + oneOf(database).addContact(with(txn), with(inSecret), + with(outSecret), with(any(Collection.class))); will(returnValue(contactId)); oneOf(listener).eventOccurred(with(any(ContactAddedEvent.class))); // getContacts() @@ -149,16 +149,6 @@ public abstract class DatabaseComponentTest extends TestCase { will(returnValue(true)); oneOf(database).getConnectionWindow(txn, contactId, remoteIndex); will(returnValue(connectionWindow)); - // getSharedSecret(contactId, true) - oneOf(database).containsContact(txn, contactId); - will(returnValue(true)); - oneOf(database).getSharedSecret(txn, contactId, true); - will(returnValue(inSecret)); - // getSharedSecret(contactId, false) - oneOf(database).containsContact(txn, contactId); - will(returnValue(true)); - oneOf(database).getSharedSecret(txn, contactId, false); - will(returnValue(outSecret)); // getTransportProperties(transportId) oneOf(database).getRemoteProperties(txn, transportId); will(returnValue(remoteProperties)); @@ -213,8 +203,6 @@ public abstract class DatabaseComponentTest extends TestCase { assertEquals(Collections.singletonList(contactId), db.getContacts()); assertEquals(connectionWindow, db.getConnectionWindow(contactId, remoteIndex)); - assertArrayEquals(inSecret, db.getSharedSecret(contactId, true)); - assertArrayEquals(outSecret, db.getSharedSecret(contactId, false)); assertEquals(remoteProperties, db.getRemoteProperties(transportId)); db.subscribe(group); // First time - listeners called db.subscribe(group); // Second time - not called @@ -516,11 +504,11 @@ public abstract class DatabaseComponentTest extends TestCase { context.mock(TransportUpdate.class); context.checking(new Expectations() {{ // Check whether the contact is still in the DB (which it's not) - exactly(20).of(database).startTransaction(); + exactly(19).of(database).startTransaction(); will(returnValue(txn)); - exactly(20).of(database).containsContact(txn, contactId); + exactly(19).of(database).containsContact(txn, contactId); will(returnValue(false)); - exactly(20).of(database).commitTransaction(txn); + exactly(19).of(database).commitTransaction(txn); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner); @@ -575,11 +563,6 @@ public abstract class DatabaseComponentTest extends TestCase { fail(); } catch(NoSuchContactException expected) {} - try { - db.getSharedSecret(contactId, true); - fail(); - } catch(NoSuchContactException expected) {} - try { db.hasSendableMessages(contactId); fail(); diff --git a/test/net/sf/briar/db/H2DatabaseTest.java b/test/net/sf/briar/db/H2DatabaseTest.java index 3e860fc7ad..60ff32646b 100644 --- a/test/net/sf/briar/db/H2DatabaseTest.java +++ b/test/net/sf/briar/db/H2DatabaseTest.java @@ -4,6 +4,7 @@ import static org.junit.Assert.assertArrayEquals; import java.io.File; import java.sql.Connection; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; @@ -88,6 +89,7 @@ public class H2DatabaseTest extends TestCase { private final Collection<Transport> remoteTransports; private final Map<Group, Long> subscriptions; private final byte[] inSecret, outSecret; + private final Collection<byte[]> erase; public H2DatabaseTest() throws Exception { super(); @@ -131,6 +133,7 @@ public class H2DatabaseTest extends TestCase { r.nextBytes(inSecret); outSecret = new byte[32]; r.nextBytes(outSecret); + erase = new ArrayList<byte[]>(); } @Before @@ -144,8 +147,7 @@ public class H2DatabaseTest extends TestCase { Database<Connection> db = open(false); Connection txn = db.startTransaction(); assertFalse(db.containsContact(txn, contactId)); - assertEquals(contactId, - db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertTrue(db.containsContact(txn, contactId)); assertFalse(db.containsSubscription(txn, groupId)); db.addSubscription(txn, group); @@ -201,20 +203,23 @@ public class H2DatabaseTest extends TestCase { // Create three contacts assertFalse(db.containsContact(txn, contactId)); - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertTrue(db.containsContact(txn, contactId)); assertFalse(db.containsContact(txn, contactId1)); - assertEquals(contactId1, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId1, + db.addContact(txn, inSecret, outSecret, erase)); assertTrue(db.containsContact(txn, contactId1)); assertFalse(db.containsContact(txn, contactId2)); - assertEquals(contactId2, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId2, + db.addContact(txn, inSecret, outSecret, erase)); assertTrue(db.containsContact(txn, contactId2)); // Delete the contact with the highest ID db.removeContact(txn, contactId2); assertFalse(db.containsContact(txn, contactId2)); // Add another contact - a new ID should be created assertFalse(db.containsContact(txn, contactId3)); - assertEquals(contactId3, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId3, + db.addContact(txn, inSecret, outSecret, erase)); assertTrue(db.containsContact(txn, contactId3)); db.commitTransaction(txn); @@ -261,7 +266,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact and store a private message - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addPrivateMessage(txn, privateMessage, contactId); // Removing the contact should remove the message @@ -280,7 +285,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact and store a private message - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addPrivateMessage(txn, privateMessage, contactId); // The message has no status yet, so it should not be sendable @@ -319,7 +324,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact and store a private message - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addPrivateMessage(txn, privateMessage, contactId); db.setStatus(txn, contactId, privateMessageId, Status.NEW); @@ -347,7 +352,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact, subscribe to a group and store a message - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addSubscription(txn, group); db.setVisibility(txn, groupId, Collections.singletonList(contactId)); db.setSubscriptions(txn, contactId, subscriptions, 1); @@ -385,7 +390,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact, subscribe to a group and store a message - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addSubscription(txn, group); db.setVisibility(txn, groupId, Collections.singletonList(contactId)); db.setSubscriptions(txn, contactId, subscriptions, 1); @@ -427,7 +432,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact, subscribe to a group and store a message - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addSubscription(txn, group); db.setVisibility(txn, groupId, Collections.singletonList(contactId)); db.addGroupMessage(txn, message); @@ -466,7 +471,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact, subscribe to a group and store a message - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addSubscription(txn, group); db.setVisibility(txn, groupId, Collections.singletonList(contactId)); db.addGroupMessage(txn, message); @@ -501,7 +506,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact, subscribe to a group and store a message - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addSubscription(txn, group); db.setVisibility(txn, groupId, Collections.singletonList(contactId)); db.setSubscriptions(txn, contactId, subscriptions, 1); @@ -532,7 +537,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact, subscribe to a group and store a message - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addSubscription(txn, group); db.setSubscriptions(txn, contactId, subscriptions, 1); db.addGroupMessage(txn, message); @@ -565,7 +570,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact and some batches to ack - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addBatchToAck(txn, contactId, batchId); db.addBatchToAck(txn, contactId, batchId1); @@ -592,7 +597,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact and receive the same batch twice - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addBatchToAck(txn, contactId, batchId); db.addBatchToAck(txn, contactId, batchId); @@ -618,7 +623,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact, subscribe to a group and store a message - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addSubscription(txn, group); db.addGroupMessage(txn, message); @@ -643,8 +648,8 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add two contacts, subscribe to a group and store a message - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); - ContactId contactId1 = db.addContact(txn, inSecret, outSecret); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); + ContactId contactId1 = db.addContact(txn, inSecret, outSecret, erase); db.addSubscription(txn, group); db.addGroupMessage(txn, message); @@ -666,7 +671,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact, subscribe to a group and store a message - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addSubscription(txn, group); db.setVisibility(txn, groupId, Collections.singletonList(contactId)); db.setSubscriptions(txn, contactId, subscriptions, 1); @@ -705,7 +710,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact, subscribe to a group and store a message - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addSubscription(txn, group); db.setVisibility(txn, groupId, Collections.singletonList(contactId)); db.setSubscriptions(txn, contactId, subscriptions, 1); @@ -750,7 +755,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); // Add some outstanding batches, a few ms apart for(int i = 0; i < ids.length; i++) { @@ -790,7 +795,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); // Add some outstanding batches, a few ms apart for(int i = 0; i < ids.length; i++) { @@ -1010,7 +1015,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact with a transport - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.setTransports(txn, contactId, remoteTransports, 1); assertEquals(remoteProperties, db.getRemoteProperties(txn, transportId)); @@ -1103,7 +1108,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact with a transport - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.setTransports(txn, contactId, remoteTransports, 1); assertEquals(remoteProperties, db.getRemoteProperties(txn, transportId)); @@ -1147,7 +1152,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact with some subscriptions - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.setSubscriptions(txn, contactId, subscriptions, 1); assertEquals(Collections.singletonList(group), db.getSubscriptions(txn, contactId)); @@ -1172,7 +1177,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact with some subscriptions - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.setSubscriptions(txn, contactId, subscriptions, 2); assertEquals(Collections.singletonList(group), db.getSubscriptions(txn, contactId)); @@ -1196,7 +1201,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact and subscribe to a group - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addSubscription(txn, group); db.setSubscriptions(txn, contactId, subscriptions, 1); @@ -1214,7 +1219,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact, subscribe to a group and store a message - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addSubscription(txn, group); db.setSubscriptions(txn, contactId, subscriptions, 1); db.addGroupMessage(txn, message); @@ -1237,7 +1242,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact, subscribe to a group and store a message - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addSubscription(txn, group); db.setSubscriptions(txn, contactId, subscriptions, 1); db.addGroupMessage(txn, message); @@ -1260,7 +1265,7 @@ public class H2DatabaseTest extends TestCase { // Add a contact, subscribe to a group and store a message - // the message is older than the contact's subscription - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addSubscription(txn, group); db.setVisibility(txn, groupId, Collections.singletonList(contactId)); Map<Group, Long> subs = Collections.singletonMap(group, timestamp + 1); @@ -1284,7 +1289,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact, subscribe to a group and store a message - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addSubscription(txn, group); db.setVisibility(txn, groupId, Collections.singletonList(contactId)); db.setSubscriptions(txn, contactId, subscriptions, 1); @@ -1309,7 +1314,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact and subscribe to a group - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addSubscription(txn, group); db.setVisibility(txn, groupId, Collections.singletonList(contactId)); db.setSubscriptions(txn, contactId, subscriptions, 1); @@ -1328,7 +1333,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact with a subscription - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.setSubscriptions(txn, contactId, subscriptions, 1); // There's no local subscription for the group @@ -1345,7 +1350,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact, subscribe to a group and store a message - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addSubscription(txn, group); db.addGroupMessage(txn, message); db.setStatus(txn, contactId, messageId, Status.NEW); @@ -1364,7 +1369,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact, subscribe to a group and store a message - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addSubscription(txn, group); db.addGroupMessage(txn, message); db.setSubscriptions(txn, contactId, subscriptions, 1); @@ -1384,7 +1389,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact, subscribe to a group and store a message - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addSubscription(txn, group); db.setVisibility(txn, groupId, Collections.singletonList(contactId)); db.setSubscriptions(txn, contactId, subscriptions, 1); @@ -1406,7 +1411,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact, subscribe to a group and store a message - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addSubscription(txn, group); db.setVisibility(txn, groupId, Collections.singletonList(contactId)); db.setSubscriptions(txn, contactId, subscriptions, 1); @@ -1427,7 +1432,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact and subscribe to a group - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addSubscription(txn, group); // The group should not be visible to the contact assertEquals(Collections.emptyList(), db.getVisibility(txn, groupId)); @@ -1450,7 +1455,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); // Get the connection window for a new index ConnectionWindow w = db.getConnectionWindow(txn, contactId, remoteIndex); @@ -1469,18 +1474,18 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); // Get the connection window for a new index ConnectionWindow w = db.getConnectionWindow(txn, contactId, remoteIndex); // The connection window should exist and be in the initial state assertNotNull(w); - Collection<Long> unseen = w.getUnseen(); + Map<Long, byte[]> unseen = w.getUnseen(); long top = ProtocolConstants.CONNECTION_WINDOW_SIZE / 2 - 1; assertEquals(top + 1, unseen.size()); for(long l = 0; l <= top; l++) { assertFalse(w.isSeen(l)); - assertTrue(unseen.contains(l)); + assertTrue(unseen.containsKey(l)); } // Update the connection window and store it w.setSeen(5); @@ -1573,7 +1578,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact and subscribe to a group - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addSubscription(txn, group); // A message with a private parent should return null @@ -1622,7 +1627,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); // The subscription and transport timestamps should be initialised to 0 assertEquals(0L, db.getSubscriptionsModified(txn, contactId)); @@ -1653,7 +1658,7 @@ public class H2DatabaseTest extends TestCase { Connection txn = db.startTransaction(); // Add a contact and subscribe to a group - assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); + assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); db.addSubscription(txn, group); // Store a couple of messages @@ -1897,6 +1902,7 @@ public class H2DatabaseTest extends TestCase { @After public void tearDown() { + erase.clear(); TestUtils.deleteTestDirectory(testDir); } diff --git a/test/net/sf/briar/transport/ConnectionDecrypterImplTest.java b/test/net/sf/briar/transport/ConnectionDecrypterImplTest.java index 6f609aeb27..b12651a013 100644 --- a/test/net/sf/briar/transport/ConnectionDecrypterImplTest.java +++ b/test/net/sf/briar/transport/ConnectionDecrypterImplTest.java @@ -52,7 +52,8 @@ public class ConnectionDecrypterImplTest extends TestCase { private void testDecryption(boolean initiator) throws Exception { // Calculate the plaintext and ciphertext for the IV - byte[] iv = IvEncoder.encodeIv(initiator, transportIndex, connection); + byte[] iv = IvEncoder.encodeIv(initiator, transportIndex.getInt(), + connection); ivCipher.init(Cipher.ENCRYPT_MODE, ivKey); byte[] encryptedIv = ivCipher.doFinal(iv); assertEquals(IV_LENGTH, encryptedIv.length); @@ -85,8 +86,8 @@ public class ConnectionDecrypterImplTest extends TestCase { ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); // Use a ConnectionDecrypter to decrypt the ciphertext ConnectionDecrypter d = new ConnectionDecrypterImpl(in, - IvEncoder.encodeIv(initiator, transportIndex, connection), - frameCipher, frameKey); + IvEncoder.encodeIv(initiator, transportIndex.getInt(), + connection), frameCipher, frameKey); // First frame byte[] decrypted = new byte[ciphertext.length]; TestUtils.readFully(d.getInputStream(), decrypted); diff --git a/test/net/sf/briar/transport/ConnectionEncrypterImplTest.java b/test/net/sf/briar/transport/ConnectionEncrypterImplTest.java index d6ad51a28c..0012ea8a8f 100644 --- a/test/net/sf/briar/transport/ConnectionEncrypterImplTest.java +++ b/test/net/sf/briar/transport/ConnectionEncrypterImplTest.java @@ -50,7 +50,8 @@ public class ConnectionEncrypterImplTest extends TestCase { private void testEncryption(boolean initiator) throws Exception { // Calculate the expected ciphertext for the IV - byte[] iv = IvEncoder.encodeIv(initiator, transportIndex, connection); + byte[] iv = IvEncoder.encodeIv(initiator, transportIndex.getInt(), + connection); ivCipher.init(Cipher.ENCRYPT_MODE, ivKey); byte[] encryptedIv = ivCipher.doFinal(iv); assertEquals(IV_LENGTH, encryptedIv.length); @@ -82,7 +83,7 @@ public class ConnectionEncrypterImplTest extends TestCase { byte[] expected = out.toByteArray(); // Use a ConnectionEncrypter to encrypt the plaintext out.reset(); - iv = IvEncoder.encodeIv(initiator, transportIndex, connection); + iv = IvEncoder.encodeIv(initiator, transportIndex.getInt(), connection); ConnectionEncrypter e = new ConnectionEncrypterImpl(out, Long.MAX_VALUE, iv, ivCipher, frameCipher, ivKey, frameKey); e.getOutputStream().write(plaintext); diff --git a/test/net/sf/briar/transport/ConnectionRecogniserImplTest.java b/test/net/sf/briar/transport/ConnectionRecogniserImplTest.java index 0471541966..b5a3627895 100644 --- a/test/net/sf/briar/transport/ConnectionRecogniserImplTest.java +++ b/test/net/sf/briar/transport/ConnectionRecogniserImplTest.java @@ -4,6 +4,7 @@ import static net.sf.briar.api.transport.TransportConstants.IV_LENGTH; import java.util.Collection; import java.util.Collections; +import java.util.Map; import java.util.Random; import javax.crypto.Cipher; @@ -51,7 +52,8 @@ public class ConnectionRecogniserImplTest extends TestCase { Transport transport = new Transport(transportId, localIndex, Collections.singletonMap("foo", "bar")); transports = Collections.singletonList(transport); - connectionWindow = new ConnectionWindowImpl(); + connectionWindow = new ConnectionWindowImpl(crypto, remoteIndex, + inSecret); } @Test @@ -65,8 +67,6 @@ public class ConnectionRecogniserImplTest extends TestCase { will(returnValue(transports)); oneOf(db).getContacts(); will(returnValue(Collections.singletonList(contactId))); - oneOf(db).getSharedSecret(contactId, true); - will(returnValue(inSecret)); oneOf(db).getRemoteIndex(contactId, transportId); will(returnValue(remoteIndex)); oneOf(db).getConnectionWindow(contactId, remoteIndex); @@ -80,11 +80,16 @@ public class ConnectionRecogniserImplTest extends TestCase { @Test public void testExpectedIv() throws Exception { + // Calculate the shared secret for connection number 3 + byte[] secret = inSecret; + for(int i = 0; i < 4; i++) { + secret = crypto.deriveNextSecret(secret, remoteIndex.getInt(), i); + } // Calculate the expected IV for connection number 3 - ErasableKey ivKey = crypto.deriveIvKey(inSecret, true); + ErasableKey ivKey = crypto.deriveIvKey(secret, true); Cipher ivCipher = crypto.getIvCipher(); ivCipher.init(Cipher.ENCRYPT_MODE, ivKey); - byte[] iv = IvEncoder.encodeIv(true, remoteIndex, 3L); + byte[] iv = IvEncoder.encodeIv(true, remoteIndex.getInt(), 3); byte[] encryptedIv = ivCipher.doFinal(iv); Mockery context = new Mockery(); @@ -96,8 +101,6 @@ public class ConnectionRecogniserImplTest extends TestCase { will(returnValue(transports)); oneOf(db).getContacts(); will(returnValue(Collections.singletonList(contactId))); - oneOf(db).getSharedSecret(contactId, true); - will(returnValue(inSecret)); oneOf(db).getRemoteIndex(contactId, transportId); will(returnValue(remoteIndex)); oneOf(db).getConnectionWindow(contactId, remoteIndex); @@ -107,8 +110,6 @@ public class ConnectionRecogniserImplTest extends TestCase { will(returnValue(connectionWindow)); oneOf(db).setConnectionWindow(contactId, remoteIndex, connectionWindow); - oneOf(db).getSharedSecret(contactId, true); - will(returnValue(inSecret)); }}); final ConnectionRecogniserImpl c = new ConnectionRecogniserImpl(crypto, db); @@ -121,11 +122,11 @@ public class ConnectionRecogniserImplTest extends TestCase { // Second time - the IV should no longer be expected assertNull(c.acceptConnection(encryptedIv)); // The window should have advanced - Collection<Long> unseen = connectionWindow.getUnseen(); + Map<Long, byte[]> unseen = connectionWindow.getUnseen(); assertEquals(19, unseen.size()); for(int i = 0; i < 19; i++) { if(i == 3) continue; - assertTrue(unseen.contains(Long.valueOf(i))); + assertTrue(unseen.containsKey(Long.valueOf(i))); } context.assertIsSatisfied(); } diff --git a/test/net/sf/briar/transport/ConnectionWindowImplTest.java b/test/net/sf/briar/transport/ConnectionWindowImplTest.java index d39855f911..5c9839bd86 100644 --- a/test/net/sf/briar/transport/ConnectionWindowImplTest.java +++ b/test/net/sf/briar/transport/ConnectionWindowImplTest.java @@ -1,18 +1,39 @@ package net.sf.briar.transport; -import java.util.ArrayList; -import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.Random; import junit.framework.TestCase; +import net.sf.briar.api.crypto.CryptoComponent; +import net.sf.briar.api.protocol.TransportIndex; +import net.sf.briar.api.transport.ConnectionWindow; +import net.sf.briar.crypto.CryptoModule; import net.sf.briar.util.ByteUtils; import org.junit.Test; +import com.google.inject.Guice; +import com.google.inject.Injector; + public class ConnectionWindowImplTest extends TestCase { + private final CryptoComponent crypto; + private final byte[] secret; + private final TransportIndex transportIndex = new TransportIndex(13); + + public ConnectionWindowImplTest(String name) { + super(name); + Injector i = Guice.createInjector(new CryptoModule()); + crypto = i.getInstance(CryptoComponent.class); + secret = new byte[32]; + new Random().nextBytes(secret); + } + @Test public void testWindowSliding() { - ConnectionWindowImpl w = new ConnectionWindowImpl(); + ConnectionWindow w = new ConnectionWindowImpl(crypto, + transportIndex, secret); for(int i = 0; i < 100; i++) { assertFalse(w.isSeen(i)); w.setSeen(i); @@ -22,7 +43,8 @@ public class ConnectionWindowImplTest extends TestCase { @Test public void testWindowJumping() { - ConnectionWindowImpl w = new ConnectionWindowImpl(); + ConnectionWindow w = new ConnectionWindowImpl(crypto, transportIndex, + secret); for(int i = 0; i < 100; i += 13) { assertFalse(w.isSeen(i)); w.setSeen(i); @@ -32,7 +54,8 @@ public class ConnectionWindowImplTest extends TestCase { @Test public void testWindowUpperLimit() { - ConnectionWindowImpl w = new ConnectionWindowImpl(); + ConnectionWindow w = new ConnectionWindowImpl(crypto, transportIndex, + secret); // Centre is 0, highest value in window is 15 w.setSeen(15); // Centre is 16, highest value in window is 31 @@ -43,11 +66,11 @@ public class ConnectionWindowImplTest extends TestCase { fail(); } catch(IllegalArgumentException expected) {} // Values greater than 2^32 - 1 should never be allowed - Collection<Long> unseen = new ArrayList<Long>(); + Map<Long, byte[]> unseen = new HashMap<Long, byte[]>(); for(int i = 0; i < 32; i++) { - unseen.add(ByteUtils.MAX_32_BIT_UNSIGNED - i); + unseen.put(ByteUtils.MAX_32_BIT_UNSIGNED - i, secret); } - w = new ConnectionWindowImpl(unseen); + w = new ConnectionWindowImpl(crypto, transportIndex, unseen); w.setSeen(ByteUtils.MAX_32_BIT_UNSIGNED); try { w.setSeen(ByteUtils.MAX_32_BIT_UNSIGNED + 1); @@ -57,7 +80,8 @@ public class ConnectionWindowImplTest extends TestCase { @Test public void testWindowLowerLimit() { - ConnectionWindowImpl w = new ConnectionWindowImpl(); + ConnectionWindow w = new ConnectionWindowImpl(crypto, transportIndex, + secret); // Centre is 0, negative values should never be allowed try { w.setSeen(-1); @@ -87,7 +111,8 @@ public class ConnectionWindowImplTest extends TestCase { @Test public void testCannotSetSeenTwice() { - ConnectionWindowImpl w = new ConnectionWindowImpl(); + ConnectionWindow w = new ConnectionWindowImpl(crypto, transportIndex, + secret); w.setSeen(15); try { w.setSeen(15); @@ -97,12 +122,13 @@ public class ConnectionWindowImplTest extends TestCase { @Test public void testGetUnseenConnectionNumbers() { - ConnectionWindowImpl w = new ConnectionWindowImpl(); + ConnectionWindow w = new ConnectionWindowImpl(crypto, transportIndex, + secret); // Centre is 0; window should cover 0 to 15, inclusive, with none seen - Collection<Long> unseen = w.getUnseen(); + Map<Long, byte[]> unseen = w.getUnseen(); assertEquals(16, unseen.size()); for(int i = 0; i < 16; i++) { - assertTrue(unseen.contains(Long.valueOf(i))); + assertTrue(unseen.containsKey(Long.valueOf(i))); assertFalse(w.isSeen(i)); } w.setSeen(3); @@ -112,10 +138,10 @@ public class ConnectionWindowImplTest extends TestCase { assertEquals(19, unseen.size()); for(int i = 0; i < 21; i++) { if(i == 3 || i == 4) { - assertFalse(unseen.contains(Long.valueOf(i))); + assertFalse(unseen.containsKey(Long.valueOf(i))); assertTrue(w.isSeen(i)); } else { - assertTrue(unseen.contains(Long.valueOf(i))); + assertTrue(unseen.containsKey(Long.valueOf(i))); assertFalse(w.isSeen(i)); } } @@ -125,10 +151,10 @@ public class ConnectionWindowImplTest extends TestCase { assertEquals(30, unseen.size()); for(int i = 4; i < 36; i++) { if(i == 4 || i == 19) { - assertFalse(unseen.contains(Long.valueOf(i))); + assertFalse(unseen.containsKey(Long.valueOf(i))); assertTrue(w.isSeen(i)); } else { - assertTrue(unseen.contains(Long.valueOf(i))); + assertTrue(unseen.containsKey(Long.valueOf(i))); assertFalse(w.isSeen(i)); } } diff --git a/test/net/sf/briar/transport/ConnectionWriterTest.java b/test/net/sf/briar/transport/ConnectionWriterTest.java index cb27fd9653..9e7290316a 100644 --- a/test/net/sf/briar/transport/ConnectionWriterTest.java +++ b/test/net/sf/briar/transport/ConnectionWriterTest.java @@ -8,7 +8,10 @@ import java.util.Random; import junit.framework.TestCase; import net.sf.briar.TestDatabaseModule; +import net.sf.briar.api.ContactId; import net.sf.briar.api.protocol.TransportIndex; +import net.sf.briar.api.transport.ConnectionContext; +import net.sf.briar.api.transport.ConnectionContextFactory; import net.sf.briar.api.transport.ConnectionWriter; import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.crypto.CryptoModule; @@ -26,8 +29,10 @@ import com.google.inject.Injector; public class ConnectionWriterTest extends TestCase { + private final ConnectionContextFactory connectionContextFactory; private final ConnectionWriterFactory connectionWriterFactory; - private final byte[] outSecret; + private final byte[] secret; + private final ContactId contactId = new ContactId(13); private final TransportIndex transportIndex = new TransportIndex(13); private final long connection = 12345L; @@ -38,17 +43,22 @@ public class ConnectionWriterTest extends TestCase { new ProtocolWritersModule(), new SerialModule(), new TestDatabaseModule(), new TransportBatchModule(), new TransportModule(), new TransportStreamModule()); + connectionContextFactory = + i.getInstance(ConnectionContextFactory.class); connectionWriterFactory = i.getInstance(ConnectionWriterFactory.class); - outSecret = new byte[32]; - new Random().nextBytes(outSecret); + secret = new byte[32]; + new Random().nextBytes(secret); } @Test public void testOverhead() throws Exception { ByteArrayOutputStream out = new ByteArrayOutputStream(MIN_CONNECTION_LENGTH); + ConnectionContext ctx = + connectionContextFactory.createConnectionContext(contactId, + transportIndex, connection, secret); ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out, - MIN_CONNECTION_LENGTH, transportIndex, connection, outSecret); + MIN_CONNECTION_LENGTH, ctx); // Check that the connection writer thinks there's room for a packet long capacity = w.getRemainingCapacity(); assertTrue(capacity >= MAX_PACKET_LENGTH); diff --git a/test/net/sf/briar/transport/FrameReadWriteTest.java b/test/net/sf/briar/transport/FrameReadWriteTest.java index b05fff5224..07b00d9b98 100644 --- a/test/net/sf/briar/transport/FrameReadWriteTest.java +++ b/test/net/sf/briar/transport/FrameReadWriteTest.java @@ -64,7 +64,8 @@ public class FrameReadWriteTest extends TestCase { private void testWriteAndRead(boolean initiator) throws Exception { // Create and encrypt the IV - byte[] iv = IvEncoder.encodeIv(initiator, transportIndex, connection); + byte[] iv = IvEncoder.encodeIv(initiator, transportIndex.getInt(), + connection); ivCipher.init(Cipher.ENCRYPT_MODE, ivKey); byte[] encryptedIv = ivCipher.doFinal(iv); assertEquals(IV_LENGTH, encryptedIv.length); @@ -92,7 +93,7 @@ public class FrameReadWriteTest extends TestCase { // Decrypt the IV ivCipher.init(Cipher.DECRYPT_MODE, ivKey); byte[] recoveredIv = ivCipher.doFinal(recoveredEncryptedIv); - iv = IvEncoder.encodeIv(initiator, transportIndex, connection); + iv = IvEncoder.encodeIv(initiator, transportIndex.getInt(), connection); assertArrayEquals(iv, recoveredIv); // Read the frames back ConnectionDecrypter decrypter = new ConnectionDecrypterImpl(in, iv, diff --git a/test/net/sf/briar/transport/batch/BatchConnectionReadWriteTest.java b/test/net/sf/briar/transport/batch/BatchConnectionReadWriteTest.java index 592a46c487..0be09b15c9 100644 --- a/test/net/sf/briar/transport/batch/BatchConnectionReadWriteTest.java +++ b/test/net/sf/briar/transport/batch/BatchConnectionReadWriteTest.java @@ -119,7 +119,7 @@ public class BatchConnectionReadWriteTest extends TestCase { alice.getInstance(ProtocolWriterFactory.class); BatchTransportWriter writer = new TestBatchTransportWriter(out); OutgoingBatchConnection batchOut = new OutgoingBatchConnection( - connFactory, db, protoFactory, transportIndex, contactId, + connFactory, db, protoFactory, contactId, transportIndex, writer); // Write whatever needs to be written batchOut.write(); @@ -170,8 +170,7 @@ public class BatchConnectionReadWriteTest extends TestCase { bob.getInstance(ProtocolReaderFactory.class); BatchTransportReader reader = new TestBatchTransportReader(in); IncomingBatchConnection batchIn = new IncomingBatchConnection( - connFactory, db, protoFactory, transportIndex, contactId, - reader, encryptedIv); + connFactory, db, protoFactory, ctx, reader, encryptedIv); // No messages should have been added yet assertFalse(listener.messagesAdded); // Read whatever needs to be read diff --git a/util/net/sf/briar/util/ByteUtils.java b/util/net/sf/briar/util/ByteUtils.java index 2066e820a1..4871865437 100644 --- a/util/net/sf/briar/util/ByteUtils.java +++ b/util/net/sf/briar/util/ByteUtils.java @@ -40,4 +40,8 @@ public class ByteUtils { return ((b[offset] & 0xFFL) << 24) | ((b[offset + 1] & 0xFFL) << 16) | ((b[offset + 2] & 0xFFL) << 8) | (b[offset + 3] & 0xFFL); } + + public static void erase(byte[] b) { + for(int i = 0; i < b.length; i++) b[i] = 0; + } } -- GitLab