diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/contact/ContactId.java b/bramble-api/src/main/java/org/briarproject/bramble/api/contact/ContactId.java index 9e11e555783018e242eef89a7141e8565f68ea12..21293599febac3854411de04de5ace41017165e4 100644 --- a/bramble-api/src/main/java/org/briarproject/bramble/api/contact/ContactId.java +++ b/bramble-api/src/main/java/org/briarproject/bramble/api/contact/ContactId.java @@ -6,7 +6,7 @@ import javax.annotation.concurrent.Immutable; /** * Type-safe wrapper for an integer that uniquely identifies a contact within - * the scope of a single node. + * the scope of the local device. */ @Immutable @NotNullByDefault diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/db/DatabaseComponent.java b/bramble-api/src/main/java/org/briarproject/bramble/api/db/DatabaseComponent.java index 08fbe45405b0b5c3a1011e4780156a7d157c6219..45f721f22df00a5160352619a0801955a037fbd0 100644 --- a/bramble-api/src/main/java/org/briarproject/bramble/api/db/DatabaseComponent.java +++ b/bramble-api/src/main/java/org/briarproject/bramble/api/db/DatabaseComponent.java @@ -18,6 +18,8 @@ import org.briarproject.bramble.api.sync.MessageId; import org.briarproject.bramble.api.sync.MessageStatus; import org.briarproject.bramble.api.sync.Offer; import org.briarproject.bramble.api.sync.Request; +import org.briarproject.bramble.api.transport.KeySet; +import org.briarproject.bramble.api.transport.KeySetId; import org.briarproject.bramble.api.transport.TransportKeys; import java.util.Collection; @@ -102,10 +104,11 @@ public interface DatabaseComponent { throws DbException; /** - * Stores transport keys for a newly added contact. + * Stores the given transport keys, optionally binding them to the given + * contact, and returns a key set ID. */ - void addTransportKeys(Transaction txn, ContactId c, TransportKeys k) - throws DbException; + KeySetId addTransportKeys(Transaction txn, @Nullable ContactId c, + TransportKeys k) throws DbException; /** * Returns true if the database contains the given contact for the given @@ -394,8 +397,8 @@ public interface DatabaseComponent { * <p/> * Read-only. */ - Map<ContactId, TransportKeys> getTransportKeys(Transaction txn, - TransportId t) throws DbException; + Collection<KeySet> getTransportKeys(Transaction txn, TransportId t) + throws DbException; /** * Increments the outgoing stream counter for the given contact and @@ -507,15 +510,15 @@ public interface DatabaseComponent { Collection<MessageId> dependencies) throws DbException; /** - * Sets the reordering window for the given contact and transport in the + * Sets the reordering window for the given key set and transport in the * given rotation period. */ - void setReorderingWindow(Transaction txn, ContactId c, TransportId t, + void setReorderingWindow(Transaction txn, KeySetId k, TransportId t, long rotationPeriod, long base, byte[] bitmap) throws DbException; /** * Stores the given transport keys, deleting any keys they have replaced. */ - void updateTransportKeys(Transaction txn, - Map<ContactId, TransportKeys> keys) throws DbException; + void updateTransportKeys(Transaction txn, Collection<KeySet> keys) + throws DbException; } diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeySet.java b/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeySet.java new file mode 100644 index 0000000000000000000000000000000000000000..9cc8f63c294a6cd0b4674944e51307f72058ddf6 --- /dev/null +++ b/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeySet.java @@ -0,0 +1,51 @@ +package org.briarproject.bramble.api.transport; + +import org.briarproject.bramble.api.contact.ContactId; +import org.briarproject.bramble.api.nullsafety.NotNullByDefault; + +import javax.annotation.Nullable; +import javax.annotation.concurrent.Immutable; + +/** + * A set of transport keys for communicating with a contact. If the keys have + * not yet been bound to a contact, {@link #getContactId()}} returns null. + */ +@Immutable +@NotNullByDefault +public class KeySet { + + private final KeySetId keySetId; + @Nullable + private final ContactId contactId; + private final TransportKeys transportKeys; + + public KeySet(KeySetId keySetId, @Nullable ContactId contactId, + TransportKeys transportKeys) { + this.keySetId = keySetId; + this.contactId = contactId; + this.transportKeys = transportKeys; + } + + public KeySetId getKeySetId() { + return keySetId; + } + + @Nullable + public ContactId getContactId() { + return contactId; + } + + public TransportKeys getTransportKeys() { + return transportKeys; + } + + @Override + public int hashCode() { + return keySetId.hashCode(); + } + + @Override + public boolean equals(Object o) { + return o instanceof KeySet && keySetId.equals(((KeySet) o).keySetId); + } +} diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeySetId.java b/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeySetId.java new file mode 100644 index 0000000000000000000000000000000000000000..1f872e72a34d7ffb480423d11bbe1aa0db1bdf82 --- /dev/null +++ b/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeySetId.java @@ -0,0 +1,36 @@ +package org.briarproject.bramble.api.transport; + +import org.briarproject.bramble.api.nullsafety.NotNullByDefault; + +import javax.annotation.concurrent.Immutable; + +/** + * Type-safe wrapper for an integer that uniquely identifies a set of transport + * keys within the scope of the local device. + * <p/> + * Key sets created on a given device must have increasing identifiers. + */ +@Immutable +@NotNullByDefault +public class KeySetId { + + private final int id; + + public KeySetId(int id) { + this.id = id; + } + + public int getInt() { + return id; + } + + @Override + public int hashCode() { + return id; + } + + @Override + public boolean equals(Object o) { + return o instanceof KeySetId && id == ((KeySetId) o).id; + } +} diff --git a/bramble-core/src/main/java/org/briarproject/bramble/db/Database.java b/bramble-core/src/main/java/org/briarproject/bramble/db/Database.java index bc3166d5622e605a1fdddcbf0858bc438b2216ce..e8597eaba7deca87148ede882f806b7369d35f70 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/db/Database.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/db/Database.java @@ -21,6 +21,8 @@ import org.briarproject.bramble.api.sync.Message; import org.briarproject.bramble.api.sync.MessageId; import org.briarproject.bramble.api.sync.MessageStatus; import org.briarproject.bramble.api.sync.ValidationManager.State; +import org.briarproject.bramble.api.transport.KeySet; +import org.briarproject.bramble.api.transport.KeySetId; import org.briarproject.bramble.api.transport.TransportKeys; import java.util.Collection; @@ -123,9 +125,10 @@ interface Database<T> { throws DbException; /** - * Stores transport keys for a newly added contact. + * Stores the given transport keys, optionally binding them to the given + * contact, and returns a key set ID. */ - void addTransportKeys(T txn, ContactId c, TransportKeys k) + KeySetId addTransportKeys(T txn, @Nullable ContactId c, TransportKeys k) throws DbException; /** @@ -486,7 +489,7 @@ interface Database<T> { * <p/> * Read-only. */ - Map<ContactId, TransportKeys> getTransportKeys(T txn, TransportId t) + Collection<KeySet> getTransportKeys(T txn, TransportId t) throws DbException; /** @@ -619,10 +622,10 @@ interface Database<T> { void setMessageState(T txn, MessageId m, State state) throws DbException; /** - * Sets the reordering window for the given contact and transport in the + * Sets the reordering window for the given key set and transport in the * given rotation period. */ - void setReorderingWindow(T txn, ContactId c, TransportId t, + void setReorderingWindow(T txn, KeySetId k, TransportId t, long rotationPeriod, long base, byte[] bitmap) throws DbException; /** @@ -636,6 +639,5 @@ interface Database<T> { /** * Stores the given transport keys, deleting any keys they have replaced. */ - void updateTransportKeys(T txn, Map<ContactId, TransportKeys> keys) - throws DbException; + void updateTransportKeys(T txn, Collection<KeySet> keys) throws DbException; } diff --git a/bramble-core/src/main/java/org/briarproject/bramble/db/DatabaseComponentImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/db/DatabaseComponentImpl.java index 013233f395ceac5ff0f65cf5a0044473238858c4..f90f9e3ff18631275a73cfd890bd6f51dd4ad9c6 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/db/DatabaseComponentImpl.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/db/DatabaseComponentImpl.java @@ -51,15 +51,15 @@ import org.briarproject.bramble.api.sync.event.MessageToAckEvent; import org.briarproject.bramble.api.sync.event.MessageToRequestEvent; import org.briarproject.bramble.api.sync.event.MessagesAckedEvent; import org.briarproject.bramble.api.sync.event.MessagesSentEvent; +import org.briarproject.bramble.api.transport.KeySet; +import org.briarproject.bramble.api.transport.KeySetId; import org.briarproject.bramble.api.transport.TransportKeys; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.ReentrantReadWriteLock; import java.util.logging.Logger; @@ -234,15 +234,15 @@ class DatabaseComponentImpl<T> implements DatabaseComponent { } @Override - public void addTransportKeys(Transaction transaction, ContactId c, - TransportKeys k) throws DbException { + public KeySetId addTransportKeys(Transaction transaction, + @Nullable ContactId c, TransportKeys k) throws DbException { if (transaction.isReadOnly()) throw new IllegalArgumentException(); T txn = unbox(transaction); - if (!db.containsContact(txn, c)) + if (c != null && !db.containsContact(txn, c)) throw new NoSuchContactException(); if (!db.containsTransport(txn, k.getTransportId())) throw new NoSuchTransportException(); - db.addTransportKeys(txn, c, k); + return db.addTransportKeys(txn, c, k); } @Override @@ -586,8 +586,8 @@ class DatabaseComponentImpl<T> implements DatabaseComponent { } @Override - public Map<ContactId, TransportKeys> getTransportKeys( - Transaction transaction, TransportId t) throws DbException { + public Collection<KeySet> getTransportKeys(Transaction transaction, + TransportId t) throws DbException { T txn = unbox(transaction); if (!db.containsTransport(txn, t)) throw new NoSuchTransportException(); @@ -858,31 +858,25 @@ class DatabaseComponentImpl<T> implements DatabaseComponent { } @Override - public void setReorderingWindow(Transaction transaction, ContactId c, + public void setReorderingWindow(Transaction transaction, KeySetId k, TransportId t, long rotationPeriod, long base, byte[] bitmap) throws DbException { if (transaction.isReadOnly()) throw new IllegalArgumentException(); T txn = unbox(transaction); - if (!db.containsContact(txn, c)) - throw new NoSuchContactException(); if (!db.containsTransport(txn, t)) throw new NoSuchTransportException(); - db.setReorderingWindow(txn, c, t, rotationPeriod, base, bitmap); + db.setReorderingWindow(txn, k, t, rotationPeriod, base, bitmap); } @Override public void updateTransportKeys(Transaction transaction, - Map<ContactId, TransportKeys> keys) throws DbException { + Collection<KeySet> keys) throws DbException { if (transaction.isReadOnly()) throw new IllegalArgumentException(); T txn = unbox(transaction); - Map<ContactId, TransportKeys> filtered = new HashMap<>(); - for (Entry<ContactId, TransportKeys> e : keys.entrySet()) { - ContactId c = e.getKey(); - TransportKeys k = e.getValue(); - if (db.containsContact(txn, c) - && db.containsTransport(txn, k.getTransportId())) { - filtered.put(c, k); - } + Collection<KeySet> filtered = new ArrayList<>(); + for (KeySet ks : keys) { + TransportId t = ks.getTransportKeys().getTransportId(); + if (db.containsTransport(txn, t)) filtered.add(ks); } db.updateTransportKeys(txn, filtered); } diff --git a/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java b/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java index 51fa9ae309c39b93421e37673745ce143e8e59f5..327203d63005514b6cb66ef389641c8f49859a99 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java @@ -25,6 +25,8 @@ import org.briarproject.bramble.api.sync.MessageStatus; import org.briarproject.bramble.api.sync.ValidationManager.State; import org.briarproject.bramble.api.system.Clock; import org.briarproject.bramble.api.transport.IncomingKeys; +import org.briarproject.bramble.api.transport.KeySet; +import org.briarproject.bramble.api.transport.KeySetId; import org.briarproject.bramble.api.transport.OutgoingKeys; import org.briarproject.bramble.api.transport.TransportKeys; @@ -223,37 +225,43 @@ abstract class JdbcDatabase implements Database<Connection> { + " maxLatency INT NOT NULL," + " PRIMARY KEY (transportId))"; - private static final String CREATE_INCOMING_KEYS = - "CREATE TABLE incomingKeys" - + " (contactId INT NOT NULL," - + " transportId _STRING NOT NULL," + private static final String CREATE_OUTGOING_KEYS = + "CREATE TABLE outgoingKeys" + + " (transportId _STRING NOT NULL," + + " keySetId _COUNTER," + " rotationPeriod BIGINT NOT NULL," + + " contactId INT," // Null if keys are not bound + " tagKey _SECRET NOT NULL," + " headerKey _SECRET NOT NULL," - + " base BIGINT NOT NULL," - + " bitmap _BINARY NOT NULL," - + " PRIMARY KEY (contactId, transportId, rotationPeriod)," - + " FOREIGN KEY (contactId)" - + " REFERENCES contacts (contactId)" - + " ON DELETE CASCADE," + + " stream BIGINT NOT NULL," + + " PRIMARY KEY (transportId, keySetId)," + " FOREIGN KEY (transportId)" + " REFERENCES transports (transportId)" + + " ON DELETE CASCADE," + + " UNIQUE (keySetId)," + + " FOREIGN KEY (contactId)" + + " REFERENCES contacts (contactId)" + " ON DELETE CASCADE)"; - private static final String CREATE_OUTGOING_KEYS = - "CREATE TABLE outgoingKeys" - + " (contactId INT NOT NULL," - + " transportId _STRING NOT NULL," + private static final String CREATE_INCOMING_KEYS = + "CREATE TABLE incomingKeys" + + " (transportId _STRING NOT NULL," + + " keySetId INT NOT NULL," + " rotationPeriod BIGINT NOT NULL," + + " contactId INT," // Null if keys are not bound + " tagKey _SECRET NOT NULL," + " headerKey _SECRET NOT NULL," - + " stream BIGINT NOT NULL," - + " PRIMARY KEY (contactId, transportId)," - + " FOREIGN KEY (contactId)" - + " REFERENCES contacts (contactId)" - + " ON DELETE CASCADE," + + " base BIGINT NOT NULL," + + " bitmap _BINARY NOT NULL," + + " PRIMARY KEY (transportId, keySetId, rotationPeriod)," + " FOREIGN KEY (transportId)" + " REFERENCES transports (transportId)" + + " ON DELETE CASCADE," + + " FOREIGN KEY (keySetId)" + + " REFERENCES outgoingKeys (keySetId)" + + " ON DELETE CASCADE," + + " FOREIGN KEY (contactId)" + + " REFERENCES contacts (contactId)" + " ON DELETE CASCADE)"; private static final String INDEX_CONTACTS_BY_AUTHOR_ID = @@ -415,8 +423,8 @@ abstract class JdbcDatabase implements Database<Connection> { s.executeUpdate(insertTypeNames(CREATE_OFFERS)); s.executeUpdate(insertTypeNames(CREATE_STATUSES)); s.executeUpdate(insertTypeNames(CREATE_TRANSPORTS)); - s.executeUpdate(insertTypeNames(CREATE_INCOMING_KEYS)); s.executeUpdate(insertTypeNames(CREATE_OUTGOING_KEYS)); + s.executeUpdate(insertTypeNames(CREATE_INCOMING_KEYS)); s.close(); } catch (SQLException e) { tryToClose(s); @@ -865,62 +873,78 @@ abstract class JdbcDatabase implements Database<Connection> { } @Override - public void addTransportKeys(Connection txn, ContactId c, TransportKeys k) - throws DbException { + public KeySetId addTransportKeys(Connection txn, @Nullable ContactId c, + TransportKeys k) throws DbException { PreparedStatement ps = null; + ResultSet rs = null; try { + // Store the outgoing keys + String sql = "INSERT INTO outgoingKeys (contactId, transportId," + + " rotationPeriod, tagKey, headerKey, stream)" + + " VALUES (?, ?, ?, ?, ?, ?)"; + ps = txn.prepareStatement(sql); + if (c == null) ps.setNull(1, INTEGER); + else ps.setInt(1, c.getInt()); + ps.setString(2, k.getTransportId().getString()); + OutgoingKeys outCurr = k.getCurrentOutgoingKeys(); + ps.setLong(3, outCurr.getRotationPeriod()); + ps.setBytes(4, outCurr.getTagKey().getBytes()); + ps.setBytes(5, outCurr.getHeaderKey().getBytes()); + ps.setLong(6, outCurr.getStreamCounter()); + int affected = ps.executeUpdate(); + if (affected != 1) throw new DbStateException(); + ps.close(); + // Get the new (highest) key set ID + sql = "SELECT keySetId FROM outgoingKeys" + + " ORDER BY keySetId DESC LIMIT 1"; + ps = txn.prepareStatement(sql); + rs = ps.executeQuery(); + if (!rs.next()) throw new DbStateException(); + KeySetId keySetId = new KeySetId(rs.getInt(1)); + if (rs.next()) throw new DbStateException(); + rs.close(); + ps.close(); // Store the incoming keys - String sql = "INSERT INTO incomingKeys (contactId, transportId," + sql = "INSERT INTO incomingKeys (keySetId, contactId, transportId," + " rotationPeriod, tagKey, headerKey, base, bitmap)" - + " VALUES (?, ?, ?, ?, ?, ?, ?)"; + + " VALUES (?, ?, ?, ?, ?, ?, ?, ?)"; ps = txn.prepareStatement(sql); - ps.setInt(1, c.getInt()); - ps.setString(2, k.getTransportId().getString()); + ps.setInt(1, keySetId.getInt()); + if (c == null) ps.setNull(2, INTEGER); + else ps.setInt(2, c.getInt()); + ps.setString(3, k.getTransportId().getString()); // Previous rotation period IncomingKeys inPrev = k.getPreviousIncomingKeys(); - ps.setLong(3, inPrev.getRotationPeriod()); - ps.setBytes(4, inPrev.getTagKey().getBytes()); - ps.setBytes(5, inPrev.getHeaderKey().getBytes()); - ps.setLong(6, inPrev.getWindowBase()); - ps.setBytes(7, inPrev.getWindowBitmap()); + ps.setLong(4, inPrev.getRotationPeriod()); + ps.setBytes(5, inPrev.getTagKey().getBytes()); + ps.setBytes(6, inPrev.getHeaderKey().getBytes()); + ps.setLong(7, inPrev.getWindowBase()); + ps.setBytes(8, inPrev.getWindowBitmap()); ps.addBatch(); // Current rotation period IncomingKeys inCurr = k.getCurrentIncomingKeys(); - ps.setLong(3, inCurr.getRotationPeriod()); - ps.setBytes(4, inCurr.getTagKey().getBytes()); - ps.setBytes(5, inCurr.getHeaderKey().getBytes()); - ps.setLong(6, inCurr.getWindowBase()); - ps.setBytes(7, inCurr.getWindowBitmap()); + ps.setLong(4, inCurr.getRotationPeriod()); + ps.setBytes(5, inCurr.getTagKey().getBytes()); + ps.setBytes(6, inCurr.getHeaderKey().getBytes()); + ps.setLong(7, inCurr.getWindowBase()); + ps.setBytes(8, inCurr.getWindowBitmap()); ps.addBatch(); // Next rotation period IncomingKeys inNext = k.getNextIncomingKeys(); - ps.setLong(3, inNext.getRotationPeriod()); - ps.setBytes(4, inNext.getTagKey().getBytes()); - ps.setBytes(5, inNext.getHeaderKey().getBytes()); - ps.setLong(6, inNext.getWindowBase()); - ps.setBytes(7, inNext.getWindowBitmap()); + ps.setLong(4, inNext.getRotationPeriod()); + ps.setBytes(5, inNext.getTagKey().getBytes()); + ps.setBytes(6, inNext.getHeaderKey().getBytes()); + ps.setLong(7, inNext.getWindowBase()); + ps.setBytes(8, inNext.getWindowBitmap()); ps.addBatch(); int[] batchAffected = ps.executeBatch(); if (batchAffected.length != 3) throw new DbStateException(); for (int rows : batchAffected) if (rows != 1) throw new DbStateException(); ps.close(); - // Store the outgoing keys - sql = "INSERT INTO outgoingKeys (contactId, transportId," - + " rotationPeriod, tagKey, headerKey, stream)" - + " VALUES (?, ?, ?, ?, ?, ?)"; - ps = txn.prepareStatement(sql); - ps.setInt(1, c.getInt()); - ps.setString(2, k.getTransportId().getString()); - OutgoingKeys outCurr = k.getCurrentOutgoingKeys(); - ps.setLong(3, outCurr.getRotationPeriod()); - ps.setBytes(4, outCurr.getTagKey().getBytes()); - ps.setBytes(5, outCurr.getHeaderKey().getBytes()); - ps.setLong(6, outCurr.getStreamCounter()); - int affected = ps.executeUpdate(); - if (affected != 1) throw new DbStateException(); - ps.close(); + return keySetId; } catch (SQLException e) { + tryToClose(rs); tryToClose(ps); throw new DbException(e); } @@ -2078,8 +2102,8 @@ abstract class JdbcDatabase implements Database<Connection> { } @Override - public Map<ContactId, TransportKeys> getTransportKeys(Connection txn, - TransportId t) throws DbException { + public Collection<KeySet> getTransportKeys(Connection txn, TransportId t) + throws DbException { PreparedStatement ps = null; ResultSet rs = null; try { @@ -2088,7 +2112,7 @@ abstract class JdbcDatabase implements Database<Connection> { + " base, bitmap" + " FROM incomingKeys" + " WHERE transportId = ?" - + " ORDER BY contactId, rotationPeriod"; + + " ORDER BY keySetId, rotationPeriod"; ps = txn.prepareStatement(sql); ps.setString(1, t.getString()); rs = ps.executeQuery(); @@ -2105,29 +2129,33 @@ abstract class JdbcDatabase implements Database<Connection> { rs.close(); ps.close(); // Retrieve the outgoing keys in the same order - sql = "SELECT contactId, rotationPeriod, tagKey, headerKey, stream" + sql = "SELECT keySetId, contactId, rotationPeriod," + + " tagKey, headerKey, stream" + " FROM outgoingKeys" + " WHERE transportId = ?" - + " ORDER BY contactId, rotationPeriod"; + + " ORDER BY keySetId"; ps = txn.prepareStatement(sql); ps.setString(1, t.getString()); rs = ps.executeQuery(); - Map<ContactId, TransportKeys> keys = new HashMap<>(); + Collection<KeySet> keys = new ArrayList<>(); for (int i = 0; rs.next(); i++) { // There should be three times as many incoming keys if (inKeys.size() < (i + 1) * 3) throw new DbStateException(); - ContactId contactId = new ContactId(rs.getInt(1)); - long rotationPeriod = rs.getLong(2); - SecretKey tagKey = new SecretKey(rs.getBytes(3)); - SecretKey headerKey = new SecretKey(rs.getBytes(4)); - long streamCounter = rs.getLong(5); + KeySetId keySetId = new KeySetId(rs.getInt(1)); + ContactId contactId = new ContactId(rs.getInt(2)); + if (rs.wasNull()) contactId = null; + long rotationPeriod = rs.getLong(3); + SecretKey tagKey = new SecretKey(rs.getBytes(4)); + SecretKey headerKey = new SecretKey(rs.getBytes(5)); + long streamCounter = rs.getLong(6); OutgoingKeys outCurr = new OutgoingKeys(tagKey, headerKey, rotationPeriod, streamCounter); IncomingKeys inPrev = inKeys.get(i * 3); IncomingKeys inCurr = inKeys.get(i * 3 + 1); IncomingKeys inNext = inKeys.get(i * 3 + 2); - keys.put(contactId, new TransportKeys(t, inPrev, inCurr, - inNext, outCurr)); + TransportKeys transportKeys = new TransportKeys(t, inPrev, + inCurr, inNext, outCurr); + keys.add(new KeySet(keySetId, contactId, transportKeys)); } rs.close(); ps.close(); @@ -2791,18 +2819,18 @@ abstract class JdbcDatabase implements Database<Connection> { } @Override - public void setReorderingWindow(Connection txn, ContactId c, TransportId t, + public void setReorderingWindow(Connection txn, KeySetId k, TransportId t, long rotationPeriod, long base, byte[] bitmap) throws DbException { PreparedStatement ps = null; try { String sql = "UPDATE incomingKeys SET base = ?, bitmap = ?" - + " WHERE contactId = ? AND transportId = ?" + + " WHERE transportId = ? AND keySetId = ?" + " AND rotationPeriod = ?"; ps = txn.prepareStatement(sql); ps.setLong(1, base); ps.setBytes(2, bitmap); - ps.setInt(3, c.getInt()); - ps.setString(4, t.getString()); + ps.setString(3, t.getString()); + ps.setInt(4, k.getInt()); ps.setLong(5, rotationPeriod); int affected = ps.executeUpdate(); if (affected < 0 || affected > 1) throw new DbStateException(); @@ -2848,45 +2876,31 @@ abstract class JdbcDatabase implements Database<Connection> { } @Override - public void updateTransportKeys(Connection txn, - Map<ContactId, TransportKeys> keys) throws DbException { + public void updateTransportKeys(Connection txn, Collection<KeySet> keys) + throws DbException { PreparedStatement ps = null; try { - // Delete any existing incoming keys - String sql = "DELETE FROM incomingKeys" - + " WHERE contactId = ?" - + " AND transportId = ?"; + // Delete any existing outgoing keys - this will also remove any + // incoming keys with the same key set ID + String sql = "DELETE FROM outgoingKeys WHERE keySetId = ?"; ps = txn.prepareStatement(sql); - for (Entry<ContactId, TransportKeys> e : keys.entrySet()) { - ps.setInt(1, e.getKey().getInt()); - ps.setString(2, e.getValue().getTransportId().getString()); + for (KeySet ks : keys) { + ps.setInt(1, ks.getKeySetId().getInt()); ps.addBatch(); } int[] batchAffected = ps.executeBatch(); if (batchAffected.length != keys.size()) throw new DbStateException(); - ps.close(); - // Delete any existing outgoing keys - sql = "DELETE FROM outgoingKeys" - + " WHERE contactId = ?" - + " AND transportId = ?"; - ps = txn.prepareStatement(sql); - for (Entry<ContactId, TransportKeys> e : keys.entrySet()) { - ps.setInt(1, e.getKey().getInt()); - ps.setString(2, e.getValue().getTransportId().getString()); - ps.addBatch(); - } - batchAffected = ps.executeBatch(); - if (batchAffected.length != keys.size()) - throw new DbStateException(); + for (int rows: batchAffected) + if (rows < 0) throw new DbStateException(); ps.close(); } catch (SQLException e) { tryToClose(ps); throw new DbException(e); } // Store the new keys - for (Entry<ContactId, TransportKeys> e : keys.entrySet()) { - addTransportKeys(txn, e.getKey(), e.getValue()); + for (KeySet ks : keys) { + addTransportKeys(txn, ks.getContactId(), ks.getTransportKeys()); } } } diff --git a/bramble-core/src/main/java/org/briarproject/bramble/transport/MutableKeySet.java b/bramble-core/src/main/java/org/briarproject/bramble/transport/MutableKeySet.java new file mode 100644 index 0000000000000000000000000000000000000000..b55c5aef4f8551741c52e6a48d8630aebf86b29e --- /dev/null +++ b/bramble-core/src/main/java/org/briarproject/bramble/transport/MutableKeySet.java @@ -0,0 +1,34 @@ +package org.briarproject.bramble.transport; + +import org.briarproject.bramble.api.contact.ContactId; +import org.briarproject.bramble.api.transport.KeySetId; + +import javax.annotation.Nullable; + +public class MutableKeySet { + + private final KeySetId keySetId; + @Nullable + private final ContactId contactId; + private final MutableTransportKeys transportKeys; + + public MutableKeySet(KeySetId keySetId, @Nullable ContactId contactId, + MutableTransportKeys transportKeys) { + this.keySetId = keySetId; + this.contactId = contactId; + this.transportKeys = transportKeys; + } + + public KeySetId getKeySetId() { + return keySetId; + } + + @Nullable + public ContactId getContactId() { + return contactId; + } + + public MutableTransportKeys getTransportKeys() { + return transportKeys; + } +} diff --git a/bramble-core/src/main/java/org/briarproject/bramble/transport/TransportKeyManagerImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/transport/TransportKeyManagerImpl.java index 60b48427ff18f07703a147f827ae39639c6f629e..1220beee580ea7af59c0b682af94546fd80e0068 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/transport/TransportKeyManagerImpl.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/transport/TransportKeyManagerImpl.java @@ -11,19 +11,25 @@ import org.briarproject.bramble.api.nullsafety.NotNullByDefault; import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.system.Clock; import org.briarproject.bramble.api.system.Scheduler; +import org.briarproject.bramble.api.transport.KeySet; +import org.briarproject.bramble.api.transport.KeySetId; import org.briarproject.bramble.api.transport.StreamContext; import org.briarproject.bramble.api.transport.TransportKeys; import org.briarproject.bramble.transport.ReorderingWindow.Change; +import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; import java.util.Iterator; import java.util.Map; import java.util.Map.Entry; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.ReentrantLock; import java.util.logging.Logger; +import javax.annotation.Nullable; import javax.annotation.concurrent.ThreadSafe; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -47,12 +53,13 @@ class TransportKeyManagerImpl implements TransportKeyManager { private final Clock clock; private final TransportId transportId; private final long rotationPeriodLength; - private final ReentrantLock lock; + private final AtomicBoolean used = new AtomicBoolean(false); + private final ReentrantLock lock = new ReentrantLock(); // The following are locking: lock - private final Map<Bytes, TagContext> inContexts; - private final Map<ContactId, MutableOutgoingKeys> outContexts; - private final Map<ContactId, MutableTransportKeys> keys; + private final Collection<MutableKeySet> keys = new ArrayList<>(); + private final Map<Bytes, TagContext> inContexts = new HashMap<>(); + private final Map<ContactId, MutableKeySet> outContexts = new HashMap<>(); TransportKeyManagerImpl(DatabaseComponent db, TransportCrypto transportCrypto, Executor dbExecutor, @@ -65,20 +72,16 @@ class TransportKeyManagerImpl implements TransportKeyManager { this.clock = clock; this.transportId = transportId; rotationPeriodLength = maxLatency + MAX_CLOCK_DIFFERENCE; - lock = new ReentrantLock(); - inContexts = new HashMap<>(); - outContexts = new HashMap<>(); - keys = new HashMap<>(); } @Override public void start(Transaction txn) throws DbException { + if (used.getAndSet(true)) throw new IllegalStateException(); long now = clock.currentTimeMillis(); lock.lock(); try { // Load the transport keys from the DB - Map<ContactId, TransportKeys> loaded = - db.getTransportKeys(txn, transportId); + Collection<KeySet> loaded = db.getTransportKeys(txn, transportId); // Rotate the keys to the current rotation period RotationResult rotationResult = rotateKeys(loaded, now); // Initialise mutable state for all contacts @@ -93,41 +96,51 @@ class TransportKeyManagerImpl implements TransportKeyManager { scheduleKeyRotation(now); } - private RotationResult rotateKeys(Map<ContactId, TransportKeys> keys, - long now) { + private RotationResult rotateKeys(Collection<KeySet> keys, long now) { RotationResult rotationResult = new RotationResult(); long rotationPeriod = now / rotationPeriodLength; - for (Entry<ContactId, TransportKeys> e : keys.entrySet()) { - ContactId c = e.getKey(); - TransportKeys k = e.getValue(); + for (KeySet ks : keys) { + TransportKeys k = ks.getTransportKeys(); TransportKeys k1 = transportCrypto.rotateTransportKeys(k, rotationPeriod); + KeySet ks1 = new KeySet(ks.getKeySetId(), ks.getContactId(), k1); if (k1.getRotationPeriod() > k.getRotationPeriod()) - rotationResult.rotated.put(c, k1); - rotationResult.current.put(c, k1); + rotationResult.rotated.add(ks1); + rotationResult.current.add(ks1); } return rotationResult; } // Locking: lock - private void addKeys(Map<ContactId, TransportKeys> m) { - for (Entry<ContactId, TransportKeys> e : m.entrySet()) - addKeys(e.getKey(), new MutableTransportKeys(e.getValue())); + private void addKeys(Collection<KeySet> keys) { + for (KeySet ks : keys) { + addKeys(ks.getKeySetId(), ks.getContactId(), + new MutableTransportKeys(ks.getTransportKeys())); + } } // Locking: lock - private void addKeys(ContactId c, MutableTransportKeys m) { - encodeTags(c, m.getPreviousIncomingKeys()); - encodeTags(c, m.getCurrentIncomingKeys()); - encodeTags(c, m.getNextIncomingKeys()); - outContexts.put(c, m.getCurrentOutgoingKeys()); - keys.put(c, m); + private void addKeys(KeySetId keySetId, @Nullable ContactId contactId, + MutableTransportKeys m) { + MutableKeySet ks = new MutableKeySet(keySetId, contactId, m); + keys.add(ks); + if (contactId != null) { + encodeTags(keySetId, contactId, m.getPreviousIncomingKeys()); + encodeTags(keySetId, contactId, m.getCurrentIncomingKeys()); + encodeTags(keySetId, contactId, m.getNextIncomingKeys()); + // Use the outgoing keys with the highest key set ID + MutableKeySet old = outContexts.get(contactId); + if (old == null || old.getKeySetId().getInt() < keySetId.getInt()) + outContexts.put(contactId, ks); + } } // Locking: lock - private void encodeTags(ContactId c, MutableIncomingKeys inKeys) { + private void encodeTags(KeySetId keySetId, ContactId contactId, + MutableIncomingKeys inKeys) { for (long streamNumber : inKeys.getWindow().getUnseen()) { - TagContext tagCtx = new TagContext(c, inKeys, streamNumber); + TagContext tagCtx = + new TagContext(keySetId, contactId, inKeys, streamNumber); byte[] tag = new byte[TAG_LENGTH]; transportCrypto.encodeTag(tag, inKeys.getTagKey(), PROTOCOL_VERSION, streamNumber); @@ -169,10 +182,10 @@ class TransportKeyManagerImpl implements TransportKeyManager { // Rotate the keys to the current rotation period if necessary rotationPeriod = clock.currentTimeMillis() / rotationPeriodLength; k = transportCrypto.rotateTransportKeys(k, rotationPeriod); - // Initialise mutable state for the contact - addKeys(c, new MutableTransportKeys(k)); // Write the keys back to the DB - db.addTransportKeys(txn, c, k); + KeySetId keySetId = db.addTransportKeys(txn, c, k); + // Initialise mutable state for the contact + addKeys(keySetId, c, new MutableTransportKeys(k)); } finally { lock.unlock(); } @@ -183,12 +196,18 @@ class TransportKeyManagerImpl implements TransportKeyManager { lock.lock(); try { // Remove mutable state for the contact - Iterator<Entry<Bytes, TagContext>> it = + Iterator<Entry<Bytes, TagContext>> inContextsIter = inContexts.entrySet().iterator(); - while (it.hasNext()) - if (it.next().getValue().contactId.equals(c)) it.remove(); + while (inContextsIter.hasNext()) { + ContactId c1 = inContextsIter.next().getValue().contactId; + if (c1.equals(c)) inContextsIter.remove(); + } outContexts.remove(c); - keys.remove(c); + Iterator<MutableKeySet> keysIter = keys.iterator(); + while (keysIter.hasNext()) { + ContactId c1 = keysIter.next().getContactId(); + if (c1 != null && c1.equals(c)) keysIter.remove(); + } } finally { lock.unlock(); } @@ -200,8 +219,10 @@ class TransportKeyManagerImpl implements TransportKeyManager { lock.lock(); try { // Look up the outgoing keys for the contact - MutableOutgoingKeys outKeys = outContexts.get(c); - if (outKeys == null) return null; + MutableKeySet ks = outContexts.get(c); + if (ks == null) return null; + MutableOutgoingKeys outKeys = + ks.getTransportKeys().getCurrentOutgoingKeys(); if (outKeys.getStreamCounter() > MAX_32_BIT_UNSIGNED) return null; // Create a stream context StreamContext ctx = new StreamContext(c, transportId, @@ -238,8 +259,9 @@ class TransportKeyManagerImpl implements TransportKeyManager { byte[] addTag = new byte[TAG_LENGTH]; transportCrypto.encodeTag(addTag, inKeys.getTagKey(), PROTOCOL_VERSION, streamNumber); - inContexts.put(new Bytes(addTag), new TagContext( - tagCtx.contactId, inKeys, streamNumber)); + TagContext tagCtx1 = new TagContext(tagCtx.keySetId, + tagCtx.contactId, inKeys, streamNumber); + inContexts.put(new Bytes(addTag), tagCtx1); } // Remove tags for any stream numbers removed from the window for (long streamNumber : change.getRemoved()) { @@ -250,7 +272,7 @@ class TransportKeyManagerImpl implements TransportKeyManager { inContexts.remove(new Bytes(removeTag)); } // Write the window back to the DB - db.setReorderingWindow(txn, tagCtx.contactId, transportId, + db.setReorderingWindow(txn, tagCtx.keySetId, transportId, inKeys.getRotationPeriod(), window.getBase(), window.getBitmap()); return ctx; @@ -264,9 +286,11 @@ class TransportKeyManagerImpl implements TransportKeyManager { lock.lock(); try { // Rotate the keys to the current rotation period - Map<ContactId, TransportKeys> snapshot = new HashMap<>(); - for (Entry<ContactId, MutableTransportKeys> e : keys.entrySet()) - snapshot.put(e.getKey(), e.getValue().snapshot()); + Collection<KeySet> snapshot = new ArrayList<>(keys.size()); + for (MutableKeySet ks : keys) { + snapshot.add(new KeySet(ks.getKeySetId(), ks.getContactId(), + ks.getTransportKeys().snapshot())); + } RotationResult rotationResult = rotateKeys(snapshot, now); // Rebuild the mutable state for all contacts inContexts.clear(); @@ -285,12 +309,14 @@ class TransportKeyManagerImpl implements TransportKeyManager { private static class TagContext { + private final KeySetId keySetId; private final ContactId contactId; private final MutableIncomingKeys inKeys; private final long streamNumber; - private TagContext(ContactId contactId, MutableIncomingKeys inKeys, - long streamNumber) { + private TagContext(KeySetId keySetId, ContactId contactId, + MutableIncomingKeys inKeys, long streamNumber) { + this.keySetId = keySetId; this.contactId = contactId; this.inKeys = inKeys; this.streamNumber = streamNumber; @@ -299,11 +325,7 @@ class TransportKeyManagerImpl implements TransportKeyManager { private static class RotationResult { - private final Map<ContactId, TransportKeys> current, rotated; - - private RotationResult() { - current = new HashMap<>(); - rotated = new HashMap<>(); - } + private final Collection<KeySet> current = new ArrayList<>(); + private final Collection<KeySet> rotated = new ArrayList<>(); } } diff --git a/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseComponentImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseComponentImplTest.java index 0ed4574618fc17e8eac6f0dcb59cb395f206e78d..eb31247ce074642cd6c9fcc29de75280def9151c 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseComponentImplTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseComponentImplTest.java @@ -44,6 +44,8 @@ import org.briarproject.bramble.api.sync.event.MessageToRequestEvent; import org.briarproject.bramble.api.sync.event.MessagesAckedEvent; import org.briarproject.bramble.api.sync.event.MessagesSentEvent; import org.briarproject.bramble.api.transport.IncomingKeys; +import org.briarproject.bramble.api.transport.KeySet; +import org.briarproject.bramble.api.transport.KeySetId; import org.briarproject.bramble.api.transport.OutgoingKeys; import org.briarproject.bramble.api.transport.TransportKeys; import org.briarproject.bramble.test.BrambleMockTestCase; @@ -55,12 +57,10 @@ import org.junit.Test; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; -import java.util.Map; import java.util.concurrent.atomic.AtomicReference; import static java.util.Collections.emptyMap; import static java.util.Collections.singletonList; -import static java.util.Collections.singletonMap; import static org.briarproject.bramble.api.sync.Group.Visibility.INVISIBLE; import static org.briarproject.bramble.api.sync.Group.Visibility.SHARED; import static org.briarproject.bramble.api.sync.Group.Visibility.VISIBLE; @@ -100,6 +100,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { private final int maxLatency; private final ContactId contactId; private final Contact contact; + private final KeySetId keySetId; public DatabaseComponentImplTest() { clientId = new ClientId(getRandomString(123)); @@ -121,6 +122,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { contactId = new ContactId(234); contact = new Contact(contactId, author, localAuthor.getId(), true, true); + keySetId = new KeySetId(345); } private DatabaseComponent createDatabaseComponent(Database<Object> database, @@ -282,11 +284,11 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { throws Exception { context.checking(new Expectations() {{ // Check whether the contact is in the DB (which it's not) - exactly(18).of(database).startTransaction(); + exactly(17).of(database).startTransaction(); will(returnValue(txn)); - exactly(18).of(database).containsContact(txn, contactId); + exactly(17).of(database).containsContact(txn, contactId); will(returnValue(false)); - exactly(18).of(database).abortTransaction(txn); + exactly(17).of(database).abortTransaction(txn); }}); DatabaseComponent db = createDatabaseComponent(database, eventBus, shutdown); @@ -454,17 +456,6 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { db.endTransaction(transaction); } - transaction = db.startTransaction(false); - try { - db.setReorderingWindow(transaction, contactId, transportId, 0, 0, - new byte[REORDERING_WINDOW_SIZE / 8]); - fail(); - } catch (NoSuchContactException expected) { - // Expected - } finally { - db.endTransaction(transaction); - } - transaction = db.startTransaction(false); try { db.setGroupVisibility(transaction, contactId, groupId, SHARED); @@ -779,7 +770,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { // Check whether the transport is in the DB (which it's not) exactly(4).of(database).startTransaction(); will(returnValue(txn)); - exactly(2).of(database).containsContact(txn, contactId); + oneOf(database).containsContact(txn, contactId); will(returnValue(true)); exactly(4).of(database).containsTransport(txn, transportId); will(returnValue(false)); @@ -830,7 +821,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { transaction = db.startTransaction(false); try { - db.setReorderingWindow(transaction, contactId, transportId, 0, 0, + db.setReorderingWindow(transaction, keySetId, transportId, 0, 0, new byte[REORDERING_WINDOW_SIZE / 8]); fail(); } catch (NoSuchTransportException expected) { @@ -1303,15 +1294,13 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { @Test public void testTransportKeys() throws Exception { TransportKeys transportKeys = createTransportKeys(); - Map<ContactId, TransportKeys> keys = - singletonMap(contactId, transportKeys); + Collection<KeySet> keys = + singletonList(new KeySet(keySetId, contactId, transportKeys)); context.checking(new Expectations() {{ // startTransaction() oneOf(database).startTransaction(); will(returnValue(txn)); // updateTransportKeys() - oneOf(database).containsContact(txn, contactId); - will(returnValue(true)); oneOf(database).containsTransport(txn, transportId); will(returnValue(true)); oneOf(database).updateTransportKeys(txn, keys); diff --git a/bramble-core/src/test/java/org/briarproject/bramble/db/JdbcDatabaseTest.java b/bramble-core/src/test/java/org/briarproject/bramble/db/JdbcDatabaseTest.java index d81cfd85149cdc1ec6e2a53d0b1a4de2b4afeb7a..4237af191b8123ba70c3deacb79e21b0124ef725 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/db/JdbcDatabaseTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/db/JdbcDatabaseTest.java @@ -19,6 +19,8 @@ import org.briarproject.bramble.api.sync.MessageStatus; import org.briarproject.bramble.api.sync.ValidationManager.State; import org.briarproject.bramble.api.system.Clock; import org.briarproject.bramble.api.transport.IncomingKeys; +import org.briarproject.bramble.api.transport.KeySet; +import org.briarproject.bramble.api.transport.KeySetId; import org.briarproject.bramble.api.transport.OutgoingKeys; import org.briarproject.bramble.api.transport.TransportKeys; import org.briarproject.bramble.system.SystemClock; @@ -34,7 +36,6 @@ import java.sql.Connection; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -42,6 +43,10 @@ import java.util.Random; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; import static java.util.concurrent.TimeUnit.SECONDS; import static org.briarproject.bramble.api.db.Metadata.REMOVE; import static org.briarproject.bramble.api.sync.Group.Visibility.INVISIBLE; @@ -86,6 +91,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { private final Message message; private final TransportId transportId; private final ContactId contactId; + private final KeySetId keySetId; JdbcDatabaseTest() throws Exception { groupId = new GroupId(getRandomId()); @@ -101,6 +107,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { message = new Message(messageId, groupId, timestamp, raw); transportId = new TransportId("id"); contactId = new ContactId(1); + keySetId = new KeySetId(1); } protected abstract JdbcDatabase createDatabase(DatabaseConfig config, @@ -190,9 +197,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { // The contact has not seen the message, so it should be sendable Collection<MessageId> ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE); - assertEquals(Collections.singletonList(messageId), ids); + assertEquals(singletonList(messageId), ids); ids = db.getMessagesToOffer(txn, contactId, 100); - assertEquals(Collections.singletonList(messageId), ids); + assertEquals(singletonList(messageId), ids); // Changing the status to seen = true should make the message unsendable db.raiseSeenFlag(txn, contactId, messageId); @@ -228,9 +235,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { // Marking the message delivered should make it sendable db.setMessageState(txn, messageId, DELIVERED); ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE); - assertEquals(Collections.singletonList(messageId), ids); + assertEquals(singletonList(messageId), ids); ids = db.getMessagesToOffer(txn, contactId, 100); - assertEquals(Collections.singletonList(messageId), ids); + assertEquals(singletonList(messageId), ids); // Marking the message invalid should make it unsendable db.setMessageState(txn, messageId, INVALID); @@ -279,9 +286,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { // Sharing the group should make the message sendable db.setGroupVisibility(txn, contactId, groupId, true); ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE); - assertEquals(Collections.singletonList(messageId), ids); + assertEquals(singletonList(messageId), ids); ids = db.getMessagesToOffer(txn, contactId, 100); - assertEquals(Collections.singletonList(messageId), ids); + assertEquals(singletonList(messageId), ids); // Unsharing the group should make the message unsendable db.setGroupVisibility(txn, contactId, groupId, false); @@ -324,9 +331,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { // Sharing the message should make it sendable db.setMessageShared(txn, messageId); ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE); - assertEquals(Collections.singletonList(messageId), ids); + assertEquals(singletonList(messageId), ids); ids = db.getMessagesToOffer(txn, contactId, 100); - assertEquals(Collections.singletonList(messageId), ids); + assertEquals(singletonList(messageId), ids); db.commitTransaction(txn); db.close(); @@ -352,7 +359,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { // The message is just the right size to send ids = db.getMessagesToSend(txn, contactId, size); - assertEquals(Collections.singletonList(messageId), ids); + assertEquals(singletonList(messageId), ids); db.commitTransaction(txn); db.close(); @@ -384,7 +391,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { db.lowerAckFlag(txn, contactId, Arrays.asList(messageId, messageId1)); // Both message IDs should have been removed - assertEquals(Collections.emptyList(), db.getMessagesToAck(txn, + assertEquals(emptyList(), db.getMessagesToAck(txn, contactId, 1234)); // Raise the ack flag again @@ -415,7 +422,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { // Retrieve the message from the database and mark it as sent Collection<MessageId> ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE); - assertEquals(Collections.singletonList(messageId), ids); + assertEquals(singletonList(messageId), ids); db.updateExpiryTime(txn, contactId, messageId, Integer.MAX_VALUE); // The message should no longer be sendable @@ -626,31 +633,31 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { // The group should not be visible to the contact assertEquals(INVISIBLE, db.getGroupVisibility(txn, contactId, groupId)); - assertEquals(Collections.emptyMap(), + assertEquals(emptyMap(), db.getGroupVisibility(txn, groupId)); // Make the group visible to the contact db.addGroupVisibility(txn, contactId, groupId, false); assertEquals(VISIBLE, db.getGroupVisibility(txn, contactId, groupId)); - assertEquals(Collections.singletonMap(contactId, false), + assertEquals(singletonMap(contactId, false), db.getGroupVisibility(txn, groupId)); // Share the group with the contact db.setGroupVisibility(txn, contactId, groupId, true); assertEquals(SHARED, db.getGroupVisibility(txn, contactId, groupId)); - assertEquals(Collections.singletonMap(contactId, true), + assertEquals(singletonMap(contactId, true), db.getGroupVisibility(txn, groupId)); // Unshare the group with the contact db.setGroupVisibility(txn, contactId, groupId, false); assertEquals(VISIBLE, db.getGroupVisibility(txn, contactId, groupId)); - assertEquals(Collections.singletonMap(contactId, false), + assertEquals(singletonMap(contactId, false), db.getGroupVisibility(txn, groupId)); // Make the group invisible again db.removeGroupVisibility(txn, contactId, groupId); assertEquals(INVISIBLE, db.getGroupVisibility(txn, contactId, groupId)); - assertEquals(Collections.emptyMap(), + assertEquals(emptyMap(), db.getGroupVisibility(txn, groupId)); db.commitTransaction(txn); @@ -665,24 +672,22 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { Connection txn = db.startTransaction(); // Initially there should be no transport keys in the database - assertEquals(Collections.emptyMap(), - db.getTransportKeys(txn, transportId)); + assertEquals(emptyList(), db.getTransportKeys(txn, transportId)); // Add the contact, the transport and the transport keys db.addLocalAuthor(txn, localAuthor); assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(), true, true)); db.addTransport(txn, transportId, 123); - db.addTransportKeys(txn, contactId, keys); + assertEquals(keySetId, db.addTransportKeys(txn, contactId, keys)); // Retrieve the transport keys - Map<ContactId, TransportKeys> newKeys = - db.getTransportKeys(txn, transportId); + Collection<KeySet> newKeys = db.getTransportKeys(txn, transportId); assertEquals(1, newKeys.size()); - Entry<ContactId, TransportKeys> e = - newKeys.entrySet().iterator().next(); - assertEquals(contactId, e.getKey()); - TransportKeys k = e.getValue(); + KeySet ks = newKeys.iterator().next(); + assertEquals(keySetId, ks.getKeySetId()); + assertEquals(contactId, ks.getContactId()); + TransportKeys k = ks.getTransportKeys(); assertEquals(transportId, k.getTransportId()); assertKeysEquals(keys.getPreviousIncomingKeys(), k.getPreviousIncomingKeys()); @@ -695,8 +700,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { // Removing the contact should remove the transport keys db.removeContact(txn, contactId); - assertEquals(Collections.emptyMap(), - db.getTransportKeys(txn, transportId)); + assertEquals(emptyList(), db.getTransportKeys(txn, transportId)); db.commitTransaction(txn); db.close(); @@ -735,18 +739,18 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(), true, true)); db.addTransport(txn, transportId, 123); - db.updateTransportKeys(txn, Collections.singletonMap(contactId, keys)); + db.updateTransportKeys(txn, + singletonList(new KeySet(keySetId, contactId, keys))); // Increment the stream counter twice and retrieve the transport keys db.incrementStreamCounter(txn, contactId, transportId, rotationPeriod); db.incrementStreamCounter(txn, contactId, transportId, rotationPeriod); - Map<ContactId, TransportKeys> newKeys = - db.getTransportKeys(txn, transportId); + Collection<KeySet> newKeys = db.getTransportKeys(txn, transportId); assertEquals(1, newKeys.size()); - Entry<ContactId, TransportKeys> e = - newKeys.entrySet().iterator().next(); - assertEquals(contactId, e.getKey()); - TransportKeys k = e.getValue(); + KeySet ks = newKeys.iterator().next(); + assertEquals(keySetId, ks.getKeySetId()); + assertEquals(contactId, ks.getContactId()); + TransportKeys k = ks.getTransportKeys(); assertEquals(transportId, k.getTransportId()); OutgoingKeys outCurr = k.getCurrentOutgoingKeys(); assertEquals(rotationPeriod, outCurr.getRotationPeriod()); @@ -771,19 +775,19 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(), true, true)); db.addTransport(txn, transportId, 123); - db.updateTransportKeys(txn, Collections.singletonMap(contactId, keys)); + db.updateTransportKeys(txn, + singletonList(new KeySet(keySetId, contactId, keys))); // Update the reordering window and retrieve the transport keys new Random().nextBytes(bitmap); - db.setReorderingWindow(txn, contactId, transportId, rotationPeriod, + db.setReorderingWindow(txn, keySetId, transportId, rotationPeriod, base + 1, bitmap); - Map<ContactId, TransportKeys> newKeys = - db.getTransportKeys(txn, transportId); + Collection<KeySet> newKeys = db.getTransportKeys(txn, transportId); assertEquals(1, newKeys.size()); - Entry<ContactId, TransportKeys> e = - newKeys.entrySet().iterator().next(); - assertEquals(contactId, e.getKey()); - TransportKeys k = e.getValue(); + KeySet ks = newKeys.iterator().next(); + assertEquals(keySetId, ks.getKeySetId()); + assertEquals(contactId, ks.getContactId()); + TransportKeys k = ks.getTransportKeys(); assertEquals(transportId, k.getTransportId()); IncomingKeys inCurr = k.getCurrentIncomingKeys(); assertEquals(rotationPeriod, inCurr.getRotationPeriod()); @@ -830,18 +834,18 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { db.addLocalAuthor(txn, localAuthor); Collection<ContactId> contacts = db.getContacts(txn, localAuthor.getId()); - assertEquals(Collections.emptyList(), contacts); + assertEquals(emptyList(), contacts); // Add a contact associated with the local author assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(), true, true)); contacts = db.getContacts(txn, localAuthor.getId()); - assertEquals(Collections.singletonList(contactId), contacts); + assertEquals(singletonList(contactId), contacts); // Remove the local author - the contact should be removed db.removeLocalAuthor(txn, localAuthor.getId()); contacts = db.getContacts(txn, localAuthor.getId()); - assertEquals(Collections.emptyList(), contacts); + assertEquals(emptyList(), contacts); assertFalse(db.containsContact(txn, contactId)); db.commitTransaction(txn); @@ -1560,9 +1564,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { // The message should be sendable Collection<MessageId> ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE); - assertEquals(Collections.singletonList(messageId), ids); + assertEquals(singletonList(messageId), ids); ids = db.getMessagesToOffer(txn, contactId, 100); - assertEquals(Collections.singletonList(messageId), ids); + assertEquals(singletonList(messageId), ids); // The raw message should not be null assertNotNull(db.getRawMessage(txn, messageId)); diff --git a/bramble-core/src/test/java/org/briarproject/bramble/transport/TransportKeyManagerImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/transport/TransportKeyManagerImplTest.java index 4dc1f980240d9d91daa57860c5da1ea40d47b93d..213400f1e6096181c5c18a8fdc3bd437d64a4c6e 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/transport/TransportKeyManagerImplTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/transport/TransportKeyManagerImplTest.java @@ -8,6 +8,8 @@ import org.briarproject.bramble.api.db.Transaction; import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.system.Clock; import org.briarproject.bramble.api.transport.IncomingKeys; +import org.briarproject.bramble.api.transport.KeySet; +import org.briarproject.bramble.api.transport.KeySetId; import org.briarproject.bramble.api.transport.OutgoingKeys; import org.briarproject.bramble.api.transport.StreamContext; import org.briarproject.bramble.api.transport.TransportKeys; @@ -22,14 +24,13 @@ import org.junit.Test; import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; -import java.util.LinkedHashMap; import java.util.List; -import java.util.Map; import java.util.Random; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.briarproject.bramble.api.transport.TransportConstants.MAX_CLOCK_DIFFERENCE; import static org.briarproject.bramble.api.transport.TransportConstants.PROTOCOL_VERSION; @@ -55,6 +56,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { private final long rotationPeriodLength = maxLatency + MAX_CLOCK_DIFFERENCE; private final ContactId contactId = new ContactId(123); private final ContactId contactId1 = new ContactId(234); + private final KeySetId keySetId = new KeySetId(345); private final SecretKey tagKey = TestUtils.getSecretKey(); private final SecretKey headerKey = TestUtils.getSecretKey(); private final SecretKey masterKey = TestUtils.getSecretKey(); @@ -62,11 +64,12 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { @Test public void testKeysAreRotatedAtStartup() throws Exception { - Map<ContactId, TransportKeys> loaded = new LinkedHashMap<>(); TransportKeys shouldRotate = createTransportKeys(900, 0); TransportKeys shouldNotRotate = createTransportKeys(1000, 0); - loaded.put(contactId, shouldRotate); - loaded.put(contactId1, shouldNotRotate); + Collection<KeySet> loaded = asList( + new KeySet(keySetId, contactId, shouldRotate), + new KeySet(keySetId, contactId1, shouldNotRotate) + ); TransportKeys rotated = createTransportKeys(1000, 0); Transaction txn = new Transaction(null, false); @@ -91,7 +94,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { } // Save the keys that were rotated oneOf(db).updateTransportKeys(txn, - Collections.singletonMap(contactId, rotated)); + singletonList(new KeySet(keySetId, contactId, rotated))); // Schedule key rotation at the start of the next rotation period oneOf(scheduler).schedule(with(any(Runnable.class)), with(rotationPeriodLength - 1), with(MILLISECONDS)); @@ -129,6 +132,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { } // Save the keys oneOf(db).addTransportKeys(txn, contactId, rotated); + will(returnValue(keySetId)); }}); TransportKeyManager transportKeyManager = new TransportKeyManagerImpl( @@ -179,6 +183,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { will(returnValue(transportKeys)); // Save the keys oneOf(db).addTransportKeys(txn, contactId, transportKeys); + will(returnValue(keySetId)); }}); TransportKeyManager transportKeyManager = new TransportKeyManagerImpl( @@ -218,6 +223,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { will(returnValue(transportKeys)); // Save the keys oneOf(db).addTransportKeys(txn, contactId, transportKeys); + will(returnValue(keySetId)); // Increment the stream counter oneOf(db).incrementStreamCounter(txn, contactId, transportId, 1000); }}); @@ -268,6 +274,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { will(returnValue(transportKeys)); // Save the keys oneOf(db).addTransportKeys(txn, contactId, transportKeys); + will(returnValue(keySetId)); }}); TransportKeyManager transportKeyManager = new TransportKeyManagerImpl( @@ -308,13 +315,14 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { will(returnValue(transportKeys)); // Save the keys oneOf(db).addTransportKeys(txn, contactId, transportKeys); + will(returnValue(keySetId)); // Encode a new tag after sliding the window oneOf(transportCrypto).encodeTag(with(any(byte[].class)), with(tagKey), with(PROTOCOL_VERSION), with((long) REORDERING_WINDOW_SIZE)); will(new EncodeTagAction(tags)); // Save the reordering window (previous rotation period, base 1) - oneOf(db).setReorderingWindow(txn, contactId, transportId, 999, + oneOf(db).setReorderingWindow(txn, keySetId, transportId, 999, 1, new byte[REORDERING_WINDOW_SIZE / 8]); }}); @@ -345,8 +353,8 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { @Test public void testKeysAreRotatedToCurrentPeriod() throws Exception { TransportKeys transportKeys = createTransportKeys(1000, 0); - Map<ContactId, TransportKeys> loaded = - Collections.singletonMap(contactId, transportKeys); + Collection<KeySet> loaded = + singletonList(new KeySet(keySetId, contactId, transportKeys)); TransportKeys rotated = createTransportKeys(1001, 0); Transaction txn = new Transaction(null, false); Transaction txn1 = new Transaction(null, false); @@ -393,7 +401,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { } // Save the keys that were rotated oneOf(db).updateTransportKeys(txn1, - Collections.singletonMap(contactId, rotated)); + singletonList(new KeySet(keySetId, contactId, rotated))); // Schedule key rotation at the start of the next rotation period oneOf(scheduler).schedule(with(any(Runnable.class)), with(rotationPeriodLength), with(MILLISECONDS));